// vwordtuples.cc
//
// (c) Mark Johnson, 28th March 2001
//
// This program identifies n-tuples of words that occur significantly
// more frequently than their subtuples would suggest.  This program
// only collects tuoles whose first word is a verb.

#include <algorithm>
#include <bitset>
#include <cassert>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <functional>
#include <ext/hash_map>
#include <iostream>
#include <utility>         // definition of pair structure is in here
#include <vector>

#include "interaction.h"
#include "morph.h"
#include "read-files.h"
#include "sym.h"
#include "utility.h"

static const size_t nmax = 16;
static const size_t n_print_max = 250;

typedef double         Float;
typedef unsigned long  U;
typedef symbol         S;
typedef vector<S>      Ss;
typedef hash_map<Ss,U> Ss_U;

//////////////////////////////////////////////////////////////////////////////
//                                                                          //
//                             Word counting                                //
//                                                                          //
//////////////////////////////////////////////////////////////////////////////


inline bool is_verbtag(symbol tag)
{
  static const symbol // AUX("AUX"), AUXG("AUXG"), MD("MD"), TO("TO"),
    VB("VB"), VBD("VBD"), VBN("VBN"), VBP("VBP"), VBZ("VBZ"); 

  return (// tag == AUX || tag == AUXG || tag == MD || tag == TO || 
	  tag == VB || tag == VBD || tag == VBN || tag == VBP || tag == VBZ);
}

inline bool is_nountag(symbol tag)
{
  static const symbol NN("NN"), NNS("NNS");

  return (tag == NN || tag == NNS);
}

inline symbol lowercase_initial(symbol word, symbol tag) 
{
  static const symbol NNP("NNP"), NNPS("NNPS"), I("I"), A("A"), a("a"), DT("DT");
  if (tag == NNP || tag == NNPS || word == I) 
    return word;
  if (word == A && tag == DT)
    return a;
  const string& s0 = word.string_reference();
  if (s0.size() > 1 && isupper(s0[0]) && islower(s0[1])) {
    string s(s0);
    s[0] = tolower(s[0]);
    return symbol(s);
  }
  return word;
}

struct count_words
{
  Ss_U words_count;   // Words -> count
  size_t n;
  size_t min_count;
  Float sig;

  count_words(size_t n_, size_t min_count_, Float sig_) 
    : n(n_), min_count(min_count_), sig(sig_) { }  // initialize w

  void operator() (vector<symbol>& words, const vector<symbol>& tags) {
    assert(words.size() > 0);
    Ss ws(n);
    for (size_t i = 0; i < words.size(); ++i)
      words[i] = lowercase_initial(words[i], tags[i]);
    if (words.size() >= n)
      for (size_t i = 0; i <= words.size()-n; ++i) {
	if (!is_verbtag(tags[i]))
	  continue;
	ws[0] = morph_base(words[i], VERB);
	size_t j;
	for (j = 1; j < n; j++) {
	  symbol w = words[i+j];
	  if (!isalpha(w.c_str()[0]))
	    break;
	  else
	    ws[j] = w;
	}
	if (j == n)
	  ++words_count[ws];
      }
  }
};


//////////////////////////////////////////////////////////////////////////////
//                                                                          //
//                             mask and context                             //
//                                                                          //
//////////////////////////////////////////////////////////////////////////////

class mask : public bitset<nmax> {
public:

  // masks inherit equality from bitset

  // The ordering imposed on masks
  //  It's important that turning off bits in a mask reduces its position in
  //  the ordering
  //
  bool operator< (const mask m2) const { return to_ulong() < m2.to_ulong(); }

  bool contains(const mask m0) const { return (m0 & operator~()).none(); }

  mask() : bitset<nmax>() { }
  mask(unsigned long val) : bitset<nmax>(val) { }
};

template<> struct hash<mask>
{
  size_t operator()(const mask m) const { return m.to_ulong(); }
};

struct context {
  mask  m;
  const Ss* s;  // only hold a pointer to the context

  context() : m(), s(NULL) { }   // empty context
  context(mask m_, const Ss& s_) : m(m_), s(&s_) { }
  context(mask m_, const Ss* s_) : m(m_), s(s_) { assert(s_); }
  context(const context& c) : m(c.m), s(c.s) { }

  bool operator== (const context& c2) const {
    if (m.none())
      return c2.m.none();
    assert(s);
    if (m != c2.m)
      return false;
    assert(c2.s);
    size_t n = s->size();
    assert(n == c2.s->size());
    assert(n <= nmax);
    for (size_t i = 0; i < n; ++i)
      if (m.test(i) && (*s)[i] != (*c2.s)[i])
	return false;
    return true;
  }

  bool operator< (const context& c2) const {
    assert(s);
    assert(c2.s);
    if (m < c2.m)
      return true;
    if (c2.m < m)
      return false;
    if (m.none() && c2.m.none())
      return false;
    size_t n = s->size();
    assert(n == c2.s->size());
    assert(n <= nmax);
    for (size_t i = 0; i < n; ++i)
      if (m.test(i)) {
	symbol s1((*s)[i]), s2((*c2.s)[i]);
	if (s1 < s2)
	  return true;
	else if (s1 > s2)
	  return false;
      }
    return false;
  }
};


// hash function for context
//
struct hash<context> 
{ //  This is the fn hashpjw of Aho, Sethi and Ullman, p 436.
  size_t operator()(const context& c) const 
  {
    if (c.m.none())
      return 0;

    typedef Ss::const_iterator CI;

    unsigned long h = hash<mask>()(c.m); 
    unsigned long g;
    assert(c.s);
    CI p = c.s->begin();
    CI end = c.s->end();
    size_t i = 0;
      
    while (p!=end) {
      if (c.m.test(i++)) {  // ignore non-identical contexts
	h = (h << 5) + hash<symbol>()(*p);
	if ((g = h&0xf0000000)) {
	  h = h ^ (g >> 24);
	  h = h ^ g;
	}}
      ++p;
    }
    return size_t(h);
  }
};


template <class Pair, class Cmp>
struct second_cmp : public binary_function<const Pair&, const Pair&, bool> {
  bool operator() (const Pair& p1, const Pair& p2) {
    return Cmp()(p1.second, p2.second);
  }
};


// The print routine for count_words objects
//
ostream& operator<< (ostream& os, const count_words& cw)
{
  typedef context         C;
  typedef hash_map<C,U>   C_U;
  typedef const Ss*       Ssp;
  typedef pair<Ssp,float> SspF;
  typedef vector<SspF>    SspFs;

  size_t n = cw.n;
  size_t nn = 1 << n;

  C_U context_count;

  cforeach (Ss_U, wsci, cw.words_count) {
    const Ss& words = wsci->first;
    size_t count = wsci->second;
    for (unsigned long m = 0; m < nn-1; ++m)
      context_count[context(m,words)] += count;
  }

  float alpha = standard_errors(cw.sig);
  SspFs ws_r;
  float* s = new float[nn];

  cforeach(Ss_U, wsci, cw.words_count) 
    if (wsci->second >= cw.min_count) {
      const Ss& words = wsci->first;
      for (unsigned long m = 0; m < nn-1; ++m)
	s[m] = afind(context_count, context(m,words));
      s[nn-1] = wsci->second;
      float ase;
      float lambda = interaction(n, s, &ase);
      float lambda_alpha_ase = lambda-alpha*ase;
      if (lambda_alpha_ase > 0)
	ws_r.push_back(make_pair(&words, lambda_alpha_ase));
    }

  delete[] s;

  sort(ws_r.begin(), ws_r.end(), second_cmp<SspF, greater<Float> >());

  size_t n_printed = 0;
  cforeach (SspFs, it, ws_r) {
    const Ss& words = *(it->first);
    cforeach (Ss, wi, words)
      os << *wi << " ";
    os << "\t" << it->second << endl;
    if (++n_printed >= n_print_max)
      break;
  }

  return os;
}

int main(int argc, char **argv)
{
  // Check that the program is called with exactly one argument
  //
  if (argc != 5) {
    cerr << "Usage: " << argv[0] 
	 << " tuple_size min_count significance_level filename" << endl;
    abort();
  }
  
  char *remainder;

  size_t tuple_size = strtoul(argv[1], &remainder, 10);
  if (*remainder != '\0') {
    cerr << argv[0] << ": Couldn't parse tuple_size argument: " 
	 << argv[1] << endl;
    exit(EXIT_FAILURE);
  }

  if (tuple_size > nmax) {
    cerr << argv[0] << ": tuple_size = " << tuple_size 
	 << " > nmax = " << nmax << endl;
    exit(EXIT_FAILURE);
  }

  size_t min_count = strtoul(argv[2], &remainder, 10);
  if (*remainder != '\0') {
    cerr << argv[0] << ": Couldn't parse min_count argument: " 
	 << argv[2] << endl;
    exit(EXIT_FAILURE);
  }

  Float sig = strtod(argv[3], &remainder);
  if (*remainder != '\0') {
    cerr << argv[0] << ": Couldn't parse significance_level argument: " 
	 << argv[3] << endl;
    exit(EXIT_FAILURE);
  }

  if (sig<=0.0 || sig > 1.0) {
    cerr << argv[0] << ": significance_level not between 0.0 and 1.0: " 
	 << sig << endl;
    exit(EXIT_FAILURE);
  }

  count_words cw(tuple_size, min_count, sig);
  process_files(argv[4], cw);
  cout << cw;
  return 0;
}
