Ejemplo n.º 1
0
def sample(forest, B, v=None):
    """ Sample from parse forest. """
    if v is None:
        v = forest.root
    edges = forest.incoming[v]
    if not edges:
        # base case (leaf), nothing to sample
        return v
    # sample incoming edge, p(e|head) \propto edge.weight * (\prod_{b in e.body} beta[b])
    Z = LogVal.Zero()
    cs = []
    for e in edges:
        p = e.weight
        for y in e.body:
            p *= B[y]
        Z += p
        cs.append(Z.to_real())
    # sample one of the incoming edges
    i = np.array(cs).searchsorted(uniform(0, cs[-1]))
    e = edges[i]
    return Tree(v, [sample(forest, B, y) for y in e.body])
Ejemplo n.º 2
0
def _test_sample_tree(example, grammar, N):
    #    gold = {(X,I,K) for (X,I,K) in example.gold_items if (I,K) in example.nodes}
    print()
    _forest = parse_forest(example, grammar)
    # apply temperature to grammar rules
    forest = Hypergraph()
    forest.root = _forest.root
    for e in _forest.edges:
        c = LogVal.Zero()
        c.logeq(e.weight)
        forest.edge(c, e.head, *e.body)
    # run inside-outside
    B, A = sum_product(forest)
    Z = B[forest.root]
    # compute marginals and recall from samples
    #    sample_recall = 0.0
    m = defaultdict(float)
    for _ in iterview(range(N)):
        t = sample(forest, B)
        for s in t.subtrees():
            x = s.label()
            m[x] += 1.0 / N
#            xx = rename(grammar, x)
#            sample_recall += (xx in gold) * 1.0 / N
# convert node names and marginalize-out time index
    IO = defaultdict(float)
    for x in forest.incoming:
        IO[x] += (B[x] * A[x] / Z).to_real()
    # check marginals
    threshold = 1e-4
    for x in IO:
        (I, K, X, T) = x
        if K - I > 1:
            a = IO[x]
            b = m[x]
            if a > threshold or b > threshold:
                print('[%s %s %8s, %s] %7.3f %7.3f' \
                    % (I, K, X, T, a, b))
                assert abs(a - b) < 0.05
Ejemplo n.º 3
0
 def One(cls):
     return cls(LogVal.One(), LogVal.Zero())
Ejemplo n.º 4
0
 def Zero(cls):
     return cls(LogVal.Zero(), LogVal.Zero())
Ejemplo n.º 5
0
 def One():
     return Semiring1(LogVal.One(), LogVal.Zero())
Ejemplo n.º 6
0
 def One():
     return Semiring2(LogVal.One(), LogVal.Zero(), LogValVector(),
                      LogValVector())