#include <sys/types.h>
#include <sys/param.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <netdb.h>

#include "strerr.h"
#include "substdio.h"
#include "stralloc.h"
#include "alloc.h"
#include "readwrite.h"
#include "fd.h"
#include "sig.h"
#include "wait.h"
#include "ip.h"
#include "ipalloc.h"
#include "dns.h"
#include "str.h"
#include "case.h"
#include "byte.h"
#include "sgetopt.h"
#include "remoteinfo.h"
#include "exit.h"
#include "open.h"
#include "scan.h"
#include "fmt.h"
#include "env.h"
#include "cdb.h"

#define FATAL "tcpserver: fatal: "
#define DROP "tcpserver: warning: dropping connection, "
int verbosity = 1;

void die_nomem()
{
  if (verbosity) strerr_warn2(FATAL,"out of memory",0);
  _exit(111);
}
void drop_nomem()
{
  if (verbosity) strerr_warn2(DROP,"out of memory",0);
  _exit(111);
}
void usage()
{
  if (verbosity) strerr_warn1("\
tcpserver: usage: \
tcpserver [ -qQvdDoOpPhHrR1 ] \
[ -xrules.cdb ] \
[ -bbacklog ] [ -climit ] [ -ttimeout ] [ -llocalname ] [ -ggid ] [ -uuid ] \
host port program",0);
  _exit(100);
}

void safeappend(sa,s)
stralloc *sa;
char *s;
{
  char ch;
  while (ch = *s++) {
    if (ch < 33) ch = '?';
    if (ch > 126) ch = '?';
    if (ch == '%') ch = '?'; /* logger stupidity */
    if (ch == ':') ch = '?';
    if (!stralloc_append(sa,&ch)) drop_nomem();
  }
}

char strnum[FMT_ULONG];
char strnum2[FMT_ULONG];
stralloc tmp = {0};
ipalloc ia = {0};

unsigned long numchildren = 0;

char tcpremoteip[IPFMT];
char tcpremoteport[IPFMT];
char *tcpremoteinfo;
struct sockaddr_in saremote;
struct ip_address ipremote;
unsigned long portremote;

char tcplocalip[IPFMT];
char tcplocalport[IPFMT];
struct sockaddr_in salocal;
struct ip_address iplocal;
unsigned long portlocal;

int fdrules;
char *fnrules = 0;
int flagdeny = 0;

void printenv()
{
  char *tcplocalhost;
  char *tcpremotehost;

  if (verbosity < 2) return;

  tcplocalhost = env_get("TCPLOCALHOST");
  tcpremotehost = env_get("TCPREMOTEHOST");

  if (!tcplocalhost) tcplocalhost = "";
  if (!tcpremotehost) tcpremotehost = "";

  if (!stralloc_copys(&tmp,"tcpserver: ")) drop_nomem();
  if (!stralloc_cats(&tmp,flagdeny ? "deny " : "ok ")) drop_nomem();
  if (!stralloc_catb(&tmp,strnum,fmt_ulong(strnum,getpid()))) drop_nomem();
  if (!stralloc_cats(&tmp," ")) drop_nomem();
  safeappend(&tmp,tcplocalhost);
  if (!stralloc_cats(&tmp,":")) drop_nomem();
  safeappend(&tmp,tcplocalip);
  if (!stralloc_cats(&tmp,":")) drop_nomem();
  safeappend(&tmp,tcplocalport);
  if (!stralloc_cats(&tmp," ")) drop_nomem();
  safeappend(&tmp,tcpremotehost);
  if (!stralloc_cats(&tmp,":")) drop_nomem();
  safeappend(&tmp,tcpremoteip);
  if (!stralloc_cats(&tmp,":")) drop_nomem();
  safeappend(&tmp,tcpremoteinfo ? tcpremoteinfo : "");
  if (!stralloc_cats(&tmp,":")) drop_nomem();
  safeappend(&tmp,tcpremoteport);
  if (!stralloc_0(&tmp)) drop_nomem();

  strerr_warn1(tmp.s,0);
}

void printpid()
{
  if (verbosity < 2) return;
  strnum[fmt_ulong(strnum,(unsigned long) getpid())] = 0;
  strnum2[fmt_ulong(strnum2,numchildren)] = 0;
  strerr_warn6("tcpserver: pid ",strnum," num ",strnum2," from ",tcpremoteip,0);
}

char printbuf[16];
struct substdio print = SUBSTDIO_FDBUF(write,1,printbuf,sizeof(printbuf));
void printlocalport()
{
  substdio_puts(&print,tcplocalport);
  substdio_puts(&print,"\n");
  substdio_flush(&print);
}

void sigterm() { _exit(0); }

void sigchld()
{
  int wstat;
  int pid;
 
  while ((pid = wait_nohang(&wstat)) > 0) {
    if (numchildren) --numchildren;
    if (verbosity >= 2) {
      strnum[fmt_ulong(strnum,(unsigned long) pid)] = 0;
      strnum2[fmt_ulong(strnum2,(unsigned long) wstat)] = 0;
      strerr_warn4("tcpserver: end ",strnum," status ",strnum2,0);
    }
  }
}

void drop_rules()
{
  if (verbosity) strerr_warn4(DROP,"unable to read ",fnrules,": ",&strerr_sys);
  _exit(111);
}

int dorule()
{
  char *data;
  uint32 dlen32;
  unsigned int datalen;
  unsigned int next0;

  switch(cdb_seek(fdrules,tmp.s,tmp.len,&dlen32)) {
    case -1: drop_rules();
    case 0: return 0;
  }

  datalen = dlen32;
  data = alloc(datalen);
  if (!data) drop_nomem();
  if (cdb_bread(fdrules,data,datalen) != 0) drop_rules();

  while ((next0 = byte_chr(data,datalen,0)) < datalen) {
    switch(data[0]) {
      case 'D': flagdeny = 1; break;
      case '+': if (!env_put(data + 1)) drop_nomem(); break;
    }
    data += next0 + 1; datalen -= next0 + 1;
  }
  return 1;
}

void rules()
{
  if (!fnrules) return;

  fdrules = open_read(fnrules);
  if (fdrules == -1) drop_rules();

  if (tcpremoteinfo) {
    if (!stralloc_copys(&tmp,tcpremoteinfo)) drop_nomem();
    if (!stralloc_cats(&tmp,"@")) drop_nomem();
    if (!stralloc_cats(&tmp,tcpremoteip)) drop_nomem();
    if (dorule()) goto done;
  }

  if (!stralloc_copys(&tmp,tcpremoteip)) drop_nomem();
  if (dorule()) goto done;
  while (tmp.len > 0) {
    if (tcpremoteip[tmp.len - 1] == '.')
      if (dorule()) goto done;
    --tmp.len;
  }

  dorule();

  done:
  close(fdrules);
}

int flagkillopts = 1;
int flagdelay = 1;
int flagremoteinfo = 1;
int flagremotehost = 1;
int flagparanoid = 0;
int flag1 = 0;
unsigned long backlog = 20;
unsigned long timeout = 26;
unsigned long uid = 0;
unsigned long gid = 0;
char *forcelocal = 0;
unsigned long limit = 40;

void main(argc,argv)
int argc;
char **argv;
{
  int s;
  int t;
  int dummy;
  int opt;
  char *hostname;
  char *portname;
  struct servent *se;
  int j;
 
  while ((opt = getopt(argc,argv,"dDvqQhHrR1x:t:u:g:l:b:c:pPoO")) != opteof)
    switch(opt) {
      case 'b': scan_ulong(optarg,&backlog); break;
      case 'c': scan_ulong(optarg,&limit); break;
      case 'x': fnrules = optarg; break;
      case 'd': flagdelay = 1; break;
      case 'D': flagdelay = 0; break;
      case 'v': verbosity = 2; break;
      case 'q': verbosity = 0; break;
      case 'Q': verbosity = 1; break;
      case 'P': flagparanoid = 0; break;
      case 'p': flagparanoid = 1; break;
      case 'O': flagkillopts = 1; break;
      case 'o': flagkillopts = 0; break;
      case 'H': flagremotehost = 0; break;
      case 'h': flagremotehost = 1; break;
      case 'R': flagremoteinfo = 0; break;
      case 'r': flagremoteinfo = 1; break;
      case 't': scan_ulong(optarg,&timeout); break;
      case 'g': scan_ulong(optarg,&gid); break;
      case 'u': scan_ulong(optarg,&uid); break;
      case '1': flag1 = 1; break;
      case 'l': forcelocal = optarg; break;
      default: usage();
    }
  argc -= optind;
  argv += optind;
 
  hostname = *argv++;
  if (!hostname) usage();
  portname = *argv++;
  if (!portname) usage();
  if (!*argv) usage();
 
  sig_pipeignore();
  sig_termcatch(sigterm);
  sig_childcatch(sigchld);
 
  dns_init(1);
 
  byte_zero(&salocal,sizeof(salocal));
  salocal.sin_family = AF_INET;
 
  if (!portname[scan_ulong(portname,&portlocal)])
    salocal.sin_port = htons((unsigned short) portlocal);
  else {
    se = getservbyname(portname,"tcp");
    if (!se) {
      if (verbosity) strerr_warn3(FATAL,"unable to figure out number for port ",portname,0);
      _exit(111);
    }
    salocal.sin_port = se->s_port;
  }
 
  if (str_equal(hostname,"")) hostname = "0.0.0.0";
  if (str_equal(hostname,"0")) hostname = "0.0.0.0";
 
  if (hostname[ip_scan(hostname,&iplocal)]) {
    if (!stralloc_copys(&tmp,hostname)) die_nomem();
    switch(dns_ip(&ia,&tmp)) {
      case DNS_MEM:
	die_nomem();
      case DNS_HARD:
        if (verbosity) strerr_warn3(FATAL,"unable to figure out address for host ",hostname,0);
        _exit(111);
      case DNS_SOFT:
        if (verbosity) strerr_warn3(FATAL,"temporarily unable to figure out address for host ",hostname,0);
        _exit(111);
    }
    if (!ia.len) {
      if (verbosity) strerr_warn3(FATAL,"no IP addresses for host ",hostname,0);
      _exit(111);
    }
    byte_copy(&iplocal,4,&ia.ix[0].ip);
  }
 
  byte_copy(&salocal.sin_addr,4,&iplocal);
  s = socket(AF_INET,SOCK_STREAM,0);
  if (s == -1) {
    if (verbosity) strerr_warn2(FATAL,"unable to create socket: ",&strerr_sys);
    _exit(111);
  }
 
  {
    int opt = 1;
    setsockopt(s,SOL_SOCKET,SO_REUSEADDR,&opt,sizeof(opt));
  }

  if (bind(s,(struct sockaddr *) &salocal,sizeof(salocal)) == -1) {
    if (verbosity) strerr_warn2(FATAL,"unable to bind: ",&strerr_sys);
    _exit(111);
  }
  dummy = sizeof(salocal);
  if (getsockname(s,(struct sockaddr *) &salocal,&dummy) == -1) {
    if (verbosity) strerr_warn2(FATAL,"unable to get local address: ",&strerr_sys);
    _exit(111);
  }
  if (listen(s,backlog) == -1) {
    if (verbosity) strerr_warn2(FATAL,"unable to listen: ",&strerr_sys);
    _exit(111);
  }
 
  if (gid) if (setgid(gid) == -1) {
    if (verbosity) strerr_warn2(FATAL,"unable to set gid: ",&strerr_sys);
    _exit(100);
  }
  if (uid) if (setuid(uid) == -1) {
    if (verbosity) strerr_warn2(FATAL,"unable to set uid: ",&strerr_sys);
    _exit(100);
  }
 
  if (!env_init()) die_nomem();
  if (!env_put("PROTO=TCP")) die_nomem();
  if (!env_unset("TCPLOCALHOST")) die_nomem();
  if (!env_unset("TCPREMOTEHOST")) die_nomem();
  if (!env_unset("TCPREMOTEINFO")) die_nomem();
 
  if (forcelocal)
    if (!env_put2("TCPLOCALHOST",forcelocal)) die_nomem();
 
  portlocal = ntohs(salocal.sin_port);
  tcplocalport[fmt_ulong(tcplocalport,portlocal)] = 0;
  if (!env_put2("TCPLOCALPORT",tcplocalport)) die_nomem();
  if (flag1) printlocalport();
 
  close(0);
  close(1);
  sig_childblock();
 
  for (;;) {
    while (numchildren >= limit) sig_pause();
    sig_childunblock();
 
    dummy = sizeof(saremote);
    t = accept(s,(struct sockaddr *) &saremote,&dummy);
    if (t == -1) continue;
    portremote = ntohs(saremote.sin_port);
    byte_copy(&ipremote,4,&saremote.sin_addr);
 
    sig_childblock();
 
    switch(fork()) {
      default:
        ++numchildren;
        break;
      case -1:
	if (verbosity) strerr_warn2(DROP,"unable to fork: ",&strerr_sys);
        break;
      case 0:
        tcpremoteip[ip_fmt(tcpremoteip,&ipremote)] = 0;
        printpid();

        close(s);
        if (flagkillopts) {
          setsockopt(t,IPPROTO_IP,1,(char *) 0,0); /* 1 == IP_OPTIONS */
          /* if it fails, bummer */
        }
        if (!flagdelay) {
          int opt = 1;
          setsockopt(t,IPPROTO_TCP,1,&opt,sizeof(opt)); /* 1 == TCP_NODELAY */
          /* if it fails, bummer */
        }
        
        dummy = sizeof(salocal);
        if (getsockname(t,(struct sockaddr *) &salocal,&dummy) == -1) {
	  if (verbosity) strerr_warn2(DROP,"unable to get local address",&strerr_sys);
	  _exit(111);
	}
        byte_copy(&iplocal,4,&salocal.sin_addr);
       
        tcplocalip[ip_fmt(tcplocalip,&iplocal)] = 0;
        tcpremoteport[fmt_ulong(tcpremoteport,portremote)] = 0;

        if (!env_put2("TCPREMOTEIP",tcpremoteip)) drop_nomem();
        if (!env_put2("TCPLOCALIP",tcplocalip)) drop_nomem();
        if (!env_put2("TCPREMOTEPORT",tcpremoteport)) drop_nomem();

        if (!forcelocal)
          switch(dns_ptr(&tmp,&iplocal)) {
            case DNS_MEM: drop_nomem();
            case 0:
              if (!stralloc_0(&tmp)) drop_nomem();
              case_lowers(tmp.s);
              if (!env_put2("TCPLOCALHOST",tmp.s)) drop_nomem();
          }
 
        if (flagremotehost)
          switch(dns_ptr(&tmp,&ipremote)) {
            case DNS_MEM: drop_nomem();
            case 0:
              if (flagparanoid) {
                if (dns_ip(&ia,&tmp) != 0) break;
                for (j = 0;j < ia.len;++j)
                  if (!byte_diff(&ipremote,4,&ia.ix[j].ip))
                    break;
                if (j == ia.len)
                  break;
              }
              if (!stralloc_0(&tmp)) drop_nomem();
              case_lowers(tmp.s);
              if (!env_put2("TCPREMOTEHOST",tmp.s)) drop_nomem();
          }
        if (flagremoteinfo) {
          tcpremoteinfo = remoteinfo_get(&ipremote,portremote,&iplocal,portlocal,(int) timeout);
          if (tcpremoteinfo)
            if (!env_put2("TCPREMOTEINFO",tcpremoteinfo)) drop_nomem();
        }

	rules();
	printenv();
	if (flagdeny) _exit(100);

        if ((fd_move(0,t) == -1) || (fd_copy(1,0) == -1)) {
	  if (verbosity) strerr_warn2(DROP,"unable to set up descriptors: ",&strerr_sys);
	  _exit(111);
	}
        sig_childunblock();
        sig_childdefault();
        sig_termdefault();
        sig_pipedefault();
        execvp(*argv,argv);
	if (verbosity) strerr_warn4(DROP,"unable to run ",*argv,": ",&strerr_sys);
	_exit(111);
    }
    close(t);
  }
}
