def test_inside_outside2(): pcfg = FixedPCFG("x", terminals=["c", "d"], nonterminals=["x"], preterminals=["b"], productions=[("b", "b"), ("b", "x")], binary_weights=np.array([[0.25, 0.75]]), unary_weights=np.array([[0.5, 0.5]])) sentence = "c d d".split() alphas, betas, backtrace = parse(pcfg, sentence) from pprint import pprint pprint(list(zip(pcfg.nonterminals, alphas))) pprint(list(zip(pcfg.nonterminals, betas))) tree_from_backtrace(pcfg, sentence, backtrace).pretty_print() assert_equal(tree_from_backtrace(pcfg, sentence, backtrace), Tree.fromstring("(x (b c) (x (b d) (b d)))")) # check alpha[x] np.testing.assert_allclose(alphas[0], [[0, 0.0625, 0.023438], [0, 0, 0.0625], [0, 0, 0]], atol=1e-5) # check alpha[b] (preterminals) np.testing.assert_allclose(alphas[1], [[0.5, 0, 0], [0, 0.5, 0], [0, 0, 0.5]])
def test_inside_outside_em_update(): pcfg = FixedPCFG("x", terminals=["c", "d"], nonterminals=["x"], preterminals=["b"], productions=[("b", "b"), ("b", "x")]) sentence = "c d d".split() prev_total_prob = 0 for i in range(20): alphas, betas, backtrace = parse(pcfg, sentence) total_prob = alphas[pcfg.nonterm2idx[pcfg.start], 0, len(sentence) - 1] tree_from_backtrace(pcfg, sentence, backtrace).pretty_print() print("%d\t%f" % (i, total_prob)) # NB include small tolerance due to float imprecision assert total_prob - prev_total_prob >= 0, \ "Total prob should never decrease: %f -> %f (iter %d)" % \ (prev_total_prob, total_prob, i) prev_total_prob = total_prob pcfg = update_em(pcfg, sentence)
pcfg = P.FixedPCFG("S", terminals=list(vocabulary), nonterminals=["S", "NP", "VP", "VP$"], preterminals=["N", "V", "D"], productions=[("NP", "VP"), ("V", "NP"), ("V", "VP$"), ("NP", "NP"), ("D", "N")]) prev_ll = -np.inf for e in trange(40, desc="Epoch"): for sentence in tqdm(sentences): pcfg = I.update_em(pcfg, sentence) # Calculate total probability of corpus. ll = 0 for sentence in tqdm(sentences): alphas, betas, _ = I.parse(pcfg, sentence) total_prob = alphas[pcfg.nonterm2idx[pcfg.start], 0, len(sentence) - 1] ll += np.log(total_prob) tqdm.write("%i ll: %f" % (e, ll)) if ll - prev_ll > 0 and ll - prev_ll <= 1e-3: break prev_ll = ll for sentence in sentences: print(" ".join(sentence)) alphas, betas, backtrace = I.parse(pcfg, sentence) tree = P.tree_from_backtrace(pcfg, sentence, backtrace) tree.pretty_print()