def get_model_file(model_name, local_model_store_dir_path=os.path.join( '~', '.mxnet', 'models')): """ Return location for the pretrained on local file system. This function will download from online model zoo when model cannot be found or has mismatch. The root directory will be created if it doesn't exist. Parameters ---------- model_name : str Name of the model. local_model_store_dir_path : str, default $MXNET_HOME/models Location for keeping the model parameters. Returns ------- file_path Path to the requested pretrained model file. """ error, sha1_hash, repo_release_tag = get_model_name_suffix_data(model_name) short_sha1 = sha1_hash[:8] file_name = '{name}-{error}-{short_sha1}.params'.format( name=model_name, error=error, short_sha1=short_sha1) local_model_store_dir_path = os.path.expanduser(local_model_store_dir_path) file_path = os.path.join(local_model_store_dir_path, file_name) if os.path.exists(file_path): if check_sha1(file_path, sha1_hash): return file_path else: logging.warning( 'Mismatch in the content of model file detected. Downloading again.' ) else: logging.info( 'Model file not found. Downloading to {}.'.format(file_path)) if not os.path.exists(local_model_store_dir_path): os.makedirs(local_model_store_dir_path) zip_file_path = file_path + '.zip' download( url='{repo_url}/releases/download/{repo_release_tag}/{file_name}.zip'. format(repo_url=imgclsmob_repo_url, repo_release_tag=repo_release_tag, file_name=file_name), path=zip_file_path, overwrite=True) with zipfile.ZipFile(zip_file_path) as zf: zf.extractall(local_model_store_dir_path) os.remove(zip_file_path) if check_sha1(file_path, sha1_hash): return file_path else: raise ValueError( 'Downloaded file has different hash. Please try again.')
def _generate(self, segment): """Partition MRPC dataset into train, dev and test. Adapted from https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e """ # download raw data data_name = segment + '.tsv' raw_name, raw_hash, data_hash = self._data_file[segment] raw_path = os.path.join(self._root, raw_name) download(self._repo_dir() + raw_name, path=raw_path, sha1_hash=raw_hash) data_path = os.path.join(self._root, data_name) if segment in ('train', 'dev'): if os.path.isfile(data_path) and check_sha1(data_path, data_hash): return # retrieve dev ids for train and dev set DEV_ID_URL = 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc' DEV_ID_HASH = '506c7a1a5e0dd551ceec2f84070fa1a8c2bc4b41' dev_id_name = 'dev_ids.tsv' dev_id_path = os.path.join(self._root, dev_id_name) download(DEV_ID_URL, path=dev_id_path, sha1_hash=DEV_ID_HASH) # read dev data ids dev_ids = [] with io.open(dev_id_path, encoding='utf8') as ids_fh: for row in ids_fh: dev_ids.append(row.strip().split('\t')) # generate train and dev set train_path = os.path.join(self._root, 'train.tsv') dev_path = os.path.join(self._root, 'dev.tsv') with io.open(raw_path, encoding='utf8') as data_fh: with io.open(train_path, 'w', encoding='utf8') as train_fh: with io.open(dev_path, 'w', encoding='utf8') as dev_fh: header = data_fh.readline() train_fh.write(header) dev_fh.write(header) for row in data_fh: label, id1, id2, s1, s2 = row.strip().split('\t') example = '%s\t%s\t%s\t%s\t%s\n'%(label, id1, id2, s1, s2) if [id1, id2] in dev_ids: dev_fh.write(example) else: train_fh.write(example) else: # generate test set if os.path.isfile(data_path) and check_sha1(data_path, data_hash): return with io.open(raw_path, encoding='utf8') as data_fh: with io.open(data_path, 'w', encoding='utf8') as test_fh: header = data_fh.readline() test_fh.write('index\t#1 ID\t#2 ID\t#1 String\t#2 String\n') for idx, row in enumerate(data_fh): label, id1, id2, s1, s2 = row.strip().split('\t') test_fh.write('%d\t%s\t%s\t%s\t%s\n'%(idx, id1, id2, s1, s2))
def download(url, path=None, overwrite=False, sha1_hash=None, verify=True): """Download an given URL Parameters ---------- url : str URL to download path : str, optional Destination path to store downloaded file. By default stores to the current directory with same name as in url. overwrite : bool, optional Whether to overwrite destination file if already exists. sha1_hash : str, optional Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified but doesn't match. verify : bool Toggle verification of SSL certificates. Returns ------- str The file path of the downloaded file. """ if path is None: fname = url.split('/')[-1] else: path = os.path.expanduser(path) if os.path.isdir(path): fname = os.path.join(path, url.split('/')[-1]) else: fname = path if overwrite or not os.path.exists(fname) or ( sha1_hash and not check_sha1(fname, sha1_hash)): dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) if not os.path.exists(dirname): os.makedirs(dirname) print('Downloading %s from %s...' % (fname, url)) r = requests.get(url, stream=True, verify=verify) if r.status_code != 200: raise RuntimeError('Failed downloading url %s' % url) with open(fname, 'wb') as f: for chunk in r.iter_content(chunk_size=1024): if chunk: # filter out keep-alive new chunks f.write(chunk) if sha1_hash and not check_sha1(fname, sha1_hash): raise UserWarning('File {} is downloaded but the content hash does not match. ' \ 'The repo may be outdated or download may be incomplete. ' \ 'If the "repo_url" is overridden, consider switching to ' \ 'the default repo.'.format(fname)) return fname
def get_model_file(name, root=os.path.join('~', '.mxnet', 'models')): r"""Return location for the pretrained on local file system. This function will download from online model zoo when model cannot be found or has mismatch. The root directory will be created if it doesn't exist. Parameters ---------- name : str Name of the model. root : str, default '~/.mxnet/models' Location for keeping the model parameters. Returns ------- file_path Path to the requested pretrained model file. """ file_name = '{name}-{short_hash}'.format(name=name, short_hash=short_hash(name)) root = os.path.expanduser(root) file_path = os.path.join(root, file_name + '.params') sha1_hash = _model_sha1[name] print("find file path: {}".format(file_path)) if os.path.exists(file_path): if check_sha1(file_path, sha1_hash): return file_path else: print( 'Mismatch in the content of model file detected. Downloading again.' ) else: print('Model file is not found. Downloading.') if not os.path.exists(root): os.makedirs(root) zip_file_path = os.path.join(root, file_name + '.zip') repo_url = os.environ.get('MXNET_GLUON_REPO', apache_repo_url) if repo_url[-1] != '/': repo_url = repo_url + '/' download(_url_format.format(repo_url=repo_url, file_name=file_name), path=zip_file_path, overwrite=True) with zipfile.ZipFile(zip_file_path) as zf: zf.extractall(root) os.remove(zip_file_path) if check_sha1(file_path, sha1_hash): return file_path else: raise ValueError( 'Downloaded file has different hash. Please try again.')
def __init__(self, root=os.path.join(get_home_dir(), 'models')): try: import regex # pylint: disable=import-outside-toplevel self._regex = regex except ImportError: raise ImportError('GPT2BPETokenizer requires regex. ' 'To install regex, use pip install -U regex') super(GPT2BPETokenizer, self).__init__() root = os.path.expanduser(root) file_name, sha1_hash = self.bpe_ranks_file_hash file_path = os.path.join(root, file_name) if not os.path.exists(file_path) or not check_sha1( file_path, sha1_hash): if os.path.exists(file_path): print( 'Detected mismatch in the content of BPE rank file. Downloading again.' ) else: print('BPE rank file is not found. Downloading.') os.makedirs(root, exist_ok=True) prefix = str(time.time()) zip_file_path = os.path.join(root, prefix + file_name) repo_url = _get_repo_url() if repo_url[-1] != '/': repo_url = repo_url + '/' archive_name, archive_hash = self.bpe_ranks_archive_hash _url_format = '{repo_url}gluon/dataset/vocab/{file_name}' download(_url_format.format(repo_url=repo_url, file_name=archive_name), path=zip_file_path, sha1_hash=archive_hash, overwrite=True) with zipfile.ZipFile(zip_file_path) as zf: if not os.path.exists(file_path): zf.extractall(root) try: os.remove(zip_file_path) except OSError as e: # file has already been removed. if e.errno == 2: pass else: raise e if not check_sha1(file_path, sha1_hash): raise ValueError( 'Downloaded file has different hash. Please try again.') self._read_bpe_ranks(file_path) self._cache = {} self._token_pattern = self._regex.compile( r'\'s|\'t|\'re|\'ve|\'m|\'ll|\'d| ?\p{L}+' r'| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+')
def _load_pretrained_vocab(name, root=os.path.join(get_home_dir(), 'models'), cls=None): """Load the accompanying vocabulary object for pre-trained model. Parameters ---------- name : str Name of the vocabulary, usually the name of the dataset. root : str, default '$MXNET_HOME/models' Location for keeping the model parameters. MXNET_HOME defaults to '~/.mxnet'. cls : nlp.Vocab or nlp.vocab.BERTVocab, default nlp.Vocab Returns ------- Vocab or nlp.vocab.BERTVocab Loaded vocabulary object for the pre-trained model. """ file_name = '{name}-{short_hash}'.format(name=name, short_hash=short_hash(name)) root = os.path.expanduser(root) file_path = os.path.join(root, file_name + '.vocab') sha1_hash = _vocab_sha1[name] if os.path.exists(file_path): if check_sha1(file_path, sha1_hash): return _load_vocab_file(file_path, cls) else: print( 'Detected mismatch in the content of model vocab file. Downloading again.' ) else: print('Vocab file is not found. Downloading.') if not os.path.exists(root): os.makedirs(root) zip_file_path = os.path.join(root, file_name + '.zip') repo_url = _get_repo_url() if repo_url[-1] != '/': repo_url = repo_url + '/' download(_url_format.format(repo_url=repo_url, file_name=file_name), path=zip_file_path, overwrite=True) with zipfile.ZipFile(zip_file_path) as zf: zf.extractall(root) os.remove(zip_file_path) if check_sha1(file_path, sha1_hash): return _load_vocab_file(file_path, cls) else: raise ValueError( 'Downloaded file has different hash. Please try again.')
def _load_pretrained_vocab(name, root, cls=None): """Load the accompanying vocabulary object for pre-trained model. Parameters ---------- name : str Name of the vocabulary, usually the name of the dataset. root : str Location for keeping the model vocabulary. cls : nlp.Vocab or nlp.vocab.BERTVocab, default nlp.Vocab Returns ------- Vocab or nlp.vocab.BERTVocab Loaded vocabulary object for the pre-trained model. """ file_name = '{name}-{short_hash}'.format(name=name, short_hash=short_hash(name)) root = os.path.expanduser(root) file_path = os.path.join(root, file_name + '.vocab') sha1_hash = _vocab_sha1[name] temp_num = str(random.Random().randint(1, sys.maxsize)) temp_root = os.path.join(root, temp_num) temp_file_path = os.path.join(temp_root, file_name + '.vocab') temp_zip_file_path = os.path.join(root, temp_num + file_name + '.zip') if os.path.exists(file_path): if check_sha1(file_path, sha1_hash): return _load_vocab_file(file_path, cls) else: print('Detected mismatch in the content of model vocab file. Downloading again.') else: print('Vocab file is not found. Downloading.') utils.mkdir(root) repo_url = _get_repo_url() if repo_url[-1] != '/': repo_url = repo_url + '/' download(_url_format.format(repo_url=repo_url, file_name=file_name), path=temp_zip_file_path, overwrite=True) with zipfile.ZipFile(temp_zip_file_path) as zf: if not os.path.exists(file_path): utils.mkdir(temp_root) zf.extractall(temp_root) os.replace(temp_file_path, file_path) shutil.rmtree(temp_root) if check_sha1(file_path, sha1_hash): return _load_vocab_file(file_path, cls) else: raise ValueError('Downloaded file has different hash. Please try again.')
def _get_data(shape): hash_test_img = "355e15800642286e7fe607d87c38aeeab085b0cc" hash_inception_v3 = "91807dfdbd336eb3b265dd62c2408882462752b9" fname = utils.download("http://data.mxnet.io/data/test_images_%d_%d.npy" % (shape), path="data/test_images_%d_%d.npy" % (shape), sha1_hash=hash_test_img) if not utils.check_sha1(fname, hash_test_img): raise RuntimeError("File %s not downloaded completely" % ("test_images_%d_%d.npy"%(shape))) fname = utils.download("http://data.mxnet.io/data/inception-v3-dump.npz", path='data/inception-v3-dump.npz', sha1_hash=hash_inception_v3) if not utils.check_sha1(fname, hash_inception_v3): raise RuntimeError("File %s not downloaded completely" % ("inception-v3-dump.npz"))
def __init__(self, segmenter_root=os.path.join(_get_home_dir(), 'stanford-segmenter'), slf4j_root=os.path.join(_get_home_dir(), 'slf4j'), java_class='edu.stanford.nlp.ie.crf.CRFClassifier'): is_java_exist = os.system('java -version') assert is_java_exist == 0, 'Java is not installed. You must install Java 8.0' \ 'in order to use the NLTKStanfordSegmenter' try: from nltk.tokenize import StanfordSegmenter except ImportError: raise ImportError('NLTK or relevant packages are not installed. You must install NLTK ' 'in order to use the NLTKStanfordSegmenter. You can refer to the ' 'official installation guide in https://www.nltk.org/install.html.') path_to_jar = os.path.join(segmenter_root, 'stanford-segmenter-2018-02-27', 'stanford-segmenter-3.9.1.jar') path_to_model = os.path.join(segmenter_root, 'stanford-segmenter-2018-02-27', 'data', 'pku.gz') path_to_dict = os.path.join(segmenter_root, 'stanford-segmenter-2018-02-27', 'data', 'dict-chris6.ser.gz') path_to_sihan_corpora_dict = os.path.join(segmenter_root, 'stanford-segmenter-2018-02-27', 'data') segmenter_url = 'https://nlp.stanford.edu/software/stanford-segmenter-2018-02-27.zip' segmenter_sha1 = 'aa27a6433704b7b4c6a14be1c126cb4b14b8f57b' stanford_segmenter = os.path.join(segmenter_root, 'stanford-segmenter-2018-02-27.zip') if not os.path.exists(path_to_jar) or \ not os.path.exists(path_to_model) or \ not os.path.exists(path_to_dict) or \ not os.path.exists(path_to_sihan_corpora_dict) or \ not check_sha1(filename=stanford_segmenter, sha1_hash=segmenter_sha1): # automatically download the files from the website and place them to stanford_root if not os.path.exists(segmenter_root): os.mkdir(segmenter_root) download(url=segmenter_url, path=segmenter_root, sha1_hash=segmenter_sha1) _extract_archive(file=stanford_segmenter, target_dir=segmenter_root) path_to_slf4j = os.path.join(slf4j_root, 'slf4j-1.7.25', 'slf4j-api-1.7.25.jar') slf4j_url = 'https://www.slf4j.org/dist/slf4j-1.7.25.zip' slf4j_sha1 = '89ea41ad6ebe1b190139421bb7c8d981e9df1625' slf4j = os.path.join(slf4j_root, 'slf4j-1.7.25.zip') if not os.path.exists(path_to_slf4j) or \ not check_sha1(filename=slf4j, sha1_hash=slf4j_sha1): # automatically download the files from the website and place them to slf4j_root if not os.path.exists(slf4j_root): os.mkdir(slf4j_root) download(url=slf4j_url, path=slf4j_root, sha1_hash=slf4j_sha1) _extract_archive(file=slf4j, target_dir=slf4j_root) self._tokenizer = StanfordSegmenter(java_class=java_class, path_to_jar=path_to_jar, path_to_slf4j=path_to_slf4j, path_to_dict=path_to_dict, path_to_sihan_corpora_dict=path_to_sihan_corpora_dict, path_to_model=path_to_model)
def __init__(self, segmenter_root=os.path.join(_get_home_dir(), 'stanford-segmenter'), slf4j_root=os.path.join(_get_home_dir(), 'slf4j'), java_class='edu.stanford.nlp.ie.crf.CRFClassifier'): is_java_exist = os.system('java -version') assert is_java_exist == 0, 'Java is not installed. You must install Java 8.0' \ 'in order to use the NLTKStanfordSegmenter' try: from nltk.tokenize import StanfordSegmenter except ImportError: raise ImportError('NLTK or relevant packages are not installed. You must install NLTK ' 'in order to use the NLTKStanfordSegmenter. You can refer to the ' 'official installation guide in https://www.nltk.org/install.html.') path_to_jar = os.path.join(segmenter_root, 'stanford-segmenter-2018-02-27', 'stanford-segmenter-3.9.1.jar') path_to_model = os.path.join(segmenter_root, 'stanford-segmenter-2018-02-27', 'data', 'pku.gz') path_to_dict = os.path.join(segmenter_root, 'stanford-segmenter-2018-02-27', 'data', 'dict-chris6.ser.gz') path_to_sihan_corpora_dict = os.path.join(segmenter_root, 'stanford-segmenter-2018-02-27', 'data') segmenter_url = 'https://nlp.stanford.edu/software/stanford-segmenter-2018-02-27.zip' segmenter_sha1 = 'aa27a6433704b7b4c6a14be1c126cb4b14b8f57b' stanford_segmenter = os.path.join(segmenter_root, 'stanford-segmenter-2018-02-27.zip') if not os.path.exists(path_to_jar) or \ not os.path.exists(path_to_model) or \ not os.path.exists(path_to_dict) or \ not os.path.exists(path_to_sihan_corpora_dict) or \ not check_sha1(filename=stanford_segmenter, sha1_hash=segmenter_sha1): # automatically download the files from the website and place them to stanford_root if not os.path.exists(segmenter_root): os.mkdir(segmenter_root) download(url=segmenter_url, path=segmenter_root, sha1_hash=segmenter_sha1) _extract_archive(file=stanford_segmenter, target_dir=segmenter_root) path_to_slf4j = os.path.join(slf4j_root, 'slf4j-1.7.25', 'slf4j-api-1.7.25.jar') slf4j_url = 'https://www.slf4j.org/dist/slf4j-1.7.25.zip' slf4j_sha1 = '89ea41ad6ebe1b190139421bb7c8d981e9df1625' slf4j = os.path.join(slf4j_root, 'slf4j-1.7.25.zip') if not os.path.exists(path_to_slf4j) or \ not check_sha1(filename=slf4j, sha1_hash=slf4j_sha1): # automatically download the files from the website and place them to slf4j_root if not os.path.exists(slf4j_root): os.mkdir(slf4j_root) download(url=slf4j_url, path=slf4j_root, sha1_hash=slf4j_sha1) _extract_archive(file=slf4j, target_dir=slf4j_root) self._tokenizer = StanfordSegmenter(java_class=java_class, path_to_jar=path_to_jar, path_to_slf4j=path_to_slf4j, path_to_dict=path_to_dict, path_to_sihan_corpora_dict=path_to_sihan_corpora_dict, path_to_model=path_to_model)
def _get_xlnet_tokenizer(dataset_name, root): assert dataset_name.lower() == '126gb' root = os.path.expanduser(root) file_path = os.path.join(root, 'xlnet_126gb-871f0b3c.spiece') sha1_hash = '871f0b3c13b92fc5aea8fba054a214c420e302fd' if os.path.exists(file_path): if not check_sha1(file_path, sha1_hash): print( 'Detected mismatch in the content of model tokenizer. Downloading again.' ) else: print('Tokenizer file is not found. Downloading.') if not os.path.exists(root): try: os.makedirs(root) except OSError as e: if e.errno == errno.EEXIST and os.path.isdir(root): pass else: raise e repo_url = _get_repo_url() prefix = str(time.time()) zip_file_path = os.path.join(root, prefix + 'xlnet_126gb-871f0b3c.zip') if repo_url[-1] != '/': repo_url = repo_url + '/' download(_url_format.format(repo_url=repo_url, file_name='xlnet_126gb-871f0b3c'), path=zip_file_path, overwrite=True) with zipfile.ZipFile(zip_file_path) as zf: if not os.path.exists(file_path): zf.extractall(root) try: os.remove(zip_file_path) except OSError as e: # file has already been removed. if e.errno == 2: pass else: raise e if not check_sha1(file_path, sha1_hash): raise ValueError( 'Downloaded file has different hash. Please try again.') tokenizer = XLNetTokenizer(file_path) return tokenizer
def _download_data(self): _, archive_hash = self._archive_file for name, checksum in self._checksums.items(): name = name.split('/') path = os.path.join(self.root, *name) if not os.path.exists(path) or not check_sha1(path, checksum): if self._namespace is not None: url = _get_repo_file_url(self._namespace, self._archive_file[0]) else: url = self._url downloaded_file_path = download(url, path=self.root, sha1_hash=archive_hash, verify_ssl=self._verify_ssl) if downloaded_file_path.lower().endswith('zip'): with zipfile.ZipFile(downloaded_file_path, 'r') as zf: zf.extractall(path=self.root) elif downloaded_file_path.lower().endswith('tar.gz'): with tarfile.open(downloaded_file_path, 'r') as tf: tf.extractall(path=self.root) elif len(self._checksums) > 1: err = 'Failed retrieving {clsname}.'.format( clsname=self.__class__.__name__) err += (' Expecting multiple files, ' 'but could not detect archive format.') raise RuntimeError(err)
def _get_data(self): data_file_name, data_hash = self._data_file()[self._segment] root = self._root path = os.path.join(root, data_file_name) if not os.path.exists(path) or not check_sha1(path, data_hash): download(_get_repo_file_url(self._repo_dir(), data_file_name), path=root, sha1_hash=data_hash)
def _get_data(self): if any(not os.path.exists(path) or not check_sha1(path, sha1) for path, sha1 in ((os.path.join(self._root, name), sha1) for name, sha1 in self._train_data + self._test_data)): namespace = 'gluon/dataset/' + self._namespace filename = download(_get_repo_file_url(namespace, self._archive_file[0]), path=self._root, sha1_hash=self._archive_file[1]) with tarfile.open(filename) as tar: tar.extractall(self._root) if self._train: data_files = self._train_data else: data_files = self._test_data data, label = zip(*(self._read_batch(os.path.join(self._root, name)) for name, _ in data_files)) data = np.concatenate(data) label = np.concatenate(label) if self._train: npr.seed(0) rand_inds = npr.permutation(50000) data = data[rand_inds] label = label[rand_inds] data = data[self._split_id * 10000:(self._split_id + 1) * 10000] label = label[self._split_id * 10000:(self._split_id + 1) * 10000] self._data = nd.array(data, dtype=data.dtype) self._label = label
def check_file(filename, checksum, sha1): from mxnet.gluon.utils import check_sha1 if not os.path.exists(filename): raise ValueError('File not found: '+filename) if checksum and not check_sha1(filename, sha1): raise ValueError('Corrupted file: '+filename)
def _get_data(self): archive_file_name, archive_hash = self._archive_file archive_file_path = os.path.join(self._root, archive_file_name) exists = False if os.path.exists(self._dir) and os.path.exists(self._subdir): # verify sha1 for all files in the subdir sha1 = hashlib.sha1() filenames = sorted(glob.glob(self._file_pattern)) for filename in filenames: with open(filename, 'rb') as f: while True: data = f.read(1048576) if not data: break sha1.update(data) if sha1.hexdigest() == self._data_hash: exists = True if not exists: # download archive if not os.path.exists(archive_file_path) or \ not check_sha1(archive_file_path, archive_hash): download(_get_repo_file_url(self._namespace, archive_file_name), path=self._root, sha1_hash=archive_hash) # extract archive with tarfile.open(archive_file_path, 'r:gz') as tf: tf.extractall(path=self._root)
def get_model_file(name, root=os.path.join('~', '.mxnet', 'models')): r"""Return location for the pretrained on local file system. This function will download from online model zoo when model cannot be found or has mismatch. The root directory will be created if it doesn't exist. Parameters ---------- name : str Name of the model. root : str, default '~/.mxnet/models' Location for keeping the model parameters. Returns ------- file_path Path to the requested pretrained model file. """ file_name = '{name}-{short_hash}'.format(name=name, short_hash=short_hash(name)) root = os.path.expanduser(root) file_path = os.path.join(root, file_name + '.params') sha1_hash = _model_sha1[name] if os.path.exists(file_path): if check_sha1(file_path, sha1_hash): return file_path else: raise AssertionError( 'Mismatch in the content of model file detected. Please download it again.' ) else: raise AssertionError( 'Model file: %s is not found. Please download before use it.' % file_path)
def _load_pretrained_vocab(name, root=os.path.join('~', '.mxnet', 'models')): """Load the accompanying vocabulary object for pre-trained model. Parameters ---------- name : str Name of the vocabulary, usually the name of the dataset. root : str, default '~/.mxnet/models' Location for keeping the model parameters. Returns ------- Vocab Loaded vocabulary object for the pre-trained model. """ file_name = '{name}-{short_hash}'.format(name=name, short_hash=short_hash(name)) root = os.path.expanduser(root) file_path = os.path.join(root, file_name+'.vocab') sha1_hash = _vocab_sha1[name] if os.path.exists(file_path): if check_sha1(file_path, sha1_hash): return _load_vocab_file(file_path) else: print('Detected mismatch in the content of model vocab file. Downloading again.') else: print('Vocab file is not found. Downloading.') if not os.path.exists(root): os.makedirs(root) zip_file_path = os.path.join(root, file_name+'.zip') repo_url = _get_repo_url() if repo_url[-1] != '/': repo_url = repo_url + '/' download(_url_format.format(repo_url=repo_url, file_name=file_name), path=zip_file_path, overwrite=True) with zipfile.ZipFile(zip_file_path) as zf: zf.extractall(root) os.remove(zip_file_path) if check_sha1(file_path, sha1_hash): return _load_vocab_file(file_path) else: raise ValueError('Downloaded file has different hash. Please try again.')
def _get_data(self): filename_format, sha1_hash = self._download_info filename = filename_format.format(sha1_hash[:8]) data_filename = os.path.join(self._root, filename) url = _get_repo_file_url('gluon/dataset', filename) if not os.path.exists(data_filename) or not check_sha1( data_filename, sha1_hash): download(url, path=data_filename, sha1_hash=sha1_hash) with zipfile.ZipFile(data_filename, 'r') as zf: zf.extractall(self._root)
def _get_file_path(cls_name, file_name, file_hash): root_path = os.path.expanduser(os.path.join(get_home_dir(), 'embedding')) embedding_dir = os.path.join(root_path, cls_name) url = _get_file_url(cls_name, file_name) file_path = os.path.join(embedding_dir, file_name) if not os.path.exists(file_path) or not check_sha1(file_path, file_hash): logging.info( 'Embedding file {} is not found. Downloading from Gluon Repository. ' 'This may take some time.'.format(file_name)) download(url, file_path, sha1_hash=file_hash) return file_path
def _get_data(self): filename, url, sha1_hash = self._download_info data_filename = os.path.join(self._root, filename) if not os.path.exists(data_filename) or not check_sha1( data_filename, sha1_hash): download(url, path=data_filename, sha1_hash=sha1_hash, verify_ssl=True) with zipfile.ZipFile(data_filename, 'r') as zf: zf.extractall(self._root)
def _load_pretrained_sentencepiece_tokenizer(name, root, **kwargs): from ..data import SentencepieceTokenizer # pylint: disable=import-outside-toplevel file_name, file_ext, sha1_hash, _ = _get_vocab_tokenizer_info(name, root) file_path = os.path.join(root, file_name + file_ext) if os.path.exists(file_path): if check_sha1(file_path, sha1_hash): assert file_path.endswith('.spiece') return SentencepieceTokenizer(file_path, **kwargs) else: print( 'Detected mismatch in the content of model tokenizer file. Downloading again.' ) else: print('tokenizer file is not found. Downloading.') _download_vocab_tokenizer(root, file_name, file_ext, file_path) if check_sha1(file_path, sha1_hash): assert file_path.endswith('.spiece') return SentencepieceTokenizer(file_path, **kwargs) else: raise ValueError( 'Downloaded file has different hash. Please try again.')
def _get_data(self, force_download=False): if not os.path.exists(self._root): os.makedirs(self._root) data_name, url, data_hash = self._raw_data_urls[self._name] path = os.path.join(self._root, data_name) if not os.path.exists(path) or force_download or ( data_hash and not check_sha1(path, data_hash)): print("\n\n=====================> Download dataset ...") self.download(url, path=path, sha1_hash=data_hash) print("\n\n=====================> Unzip the file ...") with ZipFile(path, 'r') as zf: zf.extractall(path=self._root)
def _get_data(self): archive_file_name, archive_hash = self.archive_file data_file_name, data_hash = self.data_file[self._segment] root = self._root path = os.path.join(root, data_file_name) if not os.path.exists(path) or not check_sha1(path, data_hash): downloaded_file_path = download(self.url + archive_file_name, path=root, sha1_hash=archive_hash) with zipfile.ZipFile(downloaded_file_path, 'r') as zf: zf.extractall(root) return path
def _get_data(self): archive_file_name, archive_hash = self._get_data_archive_hash() paths = [] for data_file_name, data_hash in self._get_data_file_hash(): root = self._root path = os.path.join(root, data_file_name) if not os.path.exists(path) or not check_sha1(path, data_hash): download(self.base_url + archive_file_name, path=root, sha1_hash=archive_hash) self._extract_archive() paths.append(path) return paths
def _get_data(archive_file, data_file, segment, root, namespace): archive_file_name, archive_hash = archive_file data_file_name, data_hash = data_file[segment] path = os.path.join(root, data_file_name) if not os.path.exists(path) or not check_sha1(path, data_hash): downloaded_file_path = download(_get_repo_file_url( namespace, archive_file_name), path=root, sha1_hash=archive_hash) with zipfile.ZipFile(downloaded_file_path, 'r') as zf: zf.extractall(root) return path
def _get_vocab(self): archive_file_name, archive_hash = self._archive_vocab vocab_file_name, vocab_hash = self._vocab_file namespace = 'gluon/dataset/vocab' root = self._root path = os.path.join(root, vocab_file_name) if not os.path.exists(path) or not check_sha1(path, vocab_hash): downloaded_path = download(_get_repo_file_url(namespace, archive_file_name), path=root, sha1_hash=archive_hash) with zipfile.ZipFile(downloaded_path, 'r') as zf: zf.extractall(path=root) return path
def _get_data(self): archive_file_name, archive_hash = self._get_data_archive_hash() paths = [] for data_file_name, data_hash in self._get_data_file_hash(): root = self._root path = os.path.join(root, data_file_name) if not os.path.exists(path) or not check_sha1(path, data_hash): download(self.base_url + archive_file_name, path=root, sha1_hash=archive_hash) self._extract_archive() paths.append(path) return paths
def _get_data(self, segment, zip_hash, data_hash, filename): data_filename = '%s-%s.zip' % (segment, data_hash[:8]) if not os.path.exists(filename) or not check_sha1(filename, data_hash): download(_get_repo_file_url(self._repo_dir(), data_filename), path=self._root, sha1_hash=zip_hash) # unzip downloaded_path = os.path.join(self._root, data_filename) with zipfile.ZipFile(downloaded_path, 'r') as zf: # skip dir structures in the zip for zip_info in zf.infolist(): if zip_info.filename[-1] == '/': continue zip_info.filename = os.path.basename(zip_info.filename) zf.extract(zip_info, self._root)
def _get_data(self): archive_file_name, archive_hash = self._get_data_archive_hash() paths = [] for data_file_name, data_hash in self._get_data_file_hash(): root = self._root path = os.path.join(root, data_file_name) if hasattr(self, 'namespace'): url = _get_repo_file_url(self.namespace, archive_file_name) else: url = self.base_url + archive_file_name if not os.path.exists(path) or not check_sha1(path, data_hash): download(url, path=root, sha1_hash=archive_hash) self._extract_archive() paths.append(path) return paths
def download_city(path, overwrite=False): _CITY_DOWNLOAD_URLS = [ ('gtFine_trainvaltest.zip', '99f532cb1af174f5fcc4c5bc8feea8c66246ddbc'), ('leftImg8bit_trainvaltest.zip', '2c0b77ce9933cc635adda307fbba5566f5d9d404')] download_dir = os.path.join(path, 'downloads') makedirs(download_dir) for filename, checksum in _CITY_DOWNLOAD_URLS: if not check_sha1(filename, checksum): raise UserWarning('File {} is downloaded but the content hash does not match. ' \ 'The repo may be outdated or download may be incomplete. ' \ 'If the "repo_url" is overridden, consider switching to ' \ 'the default repo.'.format(filename)) # extract with zipfile.ZipFile(filename,"r") as zip_ref: zip_ref.extractall(path=path) print("Extracted", filename)
def _get_file_path(cls, source_file_hash, embedding_root, source): cls_name = cls.__name__.lower() embedding_root = os.path.expanduser(embedding_root) url = cls._get_file_url(cls_name, source_file_hash, source) embedding_dir = os.path.join(embedding_root, cls_name) pretrained_file_name, expected_file_hash = source_file_hash[source] pretrained_file_path = os.path.join(embedding_dir, pretrained_file_name) if not os.path.exists(pretrained_file_path) \ or not check_sha1(pretrained_file_path, expected_file_hash): print('Embedding file {} is not found. Downloading from Gluon Repository. ' 'This may take some time.'.format(pretrained_file_name)) download(url, pretrained_file_path, sha1_hash=expected_file_hash) return pretrained_file_path
def _get_file_path(cls, source_file_hash, embedding_root, source): cls_name = cls.__name__.lower() embedding_root = os.path.expanduser(embedding_root) url = cls._get_file_url(cls_name, source_file_hash, source) embedding_dir = os.path.join(embedding_root, cls_name) pretrained_file_name, expected_file_hash = source_file_hash[source] pretrained_file_path = os.path.join(embedding_dir, pretrained_file_name) if not os.path.exists(pretrained_file_path) \ or not check_sha1(pretrained_file_path, expected_file_hash): print('Embedding file {} is not found. Downloading from Gluon Repository. ' 'This may take some time.'.format(pretrained_file_name)) download(url, pretrained_file_path, sha1_hash=expected_file_hash) return pretrained_file_path
def download_city(path, overwrite=False): _CITY_DOWNLOAD_URLS = [('gtFine_trainvaltest.zip', '99f532cb1af174f5fcc4c5bc8feea8c66246ddbc'), ('leftImg8bit_trainvaltest.zip', '2c0b77ce9933cc635adda307fbba5566f5d9d404')] download_dir = os.path.join(path, 'downloads') makedirs(download_dir) for filename, checksum in _CITY_DOWNLOAD_URLS: if not check_sha1(filename, checksum): raise UserWarning('File {} is downloaded but the content hash does not match. ' \ 'The repo may be outdated or download may be incomplete. ' \ 'If the "repo_url" is overridden, consider switching to ' \ 'the default repo.'.format(filename)) # extract with zipfile.ZipFile(filename, "r") as zip_ref: zip_ref.extractall(path=path) print("Extracted", filename)
def _get_data(self): """Load data from the file. Do nothing if data was loaded before. """ (data_archive_name, archive_hash), (data_name, data_hash) \ = self._data_file()[self._segment] data_path = os.path.join(self._root, data_name) if not os.path.exists(data_path) or not check_sha1(data_path, data_hash): file_path = download(_get_repo_file_url(self._repo_dir(), data_archive_name), path=self._root, sha1_hash=archive_hash) with zipfile.ZipFile(file_path, 'r') as zf: for member in zf.namelist(): filename = os.path.basename(member) if filename: dest = os.path.join(self._root, filename) with zf.open(member) as source, open(dest, 'wb') as target: shutil.copyfileobj(source, target)
def _get_data(self): """Load data from the file. Does nothing if data was loaded before """ data_archive_name, _, data_hash = self._data_file[self._segment] path = os.path.join(self._root, data_archive_name) if not os.path.exists(path) or not check_sha1(path, data_hash): file_path = download(_get_repo_file_url('gluon/dataset/squad', data_archive_name), path=self._root, sha1_hash=data_hash) with zipfile.ZipFile(file_path, 'r') as zf: for member in zf.namelist(): filename = os.path.basename(member) if filename: dest = os.path.join(self._root, filename) with zf.open(member) as source, open(dest, 'wb') as target: shutil.copyfileobj(source, target)
def _get_data(self): archive_file_name, archive_hash = self._archive_file data_file_name, data_hash = self._data_file[self._segment] root = self._root path = os.path.join(root, data_file_name) if not os.path.exists(path) or not check_sha1(path, data_hash): downloaded_file_path = download(_get_repo_file_url(self._namespace, archive_file_name), path=root, sha1_hash=archive_hash) with zipfile.ZipFile(downloaded_file_path, 'r') as zf: for member in zf.namelist(): filename = os.path.basename(member) if filename: dest = os.path.join(root, filename) with zf.open(member) as source, \ open(dest, 'wb') as target: shutil.copyfileobj(source, target) return path
def check_file(filename, checksum, sha1): if not os.path.exists(filename): raise ValueError('File not found: '+filename) if checksum and not check_sha1(filename, sha1): raise ValueError('Corrupted file: '+filename)