/*
 * This file is based on a contribution of David Tolpin (dvd@pizza.msk.su)
 * It is an implementation of BSD-INET sockets and is known to run on 
 * Solaris 1 and Linux.
 *
 * Bugs correction (conversion between host and network byte order) by
 * Marc Furrer (Marc.Furrer@di.epfl.ch)
 *
 * Reworked  by Erick Gallesio for 2.2 release. Some additions and simplifications
 * (I hope)
 * 
 * Last file update: 26-Apr-1995 23:53
 */

#include "stk.h"
#include <errno.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <netdb.h>
#include <memory.h>

struct socket_type {
  int portnum;
  SCM hostname;
  int descr;
  SCM input, output;
};


static int tc_socket;


#define SOCKET(x)   ((struct socket_type*)(x->storage_as.extension.data))
#define LSOCKET(x)  (x->storage_as.extension.data)
#define SOCKETP(x)  (TYPEP(x,tc_socket))
#define NSOCKETP(x) (NTYPEP(x,tc_socket))


static void system_error(char *who)
{
  char buffer[512]; /* should suffice */
  
  sprintf(buffer, "%s: %s", who, strerror(errno));
  Err(buffer, NIL);
}

static void set_socket_io_ports(int s, SCM sock, char *who)
{
  long flag = No_interrupt(1);
  int t, len, port;
  char *hostname;
  FILE *fs, *ft;
  SCM tmp;
	
  t = dup(s); /* duplicate handles so that we are able to access one 
		 socket channel via two scheme ports */

  if(!((fs = fdopen(s, "r")) && (ft = fdopen(s, "w")))) {
    char buffer[200];
    
    sprintf(buffer, "%s: cannot create socket io ports", who);
    Err(buffer, NIL);
  }
  port     = SOCKET(sock)->portnum;
  hostname = CHARS(SOCKET(sock)->hostname);
  len      = strlen(hostname) + 16;

  /* Create input port */
  NEWCELL(tmp, tc_iport); 
  tmp->storage_as.port.f = fs; 
  tmp->storage_as.port.name = (char*) must_malloc(len);
  sprintf(tmp->storage_as.port.name, "%s:%d", hostname, port);
  SOCKET(sock)->input = tmp;

  
  /* Create output port */
  NEWCELL(tmp, tc_oport);
  tmp->storage_as.port.f = ft; 
  tmp->storage_as.port.name = (char*) must_malloc(len);
  sprintf(tmp->storage_as.port.name, "%s:%d", hostname, port);
  SOCKET(sock)->output = tmp;

  No_interrupt(flag);
}

static PRIMITIVE make_server_socket(SCM port)
{
  struct sockaddr_in sin;
  int s, portnum, len;
  SCM z, hostname;
  char msg[] = "make-server-socket";
  
  hostname = NIL;
  portnum  =  (port == UNBOUND) ? 0 : STk_integer_value(port);
  if (portnum < 0)  Err("make-server-socket: bad port number", port);

  if ((s = socket(AF_INET, SOCK_STREAM, 0)) < 0) Err("Cannot create socket", NIL);
  sin.sin_family      = AF_INET;
  sin.sin_port 	      = htons(portnum);
  sin.sin_addr.s_addr = INADDR_ANY;

  if (bind(s, (struct sockaddr *) &sin, sizeof(sin)) < 0) 
    system_error(msg);

  /* Query the socket name (permits to get the true socket number if 0 was given */
  len = sizeof(sin);
  if (getsockname(s, (struct sockaddr *) &sin, (int *) &len) < 0)
    system_error(msg);

  if (listen(s, 5) < 0)
    system_error(msg);

  /* Now we can create the socket object */
  NEWCELL(z, tc_socket);
  LSOCKET(z) = (struct socket_type*) 
    				must_malloc(sizeof (struct socket_type));

  SOCKET(z)->portnum  = ntohs(sin.sin_port);
  SOCKET(z)->hostname = STk_makestring("localhost");
  SOCKET(z)->descr    = s;
  SOCKET(z)->input    = Ntruth;
  SOCKET(z)->output   = Ntruth;
  
  return z;
}

static PRIMITIVE socketp(SCM sock)
{
  return SOCKETP(sock)? Truth: Ntruth;
}

static PRIMITIVE socket_port_number(SCM sock)
{
  if (NSOCKETP(sock)) Err("socket-port-number: bad socket", sock);
  return STk_makeinteger(SOCKET(sock)->portnum);
}

static PRIMITIVE socket_input(SCM sock)
{
  if (NSOCKETP(sock)) Err("socket-input: bad socket", sock);
  return SOCKET(sock)->input;
}

static PRIMITIVE socket_output(SCM sock)
{
  if (NSOCKETP(sock)) Err("socket-output: bad socket", sock);
  return SOCKET(sock)->output;
}

static PRIMITIVE socket_hostname(SCM sock)
{
  if (NSOCKETP(sock)) Err("socket-hostname: bad socket", sock);
  return SOCKET(sock)->hostname;
}


static PRIMITIVE socket_accept_connection(SCM sock)
{
  char str[]= "socket-accept-connection";
  int s;

  if (NSOCKETP(sock)) Err("socket-accept-connection: Bad socket", sock);	

  if ((s = accept(SOCKET(sock)->descr, NULL, NULL)) < 0)
    system_error(str);

  set_socket_io_ports(s, sock, str);
  return UNDEFINED;
}


static PRIMITIVE make_client_socket(SCM hostname, SCM port)
{
  char str[] = "make-client-socket";
  struct hostent *hp;
  struct sockaddr_in server;
  int s;
  SCM z;
  
  if(NSTRINGP(hostname)) Err("make-client-socket: bad hostname", hostname);
  if(NINTEGERP(port))    Err("make-client-socket: bad port number", port);
  
  if ((hp=gethostbyname(CHARS(hostname))) == NULL)
    Err("make-client-socket: unknown or misspelled host name", hostname);

  memset(&server, 0, sizeof(server));
  memcpy((char*)&server.sin_addr, hp->h_addr, hp->h_length);

  server.sin_family = hp->h_addrtype;
  server.sin_port   = htons(INTEGER(port));

  if ((s=socket(AF_INET,SOCK_STREAM,0)) < 0)  Err("Cannot create socket", NIL);

  if (connect(s, (struct sockaddr *) &server, sizeof(server)) < 0)
    system_error(str);


  NEWCELL(z, tc_socket);
  LSOCKET(z) = (struct socket_type*) 
    				must_malloc(sizeof (struct socket_type));

  SOCKET(z)->portnum  = ntohs(server.sin_port);
  SOCKET(z)->hostname = hostname;
  SOCKET(z)->descr    = -1;
  SOCKET(z)->input    = Ntruth;
  SOCKET(z)->output   = Ntruth;
  
  set_socket_io_ports(s, z, str);
  return z;
}

static PRIMITIVE socket_shutdown(SCM sock, SCM close_socket)
{
  SCM tmp;
  
  if (close_socket == UNBOUND) close_socket = Truth;

  if (NSOCKETP(sock)) 	       Err("socket-shutdown: bad socket", sock);
  if (NBOOLEANP(close_socket)) Err("socket-shutdown: bad boolean", close_socket);

  if (close_socket == Truth && SOCKET(sock)->descr > 0) 
    close(SOCKET(sock)->descr);

  tmp = SOCKET(sock)->input;
  if (tmp->storage_as.port.f) /* not already closed */
    shutdown(fileno(tmp->storage_as.port.f), 2);

  tmp = SOCKET(sock)->output;
  if (tmp->storage_as.port.f) /* not already closed */
    shutdown(fileno(tmp->storage_as.port.f), 2);

  return UNDEFINED;
}


/******************************************************************************/


static void mark_socket(SCM sock)
{
  STk_gc_mark(SOCKET(sock)->hostname);
  STk_gc_mark(SOCKET(sock)->input);
  STk_gc_mark(SOCKET(sock)->output);
}


static void free_socket(SCM sock)
{
  socket_shutdown(sock, Truth);
  free(SOCKET(sock));
}

static void displ_socket(SCM sock, SCM port, int mode)
{
  struct socket_type *sh = SOCKET(sock);

  sprintf(STk_tkbuffer, "#[socket %s %d]", CHARS(sh->hostname), sh->portnum);
  Puts(STk_tkbuffer, port->storage_as.port.f);
}

static STk_extended_scheme_type socket_type = {
  "socket",		/* name */
  0, 			/* is_procp */
  mark_socket, 		/* gc_mark_fct */
  free_socket,		/* gc_free_fct */
  NULL,			/* apply_fct */
  displ_socket		/* display_fct */
};

/******************************************************************************/

PRIMITIVE STk_init_socket(void)
{
  tc_socket = STk_add_new_type(&socket_type);

  STk_add_new_primitive("make-server-socket",  tc_subr_0_or_1, make_server_socket);
  STk_add_new_primitive("socket?",	       tc_subr_1,      socketp);
  STk_add_new_primitive("socket-port-number",  tc_subr_1,      socket_port_number);
  STk_add_new_primitive("socket-input",        tc_subr_1,      socket_input);
  STk_add_new_primitive("socket-output",       tc_subr_1,      socket_output);
  STk_add_new_primitive("socket-hostname",     tc_subr_1,      socket_hostname);
  STk_add_new_primitive("socket-shutdown",     tc_subr_1_or_2, socket_shutdown);
  STk_add_new_primitive("make-client-socket",  tc_subr_2,      make_client_socket);
  STk_add_new_primitive("socket-accept-connection",
		    		               tc_subr_1, socket_accept_connection);

  return UNDEFINED;
}
