def test_hmm(): C, V, batch, N = 5, 20, 2, 5 transition = torch.rand(C, C) emission = torch.rand(V, C) init = torch.rand(C) observations = torch.randint(0, V, (batch, N)) out = LinearChain.hmm(transition, emission, init, observations) LinearChain().sum(out)
def test_linear_chain_counting(batch, N, C): vals = torch.ones(batch, N, C, C) semiring = StdSemiring alpha = LinearChain(semiring).sum(vals) c = pow(C, N + 1) assert (alpha == c).all()
def test_lc_custom(): model = LinearChain vals, _ = model._rand() struct = LinearChain(LogSemiring) marginals = struct.marginals(vals) s = struct.sum(vals) struct = LinearChain(CheckpointSemiring(LogSemiring, 1)) marginals2 = struct.marginals(vals) s2 = struct.sum(vals) assert torch.isclose(s, s2).all() assert torch.isclose(marginals, marginals2).all() struct = LinearChain(CheckpointShardSemiring(LogSemiring, 1)) marginals2 = struct.marginals(vals) s2 = struct.sum(vals) assert torch.isclose(s, s2).all() assert torch.isclose(marginals, marginals2).all()
def test_sparse_max2(): print(LinearChain(SparseMaxSemiring).sum(torch.rand(1, 8, 3, 3))) print(LinearChain(SparseMaxSemiring).marginals(torch.rand(1, 8, 3, 3)))