def clean_redundant_check_points(self, root_path, fs=LocalFS(), checkpoint_num=1): max_no = self._get_last_checkpoint_no(root_path, fs) if max_no < 0: return if checkpoint_num < 1: checkpoint_num = 1 dirs = fs.list_dirs(root_path) for dir in dirs: g = dir.split(".") if len(g) != 2: continue if g[0] != self._checkoint_prefix: continue try: n = int(g[1]) if n <= max_no - checkpoint_num: path = "{}/{}.{}".format(root_path, self._checkoint_prefix, n) fs.rmr(path) except Exception as e: print(e) continue
def save_check_point(self, executor, path, train_status, main_program=None, fs=LocalFS(), local_cache_path=".cache", remain_all_checkpoint=True): """ This function save persistables and current epoch num to path. """ if main_program == None: main_program = self._transpiled_program if not fs.stat(path): fs.mkdir(path) max_no = self._get_last_checkpoint_no(path, fs=fs) if max_no < 0: max_no = -1 real_path = "{}/{}.{}".format(path, self._checkoint_prefix, max_no + 1) tmp_path = "{}.tmp".format(real_path) saved_path = tmp_path local_fs = LocalFS() cache_path = None if fs.need_upload_download(): cache_path = "{}/{}.{}.saved_cache".format(local_cache_path, self._checkoint_prefix, max_no + 1) if not local_fs.stat(cache_path): local_fs.mkdir(cache_path) saved_path = cache_path self.save_persistables(executor=executor, dirname=saved_path, main_program=main_program, filename=self._param_file_name) self._save_train_status(path=saved_path, train_status=train_status) if fs.need_upload_download(): fs.delete(tmp_path) fs.upload(cache_path, tmp_path) fs.mv(tmp_path, real_path) if not remain_all_checkpoint: self.clean_redundant_check_points(path)
def load_check_point(self, executor, path, trainer_id, main_program=None, fs=LocalFS(), local_cache_path=".cache", ignore_empty=True): """ This function load persistables and current epoch num from path. """ max_no = self._get_last_checkpoint_no(path, fs) if not ignore_empty: assert max_no >= 0, "Can't find checkpoint" if max_no < 0: return None local_fs = LocalFS() if fs.need_upload_download(): cache_path = "{}/{}.{}.load_cache.{}".format( local_cache_path, self._checkoint_prefix, max_no, trainer_id) if local_fs.stat(cache_path): local_fs.delete(cache_path) real_path = "{}/{}.{}".format(path, self._checkoint_prefix, max_no) load_path = real_path if fs.need_upload_download(): fs.download(real_path, cache_path) load_path = cache_path if main_program == None: main_program = self._transpiled_program io.load_persistables(executor=executor, dirname=load_path, main_program=main_program, filename=self._param_file_name) return self._load_train_status(load_path)
def test_local_check_point(self): fs = LocalFS() dir_path = "./my_paddle_model" self._test_check_point(fs, dir_path)