예제 #1
0
def test_neg_log_loss_batch_stable(generate_examples_and_batch):
    i1, t1, l1, i2, t2, l2, i, t, l = generate_examples_and_batch
    h = i1.size(2)
    crf = CRF(h)
    crf.transitions_p.data = torch.rand(1, h, h)
    nll1 = crf.neg_log_loss(i1, t1, l1)
    nll2 = crf.neg_log_loss(i2, t2, l2)
    one_x_one = torch.cat([nll1, nll2], dim=0)
    batched = crf.neg_log_loss(i, t, l)
    np.testing.assert_allclose(one_x_one.detach().numpy(), batched.detach().numpy())
예제 #2
0
def test_neg_log_loss_batch_stable(generate_examples_and_batch):
    i1, t1, l1, i2, t2, l2, i, t, l = generate_examples_and_batch
    h = i1.size(2)
    crf = CRF(h)
    crf.transitions_p.data = torch.rand(1, h, h)
    nll1 = crf.neg_log_loss(i1, t1, l1)
    nll2 = crf.neg_log_loss(i2, t2, l2)
    one_x_one = torch.cat([nll1, nll2], dim=0)
    batched = crf.neg_log_loss(i, t, l)
    np.testing.assert_allclose(one_x_one.detach().numpy(),
                               batched.detach().numpy())
예제 #3
0
def test_mask_same_after_update(generate_batch):
    from torch.optim import SGD
    unary, tags, lengths = generate_batch
    h = unary.size(2)
    constraint = torch.rand(h, h) < 0.5
    crf = CRF(h, constraint=constraint)
    opt = SGD(crf.parameters(), lr=10)
    m1 = crf.constraint.numpy()
    t1 = crf.transitions_p.detach().clone().numpy()
    l = crf.neg_log_loss(unary, tags, lengths)
    l = torch.mean(l)
    l.backward()
    opt.step()
    m2 = crf.constraint.numpy()
    t2 = crf.transitions_p.detach().numpy()
    np.testing.assert_allclose(m1, m2)
    with pytest.raises(AssertionError):
        np.testing.assert_allclose(t1, t2)
예제 #4
0
def test_neg_log_loss(generate_batch):
    unary, tags, lengths = generate_batch
    h = unary.size(2)
    crf = CRF(h)
    trans = torch.rand(h, h)
    crf.transitions_p.data = trans.unsqueeze(0)
    nll = crf.neg_log_loss(unary, tags, lengths)

    new_trans = build_trans(trans)
    unary = unary.transpose(0, 1)
    tags = tags.transpose(0, 1)
    scores = []
    for u, t, l in zip(unary, tags, lengths):
        emiss = build_emission(u[:l])
        golds = t[:l].tolist()
        scores.append(explicit_nll(emiss, new_trans, golds, Offsets.GO, Offsets.EOS))
    gold_scores = np.array(scores)
    np.testing.assert_allclose(nll.detach().numpy(), gold_scores, rtol=1e-6)
예제 #5
0
def test_mask_same_after_update(generate_batch):
    from torch.optim import SGD
    unary, tags, lengths = generate_batch
    h = unary.size(2)
    constraint = torch.rand(h, h) < 0.5
    crf = CRF(h, constraint=constraint)
    opt = SGD(crf.parameters(), lr=10)
    m1 = crf.constraint.numpy()
    t1 = crf.transitions_p.detach().clone().numpy()
    l = crf.neg_log_loss(unary, tags, lengths)
    l = torch.mean(l)
    l.backward()
    opt.step()
    m2 = crf.constraint.numpy()
    t2 = crf.transitions_p.detach().numpy()
    np.testing.assert_allclose(m1, m2)
    with pytest.raises(AssertionError):
        np.testing.assert_allclose(t1, t2)
예제 #6
0
def test_neg_log_loss(generate_batch):
    unary, tags, lengths = generate_batch
    h = unary.size(2)
    crf = CRF(h)
    trans = torch.rand(h, h)
    crf.transitions_p.data = trans.unsqueeze(0)
    nll = crf.neg_log_loss(unary, tags, lengths)

    new_trans = build_trans(trans)
    unary = unary.transpose(0, 1)
    tags = tags.transpose(0, 1)
    scores = []
    for u, t, l in zip(unary, tags, lengths):
        emiss = build_emission(u[:l])
        golds = t[:l].tolist()
        scores.append(
            explicit_nll(emiss, new_trans, golds, Offsets.GO, Offsets.EOS))
    gold_scores = np.array(scores)
    np.testing.assert_allclose(nll.detach().numpy(), gold_scores, rtol=1e-6)
예제 #7
0
def test_neg_log_loss_shape(generate_batch):
    unary, tags, lengths = generate_batch
    h = unary.size(2)
    crf = CRF(h)
    nll = crf.neg_log_loss(unary, tags, lengths)
    assert nll.shape == torch.Size([unary.size(1)])
예제 #8
0
def test_neg_log_loss_shape(generate_batch):
    unary, tags, lengths = generate_batch
    h = unary.size(2)
    crf = CRF(h)
    nll = crf.neg_log_loss(unary, tags, lengths)
    assert nll.shape == torch.Size([unary.size(1)])