Example #1
0
def get_eager_variable(
    name,
    shape=None,
    dtype=None,
    initializer=None,
    regularizer=None,
    trainable=None,
    model_name=None,
    random_seed=None,
    nd_sbp=None,
    reuse=True,
):
    assert isinstance(name, str)
    assert isinstance(
        shape,
        (list, tuple)), "param shape should be a list or tuple of dimension"
    job_name = oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobName(
    )
    name = name_scope.GetJobNameScopePrefix(job_name) + name
    sess = session_ctx.GetDefaultSession()
    (var_blob,
     job_var_blob) = sess.TryGetVariableBlobOfJobFromStash(job_name, name)
    if reuse is False:
        assert (
            job_var_blob is None
        ), "variable '{}' already exists, getting the same variable is not allowed when reuse is False".format(
            name)
    if job_var_blob is None:
        op_conf = GenerateVariableOpConf(
            name=name,
            shape=shape,
            dtype=dtype,
            initializer=initializer,
            regularizer=regularizer,
            trainable=trainable,
            model_name=model_name,
            random_seed=random_seed,
            nd_sbp=nd_sbp,
        )
        op_attribute = compile_context.CurJobAddConsistentOp(op_conf)
        if var_blob is None:
            var_blob = CreateEagerVariableBlob(op_attribute)
            op_executor.EagerInitVariableBlob(sess, op_conf, var_blob)
        assert isinstance(var_blob,
                          oneflow._oneflow_internal.EagerConsistentBlob)
        sess.StashVariableBlob4Job(job_name, op_conf.name, var_blob)
    else:
        assert isinstance(job_var_blob,
                          oneflow._oneflow_internal.EagerConsistentBlob)
        assert isinstance(var_blob,
                          oneflow._oneflow_internal.EagerConsistentBlob)
        assert var_blob.IdenticalTo(job_var_blob)
    bw_blob_register = gradient_util.GetDefaultBackwardBlobRegister()
    bw_blob_register.TrySetObject4BlobName(var_blob.logical_blob_name,
                                           var_blob.blob_object)
    return var_blob
Example #2
0
def get_lazy_variable(
    name,
    shape=None,
    dtype=None,
    initializer=None,
    regularizer=None,
    trainable=None,
    model_name=None,
    random_seed=None,
    nd_sbp=None,
    reuse=True,
):
    assert isinstance(name, str)
    assert isinstance(
        shape,
        (list, tuple)), "param shape should be a list or tuple of dimension"
    job_name = oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobName(
    )
    name = name_scope.GetJobNameScopePrefix(job_name) + name
    sess = session_ctx.GetDefaultSession()
    (var_blob,
     job_var_blob) = sess.TryGetVariableBlobOfJobFromStash(job_name, name)
    if reuse is False:
        assert (
            job_var_blob is None
        ), "variable '{}' already exists, getting the same variable is not allowed when param reuse is False".format(
            name)
    if job_var_blob is None:
        op_conf = GenerateVariableOpConf(
            name=name,
            shape=shape,
            dtype=dtype,
            initializer=initializer,
            regularizer=regularizer,
            trainable=trainable,
            model_name=model_name,
            random_seed=random_seed,
            nd_sbp=nd_sbp,
        )
        job_var_blob = _CreateVariableBlob(op_conf)
        assert isinstance(job_var_blob,
                          oneflow._oneflow_internal.LazyConsistentBlob)
        sess.StashVariableBlob4Job(job_name, op_conf.name, job_var_blob)
        if var_blob is not None:
            assert isinstance(var_blob,
                              oneflow._oneflow_internal.LazyConsistentBlob)
            assert var_blob.IdenticalTo(job_var_blob)
    else:
        assert isinstance(job_var_blob,
                          oneflow._oneflow_internal.LazyConsistentBlob)
        assert isinstance(var_blob,
                          oneflow._oneflow_internal.LazyConsistentBlob)
        assert var_blob.IdenticalTo(job_var_blob)
    return job_var_blob
Example #3
0
    def OpName(self, op_name):
        job_name = oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobName(
        )
        op_name = name_scope.GetJobNameScopePrefix(job_name) + op_name
        self.user_op_.op_conf_.name = op_name
        user_conf = self.user_op_.op_conf_.user_conf

        def GetLbn(output_name, i):
            return "{}/{}_{}".format(op_name, output_name, i)

        for (output_name, output) in user_conf.output.items():
            output.s[:] = [
                GetLbn(output_name, i) for i in range(len(output.s))
            ]
        return self
Example #4
0
def _FindOrCreateVarBlobObject(op_attribute, parallel_conf, blob_register):
    job_name = oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobName(
    )
    name = name_scope.GetJobNameScopePrefix(
        job_name) + op_attribute.op_conf.name
    sess = session_ctx.GetDefaultSession()
    (var_blob, _) = sess.TryGetVariableBlobOfJobFromStash(job_name, name)
    if var_blob is not None:
        blob_register.SetObject4BlobName(var_blob.logical_blob_name,
                                         var_blob.blob_object)
        return
    _NaiveInterpret(op_attribute, parallel_conf, blob_register)
    var_blob = _MakeEagerLogicalBlob(op_attribute,
                                     "out",
                                     blob_register=blob_register)
    EagerInitVariableBlob(sess, op_attribute.op_conf, var_blob)
    sess.StashVariableBlob4Job(job_name, op_attribute.op_conf.name, var_blob)
    return var_blob
Example #5
0
def eager_consistent_user_op_module_builder(op_type_name):
    job_name = oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobName(
    )
    op_name = name_scope.GetJobNameScopePrefix(job_name) + op_type_name
    return UserOpModuleBuilder(EagerConsistentUserOpModule, op_name,
                               op_type_name)
Example #6
0
def api_consistent_user_op_builder(op_name):
    job_name = oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobName(
    )
    op_name = name_scope.GetJobNameScopePrefix(job_name) + op_name
    return UserOpConfBuilder(ConsistentUserOp, op_name, None)