Example #1
0
    def test_extend_vectors_1(self):
        vectors_cache_dir = '.cache'
        if os.path.exists(vectors_cache_dir):
            shutil.rmtree(vectors_cache_dir)

        pathdir = os.path.abspath(os.path.join(test_dir_path, 'test_datasets'))
        filename = 'fasttext_sample.vec'
        file = os.path.join(pathdir, filename)
        url_base = urljoin('file:', pathname2url(file))
        vecs = Vectors(name=filename, cache=vectors_cache_dir, url=url_base)
        self.assertIsInstance(vecs, Vectors)

        vec_data = MatchingField._get_vector_data(vecs, vectors_cache_dir)
        v = MatchingVocab(Counter())
        v.vectors = torch.Tensor(1, vec_data[0].dim)
        v.unk_init = torch.Tensor.zero_
        tokens = {'hello', 'world'}
        v.extend_vectors(tokens, vec_data)
        self.assertEqual(len(v.itos), 4)
        self.assertEqual(v.vectors.size(), torch.Size([4, 300]))
        self.assertEqual(list(v.vectors[2][0:10]), [0.0] * 10)
        self.assertEqual(list(v.vectors[3][0:10]), [0.0] * 10)

        if os.path.exists(vectors_cache_dir):
            shutil.rmtree(vectors_cache_dir)
Example #2
0
    def test_get_vector_data(self):
        vectors_cache_dir = '.cache'
        if os.path.exists(vectors_cache_dir):
            shutil.rmtree(vectors_cache_dir)

        pathdir = os.path.abspath(os.path.join('.', 'test_datasets'))
        filename = 'fasttext_sample.vec'
        file = os.path.join(pathdir, filename)
        url_base = urljoin('file:', pathname2url(file))
        vecs = Vectors(name=filename, cache=vectors_cache_dir, url=url_base)
        self.assertIsInstance(vecs, Vectors)

        vec_data = MatchingField._get_vector_data(vecs, vectors_cache_dir)
        self.assertEqual(len(vec_data), 1)

        if os.path.exists(vectors_cache_dir):
            shutil.rmtree(vectors_cache_dir)