Example #1
0
def test_check_equal():
    """
    Make sure our helper function for checking tensor equality actually works.
    """
    a = torch.tensor([[1, 2], [3, 4]])
    b = torch.tensor([[1, 2], [3, 4]])
    utils.assert_equal(a, b)
Example #2
0
def test_target(dataset):
    """Make sure the `target` attr contains what we expect."""
    assert len(dataset.target) == 4
    assert_equal(dataset.target[0], torch.tensor([0, 0]))
    assert_equal(dataset.target[1], torch.tensor([0, 0, 0, 1]))
    assert_equal(dataset.target[2], torch.tensor([0, 0, 1, 2]))
    assert_equal(dataset.target[3], torch.tensor([0]))
Example #3
0
def test_source(dataset):
    """Make sure the `source` attr contains what we expect."""
    assert len(dataset.source) == 4

    # Check the first item, which comes from the sentence:
    # [("hi", "O"), ("there", "O")]
    words, word_lens, idxs, word_idxs, context = dataset.source[0]
    assert isinstance(words, torch.Tensor)
    assert isinstance(word_lens, torch.Tensor)
    assert isinstance(idxs, torch.Tensor)
    assert isinstance(word_idxs, torch.Tensor)
    assert context is None

    assert list(words.size()) == [2, 5]
    assert_equal(word_lens, torch.tensor([5, 2]))
    assert_equal(idxs, torch.tensor([1, 0]))
    assert list(word_idxs.size()) == [2]
Example #4
0
def test_sent2tensor(vocab, sent):
    """Check that Vocab.sent2tensor has the correct output format."""
    char_tensors, word_lengths, word_idxs, word_tensors, context = \
        vocab.sent2tensor(sent)

    check_lens = [len(s) for s in sent]
    check_sorted_lens = sorted(check_lens, reverse=True)
    check_idxs = sorted(range(len(sent)),
                        reverse=True,
                        key=lambda i: check_lens[i])

    # Verify sizes.
    assert isinstance(char_tensors, torch.Tensor)
    assert list(char_tensors.size()) == [len(sent), max(check_lens)]
    assert isinstance(word_tensors, torch.Tensor)
    assert list(word_tensors.size()) == [len(sent)]

    # Verify order of word lengths and idxs.
    assert_equal(word_lengths, torch.tensor(check_sorted_lens))
    assert_equal(word_idxs, torch.tensor(check_idxs))

    for i, word_tensor in enumerate(char_tensors):
        check_word = sent[check_idxs[i]]
        check_word_tensor = torch.tensor(
            [vocab.chars_stoi[c]
             for c in check_word] + [0] * (max(check_lens) - len(check_word)))
        assert_equal(word_tensor, check_word_tensor)
Example #5
0
def test_labs2tensor(vocab, labs, is_test, check):
    res = vocab.labs2tensor(labs, test=is_test)
    assert_equal(res, check)
Example #6
0
def test_unsort(inputs, check):
    utils.assert_equal(utils.unsort(*inputs), check)
Example #7
0
def test_sort_and_pad(inputs, check):
    padded, lens, idx = utils.sort_and_pad(*inputs)
    utils.assert_equal(padded, check[0])
    utils.assert_equal(lens, check[1])
    utils.assert_equal(idx, check[2])
Example #8
0
def test_pad(inputs, check):
    utils.assert_equal(utils.pad(*inputs), check)
Example #9
0
def test_sequence_mask(inp, chk, max_len):
    """Test `sequence_mask()` method."""
    res = utils.sequence_mask(inp, max_len=max_len)
    utils.assert_equal(res, chk)