#include <stdio.h>
#include <malloc.h>
#include <string.h>
#include "ctable.h"
#include "util.h"

/* chaining hash table */

unsigned long hash (const char *string);

void
single_table_inc (table *t, char *v, unsigned long amt)
{
  unsigned long len, i, hashval;
  single_table_ent *ent;

  hashval = hash(v) % t->size;
  for (ent = (single_table_ent*)t->data[hashval], i = 0;
       ent; ent = ent->next, i++)
  {
    if (!strcmp(ent->name, v))
    {
      ent->val += amt;
      return;
    }
  }

  len = strlen(v);
  if (!(ent = mymalloc(sizeof(*ent) + len - 3)))
  {
    fprintf(stderr, "out of memory for %s table\n", t->name);
    exit(1);
  }
  ent->val = amt;
  strcpy(ent->name, v);
  if (!(ent->next = (single_table_ent*)t->data[hashval]))
    t->used++;
  else if (i >= t->longest)
    t->longest = i + 1;
  t->ents++;
  t->data[hashval] = (generic_table_ent*)ent;
}

void
double_table_inc (table *t, char *v, unsigned long amt1, unsigned long amt2)
{
  unsigned long len, i, hashval;
  double_table_ent *ent;

  hashval = hash(v) % t->size;
  for (ent = (double_table_ent*)t->data[hashval], i = 0;
       ent; ent = ent->next, i++)
  {
    if (!strcmp(ent->name, v))
    {
      ent->val1 += amt1;
      ent->val2 += amt2;
      return;
    }
  }

  len = strlen(v);
  if (!(ent = mymalloc(sizeof(*ent) + len - 3)))
  {
    fprintf(stderr, "out of memory for %s table\n", t->name);
    exit(1);
  }
  ent->val1 = amt1;
  ent->val2 = amt2;
  strcpy(ent->name, v);
  if (!(ent->next = (double_table_ent*)t->data[hashval]))
    t->used++;
  else if (i >= t->longest)
    t->longest = i + 1;
  t->ents++;
  t->data[hashval] = (generic_table_ent*)ent;
}

void
long_table_inc (table *t, unsigned long v, unsigned long amt)
{
  unsigned long hashval, i;
  long_table_ent *ent;

  hashval = ((v>>16)+(v>>8)+v) % (unsigned long)t->size;

  for (ent = (long_table_ent*)t->data[hashval], i = 0;
       ent; ent = ent->next, i++)
  {
    if (ent->name == v)
    {
      ent->val += amt;
      return;
    }
  }

  if (!(ent = mymalloc(sizeof(*ent))))
  {
    fprintf(stderr, "out of memory for %s table\n", t->name);
    exit(1);
  }
  ent->val = amt;
  ent->name = v;
  if (!(ent->next = (long_table_ent*)t->data[hashval]))
    t->used++;
  else if (i >= t->longest)
    t->longest = i + 1;
  t->ents++;
  t->data[hashval] = (generic_table_ent*)ent;
}

int
table_init (table *t, unsigned long s, char *name)
{
  t->name = name;
  t->size = s;
  t->used = t->ents = t->longest = 0;
  if (!(t->data =
        (generic_table_ent**)mymalloc(s * sizeof(generic_table_ent*))))
    return 1;
  memset(t->data, 0, s * sizeof(generic_table_ent*));
  return 0;
}

void
table_info (table *t)
{
  fprintf(stderr,"%s table: %lu/%lu used, %lu entries, longest chain %lu\n",
	  t->name, t->used, t->size, t->ents, t->longest);
}

int
table_select (table *t, int n,
              generic_table_ent **indexes, unsigned long long *scratch)
{
  unsigned long i, have, min, max, cur;
  generic_table_ent *ent;

  for (i = 0; i < t->size; i++)
  {
    if (t->data[i])
      break;
  }
  if (i == t->size)
    return 0;
  ent = t->data[i];
  indexes[0] = ent;
  scratch[0] = ent->val;

  for (have = 1;;)
  {
    ent = ent->next;
    while (!ent)
    {
      i++;
      if (i == t->size)
        break;
      ent = t->data[i];
    }
    if (i == t->size)
      break;

    if (ent->val >= scratch[have-1] || have < n)
    {
      min = 0;
      max = have;
      cur = have >> 1;
      while ((max - min) > 1)
      {
	if (ent->val > scratch[cur])
	  max = cur;
	else if (ent->val < scratch[cur])
	  min = cur;
	else
	  break;
	cur = min + ((max - min) >> 1);
      }
      if (ent->val <= scratch[cur])
	cur++;
      memmove(indexes + cur + 1, indexes + cur,
              (have - cur) * sizeof(generic_table_ent*));
      memmove(scratch + cur + 1, scratch + cur, (have - cur) * sizeof(long long));
      indexes[cur] = ent;
      scratch[cur] = ent->val;
      if (have < n)
	have++;
    }
  }

  return have;
}

int
table_select_min (table *t, int n,
                  generic_table_ent **indexes, unsigned long long *scratch)
{
  unsigned long i, have, min, max, cur;
  generic_table_ent *ent;

  for (i = 0; i < t->size; i++)
  {
    if (t->data[i])
      break;
  }
  if (i == t->size)
    return 0;
  ent = t->data[i];
  indexes[0] = ent;
  scratch[0] = ent->val;

  for (have = 1;;)
  {
    ent = ent->next;
    while (!ent)
    {
      i++;
      if (i == t->size)
        break;
      ent = t->data[i];
    }
    if (i == t->size)
      break;

    if (ent->val <= scratch[have-1] || have < n)
    {
      min = 0;
      max = have;
      cur = have >> 1;
      while ((max - min) > 1)
      {
	if (ent->val < scratch[cur])
	  max = cur;
	else if (ent->val > scratch[cur])
	  min = cur;
	else
	  break;
	cur = min + ((max - min) >> 1);
      }
      if (ent->val > scratch[cur])
	cur++;
      memmove(indexes + cur + 1, indexes + cur,
              (have - cur) * sizeof(generic_table_ent*));
      memmove(scratch + cur + 1, scratch + cur, (have - cur) * sizeof(long long));
      indexes[cur] = ent;
      scratch[cur] = ent->val;
      if (have < n)
	have++;
    }
  }

  return have;
}
