def test_viterbi_batch_stable(generate_examples_and_batch): i1, _, l1, i2, _, l2, i, _, l = generate_examples_and_batch h = i1.size(2) trans = torch.rand(1, h, h) p1, s1 = Viterbi(Offsets.GO, Offsets.EOS)(i1, trans, l1) p2, s2 = Viterbi(Offsets.GO, Offsets.EOS)(i2, trans, l2) pad = torch.zeros(p1.size(0) - p2.size(0), 1, dtype=torch.long) one_x_one_p = torch.cat([p1, torch.cat([p2, pad], dim=0)], dim=1) one_x_one_s = torch.cat([s1, s2], dim=0) batched_p, batched_s = Viterbi(Offsets.GO, Offsets.EOS)(i, trans, l) np.testing.assert_allclose(one_x_one_s.detach().numpy(), batched_s.detach().numpy()) np.testing.assert_allclose(one_x_one_p.detach().numpy(), batched_p.detach().numpy())
def test_decode_shape(generate_batch): unary, _, lengths = generate_batch h = unary.size(2) trans = torch.rand(1, h, h) viterbi = Viterbi(Offsets.GO, Offsets.EOS) paths, scores = viterbi(unary, trans, lengths) assert scores.shape == torch.Size([unary.size(1)]) assert paths.shape == torch.Size([unary.size(0), unary.size(1)])
def test_viterbi_score_equals_sentence_score(generate_batch): """Test that the scores from viterbi decoding are the same scores that you get when looking up those returned paths.""" unary, _, lengths = generate_batch h = unary.size(2) trans = torch.rand(h, h) crf = CRF(h) p, viterbi_scores = Viterbi(Offsets.GO, Offsets.EOS)(unary, crf.transitions, lengths) gold_scores = crf.score_sentence(unary, p, lengths) np.testing.assert_allclose(viterbi_scores.detach().numpy(), gold_scores.detach().numpy(), rtol=1e-6)
def test_viterbi(generate_batch): unary, _, lengths = generate_batch h = unary.size(2) trans = torch.rand(h, h) pyt_path, pyt_scores = Viterbi(Offsets.GO, Offsets.EOS)(unary, trans.unsqueeze(0), lengths) new_trans = build_trans(trans) unary = unary.transpose(0, 1) paths = [] scores = [] for u, l in zip(unary, lengths): emiss = build_emission(u[:l]) p, s = explicit_viterbi(emiss, new_trans, Offsets.GO, Offsets.EOS) scores.append(s) paths.append(p) gold_scores = np.array(scores) np.testing.assert_allclose(pyt_scores.detach().numpy(), gold_scores, rtol=1e-6) pyt_path = pyt_path.transpose(0, 1) for pp, l, p in zip(pyt_path, lengths, paths): assert pp[:l].tolist() == p
def test_viterbi_degenerates_to_argmax(generate_batch): scores, _, l = generate_batch h = scores.size(2) # Then transitions are all zeros then it just greedily selects the best # state at that given emission. This is the same as doing argmax. trans = torch.zeros((1, h, h)) viterbi = Viterbi(Offsets.GO, Offsets.EOS) p, s = viterbi(scores, trans, l) s_gold, p_gold = torch.max(scores, 2) # Mask out the argmax results from past the lengths for i, sl in enumerate(l): s_gold[sl:, i] = 0 p_gold[sl:, i] = 0 s_gold = torch.sum(s_gold, 0) np.testing.assert_allclose(p.detach().numpy(), p_gold.detach().numpy()) np.testing.assert_allclose(s.detach().numpy(), s_gold.detach().numpy())