extern "C" {
  #include <stdio.h>
  #include <time.h>
  #include <sys/time.h>
  #include <sys/param.h>
  #include <stdlib.h>
}
#include <Integer.h>
#include "utils.h"

int small_primes[] = { 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41,
		       43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97,
		       101, 103, 107, 109, 113, 127, 131, 137, 139,
		       149, 151, 157, 163, 167, 173, 179, 181, 191,
		       193, 197, 199, 211, 223, 227, 229, 233, 239,
		       241, 251, 257, 263, 269, 271, 277, 281, 283,
		       293, 307, 311, 313, 317, 331, 337, 347, 349,
		       353, 359, 367, 373, 379, 383, 389, 397, 401,
		       409, 419, 421, 431, 433, 439, 443, 449, 457,
		       461, 463, 467, 479, 487, 491, 499, 503, 509,
		       521, 523, 541, 547, 557, 563, 569, 571, 577,
		       587, 593, 599, 601, 607, 613, 617, 619, 631,
		       641, 643, 647, 653, 659, 661, 673, 677, 683,
		       691, 701, 709, 719, 727, 733, 739, 743, 751,
		       757, 761, 769, 773, 787, 797, 809, 811, 821,
		       823, 827, 829, 839, 853, 857, 859, 863, 877,
		       881, 883, 887, 907, 911, 919, 929, 937, 941,
		       947, 953, 967, 971, 977, 983, 991, 997, 1009,
		       1013, 1019, 1021, 1031, 1033, 1039, 1049, 1051,
		       1061, 1063, 1069, 1087, 1091, 1093, 1097, 1103,
		       1109, 1117, 1123, 1129, 1151, 1153, 1163, 1171,
		       1181, 1187, 1193, 1201, 1213, 1217, 1223, 1229,
		       1231, 1237, 1249, 1259, 1277, 1279, 1283, 1289,
		       1291, 1297, 1301, 1303, 1307, 1319, 1321, 1327,
		       1361, 1367, 1373, 1381, 1399, 1409, 1423, 1427,
		       1429, 1433, 1439, 1447, 1451, 1453, 1459, 1471,
		       1481, 1483, 1487, 1489, 1493, 1499, 1511, 1523,
		       1531, 1543, 1549, 1553, 1559, 1567, 1571, 1579,
		       1583, 1597, 1601, 1607, 1609, 1613, 1619, 1621,
		       1627, 1637, 1657, 1663, 1667, 1669, 1693, 1697,
		       1699, 1709, 1721, 1723, 1733, 1741, 1747, 1753,
		       1759, 1777, 1783, 1787, 1789, 1801, 1811, 1823,
		       1831, 1847, 1861, 1867, 1871, 1873, 1877, 1879,
		       1889, 1901, 1907, 1913, 1931, 1933, 1949, 1951,
		       1973, 1979, 1987, 1993, 1997, 1999, 0 };

// Prototypes of functions that should only be accesses from here
Integer       sqrt_p          (Integer, Integer);
int           rabin_miller    (Integer, int);        
void          e_euclid        (Integer, Integer, Integer *, Integer *, Integer *);

void init()
/* initializer function. This should be called before anything else
 * to seed the random number generator and do whatever other necessary
 * things need to be done.
 */
{
  struct timeval foo;
  
  gettimeofday(&foo, NULL);
  srandom(foo.tv_sec + foo.tv_usec);
}

Integer exp_mod(Integer a, Integer b, Integer m)
/* This function does a^b (mod m), in hopefully a slightly fast
 * way through repeated exponentiation and modulo in appropriate
 * times (using x^2^2 etc..).
 */
{
  Integer s = 1;
  while (b != 0) {
    if (odd(b))
      s = (s*a) % m;
    b >>= 1;
    a = (a*a) % m;
  }
  return (s);
}

long exp_mod(long a, long b, long m)
/* This function does a^b (mod m), in hopefully a slightly fast
 * way through repeated exponentiation and modulo in appropriate
 * times (using x^2^2 etc..).
 * it won't work for anything greater than 15 bit numbers....
 */
{
  long s = 1;
  while (b != 0) {
    if (odd(b))
      s = (s*a) % m;
    b >>= 1;
    a = (a*a) % m;
  }
  return (s);
}

Integer generate_random(int size)
/* This function generates a random number of size bits
 *  This is NOT a good random number generator right
 *  now.. fixing it will be nice, but unnecessary for
 *  testing purposes.
 */
{
  Integer value, mask, temp=size;
  value = 0;

  while (temp > 0){
    temp -= 32;
    value <<= 32;
    value += random();
  }
  mask = (1<<(Integer)size) - 1;     // set all the bits up to size'th to 1
  value &= mask;                     // zero out all the larger bits

  return value;
}

Integer generate_prime(int size) 
/* this function generates a prime of the size passed in.
 * It relies on a random number generator, and uses the following
 * scheme:
 *  1) generate a size-bit number p
 *  2) set the size'th bit and bit 0 to 1 (odd and of set size).
 *  3) check to ensure it doesn't devide by small primes
 *  4) do a Rabin-Miller test five times.
 */
{
  int good = 0;
  Integer p, a, b, m;
  FILE *prime_file;
  char name[MAXPATHLEN];
  char value[LARGEST_PRIME_DIGITS];  // okay..this shouldn't be static but that's fast.
  
  sprintf(name, "primes.%d", size);
  if ((prime_file = fopen(name, "r")) != NULL) { // get the prime from the file
    good=(random() % PRIMES_IN_FILE);
    while (good--){
      fgets(value, LARGEST_PRIME_DIGITS-1, prime_file);
    }
    fgets(value, LARGEST_PRIME_DIGITS-1, prime_file);
    p = atoI(value);
    fclose(prime_file);
  }
  else{
    while (!good) {
    try_again:
      good = 1;
      
      p = generate_random(size);
      (setbit)(p, (size-1));      // make sure p is fully size bits
      (setbit)(p, 0);             // make it odd
      
      // p should now be a large random odd number of size bits.
      // test that it doesn't devide by small primes

      /*
	 for (int index = 0; small_primes[index]; index++) {
	 if ((p % small_primes[index]) == 0)
	 goto try_again;
	 }
	 */
      if ((p % 3) == 0)
	goto try_again;
      if ((p % 5) == 0)
	goto try_again;
      if ((p % 7) == 0)
	goto try_again;
      if ((p % 11) == 0)
	goto try_again;
      if ((p % 13) == 0)
	goto try_again;
      if ((p % 17) == 0)
	goto try_again;
      if ((p % 19) == 0)
	goto try_again;
      if ((p % 23) == 0)
	goto try_again;
      if ((p % 29) == 0)
	goto try_again;
      if ((p % 31) == 0)
	goto try_again;
      if ((p % 37) == 0)
	goto try_again;
      
      // okay.. so it doesn't have a small factor, now do a probabilistic
      // test, and cross your fingers
      // Miller-Rabin test:
      
      if (!(rabin_miller(p, TIMES_TO_REPEAT)))
	goto try_again;      
    }
  }
  return(p); 
}

int rabin_miller(Integer p, int repeat)
/* Test whether p is indeed a prime using the rabin_miller scheme, with
 * one pass. It is reccomended you call this function at least 5 times.
 */
{
  Integer m, a, z;
  int j, b = 0;

  //b is the number of times p-1 devides into 2
  m = p-1;
  while (testbit(m, b) == 0){
    b++;
  }
  m /= (pow(2, b));

  while (repeat--){
    a = generate_random(32);  // a should be smaller than p...
    j = 0;
    z = exp_mod(a,m,p);
    if (!((z==1) || (z == p-1))){
      while(1){  // repeat until convinced of one or the other
	if ((j > 0) && (z == 1))
	  return(0);  // not prime
	j++;
	if ((j < b) && (z != p-1))
	  z = exp_mod(z, 2, p);
	else
	  if ((j == b) && (z != p-1))
	    return(0);  // not prime
	  else
	    return(1);  // prime
      } 
    }
  }
  return(1);
}

Integer crt(Integer m[], Integer c[], int num_elements)
/* This function computes the chinese remainder theorem result
 *  for    result = c[0] % m[0]
 *         result = c[1] % m[1]
 *  etc.
 * the arrays passed in are the values, and num_elements is the
 *  number of elements in the array.
 */
{
  int i;
  Integer modulus, n, a;

  modulus = 1;
  for (i = 0; i < num_elements; i++)
    modulus *= m[i];
  n = 0;
  for (i=0; i < num_elements; i++){
    a = modulus/m[i];
    n += c[i] * a * (inverse(a,m[i]));
    n %= modulus;
  }
  return n;
}

Integer crt2(Integer value1, Integer p, Integer value2, Integer q)
/* convenience routine to access crt for val1 % p && val2 % q
 *  calls crt, and returns the result
 */
{
  Integer mods[2], coef[2];

  mods[0] = p;
  mods[1] = q;
  coef[0] = value1;
  coef[1] = value2;

  return(crt(mods, coef, 2));
}

Integer sqrt(Integer value, Integer modulus[], int k)
/* This function calculates the square root of a value in a field
 * where m[0]*m[1]*...*m[k] = n
 * for now, there is another caveat: m[i] = 3 mod 4. The alternative
 *   is much harder for me to code :-)
 */
{
  Integer roots[k];
  int t = k;

  while (t--)
    if ((modulus[t] % 4) != 3){
      cerr<< "sqrt: not 3 mod 4 prime!";
      return(0); // error
    }
  
  if (k>1)
    {
      for (t=0; t<k; t++)
	roots[t] = sqrt_p(value, modulus[t]);  
      return(crt(modulus, roots, k));
    }
  else
    return(sqrt_p(value, modulus[0]));
}

Integer sqrt2(Integer val, Integer p, Integer q)
/* Same as above for m = p*q
 */
{
  Integer roots[2];
 if (((p%4) != 3) || ((q%4) !=3)) {
    cerr<< "sqrt2: not 3 mod 4 prime!";
    return(0); // error
  }
  return(crt2(sqrt_p(val, p), p, sqrt_p(val, q), q));
}


Integer sqrt_p(Integer value, Integer prime)
/* This function calculates sqrt(x) mod p. It should only be called
 *  from the sqrt function above, since we assume that p = 3 mod 4
 */
{
  Integer m;

  return(exp_mod(value, ((((prime-1)/2) +1)/2), prime));
}

Integer inverse(Integer a, Integer p)
/* Computes x s.t.  x*a = 1 (mod p) and returns that as the result.
 *  assumes that p is prime. uses Extended Euclid.
 */
{
  Integer x, d, y;
  e_euclid(a, p, &d, &x, &y);
  if (d != 1) {
    cerr << "inverse: gcd is NOT 1\n";
    return(0);
  }
  if (x<0)
    return(p + x);
  else
    return(x);
}

void e_euclid(Integer a, Integer b, Integer *d, Integer *x, Integer *y)
/* This computes d=ax+by using extended euclid gcd algorithm. from CLR
 */
{
  Integer t, u, v;

  if (b==0){
    *d = a;
    *x = 1;
    *y = 0;
    return;
  }
  e_euclid(b, a%b, &t, &u, &v);
  *d = t;
  *x = v;
  *y = u - ((a/b) * v);
  return;
}

int legendre(Integer number, Integer prime)
/* This computes the legendre symbol of a given number, using the
 * n ^ ((p-1)/2) method.
 */
{
  Integer a;
  a = exp_mod(number, ((prime-1)/2), prime);
  if (a==1)
    return 1;
  if ((a==0) || (a==prime))
    return 0;
  if ((a==-1) || (a== prime-1));
    return -1;
}

int legendre(long number, long prime)
/* This computes the legendre symbol of a given number, using the
 * n ^ ((p-1)/2) method.
 */
{
  long a;
  a = exp_mod(number, ((prime-1)/2), prime);
  if (a==1)
    return 1;
  if ((a==0) || (a==prime))
    return 0;
  if ((a==-1) || (a== prime-1));
    return -1;
}
  
