def test_decode_shape_crf(generate_batch): unary, _, lengths = generate_batch h = unary.size(2) crf = CRF(h, batch_first=False) paths, scores = crf.decode(unary, lengths) assert scores.shape == torch.Size([unary.size(1)]) assert paths.shape == torch.Size([unary.size(0), unary.size(1)])
def test_score_sentence_batch_stable(generate_examples_and_batch): i1, t1, l1, i2, t2, l2, i, t, l = generate_examples_and_batch h = i1.size(2) crf = CRF(h, batch_first=False) crf.transitions_p.data = torch.rand(1, h, h) score1 = crf.score_sentence(i1, t1, l1) score2 = crf.score_sentence(i2, t2, l2) one_x_one = torch.cat([score1, score2], dim=0) batched = crf.score_sentence(i, t, l) np.testing.assert_allclose(one_x_one.detach().numpy(), batched.detach().numpy())
def test_forward_batch_stable(generate_examples_and_batch): i1, _, l1, i2, _, l2, i, _, l = generate_examples_and_batch h = i1.size(2) crf = CRF(h, batch_first=False) crf.transitions_p.data = torch.rand(1, h, h) fw1 = crf.forward((i1, l1)) fw2 = crf.forward((i2, l2)) one_x_one = torch.cat([fw1, fw2], dim=0) batched = crf.forward((i, l)) np.testing.assert_allclose(one_x_one.detach().numpy(), batched.detach().numpy())
def test_viterbi_score_equals_sentence_score(generate_batch): """Test that the scores from viterbi decoding are the same scores that you get when looking up those returned paths.""" unary, _, lengths = generate_batch h = unary.size(2) trans = torch.rand(h, h) crf = CRF(h) p, viterbi_scores = Viterbi(Offsets.GO, Offsets.EOS)(unary, crf.transitions, lengths) gold_scores = crf.score_sentence(unary, p, lengths) np.testing.assert_allclose(viterbi_scores.detach().numpy(), gold_scores.detach().numpy(), rtol=1e-6)
def test_neg_log_loss_batch_stable(generate_examples_and_batch): i1, t1, l1, i2, t2, l2, i, t, l = generate_examples_and_batch h = i1.size(2) crf = CRF(h, batch_first=False) crf.transitions_p.data = torch.rand(1, h, h) nll1 = crf.neg_log_loss(i1, t1, l1) nll2 = crf.neg_log_loss(i2, t2, l2) one_x_one = (nll1 + nll2) / 2 batched = crf.neg_log_loss(i, t, l) np.testing.assert_allclose(one_x_one.detach().numpy(), batched.detach().numpy())
def test_decode_batch_stable(generate_examples_and_batch): i1, _, l1, i2, _, l2, i, _, l = generate_examples_and_batch h = i1.size(2) crf = CRF(h, batch_first=False) crf.transitions_p.data = torch.rand(1, h, h) p1, s1 = crf.decode(i1, l1) p2, s2 = crf.decode(i2, l2) pad = torch.zeros(p1.size(0) - p2.size(0), 1, dtype=torch.long) one_x_one_p = torch.cat([p1, torch.cat([p2, pad], dim=0)], dim=1) one_x_one_s = torch.cat([s1, s2], dim=0) batched_p, batched_s = crf.decode(i, l) np.testing.assert_allclose(one_x_one_s.detach().numpy(), batched_s.detach().numpy()) for p1, p2 in zip(one_x_one_p, batched_p): np.testing.assert_allclose(p1.detach().numpy(), p2.detach().numpy())
def test_forward(generate_batch): unary, _, lengths = generate_batch h = unary.size(2) crf = CRF(h, batch_first=False) trans = torch.rand(h, h) crf.transitions_p.data = trans.unsqueeze(0) forward = crf.forward((unary, lengths)) new_trans = build_trans(trans) unary = unary.transpose(0, 1) scores = [] for u, l in zip(unary, lengths): emiss = build_emission(u[:l]) scores.append(explicit_forward(emiss, new_trans, Offsets.GO, Offsets.EOS)) gold_scores = np.array(scores) np.testing.assert_allclose(forward.detach().numpy(), gold_scores, rtol=1e-6)
def test_mask_is_applied(): h = np.random.randint(22, 41) loc = np.random.randint(h) constraint = torch.zeros(h, h, dtype=torch.uint8) constraint[Offsets.GO, loc] = 1 crf = CRF(h, constraint_mask=constraint) t = crf.transitions.detach().numpy() assert t[0, Offsets.GO, loc] == -1e4
def test_score_sentence(generate_batch): unary, tags, lengths = generate_batch h = unary.size(2) crf = CRF(h, batch_first=False) trans = torch.rand(h, h) crf.transitions_p.data = trans.unsqueeze(0) sentence_score = crf.score_sentence(unary, tags, lengths) new_trans = build_trans(trans) unary = unary.transpose(0, 1) tags = tags.transpose(0, 1) scores = [] for u, t, l in zip(unary, tags, lengths): emiss = build_emission(u[:l]) golds = t[:l].tolist() scores.append(explicit_score_gold(emiss, new_trans, golds, Offsets.GO, Offsets.EOS)) gold_scores = np.array(scores) np.testing.assert_allclose(sentence_score.detach().numpy(), gold_scores, rtol=1e-6)
def test_mask_same_after_update(generate_batch): from torch.optim import SGD unary, tags, lengths = generate_batch h = unary.size(2) constraint = torch.rand(h, h) < 0.5 crf = CRF(h, constraint_mask=constraint, batch_first=False) opt = SGD(crf.parameters(), lr=10) m1 = crf.constraint_mask.numpy() t1 = crf.transitions_p.detach().clone().numpy() l = crf.neg_log_loss(unary, tags, lengths) l = torch.mean(l) l.backward() opt.step() m2 = crf.constraint_mask.numpy() t2 = crf.transitions_p.detach().numpy() np.testing.assert_allclose(m1, m2) with pytest.raises(AssertionError): np.testing.assert_allclose(t1, t2)
def test_mask_not_applied(): h = np.random.randint(22, 41) crf = CRF(h) t = crf.transitions.detach().numpy() assert t[0, Offsets.GO, np.random.randint(h)] != -1e4
def test_forward_shape(generate_batch): unary, _, lengths = generate_batch h = unary.size(2) crf = CRF(h, batch_first=False) fwd = crf.forward((unary, lengths)) assert fwd.shape == torch.Size([unary.size(1)])
def test_score_sentence_shape(generate_batch): unary, tags, lengths = generate_batch h = unary.size(2) crf = CRF(h, batch_first=False) score = crf.score_sentence(unary, tags, lengths) assert score.shape == torch.Size([unary.size(1)])