// gibbs-pcfg.cc
//
// gibbs-pcfg uses gibbs sampling to estimate a PCFG

const char usage[] =
"gibbs-pcfg -- Unsupervised PCFG estimator using (uncollapsed) Gibbs sampler\n"
"\n"
" (c) Mark Johnson, version of 22nd April, 2013\n"
"\n"
"gibbs-pcfg [-d debug] [-F trace-file] [-S sample-rate]\n"
"       [-A parsefile] [-G grammarfile] [-P ruleprobfile] [-U rulecountfile]\n"
"       [-I] [-r rand-init] [-C cflag] [-c closetol] [-n niterations] [-w weight]\n"
"       [-T anneal-temp-start] [-t anneal-temp-stop] [-m anneal-its]\n"
"       [-a alpha] [-R] [-Z z-temp] [-z z-its] [-X eval-cmd] [-x eval-every]\n"
"       grammar.lt < train.yld\n"
"\n"
" -d debug         -- debug level\n"
" -F trace-file    -- file to write trace output to (default is stderr)\n"
" -S sample_rate   -- resample theta every sample_rate parses\n"
" -A parsefile     -- print analyses of training data to parsefile at termination\n"
" -G grammarfile   -- print out grammar to grammarfile at termination\n"
" -P ruleprobfile  -- print out rule probabilities during sampling\n" 
" -U rulecountfile -- print out rule counts during sampling\n"
" -I               -- parse sentences in order (default is random order)\n"
" -r rand-init     -- initializer for random number generator (integer)\n"
" -C cflag         -- consistency handling (0 = sink state, 1 = only tight, 2 = renormalise)\n"
" -c closetol      -- tolerance for closure operations (including partition function)\n"
" -n niterations   -- number of iterations\n"
" -w weight        -- default rule weight\n"
" -T tstart        -- start at this annealing temperature\n"
" -t tstop         -- stop with this annealing temperature\n"
" -m anneal-its    -- anneal for this many iterations\n"
" -a alpha         -- default pseudo-count alpha\n"
" -R               -- train.yld contains initial trees rather than strings\n"
" -Z z-temp        -- set Z-temp\n"
" -z z-its         -- perform z-its at temperature Z-temp at end\n"
" -X eval-cmd      -- pipe sampled parses into eval-cmd\n"
" -x eval-every    -- pipe sampled parses every eval-every iterations\n"
"\n"
"The file grammar.lt should contain a list of rules.\n"
"The start category is the parent of the first rule.\n"
"Each rule has the following format:\n"
"\n"
"   [<Alpha> [Theta]] <Category> --> <Category>+\n"
"\n"
"where:\n"
"\n"
"    <Alpha> is an optional Dirichlet prior parameter associated with\n"
"            this rule (default specified on command line)\n"
"    <Theta> is the rule's initial probability\n"
"            (default is a uniform distribution over rules)\n"
"    <Category> is a category in the grammar\n"
"\n";
    
#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstdlib>
#include <ext/hash_map>
#include <iostream>
#include <map>
#include <sstream>
#include <string>
#include <unistd.h>
#include <utility>
#include <vector>

#include "gammavariate.h"
#include "mt19937ar.h"
#include "pstream.h"
#include "sym.h"
#include "tree.h"
#include "trie.h"
#include "utility.h"


typedef unsigned U;

typedef double F;
typedef std::vector<F> Fs;

typedef symbol S;
typedef std::vector<S> Ss;

typedef std::map<S,F> S_F;

int debug = 0;

inline float power(float x, float y) { return (y == 1) ? x : powf(x, y); }
inline double power(double x, double y) { return (y == 1) ? x : pow(x, y); }
inline long double power(long double x, long double y) { return (y == 1) ? x : powl(x, y); }

typedef pstream::ostream* Postreamp;
typedef std::vector<Postreamp> Postreamps;

//! readline_symbols() reads all of the symbols on the current
//! line into syms
//
std::istream& readline_symbols(std::istream& is, Ss& syms) {
  syms.clear();
  std::string line;
  if (std::getline(is, line)) {
    std::istringstream iss(line);
    std::string s;
    while (iss >> s)
      syms.push_back(s);
  }
  return is;
}  // readline_symbols()


//! A default_value_type{} object is used to read an object from a stream,
//! assigning a default value if the read fails.  Users should not need to
//! construct such objects, but should use default_value() instead.
//
template <typename object_type, typename default_type>
struct default_value_type {
  object_type& object;
  const default_type defaultvalue;
  default_value_type(object_type& object, const default_type defaultvalue)
    : object(object), defaultvalue(defaultvalue) { }
};

//! default_value() is used to read an object from a stream, assigning a
//! default value if the read fails.  It returns a default_value_type{}
//! object, which does the actual reading.
//
template <typename object_type, typename default_type>
default_value_type<object_type,default_type>
default_value(object_type& object, const default_type defaultvalue=default_type()) {
  return default_value_type<object_type,default_type>(object, defaultvalue);
}

//! This version of operator>>() reads default_value_type{} from an input stream.
//
template <typename object_type, typename default_type>
std::istream& operator>> (std::istream& is, 
			  default_value_type<object_type, default_type> dv) {
  if (is) {
    if (is >> dv.object)
      ;
    else {
      is.clear(is.rdstate() & ~std::ios::failbit);  // clear failbit
      dv.object = dv.defaultvalue;
    }
  }
  return is;
}

// inline F random1() { return rand()/(RAND_MAX+1.0); }
inline F random1() { return mt_genrand_res53(); }

//! R{} holds a single PCFG rule, together with associated
//! information, i.e., its prior and its count
//
struct R {
  Ss cats;      // parent is cats[0], first child is cats[1], etc.
  F  theta;     // rule probability
  F  old_theta; // use to hold old theta in MH
  F  alpha;     // bayesian prior
  F  count;     // count
  F  sum_theta; // sum of theta
  F  sum_count; // sum of counts

  R() : theta(0), old_theta(0), alpha(0), count(0), sum_theta(0), sum_count(0) { }

  R(S parent, const Ss& rhs, F theta, F alpha) 
    : theta(theta), old_theta(0), alpha(alpha), count(0), sum_theta(0), sum_count(0) {
    cats.reserve(rhs.size()+1);
    cats.push_back(parent);
    cats.insert(cats.end(), rhs.begin(), rhs.end());
  }  // R::R()

  S parent() const { assert(!cats.empty()); return cats[0]; }

  F prob() const { return theta; }

  void sum_theta_count() { sum_theta += theta; sum_count += count; }

  std::ostream& 
  write(std::ostream& os, const char* spacestr=" ", const char* rulestr="-->") const{
    assert(!cats.empty());
    os << cats[0] << spacestr << rulestr;
    for (Ss::const_iterator it = cats.begin()+1; it != cats.end(); ++it)
      os << spacestr << *it;
    return os;
  }  // R::write()

};  // R{}

std::ostream& operator<< (std::ostream& os, const R& r) { return r.write(os); }

typedef std::vector<R> Rs;

typedef R* Rp;

typedef ext::hash_map<S, U> S_U;
typedef ext::hash_map<S, Rp> S_Rp;
typedef ext::hash_map<S, S_Rp> S_S_Rp;

typedef ext::hash_map<Ss,Rp> Ss_Rp;

typedef trie<S, S_Rp> St_S_Rp;
typedef St_S_Rp::const_iterator Stit;

struct pcfg_type;
std::ostream& operator<< (std::ostream& os, const pcfg_type& g);

struct pcfg_type {
  Rs      rules;                       //!< grammar rules
  F       default_theta;               //!< used when reading a grammar
  F       default_alpha;
  F       closetolerance;              //!< tolerance for closure operations
  U       nnonterms;                   //!< number of nonterminals in grammar
  U       nsamples;                    //!< number of samples that were taken
  U       nrejects;                    //!< number of times that a sample was rejected
  F       Z;                           //!< partition function

  // indices 

  S_U     nonterm_id;                  //!< nonterminal -> integer map
  Ss_Rp   cats_rulep;                  //!< rule pointer for this rule
  S_S_Rp  child_parent_urulep;         //!< unary rules, indexed by child
  St_S_Rp rhs_parent_rulep;            //!< rhs -> rules with this rhs
  
  pcfg_type(F t=1, F a=1, F tol=1e-7) 
    : default_theta(t), default_alpha(a), closetolerance(tol), nsamples(0), nrejects(0), Z(-1) { }

  //! start() returns the start category of the grammar
  //
  S start() const { 
    assert(!rules.empty()); 
    assert(!rules[0].cats.empty()); 
    return rules[0].cats[0]; 
  }  // pcfg_type::start()


  //! sample_theta() sets theta to a sample from the Dirichlet distribution
  //! defined by the rule counts and alpha
  //
  void sample_theta () {
    ++nsamples;   // increment number of theta samples
    S_F parent_sum;
    foreach (Rs, it, rules) {
      F alpha = it->count + it->alpha;
      assert(alpha > 0);
      parent_sum[it->parent()] += (it->theta = gammavariate(alpha));
      assert(std::isfinite(it->theta));
      assert(it->theta >= 0);
    }
    foreach (Rs, it, rules) 
      if (it->theta > 0)
	it->theta /= afind(parent_sum, it->parent());
  }  // pcfg_type::sample_theta()

  void sample_theta(U nsentences, U consistency_flag) {
    if (debug >= 900)
      std::cerr << "# starting sample_theta(), old Z = " << Z << ", " << std::flush;
    if (consistency_flag == 0)         // sink element => ignore consistency
      sample_theta();
    else if (consistency_flag == 1) {  // only tight => sample until consistent
      sample_theta();
      while (inconsistent()) {
	sample_theta();
	++nrejects;
      }
    }
    else if (consistency_flag == 2) { // renormalise => MH procedure
      foreach (Rs, it, rules)
	it->old_theta = it->theta;  // save old rule probs
      if (Z < 0) {                  // Z is initialised to -1, so this is first sample
	sample_theta();             // sample new grammar
	Z = partition_function();   // save its partition function
      }
      else {                        // second or later sample
	sample_theta();             // propose new theta
	F new_Z = partition_function(); 
	if ((Z > new_Z) || (random1() < power(Z/new_Z, nsentences))) {
	  Z = new_Z; // accept; save new partition function
	}
	else {       // reject; replace theta proposal with old theta
	  foreach (Rs, it, rules)
	    it->theta = it->old_theta;
	  ++nrejects;
	}
      }
    }
    else {
      std::cerr << "Error: unknown consistency_flag = " << consistency_flag << std::endl;
      std::abort();
    }
    if (debug >= 900)
      std::cerr << " sample_theta() done, Z = " << Z << ", nrejects = " << nrejects << std::endl;
    if (debug >= 10000)
      std::cerr << "# resampled thetas:\n" << *this << std::endl;
  }  // pcfg_type::sample_theta()

  //! sum_theta_count() updates the sums of rule counts and rule thetas.
  //!  This is just for tracing; it doesn't influence the sampler.
  //
  void sum_theta_count() {
    foreach (Rs, it, rules)
      it->sum_theta_count();
  }  // pcfg_type::sum_theta_count()
      
  //! increment() increments the count of the all of the local trees
  //! in tp, and returns the conditional probability of tp
  //
  F increment(const tree* tp, F count = 1) {
    F prob = 1;
    if (tp->child) {
      { 
	Ss cats;
	cats.push_back(tp->label.cat);
	for (const tree* child = tp->child; child; child = child->next)
	  cats.push_back(child->label.cat);
	Rp rp = afind(cats_rulep, cats);
	prob *= rp->prob();
	rp->count += count;
	if (rp->count < 0) 
	  std::cerr << "## rp->count = " << rp->count << ", cats = " << cats << ", tp = " << tp << std::endl;
	assert(rp->count >= 0);
      }
      for (const tree* child = tp->child; child; child = child->next)
	prob *= increment(child, count);
    }
    return prob;
  }  // pcfg_type::increment()

  //! index() constructs the indices into rules needed by the sampler
  //
  void index() {
    nonterm_id.clear();
    cats_rulep.clear();
    child_parent_urulep.clear();
    rhs_parent_rulep.clear();
    S_Rp nullsrp;

    foreach (Rs, rit, rules) {

      nonterm_id.insert(S_U::value_type(rit->cats.front(), nonterm_id.size()));

      bool inserted = cats_rulep.insert(Ss_Rp::value_type(rit->cats, &*rit)).second;
      if (!inserted)
	std::cerr << "## Error in gibbs-pcfg::index() Duplicate rule " << *rit << std::endl;

      if (rit->cats.size() <= 2)
	child_parent_urulep[rit->cats[1]][rit->cats[0]] = &*rit;
      else {
	S_Rp& parent_rulep = rhs_parent_rulep.insert(rit->cats.begin()+1, 
						     rit->cats.end(), 
						     nullsrp).first->data;
	parent_rulep[rit->cats[0]] = &*rit;
      }
    }

    nnonterms = nonterm_id.size();

  }  // pcfg_type::index()

  //! log2prob_corpus() returns the marginal probability of
  //! the corpus counts given the prior
  //
  F log2prob_corpus() const {
    typedef std::pair<F,F> FF;
    typedef ext::hash_map<S,FF> S_FF;
    S_FF parent_alphacount;
    F logprob = 0;
    cforeach (Rs, it, rules) {
      logprob += lgamma(it->count+it->alpha);
      logprob -= lgamma(it->alpha);
      FF& alphacount = parent_alphacount[it->parent()];
      alphacount.first += it->alpha;
      alphacount.second += it->count;
    }
    cforeach (S_FF, it, parent_alphacount) {
      logprob -= lgamma(it->second.first+it->second.second);
      logprob += lgamma(it->second.first);
    }
    return logprob / log(2.0);
  }  // pcfg_type::log2prob_corpus()

  //! loglikelihood() returns the probability of the parse trees
  //! (or equivalently, their counts) given theta estimated from
  //! them.  This is what EM optimizes.
  //
  F loglikelihood() const {
    S_F parent_count;
    F logprob = 0;
    cforeach (Rs, it, rules) 
      if (it->count > 0) {
	logprob += it->count*log(it->count);
	parent_count[it->parent()] += it->count;
      }
    
    cforeach (S_F, it, parent_count)
      logprob -= it->second*log(it->second);
    return logprob;
  } // pcfg_type::loglikelihood()


  //! partition_function() computes the partition function
  //!  for each nonterminal in the PCFG. 
  //
  void partition_function(S_F& parent_Z) const {
    parent_Z.clear();
    S_F parent_Z0;
    cforeach (Rs, it, rules)
      parent_Z[it->cats[0]] = 0.0;
    F delta = 1;
    U niterations = 0;
    while (delta > closetolerance) {
      ++niterations;
      parent_Z.swap(parent_Z0);
      foreach (S_F, it, parent_Z)
	it->second = 0.0;
      cforeach (Rs, it, rules) {
	F p = it->theta;
	const Ss& cats = it->cats;
	for (Ss::const_iterator cit = cats.begin()+1; cit != cats.end(); ++cit) 
	  p *= dfind(parent_Z0, *cit, 1.0);
	parent_Z[cats.front()] += p;
      }
      delta = 0;
      cforeach (S_F, it, parent_Z) {
	F delta0 = it->second - parent_Z0[it->first];
	assert(delta0 >= 0.0);
	delta = std::max(delta, delta0);
      }
    }
    if (debug > 900)
      std::cerr << "partition_function()::niterations = " << niterations << ", " << std::flush;
  }  // pcfg_type::partition_function()

  //! partition_function() computes the partition function of the
  //!  start symbol
  //
  F partition_function() const {
    if (inconsistent()) {
      S_F parent_Z;
      partition_function(parent_Z);
      return afind(parent_Z, start());
    }
    else
      return 1.0;
  }  // pcfg_type::partition_function()

  //! mindex() returns the location of the pair (i,j) in a matrix
  //
  U mindex(U i, U j) const { return nnonterms*i + j; }

  //! rewrite_matrix() returns the expected number of times each 
  //!  nonterminal rewrites to another
  //
  void rewrite_matrix(Fs& m) const {
    m.clear();
    m.resize(nnonterms*nnonterms, 0.0);
    cforeach (Rs, it, rules) {
      F p = it->theta;
      U parent = afind(nonterm_id, it->cats.front());
      const Ss& cats = it->cats;
      for (Ss::const_iterator cit = cats.begin()+1; cit != cats.end(); ++cit) {
	S_U::const_iterator nit = nonterm_id.find(*cit);
	if (nit != nonterm_id.end())
	  m[mindex(parent,nit->second)] += p;
      }
    }
  } // pcfg_type::rewrite_matrix()

  //! inconsistent() returns true if the grammar is inconsistent
  //
  bool inconsistent() const {
    // PCFG is inconsistent if partition function is less than 1
    // return partition_function() + closetolerance < 1.0;
    
    // PCFG is consistent if principal eigenvalue of rewrite matrix is less than 1.0
    Fs rwm(nnonterms*nnonterms, 0.0);
    rewrite_matrix(rwm);
    Fs rw0(nnonterms, 0.0);
    Fs rw1(nnonterms);
    rw1[afind(nonterm_id, start())] = 1.0;
    F sqrtsumsq = 0;  // converges to principal eigenvector
    for (U niters = 0; niters <= 2*nnonterms; ++niters) {
      rw0.swap(rw1);
      rw1.clear();
      rw1.resize(nnonterms, 0.0);
      F sumsq = 0.0;
      for (U child = 0; child < nnonterms; ++child) {
	for (U parent = 0; parent < nnonterms; ++parent)
	  rw1[child] += rw0[parent] * rwm[mindex(parent,child)];
	sumsq += rw1[child]*rw1[child];
      }
      if (sumsq < closetolerance)
	return false;
      sqrtsumsq = sqrt(sumsq);
      for (U child = 0; child < nnonterms; ++child)
	rw1[child] /= sqrtsumsq;
    }
    return (sqrtsumsq >= 1.0);
  }  // pcfg_type::inconsistent()

  //! write_averages() writes out average rule thetas and counts
  //
  std::ostream& write_averages(std::ostream& os) const {
    os << "#AvTheta";
    if (debug >= 10)
      os << "\tAvCount";
    os << "\tRule" << std::endl;
    cforeach (Rs, it, rules) {
      os << it->sum_theta/nsamples << '\t';
      if (debug >= 10)
	os << it->sum_count/nsamples << '\t';
      os << *it << std::endl;
    }
    return os;
  }  // pcfg_type::write_averages()

};  // pcfg_type{}
  

std::istream& operator>> (std::istream& is, pcfg_type& g) {
  g.rules.clear();
  S_F parent_sum;   // used to normalize theta
  std::string parent;
  F theta, alpha;
  while (is 
	 >> default_value(alpha, g.default_alpha)
	 >> default_value(theta, g.default_theta) 
	 >> parent >> " -->") {
    Ss rhs;
    readline_symbols(is, rhs);
    g.rules.push_back(R(parent, rhs, theta, alpha));
    parent_sum[parent] += theta;
  }

  foreach (Rs, it, g.rules)
    it->theta /= afind(parent_sum, it->parent());  // normalize theta;

  g.index();  // construct indices

  return is;
} 

std::ostream& operator<< (std::ostream& os, const pcfg_type& g) {
  os << "#AvTheta";
  if (debug > 10)
    os << "\tAvCount";
  os << "\tRule" << std::endl;
  cforeach (Rs, it, g.rules)
    if (it->theta > 0.0) {
      os << it->theta << '\t';
      if (debug >= 10)
	os << it->count << '\t';
      os << *it << std::endl;
    }
  return os;
}

namespace EXT_NAMESPACE {
  template <> struct hash<Stit> {
    size_t operator()(const Stit t) const
    {
      return size_t(&(*t));
    }  // ext::hash<Stit>::operator()
  };  // ext::hash<Stit>{}
}  // namespace EXT_NAMESPACE


struct cky_type {

  pcfg_type& g;
  F anneal;         // annealing factor (1 = no annealing)
  
  cky_type(pcfg_type& g, F anneal=1) : g(g), anneal(anneal) { }

  //! index() returns the location of cell in cells[]
  //
  static U index(U i, U j) { return j*(j-1)/2+i; }

  //! ncells() returns the number of cells required for sentence of length n
  //
  static U ncells(U n) { return n*(n+1)/2; }

  typedef std::vector<S_F> S_Fs;
  typedef ext::hash_map<Stit,F> Stit_F;
  typedef std::vector<Stit_F> Stit_Fs;
  
  Ss terminals;
  S_Fs inactives;
  Stit_Fs actives;

  //! inside() constructs the inside table, and returns the probability
  //! of the start symbol rewriting to the terminals.
  //
  template <typename terminals_type>
  F inside(const terminals_type& terminals0) {

    terminals = terminals0;

    if (debug >= 100000)
      std::cerr << "# cky::inside() terminals = " << terminals << std::endl;

    U n = terminals.size();

    inactives.clear();
    inactives.resize(ncells(n));
    actives.clear();
    actives.resize(ncells(n));

    for (U i = 0; i < n; ++i) {
      inactives[index(i,i+1)][terminals[i]] = 1;
      inside_unaryclose(inactives[index(i,i+1)], actives[index(i,i+1)]);
      
      if (debug >= 200000)
	std::cerr << "# cky::inside() inactives[" << i << "," << i+1 << "] = " 
		  << inactives[index(i,i+1)] << std::endl;
      if (debug >= 201000)
	std::cerr << "# cky::inside() actives[" << i << "," << i+1 << "] = " 
		  << actives[index(i,i+1)] << std::endl;
    }

    for (U gap = 2; gap <= n; ++gap)
      for (U left = 0; left + gap <= n; ++left) {
	U right = left + gap;
	S_F& parentinactives = inactives[index(left,right)];
	Stit_F& parentactives = actives[index(left,right)];
	for (U mid = left+1; mid < right; ++mid) {
	  Stit_F& leftactives = actives[index(left,mid)];
	  const S_F& rightinactives = inactives[index(mid,right)];
	  cforeach (Stit_F, itleft, leftactives) {
	    const Stit leftactive = itleft->first;
	    const F leftprob = itleft->second;
	    cforeach (S_F, itright, rightinactives) {
	      S rightinactive = itright->first;
	      const F rightprob = itright->second;
	      F leftrightprob = leftprob * rightprob;
	      if (leftrightprob > 0) {
		const Stit parentactive = leftactive->find1(rightinactive);
		if (parentactive != leftactive->end()) {
		  cforeach (S_Rp, itparent, parentactive->data) {
		    S parent = itparent->first;
		    F leftrightruleprob = leftrightprob 
		      * power(itparent->second->prob(), anneal);
		    if (leftrightruleprob > 0)
		      parentinactives[parent] += leftrightruleprob;
		  }
		  if (!parentactive->key_trie.empty())
		    parentactives[parentactive] += leftrightprob;
		}
	      }
	    }
	  }
	}
	inside_unaryclose(parentinactives, parentactives);
	if (debug >= 200000)
	  std::cerr << "# cky::inside() inactives[" << left << "," << right 
		    << "] = " << parentinactives << std::endl;
	if (debug >= 201000)
	  std::cerr << "# cky::inside() actives[" << left << "," << right << "] = " 
		    << parentactives << std::endl;
      }

    return dfind(inactives[index(0,n)], g.start());
  }  // cky_type::inside()

  void inside_unaryclose(S_F& inactives, Stit_F& actives) {
    F delta = 1;
    S_F delta_prob1 = inactives;
    S_F delta_prob0;
    while (delta > g.closetolerance) {
      delta = 0;
      delta_prob0.swap(delta_prob1);
      delta_prob1.clear();
      cforeach (S_F, it0, delta_prob0) {
	S child = it0->first;
	S_S_Rp::const_iterator it = g.child_parent_urulep.find(child);
	if (it != g.child_parent_urulep.end()) {
	  const S_Rp& parent_urulep = it->second;
	  cforeach (S_Rp, it1, parent_urulep) {
	    S parent = it1->first;
	    F prob = it0->second * power(it1->second->prob(), anneal);
	    if (prob > 0) {
	      delta_prob1[parent] += prob;
	      delta = std::max(delta, prob/(inactives[parent] += prob));
	    }
	  }
	}
      }
    }
    cforeach (S_F, it0, inactives) {
      Stit it1 = g.rhs_parent_rulep.find1(it0->first);
      if (it1 != g.rhs_parent_rulep.end())
	actives[it1] += it0->second;
    }
  } // cky_type::inside_unaryclose()

 
  //! random_tree() returns a random parse tree for terminals
  //
  tree* random_tree() {
    U n = terminals.size();
    return random_inactive(g.start(), afind(inactives[index(0, n)], g.start()), 0, n);
  }  // cky_type::random_tree

  //! random_inactive() returns a random expansion for an inactive edge
  //
  tree* random_inactive(const S parent, const F parentprob, const U left, const U right,
			tree* next = NULL) const {

    tree* tp = new tree(parent, NULL, next);

    if (left+1 == right && parent == terminals[left])
      return tp;

    const S_F& parentinactives = inactives[index(left, right)];
    F probthreshold = random1() * parentprob;
    F probsofar = 0;

    // try unary rules

    cforeach (S_F, it0, parentinactives) {
      S child = it0->first;
      F childprob = it0->second;
      S_S_Rp::const_iterator it1 = g.child_parent_urulep.find(child);
      if (it1 != g.child_parent_urulep.end()) {
	const S_Rp& parent1_urulep = it1->second;
	S_Rp::const_iterator it2 = parent1_urulep.find(parent);
	if (it2 != parent1_urulep.end()) {
	  probsofar += childprob * power(it2->second->prob(), anneal);
	  if (probsofar >= probthreshold) {
	    tp->child = random_inactive(child, childprob, left, right);
	    return tp;
	  }
	}
      }
    }

    // try binary rules

    for (U mid = left+1; mid < right; ++mid) {
      const Stit_F& leftactives = actives[index(left,mid)];
      const S_F& rightinactives = inactives[index(mid,right)];
      cforeach (Stit_F, itleft, leftactives) {
	const Stit leftactive = itleft->first;
	const F leftprob = itleft->second;
	cforeach (S_F, itright, rightinactives) {
	  S rightinactive = itright->first;
	  const F rightprob = itright->second;
	  const Stit parentactive = leftactive->find1(rightinactive);
	  if (parentactive != leftactive->end()) {
	    S_Rp::const_iterator it = parentactive->data.find(parent);
	    if (it != parentactive->data.end()) {
	      probsofar += leftprob * rightprob 
		* power(it->second->prob(), anneal);
	      if (probsofar >= probthreshold) {
		tp->child = random_active(leftactive, leftprob, left, mid,
					  random_inactive(rightinactive, rightprob, mid, right));
		return tp;
	      }
	    }
	  }
	}
      }
    }

    std::cerr << "## Error in cky_type::random_inactive(), parent = " << parent
	      << ", left = " << left << ", right = " << right 
	      << ", probsofar = " << probsofar << ", probthreshold = " << probthreshold 
	      << std::endl;
    return tp;
  }  // cky_type::random_inactive()

  tree* random_active(const Stit parent, F parentprob, const U left, const U right, 
		      tree* next = NULL) const {
    F probthreshold = random1() * parentprob;
    F probsofar = 0;

    // unary rule
    
    const S_F& parentinactives = inactives[index(left, right)];
    cforeach (S_F, it, parentinactives)
      if (g.rhs_parent_rulep.find1(it->first) == parent) {
	probsofar += it->second;
	if (probsofar >= probthreshold)
	  return random_inactive(it->first, it->second, left, right, next);
	break;  // only one unary child can possibly generate this parent
      }

    // binary rules

    for (U mid = left + 1; mid < right; ++mid) {
      const Stit_F& leftactives = actives[index(left,mid)];
      const S_F& rightinactives = inactives[index(mid,right)];
      cforeach (Stit_F, itleft, leftactives) {
	const Stit leftactive = itleft->first;
	const F leftprob = itleft->second;
	cforeach (S_F, itright, rightinactives) {
	  S rightinactive = itright->first;
	  const F rightprob = itright->second;
	  if (parent == leftactive->find1(rightinactive)) {
	    probsofar += leftprob * rightprob;
	    if (probsofar >= probthreshold) {
	      return random_active(leftactive, leftprob, left, mid,
				   random_inactive(rightinactive, rightprob, mid, right, next));
	    }
	  }
	}
      }
    }

    std::cerr << "## Error in cky_type::random_active(), parent = " << parent
	      << ", left = " << left << ", right = " << right 
	      << ", probsofar = " << probsofar << ", probthreshold = " << probthreshold 
	      << std::endl;
    return NULL;
  }  // cky_type::random_active()

}; // cky_type{}


typedef std::vector<Ss> Sss;
typedef std::vector<tree*> tps_type;

F gibbs_estimate(pcfg_type& g, const Sss& trains, tps_type& tps, 
		 Postreamps& evalcmds, U eval_every,
		 U niterations, 
		 F anneal_start, F anneal_stop, U anneal_its,
		 F z_temp, U z_its,
		 U theta_sample_rate,
		 bool random_order, 
		 U consistency_flag,
		 std::ostream* ruleprob_stream_ptr,
		 std::ostream* rulecount_stream_ptr,
		 std::ostream* analyses_stream_ptr,
		 std::ostream* trace_stream_ptr) {

  if (ruleprob_stream_ptr) {
    *ruleprob_stream_ptr << "Iteration";
    cforeach (Rs, it, g.rules)
      it->write(*ruleprob_stream_ptr << ',');
    *ruleprob_stream_ptr << std::endl;
  }

  if (rulecount_stream_ptr) {
    *rulecount_stream_ptr << "Iteration";
    cforeach (Rs, it, g.rules)
      it->write(*rulecount_stream_ptr << ',');
    *rulecount_stream_ptr << std::endl;
  }

  U n = trains.size();
  assert(tps.size() == n);
  if (theta_sample_rate > n)
    theta_sample_rate = n;      // no point in sampling less than once per iteration

  U nwords = 0;
  cky_type p(g, anneal_start);
  F sum_log2prob = 0;

  // initialize tps with trees; don't learn

  for (unsigned i = 0; i < n; ++i) {
    if (debug >= 1000)
      std::cerr << "# trains[" << i << "] = " << trains[i];

    nwords += trains[i].size();

    if (!tps[i]) {
      F tprob = p.inside(trains[i]);

      if (debug >= 1000)
	std::cerr << ", tprob = " << tprob;
      if (tprob <= 0) {
	std::cerr << "## Error in gibbs_estimate(): parse failure, tprob = " << tprob
		  << ", trains[" << i << "] = " << trains[i] << std::endl;
	std::abort();
      }
      assert(tprob > 0);
      sum_log2prob += log2(tprob);
      tps[i] = p.random_tree();
    }

    g.increment(tps[i], 1);
    
    if (debug >= 1000)
      std::cerr << ", tps[" << i << "] = " << tps[i] << std::endl;
  }

  // collect statistics from the random trees
  typedef std::vector<U> Us;
  Us index(n);
  U unchanged = 0;
  
  for (unsigned i = 0; i < n; ++i) 
    index[i] = i;

  for (U iteration = 0; iteration < niterations; ++iteration) {

    if (random_order)
      std::random_shuffle(index.begin(), index.end());

    if (iteration < anneal_its) 
      p.anneal = anneal_start*power(anneal_stop/anneal_start,F(iteration)/F(anneal_its-1));
    else if (iteration + z_its > niterations) 
      p.anneal = 1.0/z_temp;
    else
      p.anneal = anneal_stop;

    assert(finite(p.anneal));

    F log2prob_corpus = g.log2prob_corpus();
    F loglikelihood = g.loglikelihood();

    if (debug >= 100) {
      std::cerr << "# Iteration " << iteration << ", " 
		<< -log2prob_corpus/(nwords+1e-100) << " bits per word, " 
		<< "log likelihood = " << -loglikelihood << ", "
		<< unchanged << '/' << n << " parses did not change";
      if (p.anneal != 1)
	std::cerr << ", temperature = " << 1/p.anneal;
      std::cerr << '.' << std::endl;
    }

    if (iteration % eval_every == 0) {

      if (trace_stream_ptr) 
	*trace_stream_ptr << iteration << '\t'          // iteration
			  << 1.0/p.anneal << '\t'       // temperature
			  << -log2prob_corpus << '\t'   // - log2 P(corpus)
			  << -loglikelihood << '\t'     // - log likelihood
			  << unchanged << '\t'          // # unchanged parses 
			  << n-unchanged << '\t'        // # changed
			  << g.nsamples << '\t'         // # theta samples
			  << g.nrejects                 // # theta samples rejected
			  << std::endl;
      
      foreach (Postreamps, ecit, evalcmds) {
	pstream::ostream& ec = **ecit;
	for (U i = 0; i < n; ++i) 
	  ec << tps[i] << std::endl;
	ec << std::endl;
      }
    }

    sum_log2prob = 0;
    unchanged = 0;

    if (theta_sample_rate == 0) 
      g.sample_theta(n, consistency_flag);

    for (U i0 = 0; i0 < n; ++i0) {
      
      if (theta_sample_rate == 1 
	  || ( theta_sample_rate > 0 && ((iteration*n+i0) % theta_sample_rate) == 0))
	g.sample_theta(n, consistency_flag);
      
      U i = index[i0];
      if (debug >= 1000)
	std::cerr << "\n# trains[" << i << "] = " << trains[i];

      tree* tp0 = tps[i];

      F tprob = p.inside(trains[i]);       // parse string
      if (tprob <= 0) 
	std::cerr << "## Error in gibbs_estimate(): tprob = " << tprob
		  << ", iteration = " << iteration 
		  << ", trains[" << i << "] = " << trains[i] << std::endl
		  << "## g = " << g << std::endl;
      assert(tprob > 0);
      if (debug >= 1000)
	std::cerr << ", tprob = " << tprob;
      sum_log2prob += log2(tprob);
      
      tree* tp1 = p.random_tree();

      if (*tp0 == *tp1) {
	++unchanged;
	delete tp1;
	if (debug >= 1000)
	  std::cerr << ", tp0 == tp1" << std::flush;
      }
      else { 
	tps[i] = tp1;
	g.increment(tp0, -1);
	g.increment(tp1, 1);
	delete tp0;
      }

      if (debug >= 1000)
	std::cerr << ", tps[" << i << "] = " << tps[i] << std::endl;
    }

    if (ruleprob_stream_ptr && iteration >= anneal_its && iteration%eval_every == 0) {
      *ruleprob_stream_ptr << iteration;
      cforeach (Rs, it, g.rules)
	*ruleprob_stream_ptr << ',' << it->theta;
      *ruleprob_stream_ptr << std::endl;
    }

    if (rulecount_stream_ptr && iteration >= anneal_its && iteration%eval_every == 0) {
      *rulecount_stream_ptr << iteration;
      cforeach (Rs, it, g.rules)
	*rulecount_stream_ptr << ',' << it->count;
      *rulecount_stream_ptr << std::endl;
    }

    if (iteration >= anneal_its) 
      g.sum_theta_count();
  }
  
  F log2prob_corpus = g.log2prob_corpus();
  F loglikelihood = g.loglikelihood();

  if (debug >= 10) 
    std::cerr << "# After " << niterations << " iterations, " 
	      << -log2prob_corpus/(nwords+1e-100) << " bits per word, " 
	      << "log likelihood = " << -loglikelihood << ", "
	      << unchanged << '/' << n << " parses did not change, "
	      << g.nsamples << " theta samples, " 
	      << g.nrejects << " rejected theta samples."
	      << std::endl;

  if (analyses_stream_ptr)
    for (U i = 0; i < n; ++i)
      *analyses_stream_ptr << tps[i] << std::endl;

  if (niterations % eval_every == 0)  {

    if (trace_stream_ptr)
      *trace_stream_ptr << niterations << '\t'          // iteration
			<< 1.0/p.anneal << '\t'         // temperature
			<< -log2prob_corpus << '\t'     // - log2 P(corpus)
			<< -loglikelihood << '\t'       // - log likelihood
			<< unchanged << '\t'            // # unchanged parses 
			<< n-unchanged << '\t'          // # changed
			<< g.nsamples << '\t'           // # theta samples
			<< g.nrejects                   // # theta samples rejected
			<< std::endl;

    foreach (Postreamps, ecit, evalcmds) {
      pstream::ostream& ec = **ecit;
      for (U i = 0; i < n; ++i) 
	ec << tps[i] << std::endl;
      ec << std::endl;
    }
  }

  for (U i = 0; i < n; ++i) 
    delete tps[i];

  return log2prob_corpus;
}  // gibbs_estimate()



int main(int argc, char** argv) {

  pcfg_type g;
  bool random_order = true;
  U niterations = 100;
  U theta_sample_rate = 0;
  F anneal_start = 1;
  F anneal_stop = 1;
  U anneal_its = 0;
  U consistency_flag = 0;
  bool tree_initialize = false;
  F z_temp = 1;
  U z_its = 0;
  unsigned long rand_init = 0;
  std::ostream* grammar_stream_ptr = NULL;
  std::ostream* ruleprob_stream_ptr = NULL;
  std::ostream* rulecount_stream_ptr = NULL;
  std::ostream* analyses_stream_ptr = NULL;
  std::ostream* trace_stream_ptr = NULL;
  Postreamps evalcmds;
  U eval_every = 1;
  
  int chr;
  while ((chr = getopt(argc, argv, "A:C:F:G:IP:RS:T:U:X:Z:a:c:d:m:n:r:t:w:x:z:")) != -1)
    switch (chr) {
    case 'A':
      analyses_stream_ptr = new std::ofstream(optarg);
      break;
    case 'C':
      consistency_flag = strtoul(optarg, NULL, 10);
      break;
    case 'F':
      trace_stream_ptr = new std::ofstream(optarg);
      break;
    case 'G':
      grammar_stream_ptr = new std::ofstream(optarg);
      break;
    case 'I':
      random_order = false;
      break;
    case 'P':
      ruleprob_stream_ptr = new std::ofstream(optarg);
      break;
    case 'R':
      tree_initialize = true;
      break;
    case 'S':
      theta_sample_rate = strtoul(optarg, NULL, 10);
      break;
    case 'T':
      anneal_start = 1/atof(optarg);
      break;
    case 'U':
      rulecount_stream_ptr = new std::ofstream(optarg);
      break;
    case 'X':
      evalcmds.push_back(new pstream::ostream(optarg));
      break;
    case 'Z':
      z_temp = atof(optarg);
      break;
    case 'a':
      g.default_alpha = atof(optarg);
      break;
    case 'c':
      g.closetolerance = atof(optarg);
      break;
    case 'd':
      debug = atoi(optarg);
      break;
    case 'm':
      anneal_its = atoi(optarg);
      break;
    case 'n':
      niterations = atoi(optarg);
      break;
    case 'r':
      rand_init = strtoul(optarg, NULL, 10);
      break;
    case 't':
      anneal_stop = 1/atof(optarg);
      break;
    case 'w':
      g.default_theta = atof(optarg);
      break;
    case 'x':
      eval_every = atoi(optarg);
      break;
    case 'z':
      z_its = atoi(optarg);
      break;
    default:
      std::cerr << "# Error in " << argv[0] 
		<< ": can't interpret argument -" << char(chr) << std::endl;
      std::cerr << usage << std::endl;
      exit(EXIT_FAILURE);
    }

  if (argc - optind != 1) {
    std::cerr << "# Error in " << argv[0] 
	      << ":\n# expected a single grammar file as command-line argument,\n# found "
	      << argc - optind << " command-line arguments instead.\n" 
	      << usage << std::endl;
    std::abort();
  }

  {
    std::ifstream is(argv[optind]);
    is >> g;
  }

  if (rand_init == 0)
    rand_init = time(NULL);

  mt_init_genrand(rand_init);
      
  if (trace_stream_ptr)
    *trace_stream_ptr << "# I = " << random_order 
		      << ", n = " << niterations
		      << ", C = " << consistency_flag
		      << ", S = " << theta_sample_rate
		      << ", a = " << g.default_alpha
		      << ", w = " << g.default_theta
		      << ", c = " << g.closetolerance
		      << ", m = " << anneal_its
		      << ", Z = " << z_temp
		      << ", z = " << z_its
		      << ", T = " << 1.0/anneal_start
		      << ", t = " << anneal_stop
		      << ", r = " << rand_init
		      << std::endl
		      << "#iteration temperature -logP -logL unchanged changed nsamples nrejects" 
		      << std::endl;

  assert(anneal_its < niterations);
  assert(consistency_flag <= 2);

  if (debug >= 1000)
    std::cerr << "# gibbs-pcfg\n# Initial grammar:\n" << g << std::endl;

  Sss trains;
  tps_type trees;
  
  if (tree_initialize) {
    Ss terminals;
    tree* tp;
    while (std::cin >> tp) {
      trees.push_back(tp);
      terminals.clear();
      tp->terminals(terminals, true);
      trains.push_back(terminals);
    }
  }
  else { 
    Ss terminals;
    while (readline_symbols(std::cin, terminals)) {
      trains.push_back(terminals);
      trees.push_back(NULL);
    }
  }
  
  if (debug >= 1000)
    std::cerr << "# trains.size() = " << trains.size() << std::endl;

  cky_type parser(g);

  gibbs_estimate(g, trains, trees, evalcmds, eval_every, niterations, 
		 anneal_start, anneal_stop, anneal_its, z_temp, z_its,
		 theta_sample_rate, random_order, consistency_flag,
		 ruleprob_stream_ptr, rulecount_stream_ptr, 
		 analyses_stream_ptr, trace_stream_ptr);

  if (grammar_stream_ptr)
    g.write_averages(*grammar_stream_ptr);

  if (debug >= 20)
    g.write_averages(std::cout) << std::flush;

  delete ruleprob_stream_ptr;
  delete trace_stream_ptr;
  delete analyses_stream_ptr;
  delete grammar_stream_ptr;
}
