// munge-trees.cc -- transformations on Penn Treebank style trees
//
// (c) Mark Johnson, 22nd December 2012

const char* usage =
"munge-trees [options] < trees > output\n"
"\n"
" (c) Mark Johnson, version of 22nd December, 2012\n"
"\n"
"This program can be freely used for any research or educational purposes.\n"
"However, I do request acknowledgement in any publications that mention data\n"
"produced using this program.\n"
"\n"
"This program munges trees.  It reads in a tree from stdin and passes\n"
"it through a pipeline of operations.  These operations are specified\n"
"by the command-line options.  These are run on each treein the order\n"
"they occur on the command line.\n"
"\n"
"Initialization options: (performed before any tree is read)\n"
"\n"
"  -h   -- print this message and exit\n"
"\n"
"  -A x -- set the parent annotation infix to x (default = \"^\")\n"
"\n"
"  -B x -- set the binarization infix to x (default = \"_\")\n"
"\n"
"  -T t -- set the top node label used in the -n option to t (default = \"TOP\")\n"
"\n"
"  -U u -- set the unknown word terminal symbol to u (default = \"*UNK*\")\n"
"\n"
"Tree processing options:\n"
"  (run after each tree is read, in the order specified on command line)\n"
"\n"
"  -a[C] -- add parent-annotations to nodes in the tree.  If C is absent\n"
"           then all nodes are parent annotated.  If C is present, it should\n"
"           be a list of categories separated by ':'; only these categories\n"
"           will be parent-annotated.  E.g., -aNP:S only annotates NP and S.\n"
"           Binarized nodes never count as parents; i.e., the parent node is\n"
"           the closest non-binarized node.\n"
"\n"
"  -b   -- left binarize the tree\n"
"\n"
"  -c   -- collect the counts of the local trees (i.e., the rules used)\n"
"          in this tree.  These will be printed to stdout when all the\n"
"          trees have been read. .  Rules expanding the TOP node are listed first.\n"
"\n"
"  -d   -- downcase characters in the terminal strings\n"
"\n"
"  -e   -- delete empty \"NONE\" subtrees in the tree\n"
"\n"
"  -f file -- file should contain a list of words.  The preterminal above\n"
"             these words will be annotated with that word.\n"
"\n"
"  -l n -- skip trees with more than n terminals\n"
"\n"
"  -n   -- normalize the tree by replacing empty nonterminal node labels\n"
"          with \"TOP\", trimming characters following a '-' or a '=' in\n"
"          a nonterminal label and adding a '^' to punctuation POS tags\n"
"\n"
"  -p   -- pretty-print the tree to stdout (takes several lines)\n"
"\n"
"  -P   -- remove punctuation terminals and preterminals\n"
"\n"
"  -r   -- remove binarization nodes and parent annotation\n"
"\n"
"  -s   -- remove the terminal nodes (so the preterminals become the terminals)\n"
"\n"
"  -t c -- count occurrences of each terminal.  After all trees are read, then\n"
"          write to stdout all terminals that occured at least c times, in order\n"
"          of frequency (high frequency words first).\n"
"\n"
"  -u file -- replace all terminals that don't appear in file (which should contain\n"
"             white-space separated words) with the unknown-word terminal symbol\n"
"\n"
"  -w   -- write the tree to stdout on one line\n"
"\n"
"  -y   -- write the yield (i.e., the words) of the tree to stdout on one line\n"
"\n"
"  -Y   -- write the preterminal yield (i.e., the POS tags) of the tree to stdout\n"
"             on one line\n"
"\n"
"  -z   -- write out yield (i.e., the words) followed by their preterminals\n"
"            (parts of speech) on one line\n"
"\n"
"Note: the -a, -l, -t and -u options can occur once in the option sequence.\n"
"\n";

#include <algorithm>
#include <cctype>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <map>
#include <set>
#include <sstream>
#include <string>
#include <unistd.h>
#include <utility>
#include <vector>

#include "tree.h"

typedef std::string S;
typedef std::set<S> sS;
typedef std::map<S,unsigned> S_U;
typedef std::vector<S> Ss;
typedef std::map<Ss,unsigned> Ss_U;

typedef tree::tree_type<std::string> T;

//! delete_empties() deletes empty "-NONE-" subtrees.
//
T* delete_empties(T* tp) {
  if (tp == NULL || tp->is_terminal())
    return tp;
  tp->subtrees = delete_empties(tp->subtrees);
  tp->next = delete_empties(tp->next);
  if (tp->label == "-NONE-" || tp->subtrees == NULL) {
    T* next = tp->next;
    tp->next = NULL;
    delete tp;
    return next;
  }
  return tp;
}

//! delete_terminals() deletes terminal nodes, so nonterminals become
//!  the "new" terminals
//
void delete_terminals(T* tp) {
  if (tp == NULL || tp->is_terminal())
    return;
  else if (tp->is_preterminal()) {
    delete tp->subtrees;
    tp->subtrees = NULL;
    delete_terminals(tp->next);
  }
  else {
    delete_terminals(tp->subtrees);
    delete_terminals(tp->next);
  }
}

//! is_punctuation() identifies punctuation labels
//
inline bool is_punctuation(const std::string& s) {
  if (s == "-NONE-")
    return false;
  return s.size() > 0 && ispunct(s[0]);
}

//! delete_punctuation() deletes punctuation preterminals and terminals.
//
T* delete_punctuation(T* tp) {
  if (tp == NULL || tp->is_terminal())
    return tp;
  tp->subtrees = delete_punctuation(tp->subtrees);
  tp->next = delete_punctuation(tp->next);
  if (is_punctuation(tp->label) || tp->subtrees == NULL) {
    T* next = tp->next;
    tp->next = NULL;
    delete tp;
    return next;
  }
  return tp;
}

//! normalize_tree() normalizes a PTB tree by replacing empty labels
//! with "TOP" and trimming the chars following a '-' or a '=' in a label.
//
T* normalize_tree(T* tp, const std::string& parent_infix, const std::string* top_labelp=NULL) {
  if (tp == NULL || tp->is_terminal())
    return tp;
  tp->subtrees = normalize_tree(tp->subtrees, parent_infix);
  tp->next = normalize_tree(tp->next, parent_infix);
  if (top_labelp) 
    tp->label = *top_labelp;
  else if (tp->is_preterminal() && !tp->label.empty() && ispunct(tp->label[0])) {
    tp->label += parent_infix;
  }
  else {
    std::string::size_type pos = tp->label.find_first_of("=-");
    if (pos != std::string::npos && pos != 0 && pos+1 != tp->label.size())
      tp->label.resize(pos);  // trim '-' and following chars
  }
  return tp;
}

//! binarize_tree() makes a PTB strictly binary branching
//
T* binarize_tree(T* tp, const std::string& binarization_infix) {
  if (!tp)
    return NULL;
  tp->subtrees = binarize_tree(tp->subtrees, binarization_infix);
  while (tp->next && tp->next->next) {
    tp->next->subtrees = binarize_tree(tp->next->subtrees, binarization_infix);
    T* tp1 = new T(tp->label+binarization_infix+tp->next->label, tp, tp->next->next);
    tp->next->next = NULL;
    tp = tp1;
  }
  if (tp->next)
    tp->next->subtrees = binarize_tree(tp->next->subtrees, binarization_infix);
  return tp;
}

//! parent_annotate_node_p() determines whether to add parent annotation to
//! a node.  This is complicated by the fact that the node label may already
//! be binarized.
//
bool parent_annotate_node_p(const sS& parentannotations, const std::string& label, const std::string& binarization_infix) {
  for (sS::const_iterator it = parentannotations.begin(); it != parentannotations.end(); ++it) {
    const std::string& plabel = *it;
    if (plabel == label)
      return true;
    std::string::size_type indx = label.find(binarization_infix+plabel+binarization_infix);
    if (indx != std::string::npos)
      return true;
    indx = label.find(plabel+binarization_infix);
    if (indx == 0)
      return true;
    indx = label.find(binarization_infix+plabel);
    if (indx != std::string::npos && indx + plabel.size() + binarization_infix.size() == label.size())
      return true;
  }
  return false;
}

//! parent_annotate() adds parent-annotations to the tree.  This version
//! does not view any node with the binarization_infix as a parent node.
//
T* parent_annotate(T* tp, const sS& parentannotations, const std::string& parent_infix, const std::string& binarization_infix, T* parentp=NULL) {
  if (!tp || tp->is_terminal())
    return tp;
  tp->subtrees = parent_annotate(tp->subtrees, parentannotations, parent_infix, binarization_infix,
				 (tp->label.find(binarization_infix) == std::string::npos ? tp : parentp));
  tp->next = parent_annotate(tp->next, parentannotations, parent_infix, binarization_infix, parentp);
  if (parentp && (parentannotations.empty() || parent_annotate_node_p(parentannotations, tp->label, binarization_infix)))
    (tp->label += parent_infix) += parentp->label;
  return tp;
}
  
//! remove_annotations() removes binarization nodes and parent annotations
//
T* remove_annotations(T* tp, const std::string& parent_infix, const std::string& binarization_infix, T* next=NULL) {
  if (!tp)
    return next;
  if (tp->is_terminal()) {
    assert(next == NULL);
    return tp;
  }
  if (tp->label.find(binarization_infix) == std::string::npos) {
    std::string::size_type pos = tp->label.find(parent_infix);
    if (pos != std::string::npos)
      tp->label.resize(pos);
    tp->subtrees = remove_annotations(tp->subtrees, parent_infix, binarization_infix, NULL);
    tp->next = remove_annotations(tp->next, parent_infix, binarization_infix, next);
    return tp;
  }
  else {  // binarized node -- ignore it
    T* tp1 = remove_annotations(tp->subtrees, parent_infix, binarization_infix, 
				remove_annotations(tp->next, parent_infix, binarization_infix,
						   next));
    tp->subtrees = NULL;
    tp->next = NULL;
    delete tp;
    return tp1;
  }
}

//! replace_unks() replaces all terminals not in vocabulary with the unknown word symbol
//
T* replace_unks(T* tp, const sS& vocabulary, const std::string& unknown_word) {
  if (!tp)
    return tp;
  if (tp->is_terminal() && vocabulary.count(tp->label) == 0)
    tp->label = unknown_word;
  tp->subtrees = replace_unks(tp->subtrees, vocabulary, unknown_word);
  tp->next = replace_unks(tp->next, vocabulary, unknown_word);
  return tp;
}

//! decorate_preterms() copies each word in wordannotations onto the preterm above it
//
T* decorate_preterms(T* tp, const sS& wordannotations, const std::string& parent_infix) {
  if (!tp)
    return tp;
  if (tp->is_preterminal()) {
    if (wordannotations.count(tp->subtrees->label))
      (tp->label += parent_infix) += tp->subtrees->label;
  }
  tp->subtrees = decorate_preterms(tp->subtrees, wordannotations, parent_infix);
  tp->next = decorate_preterms(tp->next, wordannotations, parent_infix);
  return tp;
}

//! count_terminals() counts how often each word occurs in the tree
//
void count_terminals(const T* tp, S_U& w_c) {
  if (!tp)
    return;
  if (tp->is_terminal()) 
    ++w_c[tp->label];
  else {
    count_terminals(tp->subtrees, w_c);
    count_terminals(tp->next, w_c);
  }
}

//! define a comparison operation between vectors, so we can sort them
//
template <typename X>
bool operator< (const std::vector<X>& x1s, const std::vector<X>& x2s) {
  for (unsigned i = 0; true; ++i) {
    if (i == x1s.size()) 
      return (i < x2s.size());
    else if (i == x2s.size())
      return false;
    if (x1s[i] < x2s[i])
      return true;
    else if (x2s[i] < x1s[i])
      return false;
  }
  return false;
}

//! count_rules() counts the number of times each local tree appears
//
void count_rules(T* tp, Ss_U& r_c) {
  if (tp && tp->subtrees) {
    Ss r;
    r.push_back(tp->label);
    for (T* cp = tp->subtrees; cp; cp = cp->next) {
      r.push_back(cp->label);
      count_rules(cp, r_c);
    }
    ++r_c[r];
  }
}

//! downcase() maps a string to lower case
//
std::string& downcase(std::string& s) {
  for (unsigned i = 0; i < s.size(); ++i)
    if (isupper(s[i]))
      s[i] = tolower(s[i]);
  return s;
}

//! downcase() downcases the labels of the tree labels
//
T* downcase(T* tp) {
  if (tp) {
    if (tp->is_terminal())
      tp->label = downcase(tp->label);
    tp->subtrees = downcase(tp->subtrees);
    tp->next = downcase(tp->next);
  }
  return tp;
} 

//! second_greaterthan{} is a comparison function for sorting pairs
//
struct second_greaterthan {
  template <typename T1, typename T2>
  bool operator() (const T1& e1, const T2& e2) {
    return e1.second > e2.second;
  }
};

int main(int argc, char** argv) {

  std::string binarization_infix = "_";
  std::string parent_infix = "^";
  std::string top_label = "TOP";
  std::string unknown_word = "*UNK*";
  unsigned nwords_limit = 0;
  unsigned vocabcount_limit = 0;
  std::string options;

  sS vocabulary;         // set of words NOT to map to *U*
  sS parentannotations;  // categories to parent annotate
  sS wordannotations;    // words whose preterm parents should be annotated

  S_U word_count;        // word -> count map
  Ss_U rule_count;       // rule -> count map

  int opt;
  while ((opt = getopt(argc, argv, "A:B:T:U:a::bcdef:hl:npPrst:u:wyYz")) != -1) 
    switch (opt) {
    case 'A':
      parent_infix = optarg;
      break;
    case 'B':
      binarization_infix = optarg;
      break;
    case 'T':
      top_label = optarg;
      break;
    case 'U':
      unknown_word = optarg;
      break;
    case 'a':
      if (optarg) {
	std::string arg(optarg);
	for (std::string::iterator it = arg.begin(); it != arg.end(); ++it)
	  if (*it == ':')   // replace ':' with ' '
	    *it = ' ';
	std::istringstream is(arg);
	std::string l;
	while (is >> l)
	  parentannotations.insert(l);
      }
      options.push_back(char(opt));
      break;     
    case 'f':  // read word annotation file
      {
	std::ifstream is(optarg);
	if (!is) {
	  std::cerr << "Fatal error in munge-trees: can't open word annotation file " << optarg << std::endl;
	  exit(EXIT_FAILURE);
	}
	std::string w;
	while (is >> w)
	  wordannotations.insert(w);
      options.push_back(char(opt));
      }
      break;
    case 'l':
      nwords_limit = atoi(optarg);
      options.push_back(char(opt));
      break;
    case 's':
      options.push_back(char(opt));
      break;
    case 't':
      vocabcount_limit = atoi(optarg);
      options.push_back(char(opt));
      break;
    case 'u':  // read vocabulary file
      {
	std::ifstream is(optarg);
	if (!is) {
	  std::cerr << "Fatal error in munge-trees: can't open vocabulary file " << optarg << std::endl;
	  exit(EXIT_FAILURE);
	}
	std::string w;
	while (is >> w)
	  vocabulary.insert(w);
      options.push_back(char(opt));
      }
      break;
    case 'h':
    case '?':
      std::cerr << usage << std::endl;
      exit(EXIT_FAILURE);
      break;
    default:
      options.push_back(char(opt));
      break;
    }
   
  T* t;
  std::string last_root_label;
  while (std::cin >> t) {
    for (std::string::size_type i = 0; i < options.size(); ++i)
      switch (options[i]) {
      case 'a':
	t = parent_annotate(t, parentannotations, parent_infix, binarization_infix);
	break;
      case 'b':
	t = binarize_tree(t, binarization_infix);
	break;
      case 'c':
	count_rules(t, rule_count);
	last_root_label = t->label;
	break;
      case 'd':
	downcase(t);
	break;
      case 'e':
	t = delete_empties(t);
	break;
      case 'f':
	t = decorate_preterms(t, wordannotations, parent_infix);
	break;
      case 'l':
	if (t->size() > nwords_limit)
	  i = options.size();  // don't process any more options
	break;
      case 'n':
	t = normalize_tree(t, parent_infix, &top_label);
	break;
      case 'p':
	t->prettyprint(std::cout, tree::IdentityFn()) << std::endl;
	break;
      case 'P':
	t = delete_punctuation(t);
	break;
      case 'r':
	t = remove_annotations(t, parent_infix, binarization_infix);
	break;
      case 's':
	delete_terminals(t);
	break;
      case 't':
	count_terminals(t, word_count);
	break;
      case 'u':
	t = replace_unks(t, vocabulary, unknown_word);
	break;
      case 'w':
	std::cout << t << std::endl;
	break;
      case 'y': 
	{
	  std::vector<std::string> ws;
	  t->yield(ws);
	  for (unsigned i = 0; i < ws.size(); ++i) 
	    std::cout << (i > 0 ? " " : "") << ws[i];
	  std::cout << std::endl;
	}
	break;
      case 'Y': 
	{
	  std::vector<std::string> pts, ws;
	  t->preterminals(pts, ws);
	  for (unsigned i = 0; i < pts.size(); ++i) 
	    std::cout << (i > 0 ? " " : "") << pts[i];
	  std::cout << std::endl;
	}
	break;
      case 'z': 
	{
	  std::vector<std::string> pts, ws;
	  t->preterminals(pts, ws);
	  for (unsigned i = 0; i < pts.size(); ++i) 
	    std::cout << (i > 0 ? " " : "") << ws[i] << " " << pts[i];
	  std::cout << std::endl;
	}
	break;
      default:
	break;
      }
    delete t;
  }

  if (!word_count.empty()) {   // write out words in order of frequency
    typedef std::pair<S,unsigned> SU;
    typedef std::vector<SU> SUs;
    SUs wordcounts;
    for (S_U::const_iterator it = word_count.begin(); it != word_count.end(); ++it)
      if (it->second >= vocabcount_limit)
	wordcounts.push_back(*it);
    std::sort(wordcounts.begin(), wordcounts.end(), second_greaterthan());

    for (SUs::const_iterator it = wordcounts.begin(); it != wordcounts.end(); ++it)
      std::cout << it->first << std::endl;
  }

  if (!rule_count.empty()) {  // write out rules and their counts
    // print out root rules first
    for (Ss_U::const_iterator it = rule_count.begin(); it != rule_count.end(); ++it)
      if (it->first[0] == last_root_label) {
	std::cout << it->second << ' ' << it->first[0] << " -->";
	for (unsigned i = 1; i < it->first.size(); ++i)
	  std::cout << ' ' << it->first[i];
	std::cout << std::endl;
      }
    // now print out nonroot rules
    for (Ss_U::const_iterator it = rule_count.begin(); it != rule_count.end(); ++it)
      if (it->first[0] != last_root_label) {
	std::cout << it->second << ' ' << it->first[0] << " -->";
	for (unsigned i = 1; i < it->first.size(); ++i)
	  std::cout << ' ' << it->first[i];
	std::cout << std::endl;
      }
  }

}
