Exemplo n.º 1
0
def check_semirings(graph):
    weights = [10.0] * len(graph.edges)
    weights2 = [0.5] * len(graph.edges)
    potentials = np.array(weights)
    edge_marg = pydecode.marginals(graph, potentials,
                                   weight_type=pydecode.Viterbi)

    log_potentials = np.array(weights)
    potentials = np.array(weights)
    chart = pydecode.inside(graph, log_potentials)
    chart2 = pydecode.inside(graph, potentials)

    # Array-form.
    for node in graph.nodes:
        nt.assert_equal(chart[node.id], chart2[node.id])

    # Matrix-form.
    numpy.testing.assert_array_almost_equal(
        chart, chart2, decimal=4)


    marg = pydecode.marginals(graph, log_potentials)
    marg2 = pydecode.marginals(graph, potentials)

    for edge in graph.edges:
        nt.assert_almost_equal(marg[edge.id], marg2[edge.id])

    potentials = np.array(weights2)
Exemplo n.º 2
0
def test_pruning():
    for h in utils.hypergraphs():

        w = numpy.random.random(len(h.edges))

        original_path = pydecode.best_path(h, w)
        marginals = pydecode.marginals(h, w)
        best = w.T * original_path.v
        print marginals[1]
        a = np.array(marginals > 0.99 * best, dtype=np.uint8)
Exemplo n.º 3
0
def test_pruning():
    for h in utils.hypergraphs():

        w = numpy.random.random(len(h.edges))

        original_path = pydecode.best_path(h, w)
        marginals = pydecode.marginals(h, w)
        best = w.T * original_path.v
        print marginals[1]
        a = np.array(marginals > 0.99* best, dtype=np.uint8)
Exemplo n.º 4
0
def check_posteriors(graph, pot):
    """
    Check the posteriors by enumeration.
    """

    node_marg = pydecode.marginals(graph, pot)

    paths = utils.all_paths(graph)
    m = defaultdict(lambda: 0.0)
    total_score = 0.0
    for path in paths:
        #path_score = prod([pot[edge.id] for edge in path.edges])
        path_score = np.exp(np.log(pot.T) * path.v)
        total_score += path_score
        for edge in path:
            m[edge.id] += path_score
Exemplo n.º 5
0
def check_max_marginals(graph, pot):
    """
    Test that max-marginals are correct.
    """

    path = pydecode.best_path(graph, pot)
    best = pot.T * path.v
    # print "BEST"
    # print "\n".join(["%20s : %s" % (edge.label, pot[edge.id])
    #                  for edge in path.edges])
    # print best
    nt.assert_not_equal(best, 0.0)
    max_marginals = pydecode.marginals(graph, pot)

    # Array-form.
    for edge in graph.edges:
        other = max_marginals[edge.id]
        nt.assert_less_equal(other, best + 1e-4)

    # Matrix-form.
    assert (max_marginals < best + 1e-4).all()
Exemplo n.º 6
0
 def compute_marginals(self, label_scores):
     self.compute_weights(label_scores)
     edge_marginals = pydecode.marginals(self.graph, self.weights)
     return edge_marginals