Ejemplo n.º 1
0
def test_viterbi_batch_stable(generate_examples_and_batch):
    i1, _, l1, i2, _, l2, i, _, l = generate_examples_and_batch
    h = i1.size(2)
    trans = torch.rand(1, h, h)
    p1, s1 = Viterbi(Offsets.GO, Offsets.EOS)(i1, trans, l1)
    p2, s2 = Viterbi(Offsets.GO, Offsets.EOS)(i2, trans, l2)
    pad = torch.zeros(p1.size(0) - p2.size(0), 1, dtype=torch.long)
    one_x_one_p = torch.cat([p1, torch.cat([p2, pad], dim=0)], dim=1)
    one_x_one_s = torch.cat([s1, s2], dim=0)
    batched_p, batched_s = Viterbi(Offsets.GO, Offsets.EOS)(i, trans, l)
    np.testing.assert_allclose(one_x_one_s.detach().numpy(), batched_s.detach().numpy())
    np.testing.assert_allclose(one_x_one_p.detach().numpy(), batched_p.detach().numpy())
Ejemplo n.º 2
0
def test_decode_shape(generate_batch):
    unary, _, lengths = generate_batch
    h = unary.size(2)
    trans = torch.rand(1, h, h)
    viterbi = Viterbi(Offsets.GO, Offsets.EOS)
    paths, scores = viterbi(unary, trans, lengths)
    assert scores.shape == torch.Size([unary.size(1)])
    assert paths.shape == torch.Size([unary.size(0), unary.size(1)])
Ejemplo n.º 3
0
def test_viterbi_score_equals_sentence_score(generate_batch):
    """Test that the scores from viterbi decoding are the same scores that you get when looking up those returned paths."""
    unary, _, lengths = generate_batch
    h = unary.size(2)
    trans = torch.rand(h, h)
    crf = CRF(h)

    p, viterbi_scores = Viterbi(Offsets.GO, Offsets.EOS)(unary, crf.transitions, lengths)
    gold_scores = crf.score_sentence(unary, p, lengths)
    np.testing.assert_allclose(viterbi_scores.detach().numpy(), gold_scores.detach().numpy(), rtol=1e-6)
Ejemplo n.º 4
0
def test_viterbi(generate_batch):
    unary, _, lengths = generate_batch
    h = unary.size(2)
    trans = torch.rand(h, h)
    pyt_path, pyt_scores = Viterbi(Offsets.GO, Offsets.EOS)(unary, trans.unsqueeze(0), lengths)

    new_trans = build_trans(trans)
    unary = unary.transpose(0, 1)
    paths = []
    scores = []
    for u, l in zip(unary, lengths):
        emiss = build_emission(u[:l])
        p, s = explicit_viterbi(emiss, new_trans, Offsets.GO, Offsets.EOS)
        scores.append(s)
        paths.append(p)
    gold_scores = np.array(scores)
    np.testing.assert_allclose(pyt_scores.detach().numpy(), gold_scores, rtol=1e-6)
    pyt_path = pyt_path.transpose(0, 1)
    for pp, l, p in zip(pyt_path, lengths, paths):
        assert pp[:l].tolist() == p
Ejemplo n.º 5
0
def test_viterbi_degenerates_to_argmax(generate_batch):
    scores, _, l = generate_batch
    h = scores.size(2)
    # Then transitions are all zeros then it just greedily selects the best
    # state at that given emission. This is the same as doing argmax.
    trans = torch.zeros((1, h, h))
    viterbi = Viterbi(Offsets.GO, Offsets.EOS)
    p, s = viterbi(scores, trans, l)
    s_gold, p_gold = torch.max(scores, 2)
    # Mask out the argmax results from past the lengths
    for i, sl in enumerate(l):
        s_gold[sl:, i] = 0
        p_gold[sl:, i] = 0
    s_gold = torch.sum(s_gold, 0)
    np.testing.assert_allclose(p.detach().numpy(), p_gold.detach().numpy())
    np.testing.assert_allclose(s.detach().numpy(), s_gold.detach().numpy())