コード例 #1
0
ファイル: test_crf_pytorch.py プロジェクト: dpressel/baseline
def test_decode_batch_stable(generate_examples_and_batch):
    i1, _, l1, i2, _, l2, i, _, l = generate_examples_and_batch
    h = i1.size(2)
    crf = CRF(h)
    crf.transitions_p.data = torch.rand(1, h, h)
    p1, s1 = crf.decode(i1, l1)
    p2, s2 = crf.decode(i2, l2)
    pad = torch.zeros(p1.size(0) - p2.size(0), 1, dtype=torch.long)
    one_x_one_p = torch.cat([p1, torch.cat([p2, pad], dim=0)], dim=1)
    one_x_one_s = torch.cat([s1, s2], dim=0)
    batched_p, batched_s = crf.decode(i, l)
    np.testing.assert_allclose(one_x_one_s.detach().numpy(), batched_s.detach().numpy())
    for p1, p2 in zip(one_x_one_p, batched_p):
        np.testing.assert_allclose(p1.detach().numpy(), p2.detach().numpy())
コード例 #2
0
def test_decode_batch_stable(generate_examples_and_batch):
    i1, _, l1, i2, _, l2, i, _, l = generate_examples_and_batch
    h = i1.size(2)
    crf = CRF(h)
    crf.transitions_p.data = torch.rand(1, h, h)
    p1, s1 = crf.decode(i1, l1)
    p2, s2 = crf.decode(i2, l2)
    pad = torch.zeros(p1.size(0) - p2.size(0), 1, dtype=torch.long)
    one_x_one_p = torch.cat([p1, torch.cat([p2, pad], dim=0)], dim=1)
    one_x_one_s = torch.cat([s1, s2], dim=0)
    batched_p, batched_s = crf.decode(i, l)
    np.testing.assert_allclose(one_x_one_s.detach().numpy(),
                               batched_s.detach().numpy())
    for p1, p2 in zip(one_x_one_p, batched_p):
        np.testing.assert_allclose(p1.detach().numpy(), p2.detach().numpy())
コード例 #3
0
ファイル: test_crf_pytorch.py プロジェクト: dpressel/baseline
def test_decode_shape(generate_batch):
    unary, _, lengths = generate_batch
    h = unary.size(2)
    crf = CRF(h)
    paths, scores = crf.decode(unary, lengths)
    assert scores.shape == torch.Size([unary.size(1)])
    assert paths.shape == torch.Size([unary.size(0), unary.size(1)])
コード例 #4
0
def test_decode_shape(generate_batch):
    unary, _, lengths = generate_batch
    h = unary.size(2)
    crf = CRF(h)
    paths, scores = crf.decode(unary, lengths)
    assert scores.shape == torch.Size([unary.size(1)])
    assert paths.shape == torch.Size([unary.size(0), unary.size(1)])
コード例 #5
0
ファイル: tagger_decoders.py プロジェクト: dpressel/baseline
    return torch.stack(new_path[1:]), path_score


if __name__ == '__main__':
    from baseline.pytorch.crf import CRF, transition_mask
    vocab = ["<GO>", "<EOS>", "B-X", "I-X", "E-X", "S-X", "O", "B-Y", "I-Y", "E-Y", "S-Y"]
    vocab = {k: i for i, k in enumerate(vocab)}
    mask = transition_mask(vocab, "IOBES", 0, 1)
    crf = CRF(10, (0, 1), batch_first=False)
    trans = crf.transitions

    icrf = InferenceCRF(torch.nn.Parameter(trans.squeeze(0)), 0, 1, False)

    u = torch.rand(20, 1, 10)
    l = torch.full((1,), 20, dtype=torch.long)
    print(crf.decode(u, l))
    print(icrf.decode(u, l))

    u = torch.rand(15, 1, 10)
    traced_model = torch.jit.trace(icrf.decode, (u, l))
    traced_model.save('crf.pt')
    traced_model = torch.jit.load('crf.pt')

    u = torch.rand(8, 1, 10)
    l = torch.full((1,), 8, dtype=torch.long)

    print(crf.decode(u, l))
    print(traced_model(u, l))

    u = torch.rand(22, 1, 10)
    l = torch.full((1,), 22, dtype=torch.long)
コード例 #6
0
if __name__ == '__main__':
    from baseline.pytorch.crf import CRF, transition_mask
    vocab = [
        "<GO>", "<EOS>", "B-X", "I-X", "E-X", "S-X", "O", "B-Y", "I-Y", "E-Y",
        "S-Y"
    ]
    vocab = {k: i for i, k in enumerate(vocab)}
    mask = transition_mask(vocab, "IOBES", 0, 1)
    crf = CRF(10, (0, 1), batch_first=False)
    trans = crf.transitions

    icrf = InferenceCRF(torch.nn.Parameter(trans.squeeze(0)), 0, 1, False)

    u = torch.rand(20, 1, 10)
    l = torch.full((1, ), 20, dtype=torch.long)
    print(crf.decode(u, l))
    print(icrf.decode(u, l))

    u = torch.rand(15, 1, 10)
    traced_model = torch.jit.trace(icrf.decode, (u, l))
    traced_model.save('crf.pt')
    traced_model = torch.jit.load('crf.pt')

    u = torch.rand(8, 1, 10)
    l = torch.full((1, ), 8, dtype=torch.long)

    print(crf.decode(u, l))
    print(traced_model(u, l))

    u = torch.rand(22, 1, 10)
    l = torch.full((1, ), 22, dtype=torch.long)