示例#1
0
def api_eager_execution_enabled() -> bool:
    """Get current setting of the job, if enable eager execution mode ,then return True

    Returns:
        bool: [description]
    """
    return c_api_util.EagerExecutionEnabled()
示例#2
0
def is_trainable(ctx):
    assert in_global_mode(ctx)
    if c_api_util.EagerExecutionEnabled():
        return session_ctx.GetDefaultSession().CurrentEagerGlobalFunctionDesc()
    else:
        job_name = c_api_util.JobBuildAndInferCtx_GetCurrentJobName()
        return session_ctx.GetDefaultSession().GetFunctionDesc(job_name)
示例#3
0
 def Init(self):
     assert self.status_ is SessionStatus.OPEN
     self.status_ = SessionStatus.RUNNING
     if not c_api_util.IsEnvInited():
         oneflow.env.init()
     _TryCompleteConfigProto(self.config_proto)
     self.resource_ = self.config_proto.resource
     if not c_api_util.EagerExecutionEnabled():
         c_api_util.InitLazyGlobalSession(self.config_proto)
         for job_name, func_desc in self.job_name2function_desc_.items():
             compiler.Compile(self, func_desc, self.config_proto)
             self.existed_module_names_ = set()
         self.job_name2var_name2var_blob_ = dict()
         assert len(self.job_name2function_desc_.items()) > 0
         c_api_util.StartLazyGlobalSession()
         self.inter_user_job_info_ = c_api_util.GetInterUserJobInfo()
         # Get latest op_attr and job_name after compiler.Compile
         self._UpdateInfo4LazyInterfaceOp()
         if not config_util.api_legacy_model_io_enabled():
             check_point_v2.Init()
     else:
         self.eager_config_proto_ctx_ = oneflow_api.LogicalConfigProtoContext(
             str(self.config_proto)
         )
     return self
示例#4
0
 def Init(self):
     assert self.status_ is SessionStatus.OPEN
     self.status_ = SessionStatus.RUNNING
     if not c_api_util.IsEnvInited():
         oneflow.env.init()
     _TryCompleteConfigProto(self.config_proto)
     c_api_util.InitGlobalSession(self.config_proto)
     if not c_api_util.EagerExecutionEnabled():
         for job_name, func_desc in self.job_name2function_desc_.items():
             compiler.Compile(self, func_desc, self.config_proto)
             self.existed_module_names_ = set()
         self.job_name2var_name2var_blob_ = dict()
         assert len(self.job_name2function_desc_.items()) > 0
         c_api_util.StartGlobalSession()
         self.inter_user_job_info_ = c_api_util.GetInterUserJobInfo()
     return self
示例#5
0
def eager_execution_enabled(ctx):
    return c_api_util.EagerExecutionEnabled()
示例#6
0
def _GetInterfaceBlobObject(builder, op_name):
    if c_api_util.EagerExecutionEnabled():
        return session_ctx.GetDefaultSession(
        ).var_name2var_blob[op_name].blob_object
    blob_object = builder.MakeLazyRefBlobObject(op_name)
    return blob_object