def test_max_marginals():
    """
    Test that max-marginals are correct.
    """
    for h in hypergraphs():
        w = utils.random_viterbi_potentials(h)
        print w.show(h)

        path = ph.best_path(h, w)
        best = w.dot(path)
        print "BEST"

        print "\n".join(["%20s : %s"%(edge.label, w[edge]) for edge in path.edges])
        print best
        nt.assert_not_equal(best, 0.0)
        max_marginals = ph.compute_marginals(h, w)
        for node in h.nodes:
            other = max_marginals[node]
            nt.assert_less_equal(other, best + 1e-4)

        for edge in h.edges:
            other = max_marginals[edge]
            nt.assert_less_equal(other, best + 1e-4)
            if edge in path:
                nt.assert_almost_equal(other, best)
def test_variables():
    """
    Test variable constraint checking.
    """
    for h in hypergraphs():
        w = utils.random_viterbi_potentials(h)
        variables, edge = random_constraint_trans(h)
        path = ph.best_path(h, w)
        match = list(variables.check(path))
        if edge not in path:
            print "Should not have", edge.id
            assert "have" in match
            assert "not" not in match
        else:
            print "Should have", edge.id
            assert "have" not in match

        nt.assert_equal(len(match), 1)
def test_outside():
    """
    Test outside chart properties.
    """
    for h in hypergraphs():
        w = utils.random_viterbi_potentials(h)
        path = ph.best_path(h, w)
        chart = ph.inside_values(h, w)
        best = w.dot(path)
        nt.assert_not_equal(best, 0.0)
        out_chart = ph.outside_values(h, w, chart)
        for node in h.nodes:
            other = chart[node] * out_chart[node]
            nt.assert_less_equal(other, best + 1e-4)
        for edge in path.edges:
            for node in edge.tail:
                if node.is_terminal:
                    nt.assert_almost_equal(other, best)
def test_pruning():
    for h in hypergraphs():
        w = utils.random_viterbi_potentials(h)

        original_path = ph.best_path(h, w)
        new_hyper, new_potentials = ph.prune_hypergraph(h, w, -0.99)
        prune_path = ph.best_path(new_hyper, new_potentials)
        assert len(original_path.edges) > 0
        for edge in original_path.edges:
            assert edge in prune_path
        valid_path(new_hyper, prune_path)

        original_score = w.dot(original_path)
        print original_score
        print new_potentials.dot(prune_path)
        nt.assert_almost_equal(original_score,
                               new_potentials.dot(prune_path))

        # Test pruning amount.
        prune = 0.001
        max_marginals = ph.compute_marginals(h, w)
        new_hyper, new_potentials = ph.prune_hypergraph(h, w, prune)

        assert (len(new_hyper.edges) > 0)
        original_edges = {}
        for edge in h.edges:
            original_edges[edge.label] = edge

        new_edges = {}
        for edge in new_hyper.edges:
            new_edges[edge.label] = edge

        for name, edge in new_edges.iteritems():

            orig = original_edges[name]
            nt.assert_almost_equal(w[orig], new_potentials[edge])
            m = max_marginals[orig]
            nt.assert_greater(m, prune)