Example #1
0
 def __init__(self):
     self.id_ = oneflow_api.NewSessionId()
     self.job_name2function_desc_ = {}
     self.status_ = SessionStatus.OPEN
     self.cond_var_ = threading.Condition()
     self.running_job_cnt_ = 0
     self.inter_user_job_info_ = None
     self.uuid2watch_handler_ = {}
     self.config_proto_ = None
     self.resource_ = None
     self.is_mirrored_strategy_enabled_stack_ = []
     self.job_name2var_name2var_blob_ = {}
     self.job_name2module_name2module_ = {}
     self.existed_module_names_ = set()
     self.var_name2var_blob_ = {}
     # parallel desc symbol id in op attribute does not always correct
     # for lazy ops as parallel conf may be updated in some passes
     # (like optimizer_placement_optimization_pass)
     self.interface_op_name2op_attr_ = {}
     self.interface_op_name2job_name_ = {}
     self.lazy_interface_op_name2parallel_conf_ = {}
     self.op_name2lazy_blob_cache_ = {}
     self.job_name2name_scope_stack_ = {}
     self.eager_global_function_desc_stack_ = []
     self.function_flag_name2default_val_ = {}
     self._UpdateFunctionFlagName2DefaultVal()
     self.scope_attr_name2default_val_ = {}
     self._UpdateScopeAttrName2DefaultVal()
     self.instruction_list_ = instr_cfg.InstructionListProto()
     self.eager_symbol_list_ = eager_symbol_util.EagerSymbolList()
     self.backward_blob_register_ = blob_register_util.BlobRegister()
     self.snapshot_mgr_ = SnapshotManager()
     self.eager_config_proto_ctx_ = None
Example #2
0
 def __init__(self):
     self.id_ = oneflow_api.NewSessionId()
     self.job_name2function_desc_ = {}
     self.status_ = SessionStatus.OPEN
     self.cond_var_ = threading.Condition()
     self.running_job_cnt_ = 0
     self.inter_user_job_info_ = None
     self.uuid2watch_handler_ = {}
     self.config_proto_ = None
     self.resource_ = None
     self.is_mirrored_strategy_enabled_stack_ = []
     self.function_flag_name2default_val_ = {}
     self.job_name2var_name2var_blob_ = {}
     self.job_name2module_name2module_ = {}
     self.existed_module_names_ = set()
     self.var_name2var_blob_ = {}
     self.interface_op_name2op_attr_ = {}
     self.interface_op_name2job_name_ = {}
     self.job_name2name_scope_stack_ = {}
     self.eager_global_function_desc_stack_ = []
     self._UpdateFunctionFlagName2DefaultVal()
     self.instruction_list_ = instr_util.InstructionListProto()
     self.eager_symbol_list_ = eager_symbol_util.EagerSymbolList()
     self.backward_blob_register_ = blob_register_util.BlobRegister()
     self.snapshot_mgr_ = SnapshotManager()
     self.eager_config_proto_ctx_ = None
Example #3
0
def clear_default_session():
    session_ctx.TryCloseDefaultSession()
    session_ctx.OpenDefaultSession(Session(oneflow_api.NewSessionId()))
Example #4
0
    func = enable_if.unique([sync_default_session])
    return func()


@enable_if.condition(hob.in_normal_mode)
def sync_default_session() -> None:
    session_ctx.GetDefaultSession().Sync()


def _TryCompleteConfigProto(config_proto):
    if config_proto.resource.machine_num == 0:
        config_proto.resource.machine_num = len(
            env_util.default_env_proto.machine)


def _GetDefaultConfigProto():
    config_proto = job_set_util.ConfigProto()
    config_proto.resource.machine_num = 0
    if oneflow_api.flags.with_cuda():
        config_proto.resource.gpu_device_num = 1
    else:
        config_proto.resource.cpu_device_num = 1
        config_proto.resource.gpu_device_num = 0
    config_proto.io_conf.data_fs_conf.localfs_conf.SetInParent()
    config_proto.io_conf.snapshot_fs_conf.localfs_conf.SetInParent()
    config_proto.session_id = session_ctx.GetDefaultSession().id
    return config_proto


session_ctx.OpenDefaultSession(Session(oneflow_api.NewSessionId()))