#include <stdio.h>
#include <des.h>
#include <string.h>
#include <stdlib.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <signal.h>
#include <ctype.h>
#include <errno.h>
#include "Connect.h"
#include "inet-udp.h"
#include "Code.h"
#include "literal.h"
#include "slw.h"
#include "regex.h"

static char name[50], pname[100];
static char timestamp[50];
static int verbose = 0;
static int gothup = 0;
static int gotusr1 = 0;

#define MAXLIST 50 /* I don't want to dynamic this right now, and if
		      even this is needed it's way outside of design
		      spec. */
#define MAXTYPES 50

#define Offset(type,field) \
        ((int) (((char *) (&(((type)NULL)->field))) - ((char *) NULL)))

typedef struct _offtable {
  char *name;
  int indirect;
  int offset;
} offtable;

typedef struct _restriction {
  char *type;
  char *failure;
  char *plusip[MAXLIST+1];		/* NULL terminated lists */
  char *plusname[MAXLIST+1];
  char *minusip[MAXLIST+1];
  char *minusname[MAXLIST+1];
} Restriction;

#define roffset(field) Offset(Restriction *,field)

offtable restable[] = {
  { "restriction", 1, roffset(type) },
  { "failure", 1, roffset(failure) },
  { "+ip", 2, roffset(plusip) },
  { "-ip", 2, roffset(minusip) },
  { "+name", 2, roffset(plusname) },
  { "-name", 2, roffset(minusname) },
  { NULL, 0, 0 }
};

Restriction *restrictions = NULL;

freeResInfo(res)
     Restriction *res;
{
  Restriction *ptr;
  int i;

  ptr = res;

  while(ptr->type)
    { /* should just use restable for freeing. */
      free(ptr->type);
      if (ptr->failure)
	free (ptr->failure);
      for (i = 0; i < MAXLIST; i++)
	if (ptr->plusip[i] != NULL)
	  free(ptr->plusip[i]);
	else
	  break;
      for (i = 0; i < MAXLIST; i++)
	if (ptr->minusip[i] != NULL)
	  free(ptr->minusip[i]);
	else
	  break;
      for (i = 0; i < MAXLIST; i++)
	if (ptr->plusip[i] != NULL)
	  free(ptr->plusname[i]);
	else
	  break;
      for (i = 0; i < MAXLIST; i++)
	if (ptr->plusip[i] != NULL)
	  free(ptr->minusname[i]);
	else
	  break;
      ptr++;
    }
  free(res);
}

int addResInfo(res, string, line)
     Restriction *res;
     char *string;
     int line;
{
  offtable *off;
  char *copy, *ptr;
  char **array;
  int i, l;

  for (off = restable; off->name != NULL; off++)
    if (!strncmp(string, off->name, strlen(off->name)))
      {
	ptr = string + strlen(off->name);
	if (!isspace(*ptr))
	  {
	    fprintf(stderr, "%s: Could not parse line %d\n", pname, line);
	    return -1;
	  }
	while (isspace(*ptr))
	  ptr++;

	if (*ptr == '\0')
	  {
	    fprintf(stderr, "%s: Line %d ended prematurely.\n", pname, line);
	    return -1;
	  }

	l = strlen(ptr);
	copy = malloc(l + 1);
	if (copy == NULL)
	  {
	    fprintf(stderr, "%s: Could not malloc memory for line %d.\n",
		    pname, line);
	    return -1;
	  }
	strcpy(copy, ptr);
	if (l > 0 && copy[l - 1] == '\n')
	  copy[l - 1] = '\0';

	switch(off->indirect)
	  {
	  case 1:
	    *(char **)((char *)res + off->offset) = copy;
	    return 0;
	    break;
	  case 2:
	    array = (char **)((char *)res + off->offset);
	    for (i = 0; array[i] != NULL && i < MAXLIST; i++) ;
	    if (i == MAXLIST)
	      {
		free(copy);
		fprintf(stderr, "%s: List overflow at line %d.\n",
			pname, line);
		return -1;
	      }
	    array[i] = copy;
	    return 0;
	    break;
	  default:
	    fprintf(stderr, "%s: Bad offset table - code bug (line %d).\n",
		    pname, line);
	    return -1;
	    break;
	  }
      }

  fprintf(stderr, "%s: Could not parse line %d\n", pname, line);
  return -1;
}

int loadRestrictions(configFile)
     char *configFile;
{
  FILE *res;
  char inbuffer[4096];
  Restriction *work, *ptr;
  int line = 0;

  res = fopen(configFile, "r");
  if (res == NULL)
    {
      fprintf(stderr, "%s: Could not open restrictions file.\n", pname);
      return -1;
    }

  work = (Restriction *)malloc(sizeof(Restriction) * MAXTYPES);
  if (work == NULL)
    {
      fclose(res);
      fprintf(stderr, "%s: Could not malloc for restrictions.\n", pname);
      return -1;
    }
  memset(work, 0, sizeof(Restriction) * MAXTYPES);

  ptr = work;

  while (fgets(inbuffer, sizeof(inbuffer), res))
    {
      line++;

      if (inbuffer[0] == '#')
	continue;

      if (!strncmp(inbuffer, "restriction", strlen("restriction")))
	{
	  if (ptr->type != NULL)
	    ptr++; /* XXX trap for overflow */
	}

      if (addResInfo(ptr, inbuffer, line))
	{
	  fclose(res);
	  freeResInfo(work);
	  return -1;
	}
    }

  fclose(res);
  if (restrictions)
    freeResInfo(restrictions);
  restrictions = work;
  return 0;
}

int match(regexp, string)
     char *regexp, *string;
{
  regex_t re;
  int err;
  char errbuf[256];
  regmatch_t scratch;

  if (string == NULL)
    return 0;

#ifdef DEBUG
  fprintf(stderr, "%s: match: %s:%s\n", pname, regexp, string);
#endif

  err = regcomp(&re, regexp, REG_NOSUB|REG_NEWLINE);
  if (err)
    {
      regerror(err, &re, errbuf, 256);
      fprintf(stderr, "%s: regcomp(\"%s\") - %s\n",
	      pname, regexp, errbuf);
      return 0;
    }

  err = regexec(&re, string, 0, &scratch, 0);
  if (!err)
    return 1;

  return 0;
}

#define RES_OK 0
#define RES_UNKNOWNTYPE 1
#define RES_NOTALLOWED 2

int checkRestrictions(type, ipstring, namestring, errret)
     char *type, *ipstring, *namestring;
     char **errret;
{
  Restriction *res;
  char **ptr;
  int found = 0;

  for (res = restrictions; res->type != NULL; res++)
    if (!strcmp(res->type, type))
      {
	/* On positive ip acl? */
	for (ptr = res->plusip; *ptr != NULL; ptr++)
	  if (match(*ptr, ipstring))
	    {
	      found = 1;
	      break;
	    }

	/* On positive name acl? */
	if (!found)
	  for (ptr = res->plusname; *ptr != NULL; ptr++)
	    if (match(*ptr, namestring))
	      {
		found = 1;
		break;
	      }

	if (!found)
	  {
	    *errret = res->failure;
	    return RES_NOTALLOWED;
	  }

	/* On negative acl? */
	for (ptr = res->minusip; *ptr != NULL; ptr++)
	  if (match(*ptr, ipstring))
	    {
	      *errret = res->failure;
	      return RES_NOTALLOWED;
	    }

	for (ptr = res->minusname; *ptr != NULL; ptr++)
	  if (match(*ptr, namestring))
	    {
	      *errret = res->failure;
	      return RES_NOTALLOWED;
	    }

	return RES_OK;
      }

  return RES_UNKNOWNTYPE;
}

void printError(die)
     int die;
{
  int fatal = 0;

  while (Error_Exists)
    {
      if ((Error_Severity == S_FATAL) ||
	  (Error_Severity == S_WARNING))
	{
	  fprintf(stderr, "%s: ", pname);
	  fprintf(stderr, Error_String(Error), Error_Info);
	  fputc('\n', stderr);
	  if (Error_Severity == S_FATAL)
	    fatal++;
	}
      Error_Pop();
    }

  if (fatal && die)
    exit(1);
}

#define E(x) if (x) printError(1)
#define ED(x) if (x) printError(0)

void takehup(int sig)
{
  gothup = 1;
  signal(SIGHUP, takehup);
}

void takeusr1(int sig)
{
  gotusr1 = 1;
  signal(SIGUSR1, takeusr1);
}

usage()
{
  fprintf(stderr, "usage: %s [-logfile filename] [-config filename]\n", pname);
  exit(1);
}
	  
openLog(file)
     char *file;
{
  if (file)
    {
      freopen(file, "a", stdout);
      freopen(file, "a", stderr);
    }

  setbuf(stdout, NULL);
  setbuf(stderr, NULL);
}

main(argc, argv)
     char **argv;
{
  Addr a, b;
  Packet p;
  char from[50], to[50];
  Trap retval;
  char *logFile = NULL;
  char *configFile = "/var/ops/slw/restrictions";

  if (strrchr(argv[0], '/'))
    sprintf(pname, "%s", strrchr(argv[0], '/') + 1);
  else
    sprintf(pname, "%s", argv[0]);

  argv++;
  while (*argv)
    {
      if (!strcmp(*argv, "-logfile"))
	{
	  argv++;
	  if (*argv)
	    logFile = *argv;
	  else
	    usage();
	}
      else
	if (!strcmp(*argv, "-config"))
	  {
	    argv++;
	    if (*argv)
	      configFile = *argv;
	    else
	      usage();
	  }
	else
	  usage();

      argv++;
    }

  if (logFile)
    fprintf(stderr, "%s: Using logfile %s\n", pname, logFile);

  fprintf(stderr, "%s: Using config file %s\n", pname, configFile);

  openLog(logFile);

  E(Connect_Initialize());
  E(Connect_RegisterDomain(&inetudp));

  E(Code_Initialize());
  E(Code_RegisterDomain(&literal));

  sprintf(from, "%s:./%s", PORTDOMAIN, PORTNAME);
  if (Connect_NameToAddress(from, &a))
    {
      if (Error == IP_UNKNOWNSERVICE)
	{
	  fprintf(stderr,
	   "%s: Warning: Service %s not in /etc/services; using port %d\n",
		  pname, PORTNAME, PORTNUMBER);
	  Error_Pop();
	  sprintf(from, "%s:./%d", PORTDOMAIN, PORTNUMBER);
	  E(Connect_NameToAddress(from, &a));
	}
      else
	printError(1);
    }

  sprintf(to, "%s:*", PORTDOMAIN);
  E(Connect_NameToAddress(to, &b));

  E(Connect_OpenConnection(a, b));

  if (loadRestrictions(configFile))
    {
      fprintf(stderr, "%s: could not load restrictions file %s\n", pname,
	      configFile);
      exit(1);
    }

  signal(SIGHUP, takehup);
  signal(SIGUSR1, takeusr1);

  Ip_NameResolve(0);

  while (1)
    {
      if (retval = Connect_Wait(&p, NOTIMEOUT))
	{
	  if ((Error == CONNECT_SELECT) && ((int)Error_Info == EINTR) &&
	      (gothup || gotusr1))
	    Error_Pop();
	  else
	    E(retval);
	}

      if (gothup)
	{
	  gothup = 0;
	  fprintf(stderr, "%s: HUP received; reloading restrictions\n",
		  pname);
	  if (loadRestrictions(configFile))
	    fprintf(stderr,
		    "%s: reload of %s failed, continuing with no changes\n",
		    pname, configFile);
	  if (retval) /* not random signal; interrupted the select */
	    continue;
	}

      if (gotusr1)
	{
	  gotusr1 = 0;
	  fprintf(stderr, "%s: USR1 received; resetting log file\n",
		  pname);
	  openLog(logFile);
	  if (retval) /* not random signal; interrupted the select */
	    continue;
	}

      Connect_AddressToName(p.Source, name, sizeof(name));
#ifdef HUMAN
      strcpy(timestamp, ctime(&p.when));
      timestamp[strlen(timestamp)-1] = '\0';
#else
      sprintf(timestamp, "%d", p.when);
#endif
      if (verbose)
	fprintf(stdout, "From: %d - %s at %s\n", p.Source, name, timestamp);

      processRequest(&p);
    }
}

int processRequest(p)
     Packet *p;
{
  static CodeBlock rblock;
  static char rdata[512];
  static int rinitialized = 0;
  CodeBlock *block;
  Card16 proto, reqtype;
  Card32 keygen;
  char *system, *program, *vendor, *version, *restriction, *path;
  char scratch[160];
  int delta;
  des_cblock key;
  struct in_addr netip;
  char netipname[30];
  char hostname[50], *h, *i;
  Addr a;
  int reserror, err = 0;
  char *reserrorstring;

  if (!rinitialized)
    {
      E(Code_InitBlock(&rblock, CODEDOMAIN, rdata, sizeof(rdata)));
      rinitialized = 1;
    }
  else
    Code_ResetBlock(&rblock);

  E(Code_CreateBlock(&block, CODEDOMAIN, p->packet, p->length));

  if (Code_GetCard16(block, &proto))
    {
    badpacket:
      Code_PutCard16(&rblock, REPLY_ERROR);
      Code_PutCard16(&rblock, err = ERROR_BADPACKET);
      Code_PutString(&rblock, "");
      goto sendrblock;
    }

  if (proto != PROTOVERSION0)
    {
      Code_PutCard16(&rblock, REPLY_ERROR);
      Code_PutCard16(&rblock, err = ERROR_PROTOUNSUPPORTED);
      Code_PutString(&rblock, "");
      goto sendrblock;
    }

  if (Code_GetCard16(block, &reqtype))
    goto badpacket;

  switch(reqtype)
    {
    case REQUEST_REGISTRATION:
    case REQUEST_KEY:
      if (Code_GetString(block, &system) ||
	  Code_GetString(block, &program) ||
	  Code_GetString(block, &version) ||
	  Code_GetString(block, &vendor) ||
	  Code_GetString(block, &restriction) ||
	  Code_GetString(block, &path) ||
	  Code_GetCard32(block, &keygen))
	goto badpacket;

      if (verbose)
	{
	  if (reqtype == REQUEST_REGISTRATION)
	    fprintf(stdout, "Registration request for ");
	  else
	    fprintf(stdout, "Key request for ");

	  fprintf(stdout, "{%s,%s,%s,%s,%s,%s} %d\n", system, program,
		  version, vendor, restriction, path, keygen);
	}

      fprintf(stdout, "%s,%s,%s,%s,%s\n",
	      requests[reqtype], name, program, restriction, timestamp);

      Udp_IP(p->Source, &netip.s_addr);
      strcpy(netipname,inet_ntoa(netip));

      /*
       * Name returned is of form domain:iphostname/portnum; strip to
       * iphostname for comparison.
       */
      strcpy(hostname, name);
      h = strrchr(hostname, '/');
      if (h != NULL)
	*h = '\0';
      h = strchr(hostname, ':');
      if (h != NULL)
	h++;
      else
	h = hostname;
      for (i = h; *i != '\0'; i++)
	*i = tolower(*i);

      if (reserror = checkRestrictions(restriction,
				       netipname, h,
				       &reserrorstring))
	{
	  Code_PutCard16(&rblock, REPLY_ERROR);
	  if (reserror == RES_UNKNOWNTYPE)
	    {
	      Code_PutCard16(&rblock, err = ERROR_UNKNOWNKEYTYPE);
	      sprintf(scratch, "Key type passed: \"%s\"", restriction);
	      Code_PutString(&rblock, scratch);
	    }
	  if (reserror == RES_NOTALLOWED)
	    {
	      Code_PutCard16(&rblock, err = ERROR_PERMISSIONDENIED);
	      Code_PutString(&rblock, reserrorstring);
	    }
	  goto sendrblock;
	}

      if (abs(p->when - keygen) > TIMERANGE) /* XXX - int sizes? */
	{
	  Code_PutCard16(&rblock, REPLY_ERROR);
	  Code_PutCard16(&rblock, err = ERROR_BADKEYGENERATOR);
	  Code_PutString(&rblock, "");
	  goto sendrblock;
	}

      gimme_key_for_prog(key, keygen, system, program,
			 version, vendor, restriction);

      if (reqtype == REQUEST_REGISTRATION)
	Code_PutCard16(&rblock, REPLY_REGISTRATION);
      else
	Code_PutCard16(&rblock, REPLY_KEY);
      Code_PutMemory(&rblock, key, sizeof(key));
      goto sendrblock;
      break;

    case REQUEST_OPEN:
    case REQUEST_CLOSE:
      fprintf(stdout, "%s,%s,,,%s\n",
	      requests[reqtype], name, timestamp);
      if (reqtype == REQUEST_OPEN)
	Code_PutCard16(&rblock, REPLY_OPEN);
      else
	Code_PutCard16(&rblock, REPLY_CLOSE);
      goto sendrblock;
      break;

    case REQUEST_PING:
      Code_PutCard16(&rblock, REPLY_PING);
      goto sendrblock;
      break;

    case REQUEST_DEREGISTRATION:
      Code_PutCard16(&rblock, REPLY_DEREGRISTRATION);
      goto sendrblock;
      break;

    default:
      Code_PutCard16(&rblock, REPLY_ERROR);
      Code_PutCard16(&rblock, err = ERROR_BADREQUEST);
      Code_PutString(&rblock, "");
      goto sendrblock;
      break;
    }

 sendrblock:

  E(Code_FreeBlock(block));
  free(p->packet);

  a = p->Destination;
  p->Destination = p->Source;
  p->Source = a;
  p->packet = Code_BlockData(&rblock);
  p->length = Code_BlockLength(&rblock);
  E(Connect_SendPacket(p));
  E(Code_FreeBlock(block));
  E(Connect_CloseConnection(p->Source, p->Destination));
  if (err)
    fprintf(stdout, "error,%s,%s,%s\n",
	    name, errors[err], timestamp);
}
