/* This is -*- C -*- */
/* vim: set sw=2: */
/* $Id$ */

/*
 * dictionary.c
 *
 * Copyright (C) 2003 The Free Software Foundation, Inc.
 *
 */

/*
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; either version 2 of the
 * License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 * USA.
 */

#ifdef CONFIG_H
#include <config.h>
#endif
#include "dictionary.h"

#include <stdio.h>
#include <ctype.h>
#include <string.h>

#define PRONUNCIATION_DICTNAME "cmudict/cmudict.0.6"
#define PARTOFSPEECH_DICTNAME  "pos/part-of-speech.txt"

typedef struct _RhymeTreeNode RhymeTreeNode;
struct _RhymeTreeNode {
  PhonemeCode code;
  RhymeTreeNode *branch[PHONEME_LAST];
  GSList *list;
};

static RhymeTreeNode *rhyme_tree_root = NULL;
static RhymeTreeNode *rhyme_tree_slanted_root = NULL;

static RhymeTreeNode *
node_new (PhonemeCode code)
{
  RhymeTreeNode *node = g_new0 (RhymeTreeNode, 1);
  node->code = code;
  return node;
}

static void
rhyme_index_word (DictionaryWord *dword)
{
  RhymeTreeNode *node, *snode;
  int i;
  Phoneme phon;
  PhonemeCode code, scode;

  if (dword->decomp == NULL)
    return;

  if (rhyme_tree_root == NULL)
    rhyme_tree_root = node_new (0);
  node = rhyme_tree_root;

  if (rhyme_tree_slanted_root == NULL)
    rhyme_tree_slanted_root = node_new (0);
  snode = rhyme_tree_slanted_root;

  for (i = 0; dword->decomp[i]; ++i) {
    phon = dword->decomp[i];
    code = PHONEME_TO_CODE (phon);
    scode = code;
    if (PHONEME_IS_VOWEL (code))
      scode = 0;

    if (node->branch[code] == NULL)
      node->branch[code] = node_new (code);
    node = node->branch[code];

    if (snode->branch[scode] == NULL)
      snode->branch[scode] = node_new (scode);
    snode = snode->branch[scode];

    if (PHONEME_IS_STRESSED (phon))
      break;
  }

  if (dword->decomp[i]) {
    node->list = g_slist_prepend (node->list, dword);
    snode->list = g_slist_prepend (snode->list, dword);
  }
}

/* ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** */

static GHashTable *dictionary = NULL;

DictionaryWord *
dictionary_add_word (const char      *word,
		     Phoneme         *decomp,
		     PartOfSpeechMask pos_mask)
{
  DictionaryWord *dword;
  char *query_word;

  g_return_val_if_fail (word != NULL, NULL);

  if (dictionary == NULL) {
    dictionary = g_hash_table_new (g_str_hash, g_str_equal);
  }

  query_word = g_ascii_strdown (word, -1);
  dword = g_hash_table_lookup (dictionary, query_word);

  if (dword) {
    
    if (dword->decomp == NULL) {
      dword->decomp = decomp;
      if (decomp)
	rhyme_index_word (dword);
    } else if (decomp != NULL) {
      /* drop extra decomp */
      g_free (decomp);
    }

    dword->pos_mask |= pos_mask;

    g_free (query_word);

  } else {

    dword = g_new (DictionaryWord, 1);
    dword->word     = query_word;
    dword->decomp   = decomp;
    dword->pos_mask = pos_mask;

    g_hash_table_insert (dictionary, dword->word, dword);
    rhyme_index_word (dword);
  }

  return dword;
}

DictionaryWord *
dictionary_get_word (const char *word)
{
  char *downword;
  DictionaryWord *dword;

  g_return_val_if_fail (word != NULL, NULL);

  if (dictionary == NULL) {
    dictionary_load_pronunciation (NULL); /* try default */
    dictionary_load_part_of_speech (NULL);
  }

  downword = g_ascii_strdown (word, -1);
  g_strstrip (downword);
  dword = g_hash_table_lookup (dictionary, downword);

  if (dword == NULL) {
    int len = strlen (downword);

    /* If we have XXX, synthesize XXXer */
    if (len > 4 && !strcmp (downword+len-2, "er")) {
      char *newword = g_strndup (downword, len-2);
      int i;

      dword = dictionary_get_word (newword);
      if (dword && dword->decomp) {
	Phoneme *newdecomp = g_new (Phoneme,
				    phoneme_decomp_length (dword->decomp)+2);
	newdecomp[0] = PHONEME_JOIN (PHONEME_ER, PHONEME_NO_STRESS);
	for (i = 0; dword->decomp[i]; ++i)
	  newdecomp[i+1] = dword->decomp[i];
	newdecomp[i+1] = 0;

	dword = dictionary_add_word (word, newdecomp, 0);
      }
      g_free (newword);
    }
  }

  g_free (downword);
  return dword;
}

static void
remove_excess_whitespace (char *str)
{
  char *r = str, *w = str;
  gboolean last_was_space = TRUE;

  while (*r) {
    if (isspace (*r) && ! last_was_space) {
      *w = *r;
      ++w;
      last_was_space = TRUE;
    } else {
      *w = *r;
      ++w;
      last_was_space = FALSE;
    }
    ++r;
  }
  *w = '\0';
  g_strstrip (str);
}

static void
non_alpha_to_whitespace (char *str)
{
  char *p = str;
  while (*p) {
    if (! isalpha (*p))
      *p = ' ';
    ++p;
  }
  remove_excess_whitespace (str);
}

static gboolean
contains_non_alpha (char *str)
{
  while (*str) {
    if (! isalpha (*str))
      return TRUE;
    ++str;
  }
  return FALSE;
}

Phoneme *
dictionary_get_decomp (const char *phrase_in)
{
  char *phrase;
  char **phrasev = NULL;
  int i;
  GSList *decomps = NULL;
  Phoneme *decomp = NULL;

  if (! (phrase_in && *phrase_in))
    return NULL;

  phrase = g_strdup (phrase_in);
  remove_excess_whitespace (phrase);
  phrasev = g_strsplit (phrase, " ", -1);

  for (i = 0; phrasev[i]; ++i) {
    DictionaryWord *dword;
    Phoneme *this_decomp = NULL;

    dword = dictionary_get_word (phrasev[i]);
    if (dword != NULL) {
      this_decomp = dword->decomp;
    } else if (contains_non_alpha (phrasev[i])) {
      non_alpha_to_whitespace (phrasev[i]);
      this_decomp = dictionary_get_decomp (phrasev[i]);
    }
    
    if (this_decomp == NULL)
      goto finished;
    
    decomps = g_slist_prepend (decomps, this_decomp);
  }

  if (g_slist_length (decomps) == 1) {

    decomp = (Phoneme *) decomps->data;

  } else if (g_slist_length (decomps) > 1) {
    int len = 0, i = 0, j = 0;
    GSList *iter;

    for (iter = decomps; iter != NULL; iter = iter->next)
      len += phoneme_decomp_length (iter->data);

    decomp = g_new0 (Phoneme, len+1);
    i = 0;

    for (iter = decomps; iter != NULL; iter = iter->next) {
      Phoneme *d = iter->data;
      for (j = 0; d[j]; ++j) {
	decomp[i] = d[j];
	++i;
      }
    }
  }

 finished:
  g_free (phrase);
  g_strfreev (phrasev);
  g_slist_free (decomps);

  return decomp;
}

/* ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** */

void
dictionary_load_pronunciation (const char *filename)
{
  FILE *in;
  char buffer[256];
  char *s, *t;
  char *word;
  Phoneme *decomp;

  /* If the filename is NULL, try the default positions */
  if (filename == NULL) {
    in = NULL;
#ifdef DICTPATH
    in = fopen (DICTPATH "/" PRONUNCIATION_DICTNAME, "r");
#endif
    if (in == NULL)
      in = fopen ("../dict/" PRONUNCIATION_DICTNAME, "r");
    if (in == NULL)
      in = fopen ("../../../dict/" PRONUNCIATION_DICTNAME, "r");
    g_assert (in != NULL);
  } else {
    in = fopen (filename, "r");
  }

  g_return_if_fail (in != NULL);

  while (fgets (buffer, 256, in)) {

    /* Elide comments (##) */
    for (s = buffer; *s; ++s) {
      if (*s == '#' && *(s+1) == '#') {
	*s = '\0';
	break;
      }
    }

    /* Skip leading whitespace, looking for word */
    for (word = buffer; *word && isspace (*word); ++word);
    
    if (*word) { /* if line isn't empty... */
      
      /* Walk to first whitespace */
      for (t = word; *t && ! isspace (*t); ++t);
      
      if (*t) {
	/* Clobber whitespace, looking for decomp */
	for (; *t && isspace (*t); ++t) {
	  *t = '\0';
	}

	if (*t) {

	  decomp = phoneme_decomp_from_string (t);
	  dictionary_add_word (word, decomp, 0);

	}
      }
    }

  }

  fclose(in);
}

void
dictionary_load_part_of_speech (const char *filename)
{
  FILE *in;
  char buffer[256];
  char *s;
  PartOfSpeechMask pos_mask;

  if (filename == NULL) {

    in = NULL;
#ifdef DICTNAME
    in = fopen (DICTPATH "/" PARTOFSPEECH_DICTNAME, "r");
#endif
    if (in == NULL)
      in = fopen ("../dict/" PARTOFSPEECH_DICTNAME, "r");
    if (in == NULL)
      in = fopen ("../../../dict/" PARTOFSPEECH_DICTNAME, "r");
  } else {
    in = fopen (filename, "r");
  }

  g_return_if_fail (in != NULL);

  while (fgets (buffer, 256, in)) {
    s = buffer;

    while (*s && *s != '\t')
      ++s;
    if (! *s) {
      g_warning ("No tab in line: %s", buffer);
      continue;
    }

    *s = '\0';
    ++s;
    pos_mask = 0;
    while (*s) {
      if (*s != '|' && !isspace (*s)) {
	PartOfSpeech pos = pos_from_char (*s);
	if (pos == POS_UNKNOWN) {
	  g_warning ("Unknown POS char: (%d) '%c' %s", *s, *s, buffer);
	} else {
	  pos_mask |= pos_get_mask(pos);
	}
      }
      ++s;
    }

    dictionary_add_word (buffer, NULL, pos_mask);
  }
}


/* ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** */


void
dictionary_foreach_by_tail (Phoneme     *decomp,
			    DictionaryFn fn,
			    gpointer     user_data)
{
  int i;
  RhymeTreeNode *node;
  Phoneme phon;
  PhonemeCode code;

  g_return_if_fail (decomp != NULL);
  g_return_if_fail (fn != NULL);

  node = rhyme_tree_slanted_root;

  for (i = 0; decomp[i] && node; ++i) {
    phon = decomp[i];
    code = PHONEME_TO_CODE (phon);
    if (PHONEME_IS_VOWEL (code))
      code = 0;
    node = node->branch[code];

    if (node && PHONEME_IS_STRESSED (phon)) {
      GSList *iter = node->list;
      while (iter != NULL) {
	fn ((DictionaryWord *) iter->data, user_data);
	iter = iter->next;
      }
      return;
    }
  }
}

/* ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** ** */

PyObject *
py_dictionary_load (PyObject *self, PyObject *args)
{
  char *filename;

  if (! PyArg_ParseTuple (args, "s", &filename))
    return NULL;

  dictionary_load_pronunciation (filename);
  
  Py_INCREF (Py_None);
  return Py_None;
}

PyObject *
py_dictionary_get_word (PyObject *self, PyObject *args)
{
  char *word;
  DictionaryWord *dword;

  if (! PyArg_ParseTuple (args, "s", &word))
    return NULL;

  dword = dictionary_get_word (word);

  if (dword && dword->decomp) {
    return phoneme_decomp_to_py (dword->decomp);
  }

  Py_INCREF (Py_None);
  return Py_None;
}

PyObject *
py_dictionary_get_pos_mask (PyObject *self ,PyObject *args)
{
  char *word;
  DictionaryWord *dword;

  if (! PyArg_ParseTuple (args, "s", &word))
    return NULL;

  dword = dictionary_get_word (word);

  if (dword) {
    return Py_BuildValue ("i", dword->pos_mask);
  }

  Py_INCREF (Py_None);
  return Py_None;
}
