def _GetInterfaceBlobObject(builder, op_name): sess = session_ctx.GetDefaultSession() if oneflow._oneflow_internal.EagerExecutionEnabled(): return sess.var_name2var_blob[op_name].blob_object sess = session_ctx.GetDefaultSession() op_attribute = sess.OpAttribute4InterfaceOpName(op_name) cfg_op_attribute = oneflow._oneflow_internal.deprecated.MakeOpAttributeByString( str(op_attribute)) parallel_conf = sess.ParallelConf4LazyInterfaceOpName(op_name) if not isinstance( parallel_conf, oneflow._oneflow_internal.oneflow.core.job.placement.ParallelConf): parallel_conf_cfg = placement_cfg.ParallelConf() parallel_conf_cfg.set_device_tag(parallel_conf.device_tag) for device_name in parallel_conf.device_name: parallel_conf_cfg.add_device_name(device_name) if parallel_conf.HasField("hierarchy"): hierarchy = shape_proto_cfg.ShapeProto() for dim in parallel_conf.hierarchy.dim: hierarchy.add_dim(dim) assert hierarchy.dim_size() > 0 parallel_conf_cfg.mutable_hierarchy().CopyFrom(hierarchy) parallel_conf = parallel_conf_cfg blob_object = builder.MakeLazyRefBlobObject(op_name, cfg_op_attribute, parallel_conf) return blob_object
def is_trainable(ctx): assert in_global_mode(ctx) if oneflow._oneflow_internal.EagerExecutionEnabled(): return session_ctx.GetDefaultSession().CurrentEagerGlobalFunctionDesc() else: job_name = oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobName( ) return session_ctx.GetDefaultSession().GetFunctionDesc(job_name)
def HasAttr(self, attr_name): if attr_name == "flag_name2flag_value": return False name2default = session_ctx.GetDefaultSession().function_flag_name2default_val if attr_name in self.job_config_proto.flag_name2flag_value(): return True return getattr(self.job_config_proto, "has_" + attr_name)()
def GetLazyCurJobConfigProto(): job_name = oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobName( ) function_desc = session_ctx.GetDefaultSession().GetLazyFunctionDesc( job_name) assert function_desc is not None return function_desc.job_config_proto
def GetJobNameScopePrefix(job_name): sess = session_context.GetDefaultSession() if job_name not in sess.job_name2name_scope_stack: return "" if len(sess.job_name2name_scope_stack[job_name]) == 0: return "" return "-".join(sess.job_name2name_scope_stack[job_name]) + "-"
def name_scope_stack_pop(): job_name = oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobName( ) sess = session_context.GetDefaultSession() assert job_name in sess.job_name2name_scope_stack assert len(sess.job_name2name_scope_stack[job_name]) > 0 return sess.job_name2name_scope_stack[job_name].pop()
def name_scope_stack_push(name): job_name = oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobName( ) sess = session_context.GetDefaultSession() if job_name not in sess.job_name2name_scope_stack: sess.job_name2name_scope_stack[job_name] = [] sess.job_name2name_scope_stack[job_name].append(name)
def name_scope(name: str) -> None: """Create a namespace. All variables within the namespace will have a prefix `[SCOPE NAME]-`. This is for convenience only and has no other effect on the system. Usage:: with oneflow.compatible.single_client.scope.namespace("scope1"): ... with oneflow.compatible.single_client.scope.namespace("scope2"): ... Args: name: Name of this namespace """ assert isinstance(name, str) name_scope_stack_push(name) def BuildScope(old_scope, builder): return builder.BuildScopeWithNewScopeName(old_scope, name) sess = session_context.GetDefaultSession() try: with scope_util.ScopeContext(scope_util.MakeScope(BuildScope)): yield finally: name_scope_stack_pop()
def GetEagerInterfaceBlob(op_name): sync_default_session_if_normal() sess = session_ctx.GetDefaultSession() def CreateBlob(): job_name = sess.JobName4InterfaceOpName(op_name) def Build(builder, Yield): blob_object = _GetInterfaceBlobObject(builder, op_name) lbi = lbi_util.LogicalBlobId() lbi.set_op_name(op_name) op_attribute = sess.OpAttribute4InterfaceOpName(op_name) assert len(op_attribute.output_bns) == 1 lbi.set_blob_name(op_attribute.output_bns[0]) if blob_object.op_arg_parallel_attr.is_mirrored(): remote_blob = oneflow._oneflow_internal.EagerMirroredBlob( lbi, blob_object, blob_register, job_name) else: remote_blob = oneflow._oneflow_internal.EagerConsistentBlob( lbi, blob_object, blob_register, job_name) Yield(remote_blob) def AsyncGetInterfaceBlob(Yield): oneflow._oneflow_internal.deprecated.LogicalRun( lambda builder: Build(builder, Yield)) blob = async_util.Await(1, AsyncGetInterfaceBlob)[0] return blob return sess.FindOrCreateLazyBlob(op_name, CreateBlob)
def GetInterfaceBlobValue(op_name): sync_default_session_if_normal() sess = session_ctx.GetDefaultSession() job_name = sess.JobName4InterfaceOpName(op_name) def AsyncGetInterfaceBlobValue(Yield): def build(builder): blob_object = GetEagerInterfaceBlob(op_name).blob_object lbi = lbi_util.LogicalBlobId() lbi.set_op_name(op_name) op_attribute = sess.OpAttribute4InterfaceOpName(op_name) assert len(op_attribute.output_bns) == 1 lbi.set_blob_name(op_attribute.output_bns[0]) if not isinstance(lbi, lbi_util.LogicalBlobId): cfg_lbi = lbi_util.LogicalBlobId() cfg_lbi.set_op_name(lbi.op_name) cfg_lbi.set_blob_name(lbi.blob_name) lbi = cfg_lbi if blob_object.op_arg_parallel_attr.is_mirrored(): remote_blob = oneflow._oneflow_internal.EagerMirroredBlob( lbi, blob_object, blob_register, job_name) else: remote_blob = oneflow._oneflow_internal.EagerConsistentBlob( lbi, blob_object, blob_register, job_name) value = remote_blob.numpy() Yield(value) oneflow._oneflow_internal.deprecated.LogicalRun(build) return async_util.Await(1, AsyncGetInterfaceBlobValue)[0]
def __getattr__( self, attr_name: str ) -> Callable[[Optional[Union[bool, int, float, str]]], None]: name2default = session_ctx.GetDefaultSession( ).function_flag_name2default_val assert attr_name in name2default flag_name2flag_value = ( self.function_desc.job_config_proto.mutable_flag_name2flag_value()) default_val = name2default[attr_name] def FunctionConfigSetter( attr_value: Optional[Union[bool, int, float, str]] = None) -> None: if default_val.HasField("at_bool"): if attr_value is None: attr_value = True assert type(attr_value) is bool flag_name2flag_value[attr_name].set_at_bool(attr_value) elif default_val.HasField("at_int64"): assert type(attr_value) is int flag_name2flag_value[attr_name].set_at_int64(attr_value) elif default_val.HasField("at_double"): assert type(attr_value) is float flag_name2flag_value[attr_name].set_at_double(attr_value) elif default_val.HasField("at_string"): assert type(attr_value) is str flag_name2flag_value[attr_name].set_at_string(attr_value) else: raise NotImplementedError( "config_flag `%s' with type %s is not supported" % (attr_name, type(attr_value))) return FunctionConfigSetter
def _WatcherHandler(handler_uuid, of_blob_ptr): uuid2handler = session_ctx.GetDefaultSession().uuid2watch_handler assert handler_uuid in uuid2handler (blob_watched, handler) = uuid2handler[handler_uuid] assert callable(handler) ndarray = ofblob.OfBlob(of_blob_ptr).CopyToNdarray() local_blob = local_blob_util.LocalBlob(ndarray, blob_watched.is_dynamic) handler(oft_util.TransformWatchedBlob(local_blob, handler))
def get_eager_variable( name, shape=None, dtype=None, initializer=None, regularizer=None, trainable=None, model_name=None, random_seed=None, nd_sbp=None, reuse=True, ): assert isinstance(name, str) assert isinstance( shape, (list, tuple)), "param shape should be a list or tuple of dimension" job_name = oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobName( ) name = name_scope.GetJobNameScopePrefix(job_name) + name sess = session_ctx.GetDefaultSession() (var_blob, job_var_blob) = sess.TryGetVariableBlobOfJobFromStash(job_name, name) if reuse is False: assert ( job_var_blob is None ), "variable '{}' already exists, getting the same variable is not allowed when reuse is False".format( name) if job_var_blob is None: op_conf = GenerateVariableOpConf( name=name, shape=shape, dtype=dtype, initializer=initializer, regularizer=regularizer, trainable=trainable, model_name=model_name, random_seed=random_seed, nd_sbp=nd_sbp, ) op_attribute = compile_context.CurJobAddConsistentOp(op_conf) if var_blob is None: var_blob = CreateEagerVariableBlob(op_attribute) op_executor.EagerInitVariableBlob(sess, op_conf, var_blob) assert isinstance(var_blob, oneflow._oneflow_internal.EagerConsistentBlob) sess.StashVariableBlob4Job(job_name, op_conf.name, var_blob) else: assert isinstance(job_var_blob, oneflow._oneflow_internal.EagerConsistentBlob) assert isinstance(var_blob, oneflow._oneflow_internal.EagerConsistentBlob) assert var_blob.IdenticalTo(job_var_blob) bw_blob_register = gradient_util.GetDefaultBackwardBlobRegister() bw_blob_register.TrySetObject4BlobName(var_blob.logical_blob_name, var_blob.blob_object) return var_blob
def _GetDefaultConfigProto(): config_proto = job_set_util.ConfigProto() config_proto.resource.machine_num = 0 if oneflow._oneflow_internal.flags.with_cuda(): config_proto.resource.gpu_device_num = 1 else: config_proto.resource.cpu_device_num = 1 config_proto.resource.gpu_device_num = 0 config_proto.session_id = session_ctx.GetDefaultSession().id return config_proto
def get_lazy_variable( name, shape=None, dtype=None, initializer=None, regularizer=None, trainable=None, model_name=None, random_seed=None, nd_sbp=None, reuse=True, ): assert isinstance(name, str) assert isinstance( shape, (list, tuple)), "param shape should be a list or tuple of dimension" job_name = oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobName( ) name = name_scope.GetJobNameScopePrefix(job_name) + name sess = session_ctx.GetDefaultSession() (var_blob, job_var_blob) = sess.TryGetVariableBlobOfJobFromStash(job_name, name) if reuse is False: assert ( job_var_blob is None ), "variable '{}' already exists, getting the same variable is not allowed when param reuse is False".format( name) if job_var_blob is None: op_conf = GenerateVariableOpConf( name=name, shape=shape, dtype=dtype, initializer=initializer, regularizer=regularizer, trainable=trainable, model_name=model_name, random_seed=random_seed, nd_sbp=nd_sbp, ) job_var_blob = _CreateVariableBlob(op_conf) assert isinstance(job_var_blob, oneflow._oneflow_internal.LazyConsistentBlob) sess.StashVariableBlob4Job(job_name, op_conf.name, job_var_blob) if var_blob is not None: assert isinstance(var_blob, oneflow._oneflow_internal.LazyConsistentBlob) assert var_blob.IdenticalTo(job_var_blob) else: assert isinstance(job_var_blob, oneflow._oneflow_internal.LazyConsistentBlob) assert isinstance(var_blob, oneflow._oneflow_internal.LazyConsistentBlob) assert var_blob.IdenticalTo(job_var_blob) return job_var_blob
def _MakeModelInitJobFunc(): def push_cb(blob): pass def finish_cb(): pass sess = session_ctx.GetDefaultSession() return job_instance.MakeJobInstance( str(sess.inter_user_job_info.global_model_init_job_name), push_cb=push_cb, finish_cb=finish_cb, )
def __init__(self, is_mirrored): self.is_mirrored_ = is_mirrored self.scope_context_ = None sess = session_ctx.GetDefaultSession() if sess.is_running and ( not sess.has_empty_is_mirrored_strategy_enabled_stack()): def BuildScope(old_scope, builder): return builder.BuildScopeWithNewIsMirrored( old_scope, is_mirrored) self.scope_context_ = scope_util.ScopeContext( scope_util.MakeScope(BuildScope))
def CurJobAddMirroredOp(op_conf, scope_symbol=None): assert not hob.consistent_view_enabled(None) if scope_symbol is None: scope_symbol = flow.current_scope() op_conf.scope_symbol_id = scope_symbol.symbol_id if not op_conf.HasField("device_tag"): device_tag = scope_symbol.device_parallel_desc_symbol.device_tag op_conf.device_tag = device_tag op_attr = c_api_util.CurJobBuildAndInferCtx_AddAndInferMirroredOp(op_conf) if c_api_util.IsInterfaceOpConf(op_conf): sess = session_ctx.GetDefaultSession() sess.AddInfo4InterfaceOpName(op_conf.name, op_attr) return op_attr
def GetGlobalModePlacementScope(device_tag, machine_device_ids, hierarchy=None): if isinstance(machine_device_ids, (list, tuple)) == False: machine_device_ids = [machine_device_ids] sess = session_ctx.GetDefaultSession() if hierarchy is not None: hierarchy = oneflow._oneflow_internal.Size(tuple(hierarchy)) def BuildScope(old_scope, builder): return builder.BuildScopeWithNewParallelDesc( old_scope, device_tag, machine_device_ids, hierarchy ) scope_ctx = scope_util.ScopeContext(scope_util.MakeScope(BuildScope)) return placement_ctx.GlobalModePlacementScope(scope_ctx)
def _MakeModelSaveJobFunc(path): def push_cb(blob): blob.CopyFromNdarray(np.frombuffer(path.encode("ascii"), dtype=np.int8)) def finish_cb(): pass sess = session_ctx.GetDefaultSession() return job_instance.MakeJobInstance( str(sess.inter_user_job_info.global_model_save_job_name), push_cb=push_cb, finish_cb=finish_cb, )
def GetNormalModePlacementScope(device_tag, machine_device_ids, hierarchy=None): if isinstance(machine_device_ids, tuple): machine_device_ids = list(machine_device_ids) if not isinstance(machine_device_ids, list): machine_device_ids = [machine_device_ids] sess = session_ctx.GetDefaultSession() if hierarchy is not None: hierarchy = oneflow._oneflow_internal.Size(tuple(hierarchy)) scope = scope_util.MakeScope( lambda old_scope, builder: builder.BuildScopeWithNewParallelDesc( old_scope, device_tag, machine_device_ids, hierarchy ) ) return scope_util.ScopeContext(scope)
def GetAllVariables( ) -> Dict[str, oneflow._oneflow_internal.EagerConsistentBlob]: """ Get all variables of all jobs as a dict. """ sync_default_session_if_normal() sess = session_ctx.GetDefaultSession() interface_ops = sess.interface_ops variables = {} for op in interface_ops: op_attr = sess.OpAttribute4InterfaceOpName(op) if op_attr.op_conf.WhichOneof("op_type") != "variable_conf": continue variables[op] = interface_op_read_and_write.GetEagerInterfaceBlob(op) return variables
def Decorator(job_func): if not hasattr(job_func, "__oneflow_function_signature__"): job_func.__oneflow_function_signature__ = inspect.signature( job_func) oft_util.CheckGlobalFunctionAnnotation( job_func.__oneflow_function_signature__) sess = session_ctx.GetDefaultSession() @functools.wraps(job_func) def Func(*args, **kwargs): return _RunLazyJob(sess, job_func, *args, **kwargs) sess.AddJob(_CloneFunctionDesc(function_config.function_desc, job_func)) for x in dir(job_func): if x.startswith("__oneflow_"): setattr(Func, x, getattr(job_func, x)) return Func
def _FindOrCreateVarBlobObject(op_attribute, parallel_conf, blob_register): job_name = oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobName( ) name = name_scope.GetJobNameScopePrefix( job_name) + op_attribute.op_conf.name sess = session_ctx.GetDefaultSession() (var_blob, _) = sess.TryGetVariableBlobOfJobFromStash(job_name, name) if var_blob is not None: blob_register.SetObject4BlobName(var_blob.logical_blob_name, var_blob.blob_object) return _NaiveInterpret(op_attribute, parallel_conf, blob_register) var_blob = _MakeEagerLogicalBlob(op_attribute, "out", blob_register=blob_register) EagerInitVariableBlob(sess, op_attribute.op_conf, var_blob) sess.StashVariableBlob4Job(job_name, op_attribute.op_conf.name, var_blob) return var_blob
def find_or_create_module(module_name, create, reuse=False): assert callable(create) sess = session_ctx.GetDefaultSession() job_name = flow.current_global_function_desc().job_config_proto.job_name() if job_name not in sess.job_name2module_name2module_: sess.job_name2module_name2module_[job_name] = {} module_name2module = sess.job_name2module_name2module_[job_name] if module_name not in module_name2module: module = create() assert isinstance(module, module_util.Module) module_name2module[module_name] = module elif not reuse: assert module_name not in sess.existed_module_names_, ( "duplicated module_name `%s' in global_function `%s'" % (module_name, job_name)) else: pass sess.existed_module_names_.add(module_name) return module_name2module[module_name]
def _GetReturnOpConfAndOutLbiAndScope(remote_blob, allow_cpu_return_op=True): op_conf = op_conf_util.OperatorConf() op_conf.name = id_util.UniqueStr("Return_") setattr(op_conf.return_conf, "in", remote_blob.unique_name) op_conf.return_conf.out = "out" if allow_cpu_return_op: op_conf.device_tag = "cpu" lbi = logical_blob_id_util.LogicalBlobId() lbi.op_name = op_conf.name lbi.blob_name = "out" parallel_conf = placement_cfg.ParallelConf() parallel_conf.CopyFrom(remote_blob.parallel_conf) def BuildScope(old_scope, builder): return builder.BuildScopeWithNewParallelConf(old_scope, parallel_conf) sess = session_ctx.GetDefaultSession() scope = scope_util.MakeScope(BuildScope) return (op_conf, lbi, scope)
def _EagerRunModelLoad(var_op_conf, snapshot_path): assert isinstance(snapshot_path, str) assert os.path.basename(snapshot_path) == "out" snapshot_path = os.path.dirname(snapshot_path) assert os.path.basename(snapshot_path) == var_op_conf.name snapshot_path = os.path.dirname(snapshot_path) (path_input_op_conf, path_lbi) = _GenModelIOPathInputOpConfAndRetLbi() path_input_blob_objects = {} ( BuildModelIOPathInputInstruction, BuildFeedPathInstruction, ) = _MakeModelIOPathInputBuilds(path_input_op_conf, snapshot_path, path_input_blob_objects) (model_load_op_conf, _) = _GenModelLoadOpConfAndRetLbi(var_op_conf, path_lbi) model_load_blob_objects = oneflow._oneflow_internal.deprecated.BnInOp2BlobObject( ) def BuildModelLoadInstruction(builder): path_blob_object = path_input_blob_objects["out"] model_load_blob_objects["path"] = path_blob_object op_attribute = op_infer_util.Infer( model_load_op_conf, ibn2blob_object=model_load_blob_objects) parallel_conf = path_blob_object.parallel_desc_symbol.parallel_conf cfg_op_attribute = oneflow._oneflow_internal.deprecated.MakeOpAttributeByString( str(op_attribute)) builder.StatelessCall( cfg_op_attribute, parallel_conf, model_load_blob_objects, boxing_util.BoxingTo, ) sess = session_ctx.GetDefaultSession() with scope_util.ScopeContext(scope_util.MakeScope(_BuildNotMirroredScope)): oneflow._oneflow_internal.deprecated.LogicalRun( BuildModelIOPathInputInstruction) oneflow._oneflow_internal.deprecated.LogicalRun( BuildFeedPathInstruction) oneflow._oneflow_internal.deprecated.LogicalRun( BuildModelLoadInstruction) return model_load_blob_objects["out_0"]
def __getattr__(self, attr_name): assert attr_name != "flag_name2flag_value" flag_name2flag_value = self.job_config_proto.flag_name2flag_value() name2default = session_ctx.GetDefaultSession().function_flag_name2default_val if attr_name not in name2default: assert getattr(self.job_config_proto, "has_" + attr_name)() return getattr(self.job_config_proto, attr_name)() attr_value = name2default[attr_name] if attr_name in flag_name2flag_value: attr_value = flag_name2flag_value[attr_name] if attr_value.HasField("at_bool"): return attr_value.at_bool elif attr_value.HasField("at_int64"): return attr_value.at_int64 elif attr_value.HasField("at_double"): return attr_value.at_double elif attr_value.HasField("at_string"): return attr_value.at_string else: raise NotImplementedError()
def _Watch(op_attribute, parallel_conf, blob_register): lbi = op_attribute.arg_signature.bn_in_op2lbi["in"] uuid = op_attribute.op_conf.foreign_watch_conf.handler_uuid lbn = "%s/%s" % (lbi.op_name, lbi.blob_name) in_blob_object = blob_register.GetObject4BlobName(lbn) if not isinstance(lbi, lbi_util.LogicalBlobId): cfg_lbi = lbi_util.LogicalBlobId() cfg_lbi.set_op_name(lbi.op_name) cfg_lbi.set_blob_name(lbi.blob_name) lbi = cfg_lbi if in_blob_object.op_arg_parallel_attr.is_mirrored(): blob = oneflow._oneflow_internal.EagerMirroredBlob( lbi, in_blob_object, default_blob_register) else: blob = oneflow._oneflow_internal.EagerConsistentBlob( lbi, in_blob_object, default_blob_register) uuid2watch_handler = session_ctx.GetDefaultSession().uuid2watch_handler assert uuid in uuid2watch_handler uuid2watch_handler[uuid](blob) del uuid2watch_handler[uuid]
def _EagerRunModelInit(var_op_conf): (op_conf, _) = _GenModelInitOpConfAndRetLbi(var_op_conf) bn_in_op2blob_object = oneflow._oneflow_internal.deprecated.BnInOp2BlobObject( ) def BuildModelInitInstruction(builder): upstream_signature = op_node_signature_pb.OpNodeSignature() op_conf.scope_symbol_id = flow.current_scope().symbol_id op_attribute = c_api_util.InferOpConf(op_conf, upstream_signature) parallel_conf = flow.current_scope( ).device_parallel_desc_symbol.parallel_conf cfg_op_attribute = oneflow._oneflow_internal.deprecated.MakeOpAttributeByString( str(op_attribute)) builder.StatelessCall(cfg_op_attribute, parallel_conf, bn_in_op2blob_object, boxing_util.BoxingTo) sess = session_ctx.GetDefaultSession() with scope_util.ScopeContext(scope_util.MakeScope(_BuildNotMirroredScope)): oneflow._oneflow_internal.deprecated.LogicalRun( BuildModelInitInstruction) return bn_in_op2blob_object["out_0"]