Esempio n. 1
0
def reduce_any_d(x, axis=None, keepdims=False):
    """
    Reduce a tensor on a certain axis based on max.

    Args:

        x (tvm.tensor.Tensor): The input tensor to reduce. Should be of type int8.
        axis (Union[list, tuple, int, None]): The dimensions to reduce. If None, all dimensions will be reduced.
                                              each dim must be in the range [-len(data.shape), len(data.shape) - 1].
        keepdims (Union[bool, None]): If True, retains reduced dimensions with length 1, defaults to False.

    Returns:
        tvm.tensor.Tensor of same type as input tensor x.
    """
    # check type
    vc_util.ops_dtype_check(x.dtype, vc_util.DtypeForDavinci.INT8)
    vc_util.check_shape(x.shape)
    # check axis
    vc_util.reduce_axis_check(x.shape, axis)
    refined_axis = refine_reduce_axis(x, axis)
    if len(set(refined_axis)) == len(x.shape) and not keepdims:
        keepdims = True
    res = _reduce_any_d_compute(x, refined_axis, keepdims)
    if len(set(refined_axis)) == len(x.shape):
        res = topi.reshape(res, (1, ))
    return res
Esempio n. 2
0
def reduce_prod(data, axis=None, keepdims=False):
    """
    Computes the product of elements along specific axis

    Args:
        data (tvm.tensor.Tensor): indicating the input tensor.
        axis (Union[list, tuple, int, None]): indicating the dimensions to reduce at. if it's None, all dimensions
                                               will be reduced.
        keepdims (Union[bool, None]): if true, keep the dimensions with length 1.

    Returns:
    Tensor, the product of elements of input tensor.
    """
    shape = [x.value for x in data.shape]
    ops_dtype_check(data.dtype, [
        DtypeForDavinci.ALL_FLOAT, DtypeForDavinci.INT8, DtypeForDavinci.UINT8
    ])

    if axis is None and keepdims is False:
        raise ValueError("keepdims must be True when axis is None!")

    axis_new = ft_util.refine_reduce_axis(data, axis)

    check_shape(shape)
    dtype = data.dtype
    if dtype in ["int8", "uint8"]:
        data = akg.topi.cast(data, "float16")

    vlog_t = akg_log(data)
    res = akg.topi.sum(vlog_t, axis=axis_new, keepdims=keepdims)
    res = akg_exp(res)

    if dtype in ["int8", "uint8"]:
        res = akg.topi.cast(res, dtype)
    return res
Esempio n. 3
0
File: sum.py Progetto: zhuyawen/akg
def sum_value(inputs, axis=None, keepdims=False):
    """
    Computes the sum value of a tensor along the given axes.

    Args:
        inputs (tvm.tensor.Tensor): Tensor of type float16, float32.
        axis (Union[list, tuple, int, None]): Specifies which axis or axes to reduce.
        keepdims (bool): If true, the dimension specified by axis will be one.

    Returns:
        tvm.tensor.Tensor with same type as input tensor.
    """

    # Check types
    dtype = inputs.dtype
    vc_util.ops_dtype_check(dtype, vc_util.DtypeForDavinci.ALL_FLOAT)
    axis = ft_util.refine_reduce_axis(inputs, axis)
    vc_util.check_shape(inputs.shape)

    if not axis:
        output = akg.topi.identity(inputs)
    else:
        output = akg.topi.sum(inputs, axis=axis, keepdims=keepdims)
    attr_map = get_attrs()
    return output, attr_map
Esempio n. 4
0
def mean_v2(data, axis=None, keepdims=False, target=utils.CCE):
    """
    Simple implementation of mean.

    Supported Platforms:
        'Ascend'
    """
    # Check types
    utils.ops_dtype_check(data.dtype, utils.DtypeForDavinci.ALL_FLOAT)

    # Check shape
    shape = [x.value for x in data.shape]
    utils.reduce_axis_check(shape, axis)
    axis = ft_util.refine_reduce_axis(data, axis)

    dtype = data.dtype
    count = 1
    for i in axis:
        count *= shape[i]

    count_rec = 1 / count
    output = sum_v2(data, axis, keepdims, target=target)
    res = output * akg.tvm.const(count_rec, dtype)
    attrs = get_attrs(data)
    if shape_is_dynamic(data):
        attrs["custom_tiling"] = mean_dynamic_tiling_strategy(data, axis)
    return res, attrs
Esempio n. 5
0
def sum_v2(inputs, axis=None, keepdims=True, target=utils.CCE):
    """
    another implementation of sum with topi api.

    Supported Platforms:
        'Ascend'
    """
    if target != utils.CCE:
        raise RuntimeError('operator not supported on %s' %
                           utils.get_backend(target))

    dtype = inputs.dtype
    utils.ops_dtype_check(dtype, utils.DtypeForDavinci.ALL_FLOAT)
    axis = ft_util.refine_reduce_axis(inputs, axis)
    utils.check_shape(inputs.shape)
    if not axis:
        output = akg.topi.identity(inputs)
    else:
        if dtype == "float16":
            step_sum = Cast(inputs, "float32", target)
        else:
            step_sum = inputs

        step_sum = akg.topi.sum(step_sum, axis=axis, keepdims=keepdims)

        if dtype == "float16":
            output = Cast(step_sum, "float16", target)
        else:
            output = step_sum
    return output
Esempio n. 6
0
def reduce_sum(inputs, axis=None, keepdims=False, target=utils.CCE):
    """
    Compute the sum of elements across dimensions of a tensor.

    Args:
        inputs (tvm.tensor.Tensor): Tensor.
        axis (Union[list, tuple, int, None]): If the list or tuple is empty, the axis equal to None.
        keepdims (bool): If keepdims equal to True, the result shape length is same to input shape length.

    Returns:
        tvm.tensor.Tensor, has same type as input. If keepdims is True, all reduced dimensions are retained
        with length 1, else these reduced axis will be eliminate.

    Supported Platforms:
        'Ascend', 'GPU', 'CPU'
    """
    utils.check_supported_target(target)
    if target == utils.CCE:
        return ascend_sum(inputs, axis, keepdims)
    axis = refine_reduce_axis(inputs, axis)
    utils.check_shape(inputs.shape)

    in_dtype = inputs.dtype
    if in_dtype == 'float16':
        inputs = akg.topi.cast(inputs, 'float32')

    output = akg.topi.sum(inputs, axis=axis, keepdims=keepdims)

    if in_dtype == 'float16':
        output = akg.topi.cast(output, 'float16')

    return output
Esempio n. 7
0
def reduce_sum(inputs, axis=None, keepdims=False):
    """
    Compute the sum of elements across dimensions of a tensor.

    Args:
        inputs (tvm.tensor.Tensor): Tensor.
        axis (Union[list, tuple, int, None]): If the list or tuple is empty, the axis equal to None.
        keepdims (bool): If keepdims equal to True, the result shape length is same to input shape length.

    Returns:
        tvm.tensor.Tensor, has same type as input. If keepdims is True, all reduced dimensions are retained
        with length 1, else these reduced axis will be eliminate.
    """
    axis = ft_util.refine_reduce_axis(inputs, axis)
    vc_util.check_shape(inputs.shape)

    in_dtype = inputs.dtype
    if in_dtype == 'float16':
        inputs = akg.topi.cast(inputs, 'float32')

    output = akg.topi.sum(inputs, axis=axis, keepdims=keepdims)

    if in_dtype == 'float16':
        output = akg.topi.cast(output, 'float16')

    return output
Esempio n. 8
0
def sum(inputs, axis=None, keepdims=False, target=utils.CCE):
    """
    Compute the sum of elements across dimensions of a tensor.

    Args:
        inputs (tvm.tensor.Tensor): Tensor.
        axis (Union[list, tuple, int, None]): If the list or tuple is empty, the axis equal to None.
        keepdims (bool): If keepdims equal to True, the result shape length is same to input shape length.

    Returns:
        tvm.tensor.Tensor, has same type as input. If keepdims is True, all reduced dimensions are retained
        with length 1, else these reduced axis will be eliminate.

    Supported Platforms:
        'Ascend', 'GPU', 'CPU'
    """
    # Check types
    if target == utils.CCE:
        dtype = inputs.dtype
        utils.ops_dtype_check(dtype, utils.DtypeForDavinci.ALL_FLOAT)
    axis = ft_util.refine_reduce_axis(inputs, axis)
    utils.check_shape(inputs.shape)

    if not axis:
        output = akg.topi.identity(inputs)
    else:
        output = akg.topi.sum(inputs, axis=axis, keepdims=keepdims)
    return output
Esempio n. 9
0
def _reduce_min_max_ascend(data, axis=None, keepdims=False, method="min"):
    """
    Computes the maximum or minimum of elements over a given axis or a list of axes of a tensor.

    Args:
        data (tvm.tensor.Tensor): The input tensor to reduce. Should be of type float16, float32, int8, uint8, int32.
        axis (Union[list, tuple, int, None]): The dimensions to reduce.
                                      If None, all dimensions will be reduced.
                                      If int or list, must be in the range [-len(data.shape), len(data.shape) - 1].
        keepdims (bool): If True, retains reduced dimensions with length 1, default value is False.
        method (str): Specifies to compute maximum or minimum of input tensor, default value is min.

    Returns:
        tvm.tensor.Tensor of same type as input tensor data.
    """
    # check shape
    utils.check_shape(data.shape)

    # check type
    dtype = data.dtype
    utils.ops_dtype_check(dtype, utils.DtypeForDavinci.ALL_TYPES)

    # check axis
    shape_len = len(data.shape)
    if axis is None:
        axis = range(shape_len)
    if hasattr(axis, 'index'):
        axis = list(axis)
    if isinstance(axis, int):
        axis = [axis]
    utils.is_valid_reduce_axis(data, axis)
    refined_axis = refine_reduce_axis(data, axis)
    if len(set(refined_axis)) == len(data.shape) and not keepdims:
        raise ValueError(
            "When reducing on all axes of input, keepdim should be set to True."
        )
    # check method
    method_list = ["min", "max"]
    if method not in method_list:
        raise ValueError("supported method %s while given method is %s" %
                         (",".join(method_list), method))

    # In the emit_insn pass, for vmin and vmax, reduce_last_axis only support float16.
    if dtype != "float16":
        data = Cast(data, "float16", target="cce")

    if method == "min":
        res = akg.topi.min(data, axis=axis, keepdims=keepdims)
    else:
        res = akg.topi.max(data, axis=axis, keepdims=keepdims)

    if res.dtype != dtype:
        res = Cast(res, dtype, target="cce")

    return res
Esempio n. 10
0
def Softmax(data, axis, target=utils.CCE):
    """
    Map all element of data to (0,1) and sum to 1.

    Args:
        data (tvm.tensor.Tensor): input.
        axis (int): along which normalization is applied.

    Return:
        tvm.tensor.Tensor, output.
    
    Supported Platforms:
        'Ascend'
    """
    utils.check_shape(data.shape)
    shape = data.shape

    utils.ops_dtype_check(data.dtype, utils.DtypeForDavinci.ALL_FLOAT)
    utils.reduce_axis_check(shape, axis)
    axis = ft_util.refine_reduce_axis(data, axis)

    if isinstance(axis, (list, tuple)):
        if len(axis) != 1:
            raise RuntimeError(
                "Reduce axis for softmax op must be 1-dimension, while current is %d-dimension"
                % (len(axis)))
        axis = axis[0]
    output = softmax_op(data, axis, shape)
    attr_map = {}
    if ds.shape_is_dynamic(data):
        # For shifted loops, should have:
        #     dynamic_shape_bound mod tile_size_prime == 2
        # This aims to ensure that the shift constant is a multiple of tile_size_prime.
        # So the generated IR will not have complicated head and tail for shifted blocks.
        attr_map = {
            "pragma_modshift":
            1,
            "pragma_outerband_need_split":
            1,
            "enable_post_poly_loop_partition":
            False,
            "pragma_disable_whole_component":
            False,
            "dynamic_shape":
            ds.set_dynamic_shape_limit_for_tensor(output, 2048, axis) +
            ds.set_poly_upper_bound_for_tensor(output, 2048, axis),
            "custom_tiling":
            ct.create_constraint_on_tensor(
                tensor=output,
                values=[1 for i, _ in enumerate(shape) if i != axis],
                constraints=ct.TileConstraint.FACTOR,
                tensor_pos=[i for i, _ in enumerate(shape) if i != axis])
        }
    return output, attr_map
Esempio n. 11
0
def _reduce_max(inputs, axis=None, keepdims=False):
    axis = refine_reduce_axis(inputs, axis)
    utils.check_shape(inputs.shape)

    in_dtype = inputs.dtype
    if in_dtype == 'float16':
        inputs = akg.topi.cast(inputs, 'float32')

    output = akg.topi.max(inputs, axis=axis, keepdims=keepdims)

    if in_dtype == 'float16':
        output = akg.topi.cast(output, 'float16')

    return output
Esempio n. 12
0
def mean(data, axis=None, keepdims=False, target=utils.CCE):
    """
    Computes the mean of the values of a Tensor over the whole dataset.

    Note:
        If the tuple's elements are unsorted, this function will call preprocess_axis firstly to let these elements
        sorted. if tuple is empty, this function will compute all elements' sum.
        if the data type is folat 16 and the whole dim not less than 65536, this function will compute the mean by
        divide 65535 first to avoid whole dim too large.

    Args:
        data (tvm.tensor.Tensor): Tensor of type float16, float32.
        axis (Union[list, tuple, int, None]): If the tuple is empty, the axis equal to None.
        keepdims (bool): If keepdims equal to True, the result shape length is same to input shape length.

    Returns:
            tvm.tensor.Tensor, has the same type as data. If keepdims equal to True, all reduced dimensions are
            retained with length 1. else these reduced axis will be eliminate.

    Supported Platforms:
        'Ascend'
    """
    # Check types
    utils.ops_dtype_check(data.dtype, utils.DtypeForDavinci.ALL_FLOAT)

    # Check shape
    shape = ft_util.get_shape(data)
    utils.reduce_axis_check(shape, axis)
    axis = ft_util.refine_reduce_axis(data, axis)

    count = 1
    for i in axis:
        count *= shape[i]
    output = sum(data, axis, keepdims, target=target)

    if shape_is_dynamic(data):
        res = akg.tvm.compute(
            output.shape,
            lambda *i: akg.lang.ascend.divide_var(output(*i), count),
            name="res")
    else:
        res = akg.topi.divide(output, count)

    attrs = get_attrs(data)
    if shape_is_dynamic(data):
        attrs["custom_tiling"] = mean_dynamic_tiling_strategy(data, axis)
    return res, attrs
Esempio n. 13
0
def reduce_all(data, axis=None, keepdims=False):
    """
    Computes logical and of the input tensor.

    Args:
        data(tvm.tensor.Tensor): Tensor of type Boolean.
        axis(Union[None, int, list, tuple]): Specifies which axes to reduce, if None, all dimensions of
            input tensor data will be reduced and the shape of output tensor will be (1,).
        keepdims(Union[None, bool]): if true, keep the dimensions with length 1.

    Returns:
        tvm.tensor.Tensor of same type as input tensor data.
    """

    shape = [x.value for x in data.shape]

    vc_util.ops_dtype_check(data.dtype, vc_util.DtypeForDavinci.BOOL)
    vc_util.check_shape(shape)

    if axis is None and keepdims is False:
        raise ValueError("keepdims must be True when axis is None!")

    axis_new = ft_util.refine_reduce_axis(data, axis)

    xx1 = akg.tvm.compute(shape,
                          lambda *indice: data(*indice).astype("float16"),
                          name='xx1')
    xx = (-xx1 + dc.one_const("float16"))
    yy = akg.topi.sum(xx, axis=axis_new, keepdims=keepdims)

    o_shape = list(yy.shape)

    zz = akg.tvm.compute(o_shape,
                         lambda *indice: yy(*indice).astype("bool"),
                         name='zz')

    y1 = akg.tvm.compute(
        o_shape,
        lambda *indice: akg.tvm.expr.Select(zz(
            *indice), dc.zero_const("float16"), dc.one_const("float16")),
        name="y1")

    y = akg.tvm.compute(o_shape,
                        lambda *indice: y1(*indice).astype("bool"),
                        name='y')

    return y
Esempio n. 14
0
def reduce_prod(data, axis=None, keepdims=False, target=utils.CCE):
    """
    Computes the product of elements along specific axis

    Args:
        data (tvm.tensor.Tensor): indicating the input tensor.
        axis (Union[list, tuple, int, None]): indicating the dimensions to reduce at. if it's None, all dimensions
                                               will be reduced.
        keepdims (Union[bool, None]): if true, keep the dimensions with length 1.

    Returns:
        Tensor, the product of elements of input tensor.

    Supported Platforms:
        'Ascend', 'GPU'
    """
    utils.check_supported_target(target)
    shape = [x.value for x in data.shape]
    utils.ops_dtype_check(data.dtype, [utils.DtypeForDavinci.ALL_FLOAT,
        utils.DtypeForDavinci.INT8, utils.DtypeForDavinci.UINT8])

    if axis is None and keepdims is False:
        raise ValueError("keepdims must be True when axis is None!")

    axis_new = refine_reduce_axis(data, axis)

    if target == utils.CUDA:
        return akg.topi.prod(data, axis=axis, keepdims=keepdims)

    utils.check_shape(shape)
    dtype = data.dtype
    if dtype in ["int8", "uint8"]:
        data = akg.topi.cast(data, "float16")

    vlog_t = log(data, target)
    res = akg.topi.sum(vlog_t, axis=axis_new, keepdims=keepdims)
    res = Exp(res, target)

    if dtype in ["int8", "uint8"]:
        res = akg.topi.cast(res, dtype)
    return res
Esempio n. 15
0
def get_reduced_indices(*indices, axis, keepdims):
    """Get the adjoint for an arbitrary dimension input."""

    # get all indices
    indices_list = list(indices)
    # list of reduction axis: transform negative indices into positive
    # axis in this list wont exist after the reduction
    axis_list = ft_util.refine_reduce_axis(indices_list, list(axis))
    # get indices after reduction
    if keepdims:
        grad_indices_list = [
            index_i if i not in axis_list else 0
            for i, index_i in enumerate(indices_list)
        ]
    else:
        grad_indices_list = [
            index_i for i, index_i in enumerate(indices_list)
            if i not in axis_list
        ]
    grad_ind = tuple(grad_indices_list)
    return grad_ind
Esempio n. 16
0
File: sum.py Progetto: zhuyawen/akg
def sum_v2(inputs, axis=None, keepdims=True):
    """another implementation of sum with topi api."""
    dtype = inputs.dtype
    vc_util.ops_dtype_check(dtype, vc_util.DtypeForDavinci.ALL_FLOAT)
    axis = ft_util.refine_reduce_axis(inputs, axis)
    vc_util.check_shape(inputs.shape)
    if not axis:
        output = akg.topi.identity(inputs)
    else:
        if dtype == "float16":
            step_sum = cast(inputs, "float32")
        else:
            step_sum = inputs

        step_sum = akg.topi.sum(step_sum, axis=axis, keepdims=keepdims)

        if dtype == "float16":
            output = cast(step_sum, "float16")
        else:
            output = step_sum
    attr_map = get_attrs()
    return output, attr_map
Esempio n. 17
0
File: mean.py Progetto: zhuyawen/akg
def mean_v2(data, axis=None, keepdims=False):
    """Simple implementation of mean."""
    # Check types
    vc_util.ops_dtype_check(data.dtype, vc_util.DtypeForDavinci.ALL_FLOAT)

    # Check shape
    shape = [x.value for x in data.shape]
    vc_util.reduce_axis_check(shape, axis)
    axis = ft_util.refine_reduce_axis(data, axis)

    dtype = data.dtype
    count = 1
    for i in axis:
        count *= shape[i]

    count_rec = 1 / count
    output, _ = sum.sum_v2(data, axis, keepdims)
    res = output * akg.tvm.const(count_rec, dtype)
    attrs = get_attrs(data)
    if shape_is_dynamic(data):
        attrs["custom_tiling"] = mean_dynamic_tiling_strategy(data, axis)
    return res, attrs
Esempio n. 18
0
def logsoftmax(inputs, axis):
    """
    Activation function, computes log softmax.

    Args:
        inputs: Tensor.
        axis: On which dimension log softmax is performed.

    Return:
        Tensor, which has the same shape and type as input.
    """
    dtype = inputs.dtype
    vc_util.check_shape(inputs.shape)
    vc_util.ops_dtype_check(dtype, vc_util.DtypeForDavinci.ALL_FLOAT)
    axis = refine_reduce_axis(inputs, axis)
    if isinstance(axis, (list, tuple)):
        if len(axis) != 1:
            raise RuntimeError("Reduce axis for logsoftmax op must br 1-dimension, while current is %d-dimension"
                               % (len(axis)))
        axis = axis[0]
    out = logsoftmax_op(inputs, inputs.shape, axis)
    attr_map = {"pragma_modshift": 1, "disable_cse": 1}
    return out, attr_map
Esempio n. 19
0
def reduce_and(inputs, axis=None, keepdims=False, target=utils.CUDA):
    """
    Compute the logical and of elements across dimensions of a tensor.

    Args:
        inputs (tvm.tensor.Tensor): Tensor.
        axis (Union[list, tuple, int, None]): If the list or tuple is empty, the axis equal to None.
        keepdims (bool): If keepdims equal to True, the result shape length is same to input shape length.

    Returns:
        tvm.tensor.Tensor, has same type as input. If keepdims is True, all reduced dimensions are retained
        with length 1, else these reduced axis will be eliminate.

    Supported Platforms:
        'GPU', 'CPU'
    """
    utils.check_supported_target(target)
    axis = ft_util.refine_reduce_axis(inputs, axis)
    utils.check_shape(inputs.shape)

    output = akg.topi.all(inputs, axis=axis, keepdims=keepdims)

    return output
Esempio n. 20
0
def mean(data, axis=None, keepdims=False):
    """
    Computes the mean of the values of a Tensor over the whole dataset.

    Args:
        data (tvm.tensor.Tensor): Tensor.
        axis (Union[list, tuple, int, None]): If the tuple is empty, the axis equal to None.
        keepdims (bool): If keepdims equal to True, the result shape length is same to input shape length.

    Returns:
            tvm.tensor.Tensor, has the same type as data. If keepdims equal to True, all reduced dimensions are
            retained with length 1. else these reduced axis will be eliminate.
    """
    shape = [x.value for x in data.shape]
    vc_util.reduce_axis_check(shape, axis)
    axis = ft_util.refine_reduce_axis(data, axis)

    count = 1
    for i in axis:
        count *= shape[i]
    output, _ = sum.sum_value(data, axis, keepdims)
    res = akg.topi.divide(output, count)

    return res
Esempio n. 21
0
def common(data, axis, method="min"):
    """
    Returns the index with the max or min value across axes of a tensor.

    Note:
        method can be "max" or "min" to get argmax or argmin.

    Args:
        data (tvm.tensor.Tensor): Tensor of type float16, float32, int8, int32.
        axis (int): Describe the axis of input tensor.
        method (str): Can be "max" or "min".

    Returns:
        tvm.tensor.Tensor, has type of int32.
    """
    shape = get_shape(data)
    dtype = data.dtype

    utils.ops_dtype_check(
        data.dtype,
        [utils.DtypeForDavinci.ALL_FLOAT, utils.DtypeForDavinci.ALL_INT])
    utils.reduce_axis_check(shape, axis)
    real_axis = refine_reduce_axis(shape, axis)[0]
    out_shape = get_reduce_out_shape(shape, axis=axis)
    attr_map = {}
    if shape_is_dynamic(data):
        attr_map["dynamic_shape"] = set_dynamic_shape_limit_for_tensor(
            data, 4096, real_axis)
    if dtype != "float16":
        data = akg.topi.cast(data, "float16")
    k = akg.tvm.reduce_axis((0, data.shape[real_axis]), "k")
    if axis in (len(shape) - 1, -1):
        if method == "min":
            reducer = akg.tvm.comm_reducer(lambda x, y: dav.fargmin(x, y),
                                           lambda t: akg.tvm.max_value(t))
        elif method == "max":
            reducer = akg.tvm.comm_reducer(lambda x, y: dav.fargmax(x, y),
                                           lambda t: akg.tvm.min_value(t))
        else:
            raise ValueError("not support {}".format(method))

        if len(data.shape) == 1:
            res = akg.tvm.compute((1, ), lambda i: reducer(data[k], axis=k))
        else:
            res = akg.tvm.compute(
                out_shape, lambda *indice: reducer(data(*indice, k), axis=k))

        res = akg.tvm.compute(out_shape,
                              lambda *indice: res(*indice).astype("int32"),
                              "argred_output")
    elif axis in (0, -len(shape)):
        tmp_idx = akg.tvm.compute(
            shape[1:],
            lambda *indice: akg.tvm.const(0.0, "float16"),
            name='tmp_index')
        local_data = akg.tvm.compute(shape[1:],
                                     lambda *indice: data(0, *indice),
                                     name="tmp_data")
        for idx in range(shape[axis] - 1):
            if method == 'min':
                tmp_idx = akg.tvm.compute(
                    shape[1:],
                    lambda *indice, ite_idx=idx: akg.tvm.expr.Select(
                        local_data(*indice) > data(ite_idx + 1, *indice),
                        akg.tvm.const(ite_idx + 1, "float16"), tmp_idx(*indice)
                    ))
                local_data = akg.tvm.compute(
                    shape[1:],
                    lambda *indice, ite_idx=idx: akg.tvm.expr.Select(
                        local_data(*indice) > data(ite_idx + 1, *indice),
                        data(ite_idx + 1, *indice), local_data(*indice)))
            elif method == "max":
                tmp_idx = akg.tvm.compute(
                    shape[1:],
                    lambda *indice, ite_idx=idx: akg.tvm.expr.Select(
                        local_data(*indice) < data(ite_idx + 1, *indice),
                        akg.tvm.const(ite_idx + 1, "float16"), tmp_idx(*indice)
                    ))
                local_data = akg.tvm.compute(
                    shape[1:],
                    lambda *indice, ite_idx=idx: akg.tvm.expr.Select(
                        local_data(*indice) < data(ite_idx + 1, *indice),
                        data(ite_idx + 1, *indice), local_data(*indice)))
            else:
                raise ValueError("not support " + method)

        res = akg.tvm.compute(out_shape,
                              lambda *indice: tmp_idx(*indice).astype("int32"),
                              "cast1")
    else:
        raise ValueError(
            "Argmax only support first axis and is last axis now!")

    lager = out_shape if len(out_shape) > len(shape) else shape
    strategy = argminmax_tiling_strategy(lager, real_axis)
    if strategy:
        attr_map["custom_tiling"] = strategy
    return res, attr_map