Exemplo n.º 1
0
def FeedValueToEagerBlob(blob_object, blob_def, ndarray):
    physical_blob_objects = _GetPhysicalBlobObjects(blob_object, None)
    feed_ctx = FeedContext(blob_object.op_arg_parallel_attr, ndarray)
    for i, physical_blob_object in enumerate(physical_blob_objects):
        feed_ctx.set_rank(i)
        _FeedValueToInputPhysicalBlob(feed_ctx, blob_def, physical_blob_object)
    oneflow_api.TryDisableBlobCache(blob_object)
Exemplo n.º 2
0
def _LogicalSliceAssign(
    ref_blob: EagerBlobTrait,
    value_blob: EagerBlobTrait,
    start: Sequence[int],
    stop: Sequence[int],
) -> None:
    """
    Construct a logical_slice_assign op and run it by oneflow eager
    """
    ref_blob_object = ref_blob.blob_object
    value_blob_object = value_blob.blob_object

    def BuildAssignInstruction(builder):
        op_conf = op_conf_pb.OperatorConf()
        # device_tag doesn't matter for logical_slice_assign op
        device_tag = oneflow.current_scope(
        ).device_parallel_desc_symbol.device_tag
        op_conf.device_tag = device_tag
        op_name = id_util.UniqueStr(OP_PREFIX)
        op_conf.name = op_name
        op_conf.user_conf.op_type_name = "logical_slice_assign"
        op_conf.user_conf.input["value"].s.append("{}/value_0".format(op_name))
        op_conf.user_conf.input["ref"].s.append("{}/ref_0".format(op_name))
        parallel_conf = ref_blob_object.parallel_desc_symbol.parallel_conf
        op_conf.user_conf.attr["parallel_conf"].at_string = str(parallel_conf)
        op_conf.user_conf.attr["start"].at_list_int64.val[:] = start
        op_conf.user_conf.attr["stop"].at_list_int64.val[:] = stop
        op_conf.user_conf.attr["step"].at_list_int64.val[:] = [1] * len(start)
        bn_in_op2blob_object = oneflow_api.deprecated.BnInOp2BlobObject()
        bn_in_op2blob_object["ref_0"] = ref_blob_object
        bn_in_op2blob_object["value_0"] = value_blob_object
        scope_symbol_id = _GetScopeSymbolIdFromEagerBlob(ref_blob)
        op_attribute = op_infer_util.Infer(op_conf, bn_in_op2blob_object,
                                           scope_symbol_id)
        cfg_op_attribute = oneflow_api.deprecated.MakeOpAttributeByString(
            str(op_attribute))
        builder.StatelessCall(
            cfg_op_attribute,
            parallel_conf,
            bn_in_op2blob_object,
            boxing_util.BoxingTo,
        )

    oneflow_api.deprecated.LogicalRun(BuildAssignInstruction)
    oneflow_api.TryDisableBlobCache(ref_blob_object)
Exemplo n.º 3
0
def BuildAssignInstruction(builder, ref_blob_object, value_blob_object,
                           op_conf):
    oneflow_api.TryDisableBlobCache(ref_blob_object)
    ref_parallel_conf = ref_blob_object.parallel_desc_symbol.parallel_conf
    ref_devices = ref_blob_object.parallel_desc_symbol.machine_id2device_id_list
    value_devices = value_blob_object.parallel_desc_symbol.machine_id2device_id_list
    assert ref_devices == value_devices, "\nref_devices: %s\nvalue_devices: %s" % (
        ref_devices,
        value_devices,
    )
    ref_device_tag = ref_blob_object.parallel_desc_symbol.device_tag
    value_device_tag = value_blob_object.parallel_desc_symbol.device_tag
    bn_in_op2blob_object = oneflow_api.deprecated.BnInOp2BlobObject()
    bn_in_op2blob_object["ref"] = ref_blob_object
    bn_in_op2blob_object["value"] = value_blob_object
    op_attribute = op_infer_util.Infer(op_conf, bn_in_op2blob_object)
    cfg_op_attribute = oneflow_api.deprecated.MakeOpAttributeByString(
        str(op_attribute))
    if ref_device_tag == value_device_tag:
        builder.NoBoxingStatelessCall(
            cfg_op_attribute,
            ref_parallel_conf,
            bn_in_op2blob_object,
        )
    elif ref_device_tag == "cpu" and value_device_tag == "gpu":
        value_parallel_conf = value_blob_object.parallel_desc_symbol.parallel_conf
        builder.NoBoxingCudaD2HStatelessCall(
            cfg_op_attribute,
            value_parallel_conf,
            bn_in_op2blob_object,
            TryReplaceDeviceTag,
        )
    elif ref_device_tag == "gpu" and value_device_tag == "cpu":
        with _CudaHostPinBlob(builder, value_blob_object):
            builder.NoBoxingCudaH2DStatelessCall(
                cfg_op_attribute,
                ref_parallel_conf,
                bn_in_op2blob_object,
            )
    else:
        raise NotImplementedError(
            "invalid device found. ref_device_tag: %s, value_device_tag: %s" %
            (ref_device_tag, value_device_tag))