#include <stdio.h>
#include <malloc.h>
#include <string.h>
#include "table.h"
#include "util.h"

unsigned long hash (const char *string);

static unsigned long
string_table_get_index (table *t, char *v)
{
  unsigned long oval, hashval = hash(v);

  oval = hashval %= t->size;
  while (t->data[hashval])
  {
    if (!strcmp(t->data[hashval], v))
      return hashval;
    hashval++;
    if (hashval == t->size)
      hashval = 0;
  }
  if (t->available < 2)
  {
    fprintf(stderr, "full %s table! (size %ld)\n", t->name, t->size);
    exit(1);
  }
  if (t->available == 100)
  {
    fprintf(stderr, "down to 100 free entries of %ld in %s table\n",
	    t->size, t->name);
  }
  if (oval != hashval)
    t->col++;
  t->available--;
  return hashval;
}

void
single_table_inc (table *t, char *v, unsigned long amt)
{
  unsigned long i = string_table_get_index (t, v);
  if (t->data[i])
    t->values[0][i] += amt;
  else
  {
    t->data[i] = mystrdup(v);
    t->values[0][i] = amt;
  }
}

void
double_table_inc (table *t, char *v, unsigned long amt1, unsigned long amt2)
{
  unsigned long i = string_table_get_index (t, v);
  if (t->data[i])
  {
    t->values[0][i] += amt1;
    t->values[1][i] += amt2;
  }
  else
  {
    t->data[i] = mystrdup(v);
    t->values[0][i] = amt1;
    t->values[1][i] = amt2;
  }
}

void
long_table_inc (table *t, unsigned long v, unsigned long amt)
{
  unsigned long oval, hashval;

  oval = hashval = ((v>>16)+(v>>8)+v) % (unsigned long)t->size;
  while (t->data[hashval])
  {
    if ((unsigned long)t->data[hashval] == v)
    {
      t->values[0][hashval] += amt;
      return;
    }
    hashval++;
    if (hashval == t->size)
      hashval = 0;
  }
  if (t->available < 2)
  {
    fprintf(stderr, "full %s table! (size %ld)\n", t->name, t->size);
    exit(1);
  }
  if (t->available == 100)
  {
    fprintf(stderr, "down to 100 free entries of %ld in %s table\n",
	    t->size, t->name);
  }
  if (oval != hashval)
    t->col++;
  t->available--;
  t->data[hashval] = (char*)v;
  t->values[0][hashval] = amt;
  return;
}

int
table_init (table *t, unsigned long s, char *name, int dims)
{
  int i;
  t->name = name;
  t->available = t->size = s;
  t->col = 0;
  if (!(t->data = (char**)mymalloc(s*sizeof(char*))) ||
      !(t->values = (unsigned long**)mymalloc(dims*sizeof(unsigned long*))))
    return 1;
  for (i = 0; i < dims; i++)
    if (!(t->values[i] = (unsigned long*)mymalloc(s*sizeof(unsigned long))))
      return 1;
  memset(t->data, 0, s*sizeof(char*));
  return 0;
}

void
table_info (table *t)
{
  fprintf(stderr,"%s table: %ld/%ld used, %ld collisions\n",
	  t->name, t->size - t->available, t->size, t->col);
}

int
table_select (table *t, int n, int dim,
	      unsigned long *indexes, unsigned long *scratch)
{
  long i, have, min, max, cur;

  for (i = 0; i < t->size; i++)
  {
    if (t->data[i])
      break;
  }
  indexes[0] = i;
  scratch[0] = t->values[dim][i];
  i++;
  for (have = 1; i < t->size; i++)
  {
    if (t->data[i] && (t->values[dim][i] >= scratch[have-1] || have < n))
    {
      min = 0;
      max = have;
      cur = have >> 1;
      while ((max - min) > 1)
      {
	if (t->values[dim][i] > scratch[cur])
	  max = cur;
	else if (t->values[dim][i] < scratch[cur])
	  min = cur;
	else
	  break;
	cur = min + ((max - min) >> 1);
      }
      if (t->values[dim][i] <= scratch[cur])
	cur++;
      memmove(indexes+cur+1, indexes+cur, (have-cur)*sizeof(long));
      memmove(scratch+cur+1, scratch+cur, (have-cur)*sizeof(long));
      indexes[cur] = i;
      scratch[cur] = t->values[dim][i];
      if (have < n)
	have++;
    }
  }

  return have;
}

int
table_select_min (table *t, int n, int dim,
		  unsigned long *indexes, unsigned long *scratch)
{
  long i, have, min, max, cur;

  for (i = 0; i < t->size; i++)
  {
    if (t->data[i])
      break;
  }
  indexes[0] = i;
  scratch[0] = t->values[dim][i];
  i++;
  for (have = 1; i < t->size; i++)
  {
    if (t->data[i] && (t->values[dim][i] <= scratch[have-1] || have < n))
    {
      min = 0;
      max = have;
      cur = have >> 1;
      while ((max - min) > 1)
      {
	if (t->values[dim][i] < scratch[cur])
	  max = cur;
	else if (t->values[dim][i] > scratch[cur])
	  min = cur;
	else
	  break;
	cur = min + ((max - min) >> 1);
      }
      if (t->values[dim][i] > scratch[cur])
	cur++;
      memmove(indexes+cur+1, indexes+cur, (have-cur)*sizeof(long));
      memmove(scratch+cur+1, scratch+cur, (have-cur)*sizeof(long));
      indexes[cur] = i;
      scratch[cur] = t->values[dim][i];
      if (have < n)
	have++;
    }
  }

  return have;
}
