def test_errors_vectors_python(self): tokens = [] vecs = torch.empty(0, dtype=torch.float) with self.assertRaises(ValueError): # Test proper error raised when passing in empty tokens and vectors and # not passing in a user defined unk_tensor vectors(tokens, vecs) tensorA = torch.tensor([1, 0, 0], dtype=torch.int8) tokens = ['a'] vecs = tensorA.unsqueeze(0) with self.assertRaises(TypeError): # Test proper error raised when vector is not of type torch.float vectors(tokens, vecs) with tempfile.TemporaryDirectory() as dir_name: # Test proper error raised when incorrect filename or dim passed into GloVe asset_name = 'glove.6B.zip' asset_path = get_asset_path(asset_name) data_path = os.path.join(dir_name, asset_name) shutil.copy(asset_path, data_path) with self.assertRaises(ValueError): # incorrect name GloVe(name='UNK', dim=50, root=dir_name, validate_file=False) with self.assertRaises(ValueError): # incorrect dim GloVe(name='6B', dim=500, root=dir_name, validate_file=False)
def test_errors_vectors_cpp(self): tensorA = torch.tensor([1, 0, 0], dtype=torch.float) tensorB = torch.tensor([0, 1, 0], dtype=torch.float) tensorC = torch.tensor([0, 0, 1], dtype=torch.float) tokens = ['a', 'a', 'c'] vecs = torch.stack((tensorA, tensorB, tensorC), 0) with self.assertRaises(RuntimeError): # Test proper error raised when tokens have duplicates # TODO: use self.assertRaisesRegex() to check # the key of the duplicate token in the error message vectors(tokens, vecs)
def test_empty_vectors(self): tokens = [] vecs = torch.empty(0, dtype=torch.float) unk_tensor = torch.tensor([0], dtype=torch.float) vectors_obj = vectors(tokens, vecs, unk_tensor) self.assertEqual(vectors_obj['not_in_it'], unk_tensor)
def test_empty_unk(self): tensorA = torch.tensor([1, 0], dtype=torch.float) expected_unk_tensor = torch.tensor([0, 0], dtype=torch.float) tokens = ['a'] vecs = tensorA.unsqueeze(0) vectors_obj = vectors(tokens, vecs) self.assertEqual(vectors_obj['not_in_it'], expected_unk_tensor)
def test_errors(self): tokens = [] vecs = torch.empty(0, dtype=torch.float) with self.assertRaises(ValueError): # Test proper error raised when passing in empty tokens and vectors and # not passing in a user defined unk_tensor vectors(tokens, vecs) tensorA = torch.tensor([1, 0, 0], dtype=torch.float) tensorB = torch.tensor([0, 1, 0], dtype=torch.float) tokens = ['a', 'b', 'c'] vecs = torch.stack(( tensorA, tensorB, ), 0) with self.assertRaises(RuntimeError): # Test proper error raised when tokens and vectors have different sizes vectors(tokens, vecs) tensorC = torch.tensor([0, 0, 1], dtype=torch.float) tokens = ['a', 'a', 'c'] vecs = torch.stack((tensorA, tensorB, tensorC), 0) with self.assertRaises(RuntimeError): # Test proper error raised when tokens have duplicates # TODO (Nayef211): use self.assertRaisesRegex() to check # the key of the duplicate token in the error message vectors(tokens, vecs) tensorC = torch.tensor([0, 0, 1], dtype=torch.int8) tokens = ['a'] vecs = tensorC.unsqueeze(0) with self.assertRaises(TypeError): # Test proper error raised when vector is not of type torch.float vectors(tokens, vecs) with tempfile.TemporaryDirectory() as dir_name: # Test proper error raised when incorrect filename or dim passed into GloVe asset_name = 'glove.6B.zip' asset_path = get_asset_path(asset_name) data_path = os.path.join(dir_name, asset_name) shutil.copy(asset_path, data_path) with self.assertRaises(ValueError): # incorrect name GloVe(name='UNK', dim=50, root=dir_name, validate_file=False) with self.assertRaises(ValueError): # incorrect dim GloVe(name='6B', dim=500, root=dir_name, validate_file=False)
def test_vectors_basic(self): tensorA = torch.tensor([1, 0], dtype=torch.float) tensorB = torch.tensor([0, 1], dtype=torch.float) unk_tensor = torch.tensor([0, 0], dtype=torch.float) tokens = ['a', 'b'] vecs = torch.stack((tensorA, tensorB), 0) vectors_obj = vectors(tokens, vecs, unk_tensor=unk_tensor) self.assertEqual(vectors_obj['a'], tensorA) self.assertEqual(vectors_obj['b'], tensorB) self.assertEqual(vectors_obj['not_in_it'], unk_tensor)
def test_vectors_add_item(self): tensorA = torch.tensor([1, 0], dtype=torch.float) unk_tensor = torch.tensor([0, 0], dtype=torch.float) tokens = ['a'] vecs = tensorA.unsqueeze(0) vectors_obj = vectors(tokens, vecs, unk_tensor=unk_tensor) tensorB = torch.tensor([0, 1], dtype=torch.float) vectors_obj['b'] = tensorB self.assertEqual(vectors_obj['a'], tensorA) self.assertEqual(vectors_obj['b'], tensorB) self.assertEqual(vectors_obj['not_in_it'], unk_tensor)
def test_vectors_lookup_vectors(self): tensorA = torch.tensor([1, 0], dtype=torch.float) tensorB = torch.tensor([0, 1], dtype=torch.float) unk_tensor = torch.tensor([0, 0], dtype=torch.float) tokens = ['a', 'b'] vecs = torch.stack((tensorA, tensorB), 0) vectors_obj = vectors(tokens, vecs, unk_tensor=unk_tensor) tokens_to_lookup = ['a', 'b', 'c'] expected_vectors = torch.stack((tensorA, tensorB, unk_tensor), 0) vectors_by_tokens = vectors_obj.lookup_vectors(tokens_to_lookup) self.assertEqual(expected_vectors, vectors_by_tokens)
def test_vectors_forward(self): tensorA = torch.tensor([1, 0], dtype=torch.float) tensorB = torch.tensor([0, 1], dtype=torch.float) unk_tensor = torch.tensor([0, 0], dtype=torch.float) tokens = ['a', 'b'] vecs = torch.stack((tensorA, tensorB), 0) vectors_obj = vectors(tokens, vecs, unk_tensor=unk_tensor) jit_vectors_obj = torch.jit.script(vectors_obj.to_ivalue()) tokens_to_lookup = ['a', 'b', 'c'] expected_vectors = torch.stack((tensorA, tensorB, unk_tensor), 0) vectors_by_tokens = vectors_obj(tokens_to_lookup) jit_vectors_by_tokens = jit_vectors_obj(tokens_to_lookup) self.assertEqual(expected_vectors, vectors_by_tokens) self.assertEqual(expected_vectors, jit_vectors_by_tokens)
def test_vectors_jit(self): tensorA = torch.tensor([1, 0], dtype=torch.float) tensorB = torch.tensor([0, 1], dtype=torch.float) unk_tensor = torch.tensor([0, 0], dtype=torch.float) tokens = ['a', 'b'] vecs = torch.stack((tensorA, tensorB), 0) vectors_obj = vectors(tokens, vecs, unk_tensor=unk_tensor) jit_vectors_obj = torch.jit.script(vectors_obj.to_ivalue()) assert not vectors_obj.is_jitable assert vectors_obj.to_ivalue().is_jitable self.assertEqual(vectors_obj['a'], jit_vectors_obj['a']) self.assertEqual(vectors_obj['b'], jit_vectors_obj['b']) self.assertEqual(vectors_obj['not_in_it'], jit_vectors_obj['not_in_it'])
def test_vectors_load_and_save(self): tensorA = torch.tensor([1, 0], dtype=torch.float) tensorB = torch.tensor([0, 1], dtype=torch.float) expected_unk_tensor = torch.tensor([0, 0], dtype=torch.float) tokens = ['a', 'b'] vecs = torch.stack((tensorA, tensorB), 0) vectors_obj = vectors(tokens, vecs) tensorC = torch.tensor([1, 1], dtype=torch.float) vectors_obj['b'] = tensorC vector_path = os.path.join(self.test_dir, 'vectors.pt') torch.save(vectors_obj, vector_path) loaded_vectors_obj = torch.load(vector_path) self.assertEqual(loaded_vectors_obj['a'], tensorA) self.assertEqual(loaded_vectors_obj['b'], tensorC) self.assertEqual(loaded_vectors_obj['not_in_it'], expected_unk_tensor)