def test_setup_with_extra_tokens(): field = TextField.from_embeddings( embeddings="tests/data/dummy_embeddings/test.txt", pad_token=None, unk_init_all=False, additional_special_tokens=['<a>', '<b>', '<c>']) dummy = "this is a test" field.setup([dummy]) assert recursive_tensor_to_list(field.process(dummy)) == [4, 5, 6, 7] dummy = "this is a test <a> <c>" assert recursive_tensor_to_list(field.process(dummy)) == [4, 5, 6, 7, 1, 3]
def test_load_embeddings_with_extra_tokens(): field = TextField.from_embeddings( embeddings="tests/data/dummy_embeddings/test.txt", pad_token=None, unk_init_all=False, additional_special_tokens=['<a>', '<b>', '<c>']) dummy = "a test ! <a> <b> " field.setup([dummy]) assert '<a>' in field.vocab and '<b>' in field.vocab and '<c>' in field.vocab assert field.embedding_matrix[field.vocab['<a>']].size(-1) == 4 assert field.embedding_matrix[field.vocab['<b>']].size(-1) == 4 assert all(field.embedding_matrix[field.vocab['<b>']] != field.embedding_matrix[field.vocab['<c>']])
def test_load_embeddings(): field = TextField.from_embeddings( embeddings="tests/data/dummy_embeddings/test.txt", pad_token=None, unk_init_all=False, ) dummy = "a test !" field.setup([dummy]) # Now we have embeddings to check against true_embeddings = torch.tensor([[0.9, 0.1, 0.2, 0.3], [0.4, 0.5, 0.6, 0.7]]) assert len(field.embedding_matrix) == 3 assert torch.all(torch.eq(field.embedding_matrix[1:3], true_embeddings))
def test_load_embeddings_empty_voc(): field = TextField.from_embeddings( embeddings="tests/data/dummy_embeddings/test.txt", pad_token=None, unk_init_all=True, ) dummy = "justo Praesent luctus justo praesent" field.setup([dummy]) # No embeddings in the data, so get zeros assert len(field.embedding_matrix) == 5 field = TextField.from_embeddings( embeddings="tests/data/dummy_embeddings/test.txt", pad_token=None, unk_init_all=False, ) dummy = "justo Praesent luctus justo praesent" field.setup([dummy]) # No embeddings in the data, so get zeros assert len(field.embedding_matrix) == 1
def test_build_vocab_build_vocab_from_embeddings(): """ This test shows that all fields in the embeddings will be included. In embeddings and data: blue green yellow In embeddings only: purple gold In data only: white Expected vocab: blue green yellow purple gold white """ model = KeyedVectors(10) model.add('purple', np.random.rand(10)) model.add('gold', np.random.rand(10)) model.add('<unk>', np.random.rand(10)) model.add('blue', np.random.rand(10)) model.add('green', np.random.rand(10)) model.add('<pad>', np.random.rand(10)) model.add('yellow', np.random.rand(10)) with tempfile.NamedTemporaryFile() as tmpfile: model.save(tmpfile.name) field = TextField.from_embeddings( embeddings=tmpfile.name, embeddings_format='gensim', build_vocab_from_embeddings=True, ) dummy = ["blue green", "yellow", 'white'] field.setup(dummy) # assert vocab setup in expected order assert field.vocab == odict([ ('<pad>', 0), ('<unk>', 1), ('blue', 2), ('green', 3), ('yellow', 4), ('white', 1), ('purple', 5), ('gold', 6), ]) # assert embedding matrix organized in expected order assert torch.equal( field.embedding_matrix, torch.stack([ torch.tensor(model['<pad>']), torch.tensor(model['<unk>']), torch.tensor(model['blue']), torch.tensor(model['green']), torch.tensor(model['yellow']), torch.tensor(model['purple']), torch.tensor(model['gold']) ]), )