def test_score_sentence_batch_stable(generate_examples_and_batch): i1, t1, l1, i2, t2, l2, i, t, l = generate_examples_and_batch h = i1.size(2) crf = CRF(h) crf.transitions_p.data = torch.rand(1, h, h) score1 = crf.score_sentence(i1, t1, l1, l1.size(0)) score2 = crf.score_sentence(i2, t2, l2, l2.size(0)) one_x_one = torch.cat([score1, score2], dim=0) batched = crf.score_sentence(i, t, l, l.size(0)) np.testing.assert_allclose(one_x_one.detach().numpy(), batched.detach().numpy())
def test_score_sentence(generate_batch): unary, tags, lengths = generate_batch h = unary.size(2) crf = CRF(h) trans = torch.rand(h, h) crf.transitions_p.data = trans.unsqueeze(0) sentence_score = crf.score_sentence(unary, tags, lengths, unary.size(1)) new_trans = build_trans(trans) unary = unary.transpose(0, 1) tags = tags.transpose(0, 1) scores = [] for u, t, l in zip(unary, tags, lengths): emiss = build_emission(u[:l]) golds = t[:l].tolist() scores.append(explicit_score_gold(emiss, new_trans, golds, Offsets.GO, Offsets.EOS)) gold_scores = np.array(scores) np.testing.assert_allclose(sentence_score.detach().numpy(), gold_scores, rtol=1e-6)
def test_score_sentence(generate_batch): unary, tags, lengths = generate_batch h = unary.size(2) crf = CRF(h) trans = torch.rand(h, h) crf.transitions_p.data = trans.unsqueeze(0) sentence_score = crf.score_sentence(unary, tags, lengths, unary.size(1)) new_trans = build_trans(trans) unary = unary.transpose(0, 1) tags = tags.transpose(0, 1) scores = [] for u, t, l in zip(unary, tags, lengths): emiss = build_emission(u[:l]) golds = t[:l].tolist() scores.append( explicit_score_gold(emiss, new_trans, golds, Offsets.GO, Offsets.EOS)) gold_scores = np.array(scores) np.testing.assert_allclose(sentence_score.detach().numpy(), gold_scores, rtol=1e-6)
def test_score_sentence_shape(generate_batch): unary, tags, lengths = generate_batch h = unary.size(2) crf = CRF(h) score = crf.score_sentence(unary, tags, lengths, lengths.size(0)) assert score.shape == torch.Size([unary.size(1)])