def test_forward_batch_stable(generate_examples_and_batch): i1, _, l1, i2, _, l2, i, _, l = generate_examples_and_batch h = i1.size(2) crf = CRF(h, batch_first=False) crf.transitions_p.data = torch.rand(1, h, h) fw1 = crf.forward((i1, l1)) fw2 = crf.forward((i2, l2)) one_x_one = torch.cat([fw1, fw2], dim=0) batched = crf.forward((i, l)) np.testing.assert_allclose(one_x_one.detach().numpy(), batched.detach().numpy())
def test_forward(generate_batch): unary, _, lengths = generate_batch h = unary.size(2) crf = CRF(h, batch_first=False) trans = torch.rand(h, h) crf.transitions_p.data = trans.unsqueeze(0) forward = crf.forward((unary, lengths)) new_trans = build_trans(trans) unary = unary.transpose(0, 1) scores = [] for u, l in zip(unary, lengths): emiss = build_emission(u[:l]) scores.append(explicit_forward(emiss, new_trans, Offsets.GO, Offsets.EOS)) gold_scores = np.array(scores) np.testing.assert_allclose(forward.detach().numpy(), gold_scores, rtol=1e-6)
def test_forward_shape(generate_batch): unary, _, lengths = generate_batch h = unary.size(2) crf = CRF(h, batch_first=False) fwd = crf.forward((unary, lengths)) assert fwd.shape == torch.Size([unary.size(1)])