def Init(self): assert self.status_ is SessionStatus.OPEN self.status_ = SessionStatus.RUNNING if not oneflow._oneflow_internal.IsEnvInited(): oneflow.env.init() _TryCompleteConfigProto(self.config_proto) self.resource_ = self.config_proto.resource if not oneflow._oneflow_internal.EagerExecutionEnabled(): c_api_util.InitLazyGlobalSession(self.config_proto) for job_name, func_desc in self.job_name2function_desc_.items(): compiler.Compile(self, func_desc, self.config_proto) self.existed_module_names_ = set() self.job_name2var_name2var_blob_ = dict() assert len(self.job_name2function_desc_.items()) > 0 oneflow._oneflow_internal.StartLazyGlobalSession() self.inter_user_job_info_ = c_api_util.GetInterUserJobInfo() # Get latest op_attr and job_name after compiler.Compile self.UpdateInfo4InterfaceOp() if not config_util.api_legacy_model_io_enabled(): check_point_v2.Init() else: self.eager_config_proto_ctx_ = oneflow._oneflow_internal.LogicalConfigProtoContext( str(self.config_proto) ) return self
def __init__(self) -> None: if not config_util.api_legacy_model_io_enabled(): print( "\033[1mWARNING: 'flow.train.CheckPoint' is deprecated. Please use the new API:\033[0m\n" "flow.train.CheckPoint().save(path) => \033[1m\033[92mflow.checkpoint.save(path)\033[0m\n" "flow.train.CheckPoint().load(path) => \033[1m\033[92mflow.load_variables(flow.checkpoint.get(path))\033[0m\n" "flow.train.CheckPoint().init() is not needed any more.\n")
def load(self, path: str) -> None: r"""load a checkpoint from `path` and initialize models. Args: path: A `string` of path to load checkpoint. """ if not config_util.api_legacy_model_io_enabled(): check_point_v2.LoadVariables(check_point_v2.GetCheckpoint(path)) return assert type(path) is str enable_if.unique([lazy_checkpoint_load, eager_checkpoint_load])(path)
def save(self, path: str) -> None: r"""save a checkpoint to `path`. Args: path: A `string` of path to save checkpoint. """ if not config_util.api_legacy_model_io_enabled(): check_point_v2.SaveVarDict(path) return assert type(path) is str enable_if.unique([lazy_checkpoint_save, eager_checkpoint_save])(path)
def init(self) -> None: r"""Initialize models by default initializer of op or Job. """ if not config_util.api_legacy_model_io_enabled(): return enable_if.unique([lazy_checkpoint_init, eager_checkpoint_init])()