Example #1
0
def test_decode_batch_stable(generate_examples_and_batch):
   i1, _, l1, i2, _, l2, i, _, l = generate_examples_and_batch
   h = i1.size(2)
   crf = CRF(h, batch_first=False)
   crf.transitions_p.data = torch.rand(1, h, h)
   p1, s1 = crf.decode(i1, l1)
   p2, s2 = crf.decode(i2, l2)
   pad = torch.zeros(p1.size(0) - p2.size(0), 1, dtype=torch.long)
   one_x_one_p = torch.cat([p1, torch.cat([p2, pad], dim=0)], dim=1)
   one_x_one_s = torch.cat([s1, s2], dim=0)
   batched_p, batched_s =  crf.decode(i, l)
   np.testing.assert_allclose(one_x_one_s.detach().numpy(), batched_s.detach().numpy())
   for p1, p2 in zip(one_x_one_p, batched_p):
       np.testing.assert_allclose(p1.detach().numpy(), p2.detach().numpy())
Example #2
0
def test_decode_shape_crf(generate_batch):
    unary, _, lengths = generate_batch
    h = unary.size(2)
    crf = CRF(h, batch_first=False)
    paths, scores = crf.decode(unary, lengths)
    assert scores.shape == torch.Size([unary.size(1)])
    assert paths.shape == torch.Size([unary.size(0), unary.size(1)])