/* $Id: proxy.C,v 1.10 2001/11/07 20:18:36 kaminsky Exp $ */

/*
 *
 * Copyright (C) 2000-2001 Eric Peterson (ericp@lcs.mit.edu)
 * Copyright (C) 2000 David Mazieres (dm@uun.org)
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; either version 2, or (at
 * your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 * USA
 *
 */

#include "proxy.h"
#include "qhash.h"
#include "axprt_crypt.h"
#include "sfsmisc.h"
struct rexclnt {
  static u_int nclnt;

  ref<aclnt> c;
  ref<asrv> s;

  qhash<int, ref<chanbase> > chantab;

  rexclnt (ref<axprt> x)
    : c (aclnt::alloc (x, rexcb_prog_1)),
      s (asrv::alloc (x, rex_prog_1, wrap (this, &rexclnt::dispatch)))
    { nclnt++; }
  ~rexclnt ();
      
  int chanalloc ();
  void dispatch (svccb *);
};

u_int rexclnt::nclnt;
ptr<axprt_unix> rxprt;
ptr<asrv> rsrv;

rexclnt::~rexclnt ()
{
  if (!--nclnt)
    exit (0);
}

int
rexclnt::chanalloc ()
{
  int i;
  for (i = 1; chantab[i]; i++)
    ;
  return i;
}

void
rexclnt::dispatch (svccb *sbp)
{
  if (!sbp) {
    delete this;
    return;
  }
  switch (sbp->proc ()) {
  case REX_NULL:
    sbp->reply (NULL);
    break;
  case REX_DATA:
    {
      rex_payload *argp = sbp->template getarg<rex_payload> ();
      if (argp->fd < 0) {
	chantab.remove (argp->channel);
	sbp->replyref (false);
      }
      else if (chanbase *c = chantab[argp->channel])
	c->data (sbp);
      else
	sbp->replyref (false);
      break;
    }
  case REX_CLOSE:
  case REX_KILL:
    {
      rex_int_arg *argp = sbp->template getarg<rex_int_arg> ();
      if (chanbase *c = chantab[argp->channel]) {
	if (sbp->proc () == REX_KILL)
	  c->kill (sbp);
	else
	  c->close (sbp);
      }
      else
	sbp->replyref (false);
      break;
    }
  case REX_MKCHANNEL:
    {
      rex_mkchannel_arg *argp = sbp->template getarg<rex_mkchannel_arg> ();
      ptr<chanbase> cb;

      int cn = chanalloc ();
      cb = mkchannel_prog (c, cn, argp);

      rex_mkchannel_res res (SFS_TEMPERR);
      if (cb) {
	chantab.insert (cn, cb);
	res.set_err (SFS_OK);
	res.resok->channel = cn;
      }
      sbp->reply (&res);
      break;
    }
  case REX_SETENV:
    {
      rex_setenv_arg *arg = sbp->template getarg<rex_setenv_arg> ();
      if (!arg->name.len ()) {
	warn ("received REX_SETENV with null name\n");
	sbp->replyref (false);
	break;
      }
      if (setenv (arg->name.cstr(), arg->value.cstr(), 1)) {
	warn ("dispatch (REX_SETENV call) setenv failed for (%m)\n");
	sbp->replyref (false);
	break;
      }
      sbp->replyref (true);
      break;
    }
  case REX_UNSETENV:
    {
      rex_unsetenv_arg *arg = sbp->template getarg<rex_unsetenv_arg> ();
      if (arg->len ())
	unsetenv (arg->cstr ());
      else
        warn ("received unsetenv on null variable name\n");
    }
  default:
    sbp->reject (PROC_UNAVAIL);
    break;
  }
}

static void
ctldispatch (svccb *sbp)
{
  if (!sbp) {
    warn ("EOF from rexd\n");
    rxprt = NULL;
    rsrv = NULL;
    return;
  }

  switch (sbp->proc ()) {
  case REXCTL_NULL:
    sbp->reply (NULL);
    break;
  case REXCTL_CONNECT:
    {
      sfs_sessinfo *argp = sbp->template getarg<sfs_sessinfo> ();
      int fd = rxprt->recvfd ();
      if (fd >= 0) {
	ref<axprt_crypt> x (axprt_crypt::alloc (fd));
	x->encrypt (argp->ksc.base (), argp->ksc.size (),
		    argp->kcs.base (), argp->kcs.size ());
	vNew rexclnt (x);
      }
      else
	warn ("could not receive descriptor from rexd\n");
      // XXX - more stuff needs to be bzeroed
      bzero (argp->ksc.base (), argp->ksc.size ());
      bzero (argp->kcs.base (), argp->kcs.size ());
      sbp->reply (NULL);
      break;
    }
  default:
    sbp->reject (PROC_UNAVAIL);
    break;
  }
}

static void
timeout ()
{
  if (!rexclnt::nclnt)
    exit (0);
}

int
main (int argc, char **argv)
{
  setprogname (argv[0]);
  sfsconst_init ();
  
  if (argc > 1)
    fatal << "usage: " << progname << "\n";

  if (!isunixsocket (0))
    fatal ("stdin must be a unix domain socket.\n");
  rxprt = axprt_unix::alloc (0);
  rsrv = asrv::alloc (rxprt, rexctl_prog_1, wrap (ctldispatch));

  timecb (time (NULL) + 120, wrap (timeout));

  amain ();
}
