def test_score_sentence_batch_stable(generate_examples_and_batch):
    i1, t1, l1, i2, t2, l2, i, t, l = generate_examples_and_batch
    h = i1.size(2)
    crf = CRF(h, batch_first=False)
    crf.transitions_p.data = torch.rand(1, h, h)
    score1 = crf.score_sentence(i1, t1, l1)
    score2 = crf.score_sentence(i2, t2, l2)
    one_x_one = torch.cat([score1, score2], dim=0)
    batched = crf.score_sentence(i, t, l)
    np.testing.assert_allclose(one_x_one.detach().numpy(), batched.detach().numpy())
Beispiel #2
0
def test_viterbi_score_equals_sentence_score(generate_batch):
    """Test that the scores from viterbi decoding are the same scores that you get when looking up those returned paths."""
    unary, _, lengths = generate_batch
    h = unary.size(2)
    trans = torch.rand(h, h)
    crf = CRF(h)

    p, viterbi_scores = Viterbi(Offsets.GO, Offsets.EOS)(unary, crf.transitions, lengths)
    gold_scores = crf.score_sentence(unary, p, lengths)
    np.testing.assert_allclose(viterbi_scores.detach().numpy(), gold_scores.detach().numpy(), rtol=1e-6)
def test_score_sentence(generate_batch):
    unary, tags, lengths = generate_batch
    h = unary.size(2)
    crf = CRF(h, batch_first=False)
    trans = torch.rand(h, h)
    crf.transitions_p.data = trans.unsqueeze(0)
    sentence_score = crf.score_sentence(unary, tags, lengths)

    new_trans = build_trans(trans)
    unary = unary.transpose(0, 1)
    tags = tags.transpose(0, 1)
    scores = []
    for u, t, l in zip(unary, tags, lengths):
        emiss = build_emission(u[:l])
        golds = t[:l].tolist()
        scores.append(explicit_score_gold(emiss, new_trans, golds, Offsets.GO, Offsets.EOS))
    gold_scores = np.array(scores)
    np.testing.assert_allclose(sentence_score.detach().numpy(), gold_scores, rtol=1e-6)
Beispiel #4
0
def test_score_sentence_shape(generate_batch):
    unary, tags, lengths = generate_batch
    h = unary.size(2)
    crf = CRF(h, batch_first=False)
    score = crf.score_sentence(unary, tags, lengths)
    assert score.shape == torch.Size([unary.size(1)])