package mkgray.security.dh;

import Acme.Crypto.*;
import java.math.*;
import java.security.*;

public class DiffieHellmanChannel {
  public static BigInteger mod = new BigInteger("60532748047142398241743914746781539761521964315763078807935134362794697295107");
  public static BigInteger generator = new BigInteger("2");

  public BigInteger privateKey;
  public BigInteger publicKey;

  public BigInteger sharedKey;

  public Acme.Crypto.BlockCipher cipher;

  public void generateKeyPair(int size) {
    privateKey = new BigInteger(size, new SecureRandom());
    publicKey = generator.modPow(privateKey, mod);
  }

  public void generateSharedKey(BigInteger otherPublicKey){
    sharedKey = otherPublicKey.modPow(privateKey, mod);
  }

  public void setCipher(Acme.Crypto.BlockCipher c){
    cipher = c;
    if(sharedKey == null)
      System.out.println("Warning!!! No Shared Key.");
    else
      cipher.setKey(sharedKey.toByteArray());
  }

  public String encrypt(String s){
    return new String(encrypt(s.getBytes()));
  }

  public String decrypt(String s){
    return new String(decrypt(s.getBytes()));
  }

  public byte[] encrypt(byte plainText[]){
    int ptr = 0;
    byte[] pblock = new byte[8];
    byte[] cblock = new byte[8];
    byte[] cipherText = new byte[ (((int) plainText.length/8)+1)*8 ];

    while((plainText.length - ptr) >= 8){
      try { System.arraycopy(plainText, ptr, pblock, 0, 8); }
      catch(Exception e){}

      cipher.encrypt(pblock, 0, cblock, 0);

      try { System.arraycopy(cblock, 0, cipherText, ptr, 8); }
      catch (Exception e) {}
      ptr+=8;
    }

    try { System.arraycopy(plainText, ptr, pblock, 0, plainText.length-ptr); }
    catch (Exception e) {}

    for(int tmp = plainText.length-ptr;tmp<8;tmp++)
      pblock[tmp] = 0;

    cipher.encrypt(pblock, 0, cblock, 0);
    
    try { System.arraycopy(cblock, 0, cipherText, ptr, 8); }
    catch (Exception e) {}
    
    return cipherText;
  }
  
  public byte[] decrypt(byte cipherText[]){
    int ptr = 0;
    byte[] pblock = new byte[8];
    byte[] cblock = new byte[8];
    byte[] plainText = new byte[cipherText.length];

    while((cipherText.length - ptr) >= 8){
      try { System.arraycopy(cipherText, ptr, cblock, 0, 8); }
      catch(Exception e){}

      cipher.decrypt(cblock, 0, pblock, 0);

      try { System.arraycopy(pblock, 0, plainText, ptr, 8); }
      catch (Exception e) {}
      ptr+=8;
    }

    return plainText;
  }
}

