Ejemplo n.º 1
0
def less_equal(input1, input2):
    """
    Check whether input1 lessequals to input2.

    Args:
        input1 (tvm.tensor.Tensor): Tensor.
        input2 (tvm.tensor.Tensor): Tensor.

    Returns:
        tvm.tensor.Tensor. If input1 lessequal to input2 return True, else return False.
    """
    shape1 = [x.value for x in input1.shape]
    shape2 = [x.value for x in input2.shape]
    vc_util.check_shape(shape1)
    vc_util.check_shape(shape2)

    shape1, shape2, shape = produce_shapes(shape1, shape2)

    vc_util.elemwise_dtype_check(input1.dtype, input2.dtype)
    dtype = input1.dtype

    # get lessequal compute
    t_value = akg.tvm.compute(shape, lambda *indice: akg.tvm.const(1, dtype), "T")
    f_value = akg.tvm.compute(shape, lambda *indice: akg.tvm.const(0, dtype), "F")

    input1_bro = akg.topi.broadcast_to(input1, shape)
    input2_bro = akg.topi.broadcast_to(input2, shape)
    c_out = akg.tvm.compute(shape, lambda *indice: akg.tvm.expr.Select(input1_bro[indice] <= input2_bro[indice],
                                                                         t_value[indice], f_value[indice]), name="C")
    res = akg.tvm.compute(shape, lambda *indice: c_out(*indice).astype("bool"), name="res")

    return res
Ejemplo n.º 2
0
def bitwise_or(x1, x2, target=utils.CCE):
    """
    Computes the bitwise or of `x1` and `x2`.

    Args:
        x1 (tvm.tensor.Tensor): Tensor of type int16, uint16.
        x2 (tvm.tensor.Tensor): Tensor of type int16, uint16.

    Returns:
        tvm.tensor.Tensor, has the same type as x1.
    """
    # check shape
    utils.check_shape(x1)
    utils.check_shape(x2)
    _, _, output_shape = produce_shapes(get_shape(x1), get_shape(x2))

    # check input tensor data_type
    utils.ops_dtype_check(
        [x1.dtype, x2.dtype],
        [utils.DtypeForDavinci.INT16, utils.DtypeForDavinci.UINT16])
    dtype = x1.dtype
    if dtype != x2.dtype:
        raise RuntimeError("input type must be same, but got %s  vs %s", dtype,
                           x2.dtype)

    x1 = akg.topi.broadcast_to(x1, output_shape)
    x2 = akg.topi.broadcast_to(x2, output_shape)
    res = akg.tvm.compute(output_shape,
                          lambda *indice: x1(*indice) | x2(*indice))
    return res
Ejemplo n.º 3
0
def bitwise_and(x1, x2):
    """
    Computes the bitwise and of `x1` and `x2`.

    Args:
        x1 (tvm.tensor.Tensor): tensor x1, only support int16,uint16.
        x2 (tvm.tensor.Tensor): tensor x2, only support int16,uint16.

    Returns:
        A tvm.tensor.Tensor as result of bitwise and.
    """
    _check_parameters(x1, x2)

    shape_x = get_shape(x1)
    shape_y = get_shape(x2)
    _, _, shape_max = produce_shapes(shape_x, shape_y)

    data_x = topi.broadcast_to(x1, shape_max)
    data_y = topi.broadcast_to(x2, shape_max)

    res = tvm.compute(data_x.shape,
                      lambda *i: data_x(*i) & data_y(*i),
                      name="and_res")

    return res
Ejemplo n.º 4
0
def _div_ascend(data1, data2):
    """
    Calculates x/y, and returns an integer when inputs are all integers.

    When both arguments are integers, use integer division (also known as "floor division").
    When arguments are float numbers, use normal floating point division

    Note:
        div supports broadcasting.

    Args:
        data1 (tvm.tensor.Tensor): Tensor of type float16, float32, int32, int8 and uint8.
        data2 (tvm.tensor.Tensor): Tensor of type float16, float32, int32, int8 and uint8.

    Returns:
        tvm.tensor.Tensor, has the same type as data1 and data2.
    """

    utils.ops_dtype_check([data1.dtype, data2.dtype],
                          utils.DtypeForDavinci.ALL_TYPES)
    utils.elemwise_dtype_check(data1.dtype, data2.dtype)
    dtype = data1.dtype

    shape1 = [x.value for x in data1.shape]
    shape2 = [x.value for x in data2.shape]
    utils.check_shape(shape1)
    utils.check_shape(shape2)

    utils.auto_broadcast_check(shape1, shape2)
    n_shape1, n_shape2, out_shape = produce_shapes(shape1, shape2)
    if n_shape1 != out_shape:
        input1_cast = akg.topi.broadcast_to(data1, out_shape)
    else:
        input1_cast = data1
    if n_shape2 != out_shape:
        input2_cast = akg.topi.broadcast_to(data2, out_shape)
    else:
        input2_cast = data2

    if dtype in ("int32", "int8", "uint8"):
        input1p = Case(input1_cast, "float16", utils.CCE)
        input2p = Cast(input2_cast, "float16", utils.CCE)
    else:
        input1p = input1_cast
        input2p = input2_cast

    if product_is_mini():
        input2p_rec = reciprocal(input2p, target=utils.CCE)
        res = akg.topi.multiply(input1p, input2p_rec)
    else:
        res = akg.topi.divide(input1p, input2p)

    if dtype in ("int8", "uint8"):
        res = floor(res, utils.CCE)
        res = Cast(res, "float16", utils.CCE)
    if dtype in ("int32", "int8", "uint8"):
        res = Cast(res, dtype, utils.CCE)

    return res
Ejemplo n.º 5
0
def xdivy_grad(x1, x2, grad):
    """
    Returns gradient of xdivy(x1, x2) with respect to x1 and x2.

    Args:
        x1 (tvm.tensor.Tensor): Tensor of dtype "float16" or "float32".
        x2 (tvm.tensor.Tensor): Tensor of dtype "float16" or "float32".
        grad (tvm.tensor.Tensor): Gradient tensor of dtype "float16" or "float32".

    Returns:
        Two tvm.tensor.Tensor as gradients for x1 and x2.
    """
    shape_x1 = get_shape(x1)
    dtype_x1 = x1.dtype
    shape_x2 = get_shape(x2)
    dtype_x2 = x2.dtype
    shape_grad = get_shape(grad)
    dtype_grad = grad.dtype
    if dtype_x1 != dtype_x2 or dtype_x2 != dtype_grad or dtype_grad != dtype_x1:
        raise RuntimeError(
            "the type of x1, x2 and grad must be the same,"
            "while dtype_x1 = %s, dtype_x2 = %s, dtype_grad = %s" %
            (dtype_x1, dtype_x2, dtype_grad))

    vc_util.check_shape(shape_x1)
    vc_util.check_shape(shape_x2)
    vc_util.check_shape(shape_grad)

    vc_util.ops_dtype_check(dtype_x1, vc_util.DtypeForDavinci.ALL_FLOAT)
    shape_x1, shape_x2, shape_max_x1x2 = produce_shapes(shape_x1, shape_x2)
    if len(shape_max_x1x2) < len(shape_grad):
        raise RuntimeError(
            "the length of shape_grad can not be longer than the maximum "
            "length of x1 and x2, while shape_grad = %s, shape_max= %s" %
            (list(shape_grad), shape_max_x1x2))

    shape_grad, _, shape_max = produce_shapes(shape_grad, shape_max_x1x2)
    for (x, y) in zip(shape_max_x1x2, shape_grad):
        if x < y:
            raise RuntimeError(
                "Don't support this shape. while shape_max = %s, shape_grad "
                "= %s" % (shape_max_x1x2, list(shape_grad)))

    rx, ry = broadcast_gradient_args(shape_x1, shape_x2)
    return xdivy_grad_compute([x1, x2, grad], shape_max, dtype_x1, rx, ry)
Ejemplo n.º 6
0
def gen_data(shape1, shape2, dtype):
    x1 = random_gaussian(shape1, miu=1, sigma=0.3).astype(dtype)
    x2 = random_gaussian(shape2, miu=1, sigma=0.3).astype(dtype)

    _, _, out_shape = produce_shapes(shape1, shape2)
    expect = np.where(np.equal(x1, 0.), np.zeros_like(np.multiply(x1, x2)),
                      np.divide(x1, x2))
    output = np.full(out_shape, np.nan, dtype)
    return expect, (x1, x2), output
Ejemplo n.º 7
0
def gen_expect(input1, input2):
    a, b, out_shape = produce_shapes(input1.shape, input2.shape)
    n_input1 = np.broadcast_to(input1, out_shape)
    n_input2 = np.broadcast_to(input2, out_shape)

    sign2 = np.sign(n_input2)
    input2 = np.add(np.abs(n_input2), 1)
    input2 = np.multiply(n_input2, sign2)
    expect = np.divide(n_input1, n_input2)
    return expect
Ejemplo n.º 8
0
def equal_count(x, y, target=utils.CCE):
    """
    compute equal num of x and y.

    Args:
        x (tvm.tensor.Tensor): Tensor of type int32.
        y (tvm.tensor.Tensor): Tensor of type int32.

    Returns:
        tvm.tensor.Tensor, equal num, type is int32.

    Supported Platforms:
        'Ascend'
    """
    # check shapes
    shape1 = get_shape(x)
    shape2 = get_shape(y)
    shapes = [shape1, shape2]
    for _, shape_ in enumerate(shapes):
        utils.check_shape(shape_)
    if len(shape1) != 1 or len(shape2) != 1:
        raise RuntimeError("Two inputs should all be one dim!")

    # check types
    dtype = x.dtype
    utils.ops_dtype_check([x.dtype, y.dtype], utils.DtypeForDavinci.INT32)

    # Due to instruction limitations, the int32 data needs to be converted to
    # float16 or float32.
    # When the int32 data is casted to float16, there may be overflow problems,
    # so as far as possible the int32 data is casted to float32.
    orig_dtype = dtype
    if product_is_mini():
        dtype = "float16"
    else:
        dtype = "float32"
    x = Cast(x, dtype, target)
    y = Cast(y, dtype, target)

    shape1, shape2, shape = produce_shapes(shape1, shape2)
    t = akg.tvm.compute(shape, lambda *indice: akg.tvm.const(1, dtype), "t")
    f = akg.tvm.compute(shape, lambda *indice: akg.tvm.const(0, dtype), "f")
    x = akg.topi.broadcast_to(x, shape)
    y = akg.topi.broadcast_to(y, shape)
    z = akg.tvm.compute(shape,
                        lambda *indice: akg.tvm.expr.Select(
                            x[indice] == y[indice], t[indice], f[indice]),
                        name="z")
    res = sum(z, target=target)
    if res.dtype != orig_dtype:
        res = Cast(res, orig_dtype, target)
    return res
Ejemplo n.º 9
0
def xdivy_compute(input_x, input_y):
    """xdivy compute"""
    _, _, shape_res = produce_shapes(get_shape(input_x), get_shape(input_y))
    vc_util.check_shape(shape_res)

    dtype = input_x.dtype

    broadcast_x = akg.lang.cce.broadcast(input_x, shape_res)
    broadcast_y = akg.lang.cce.broadcast(input_y, shape_res)
    broadcast_one = akg.lang.cce.broadcast(tvm.const(SCALAR_ONE, dtype),
                                           shape_res, dtype)

    abs_x = akg.lang.cce.vabs(broadcast_x)
    abs_y = akg.lang.cce.vabs(broadcast_y)
    add_x_y = akg.lang.cce.vadd(abs_x, abs_y)

    if dtype == "float32":
        data_min = akg.lang.cce.broadcast(
            tvm.const(MININUM_NUM_FLOAT, dtype=dtype), shape_res, dtype)
    elif dtype == "float16":
        data_min = akg.lang.cce.broadcast(
            tvm.const(MININUM_NUM_HALF, dtype=dtype), shape_res, dtype)

    zero_x_y = akg.lang.cce.vmin(add_x_y, data_min)

    if dtype == "float32":
        data_mul1 = akg.lang.cce.vmuls(
            zero_x_y, tvm.const(MAX_ONE_CONST_FLOAT, dtype=dtype))
        data_mul2 = akg.lang.cce.vmuls(
            data_mul1, tvm.const(MAX_ONE_CONST_FLOAT, dtype=dtype))
        mul_data = akg.lang.cce.vmuls(
            data_mul2, tvm.const(MAX_TWO_CONST_FLOAT, dtype=dtype))
    elif dtype == "float16":
        data_mul1 = akg.lang.cce.vmuls(zero_x_y,
                                       tvm.const(MAX_CONST_HALF, dtype=dtype))
        mul_data = akg.lang.cce.vmuls(data_mul1,
                                      tvm.const(MAX_CONST_HALF, dtype=dtype))

    sub_x_y_zero = akg.lang.cce.vsub(mul_data, broadcast_one)
    abs_x_y_zero = akg.lang.cce.vabs(sub_x_y_zero)
    input_y_revised = akg.lang.cce.vadd(broadcast_y, abs_x_y_zero)

    if dtype == "float16":
        broadcast_x = akg.lang.cce.cast_to(broadcast_x, "float32")
        input_y_revised = akg.lang.cce.cast_to(input_y_revised, "float32")

    res = div(broadcast_x, input_y_revised)

    if dtype == "float16":
        res = akg.lang.cce.cast_to(res, dtype)

    return res
Ejemplo n.º 10
0
def _truncatemod_compute_mini(x, y):
    """
    Computes truncatemod value of x and y on mini device.
    Args:
        x(tvm.tensor.Tensor): Tensor, float16.
        y(tvm.tensor.Tensor): Tensor with same type as x.
    Returns:
        tvm.tensor.Tensor of same type as x.
    """
    def truncatemod_positive(x_abs, y_abs):
        """Computes truncatemod value for positive number"""
        x_abs_fp32 = akg.topi.cast(x_abs, "float32")
        y_abs_fp32 = akg.topi.cast(y_abs, "float32")

        def truncatemod_func(a, b):
            """function for truncatemod formula"""
            # For positive numbers, floor and trunc are equivalent
            return akg.topi.subtract(
                a,
                akg.topi.multiply(
                    b,
                    Cast(floor(Divide(a, b, utils.CCE)),
                         b.dtype,
                         target=utils.CCE)))

        mod_value = truncatemod_func(x_abs_fp32, y_abs_fp32)

        # Because there are precision errors in division on mini, etc.,
        # the calculation results need to be corrected
        mod_value = truncatemod_func(mod_value, y_abs_fp32)
        mod_value = akg.topi.cast(mod_value, "float16")
        mod_value = akg.tvm.compute(
            mod_value.shape,
            lambda *indice: akg.tvm.expr.Select(
                mod_value(*indice) >= y_abs(*indice),
                mod_value(*indice) - y_abs(*indice), mod_value(*indice)),
            name="mod_value")
        return mod_value

    _, _, out_shape = produce_shapes(utils.get_shape(x), utils.get_shape(y))
    x = akg.topi.broadcast_to(x, out_shape)
    y = akg.topi.broadcast_to(y, out_shape)

    # Scenarios for correcting calculation results are complex,
    # using absolute values can simplify the scenario:
    # truncatemod(x,y) = Sign(x) * truncatemod(|x|, |y|)
    mod_abs = truncatemod_positive(akg.topi.abs(x), akg.topi.abs(y))
    mod = akg.topi.multiply(akg.topi.sign(x), mod_abs)
    return mod
Ejemplo n.º 11
0
def xlogy_grad_run(shape1, shape2, dtype, attrs):
    _, _, grad_shape = produce_shapes(shape1, shape2)
    mod = utils.op_build_test(xlogy_grad.xlogy_grad,
                              [shape1, shape2, grad_shape],
                              [dtype, dtype, dtype],
                              kernel_name="xlogy_grad", attrs=attrs)
    expects, inputs, outputs = gen_data(shape1, shape2, dtype)
    reses = utils.mod_launch(
        mod, (*inputs, *outputs), expect=expects,
        outputs=(-2, -1))

    rtol, atol = get_rtol_atol("xlogy_grad", dtype)
    TestCase_Results = list(map(lambda x, y: compare_tensor(
        x, y, rtol=rtol, atol=atol, equal_nan=True), reses, expects))

    return inputs, reses, expects, all(TestCase_Results)
Ejemplo n.º 12
0
def gen_data(shape1, shape2, dtype):
    x1 = random_gaussian(shape1, miu=1, sigma=0.3).astype(dtype)
    x2 = random_gaussian(shape2, miu=1, sigma=0.3).astype(dtype)
    shape1, shape2, fout_shape = produce_shapes(shape1, shape2)
    dy = random_gaussian(fout_shape, miu=1, sigma=0.3).astype(dtype)
    rx, ry = xdivy_grad.broadcast_gradient_args(shape1, shape2)
    dx1_bc = np.where(np.equal(x1, 0.), np.zeros_like(np.multiply(x1, x2)),
                      np.divide(1., x2)) * dy
    dx2_bc = np.where(np.equal(x1, 0.), np.zeros_like(np.multiply(x1, x2)),
                      -1. * np.divide(x1, np.square(x2))) * dy
    dx1 = (np.sum(dx1_bc.astype("float64"), axis=tuple(rx), keepdims=True)
           if len(rx) > 0 else dx1_bc).astype(dtype)
    dx2 = (np.sum(dx2_bc.astype("float64"), axis=tuple(ry), keepdims=True)
           if len(ry) > 0 else dx2_bc).astype(dtype)
    output1 = np.full(shape1, np.nan, dtype)
    output2 = np.full(shape2, np.nan, dtype)
    return (dx1, dx2), (x1, x2, dy), (output1, output2)
Ejemplo n.º 13
0
def softplus_grad_compute(input_gradients, input_features):
    """compute for calculations of softplus gradients"""
    shape_dy = get_shape(input_gradients)
    shape_x = get_shape(input_features)
    dtype = input_gradients.dtype

    if list(shape_dy) != list(shape_x):
        shape_dy, shape_x, shape_max = produce_shapes(shape_dy, shape_x)
        input_gradients = akg.lang.cce.broadcast(input_gradients, shape_max,
                                                 dtype)
        input_features = akg.lang.cce.broadcast(input_features, shape_max,
                                                dtype)
    else:
        shape_max = shape_dy

    if dtype != "float32":
        input_gradients = akg.lang.cce.cast_to(input_gradients, "float32")
        input_features = akg.lang.cce.cast_to(
            input_features,
            "float16" if utils.product_is_mini() else "float32")

    data_exp_tmp = akg.lang.cce.vexp(input_features)
    data_add_tmp = akg.lang.cce.vadds(data_exp_tmp, SCALAR_ONE)
    data_div_tmp = div(data_exp_tmp, data_add_tmp)
    res_tmp = akg.lang.cce.vmul(input_gradients, data_div_tmp)

    if dtype == "float16":
        res = akg.lang.cce.cast_to(res_tmp, "float16")
    elif dtype == "int32" or dtype == "int8" or dtype == "uint8":
        data_zero = akg.lang.cce.broadcast(tvm.const(0, "float16"), shape_max,
                                           "float16")
        res_min = akg.lang.cce.vmin(res_tmp, data_zero)
        res_max = akg.lang.cce.vmax(res_tmp, data_zero)
        res_max_int = akg.lang.cce.floor(res_max)
        res_min_int = akg.lang.cce.ceil(res_min)
        res = akg.lang.cce.vadd(res_max_int, res_min_int)
    else:
        res = res_tmp

    if dtype == "int8":
        res = akg.lang.cce.cast_to(res, "int8")
    elif dtype == "uint8":
        res = akg.lang.cce.cast_to(res, "uint8")

    return res
Ejemplo n.º 14
0
def make_input_and_value(data1, data2):
    shape1 = [x.value for x in data1.shape]
    shape2 = [x.value for x in data2.shape]
    utils.check_shape(shape1)
    utils.check_shape(shape2)

    shape1, shape2, shape = produce_shapes(shape1, shape2)

    utils.elemwise_dtype_check(data1.dtype, data2.dtype)
    dtype = data1.dtype

    t_value = akg.tvm.compute(shape, lambda *indice: akg.tvm.const(1, dtype), "T")
    f_value = akg.tvm.compute(shape, lambda *indice: akg.tvm.const(0, dtype), "F")
    
    input1_bro = akg.topi.broadcast_to(data1, shape)
    input2_bro = akg.topi.broadcast_to(data2, shape)
    res = (t_value, f_value, input1_bro, input2_bro, shape)
    return res
Ejemplo n.º 15
0
def RealDiv(input1, input2, target=utils.CCE):
    """
    Returns input1 / input2 element-wise for real types.

    Note:
        Realdiv supports broadcasting.

    Args:
        input1 (tvm.tensor.Tensor): Tensor of type float16, float32.
        input2 (tvm.tensor.Tensor): Tensor of type float16, float32.

    Returns:
        tvm.tensor.Tensor, has the same type of input1 and shaped by broadcasting.
    
    Supported Platforms:
        'Ascend'
    """
    utils.ops_dtype_check([input1.dtype, input2.dtype],
                          utils.DtypeForDavinci.ALL_FLOAT)
    utils.elemwise_dtype_check(input1.dtype, input2.dtype)

    shape1 = [x.value for x in input1.shape]
    shape2 = [x.value for x in input2.shape]
    utils.check_shape(shape1)
    utils.check_shape(shape2)

    utils.auto_broadcast_check(shape1, shape2)
    n_shape1, n_shape2, out_shape = produce_shapes(shape1, shape2)

    if n_shape1 != out_shape:
        input1_cast = akg.topi.broadcast_to(input1, out_shape)
    else:
        input1_cast = input1
    if n_shape2 != out_shape:
        input2_cast = akg.topi.broadcast_to(input2, out_shape)
    else:
        input2_cast = input2

    res = akg.topi.divide(input1_cast, input2_cast)
    return res
Ejemplo n.º 16
0
def _equal(input1, input2):

    shape1 = [x.value for x in input1.shape]
    shape2 = [x.value for x in input2.shape]
    utils.check_shape(shape1)
    utils.check_shape(shape2)

    shape1, shape2, shape = produce_shapes(shape1, shape2)

    utils.elemwise_dtype_check(input1.dtype, input2.dtype)
    dtype = input1.dtype

    # get equal compute
    t_value = akg.tvm.compute(shape, lambda *indice: akg.tvm.const(1, dtype), "T")
    f_value = akg.tvm.compute(shape, lambda *indice: akg.tvm.const(0, dtype), "F")

    input1_bro = akg.topi.broadcast_to(input1, shape)
    input2_bro = akg.topi.broadcast_to(input2, shape)
    c_out = akg.tvm.compute(shape, lambda *indice: akg.tvm.expr.Select(input1_bro[indice] == input2_bro[indice],
                                                                         t_value[indice], f_value[indice]), name="C")
    res = akg.tvm.compute(shape, lambda *indice: c_out(*indice).astype("bool"), name="res")

    return res
Ejemplo n.º 17
0
Archivo: add.py Proyecto: zhuyawen/akg
def add(first_input, second_input, scale=1.0, polyhedral=True, attrs=None):
    """
    Computes first_input + second_input * scale elementwise.

    Args:
        first_input (tvm.tensor.Tensor): Tensor of type float16, float32, int32.
        second_input (tvm.tensor.Tensor): Tensor with same type as first_input.
                                      Broadcast will happen if shapes of input tensors are different.
        scale (float): scale factor applied on second_input, default value is 1.0.
        polyhedral (bool): If True, use auto-schedule, else use manual-schedule, default value is True.
        attrs (dict): Specifies parameters used in manual-schedule.

    Returns:
        tvm.tensor.Tensor of same type as input tensor with shape the broadcast shape of input tensors.
    """
    vc_util.check_shape(first_input.shape)
    vc_util.check_shape(second_input.shape)
    attr_map = {}

    first_input_shape = get_shape(first_input)
    second_input_shape = get_shape(second_input)

    if shape_is_dynamic([first_input, second_input]):
        if first_input_shape != second_input_shape:
            raise RuntimeError(
                "Input tensors have different shapes, broadcast is not supported for dynamic."
            )
        first_broadcast = first_input
        second_broadcast = second_input
    else:
        if first_input_shape != second_input_shape:
            _, _, out_shape = produce_shapes(first_input_shape,
                                             second_input_shape)
        else:
            out_shape = first_input_shape
        first_broadcast = akg.topi.broadcast_to(first_input, out_shape)
        second_broadcast = akg.topi.broadcast_to(second_input, out_shape)

    first_input_type = first_input.dtype
    second_input_type = second_input.dtype
    if first_input_type != second_input_type:
        raise TypeError("Input tensors have different data types.")
    vc_util.ops_dtype_check(first_input_type,
                            vc_util.DtypeForDavinci.ALL_TYPES)

    temp = vmuls(second_broadcast, scale)
    res = vadd(first_broadcast, temp)
    res_cast = res.astype(first_input_type)
    if polyhedral:
        return res_cast, attr_map

    def comp_func(s):
        first_ub = s.cache_read(first_input, "local.UB", [first_broadcast])
        second_ub = s.cache_read(second_input, "local.UB", [second_broadcast])
        res_cast_ub = s.cache_write(res_cast, "local.UB")

        s[first_broadcast].set_scope("local.UB")
        s[second_broadcast].set_scope("local.UB")
        s[temp].set_scope("local.UB")
        s[res].set_scope("local.UB")

        split_axis = []
        for i in range(len(attrs["tile"])):
            outer, inner = s[res_cast].split(res_cast.op.axis[i],
                                             attrs["tile"][i])
            axis_dict = {"outer": outer, "inner": inner}
            split_axis.append(axis_dict)

        s[first_ub].compute_at(s[res], res.op.axis[0])
        s[second_ub].compute_at(s[res], res.op.axis[0])

        s[first_broadcast].compute_at(s[res], res.op.axis[0])
        s[second_broadcast].compute_at(s[res], res.op.axis[0])

        s[temp].compute_at(s[res], res.op.axis[0])
        s[res].compute_at(s[res_cast_ub], res_cast_ub.op.axis[0])

        s[res_cast_ub].compute_at(s[res_cast], split_axis[-1]['outer'])

        # no scaling nedeed
        if scale == 1:
            s[temp].compute_inline()

        # no broadcast needed
        if first_input_shape == second_input_shape:
            s[first_broadcast].compute_inline()
            s[second_broadcast].compute_inline()

    return res_cast, comp_func, attr_map