def save_checkpoint(self, path, slists, trainer_id=None, local_cache_path=".cache"): """ Serialize objects in slists to path Return really saved path and checkpoint_no """ if not self._fs.is_exist(path): self._fs.mkdirs(path) else: assert self._fs.is_dir(path), "path:{} must be a directory".format( path) max_no = self._get_last_checkpoint_no(path) if max_no < 0: max_no = -1 max_no += 1 real_path = "{}/{}.{}".format(path, self._checkpoint_prefix, max_no) tmp_path = "{}.tmp".format(real_path) saved_path = tmp_path from paddle.distributed.fleet.utils.fs import LocalFS local_fs = LocalFS() cache_path = None if self._fs.need_upload_download(): cache_path = "{}/{}.{}.saved_cache".format(local_cache_path, self._checkpoint_prefix, max_no) if trainer_id is not None: cache_path = "{}.{}".format(cache_path, trainer_id) if not local_fs.is_exist(cache_path): local_fs.mkdirs(cache_path) else: assert local_fs.is_dir(cache_path), \ "cache path:{} must be a directory".format(cache_path) saved_path = cache_path for s in slists: s.serialize(saved_path) if self._fs.need_upload_download(): self._fs.delete(tmp_path) self._fs.upload(cache_path, tmp_path) local_fs.delete(cache_path) self._fs.mv(tmp_path, real_path) return real_path, max_no
def load_checkpoint(self, path, slists, trainer_id, local_cache_path=".cache", checkpoint_no=None, ignore_empty=True): """ Deserialize objects in slists from path Return really load path """ if checkpoint_no is None: max_no = self._get_last_checkpoint_no(path) if not ignore_empty: assert max_no >= 0, "Can't find checkpoint" if max_no < 0: return None checkpoint_no = max_no else: assert isinstance(checkpoint_no, int) assert checkpoint_no >= 0 from paddle.distributed.fleet.utils.fs import LocalFS local_fs = LocalFS() if self._fs.need_upload_download(): cache_path = "{}/{}.{}.load_cache".format(local_cache_path, self._checkpoint_prefix, checkpoint_no) if trainer_id is not None: cache_path = "{}.{}".format(cache_path, trainer_id) if not local_fs.is_exist(local_cache_path): local_fs.mkdirs(local_cache_path) if local_fs.is_exist(cache_path): local_fs.delete(cache_path) real_path = "{}/{}.{}".format(path, self._checkpoint_prefix, checkpoint_no) load_path = real_path if self._fs.need_upload_download(): self._fs.download(real_path, cache_path) load_path = cache_path for s in slists: s.deserialize(load_path) if self._fs.need_upload_download() and cache_path: local_fs.delete(cache_path) return real_path
def _test_upload_dir(self, fs): # upload dir src_file = os.path.abspath("./test_upload_dir") dst_file = os.path.abspath("./test_uolpad_dir") file1 = os.path.abspath("./test_upload_dir/file1") file2 = os.path.abspath("./test_upload_dir/file2") local = LocalFS() local.mkdirs(src_file) local.touch(file1) local.touch(file2) fs.upload(src_file, dst_file) self.assertTrue(fs.is_exist(dst_file)) fs.delete(dst_file) local.delete(src_file)