Beispiel #1
0
def get_eager_variable(
    name,
    shape=None,
    dtype=None,
    initializer=None,
    regularizer=None,
    trainable=None,
    model_name=None,
    random_seed=None,
    distribute=oneflow_api.distribute.broadcast(),
    reuse=True,
):
    assert isinstance(name, str)
    assert isinstance(
        shape, (list, tuple)
    ), "param shape should be a list or tuple of dimension"

    job_name = oneflow_api.JobBuildAndInferCtx_GetCurrentJobName()
    name = name_scope.GetJobNameScopePrefix(job_name) + name
    sess = session_ctx.GetDefaultSession()
    var_blob, job_var_blob = sess.TryGetVariableBlobOfJobFromStash(job_name, name)

    if reuse is False:
        assert job_var_blob is None, (
            "varaible '{}' already exists, "
            "getting the same variable is not allowed "
            "when reuse is False".format(name)
        )

    if job_var_blob is None:
        op_conf = GenerateVariableOpConf(
            name=name,
            shape=shape,
            dtype=dtype,
            initializer=initializer,
            regularizer=regularizer,
            trainable=trainable,
            model_name=model_name,
            random_seed=random_seed,
            distribute=distribute,
        )
        op_attribute = compile_context.CurJobAddConsistentOp(op_conf)
        if var_blob is None:
            var_blob = CreateEagerVariableBlob(op_attribute)
            op_executor.EagerInitVariableBlob(sess, op_conf, var_blob)

        assert isinstance(var_blob, remote_blob_util.EagerConsistentBlob)
        sess.StashVariableBlob4Job(job_name, op_conf.name, var_blob)
    else:
        assert isinstance(job_var_blob, remote_blob_util.EagerConsistentBlob)
        assert isinstance(var_blob, remote_blob_util.EagerConsistentBlob)
        assert var_blob.IdenticalTo(job_var_blob)

    bw_blob_register = gradient_util.GetDefaultBackwardBlobRegister()
    bw_blob_register.TrySetObject4BlobName(
        var_blob.logical_blob_name, var_blob.blob_object
    )
    return var_blob
Beispiel #2
0
def get_lazy_variable(
    name,
    shape=None,
    dtype=None,
    initializer=None,
    regularizer=None,
    trainable=None,
    model_name=None,
    random_seed=None,
    parallel_distribution=None,
    reuse=True,
):
    assert isinstance(name, str)
    assert isinstance(
        shape,
        (list, tuple)), "param shape should be a list or tuple of dimension"

    job_name = oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobName(
    )
    name = name_scope.GetJobNameScopePrefix(job_name) + name
    sess = session_ctx.GetDefaultSession()
    var_blob, job_var_blob = sess.TryGetVariableBlobOfJobFromStash(
        job_name, name)

    if reuse is False:
        assert job_var_blob is None, (
            "variable '{}' already exists, "
            "getting the same variable is not allowed "
            "when param reuse is False".format(name))

    if job_var_blob is None:
        op_conf = GenerateVariableOpConf(
            name=name,
            shape=shape,
            dtype=dtype,
            initializer=initializer,
            regularizer=regularizer,
            trainable=trainable,
            model_name=model_name,
            random_seed=random_seed,
            parallel_distribution=parallel_distribution,
        )
        job_var_blob = _CreateVariableBlob(op_conf)
        assert isinstance(job_var_blob,
                          oneflow._oneflow_internal.LazyConsistentBlob)
        sess.StashVariableBlob4Job(job_name, op_conf.name, job_var_blob)
        if var_blob is not None:
            assert isinstance(var_blob,
                              oneflow._oneflow_internal.LazyConsistentBlob)
            assert var_blob.IdenticalTo(job_var_blob)
    else:
        assert isinstance(job_var_blob,
                          oneflow._oneflow_internal.LazyConsistentBlob)
        assert isinstance(var_blob,
                          oneflow._oneflow_internal.LazyConsistentBlob)
        assert var_blob.IdenticalTo(job_var_blob)

    return job_var_blob
Beispiel #3
0
    def OpName(self, op_name):
        job_name = c_api_util.JobBuildAndInferCtx_GetCurrentJobName()
        op_name = name_scope.GetJobNameScopePrefix(job_name) + op_name

        self.user_op_.op_conf_.name = op_name
        user_conf = self.user_op_.op_conf_.user_conf

        def GetLbn(output_name, i):
            return "{}/{}_{}".format(op_name, output_name, i)

        for output_name, output in user_conf.output.items():
            output.s[:] = [GetLbn(output_name, i) for i in range(len(output.s))]
        return self
Beispiel #4
0
def _FindOrCreateVarBlobObject(op_attribute, parallel_conf, blob_register):
    job_name = oneflow_api.JobBuildAndInferCtx_GetCurrentJobName()
    name = name_scope.GetJobNameScopePrefix(job_name) + op_attribute.op_conf.name
    sess = session_ctx.GetDefaultSession()
    var_blob, _ = sess.TryGetVariableBlobOfJobFromStash(job_name, name)
    if var_blob is not None:
        blob_register.SetObject4BlobName(
            var_blob.logical_blob_name, var_blob.blob_object
        )
        return
    _NaiveInterpret(op_attribute, parallel_conf, blob_register)
    var_blob = _MakeEagerLogicalBlob(op_attribute, "out", blob_register=blob_register)
    EagerInitVariableBlob(sess, op_attribute.op_conf, var_blob)
    sess.StashVariableBlob4Job(job_name, op_attribute.op_conf.name, var_blob)
    return var_blob
Beispiel #5
0
def eager_consistent_user_op_module_builder(op_type_name):
    job_name = oneflow_api.JobBuildAndInferCtx_GetCurrentJobName()
    op_name = name_scope.GetJobNameScopePrefix(job_name) + op_type_name
    return UserOpModuleBuilder(EagerConsistentUserOpModule, op_name,
                               op_type_name)
Beispiel #6
0
def api_consistent_user_op_builder(op_name):
    job_name = oneflow_api.JobBuildAndInferCtx_GetCurrentJobName()
    op_name = name_scope.GetJobNameScopePrefix(job_name) + op_name
    return UserOpConfBuilder(ConsistentUserOp, op_name, None)
def lazy_consistent_user_op_module_builder(op_type_name):
    job_name = c_api_util.JobBuildAndInferCtx_GetCurrentJobName()
    op_name = name_scope.GetJobNameScopePrefix(job_name) + op_type_name
    return UserOpModuleBuilder(LazyConsistentUserOpModule, op_name,
                               op_type_name)
def eager_logical_user_op_module_builder(op_type_name):
    job_name = c_api_util.JobBuildAndInferCtx_GetCurrentJobName()
    op_name = name_scope.GetJobNameScopePrefix(job_name) + op_type_name
    return UserOpModuleBuilder(EagerLogicalUserOpModule, op_name, op_type_name)
def eager_user_op_builder(op_name):
    job_name = c_api_util.JobBuildAndInferCtx_GetCurrentJobName()
    op_name = name_scope.GetJobNameScopePrefix(job_name) + op_name
    return UserOpConfBuilder(EagerUserOp, op_name, None)
Beispiel #10
0
 def __init__(self, job_name, op_name, user_op_class):
     name_scope_prefix = name_scope.GetJobNameScopePrefix(job_name)
     self.user_op_ = user_op_class(name_scope_prefix + op_name)
Beispiel #11
0
def lazy_user_op_module_builder(op_type_name):
    job_name = oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobName(
    )
    op_name = name_scope.GetJobNameScopePrefix(job_name) + op_type_name
    return UserOpModuleBuilder(LazyUserOpModule, op_name, op_type_name)