/* This is -*- C -*- */
/* vim: set sw=2: */
/* $Id$ */

/*
 * markov.c
 *
 * Copyright (C) 2003 The Free Software Foundation, Inc.
 *
 */

/*
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; either version 2 of the
 * License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 * USA.
 */

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif
#include "markov.h"

#include "time.h"

Markov *
markov_new (void)
{
  Markov *markov = g_new0 (Markov, 1);

  markov->all_hash  = g_hash_table_new (NULL, NULL);
  markov->all = g_ptr_array_new ();

  markov->next  = g_hash_table_new (NULL, NULL);
  markov->prev  = g_hash_table_new (NULL, NULL);

  return markov;
}

/* ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** */

static int
markov_sort_array_cmp (gconstpointer a, gconstpointer b)
{
  Token *t1 = *(Token **)a;
  Token *t2 = *(Token **)b;
  
  return (t1 > t2) - (t1 < t2);
}

static void
markov_sort_array_cb (gpointer key, gpointer val, gpointer user_data)
{
  GPtrArray *array = val;
  qsort (array->pdata, array->len, sizeof (gpointer), markov_sort_array_cmp);
}

static void
markov_sort_arrays (Markov *markov)
{
  if (markov->sorted)
    return;
  g_hash_table_foreach (markov->next, markov_sort_array_cb, NULL);
  g_hash_table_foreach (markov->prev, markov_sort_array_cb, NULL);
  markov->sorted = TRUE;
}

/* ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** */


static void
markov_add_to_all (Markov *markov, Token *t)
{
  if (g_hash_table_lookup (markov->all_hash, t) == NULL) {
    g_hash_table_insert (markov->all_hash, t, t);
    g_ptr_array_add (markov->all, t);
  }
}

void
markov_add_pair (Markov *markov, Token *t1, Token *t2)
{
  GPtrArray *next_array, *prev_array;

  g_return_if_fail (markov != NULL);
  g_return_if_fail (t1 != NULL);
  g_return_if_fail (t2 != NULL);

  markov_add_to_all (markov, t1);
  markov_add_to_all (markov, t2);

  next_array = g_hash_table_lookup (markov->next, t1);
  if (next_array == NULL) {
    next_array = g_ptr_array_new ();
    g_hash_table_insert (markov->next, t1, next_array);
  }

  prev_array = g_hash_table_lookup (markov->prev, t2);
  if (prev_array == NULL) {
    prev_array = g_ptr_array_new ();
    g_hash_table_insert (markov->prev, t2, prev_array);
  }

  g_ptr_array_add (next_array, t2);
  g_ptr_array_add (prev_array, t1);

  markov->sorted = FALSE;
}

void
markov_add_text (Markov *markov, Text *txt)
{
  int i, N;
  Token *prev;

  int total=0, complete=0;
  gboolean seen_unknown = FALSE;

  g_return_if_fail (markov != NULL);
  g_return_if_fail (txt != NULL);

  N = text_length (txt);
  assert (N > 0);

  

  prev = text_get_token (txt, 0);
  for (i = 1; i < N; ++i) {
    Token *curr = text_get_token (txt, i);

    if (token_is_stop (curr)) {
      ++total;
      if (! seen_unknown)
	++complete;
      seen_unknown = FALSE;
    } else if (! (token_is_start (curr) || curr->pos_mask)) {
      seen_unknown = TRUE;
    }

    markov_add_pair (markov, prev, curr);
    prev = curr;
  }

  g_print ("%d sentences, %d complete: %.2f%%\n",
	   total, complete, 100*complete/(double)total);

  markov_sort_arrays (markov);
}

/* ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** */

int
markov_random (gint N)
{
  static gboolean inited = FALSE;
  if (! inited) {
    srandom (time (NULL));
    inited = TRUE;
  }

  return random () % N;
}

Token *
markov_choose_next (Markov *markov, Token *t,
		    MarkovFilter *filter,
		    int *unfiltered_choices,
		    int *filtered_choices)
{
  MarkovFilter fallback = MARKOV_FILTER_INIT;

  g_return_val_if_fail (markov != NULL, NULL);
  g_return_val_if_fail (t != NULL, NULL);

  if (filter == NULL)
    filter = &fallback;
  filter->prev_is = t;

  return markov_choose (markov, filter, FALSE,
			unfiltered_choices, filtered_choices);
}

Token *
markov_choose_prev (Markov *markov, Token *t,
		    MarkovFilter *filter,
		    int *unfiltered_choices,
		    int *filtered_choices)
{
  MarkovFilter fallback = MARKOV_FILTER_INIT;

  g_return_val_if_fail (markov != NULL, NULL);
  g_return_val_if_fail (t != NULL, NULL);

  if (filter == NULL)
    filter = &fallback;
  filter->next_is = t;

  return markov_choose (markov, filter, FALSE,
			unfiltered_choices, filtered_choices);
}

static void
markov_choose_cb (Token *t, gpointer user_data)
{
  *((Token **) user_data) = t;
}

Token *
markov_choose (Markov *markov,
	       MarkovFilter *filter,
	       gboolean count_choices_only,
	       int *basic_choices,
	       int *filtered_choices)
{
  Token *choice = NULL;

  g_return_val_if_fail (markov != NULL, NULL);
  g_return_val_if_fail (filter != NULL, NULL);
  
  markov_choose_many (markov, filter,
		      1, 
		      count_choices_only ? NULL : markov_choose_cb,
		      &choice,
		      basic_choices, filtered_choices);

  return choice;
}

/* ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** */

#define FLAG_PTR(p)   ((p) = GUINT_TO_POINTER (GPOINTER_TO_UINT (p) | 1))
#define UNFLAG_PTR(p) ((p) = GUINT_TO_POINTER (GPOINTER_TO_UINT (p) & ~1))
#define TEST_PTR(p)   (GPOINTER_TO_UINT (p) & 1)

static gboolean
sorted_array_contains (GPtrArray *array, Token *t)
{
  int a, b, i;
  Token *ta, *tb, *ti;

  if (array == NULL)
    return FALSE;

  a = 0;
  b = array->len-1;

  ta = g_ptr_array_index (array, a);
  tb = g_ptr_array_index (array, b);

  UNFLAG_PTR (ta);
  UNFLAG_PTR (tb);

  if (t == ta || t == tb)
    return TRUE;

  if (t < ta || t > tb)
    return FALSE;

  while (b - a > 1) {
    
    i = (a + b)/2;
    ti = g_ptr_array_index (array, i);
    UNFLAG_PTR (ti);

    if (t == ti)
      return TRUE;
    else if (t < ti) {
      b = i;
    } else { /* t > ti */ 
      a = i;
    }
  }

  return FALSE;

  ta = g_ptr_array_index (array, a);
  tb = g_ptr_array_index (array, b);

  UNFLAG_PTR (ta);
  UNFLAG_PTR (tb);

  return t == ta || t == tb;
}

static int
int_cmp (gconstpointer a, gconstpointer b)
{
  int ia = *(int *) a;
  int ib = *(int *) b;
  return (ia > ib) - (ia < ib);
}

void
markov_choose_many (Markov *markov,
		    MarkovFilter *filter,
		    int N,
		    TokenFn iter_fn, gpointer user_data,
		    int *basic_choices,
		    int *filtered_choices)
{
  GPtrArray *candidates;
  int i, count=0;
  gboolean count_choices_only;

  if (basic_choices)
    *basic_choices = 0;
  if (filtered_choices)
    *filtered_choices = 0;

  g_return_if_fail (markov != NULL);
  g_return_if_fail (filter != NULL);

  count_choices_only = (N <= 0 || iter_fn == NULL);

  if (! markov->sorted)
    markov_sort_arrays (markov);

  if ( (filter->left_rooted ? filter->prev_is : filter->next_is) == NULL)
    candidates = markov->all;
  else if (filter->left_rooted)
    candidates = g_hash_table_lookup (markov->next, filter->prev_is);
  else
    candidates = g_hash_table_lookup (markov->prev, filter->next_is);

  if (candidates == NULL || candidates->len == 0)
    return;

  /* Handle some pathological rhyme situations. */
  if (filter->minimum_rhyme_type > RHYME_TRUE)
    return;
  if (filter->minimum_rhyme_type <= RHYME_NONE
      || (filter->rhymes_with && filter->rhymes_with->decomp == NULL))
    filter->rhymes_with = NULL;

  for (i = 0; i < candidates->len; ++i) {
    Token *t = g_ptr_array_index (candidates, i);
    int syl;

    /* If this token is the same as the last, we don't need to
       re-do the whole computation. */
    if (i > 0) {
      Token *tc = t;
      Token *prev = g_ptr_array_index (candidates, i-1);
      gboolean valid = TEST_PTR (prev);
      UNFLAG_PTR (prev);
      UNFLAG_PTR (tc);
      if (prev == tc) {
	if (valid)
	  goto possible_choice;
	else
	  continue;
      }
    }

    if ((token_is_start (t) && filter->no_start)
	|| (token_is_stop (t) && filter->no_stop))
      continue;

    syl = token_syllables (t);

    /* Don't let stop follow start. */
    if (filter->prev_is != NULL
	&& token_is_start (filter->prev_is)
	&& token_is_stop (t))
      continue;
    if (filter->next_is != NULL
	&& token_is_stop (filter->next_is)
	&& token_is_start (t))
      continue;


    /* Check syllables */
    
    if (syl < filter->min_syllables)
      continue;

    if (filter->max_syllables > 0 && syl > filter->max_syllables)
      continue;


    /* Check rhymes */

    if (filter->rhymes_with != NULL) {
      if (t->decomp == NULL
	  || rhyme_get_type (filter->rhymes_with->decomp,
			     t->decomp) < filter->minimum_rhyme_type)
	continue;
    } else if (filter->rhyme_type_exists) {
      if (t->decomp == NULL
	  || ! rhyme_exists (t->decomp, filter->rhyme_type_exists))
	continue;
    }


    /* Check pinch */

    if (filter->next_is != NULL
	&& filter->prev_is != NULL
	&& syl == filter->max_syllables) {

      GPtrArray *target;

      /* Explicitly disallow pinches that would result in one-word
	 sentences. */
      if (token_is_start (filter->prev_is)
	  && token_is_stop (filter->next_is))
	continue;

      if (filter->left_rooted)
	target = g_hash_table_lookup (markov->prev, filter->next_is);
      else
	target = g_hash_table_lookup (markov->next, filter->prev_is);

      if (target == NULL || ! sorted_array_contains (target, t))
	continue;
    }
    
    /* Check arbitrary user function */

    if (filter->filter_fn
	&& ! filter->filter_fn (t, filter->user_data))
      continue;

  possible_choice:

    if (! count_choices_only)
      FLAG_PTR (g_ptr_array_index (candidates, i));
    ++count;
  }

  if (count > 0 && ! count_choices_only) {
    int *choices;
    int j, k, p;

    if (N > count)
      N = count;

    /* Generate N sorted choices */
    choices = g_newa (int, N);
    for (j = 0; j < N; ++j)
	choices[j] = markov_random (count);
    if (N > 1) {
      qsort (choices, N, sizeof (int), int_cmp);
    }
    
    k = 0; /* what number match is this? */
    p = 0; /* which of our N choices are we looking for? */
    for (i = 0; i < candidates->len; ++i) {
      Token *t = g_ptr_array_index (candidates, i);
      if (p < N && TEST_PTR (t)) {
	UNFLAG_PTR (t);
	while (p < N && choices[p] == k) {
	  iter_fn (t, user_data);
	  ++p;
	}
	++k;
      }
      UNFLAG_PTR (g_ptr_array_index (candidates, i));
    }
  }

  if (basic_choices)
    *basic_choices = candidates->len;
  if (filtered_choices)
    *filtered_choices = count;
}

/* ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** */

void
markov_filter_from_py_dict (MarkovFilter *filter,
			    PyObject *py_dict)
{
  PyObject *val;

  g_return_if_fail (filter != NULL);
  g_return_if_fail (py_dict != NULL);
  g_return_if_fail (PyDict_Check (py_dict));

  val = PyDict_GetItemString (py_dict, "no_start");
  if (val) {
    filter->no_start = PyObject_IsTrue (val);
  }

  val = PyDict_GetItemString (py_dict, "no_stop");
  if (val) {
    filter->no_stop = PyObject_IsTrue (val);
  }

  val = PyDict_GetItemString (py_dict, "left_rooted");
  if (val) {
    filter->left_rooted = PyObject_IsTrue (val);
  }

  val = PyDict_GetItemString (py_dict, "min_syllables");
  if (val) {
    g_assert (PyInt_Check (val));
    filter->min_syllables = PyInt_AsLong (val);
  }

  val = PyDict_GetItemString (py_dict, "max_syllables");
  if (val) {
    g_assert (PyInt_Check (val));
    filter->max_syllables = PyInt_AsLong (val);
  }

  val = PyDict_GetItemString (py_dict, "next_is");
  if (val && PyObject_IsTrue (val)) {
    filter->next_is = token_from_py (val);
  }

  val = PyDict_GetItemString (py_dict, "prev_is");
  if (val && PyObject_IsTrue (val)) {
    filter->prev_is = token_from_py (val);
  }

  val = PyDict_GetItemString (py_dict, "rhymes_with");
  if (val && PyObject_IsTrue (val)) {
    filter->rhymes_with = token_from_py (val);
  }

  val = PyDict_GetItemString (py_dict, "minimum_rhyme_type");
  if (val) {
    g_assert (PyInt_Check (val));
    filter->minimum_rhyme_type = PyInt_AsLong (val);
  }

  val = PyDict_GetItemString (py_dict, "rhyme_type_exists");
  if (val) {
    g_assert (PyInt_Check (val));
    filter->rhyme_type_exists = PyInt_AsLong (val);
  }
}

void
markov_filter_py_dict_cleanup (MarkovFilter *filter)
{
  g_return_if_fail (filter != NULL);

  /* these days, this does nothing */
}

/* ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** */

/* Python Type Magic */

typedef struct _PyMarkov PyMarkov;
struct _PyMarkov {
  PyObject_HEAD;
  Markov *markov;
};

static PyObject *
py_markov_add_text (PyObject *self, PyObject *args)
{
  Markov *markov = markov_from_py (self);
  PyObject *py_text;
  Text *text;

  if (! PyArg_ParseTuple (args, "O", &py_text))
    return NULL;
  
  text = text_from_py (py_text);
  markov_add_text (markov, text);

  Py_INCREF (Py_None);
  return Py_None;
}

static void
py_append_token_iter_fn (Token *t, gpointer user_data)
{
  PyObject *py_obj = user_data;
  PyObject *py_token = token_to_py (t);

  PyList_Append (py_obj, py_token);
  Py_DECREF (py_token);
}

static PyObject *
py_markov_choose_many (PyObject *self, PyObject *args)
{
  int i, N;
  PyObject *py_filter = NULL;
  MarkovFilter filter= MARKOV_FILTER_INIT;
  PyObject *py_choices;
  int basic_choices = 0;
  int unfiltered_choices = 0;
  
  if (! PyArg_ParseTuple (args, "i|O", &N, &py_filter))
    return NULL;

  if (py_filter != NULL)
    markov_filter_from_py_dict (&filter, py_filter);

  py_choices = PyList_New (0);
  markov_choose_many (markov_from_py (self),
		      &filter,
		      N,
		      py_append_token_iter_fn, py_choices,
		      &basic_choices, &unfiltered_choices);

  markov_filter_py_dict_cleanup (&filter);

  /* Scramble our choices */
  N = PyList_Size (py_choices);
  for (i = 0; i < N-1; ++i) {
    int j = i + 1 + markov_random (N-i-1);
    PyObject *tmp;
    tmp = PyList_GET_ITEM (py_choices, j);
    PyList_SET_ITEM (py_choices, j, PyList_GET_ITEM (py_choices, i));
    PyList_SET_ITEM (py_choices, i, tmp);
  }

  return Py_BuildValue ("Nii", py_choices, basic_choices, unfiltered_choices);
}

static PyObject *
py_markov_choose_nextprev (PyObject *self, PyObject *args,
			   gboolean choose_next)
{
  Markov *markov = markov_from_py (self);
  PyObject *py_token;
  PyObject *py_filter = NULL;
  MarkovFilter filter = MARKOV_FILTER_INIT;
  Token *t;
  int filtered_choices = 0;
  int unfiltered_choices = 0;
  Token *(*fn) (Markov *, Token *, MarkovFilter *, int *, int *);

  if (! PyArg_ParseTuple (args, "O|O", &py_token, &py_filter))
    return NULL;

  if (py_filter != NULL)
    markov_filter_from_py_dict (&filter, py_filter);

  fn = choose_next ? markov_choose_next : markov_choose_prev;

  t = token_from_py (py_token);
  t = fn (markov, t,
	  py_filter ? &filter : NULL,
	  &filtered_choices, &unfiltered_choices);

  markov_filter_py_dict_cleanup (&filter);

  return Py_BuildValue ("(Oii)", token_to_py (t),
			filtered_choices, unfiltered_choices);
}

static PyObject *
py_markov_choose_next (PyObject *self, PyObject *args)
{
  return py_markov_choose_nextprev (self, args, TRUE);
};

static PyObject *
py_markov_choose_prev (PyObject *self, PyObject *args)
{
  return py_markov_choose_nextprev (self, args, FALSE);
};

static PyMethodDef py_markov_methods[] = {
  { "add_text", py_markov_add_text, METH_VARARGS,
    "Extend the Markov model using the contents of a text." },
  { "choose_many", py_markov_choose_many, METH_VARARGS,
    "Choose multiple matching elements from our Markov model." },
  { "choose_next", py_markov_choose_next, METH_VARARGS,
    "Pick a successor using our Markov model." },
  { "choose_prev", py_markov_choose_prev, METH_VARARGS,
    "Pick a predecessor using our Markov model." },
  {NULL, NULL, 0, NULL}
};

static PyObject *
py_markov_getattr(PyObject *obj, char *name)
{
    return Py_FindMethod(py_markov_methods, obj, name);
}

static void
py_markov_dealloc(PyObject *self)
{
    PyObject_Del(self);
}

static PyTypeObject py_markov_type_info = {
  PyObject_HEAD_INIT(NULL)
  0,
  "Markov",
  sizeof(PyMarkov),
  0,
  py_markov_dealloc,/*tp_dealloc*/
  NULL,             /*tp_print*/
  py_markov_getattr, /*tp_getattr*/
  NULL,             /*tp_setattr*/
  NULL,             /*tp_compare*/
  NULL,             /*tp_repr*/
  NULL,             /*tp_as_number*/
  NULL,             /*tp_as_sequence*/
  NULL,             /*tp_as_mapping*/
  NULL,             /*tp_hash */
};

/* ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** */

PyObject *
markov_to_py (Markov *markov)
{
  PyMarkov *py_markov;
  py_markov = PyObject_New(PyMarkov, &py_markov_type_info);
  py_markov->markov = markov;
  return (PyObject *) py_markov;
}

Markov *
markov_from_py (PyObject *obj)
{
  return ((PyMarkov *) obj)->markov;
}

/* ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** */

PyObject *
py_markov_new (PyObject *self, PyObject *args)
{
  return markov_to_py (markov_new ());
}
