// cky.h
//
// Mark Johnson, 4th January 2006

#ifndef CKY_H
#define CKY_H

#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstdlib>
#include <ext/hash_map>
#include <iostream>
#include <map>
#include <sstream>
#include <string>
#include <utility>
#include <vector>

#include "mt19937ar.h"
#include "sym.h"
#include "tree.h"
#include "trie.h"
#include "utility.h"

extern int debug;

inline float power(float x, float y) { return powf(x, y); }
inline double power(double x, double y) { return pow(x, y); }
inline long double power(long double x, long double y) { return powl(x, y); }

typedef long double F;
typedef symbol S;
typedef std::vector<S> Ss;

typedef std::map<S,F> S_F;
// typedef ext::hash_map<S,F> S_F;

//! readline_symbols() reads all of the symbols on the current
//! line into syms
//
std::istream& readline_symbols(std::istream& is, Ss& syms) {
  syms.clear();
  std::string line;
  if (std::getline(is, line)) {
    std::istringstream iss(line);
    std::string s;
    while (iss >> s)
      syms.push_back(s);
  }
  return is;
}  // readline_symbols()


//! A default_value_type{} object is used to read an object from a stream,
//! assigning a default value if the read fails.  Users should not need to
//! construct such objects, but should use default_value() instead.
//
template <typename object_type, typename default_type>
struct default_value_type {
  object_type& object;
  const default_type defaultvalue;
  default_value_type(object_type& object, const default_type defaultvalue)
    : object(object), defaultvalue(defaultvalue) { }
};

//! default_value() is used to read an object from a stream, assigning a
//! default value if the read fails.  It returns a default_value_type{}
//! object, which does the actual reading.
//
template <typename object_type, typename default_type>
default_value_type<object_type,default_type>
default_value(object_type& object, const default_type defaultvalue=default_type()) {
  return default_value_type<object_type,default_type>(object, defaultvalue);
}

//! This version of operator>>() reads default_value_type{} from an input stream.
//
template <typename object_type, typename default_type>
std::istream& operator>> (std::istream& is, 
			  default_value_type<object_type, default_type> dv) {
  if (is) {
    if (is >> dv.object)
      ;
    else {
      is.clear(is.rdstate() & ~std::ios::failbit);  // clear failbit
      dv.object = dv.defaultvalue;
    }
  }
  return is;
}

// inline F random1() { return rand()/(RAND_MAX+1.0); }
inline F random1() { return mt_genrand_res53(); }


//! A pcfg_type is a CKY parser
//
struct pcfg_type {

  pcfg_type(F default_weight=1) : default_weight(default_weight) { }

  typedef unsigned int U;

  typedef ext::hash_map<S,S_F> S_S_F;

  typedef trie<S, S_F> St_S_F;
  typedef St_S_F::const_iterator Stit;

  //! start is the start symbol of the grammar
  //
  S start;

  //! rhs_parent_weight maps the right-hand sides of rules
  //! to rule parent and rule weight 
  //
  St_S_F rhs_parent_weight;

  //! unarychild_parent_weight maps unary children to a vector
  //! of parent-weight pairs
  //
  S_S_F unarychild_parent_weight;

  //! parent_weight maps parents to the sum of their rule weights
  //
  S_F parent_weight;

  //! default_weight is the default weight for rules with no explicit
  //! weight
  //
  F default_weight;

  typedef std::pair<S,Ss> SSs;
  typedef ext::hash_map<SSs,F> SSs_F;

  //! rule_priorweight maps rule = [parent|rhs] to its prior weight
  //
  SSs_F rule_priorweight;

  //! parent_priorweight maps parents to their prior weight
  //
  S_F parent_priorweight;

  //! rule_weight() returns the weight of rule parent --> rhs
  //
  template <typename rhs_type>
  F rule_weight(S parent, const rhs_type& rhs) const {
    assert(!rhs.empty());
    if (rhs.size() == 1) {
      S_S_F::const_iterator it = unarychild_parent_weight.find(rhs[0]);
      if (it == unarychild_parent_weight.end())
	return 0;
      else
	return dfind(it->second, parent);
    }
    else {  // rhs.size() > 1
      Stit it = rhs_parent_weight.find(rhs);
      if (it == rhs_parent_weight.end())
	return 0;
      else
	return dfind(it->data, parent);
    }
  }  // pcfg_type::rule_weight()

  //! rule_prob() returns the probability of rule parent --> rhs
  //
  template <typename rhs_type>
  F rule_prob(S parent, const rhs_type& rhs) const {
    assert(!rhs.empty());
    F parentweight = afind(parent_weight, parent);
    F ruleweight = rule_weight(parent, rhs);
    assert(ruleweight > 0);
    return ruleweight/parentweight;
  }  // pcfg_type::rule_prob()

  //! tree_prob() returns the probability of the tree under the current
  //! model
  //
  F tree_prob(const tree* tp) const {
    F prob = 1;
    if (!tp->child) 
      return prob;
    Ss children;
    for (const tree* child = tp->child; child != NULL; child = child->next) {
      children.push_back(child->label.cat);
      prob *= tree_prob(child);
    }
    prob *= rule_prob(tp->label.cat, children);
    if (prob <= 0)
      std::cerr << "## pcfg_type::tree_prob(" << tp << ") = " << prob << std::endl;
    return prob;
  }  // pcfg_type::tree_prob()

  //! increment() increments the weight of the rule parent --> rhs,
  //! returning the probability of this rule under incrementally modified
  //! grammar
  //
  template <typename rhs_type>
  F increment(S parent, const rhs_type& rhs, F weight = 1) {
    assert(!rhs.empty());
    F weight1;
    F parentweight = (parent_weight[parent] += weight);
    assert(parentweight >= 0);
    if (parentweight == 0)
      parent_weight.erase(parent);
    if (rhs.size() == 1) {
      S_F& parent1_weight = unarychild_parent_weight[rhs[0]];
      weight1 = (parent1_weight[parent] += weight);
      assert(weight1 >= 0);
      if (weight1 == 0) {
	parent1_weight.erase(parent);
	if (parent1_weight.empty())
	  unarychild_parent_weight.erase(rhs[0]);
      }
    }
    else {  // non-unary rule
      S_F& parent1_weight = rhs_parent_weight[rhs];
      weight1 = (parent1_weight[parent] += weight);
      if (weight1 == 0) {
	parent1_weight.erase(parent);
	if (parent1_weight.empty())
	  rhs_parent_weight.erase(rhs);
      }
    }
    // if weight > 0, then return rule probability under old counts,
    // otherwise return rule probability under new counts.
    return (weight > 0) ? (weight1-weight)/(parentweight-weight) : weight1/parentweight;
  }  // pcfg_type::increment()

  //! increment() increments the weight of the all of the local trees
  //! in tp, and returns the conditional probability of tp
  //
  F increment(const tree* tp, F weight = 1) {
    F prob = 1;
    if (tp->child) {
      { 
	Ss children;
	for (const tree* child = tp->child; child; child = child->next)
	  children.push_back(child->label.cat);
	prob *= increment(tp->label.cat, children, weight);
      }
      for (const tree* child = tp->child; child; child = child->next)
	prob *= increment(child, weight);
    }
    return prob;
  }  // pcfg_type::increment()

  //! read() reads a grammar from an input stream (implements >> )
  //
  std::istream& read(std::istream& is) {
    start = symbol::undefined();
    F weight;
    std::string parent;
    while (is >> default_value(weight, default_weight) >> parent >> " -->") {
      if (start.is_undefined())
	start = parent;
      Ss rhs;
      readline_symbols(is, rhs);
      increment(parent, rhs, weight);
      rule_priorweight[SSs(parent,rhs)] += weight;
      parent_priorweight[parent] += weight;
    }
    return is;
  }  // pcfg_type::read()

  //! write() writes a grammar (implements << )
  //
  std::ostream& write(std::ostream& os) const {
    assert(start.is_defined());
    write_rules(os, start);
    cforeach (S_F, it, parent_weight)
      if (it->first != start)
	write_rules(os, it->first);
    return os;
  }  // pcfg_type::write()

  std::ostream& write_rules(std::ostream& os, S parent) const {
    rhs_parent_weight.for_each(write_rule(os, parent));
    cforeach (S_S_F, it0, unarychild_parent_weight) {
      S child = it0->first;
      const S_F& parent_weight = it0->second;
      cforeach (S_F, it1, parent_weight)
	if (it1->first == parent)
	  os << it1->second << '\t' << parent 
	     << " --> " << child << std::endl;
    }
    return os;
  }  // pcfg_type::write_rules()

  //! write_rule writes a single rule
  //
  struct write_rule {
    std::ostream& os;
    S parent;

    write_rule(std::ostream& os, symbol parent) : os(os), parent(parent) { }

    template <typename Keys, typename Value>
    void operator() (const Keys& rhs, const Value& parentweights) {
      cforeach (typename Value, pwit, parentweights) 
	if (pwit->first == parent) {
	  os << pwit->second << '\t' << parent << " -->";
	  cforeach (typename Keys, rhsit, rhs)
	    os << ' ' << *rhsit;
	  os << std::endl;
	}
    }  // pcfg_type::write_rule::operator()

  };  // pcfg_type::write_rule{}
  
  F log2prob_corpus() const {
    F logprob = 0;
    cforeach (SSs_F, it, rule_priorweight) {
      const S parent = it->first.first;
      const Ss& rhs = it->first.second;
      F priorweight = it->second;
      F weight = rule_weight(parent, rhs);
      logprob += lgamma(weight);
      logprob -= lgamma(priorweight);
    }
    cforeach (S_F, it, parent_priorweight) {
      S parent = it->first;
      F priorweight = it->second;
      F weight = dfind(parent_weight, parent);
      logprob -= lgamma(weight);
      logprob += lgamma(priorweight);
    }
    return logprob / log(2.0);
  }  // pcfg_type::log2prob_corpus()

};  // pcfg_type{}


//! operator>> (pcfg_type&) reads a pcfg_type g, setting g.start
//! to the parent of the first rule read.
//
std::istream& operator>> (std::istream& is, pcfg_type& g) {
  return g.read(is);
}  // operator>> (pcfg_type&)


std::ostream& operator<< (std::ostream& os, const pcfg_type& g) {
  return g.write(os);
}  // operator<< (pcfg_type&)

namespace EXT_NAMESPACE {
  template <> struct hash<pcfg_type::Stit> {
    size_t operator()(const pcfg_type::Stit t) const
    {
      return size_t(&(*t));
    }  // ext::hash<pcfg_type::Stit>::operator()
  };  // ext::hash<pcfg_type::Stit>{}
}  // namespace EXT_NAMESPACE

static const F unaryclosetolerance = 1e-7;

class cky_type {

public:

  const pcfg_type& g;
  F anneal;         // annealing factor (1 = no annealing)
  
  cky_type(const pcfg_type& g, F anneal=1) : g(g), anneal(anneal) { }

  typedef pcfg_type::U U;
  typedef pcfg_type::S_S_F S_S_F;
  typedef pcfg_type::St_S_F St_S_F;
  typedef pcfg_type::Stit Stit;

  typedef std::vector<S_F> S_Fs;
  typedef ext::hash_map<Stit,F> Stit_F;
  typedef std::vector<Stit_F> Stit_Fs;

  //! index() returns the location of cell in cells[]
  //
  static U index(U i, U j) { return j*(j-1)/2+i; }

  //! ncells() returns the number of cells required for sentence of length n
  //
  static U ncells(U n) { return n*(n+1)/2; }
  
  Ss terminals;
  S_Fs inactives;
  Stit_Fs actives;

  //! inside() constructs the inside table, and returns the probability
  //! of the start symbol rewriting to the terminals.
  //
  template <typename terminals_type>
  F inside(const terminals_type& terminals0) {

    terminals = terminals0;

    if (debug >= 10000)
      std::cerr << "# cky::inside() terminals = " << terminals << std::endl;

    U n = terminals.size();

    inactives.clear();
    inactives.resize(ncells(n));
    actives.clear();
    actives.resize(ncells(n));

    for (U i = 0; i < n; ++i) {
      inactives[index(i,i+1)][terminals[i]] = 1;
      inside_unaryclose(inactives[index(i,i+1)], actives[index(i,i+1)]);
      
      if (debug >= 20000)
	std::cerr << "# cky::inside() inactives[" << i << "," << i+1 << "] = " 
		  << inactives[index(i,i+1)] << std::endl;
      if (debug >= 20100)
	std::cerr << "# cky::inside() actives[" << i << "," << i+1 << "] = " 
		  << actives[index(i,i+1)] << std::endl;
    }

    for (U gap = 2; gap <= n; ++gap)
      for (U left = 0; left + gap <= n; ++left) {
	U right = left + gap;
	S_F& parentinactives = inactives[index(left,right)];
	Stit_F& parentactives = actives[index(left,right)];
	for (U mid = left+1; mid < right; ++mid) {
	  Stit_F& leftactives = actives[index(left,mid)];
	  const S_F& rightinactives = inactives[index(mid,right)];
	  cforeach (Stit_F, itleft, leftactives) {
	    const Stit leftactive = itleft->first;
	    const F leftprob = itleft->second;
	    cforeach (S_F, itright, rightinactives) {
	      S rightinactive = itright->first;
	      const F rightprob = itright->second;
	      const Stit parentactive = leftactive->find1(rightinactive);
	      if (parentactive != leftactive->end()) {
		F leftrightprob = leftprob * rightprob;
		cforeach (S_F, itparent, parentactive->data) {
		  S parent = itparent->first;
		  parentinactives[parent] += leftrightprob 
		    * power(itparent->second/afind(g.parent_weight, parent), anneal);
		}
		if (!parentactive->key_trie.empty())
		  parentactives[parentactive] += leftrightprob;
	      }
	    }
	  }
	}
	inside_unaryclose(parentinactives, parentactives);
	if (debug >= 20000)
	  std::cerr << "# cky::inside() inactives[" << left << "," << right 
		    << "] = " << parentinactives << std::endl;
	if (debug >= 20100)
	  std::cerr << "# cky::inside() actives[" << left << "," << right << "] = " 
		    << parentactives << std::endl;
      }

    return dfind(inactives[index(0,n)], g.start);
  }  // cky_type::inside()

  void inside_unaryclose(S_F& inactives, Stit_F& actives) {
    F delta = 1;
    S_F delta_prob1 = inactives;
    S_F delta_prob0;
    while (delta > unaryclosetolerance) {
      delta = 0;
      delta_prob0 = delta_prob1;
      delta_prob1.clear();
      cforeach (S_F, it0, delta_prob0) {
	S child = it0->first;
	S_S_F::const_iterator it = g.unarychild_parent_weight.find(child);
	if (it != g.unarychild_parent_weight.end()) {
	  const S_F& parent_weight = it->second;
	  cforeach (S_F, it1, parent_weight) {
	    S parent = it1->first;
	    F prob = it0->second * power(it1->second/afind(g.parent_weight, parent), anneal);
	    delta_prob1[parent] += prob;
	    delta = std::max(delta, prob/(inactives[parent] += prob));
	  }
	}
      }
    }
    cforeach (S_F, it0, inactives) {
      Stit it1 = g.rhs_parent_weight.find1(it0->first);
      if (it1 != g.rhs_parent_weight.end())
	actives[it1] += it0->second;
    }
  } // cky_type::inside_unaryclose()

 
  //! random_tree() returns a random parse tree for terminals
  //
  tree* random_tree() {
    U n = terminals.size();
    return random_inactive(g.start, afind(inactives[index(0, n)], g.start), 0, n);
  }  // cky_type::random_tree

  //! random_inactive() returns a random expansion for an inactive edge
  //
  tree* random_inactive(const S parent, const F parentprob, const U left, const U right,
			tree* next = NULL) const {

    tree* tp = new tree(parent, NULL, next);

    if (left+1 == right && parent == terminals[left])
      return tp;

    const S_F& parentinactives = inactives[index(left, right)];
    F weightthreshold = random1() * parentprob;
    F parentweight = afind(g.parent_weight, parent);
    F weightsofar = 0;

    // try unary rules

    cforeach (S_F, it0, parentinactives) {
      S child = it0->first;
      F childprob = it0->second;
      S_S_F::const_iterator it1 = g.unarychild_parent_weight.find(child);
      if (it1 != g.unarychild_parent_weight.end()) {
	const S_F& parent1_weight = it1->second;
	weightsofar += childprob 
	  * power(dfind(parent1_weight, parent)/parentweight, anneal);
	if (weightsofar >= weightthreshold) {
	  tp->child = random_inactive(child, childprob, left, right);
	  return tp;
	}
      }
    }

    // try binary rules

    for (U mid = left+1; mid < right; ++mid) {
      const Stit_F& leftactives = actives[index(left,mid)];
      const S_F& rightinactives = inactives[index(mid,right)];
      cforeach (Stit_F, itleft, leftactives) {
	const Stit leftactive = itleft->first;
	const F leftprob = itleft->second;
	cforeach (S_F, itright, rightinactives) {
	  S rightinactive = itright->first;
	  const F rightprob = itright->second;
	  const Stit parentactive = leftactive->find1(rightinactive);
	  if (parentactive != leftactive->end()) {
	    S_F::const_iterator it = parentactive->data.find(parent);
	    if (it != parentactive->data.end()) {
	      weightsofar += leftprob * rightprob 
		* power(it->second/parentweight, anneal);
	      if (weightsofar >= weightthreshold) {
		tp->child = random_active(leftactive, leftprob, left, mid,
					  random_inactive(rightinactive, rightprob, mid, right));
		return tp;
	      }
	    }
	  }
	}
      }
    }

    std::cerr << "## Error in cky_type::random_inactive(), parent = " << parent
	      << ", left = " << left << ", right = " << right 
	      << ", weightsofar = " << weightsofar << ", weightthreshold = " << weightthreshold 
	      << std::endl;
    return tp;
  }  // cky_type::random_inactive()

  tree* random_active(const Stit parent, F parentprob, const U left, const U right, 
		      tree* next = NULL) const {
    F probthreshold = random1() * parentprob;
    F probsofar = 0;

    // unary rule
    
    const S_F& parentinactives = inactives[index(left, right)];
    cforeach (S_F, it, parentinactives)
      if (g.rhs_parent_weight.find1(it->first) == parent) {
	probsofar += it->second;
	if (probsofar >= probthreshold)
	  return random_inactive(it->first, it->second, left, right, next);
	break;  // only one unary child can possibly generate this parent
      }

    // binary rules

    for (U mid = left + 1; mid < right; ++mid) {
      const Stit_F& leftactives = actives[index(left,mid)];
      const S_F& rightinactives = inactives[index(mid,right)];
      cforeach (Stit_F, itleft, leftactives) {
	const Stit leftactive = itleft->first;
	const F leftprob = itleft->second;
	cforeach (S_F, itright, rightinactives) {
	  S rightinactive = itright->first;
	  const F rightprob = itright->second;
	  if (parent == leftactive->find1(rightinactive)) {
	    probsofar += leftprob * rightprob;
	    if (probsofar >= probthreshold) {
	      return random_active(leftactive, leftprob, left, mid,
				   random_inactive(rightinactive, rightprob, mid, right, next));
	    }
	  }
	}
      }
    }

    std::cerr << "## Error in cky_type::random_active(), parent = " << parent
	      << ", left = " << left << ", right = " << right 
	      << ", probsofar = " << probsofar << ", probthreshold = " << probthreshold 
	      << std::endl;
    return NULL;
  }  // cky_type::random_active()

}; // cky_type{}

#endif // CKY_H
