コード例 #1
0
ファイル: interpret_util.py プロジェクト: zjureel/oneflow
def EagerOpKernelForward(add_and_infer, op_conf, opkernel_object):
    op_attribute = add_and_infer(op_conf, opkernel_object.scope_symbol)
    op_executor.OpKernelCall(opkernel_object, op_attribute, blob_register)
    bw_blob_register = gradient_util.GetDefaultBackwardBlobRegister()
    gradient_util.TrySetBackwardUsedBlobObject(op_attribute, blob_register,
                                               bw_blob_register)
    return op_attribute
コード例 #2
0
ファイル: interpret_util.py プロジェクト: zjureel/oneflow
def EagerForward(add_and_infer, op_conf, scope_symbol=None):
    op_attribute = add_and_infer(op_conf, scope_symbol)
    parallel_conf = scope_symbol.device_parallel_desc_symbol.parallel_conf
    op_executor.Interpret(op_attribute, parallel_conf, blob_register)
    bw_blob_register = gradient_util.GetDefaultBackwardBlobRegister()
    gradient_util.TrySetBackwardUsedBlobObject(op_attribute, blob_register,
                                               bw_blob_register)
    return op_attribute
コード例 #3
0
def get_eager_variable(
    name,
    shape=None,
    dtype=None,
    initializer=None,
    regularizer=None,
    trainable=None,
    model_name=None,
    random_seed=None,
    distribute=oneflow_api.distribute.broadcast(),
    reuse=True,
):
    assert isinstance(name, str)
    assert isinstance(
        shape, (list, tuple)
    ), "param shape should be a list or tuple of dimension"

    job_name = oneflow_api.JobBuildAndInferCtx_GetCurrentJobName()
    name = name_scope.GetJobNameScopePrefix(job_name) + name
    sess = session_ctx.GetDefaultSession()
    var_blob, job_var_blob = sess.TryGetVariableBlobOfJobFromStash(job_name, name)

    if reuse is False:
        assert job_var_blob is None, (
            "varaible '{}' already exists, "
            "getting the same variable is not allowed "
            "when reuse is False".format(name)
        )

    if job_var_blob is None:
        op_conf = GenerateVariableOpConf(
            name=name,
            shape=shape,
            dtype=dtype,
            initializer=initializer,
            regularizer=regularizer,
            trainable=trainable,
            model_name=model_name,
            random_seed=random_seed,
            distribute=distribute,
        )
        op_attribute = compile_context.CurJobAddConsistentOp(op_conf)
        if var_blob is None:
            var_blob = CreateEagerVariableBlob(op_attribute)
            op_executor.EagerInitVariableBlob(sess, op_conf, var_blob)

        assert isinstance(var_blob, remote_blob_util.EagerConsistentBlob)
        sess.StashVariableBlob4Job(job_name, op_conf.name, var_blob)
    else:
        assert isinstance(job_var_blob, remote_blob_util.EagerConsistentBlob)
        assert isinstance(var_blob, remote_blob_util.EagerConsistentBlob)
        assert var_blob.IdenticalTo(job_var_blob)

    bw_blob_register = gradient_util.GetDefaultBackwardBlobRegister()
    bw_blob_register.TrySetObject4BlobName(
        var_blob.logical_blob_name, var_blob.blob_object
    )
    return var_blob
コード例 #4
0
def MirroredCast(op_attribute_str, parallel_conf):
    op_attribute = text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())
    blob_register = oneflow_api.GetDefaultBlobRegister()
    is_cast_to_mirrored = op_attribute.op_conf.HasField("cast_to_mirrored_conf")
    is_cast_from_mirrored = op_attribute.op_conf.HasField("cast_from_mirrored_conf")
    assert is_cast_to_mirrored or is_cast_from_mirrored
    _MirroredCastAndAddOutputBlobReleaser(op_attribute, blob_register)
    bw_blob_register = gradient_util.GetDefaultBackwardBlobRegister()
    gradient_util.TrySetBackwardUsedBlobObject(
        op_attribute, blob_register, bw_blob_register
    )
コード例 #5
0
    def RemoteBlobList(self):
        remote_blob_list = []
        for k in self.op_conf_.user_conf.output:
            if k not in self.output_arg_key_list_:
                raise ValueError(
                    "output_arg_name {} of {} op is not set in python op builder"
                    .format(k, self.op_conf_.name))

        for output_arg_name in self.output_arg_key_list_:
            assert output_arg_name in self.op_conf_.user_conf.output
            for i in range(
                    len(self.op_conf_.user_conf.output[output_arg_name].s)):
                lbi = logical_blob_id_util.LogicalBlobId()
                lbi.op_name = self.op_conf_.name
                lbi.blob_name = "{}_{}".format(output_arg_name, i)
                remote_blob_obj = self.MakeRemoteBlob(lbi)
                remote_blob_list.append(remote_blob_obj)
                if flow.eager_execution_enabled():
                    gradient_util.GetDefaultBackwardBlobRegister(
                    ).TrySetObject4BlobName(remote_blob_obj.logical_blob_name,
                                            remote_blob_obj.blob_object)

        return tuple(remote_blob_list)
コード例 #6
0
def InterpretCompletedOp(op_attribute_str, parallel_conf):
    op_attribute = text_format.Parse(op_attribute_str,
                                     op_attribute_pb.OpAttribute())
    blob_register = gradient_util.GetDefaultBackwardBlobRegister()
    _InterpretCompletedOp(op_attribute, parallel_conf, blob_register)
    gradient_util.ReleaseUnusedBlobObject(op_attribute, blob_register)
コード例 #7
0
def eager_add_loss(loss):
    c_api_util.CurJobBuildAndInferCtx_AddLossLogicalBlobName(loss.unique_name)
    gradient_util.GetDefaultBackwardBlobRegister().TrySetObject4BlobName(
        loss.logical_blob_name, loss.blob_object)