def test_check_equal(): """ Make sure our helper function for checking tensor equality actually works. """ a = torch.tensor([[1, 2], [3, 4]]) b = torch.tensor([[1, 2], [3, 4]]) utils.assert_equal(a, b)
def test_target(dataset): """Make sure the `target` attr contains what we expect.""" assert len(dataset.target) == 4 assert_equal(dataset.target[0], torch.tensor([0, 0])) assert_equal(dataset.target[1], torch.tensor([0, 0, 0, 1])) assert_equal(dataset.target[2], torch.tensor([0, 0, 1, 2])) assert_equal(dataset.target[3], torch.tensor([0]))
def test_source(dataset): """Make sure the `source` attr contains what we expect.""" assert len(dataset.source) == 4 # Check the first item, which comes from the sentence: # [("hi", "O"), ("there", "O")] words, word_lens, idxs, word_idxs, context = dataset.source[0] assert isinstance(words, torch.Tensor) assert isinstance(word_lens, torch.Tensor) assert isinstance(idxs, torch.Tensor) assert isinstance(word_idxs, torch.Tensor) assert context is None assert list(words.size()) == [2, 5] assert_equal(word_lens, torch.tensor([5, 2])) assert_equal(idxs, torch.tensor([1, 0])) assert list(word_idxs.size()) == [2]
def test_sent2tensor(vocab, sent): """Check that Vocab.sent2tensor has the correct output format.""" char_tensors, word_lengths, word_idxs, word_tensors, context = \ vocab.sent2tensor(sent) check_lens = [len(s) for s in sent] check_sorted_lens = sorted(check_lens, reverse=True) check_idxs = sorted(range(len(sent)), reverse=True, key=lambda i: check_lens[i]) # Verify sizes. assert isinstance(char_tensors, torch.Tensor) assert list(char_tensors.size()) == [len(sent), max(check_lens)] assert isinstance(word_tensors, torch.Tensor) assert list(word_tensors.size()) == [len(sent)] # Verify order of word lengths and idxs. assert_equal(word_lengths, torch.tensor(check_sorted_lens)) assert_equal(word_idxs, torch.tensor(check_idxs)) for i, word_tensor in enumerate(char_tensors): check_word = sent[check_idxs[i]] check_word_tensor = torch.tensor( [vocab.chars_stoi[c] for c in check_word] + [0] * (max(check_lens) - len(check_word))) assert_equal(word_tensor, check_word_tensor)
def test_labs2tensor(vocab, labs, is_test, check): res = vocab.labs2tensor(labs, test=is_test) assert_equal(res, check)
def test_unsort(inputs, check): utils.assert_equal(utils.unsort(*inputs), check)
def test_sort_and_pad(inputs, check): padded, lens, idx = utils.sort_and_pad(*inputs) utils.assert_equal(padded, check[0]) utils.assert_equal(lens, check[1]) utils.assert_equal(idx, check[2])
def test_pad(inputs, check): utils.assert_equal(utils.pad(*inputs), check)
def test_sequence_mask(inp, chk, max_len): """Test `sequence_mask()` method.""" res = utils.sequence_mask(inp, max_len=max_len) utils.assert_equal(res, chk)