package IR2;

import java.util.Enumeration;
import java.util.Hashtable;
import java.util.Vector;

/*
  This class implements the standard reaching definitions algorithm.
  The goal is to build, for each BasicBlock, the set of definitions
  which reach the beginning of the block (in) and the end of the block
  (out). These sets will contain DelocalizedLValues.
*/

public class ReachingDefinitions extends FixedPointAlgorithm {
  private Set all;
  private Hashtable ChainMap;
  private Vector resulting_chains;
  private boolean debug=false; 

  /* Constructor */
  public ReachingDefinitions(BasicBlock entry, BasicBlock exit) {
    super(entry, exit);
  }

  /* Implementation of abstract methods from FPA */
  protected int direction() {    // Return is BACKWARD or FORWARD
    return FORWARD;
  }

  protected int confluence() {   // Return is a constant from Set class
    return Set.UNION;
  }

  // returns the set of all definitions (DelocalizedInstructions doing assn)
  protected Set all_elements() { // Set of all elements.
    return all;
  }

  // returns set of defintions this block kills
  protected Set kill(BasicBlock b) { 
    Set s = new Set();
    DelocalizedInstruction current = b.get_last();
    DelocalizedLValue lval;
    
    while (current != null) {
      lval = current.destination();
      // Only actual locals, please.
      if (lval != null) lval = lval.lproxy();
      if (!(lval instanceof LocalScalarVarDescriptor)) lval = null;

      if (lval != null) {
        // Okay, so has a destination. This assignment kills ALL assignments
        // to this LValue
        Enumeration assgns = all.elements();
        while (assgns.hasMoreElements()) {
          DelocalizedInstruction instr = 
            (DelocalizedInstruction)assgns.nextElement();
          
          // Same target?
          if (instr.destination().lproxy() == lval)
            s.add(instr); // I kill you!
        }
      }
      current = b.prev_instr(current);
    }
    return s;
  }

  // returns set of defintions this block generates
  protected Set generate(BasicBlock b) {     // compute generate bitset
    Set s = new Set();
    Set lvalues = new Set();
    DelocalizedInstruction foo = b.get_last();
    DelocalizedLValue dest;

    while (foo != null) {
      // Has a destination and haven't already defined it later in block?
      dest = foo.destination();
      if (dest != null) dest = dest.lproxy();

      // Only actual locals, please.
      if (!(dest instanceof LocalScalarVarDescriptor)) dest = null;
      
      if (dest != null && !lvalues.contains(dest)) {
        s.add(foo);
        lvalues.add(dest);
      }
      foo = b.prev_instr(foo);
    }
    return s;
  }

  protected void optimize() { // Mutate Blocks to get optimization

    if (debug) {
      Enumeration blob = entry.self_and_all_successors();
      while (blob.hasMoreElements()){
      BasicBlock curblock = (BasicBlock)blob.nextElement();
      
      System.out.println("\nBlock: " + curblock + ": " +curblock.get_first()+
                         " -> " + curblock.get_last());
      System.out.println(in(curblock).toString());
      System.out.println("--->");
      System.out.println(out(curblock).toString());
      }
    }

    // Construct DUChains (not actually optimizing. so it's a misnomer. Should
    // be like "postpass()" or something.
    // Should be one chain for each Definition. So start with a hashtable
    // of Definition to DUChain.
    Enumeration e = all.elements();
    ChainMap = new Hashtable();
    while (e.hasMoreElements()) {
      DelocalizedInstruction di = (DelocalizedInstruction)e.nextElement();
      ChainMap.put(di, new DUChain(di));
    }

    // Look at every basicblock. Start at the beginning, with the IN set.
    // For every instruction, if a given definition reaches that point, and
    // the instruction uses the lvalue defined, add that instruction to the
    // correct DUChain. If a definition is encountered, add that definition
    // to the in set, and put other definitions of the same lvalue. proceed.
    // deal with uses BEFORE looking at what it is assigned to.
    Enumeration block_enum = entry.self_and_all_successors(), src_enum, 
      def_enum;
    BasicBlock curblock;
    Set reaches;
    DelocalizedInstruction curinstr, defn;
    DelocalizedLValue dest;
    int bc = 0;
    while (block_enum.hasMoreElements()) {
      if (debug) System.out.println("** BLOCK " + bc++);
      curblock = (BasicBlock)block_enum.nextElement();
      reaches = this.in(curblock);
      //      if (debug) System.out.println("IN: " + reaches);
      
      curinstr = curblock.get_first();
      
      // instruction processing loop
      while (curinstr != null) {

        // This instruction is reached by definitions in the reaches set
        for (def_enum = reaches.elements(); def_enum.hasMoreElements() ; ) {
          defn = (DelocalizedInstruction)def_enum.nextElement();
          //          if (debug) System.out.println("   Defn " + defn + " reaches " + curinstr);
          if (ChainMap.get(defn) == null)
            System.out.println("Oh fuck. defn but no Chain for: " + defn +
                               "\n    Note that dest: " + 
                               defn.destination().lproxy().hashCode());
              else
                ((DUChain)ChainMap.get(defn)).addElement(curinstr);
        }
        
        // Is this a defn itself?
        dest = curinstr.destination();
        if (dest != null) dest = dest.lproxy();
        if (!(dest instanceof LocalScalarVarDescriptor)) dest = null;
        if (dest != null) {

          // Yup. remove all other defns of this source. then add this one.
          for (def_enum = reaches.elements(); def_enum.hasMoreElements() ;) {
            defn = (DelocalizedInstruction)def_enum.nextElement();
            
            if (defn.destination().lproxy() == dest) { // I kill you dead! 
              //              if (debug) System.out.println("---Removing: " + defn);
              reaches.removeElement(defn);
            }
          }
          
          // This definition is now reaching things.
          //          if (debug) System.out.println("---Adding: " + curinstr);
          reaches.addElement(curinstr);
        }
        curinstr = curblock.next_instr(curinstr);
      }
      if (debug && !reaches.equals(out(curblock)))
        System.out.println("ASSERTION ERROR.\nREACHES: " + reaches + "\nOUT: "+
                           out(curblock));
    }

    vectorize_chains(); // Stuff in vector
  }

  protected void set_initial_conditions() { // Setup initial conds
    // Initial conditions: All outs start empty. out[entry] = gen[Entry]
    // This is already done by FPA.

    // Create set of all elements just once.
    all = new Set();

    Enumeration blob = entry.self_and_all_successors();
    while (blob.hasMoreElements()){
      BasicBlock curblock = (BasicBlock)blob.nextElement();

      DelocalizedInstruction foo = curblock.get_first();
      while (foo != null) {
        if (foo.destination() != null) {
          if (foo.destination().lproxy() instanceof LocalScalarVarDescriptor)
            all.add(foo); // Add entire instruction (so defs will be unique)
        }
        foo = curblock.next_instr(foo);
      }
    }
  }

  private void vectorize_chains() {
    // Take all the values in the HashTable (the DUChains) and put in a vector
    resulting_chains = new Vector();

    for (Enumeration e = ChainMap.elements(); e.hasMoreElements(); ) {
      Object o = e.nextElement();
      if (debug) System.out.println("DUChain: " + o + " \n");
      resulting_chains.addElement(o);
    }
  }

  // Retrieve all DUChains after postpass.
  public Vector get_chains() { return resulting_chains; }
  
}
