Exemplo n.º 1
0
def accept_correction( meaning, correction, update_func='geometric', update_scale=10 ):
    (lmk, lmk_prob, lmk_ent,
     rel, rel_prob, rel_ent,
     rel_exp_chain, rele_prob_chain, rele_ent_chain, rel_terminals, rel_landmarks,
     lmk_exp_chain, lmke_prob_chain, lmke_ent_chain, lmk_terminals, lmk_landmarks,
     rel_words, relw_prob, relw_ent,
     lmk_words, lmkw_prob, lmkw_ent) = meaning.args

    old_meaning_prob, old_meaning_entropy, lrpc, tps = get_sentence_meaning_likelihood( correction, lmk, rel )

    update = update_funcs[update_func](lmk_prob * rel_prob, old_meaning_prob, lmk_ent + rel_ent, old_meaning_entropy) * update_scale

    logger('Update functions is %s and update value is: %f' % (update_func, update))
    # print 'lmk_prob, lmk_ent, rel_prob, rel_ent, old_meaning_prob, old_meaning_entropy, update', lmk_prob, lmk_ent, rel_prob, rel_ent, old_meaning_prob, old_meaning_entropy, update
    # print lmk.object_class, type(rel)

    dec_update = -update

    for lhs,rhs,parent,_ in rel_exp_chain:
        # print 'Decrementing production - lhs: %s, rhs: %s, parent: %s' % (lhs,rhs,parent)
        update_expansion_counts( dec_update, lhs, rhs, parent, rel=rel )

    for lhs,rhs,parent,lmk in lmk_exp_chain:
        # print 'Decrementing production - lhs: %s, rhs: %s, parent: %s' % (lhs,rhs,parent)
        update_expansion_counts( dec_update, lhs, rhs, parent, lmk_class=(lmk.object_class if lmk else None),
                                                               lmk_ori_rels=get_lmk_ori_rels_str(lmk),
                                                               lmk_color=(lmk.color if lmk else None) )

    for term,word in zip(rel_terminals,rel_words):
        # print 'Decrementing word - pos: %s, word: %s, rel: %s' % (term, word, rel)
        update_word_counts( dec_update, term, word, rel=rel )

    for term,word,lmk in zip(lmk_terminals,lmk_words,lmk_landmarks):
        # print 'Decrementing word - pos: %s, word: %s, lmk_class: %s' % (term, word, lmk.object_class)
        update_word_counts( dec_update, term, word, lmk_class=lmk.object_class,
                                                    lmk_ori_rels=get_lmk_ori_rels_str(lmk),
                                                    lmk_color=(lmk.color if lmk else None) )

    # reward new words with old meaning
    for lhs,rhs,parent,lmk,rel in lrpc:
        # print 'Incrementing production - lhs: %s, rhs: %s, parent: %s' % (lhs,rhs,parent)
        update_expansion_counts( update, lhs, rhs, parent, rel=rel,
                                                           lmk_class=(lmk.object_class if lmk else None),
                                                           lmk_ori_rels=get_lmk_ori_rels_str(lmk),
                                                           lmk_color=(lmk.color if lmk else None) )

    for lhs,rhs,lmk,rel in tps:
        # print 'Incrementing word - pos: %s, word: %s, lmk_class: %s' % (lhs, rhs, (lmk.object_class if lmk else None) )
        update_word_counts( update, lhs, rhs, lmk_class=(lmk.object_class if lmk else None),
                                              rel=rel,
                                              lmk_ori_rels=get_lmk_ori_rels_str(lmk),
                                              lmk_color=(lmk.color if lmk else None) )
Exemplo n.º 2
0
def get_words(terminals, landmarks, rel=None):
    words = []
    probs = []
    entropy = []

    for n,lmk in zip(terminals, landmarks):
        # if we could not get an expansion for the LHS, we just pass down the unexpanded nonterminal symbol
        # it gets the probability of 1 and entropy of 0
        if n in NONTERMINALS:
            words.append(n)
            probs.append(1.0)
            entropy.append(0.0)
            continue

        lmk_class = (lmk.object_class if lmk else None)
        lmk_color = (lmk.color if lmk else None)
        rel_class = rel_type(rel)
        dist_class = (rel.measurement.best_distance_class if hasattr(rel, 'measurement') else None)
        deg_class = (rel.measurement.best_degree_class if hasattr(rel, 'measurement') else None)

        cp_db = CWord.get_word_counts(pos=n,
                                      lmk_class=lmk_class,
                                      lmk_ori_rels=get_lmk_ori_rels_str(lmk),
                                      lmk_color=lmk_color,
                                      rel=rel_class,
                                      rel_dist_class=dist_class,
                                      rel_deg_class=deg_class)

        if cp_db.count() <= 0:
            logger( 'Could not expand %s (lmk_class: %s, lmk_color: %s, rel: %s, dist_class: %s, deg_class: %s)' % (n, lmk_class, lmk_color, rel_class, dist_class, deg_class) )
            terminals.append( n )
            continue

        logger( 'Expanded %s (lmk_class: %s, lmk_color: %s, rel: %s, dist_class: %s, deg_class: %s)' % (n, lmk_class, lmk_color, rel_class, dist_class, deg_class) )

        ckeys, ccounts = zip(*[(cword.word,cword.count) for cword in cp_db.all()])

        ccounter = {}
        for cword in cp_db.all():
            if cword.word in ccounter: ccounter[cword.word] += cword.count
            else: ccounter[cword.word] = cword.count

        ckeys, ccounts = zip(*ccounter.items())

        # print 'ckeys', ckeys
        # print 'ccounts', ccounts

        ccounts = np.array(ccounts, dtype=float)
        ccounts /= ccounts.sum()

        w, w_prob, w_entropy = categorical_sample(ckeys, ccounts)
        words.append(w)
        probs.append(w_prob)
        entropy.append(w_entropy)

    p, H = np.prod(probs), np.sum(entropy)
    # print 'expanding %s to %s (p: %f, H: %f)' % (terminals, words, p, H)
    return words, p, H
Exemplo n.º 3
0
def accept_object_correction( meaning, sentence, update, eval_lmk=True, printing=True ):

    lmk,rel = meaning
    _,_, lrpc, tps = get_sentence_meaning_likelihood( sentence, lmk, rel, printing=printing)

    # reward new words with old meaning
    for lhs,rhs,parent,lmk,rel in lrpc:
        # print 'Incrementing production - lhs: %s, rhs: %s, parent: %s' % (lhs,rhs,parent)
        update_expansion_counts( update, lhs, rhs, parent, rel=rel,
                                                           lmk_class=(lmk.object_class if lmk else None),
                                                           lmk_ori_rels=get_lmk_ori_rels_str(lmk),
                                                           lmk_color=(lmk.color if lmk else None) )

    for i in xrange(len(tps)):
        lhs,rhs,lmk,rel = tps[i]
        prev_word = tps[i-1][1] if i > 0 else None
        # print 'Incrementing word - pos: %s, word: %s, lmk_class: %s' % (lhs, rhs, (lmk.object_class if lmk else None) )
        update_word_counts( update, lhs, rhs, prev_word, lmk_class=(lmk.object_class if lmk else None),
                                                         rel=rel,
                                                         lmk_ori_rels=get_lmk_ori_rels_str(lmk),
                                                         lmk_color=(lmk.color if lmk else None) )
Exemplo n.º 4
0
def train_rec( tree, parent=None, lmk=None, rel=None, prev_word='<no prev word>', update=1, printing=False):

    lhs = tree.node

    if isinstance(tree[0], ParentedTree): rhs = ' '.join(n.node for n in tree)
    else: rhs = ' '.join(n for n in tree)
    
    # check if this version of nltk uses a function for parent
    if hasattr( tree.parent, '__call__' ):
        parent = tree.parent().node if tree.parent() else None
    else:
        parent = tree.parent.node if tree.parent else None

    if lhs == 'RELATION':
        # everything under a RELATION node should ignore the landmark
        lmk = None

    if lhs == 'LANDMARK-PHRASE':
        # everything under a LANDMARK-PHRASE node should ignore the relation
        rel = None

    if lhs == parent == 'LANDMARK-PHRASE':
        # we need to move to the parent landmark
        lmk = parent_landmark(lmk)

    lmk_class = (lmk.object_class if lmk and lhs != 'LOCATION-PHRASE' else None)
    lmk_ori_rels = get_lmk_ori_rels_str(lmk) if lhs != 'LOCATION-PHRASE' else None
    lmk_color = (lmk.color if lmk and lhs != 'LOCATION-PHRASE' else None)

    if lhs in NONTERMINALS:

        update_expansion_counts(update=update, 
                                lhs=lhs, 
                                rhs=rhs, 
                                parent=parent, 
                                lmk_class=lmk_class, 
                                lmk_ori_rels=lmk_ori_rels, 
                                lmk_color=lmk_color, 
                                rel=rel)

        for subtree in tree:
            prev_word = train_rec(tree=subtree, parent=parent, lmk=lmk, rel=rel, prev_word=prev_word, printing=printing)

    else:
        update_word_counts(update=update,
                           pos=lhs,
                           word=rhs,
                           prev_word=prev_word,
                           lmk_class=lmk_class,
                           lmk_ori_rels=lmk_ori_rels,
                           lmk_color=lmk_color,
                           rel=rel)
        return rhs
Exemplo n.º 5
0
def accept_new_words_meaning( lmk, rel, sentence, update_func='geometric', update_scale=10, num_meanings=10, printing=True ):

    lmk, rel
    old_meaning_prob, old_meaning_entropy, lrpc, tps = get_sentence_meaning_likelihood( sentence, lmk, rel, printing )

    update = 10 * update_scale

    # reward new words with old meaning
    for lhs,rhs,parent,lmk,rel in lrpc:
        # print 'Incrementing production - lhs: %s, rhs: %s, parent: %s' % (lhs,rhs,parent)
        update_expansion_counts( update, lhs, rhs, parent, rel=rel,
                                                           lmk_class=(lmk.object_class if lmk else None),
                                                           lmk_ori_rels=get_lmk_ori_rels_str(lmk),
                                                           lmk_color=(lmk.color if lmk else None) )

    for lhs,rhs,lmk,rel in tps:
        # print 'Incrementing word - pos: %s, word: %s, lmk_class: %s' % (lhs, rhs, (lmk.object_class if lmk else None) )
        update_word_counts( update, lhs, rhs, lmk_class=(lmk.object_class if lmk else None),
                                              rel=rel,
                                              lmk_ori_rels=get_lmk_ori_rels_str(lmk),
                                              lmk_color=(lmk.color if lmk else None) )
    print
Exemplo n.º 6
0
def save_tree(tree, loc, rel, lmk, parent=None):
    if len(tree.productions()) == 1:
        # if this tree only has one production
        # it means that its child is a terminal (word)
        word = Word()
        word.word = tree[0]
        word.pos = tree.node
        word.parent = parent
        word.location = loc
    else:
        prod = Production()
        prod.lhs = tree.node
        prod.rhs = ' '.join(n.node for n in tree)
        prod.parent = parent
        prod.location = loc

        # some productions are related to semantic representation
        if prod.lhs == 'RELATION':
            prod.relation = rel_type(rel)
            if hasattr(rel, 'measurement'):
                prod.relation_distance_class = rel.measurement.best_distance_class
                prod.relation_degree_class = rel.measurement.best_degree_class

        elif prod.lhs == 'LANDMARK-PHRASE':
            prod.landmark = lmk_id(lmk)
            prod.landmark_class = lmk.object_class
            prod.landmark_orientation_relations = get_lmk_ori_rels_str(lmk)
            prod.landmark_color = lmk.color
            # next landmark phrase will need the parent landmark
            lmk = parent_landmark(lmk)

        elif prod.lhs == 'LANDMARK':
            # LANDMARK has the same landmark as its parent LANDMARK-PHRASE
            prod.landmark = parent.landmark
            prod.landmark_class = parent.landmark_class
            prod.landmark_orientation_relations = parent.landmark_orientation_relations
            prod.landmark_color = parent.landmark_color

        # save subtrees, keeping track of parent
        for subtree in tree:
            save_tree(subtree, loc, rel, lmk, prod)
Exemplo n.º 7
0
def save_tree(tree, loc, rel, lmk, parent=None):
    if len(tree.productions()) == 1:
        # if this tree only has one production
        # it means that its child is a terminal (word)
        word = Word()
        word.word = tree[0]
        word.pos = tree.node
        word.parent = parent
        word.location = loc
    else:
        prod = Production()
        prod.lhs = tree.node
        prod.rhs = ' '.join(n.node for n in tree)
        prod.parent = parent
        prod.location = loc

        # some productions are related to semantic representation
        if prod.lhs == 'RELATION':
            prod.relation = rel_type(rel)
            if hasattr(rel, 'measurement'):
                prod.relation_distance_class = rel.measurement.best_distance_class
                prod.relation_degree_class = rel.measurement.best_degree_class

        elif prod.lhs == 'LANDMARK-PHRASE':
            prod.landmark = lmk_id(lmk)
            prod.landmark_class = lmk.object_class
            prod.landmark_orientation_relations = get_lmk_ori_rels_str(lmk)
            prod.landmark_color = lmk.color
            # next landmark phrase will need the parent landmark
            lmk = parent_landmark(lmk)

        elif prod.lhs == 'LANDMARK':
            # LANDMARK has the same landmark as its parent LANDMARK-PHRASE
            prod.landmark = parent.landmark
            prod.landmark_class = parent.landmark_class
            prod.landmark_orientation_relations = parent.landmark_orientation_relations
            prod.landmark_color = parent.landmark_color

        # save subtrees, keeping track of parent
        for subtree in tree:
            save_tree(subtree, loc, rel, lmk, prod)
Exemplo n.º 8
0
def get_expansion(lhs, parent=None, lmk=None, rel=None):
    lhs_rhs_parent_chain = []
    prob_chain = []
    entropy_chain = []
    terminals = []
    landmarks = []

    for n in lhs.split():
        if n in NONTERMINALS:
            if n == parent == 'LANDMARK-PHRASE':
                # we need to move to the parent landmark
                lmk = parent_landmark(lmk)

            lmk_class = (lmk.object_class if lmk else None)
            lmk_ori_rels = get_lmk_ori_rels_str(lmk)
            lmk_color = (lmk.color if lmk else None)
            rel_class = rel_type(rel)
            dist_class = (rel.measurement.best_distance_class if hasattr(rel, 'measurement') else None)
            deg_class = (rel.measurement.best_degree_class if hasattr(rel, 'measurement') else None)

            cp_db = CProduction.get_production_counts(lhs=n,
                                                      parent=parent,
                                                      lmk_class=lmk_class,
                                                      lmk_ori_rels=lmk_ori_rels,
                                                      lmk_color=lmk_color,
                                                      rel=rel_class,
                                                      dist_class=dist_class,
                                                      deg_class=deg_class)

            if cp_db.count() <= 0:
                logger('Could not expand %s (parent: %s, lmk_class: %s, lmk_ori_rels: %s, lmk_color: %s, rel: %s, dist_class: %s, deg_class: %s)' % (n, parent, lmk_class, lmk_ori_rels, lmk_color, rel_class, dist_class, deg_class))
                terminals.append( n )
                continue

            ckeys, ccounts = zip(*[(cprod.rhs,cprod.count) for cprod in cp_db.all()])

            ccounter = {}
            for cprod in cp_db.all():
                if cprod.rhs in ccounter: ccounter[cprod.rhs] += cprod.count
                else: ccounter[cprod.rhs] = cprod.count

            ckeys, ccounts = zip(*ccounter.items())

            # print 'ckeys', ckeys
            # print 'ccounts', ccounts

            ccounts = np.array(ccounts, dtype=float)
            ccounts /= ccounts.sum()

            cprod, cprod_prob, cprod_entropy = categorical_sample(ckeys, ccounts)
            # print cprod, cprod_prob, cprod_entropy

            lhs_rhs_parent_chain.append( ( n,cprod,parent,lmk ) )
            prob_chain.append( cprod_prob )
            entropy_chain.append( cprod_entropy )

            lrpc, pc, ec, t, ls = get_expansion( lhs=cprod, parent=n, lmk=lmk, rel=rel )
            lhs_rhs_parent_chain.extend( lrpc )
            prob_chain.extend( pc )
            entropy_chain.extend( ec )
            terminals.extend( t )
            landmarks.extend( ls )
        else:
            terminals.append( n )
            landmarks.append( lmk )

    return lhs_rhs_parent_chain, prob_chain, entropy_chain, terminals, landmarks
Exemplo n.º 9
0
def get_words(terminals, landmarks, rel=None, prevword=None):
    words = []
    probs = []
    alphas = []
    entropy = []
    C = CWord.get_count

    for n,lmk in zip(terminals, landmarks):
        # if we could not get an expansion for the LHS, we just pass down the unexpanded nonterminal symbol
        # it gets the probability of 1 and entropy of 0
        if n in NONTERMINALS:
            words.append(n)
            probs.append(1.0)
            entropy.append(0.0)
            continue

        lmk_class = (lmk.object_class if lmk else None)
        lmk_color = (lmk.color if lmk else None)
        rel_class = rel_type(rel)
        dist_class = (rel.measurement.best_distance_class if hasattr(rel, 'measurement') else None)
        deg_class = (rel.measurement.best_degree_class if hasattr(rel, 'measurement') else None)



        meaning = dict(pos=n,
                       lmk_class=lmk_class,
                       lmk_ori_rels=get_lmk_ori_rels_str(lmk),
                       lmk_color=lmk_color,
                       rel=rel_class,
                       rel_dist_class=dist_class,
                       rel_deg_class=deg_class)

        cp_db_uni = CWord.get_word_counts(**meaning)

        ccounter = {}
        for c in cp_db_uni:
            ccounter[c.word] = ccounter.get(c.word, 0) + c.count
        ckeys, ccounts_uni = zip(*ccounter.items())
        ccounts_uni = np.array(ccounts_uni, dtype=float)
        ccounts_uni /= ccounts_uni.sum()


        prev_word = words[-1] if words else prevword
        alpha = C(prev_word=prev_word, **meaning) / C(**meaning)
        alphas.append(alpha)

        if alpha:
            cp_db_bi = CWord.get_word_counts(prev_word=prev_word, **meaning)

            ccounter = {}
            for c in cp_db_bi:
                ccounter[c.word] = ccounter.get(c.word, 0) + c.count
            ccounts_bi = np.array([ccounter.get(k,0) for k in ckeys], dtype=float)
            ccounts_bi /= ccounts_bi.sum()

            cprob = (alpha * ccounts_bi) + ((1-alpha) * ccounts_uni)

        else:
            cprob = ccounts_uni


        # if cp_db.count() <= 0:
            # logger( 'Could not expand %s (lmk_class: %s, lmk_color: %s, rel: %s, dist_class: %s, deg_class: %s)' % (n, lmk_class, lmk_color, rel_class, dist_class, deg_class) )
            # terminals.append( n )
            # continue

        # ckeys, ccounts = zip(*[(cword.word,cword.count) for cword in cp_db.all()])

        # ccounter = {}
        # for cword in cp_db.all():
        #     if cword.word in ccounter: ccounter[cword.word] += cword.count
        #     else: ccounter[cword.word] = cword.count

        # ckeys, ccounts = zip(*ccounter.items())

        # print 'ckeys', ckeys
        # print 'ccounts', ccounts

        # ccounts = np.array(ccounts, dtype=float)
        # ccounts /= ccounts.sum()

        w, w_prob, w_entropy = categorical_sample(ckeys, cprob)
        words.append(w)
        probs.append(w_prob)
        entropy.append(w_entropy)

    p, H = np.prod(probs), np.sum(entropy)
    # print 'expanding %s to %s (p: %f, H: %f)' % (terminals, words, p, H)
    return words, p, H, alphas
Exemplo n.º 10
0
def get_tree_probs(tree, lmk=None, rel=None):
    lhs_rhs_parent_chain = []
    prob_chain = []
    entropy_chain = []
    term_prods = []

    lhs = tree.node

    if isinstance(tree[0], ParentedTree): rhs = ' '.join(n.node for n in tree)
    else: rhs = ' '.join(n for n in tree)

    parent = tree.parent.node if tree.parent else None

    if lhs == 'RELATION':
        # everything under a RELATION node should ignore the landmark
        lmk = None

    if lhs == 'LANDMARK-PHRASE':
        # everything under a LANDMARK-PHRASE node should ignore the relation
        rel = None

    if lhs == parent == 'LANDMARK-PHRASE':
        # we need to move to the parent landmark
        lmk = parent_landmark(lmk)

    lmk_class = (lmk.object_class if lmk and lhs != 'LOCATION-PHRASE' else None)
    lmk_ori_rels = get_lmk_ori_rels_str(lmk) if lhs != 'LOCATION-PHRASE' else None
    lmk_color = (lmk.color if lmk and lhs != 'LOCATION-PHRASE' else None)
    rel_class = rel_type(rel) if lhs != 'LOCATION-PHRASE' else None
    dist_class = (rel.measurement.best_distance_class if hasattr(rel, 'measurement') and lhs != 'LOCATION-PHRASE' else None)
    deg_class = (rel.measurement.best_degree_class if hasattr(rel, 'measurement') and lhs != 'LOCATION-PHRASE' else None)

    if lhs in NONTERMINALS:
        cp_db = CProduction.get_production_counts(lhs=lhs,
                                                  parent=parent,
                                                  lmk_class=lmk_class,
                                                  lmk_ori_rels=lmk_ori_rels,
                                                  lmk_color=lmk_color,
                                                  rel=rel_class,
                                                  dist_class=dist_class,
                                                  deg_class=deg_class)

        if cp_db.count() <= 0:
            logger('Could not expand %s (parent: %s, lmk_class: %s, lmk_ori_rels: %s, lmk_color: %s, rel: %s, dist_class: %s, deg_class: %s)' % (lhs, parent, lmk_class, lmk_ori_rels, lmk_color, rel_class, dist_class, deg_class))
        else:
            ckeys, ccounts = zip(*[(cprod.rhs,cprod.count) for cprod in cp_db.all()])

            ccounter = {}
            for cprod in cp_db.all():
                if cprod.rhs in ccounter: ccounter[cprod.rhs] += cprod.count
                else: ccounter[cprod.rhs] = cprod.count + 1

            # we have never seen this RHS in this context before
            if rhs not in ccounter: ccounter[rhs] = 1

            ckeys, ccounts = zip(*ccounter.items())

            # add 1 smoothing
            ccounts = np.array(ccounts, dtype=float)
            ccount_probs = ccounts / ccounts.sum()
            cprod_entropy = -np.sum( (ccount_probs * np.log(ccount_probs)) )
            cprod_prob = ccounter[rhs]/ccounts.sum()

            # logger('ckeys: %s' % str(ckeys))
            # logger('ccounts: %s' % str(ccounts))
            # logger('rhs: %s, cprod_prob: %s, cprod_entropy: %s' % (rhs, cprod_prob, cprod_entropy))

            prob_chain.append( cprod_prob )
            entropy_chain.append( cprod_entropy )

        lhs_rhs_parent_chain.append( ( lhs, rhs, parent, lmk, rel ) )

        for subtree in tree:
            pc, ec, lrpc, tps = get_tree_probs(subtree, lmk, rel)
            prob_chain.extend( pc )
            entropy_chain.extend( ec )
            lhs_rhs_parent_chain.extend( lrpc )
            term_prods.extend( tps )

    else:
        cw_db = CWord.get_word_counts(pos=lhs,
                                      lmk_class=lmk_class,
                                      lmk_ori_rels=lmk_ori_rels,
                                      lmk_color=lmk_color,
                                      rel=rel_class,
                                      rel_dist_class=dist_class,
                                      rel_deg_class=deg_class)

        if cw_db.count() <= 0:
            # we don't know the probability or entropy values for the context we have never seen before
            # we just update the term_prods list
            logger('Could not expand %s (lmk_class: %s, lmk_ori_rels: %s, lmk_color: %s, rel: %s, dist_class: %s, deg_class: %s)' % (lhs, lmk_class, lmk_ori_rels, lmk_color, rel_class, dist_class, deg_class))
        else:

            ckeys, ccounts = zip(*[(cword.word,cword.count) for cword in cw_db.all()])

            ccounter = {}
            for cword in cw_db.all():
                if cword.word in ccounter: ccounter[cword.word] += cword.count
                else: ccounter[cword.word] = cword.count + 1

            # we have never seen this RHS in this context before
            if rhs not in ccounter: ccounter[rhs] = 1

            ckeys, ccounts = zip(*ccounter.items())

            # logger('ckeys: %s' % str(ckeys))
            # logger('ccounts: %s' % str(ccounts))

            # add 1 smoothing
            ccounts = np.array(ccounts, dtype=float)
            ccount_probs = ccounts/ccounts.sum()

            w_prob = ccounter[rhs]/ccounts.sum()
            w_entropy = -np.sum( (ccount_probs * np.log(ccount_probs)) )

            prob_chain.append(w_prob)
            entropy_chain.append(w_entropy)

        term_prods.append( (lhs, rhs, lmk, rel) )

    return prob_chain, entropy_chain, lhs_rhs_parent_chain, term_prods
Exemplo n.º 11
0
def accept_correction(meaning,
                      correction,
                      update_func='geometric',
                      update_scale=10):
    (lmk, lmk_prob, lmk_ent, rel, rel_prob, rel_ent, rel_exp_chain,
     rele_prob_chain, rele_ent_chain, rel_terminals, rel_landmarks,
     lmk_exp_chain, lmke_prob_chain, lmke_ent_chain, lmk_terminals,
     lmk_landmarks, rel_words, relw_prob, relw_ent, lmk_words, lmkw_prob,
     lmkw_ent) = meaning.args
    rel_a = meaning.rel_a
    lmk_a = meaning.lmk_a

    old_meaning_prob, old_meaning_entropy, lrpc, tps = get_sentence_meaning_likelihood(
        correction, lmk, rel)

    update = update_funcs[update_func](lmk_prob * rel_prob, old_meaning_prob,
                                       lmk_ent + rel_ent,
                                       old_meaning_entropy) * update_scale

    logger('Update functions is %s and update value is: %f' %
           (update_func, update))
    # print 'lmk_prob, lmk_ent, rel_prob, rel_ent, old_meaning_prob, old_meaning_entropy, update', lmk_prob, lmk_ent, rel_prob, rel_ent, old_meaning_prob, old_meaning_entropy, update
    # print lmk.object_class, type(rel)

    dec_update = -update

    for lhs, rhs, parent, _ in rel_exp_chain:
        # print 'Decrementing production - lhs: %s, rhs: %s, parent: %s' % (lhs,rhs,parent)
        update_expansion_counts(dec_update, lhs, rhs, parent, rel=rel)

    for lhs, rhs, parent, lmk in lmk_exp_chain:
        # print 'Decrementing production - lhs: %s, rhs: %s, parent: %s' % (lhs,rhs,parent)
        update_expansion_counts(dec_update,
                                lhs,
                                rhs,
                                parent,
                                lmk_class=(lmk.object_class if lmk else None),
                                lmk_ori_rels=get_lmk_ori_rels_str(lmk),
                                lmk_color=(lmk.color if lmk else None))

    data = zip(rel_terminals, rel_words)
    for i in xrange(len(data)):
        term, word = data[i]
        prev_word = data[i - 1][1] if i > 0 else None
        a = rel_a[i]
        # print 'Decrementing word - pos: %s, word: %s, rel: %s' % (term, word, rel)
        update_word_counts((1 - a) * dec_update, term, word, rel=rel)
        update_word_counts(a * dec_update,
                           term,
                           word,
                           rel=rel,
                           prev_word=prev_word)

    data = zip(lmk_terminals, lmk_words, lmk_landmarks)
    for i in xrange(len(data)):
        term, word, lmk = data[i]
        prev_word = data[i - 1][1] if i > 0 else rel_words[-1]
        a = lmk_a[i]
        # print 'Decrementing word - pos: %s, word: %s, lmk_class: %s' % (term, word, lmk.object_class)
        update_word_counts((1 - a) * dec_update,
                           term,
                           word,
                           lmk_class=lmk.object_class,
                           lmk_ori_rels=get_lmk_ori_rels_str(lmk),
                           lmk_color=(lmk.color if lmk else None))
        update_word_counts(a * dec_update,
                           term,
                           word,
                           prev_word,
                           lmk_class=lmk.object_class,
                           lmk_ori_rels=get_lmk_ori_rels_str(lmk),
                           lmk_color=(lmk.color if lmk else None))

    # reward new words with old meaning
    for lhs, rhs, parent, lmk, rel in lrpc:
        # print 'Incrementing production - lhs: %s, rhs: %s, parent: %s' % (lhs,rhs,parent)
        update_expansion_counts(update,
                                lhs,
                                rhs,
                                parent,
                                rel=rel,
                                lmk_class=(lmk.object_class if lmk else None),
                                lmk_ori_rels=get_lmk_ori_rels_str(lmk),
                                lmk_color=(lmk.color if lmk else None))

    for i in xrange(len(tps)):
        lhs, rhs, lmk, rel = tps[i]
        prev_word = tps[i - 1][1] if i > 0 else None
        # print 'Incrementing word - pos: %s, word: %s, lmk_class: %s' % (lhs, rhs, (lmk.object_class if lmk else None) )
        update_word_counts(update,
                           lhs,
                           rhs,
                           prev_word,
                           lmk_class=(lmk.object_class if lmk else None),
                           rel=rel,
                           lmk_ori_rels=get_lmk_ori_rels_str(lmk),
                           lmk_color=(lmk.color if lmk else None))
Exemplo n.º 12
0
def get_expansion(lhs, parent=None, lmk=None, rel=None):
    lhs_rhs_parent_chain = []
    prob_chain = []
    entropy_chain = []
    terminals = []
    landmarks = []

    for n in lhs.split():
        if n in NONTERMINALS:
            if n == parent == 'LANDMARK-PHRASE':
                # we need to move to the parent landmark
                lmk = parent_landmark(lmk)

            lmk_class = (lmk.object_class if lmk else None)
            lmk_ori_rels = get_lmk_ori_rels_str(lmk)
            lmk_color = (lmk.color if lmk else None)
            rel_class = rel_type(rel)
            dist_class = (rel.measurement.best_distance_class if hasattr(
                rel, 'measurement') else None)
            deg_class = (rel.measurement.best_degree_class if hasattr(
                rel, 'measurement') else None)

            cp_db = CProduction.get_production_counts(
                lhs=n,
                parent=parent,
                lmk_class=lmk_class,
                lmk_ori_rels=lmk_ori_rels,
                lmk_color=lmk_color,
                rel=rel_class,
                dist_class=dist_class,
                deg_class=deg_class)

            if cp_db.count() <= 0:
                logger(
                    'Could not expand %s (parent: %s, lmk_class: %s, lmk_ori_rels: %s, lmk_color: %s, rel: %s, dist_class: %s, deg_class: %s)'
                    % (n, parent, lmk_class, lmk_ori_rels, lmk_color,
                       rel_class, dist_class, deg_class))
                terminals.append(n)
                continue

            ckeys, ccounts = zip(*[(cprod.rhs, cprod.count)
                                   for cprod in cp_db.all()])

            ccounter = {}
            for cprod in cp_db.all():
                if cprod.rhs in ccounter: ccounter[cprod.rhs] += cprod.count
                else: ccounter[cprod.rhs] = cprod.count

            ckeys, ccounts = zip(*ccounter.items())

            # print 'ckeys', ckeys
            # print 'ccounts', ccounts

            ccounts = np.array(ccounts, dtype=float)
            ccounts /= ccounts.sum()

            cprod, cprod_prob, cprod_entropy = categorical_sample(
                ckeys, ccounts)
            # print cprod, cprod_prob, cprod_entropy

            lhs_rhs_parent_chain.append((n, cprod, parent, lmk))
            prob_chain.append(cprod_prob)
            entropy_chain.append(cprod_entropy)

            lrpc, pc, ec, t, ls = get_expansion(lhs=cprod,
                                                parent=n,
                                                lmk=lmk,
                                                rel=rel)
            lhs_rhs_parent_chain.extend(lrpc)
            prob_chain.extend(pc)
            entropy_chain.extend(ec)
            terminals.extend(t)
            landmarks.extend(ls)
        else:
            terminals.append(n)
            landmarks.append(lmk)

    return lhs_rhs_parent_chain, prob_chain, entropy_chain, terminals, landmarks
Exemplo n.º 13
0
def get_words(terminals, landmarks, rel=None, prevword=None):
    words = []
    probs = []
    alphas = []
    entropy = []
    C = CWord.get_count

    for n, lmk in zip(terminals, landmarks):
        # if we could not get an expansion for the LHS, we just pass down the unexpanded nonterminal symbol
        # it gets the probability of 1 and entropy of 0
        if n in NONTERMINALS:
            words.append(n)
            probs.append(1.0)
            entropy.append(0.0)
            continue

        lmk_class = (lmk.object_class if lmk else None)
        lmk_color = (lmk.color if lmk else None)
        rel_class = rel_type(rel)
        dist_class = (rel.measurement.best_distance_class if hasattr(
            rel, 'measurement') else None)
        deg_class = (rel.measurement.best_degree_class if hasattr(
            rel, 'measurement') else None)

        meaning = dict(pos=n,
                       lmk_class=lmk_class,
                       lmk_ori_rels=get_lmk_ori_rels_str(lmk),
                       lmk_color=lmk_color,
                       rel=rel_class,
                       rel_dist_class=dist_class,
                       rel_deg_class=deg_class)

        cp_db_uni = CWord.get_word_counts(**meaning)

        ccounter = {}
        for c in cp_db_uni:
            ccounter[c.word] = ccounter.get(c.word, 0) + c.count
        ckeys, ccounts_uni = zip(*ccounter.items())
        ccounts_uni = np.array(ccounts_uni, dtype=float)
        ccounts_uni /= ccounts_uni.sum()

        prev_word = words[-1] if words else prevword
        alpha = C(prev_word=prev_word, **meaning) / C(**meaning)
        alphas.append(alpha)

        if alpha:
            cp_db_bi = CWord.get_word_counts(prev_word=prev_word, **meaning)

            ccounter = {}
            for c in cp_db_bi:
                ccounter[c.word] = ccounter.get(c.word, 0) + c.count
            ccounts_bi = np.array([ccounter.get(k, 0) for k in ckeys],
                                  dtype=float)
            ccounts_bi /= ccounts_bi.sum()

            cprob = (alpha * ccounts_bi) + ((1 - alpha) * ccounts_uni)

        else:
            cprob = ccounts_uni

        # if cp_db.count() <= 0:
        # logger( 'Could not expand %s (lmk_class: %s, lmk_color: %s, rel: %s, dist_class: %s, deg_class: %s)' % (n, lmk_class, lmk_color, rel_class, dist_class, deg_class) )
        # terminals.append( n )
        # continue

        # ckeys, ccounts = zip(*[(cword.word,cword.count) for cword in cp_db.all()])

        # ccounter = {}
        # for cword in cp_db.all():
        #     if cword.word in ccounter: ccounter[cword.word] += cword.count
        #     else: ccounter[cword.word] = cword.count

        # ckeys, ccounts = zip(*ccounter.items())

        # print 'ckeys', ckeys
        # print 'ccounts', ccounts

        # ccounts = np.array(ccounts, dtype=float)
        # ccounts /= ccounts.sum()

        w, w_prob, w_entropy = categorical_sample(ckeys, cprob)
        words.append(w)
        probs.append(w_prob)
        entropy.append(w_entropy)

    p, H = np.prod(probs), np.sum(entropy)
    # print 'expanding %s to %s (p: %f, H: %f)' % (terminals, words, p, H)
    return words, p, H, alphas
Exemplo n.º 14
0
def accept_correction( meaning, correction, update_func='geometric', update_scale=10, eval_lmk=True, multiply=False, printing=True ):
    (lmk, lmk_prob, lmk_ent,
     rel, rel_prob, rel_ent,
     rel_exp_chain, rele_prob_chain, rele_ent_chain, rel_terminals, rel_landmarks,
     lmk_exp_chain, lmke_prob_chain, lmke_ent_chain, lmk_terminals, lmk_landmarks,
     rel_words, relw_prob, relw_ent,
     lmk_words, lmkw_prob, lmkw_ent) = meaning.args
    rel_a = meaning.rel_a
    lmk_a = meaning.lmk_a


    try:

        if correction is None:
            # If we recieved no correction, reward our generated sentence INSTEAD of decrementing
            dec_update = (1 if multiply else update_scale)
        else:
            # otherwise we decrement it
            old_meaning_prob, old_meaning_entropy, lrpc, tps = get_sentence_meaning_likelihood( correction, lmk, rel, printing=printing)

            if eval_lmk:
                update = update_funcs[update_func](lmk_prob * rel_prob, old_meaning_prob, lmk_ent + rel_ent, old_meaning_entropy) * (1 if multiply else update_scale)
            else:
                update = rel_prob * (0.1 if multiply else update_scale)


            logger('Update functions is %s and update value is: %f (using multiply: %s)' % (update_func, update, multiply))
            # print 'lmk_prob, lmk_ent, rel_prob, rel_ent, old_meaning_prob, old_meaning_entropy, update', lmk_prob, lmk_ent, rel_prob, rel_ent, old_meaning_prob, old_meaning_entropy, update
            # print lmk.object_class, type(rel)

            dec_update = -update


        for lhs,rhs,parent,_ in rel_exp_chain:
            # print 'Decrementing production - lhs: %s, rhs: %s, parent: %s' % (lhs,rhs,parent)
            update_expansion_counts( dec_update, lhs, rhs, parent, rel=rel, multiply=multiply )

        for lhs,rhs,parent,lmk in lmk_exp_chain:
            # print 'Decrementing production - lhs: %s, rhs: %s, parent: %s' % (lhs,rhs,parent)
            update_expansion_counts( dec_update, lhs, rhs, parent, lmk_class=(lmk.object_class if lmk else None),
                                                                   lmk_ori_rels=get_lmk_ori_rels_str(lmk),
                                                                   lmk_color=(lmk.color if lmk else None),
                                                                   multiply=multiply )

        data = zip(rel_terminals, rel_words)
        for i in xrange(len(data)):
            term,word = data[i]
            prev_word = data[i-1][1] if i > 0 else None
            a = rel_a[i]
            # print 'Decrementing word - pos: %s, word: %s, rel: %s' % (term, word, rel)
            update_word_counts( (1-a)*dec_update, term, word, rel=rel, multiply=multiply )
            update_word_counts(a*dec_update, term, word, rel=rel, prev_word=prev_word, multiply=multiply )

        data = zip(lmk_terminals, lmk_words, lmk_landmarks)
        for i in xrange(len(data)):
            term, word, lmk = data[i]
            prev_word = data[i-1][1] if i > 0 else rel_words[-1]
            a = lmk_a[i]
            # print 'Decrementing word - pos: %s, word: %s, lmk_class: %s' % (term, word, lmk.object_class)
            update_word_counts((1-a)*dec_update, term, word, lmk_class=(lmk.object_class if lmk else None),
                                                             lmk_ori_rels=get_lmk_ori_rels_str(lmk),
                                                             lmk_color=(lmk.color if lmk else None),
                                                             multiply=multiply )
            update_word_counts( a*dec_update, term, word, prev_word, lmk_class=(lmk.object_class if lmk else None),
                                                                     lmk_ori_rels=get_lmk_ori_rels_str(lmk),
                                                                     lmk_color=(lmk.color if lmk else None),
                                                                     multiply=multiply )

        # Only enforce the correction if we got one
        if correction is not None:
            # reward new words with old meaning
            for lhs,rhs,parent,lmk,rel in lrpc:
                # print 'Incrementing production - lhs: %s, rhs: %s, parent: %s' % (lhs,rhs,parent)
                update_expansion_counts( update, lhs, rhs, parent, rel=rel,
                                                                   lmk_class=(lmk.object_class if lmk else None),
                                                                   lmk_ori_rels=get_lmk_ori_rels_str(lmk),
                                                                   lmk_color=(lmk.color if lmk else None), 
                                                                   multiply=multiply )

            for i in xrange(len(tps)):
                lhs,rhs,lmk,rel = tps[i]
                prev_word = tps[i-1][1] if i > 0 else None
                # print 'Incrementing word - pos: %s, word: %s, lmk_class: %s' % (lhs, rhs, (lmk.object_class if lmk else None) )
                update_word_counts( update, lhs, rhs, prev_word, lmk_class=(lmk.object_class if lmk else None),
                                                                 rel=rel,
                                                                 lmk_ori_rels=get_lmk_ori_rels_str(lmk),
                                                                 lmk_color=(lmk.color if lmk else None),
                                                                 multiply=multiply )
    except ParseError as pe:
        logger( pe )
        logger( 'No update performed' )