예제 #1
0
def get_word2vec(lang: str = "en"):
    # Download.
    urls = {
            "en": "https://s3.amazonaws.com/dl4j-distribution/GoogleNews-vectors-negative300.bin.gz",
            "ja": "http://public.shiroyagi.s3.amazonaws.com/latest-ja-word2vec-gensim-model.zip"
            }
    path = download.cached_download(urls[lang])
    path = Path(path)

    filename = "word2vec.gensim.model"

    print("Loading model...")

    if lang == "ja":
        dirpath = Path(download.get_cache_directory(str(Path("word2vec"))))
        download.cached_unzip(path, dirpath / lang)
        model_path = dirpath / lang / filename
        model = gensim.models.Word2Vec.load(str(model_path))

    if lang == "en":
        dirpath = Path(download.get_cache_directory(str(Path("word2vec") / "en")))
        model_path = dirpath / filename
        download.cached_decompress_gzip(path, model_path)
        model = gensim.models.KeyedVectors.load_word2vec_format(str(model_path), binary=True)

    return model
예제 #2
0
 def test_cache_exists(self, f: Callable):
     f.return_value = True
     url = 'https://example.com'
     path = download.cached_download(url)
     self.assertEqual(
         path,
         os.path.join(self.temp_dir, '_dl_cache',
                      hashlib.md5(url.encode('utf-8')).hexdigest()))
예제 #3
0
def get_fasttext(lang: str = "en"):
    # Download.
    urls = {
            "en": "https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.simple.zip",
            "ja": "https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.ja.zip",
            "fr": "https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.fr.zip",
            }
    path = download.cached_download(urls[lang])
    path = Path(path)
    dirpath = path.parent / 'fasttext' / lang
    download.cached_unzip(path, dirpath)

    print("Loading model...")
    filename = Path(urls[lang]).stem + '.bin'
    model = load_model(str(dirpath / filename))
    return model
예제 #4
0
    def test_cached_download(self, f: Callable):
        def urlretrieve(url, path, progress_hook=None):
            with open(path, "w") as f:
                f.write("test")

        f.side_effect = urlretrieve

        cache_path = download.cached_download("https://example.com")

        self.assertEqual(f.call_count, 1)
        args, kwargs = f.call_args
        self.assertEqual(kwargs, {})
        self.assertEqual(len(args), 3)
        # The second argument is a temporary path, and it is removed
        self.assertEqual(args[0], "https://example.com")

        self.assertTrue(os.path.exists(cache_path))
        with open(cache_path) as f:
            stored_data = f.read()
        self.assertEqual(stored_data, "test")
예제 #5
0
 def test_file_exists(self):
     # Make an empty file which has the same name as the cache directory
     with open(os.path.join(self.temp_dir, "_dl_cache"), "w"):
         pass
     with self.assertRaises(OSError):
         download.cached_download("https://example.com")
예제 #6
0
 def test_fails_to_make_directory(self, f: Callable):
     f.side_effect = OSError()
     with self.assertRaises(OSError):
         download.cached_download("https://example.com")