def test_extend_vectors_1(self): vectors_cache_dir = '.cache' if os.path.exists(vectors_cache_dir): shutil.rmtree(vectors_cache_dir) pathdir = os.path.abspath(os.path.join(test_dir_path, 'test_datasets')) filename = 'fasttext_sample.vec' file = os.path.join(pathdir, filename) url_base = urljoin('file:', pathname2url(file)) vecs = Vectors(name=filename, cache=vectors_cache_dir, url=url_base) self.assertIsInstance(vecs, Vectors) vec_data = MatchingField._get_vector_data(vecs, vectors_cache_dir) v = MatchingVocab(Counter()) v.vectors = torch.Tensor(1, vec_data[0].dim) v.unk_init = torch.Tensor.zero_ tokens = {'hello', 'world'} v.extend_vectors(tokens, vec_data) self.assertEqual(len(v.itos), 4) self.assertEqual(v.vectors.size(), torch.Size([4, 300])) self.assertEqual(list(v.vectors[2][0:10]), [0.0] * 10) self.assertEqual(list(v.vectors[3][0:10]), [0.0] * 10) if os.path.exists(vectors_cache_dir): shutil.rmtree(vectors_cache_dir)
def test_get_vector_data(self): vectors_cache_dir = '.cache' if os.path.exists(vectors_cache_dir): shutil.rmtree(vectors_cache_dir) pathdir = os.path.abspath(os.path.join('.', 'test_datasets')) filename = 'fasttext_sample.vec' file = os.path.join(pathdir, filename) url_base = urljoin('file:', pathname2url(file)) vecs = Vectors(name=filename, cache=vectors_cache_dir, url=url_base) self.assertIsInstance(vecs, Vectors) vec_data = MatchingField._get_vector_data(vecs, vectors_cache_dir) self.assertEqual(len(vec_data), 1) if os.path.exists(vectors_cache_dir): shutil.rmtree(vectors_cache_dir)