Ejemplo n.º 1
0
def test_rnn_module(caplog):
    """Unit test of RNN Module"""

    caplog.set_level(logging.INFO)

    n_class = 2
    emb_size = 10
    lstm_hidden = 20
    batch_size = 3
    seq_len = 4

    # Single direction RNN
    rnn = RNN(
        num_classes=n_class,
        emb_size=emb_size,
        lstm_hidden=lstm_hidden,
        attention=True,
        dropout=0.2,
        bidirectional=False,
    )
    _, input_mask = pad_batch(torch.randn(batch_size, seq_len))

    assert rnn(torch.randn(batch_size, seq_len,
                           emb_size)).size() == (3, n_class)
    assert rnn(torch.randn(batch_size, seq_len, emb_size),
               input_mask).size() == (
                   3,
                   n_class,
               )

    # Bi-direction RNN
    rnn = RNN(
        num_classes=0,
        emb_size=emb_size,
        lstm_hidden=lstm_hidden,
        attention=False,
        dropout=0.2,
        bidirectional=True,
    )

    _, input_mask = pad_batch(torch.randn(batch_size, seq_len))

    assert rnn(torch.randn(batch_size, seq_len, emb_size)).size() == (
        3,
        2 * lstm_hidden,
    )
    assert rnn(torch.randn(batch_size, seq_len, emb_size),
               input_mask).size() == (
                   3,
                   2 * lstm_hidden,
               )
Ejemplo n.º 2
0
def test_pad_batch(caplog):
    """Unit test of pad batch"""

    caplog.set_level(logging.INFO)

    batch = [torch.Tensor([1, 2]), torch.Tensor([3]), torch.Tensor([4, 5, 6])]
    padded_batch, mask_batch = pad_batch(batch)

    assert torch.equal(padded_batch,
                       torch.Tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]]))
    assert torch.equal(
        mask_batch, mask_batch.new_tensor([[0, 0, 1], [0, 1, 1], [0, 0, 0]]))

    padded_batch, mask_batch = pad_batch(batch, max_len=2)

    assert torch.equal(padded_batch, torch.Tensor([[1, 2], [3, 0], [4, 5]]))
    assert torch.equal(mask_batch,
                       mask_batch.new_tensor([[0, 0], [0, 1], [0, 0]]))

    padded_batch, mask_batch = pad_batch(batch, pad_value=-1)

    assert torch.equal(padded_batch,
                       torch.Tensor([[1, 2, -1], [3, -1, -1], [4, 5, 6]]))
    assert torch.equal(
        mask_batch, mask_batch.new_tensor([[0, 0, 1], [0, 1, 1], [0, 0, 0]]))

    padded_batch, mask_batch = pad_batch(batch, left_padded=True)

    assert torch.equal(padded_batch,
                       torch.Tensor([[0, 1, 2], [0, 0, 3], [4, 5, 6]]))
    assert torch.equal(
        mask_batch, mask_batch.new_tensor([[1, 0, 0], [1, 1, 0], [0, 0, 0]]))

    padded_batch, mask_batch = pad_batch(batch, max_len=2, left_padded=True)

    assert torch.equal(padded_batch, torch.Tensor([[1, 2], [0, 3], [5, 6]]))
    assert torch.equal(mask_batch,
                       mask_batch.new_tensor([[0, 0], [1, 0], [0, 0]]))