Esempio n. 1
0
def test_hmm():
    C, V, batch, N = 5, 20, 2, 5
    transition = torch.rand(C, C)
    emission = torch.rand(V, C)
    init = torch.rand(C)
    observations = torch.randint(0, V, (batch, N))
    out = LinearChain.hmm(transition, emission, init, observations)
    LinearChain().sum(out)
Esempio n. 2
0
def test_linear_chain_counting(batch, N, C):
    vals = torch.ones(batch, N, C, C)
    semiring = StdSemiring
    alpha = LinearChain(semiring).sum(vals)
    c = pow(C, N + 1)
    assert (alpha == c).all()
Esempio n. 3
0
def test_lc_custom():
    model = LinearChain
    vals, _ = model._rand()

    struct = LinearChain(LogSemiring)
    marginals = struct.marginals(vals)
    s = struct.sum(vals)

    struct = LinearChain(CheckpointSemiring(LogSemiring, 1))
    marginals2 = struct.marginals(vals)
    s2 = struct.sum(vals)
    assert torch.isclose(s, s2).all()
    assert torch.isclose(marginals, marginals2).all()

    struct = LinearChain(CheckpointShardSemiring(LogSemiring, 1))
    marginals2 = struct.marginals(vals)
    s2 = struct.sum(vals)
    assert torch.isclose(s, s2).all()
    assert torch.isclose(marginals, marginals2).all()
Esempio n. 4
0
def test_sparse_max2():
    print(LinearChain(SparseMaxSemiring).sum(torch.rand(1, 8, 3, 3)))
    print(LinearChain(SparseMaxSemiring).marginals(torch.rand(1, 8, 3, 3)))