def _get_data(self, mode, **kwargs): default_root = os.path.join(DATA_HOME, self.__class__.__name__) src_filename, tgt_filename, src_data_hash, tgt_data_hash = self.SPLITS[ mode] src_fullname = os.path.join(default_root, src_filename) tgt_fullname = os.path.join(default_root, tgt_filename) (bpe_vocab_filename, bpe_vocab_hash), (sub_vocab_filename, sub_vocab_hash) = self.VOCAB_INFO bpe_vocab_fullname = os.path.join(default_root, bpe_vocab_filename) sub_vocab_fullname = os.path.join(default_root, sub_vocab_filename) if (not os.path.exists(src_fullname) or (src_data_hash and not md5file(src_fullname) == src_data_hash)) or ( not os.path.exists(tgt_fullname) or (tgt_data_hash and not md5file(tgt_fullname) == tgt_data_hash)) or ( not os.path.exists(bpe_vocab_fullname) or (bpe_vocab_hash and not md5file(bpe_vocab_fullname) == bpe_vocab_hash)) or ( not os.path.exists(sub_vocab_fullname) or (sub_vocab_hash and not md5file(sub_vocab_fullname) == sub_vocab_hash)): get_path_from_url(self.URL, default_root, self.MD5) return src_fullname, tgt_fullname
def download(url, md5sum, target_dir): """Download file from url to target_dir, and check md5sum.""" if not os.path.exists(target_dir): os.makedirs(target_dir) filepath = os.path.join(target_dir, url.split("/")[-1]) if not (os.path.exists(filepath) and md5file(filepath) == md5sum): print("Downloading %s ..." % url) os.system("wget -c " + url + " -P " + target_dir) print("\nMD5 Chesksum %s ..." % filepath) if not md5file(filepath) == md5sum: raise RuntimeError("MD5 checksum failed.") else: print("File exists, skip downloading. (%s)" % filepath) return filepath
def _get_data(self, mode, **kwargs): default_root = os.path.join(DATA_HOME, self.__class__.__name__) filename, data_hash = self.SPLITS[mode] fullname = os.path.join(default_root, filename) vocab_filename, vocab_hash = self.VOCAB_INFO vocab_fullname = os.path.join(default_root, vocab_filename) if (not os.path.exists(fullname) ) or (data_hash and not md5file(fullname) == data_hash) or ( not os.path.exists(vocab_fullname) or (vocab_hash and not md5file(vocab_fullname) == vocab_hash)): get_path_from_url(self.URL, default_root, self.MD5) return fullname
def _get_data(self, mode): """ Check and download Dataset """ dl_paths = {} version = self.config.get("version", "3.0.0") if version not in ["1.0.0", "2.0.0", "3.0.0"]: raise ValueError("Unsupported version: %s" % version) dl_paths["version"] = version default_root = os.path.join(DATA_HOME, self.__class__.__name__) for k, v in self.cnn_dailymail.items(): dir_path = os.path.join(default_root, k) if not os.path.exists(dir_path): get_path_from_url(v["url"], default_root, v["md5"]) unique_endpoints = _get_unique_endpoints(ParallelEnv() .trainer_endpoints[:]) if ParallelEnv().current_endpoint in unique_endpoints: file_num = len(os.listdir(os.path.join(dir_path, "stories"))) if file_num != v["file_num"]: logger.warning( "Number of %s stories is %d != %d, decompress again." % (k, file_num, v["file_num"])) shutil.rmtree(os.path.join(dir_path, "stories")) _decompress( os.path.join(default_root, os.path.basename(v["url"]))) dl_paths[k] = dir_path filename, url, data_hash = self.SPLITS[mode] fullname = os.path.join(default_root, filename) if not os.path.exists(fullname) or (data_hash and not md5file(fullname) == data_hash): get_path_from_url(url, default_root, data_hash) dl_paths[mode] = fullname return dl_paths
def _download_data(cls, mode="train", root=None): """Download dataset""" default_root = os.path.join(DATA_HOME, 'machine_translation', cls.__name__) src_filename, tgt_filename, src_data_hash, tgt_data_hash = cls.SPLITS[ mode] filename_list = [ src_filename, tgt_filename, cls.VOCAB_INFO[0], cls.VOCAB_INFO[1] ] fullname_list = [] for filename in filename_list: fullname = os.path.join( default_root, filename) if root is None else os.path.join( os.path.expanduser(root), filename) fullname_list.append(fullname) data_hash_list = [ src_data_hash, tgt_data_hash, cls.VOCAB_INFO[2], cls.VOCAB_INFO[3] ] for i, fullname in enumerate(fullname_list): if not os.path.exists(fullname) or ( data_hash_list[i] and not md5file(fullname) == data_hash_list[i]): if root is not None: # not specified, and no need to warn warnings.warn( 'md5 check failed for {}, download {} data to {}'. format(filename, cls.__name__, default_root)) path = get_path_from_url(cls.URL, default_root, cls.MD5) return default_root return root if root is not None else default_root
def _get_data(self, root, mode, **kwargs): default_root = os.path.join(DATA_HOME, self.__class__.__name__) if self.version_2_with_negative: filename, data_hash = self.SPLITS['2.0'][mode] else: filename, data_hash = self.SPLITS['1.1'][mode] fullname = os.path.join(default_root, filename) if root is None else os.path.join( os.path.expanduser(root), filename) if not os.path.exists(fullname) or ( data_hash and not md5file(fullname) == data_hash): if root is not None: # not specified, and no need to warn warnings.warn( 'md5 check failed for {}, download {} data to {}'.format( filename, self.__class__.__name__, default_root)) if mode == 'train': if self.version_2_with_negative: fullname = get_path_from_url( self.TRAIN_DATA_URL_V2, os.path.join(default_root, 'v2')) else: fullname = get_path_from_url( self.TRAIN_DATA_URL_V1, os.path.join(default_root, 'v1')) elif mode == 'dev': if self.version_2_with_negative: fullname = get_path_from_url( self.DEV_DATA_URL_V2, os.path.join(default_root, 'v2')) else: fullname = get_path_from_url( self.DEV_DATA_URL_V1, os.path.join(default_root, 'v1')) self.full_path = fullname
def download_file(save_dir, filename, url, md5=None, task=None): """ Download the file from the url to specified directory. Check md5 value when the file is exists, if the md5 value is the same as the existed file, just use the older file, if not, will download the file from the url. Args: save_dir(string): The specified directory saving the file. filename(string): The specified filename saving the file. url(string): The url downling the file. md5(string, optional): The md5 value that checking the version downloaded. """ logger.disable() global DOWNLOAD_CHECK if not DOWNLOAD_CHECK: DOWNLOAD_CHECK = True checker = DownloaderCheck(task) checker.start() checker.join() fullname = os.path.join(save_dir, filename) if os.path.exists(fullname): if md5 and (not md5file(fullname) == md5): get_path_from_url(url, save_dir, md5) else: get_path_from_url(url, save_dir, md5) logger.enable() return fullname
def _get_data(self, mode, **kwargs): """Downloads dataset.""" default_root = os.path.join(DATA_HOME, self.__class__.__name__) filename, data_hash, url, zipfile_hash = self.SPLITS[mode] fullname = os.path.join(default_root, filename) if mode == 'train': if not os.path.exists(fullname): get_path_from_url(url, default_root, zipfile_hash) unique_endpoints = _get_unique_endpoints( ParallelEnv().trainer_endpoints[:]) if ParallelEnv().current_endpoint in unique_endpoints: file_num = len(os.listdir(fullname)) if file_num != len(ALL_LANGUAGES): logger.warning( "Number of train files is %d != %d, decompress again." % (file_num, len(ALL_LANGUAGES))) shutil.rmtree(fullname) _decompress( os.path.join(default_root, os.path.basename(url))) else: if not os.path.exists(fullname) or ( data_hash and not md5file(fullname) == data_hash): get_path_from_url(url, default_root, zipfile_hash) return fullname
def _get_data(self, mode, **kwargs): default_root = DATA_HOME filename, data_hash = self.SPLITS[mode] fullname = os.path.join(default_root, filename) if not os.path.exists(fullname) or (data_hash and not md5file(fullname) == data_hash): fullname = os.path.join(default_root, filename) return fullname
def _get_data(self, mode, **kwargs): builder_config = self.BUILDER_CONFIGS[self.name] default_root = os.path.join(DATA_HOME, self.__class__.__name__) filename, url, data_hash = builder_config['splits'][mode] fullname = os.path.join(default_root, filename) if not os.path.exists(fullname) or ( data_hash and not md5file(fullname) == data_hash): get_path_from_url(url, default_root, data_hash) return fullname
def _get_data(self, mode, **kwargs): default_root = os.path.join(DATA_HOME, self.__class__.__name__) filename, data_hash = self.SPLITS[mode] fullname = os.path.join(default_root, filename) if not os.path.exists(fullname) or ( data_hash and not md5file(fullname) == data_hash): get_path_from_url(self.URL, default_root) return fullname
def _get_data(self, mode, **kwargs): """Downloads dataset.""" default_root = DATA_HOME filename, data_hash = self.SPLITS[mode] fullname = os.path.join(default_root, filename) if not os.path.exists(fullname) or ( data_hash and not md5file(fullname) == data_hash): path = get_path_from_url(self.URL, default_root, self.MD5) fullname = os.path.join(default_root, filename) return fullname
def _get_data(self, mode, **kwargs): default_root = os.path.join(DATA_HOME, self.__class__.__name__, mode) meta_info_list = self.SPLITS[mode] fullnames = [] for meta_info in meta_info_list: filename, data_hash, URL = meta_info fullname = os.path.join(default_root, filename) if not os.path.exists(fullname) or ( data_hash and not md5file(fullname) == data_hash): get_path_from_url(URL, default_root) fullnames.append(fullname) return fullnames
def _get_data(self, mode, **kwargs): """Downloads dataset.""" builder_config = self.BUILDER_CONFIGS[self.name] default_root = os.path.join(DATA_HOME, 'SE-ABSA16_PHNS') filename, data_hash, _, _ = builder_config['splits'][mode] fullname = os.path.join(default_root, filename) if not os.path.exists(fullname) or ( data_hash and not md5file(fullname) == data_hash): url = builder_config['url'] md5 = builder_config['md5'] get_path_from_url(url, DATA_HOME, md5) return fullname
def test_download_url(self): LABEL_URL = 'http://paddlemodels.bj.bcebos.com/flowers/imagelabels.mat' LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d' catch_exp = False try: download(LABEL_URL, 'flowers', LABEL_MD5) except Exception as e: catch_exp = True self.assertTrue(catch_exp == False) file_path = DATA_HOME + "/flowers/imagelabels.mat" self.assertTrue(os.path.exists(file_path)) self.assertTrue(md5file(file_path), LABEL_MD5)
def _get_data(self, root, mode, **kwargs): default_root = os.path.join(DATA_HOME, 'DuReader') filename, data_hash = self.SPLITS[mode] fullname = os.path.join(default_root, filename) if root is None else os.path.join( os.path.expanduser(root), filename) if not os.path.exists(fullname) or ( data_hash and not md5file(fullname) == data_hash): if root is not None: # not specified, and no need to warn warnings.warn( 'md5 check failed for {}, download {} data to {}'.format( filename, self.__class__.__name__, default_root)) get_path_from_url(self.DATA_URL, default_root) self.full_path = fullname
def _get_data(self, root, mode, **kwargs): default_root = DATA_HOME filename, data_hash, field_indices, num_discard_samples = self.SPLITS[ mode] fullname = os.path.join(default_root, filename) if root is None else os.path.join( os.path.expanduser(root), filename) if not os.path.exists(fullname) or ( data_hash and not md5file(fullname) == data_hash): if root is not None: # not specified, and no need to warn warnings.warn( 'md5 check failed for {}, download {} data to {}'.format( filename, self.__class__.__name__, default_root)) path = get_path_from_url(self.URL, default_root, self.MD5) fullname = os.path.join(default_root, filename) super(LCQMC, self).__init__(fullname, field_indices=field_indices, num_discard_samples=num_discard_samples, **kwargs)
def download_file(save_dir, filename, url, md5=None): """ Download the file from the url to specified directory. Check md5 value when the file is exists, if the md5 value is the same as the existed file, just use the older file, if not, will download the file from the url. Args: save_dir(string): The specified directory saving the file. filename(string): The specified filename saving the file. url(string): The url downling the file. md5(string, optional): The md5 value that checking the version downloaded. """ fullname = os.path.join(save_dir, filename) if os.path.exists(fullname): if md5 and (not md5file(fullname) == md5): logger.disable() get_path_from_url(url, save_dir, md5) else: logger.info("Downloading {} from {}".format(filename, url)) logger.disable() get_path_from_url(url, save_dir, md5) logger.enable() return fullname
def _check_task_files(self): """ Check files required by the task. """ for file_id, file_name in self.resource_files_names.items(): path = os.path.join(self._task_path, file_name) url = self.resource_files_urls[self.model][file_id][0] md5 = self.resource_files_urls[self.model][file_id][1] downloaded = True if not os.path.exists(path): downloaded = False else: if not self._custom_model: if os.path.exists(path): # Check whether the file is updated if not md5file(path) == md5: downloaded = False if file_id == "model_state": self._param_updated = True else: downloaded = False if not downloaded: download_file(self._task_path, file_name, url, md5)
def _get_data(self, mode, **kwargs): builder_config = self.BUILDER_CONFIGS[self.name] if self.name != 'mrpc': default_root = os.path.join(DATA_HOME, self.__class__.__name__) filename, data_hash, _, _ = builder_config['splits'][mode] fullname = os.path.join(default_root, filename) if not os.path.exists(fullname) or ( data_hash and not md5file(fullname) == data_hash): get_path_from_url(builder_config['url'], default_root, builder_config['md5']) else: default_root = os.path.join(DATA_HOME, self.__class__.__name__) filename, data_hash, _, _ = builder_config['splits'][mode] fullname = os.path.join(default_root, filename) if not os.path.exists(fullname) or ( data_hash and not md5file(fullname) == data_hash): if mode in ('train', 'dev'): dev_id_path = get_path_from_url( builder_config['url']['dev_id'], os.path.join(default_root, 'MRPC'), builder_config['md5']['dev_id']) train_data_path = get_path_from_url( builder_config['url']['train_data'], os.path.join(default_root, 'MRPC'), builder_config['md5']['train_data']) # read dev data ids dev_ids = [] print(dev_id_path) with open(dev_id_path, encoding='utf-8') as ids_fh: for row in ids_fh: dev_ids.append(row.strip().split('\t')) # generate train and dev set train_path = os.path.join(default_root, 'MRPC', 'train.tsv') dev_path = os.path.join(default_root, 'MRPC', 'dev.tsv') with open(train_data_path, encoding='utf-8') as data_fh: with open(train_path, 'w', encoding='utf-8') as train_fh: with open(dev_path, 'w', encoding='utf8') as dev_fh: header = data_fh.readline() train_fh.write(header) dev_fh.write(header) for row in data_fh: label, id1, id2, s1, s2 = row.strip( ).split('\t') example = '%s\t%s\t%s\t%s\t%s\n' % ( label, id1, id2, s1, s2) if [id1, id2] in dev_ids: dev_fh.write(example) else: train_fh.write(example) else: test_data_path = get_path_from_url( builder_config['url']['test_data'], os.path.join(default_root, 'MRPC'), builder_config['md5']['test_data']) test_path = os.path.join(default_root, 'MRPC', 'test.tsv') with open(test_data_path, encoding='utf-8') as data_fh: with open(test_path, 'w', encoding='utf-8') as test_fh: header = data_fh.readline() test_fh.write( 'index\t#1 ID\t#2 ID\t#1 String\t#2 String\n') for idx, row in enumerate(data_fh): label, id1, id2, s1, s2 = row.strip().split( '\t') test_fh.write('%d\t%s\t%s\t%s\t%s\n' % (idx, id1, id2, s1, s2)) return fullname
def _get_data(self, root, segment, **kwargs): default_root = os.path.join(DATA_HOME, 'glue') filename, data_hash, field_indices, num_discard_samples = self.SEGMENTS[ segment] fullname = os.path.join(default_root, filename) if root is None else os.path.join( os.path.expanduser(root), filename) if not os.path.exists(fullname) or ( data_hash and not md5file(fullname) == data_hash): if root is not None: # not specified, and no need to warn warnings.warn( 'md5 check failed for {}, download {} data to {}'.format( filename, self.__class__.__name__, default_root)) if segment in ('train', 'dev'): dev_id_path = get_path_from_url( self.DEV_ID_URL, os.path.join(default_root, 'MRPC'), self.DEV_ID_MD5) train_data_path = get_path_from_url( self.TRAIN_DATA_URL, os.path.join(default_root, 'MRPC'), self.TRAIN_DATA_MD5) # read dev data ids dev_ids = [] with io.open(dev_id_path, encoding='utf-8') as ids_fh: for row in ids_fh: dev_ids.append(row.strip().split('\t')) # generate train and dev set train_path = os.path.join(default_root, 'MRPC', 'train.tsv') dev_path = os.path.join(default_root, 'MRPC', 'dev.tsv') with io.open(train_data_path, encoding='utf-8') as data_fh: with io.open(train_path, 'w', encoding='utf-8') as train_fh: with io.open(dev_path, 'w', encoding='utf8') as dev_fh: header = data_fh.readline() train_fh.write(header) dev_fh.write(header) for row in data_fh: label, id1, id2, s1, s2 = row.strip().split( '\t') example = '%s\t%s\t%s\t%s\t%s\n' % ( label, id1, id2, s1, s2) if [id1, id2] in dev_ids: dev_fh.write(example) else: train_fh.write(example) else: test_data_path = get_path_from_url( self.TEST_DATA_URL, os.path.join(default_root, 'MRPC'), self.TEST_DATA_MD5) test_path = os.path.join(default_root, 'MRPC', 'test.tsv') with io.open(test_data_path, encoding='utf-8') as data_fh: with io.open(test_path, 'w', encoding='utf-8') as test_fh: header = data_fh.readline() test_fh.write( 'index\t#1 ID\t#2 ID\t#1 String\t#2 String\n') for idx, row in enumerate(data_fh): label, id1, id2, s1, s2 = row.strip().split('\t') test_fh.write('%d\t%s\t%s\t%s\t%s\n' % (idx, id1, id2, s1, s2)) root = default_root super(GlueMRPC, self)._get_data(root, segment, **kwargs)