/*  PKE.c 
 *    The code which implements all the routines necessary to
 * support the PKE.el routines (see pkeapi). It makes extensive use of
 * the RSAREF implementation from RSA Laboratories.
 *
 *  The datastructure used to represent this is a linked list of 
 *  pke_elt's, where each pke_elt holds a pair of values:
 *      data   :   the encrypted item (random + plaintext).
 *      next   :   a pointer to the next element in the list
 *                    or NULL for the last element.
 */


#include "incremental.h"

#include <stdio.h>
#include <sys/param.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <unistd.h>
#include <time.h>
#include <string.h>
#include <stdlib.h>
#include <pwd.h>

/* Some static variables.... these should be constant throughout
  the program */
R_RSA_PUBLIC_KEY key;
int key_good = 0;
R_RANDOM_STRUCT randomStruct;
pke_elt *head = NULL; /* The head of the list */
int current_num = 0;         /* the number of the element we just
				looked at */
pke_elt *current_ptr; /* a cache pointer to the element we just
				dealt with */
char input[MAX_INPUT_SIZE];  /* Note static size!! bad!! */


int main_loop(), get_input();
void set_key(), reset_structure(), load_file(), create(),
  modify(), delete(), insert(), save_file(), 
  decode(char *, char *), PKE_encrypt(char *, int, pke_elt *),
  pad_with_random(char *), get_home(char *), Init_setup();


int main(argc, argv){
  /* Setup whatever is necessary... */

  if (argc > 1) {
    fprintf(stderr, 
	    "This program has NO arguments. See the documentation!\n");
    exit(1);
  }
  Init_setup();

  return(main_loop());
}

int main_loop() {
  /* loop in this function until the input is 'End' and arbitrate what
     to do based on input received here */

  int finished = 0;

  while (!finished){
    switch(get_input()) {
    case INPUT_QUIT:
      finished = 1;
      break;
    case INPUT_KEY:
      reset_structure();
      set_key();
      break;
    case INPUT_CREATE:
      create();
      break;
    case INPUT_INSERT:
      insert();
      break;
    case INPUT_MODIFY:
      modify();
      break;
    case INPUT_DELETE:
      delete();
      break;
    case INPUT_SAVE:
      save_file();
      break;
    case INPUT_LOAD:
      load_file();
      break;
    case INPUT_ERROR:
      fprintf(stderr, "Got something bad: %s\n", input);
      break;
    default:
      fprintf(stderr, "I'm confused : %s\n", input);
      exit(1);
    }
  }
  return (0);
}

int get_input() {
  /* Figure out what needs to be done */

  if ((fgets(input, MAX_INPUT_SIZE, stdin)) == NULL)
    return (INPUT_ERROR);
  if (! strncasecmp(input, "public-key", 
		    (sizeof("public-key") - 1)))
    return(INPUT_KEY);
  if (! strncasecmp(input, "create", (sizeof("create") - 1)))
    return(INPUT_CREATE);
  if (! strncasecmp(input, "insert", (sizeof("insert") - 1)))
    return (INPUT_INSERT);
  if (! strncasecmp(input, "modify", (sizeof("modify") - 1)))
    return (INPUT_MODIFY);
  if (! strncasecmp(input, "delete", (sizeof("delete") - 1)))
    return (INPUT_DELETE);
  if (! strncasecmp(input, "save", (sizeof("save") - 1)))
    return (INPUT_SAVE);
  if (! strncasecmp(input, "load", (sizeof("load") - 1)))
    return (INPUT_LOAD);
  if (! strncasecmp(input, "quit", (sizeof("quit") - 1)))
    return (INPUT_QUIT);
  return(INPUT_ERROR);
}

void decode(char *input, char *output){
  /* decode a string of form "42 6a" (hex ascii codes) into
     appropriate format..  assumes input is a string with enough space
     to hold result */
  int len, ref_i, ref_o;

  len = strlen(input)/3;
  ref_i = 0;
  ref_o = 0;
  while (ref_i < len){
    output[ref_o] = (unsigned char)strtol(&input[ref_i], NULL, 16);
    ref_o++;
    ref_i += 3;
  }
}

void set_key() {
  int fd;
  char keyname[MAXPATHLEN];
  
  if ((fgets(input, MAX_INPUT_SIZE, stdin)) == NULL){
    perror("fgets");
    return;
  }
  input[(strlen(input) - 1)] = '\0';  /* strip \n */
  get_home(keyname);
  strcat(keyname, PKE_PATH);
  strcat(keyname, input);
  strcat(keyname, ".pub");

  if ((fd = open(keyname, O_RDONLY)) < 0) {
    fprintf(stderr, "while trying to open %s \n", keyname);
    perror("open");
    return;
  }
  read(fd, &key.bits, sizeof(key.bits)); 
  read(fd, &key.modulus, sizeof(key.modulus)); 
  read(fd, &key.exponent, sizeof(key.exponent));

  key_good = 1;
}

void reset_structure() {
  /* This should reset the data structure to empty */
  
  pke_elt *next, *pos = head;

  if (pos != NULL)
    next = pos->next;
  while (pos != NULL){
    free(pos);
    pos = next;
    if (next != NULL)
      next = next->next;
  }
  head = NULL;
  current_num = 0;
}

void pad_with_random(char* string)
{ /* Pads the string with appropriate length random data */
  char *ptr = string, *end = (string + RANDOM_PAD_SIZE);

  srand(time(NULL));
  
  while (ptr < end)
    *ptr++ = (char) rand();
}


void create() {
  /* This function implements the Create: operation */
  /* Make sure we've cleared everything */
  pke_elt *ptr, *prev;
  char new_char;
  int done = 0;

  reset_structure();
  ptr = head;
  prev = head;

  /* We need a key value */
  if (!key_good) {
    fprintf(stderr, "No key supplied!");
    return;
  }

  /* No we read a number which is a decimal ASCII code for the byte
     that will go into the structure., Until we read 'End' we keep
     going */

  while (!done) {
    if ((fgets(input, MAX_INPUT_SIZE, stdin)) == NULL) {
      perror("fgets");
      exit(1);
    }
    if (!strncasecmp(input, "end", (sizeof("end") - 1)))
      done = 1;
    else {
      if ((ptr = (struct pke_elt_struct *)
	   malloc(sizeof(struct pke_elt_struct))) == NULL) {
	fprintf(stderr, "malloc failed.. out of memory?\n");
	exit(1);
      }
      ptr->next = NULL;
      if (prev == NULL) {
	head = ptr;
	prev = head;
      }
      else {
	prev->next = ptr;
	prev = ptr;
      }

      /* Decode the byte into a data block with randomizer */
      new_char = (char)atoi(input);
      pad_with_random(input);
      input[RANDOM_PAD_SIZE] = new_char;
      input[RANDOM_PAD_SIZE + 1] = '\0';

      /* Now encrypt it */
      PKE_encrypt(input, (int)(DATA_LEN), ptr);
      
      /* update our 'cached value' */
      current_num++;
    }
  }
  /* make the cached value hold the right data */
  current_ptr = ptr;
}

void modify() {
  /* This function implements the Modify: operation */
  
  int cur_pos = 1, searching;
  pke_elt *cur_ptr = head;
  char newdata[DATA_LEN + 1];

  if ((fgets(input, MAX_INPUT_SIZE, stdin)) == NULL) {
    perror("fgets");
    return;
  }
  if ((sscanf(input, "%d", &searching)) != 1){
    fprintf(stderr, "sscanf failure");
    return;
  }
  if ((fgets(input, MAX_INPUT_SIZE, stdin)) == NULL) {
    perror("fgets");
    return;
  }
  input[(strlen(input) - 1)] = '\0';  /* strip \n */

  /* use cache if can */
  if ((current_num <= searching) &&
      (current_num != 0)) {
    cur_pos = current_num;
    cur_ptr = current_ptr;
  }
  while ((cur_pos != searching) &&
	 (cur_ptr != NULL)){
    cur_ptr = cur_ptr->next;
    cur_pos++;
  }
  if ((cur_pos == searching) &&
      (cur_ptr != NULL)) {
    pad_with_random(newdata);
    decode(input, &newdata[RANDOM_PAD_SIZE]);
    PKE_encrypt(newdata, DATA_LEN, cur_ptr);
    current_num = cur_pos;
    current_ptr = cur_ptr;
  }
  else
    if (cur_ptr == NULL){
      fprintf(stderr, "modify ran past end of structure\n");
      return;
    }
    else {
      fprintf(stderr, "problems modifying\n");
      return;
    }
}

void delete() {
  /* This function implements the Delete: operation */
  
  int searching, cur_pos = 1;
  pke_elt *prev_ptr = NULL, *cur_ptr = head;
  
  if ((fgets(input, MAX_INPUT_SIZE, stdin)) == NULL) {
    perror("fgets");
    return;
  }
  if ((sscanf(input, "%d", &searching)) != 1){
    fprintf(stderr, "sscanf failure");
    return;
  }

  /* use cache if can */
  if ((current_num < searching) &&
      (current_num != 0)) {
    cur_pos = current_num + 1;
    cur_ptr = current_ptr -> next;
    prev_ptr = current_ptr;
  }

  while ((cur_pos != searching) &&
	 (cur_ptr != NULL)){
    prev_ptr = cur_ptr;
    cur_ptr = cur_ptr->next;
    cur_pos++;
  }
  
  if ((cur_ptr != NULL) &&
      (cur_pos == searching))
    if (prev_ptr == NULL) {  /* we're deleting the first element */
      head = cur_ptr -> next;
      free(cur_ptr);
      current_num = 0;  /* blow away cache */
    }
    else{
      prev_ptr->next = cur_ptr->next;
      free(cur_ptr);
      current_num = cur_pos - 1;
      current_ptr = prev_ptr;
    }
  else{
    fprintf(stderr, "error deleting\n");
  }
}

void insert() {
  /* This function implements the Insert: operation */
  int searching, cur_pos = 1;
  pke_elt *prev_ptr = NULL, *cur_ptr = head, *temp;
  char newdata[DATA_LEN + 1];
  
  if ((fgets(input, MAX_INPUT_SIZE, stdin)) == NULL) {
    perror("fgets");
    return;
  }
  if ((sscanf(input, "%d", &searching)) != 1){
    fprintf(stderr, "sscanf failure");
    return;
  }
  if ((fgets(input, MAX_INPUT_SIZE, stdin)) == NULL) {
    perror("fgets");
    return;
  }
  input[(strlen(input) - 1)] = '\0';  /* strip \n */

  /* use cache if can */
  if ((current_num < searching) &&
      (current_num != 0)) {
    cur_pos = current_num + 1;
    cur_ptr = current_ptr -> next;
    prev_ptr = current_ptr;
  }

  while ((cur_pos != searching) &&
	 (cur_ptr != NULL)){
    prev_ptr = cur_ptr;
    cur_ptr = cur_ptr->next;
    cur_pos++;
  }
  if (prev_ptr == NULL){  /* we're inserting at the head */
    if ((prev_ptr = (struct pke_elt_struct *)
	 malloc(sizeof(struct pke_elt_struct))) == NULL){
      perror("malloc");
      exit(1);
    }
    pad_with_random(newdata);
    decode(input, &newdata[RANDOM_PAD_SIZE]);
    PKE_encrypt(newdata, DATA_LEN, prev_ptr);
    head = prev_ptr;
    prev_ptr->next = cur_ptr;
    current_num = 2;
    current_ptr = cur_ptr;
  }
  else {
    if (cur_pos == searching){
      if ((temp = (struct pke_elt_struct *)
	   malloc(sizeof(struct pke_elt_struct))) == NULL){
	perror("malloc");
	exit(1);
      }
      pad_with_random(newdata);
      decode(input, &newdata[RANDOM_PAD_SIZE]);
      PKE_encrypt(newdata, DATA_LEN, temp);
      prev_ptr->next = temp;
      temp -> next = cur_ptr;
      current_num = searching;
      current_ptr = temp;
    }
    else {
      fprintf(stderr, "confused!!");
      exit(1);
    }
  }
}

void save_file() {
  /* This function implements the Save: operation */
  int fd;
  pke_elt *cur_ptr = head;

  if ((fgets(input, MAX_INPUT_SIZE, stdin)) == NULL) {
    perror("fgets");
    return;
  }
  input[(strlen(input) - 1)] = '\0';  /* strip \n */
  if ((fd = open(input, O_RDWR|O_CREAT, S_IRUSR | S_IWUSR |
		 S_IRGRP | S_IWGRP | S_IROTH | S_IWOTH)) <0) {
    perror("failed to open file for writing");
    exit(1);
  }
  
  while(cur_ptr) {
    if ((write(fd, &cur_ptr->datalen, sizeof(int))) < sizeof(int)){
      perror("couldn't write");
      return;
    }      
    if ((write(fd, cur_ptr->data, cur_ptr->datalen)) <
        cur_ptr->datalen) {
      perror("couldn't write");
      return;
    }
    cur_ptr = cur_ptr -> next;
  }
  close(fd);
  printf("Done\n");
}

void load_file() {
  /* This function implements the Load: operation */
  int fd, done = 0;
  pke_elt *cur_ptr, *prev_ptr = NULL;

  reset_structure();

  if ((fgets(input, MAX_INPUT_SIZE, stdin)) == NULL) {
    perror("fgets");
    return;
  }
  input[(strlen(input) - 1)] = '\0';  /* strip \n */
  if ((fd = open(input, O_RDONLY)) <0) {
    perror("failed to open file");
    exit(1);
  }

  while (!done) {
    if ((cur_ptr = (struct pke_elt_struct *)
	 malloc(sizeof(struct pke_elt_struct))) == NULL){
      fprintf(stderr, "malloc failed to find enough memory");
      exit(1);
    }
    if ((read(fd, cur_ptr->data, MAX_ENC_DATA_LEN)) < 
	MAX_ENC_DATA_LEN){
      done = 1;
      free(cur_ptr);
      if (prev_ptr == NULL)
	head = NULL;
      else
	prev_ptr -> next = NULL;
    }
    else {
      if (prev_ptr == NULL){
	head = cur_ptr;
	prev_ptr = cur_ptr;
      }
      else{
	prev_ptr -> next = cur_ptr;
	prev_ptr = cur_ptr;
      }
    }
  }
}

/* todo */
void PKE_encrypt (char * input, int inlen, pke_elt *output) {
  /* encrypt the string in input to the string in output...
     assumes output has enough space to store data */
  int status, outlen;

  if (!key_good){
    fprintf(stderr, "no key!!\n");
    exit(1);
  }
  if ((status = RSAPublicEncrypt
    (output->data, &outlen, input, inlen, &key, &randomStruct))
      != 0){
    fprintf(stderr, "Having problems encrypting");
    exit(1);
  }
  output->datalen = outlen;
}

void get_home(result)
char *result;
{
  struct passwd *pwd;
 
  if (strcpy(result,(char *)getenv("HOME")))
    return;
  if (pwd = getpwuid(getuid())){
    strcpy(result, pwd->pw_dir);
    return;
  }
  else{
    strcpy(result, "/");
    return;
  }
}

void Init_setup() { /* try to initialize things */
  static unsigned char seedByte = 0;
  unsigned int bytesNeeded =1000;
  struct timeval tp;
  struct timezone tz;

  setbuf(stdout, NULL);
  R_RandomInit (&randomStruct);

  gettimeofday(&tp, &tz);
  seedByte = (char)tp.tv_usec;
  R_RandomUpdate (&randomStruct, &seedByte, 1);

  while (bytesNeeded > 0) {
    gettimeofday(&tp, &tz);
    seedByte = (char)tp.tv_usec;
    R_RandomUpdate (&randomStruct, &seedByte, 1);
    R_GetRandomBytesNeeded (&bytesNeeded, &randomStruct);
  }
}
