Exemplo n.º 1
0
 def download(self):
     if is_file_correct(self.embedding_file):
         logger.info("embedding file location: {}".format(
             self.embedding_file))
         return self.embedding_file
     dcache_path = os.path.join(self.data_download_cache, DATA_CACHE_CONF)
     dcache = read_json(dcache_path)
     if self.embedding_file in dcache and not self.cache_ignore:
         download_loc = dcache[self.embedding_file]
         logger.info("files for {} found in cache".format(
             self.embedding_file))
         return self._get_embedding_file(download_loc, self.embedding_key)
     else:  # try to download the bundle and unzip
         url = self.embedding_file
         if not validate_url(url):
             raise RuntimeError("can not download from the given url")
         else:
             cache_dir = self.data_download_cache
             temp_file = web_downloader(url)
             unzip_fn = Downloader.ZIPD.get(
                 mime_type(temp_file)) if self.unzip_file else None
             download_loc = extractor(filepath=temp_file,
                                      cache_dir=cache_dir,
                                      extractor_func=unzip_fn)
             if self.sha1 is not None:
                 if os.path.split(download_loc)[-1] != self.sha1:
                     raise RuntimeError(
                         "The sha1 of the downloaded file does not match with the provided one"
                     )
             dcache.update({url: download_loc})
             write_json(
                 dcache,
                 os.path.join(self.data_download_cache, DATA_CACHE_CONF))
             return self._get_embedding_file(download_loc,
                                             self.embedding_key)
Exemplo n.º 2
0
    def load(cls, embeddings, **kwargs):
        c = cls("tlm-words-embed-pooled-output", **kwargs)

        if embeddings.endswith('.bin'):
            # HuggingFace checkpoint, convert on the fly
            from eight_mile.pytorch.serialize import load_tlm_transformers_bin, BERT_HF_FT_LAYER_MAP
            unmatch = load_tlm_transformers_bin(
                c, embeddings, replace_layers=BERT_HF_FT_LAYER_MAP)
            if unmatch['missing'] or unmatch['unexpected']:
                raise Exception("Unable to load the HuggingFace checkpoint")
        if mime_type(embeddings
                     ) == 'application/zip' and not embeddings.endswith("pth"):
            keys_to_restore = set(list(c.embeddings.keys()))
            filtered_keys = keys_to_restore.difference(c.skippable)
            if not keys_to_restore:
                raise Exception("No keys to restore!")
            if len(filtered_keys) < len(keys_to_restore):
                logger.warning("Restoring only key [%s]",
                               ' '.join(filtered_keys))
            load_tlm_output_npz(c, embeddings, filtered_keys)
        else:
            map_location = 'cpu' if kwargs.get('cpu_placement') else None
            tlm_load_state_dict(c,
                                embeddings,
                                str_map={
                                    'model.embeddings.embeddings.0.': '',
                                    'model.output_layer': 'output_layer'
                                },
                                map_location=map_location)
        return c
Exemplo n.º 3
0
 def download(self):
     file_loc = self.dataset_file
     if is_file_correct(file_loc):
         return file_loc
     elif validate_url(
             file_loc):  # is it a web URL? check if exists in cache
         url = file_loc
         dcache_path = os.path.join(self.data_download_cache,
                                    DATA_CACHE_CONF)
         dcache = read_json(dcache_path)
         if url in dcache and is_file_correct(
                 dcache[url], self.data_download_cache,
                 url) and not self.cache_ignore:
             logger.info(
                 "file for {} found in cache, not downloading".format(url))
             return dcache[url]
         else:  # download the file in the cache, update the json
             cache_dir = self.data_download_cache
             logger.info(
                 "using {} as data/embeddings cache".format(cache_dir))
             temp_file = web_downloader(url)
             dload_file = extractor(filepath=temp_file,
                                    cache_dir=cache_dir,
                                    extractor_func=Downloader.ZIPD.get(
                                        mime_type(temp_file), None))
             dcache.update({url: dload_file})
             write_json(
                 dcache,
                 os.path.join(self.data_download_cache, DATA_CACHE_CONF))
             return dload_file
     raise RuntimeError(
         "the file [{}] is not in cache and can not be downloaded".format(
             file_loc))
Exemplo n.º 4
0
def _verify_file(file_loc):
    # dropbox doesn't give 404 in case the file does not exist, produces an HTML. The actual files are never HTMLs.
    if not os.path.exists(file_loc):
        return False

    if os.path.isfile(file_loc) and mime_type(file_loc) == "text/html":
        return False

    return True
Exemplo n.º 5
0
def extract_gzip(file_loc):
    temp_file = delete_old_copy("{}.1".format(file_loc))
    with gzip.open(file_loc, 'rb') as f_in:
        with open(temp_file, 'wb') as f_out:
            shutil.copyfileobj(f_in, f_out)
    if mime_type(temp_file) == "application/x-tar":
        return extract_tar(temp_file)
    else:
        shutil.move(temp_file, file_loc)
        return file_loc
Exemplo n.º 6
0
    def _read_vectors(self, filename, idx, known_vocab, keep_unused, **kwargs):
        use_mmap = bool(kwargs.get("use_mmap", False))
        read_fn = self._read_word2vec_file
        is_glove_file = mime_type(filename) == "text/plain"
        if use_mmap:
            if is_glove_file:
                read_fn = self._read_text_mmap
            else:
                read_fn = self._read_word2vec_mmap
        elif is_glove_file:
            read_fn = self._read_text_file

        return read_fn(filename, idx, known_vocab, keep_unused)
Exemplo n.º 7
0
    def download(self):
        dload_bundle = self.dataset_desc.get("download", None)
        if dload_bundle is not None:  # download a zip/tar/tar.gz directory, look for train, dev test files inside that.
            dcache_path = os.path.join(self.data_download_cache,
                                       DATA_CACHE_CONF)
            dcache = read_json(dcache_path)
            if dload_bundle in dcache and \
                    is_dir_correct(dcache[dload_bundle], self.dataset_desc, self.data_download_cache, dload_bundle,
                                   self.enc_dec) and not self.cache_ignore:
                download_dir = dcache[dload_bundle]
                logger.info(
                    "files for {} found in cache, not downloading".format(
                        dload_bundle))
                updated = _update_md(self.dataset_desc, download_dir)
                return updated
            else:  # try to download the bundle and unzip
                if not validate_url(dload_bundle):
                    raise RuntimeError("can not download from the given url")
                else:
                    cache_dir = self.data_download_cache
                    temp_file = web_downloader(dload_bundle)

                    download_dir = extractor(
                        filepath=temp_file,
                        cache_dir=cache_dir,
                        extractor_func=Downloader.ZIPD.get(
                            mime_type(temp_file), None))
                    if "sha1" in self.dataset_desc:
                        if os.path.split(
                                download_dir)[-1] != self.dataset_desc["sha1"]:
                            raise RuntimeError(
                                "The sha1 of the downloaded file does not match with the provided one"
                            )
                    dcache.update({dload_bundle: download_dir})
                    write_json(
                        dcache,
                        os.path.join(self.data_download_cache,
                                     DATA_CACHE_CONF))
                    updated = _update_md(self.dataset_desc, download_dir)
                    return updated
        else:  # we have download links to every file or they exist
            updated = _update_md(self.dataset_desc, None)
            if not self.enc_dec:
                updated.update({
                    k:
                    SingleFileDownloader(self.dataset_desc[k],
                                         self.data_download_cache).download()
                    for k in self.dataset_desc
                    if k.endswith("_file") and self.dataset_desc[k]
                })
            return updated
Exemplo n.º 8
0
    def load(cls, embeddings, **kwargs):
        c = cls("tlm-words-embed", **kwargs)

        if embeddings.endswith('.bin'):
            # HuggingFace checkpoint, convert on the fly
            from eight_mile.pytorch.serialize import load_tlm_transformers_bin, BERT_HF_FT_LAYER_MAP
            unmatch = load_tlm_transformers_bin(
                c, embeddings, replace_layers=BERT_HF_FT_LAYER_MAP)
            if unmatch['missing'] or unmatch['unexpected']:
                raise Exception("Unable to load the HuggingFace checkpoint")
        if mime_type(embeddings) == 'application/zip':
            load_tlm_npz(c, embeddings)
        else:
            tlm_load_state_dict(c, embeddings)
        return c
Exemplo n.º 9
0
def unzip_files(zip_path):
    if os.path.isdir(zip_path):
        return zip_path
    from eight_mile.utils import mime_type
    if mime_type(zip_path) == 'application/zip':
        with open(zip_path, 'rb') as f:
            sha1 = hashlib.sha1(f.read()).hexdigest()
            temp_dir = os.path.join("/tmp/", sha1)
            if not os.path.exists(temp_dir):
                logger.info("unzipping model")
                with zipfile.ZipFile(zip_path, "r") as zip_ref:
                    zip_ref.extractall(temp_dir)
            if len(os.listdir(temp_dir)) == 1:  # a directory was zipped v files
                temp_dir = os.path.join(temp_dir, os.listdir(temp_dir)[0])
        return temp_dir
    return zip_path