示例#1
0
def test_decode_shape(generate_batch):
    unary, _, lengths = generate_batch
    h = unary.size(2)
    crf = CRF(h)
    paths, scores = crf.decode(unary, lengths)
    assert scores.shape == torch.Size([unary.size(1)])
    assert paths.shape == torch.Size([unary.size(0), unary.size(1)])
示例#2
0
def test_decode_shape(generate_batch):
    unary, _, lengths = generate_batch
    h = unary.size(2)
    crf = CRF(h)
    paths, scores = crf.decode(unary, lengths)
    assert scores.shape == torch.Size([unary.size(1)])
    assert paths.shape == torch.Size([unary.size(0), unary.size(1)])
示例#3
0
def test_score_sentence_batch_stable(generate_examples_and_batch):
    i1, t1, l1, i2, t2, l2, i, t, l = generate_examples_and_batch
    h = i1.size(2)
    crf = CRF(h)
    crf.transitions_p.data = torch.rand(1, h, h)
    score1 = crf.score_sentence(i1, t1, l1, l1.size(0))
    score2 = crf.score_sentence(i2, t2, l2, l2.size(0))
    one_x_one = torch.cat([score1, score2], dim=0)
    batched = crf.score_sentence(i, t, l, l.size(0))
    np.testing.assert_allclose(one_x_one.detach().numpy(), batched.detach().numpy())
示例#4
0
def test_neg_log_loss_batch_stable(generate_examples_and_batch):
    i1, t1, l1, i2, t2, l2, i, t, l = generate_examples_and_batch
    h = i1.size(2)
    crf = CRF(h)
    crf.transitions_p.data = torch.rand(1, h, h)
    nll1 = crf.neg_log_loss(i1, t1, l1)
    nll2 = crf.neg_log_loss(i2, t2, l2)
    one_x_one = torch.cat([nll1, nll2], dim=0)
    batched = crf.neg_log_loss(i, t, l)
    np.testing.assert_allclose(one_x_one.detach().numpy(), batched.detach().numpy())
示例#5
0
def test_forward_batch_stable(generate_examples_and_batch):
    i1, _, l1, i2, _, l2, i, _, l = generate_examples_and_batch
    h = i1.size(2)
    crf = CRF(h)
    crf.transitions_p.data = torch.rand(1, h, h)
    fw1 = crf.forward(i1, l1, l1.size(0))
    fw2 = crf.forward(i2, l2, l2.size(0))
    one_x_one = torch.cat([fw1, fw2], dim=0)
    batched = crf.forward(i, l, l.size(0))
    np.testing.assert_allclose(one_x_one.detach().numpy(), batched.detach().numpy())
示例#6
0
def test_forward_batch_stable(generate_examples_and_batch):
    i1, _, l1, i2, _, l2, i, _, l = generate_examples_and_batch
    h = i1.size(2)
    crf = CRF(h)
    crf.transitions_p.data = torch.rand(1, h, h)
    fw1 = crf.forward(i1, l1, l1.size(0))
    fw2 = crf.forward(i2, l2, l2.size(0))
    one_x_one = torch.cat([fw1, fw2], dim=0)
    batched = crf.forward(i, l, l.size(0))
    np.testing.assert_allclose(one_x_one.detach().numpy(),
                               batched.detach().numpy())
示例#7
0
def test_neg_log_loss_batch_stable(generate_examples_and_batch):
    i1, t1, l1, i2, t2, l2, i, t, l = generate_examples_and_batch
    h = i1.size(2)
    crf = CRF(h)
    crf.transitions_p.data = torch.rand(1, h, h)
    nll1 = crf.neg_log_loss(i1, t1, l1)
    nll2 = crf.neg_log_loss(i2, t2, l2)
    one_x_one = torch.cat([nll1, nll2], dim=0)
    batched = crf.neg_log_loss(i, t, l)
    np.testing.assert_allclose(one_x_one.detach().numpy(),
                               batched.detach().numpy())
示例#8
0
def test_score_sentence_batch_stable(generate_examples_and_batch):
    i1, t1, l1, i2, t2, l2, i, t, l = generate_examples_and_batch
    h = i1.size(2)
    crf = CRF(h)
    crf.transitions_p.data = torch.rand(1, h, h)
    score1 = crf.score_sentence(i1, t1, l1, l1.size(0))
    score2 = crf.score_sentence(i2, t2, l2, l2.size(0))
    one_x_one = torch.cat([score1, score2], dim=0)
    batched = crf.score_sentence(i, t, l, l.size(0))
    np.testing.assert_allclose(one_x_one.detach().numpy(),
                               batched.detach().numpy())
示例#9
0
def test_decode_batch_stable(generate_examples_and_batch):
    i1, _, l1, i2, _, l2, i, _, l = generate_examples_and_batch
    h = i1.size(2)
    crf = CRF(h)
    crf.transitions_p.data = torch.rand(1, h, h)
    p1, s1 = crf.decode(i1, l1)
    p2, s2 = crf.decode(i2, l2)
    pad = torch.zeros(p1.size(0) - p2.size(0), 1, dtype=torch.long)
    one_x_one_p = torch.cat([p1, torch.cat([p2, pad], dim=0)], dim=1)
    one_x_one_s = torch.cat([s1, s2], dim=0)
    batched_p, batched_s = crf.decode(i, l)
    np.testing.assert_allclose(one_x_one_s.detach().numpy(), batched_s.detach().numpy())
    for p1, p2 in zip(one_x_one_p, batched_p):
        np.testing.assert_allclose(p1.detach().numpy(), p2.detach().numpy())
示例#10
0
def test_decode_batch_stable(generate_examples_and_batch):
    i1, _, l1, i2, _, l2, i, _, l = generate_examples_and_batch
    h = i1.size(2)
    crf = CRF(h)
    crf.transitions_p.data = torch.rand(1, h, h)
    p1, s1 = crf.decode(i1, l1)
    p2, s2 = crf.decode(i2, l2)
    pad = torch.zeros(p1.size(0) - p2.size(0), 1, dtype=torch.long)
    one_x_one_p = torch.cat([p1, torch.cat([p2, pad], dim=0)], dim=1)
    one_x_one_s = torch.cat([s1, s2], dim=0)
    batched_p, batched_s = crf.decode(i, l)
    np.testing.assert_allclose(one_x_one_s.detach().numpy(),
                               batched_s.detach().numpy())
    for p1, p2 in zip(one_x_one_p, batched_p):
        np.testing.assert_allclose(p1.detach().numpy(), p2.detach().numpy())
示例#11
0
def test_forward(generate_batch):
    unary, _, lengths = generate_batch
    h = unary.size(2)
    crf = CRF(h)
    trans = torch.rand(h, h)
    crf.transitions_p.data = trans.unsqueeze(0)
    forward = crf.forward(unary, lengths, unary.size(1))

    new_trans = build_trans(trans)
    unary = unary.transpose(0, 1)
    scores = []
    for u, l in zip(unary, lengths):
        emiss = build_emission(u[:l])
        scores.append(explicit_forward(emiss, new_trans, Offsets.GO, Offsets.EOS))
    gold_scores = np.array(scores)
    np.testing.assert_allclose(forward.detach().numpy(), gold_scores, rtol=1e-6)
示例#12
0
def test_mask_skipped(label_vocab):
    from baseline.pytorch.crf import CRF
    crf = CRF(
        len(label_vocab),
        (label_vocab[S], label_vocab[E]),
    )
    t = crf.transitions.detach().numpy()
    assert t[0, label_vocab['<GO>'], label_vocab['O']] != -1e4
示例#13
0
def test_mask_is_applied():
    h = np.random.randint(22, 41)
    loc = np.random.randint(h)
    constraint = torch.zeros(h, h, dtype=torch.uint8)
    constraint[Offsets.GO, loc] = 1
    crf = CRF(h, constraint=constraint)
    t = crf.transitions.detach().numpy()
    assert t[0, Offsets.GO, loc] == -1e4
示例#14
0
def test_mask_same_after_update(generate_batch):
    from torch.optim import SGD
    unary, tags, lengths = generate_batch
    h = unary.size(2)
    constraint = torch.rand(h, h) < 0.5
    crf = CRF(h, constraint=constraint)
    opt = SGD(crf.parameters(), lr=10)
    m1 = crf.constraint.numpy()
    t1 = crf.transitions_p.detach().clone().numpy()
    l = crf.neg_log_loss(unary, tags, lengths)
    l = torch.mean(l)
    l.backward()
    opt.step()
    m2 = crf.constraint.numpy()
    t2 = crf.transitions_p.detach().numpy()
    np.testing.assert_allclose(m1, m2)
    with pytest.raises(AssertionError):
        np.testing.assert_allclose(t1, t2)
示例#15
0
def test_neg_log_loss(generate_batch):
    unary, tags, lengths = generate_batch
    h = unary.size(2)
    crf = CRF(h)
    trans = torch.rand(h, h)
    crf.transitions_p.data = trans.unsqueeze(0)
    nll = crf.neg_log_loss(unary, tags, lengths)

    new_trans = build_trans(trans)
    unary = unary.transpose(0, 1)
    tags = tags.transpose(0, 1)
    scores = []
    for u, t, l in zip(unary, tags, lengths):
        emiss = build_emission(u[:l])
        golds = t[:l].tolist()
        scores.append(explicit_nll(emiss, new_trans, golds, Offsets.GO, Offsets.EOS))
    gold_scores = np.array(scores)
    np.testing.assert_allclose(nll.detach().numpy(), gold_scores, rtol=1e-6)
示例#16
0
def test_mask_same_after_update(generate_batch):
    from torch.optim import SGD
    unary, tags, lengths = generate_batch
    h = unary.size(2)
    constraint = torch.rand(h, h) < 0.5
    crf = CRF(h, constraint=constraint)
    opt = SGD(crf.parameters(), lr=10)
    m1 = crf.constraint.numpy()
    t1 = crf.transitions_p.detach().clone().numpy()
    l = crf.neg_log_loss(unary, tags, lengths)
    l = torch.mean(l)
    l.backward()
    opt.step()
    m2 = crf.constraint.numpy()
    t2 = crf.transitions_p.detach().numpy()
    np.testing.assert_allclose(m1, m2)
    with pytest.raises(AssertionError):
        np.testing.assert_allclose(t1, t2)
示例#17
0
def test_neg_log_loss(generate_batch):
    unary, tags, lengths = generate_batch
    h = unary.size(2)
    crf = CRF(h)
    trans = torch.rand(h, h)
    crf.transitions_p.data = trans.unsqueeze(0)
    nll = crf.neg_log_loss(unary, tags, lengths)

    new_trans = build_trans(trans)
    unary = unary.transpose(0, 1)
    tags = tags.transpose(0, 1)
    scores = []
    for u, t, l in zip(unary, tags, lengths):
        emiss = build_emission(u[:l])
        golds = t[:l].tolist()
        scores.append(
            explicit_nll(emiss, new_trans, golds, Offsets.GO, Offsets.EOS))
    gold_scores = np.array(scores)
    np.testing.assert_allclose(nll.detach().numpy(), gold_scores, rtol=1e-6)
示例#18
0
def test_forward(generate_batch):
    unary, _, lengths = generate_batch
    h = unary.size(2)
    crf = CRF(h)
    trans = torch.rand(h, h)
    crf.transitions_p.data = trans.unsqueeze(0)
    forward = crf.forward(unary, lengths, unary.size(1))

    new_trans = build_trans(trans)
    unary = unary.transpose(0, 1)
    scores = []
    for u, l in zip(unary, lengths):
        emiss = build_emission(u[:l])
        scores.append(
            explicit_forward(emiss, new_trans, Offsets.GO, Offsets.EOS))
    gold_scores = np.array(scores)
    np.testing.assert_allclose(forward.detach().numpy(),
                               gold_scores,
                               rtol=1e-6)
示例#19
0
def crf(label_vocab):
    from baseline.pytorch.crf import CRF, transition_mask
    mask = transition_mask(
        label_vocab, SPAN_TYPE,
        label_vocab[S], label_vocab[E], label_vocab[P]
    )
    return CRF(
        len(label_vocab),
        (label_vocab[S], label_vocab[E]), True,
        mask
    )
示例#20
0
    new_path = []
    for i in range(len(best_path)):
        i = len(best_path) - i - 1
        new_path.append(best_path[i])
    return torch.stack(new_path[1:]), path_score


if __name__ == '__main__':
    from baseline.pytorch.crf import CRF, transition_mask
    vocab = [
        "<GO>", "<EOS>", "B-X", "I-X", "E-X", "S-X", "O", "B-Y", "I-Y", "E-Y",
        "S-Y"
    ]
    vocab = {k: i for i, k in enumerate(vocab)}
    mask = transition_mask(vocab, "IOBES", 0, 1)
    crf = CRF(10, (0, 1), batch_first=False)
    trans = crf.transitions

    icrf = InferenceCRF(torch.nn.Parameter(trans.squeeze(0)), 0, 1, False)

    u = torch.rand(20, 1, 10)
    l = torch.full((1, ), 20, dtype=torch.long)
    print(crf.decode(u, l))
    print(icrf.decode(u, l))

    u = torch.rand(15, 1, 10)
    traced_model = torch.jit.trace(icrf.decode, (u, l))
    traced_model.save('crf.pt')
    traced_model = torch.jit.load('crf.pt')

    u = torch.rand(8, 1, 10)
示例#21
0
        best_tag_id = backpointers[i][best_tag_id]
        best_path.append(best_tag_id)

    new_path = []
    for i in range(len(best_path)):
        i = len(best_path) - i - 1
        new_path.append(best_path[i])
    return torch.stack(new_path[1:]), path_score


if __name__ == '__main__':
    from baseline.pytorch.crf import CRF, transition_mask
    vocab = ["<GO>", "<EOS>", "B-X", "I-X", "E-X", "S-X", "O", "B-Y", "I-Y", "E-Y", "S-Y"]
    vocab = {k: i for i, k in enumerate(vocab)}
    mask = transition_mask(vocab, "IOBES", 0, 1)
    crf = CRF(10, (0, 1), batch_first=False)
    trans = crf.transitions

    icrf = InferenceCRF(torch.nn.Parameter(trans.squeeze(0)), 0, 1, False)

    u = torch.rand(20, 1, 10)
    l = torch.full((1,), 20, dtype=torch.long)
    print(crf.decode(u, l))
    print(icrf.decode(u, l))

    u = torch.rand(15, 1, 10)
    traced_model = torch.jit.trace(icrf.decode, (u, l))
    traced_model.save('crf.pt')
    traced_model = torch.jit.load('crf.pt')

    u = torch.rand(8, 1, 10)
示例#22
0
def test_forward_shape(generate_batch):
    unary, _, lengths = generate_batch
    h = unary.size(2)
    crf = CRF(h)
    fwd = crf.forward(unary, lengths, lengths.size(0))
    assert fwd.shape == torch.Size([unary.size(1)])
示例#23
0
def test_score_sentence_shape(generate_batch):
    unary, tags, lengths = generate_batch
    h = unary.size(2)
    crf = CRF(h)
    score = crf.score_sentence(unary, tags, lengths, lengths.size(0))
    assert score.shape == torch.Size([unary.size(1)])
示例#24
0
def test_neg_log_loss_shape(generate_batch):
    unary, tags, lengths = generate_batch
    h = unary.size(2)
    crf = CRF(h)
    nll = crf.neg_log_loss(unary, tags, lengths)
    assert nll.shape == torch.Size([unary.size(1)])
示例#25
0
def test_neg_log_loss_shape(generate_batch):
    unary, tags, lengths = generate_batch
    h = unary.size(2)
    crf = CRF(h)
    nll = crf.neg_log_loss(unary, tags, lengths)
    assert nll.shape == torch.Size([unary.size(1)])
示例#26
0
def test_score_sentence_shape(generate_batch):
    unary, tags, lengths = generate_batch
    h = unary.size(2)
    crf = CRF(h)
    score = crf.score_sentence(unary, tags, lengths, lengths.size(0))
    assert score.shape == torch.Size([unary.size(1)])
示例#27
0
def test_mask_not_applied():
    h = np.random.randint(22, 41)
    crf = CRF(h)
    t = crf.transitions.detach().numpy()
    assert t[0, Offsets.GO, np.random.randint(h)] != -1e4
示例#28
0
def test_forward_shape(generate_batch):
    unary, _, lengths = generate_batch
    h = unary.size(2)
    crf = CRF(h)
    fwd = crf.forward(unary, lengths, lengths.size(0))
    assert fwd.shape == torch.Size([unary.size(1)])