#include <stdio.h>
#include <stdlib.h>
#include <ctype.h>
#include <strings.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/ioctl.h>
#ifdef SOLARIS
#include <sys/filio.h>
#endif
#include <sys/param.h>
#include <netdb.h>
#include <errno.h>
#include <arpa/inet.h> /* hmmm... */
#include "Connect.h"
#include "ip.h"

inetaddr *addrhash[INETTABLESIZE], *inethash[INETTABLESIZE];
int nameResolve;

#define hash_addr(a) ((a)%INETTABLESIZE)
#define hash_inet(h) (((h)->sin_port + (h)->sin_addr.s_addr)%INETTABLESIZE)

/*
 * Added family check when generalized, but haven't actually verified
 * this is the right thing.
 */
#define ip_eq(a, b) (((a)->sin_family == (b)->sin_family) && \
		       ((a)->sin_port == (b)->sin_port) && \
		       ((a)->sin_addr.s_addr == (b)->sin_addr.s_addr))

Trap ip_init()
{
  static int init = 0;
  int i;	

  if (!init)
    {
      for (i = 0; i < INETTABLESIZE; i++)
	{
	  addrhash[i] = NULL;
	  inethash[i] = NULL;
	}

      nameResolve = 1;
    }

  return OK;
}

inetaddr *ip_addrlookup(addr)
     Addr addr;
{
  inetaddr *a;

  a = addrhash[hash_addr(addr)];

  while (a != NULL)
    {
      if (addr == a->address)
	break;
      a = a->nextaddr;
    }

  return a;
}

inetaddr *ip_inetlookup(addr)
     struct sockaddr_in *addr;
{
  inetaddr *a;

  a = inethash[hash_inet(addr)];

  while (a != NULL)
    {
      if (ip_eq(addr, &(a->netaddr)))
	break;
      a = a->nextinet;
    }

  return a;
}

Trap ip_newaddress(domain, new, addr)
     int domain;
     inetaddr **new, *addr;
{
  inetaddr *newaddr;
  int ahash, ihash;
  Addr a;

  newaddr = (inetaddr *)malloc(sizeof(inetaddr));
  if (newaddr == NULL)
    Return(GENERAL_MALLOC, sizeof(inetaddr));
  memcpy(newaddr, addr, sizeof(inetaddr));

  if (Connect_NewAddress(domain, &a))
    {
      free(newaddr);
      return CHECK;
    }

  ahash = hash_addr(a);
  ihash = hash_inet(&(addr->netaddr));

  newaddr->address = a;
  newaddr->nextaddr = addrhash[ahash];
  newaddr->nextinet = inethash[ihash];
  addrhash[ahash] = newaddr;
  inethash[ihash] = newaddr;

  *new = newaddr;
  return OK;
}

void ip_freeaddr(freeaddr)
     inetaddr *freeaddr;
{
  int ahash, ihash;
  inetaddr *find;

  ahash = hash_addr(freeaddr->address);
  ihash = hash_inet(&(freeaddr->netaddr));

  /*
   * Remove from address bucket...
   */
  if (addrhash[ahash] == freeaddr)
    addrhash[ahash] = freeaddr->nextaddr;
  else
    {
      find = addrhash[ahash];
      while (find != NULL && find->nextaddr != freeaddr)
	find = find->nextaddr;
      if (find->nextaddr == freeaddr)
	find->nextaddr = freeaddr->nextaddr;
    }

  /*
   * Remove from inet bucket...
   */
  if (inethash[ihash] == freeaddr)
    inethash[ihash] = freeaddr->nextinet;
  else
    {
      find = inethash[ihash];
      while (find != NULL && find->nextinet != freeaddr)
	find = find->nextinet;
      if (find->nextinet == freeaddr)
	find->nextinet = freeaddr->nextinet;
    }

  if (freeaddr->hostname != NULL)
    free(freeaddr->hostname);
  free(freeaddr);
}

/*
 * ERRORS:
 *
 * CONNECT_BADADDRESSFORMAT: FATAL
 * The address string %s is ill-formed.
 *
 * CONNECT_UNKNOWNNAME: FATAL
 * The hostname %s was not resolvable.
 *
 * CONNECT_DOMAINMISMATCH: FATAL
 * The hostname given did not match the domain it was paired with.
 *
 * CONNECT_NEWADDRESS: INFO
 * Operation generated a new address (%d).
 *
 * IP_UNKNOWNSERVICE: FATAL
 * Service name requested for address (%s) was unknown.
 */
Trap ip_ntoa(domain, protocol, name, address)
     int domain;
     char *protocol, *name;
     Addr *address;
{
  struct hostent *hp;
  struct servent *sp;
  char *n, *p;
  inetaddr addr, *s;

  if (name[0] == '*' && name[1] == '\0')
    {
      *address = Connect_AnyAddress(domain); /* wild */
      return OK;
    }

  n = malloc(strlen(name) + 1 + strlen(protocol));
  if (n == NULL)
    Return(GENERAL_MALLOC, strlen(name) + 1);
  strcpy(n, name);

  memset(&addr, 0, sizeof(addr));
  addr.netaddr.sin_family = AF_INET;

  p = strchr(n, '/');
  if (p == NULL)
    addr.netaddr.sin_port = 0;
  else
    {
      *p = '\0';
      p++;
      if (isdigit(*p))
	addr.netaddr.sin_port = htons(atoi(p));
      else
	{
	  sp = getservbyname(p, protocol);
	  if (sp == NULL)
	    {
	      memmove(n, p, strlen(p) + 1);
	      strcat(n, "/");
	      strcat(n, protocol);
	      ReturnF(IP_UNKNOWNSERVICE, n);
	    }
	  addr.netaddr.sin_port = sp->s_port;
	  /* Note: network routines seem to return values in net byte order */
	}
    }

  if (isdigit(*n))
    {
      addr.netaddr.sin_addr.s_addr = inet_addr(n);
      if (addr.netaddr.sin_addr.s_addr == -1)
	ReturnF(CONNECT_BADADDRESSFORMAT, n);
      free(n);

      /*
       * If they passed the address in dot notation, it's
       * unlikely they'll care to get the name later. If
       * they do, it'll be handled there.
       */
      addr.hostname = NULL;
    }
  else
    {
      char hnam[MAXHOSTNAMELEN];

      if (*n == '.')
	{
	  gethostname(hnam, MAXHOSTNAMELEN);
	  hp = gethostbyname(hnam);
	}
      else
	hp = gethostbyname(n);

      if (hp == NULL)
	ReturnF(CONNECT_UNKNOWNNAME, n);

      free(n);

      if (hp->h_addrtype != AF_INET)
	Return(CONNECT_DOMAINMISMATCH, name);

      addr.hostname = malloc(strlen(hp->h_name) + 1);
      if (addr.hostname != NULL)
	strcpy(addr.hostname, hp->h_name);

      memcpy(&addr.netaddr.sin_addr, hp->h_addr_list[0], 
	     sizeof(addr.netaddr.sin_addr));
    }

  /*
   * At this point, we have inetaddr filled up.
   * Now we look for an existing address which matches it.
   */
  s = ip_inetlookup(&addr.netaddr);
  if (s != NULL)
    {
      if (addr.hostname != NULL)
	free(addr.hostname);
      *address = s->address;
      return OK;
    }

  if (!ip_newaddress(domain, &s, &addr))
    {
      *address = s->address;
      Return(CONNECT_NEWADDRESS, s->address);
    }

  if (addr.hostname != NULL)
    free(addr.hostname);
  return CHECK;
}

Trap ip_aton(address, name, length)
     Addr address;
     char *name;
     int length;
{
  inetaddr *a;
  struct hostent *hp;
  char *dot, port[10];

  if (address == Connect_AnyAddress(Connect_Domain(address))) /* wild */
    {
      strncat(name, "*", length - strlen(name) - 1);
      return OK;
    }

  a = ip_addrlookup(address);
  if (a == NULL)
    Return(CONNECT_BADADDRESS, address);

  if (a->hostname == NULL)
    {
      if (nameResolve)
	hp = gethostbyaddr((char *) &(a->netaddr.sin_addr),
			   sizeof(a->netaddr.sin_addr), AF_INET);
      else
	hp = NULL;

      if (hp != NULL)
	{
	  strncat(name, hp->h_name, length - strlen(name) - 1);
	  a->hostname = malloc(strlen(hp->h_name) + 1);
	  if (a->hostname != NULL)
	    strcpy(a->hostname, hp->h_name);
	}
      else
	{
	  dot = inet_ntoa(a->netaddr.sin_addr);
	  strncat(name, dot, length - strlen(name) - 1);
	  a->hostname = malloc(strlen(dot) + 1);
	  if (a->hostname != NULL)
	    strcpy(a->hostname, dot);
	}
    }
  else
    strncat(name, a->hostname, length - strlen(name) - 1);

  sprintf(port, "/%d", ntohs(a->netaddr.sin_port));
  strncat(name, port, length - strlen(name) - 1);
  return OK;
}

/* Need to have a generic form of this to be called via Connect.
   Probably also for all the Udp_* routines. */
Trap Ip_NameResolve(val)
     int val;
{
  nameResolve = val;

  return OK;
}

