def test_decode_batch_stable(generate_examples_and_batch): i1, _, l1, i2, _, l2, i, _, l = generate_examples_and_batch h = i1.size(2) crf = CRF(h, batch_first=False) crf.transitions_p.data = torch.rand(1, h, h) p1, s1 = crf.decode(i1, l1) p2, s2 = crf.decode(i2, l2) pad = torch.zeros(p1.size(0) - p2.size(0), 1, dtype=torch.long) one_x_one_p = torch.cat([p1, torch.cat([p2, pad], dim=0)], dim=1) one_x_one_s = torch.cat([s1, s2], dim=0) batched_p, batched_s = crf.decode(i, l) np.testing.assert_allclose(one_x_one_s.detach().numpy(), batched_s.detach().numpy()) for p1, p2 in zip(one_x_one_p, batched_p): np.testing.assert_allclose(p1.detach().numpy(), p2.detach().numpy())
def test_decode_shape_crf(generate_batch): unary, _, lengths = generate_batch h = unary.size(2) crf = CRF(h, batch_first=False) paths, scores = crf.decode(unary, lengths) assert scores.shape == torch.Size([unary.size(1)]) assert paths.shape == torch.Size([unary.size(0), unary.size(1)])