Ejemplo n.º 1
0
def get_eager_variable(
    name,
    shape=None,
    dtype=None,
    initializer=None,
    regularizer=None,
    trainable=None,
    model_name=None,
    random_seed=None,
    distribute=oneflow_api.distribute.broadcast(),
    reuse=True,
):
    assert isinstance(name, str)
    assert isinstance(
        shape, (list, tuple)
    ), "param shape should be a list or tuple of dimension"

    job_name = oneflow_api.JobBuildAndInferCtx_GetCurrentJobName()
    name = name_scope.GetJobNameScopePrefix(job_name) + name
    sess = session_ctx.GetDefaultSession()
    var_blob, job_var_blob = sess.TryGetVariableBlobOfJobFromStash(job_name, name)

    if reuse is False:
        assert job_var_blob is None, (
            "varaible '{}' already exists, "
            "getting the same variable is not allowed "
            "when reuse is False".format(name)
        )

    if job_var_blob is None:
        op_conf = GenerateVariableOpConf(
            name=name,
            shape=shape,
            dtype=dtype,
            initializer=initializer,
            regularizer=regularizer,
            trainable=trainable,
            model_name=model_name,
            random_seed=random_seed,
            distribute=distribute,
        )
        op_attribute = compile_context.CurJobAddConsistentOp(op_conf)
        if var_blob is None:
            var_blob = CreateEagerVariableBlob(op_attribute)
            op_executor.EagerInitVariableBlob(sess, op_conf, var_blob)

        assert isinstance(var_blob, remote_blob_util.EagerConsistentBlob)
        sess.StashVariableBlob4Job(job_name, op_conf.name, var_blob)
    else:
        assert isinstance(job_var_blob, remote_blob_util.EagerConsistentBlob)
        assert isinstance(var_blob, remote_blob_util.EagerConsistentBlob)
        assert var_blob.IdenticalTo(job_var_blob)

    bw_blob_register = gradient_util.GetDefaultBackwardBlobRegister()
    bw_blob_register.TrySetObject4BlobName(
        var_blob.logical_blob_name, var_blob.blob_object
    )
    return var_blob
Ejemplo n.º 2
0
 def EagerAddAndInferOp(self, op_conf: op_conf_util.OperatorConf) -> Any:
     parallel_symbol = oneflow.current_scope().device_parallel_desc_symbol
     if (parallel_symbol.device_tag == "gpu" and list(
             dict(parallel_symbol.machine_id2device_id_list).keys()) == [0]
             and parallel_symbol.parallel_num == 1):
         device_tag = "gpu"
         device_ids = "0:%s" % (
             parallel_symbol.machine_id2device_id_list[0][0])
     else:
         device_tag = "cpu"
         device_ids = "0:0"
     with oneflow.scope.placement(device_tag, device_ids):
         return compile_context.CurJobAddConsistentOp(op_conf)
Ejemplo n.º 3
0
def _CreateVariableBlob(op_conf):
    compile_context.CurJobAddConsistentOp(op_conf)
    lbi = logical_blob_id_util.LogicalBlobId()
    lbi.op_name = op_conf.name
    lbi.blob_name = op_conf.variable_conf.out
    return remote_blob_util.RemoteBlob(lbi)
Ejemplo n.º 4
0
 def AddAndInferOp(self, op_conf: op_conf_util.OperatorConf) -> Any:
     return compile_context.CurJobAddConsistentOp(op_conf)
Ejemplo n.º 5
0
 def InferAndTryRun(self):
     assert hob.in_global_mode(None)
     compile_context.CurJobAddConsistentOp(self.op_conf_)
     return self