Example #1
0
def test_setup_with_extra_tokens():
    field = TextField.from_embeddings(
        embeddings="tests/data/dummy_embeddings/test.txt",
        pad_token=None,
        unk_init_all=False,
        additional_special_tokens=['<a>', '<b>', '<c>'])

    dummy = "this is a test"
    field.setup([dummy])
    assert recursive_tensor_to_list(field.process(dummy)) == [4, 5, 6, 7]

    dummy = "this is a test <a> <c>"
    assert recursive_tensor_to_list(field.process(dummy)) == [4, 5, 6, 7, 1, 3]
Example #2
0
def test_load_embeddings_with_extra_tokens():
    field = TextField.from_embeddings(
        embeddings="tests/data/dummy_embeddings/test.txt",
        pad_token=None,
        unk_init_all=False,
        additional_special_tokens=['<a>', '<b>', '<c>'])
    dummy = "a test ! <a> <b> "
    field.setup([dummy])
    assert '<a>' in field.vocab and '<b>' in field.vocab and '<c>' in field.vocab
    assert field.embedding_matrix[field.vocab['<a>']].size(-1) == 4
    assert field.embedding_matrix[field.vocab['<b>']].size(-1) == 4
    assert all(field.embedding_matrix[field.vocab['<b>']] !=
               field.embedding_matrix[field.vocab['<c>']])
def test_load_embeddings():
    field = TextField.from_embeddings(
        embeddings="tests/data/dummy_embeddings/test.txt",
        pad_token=None,
        unk_init_all=False,
    )
    dummy = "a test !"
    field.setup([dummy])

    # Now we have embeddings to check against
    true_embeddings = torch.tensor([[0.9, 0.1, 0.2, 0.3], [0.4, 0.5, 0.6, 0.7]])
    assert len(field.embedding_matrix) == 3
    assert torch.all(torch.eq(field.embedding_matrix[1:3], true_embeddings))
Example #4
0
def test_load_embeddings_empty_voc():
    field = TextField.from_embeddings(
        embeddings="tests/data/dummy_embeddings/test.txt",
        pad_token=None,
        unk_init_all=True,
    )

    dummy = "justo Praesent luctus justo praesent"
    field.setup([dummy])

    # No embeddings in the data, so get zeros
    assert len(field.embedding_matrix) == 5

    field = TextField.from_embeddings(
        embeddings="tests/data/dummy_embeddings/test.txt",
        pad_token=None,
        unk_init_all=False,
    )

    dummy = "justo Praesent luctus justo praesent"
    field.setup([dummy])

    # No embeddings in the data, so get zeros
    assert len(field.embedding_matrix) == 1
Example #5
0
def test_build_vocab_build_vocab_from_embeddings():
    """
    This test shows that all fields in the embeddings will be included.

    In embeddings and data:
        blue
        green
        yellow
    In embeddings only:
        purple
        gold
    In data only:
        white

    Expected vocab:
        blue
        green
        yellow
        purple
        gold
        white
    """

    model = KeyedVectors(10)
    model.add('purple', np.random.rand(10))
    model.add('gold', np.random.rand(10))
    model.add('<unk>', np.random.rand(10))
    model.add('blue', np.random.rand(10))
    model.add('green', np.random.rand(10))
    model.add('<pad>', np.random.rand(10))
    model.add('yellow', np.random.rand(10))

    with tempfile.NamedTemporaryFile() as tmpfile:
        model.save(tmpfile.name)

        field = TextField.from_embeddings(
            embeddings=tmpfile.name,
            embeddings_format='gensim',
            build_vocab_from_embeddings=True,
        )

        dummy = ["blue green", "yellow", 'white']

        field.setup(dummy)

    # assert vocab setup in expected order
    assert field.vocab == odict([
        ('<pad>', 0),
        ('<unk>', 1),
        ('blue', 2),
        ('green', 3),
        ('yellow', 4),
        ('white', 1),
        ('purple', 5),
        ('gold', 6),
    ])

    # assert embedding matrix organized in expected order
    assert torch.equal(
        field.embedding_matrix,
        torch.stack([
            torch.tensor(model['<pad>']),
            torch.tensor(model['<unk>']),
            torch.tensor(model['blue']),
            torch.tensor(model['green']),
            torch.tensor(model['yellow']),
            torch.tensor(model['purple']),
            torch.tensor(model['gold'])
        ]),
    )