#include "rabin.h"

void gen_rabin_key(rabin_private_key *k)
/* This function generates a secret rabin key, which can
 * then be used to derive the public info. For now, it creates
 *  a key pair s.t p = q = 3 mod 4
 */
{
  do
    k->p = generate_prime(SIZE_OF_KEYS);
  while ((k->p % 4) != 3);
  do
    k->q = generate_prime(SIZE_OF_KEYS);
  while ((k->q % 4) != 3);
  k->n = k->p * k->q;
}

void rabin_sign(rabin_private_key k, signed_message *message)
/* This function takes a message, and modifies it in such a
 * way so that it becomes signed. It takes a message with
 * an invalid signature, and a key k, and puts a valid signature
 * in the message. NOTE, it is assumed that the thing to be signed
 * is a quadratic residue for this implementation. Please check.
 *   it uses a standard rabin signature.
 */
{
  // check that this is indeed a QRn
  if ((legendre(message->msg, k.p) != 1) ||
      (legendre(message->msg, k.q) != 1)) {
	cerr << "sign_message: got a non quadratic residue message!!\n";
	return;
      }
  message->sig = sqrt2(message->msg, k.p, k.q);  //  set s to be the sqrt, which exists.
}

int rabin_check(rabin_public_key k, signed_message m)
/* This function attempts to figure out whether a signature for a message
     is valid for the user who'se public key is k
 */
{
 if (exp_mod(m.sig, 2, k.n) == m.msg)
    return 1;
  else
    return 0;
}

void mod_rabin_sign(rabin_private_key k, signed_message *message)
/* This will generate the signature of the type
   s^2 = m + cn, and returns a c s.t. s^2=m mod n;
 */
{
  Integer s;
  // check that this is indeed a QRn
  if ((legendre(message->msg, k.p) != 1) ||
      (legendre(message->msg, k.q) != 1)) {
	cerr << "sign_message: got a non quadratic residue message!!\n";
	return;
      }
  s = sqrt2(message->msg, k.p, k.q);  //  set s to be the sqrt, which exists.
  message->sig = ((pow(s,2) - message->msg) / k.n);
}

int mod_rabin_check(rabin_public_key k, signed_message m, int times_to_try)
/* This function should attempt to probabilistically check a message
      for a valid signature. It will do many passes on the same algorithm
      */
{
  long temp_prime, number;

  while (times_to_try--) {
    temp_prime = generate_prime(TEMP_PRIME_SIZE).as_long();
    number = ((m.msg%temp_prime + ((k.n%temp_prime) * (m.sig%temp_prime))) % temp_prime).as_long();
    if (legendre(number, temp_prime) != 1){
      return 0;  ///  AHA!!  not a valid signature !!
      }
  }
  return 1; //  couldn't find an invalid case.
}
