#include <stdio.h>
#include <sys/types.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <netdb.h>
#include <signal.h>
#include <errno.h>

extern int errno;

#define min(a,b) ((a) < (b) ? (a) : (b))
#define max(a,b) ((a) > (b) ? (a) : (b))

void usage(), doxmit(), readwrite();

void sigchld();

struct sockaddr_in forward_to;

struct rwbuf {
    char readbuffer[BUFSIZ+5];
    char writebuffer[BUFSIZ+5];
    int rbytes;				/* #bytes active in readbuffer */
    int wbytes;				/* #bytes active in wbytes */
    int connto;				/* fd this one should be connected to */
    int wclose;				/* if set, close after finished writing */
};

fd_set rinit, winit;

struct rwbuf *rwbuf;

main(argc, argv)
int argc;
char *argv[];
{
    struct hostent *hp;
    struct sockaddr_in local, dummy;
    unsigned short port;
    register int i;
    int lsock;
    int dummylen = sizeof(dummy);
    int one = 1;
    int nready, nfds;
    fd_set readable, writable;
    int dtsize;

    if (argc != 4)
	usage();
    hp = gethostbyname(argv[1]);
    if (!hp) {
	fprintf(stderr, "no such host %s\n",argv[1]);
	exit(1);
    }
    bzero(&local, sizeof(local));
    bzero(&forward_to, sizeof(forward_to));

    port = atoi(argv[2]);
    if (!port) {
	fprintf(stderr, "port must be a positive integer!\n");
	exit(2);
    }
#if 0
    if (port < 6000 || port > 6010) {
	fprintf(stderr, "ports must be >= X11.0 and <= X11.10\n");
	exit(3);
    }
#endif
    local.sin_port = htons(port);
    local.sin_family = AF_INET;

    port = atoi(argv[3]);
    if (!port) {
	fprintf(stderr, "port must be a positive integer!\n");
	exit(3);
    }
#if 0
    if (port < 6000 || port > 6010) {
	fprintf(stderr, "ports must be >= X11.0 and <= X11.10\n");
	exit(3);
    }
#endif
	
    forward_to.sin_port = htons(port);
    forward_to.sin_family = AF_INET;
    bcopy(hp->h_addr, &forward_to.sin_addr, sizeof(forward_to.sin_addr));

    if ((lsock = socket(PF_INET, SOCK_STREAM, 0)) < 0) {
	perror("socket");
	exit(1);
    }
    if (setsockopt(lsock, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)) < 0) {
	perror("SO_REUSEADDR:");
	exit(1);
    }
    if (bind(lsock, (struct sockaddr *)&local, sizeof(local)) < 0) {
	perror("bind");
	exit(1);
    }
    if (listen(lsock, SOMAXCONN) < 0) {
	perror("listen");
	exit(1);
    }
    /* clean up to conserve descriptors */
    close(0);
    close(1);
    /* close(2); -- leave open for stderr */

    signal(SIGPIPE, SIG_IGN);

    dtsize = getdtablesize();

    /* set up buffers & sizes */
    rwbuf = (struct rwbuf *)calloc(dtsize, sizeof *rwbuf);
    if (!rwbuf) {
	fprintf(stderr, "can't allocate buffers\n");
	exit(1);
    }
    /* -1 means not connected/invalid */
    while (dtsize--) {
	rwbuf[dtsize].connto = -1;
    }

    FD_ZERO(&rinit);
    FD_ZERO(&winit);
    FD_SET(lsock, &rinit);
    
    nfds = lsock + 1;
    while (1) {
	for (i = 0; i < nfds; i++)
	    if (rwbuf[i].wclose)
		FD_SET(i, &winit);

	readable = rinit;
	writable = winit;
	if ((nready = select(nfds, &readable, &writable, 0, 0)) == -1) {
	    if (errno == EINTR)
		continue;
	    perror("select");
	    exit(1);
	}
	if (!nready) {
	    fprintf(stderr, "select not ready?!?\n");
	    exit(1);
	}
	for (i = 0; i < nfds && nready; i++) {
	    /* loop through descriptors */
	    if (FD_ISSET(i, &writable)) {
		int cc, leftover, connto;
		connto = rwbuf[i].connto;
		nready--;
		/* write what we can */
		if (rwbuf[i].wbytes) {
		    cc = write(i, rwbuf[i].writebuffer, rwbuf[i].wbytes);
		    if (cc == -1) {
			if (errno != EPIPE) {
			    fprintf(stderr, "fd %d:", i);
			    perror("write");
			}
			/* asynchrony on close probs? */
			FD_CLR(i, &winit);
			FD_CLR(i, &rinit);
			close(i);
			if (connto != -1) {
			    rwbuf[connto].connto = -1;
			    FD_CLR(connto, &rinit);
			    FD_CLR(connto, &winit);
			    close(connto);
			}
			rwbuf[i].connto = -1;
			continue;
		    } else {
			leftover = rwbuf[i].wbytes - cc;
			if (leftover) {
			    /* didn't write it all */
			    /* copy down */
			    bcopy(rwbuf[i].writebuffer + cc,
				  rwbuf[i].writebuffer,
				  leftover);
			    rwbuf[i].wbytes = leftover;
			} else
			    rwbuf[i].wbytes = 0; /* buffer empty */
			if (connto != -1 && rwbuf[connto].rbytes) {
			    /* more stuff to copy in */
			    copyfromto(connto, i);
			}
			if (!rwbuf[i].wbytes) {
			    /* nothing left to write */
			    if (rwbuf[i].wclose) {
				/* close after flushing */
				FD_CLR(i, &rinit);
				close(i);
				rwbuf[i].connto = -1;
				rwbuf[i].wclose = 0;
			    }
			    FD_CLR(i, &winit);
			}
			if (connto != -1)
			/* since we wrote some, go look for more */
			    FD_SET(connto, &rinit);
		    }
		} else {
		    if (connto != -1 && rwbuf[connto].rbytes) {
			    /* more stuff to copy in */
			    copyfromto(connto, i);
			    if (rwbuf[i].wbytes)
				continue;
		    }
		    /* nothing to write at the moment */
		    if (rwbuf[i].wclose) {
			/* close after flushing */
			FD_CLR(i, &rinit);
			close(i);
			rwbuf[i].connto = -1;
			rwbuf[i].wclose = 0;
		    }
		    /* nothing to write, so clear */
		    FD_CLR(i, &winit);
		}
	    }
	    if (FD_ISSET(i, &readable)) {
		nready--;
		/* something is readable */
		if (i == lsock) {
		    int newoutgoing, newsock;
		    /* new connection ready */
		    if ((newsock = accept(lsock, (struct sockaddr *)&dummy,
					  &dummylen)) < 0) {
			if (errno == EINTR)
			    continue;
			perror("accept");
			exit(1);
		    }
		    if ((newoutgoing = socket(PF_INET, SOCK_STREAM, 0)) < 0) {
			perror("socket");
			close(newsock);
			continue;
		    }
		    if (connect(newoutgoing, (struct sockaddr *)&forward_to,
				sizeof(forward_to)) < 0) {
			perror("connect");
			close(newoutgoing);
			close(newsock);
			continue;
		    }

		    nfds = max(nfds, newsock+1);
		    nfds = max(nfds, newoutgoing+1);

		    rwbuf[newsock].connto = newoutgoing;
		    rwbuf[newoutgoing].connto = newsock;

		    rwbuf[newoutgoing].rbytes = rwbuf[newsock].rbytes = 0;
		    rwbuf[newoutgoing].wbytes = rwbuf[newsock].wbytes = 0;
		    FD_SET(newsock, &rinit);
		    FD_SET(newsock, &winit);
		    FD_SET(newoutgoing, &rinit);
		    FD_SET(newoutgoing, &winit);
		} else {
		    int cc, connto, ncopy;
		    connto = rwbuf[i].connto;
		    /* normal fd is readable */
		    if (rwbuf[i].rbytes < BUFSIZ) {
			/* read what we have room for */
			cc = read(i, rwbuf[i].readbuffer + rwbuf[i].rbytes,
				  BUFSIZ-rwbuf[i].rbytes);
			if (cc == -1) {
			    fprintf(stderr, "fd %d:", i);
			    perror("read");
			    /* asynchrony on close probs? */
			    FD_CLR(i, &winit);
			    FD_CLR(i, &rinit);
			    close(i);
			    if (connto != -1) {
				rwbuf[connto].connto = -1;
			        FD_CLR(connto, &rinit);
				FD_CLR(connto, &winit);
				close(connto);
			    }
			    rwbuf[i].connto = -1;
			    continue;
			} else if (cc == 0) {
			    /* closedown */
			    FD_CLR(i, &rinit);
			    FD_CLR(i, &winit);
			    close(i);
			    if (connto != -1) {
				/* set close after finishing write */
				rwbuf[connto].wclose = 1;
				rwbuf[connto].connto = -1;
				/* force a write cycle to clean up */
				FD_SET(connto, &winit);
			    }
			    rwbuf[i].rbytes = 0;
			    rwbuf[i].wbytes = 0;
			    rwbuf[i].connto = -1;
			    /* XXX what else */
			} else {
			    rwbuf[i].rbytes += cc;
			    /* try to put onto write buffer */

			    if (connto != -1)
				copyfromto(i, connto);
			    if (rwbuf[i].rbytes >= BUFSIZ) {
				/* buffer is full */
				FD_CLR(i, &rinit);
			    }
			}
		    }
		}
	    }
	}
    }
}

void
usage()
{
    fprintf(stderr, "usage: tcpforward hostname localportnum foreignportnum\n");
    exit(1);
}

copyfromto(from, to)
int from, to;
{
    int ncopy;
    if (rwbuf[to].wbytes < BUFSIZ) {
	ncopy = min(rwbuf[from].rbytes,
		    BUFSIZ-rwbuf[to].wbytes);
	
	bcopy(rwbuf[from].readbuffer,
	      rwbuf[to].writebuffer + rwbuf[to].wbytes,
	      ncopy);
	rwbuf[to].wbytes += ncopy;
	FD_SET(to, &winit);
	if (ncopy == rwbuf[from].rbytes)
	    rwbuf[from].rbytes = 0;
	else {
	    bcopy(rwbuf[from].readbuffer + ncopy,
		  rwbuf[from].readbuffer,
		  rwbuf[from].rbytes - ncopy);
				
	    rwbuf[from].rbytes -= ncopy;
	}
	/* we have room */
	FD_SET(to, &rinit);
    }
    sleep(1);  
}
