def test_max_marginals(): """ Test that max-marginals are correct. """ for h in hypergraphs(): w = utils.random_viterbi_potentials(h) print w.show(h) path = ph.best_path(h, w) best = w.dot(path) print "BEST" print "\n".join(["%20s : %s"%(edge.label, w[edge]) for edge in path.edges]) print best nt.assert_not_equal(best, 0.0) max_marginals = ph.compute_marginals(h, w) for node in h.nodes: other = max_marginals[node] nt.assert_less_equal(other, best + 1e-4) for edge in h.edges: other = max_marginals[edge] nt.assert_less_equal(other, best + 1e-4) if edge in path: nt.assert_almost_equal(other, best)
def test_variables(): """ Test variable constraint checking. """ for h in hypergraphs(): w = utils.random_viterbi_potentials(h) variables, edge = random_constraint_trans(h) path = ph.best_path(h, w) match = list(variables.check(path)) if edge not in path: print "Should not have", edge.id assert "have" in match assert "not" not in match else: print "Should have", edge.id assert "have" not in match nt.assert_equal(len(match), 1)
def test_outside(): """ Test outside chart properties. """ for h in hypergraphs(): w = utils.random_viterbi_potentials(h) path = ph.best_path(h, w) chart = ph.inside_values(h, w) best = w.dot(path) nt.assert_not_equal(best, 0.0) out_chart = ph.outside_values(h, w, chart) for node in h.nodes: other = chart[node] * out_chart[node] nt.assert_less_equal(other, best + 1e-4) for edge in path.edges: for node in edge.tail: if node.is_terminal: nt.assert_almost_equal(other, best)
def test_pruning(): for h in hypergraphs(): w = utils.random_viterbi_potentials(h) original_path = ph.best_path(h, w) new_hyper, new_potentials = ph.prune_hypergraph(h, w, -0.99) prune_path = ph.best_path(new_hyper, new_potentials) assert len(original_path.edges) > 0 for edge in original_path.edges: assert edge in prune_path valid_path(new_hyper, prune_path) original_score = w.dot(original_path) print original_score print new_potentials.dot(prune_path) nt.assert_almost_equal(original_score, new_potentials.dot(prune_path)) # Test pruning amount. prune = 0.001 max_marginals = ph.compute_marginals(h, w) new_hyper, new_potentials = ph.prune_hypergraph(h, w, prune) assert (len(new_hyper.edges) > 0) original_edges = {} for edge in h.edges: original_edges[edge.label] = edge new_edges = {} for edge in new_hyper.edges: new_edges[edge.label] = edge for name, edge in new_edges.iteritems(): orig = original_edges[name] nt.assert_almost_equal(w[orig], new_potentials[edge]) m = max_marginals[orig] nt.assert_greater(m, prune)