Beispiel #1
0
    def kernel_ir(dst, data):
        ib = tvm.ir_builder.create()
        original_shape_ = [get_const(x) for x in data.shape]
        m0, n0 = original_shape_[-2:]
        unpad_shape_ = [get_const(x) for x in unpad_after]
        m1, n1 = unpad_shape_[-2:]
        batch_dims = data.shape[:-2]

        with ib.for_range_n(batch_dims, "bs") as i:
            with ib.for_range(0, m0 - m1) as i_m1:
                with ib.for_range(0, n0 - n1) as i_n1:
                    output_args = i + [i_m1, i_n1]
                    input_args = i + [i_m1, i_n1]
                    ib.store(dst, output_args, ib.load(data, input_args))
        return ib.get()
Beispiel #2
0
    def _default2zn(data):
        shape = [get_const(x) for x in data.shape]
        dtype = data.dtype
        if len(shape) < 2:
            raise ValueError(
                "length of shape of input_data should be greater than or equal to 2, but got %d"
                % len(shape))
        m, n = shape[-2:]
        output_shape = []
        for i in range(0, len(shape) - 2):
            output_shape.append(shape[i])
        m1 = (m + cs - 1) // cs
        n1 = (n + cs - 1) // cs
        output_shape.extend([n1, m1, cs, cs])

        def fcompute(*output_indices):
            input_indices = []
            batch_len = len(output_indices) - 4
            n1_indice = output_indices[batch_len]
            m1_indice = output_indices[batch_len + 1]
            m0_indcie = output_indices[batch_len + 2]
            n0_indcie = output_indices[batch_len + 3]
            m_indice = m1_indice * cs + m0_indcie
            n_indice = n1_indice * cs + n0_indcie
            for i in range(0, batch_len):
                input_indices.append(output_indices[i])
            input_indices.append(m_indice)
            input_indices.append(n_indice)
            res = tvm.if_then_else(tvm.any(m_indice >= m, n_indice >= n),
                                   tvm.const(0, dtype), data(*input_indices))
            return res

        output = tvm.compute(output_shape, fcompute, name=output_name)
        return output
Beispiel #3
0
def fractalzz2two(data, out_dtype, shape_original):
    """zZ change"""
    shape = [get_const(x) for x in data.shape]
    assert len(shape) >= 4

    m1, n1, m0, n0 = shape[-4:]
    if len(shape) == 5:
        b = shape[0]
    elif len(shape) == 6:
        b, s = shape[:2]
    m, n = m1 * m0, n1 * n0
    @script(capture=locals())
    def reshape_zz_2d(inputs):
        output = allocate((m, n), inputs.dtype, 'local')
        for m_i1 in range(m1):
            for n_i1 in range(n1):
                for m_i0 in range(m0):
                    for n_i0 in range(n0):
                        output[m_i1 * 16 + m_i0, n_i1 * 16 + n_i0] = inputs[m_i1, n_i1, m_i0, n_i0]
        return output

    @script(capture=locals())
    def reshape_zz_3d(inputs):
        output = allocate((b, m, n), inputs.dtype, 'local')
        for b_i in range(b):
            for m_i1 in range(m1):
                for n_i1 in range(n1):
                    for m_i0 in range(m0):
                        for n_i0 in range(n0):
                            output[b_i, m_i1 * 16 + m_i0, n_i1 * 16 + n_i0] = inputs[b_i, m_i1, n_i1, m_i0, n_i0]
        return output

    @script(capture=locals())
    def reshape_zz_4d(inputs):
        output = allocate((b, s, m, n), inputs.dtype, 'local')
        for b_i in range(b):
            for s_i in range(s):
                for m_i1 in range(m1):
                    for n_i1 in range(n1):
                        for m_i0 in range(m0):
                            for n_i0 in range(n0):
                                output[b_i, s_i, m_i1 * 16 + m_i0, n_i1 * 16 + n_i0] = \
                                inputs[b_i, s_i, m_i1, n_i1, m_i0, n_i0]
        return output

    if len(shape_original) == 2:
        output = reshape_zz_2d(data)
    elif len(shape_original) == 3:
        output = reshape_zz_3d(data)
    elif len(shape_original) == 4:
        output = reshape_zz_4d(data)
    final_shape = shape[:-4] + [m, n]
    assert final_shape == shape_original
    assert out_dtype == data.dtype
    # if finalShape != shape_original:
    # output = akg.tvm.compute(shape_original, lambda *indice: output(*indice), name="slice_output")
    # if out_dtype != data.dtype:
    # output = akg.lang.ascend.cast_to(output, out_dtype)

    return output
Beispiel #4
0
def auto_pad(data):
    shape = [get_const(x) for x in data.shape]
    assert len(shape) >= 2
    pad_shape = [(x + 15) // 16 * 16 for x in shape]
    paddings = [[0, 0] for _ in range(len(shape))]
    paddings[-1] = [0, pad_shape[-1] - shape[-1]]
    paddings[-2] = [0, pad_shape[-2] - shape[-2]]
    return pad(data, paddings, 'constant')
Beispiel #5
0
        def kernel_ir(input_, output):
            ib = tvm.ir_builder.create()
            shape = [get_const(x) for x in input_.shape]
            n1, m1, m0, n0 = shape[-4:]
            original_shape_ = [get_const(x) for x in original_shape]
            m, n = original_shape_[-2:]
            batch_dims = shape[:-4]

            with ib.for_range_n(batch_dims, "bs") as i:
                with ib.for_range(0, n1) as i_n1:
                    with ib.for_range(0, m1) as i_m1:
                        with ib.for_range(0, m0) as i_m0:
                            with ib.for_range(0, n0) as i_n0:
                                with ib.if_scope(
                                        tvm.all((i_m1 * cs + i_m0) < m,
                                                (i_n1 * cs + i_n0) < n)):
                                    output_args = i + [
                                        i_m1 * cs + i_m0, i_n1 * cs + i_n0
                                    ]
                                    input_args = i + [i_n1, i_m1, i_m0, i_n0]
                                    ib.store(output, output_args,
                                             ib.load(input_, input_args))
            return ib.get()
Beispiel #6
0
def two2fractal(data, format_):
    support_formats = ['zN', 'zZ', 'nZ']
    shape = [get_const(x) for x in data.shape]

    assert format_ in support_formats
    assert len(shape) >= 2 and len(shape) <= 4

    m, n = shape[-2:]
    if len(shape) == 3:
        b = shape[0]
    if len(shape) == 4:
        b, s = shape[:2]
    pad_m, pad_n = m, n
    if m % 16 != 0:
        pad_m = (m + 15) // 16 * 16
    if n % 16 != 0:
        pad_n = (n + 15) // 16 * 16
    m1, n1 = pad_m // 16, pad_n // 16
    m0, n0 = 16, 16

    @script(capture=locals())
    def reshape_zn_2d(inputs, zero):
        output = allocate((n1, m1, m0, n0), inputs.dtype, 'local')
        for n_i in range(n1):
            for m_i in range(m1):
                for m_i0 in range(m0):
                    for n_i0 in range(n0):
                        if (m_i * 16 + m_i0 >= m):
                            output[n_i, m_i, m_i0, n_i0] = zero
                        else:
                            output[n_i, m_i, m_i0,
                                   n_i0] = inputs[m_i * 16 + m_i0,
                                                  n_i * 16 + n_i0]
        return output

    @script(capture=locals())
    def reshape_zn_3d(inputs, zero):
        output = allocate((b, n1, m1, m0, n0), inputs.dtype, 'local')
        for b_i in range(b):
            for n_i in range(n1):
                for m_i in range(m1):
                    for m_i0 in range(m0):
                        for n_i0 in range(n0):
                            if (m_i * 16 + m_i0 >= m):
                                output[b_i, n_i, m_i, m_i0, n_i0] = zero
                            else:
                                output[b_i, n_i, m_i, m_i0,
                                       n_i0] = inputs[b_i, m_i * 16 + m_i0,
                                                      n_i * 16 + n_i0]
        return output

    @script(capture=locals())
    def reshape_zn_4d(inputs, zero):
        output = allocate((b, s, n1, m1, m0, n0), inputs.dtype, 'local')
        for b_i in range(b):
            for s_i in range(s):
                for n_i in range(n1):
                    for m_i in range(m1):
                        for m_i0 in range(m0):
                            for n_i0 in range(n0):
                                if (m_i * 16 + m_i0 >= m):
                                    output[b_i, s_i, n_i, m_i, m_i0,
                                           n_i0] = zero
                                else:
                                    output[b_i, s_i, n_i, m_i, m_i0,
                                           n_i0] = inputs[b_i, s_i,
                                                          m_i * 16 + m_i0,
                                                          n_i * 16 + n_i0]
        return output

    @script(capture=locals())
    def reshape_nz_2d(inputs, zero):
        output = allocate((m1, n1, n0, m0), inputs.dtype, 'local')
        for m_i in range(m1):
            for n_i in range(n1):
                for n_i0 in range(n0):
                    for m_i0 in range(m0):
                        if (m_i * 16 + m_i0 >= m):
                            output[m_i, n_i, n_i0, m_i0] = zero
                        else:
                            output[m_i, n_i, n_i0,
                                   m_i0] = inputs[m_i * 16 + m_i0,
                                                  n_i * 16 + n_i0]
        return output

    @script(capture=locals())
    def reshape_nz_3d(inputs, zero):
        output = allocate((b, m1, n1, n0, m0), inputs.dtype, 'local')
        for b_i in range(b):
            for m_i in range(m1):
                for n_i in range(n1):
                    for n_i0 in range(n0):
                        for m_i0 in range(m0):
                            if (m_i * 16 + m_i0 >= m):
                                output[b_i, m_i, n_i, n_i0, m_i0] = zero
                            else:
                                output[b_i, m_i, n_i, n_i0,
                                       m_i0] = inputs[b_i, m_i * 16 + m_i0,
                                                      n_i * 16 + n_i0]
        return output

    @script(capture=locals())
    def reshape_nz_4d(inputs, zero):
        output = allocate((b, s, m1, n1, n0, m0), inputs.dtype, 'local')
        for b_i in range(b):
            for s_i in range(s):
                for m_i in range(m1):
                    for n_i in range(n1):
                        for n_i0 in range(n0):
                            for m_i0 in range(m0):
                                if (m_i * 16 + m_i0 >= m):
                                    output[b_i, s_i, m_i, n_i, n_i0,
                                           m_i0] = zero
                                else:
                                    output[b_i, s_i, m_i, n_i, n_i0,
                                           m_i0] = inputs[b_i, s_i,
                                                          m_i * 16 + m_i0,
                                                          n_i * 16 + n_i0]
        return output

    @script(capture=locals())
    def reshape_zz_2d(inputs, zero):
        output = allocate((m1, n1, m0, n0), inputs.dtype, 'local')
        for m_i in range(m1):
            for n_i in range(n1):
                for m_i0 in range(m0):
                    for n_i0 in range(n0):
                        if (m_i * 16 + m_i0 >= m):
                            output[m_i, n_i, m_i0, n_i0] = zero
                        else:
                            output[m_i, n_i, m_i0,
                                   n_i0] = inputs[m_i * 16 + m_i0,
                                                  n_i * 16 + n_i0]
        return output

    @script(capture=locals())
    def reshape_zz_3d(inputs, zero):
        output = allocate((b, m1, n1, m0, n0), inputs.dtype, 'local')
        for b_i in range(b):
            for m_i in range(m1):
                for n_i in range(n1):
                    for m_i0 in range(m0):
                        for n_i0 in range(n0):
                            if (m_i * 16 + m_i0 >= m):
                                output[b_i, m_i, n_i, m_i0, n_i0] = zero
                            else:
                                output[b_i, m_i, n_i, m_i0,
                                       n_i0] = inputs[b_i, m_i * 16 + m_i0,
                                                      n_i * 16 + n_i0]
        return output

    @script(capture=locals())
    def reshape_zz_4d(inputs, zero):
        output = allocate((b, s, m1, n1, m0, n0), inputs.dtype, 'local')
        for b_i in range(b):
            for s_i in range(s):
                for m_i in range(m1):
                    for n_i in range(n1):
                        for m_i0 in range(m0):
                            for n_i0 in range(n0):
                                if (m_i * 16 + m_i0 >= m):
                                    output[b_i, s_i, m_i, n_i, m_i0,
                                           n_i0] = zero
                                else:
                                    output[b_i, s_i, m_i, n_i, m_i0,
                                           n_i0] = inputs[b_i, s_i,
                                                          m_i * 16 + m_i0,
                                                          n_i * 16 + n_i0]
        return output

    cast_data = data
    if data.dtype == 'float32':
        cast_data = akg.lang.ascend.cast_to(data, 'float16')
    zero = akg.tvm.const(0.0, cast_data.dtype)
    pad_data = cast_data
    # n padding is not support now because of alignment issue
    if n % 16 != 0:
        paddings = [[0, 0] for _ in range(len(shape))]
        paddings[-1] = [0, pad_n - n]
        pad_data = pad.pad(cast_data, paddings, 'constant')
    if format_ == 'zN':
        if len(shape) == 2:
            output = reshape_zn_2d(pad_data, zero)
        if len(shape) == 3:
            output = reshape_zn_3d(pad_data, zero)
        if len(shape) == 4:
            output = reshape_zn_4d(pad_data, zero)
    elif format_ == 'zZ':
        if len(shape) == 2:
            output = reshape_zz_2d(pad_data, zero)
        if len(shape) == 3:
            output = reshape_zz_3d(pad_data, zero)
        if len(shape) == 4:
            output = reshape_zz_4d(pad_data, zero)
    elif format_ == 'nZ':
        if len(shape) == 2:
            output = reshape_nz_2d(pad_data, zero)
        if len(shape) == 3:
            output = reshape_nz_3d(pad_data, zero)
        if len(shape) == 4:
            output = reshape_nz_4d(pad_data, zero)

    return output