Ejemplo n.º 1
0
def _GetPhysicalBlobObjects(logical_blob_object, lbi):
    blob_register = oneflow_api.GetDefaultBlobRegister()
    physical_blob_objects = None

    def BuildLogical2PhysicalInstruction(builder):
        nonlocal physical_blob_objects
        physical_blob_objects = builder.UnpackLogicalBlobToPhysicalBlobs(
            logical_blob_object)

    oneflow_api.deprecated.LogicalRun(BuildLogical2PhysicalInstruction)
    return physical_blob_objects
Ejemplo n.º 2
0
def MirroredCast(op_attribute_str, parallel_conf):
    op_attribute = text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())
    blob_register = oneflow_api.GetDefaultBlobRegister()
    is_cast_to_mirrored = op_attribute.op_conf.HasField("cast_to_mirrored_conf")
    is_cast_from_mirrored = op_attribute.op_conf.HasField("cast_from_mirrored_conf")
    assert is_cast_to_mirrored or is_cast_from_mirrored
    _MirroredCastAndAddOutputBlobReleaser(op_attribute, blob_register)
    bw_blob_register = gradient_util.GetDefaultBackwardBlobRegister()
    gradient_util.TrySetBackwardUsedBlobObject(
        op_attribute, blob_register, bw_blob_register
    )
Ejemplo n.º 3
0
import oneflow
import oneflow.python.framework.input_blob_def as input_blob_def
import oneflow.python.framework.dtype as dtype_util
import oneflow.python.framework.python_callback as python_callback
import oneflow.python.framework.balanced_splitter as balanced_splitter
import oneflow.python.framework.remote_blob as remote_blob_util
import oneflow.python.framework.id_util as id_util
import oneflow.python.eager.boxing_util as boxing_util
import oneflow.core.operator.op_conf_pb2 as op_conf_util
import oneflow.core.register.logical_blob_id_pb2 as logical_blob_id_util
import oneflow_api.oneflow.core.register.logical_blob_id as lbi_util
import oneflow_api
import numpy
from functools import reduce

blob_register = oneflow_api.GetDefaultBlobRegister()


def AsyncPush(session, job_func, *arg):
    assert len(arg) == len(job_func.__oneflow_input_blob_defs__)
    for i in range(len(arg)):
        _AsyncPushArg(session, job_func.__oneflow_input_blob_defs__[i], arg[i])


def _AsyncPushArg(session, arg_blob_def, arg_ndarray):
    if isinstance(arg_blob_def, (list, tuple)):
        assert isinstance(
            arg_ndarray,
            (list, tuple)), "type(arg_ndarray): %s" % (type(arg_ndarray))
        assert len(arg_blob_def) == len(arg_ndarray), "%s v.s. %s" % (
            len(arg_blob_def),
Ejemplo n.º 4
0
 def ForceReleaseEagerBlobs(self):
     oneflow_api.GetDefaultBlobRegister().ForceReleaseAll()
     self.backward_blob_register_.ForceReleaseAll()