/* Modular multiplication.

  Copyright 2002, 2003, 2004, 2005 Paul Zimmermann and Alexander Kruppa.

  This file is part of the ECM Library.

  The ECM Library is free software; you can redistribute it and/or modify
  it under the terms of the GNU Lesser General Public License as published by
  the Free Software Foundation; either version 2.1 of the License, or (at your
  option) any later version.

  The ECM Library is distributed in the hope that it will be useful, but
  WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
  or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
  License for more details.

  You should have received a copy of the GNU Lesser General Public License
  along with the ECM Library; see the file COPYING.LIB.  If not, write to
  the Free Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
  MA 02111-1307, USA.
*/

#include <stdlib.h>
#include <stdio.h>
#include "gmp.h"

/* For ECM_ERROR */
#include "ecm.h"

#include "ecm-gmp.h"
#include "ecm-impl.h"

FILE *ECM_STDOUT, *ECM_STDERR; /* define them here since needed in tune.c */

/* define WANT_ASSERT to check normalization of residues */
/* #define WANT_ASSERT 1 */
/* #define DEBUG */

#define ASSERT_NORMALIZED(x) ASSERT ((modulus->repr != MOD_MODMULN && \
				      modulus->repr != MOD_REDC) || \
			     mpz_size (x) <= mpz_size (modulus->orig_modulus))
#define MPZ_NORMALIZED(x)    ASSERT (PTR(x)[ABSIZ(x)-1] != 0)

#ifndef MPZMOD_THRESHOLD
#define MPZMOD_THRESHOLD MPZMOD_THRESHOLD_DEFAULT
#endif

#ifndef REDC_THRESHOLD
#define REDC_THRESHOLD REDC_THRESHOLD_DEFAULT
#endif


#ifndef GMP_NUMB_BITS
#define GMP_NUMB_BITS __GMP_BITS_PER_MP_LIMB
#endif

void base2mod (mpres_t, mpres_t, mpres_t, mpmod_t);
void base2mod_1 (mpres_t, mpres_t, mpmod_t);
void REDC (mpres_t, mpres_t, mpz_t, mpmod_t);
void mod_mul2exp (mpz_t, unsigned int, mpmod_t);
void mod_div2exp (mpz_t, unsigned int, mpmod_t);

/* returns +/-l if n is a factor of N = 2^l +/- 1 with N <= n^threshold, 
   0 otherwise.
*/
int 
isbase2 (mpz_t n, double threshold)
{
  unsigned int k, lo; 
  int res = 0; 
  mpz_t u, w;

  mpz_init (u);
  mpz_init (w);
  lo = mpz_sizeinbase (n, 2) - 1; /* 2^lo <= n < 2^(lo+1) */  
  mpz_set_ui (u, 1);
  mpz_mul_2exp (u, u, 2 * lo);
  mpz_mod (w, u, n); /* 2^(2lo) mod n = -/+2^(2lo-l) if m*n = 2^l+/-1 */
  if (mpz_cmp_ui (w, 1) == 0) /* if 2^(2lo) mod n = 1, then n divides 2^lo+1, 
				 since n has lo+1 bits. */
    return lo;
  k = mpz_sizeinbase (w, 2) - 1;
  /* if w = 2^k then n divides 2^(2*lo-k)-1 */
  mpz_set_ui (u, 1);
  mpz_mul_2exp (u, u, k);
  if (mpz_cmp(w, u) == 0) 
    res = k - 2 * lo;
  else /* if w = -2^k then n divides 2^(2*lo-k)+1 */
    {
      mpz_neg (w, w);
      mpz_mod (w, w, n);
      k = mpz_sizeinbase (w, 2) - 1;
      mpz_set_ui (u, 1);
      mpz_mul_2exp (u, u, k);
      if (mpz_cmp (w, u) == 0) 
        res = 2 * lo - k;
    }
  mpz_clear (u);
  mpz_clear (w);
  if (abs (res) > (int) (threshold * (double) lo)) 
    res = 0;

  return res;
}

/* Do base-2 reduction. R must not equal S or t. */
void
base2mod (mpres_t R, mpres_t S, mpres_t t, mpmod_t modulus)
{
  unsigned long absbits = abs (modulus->bits);

  ASSERT (R != S && R != t);
  mpz_tdiv_q_2exp (R, S, absbits);
  mpz_tdiv_r_2exp (t, S, absbits);
  if (modulus->bits < 0)
    mpz_add (R, R, t);
  else
    mpz_sub (R, t, R);

  /* mpz_mod (R, R, modulus->orig_modulus); */
  while (mpz_sizeinbase (R, 2) > absbits)
    {
      mpz_tdiv_q_2exp (t, R, absbits);
      mpz_tdiv_r_2exp (R, R, absbits);
      if (modulus->bits < 0)
        mpz_add (R, R, t);
      else
        mpz_sub (R, R, t);
    }
}

/* Same, but source and result in same variable */
void
base2mod_1 (mpres_t RS, mpres_t t, mpmod_t modulus)
{
  unsigned long absbits = abs (modulus->bits);

  ASSERT (RS != t);
  while (mpz_sizeinbase (RS, 2) > absbits)
    {
      mpz_tdiv_q_2exp (t, RS, absbits);
      mpz_tdiv_r_2exp (RS, RS, absbits); /* Just a truncate */
      if (modulus->bits < 0)
        mpz_add (RS, RS, t);
      else
        mpz_sub (RS, RS, t);
    }
}

/* Fermat-mod */
static void
base2mod_2 (mpres_t RS, mp_size_t n, mpz_t modulus)
{
  mp_size_t s;

  s = ABSIZ(RS);
  if (s > n)
    {
      if (s == n + 1)
        {
          mp_ptr rp = PTR(RS);

          if ((rp[n] = mpn_sub_1 (rp, rp, n, rp[n])))
            rp[n] = mpn_add_1 (rp, rp, n, rp[n]);
          MPN_NORMALIZE(rp, s);
          SIZ(RS) = (SIZ(RS) > 0) ? s : -s;
        }
      else /* should happen rarely */
        mpz_mod (RS, RS, modulus);
    }
}

/* subquadratic REDC, at mpn level.
   {orig,n} is the original modulus.
   {aux,n} is the auxiliary modulus.
   Requires ABSIZ(x) = 2n and ABSIZ(orig_modulus)=ABSIZ(aux_modulus)=n.
 */
static void
ecm_redc_n (mp_ptr rp, mp_srcptr xp, mp_srcptr orig, mp_srcptr aux, mp_size_t n)
{
  mp_ptr tp, up;
  mp_size_t nn = n + n;
  mp_limb_t cy;
  TMP_DECL(marker);

  TMP_MARK(marker);
  up = TMP_ALLOC_LIMBS(nn + nn);
  ecm_mul_lo_n (up, xp, aux, n);
  tp = up + nn;
  mpn_mul_n (tp, up, orig, n);
  /* add {x, 2n} and {tp, 2n}. We know that {tp, n} + {xp, n} will give
     either 0, or a carry out. If xp[n-1] <> 0, then there is a carry. */
#ifdef HAVE_NATIVE_mpn_add_nc
  cy = __gmpn_add_nc (rp, tp + n, xp + n, n, (mp_limb_t) ((xp[n - 1]) ? 1 : 0));
#else
  cy = mpn_add_n (rp, tp + n, xp + n, n);
  cy += mpn_add_1 (rp, rp, n, (mp_limb_t) ((xp[n - 1]) ? 1 : 0));
#endif
  if (cy || mpn_cmp (rp, orig, n) > 0)
    cy -= mpn_sub_n (rp, rp, orig, n);
  /* ASSERT ((cy == 0) && (mpn_cmp (rp, orig, n) < 0)); */
  TMP_FREE(marker);
}

/* REDC. x and t must not be identical, t has limb growth */
/* subquadratic REDC, at mpz level */
void 
REDC (mpres_t r, mpres_t x, mpz_t t, mpmod_t modulus)
{
  mp_size_t n = modulus->bits / GMP_NUMB_BITS;

  ASSERT (ABSIZ(x) <= 2 * n);
  if (ABSIZ(x) == 2 * n)
    {
      mp_ptr rp;
      if (ALLOC(r) < n)
	_mpz_realloc (r, n);
      rp = PTR(r);
      ecm_redc_n (rp, PTR(x), PTR(modulus->orig_modulus), 
		PTR(modulus->aux_modulus), n);
      MPN_NORMALIZE(rp, n);
      SIZ(r) = (SIZ(x) > 0) ? n : -n;
      MPZ_NORMALIZED (r);
    }
  else
    {
      mpz_tdiv_r_2exp (t, x, modulus->bits);
      mpz_mul (t, t, modulus->aux_modulus);
      mpz_tdiv_r_2exp (t, t, modulus->bits);  /* t = (x % R) * 1/N (mod R) */
      mpz_mul (t, t, modulus->orig_modulus);
      mpz_add (t, t, x);
      mpz_tdiv_q_2exp (r, t, modulus->bits);  /* r = (x + m*N) / R */
      if (ABSIZ (r) > n)
	mpz_sub (r, r, modulus->multiple);
    }
  ASSERT (ABSIZ(r) <= n);
}

/* multiplies c by R^k modulo n where R=2^mp_bits_per_limb 
   n is supposed odd. Does not need to be efficient. */
void 
mod_mul2exp (mpz_t c, unsigned int k, mpmod_t modulus)
{
  mpz_mul_2exp (modulus->temp1, c, k * __GMP_BITS_PER_MP_LIMB);
  mpz_mod (c, modulus->temp1, modulus->orig_modulus);
}

/* divides c by R^k modulo n where R=2^mp_bits_per_limb
   n is supposed odd. Does not need to be efficient. */
void 
mod_div2exp (mpz_t c, unsigned int k, mpmod_t modulus)
{
  mpz_set_ui (modulus->temp2, 1);
  mpz_mul_2exp (modulus->temp1, modulus->temp2, k * __GMP_BITS_PER_MP_LIMB);
  mpz_invert (modulus->temp2, modulus->temp1, modulus->orig_modulus); 
    /* temp2 = 2^(-k) (mod n) */
  mpz_mul (modulus->temp1, modulus->temp2, c);
  mpz_mod (c, modulus->temp1, modulus->orig_modulus);
}

/* r <- c/R^nn mod n, where n has nn limbs, and R=2^GMP_NUMB_BITS.
   n must be odd.
   c must have space for at least 2*nn limbs.
   r must have space for at least n limbs.
   c and r can be the same variable.
   The data in c is clobbered.
*/
static void 
ecm_redc_basecase (mpz_ptr r, mpz_ptr c, mpmod_t modulus)
{
  mp_ptr rp;
  mp_ptr cp;
  mp_srcptr np;
  mp_limb_t cy;
  mp_size_t j, nn = modulus->bits / __GMP_BITS_PER_MP_LIMB;

  ASSERT(ABSIZ(c) <= 2 * nn);
  ASSERT(ALLOC(r) >= nn);
  cp = PTR(c);
  rp = PTR(r);
  np = PTR(modulus->orig_modulus);
  for (j = ABSIZ(c); j < 2 * nn; j++) 
    cp[j] = 0;
#ifndef NATIVE_REDC
  for (j = 0; j < nn; j++)
    {
      cp[0] = mpn_addmul_1 (cp, np, nn, cp[0] * modulus->Nprim);
      cp++;
    }
  /* add vector of carries and shift */
  cy = mpn_add_n (rp, cp, cp - nn, nn);
#else
  ecm_redc3 (cp, np, nn, modulus->Nprim);
  /* add vector of carries and shift */
  cy = mpn_add_n (rp, cp + nn, cp, nn);
#endif
  /* the result of Montgomery's REDC is less than 2^Nbits + N,
     thus at most one correction is enough */
  if (cy != 0)
    mpn_sub_n (rp, rp, np, nn); /* a borrow should always occur here */
  MPN_NORMALIZE (rp, nn);
  SIZ(r) = SIZ(c) < 0 ? -nn : nn;
}

/* don't use base2 if repr == -1, i.e. -nobase2 */
void 
mpmod_init (mpmod_t modulus, mpz_t N, int repr)
{
  int base2;
  
  if ((repr != -1) && (base2 = isbase2 (N, BASE2_THRESHOLD)))
    {
      int r;
      outputf (OUTPUT_VERBOSE,
	       "Using special division for factor of 2^%d%c1\n",
	       abs (base2), (base2 < 0) ? '-' : '+');
      r = mpmod_init_BASE2 (modulus, base2, N);
      ASSERT (r == 0); /* error should not happen if isbase2 is correct */
    }
  else if (mpz_size (N) < MPZMOD_THRESHOLD)
    {
      outputf (OUTPUT_VERBOSE, "Using MODMULN\n");
      mpmod_init_MODMULN (modulus, N);
    }
  else if (mpz_sizeinbase (N, 2) < REDC_THRESHOLD)
    {
      outputf (OUTPUT_VERBOSE, "Using mpz_mod\n");
      mpmod_init_MPZ (modulus, N);
    }
  else
    {
      outputf (OUTPUT_VERBOSE, "Using REDC\n");
      mpmod_init_REDC (modulus, N);
    }
  
  return;
}

void 
mpmod_init_MPZ (mpmod_t modulus, mpz_t N)
{
  int Nbits;
  
  mpz_init_set (modulus->orig_modulus, N);
  modulus->repr = MOD_PLAIN;
  
  Nbits = mpz_size (N) * __GMP_BITS_PER_MP_LIMB; /* Number of bits, rounded
                                                    up to full limb */
  mpz_init2 (modulus->temp1, 2 * Nbits + __GMP_BITS_PER_MP_LIMB);
  mpz_init2 (modulus->temp2, Nbits);
  
  return;
}

int 
mpmod_init_BASE2 (mpmod_t modulus, int base2, mpz_t N)
{
  int Nbits;
  
  mpz_init_set (modulus->orig_modulus, N);
  modulus->repr = MOD_BASE2;
  modulus->bits = base2;

  Nbits = mpz_size (N) * __GMP_BITS_PER_MP_LIMB; /* Number of bits, rounded
                                                    up to full limb */
  mpz_init2 (modulus->temp1, 2 * Nbits + __GMP_BITS_PER_MP_LIMB);
  mpz_init2 (modulus->temp2, Nbits);
  
  mpz_set_ui (modulus->temp1, 1);
  mpz_mul_2exp (modulus->temp1, modulus->temp1, abs (base2));
  if (base2 < 0)
    mpz_sub_ui (modulus->temp1, modulus->temp1, 1);
  else
    mpz_add_ui (modulus->temp1, modulus->temp1, 1);
  if (!mpz_divisible_p (modulus->temp1, N))
    {
       outputf (OUTPUT_ERROR, "mpmod_init_BASE2: n does not divide 2^%d%c1\n",
                abs (base2), base2 < 0 ? '-' : '+');
       mpz_clear (modulus->temp2);
       mpz_clear (modulus->temp1);
       mpz_clear (modulus->orig_modulus);
       return ECM_ERROR;
    }
  
  modulus->Fermat = 0;
  if (base2 > 0)
    {
      unsigned long i;
      for (i = base2; (i & 1) == 0; i >>= 1);
      if (i == 1)
        {
          modulus->Fermat = base2;
#ifdef HAVE_GWNUM
          if (modulus->Fermat >= GWTHRESHOLD)
            Fgwinit (modulus->Fermat);
#endif
        }
    }
  
  return 0;
}

void
mpmod_init_MODMULN (mpmod_t modulus, mpz_t N)
{
  int Nbits;

  mpz_init_set (modulus->orig_modulus, N);
  
  modulus->repr = MOD_MODMULN;
  Nbits = mpz_size (N) * __GMP_BITS_PER_MP_LIMB; /* Number of bits, rounded
                                                    up to full limb */
  modulus->bits = Nbits;

  mpz_init2 (modulus->temp1, 2 * Nbits + __GMP_BITS_PER_MP_LIMB);
  mpz_init2 (modulus->temp2, Nbits);

  mpz_set_ui (modulus->temp1, 1);
  mpz_mul_2exp (modulus->temp1, modulus->temp1, __GMP_BITS_PER_MP_LIMB);
  mpz_tdiv_r_2exp (modulus->temp2, modulus->orig_modulus, 
                   __GMP_BITS_PER_MP_LIMB);
  mpz_invert (modulus->temp2, modulus->temp2, modulus->temp1);
    /* Now temp2 = 1/n (mod 2^bits_per_limb) */
  mpz_sub (modulus->temp2, modulus->temp1, modulus->temp2);
  modulus->Nprim = mpz_getlimbn (modulus->temp2, 0);
    /* Now Nprim = -1/n (mod 2^bits_per_limb) */

  mpz_init (modulus->R2);
  mpz_set_ui (modulus->temp1, 1);
  mpz_mul_2exp (modulus->temp1, modulus->temp1, 2 * Nbits);
  mpz_mod (modulus->R2, modulus->temp1, modulus->orig_modulus);
  /* Now R2 = (2^bits)^2 (mod N) */
  
  mpz_init (modulus->R3);
  mpz_mul_2exp (modulus->temp1, modulus->R2, Nbits);
  mpz_mod (modulus->R3, modulus->temp1, modulus->orig_modulus);
  /* Now R3 = (2^bits)^3 (mod N) */

  mpz_init (modulus->multiple);
  mpz_set_ui (modulus->temp1, 1);
  mpz_mul_2exp (modulus->temp1, modulus->temp1, Nbits);
  /* compute ceil(2^bits / N) */
  mpz_cdiv_q (modulus->temp1, modulus->temp1, modulus->orig_modulus);
  mpz_mul (modulus->multiple, modulus->temp1, modulus->orig_modulus);
  /* Now multiple is the largest multiple of N >= 2^bits */
}

void 
mpmod_init_REDC (mpmod_t modulus, mpz_t N)
{
  mp_size_t n;
  int Nbits;
  
  mpz_init_set (modulus->orig_modulus, N);
  
  n = mpz_size (N);
  modulus->repr = MOD_REDC;
  Nbits = n * __GMP_BITS_PER_MP_LIMB; /* Number of bits, rounded
                                                    up to full limb */
  modulus->bits = Nbits;
  
  mpz_init2 (modulus->temp1, 2 * Nbits + __GMP_BITS_PER_MP_LIMB);
  mpz_init2 (modulus->temp2, Nbits);
  mpz_init (modulus->aux_modulus);

  mpz_set_ui (modulus->temp1, 1);
  mpz_mul_2exp (modulus->temp1, modulus->temp1, Nbits);
  /* since we directly check even modulus in ecm/pm1/pp1,
     N is odd here, thus 1/N mod 2^Nbits always exist */
  mpz_invert (modulus->aux_modulus, N, modulus->temp1);

  mpz_sub (modulus->aux_modulus, modulus->temp1, modulus->aux_modulus);
  /* ensure aux_modulus has n allocated limbs, for ecm_redc_n */
  if (ABSIZ(modulus->aux_modulus) < n)
    {
      _mpz_realloc (modulus->aux_modulus, n);
      MPN_ZERO (PTR(modulus->aux_modulus) + ABSIZ(modulus->aux_modulus),
		n - ABSIZ(modulus->aux_modulus));
    }

  mpz_init (modulus->R2);
  mpz_set_ui (modulus->temp1, 1);
  mpz_mul_2exp (modulus->temp1, modulus->temp1, 2 * Nbits);
  mpz_mod (modulus->R2, modulus->temp1, modulus->orig_modulus);
  /* Now R2 = (2^bits)^2 (mod N) */
  
  mpz_init (modulus->R3);
  mpz_mul_2exp (modulus->temp1, modulus->R2, Nbits);
  mpz_mod (modulus->R3, modulus->temp1, modulus->orig_modulus);
  /* Now R3 = (2^bits)^3 (mod N) */
  
  mpz_init (modulus->multiple);
  mpz_set_ui (modulus->temp1, 1);
  mpz_mul_2exp (modulus->temp1, modulus->temp1, Nbits);
  /* compute ceil(2^bits / N) */
  mpz_cdiv_q (modulus->temp1, modulus->temp1, modulus->orig_modulus);
  mpz_mul (modulus->multiple, modulus->temp1, modulus->orig_modulus);
  /* Now multiple is the largest multiple of N >= 2^bits */
}

void 
mpmod_clear (mpmod_t modulus)
{
  mpz_clear (modulus->orig_modulus);
  mpz_clear (modulus->temp1);
  mpz_clear (modulus->temp2);
  if (modulus->repr == MOD_MODMULN || modulus->repr == MOD_REDC)
    {
      mpz_clear (modulus->R2);
      mpz_clear (modulus->R3);
      mpz_clear (modulus->multiple);
      if (modulus->repr == MOD_REDC)
        mpz_clear (modulus->aux_modulus);
    }
#ifdef HAVE_GWNUM
  if (modulus->Fermat >= GWTHRESHOLD)
    Fgwclear ();
#endif
  
  return;
}

void 
mpres_init (mpres_t R, mpmod_t modulus)
{
  /* use mpz_sizeinbase since modulus->bits may not be initialized yet */
  mpz_init2 (R, mpz_sizeinbase (modulus->orig_modulus, 2));
}

/* realloc R so that it has at least the same number of limbs as modulus */
void
mpres_realloc (mpres_t R, mpmod_t modulus)
{
  if (modulus->repr == MOD_MODMULN)
    _mpz_realloc (R, modulus->bits / GMP_NUMB_BITS);
}

/* R <- BASE^EXP mod modulus.
   Assume EXP >= 0.
 */
void 
mpres_pow (mpres_t R, mpres_t BASE, mpres_t EXP, mpmod_t modulus)
{
  ASSERT_NORMALIZED (BASE);
  if (modulus->repr == MOD_PLAIN)
    {
      mpz_powm (R, BASE, EXP, modulus->orig_modulus);
    }
  else if (modulus->repr == MOD_BASE2 || modulus->repr == MOD_MODMULN ||
           modulus->repr == MOD_REDC)
    {
      unsigned int expidx;
      mp_limb_t bitmask, expbits;

      /* case EXP=0 */
      if (mpz_cmp_ui (EXP, 0) == 0)
        {
          mpres_set_ui (R, 1, modulus); /* set result to 1 */
          ASSERT_NORMALIZED (R);
          return;
        }

      expidx = mpz_size (EXP) - 1;         /* point at most significant limb */
      expbits = mpz_getlimbn (EXP, expidx); /* get most significant limb */
      bitmask = ((mp_limb_t) 1) << (GMP_NUMB_BITS - 1);

      while ((bitmask & expbits) == 0)
        {
          bitmask >>= 1;
          if (bitmask == 0)                 /* no set bits in this limb */
            {
              if (expidx == 0)              /* no more limbs -> exp was 0 */
                {
                  mpres_set_ui (R, 1, modulus); /* set result to 1 */
		  ASSERT_NORMALIZED (R);
                  return;
                }
              expidx --;
              expbits = mpz_getlimbn (EXP, expidx);
              bitmask = (mp_limb_t) 1 << (GMP_NUMB_BITS - 1);
            }
        }

    /* here the most significant limb with any set bits is in expbits, */
    /* bitmask is set to mask in the msb of expbits */

      mpz_set (modulus->temp2, BASE);
      bitmask >>= 1;

      while (1) 
        {
          for ( ; bitmask != 0; bitmask >>= 1) 
            {
              mpz_mul (modulus->temp1, modulus->temp2, modulus->temp2); /* r = r^2 */

              if (modulus->repr == MOD_BASE2)
                base2mod (modulus->temp2 , modulus->temp1, modulus->temp1, modulus);
              else if (modulus->repr == MOD_MODMULN)
                {
                  ecm_redc_basecase (modulus->temp2, modulus->temp1, modulus);
                }
              else
                REDC (modulus->temp2, modulus->temp1, modulus->temp2, modulus);

              if (expbits & bitmask)
                { 
                  mpz_mul (modulus->temp1, modulus->temp2, BASE);
                  if (modulus->repr == MOD_BASE2)
                    base2mod (modulus->temp2, modulus->temp1, modulus->temp1, modulus);
                  else if (modulus->repr == MOD_MODMULN)
                    {
                      ecm_redc_basecase (modulus->temp2, modulus->temp1, modulus);
                    }
                  else
                    REDC (modulus->temp2, modulus->temp1, modulus->temp2, modulus);
                }
            }
          if (expidx == 0)		/* if we just processed the least */
            break;			/* significant limb, we are done */
          expidx --;
          expbits = mpz_getlimbn (EXP, expidx);
          bitmask = (mp_limb_t) 1 << (GMP_NUMB_BITS - 1);
        }
      mpz_set (R, modulus->temp2); /* TODO: isn't it possible to use R instead
				      of modulus->temp2 above to avoid this
				      copy? */
    } /* if (modulus->repr == MOD_BASE2 || ... ) */
  ASSERT_NORMALIZED (R);
}


/* Returns 1 if S == 0 (mod modulus), 0 otherwise */

int
mpres_is_zero (mpres_t S, mpmod_t modulus)
{
  mpz_mod (modulus->temp1, S, modulus->orig_modulus);
  /* For all currently implemented representations, a zero residue has zero
     integer representation */
  return (mpz_sgn (modulus->temp1) == 0) ? 1 : 0;
}

/* R <- BASE^EXP mod modulus */ 
void 
mpres_ui_pow (mpres_t R, unsigned int BASE, mpres_t EXP, mpmod_t modulus)
{
  if (modulus->repr == MOD_PLAIN)
    {
      mpz_set_ui (modulus->temp1, BASE);
      mpz_powm (R, modulus->temp1, EXP, modulus->orig_modulus);
    }
  else if (modulus->repr == MOD_BASE2 || modulus->repr == MOD_MODMULN ||
           modulus->repr == MOD_REDC)
    {
      unsigned int expidx;
      mp_limb_t bitmask, expbits;

      expidx = mpz_size (EXP) -1;           /* point at most significant limb */
      expbits = mpz_getlimbn (EXP, expidx); /* get most significant limb */
      bitmask = (mp_limb_t) 1 << (GMP_NUMB_BITS - 1);

      while ((bitmask & expbits) == 0)
        {
          bitmask >>= 1;
          if (bitmask == 0)                 /* no set bits in this limb */
            {
              if (expidx == 0)              /* no more limbs -> exp was 0 */
                {
                  mpres_set_ui (R, 1, modulus); /* set result to 1 */
		  ASSERT_NORMALIZED (R);
                  return;
                }
              expidx --;
              expbits = mpz_getlimbn (EXP, expidx);
              bitmask = (mp_limb_t) 1 << (GMP_NUMB_BITS - 1);
            }
        }

    /* here the most significant limb with any set bits is in expbits, */
    /* bitmask is set to mask in the msb of expbits */
    
      mpz_set_ui (modulus->temp2, BASE); /* temp2 = BASE */
      if (modulus->repr == MOD_MODMULN || modulus->repr == MOD_REDC)
        {
          mpz_mul_2exp (modulus->temp1, modulus->temp2, modulus->bits);
          mpz_mod (modulus->temp2, modulus->temp1, modulus->orig_modulus);
        }
      bitmask >>= 1;

      while (1) 
        {
          for ( ; bitmask != 0; bitmask >>= 1) 
            {
              mpz_mul (modulus->temp1, modulus->temp2, modulus->temp2); /* r = r^2 */

              if (modulus->repr == MOD_BASE2)
                base2mod (modulus->temp2 , modulus->temp1, modulus->temp1, modulus);
              else if (modulus->repr == MOD_MODMULN)
                {
                  ecm_redc_basecase (modulus->temp2, modulus->temp1, modulus);
                }
              else
                REDC (modulus->temp2, modulus->temp1, modulus->temp2, modulus);

              if (expbits & bitmask)
                {
                  if (BASE == 2)
                    {
                      mpz_mul_2exp (modulus->temp2, modulus->temp2, 1);
                      if (mpz_cmp (modulus->temp2, modulus->orig_modulus) >= 0)
                        mpz_sub (modulus->temp2, modulus->temp2, modulus->orig_modulus);
                    }
                  else
                    {
                      mpz_mul_ui (modulus->temp1, modulus->temp2, BASE);
                      mpz_mod (modulus->temp2, modulus->temp1, modulus->orig_modulus);
                    }
                }
            }
          if (expidx == 0)		/* if we just processed the least */
            break;			/* significant limb, we are done */
          expidx--;
          expbits = mpz_getlimbn (EXP, expidx);
          bitmask = (mp_limb_t) 1 << (GMP_NUMB_BITS - 1);
        }
      mpz_set (R, modulus->temp2); /* TODO: use R instead of modulus->temp2
				      above to avoid this copy? */
    } /* if (modulus->repr == MOD_BASE2 || ... ) */
  ASSERT_NORMALIZED (R);
}

void 
mpres_mul (mpres_t R, mpres_t S1, mpres_t S2, mpmod_t modulus)
{
  ASSERT_NORMALIZED (S1);
  ASSERT_NORMALIZED (S2);

#ifdef HAVE_GWNUM
  if (modulus->repr == MOD_BASE2 && modulus->Fermat >= 1024)
    {
      base2mod_1 (S1, modulus->temp1, modulus);
      base2mod_1 (S2, modulus->temp1, modulus);
      ASSERT (mpz_sizeinbase (S1, 2) <= (unsigned) abs(modulus->bits));
      ASSERT (mpz_sizeinbase (S2, 2) <= (unsigned) abs(modulus->bits));
      Fgwmul (R, S1, S2);
      return;
    }
#elif defined(HAVE_FFT)
  if (modulus->repr == MOD_BASE2 && modulus->Fermat >= 32768)
    {
      mp_size_t n = modulus->Fermat / __GMP_BITS_PER_MP_LIMB;
      unsigned long k;
      
      _mpz_realloc (R, n + 1);
      k = mpn_fft_best_k (n, S1 == S2);
      ASSERT(mpn_fft_next_size (n, k) == n);
      base2mod_2 (S1, n, modulus->orig_modulus);
      base2mod_2 (S2, n, modulus->orig_modulus);
      mpn_mul_fft (PTR(R), n, PTR(S1), ABSIZ(S1), PTR(S2), ABSIZ(S2), k);
      n ++;
      MPN_NORMALIZE(PTR(R), n);
      SIZ(R) = ((SIZ(S1) ^ SIZ(S2)) >= 0) ? n : -n;
      return;
    }
#endif

  mpz_mul (modulus->temp1, S1, S2);

  switch (modulus->repr)
    {
    case MOD_BASE2:
      base2mod (R, modulus->temp1, modulus->temp1, modulus);
      break;
    case MOD_MODMULN:
      ecm_redc_basecase (R, modulus->temp1, modulus);
      break;
    case MOD_REDC:
      REDC (R, modulus->temp1, modulus->temp2, modulus);
      break;
    default:
      mpz_mod (R, modulus->temp1, modulus->orig_modulus);
      break;
    }
  ASSERT_NORMALIZED (R);
}

void 
mpres_mul_ui (mpres_t R, mpres_t S, unsigned int n, mpmod_t modulus)
{
  ASSERT_NORMALIZED (S);
  mpz_mul_ui (modulus->temp1, S, n);
  /* This is the same for all methods: just reduce with original modulus */
  mpz_mod (R, modulus->temp1, modulus->orig_modulus);
  ASSERT_NORMALIZED (R);
}

/* R <- S / 2^n mod modulus. Does not need to be fast. */
void 
mpres_div_2exp (mpres_t R, mpres_t S, unsigned int n, mpmod_t modulus)
{
  ASSERT_NORMALIZED (S);
  if (n == 0)
    {
      mpres_set (R, S, modulus);
      ASSERT_NORMALIZED (R);
      return;
    }

    if (mpz_odd_p (S))
      {
        mpz_add (R, S, modulus->orig_modulus);
        mpz_tdiv_q_2exp (R, R, 1);
      }
    else
      mpz_tdiv_q_2exp (R, S, 1);

    for ( ; n > 1; n--)
      if (mpz_odd_p (R))
        {
          mpz_add (R, R, modulus->orig_modulus);
          mpz_tdiv_q_2exp (R, R, 1);
        }
      else
        mpz_tdiv_q_2exp (R, R, 1);

    ASSERT_NORMALIZED (R);
}

void
mpres_add_ui (mpres_t R, mpres_t S, unsigned int n, mpmod_t modulus)
{
  ASSERT_NORMALIZED (S);
  if (modulus->repr == MOD_PLAIN || modulus->repr == MOD_BASE2)
    {
      mpz_add_ui (R, S, n);
      if (mpz_cmp (R, modulus->orig_modulus) > 0)
        mpz_sub (R, R, modulus->orig_modulus);
    }
  else if (modulus->repr == MOD_MODMULN || modulus->repr == MOD_REDC)
    {
      mpz_set_ui (modulus->temp1, n);
      mpz_mul_2exp (modulus->temp1, modulus->temp1, modulus->bits);
      mpz_add (modulus->temp1, modulus->temp1, S);
      mpz_mod (R, modulus->temp1, modulus->orig_modulus);
    }
  ASSERT_NORMALIZED (R);
}

/* R <- S1 + S2 mod modulus */
void 
mpres_add (mpres_t R, mpres_t S1, mpres_t S2, mpmod_t modulus)
{
  ASSERT_NORMALIZED (S1);
  ASSERT_NORMALIZED (S2);
  mpz_add (R, S1, S2);
  if ((modulus->repr == MOD_MODMULN || modulus->repr == MOD_REDC) &&
      ABSIZ(R) > ABSIZ(modulus->orig_modulus))
    {
      if (SIZ(R) > 0)
	mpz_sub (R, R, modulus->multiple);
      else
	mpz_add (R, R, modulus->multiple);
      /* N <= since multiple < 2^Nbits + N, now |R| < B */
    }
  ASSERT_NORMALIZED (R);
}

void
mpres_sub_ui (mpres_t R, mpres_t S, unsigned int n, mpmod_t modulus)
{
  ASSERT_NORMALIZED (S);
  if (modulus->repr == MOD_PLAIN || modulus->repr == MOD_BASE2)
    {
      mpz_sub_ui (R, S, n);
      if (mpz_sgn (R) < 0)
        mpz_add (R, R, modulus->orig_modulus);
    }
  else if (modulus->repr == MOD_MODMULN || modulus->repr == MOD_REDC)
    {
      mpz_set_ui (modulus->temp1, n);
      mpz_mul_2exp (modulus->temp1, modulus->temp1, modulus->bits);
      mpz_sub (modulus->temp1, S, modulus->temp1);
      mpz_mod (R, modulus->temp1, modulus->orig_modulus);
    }
  ASSERT_NORMALIZED (R);
}

/* R <- S1 - S2 mod modulus */
void 
mpres_sub (mpres_t R, mpres_t S1, mpres_t S2, mpmod_t modulus)
{
  ASSERT_NORMALIZED (S1);
  ASSERT_NORMALIZED (S2);
  mpz_sub (R, S1, S2);
  if ((modulus->repr == MOD_MODMULN || modulus->repr == MOD_REDC) &&
      ABSIZ(R) > ABSIZ(modulus->orig_modulus))
    {
      if (SIZ(R) > 0)
	mpz_sub (R, R, modulus->multiple);
      else
	mpz_add (R, R, modulus->multiple);
      /* N <= since multiple < 2^Nbits + N, now |R| < B */
    }
  ASSERT_NORMALIZED (R);
}

void 
mpres_set_z (mpres_t R, mpz_t S, mpmod_t modulus)
{
  if (modulus->repr == MOD_PLAIN || modulus->repr == MOD_BASE2)
    {
      mpz_mod (R, S, modulus->orig_modulus);
    }
  else if (modulus->repr == MOD_MODMULN)
    {
      mpz_mod (modulus->temp2, S, modulus->orig_modulus);
      mpz_mul (modulus->temp1, modulus->temp2, modulus->R2);
      ecm_redc_basecase (R, modulus->temp1, modulus);
    }
  else if (modulus->repr == MOD_REDC)
    {
      mpz_mod (modulus->temp2, S, modulus->orig_modulus);
      mpz_mul (modulus->temp1, modulus->temp2, modulus->R2);
      REDC (R, modulus->temp1, modulus->temp2, modulus);
    }
  ASSERT_NORMALIZED (R);
}

/* R and S must not be modulus->temp1 */
void 
mpres_get_z (mpz_t R, mpres_t S, mpmod_t modulus)
{
  ASSERT_NORMALIZED (S);
  if (modulus->repr == MOD_PLAIN || modulus->repr == MOD_BASE2)
    {
      mpz_mod (R, S, modulus->orig_modulus);
    }
  else if (modulus->repr == MOD_MODMULN)
    {
      mpz_set (modulus->temp1, S);
      _mpz_realloc (R, modulus->bits / GMP_NUMB_BITS);
      ecm_redc_basecase (R, modulus->temp1, modulus);
    }
  else if (modulus->repr == MOD_REDC)
    {
      REDC (R, S, modulus->temp1, modulus);
    }
#ifdef DEBUG
  else
    {
      fprintf (ECM_STDERR, "mpres_get_z: Unexpected representation %d\n", 
               modulus->repr);
      exit (EXIT_FAILURE);
    }
#endif
}

void 
mpres_set_ui (mpres_t R, unsigned int n, mpmod_t modulus)
{
  if (modulus->repr == MOD_PLAIN || modulus->repr == MOD_BASE2)
    {
      mpz_set_ui (R, n);
      mpz_mod (R, R, modulus->orig_modulus);
    }
  else if (modulus->repr == MOD_MODMULN || modulus->repr == MOD_REDC)
    {
      mpz_set_ui (modulus->temp1, n);
      mpz_mul_2exp (modulus->temp1, modulus->temp1, modulus->bits);
      mpz_mod (R, modulus->temp1, modulus->orig_modulus);
    }
  ASSERT_NORMALIZED (R);
}

/* R <- -S mod modulus. Does not need to be efficient. */
void
mpres_neg (mpres_t R, mpres_t S, ATTRIBUTE_UNUSED mpmod_t modulus)
{
  ASSERT_NORMALIZED (S);
  mpz_neg (R, S);
  ASSERT_NORMALIZED (R);
}

int 
mpres_invert (mpres_t R, mpres_t S, mpmod_t modulus)
{
  ASSERT_NORMALIZED (S);
  if (modulus->repr == MOD_PLAIN || modulus->repr == MOD_BASE2)
    {
      int res = mpz_invert (R, S, modulus->orig_modulus);
      ASSERT_NORMALIZED (R);
      return res;
    }
  else if (modulus->repr == MOD_MODMULN)
    {
      if (mpz_invert (modulus->temp2, S, modulus->orig_modulus))
        {
          mpz_mul (modulus->temp1, modulus->temp2, modulus->R3);
          ecm_redc_basecase (R, modulus->temp1, modulus);
	  ASSERT_NORMALIZED (R);
          return 1;
        }
      else
        return 0;
    }
  else if (modulus->repr == MOD_REDC)
    {
      MPZ_NORMALIZED (S);
      if (mpz_invert (modulus->temp2, S, modulus->orig_modulus))
        {
          mpz_mul (modulus->temp1, modulus->temp2, modulus->R3);
          REDC (R, modulus->temp1, modulus->temp2, modulus);
	  ASSERT_NORMALIZED (R);
          return 1;
        }
      else
        return 0;
    }
#ifdef DEBUG
  else
    {
      fprintf (ECM_STDERR, "mpres_invert: Unexpected representation %d\n", 
               modulus->repr);
      exit (EXIT_FAILURE);
    }
#endif
  return 0;
}

void 
mpres_gcd (mpz_t R, mpres_t S, mpmod_t modulus)
{
  /* In MODMULN and REDC form, M(x) = x*R with gcd(R, modulus) = 1 .
     Therefore gcd(M(x), modulus) = gcd(x, modulus) and we need not bother
     to convert out of Montgomery form. */
  ASSERT_NORMALIZED (S);
  mpz_gcd (R, S, modulus->orig_modulus);
}

void 
mpres_out_str (FILE *fd, unsigned int base, mpres_t S, mpmod_t modulus)
{
  mpres_get_z (modulus->temp2, S, modulus);
  mpz_out_str (fd, base, modulus->temp2);
}
