#include <sys/types.h>
#include <sys/param.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <netdb.h>

#include "substdio.h"
#include "stralloc.h"
#include "str.h"
#include "byte.h"
#include "error.h"
#include "sig.h"
#include "subfd.h"
#include "fd.h"
#include "ip.h"
#include "ipalloc.h"
#include "case.h"
#include "sgetopt.h"
#include "exit.h"
#include "scan.h"
#include "fmt.h"
#include "env.h"
#include "dns.h"
#include "remoteinfo.h"

int verbosity = 1;
char temp[IPFMT + FMT_ULONG];

void out(s) char *s; { if (verbosity) substdio_puts(subfderr,s); }
void flush() { if (verbosity) substdio_flush(subfderr); }

void usage()
{ out("tcpclient: usage: \
tcpclient [ -qQvdDrRU ] [ -plocport ] [ -llocalname ] [ -ttimeoutinfo ] \
[ -Ttimeoutconn ] host port program\n"); flush(); _exit(100); }
void die(s) char *s;
{ out("tcpclient: fatal: "); out(s); out("\n"); flush(); _exit(111); }
void diep(s) char *s;
{ char *x = error_str(errno);
  out("tcpclient: fatal: "); out(s); out(": ");
  out(x); out("\n"); flush(); _exit(111); }
void nomem()
{ die("out of memory"); }
void warnconnect(ipr) struct ip_address *ipr;
{
  char *x = error_str(errno);
  temp[ip_fmt(temp,ipr)] = 0;
  out("tcpclient: unable to connect to "); out(temp); out(": ");
  out(x); out("\n"); flush(); }
void infoconnected(ipr) struct ip_address *ipr;
{ if (verbosity < 2) return;
  temp[ip_fmt(temp,ipr)] = 0;
  out("tcpclient: connected to "); out(temp); out("\n"); flush(); }

int flagdelay = 1;
int flagremoteinfo = 1;
int flagremotehost = 1;
char *forcelocal = 0;
unsigned long timeout = 26;
unsigned long timeout2 = 60;

stralloc tmp = {0};
ipalloc ia = {0};

struct sockaddr_in salocal;
struct ip_address iplocal;
unsigned long portlocal;
struct sockaddr_in saremote;
struct ip_address ipremote;
unsigned long portremote;

void main(argc,argv)
int argc;
char **argv;
{
  int s;
  int dummy;
  int opt;
  char *hostname;
  char *portname;
  struct servent *se;
  int j;
 
  portlocal = 0;
 
  while ((opt = getopt(argc,argv,"dDvqQLhHrRp:Ut:T:l:")) != opteof)
    switch(opt) {
      case 'd': flagdelay = 1; break;
      case 'D': flagdelay = 0; break;
      case 'v': verbosity = 2; break;
      case 'q': verbosity = 0; break;
      case 'Q': verbosity = 1; break;
      case 'l': forcelocal = optarg; break;
      case 'H': flagremotehost = 0; break;
      case 'h': flagremotehost = 1; break;
      case 'R': flagremoteinfo = 0; break;
      case 'r': flagremoteinfo = 1; break;
      case 'p': scan_ulong(optarg,&portlocal); break;
      case 't': scan_ulong(optarg,&timeout); break;
      case 'T': scan_ulong(optarg,&timeout2); break;
      case 'U': usage();
      default: _exit(100);
    }
  argc -= optind;
  argv += optind;
 
  hostname = *argv++;
  if (!hostname) usage();
  portname = *argv++;
  if (!portname) usage();
  if (!*argv) usage();
 
  close(6);
  close(7);
  sig_pipeignore();
  
  dns_init(1);
 
  byte_zero(&salocal,sizeof(salocal));
  salocal.sin_family = AF_INET;
  salocal.sin_addr.s_addr = INADDR_ANY;
  salocal.sin_port = htons((unsigned short) portlocal);
 
  byte_zero(&saremote,sizeof(saremote));
  saremote.sin_family = AF_INET;
 
  if (!portname[scan_ulong(portname,&portremote)])
    ;
  else {
    se = getservbyname(portname,"tcp");
    if (!se) die("unable to figure out port number");
    portremote = ntohs(se->s_port);
    /* i continue to be amazed at the stupidity of the s_port interface */
  }
 
  if (!str_diff(hostname,"0")) hostname = "0.0.0.0";
 
  if (!hostname[ip_scan(hostname,&ipremote)]) {
    s = socket(AF_INET,SOCK_STREAM,0);
    if (s == -1) diep("unable to create socket");
    if (bind(s,(struct sockaddr *) &salocal,sizeof(salocal)) == -1)
      diep("unable to bind");
    if (timeoutconn(s,&ipremote,portremote,(int) timeout2) == -1) {
      warnconnect(&ipremote);
      _exit(111);
    }
  }
  else {
    if (!stralloc_copys(&tmp,hostname)) nomem();
    switch(dns_ip(&ia,&tmp)) {
      case DNS_SOFT: die("temporarily unable to figure out host address");
      case DNS_HARD: die("unable to figure out host address");
      case DNS_MEM: nomem();
    }
    if (ia.len == 0) die("no IP addresses for that host");
    for (j = 0;j < ia.len;++j) {
      s = socket(AF_INET,SOCK_STREAM,0);
      if (s == -1) diep("unable to create socket");
      if (bind(s,(struct sockaddr *) &salocal,sizeof(salocal)) == -1)
        diep("unable to bind");
      byte_copy(&ipremote,4,&ia.ix[j].ip);
      if (timeoutconn(s,&ipremote,portremote,(int) timeout2) == 0) break;
      warnconnect(&ipremote);
      close(s);
    }
    if (j == ia.len) _exit(111);
  }
 
  if (!flagdelay) {
    int opt = 1;
    setsockopt(s,IPPROTO_TCP,1,&opt,sizeof(opt)); /* 1 == TCP_NODELAY */
    /* if it fails, bummer */
  }
 
  if (!env_init()) nomem();
  if (!env_put("PROTO=TCP")) nomem();
  if (!env_unset("TCPLOCALHOST")) nomem();
  if (!env_unset("TCPREMOTEHOST")) nomem();
  if (!env_unset("TCPREMOTEINFO")) nomem();
 
  dummy = sizeof(salocal);
  if (getsockname(s,(struct sockaddr *) &salocal,&dummy) == -1)
    diep("unable to get local address");
  portlocal = ntohs(salocal.sin_port);
  byte_copy(&iplocal,4,&salocal.sin_addr);
 
  temp[fmt_ulong(temp,portlocal)] = 0;
  if (!env_put2("TCPLOCALPORT",temp)) nomem();
  temp[ip_fmt(temp,&iplocal)] = 0;
  if (!env_put2("TCPLOCALIP",temp)) nomem();
  if (forcelocal) {
    if (!env_put2("TCPLOCALHOST",forcelocal)) nomem();
  }
  else
    switch(dns_ptr(&tmp,&iplocal)) {
      case DNS_MEM: nomem();
      case 0:
        if (!stralloc_0(&tmp)) nomem();
        case_lowers(tmp.s);
        if (!env_put2("TCPLOCALHOST",tmp.s)) nomem();
    }
 
  dummy = sizeof(saremote);
  if (getpeername(s,(struct sockaddr *) &saremote,&dummy) == -1)
    diep("unable to get remote address");
  portremote = ntohs(saremote.sin_port);
  byte_copy(&ipremote,4,&saremote.sin_addr);
 
  infoconnected(&ipremote);
 
  temp[fmt_ulong(temp,portremote)] = 0;
  if (!env_put2("TCPREMOTEPORT",temp)) nomem();
  temp[ip_fmt(temp,&ipremote)] = 0;
  if (!env_put2("TCPREMOTEIP",temp)) nomem();
  if (flagremotehost)
    switch(dns_ptr(&tmp,&ipremote)) {
      case DNS_MEM: nomem();
      case 0:
        if (!stralloc_0(&tmp)) nomem();
        case_lowers(tmp.s);
        if (!env_put2("TCPREMOTEHOST",tmp.s)) nomem();
    }
  if (flagremoteinfo) {
    char *rinfo;
    rinfo = remoteinfo_get(&ipremote,portremote,&iplocal,portlocal,(int) timeout);
    if (rinfo)
      if (!env_put2("TCPREMOTEINFO",rinfo)) nomem();
  }
 
  if (fd_move(6,s) == -1) diep("unable to set up descriptor 6");
  if (fd_copy(7,6) == -1) diep("unable to set up descriptor 7");
  sig_pipedefault();
 
  execvp(*argv,argv);
  diep("unable to execute");
}
