import java.io.*;
import java.util.*;

public class ClusterWeightedStatePredictor implements Serializable {
  Vector states, feature, clusters;
  int dim=2; // Number of dimensions of feature vectors
  double varmin = 1;
  double varmax = 5000;
  Hashtable trained;
  boolean clusterDeath, stuckClusters;

  ClusterWeightedStatePredictor(int ncl, int ndim) {
    int ct, d;
    Random rand = new Random();
    double foo;

    dim = ndim;
    trained = new Hashtable();
    clusters = new Vector();
    states = new Vector();
    feature= new Vector();
    foo = 150;
    for(ct=0; ct< ncl; ct++){
      double m1[], v1[];
      m1 = new double[dim];
      v1 = new double[dim];
      for(d=0;d<dim;d++){
	m1[d] = (rand.nextDouble()*100);
	v1[d] = 60;
	foo+=10;
      }
      clusters.addElement(new Cluster(m1, v1, 1.0));
    }
  }

  public State predict(FeatureVector f){
    double sum = 0, summax=0;
    State outputState = null;
    
    Cluster c = (Cluster) clusters.elementAt(0);
    for(Enumeration e = c.stateDensities.keys(); e.hasMoreElements(); ){
      State mys;
      mys = ((State) e.nextElement());
      sum = 0;
      for(int ct=0; ct< clusters.size(); ct++){
	sum += density(mys, f, (Cluster) clusters.elementAt(ct));
      }
      System.out.println("Probability for "+mys+": "+sum);
      if(sum > summax){
	summax = sum;
	outputState = mys;
      }
    }
    return outputState;
  }

  public void setMinVar(double mv) {
    varmin = mv;
  }

  public void featureDimensions(int d){
    dim = d;
  }

  public void train(FeatureVector f, State s){
    int ct;
    Cluster c;

    trained.put(f, s);
    feature.addElement(f);
    states.addElement(s);
    Random rnd = new Random();
    if(rnd.nextFloat() > .97) {
      Cluster thc = (Cluster) clusters.elementAt((int) (clusters.size()*rnd.nextFloat()));
      System.arraycopy(f.vec, 0, thc.mean, 0, f.vec.length);
    }
    for(ct=0;ct < clusters.size(); ct++){
      c = (Cluster) clusters.elementAt(ct);
      c.stateDensities.put(s, new Double(1));
    }    
  }

  public void doneTraining(){
    int ct;
    Cluster c;

    for(ct=0;ct < clusters.size(); ct++){
      c = (Cluster) clusters.elementAt(ct);
      for(Enumeration e = c.stateDensities.keys(); e.hasMoreElements(); ){
	State mys;
	mys = ((State) e.nextElement());
	c.stateDensities.put(mys, new Double(((Double) c.stateDensities.get(mys)).doubleValue()/c.stateDensities.size()));
      }
      ((Cluster) clusters.elementAt(ct)).weight = 1.0/clusters.size();
    }
  }

  public void showClusters() {
    int ct;
    Cluster c;

    for(ct=0;ct < clusters.size(); ct++){
      c = (Cluster) clusters.elementAt(ct);
      System.out.println("Cluster #"+ct);
      System.out.println("Weight "+c.weight);
      System.out.print("Mean ");
      for(int ctd=0;ctd<dim;ctd++){
	System.out.print(c.mean[ctd]+" ");
      }
      System.out.println();
      System.out.print("Variance ");
      for(int ctd=0;ctd<dim;ctd++){
	System.out.print(c.variance[ctd]+" ");
      }
      System.out.println();
      for(Enumeration e = c.stateDensities.keys(); e.hasMoreElements(); ){
	State mys;
	mys = ((State) e.nextElement());
	System.out.println("State probability for state "+mys.name+": "+((Double)c.stateDensities.get(mys)).doubleValue());
      }

      System.out.println();
    }
  }

  public void updateClusters(ClusterViewer2D cv) {
    int n, m, clust, ct;
    double mres[], vres[];
    Hashtable stateTop;
    double bottom = 0.0;
    double res = 0.0;
    double posterior;

    for(clust=0;clust < clusters.size(); clust++){
      // Repaint...
      if(cv != null)
	cv.repaint();
      try { Thread.sleep(20); } catch (Exception e) { }
      Cluster newc = new Cluster(new double[dim], new double[dim], 1.0);
      posterior = 0;
      stateTop = new Hashtable();
      double stateBottom = 0.0;
      mres = new double[dim];
      vres = new double[dim];
      for(ct=0; ct < 	dim; ct++){
	mres[ct] = 0.0;
	vres[ct] = 0.0;
      }
      stateBottom = 0.0;
      res = 0.0;
      // Update cluster weights
      for(n=0; n < feature.size(); n++){
	// Repaint...
	if(cv != null)
	  cv.repaint();

	bottom = 0.0;
	for(m = 0; m < clusters.size(); m++){
	  double den;
	  bottom += (den = density((State) states.elementAt(n),
				   (FeatureVector) feature.elementAt(n),
				   (Cluster) clusters.elementAt(m)));
	}

	// Add in the posterior
	posterior = density((State) states.elementAt(n),
			    (FeatureVector) feature.elementAt(n),
			    (Cluster) clusters.elementAt(clust))/bottom;
	res += posterior;
	
	for(ct=0; ct < 	dim; ct++){
	  double tmp;
	  mres[ct] += ((FeatureVector) feature.elementAt(n)).vec[ct]*posterior;
	  tmp = ((FeatureVector) feature.elementAt(n)).vec[ct] - ((Cluster) clusters.elementAt(clust)).mean[ct];
	  vres[ct] += (posterior*tmp*tmp);
	}

	if(stateTop.get((State) states.elementAt(n)) != null){
	  stateTop.put((State) states.elementAt(n), 
		       new Double(((Double) stateTop.get((State) states.elementAt(n))).doubleValue() + 
				  density( ((State) states.elementAt(n)), 
					   ((FeatureVector) feature.elementAt(n)),
					   ((Cluster) clusters.elementAt(clust)))));
	}
	else{
	  stateTop.put((State) states.elementAt(n), 
		       new Double(density( ((State) states.elementAt(n)), 
					   ((FeatureVector) feature.elementAt(n)),
					   ((Cluster) clusters.elementAt(clust)))));
	}
	stateBottom += density( ((State) states.elementAt(n)), 
					   ((FeatureVector) feature.elementAt(n)),
					   ((Cluster) clusters.elementAt(clust)));
      }
      // Update weights
      newc.weight = (res/feature.size());
      // Repaint...
      if(cv != null)
	cv.repaint();

      // Update means
      for(ct=0; ct<dim; ct++){
	//	newc.mean[ct] = mres[ct]/(((Cluster) clusters.elementAt(clust)).weight * feature.size());
	newc.mean[ct] = mres[ct]/(newc.weight * feature.size());
	//	newc.mean[ct] = mres[ct]/(feature.size());
	// Momentum
	newc.mean[ct] += .05*(newc.mean[ct] - ((Cluster) clusters.elementAt(clust)).mean[ct]);
      }
      // Update variances
      for(ct=0; ct<dim; ct++){
	//	newc.variance[ct] = varmin+(vres[ct]/(feature.size()));
	newc.variance[ct] = varmin+Math.sqrt(vres[ct]/(feature.size()* ((Cluster) clusters.elementAt(clust)).weight));
	//	if(newc.variance[ct] > varmax)
	//	  newc.variance[ct] = varmax;
	vres[ct] = 0.0;
      }

      // Update state probabilities
      //      System.out.println();
      //      System.out.println("Total of all states: "+stateBottom);
      int stct=0;
      for(Enumeration e = stateTop.keys(); e.hasMoreElements(); ){
	State mys;
	mys = ((State) e.nextElement());
	//	System.out.println("State "+mys.name+" is "+((Double) stateTop.get(mys)).doubleValue());
	if(stuckClusters){
	  if(clust%(stateTop.size()) == stct){
	    newc.stateDensities.put(mys, new Double(1.0));
	  }
	  else{
	    newc.stateDensities.put(mys, new Double(0.0));
	  }
	  stct++;
	}
	else
	  newc.stateDensities.put(mys, new Double(((Double) stateTop.get(mys)).doubleValue()/stateBottom));
      }
      //      System.out.println();

      System.out.println("Is "+newc.weight+" less than "+(1.0/(50.0*clusters.size()))+" ?");
      if(newc.weight == 0.0){
	System.out.println("Cluster "+clust+" is dead anyway");
	clusters.removeElementAt(clust);
	clust--;
      }
      else if(clusterDeath && (newc.weight < (1.0/(50.0*clusters.size())))){
	System.out.println("Cluster "+clust+" is a waste of time...");
	clusters.removeElementAt(clust);
	clust--;
      }
      else{
	clusters.setElementAt(newc, clust);
      }
    }
  }

  public double density(State s, FeatureVector x, Cluster c){
    double psxc, pxc;
    psxc = densityOfSGivenXandC(s, x, c);
    pxc = densityOfXGivenC(x, c);
    return (psxc*pxc*c.weight);
  }

  public double densityOfSGivenXandC(State s, FeatureVector x, Cluster c){
    return ((Double) (c.stateDensities).get(s)).doubleValue();
  }

  public double densityOfXGivenC(FeatureVector x, Cluster c){
    int ct;
    double prefactor, vmd, mmd, xd;
    double res = 1;
    

    res = 1.0;
    for(ct=0; ct <dim; ct++){
      vmd = c.variance[ct];
      mmd = c.mean[ct];
      xd = x.vec[ct];
      prefactor = 1/(Math.sqrt(2*Math.PI*vmd*vmd));
      res *= prefactor*Math.exp(-1*((xd-mmd)*(xd-mmd))/(2*(vmd*vmd)));
    }
    return res;
  }

}
