
import string, sys, random
import fragment

class Template:

    def __init__(self, fragments=[]):
        self.__fragments = fragments

    def is_finished(self):
        for f in self.__fragments:
            if not f.is_complete():
                return 0
        return 1

    def completed_syllables(self):
        tally = 0
        for f in self.__fragments:
            if isinstance(f, fragment.Token()):
                tally += f.get_token().syllables()
        return tally

    def clone(self):
        new = Template()
        new.__fragments = map(lambda x: x.clone(), self.__fragments)
        return new

    def length(self):
        return len(self.__fragments)

    def get_fragment(self, i):
        return self.__fragments[i]

    def get_metric_error(self):
        tally = 0
        for f in self.__fragments:
            if isinstance(f, fragment.Token):
                tally += f.get_metric_error()
        return tally

    def get_sentence_lengths(self):
        all = []
        for f in self.__fragments:
            if isinstance(f, fragment.Token):
                if f.get_token().is_start():
                    N = 0
                elif f.get_token().is_stop():
                    all.append(N)
                else:
                    N += 1
        return all

    def pretty_print(self):
        pp = ""
        last_was_start = 0
        for frag in self.__fragments:
            if isinstance(frag, fragment.Break):
                pp += "\n"
            else:
                t = frag.get_token()
                if t.is_start():
                    last_was_start = 1
                elif t.is_stop():
                    pp += ". "
                    last_was_start = 0
                else:
                    s = t.to_string()
                    if last_was_start:
                        s = s.capitalize()
                    pp += " " + s
                    last_was_start = 0

        ppl = pp.split("\n")
        for line in ppl:
            print line.strip()

        print

    def step(self, markov):
        next = -1
        priority = 0
        for i in range(len(self.__fragments)):
            f = self.__fragments[i]
            if f.is_ready(self.__fragments, i) and not f.is_complete():
                if next < 0 or f.get_priority() > priority:
                    next = i
                    priority = f.get_priority()

        if next >= 0:
            return self.__fragments[next].evolve(markov,
                                                 self.__fragments, next)
        return None

    def evolve(self, markov, iterations=50, full_history=0):

        best = None
        best_metric_error = 1000000


        for j in xrange(iterations):

            history = [self.clone()]

            while history and not (history[-1].is_finished()):
                if history[-1].get_metric_error() >= best_metric_error:
                    history = None
                    break

                step = history[-1].clone()
                success = step.step(markov)
                
                if success:
                    history.append(step)

                # If we can't do it on the first step, we are screwed.
                elif len(history) == 1:
                    history = None
                    break

                # this attempt failed, so we have to backtrack
                else:
                    i = random.randint(1, len(history)-1)
                    history = history[:i]

            if history and history[-1].get_metric_error() < best_metric_error:
                best = history
                best_metric_error = best[-1].get_metric_error()

                if best_metric_error == 0:
                    break

        if full_history:
            return best
        else:
            return best[-1]
                
            

def is_finished(template):
    for frag in template:
        if not frag.is_complete():
            return 0
    return 1

def complete_syllables(template):
    tally = 0
    for frag in template:
        if isinstance(frag, fragment.Token):
            tally += frag.get_token().syllables()
    return tally

def evolve_next(template, markov):
    next = -1 
    priority = 0
    for i in range(1, len(template)):
        frag = template[i]
        if frag.is_ready(template, i) and not frag.is_complete():
            if next < 0 or frag.get_priority() > priority:
                next = i
                priority = frag.get_priority()
    if next >= 0:
        return template[next].evolve(markov, template, next)
    return 1

def clone(template):
    return map(lambda x: x.clone(), template)

def metric_error(template):
    tally = 0
    for frag in template:
        if isinstance(frag, fragment.Token):
            tally += frag.get_metric_error()
    return tally

def solve(template, markov, iterations=50):

    best = None
    best_metric_error = 10000000

    for j in xrange(1, iterations+1):

        if best_metric_error == 0:
            break

        if j % 1000 == 0:
            sys.stderr.write(".")

        history = [template]

        best_syl = 0
        
        while not is_finished(history[-1]):
            syl = complete_syllables(history[-1])
            if best_syl <= syl:
                best_syl = syl
                #print best_syl
                #spew(history[-1])
            if metric_error(history[-1]) >= best_metric_error:
                history = []
                break
            
            next = clone(history[-1])
            #spew(history[-1])
            if evolve_next(next, markov):
                history.append(next)
            else:
                # if we can't do it after the first step, tough luck
                if len(history) == 1:
                    history = []
                    break
                # go back to a random point in our history and restart
                i = random.randint(1, len(history)-1)
                #print "***** back to %d" % i
                history = history[:i]

        if history \
           and is_finished(history[-1]) \
           and metric_error(history[-1]) < best_metric_error:
            best = history[-1]
            best_metric_error = metric_error(best)

    print "(metric error=%f)" % best_metric_error

    return best


def spew(template):
    for frag in template:
        print "%4d %d %s" % (frag.get_priority(),
                                frag.is_complete(),
                                frag.to_string().strip())
    print "- " * 30

