コード例 #1
0
def test_text_process_lower():
    field = TextField(lower=True)

    dummy = "justo Praesent luctus justo praesent"
    assert list(field.process(dummy)) == [1, 1, 1, 1, 1]

    field.setup([dummy])
    assert list(field.process(dummy)) == [2, 3, 4, 2, 3]
コード例 #2
0
def test_build_vocab_lower():
    field = TextField(lower=True, pad_token=None, unk_token=None)

    dummy = ["justo Praesent luctus", "luctus praesent"]
    field.setup(dummy)

    vocab = {'justo': 0, 'praesent': 1, 'luctus': 2}
    assert field.vocab == vocab
コード例 #3
0
def test_build_vocab_empty():
    field = TextField(pad_token=None, unk_token=None)
    assert field.vocab == dict()

    dummy = ["justo Praesent luctus", "luctus praesent"]
    field.setup(dummy)

    vocab = {'justo': 0, 'Praesent': 1, 'luctus': 2, 'praesent': 3}
    assert field.vocab == vocab
コード例 #4
0
def test_build_vocab_setup_all_embeddings():
    """
    This test shows that all fields in the embeddings will be included.

    In embeddings and data:
        blue
        green
        yellow
    In embeddings only:
        purple
        gold
    In data only:
        white

    Expected vocab:
        blue
        green
        yellow
        purple
        gold
        white
    """

    model = KeyedVectors(10)
    model.add('purple', np.random.rand(10))
    model.add('gold', np.random.rand(10))
    model.add('<unk>', np.random.rand(10))
    model.add('blue', np.random.rand(10))
    model.add('green', np.random.rand(10))
    model.add('<pad>', np.random.rand(10))
    model.add('yellow', np.random.rand(10))

    field = TextField(
        model=model,
        setup_all_embeddings=True,
    )

    dummy = ["blue green", "yellow", 'white']

    field.setup(dummy)

    # assert vocab setup in expected order
    assert field.vocab == odict([
        ('<pad>', 0), ('<unk>', 1), ('blue', 2), ('green', 3),
        ('yellow', 4), ('white', 1), ('purple', 5), ('gold', 6),
    ])

    # assert embedding matrix organized in expected order
    assert torch.equal(
        field.embedding_matrix,
        torch.stack([
            torch.tensor(model['<pad>']), torch.tensor(model['<unk>']),
            torch.tensor(model['blue']), torch.tensor(model['green']),
            torch.tensor(model['yellow']), torch.tensor(model['purple']),
            torch.tensor(model['gold'])
        ]),
    )
コード例 #5
0
def test_load_embeddings():
    field = TextField(pad_token=None,
                      unk_init_all=False,
                      embeddings="tests/data/dummy_embeddings/test.txt")
    dummy = "a test !"
    field.setup([dummy])

    # Now we have embeddings to check against
    true_embeddings = torch.tensor([[0.9, 0.1, 0.2, 0.3], [0.4, 0.5, 0.6,
                                                           0.7]])
    assert len(field.embedding_matrix) == 3
    assert torch.all(torch.eq(field.embedding_matrix[1:3], true_embeddings))
コード例 #6
0
def test_text_process_list():
    field = TextField(lower=True)
    field.setup()
    dummy = [["justo Praesent luctus", "luctus praesent"],
             ["justo Praesent luctus", "luctus praesent est"]]
    assert recursive_tensor_to_list(field.process(dummy)) == [[[1, 1, 1],
                                                               [1, 1]],
                                                              [[1, 1, 1],
                                                               [1, 1, 1]]]

    field.setup(dummy)
    assert recursive_tensor_to_list(field.process(dummy)) == [[[2, 3, 4],
                                                               [4, 3]],
                                                              [[2, 3, 4],
                                                               [4, 3, 5]]]
コード例 #7
0
def test_build_vocab():
    field = TextField(pad_token='<pad>', unk_token='<unk>')
    assert field.vocab == {'<pad>': 0, '<unk>': 1}

    dummy = ["justo Praesent luctus", "luctus praesent"]
    field.setup(dummy)

    vocab = {
        '<pad>': 0,
        '<unk>': 1,
        'justo': 2,
        'Praesent': 3,
        'luctus': 4,
        'praesent': 5
    }
    assert field.vocab == vocab
コード例 #8
0
def test_text_process_nested_list_in_dict():
    field = TextField(lower=True)
    field.setup()
    dummy = [{
        'text1': ["justo Praesent luctus", "luctus praesent"],
        'text2': ["justo Praesent luctus", "luctus praesent est"]
    }]
    assert recursive_tensor_to_list(field.process(dummy)) == [{
        'text1': [[1, 1, 1], [1, 1]],
        'text2': [[1, 1, 1], [1, 1, 1]]
    }]
    field.setup(dummy)
    assert recursive_tensor_to_list(field.process(dummy)) == [{
        'text1': [[2, 3, 4], [4, 3]],
        'text2': [[2, 3, 4], [4, 3, 5]]
    }]
コード例 #9
0
def test_text_process_dict():
    field = TextField(lower=True)
    field.setup()
    dummy = {
        'text1': "justo Praesent luctus luctus praesent",
        'text2': "justo Praesent luctus luctus praesent est"
    }
    assert recursive_tensor_to_list(field.process(dummy)) == {
        'text1': [1, 1, 1, 1, 1],
        'text2': [1, 1, 1, 1, 1, 1]
    }
    field.setup([dummy])
    assert recursive_tensor_to_list(field.process(dummy)) == {
        'text1': [2, 3, 4, 4, 3],
        'text2': [2, 3, 4, 4, 3, 5]
    }
コード例 #10
0
def test_load_embeddings_empty_voc():
    field = TextField(pad_token=None,
                      unk_init_all=True,
                      embeddings="tests/data/dummy_embeddings/test.txt")

    dummy = "justo Praesent luctus justo praesent"
    field.setup([dummy])

    # No embeddings in the data, so get zeros
    assert len(field.embedding_matrix) == 5

    field = TextField(pad_token=None,
                      unk_init_all=False,
                      embeddings="tests/data/dummy_embeddings/test.txt")

    dummy = "justo Praesent luctus justo praesent"
    field.setup([dummy])

    # No embeddings in the data, so get zeros
    assert len(field.embedding_matrix) == 1
コード例 #11
0
def test_build_vocab_decorators():
    field = TextField(pad_token=None,
                      unk_token=None,
                      sos_token='<sos>',
                      eos_token='<eos>')

    assert field.vocab == {'<sos>': 0, '<eos>': 1}
    dummy = ["justo Praesent luctus", "luctus praesent"]
    field.setup(dummy)

    vocab = {
        '<sos>': 0,
        '<eos>': 1,
        'justo': 2,
        'Praesent': 3,
        'luctus': 4,
        'praesent': 5
    }
    assert field.vocab == vocab

    field = TextField(pad_token='<pad>',
                      unk_token='<unk>',
                      sos_token='<sos>',
                      eos_token='<eos>')

    assert field.vocab == {'<pad>': 0, '<unk>': 1, '<sos>': 2, '<eos>': 3}
    dummy = ["justo Praesent luctus", "luctus praesent"]
    field.setup(dummy)

    vocab = {
        '<pad>': 0,
        '<unk>': 1,
        '<sos>': 2,
        '<eos>': 3,
        'justo': 4,
        'Praesent': 5,
        'luctus': 6,
        'praesent': 7
    }
    assert field.vocab == vocab