// hastings-pcfg.cc
//
// PCFG estimator that integrates out the rule probabilities

const char usage[] =
"hastings-pcfg -- Unsupervised PCFG estimator using collapsed Gibbs sampler\n"
"\n"
" (c) Mark Johnson, version of 5th April 2008\n"
"\n"
"Usage: hastings-pcfg [-d debug] [-F trace-file]\n"
"         [-A parsefile] [-G grammarfile] [-H] [-I]\n"
"         [-r rand-init] [-n niterations] [-w weight]\n"
"         [-T anneal-temp-start] [-t anneal-temp-stop]\n"
"         [-m anneal-its] [-Z z-temp] [-z z-its]\n"
"         [-X eval-cmd] [-x eval-every]\n"
"         grammar.lt < train.yld\n"
"\n"
" -d debug       -- debug level\n"
" -F trace-file  -- file to write trace output to\n"
" -r rand-init   -- initializer for random number generator (integer)\n"
" -n niterations -- number of iterations\n"
" -w weight      -- default rule weight\n"
" -A parsefile   -- print analyses of training data to parsefile at termination\n"
" -G grammarfile -- print out grammar to grammarfile at termination\n"
" -H             -- skip Hastings correction of tree probabilities\n"
" -I             -- parse sentences in order (default is random order)\n"
" -T             -- start at this annealing temperature\n"
" -t             -- stop with this annealing temperature\n"
" -m             -- anneal for this many iterations\n"
" -Z z-temp      -- set Z-temp\n"
" -z z-its       -- perform z-its at temperature Z-temp at end\n"
" -X eval-cmd    -- pipe sampled parses into eval-cmd\n"
" -x eval-every  -- pipe sampled parses every eval-every iterations\n"
"\n";
             
#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstdlib>
#include <ext/hash_map>
#include <iostream>
#include <map>
#include <sstream>
#include <string>
#include <unistd.h>
#include <utility>
#include <vector>

#include "mt19937ar.h"
#include "pstream.h"
#include "sym.h"
#include "tree.h"
#include "trie.h"
#include "utility.h"

int debug = 0;

inline float power(float x, float y) { return powf(x, y); }
inline double power(double x, double y) { return pow(x, y); }
inline long double power(long double x, long double y) { return powl(x, y); }

typedef unsigned U;
typedef long double F;  // slower than double, but underflows less
typedef symbol S;
typedef std::vector<S> Ss;

typedef std::map<S,F> S_F;
// typedef ext::hash_map<S,F> S_F;

typedef pstream::ostream* Postreamp;
typedef std::vector<Postreamp> Postreamps;

//! readline_symbols() reads all of the symbols on the current
//! line into syms
//
std::istream& readline_symbols(std::istream& is, Ss& syms) {
  syms.clear();
  std::string line;
  if (std::getline(is, line)) {
    std::istringstream iss(line);
    std::string s;
    while (iss >> s)
      syms.push_back(s);
  }
  return is;
}  // readline_symbols()


//! A default_value_type{} object is used to read an object from a stream,
//! assigning a default value if the read fails.  Users should not need to
//! construct such objects, but should use default_value() instead.
//
template <typename object_type, typename default_type>
struct default_value_type {
  object_type& object;
  const default_type defaultvalue;
  default_value_type(object_type& object, const default_type defaultvalue)
    : object(object), defaultvalue(defaultvalue) { }
};

//! default_value() is used to read an object from a stream, assigning a
//! default value if the read fails.  It returns a default_value_type{}
//! object, which does the actual reading.
//
template <typename object_type, typename default_type>
default_value_type<object_type,default_type>
default_value(object_type& object, const default_type defaultvalue=default_type()) {
  return default_value_type<object_type,default_type>(object, defaultvalue);
}

//! This version of operator>>() reads default_value_type{} from an input stream.
//
template <typename object_type, typename default_type>
std::istream& operator>> (std::istream& is, 
			  default_value_type<object_type, default_type> dv) {
  if (is) {
    if (is >> dv.object)
      ;
    else {
      is.clear(is.rdstate() & ~std::ios::failbit);  // clear failbit
      dv.object = dv.defaultvalue;
    }
  }
  return is;
}

// inline F random1() { return rand()/(RAND_MAX+1.0); }
inline F random1() { return mt_genrand_res53(); }


//! A pcfg_type is a PCFG and a CKY parser
//
struct pcfg_type {

  pcfg_type(F default_weight=1) : default_weight(default_weight) { }

  typedef unsigned int U;

  typedef ext::hash_map<S,S_F> S_S_F;

  typedef trie<S, S_F> St_S_F;
  typedef St_S_F::const_iterator Stit;

  //! start is the start symbol of the grammar
  //
  S start;

  //! rhs_parent_weight maps the right-hand sides of rules
  //! to rule parent and rule weight 
  //
  St_S_F rhs_parent_weight;

  //! unarychild_parent_weight maps unary children to a vector
  //! of parent-weight pairs
  //
  S_S_F unarychild_parent_weight;

  //! parent_weight maps parents to the sum of their rule weights
  //
  S_F parent_weight;

  //! default_weight is the default weight for rules with no explicit
  //! weight
  //
  F default_weight;

  typedef std::pair<S,Ss> SSs;
  typedef ext::hash_map<SSs,F> SSs_F;

  //! rule_priorweight maps rule = [parent|rhs] to its prior weight
  //
  SSs_F rule_priorweight;

  //! parent_priorweight maps parents to their prior weight
  //
  S_F parent_priorweight;

  //! rule_weight() returns the weight of rule parent --> rhs
  //
  template <typename rhs_type>
  F rule_weight(S parent, const rhs_type& rhs) const {
    assert(!rhs.empty());
    if (rhs.size() == 1) {
      S_S_F::const_iterator it = unarychild_parent_weight.find(rhs[0]);
      if (it == unarychild_parent_weight.end())
	return 0;
      else
	return dfind(it->second, parent);
    }
    else {  // rhs.size() > 1
      Stit it = rhs_parent_weight.find(rhs);
      if (it == rhs_parent_weight.end())
	return 0;
      else
	return dfind(it->data, parent);
    }
  }  // pcfg_type::rule_weight()

  //! rule_prob() returns the probability of rule parent --> rhs
  //
  template <typename rhs_type>
  F rule_prob(S parent, const rhs_type& rhs) const {
    assert(!rhs.empty());
    F parentweight = afind(parent_weight, parent);
    F ruleweight = rule_weight(parent, rhs);
    assert(ruleweight > 0);
    return ruleweight/parentweight;
  }  // pcfg_type::rule_prob()

  //! tree_prob() returns the probability of the tree under the current
  //! model
  //
  F tree_prob(const tree* tp) const {
    F prob = 1;
    if (!tp->child) 
      return prob;
    Ss children;
    for (const tree* child = tp->child; child != NULL; child = child->next) {
      children.push_back(child->label.cat);
      prob *= tree_prob(child);
    }
    prob *= rule_prob(tp->label.cat, children);
    if (prob <= 0)
      std::cerr << "## pcfg_type::tree_prob(" << tp << ") = " << prob << std::endl;
    return prob;
  }  // pcfg_type::tree_prob()

  //! increment() increments the weight of the rule parent --> rhs,
  //! returning the probability of this rule under incrementally modified
  //! grammar
  //
  template <typename rhs_type>
  F increment(S parent, const rhs_type& rhs, F weight = 1) {
    assert(!rhs.empty());
    F weight1;
    F parentweight = (parent_weight[parent] += weight);
    assert(parentweight >= 0);
    if (parentweight == 0)
      parent_weight.erase(parent);
    if (rhs.size() == 1) {
      S_F& parent1_weight = unarychild_parent_weight[rhs[0]];
      weight1 = (parent1_weight[parent] += weight);
      assert(weight1 >= 0);
      if (weight1 == 0) {
	parent1_weight.erase(parent);
	if (parent1_weight.empty())
	  unarychild_parent_weight.erase(rhs[0]);
      }
    }
    else {  // non-unary rule
      S_F& parent1_weight = rhs_parent_weight[rhs];
      weight1 = (parent1_weight[parent] += weight);
      if (weight1 == 0) {
	parent1_weight.erase(parent);
	if (parent1_weight.empty())
	  rhs_parent_weight.erase(rhs);
      }
    }
    // if weight > 0, then return rule probability under old counts,
    // otherwise return rule probability under new counts.
    return (weight > 0) ? (weight1-weight)/(parentweight-weight) : weight1/parentweight;
  }  // pcfg_type::increment()

  //! increment() increments the weight of the all of the local trees
  //! in tp, and returns the conditional probability of tp
  //
  F increment(const tree* tp, F weight = 1) {
    F prob = 1;
    if (tp->child) {
      { 
	Ss children;
	for (const tree* child = tp->child; child; child = child->next)
	  children.push_back(child->label.cat);
	prob *= increment(tp->label.cat, children, weight);
      }
      for (const tree* child = tp->child; child; child = child->next)
	prob *= increment(child, weight);
    }
    return prob;
  }  // pcfg_type::increment()

  //! read() reads a grammar from an input stream (implements >> )
  //
  std::istream& read(std::istream& is) {
    start = symbol::undefined();
    F weight;
    std::string parent;
    while (is >> default_value(weight, default_weight) >> parent >> " -->") {
      if (start.is_undefined())
	start = parent;
      Ss rhs;
      readline_symbols(is, rhs);
      increment(parent, rhs, weight);
      rule_priorweight[SSs(parent,rhs)] += weight;
      parent_priorweight[parent] += weight;
    }
    return is;
  }  // pcfg_type::read()

  //! write() writes a grammar (implements << )
  //
  std::ostream& write(std::ostream& os) const {
    assert(start.is_defined());
    write_rules(os, start);
    cforeach (S_F, it, parent_weight)
      if (it->first != start)
	write_rules(os, it->first);
    return os;
  }  // pcfg_type::write()

  std::ostream& write_rules(std::ostream& os, S parent) const {
    rhs_parent_weight.for_each(write_rule(os, parent));
    cforeach (S_S_F, it0, unarychild_parent_weight) {
      S child = it0->first;
      const S_F& parent_weight = it0->second;
      cforeach (S_F, it1, parent_weight)
	if (it1->first == parent)
	  os << it1->second << '\t' << parent 
	     << " --> " << child << std::endl;
    }
    return os;
  }  // pcfg_type::write_rules()

  //! write_rule writes a single rule
  //
  struct write_rule {
    std::ostream& os;
    S parent;

    write_rule(std::ostream& os, symbol parent) : os(os), parent(parent) { }

    template <typename Keys, typename Value>
    void operator() (const Keys& rhs, const Value& parentweights) {
      cforeach (typename Value, pwit, parentweights) 
	if (pwit->first == parent) {
	  os << pwit->second << '\t' << parent << " -->";
	  cforeach (typename Keys, rhsit, rhs)
	    os << ' ' << *rhsit;
	  os << std::endl;
	}
    }  // pcfg_type::write_rule::operator()

  };  // pcfg_type::write_rule{}
  
  F log2prob_corpus() const {
    F logprob = 0;
    cforeach (SSs_F, it, rule_priorweight) {
      const S parent = it->first.first;
      const Ss& rhs = it->first.second;
      F priorweight = it->second;
      F weight = rule_weight(parent, rhs);
      logprob += lgamma(weight);
      logprob -= lgamma(priorweight);
    }
    cforeach (S_F, it, parent_priorweight) {
      S parent = it->first;
      F priorweight = it->second;
      F weight = dfind(parent_weight, parent);
      logprob -= lgamma(weight);
      logprob += lgamma(priorweight);
    }
    return logprob / log(2.0);
  }  // pcfg_type::log2prob_corpus()

  //! loglikelihood() returns the probability of the parse trees
  //! (or equivalently, their counts) given theta estimated from
  //! them.  This is what EM optimizes.
  //
  F loglikelihood() const {
    S_F parent_count;
    F logprob = 0;
    cforeach (SSs_F, it, rule_priorweight) {
      const S parent = it->first.first;
      const Ss& rhs = it->first.second;
      F priorweight = it->second;
      F weight = rule_weight(parent, rhs);
      F count = weight-priorweight;
      if (count > 0) {
	logprob += count*log(count);
	parent_count[parent] += count;
      }
    }
    
    cforeach (S_F, it, parent_count)
      logprob -= it->second*log(it->second);
    return logprob;
  } // pcfg_type::loglikelihood()

};  // pcfg_type{}


//! operator>> (pcfg_type&) reads a pcfg_type g, setting g.start
//! to the parent of the first rule read.
//
std::istream& operator>> (std::istream& is, pcfg_type& g) {
  return g.read(is);
}  // operator>> (pcfg_type&)


std::ostream& operator<< (std::ostream& os, const pcfg_type& g) {
  return g.write(os);
}  // operator<< (pcfg_type&)

namespace EXT_NAMESPACE {
  template <> struct hash<pcfg_type::Stit> {
    size_t operator()(const pcfg_type::Stit t) const
    {
      return size_t(&(*t));
    }  // ext::hash<pcfg_type::Stit>::operator()
  };  // ext::hash<pcfg_type::Stit>{}
}  // namespace EXT_NAMESPACE

static const F unaryclosetolerance = 1e-7;

class cky_type {

public:

  const pcfg_type& g;
  F anneal;         // annealing factor (1 = no annealing)
  
  cky_type(const pcfg_type& g, F anneal=1) : g(g), anneal(anneal) { }

  typedef pcfg_type::U U;
  typedef pcfg_type::S_S_F S_S_F;
  typedef pcfg_type::St_S_F St_S_F;
  typedef pcfg_type::Stit Stit;

  typedef std::vector<S_F> S_Fs;
  typedef ext::hash_map<Stit,F> Stit_F;
  typedef std::vector<Stit_F> Stit_Fs;

  //! index() returns the location of cell in cells[]
  //
  static U index(U i, U j) { return j*(j-1)/2+i; }

  //! ncells() returns the number of cells required for sentence of length n
  //
  static U ncells(U n) { return n*(n+1)/2; }
  
  Ss terminals;
  S_Fs inactives;
  Stit_Fs actives;

  //! inside() constructs the inside table, and returns the probability
  //! of the start symbol rewriting to the terminals.
  //
  template <typename terminals_type>
  F inside(const terminals_type& terminals0) {

    terminals = terminals0;

    if (debug >= 10000)
      std::cerr << "# cky::inside() terminals = " << terminals << std::endl;

    U n = terminals.size();

    inactives.clear();
    inactives.resize(ncells(n));
    actives.clear();
    actives.resize(ncells(n));

    for (U i = 0; i < n; ++i) {
      inactives[index(i,i+1)][terminals[i]] = 1;
      inside_unaryclose(inactives[index(i,i+1)], actives[index(i,i+1)]);
      
      if (debug >= 20000)
	std::cerr << "# cky::inside() inactives[" << i << "," << i+1 << "] = " 
		  << inactives[index(i,i+1)] << std::endl;
      if (debug >= 20100)
	std::cerr << "# cky::inside() actives[" << i << "," << i+1 << "] = " 
		  << actives[index(i,i+1)] << std::endl;
    }

    for (U gap = 2; gap <= n; ++gap)
      for (U left = 0; left + gap <= n; ++left) {
	U right = left + gap;
	S_F& parentinactives = inactives[index(left,right)];
	Stit_F& parentactives = actives[index(left,right)];
	for (U mid = left+1; mid < right; ++mid) {
	  Stit_F& leftactives = actives[index(left,mid)];
	  const S_F& rightinactives = inactives[index(mid,right)];
	  cforeach (Stit_F, itleft, leftactives) {
	    const Stit leftactive = itleft->first;
	    const F leftprob = itleft->second;
	    cforeach (S_F, itright, rightinactives) {
	      S rightinactive = itright->first;
	      const F rightprob = itright->second;
	      const Stit parentactive = leftactive->find1(rightinactive);
	      if (parentactive != leftactive->end()) {
		F leftrightprob = leftprob * rightprob;
		cforeach (S_F, itparent, parentactive->data) {
		  S parent = itparent->first;
		  parentinactives[parent] += leftrightprob 
		    * power(itparent->second/afind(g.parent_weight, parent), anneal);
		}
		if (!parentactive->key_trie.empty())
		  parentactives[parentactive] += leftrightprob;
	      }
	    }
	  }
	}
	inside_unaryclose(parentinactives, parentactives);
	if (debug >= 20000)
	  std::cerr << "# cky::inside() inactives[" << left << "," << right 
		    << "] = " << parentinactives << std::endl;
	if (debug >= 20100)
	  std::cerr << "# cky::inside() actives[" << left << "," << right << "] = " 
		    << parentactives << std::endl;
      }

    return dfind(inactives[index(0,n)], g.start);
  }  // cky_type::inside()

  void inside_unaryclose(S_F& inactives, Stit_F& actives) {
    F delta = 1;
    S_F delta_prob1 = inactives;
    S_F delta_prob0;
    while (delta > unaryclosetolerance) {
      delta = 0;
      delta_prob0.swap(delta_prob1);
      delta_prob1.clear();
      cforeach (S_F, it0, delta_prob0) {
	S child = it0->first;
	S_S_F::const_iterator it = g.unarychild_parent_weight.find(child);
	if (it != g.unarychild_parent_weight.end()) {
	  const S_F& parent_weight = it->second;
	  cforeach (S_F, it1, parent_weight) {
	    S parent = it1->first;
	    F prob = it0->second * power(it1->second/afind(g.parent_weight, parent), anneal);
	    delta_prob1[parent] += prob;
	    delta = std::max(delta, prob/(inactives[parent] += prob));
	  }
	}
      }
    }
    cforeach (S_F, it0, inactives) {
      Stit it1 = g.rhs_parent_weight.find1(it0->first);
      if (it1 != g.rhs_parent_weight.end())
	actives[it1] += it0->second;
    }
  } // cky_type::inside_unaryclose()

 
  //! random_tree() returns a random parse tree for terminals
  //
  tree* random_tree() {
    U n = terminals.size();
    return random_inactive(g.start, afind(inactives[index(0, n)], g.start), 0, n);
  }  // cky_type::random_tree

  //! random_inactive() returns a random expansion for an inactive edge
  //
  tree* random_inactive(const S parent, const F parentprob, const U left, const U right,
			tree* next = NULL) const {

    tree* tp = new tree(parent, NULL, next);

    if (left+1 == right && parent == terminals[left])
      return tp;

    const S_F& parentinactives = inactives[index(left, right)];
    F weightthreshold = random1() * parentprob;
    F parentweight = afind(g.parent_weight, parent);
    F weightsofar = 0;

    // try unary rules

    cforeach (S_F, it0, parentinactives) {
      S child = it0->first;
      F childprob = it0->second;
      S_S_F::const_iterator it1 = g.unarychild_parent_weight.find(child);
      if (it1 != g.unarychild_parent_weight.end()) {
	const S_F& parent1_weight = it1->second;
	weightsofar += childprob 
	  * power(dfind(parent1_weight, parent)/parentweight, anneal);
	if (weightsofar >= weightthreshold) {
	  tp->child = random_inactive(child, childprob, left, right);
	  return tp;
	}
      }
    }

    // try binary rules

    for (U mid = left+1; mid < right; ++mid) {
      const Stit_F& leftactives = actives[index(left,mid)];
      const S_F& rightinactives = inactives[index(mid,right)];
      cforeach (Stit_F, itleft, leftactives) {
	const Stit leftactive = itleft->first;
	const F leftprob = itleft->second;
	cforeach (S_F, itright, rightinactives) {
	  S rightinactive = itright->first;
	  const F rightprob = itright->second;
	  const Stit parentactive = leftactive->find1(rightinactive);
	  if (parentactive != leftactive->end()) {
	    S_F::const_iterator it = parentactive->data.find(parent);
	    if (it != parentactive->data.end()) {
	      weightsofar += leftprob * rightprob 
		* power(it->second/parentweight, anneal);
	      if (weightsofar >= weightthreshold) {
		tp->child = random_active(leftactive, leftprob, left, mid,
					  random_inactive(rightinactive, rightprob, mid, right));
		return tp;
	      }
	    }
	  }
	}
      }
    }

    std::cerr << "## Error in cky_type::random_inactive(), parent = " << parent
	      << ", left = " << left << ", right = " << right 
	      << ", weightsofar = " << weightsofar << ", weightthreshold = " << weightthreshold 
	      << std::endl;
    return tp;
  }  // cky_type::random_inactive()

  tree* random_active(const Stit parent, F parentprob, const U left, const U right, 
		      tree* next = NULL) const {
    F probthreshold = random1() * parentprob;
    F probsofar = 0;

    // unary rule
    
    const S_F& parentinactives = inactives[index(left, right)];
    cforeach (S_F, it, parentinactives)
      if (g.rhs_parent_weight.find1(it->first) == parent) {
	probsofar += it->second;
	if (probsofar >= probthreshold)
	  return random_inactive(it->first, it->second, left, right, next);
	break;  // only one unary child can possibly generate this parent
      }

    // binary rules

    for (U mid = left + 1; mid < right; ++mid) {
      const Stit_F& leftactives = actives[index(left,mid)];
      const S_F& rightinactives = inactives[index(mid,right)];
      cforeach (Stit_F, itleft, leftactives) {
	const Stit leftactive = itleft->first;
	const F leftprob = itleft->second;
	cforeach (S_F, itright, rightinactives) {
	  S rightinactive = itright->first;
	  const F rightprob = itright->second;
	  if (parent == leftactive->find1(rightinactive)) {
	    probsofar += leftprob * rightprob;
	    if (probsofar >= probthreshold) {
	      return random_active(leftactive, leftprob, left, mid,
				   random_inactive(rightinactive, rightprob, mid, right, next));
	    }
	  }
	}
      }
    }

    std::cerr << "## Error in cky_type::random_active(), parent = " << parent
	      << ", left = " << left << ", right = " << right 
	      << ", probsofar = " << probsofar << ", probthreshold = " << probthreshold 
	      << std::endl;
    return NULL;
  }  // cky_type::random_active()

}; // cky_type{}

struct S_F_incrementer {
  const F increment;
  S_F_incrementer(F increment) : increment(increment) { }

  template <typename arg_type>
  void operator() (const arg_type& arg, S_F& parent_weights) const
  {
    foreach (S_F, it, parent_weights)
      it->second += increment;
  }
};

typedef std::vector<Ss> Sss;

F gibbs_estimate(pcfg_type& g, const Sss& trains, 
		 Postreamps& evalcmds, U eval_every,
		 U niterations = 100, 
		 F anneal_start = 1, F anneal_stop = 1, U anneal_its = 0,
		 F z_temp = 1.0, U z_its = 0,
		 bool hastings_correction = true, bool random_order = true,
		 std::ostream* analyses_stream_ptr = NULL,
		 std::ostream* trace_stream_ptr = NULL) {

  U n = trains.size();
  U nwords = 0;
  typedef std::vector<tree*> tps_type;
  tps_type tps(n);
  cky_type p(g, anneal_start);
  F sum_log2prob = 0;

  // initialize tps with trees; don't learn

  for (unsigned i = 0; i < n; ++i) {
    if (debug >= 1000)
      std::cerr << "# trains[" << i << "] = " << trains[i];

    nwords += trains[i].size();
    F tprob = p.inside(trains[i]);

    if (debug >= 1000)
      std::cerr << ", tprob = " << tprob;
    if (tprob <= 0) 
      std::cerr << "## Error in gibbs_estimate(): tprob = " << tprob
		<< ", trains[" << i << "] = " << trains[i] << std::endl;

    assert(tprob > 0);
    sum_log2prob += log2(tprob);
    tps[i] = p.random_tree();
    
    if (debug >= 1000)
      std::cerr << ", tps[" << i << "] = " << tps[i] << std::endl;
  }

  // collect statistics from the random trees
  typedef std::vector<U> Us;
  Us index(n);
  U unchanged = 0, rejected = 0;
  
  for (unsigned i = 0; i < n; ++i) {
    g.increment(tps[i]);
    index[i] = i;
  }

  for (U iteration = 0; iteration < niterations; ++iteration) {

    if (random_order)
      std::random_shuffle(index.begin(), index.end());

    if (iteration < anneal_its) 
      p.anneal = anneal_start*power(anneal_stop/anneal_start,F(iteration)/F(anneal_its-1));
    else if (iteration + z_its > niterations) 
      p.anneal = 1.0/z_temp;
    else
      p.anneal = anneal_stop;

    assert(finite(p.anneal));

    F log2prob_corpus = g.log2prob_corpus();
    F loglikelihood = g.loglikelihood();

    if (debug >= 100) {
      std::cerr << "# Iteration " << iteration << ", " 
		<< -log2prob_corpus/nwords << " bits per word, " 
		<< "log likelihood = " << -loglikelihood << ", "
		<< unchanged << '/' << n << " parses did not change";
      if (hastings_correction)
	std::cerr << ", " << rejected << '/' << n-unchanged 
		  << " parses rejected";
      if (p.anneal != 1)
	std::cerr << ", temperature = " << 1/p.anneal;
      std::cerr << '.' << std::endl;
    }

    if (trace_stream_ptr && iteration % eval_every == 0)
      *trace_stream_ptr << iteration << '\t'            // iteration
			<< 1.0/p.anneal << '\t'         // temperature
			<< -log2prob_corpus << '\t'     // - log2 P(corpus)
			<< -loglikelihood << '\t'       // - log likelihood
			<< unchanged << '\t'            // # unchanged parses
			<< n-unchanged << '\t'          // # changed parses
			<< rejected << std::endl;       // # parses rejected

    if (iteration % eval_every == 0)
      foreach (Postreamps, ecit, evalcmds) {
	pstream::ostream& ec = **ecit;
	for (U i = 0; i < n; ++i) 
	  ec << tps[i] << std::endl;
	ec << std::endl;
      }

    if (debug >= 1000)
      std::cerr << g;

    sum_log2prob = 0;
    unchanged = 0;
    rejected = 0;

    for (U i0 = 0; i0 < n; ++i0) {
      U i = index[i0];
      if (debug >= 1000)
	std::cerr << "# trains[" << i << "] = " << trains[i];

      tree* tp0 = tps[i];
      F pi0 = power(g.increment(tp0, -1), p.anneal);
      F r0 = power(g.tree_prob(tp0), p.anneal);

      F tprob = p.inside(trains[i]);       // parse string
      if (tprob <= 0) 
	std::cerr << "## Error in gibbs_estimate(): tprob = " << tprob
		  << ", iteration = " << iteration 
		  << ", trains[" << i << "] = " << trains[i] << std::endl
		  << "## g = " << g << std::endl;
      assert(tprob > 0);
      if (debug >= 1000)
	std::cerr << ", tprob = " << tprob;
      sum_log2prob += log2(tprob);
      
      tree* tp1 = p.random_tree();
      F r1 = power(g.tree_prob(tp1), p.anneal);
      F pi1 = power(g.increment(tp1, 1), p.anneal);

      if (debug >= 1000)
	std::cerr << ", r0 = " << r0 << ", pi0 = " << pi0
		  << ", r1 = " << r1 << ", pi1 = " << pi1 << std::flush;

      if (*tp0 == *tp1) {
	++unchanged;
	delete tp1;
	if (debug >= 1000)
	  std::cerr << ", tp0 == tp1" << std::flush;
      }
      else if (hastings_correction) {
	F accept = (pi1 * r0) / (pi0 * r1);
	if (debug >= 1000)
	  std::cerr << ", accept = " << accept << std::flush;
	if (random1() <= accept) {
	  if (debug >= 1000)
	    std::cerr << ", accepted" << std::flush;
	  tps[i] = tp1;
	  delete tp0;
	}
	else {
	  if (debug >= 1000)
	    std::cerr << ", rejected" << std::flush;
	  g.increment(tp1, -1);
	  g.increment(tp0, 1);
	  delete tp1;
	  ++rejected;
	}
      }
      else {  // no hastings correction
	tps[i] = tp1;
	delete tp0;
      }

      if (debug >= 1000)
	std::cerr << ", tps[" << i << "] = " << tps[i] << std::endl;
    }
  }
  
  F log2prob_corpus = g.log2prob_corpus();
  F loglikelihood = g.loglikelihood();

  if (debug >= 10) {
    std::cerr << "# After " << niterations << " iterations, " 
	      << -log2prob_corpus/nwords << " bits per word, " 
	      << "log likelihood = " << -loglikelihood << ", "
	      << unchanged << '/' << n << " parses did not change";
    if (hastings_correction)
      std::cerr << ", " << rejected << '/' << n-unchanged 
		<< " parses rejected";
    std::cerr << '.' << std::endl;
  }

  if (trace_stream_ptr)
    *trace_stream_ptr << niterations << '\t'            // iteration
		      << 1.0/p.anneal << '\t'         // temperature
		      << -log2prob_corpus << '\t'     // - log2 P(corpus)
		      << -loglikelihood << '\t'       // - log likelihood
		      << unchanged << '\t'            // # unchanged parses
		      << n-unchanged << '\t'          // # changed parses
		      << rejected << std::endl;       // # parses rejected

  if (analyses_stream_ptr)
    for (U i = 0; i < n; ++i)
      *analyses_stream_ptr << tps[i] << std::endl;

  foreach (Postreamps, ecit, evalcmds) {
    pstream::ostream& ec = **ecit;
    for (U i = 0; i < n; ++i) 
      ec << tps[i] << std::endl;
    ec << std::endl;
  }

  for (U i = 0; i < n; ++i) 
    delete tps[i];

  return log2prob_corpus;
}  // gibbs_estimate()



int main(int argc, char** argv) {

  pcfg_type g;
  bool hastings_correction = true;
  bool random_order = true;
  U niterations = 100;
  F anneal_start = 1;
  F anneal_stop = 1;
  U anneal_its = 100;
  F z_temp = 1;
  U z_its = 0;
  unsigned long rand_init = 0;
  std::ostream* grammar_stream_ptr = NULL;
  std::ostream* analyses_stream_ptr = NULL;
  std::ostream* trace_stream_ptr = NULL;
  Postreamps evalcmds;
  U eval_every = 1;

  int chr;
  while ((chr = getopt(argc, argv, "A:F:G:HIT:X:Z:a:d:m:n:r:t:w:x:z:")) != -1)
    switch (chr) {
    case 'A':
      analyses_stream_ptr = new std::ofstream(optarg);
      break;
    case 'F':
      trace_stream_ptr = new std::ofstream(optarg);
      break;
    case 'G':
      grammar_stream_ptr = new std::ofstream(optarg);
      break;
    case 'H':
      hastings_correction = false;
      break;
    case 'I':
      random_order = false;
      break;
    case 'T':
      anneal_start = 1/atof(optarg);
      break;
    case 'X':
      evalcmds.push_back(new pstream::ostream(optarg));
      break;
    case 'Z':
      z_temp = atof(optarg);
      break;
    case 'd':
      debug = atoi(optarg);
      break;
    case 'm':
      anneal_its = atoi(optarg);
      break;
    case 'n':
      niterations = atoi(optarg);
      break;
    case 'r':
      rand_init = strtoul(optarg, NULL, 10);
      break;
    case 't':
      anneal_stop = 1/atof(optarg);
      break;
    case 'w':
      g.default_weight = atof(optarg);
      break;
    case 'x':
      eval_every = atoi(optarg);
      break;
    case 'z':
      z_its = atoi(optarg);
      break;
    default:
      std::cerr << "# Error in " << argv[0] 
		<< ": can't interpret argument -" << char(chr) << std::endl;
      std::cerr << usage << std::endl;
      exit(EXIT_FAILURE);
    }

  if (argc - optind != 1) 
    std::cerr << "# Error in " << argv[0] << '\n' << usage << std::endl;

  {
    std::ifstream is(argv[optind]);
    is >> g;
  }

  if (rand_init == 0)
    rand_init = time(NULL);

  mt_init_genrand(rand_init);
    
  if (trace_stream_ptr) 
    *trace_stream_ptr << "# I = " << random_order 
		      << ", n = " << niterations
		      << ", w = " << g.default_weight
		      << ", m = " << anneal_its
		      << ", Z = " << z_temp
		      << ", z = " << z_its
		      << ", T = " << 1.0/anneal_start
		      << ", t = " << anneal_stop
		      << ", r = " << rand_init
		      << std::endl
		      << "# iteration temperature -logP -logL unchanged changed rejected" 
		      << std::endl;
  
  if (debug >= 1000)
    std::cerr << "# gibbs-pcfg Initial grammar = " << g << std::endl;

  Sss trains;

  { 
    Ss terminals;
    
    while (readline_symbols(std::cin, terminals))
      trains.push_back(terminals);
    
    if (debug >= 1000)
      std::cerr << "# trains.size() = " << trains.size() << std::endl;
  }

  cky_type parser(g);

  gibbs_estimate(g, trains, evalcmds, eval_every, niterations, 
		 anneal_start, anneal_stop, anneal_its, z_temp, z_its,
		 hastings_correction, random_order, analyses_stream_ptr, 
		 trace_stream_ptr);

  if (grammar_stream_ptr)
    *grammar_stream_ptr << g << std::flush;

  delete trace_stream_ptr;
  delete analyses_stream_ptr;
  delete grammar_stream_ptr;
}

