def test_rnn_module(caplog): """Unit test of RNN Module""" caplog.set_level(logging.INFO) n_class = 2 emb_size = 10 lstm_hidden = 20 batch_size = 3 seq_len = 4 # Single direction RNN rnn = RNN( num_classes=n_class, emb_size=emb_size, lstm_hidden=lstm_hidden, attention=True, dropout=0.2, bidirectional=False, ) _, input_mask = pad_batch(torch.randn(batch_size, seq_len)) assert rnn(torch.randn(batch_size, seq_len, emb_size)).size() == (3, n_class) assert rnn(torch.randn(batch_size, seq_len, emb_size), input_mask).size() == ( 3, n_class, ) # Bi-direction RNN rnn = RNN( num_classes=0, emb_size=emb_size, lstm_hidden=lstm_hidden, attention=False, dropout=0.2, bidirectional=True, ) _, input_mask = pad_batch(torch.randn(batch_size, seq_len)) assert rnn(torch.randn(batch_size, seq_len, emb_size)).size() == ( 3, 2 * lstm_hidden, ) assert rnn(torch.randn(batch_size, seq_len, emb_size), input_mask).size() == ( 3, 2 * lstm_hidden, )
def test_pad_batch(caplog): """Unit test of pad batch""" caplog.set_level(logging.INFO) batch = [torch.Tensor([1, 2]), torch.Tensor([3]), torch.Tensor([4, 5, 6])] padded_batch, mask_batch = pad_batch(batch) assert torch.equal(padded_batch, torch.Tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]])) assert torch.equal( mask_batch, mask_batch.new_tensor([[0, 0, 1], [0, 1, 1], [0, 0, 0]])) padded_batch, mask_batch = pad_batch(batch, max_len=2) assert torch.equal(padded_batch, torch.Tensor([[1, 2], [3, 0], [4, 5]])) assert torch.equal(mask_batch, mask_batch.new_tensor([[0, 0], [0, 1], [0, 0]])) padded_batch, mask_batch = pad_batch(batch, pad_value=-1) assert torch.equal(padded_batch, torch.Tensor([[1, 2, -1], [3, -1, -1], [4, 5, 6]])) assert torch.equal( mask_batch, mask_batch.new_tensor([[0, 0, 1], [0, 1, 1], [0, 0, 0]])) padded_batch, mask_batch = pad_batch(batch, left_padded=True) assert torch.equal(padded_batch, torch.Tensor([[0, 1, 2], [0, 0, 3], [4, 5, 6]])) assert torch.equal( mask_batch, mask_batch.new_tensor([[1, 0, 0], [1, 1, 0], [0, 0, 0]])) padded_batch, mask_batch = pad_batch(batch, max_len=2, left_padded=True) assert torch.equal(padded_batch, torch.Tensor([[1, 2], [0, 3], [5, 6]])) assert torch.equal(mask_batch, mask_batch.new_tensor([[0, 0], [1, 0], [0, 0]]))