Esempio n. 1
0
def _decompress_dist(fname):
    env = os.environ
    if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env:
        trainer_id = int(env['PADDLE_TRAINER_ID'])
        num_trainers = int(env['PADDLE_TRAINERS_NUM'])
        if num_trainers <= 1:
            _decompress(fname)
        else:
            lock_path = fname + '.decompress.lock'
            from paddle.distributed import ParallelEnv
            unique_endpoints = _get_unique_endpoints(ParallelEnv()
                                                     .trainer_endpoints[:])
            # NOTE(dkp): _decompress_dist always performed after
            # _download_dist, in _download_dist sub-trainers is waiting
            # for download lock file release with sleeping, if decompress
            # prograss is very fast and finished with in the sleeping gap
            # time, e.g in tiny dataset such as coco_ce, spine_coco, main
            # trainer may finish decompress and release lock file, so we
            # only craete lock file in main trainer and all sub-trainer
            # wait 1s for main trainer to create lock file, for 1s is
            # twice as sleeping gap, this waiting time can keep all
            # trainer pipeline in order
            # **change this if you have more elegent methods**
            if ParallelEnv().current_endpoint in unique_endpoints:
                with open(lock_path, 'w'):  # touch    
                    os.utime(lock_path, None)
                _decompress(fname)
                os.remove(lock_path)
            else:
                time.sleep(1)
                while os.path.exists(lock_path):
                    time.sleep(0.5)
    else:
        _decompress(fname)
Esempio n. 2
0
    def _get_data(self, mode, **kwargs):
        """Downloads dataset."""
        default_root = os.path.join(DATA_HOME, self.__class__.__name__)
        filename, data_hash, url, zipfile_hash = self.SPLITS[mode]
        fullname = os.path.join(default_root, filename)
        if mode == 'train':
            if not os.path.exists(fullname):
                get_path_from_url(url, default_root, zipfile_hash)
            unique_endpoints = _get_unique_endpoints(
                ParallelEnv().trainer_endpoints[:])
            if ParallelEnv().current_endpoint in unique_endpoints:
                file_num = len(os.listdir(fullname))
                if file_num != len(ALL_LANGUAGES):
                    logger.warning(
                        "Number of train files is %d != %d, decompress again."
                        % (file_num, len(ALL_LANGUAGES)))
                    shutil.rmtree(fullname)
                    _decompress(
                        os.path.join(default_root, os.path.basename(url)))
        else:
            if not os.path.exists(fullname) or (
                    data_hash and not md5file(fullname) == data_hash):
                get_path_from_url(url, default_root, zipfile_hash)

        return fullname
Esempio n. 3
0
def _download_dist(url, path, md5sum=None):
    env = os.environ
    if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env:
        trainer_id = int(env['PADDLE_TRAINER_ID'])
        num_trainers = int(env['PADDLE_TRAINERS_NUM'])
        if num_trainers <= 1:
            return _download(url, path, md5sum)
        else:
            fname = osp.split(url)[-1]
            fullname = osp.join(path, fname)
            lock_path = fullname + '.download.lock'

            if not osp.isdir(path):
                os.makedirs(path)

            if not osp.exists(fullname):
                from paddle.distributed import ParallelEnv
                unique_endpoints = _get_unique_endpoints(ParallelEnv()
                                                         .trainer_endpoints[:])
                with open(lock_path, 'w'):  # touch    
                    os.utime(lock_path, None)
                if ParallelEnv().current_endpoint in unique_endpoints:
                    _download(url, path, md5sum)
                    os.remove(lock_path)
                else:
                    while os.path.exists(lock_path):
                        time.sleep(0.5)
            return fullname
    else:
        return _download(url, path, md5sum)
Esempio n. 4
0
 def _get_data(self, mode):
     """ Check and download Dataset """
     dl_paths = {}
     version = self.config.get("version", "3.0.0")
     if version not in ["1.0.0", "2.0.0", "3.0.0"]:
         raise ValueError("Unsupported version: %s" % version)
     dl_paths["version"] = version
     default_root = os.path.join(DATA_HOME, self.__class__.__name__)
     for k, v in self.cnn_dailymail.items():
         dir_path = os.path.join(default_root, k)
         if not os.path.exists(dir_path):
             get_path_from_url(v["url"], default_root, v["md5"])
         unique_endpoints = _get_unique_endpoints(ParallelEnv()
                                                  .trainer_endpoints[:])
         if ParallelEnv().current_endpoint in unique_endpoints:
             file_num = len(os.listdir(os.path.join(dir_path, "stories")))
             if file_num != v["file_num"]:
                 logger.warning(
                     "Number of %s stories is %d != %d, decompress again." %
                     (k, file_num, v["file_num"]))
                 shutil.rmtree(os.path.join(dir_path, "stories"))
                 _decompress(
                     os.path.join(default_root, os.path.basename(v["url"])))
         dl_paths[k] = dir_path
     filename, url, data_hash = self.SPLITS[mode]
     fullname = os.path.join(default_root, filename)
     if not os.path.exists(fullname) or (data_hash and
                                         not md5file(fullname) == data_hash):
         get_path_from_url(url, default_root, data_hash)
     dl_paths[mode] = fullname
     return dl_paths
Esempio n. 5
0
    def read_datasets(self, splits=None, data_files=None):
        def remove_if_exit(filepath):
            if isinstance(filepath, (list, tuple)):
                for file in filepath:
                    try:
                        os.remove(file)
                    except OSError:
                        pass
            else:
                try:
                    os.remove(filepath)
                except OSError:
                    pass

        if data_files is None:
            if splits is None:
                splits = list(self.BUILDER_CONFIGS[
                    self.name]['splits'].keys()) if hasattr(
                        self, "BUILDER_CONFIGS") else list(self.SPLITS.keys())

            assert isinstance(
                splits, str
            ) or (isinstance(splits, list) and isinstance(splits[0], str)) or (
                isinstance(splits, tuple) and isinstance(splits[0], str)
            ), "`splits` should be a string or list of string or a tuple of string."

            if isinstance(splits, str):
                splits = [splits]
            datasets = DatasetTuple(splits)
            parallel_env = dist.ParallelEnv()
            unique_endpoints = _get_unique_endpoints(
                parallel_env.trainer_endpoints[:])
            # move register hook to first and register togather
            lock_files = []
            for split in splits:
                lock_file = os.path.join(DATA_HOME, self.__class__.__name__)
                if self.name is not None:
                    lock_file = lock_file + "." + self.name
                lock_file += "." + split + ".done" + "." + str(os.getppid())
                lock_files.append(lock_file)
            # Must register to all procs to make the lock file can be removed
            # when any proc breaks. Otherwise, the single registered proc may
            # not receive proper singal send by the parent proc to exit.
            atexit.register(lambda: remove_if_exit(lock_files))
            for split in splits:
                filename = self._get_data(split)
                lock_file = os.path.join(DATA_HOME, self.__class__.__name__)
                if self.name is not None:
                    lock_file = lock_file + "." + self.name
                lock_file += "." + split + ".done" + "." + str(os.getppid())
                # `lock_file` indicates the finished status of`_get_data`.
                # `_get_data` only works in the `unique_endpoints` specified
                # proc since `get_path_from_url` only work for it. The other
                # procs wait `_get_data` to be finished.
                if parallel_env.current_endpoint in unique_endpoints:
                    f = open(lock_file, "w")
                    f.close()
                else:
                    while not os.path.exists(lock_file):
                        time.sleep(1)
                datasets[split] = self.read(filename=filename, split=split)
        else:
            assert isinstance(data_files, str) or isinstance(
                data_files, tuple
            ) or isinstance(
                data_files, list
            ), "`data_files` should be a string or tuple or list of strings."
            if isinstance(data_files, str):
                data_files = [data_files]
            default_split = 'train'
            if splits:
                if isinstance(splits, str):
                    splits = [splits]
                datasets = DatasetTuple(splits)
                assert len(splits) == len(
                    data_files
                ), "Number of `splits` and number of `data_files` should be the same if you want to specify the split of loacl data file."
                for i in range(len(data_files)):
                    datasets[splits[i]] = self.read(filename=data_files[i],
                                                    split=splits[i])
            else:
                datasets = DatasetTuple(
                    ["split" + str(i) for i in range(len(data_files))])
                for i in range(len(data_files)):
                    datasets["split" + str(i)] = self.read(
                        filename=data_files[i], split=default_split)

        return datasets if len(datasets) > 1 else datasets[0]