/* $Id: serial.C,v 1.28 2000/02/19 21:26:13 dm Exp $ */

/*
 *
 * Copyright (C) 1998 David Mazieres (dm@uun.org)
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; either version 2, or (at
 * your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 * USA
 *
 */

#include "crypt.h"
#include "blowfish.h"
#include "password.h"
#include "rxx.h"
#include "srp.h"

// XXX - need explicit instantiation for KCC
template const strbuf &strbuf_cat (const strbuf &,
				   const strbufcatobj<bigint, int> &);

str
export_rabin_encrypt_sec (const rabin_priv &sk, const eksblowfish *eksb)
{
  xdrsuio x (XDR_ENCODE, true);
  u_char hv[sha1::hashsize];

  if (!xdr_putint (&x, SK_RABIN_EKSBF)
      || !xdr_putbigint (&x, sk.n))
    return NULL;
  sha1_hashv (hv, x.iov (), x.iovcnt ());

  x.uio ()->clear ();
  if (!xdr_putbigint (&x, sk.p)
      || !xdr_putbigint (&x, sk.q)
      || !xdr_putpadbytes (&x, hv, sizeof (hv))
      || (eksb && (x.uio ()->resid () & 4) && !xdr_putint (&x, 0)))
    return NULL;

  wmstr seckey (x.iov (), x.iovcnt ());

  if (eksb) {
    cbc64iv iv (*eksb);
    iv.encipher_bytes (seckey, seckey.len ());
  }

  return armor64 (seckey);
}

ptr<rabin_priv>
import_rabin_decrypt_sec (const str raw, const eksblowfish *eksb)
{
  str seckey = dearmor64 (raw);
  if (seckey.len () & (eksb ? 7 : 3))
    return NULL;

  wmstr m (seckey.len ());
  memcpy (m, seckey, m.len ());
  if (eksb) {
    cbc64iv iv (*eksb);
    iv.decipher_bytes (m, m.len ());
  }
  seckey = m;

  bigint p, q;
  u_char hv1[sha1::hashsize];
  u_char hv2[sha1::hashsize];

  {
    xdrmem x1 (seckey, seckey.len ());
    if (!xdr_getbigint (&x1, p) || !xdr_getbigint (&x1, q)
	|| !xdr_getpadbytes (&x1, hv1, sizeof (hv1))
	|| p >= q || p <= 1 || q <= 1)
      return NULL;
  }

  ref<rabin_priv> skp = new refcounted<rabin_priv> (p, q);

  {
    xdrsuio x2;
    if (!xdr_putint (&x2, SK_RABIN_EKSBF)
	|| !xdr_putbigint (&x2, skp->n))
      return NULL;
    sha1_hashv (hv2, x2.iov (), x2.iovcnt ());
  }

  if (memcmp (hv1, hv2, sizeof (hv1)))
    return NULL;

  return skp;
}

str
export_rabin_priv (const rabin_priv &sk, str pwd, str comment, u_int rounds)
{
  str salt = "";
  str seckey;

  if (pwd) {
    salt = pw_gensalt (rounds);
    if (!salt)
      return NULL;

    eksblowfish eksb;
    pw_crypt (pwd, salt, SK_RABIN_SALTBITS, &eksb);
    seckey = export_rabin_encrypt_sec (sk, &eksb);
  }
  else
    seckey = export_rabin_encrypt_sec (sk, NULL);

  str c = "";
  if (comment)
    c = comment;

  return strbuf ("SK%d,", SK_RABIN_EKSBF)
    << salt << "," << seckey << ",0x" << sk.n.getstr (16)
    << "," << c;
}

#define A64STR "[A-Za-z0-9+/]+={0,2}"
const rxx rabin_import_format ("^SK1,(\\d+\\$" A64STR "\\$)?,"
			       "(" A64STR "),(0x[0-9a-f]+),(.*)$", "");

ptr<rabin_priv>
import_rabin_priv (str raw, str pwd, str *commentp)
{
  rxx r (rabin_import_format);
  if (!r.search (raw))
    return NULL;
  
  str salt = r[1];
  str seckey = r[2];
  bigint n (r[3]);
  if (commentp)
    *commentp = r[4];
  
  if ((salt && !pwd))
    return NULL;

  ptr<rabin_priv> skp;

  if (salt) {
    eksblowfish eksb;
    pw_crypt (pwd, salt, SK_RABIN_SALTBITS, &eksb);
    skp = import_rabin_decrypt_sec (seckey, &eksb);
    if (!skp)
      return NULL;
  }
  else
    skp = import_rabin_decrypt_sec (seckey, NULL);

  if (n != skp->n) {
    warn << "Error:  Public and private keys do not match!\n";
    warn << "(Public key should be 0x" << cat (skp->n, 16) << ")\n";
    return NULL;
  }

  return skp;
}

ptr<rabin_pub>
import_rabin_pub (str asc)
{
  rxx r (rabin_import_format);
  if (r.search (asc))
    return new refcounted<rabin_pub> (bigint (r[3]));
  else
    return NULL;
}

bool
import_rabin_priv_need_pwd (str asc, u_int *costp)
{
  rxx r (rabin_import_format);
  if (!r.search (asc))
    return false;
  str s (r[1]);
  if (!s)
    return false;
  if (!pw_dearmorsalt (costp, NULL, NULL, s))
    return false;
  return true;
}

static void
import_rabin_priv_askpwd_cb (str asc,
			     callback<void, ptr<rabin_priv> >::ref cb,
			     str *commentp,
			     str pwd)
{
    (*cb) (import_rabin_priv (asc, pwd, commentp));
}

bool
import_rabin_priv_askpwd (str asc, str pwdprompt,
			  callback<void, ptr<rabin_priv> >::ref cb,
			  str *commentp)
{
  if (import_rabin_priv_need_pwd (asc))
    return getkbdpwd (pwdprompt, &rnd_input,
		      wrap (import_rabin_priv_askpwd_cb, asc, cb, commentp));
  cb (import_rabin_priv (asc, NULL, commentp));
  return true;
}

const rxx srp_import_format ("^N=(0x[0-9a-f]+),g=(0x[0-9a-f]+)$");

bool
import_srp_params (str raw, bigint *Np, bigint *gp)
{
  if (!raw)
    return false;

  rxx r (srp_import_format);
  if (!r.search (raw))
    return false;

  *Np = r[1];
  *gp = r[2];
  return true;
}

str
export_srp_params (const bigint &N, const bigint &g)
{
  return strbuf ("N=0x") << N.getstr (16) << ",g=0x" << g.getstr (16);
}
