def download(self): if is_file_correct(self.embedding_file): logger.info("embedding file location: {}".format( self.embedding_file)) return self.embedding_file dcache_path = os.path.join(self.data_download_cache, DATA_CACHE_CONF) dcache = read_json(dcache_path) if self.embedding_file in dcache and not self.cache_ignore: download_loc = dcache[self.embedding_file] logger.info("files for {} found in cache".format( self.embedding_file)) return self._get_embedding_file(download_loc, self.embedding_key) else: # try to download the bundle and unzip url = self.embedding_file if not validate_url(url): raise RuntimeError("can not download from the given url") else: cache_dir = self.data_download_cache temp_file = web_downloader(url) unzip_fn = Downloader.ZIPD.get( mime_type(temp_file)) if self.unzip_file else None download_loc = extractor(filepath=temp_file, cache_dir=cache_dir, extractor_func=unzip_fn) if self.sha1 is not None: if os.path.split(download_loc)[-1] != self.sha1: raise RuntimeError( "The sha1 of the downloaded file does not match with the provided one" ) dcache.update({url: download_loc}) write_json( dcache, os.path.join(self.data_download_cache, DATA_CACHE_CONF)) return self._get_embedding_file(download_loc, self.embedding_key)
def load(cls, embeddings, **kwargs): c = cls("tlm-words-embed-pooled-output", **kwargs) if embeddings.endswith('.bin'): # HuggingFace checkpoint, convert on the fly from eight_mile.pytorch.serialize import load_tlm_transformers_bin, BERT_HF_FT_LAYER_MAP unmatch = load_tlm_transformers_bin( c, embeddings, replace_layers=BERT_HF_FT_LAYER_MAP) if unmatch['missing'] or unmatch['unexpected']: raise Exception("Unable to load the HuggingFace checkpoint") if mime_type(embeddings ) == 'application/zip' and not embeddings.endswith("pth"): keys_to_restore = set(list(c.embeddings.keys())) filtered_keys = keys_to_restore.difference(c.skippable) if not keys_to_restore: raise Exception("No keys to restore!") if len(filtered_keys) < len(keys_to_restore): logger.warning("Restoring only key [%s]", ' '.join(filtered_keys)) load_tlm_output_npz(c, embeddings, filtered_keys) else: map_location = 'cpu' if kwargs.get('cpu_placement') else None tlm_load_state_dict(c, embeddings, str_map={ 'model.embeddings.embeddings.0.': '', 'model.output_layer': 'output_layer' }, map_location=map_location) return c
def download(self): file_loc = self.dataset_file if is_file_correct(file_loc): return file_loc elif validate_url( file_loc): # is it a web URL? check if exists in cache url = file_loc dcache_path = os.path.join(self.data_download_cache, DATA_CACHE_CONF) dcache = read_json(dcache_path) if url in dcache and is_file_correct( dcache[url], self.data_download_cache, url) and not self.cache_ignore: logger.info( "file for {} found in cache, not downloading".format(url)) return dcache[url] else: # download the file in the cache, update the json cache_dir = self.data_download_cache logger.info( "using {} as data/embeddings cache".format(cache_dir)) temp_file = web_downloader(url) dload_file = extractor(filepath=temp_file, cache_dir=cache_dir, extractor_func=Downloader.ZIPD.get( mime_type(temp_file), None)) dcache.update({url: dload_file}) write_json( dcache, os.path.join(self.data_download_cache, DATA_CACHE_CONF)) return dload_file raise RuntimeError( "the file [{}] is not in cache and can not be downloaded".format( file_loc))
def _verify_file(file_loc): # dropbox doesn't give 404 in case the file does not exist, produces an HTML. The actual files are never HTMLs. if not os.path.exists(file_loc): return False if os.path.isfile(file_loc) and mime_type(file_loc) == "text/html": return False return True
def extract_gzip(file_loc): temp_file = delete_old_copy("{}.1".format(file_loc)) with gzip.open(file_loc, 'rb') as f_in: with open(temp_file, 'wb') as f_out: shutil.copyfileobj(f_in, f_out) if mime_type(temp_file) == "application/x-tar": return extract_tar(temp_file) else: shutil.move(temp_file, file_loc) return file_loc
def _read_vectors(self, filename, idx, known_vocab, keep_unused, **kwargs): use_mmap = bool(kwargs.get("use_mmap", False)) read_fn = self._read_word2vec_file is_glove_file = mime_type(filename) == "text/plain" if use_mmap: if is_glove_file: read_fn = self._read_text_mmap else: read_fn = self._read_word2vec_mmap elif is_glove_file: read_fn = self._read_text_file return read_fn(filename, idx, known_vocab, keep_unused)
def download(self): dload_bundle = self.dataset_desc.get("download", None) if dload_bundle is not None: # download a zip/tar/tar.gz directory, look for train, dev test files inside that. dcache_path = os.path.join(self.data_download_cache, DATA_CACHE_CONF) dcache = read_json(dcache_path) if dload_bundle in dcache and \ is_dir_correct(dcache[dload_bundle], self.dataset_desc, self.data_download_cache, dload_bundle, self.enc_dec) and not self.cache_ignore: download_dir = dcache[dload_bundle] logger.info( "files for {} found in cache, not downloading".format( dload_bundle)) updated = _update_md(self.dataset_desc, download_dir) return updated else: # try to download the bundle and unzip if not validate_url(dload_bundle): raise RuntimeError("can not download from the given url") else: cache_dir = self.data_download_cache temp_file = web_downloader(dload_bundle) download_dir = extractor( filepath=temp_file, cache_dir=cache_dir, extractor_func=Downloader.ZIPD.get( mime_type(temp_file), None)) if "sha1" in self.dataset_desc: if os.path.split( download_dir)[-1] != self.dataset_desc["sha1"]: raise RuntimeError( "The sha1 of the downloaded file does not match with the provided one" ) dcache.update({dload_bundle: download_dir}) write_json( dcache, os.path.join(self.data_download_cache, DATA_CACHE_CONF)) updated = _update_md(self.dataset_desc, download_dir) return updated else: # we have download links to every file or they exist updated = _update_md(self.dataset_desc, None) if not self.enc_dec: updated.update({ k: SingleFileDownloader(self.dataset_desc[k], self.data_download_cache).download() for k in self.dataset_desc if k.endswith("_file") and self.dataset_desc[k] }) return updated
def load(cls, embeddings, **kwargs): c = cls("tlm-words-embed", **kwargs) if embeddings.endswith('.bin'): # HuggingFace checkpoint, convert on the fly from eight_mile.pytorch.serialize import load_tlm_transformers_bin, BERT_HF_FT_LAYER_MAP unmatch = load_tlm_transformers_bin( c, embeddings, replace_layers=BERT_HF_FT_LAYER_MAP) if unmatch['missing'] or unmatch['unexpected']: raise Exception("Unable to load the HuggingFace checkpoint") if mime_type(embeddings) == 'application/zip': load_tlm_npz(c, embeddings) else: tlm_load_state_dict(c, embeddings) return c
def unzip_files(zip_path): if os.path.isdir(zip_path): return zip_path from eight_mile.utils import mime_type if mime_type(zip_path) == 'application/zip': with open(zip_path, 'rb') as f: sha1 = hashlib.sha1(f.read()).hexdigest() temp_dir = os.path.join("/tmp/", sha1) if not os.path.exists(temp_dir): logger.info("unzipping model") with zipfile.ZipFile(zip_path, "r") as zip_ref: zip_ref.extractall(temp_dir) if len(os.listdir(temp_dir)) == 1: # a directory was zipped v files temp_dir = os.path.join(temp_dir, os.listdir(temp_dir)[0]) return temp_dir return zip_path