/*
 *	$Source: /afs/net.mit.edu/user/tytso/src/tcp_forward/RCS/tcp_forwarder.c,v $
 *	$Author: tytso $
 *	$Header: /afs/net.mit.edu/user/tytso/src/tcp_forward/RCS/tcp_forwarder.c,v 1.2 91/02/10 18:49:52 tytso Exp Locker: tytso $
 *
 * tcp_forwarder --- a daemon which accepts TCP connection on a port
 * and forwards them to another port on (possibly) another machine.
 *
 * Known bugs:  Does _not_ handle out of band data.
 */

#ifndef lint
static char *rcsid_tcp_forwarder_c = "$Header: /afs/net.mit.edu/user/tytso/src/tcp_forward/RCS/tcp_forwarder.c,v 1.2 91/02/10 18:49:52 tytso Exp Locker: tytso $";
#endif lint

#include <stdio.h>
#include <ctype.h>
#include <signal.h>
#include <errno.h>
#include <syslog.h>
#include <sgtty.h>
#include <netdb.h>
#include <sys/types.h>
#include <sys/file.h>
#include <sys/wait.h>
#include <sys/socket.h>
#include <sys/param.h>
#include <netinet/in.h>

#define SERVER_NAME	"tcp_forwarder"
#define SYSLOG_CLASS	LOG_DAEMON
#define PID_FILE	"/tmp/tcp_forwarder.pid"

char	*progname;
int	standalone = 0;		/* 1 = Stand alone mode */
int	nofork = 0;		/* 1 = Don't fork in stand-alone mode */
int	listen_port;
int	forward_port;
char	*forward_host;
char	*listen_host;
int	debug = 0;
int	do_shutdown = 0;

int	forward_sock;
int	listen_sock;

struct	sockaddr_in	forward_addr;

char	*inet_ntoa();

extern int	errno;
int		reapchild();
int		shutdown();

main (argc, argv)
	int argc;
	char **argv;
{
	int	finet;
	struct sockaddr_in sin;
	
	PRS(argc, argv);
	if (standalone) {
		if (!nofork)
			detach();
		(void) signal(SIGCHLD, reapchild);

		if ((finet = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
			syslog("socket: finet: %m");
			exit(1);
		}
		bzero(&sin, sizeof sin);
		sin.sin_family = AF_INET;
		sin.sin_port = htons(listen_port);
		if (bind(finet, &sin, sizeof(sin), 0) < 0) {
			syslog(LOG_ERR, "bind: %m");
			exit(1);
		}
		listen(finet, 5);
		syslog(LOG_INFO, "starting standalone mode: port %d --> %s %d",
		       listen_port, forward_host, forward_port);
		for (;;) {
			struct sockaddr_in	frominet;
			int	fromlen;

			listen_sock = accept(finet, &frominet, &fromlen);
			if (listen_sock<0) {
				if (errno != EINTR)
					syslog(LOG_WARNING, "accept: %m");
				continue;
			}
			if (fork() == 0) {
				/* child */
				(void) signal(SIGCHLD, SIG_IGN);
				(void) close(finet);
				doit();
			} else
				close(listen_sock);
		}
	} else {
		/* 0 is socket passed off from inetd */
		listen_sock = 0;
		doit();
	}
	exit(0);
}

PRS(argc, argv)
	int	argc;
	char	**argv;
{
	struct hostent	*host;
	int	getopt();
	extern char	*optarg;
	extern int	optind;
	int	c;
	
	progname = argv[0];
	forward_sock = listen_sock = -1;
	standalone = nofork = 0;

	while ((c = getopt(argc, argv, "dnS")) != EOF) {
		switch (c) {
		case 'n':	/* Don't fork on standalone */
			nofork++;
			break;
		case 'S':
			standalone++; /* Turn on standalone mode */
			break;
		case 'd':
			debug++;
			break;
		default:
			usage();
		}
	}
	if ((argc-optind != 3) && (argc-optind != 2))
		usage();
	if ((listen_port = getport(argv[optind++])) == 0)
		usage();
	forward_host = argv[optind++];
	if (argc-optind) {
		if ((forward_port = getport(argv[optind++])) == 0)
			usage();
	} else
		forward_port = listen_port;
	
	openlog(SERVER_NAME, LOG_PID | LOG_ODELAY, SYSLOG_CLASS);
	
	if ((forward_addr.sin_addr.s_addr = inet_addr(forward_host)) != -1) {
		forward_addr.sin_family = AF_INET;
	} else {
		if (!(host = gethostbyname(forward_host))) {
			syslog(LOG_ERR, "%s: unknown forwarding host",
			       forward_host);
			exit(1);
		}
		forward_addr.sin_family = host->h_addrtype;
		bcopy(host->h_addr_list[0], (caddr_t) &forward_addr.sin_addr,
		      host->h_length);
	}
	forward_addr.sin_port = htons(forward_port);
}

usage()
{
	fprintf(stderr,
		"Usage: %s [-nS] local-port remote-host [remote-port]\n",
		progname);
	exit(1);
}

/*
 * Set up standard environment by detaching from the parent.
 */
detach()
{
	int	f;
#ifdef PID_FILE
	FILE	*pidfile;
#endif
	
	if (fork())
		exit(0);
	for (f=0; f < 5; f++)
		(void) close(f);
	(void) open("/dev/null", O_RDONLY);
	(void) open("/dev/null", O_WRONLY);
	(void) dup(1);
	if ((f = open("/dev/tty", O_RDWR)) > 0) {
		ioctl(f, TIOCNOTTY, 0);
		(void) close(f);
	}
#ifdef PID_FILE
	if ((pidfile = fopen(PID_FILE, "w")) != NULL) {
		fprintf(pidfile, "%d\n", getpid());
		fclose(pidfile);
	} else
		syslog(LOG_WARNING, "cannot write pid file %s", PID_FILE);
#endif
}	


doit()
{
	int	fromlen;
	struct sockaddr_in	from;
	struct hostent		*hp;
	int	on = 1;
	static char	buff1[BUFSIZ]; /* listener -> forwarded port */
	static char	buff2[BUFSIZ]; /* forwarded port -> listener */
	char	*start1, *end1, *start2, *end2;
	int	maxfds;		/* For select */
	
	signal(SIGPIPE, shutdown);
	fromlen = sizeof(from);
	if (getpeername(listen_sock, &from, &fromlen) < 0) {
		syslog(LOG_ERR, "getpeername: %m");
		exit(1);
	}
	if (hp = gethostbyaddr(&from.sin_addr, sizeof (struct in_addr),
	    from.sin_family))
		listen_host = hp->h_name;
	else
		listen_host = inet_ntoa(from.sin_addr);
	syslog(LOG_INFO, "Connection accepted from %s, to %s %d", listen_host,
	       forward_host, forward_port);
	if (setsockopt(listen_sock, SOL_SOCKET, SO_KEEPALIVE, &on, sizeof(on)) < 0) {
		syslog(LOG_WARNING, "setsockopt(SO_KEEPALIVE): %m");
	}
	if ((forward_sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
		syslog("socket: forward_sock: %m");
		exit(1);
	}
	if (connect(forward_sock, (caddr_t) &forward_addr, sizeof(forward_addr)) < 0) {
		syslog(LOG_ERR, "Couldn't connect to address %s: %m",
		       inet_ntoa(forward_addr.sin_addr));
		exit(1);
	}
	ioctl(listen_sock, FIONBIO, &on);
	ioctl(forward_sock, FIONBIO, &on);
	start1 = end1 = buff1;		/* Set up pointers */
	start2 = end2 = buff2;
	if (listen_sock > forward_sock)
		maxfds = listen_sock+1;
	else
		maxfds = forward_sock+1;
	while (1) {
		fd_set	readfds, writefds;
		int	nfound, count;

		FD_ZERO(&readfds);
		FD_ZERO(&writefds);
		if (end1 - start1)
			FD_SET(forward_sock, &writefds);
		if (end1 < start1+BUFSIZ)
			FD_SET(listen_sock, &readfds);
		if (end2 - start2)
			FD_SET(listen_sock, &writefds);
		if (end2 < start2+BUFSIZ)
			FD_SET(forward_sock, &readfds);
		if (do_shutdown && !(end1-start1) && !(end2-start2)) {
			(void) close(forward_sock);
			(void) close(listen_sock);
			syslog(LOG_INFO, "Connection closed.");
			exit(0);
		}
		nfound = select(maxfds, &readfds, &writefds, 0, 0);
		if (nfound == 0) {
			sleep(5);
			continue;
		} else if (nfound < 0) {
			if (errno == EINTR)
				continue;
			syslog(LOG_ERR, "select: %m");
		}
		if (FD_ISSET(listen_sock, &readfds)) {
			count = read(listen_sock, end1, start1+BUFSIZ-end1);
			if (debug)
				printf("listen count: %d\n", count);
			if (count == 0)
				do_shutdown++;
			if (count <0) {
				if (errno != EINTR)
					syslog(LOG_ERR, "read: s: %m");
				count = 0;
			}
			end1 += count;
		}
		if (FD_ISSET(forward_sock, &readfds)) {
			count = read(forward_sock, end2, start2+BUFSIZ-end2);
			if (debug)
				printf("forward count: %d\n", count);
			if (count == 0)
				do_shutdown++;
			if (count < 0) {
				if (errno != EINTR)
					syslog(LOG_ERR,
					       "read: forward_sock: %m");
				count = 0;
			}
			end2 += count;
		}
		if (FD_ISSET(forward_sock, &writefds)) {
			count = write(forward_sock, start1, end1 - start1);
			if (count < 0) {
				count = 0;
				if (errno != EINTR && errno != EWOULDBLOCK
				    && errno != ENOBUFS) {
					syslog(LOG_ERR,
					       "write: forward_sock: %m");
					do_shutdown++;
				}
			}
			start1 += count;
			if (start1 == end1)
				start1 = end1 = buff1;
		}
		if (FD_ISSET(listen_sock, &writefds)) {
			count = write(listen_sock, start2, end2 - start2);
			if (count < 0) {
				count = 0;
				if (errno != EINTR && errno != EWOULDBLOCK
				    && errno != ENOBUFS) {
					syslog(LOG_ERR, "write: s: %m");
					do_shutdown++;
				}
			}
			start2 += count;
			if (start2 == end2)
				start2 = end2 = buff2;
		}
	}
}

reapchild()
{
	union wait status;

	while (wait3(&status, WNOHANG, 0) > 0)
		;
}


shutdown()
{
	do_shutdown++;
}

int getport(port)
	char	*port;
{
	struct servent *sp;
	
	sp = getservbyname(port, "tcp");
	if (sp == NULL) {
		if (isdigit(*port))
			return(atoi(port));
		else {
			fprintf(stderr, "%s is not a valid port!\n", port);
			return(0);
		}
	}
	return(ntohs(sp->s_port));
}
