// wordtuples.cc
//
// (c) Mark Johnson, 28th March 2001, updated 3rd May 2012
//
// This program identifies n-tuples of words that occur significantly
// more frequently than their subtuples would suggest.

#include <algorithm>
#include <bitset>
#include <cassert>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <functional>
#include <iostream>
#include <utility>
#include <vector>

#include "interaction.h"
#include "sym.h"
#include "util.h"

static const size_t nmax = 16;

typedef double                       Float;
typedef unsigned long                U;
typedef symbol                       S;
typedef std::vector<S>               Ss;
typedef stdhash::unordered_map<Ss,U> Ss_U;

////////////////////////////////////////////////////////////////////////////////
//                                                                            //
//                             Word counting                                  //
//                                                                            //
////////////////////////////////////////////////////////////////////////////////


struct count_words
{
  Ss_U words_count;   // Words -> count
  size_t n;
  size_t min_count;
  Float sig;

  count_words(size_t n_, size_t min_count_, Float sig_) 
    : n(n_), min_count(min_count_), sig(sig_) { }  
};

// The input routine for count_words objects
//
std::istream& operator>> (std::istream& is, count_words& cw) {
  std::string line;
  Ss words, ws(cw.n);
  while (std::getline(is, line)) {
    words.clear();
    util::split(line, words);
    if (words.size() >= cw.n)
      for (size_t i = 0; i <= words.size()-cw.n; ++i) {
	for (size_t j = 0; j < cw.n; j++)
	  ws[j] = words[i+j];
	++cw.words_count[ws];
      }
  }
  return is;
}
    
////////////////////////////////////////////////////////////////////////////////
//                                                                            //
//                             mask and context                               //
//                                                                            //
////////////////////////////////////////////////////////////////////////////////

class mask : public std::bitset<nmax> {
public:

  // masks inherit equality from bitset

  // The ordering imposed on masks
  //  It's important that turning off bits in a mask reduces its position in
  //  the ordering
  //
  bool operator< (const mask m2) const { return to_ulong() < m2.to_ulong(); }

  bool contains(const mask m0) const { return (m0 & operator~()).none(); }

  mask() : std::bitset<nmax>() { }
  mask(unsigned long val) : std::bitset<nmax>(val) { }
};

UTIL_BEGIN_STDHASHNAMESPACE {
  template <> struct hash<mask>
  {
    size_t operator()(const mask m) const { return m.to_ulong(); }
  };
} UTIL_END_STDHASHNAMESPACE // namespace stdhash

struct context {
  mask  m;
  const Ss* s;  // only hold a pointer to the context

  context() : m(), s(NULL) { }   // empty context
  context(mask m_, const Ss& s_) : m(m_), s(&s_) { }
  context(mask m_, const Ss* s_) : m(m_), s(s_) { assert(s_); }
  context(const context& c) : m(c.m), s(c.s) { }

  bool operator== (const context& c2) const {
    if (m.none())
      return c2.m.none();
    assert(s);
    if (m != c2.m)
      return false;
    assert(c2.s);
    size_t n = s->size();
    assert(n == c2.s->size());
    assert(n <= nmax);
    for (size_t i = 0; i < n; ++i)
      if (m.test(i) && (*s)[i] != (*c2.s)[i])
	return false;
    return true;
  }

  bool operator< (const context& c2) const {
    assert(s);
    assert(c2.s);
    if (m < c2.m)
      return true;
    if (c2.m < m)
      return false;
    if (m.none() && c2.m.none())
      return false;
    size_t n = s->size();
    assert(n == c2.s->size());
    assert(n <= nmax);
    for (size_t i = 0; i < n; ++i)
      if (m.test(i)) {
	symbol s1((*s)[i]), s2((*c2.s)[i]);
	if (s1 < s2)
	  return true;
	else if (s1 > s2)
	  return false;
      }
    return false;
  }
};


// hash function for context
//
UTIL_BEGIN_STDHASHNAMESPACE {
template<> struct hash<context> 
{ //  This is the fn hashpjw of Aho, Sethi and Ullman, p 436.
  size_t operator()(const context& c) const 
  {
    if (c.m.none())
      return 0;

    typedef Ss::const_iterator CI;

    unsigned long h = hash<mask>()(c.m); 
    unsigned long g;
    assert(c.s);
    CI p = c.s->begin();
    CI end = c.s->end();
    size_t i = 0;
      
    while (p!=end) {
      if (c.m.test(i++)) {  // ignore non-identical contexts
	h = (h << 5) + hash<symbol>()(*p);
	if ((g = h&0xf0000000)) {
	  h = h ^ (g >> 24);
	  h = h ^ g;
	}}
      ++p;
    }
    return size_t(h);
  }
};
} UTIL_END_STDHASHNAMESPACE // namespace stdhash

template <class Pair, class Cmp>
struct second_cmp : public std::binary_function<const Pair&, const Pair&, bool> {
  bool operator() (const Pair& p1, const Pair& p2) {
    return Cmp()(p1.second, p2.second);
  }
};


// The print routine for count_words objects
//
std::ostream& operator<< (std::ostream& os, const count_words& cw)
{
  typedef context                     C;
  typedef stdhash::unordered_map<C,U> C_U;
  typedef const Ss*                   Ssp;
  typedef std::pair<Ssp,float>        SspF;
  typedef std::vector<SspF>           SspFs;

  size_t n = cw.n;
  size_t nn = 1 << n;

  C_U context_count;

  cforeach (Ss_U, wsci, cw.words_count) {
    const Ss& words = wsci->first;
    size_t count = wsci->second;
    for (unsigned long m = 0; m < nn-1; ++m)
      context_count[context(m,words)] += count;
  }

  float alpha = standard_errors(cw.sig);
  SspFs ws_r;
  float* s = new float[nn];

  cforeach(Ss_U, wsci, cw.words_count) 
    if (wsci->second >= cw.min_count) {
      const Ss& words = wsci->first;
      for (unsigned long m = 0; m < nn-1; ++m)
	s[m] = util::afind(context_count, context(m,words));
      s[nn-1] = wsci->second;
      float ase;
      float lambda = interaction(n, s, &ase);
      float lambda_alpha_ase = lambda-alpha*ase;
      if (lambda_alpha_ase > 0)
	ws_r.push_back(make_pair(&words, lambda_alpha_ase));
      // if (lambda/ase > alpha)
      //   ws_r.push_back(make_pair(&words, lambda/ase));
    }

  delete[] s;

  sort(ws_r.begin(), ws_r.end(), second_cmp<SspF, std::greater<Float> >());

  cforeach (SspFs, it, ws_r) {
    os << it->second;
    const Ss& words = *(it->first);
    cforeach (Ss, wi, words)
      os << ' ' << *wi;
    os << std::endl;
  }
  return os;
}

int main(int argc, char **argv)
{
  if (argc != 4)
    std::cerr << "Usage: " << argv[0] 
	      << " tuple_size min_count significance_level" 
	      << util::abort;
  
  char *remainder;

  size_t tuple_size = strtoul(argv[1], &remainder, 10);
  if (*remainder != '\0') 
    std::cerr << argv[0] << ": Couldn't parse tuple_size argument: " 
	      << argv[1] << util::exit_failure;

  if (tuple_size > nmax) 
    std::cerr << argv[0] << ": tuple_size = " << tuple_size 
	      << " > nmax = " << nmax << util::exit_failure;

  size_t min_count = strtoul(argv[2], &remainder, 10);
  if (*remainder != '\0') 
    std::cerr << argv[0] << ": Couldn't parse min_count argument: " 
	      << argv[2] << util::exit_failure;

  Float sig = strtod(argv[3], &remainder);
  if (*remainder != '\0') 
    std::cerr << argv[0] << ": Couldn't parse significance_level argument: " 
	      << argv[3] << util::exit_failure;

  if (sig<=0.0 || sig >=1.0) 
    std::cerr << argv[0] << ": significance_level not between 0.0 and 1.0: " 
	      << sig << util::exit_failure;

  count_words cw(tuple_size, min_count, sig);
  std::cin >> cw;
  std::cout << cw;
  return 0;
}
