def _MakeModelInitJobFunc(): def push_cb(blob): pass def finish_cb(): pass sess = session_ctx.GetDefaultSession() return job_instance.MakeJobInstance( str(sess.inter_user_job_info.global_model_init_job_name), push_cb=push_cb, finish_cb=finish_cb, )
def _run_load_checkpoint_job(self): if self.checkpoint_path_ is None: raise ValueError("checkpoint path not set") def copy_model_load_path(ofblob): ofblob.CopyFromNdarray( np.frombuffer(self.checkpoint_path_.encode("ascii"), dtype=np.int8)) load_checkpoint_job_inst = job_instance_util.MakeJobInstance( self.inter_user_job_info_.global_model_load_job_name, push_cb=copy_model_load_path, ) self._run_job(load_checkpoint_job_inst)
def _MakeModelSaveJobFunc(path): def push_cb(blob): blob.CopyFromNdarray(np.frombuffer(path.encode("ascii"), dtype=np.int8)) def finish_cb(): pass sess = session_ctx.GetDefaultSession() return job_instance.MakeJobInstance( str(sess.inter_user_job_info.global_model_save_job_name), push_cb=push_cb, finish_cb=finish_cb, )