def load_train_snapshot(executor, program, file_path): assert os.path.exists(file_path), "[%s] cann't be found." % file_path io.load_persistables(executor=executor, dirname=file_path, main_program=program) if os.path.exists(file_path + ".json"): info = file_utils.read_file(file_name=file_path + ".json", file_format="json") return info return False
def init_server(self, model_dir=None): """ `init_server` has many many functions to do before start pserver, first, run executor to initialize startup program, second, if the `model_dir` is not empty, it will load parameters from it for increment training. Args: model_dir(str): The directory path. Returns: None """ if not self.startup_program: raise ValueError( "startup_program is None, need invoke DistributedOptimizer.minimize first" ) self._executor.run(self.startup_program) if model_dir: if not os.path.isdir(model_dir): raise ValueError("There is no directory named '%s'", model_dir) io.load_persistables(self._executor, model_dir, self.startup_program)
def load_check_point(self, executor, path, trainer_id, main_program=None, fs=LocalFS(), local_cache_path=".cache", ignore_empty=True): """ This function load persistables and current epoch num from path. """ max_no = self._get_last_checkpoint_no(path, fs) if not ignore_empty: assert max_no >= 0, "Can't find checkpoint" if max_no < 0: return None local_fs = LocalFS() if fs.need_upload_download(): cache_path = "{}/{}.{}.load_cache.{}".format( local_cache_path, self._checkoint_prefix, max_no, trainer_id) if local_fs.stat(cache_path): local_fs.delete(cache_path) real_path = "{}/{}.{}".format(path, self._checkoint_prefix, max_no) load_path = real_path if fs.need_upload_download(): fs.download(real_path, cache_path) load_path = cache_path if main_program == None: main_program = self._transpiled_program io.load_persistables( executor=executor, dirname=load_path, main_program=main_program, filename=self._param_file_name) return self._load_train_status(load_path)