/*
 * Positive memoization routines.
 * 
 * Copyright 2000 KUN.
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 2 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU Library General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program; if not, write to the Free Software
 *  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
 */

/* $Id: posmemo.c,v 1.7 2001/10/17 10:39:27 ejv Exp $ */

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif /* HAVE_CONFIG_H */

#ifdef HAVE_MALLOC_H
#include <malloc.h>
#else /* HAVE_MALLOC_H */
#include <stdlib.h>
#endif /* HAVE_MALLOC_H */

#include <string.h>

#include "rtscode.h"
#include "rtsutil.h"
#include "rtslex.h"
#include "posmemo.h"

static inline void*
posmemo_getmem(size_t s)
{
    return malloc(s);
}

static inline void
posmemo_freemem(void* ptr)
{
    free(ptr);
}

void
posmemo_init()
{
    /* init memory management, not implemented yet */
}

void
posmemo_done()
{
    /* stop memory management, not implemented yet */
}

static inline PosMemo*
posmemo_fetch(StateIndicator i, unsigned n)
{
    assert(i);
    return &(i->pos_memos[n]);
}

void
posmemo_init_table_entry(PosMemo* entry)
{
    (void*)*entry = (void*)entry;
}

int
posmemo_is_unknown(StateIndicator input_state, unsigned nont_nr)
{
    PosMemo* x = posmemo_fetch(input_state, nont_nr);
    return((x != NULL) && ((void*)x == (void*)(*x)));
}

int
posmemo_is_known(StateIndicator input_state, unsigned nont_nr)
{
    PosMemo* x = posmemo_fetch(input_state, nont_nr);
    return((*x == NULL) || ((void*)x != (void*)(*x)));
}

int
posmemo_is_blocked(StateIndicator input_state, unsigned nont_nr)
{
    PosMemo* x = posmemo_fetch(input_state, nont_nr);
    return(*x == NULL);
}

int
posmemo_is_blocked_for_penlevel(StateIndicator input_state, unsigned nont_nr,
                                long penlevel)
{
    PosMemo* x = posmemo_fetch(input_state, nont_nr);

    assert((void*)x != (void*)(*x)); /* if unknown, explode */

    if (*x == NULL) {
        return 1;
    }

    return (*x)->penalty > penlevel;
}

int
posmemo_is_unblocked(StateIndicator input_state, unsigned nont_nr)
{
    PosMemo* x = posmemo_fetch(input_state, nont_nr);
    return((*x != NULL) && ((void*)x != (void*)(*x)));
}

void
posmemo_set_unknown(StateIndicator input_state, unsigned nont_nr)
{
    PosMemo* x = posmemo_fetch(input_state, nont_nr);
    (void*)*x = (void*)x;
}

void
posmemo_set_blocked(StateIndicator input_state, unsigned nont_nr)
{
    PosMemo* x = posmemo_fetch(input_state, nont_nr);
    assert((void*)x == (void*)(*x));
    *x = NULL; 
}

static PosMemo posmemo_get_worst_same_parse(PosMemo list, PosMemo new)
{
    PosMemo worst = NULL;
    long worst_penalty = -1;
    PosMemo iter = list;
    long nr_parses = 0;

    while (iter) {
        if ((iter->next_state == new->next_state)
            && (iter->nr_formals == new->nr_formals)
            && (!memcmp(iter->formals, new->formals,
                        sizeof(VALUE) * (iter->nr_formals)))) {
            /* found a matching entry */
            nr_parses++;
            if (iter->penalty > worst_penalty) {
                /* and this entry is even worse than the one we had before */
                worst = iter;
                worst_penalty = iter->penalty;
            }
        }
        iter = iter->next;
    }

    if (nr_parses >= max_parses) {
        return worst;
    } else {
        return NULL;
    }
}

static void posmemo_add_sorted(PosMemo* pms, PosMemo new_prod)
{
    if (*pms == NULL) {
        *pms = new_prod;
        new_prod->next = NULL;
    } else if ((*pms)->penalty > new_prod->penalty) {
        new_prod->next = *pms;
        *pms = new_prod;
    } else if ((*pms)->penalty == new_prod->penalty) {
        if (((*pms)->nr_formals != new_prod->nr_formals)
            || ((*pms)->next_state != new_prod->next_state)
            || (memcmp((*pms)->formals, new_prod->formals,
                       sizeof(VALUE) * (new_prod->nr_formals)))) {
            posmemo_add_sorted(&((*pms)->next), new_prod);
        } else {
            /* else: same entires, forget it */
            posmemo_freemem(new_prod->formals);
            posmemo_freemem(new_prod);
        }
    } else {
        posmemo_add_sorted(&((*pms)->next), new_prod);
    }
}

static void posmemo_add_sorted_atmost_k(PosMemo* pms, PosMemo new_prod)
{
    PosMemo worst;

    assert((void*)*pms != (void*)pms); /* if unknown, explode -> should be blocked */

    if (*pms == NULL) {
        posmemo_add_sorted(pms, new_prod);
        return;
    }

    /* get production number max_parses (or NULL if it doesn't exist) with the
     * same next_state, the same formals and the worst penalty level */
    worst = posmemo_get_worst_same_parse(*pms, new_prod);

    if (!worst) {
        posmemo_add_sorted(pms, new_prod);
    } else {
        /* worst is found, replace penalty level */
        if (worst->penalty > new_prod->penalty) {
            worst->penalty = new_prod->penalty;
        }

        posmemo_freemem(new_prod->formals);
        posmemo_freemem(new_prod);
    }
}

void posmemo_add_production(StateIndicator input_state, unsigned nont_nr,
                            long penalty, unsigned nr_formals,
                            VALUE* formals, StateIndicator target_state)
{
    unsigned i;
    size_t formal_block_size = nr_formals * sizeof(VALUE);
    PosMemo* x = posmemo_fetch(input_state, nont_nr);
    PMPROD* new_memo = posmemo_getmem(sizeof(PMPROD));
    new_memo->nont_nr = nont_nr;
    new_memo->nr_formals = nr_formals;
    if (nr_formals) {
        new_memo->formals = (VALUE*)posmemo_getmem(formal_block_size);
    } else {
        new_memo->formals = NULL;
    }
    i = nr_formals;
    while (i) {
        i--;
        new_memo->formals[i] = formals[nr_formals - i - 1];
    }
    new_memo->penalty = penalty;
    new_memo->failcont = NULL;
    new_memo->next_state = target_state;

    posmemo_add_sorted_atmost_k(x, new_memo);
}

void
posmemo_free_vec(PosMemo* entry)
{
    if ((*entry != NULL) && ((void*)entry != (void*)(*entry))) {
        PMPROD* prod = *entry;
        while (prod) {
            PMPROD* tail = prod->next;
            if (prod->nr_formals) {
                posmemo_freemem(prod->formals);
            }
            posmemo_freemem(prod);
            prod = tail;
        }
    }
}

int posmemo_count_prod(StateIndicator input_state, unsigned nont_nr)
{
    int count = 0;
    PMPROD* x = *(posmemo_fetch(input_state, nont_nr));
    while (x) {
        count++;
        x = x->next;
    }

    return count;
}

void posmemo_set_failcont(PosMemo prod, void* pc)
{
    assert(prod);
    prod->failcont = pc;
}

void* posmemo_get_failcont(PosMemo prod)
{
    assert(prod);
    return prod->failcont;
}

PosMemo posmemo_get_prod_ptr(StateIndicator input_state, unsigned nont_nr)
{
    return *(posmemo_fetch(input_state, nont_nr));
}

void* posmemo_get_formal_ptr(PosMemo state)
{ 
    assert(state);
    return state->formals;
}

long posmemo_get_penalty(PosMemo state)
{
    assert(state);
    return state->penalty;
}

PosMemo posmemo_get_next_prod(PosMemo curr)
{
    assert(curr);
    return curr->next;
}

StateIndicator posmemo_get_input_state(PosMemo curr)
{
    assert(curr);
    return curr->next_state;
}

void
posmemo_dump_pmprod(int node_nr, PMPROD* pmprod)
{
    int i;

    printf("%d: ", node_nr);
    printf("nont_nr = %u, pen = %d, nr_formals = %u\n", node_nr,
           pmprod->nont_nr, pmprod->nr_formals);
    for (i = 0; i < pmprod->nr_formals; ++i) {
        printf("\tformal %d set value: %lu\n", i, (pmprod->formals)[i].set_par);
    }
}

void
posmemo_dump_table(Trellis* trellis)
{
    StateNode** state_row = GET_TRELLIS_STATE_ROW(trellis);
    int node_nr;
    int rule_nr;
    gboolean* empty_rule;
    gboolean* empty_node;

    int** overview = (int**) GetMem(sizeof(int*) * trellis->length, "overview[]");

    /* Build the table: */
    for(node_nr = 0; node_nr < trellis->length; node_nr++) {
        StateNode* state = *state_row++;

        if(!state) {
            overview[node_nr] = NULL;
        } else {
            PosMemo* pma = state->pos_memos;

            if(!pma) {
                overview[node_nr] = NULL;
            } else {
                overview[node_nr] = (int*) GetMem(sizeof(int) * get_nr_syntax_nonterminals(), "overview[][]");

                for(rule_nr = 1; rule_nr < get_nr_syntax_nonterminals(); rule_nr++) {
                    if(posmemo_is_blocked(state, rule_nr)) {
                        overview[node_nr][rule_nr] = -2;
                    } else {
                        if(posmemo_is_unknown(state, rule_nr)) {
                            overview[node_nr][rule_nr] = -1;
                        } else {
                            PosMemo plijst = posmemo_get_prod_ptr(state, rule_nr);
                            int nr_ptrs = 0;

                            while(plijst) {
//                                posmemo_dump_pmprod(node_nr, plijst);
                                nr_ptrs++;
                                plijst = plijst->next;
                            }

                            overview[node_nr][rule_nr] = nr_ptrs;
                        }
                    }
                }
            }
        }
    }

    /* printed table compression */
    empty_rule = (gboolean*) GetMem(sizeof(gboolean) * get_nr_syntax_nonterminals(), "empty_rule");
    for (rule_nr = 1; rule_nr < get_nr_syntax_nonterminals(); rule_nr++) {
        empty_rule[rule_nr] = TRUE;
        node_nr = 0;

        while ((node_nr < trellis->length) && (empty_rule[rule_nr])) {
            if (overview[node_nr]) {
                switch (overview[node_nr][rule_nr]) {
                    case -1:
                        break;
                    case -2:
                        empty_rule[rule_nr] = FALSE;
                        break;
                    default:
                        empty_rule[rule_nr] = !overview[node_nr][rule_nr];
                }
            }

            node_nr++;
        }
    }
    empty_node = (gboolean*) GetMem(sizeof(gboolean) * trellis->length, "empty_node");
    for (node_nr = 0; node_nr < trellis->length; node_nr++) {
        empty_node[node_nr] = TRUE;
        rule_nr = 1;

        while ((rule_nr < get_nr_syntax_nonterminals())
               && (empty_node[node_nr])
               && (overview[node_nr])) {
            switch (overview[node_nr][rule_nr]) {
                case -1:
                    break;
                case -2:
                    empty_node[node_nr] = FALSE;
                    break;
                default:
                    empty_node[node_nr] = !overview[node_nr][rule_nr];
            }

            rule_nr++;
        }
    }

    /* actually show it: */
    /* first the table */
    for (rule_nr = 1; rule_nr < get_nr_syntax_nonterminals(); rule_nr++) {
        if (!empty_rule[rule_nr]) {
            printf("%3d|", rule_nr);

            for (node_nr = 0; node_nr < trellis->length; node_nr++) {
                if (!empty_node[node_nr]) {
                    switch (overview[node_nr][rule_nr]) {
                        case -1:
                            printf("   u");
                            break;
                        case -2:
                            printf("   b");
                            break;
                        default:
                            printf(" %3d", overview[node_nr][rule_nr]);
                    }
                }
            }

            if (nonterm_names[rule_nr]) {
                printf(" | %s\n", nonterm_names[rule_nr]);
            } else {
                printf(" | ?\n");
            }
        }
    }
    /* then a neat line below it */
    printf("---+");
    for (node_nr = 0; node_nr < trellis->length; node_nr++) {
        if (!empty_node[node_nr]) {
            printf("----");
        }
    }
    /* and of course the numbers */
    printf("\n   |");
    for (node_nr = 0; node_nr < trellis->length; node_nr++) {
        if (!empty_node[node_nr]) {
            printf(" %3d", node_nr);
        }
    }
    printf("\n");

    /* free the space: */
    for (node_nr = 0; node_nr < trellis->length; node_nr++) {
        if (overview[node_nr]) {
            FreeMem(overview[node_nr], "overview[][]");
        }
    }
    FreeMem(overview, "overview[]");
    FreeMem(empty_rule, "empty_rule");
}
