def _decompress_dist(fname): env = os.environ if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env: trainer_id = int(env['PADDLE_TRAINER_ID']) num_trainers = int(env['PADDLE_TRAINERS_NUM']) if num_trainers <= 1: _decompress(fname) else: lock_path = fname + '.decompress.lock' from paddle.distributed import ParallelEnv unique_endpoints = _get_unique_endpoints(ParallelEnv() .trainer_endpoints[:]) # NOTE(dkp): _decompress_dist always performed after # _download_dist, in _download_dist sub-trainers is waiting # for download lock file release with sleeping, if decompress # prograss is very fast and finished with in the sleeping gap # time, e.g in tiny dataset such as coco_ce, spine_coco, main # trainer may finish decompress and release lock file, so we # only craete lock file in main trainer and all sub-trainer # wait 1s for main trainer to create lock file, for 1s is # twice as sleeping gap, this waiting time can keep all # trainer pipeline in order # **change this if you have more elegent methods** if ParallelEnv().current_endpoint in unique_endpoints: with open(lock_path, 'w'): # touch os.utime(lock_path, None) _decompress(fname) os.remove(lock_path) else: time.sleep(1) while os.path.exists(lock_path): time.sleep(0.5) else: _decompress(fname)
def _get_data(self, mode, **kwargs): """Downloads dataset.""" default_root = os.path.join(DATA_HOME, self.__class__.__name__) filename, data_hash, url, zipfile_hash = self.SPLITS[mode] fullname = os.path.join(default_root, filename) if mode == 'train': if not os.path.exists(fullname): get_path_from_url(url, default_root, zipfile_hash) unique_endpoints = _get_unique_endpoints( ParallelEnv().trainer_endpoints[:]) if ParallelEnv().current_endpoint in unique_endpoints: file_num = len(os.listdir(fullname)) if file_num != len(ALL_LANGUAGES): logger.warning( "Number of train files is %d != %d, decompress again." % (file_num, len(ALL_LANGUAGES))) shutil.rmtree(fullname) _decompress( os.path.join(default_root, os.path.basename(url))) else: if not os.path.exists(fullname) or ( data_hash and not md5file(fullname) == data_hash): get_path_from_url(url, default_root, zipfile_hash) return fullname
def _download_dist(url, path, md5sum=None): env = os.environ if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env: trainer_id = int(env['PADDLE_TRAINER_ID']) num_trainers = int(env['PADDLE_TRAINERS_NUM']) if num_trainers <= 1: return _download(url, path, md5sum) else: fname = osp.split(url)[-1] fullname = osp.join(path, fname) lock_path = fullname + '.download.lock' if not osp.isdir(path): os.makedirs(path) if not osp.exists(fullname): from paddle.distributed import ParallelEnv unique_endpoints = _get_unique_endpoints(ParallelEnv() .trainer_endpoints[:]) with open(lock_path, 'w'): # touch os.utime(lock_path, None) if ParallelEnv().current_endpoint in unique_endpoints: _download(url, path, md5sum) os.remove(lock_path) else: while os.path.exists(lock_path): time.sleep(0.5) return fullname else: return _download(url, path, md5sum)
def _get_data(self, mode): """ Check and download Dataset """ dl_paths = {} version = self.config.get("version", "3.0.0") if version not in ["1.0.0", "2.0.0", "3.0.0"]: raise ValueError("Unsupported version: %s" % version) dl_paths["version"] = version default_root = os.path.join(DATA_HOME, self.__class__.__name__) for k, v in self.cnn_dailymail.items(): dir_path = os.path.join(default_root, k) if not os.path.exists(dir_path): get_path_from_url(v["url"], default_root, v["md5"]) unique_endpoints = _get_unique_endpoints(ParallelEnv() .trainer_endpoints[:]) if ParallelEnv().current_endpoint in unique_endpoints: file_num = len(os.listdir(os.path.join(dir_path, "stories"))) if file_num != v["file_num"]: logger.warning( "Number of %s stories is %d != %d, decompress again." % (k, file_num, v["file_num"])) shutil.rmtree(os.path.join(dir_path, "stories")) _decompress( os.path.join(default_root, os.path.basename(v["url"]))) dl_paths[k] = dir_path filename, url, data_hash = self.SPLITS[mode] fullname = os.path.join(default_root, filename) if not os.path.exists(fullname) or (data_hash and not md5file(fullname) == data_hash): get_path_from_url(url, default_root, data_hash) dl_paths[mode] = fullname return dl_paths
def read_datasets(self, splits=None, data_files=None): def remove_if_exit(filepath): if isinstance(filepath, (list, tuple)): for file in filepath: try: os.remove(file) except OSError: pass else: try: os.remove(filepath) except OSError: pass if data_files is None: if splits is None: splits = list(self.BUILDER_CONFIGS[ self.name]['splits'].keys()) if hasattr( self, "BUILDER_CONFIGS") else list(self.SPLITS.keys()) assert isinstance( splits, str ) or (isinstance(splits, list) and isinstance(splits[0], str)) or ( isinstance(splits, tuple) and isinstance(splits[0], str) ), "`splits` should be a string or list of string or a tuple of string." if isinstance(splits, str): splits = [splits] datasets = DatasetTuple(splits) parallel_env = dist.ParallelEnv() unique_endpoints = _get_unique_endpoints( parallel_env.trainer_endpoints[:]) # move register hook to first and register togather lock_files = [] for split in splits: lock_file = os.path.join(DATA_HOME, self.__class__.__name__) if self.name is not None: lock_file = lock_file + "." + self.name lock_file += "." + split + ".done" + "." + str(os.getppid()) lock_files.append(lock_file) # Must register to all procs to make the lock file can be removed # when any proc breaks. Otherwise, the single registered proc may # not receive proper singal send by the parent proc to exit. atexit.register(lambda: remove_if_exit(lock_files)) for split in splits: filename = self._get_data(split) lock_file = os.path.join(DATA_HOME, self.__class__.__name__) if self.name is not None: lock_file = lock_file + "." + self.name lock_file += "." + split + ".done" + "." + str(os.getppid()) # `lock_file` indicates the finished status of`_get_data`. # `_get_data` only works in the `unique_endpoints` specified # proc since `get_path_from_url` only work for it. The other # procs wait `_get_data` to be finished. if parallel_env.current_endpoint in unique_endpoints: f = open(lock_file, "w") f.close() else: while not os.path.exists(lock_file): time.sleep(1) datasets[split] = self.read(filename=filename, split=split) else: assert isinstance(data_files, str) or isinstance( data_files, tuple ) or isinstance( data_files, list ), "`data_files` should be a string or tuple or list of strings." if isinstance(data_files, str): data_files = [data_files] default_split = 'train' if splits: if isinstance(splits, str): splits = [splits] datasets = DatasetTuple(splits) assert len(splits) == len( data_files ), "Number of `splits` and number of `data_files` should be the same if you want to specify the split of loacl data file." for i in range(len(data_files)): datasets[splits[i]] = self.read(filename=data_files[i], split=splits[i]) else: datasets = DatasetTuple( ["split" + str(i) for i in range(len(data_files))]) for i in range(len(data_files)): datasets["split" + str(i)] = self.read( filename=data_files[i], split=default_split) return datasets if len(datasets) > 1 else datasets[0]