Esempio n. 1
0
 def Init(self):
     assert self.status_ is SessionStatus.OPEN
     self.status_ = SessionStatus.RUNNING
     if not oneflow._oneflow_internal.IsEnvInited():
         oneflow.env.init()
     _TryCompleteConfigProto(self.config_proto)
     self.resource_ = self.config_proto.resource
     if not oneflow._oneflow_internal.EagerExecutionEnabled():
         c_api_util.InitLazyGlobalSession(self.config_proto)
         for job_name, func_desc in self.job_name2function_desc_.items():
             compiler.Compile(self, func_desc, self.config_proto)
             self.existed_module_names_ = set()
         self.job_name2var_name2var_blob_ = dict()
         assert len(self.job_name2function_desc_.items()) > 0
         oneflow._oneflow_internal.StartLazyGlobalSession()
         self.inter_user_job_info_ = c_api_util.GetInterUserJobInfo()
         # Get latest op_attr and job_name after compiler.Compile
         self.UpdateInfo4InterfaceOp()
         if not config_util.api_legacy_model_io_enabled():
             check_point_v2.Init()
     else:
         self.eager_config_proto_ctx_ = oneflow._oneflow_internal.LogicalConfigProtoContext(
             str(self.config_proto)
         )
     return self
Esempio n. 2
0
 def __init__(self) -> None:
     if not config_util.api_legacy_model_io_enabled():
         print(
             "\033[1mWARNING: 'flow.train.CheckPoint' is deprecated. Please use the new API:\033[0m\n"
             "flow.train.CheckPoint().save(path) => \033[1m\033[92mflow.checkpoint.save(path)\033[0m\n"
             "flow.train.CheckPoint().load(path) => \033[1m\033[92mflow.load_variables(flow.checkpoint.get(path))\033[0m\n"
             "flow.train.CheckPoint().init() is not needed any more.\n")
Esempio n. 3
0
    def load(self, path: str) -> None:
        r"""load a checkpoint from `path` and initialize models.

        Args:
            path: A `string` of path to load checkpoint.
        """
        if not config_util.api_legacy_model_io_enabled():
            check_point_v2.LoadVariables(check_point_v2.GetCheckpoint(path))
            return
        assert type(path) is str
        enable_if.unique([lazy_checkpoint_load, eager_checkpoint_load])(path)
Esempio n. 4
0
    def save(self, path: str) -> None:
        r"""save a checkpoint to `path`.

        Args:
            path: A `string` of path to save checkpoint. 
        """
        if not config_util.api_legacy_model_io_enabled():
            check_point_v2.SaveVarDict(path)
            return
        assert type(path) is str
        enable_if.unique([lazy_checkpoint_save, eager_checkpoint_save])(path)
Esempio n. 5
0
 def init(self) -> None:
     r"""Initialize models by default initializer of op or Job.
     """
     if not config_util.api_legacy_model_io_enabled():
         return
     enable_if.unique([lazy_checkpoint_init, eager_checkpoint_init])()