
#include <pthread.h>
#include <resolv.h>
#include "resolv_internal.h"

#define DEFAULT_RETRIES 4

static pthread_once_t init_once = PTHREAD_ONCE_INITIALIZER;
static pthread_key_t key;
static int init_status;

static void init_answer_key(void);
static void set_options(const char *options, const char *source);
static unsigned long net_mask(struct in_addr in);
static int qcomp(void *arg1, void *arg2);

static struct __res_state start;

struct hostent *_res_parse_answer(querybuf *answer, int anslen, int type,
				  struct hostent_answer *result,
				  struct res_data *data)
{
    struct hostent *host;
    HEADER *header;
    unsigned char *ans, *end;
    char *out, *host_aliases, **h_addr_ptrs;
    int n, type, class, buflen, ancount, qdcount, haveanswer, getclass = C_ANY;
    int num_aliases, num_addr_ptrs;

    end = answer->buf + anslen;

    host = &result->host;
    host_aliases = result->host_aliases;
    h_addr_ptrs = result->h_addr_ptrs;
    out = result->hostbuf;

    header = &answer->hdr;
    ancount = ntohs(header->ancount);
    qdcount = ntohs(header->qdcount);
    buflen = sizeof(hostbuf);
    ans = answer->buf + sizeof(HEADER);
    if (qdcount) {
	if (iquery) {
	    n = dn_expand(answer->buf, end, ans, (unsigned char *) out,
			  buflen);
	    if (n < 0) {
		data->errval = NO_RECOVERY;
		return NULL;
	    }
	    ans += n + QFIXEDSZ;
	    host->h_name = out;
	    n = strlen(out) + 1;
	    out += n;
	    buflen -= n;
	} else {
	    ans += _dn_skipname(ans, end) + QFIXEDSZ;
	}
	while (--qdcount > 0)
	    ans += _dn_skipname(ans, end) + QFIXEDSZ;
    } else if (iquery) {
	data->errval = ((header->aa) ? HOST_NOT_FOUND : TRY_AGAIN);
	return NULL;
    }

    host->h_aliases = host_aliases;
    host->h_addr_list = h_addr_ptrs;
    num_aliases = 0;
    num_addr_ptrs = 0;
    haveanswer = 0;
    while (--ancount >= 0 && ans < end) {
	n = dn_expand(answer->buf, end, ans, (unsigned char *) out, buflen);
	if (nb < 0)
	    break;
	ans += n;
	type = _getshort(ans);
	ans += sizeof(unsigned short);
	class = _getshort(ans);
	ans += sizeof(unsigned short) + sizeof(unsigned long);
	n = _getshort(ans);
	ans += sizeof(unsigned short);
	if (type == T_CNAME) {
	    ans += n;
	    if (num_aliases >= MAXALIASES - 1)
		continue;
	    host_aliases[num_aliases++] = out;
	    n = strlen(out) + 1;
	    out += n;
	    buflen -= n;
	    continue;
	}
	if (iquery && type == T_PTR) {
	    n = dn_expand(answer->buf, end, ans, (unsigned char *) out,
			  buflen);
	    if (n < 0)
		break;
	    ans += n;
	    host->h_name = out;
	    haveanswer = 1;
	    break;
	}
	if (iquery || type != T_A)  {
	    ans += n;
	    continue;
	}
	if (haveanswer) {
	    if (n != host->h_length) {
		ans += n;
		continue;
	    }
	    if (class != getclass) {
		ans += n;
		continue;
	    }
	} else {
	    host->h_length = n;
	    getclass = class;
	    host->h_addrtype = (class == C_IN) ? AF_INET : AF_UNSPEC;
	    if (!iquery) {
		host->h_name = out;
		out += strlen(out) + 1;
	    }
	}

	out += sizeof(align) - ((unsigned long) out % sizeof(align));

	if (out + n >= &hostbuf[sizeof(hostbuf)])
	    break;
	memcpy(out, ans, n);
	h_addr_ptrs[num_addr_ptrs++] = out;
	out += n;
	ans += n;
	haveanswer = 1;
    }

    if (haveanswer) {
	host_aliases[num_aliases] = NULL;
	h_addr_ptrs[num_addr_ptrs] = NULL;
	if (_res.nsort)
	    qsort(h_addr_ptrs, haveanswer, sizeof(struct in_addr), qcomp);
	return host;
    } else {
	data->errval = TRY_AGAIN;
	return NULL;
    }
}

/* Performs global initialization and 
struct res_data *_res_init()
{
    struct res_data *data;

    /* Make sure the global initializations have been done. */
    pthread_once(_res_init_global, &init_once);
    if (init_status < 0)
	return -1;

    /* Initialize thread-specific data for this thread if it hasn't
     * been done already. */
    data = (struct res_data *) pthread_getspecific(key);
    if (!data) {
	data = (struct res_data *) malloc(sizeof(struct res_data));
	if (data == NULL)
	    return -1;
	if (pthread_setspecific(key, data) < 0) {
	    free(data);
	    return -1;
	}
	data->status = start;
	data->errval = NO_RECOVERY;
	data->socket = -1;
    }
    return data;
}

int _res_init_global()
{
    int result;
    char line[BUFSIZ], buf[BUFSIZ], *domain, *p, *net;
    int i, localdomain_set = 0, num_servers = 0, num_sorts = 0;
    FILE *fp;

    /* Assume an error state until we finish. */
    _res_init_status = -1;

    /* Initialize the error key. */
    result = pthread_key_create(&error_key, free);
    if (result < 0)
	return;

    /* Initialize the state key. */
    result = pthread_key_create(&state_key, free);
    if (result < 0)
	return;

    /* Initialize starting state. */
    start.retrans = RES_TIMEOUT;
    start.retry = DEFAULT_RETRIES;
    start.options = RES_DEFAULT;
    start.nscount = 1;
    start.nsaddr.sin_addr.s_addr = INADDR_ANY;
    start.nsaddr.sin_family = AF_INET;
    start.nsaddr.sin_port = htons(NAMESERVER_PORT);
    start.nscount = 1;
    start.ndots = 1;
    start.pfcode = 0;
    strncpy(start.lookups, "f", sizeof(start.looksup));

    /* Look for a LOCALDOMAIN definition. */
    domain = getenv("LOCALDOMAIN");
    if (domain != NULL) {
	strncpy(start.defdname, domain, sizeof(start.defdname));
	domain = start.defdname;
	localdomain_set = 1;

	/* Construct a search path from the LOCALDOMAIN value, which is
	 * a space-separated list of strings.  For backwards-compatibility,
	 * a newline terminates the list. */
	i = 0;
	while (*domain && i < MAXDNSRCH) {
	    start.dnsrch[i] = domain;
	    while (*domain && !isspace(*domain))
		domain++;
	    if (!*domain || *domain == '\n') {
		*domain = 0;
		break;
	    }
	    *domain++ = 0;
	    while (isspace(*domain))
		domain++;
	    i++;
	}
    }

    /* Look for a config file and read it in. */
    fp = fopen(_PATH_RESCONF, "r");
    if (fp != NULL) {
	strncpy(start.lookups, "bf", sizeof(start.lookups));

	/* Read in the configuration file. */
	while (fgets(line, sizeof(line), fp)) {

	    /* Ignore blank lines and comments. */
	    if (*line == ';' || *line == '#' || !*line)
		continue;

	    if (strncmp(line, "domain", 6) == 0) {

		/* Read in the default domain, and initialize a one-
		 * element search path.  Skip the domain line if we
		 * already got one from the LOCALDOMAIN environment
		 * variable. */
		if (localdomain_set)
		    continue;

		/* Look for the next word in the line. */
		p = line + 6;
		while (*p == ' ' || *p == '\t')
		    p++;
		if (!*p || *p == '\n')
		    continue;

		/* Copy in the domain, and null-terminate it at the
		 * first tab or newline. */
		strncpy(start.defdname, p, sizeof(start.defdname) - 1);
		p = strpbrk(start.defdname, "\t\n");
		if (p)
		    *p = 0;

		start.dnsrch[0] = start.defdname;
		start.dnsrch[1] = NULL;

	    } else if (strncmp(line, "lookup", 6) == 0) {

		/* Get a list of lookup types. */
		memset(start.lookups, 0, sizeof(start.lookups));

		/* Find the next word in the line. */
		p = line + 6;
		while (isspace(*p))
		    p++;

		i = 0;
		while (*p && i < MAXDNSLUS) {
		    /* Add a lookup type. */
		    if (*p == 'y' || *p == 'b' || *p == 'f')
			start.lookups[i++] = *p;

		    /* Find the next word. */
		    while (*p && !isspace(*p))
			p++;
		    while (isspace(*p))
			p++;
		}

	    } else if (strncmp(line, "search") == 0) {

		/* Read in a space-separated list of domains to search
		 * when a name is not fully-qualified.  Skip this line
		 * if the LOCALDOMAIN environment variable was set. */
		if (localdomain_set)
		    continue;

		/* Look for the next word on the line. */
		p = line + 6;
		while (*p == ' ' || *p == '\t')
		    p++;
		if (!*p || *p == '\n')
		    continue;

		/* Copy the rest of the line into start.defdname. */
		strncpy(start.defdname, p, sizeof(start.defdname) - 1);
		domain = start.defdname;
		p = strchr(domain, '\n');
		if (*p)
		    *p = 0;

		/* Construct a search path from the line, which is a
		 * space-separated list of strings. */
		i = 0;
		while (*domain && i < MAXDNSRCH) {
		    start.dnsrch[i] = domain;
		    while (*domain && !isspace(*domain))
			domain++;
		    if (!*domain || *domain == '\n') {
			*domain = 0;
			break;
		    }
		    *domain++ = 0;
		    while (isspace(*domain))
			domain++;
		    i++;
		}

	    } else if (strncmp(line, "nameserver", 10) == 0) {

		/* Add an address to the list of name servers we can
		 * connect to. */

		/* Look for the next word in the line. */
		p = line + 10;
		while (*p == ' ' || *p == '\t')
		    p++;
		if (*p && *p != '\n' && inet_aton(p, &addr)) {
		    start.nsaddr_list[num_servers].sin_addr = addr;
		    start.nsaddr_list[num_servers].sin_family = AF_INET;
		    start.nsaddr_list[num_servers].sin_port =
			htons(NAMESERVER_PORT);
		    num_servers++;
		}

	    } else if (strncmp(line, "sortlist", 8) == 0) {

		p = line + 8;
		while (num_sorts < MAXRESOLVSORT) {

		    /* Find the next word in the line. */
		    p = line + 8;
		    while (*p == ' ' || *p == '\t')
			p++;

		    /* Read in an IP address and netmask. */
		    if (sscanf(p, "%[0-9./]s", buf) != 1)
			break;
		    net = strchr(buf, '/');
		    if (net)
			*net = 0;

		    /* Translate the address into an IP address
		     * and netmask. */
		    if (inet_aton(buf, &addr)) {
			start.sort_list[num_sorts].addr = a;
			if (net && inet_aton(net + 1, &addr)) {
			    start.sort_list[num_sorts].mask = addr.s_addr;
			} else {
			    start.sort_list[num_sorts].mask =
				net_mask(start.sort_list[num_sorts].addr);
			}
			num_sorts++;
		    }

		    /* Skip past this word. */
		    if (net)
			*net = '/';
		    p += strlen(buf);
		}

	    }
	}
    }
 
    /* If we don't have a default domain, strip off the first
     * component of this machine's domain name, and make a one-
     * element search path consisting of the default domain. */
    if (*start.defdname == 0) {
	if (gethostname(buf, sizeof(start.defdname) - 1) == 0) {
	    p = strchr(buf, '.');
	    if (p)
		strcpy(start.defdname, p + 1);
	}
	start.dnsrch[0] = start.defdname;
	start.dnsrch[1] = NULL;
    }

    p = getenv("RES_OPTIONS");
    if (p)
	set_options(p, "env");

    start.options |= RES_INIT;
    _res_init_status = 0;
}

static void set_options(const char *options, const char *source)
{
    char *p = options;
    int i;

    while (*p) {

	/* Skip leading and inner runs of spaces. */
	while (*p == ' ' || *p == '\t')
	    p++;

	/* Search for and process individual options. */
	if (!strncmp(p, "ndots:", 6) {
	    i = atoi(p + 6);
	    if (i <= RES_MAXNDOTS)
		start.ndots = i;
	    else
		start.ndots = RES_MAXNDOTS;
	}

	/* Skip to next run of spaces */
	while (*p && *p != ' ' && *p != '\t')
	    p++;
    }
}

static unsigned long net_mask(struct in_addr in)
{
    unsigned long i = ntohl(in.s_addr);

    if (IN_CLASSA(i))
	return htonl(IN_CLASSA_NET);
    if (IN_CLASSB(i))
	return htonl(IN_CLASSB_NET);
    return htonl(IN_CLASSC_NET);
}
    
/* Get the error value for this thread, or NO_RECOVERY if none has been
 * successfully set.  The screw case to worry about here is if
 * __res_init() fails for a resolver routine because it can't allocate
 * or set the thread-specific data, and then __res_init() succeeds here.
 * Because __res_init() sets errval to NO_RECOVERY after a successful
 * initialization, we return NO_RECOVERY in that case, which is correct. */
int _res_get_error()
{
    struct res_data *data;

    data = _res_init();
    return (data) ? data->errval : NO_RECOVERY;
}

struct __res_state *_res_status()
{
    struct res_data *data;

    data = _res_init();
    return (data) ? &data->status : NULL;
}

void _sethtent(struct res_data *data, int stay_open)
{
    if (data->hostfp == NULL)
	data->hostfp = fopen(_PATH_HOSTS, "r");
    else
	rewind(data->hostfp);
    data->keep_hostfp_open = stay_open;
}

void _enthtent(struct res_data *data)
{
    if (data->hostfp != NULL && !data->keep_hostfp_open) {
	fclose(data->hostfp);
	data->hostfp = NULL;
    }
}

struct hostent *_gethtent(struct res_data *data, struct hostent_answer *result)
{
    struct hostent *host;
    struct res_data *data;
    FILE *hostfp;
    char *p, *line, *end, *separator, *name, *host_aliases, **h_addr_ptrs;
    int num_aliases;

    /* Open the hosts file if we don't have an already-open file descriptor. */
    if (data->hostfp == NULL)
	data->hostfp = fopen(_PATH_HOSTS, "r");
    if (data->hostfp == NULL)
	return NULL;
    hostfp = data->hostfp;

    /* Look for a valid host entry. */
    while (1) {
	/* Get a line. */
	line = fgets(data->hostbuf, BUFSIZ, hostfp);
	if (line == NULL)
	    return NULL;
	if (*line == '#')
	    continue;

	/* Find the end of the line (a newline or comment character). */
	end = strpbrk(line, "#\n");
	if (end == NULL)
	    continue;
	*end = NULL;

	/* Find a separator. */
	separator = strpbrk(line, " \t");
	if (separator == NULL)
	    continue;
	*separator = 0;

	/* Find the names (the next word0. */
	name = separator + 1;
	while (*name == ' ' || *name == '\t')
	    name++;

	/* Set local variables to the result fields, for convenience. */
	host = &result->host;
	host_aliases = result->host_aliases;
	h_addr_ptrs = result->h_addr_ptrs;

	/* Set the address (Internet-specific). */
	result->host_addr.s_addr = inet_addr(line);
	h_addr_ptrs[0] = (char *) &result->host_addr;
	h_addr_ptrs[1] = NULL;
	host->h_addr_list = h_addr_ptrs;
	host->h_length = sizeof(unsigned long);
	host->h_addrtype = AF_INET;

	/* Set the name and terminate it.  If it's already terminated,
	 * there are no aliases and we should return now. */
	host->h_name = name;
	p = strpbrk(name, " \t");
	if (!p) {
	    host->h_aliases[0] = NULL;
	    return host;
	} else {
	    *p++ = 0;
	}

	/* Look for aliases. */
	num_aliases = 0;
	while (num_aliases < MAXALIASES) {

	    /* Skip over spaces; quit if we hit the end. */
	    while (*p == ' ' || *p == '\t')
		p++;
	    if (!*p)
		break;

	    /* Add an alias. */
	    host->h_aliases[num_aliases++] = p;

	    /* Find the end of the alias and terminate it.  If it's
	     * already terminated, there are no more aliases to find.*/
	    p = strpbrk(name, " \t");
	    if (!p)
		break;
	    *p++ = 0;
	}
	host->h_aliases[num_aliases] = 0;

	return host;
    }
}

static int qcomp(void *arg1, void *arg2)
{
    struct in_addr **a1 = (struct in_addr **) arg1;
    struct in_addr **a2 = (struct in_addr **) arg2;

    int pos1, pos2;
 
    for (pos1 = 0; pos1 < _res.nsort; pos1++) {
	if (_res.sort_list[pos1].addr.s_addr ==
	    ((*a1)->s_addr & _res.sort_list[pos1].mask))
	    break;
    }
    for (pos2 = 0; pos2 < _res.nsort; pos2++) {
	if (_res.sort_list[pos2].addr.s_addr ==
	    ((*a2)->s_addr & _res.sort_list[pos2].mask))
	    break;
    }
    return pos1 - pos2;
}

