
#include <pthread.h>
#include <resolv.h>
#include <arpa/nameser.h>
#include "res_internal.h"

enum { SEND_GIVE_UP = -1, SEND_TRY_NEXT = -2, SEND_TRY_SAME = -3,
	   SEND_TIMEOUT, SEND_TRUNCATED = -4 };

static int send_datagram(int server, int sock, const char *buf, int buflen,
			 char *answer, int anslen, struct res_data *data);
static int send_circuit(int server, const char *buf, int buflen, char *answer,
			int anslen, struct res_data *data);
static int close_save_errno(int sock);

int res_send(const char *buf, int buflen, char *answer, int anslen)
{
    struct res_data *data;
    struct sockaddr_in local;
    int use_virtual_circuit, result, udp_sock, have_seen_same, terrno = 0;

    data = _res_init();
    if (!data)
	return -1;

    try = 0;
    server = 0;

    /* Try doing connectionless queries if appropriate. */
    if (!(data->status.options & RES_USEVC) && buflen <= PACKETSZ) {
	/* Create and bind a local UDP socket. */
	udp_sock = socket(AF_INET, SOCK_DGRAM, 0);
	if (udp_sock < 0)
	    return -1;
	local.sin_family = AF_INET;
	local.sin_addr.s_addr = htonl(INADDR_ANY);
	local.sin_port = htons(0);
	if (bind(udp_sock, (struct sockaddr *) &local, sizeof(local)) < 0) {
	    close(udp_sock);
	    return -1;
	}

	/* Cycle through the retries and servers, sending off queries and
	 * waiting for responses. */
	for (; try < data->status.retry; try++) {
	    for (; server < data->status.nscount; server++) {
		result = send_datagram(server, udp_sock, buf, buflen, answer,
				       anslen, data);
		if (result == SEND_TIMEOUT)
		    terrno = ETIMEDOUT;
		else if (result != SEND_TRY_NEXT)
		    break;
	    }
	}

	close(udp_sock);
	errno = (terrno == ETIMEDOUT) ? ETIMEDOUT : ECONNREFUSED;
	if (result != SEND_TRUNCATED)
	    return (result >= 0) ? result : -1;
    }

    /* Either we have to use the virtual circuit, or the server couldn't
     * fit its response in a UDP packet.  Cycle through the retries and
     * servers, sending off queries and waiting for responses.  Allow a
     * response of SEND_TRY_SAME to cause an extra retry once. */
    for (; try < data->status.retry; try++) {
	for (; server < data->status.nscount; server++) {
	    result = send_circuit(server, buf, buflen, answer, anslen, data);
	    terrno = errno;
	    if (result == SEND_TRY_SAME) {
		if (!have_seen_same)
		    server--;
		have_seen_same = 1;
	    } else if (result != SEND_TRY_NEXT) {
		break;
	    }
	}
    }

    errno = ternno;
    return (result >= 0) ? result : -1;
}

static int send_datagram(int server, int sock, const char *buf, int buflen,
			 char *answer, int anslen, struct res_data *data)
{
    int count;
    struct sockaddr_in local_addr;
    HEADER *request = (HEADER *) buf, *response = (HEADER *) answer;

    /* Send a packet to the server. */
    count = sendto(sock, buf, buflen,
		   (struct sockaddr *) &data->status.nsaddr_list[server],
		   sizeof(struct sockaddr_in));
    if (count != buflen)
	return SEND_TRY_NEXT;

    /* Await a reply with the correct ID.  (REPLACE WITH TIMED READ.) */
    while (1) {
	count = recvfrom(sock, answer, anslen, 0, NULL, 0);
	if (count < 0)
	    return SEND_TRY_NEXT;
	/* If the ID is wrong, it's from an old query; ignore it. */
	if (response->id == request->id)
	    break;
    }

    /* Report a truncated response unless RES_IGNTC is set.  This will
     * cause the res_send() loop to fall back to TCP. */
    if (data->tc && !(data->status.options & RES_IGNTC))
	return SEND_TRUNCATED;

    return count;
}

static int send_circuit(int server, const char *buf, int buflen, char *answer,
			int anslen, struct res_data *data)
{
    HEADER *response = (HEADER *) answer;
    int sock = -1, result, n, response_len;
    unsigned short len;
    struct iovec iov[2];
    char *p, junk[512];

    /* If data->sock is valid, then it's an open connection to the
     * first server.  Grab it if it's appropriate; close it if not. */
    if (data->sock) {
	if (server == 0)
	    sock = data->sock;
	else
	    close(data->sock);
	data->sock = -1;
    }

    /* Initialize our socket if we didn't grab it from data. */
    if (sock == -1) {
	sock = socket(AF_INET, SOCK_STREAM, 0);
	if (sock < 0)
	    return SEND_GIVE_UP;
	result = connect(sock,
			 (struct sockaddr *) &data->status.nsaddr_list[server],
			 sizeof(struct sockaddr_in));
	if (result < 0) {
	    close_save_errno(sock);
	    return SEND_TRY_NEXT;
	}
    }

    /* Send length and message. */
    len = htons((unsigned short) buflen);
    iov[0].iov_base = (caddr_t) &len;
    iov[0].iov_len = sizeof(len);
    iov[1].iov_base = (char *) buf;
    iov[1].iov_len = buflen;
    if (writev(sock, iov, 2) != sizeof(len) + buflen) {
	close_save_errno(sock);
	return SEND_TRY_NEXT;
    }

    /* Receive length. */
    p = (char *) &len;
    n = sizeof(len);
    while (n) {
	count = read(sock, p, n);
	if (count <= 0) {
	    /* If we got ECONNRESET, the remote server may have restarted,
	     * and we report SEND_TRY_SAME.  (The main loop will only
	     * allow one of these, so we don't have to worry about looping
	     * indefinitely.) */
	    close_save_errno(sock);
	    return (errno == ECONNRESET) ? SEND_TRY_SAME : SEND_TRY_NEXT;
	}
	p += count;
	n -= count;
    }
    len = ntohs(len);
    response_len = (len > anslen) ? anslen : len;
    len -= response_len;

    /* Receive message. */
    p = answer;
    n = response_len;
    while (n) {
	count = read(sock, p, n);
	if (count <= 0) {
	    close_save_errno(sock);
	    return SEND_TRY_NEXT;
	}
	p += count;
	n -= count;
    }

    /* If the reply is longer than our answer buffer, set the truncated
     * bit and flush the rest of the reply, to keep the connection in
     * sync. */
    if (len) {
	answer->tc = 1;
	while (len) {
	    n = (len > sizeof(junk)) ? sizeof(junk) : len;
	    count = read(sock, junk, n);
	    if (count <= 0) {
		close_save_errno(sock);
		return response_len;
	    }
	    len -= count;
	}
    }

    /* If this is the first server, and RES_USEVC and RES_STAYOPEN are
     * both set, save the connection.  Otherwise, close it. */
    if (server == 0 && (data->status.options & RES_USEVC &&
			data->status.options & RES_STAYOPEN))
	data->socket = sock;
    else
	close_save_errno(sock);
    
    return response_len;
}

static int close_save_errno(int sock)
{
    int terrno;

    terrno = errno;
    close(sock);
    errno = terrno;
}

