示例#1
0
def _GetInterfaceBlobObject(builder, op_name):
    sess = session_ctx.GetDefaultSession()
    if oneflow._oneflow_internal.EagerExecutionEnabled():
        return sess.var_name2var_blob[op_name].blob_object
    sess = session_ctx.GetDefaultSession()
    op_attribute = sess.OpAttribute4InterfaceOpName(op_name)
    cfg_op_attribute = oneflow._oneflow_internal.deprecated.MakeOpAttributeByString(
        str(op_attribute))
    parallel_conf = sess.ParallelConf4LazyInterfaceOpName(op_name)
    if not isinstance(
            parallel_conf,
            oneflow._oneflow_internal.oneflow.core.job.placement.ParallelConf):
        parallel_conf_cfg = placement_cfg.ParallelConf()
        parallel_conf_cfg.set_device_tag(parallel_conf.device_tag)
        for device_name in parallel_conf.device_name:
            parallel_conf_cfg.add_device_name(device_name)
        if parallel_conf.HasField("hierarchy"):
            hierarchy = shape_proto_cfg.ShapeProto()
            for dim in parallel_conf.hierarchy.dim:
                hierarchy.add_dim(dim)
            assert hierarchy.dim_size() > 0
            parallel_conf_cfg.mutable_hierarchy().CopyFrom(hierarchy)
        parallel_conf = parallel_conf_cfg
    blob_object = builder.MakeLazyRefBlobObject(op_name, cfg_op_attribute,
                                                parallel_conf)
    return blob_object
示例#2
0
文件: hob.py 项目: zzk0/oneflow
def is_trainable(ctx):
    assert in_global_mode(ctx)
    if oneflow._oneflow_internal.EagerExecutionEnabled():
        return session_ctx.GetDefaultSession().CurrentEagerGlobalFunctionDesc()
    else:
        job_name = oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobName(
        )
        return session_ctx.GetDefaultSession().GetFunctionDesc(job_name)
示例#3
0
 def HasAttr(self, attr_name):
     if attr_name == "flag_name2flag_value":
         return False
     name2default = session_ctx.GetDefaultSession().function_flag_name2default_val
     if attr_name in self.job_config_proto.flag_name2flag_value():
         return True
     return getattr(self.job_config_proto, "has_" + attr_name)()
示例#4
0
def GetLazyCurJobConfigProto():
    job_name = oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobName(
    )
    function_desc = session_ctx.GetDefaultSession().GetLazyFunctionDesc(
        job_name)
    assert function_desc is not None
    return function_desc.job_config_proto
示例#5
0
def GetJobNameScopePrefix(job_name):
    sess = session_context.GetDefaultSession()
    if job_name not in sess.job_name2name_scope_stack:
        return ""
    if len(sess.job_name2name_scope_stack[job_name]) == 0:
        return ""
    return "-".join(sess.job_name2name_scope_stack[job_name]) + "-"
示例#6
0
def name_scope_stack_pop():
    job_name = oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobName(
    )
    sess = session_context.GetDefaultSession()
    assert job_name in sess.job_name2name_scope_stack
    assert len(sess.job_name2name_scope_stack[job_name]) > 0
    return sess.job_name2name_scope_stack[job_name].pop()
示例#7
0
def name_scope_stack_push(name):
    job_name = oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobName(
    )
    sess = session_context.GetDefaultSession()
    if job_name not in sess.job_name2name_scope_stack:
        sess.job_name2name_scope_stack[job_name] = []
    sess.job_name2name_scope_stack[job_name].append(name)
示例#8
0
def name_scope(name: str) -> None:
    """Create a namespace. All variables within the namespace will have a prefix `[SCOPE NAME]-`. This is for convenience only and has no other effect on the system.
    Usage::

        with oneflow.compatible.single_client.scope.namespace("scope1"):
            ...
            with oneflow.compatible.single_client.scope.namespace("scope2"):
                ...

    Args:
        name: Name of this namespace

    """
    assert isinstance(name, str)
    name_scope_stack_push(name)

    def BuildScope(old_scope, builder):
        return builder.BuildScopeWithNewScopeName(old_scope, name)

    sess = session_context.GetDefaultSession()
    try:
        with scope_util.ScopeContext(scope_util.MakeScope(BuildScope)):
            yield
    finally:
        name_scope_stack_pop()
示例#9
0
def GetEagerInterfaceBlob(op_name):
    sync_default_session_if_normal()
    sess = session_ctx.GetDefaultSession()

    def CreateBlob():
        job_name = sess.JobName4InterfaceOpName(op_name)

        def Build(builder, Yield):
            blob_object = _GetInterfaceBlobObject(builder, op_name)
            lbi = lbi_util.LogicalBlobId()
            lbi.set_op_name(op_name)
            op_attribute = sess.OpAttribute4InterfaceOpName(op_name)
            assert len(op_attribute.output_bns) == 1
            lbi.set_blob_name(op_attribute.output_bns[0])
            if blob_object.op_arg_parallel_attr.is_mirrored():
                remote_blob = oneflow._oneflow_internal.EagerMirroredBlob(
                    lbi, blob_object, blob_register, job_name)
            else:
                remote_blob = oneflow._oneflow_internal.EagerConsistentBlob(
                    lbi, blob_object, blob_register, job_name)
            Yield(remote_blob)

        def AsyncGetInterfaceBlob(Yield):
            oneflow._oneflow_internal.deprecated.LogicalRun(
                lambda builder: Build(builder, Yield))

        blob = async_util.Await(1, AsyncGetInterfaceBlob)[0]
        return blob

    return sess.FindOrCreateLazyBlob(op_name, CreateBlob)
示例#10
0
def GetInterfaceBlobValue(op_name):
    sync_default_session_if_normal()
    sess = session_ctx.GetDefaultSession()
    job_name = sess.JobName4InterfaceOpName(op_name)

    def AsyncGetInterfaceBlobValue(Yield):
        def build(builder):
            blob_object = GetEagerInterfaceBlob(op_name).blob_object
            lbi = lbi_util.LogicalBlobId()
            lbi.set_op_name(op_name)
            op_attribute = sess.OpAttribute4InterfaceOpName(op_name)
            assert len(op_attribute.output_bns) == 1
            lbi.set_blob_name(op_attribute.output_bns[0])
            if not isinstance(lbi, lbi_util.LogicalBlobId):
                cfg_lbi = lbi_util.LogicalBlobId()
                cfg_lbi.set_op_name(lbi.op_name)
                cfg_lbi.set_blob_name(lbi.blob_name)
                lbi = cfg_lbi
            if blob_object.op_arg_parallel_attr.is_mirrored():
                remote_blob = oneflow._oneflow_internal.EagerMirroredBlob(
                    lbi, blob_object, blob_register, job_name)
            else:
                remote_blob = oneflow._oneflow_internal.EagerConsistentBlob(
                    lbi, blob_object, blob_register, job_name)
            value = remote_blob.numpy()
            Yield(value)

        oneflow._oneflow_internal.deprecated.LogicalRun(build)

    return async_util.Await(1, AsyncGetInterfaceBlobValue)[0]
示例#11
0
    def __getattr__(
        self, attr_name: str
    ) -> Callable[[Optional[Union[bool, int, float, str]]], None]:
        name2default = session_ctx.GetDefaultSession(
        ).function_flag_name2default_val
        assert attr_name in name2default
        flag_name2flag_value = (
            self.function_desc.job_config_proto.mutable_flag_name2flag_value())
        default_val = name2default[attr_name]

        def FunctionConfigSetter(
                attr_value: Optional[Union[bool, int, float,
                                           str]] = None) -> None:
            if default_val.HasField("at_bool"):
                if attr_value is None:
                    attr_value = True
                assert type(attr_value) is bool
                flag_name2flag_value[attr_name].set_at_bool(attr_value)
            elif default_val.HasField("at_int64"):
                assert type(attr_value) is int
                flag_name2flag_value[attr_name].set_at_int64(attr_value)
            elif default_val.HasField("at_double"):
                assert type(attr_value) is float
                flag_name2flag_value[attr_name].set_at_double(attr_value)
            elif default_val.HasField("at_string"):
                assert type(attr_value) is str
                flag_name2flag_value[attr_name].set_at_string(attr_value)
            else:
                raise NotImplementedError(
                    "config_flag `%s' with type %s is not supported" %
                    (attr_name, type(attr_value)))

        return FunctionConfigSetter
示例#12
0
文件: watcher.py 项目: zzk0/oneflow
def _WatcherHandler(handler_uuid, of_blob_ptr):
    uuid2handler = session_ctx.GetDefaultSession().uuid2watch_handler
    assert handler_uuid in uuid2handler
    (blob_watched, handler) = uuid2handler[handler_uuid]
    assert callable(handler)
    ndarray = ofblob.OfBlob(of_blob_ptr).CopyToNdarray()
    local_blob = local_blob_util.LocalBlob(ndarray, blob_watched.is_dynamic)
    handler(oft_util.TransformWatchedBlob(local_blob, handler))
示例#13
0
def get_eager_variable(
    name,
    shape=None,
    dtype=None,
    initializer=None,
    regularizer=None,
    trainable=None,
    model_name=None,
    random_seed=None,
    nd_sbp=None,
    reuse=True,
):
    assert isinstance(name, str)
    assert isinstance(
        shape,
        (list, tuple)), "param shape should be a list or tuple of dimension"
    job_name = oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobName(
    )
    name = name_scope.GetJobNameScopePrefix(job_name) + name
    sess = session_ctx.GetDefaultSession()
    (var_blob,
     job_var_blob) = sess.TryGetVariableBlobOfJobFromStash(job_name, name)
    if reuse is False:
        assert (
            job_var_blob is None
        ), "variable '{}' already exists, getting the same variable is not allowed when reuse is False".format(
            name)
    if job_var_blob is None:
        op_conf = GenerateVariableOpConf(
            name=name,
            shape=shape,
            dtype=dtype,
            initializer=initializer,
            regularizer=regularizer,
            trainable=trainable,
            model_name=model_name,
            random_seed=random_seed,
            nd_sbp=nd_sbp,
        )
        op_attribute = compile_context.CurJobAddConsistentOp(op_conf)
        if var_blob is None:
            var_blob = CreateEagerVariableBlob(op_attribute)
            op_executor.EagerInitVariableBlob(sess, op_conf, var_blob)
        assert isinstance(var_blob,
                          oneflow._oneflow_internal.EagerConsistentBlob)
        sess.StashVariableBlob4Job(job_name, op_conf.name, var_blob)
    else:
        assert isinstance(job_var_blob,
                          oneflow._oneflow_internal.EagerConsistentBlob)
        assert isinstance(var_blob,
                          oneflow._oneflow_internal.EagerConsistentBlob)
        assert var_blob.IdenticalTo(job_var_blob)
    bw_blob_register = gradient_util.GetDefaultBackwardBlobRegister()
    bw_blob_register.TrySetObject4BlobName(var_blob.logical_blob_name,
                                           var_blob.blob_object)
    return var_blob
示例#14
0
def _GetDefaultConfigProto():
    config_proto = job_set_util.ConfigProto()
    config_proto.resource.machine_num = 0
    if oneflow._oneflow_internal.flags.with_cuda():
        config_proto.resource.gpu_device_num = 1
    else:
        config_proto.resource.cpu_device_num = 1
        config_proto.resource.gpu_device_num = 0
    config_proto.session_id = session_ctx.GetDefaultSession().id
    return config_proto
示例#15
0
def get_lazy_variable(
    name,
    shape=None,
    dtype=None,
    initializer=None,
    regularizer=None,
    trainable=None,
    model_name=None,
    random_seed=None,
    nd_sbp=None,
    reuse=True,
):
    assert isinstance(name, str)
    assert isinstance(
        shape,
        (list, tuple)), "param shape should be a list or tuple of dimension"
    job_name = oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobName(
    )
    name = name_scope.GetJobNameScopePrefix(job_name) + name
    sess = session_ctx.GetDefaultSession()
    (var_blob,
     job_var_blob) = sess.TryGetVariableBlobOfJobFromStash(job_name, name)
    if reuse is False:
        assert (
            job_var_blob is None
        ), "variable '{}' already exists, getting the same variable is not allowed when param reuse is False".format(
            name)
    if job_var_blob is None:
        op_conf = GenerateVariableOpConf(
            name=name,
            shape=shape,
            dtype=dtype,
            initializer=initializer,
            regularizer=regularizer,
            trainable=trainable,
            model_name=model_name,
            random_seed=random_seed,
            nd_sbp=nd_sbp,
        )
        job_var_blob = _CreateVariableBlob(op_conf)
        assert isinstance(job_var_blob,
                          oneflow._oneflow_internal.LazyConsistentBlob)
        sess.StashVariableBlob4Job(job_name, op_conf.name, job_var_blob)
        if var_blob is not None:
            assert isinstance(var_blob,
                              oneflow._oneflow_internal.LazyConsistentBlob)
            assert var_blob.IdenticalTo(job_var_blob)
    else:
        assert isinstance(job_var_blob,
                          oneflow._oneflow_internal.LazyConsistentBlob)
        assert isinstance(var_blob,
                          oneflow._oneflow_internal.LazyConsistentBlob)
        assert var_blob.IdenticalTo(job_var_blob)
    return job_var_blob
示例#16
0
def _MakeModelInitJobFunc():
    def push_cb(blob):
        pass

    def finish_cb():
        pass

    sess = session_ctx.GetDefaultSession()
    return job_instance.MakeJobInstance(
        str(sess.inter_user_job_info.global_model_init_job_name),
        push_cb=push_cb,
        finish_cb=finish_cb,
    )
示例#17
0
    def __init__(self, is_mirrored):
        self.is_mirrored_ = is_mirrored
        self.scope_context_ = None
        sess = session_ctx.GetDefaultSession()
        if sess.is_running and (
                not sess.has_empty_is_mirrored_strategy_enabled_stack()):

            def BuildScope(old_scope, builder):
                return builder.BuildScopeWithNewIsMirrored(
                    old_scope, is_mirrored)

            self.scope_context_ = scope_util.ScopeContext(
                scope_util.MakeScope(BuildScope))
示例#18
0
def CurJobAddMirroredOp(op_conf, scope_symbol=None):
    assert not hob.consistent_view_enabled(None)
    if scope_symbol is None:
        scope_symbol = flow.current_scope()
    op_conf.scope_symbol_id = scope_symbol.symbol_id
    if not op_conf.HasField("device_tag"):
        device_tag = scope_symbol.device_parallel_desc_symbol.device_tag
        op_conf.device_tag = device_tag
    op_attr = c_api_util.CurJobBuildAndInferCtx_AddAndInferMirroredOp(op_conf)
    if c_api_util.IsInterfaceOpConf(op_conf):
        sess = session_ctx.GetDefaultSession()
        sess.AddInfo4InterfaceOpName(op_conf.name, op_attr)
    return op_attr
示例#19
0
def GetGlobalModePlacementScope(device_tag, machine_device_ids, hierarchy=None):
    if isinstance(machine_device_ids, (list, tuple)) == False:
        machine_device_ids = [machine_device_ids]
    sess = session_ctx.GetDefaultSession()
    if hierarchy is not None:
        hierarchy = oneflow._oneflow_internal.Size(tuple(hierarchy))

    def BuildScope(old_scope, builder):
        return builder.BuildScopeWithNewParallelDesc(
            old_scope, device_tag, machine_device_ids, hierarchy
        )

    scope_ctx = scope_util.ScopeContext(scope_util.MakeScope(BuildScope))
    return placement_ctx.GlobalModePlacementScope(scope_ctx)
示例#20
0
def _MakeModelSaveJobFunc(path):
    def push_cb(blob):
        blob.CopyFromNdarray(np.frombuffer(path.encode("ascii"),
                                           dtype=np.int8))

    def finish_cb():
        pass

    sess = session_ctx.GetDefaultSession()
    return job_instance.MakeJobInstance(
        str(sess.inter_user_job_info.global_model_save_job_name),
        push_cb=push_cb,
        finish_cb=finish_cb,
    )
示例#21
0
def GetNormalModePlacementScope(device_tag, machine_device_ids, hierarchy=None):
    if isinstance(machine_device_ids, tuple):
        machine_device_ids = list(machine_device_ids)
    if not isinstance(machine_device_ids, list):
        machine_device_ids = [machine_device_ids]
    sess = session_ctx.GetDefaultSession()
    if hierarchy is not None:
        hierarchy = oneflow._oneflow_internal.Size(tuple(hierarchy))
    scope = scope_util.MakeScope(
        lambda old_scope, builder: builder.BuildScopeWithNewParallelDesc(
            old_scope, device_tag, machine_device_ids, hierarchy
        )
    )
    return scope_util.ScopeContext(scope)
示例#22
0
def GetAllVariables(
) -> Dict[str, oneflow._oneflow_internal.EagerConsistentBlob]:
    """
    Get all variables of all jobs as a dict.
    """
    sync_default_session_if_normal()
    sess = session_ctx.GetDefaultSession()
    interface_ops = sess.interface_ops
    variables = {}
    for op in interface_ops:
        op_attr = sess.OpAttribute4InterfaceOpName(op)
        if op_attr.op_conf.WhichOneof("op_type") != "variable_conf":
            continue
        variables[op] = interface_op_read_and_write.GetEagerInterfaceBlob(op)
    return variables
示例#23
0
    def Decorator(job_func):
        if not hasattr(job_func, "__oneflow_function_signature__"):
            job_func.__oneflow_function_signature__ = inspect.signature(
                job_func)
        oft_util.CheckGlobalFunctionAnnotation(
            job_func.__oneflow_function_signature__)
        sess = session_ctx.GetDefaultSession()

        @functools.wraps(job_func)
        def Func(*args, **kwargs):
            return _RunLazyJob(sess, job_func, *args, **kwargs)

        sess.AddJob(_CloneFunctionDesc(function_config.function_desc,
                                       job_func))
        for x in dir(job_func):
            if x.startswith("__oneflow_"):
                setattr(Func, x, getattr(job_func, x))
        return Func
示例#24
0
def _FindOrCreateVarBlobObject(op_attribute, parallel_conf, blob_register):
    job_name = oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobName(
    )
    name = name_scope.GetJobNameScopePrefix(
        job_name) + op_attribute.op_conf.name
    sess = session_ctx.GetDefaultSession()
    (var_blob, _) = sess.TryGetVariableBlobOfJobFromStash(job_name, name)
    if var_blob is not None:
        blob_register.SetObject4BlobName(var_blob.logical_blob_name,
                                         var_blob.blob_object)
        return
    _NaiveInterpret(op_attribute, parallel_conf, blob_register)
    var_blob = _MakeEagerLogicalBlob(op_attribute,
                                     "out",
                                     blob_register=blob_register)
    EagerInitVariableBlob(sess, op_attribute.op_conf, var_blob)
    sess.StashVariableBlob4Job(job_name, op_attribute.op_conf.name, var_blob)
    return var_blob
示例#25
0
def find_or_create_module(module_name, create, reuse=False):
    assert callable(create)
    sess = session_ctx.GetDefaultSession()
    job_name = flow.current_global_function_desc().job_config_proto.job_name()
    if job_name not in sess.job_name2module_name2module_:
        sess.job_name2module_name2module_[job_name] = {}
    module_name2module = sess.job_name2module_name2module_[job_name]
    if module_name not in module_name2module:
        module = create()
        assert isinstance(module, module_util.Module)
        module_name2module[module_name] = module
    elif not reuse:
        assert module_name not in sess.existed_module_names_, (
            "duplicated module_name `%s' in global_function `%s'" %
            (module_name, job_name))
    else:
        pass
    sess.existed_module_names_.add(module_name)
    return module_name2module[module_name]
示例#26
0
def _GetReturnOpConfAndOutLbiAndScope(remote_blob, allow_cpu_return_op=True):
    op_conf = op_conf_util.OperatorConf()
    op_conf.name = id_util.UniqueStr("Return_")
    setattr(op_conf.return_conf, "in", remote_blob.unique_name)
    op_conf.return_conf.out = "out"
    if allow_cpu_return_op:
        op_conf.device_tag = "cpu"
    lbi = logical_blob_id_util.LogicalBlobId()
    lbi.op_name = op_conf.name
    lbi.blob_name = "out"
    parallel_conf = placement_cfg.ParallelConf()
    parallel_conf.CopyFrom(remote_blob.parallel_conf)

    def BuildScope(old_scope, builder):
        return builder.BuildScopeWithNewParallelConf(old_scope, parallel_conf)

    sess = session_ctx.GetDefaultSession()
    scope = scope_util.MakeScope(BuildScope)
    return (op_conf, lbi, scope)
示例#27
0
def _EagerRunModelLoad(var_op_conf, snapshot_path):
    assert isinstance(snapshot_path, str)
    assert os.path.basename(snapshot_path) == "out"
    snapshot_path = os.path.dirname(snapshot_path)
    assert os.path.basename(snapshot_path) == var_op_conf.name
    snapshot_path = os.path.dirname(snapshot_path)
    (path_input_op_conf, path_lbi) = _GenModelIOPathInputOpConfAndRetLbi()
    path_input_blob_objects = {}
    (
        BuildModelIOPathInputInstruction,
        BuildFeedPathInstruction,
    ) = _MakeModelIOPathInputBuilds(path_input_op_conf, snapshot_path,
                                    path_input_blob_objects)
    (model_load_op_conf,
     _) = _GenModelLoadOpConfAndRetLbi(var_op_conf, path_lbi)
    model_load_blob_objects = oneflow._oneflow_internal.deprecated.BnInOp2BlobObject(
    )

    def BuildModelLoadInstruction(builder):
        path_blob_object = path_input_blob_objects["out"]
        model_load_blob_objects["path"] = path_blob_object
        op_attribute = op_infer_util.Infer(
            model_load_op_conf, ibn2blob_object=model_load_blob_objects)
        parallel_conf = path_blob_object.parallel_desc_symbol.parallel_conf
        cfg_op_attribute = oneflow._oneflow_internal.deprecated.MakeOpAttributeByString(
            str(op_attribute))
        builder.StatelessCall(
            cfg_op_attribute,
            parallel_conf,
            model_load_blob_objects,
            boxing_util.BoxingTo,
        )

    sess = session_ctx.GetDefaultSession()
    with scope_util.ScopeContext(scope_util.MakeScope(_BuildNotMirroredScope)):
        oneflow._oneflow_internal.deprecated.LogicalRun(
            BuildModelIOPathInputInstruction)
        oneflow._oneflow_internal.deprecated.LogicalRun(
            BuildFeedPathInstruction)
        oneflow._oneflow_internal.deprecated.LogicalRun(
            BuildModelLoadInstruction)
    return model_load_blob_objects["out_0"]
示例#28
0
 def __getattr__(self, attr_name):
     assert attr_name != "flag_name2flag_value"
     flag_name2flag_value = self.job_config_proto.flag_name2flag_value()
     name2default = session_ctx.GetDefaultSession().function_flag_name2default_val
     if attr_name not in name2default:
         assert getattr(self.job_config_proto, "has_" + attr_name)()
         return getattr(self.job_config_proto, attr_name)()
     attr_value = name2default[attr_name]
     if attr_name in flag_name2flag_value:
         attr_value = flag_name2flag_value[attr_name]
     if attr_value.HasField("at_bool"):
         return attr_value.at_bool
     elif attr_value.HasField("at_int64"):
         return attr_value.at_int64
     elif attr_value.HasField("at_double"):
         return attr_value.at_double
     elif attr_value.HasField("at_string"):
         return attr_value.at_string
     else:
         raise NotImplementedError()
示例#29
0
def _Watch(op_attribute, parallel_conf, blob_register):
    lbi = op_attribute.arg_signature.bn_in_op2lbi["in"]
    uuid = op_attribute.op_conf.foreign_watch_conf.handler_uuid
    lbn = "%s/%s" % (lbi.op_name, lbi.blob_name)
    in_blob_object = blob_register.GetObject4BlobName(lbn)
    if not isinstance(lbi, lbi_util.LogicalBlobId):
        cfg_lbi = lbi_util.LogicalBlobId()
        cfg_lbi.set_op_name(lbi.op_name)
        cfg_lbi.set_blob_name(lbi.blob_name)
        lbi = cfg_lbi
    if in_blob_object.op_arg_parallel_attr.is_mirrored():
        blob = oneflow._oneflow_internal.EagerMirroredBlob(
            lbi, in_blob_object, default_blob_register)
    else:
        blob = oneflow._oneflow_internal.EagerConsistentBlob(
            lbi, in_blob_object, default_blob_register)
    uuid2watch_handler = session_ctx.GetDefaultSession().uuid2watch_handler
    assert uuid in uuid2watch_handler
    uuid2watch_handler[uuid](blob)
    del uuid2watch_handler[uuid]
示例#30
0
def _EagerRunModelInit(var_op_conf):
    (op_conf, _) = _GenModelInitOpConfAndRetLbi(var_op_conf)
    bn_in_op2blob_object = oneflow._oneflow_internal.deprecated.BnInOp2BlobObject(
    )

    def BuildModelInitInstruction(builder):
        upstream_signature = op_node_signature_pb.OpNodeSignature()
        op_conf.scope_symbol_id = flow.current_scope().symbol_id
        op_attribute = c_api_util.InferOpConf(op_conf, upstream_signature)
        parallel_conf = flow.current_scope(
        ).device_parallel_desc_symbol.parallel_conf
        cfg_op_attribute = oneflow._oneflow_internal.deprecated.MakeOpAttributeByString(
            str(op_attribute))
        builder.StatelessCall(cfg_op_attribute, parallel_conf,
                              bn_in_op2blob_object, boxing_util.BoxingTo)

    sess = session_ctx.GetDefaultSession()
    with scope_util.ScopeContext(scope_util.MakeScope(_BuildNotMirroredScope)):
        oneflow._oneflow_internal.deprecated.LogicalRun(
            BuildModelInitInstruction)
    return bn_in_op2blob_object["out_0"]