def get_ref_data():
            out_grad_np = np.random.uniform(size=out_grad_shape).astype(dtype)
            filter_np = np.random.uniform(size=filter_shape).astype(dtype)
            dilated_out_grad_np = topi.testing.dilate_python(
                out_grad_np, [1, stride_h, stride_w, 1])
            # padding params in forward propagation
            fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(
                [padding_h, padding_w], (filter_h, filter_w))
            # padding params in backward propagation
            bpad_top = filter_h - 1 - fpad_top
            bpad_bottom = (filter_h - 1 - fpad_bottom) + (stride_h - 1)
            bpad_left = filter_w - 1 - fpad_left
            bpad_right = (filter_w - 1 - fpad_right) + (stride_w - 1)

            padded_out_grad = np.zeros(
                (batch, dilated_out_grad_np.shape[1] + bpad_top + bpad_bottom,
                 dilated_out_grad_np.shape[2] + bpad_left + bpad_right,
                            bpad_top:dilated_out_grad_np.shape[1] + bpad_top,
                            bpad_left:dilated_out_grad_np.shape[2] +
                            bpad_left, :] = dilated_out_grad_np

            in_grad_np = np.zeros((batch, in_h, in_w, in_channel))
            for b in range(batch):
                for c in range(in_channel):
                    for m in range(channel_multiplier):
                        in_grad_np[b, :, :, c] += signal.convolve2d(padded_out_grad[b, :, :, c*channel_multiplier+m], \
                                filter_np[:, :, c, m], mode='valid')[0:in_h, 0:in_w]
            return (out_grad_np, filter_np, in_grad_np)
def np_conv(na, nw, padding, stride=1):
    batch, in_channel, in_height, in_width = na.shape
    _, num_filter, kernel_h, kernel_w = nw.shape
    if isinstance(stride, int):
        stride_h = stride_w = stride
        stride_h, stride_w = stride

    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(
        padding, (kernel_h, kernel_w))
    pad_h = pad_top + pad_bottom
    pad_w = pad_left + pad_right

    out_channel = num_filter
    out_height = (in_height - kernel_h + pad_h) // stride_h + 1
    out_width = (in_width - kernel_w + pad_w) // stride_w + 1
    nb = np.zeros((batch, out_channel, out_height, out_width))
    for n in range(batch):
        for f in range(out_channel):
            for c in range(in_channel):
                if pad_h > 0 or pad_w > 0:
                    apad = np.zeros((in_height + pad_h, in_width + pad_w))
                    apad[pad_top:pad_top + in_height,
                         pad_left:pad_left + in_width] = na[n, c]
                    apad = na[n, c]
                out = scipy.signal.convolve2d(apad,
                                              np.rot90(np.rot90(nw[f, c])),
                nb[n, f] += out[::stride, ::stride]
    return nb
def conv2d_grad(orig, grad):
    """Gradient of conv2d"""
    attrs = orig.attrs
    data, weight = orig.args
    data_shape = get_const_tuple(data.checked_type.shape)
    weight_shape = get_const_tuple(weight.checked_type.shape)
    _, _, grad_h, grad_w = get_const_tuple(orig.checked_type.shape)
    batch, in_channel, in_h, in_w = data_shape
    out_channel, _, filter_h, filter_w = weight_shape

    # infer output_padding
    fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(
        get_const_tuple(attrs.padding), (filter_h, filter_w))
    stride_h, stride_w = get_const_tuple(attrs.strides)
    dilation_h, dilation_w = get_const_tuple(attrs.dilation)
    out_h = (grad_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h
    out_w = (grad_w - 1) * stride_w - fpad_left - fpad_right + filter_w
    output_padding = (in_h - out_h, in_w - out_w)

    assert attrs.data_layout == 'NCHW', 'only support NCHW data layout'
    assert attrs.kernel_layout == 'OIHW', 'only support OIHW kernel layout'
    assert attrs.out_layout in ['', 'NCHW'], 'only support NCHW output layout'

    backward_data = _nn.conv2d_transpose(grad,
    grad = tile(grad, [1, in_channel // attrs.groups, 1, 1])
    grad = reshape(grad, [-1, 1, 0, 0])  # batch * oc * ic // groups, 1, oh, ow
    data = reshape(data, [1, -1, 0, 0])  # 1, batch * ic, ih, iw

    backward_weight = _nn.conv2d(data,
                                 groups=in_channel * batch)
    # infer shape of backward_weight
    padded_weight_grad_h = (in_h - (grad_h - 1) * stride_h - 1 + fpad_top + fpad_bottom) \
                           // dilation_h + 1
    padded_weight_grad_w = (in_w - (grad_w - 1) * stride_w - 1 + fpad_left + fpad_right) \
                           // dilation_w + 1
    backward_weight = reshape(backward_weight, [
        batch, in_channel // attrs.groups, out_channel, padded_weight_grad_h,
    backward_weight = _sum(backward_weight, axis=0)
    backward_weight = transpose(backward_weight, [1, 0, 2, 3])

    assert padded_weight_grad_h >= filter_h
    assert padded_weight_grad_w >= filter_w
    if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w:
        backward_weight = strided_slice(backward_weight,
                                        begin=[0, 0, 0, 0],
                                        end=[None, None, filter_h, filter_w])

    return [backward_data, backward_weight]
def compile_conv2d_NHWC_gemm_int8_arm(batch, in_channel, in_size, num_filter, kernel, stride, padding,
                                 dilation=1, add_bias=False, add_relu=False):
    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
    padding_sum = pad_top + pad_left + pad_bottom + pad_right
    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter,
                                                          kernel, stride, padding_sum, dilation))

    in_height = in_width = in_size
    A = te.placeholder((batch, in_height, in_width, in_channel), name='A', dtype='int8')
    W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W', dtype='int8')
    bias = te.placeholder((num_filter,), name='bias', dtype='int8')
    dtype = 'int32'
    device = "llvm --device arm_cpu --mtriple aarch64-linux-gnu"

    ctx = tvm.context(device, 0)
    if not ctx.exist:
        print("Skip because %s is not enabled" % device)
    print("Compiling on arm AArch64 target: %s" % device)
        assert is_aarch64_arm(), "AArch64 target not recognized"

        C = topi.arm_cpu.compute_conv2d_NHWC_quantized(A, W, (stride, stride), padding,
                                                       (dilation, dilation), dtype)
        if add_bias:
            C = topi.add(C, bias)
        if add_relu:
            C = topi.nn.relu(C)
        s = topi.arm_cpu.schedule_conv2d_NHWC_quantized([C])

    if add_bias:, [A, W, bias, C], device,
                  name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch,
        func =, [A, W, bias, C], device,
                         name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch,
        func =, [A, W, C], device,
                         name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch,
def _conv2d_nhwc_python(a_np, w_np, stride, padding):
    """Convolution operator in NHWC layout.

    a_np : numpy.ndarray
        4-D with shape [batch, in_height, in_width, in_channel]

    w_np : numpy.ndarray
        4-D with shape [filter_height, filter_width, in_channel, num_filter]

    stride : int or a list/tuple of two ints
        Stride size, or [stride_height, stride_width]

    padding : int or str or a list/tuple of 2 or 4 ints
        Padding size, or ['VALID', 'SAME'], or
        [pad_height, pad_width] for 2 ints, or
        [pad_top, pad_left, pad_bottom, pad_right] for 2 ints

    b_np : np.ndarray
        4-D with shape [batch, out_height, out_width, out_channel]
    batch, in_height, in_width, in_channel = a_np.shape
    kernel_h, kernel_w, _, num_filter = w_np.shape
    if isinstance(stride, int):
        stride_h = stride_w = stride
        stride_h, stride_w = stride

    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(
        padding, (kernel_h, kernel_w))
    pad_h = pad_top + pad_bottom
    pad_w = pad_left + pad_right

    # compute the output shape
    out_channel = num_filter
    out_height = (in_height - kernel_h + pad_h) // stride_h + 1
    out_width = (in_width - kernel_w + pad_w) // stride_w + 1
    # change the layout from NHWC to NCHW
    at = a_np.transpose((0, 3, 1, 2))
    wt = w_np.transpose((3, 2, 0, 1))
    bt = np.zeros((batch, out_channel, out_height, out_width))
    # computation
    for n in range(batch):
        for f in range(out_channel):
            for c in range(in_channel):
                if pad_h > 0 or pad_w > 0:
                    apad = np.zeros((in_height + pad_h, in_width + pad_w))
                    apad[pad_top:pad_top + in_height,
                         pad_left:pad_left + in_width] = at[n, c]
                    apad = at[n, c]
                out = scipy.signal.convolve2d(apad,
                                              np.rot90(np.rot90(wt[f, c])),
                bt[n, f] += out[::stride_h, ::stride_w]
    return bt.transpose((0, 2, 3, 1))
def conv2d_transpose_nchw_python(a_np,
                                 output_padding=(0, 0)):
    """Transposed convolution operator in NCHW layout.

    a_np : numpy.ndarray
        4-D with shape [batch, in_channel, in_height, in_width]

    w_np : numpy.ndarray
        4-D with shape [in_channel, num_filter, filter_height, filter_width]

    stride : int or a list/tuple of two ints
        Stride size, or [stride_height, stride_width]

    padding : int or str
        Padding size, or ['VALID', 'SAME']

    b_np : np.ndarray
        4-D with shape [batch, out_channel, out_height, out_width]
    batch, in_c, in_h, in_w = a_np.shape
    _, out_c, filter_h, filter_w = w_np.shape
    if isinstance(stride, int):
        stride_h = stride_w = stride
        stride_h, stride_w = stride
    opad_h, opad_w = output_padding
    # dilate stage
    dilated_a_np = topi.testing.dilate_python(a_np, [1, 1, stride_h, stride_w])
    # padding stage
    fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(
        padding, (filter_h, filter_w))
    bpad_top = filter_h - 1 - fpad_top
    bpad_bottom = filter_h - 1 - fpad_bottom + opad_h
    bpad_left = filter_w - 1 - fpad_left
    bpad_right = filter_w - 1 - fpad_right + opad_w
    padded_a_np = np.zeros((batch, in_c, dilated_a_np.shape[2]+bpad_top+bpad_bottom, \
    padded_a_np[:, :, bpad_top:dilated_a_np.shape[2]+bpad_top, \
        bpad_left:dilated_a_np.shape[3]+bpad_left] = dilated_a_np
    # convolution stage
    out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h + opad_h
    out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w + opad_w
    b_np = np.zeros((batch, out_c, out_h, out_w))
    for n in range(batch):
        for f in range(out_c):
            for c in range(in_c):
                out = scipy.signal.convolve2d(padded_a_np[n, c],
                                              w_np[c, f],
                b_np[n, f] += out
    return b_np
def conv2d_transpose_packed(cfg,
                            output_padding=(0, 0)):
    """Packed conv2d_transpose compute"""
    ishape = get_const_tuple(data.shape)
    kshape = get_const_tuple(kernel.shape)
    b, c_i, i_h, i_w, t_b, t_ci = ishape
    c_o, _, k_h, k_w, t_co, t_ci = kshape
    stride_h, stride_w = strides
    opad_h, opad_w = output_padding
    # FIXME(tmoreau89): currently IR pass breaks when output padding != (0,0)
    assert opad_h == 0 and opad_w == 0, "VTA does not support output padding for now"

    # derive padding parameters
    fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(
        padding, (k_h, k_w))
    bpad_top = k_h - 1 - fpad_top
    bpad_bottom = k_h - 1 - fpad_bottom + opad_h
    bpad_left = k_w - 1 - fpad_left
    bpad_right = k_w - 1 - fpad_right + opad_w

    # padding stage
    dilated_input = topi.nn.dilate(data, [1, 1, stride_h, stride_w, 1, 1])
    data_pad = topi.nn.pad(dilated_input, [0, 0, bpad_top, bpad_left, 0, 0],
                           [0, 0, bpad_bottom, bpad_right, 0, 0])

    # convolution transpose stage
    out_h = (i_h - 1) * stride_h - fpad_top - fpad_bottom + k_h + opad_h
    out_w = (i_w - 1) * stride_w - fpad_left - fpad_right + k_w + opad_w
    oshape = (b, c_o, out_h, out_w, t_b, t_co)
    d_c = te.reduce_axis((0, c_i), name='d_c')
    d_h = te.reduce_axis((0, k_h), name='d_h')
    d_w = te.reduce_axis((0, k_w), name='d_w')
    d_ci = te.reduce_axis((0, t_ci), name='d_ci')

    out = te.compute(oshape,
                     lambda i_n, i_c, i_h, i_w, j_n, j_c: te.
                     sum(data_pad(i_n, d_c, i_h + d_h, i_w + d_w, j_n, d_ci).
                         astype(out_dtype) * kernel[i_c, d_c, d_h, d_w, j_c,
                         axis=[d_c, d_h, d_w, d_ci]),

    cfg.add_flop(2 * * kshape[2] *
                 kshape[3] * ishape[1] * ishape[-1])

    return out
def conv2d_transpose_nchw_python(a_np, w_np, stride, padding):
    """Transposed convolution operator in NCHW layout.

    a_np : numpy.ndarray
        4-D with shape [batch, in_channel, in_height, in_width]

    w_np : numpy.ndarray
        4-D with shape [in_channel, num_filter, filter_height, filter_width]

    stride : int or a list/tuple of two ints
        Stride size, or [stride_height, stride_width]

    padding : int or str
        Padding size, or ['VALID', 'SAME']

    b_np : np.ndarray
        4-D with shape [batch, out_channel, out_height, out_width]
    batch, in_c, in_h, in_w = a_np.shape
    _, out_c, filter_h, filter_w = w_np.shape
    if isinstance(stride, int):
        stride_h = stride_w = stride
        stride_h, stride_w = stride
    # dilate stage
    dilated_a_np = topi.testing.dilate_python(a_np, [1, 1, stride_h, stride_w])
    # padding stage
    fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(padding, (filter_h, filter_w))
    bpad_top = filter_h - 1 - fpad_top
    bpad_bottom = filter_h - 1 - fpad_bottom
    bpad_left = filter_w - 1 - fpad_left
    bpad_right = filter_w - 1 - fpad_right
    padded_a_np = np.zeros((batch, in_c, dilated_a_np.shape[2]+bpad_top+bpad_bottom, \
    padded_a_np[:, :, bpad_top:dilated_a_np.shape[2]+bpad_top, \
        bpad_left:dilated_a_np.shape[3]+bpad_left] = dilated_a_np
    # convolution stage
    out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h
    out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w
    b_np = np.zeros((batch, out_c, out_h, out_w))
    for n in range(batch):
        for f in range(out_c):
            for c in range(in_c):
                out = scipy.signal.convolve2d(
                    padded_a_np[n, c], w_np[c, f], mode='valid')
                b_np[n, f] += out
    return b_np
        def get_ref_data():
            out_grad_np = np.random.uniform(size=out_grad_shape).astype(dtype)
            input_np = np.random.uniform(size=in_shape).astype(dtype)
            dilated_out_grad_np = topi.testing.dilate_python(out_grad_np, [1, stride_h, stride_w, 1])

            pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple([padding_h, padding_w], (filter_h, filter_w))
            padded_input_np = np.zeros((batch, in_h+pad_top+pad_bottom, in_w+pad_left+pad_right, in_channel))
            padded_input_np[:, pad_top:in_h+pad_top, pad_left:in_w+pad_left, :] = input_np

            weight_grad_np = np.zeros((filter_h, filter_w, in_channel, channel_multiplier))
            for c in range(in_channel):
                for m in range(channel_multiplier):
                    for b in range(batch):
                        weight_grad_np[:, :, c, m] += signal.convolve2d(padded_input_np[b, :, :, c], \
                            np.rot90(dilated_out_grad_np[b, :, :, c*channel_multiplier+m%channel_multiplier], 2), \
                            mode='valid')[0:filter_h, 0:filter_w]
            return (out_grad_np, input_np, weight_grad_np)
        def get_ref_data():
            out_grad_np = np.random.uniform(size=out_grad_shape).astype(dtype)
            input_np = np.random.uniform(size=in_shape).astype(dtype)
            dilated_out_grad_np = topi.testing.dilate_python(out_grad_np, [1, stride_h, stride_w, 1])

            pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple([padding_h, padding_w], (filter_h, filter_w))
            padded_input_np = np.zeros((batch, in_h+pad_top+pad_bottom, in_w+pad_left+pad_right, in_channel))
            padded_input_np[:, pad_top:in_h+pad_top, pad_left:in_w+pad_left, :] = input_np

            weight_grad_np = np.zeros((filter_h, filter_w, in_channel, channel_multiplier))
            for c in range(in_channel):
                for m in range(channel_multiplier):
                    for b in range(batch):
                        weight_grad_np[:, :, c, m] += signal.convolve2d(padded_input_np[b, :, :, c], \
                            np.rot90(dilated_out_grad_np[b, :, :, c*channel_multiplier+m%channel_multiplier], 2), \
                            mode='valid')[0:filter_h, 0:filter_w]
            return (out_grad_np, input_np, weight_grad_np)
def conv2d_transpose_nchw_python(a_np, w_np, stride, padding):
    """Transposed convolution operator in NCHW layout.

    a_np : numpy.ndarray
        4-D with shape [batch, in_channel, in_height, in_width]

    w_np : numpy.ndarray
        4-D with shape [num_filter, in_channel, filter_height, filter_width]

    stride : int or a list/tuple of two ints
        Stride size, or [stride_height, stride_width]

    padding : int or str
        Padding size, or ['VALID', 'SAME']

    b_np : np.ndarray
        4-D with shape [batch, out_channel, out_height, out_width]
    batch, in_c, in_h, in_w = a_np.shape
    out_c, _, filter_h, filter_w = w_np.shape
    if isinstance(stride, int):
        stride_h = stride_w = stride
        stride_h, stride_w = stride
    # dilate stage
    dilated_a_np = topi.testing.dilate_python(a_np, [1, 1, stride_h, stride_w])
    # padding stage
    fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(padding, (filter_h, filter_w))
    bpad_top = filter_h - 1 - fpad_top
    bpad_bottom = filter_h - 1 - fpad_bottom
    bpad_left = filter_w - 1 - fpad_left
    bpad_right = filter_w - 1 - fpad_right
    padded_a_np = np.zeros((batch, in_c, dilated_a_np.shape[2]+bpad_top+bpad_bottom, \
    padded_a_np[:, :, bpad_top:dilated_a_np.shape[2]+bpad_top, \
        bpad_left:dilated_a_np.shape[3]+bpad_left] = dilated_a_np
    # convolution stage
    rotated_w_np = np.rot90(w_np, k=2, axes=(2, 3))
    b_np = topi.testing.conv2d_nchw_python(padded_a_np, rotated_w_np, stride=1, padding='VALID')
    return b_np
def _declatation_conv2d_transpose(cfg, data, kernel, strides, padding,
    ishape = get_const_tuple(data.shape)
    kshape = get_const_tuple(kernel.shape)
    b, c_i, i_h, i_w, t_b, t_ci = ishape
    c_o, _, k_h, k_w, t_co, t_ci = kshape
    stride_h, stride_w = strides

    # derive padding parameters
    fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(
        padding, (k_h, k_w))
    bpad_top = k_h - 1 - fpad_top
    bpad_bottom = k_h - 1 - fpad_bottom
    bpad_left = k_w - 1 - fpad_left
    bpad_right = k_w - 1 - fpad_right

    # padding stage
    dilated_input = topi.nn.dilate(data, [1, 1, stride_h, stride_w, 1, 1])
    data_pad = topi.nn.pad(dilated_input, [0, 0, bpad_top, bpad_left, 0, 0],
                           [0, 0, bpad_bottom, bpad_right, 0, 0])

    # convolution transpose stage
    out_h = (i_h - 1) * stride_h - fpad_top - fpad_bottom + k_h
    out_w = (i_w - 1) * stride_w - fpad_left - fpad_right + k_w
    oshape = (b, c_o, out_h, out_w, t_b, t_co)
    d_c = tvm.reduce_axis((0, c_i), name='d_c')
    d_h = tvm.reduce_axis((0, k_h), name='d_h')
    d_w = tvm.reduce_axis((0, k_w), name='d_w')
    d_ci = tvm.reduce_axis((0, t_ci), name='d_ci')

    out = tvm.compute(oshape,
                      lambda i_n, i_c, i_h, i_w, j_n, j_c: tvm.
                      sum(data_pad(i_n, d_c, i_h + d_h, i_w + d_w, j_n, d_ci).
                          astype(out_dtype) * kernel[i_c, d_c, d_h, d_w, j_c,
                          axis=[d_c, d_h, d_w, d_ci]),

    cfg.add_flop(2 * * kshape[2] *
                 kshape[3] * ishape[1] * ishape[-1])

    return out
        def get_ref_data():
            out_grad_np = np.random.uniform(size=out_grad_shape).astype(dtype)
            filter_np = np.random.uniform(size=filter_shape).astype(dtype)
            dilated_out_grad_np = topi.testing.dilate_python(out_grad_np, [1, stride_h, stride_w, 1])
            # padding params in forward propagation
            fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple([padding_h, padding_w], (filter_h, filter_w))
            # padding params in backward propagation
            bpad_top = filter_h - 1 - fpad_top
            bpad_bottom = (filter_h - 1 - fpad_bottom) + (stride_h - 1)
            bpad_left = filter_w - 1 - fpad_left
            bpad_right = (filter_w - 1 - fpad_right) + (stride_w - 1)

            padded_out_grad = np.zeros((batch, dilated_out_grad_np.shape[1]+bpad_top+bpad_bottom,
                dilated_out_grad_np.shape[2]+bpad_left+bpad_right, out_channel))
            padded_out_grad[:, bpad_top:dilated_out_grad_np.shape[1]+bpad_top,
                bpad_left:dilated_out_grad_np.shape[2]+bpad_left, :] = dilated_out_grad_np

            in_grad_np = np.zeros((batch, in_h, in_w, in_channel))
            for b in range(batch):
                for c in range(in_channel):
                    for m in range(channel_multiplier):
                        in_grad_np[b, :, :, c] += signal.convolve2d(padded_out_grad[b, :, :, c*channel_multiplier+m], \
                                filter_np[:, :, c, m], mode='valid')[0:in_h, 0:in_w]
            return (out_grad_np, filter_np, in_grad_np)
def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False):
    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
    padding_sum = pad_top + pad_left + pad_bottom + pad_right
    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))

    in_height = in_width = in_size

    A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A', dtype='int8')
    W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W', dtype='int8')
    bias = tvm.placeholder((num_filter // oc_block_factor, 1, 1, oc_block_factor), name='bias',

    a_shape = get_const_tuple(A.shape)
    w_shape = get_const_tuple(W.shape)
    bias_shape = get_const_tuple(bias.shape)
    dtype = A.dtype

    def get_ref_data():
        a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
        w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
        b_np = np.random.uniform(size=bias_shape).astype(dtype)
        dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
        c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(dtype)

        # convert to NCHWc
        _, _, out_height, out_width = c_np.shape
        c_np = c_np.reshape((batch, num_filter // oc_block_factor, oc_block_factor, \
                out_height, out_width)).transpose(0, 1, 3, 4, 2)

        if add_bias:
            b_np = np.random.uniform(size=bias_shape).astype(dtype)
            c_np += b_np
        if add_relu:
            c_np = np.maximum(c_np, 0)

        return a_np, w_np, b_np, c_np

    a_np, w_np, b_np, c_np = get_ref_data()

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
        if device == "cuda" and not tvm.contrib.nvcc.have_int8(ctx.compute_version):
            print("Skip because int8 intrinsics are not available")

        print("Running on target: %s" % device)
            C = topi.nn.conv2d(A, W, (stride, stride), padding, (dilation, dilation),
                               layout='NCHW', out_dtype=dtype)
            if add_bias:
                C = topi.add(C, bias)
            if add_relu:
                C = topi.nn.relu(C)
            s = topi.generic.schedule_conv2d_nchw([C])

        a = tvm.nd.array(a_np, ctx)
        w = tvm.nd.array(w_np, ctx)
        b = tvm.nd.array(b_np, ctx)
        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
        if add_bias:
  , [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
            func =, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
            func(a, w, b, c)
            func =, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
            func(a, w, c)
        tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)

    for device in ["cuda"]:
def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False,\

    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
    padding_sum = pad_top + pad_left + pad_bottom + pad_right
    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))

    in_height = in_width = in_size

    A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
    W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
    bias = tvm.placeholder((num_filter, 1, 1), name='bias')

    a_shape = get_const_tuple(A.shape)
    w_shape = get_const_tuple(W.shape)
    bias_shape = get_const_tuple(bias.shape)
    dtype = A.dtype

    def get_ref_data():
        a_np = np.random.uniform(size=a_shape).astype(dtype)
        w_np = np.random.uniform(size=w_shape).astype(dtype)
        b_np = np.random.uniform(size=bias_shape).astype(dtype)
        dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
        c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding)
        if add_bias:
            c_np += b_np
        if add_relu:
            c_np = np.maximum(c_np, 0)
        return a_np, w_np, b_np, c_np

    a_np, w_np, b_np, c_np = get_ref_data()

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
        print("Running on target: %s" % device)
            C = topi.nn.conv2d(A, W, (stride, stride), padding,
                               (dilation, dilation), layout='NCHW', out_dtype=dtype)
            if add_bias:
                C = topi.add(C, bias)
            if add_relu:
                C = topi.nn.relu(C)
            s = topi.generic.schedule_conv2d_nchw([C])

        a = tvm.nd.array(a_np, ctx)
        w = tvm.nd.array(w_np, ctx)
        b = tvm.nd.array(b_np, ctx)
        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
        if add_bias:
            func =, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
            func(a, w, b, c)
            func =, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
            func(a, w, c)
        tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4)

    for device in get_all_backend():
        with autotvm.tophub.context(device):  # load tophub pre-tuned parameters

    if use_cudnn:
        check_device("cuda -model=unknown -libs=cudnn")
def _gen_cfg(cfg, data, kernel, strides, padding, dilation, num_tile):
    if len(kernel.shape) == 4:
        co_, _, kh_, kw_ = get_const_tuple(kernel.shape)
    else:  # kernel tensor is pre packed
        co_, _, kh_, kw_, vc_ = get_const_tuple(kernel.shape)
        co_ = co_ * vc_

    if isinstance(dilation, int):
        dilation_h = dilation_w = dilation
        dilation_h, dilation_w = dilation

    n_, ci_, ih_, iw_ = get_const_tuple(data.shape)

    dilated_kernel_h = (kh_ - 1) * dilation_h + 1
    dilated_kernel_w = (kw_ - 1) * dilation_w + 1
    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(
        padding, (dilated_kernel_h, dilated_kernel_w))
    hstr, wstr = strides if isinstance(strides,
                                       (tuple, list)) else (strides, strides)
    oh_ = (ih_ + pad_top + pad_bottom - dilated_kernel_h) // hstr + 1
    ow_ = (iw_ + pad_left + pad_right - dilated_kernel_w) // wstr + 1

    n, co, oh, ow = cfg.axis(n_), cfg.axis(co_), cfg.axis(oh_), cfg.axis(ow_)
    ci, kh, kw = cfg.reduce_axis(ci_), cfg.reduce_axis(kh_), cfg.reduce_axis(

    if num_tile == 2:  # for arm cpu
        candidate_vc = []
        for iv in range(3, co_):
            if co_ % iv == 0:
                candidate_vc.append([co_ // iv, iv])
        candidate_vc.append([1, co_])
        co, vc = cfg.define_split("tile_co",
        oh, vh = cfg.define_split("tile_oh", oh, num_outputs=2)
        ow, vw = cfg.define_split("tile_ow", ow, num_outputs=2)
    elif num_tile == 3:  # for mali gpu
        co, _, vc = cfg.define_split("tile_co", co, num_outputs=3)
        oh, _, vh = cfg.define_split("tile_oh", oh, num_outputs=3)
        ow, _, vw = cfg.define_split("tile_ow", ow, num_outputs=3)
        raise RuntimeError("Invalid num_tile")

        [n, co, oh, ow, ci, kh, kw, vh, vw, vc],
            [n, co, oh, ow, ci, kh, kw, vh, vw, vc],

    vc_ = cfg["tile_co"].size[-1]
    vh_ = cfg["tile_oh"].size[-1]
    vw_ = cfg["tile_ow"].size[-1]
    is_var = False
    return (is_var, vh_, vw_, vc_)
def depthwise_conv2d_with_workload_NCHWc(batch, in_channel, in_height, channel_multiplier, filter_height, stride, padding, dilation=1):
    in_width = in_height
    filter_channel = in_channel
    filter_width = filter_height
    stride_h = stride_w = stride

    assert dilation == 1, "depthwise_conv2d_NCHWc currently does not support dilation."
    pad_h, pad_w, _, _ = get_pad_tuple(padding, (filter_height, filter_width))
    padding_args = (pad_h, pad_w)

    out_channel = filter_channel * channel_multiplier
    # for testing functionality,
    # we choose arbitrary block size that can divide the channel,
    # regardless of the performance.
    oc_block = 1
    for bn in range(16, 0, -1):
        if out_channel % bn == 0:
            oc_block = bn

    ic_block = 1
    for bn in range(oc_block, 0, -1):
        if in_channel % bn == 0:
            ic_block = bn

    # placeholder
    Input = tvm.placeholder((batch, in_channel//ic_block, in_height, in_width, ic_block), name='Input')
    Filter = tvm.placeholder((out_channel//oc_block, filter_height, filter_width, oc_block), name='Filter')
    in_layout = "NCHW%dc" % ic_block
    out_layout = "NCHW%dc" % oc_block
    dtype = 'float32'

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
        print("Running on target: %s" % device)
            # declare
            DepthwiseConv2d = topi.nn.depthwise_conv2d_NCHWc(Input, Filter,
                                                             (stride_h, stride_w),
                                                             (dilation, dilation),
                                                             out_layout, dtype)
            # TODO: add scale_shift implement for NCHWc and add test here
            Relu = topi.nn.relu(DepthwiseConv2d)
            # schedule
            s1 = topi.generic.schedule_depthwise_conv2d_nchw(DepthwiseConv2d)
            s2 = topi.generic.schedule_depthwise_conv2d_nchw(Relu)
        # build the kernels
        f1 =, [Input, Filter, DepthwiseConv2d], device)
        f2 =, [Input, Filter, Relu], device)

        # Prepare pod type for test data closure
        input_shape = (batch, in_channel, in_height, in_width)
        filter_shape = (filter_channel, channel_multiplier, filter_height, filter_width)

        # Use memoize, pickle the test data for next time use.
        def get_ref_data():
            input_np = np.random.uniform(size=input_shape).astype(dtype)
            filter_np = np.random.uniform(size=filter_shape).astype(dtype)
            # correctness with scipy
            depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nchw(
                input_np, filter_np, stride, padding)
            relu_scipy = np.maximum(depthwise_conv2d_scipy, 0)
            return (_transform_data(input_np, ic_block),
                    _transform_kernel(filter_np, oc_block),
                    _transform_data(depthwise_conv2d_scipy, oc_block),
                    _transform_data(relu_scipy, oc_block))

        # Get the test data
        (input_np, filter_np, depthwise_conv2d_scipy, relu_scipy) = get_ref_data()

        input_tvm = tvm.nd.array(input_np, ctx)
        filter_tvm = tvm.nd.array(filter_np, ctx)
        depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape),
                                                     dtype=DepthwiseConv2d.dtype), ctx)
        relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx)
        # launch kernel 1 (depthwise_conv2d)
        f1(input_tvm, filter_tvm, depthwise_conv2d_tvm)
        # launch kernel 2 (depthwise_conv2d + relu)
        f2(input_tvm, filter_tvm, relu_tvm)
        tvm.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5)
        tvm.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)

    # test llvm only for now since depthwise_conv2d_NCHWc implement is missing in other backend.
    for device in ["llvm"]:
        with autotvm.tophub.context(device):  # load tophub pre-tuned parameters
def fused_convs(input_data, filters, resnet_block=False):

	out_dtype = input_data.dtype

	Input = None
	nodes = [input_data]
	params = [input_data]

	for f in filters:
		Input = nodes[-1]
		Filter = f.placeholder
		layout = f.layout
		depthwise = f.depthwise
		kernel = f.kernel
		stride = f.stride
		padding = f.padding
		dilation = f.dilation

		assert not (depthwise and kernel == 1) # Don't consider 1by1 depthwise

		padded_count = 0
		conv_count = 0
		depthwise_count = 0

		if isinstance(stride, int):
			stride_h = stride_w = stride
			stride_h, stride_w = stride

		if isinstance(dilation, int):
			dilation_h = dilation_w = dilation
			dilation_h, dilation_w = dilation

		batch, in_height, in_width, in_channel = Input.shape
		if f.NHWC_transpose: # HWOI
			kernel_h, kernel_w, tmp, kernel_channel = Filter.shape
		else: # HWIO
			kernel_h, kernel_w, kernel_channel, tmp = Filter.shape
		if depthwise:
			channel_multiplier = tmp
			num_filter = tmp

		# compute the output shape
		dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
		dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
		pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
			padding, (dilated_kernel_h, dilated_kernel_w))

		out_channel = simplify(in_channel * channel_multiplier) if depthwise else num_filter
		out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
		out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)

		if f.kernel > 1:
			print("Padding is needed!")

			pad_before = [0, pad_top, pad_left, 0]
			pad_after = [0, pad_down, pad_right, 0]

			PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput_{}".format(padded_count))
			padded_count += 1

			# Update Input
			Input = PaddedInput
			batch, in_height, in_width, in_channel = Input.shape

		if not depthwise:
			rc = tvm.reduce_axis((0, in_channel), name='rc')
		if kernel > 1:
			ry = tvm.reduce_axis((0, kernel_h), name='ry')
			rx = tvm.reduce_axis((0, kernel_w), name='rx')

		if not depthwise: # Normal convolution
			if kernel > 1:
				Output = tvm.compute(
				(batch, out_height, out_width, out_channel),
				lambda nn, yy, xx, ff: tvm.sum(
					Input[nn, yy * stride_h + ry * dilation_h,
								xx * stride_w + rx * dilation_w, rc].astype(out_dtype) *
					(Filter[ry, rx, ff, rc] if f.NHWC_transpose else Filter[ry, rx, rc, ff]).astype(out_dtype), axis=[ry, rx, rc]),
					name="Conv2dOutput_{}".format(conv_count), tag="conv2d_nhwc")
			else: # Only reduce rc axis
				Output = tvm.compute(
				(batch, out_height, out_width, out_channel),
				lambda nn, yy, xx, ff: tvm.sum(
					Input[nn, yy * stride_h, xx * stride_w, rc].astype(out_dtype) *
					(Filter[0, 0, ff, rc] if f.NHWC_transpose else Filter[0, 0, rc, ff]).astype(out_dtype), axis=[rc]),
					name="Conv2dOutput_{}".format(conv_count), tag="conv2d_nhwc")
			conv_count += 1
		else: # Depthwise convolution (kernel > 1)
			Output = tvm.compute(
			(batch, out_height, out_width, out_channel),
			lambda b, i, j, c: tvm.sum(
				(Input[b, i*stride_h + ry*dilation_h, j*stride_w + rx*dilation_w,
							 tvm.indexdiv(c, channel_multiplier)].astype(out_dtype) *
				(Filter[ry, rx, tvm.indexmod(c, channel_multiplier), tvm.indexdiv(c, channel_multiplier)] if f.NHWC_transpose else Filter[ry, rx, tvm.indexdiv(c, channel_multiplier), tvm.indexmod(c, channel_multiplier)]).astype(out_dtype)),
				axis=[ry, rx]),
			name='DepthwiseConv2dOutput_{}'.format(depthwise_count), tag="depthwise_nhwc")
			depthwise_count += 1


	if resnet_block:
		First = nodes[0]
		Last = nodes[-1]
		assert (first.shape == last.shape)
		Output = tvm.compute(
			(batch, out_height, out_width, out_channel),
			lambda b, i, j, c: tvm.sum(
				(First[b, i, j, c].astype(out_dtype) + 
				(Last[b, i, j, c]).astype(out_dtype))),
			name='ElementwiseAddOutput_{}'.format(depthwise_count), tag="elem_nhwc")

	params.append(nodes[-1]) # Final output
	return nodes, params
def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_multiplier, filter_height, stride, padding, dilation=1):
    in_width = in_height
    filter_channel = in_channel
    filter_width = filter_height
    stride_h = stride_w = stride

    if dilation == 1:
        # here we transform the padding argument from 'str' to  'tuple' ,
        # because we need this to match the "workload" tuple to the records in TopHub
        pad_h, pad_w, _, _ = get_pad_tuple(padding, (filter_height, filter_width))
        padding_args = (pad_h, pad_w)
        padding_args = padding

    # placeholder
    Input = tvm.placeholder((batch, in_channel, in_height, in_width), name='Input')
    Filter = tvm.placeholder((filter_channel, channel_multiplier, filter_height, filter_width), name='Filter')
    Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale')
    Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift')

    dtype = 'float32'

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
        print("Running on target: %s" % device)
            # declare
            DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, Filter,
                (stride_h, stride_w), padding_args, dilation, dtype)
            ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift)
            Relu = topi.nn.relu(ScaleShift)
            # schedule
            s1 = topi.generic.schedule_depthwise_conv2d_nchw(DepthwiseConv2d)
            s2 = topi.generic.schedule_depthwise_conv2d_nchw(ScaleShift)
            s3 = topi.generic.schedule_depthwise_conv2d_nchw(Relu)
        # build the kernels
        f1 =, [Input, Filter, DepthwiseConv2d], device)
        f2 =, [Input, Filter, Scale, Shift, ScaleShift], device)
        f3 =, [Input, Filter, Scale, Shift, Relu], device)

        # Prepare pod type for test data closure
        input_shape = get_const_tuple(Input.shape)
        filter_shape = get_const_tuple(Filter.shape)
        scale_shape = get_const_tuple(Scale.shape)
        shift_shape = get_const_tuple(Shift.shape)
        scale_shift_shape = get_const_tuple(ScaleShift.shape)

        # Use memoize, pickle the test data for next time use.
        def get_ref_data():
            input_np = np.random.uniform(size=input_shape).astype(dtype)
            filter_np = np.random.uniform(size=filter_shape).astype(dtype)
            dilated_filter_np = topi.testing.dilate_python(filter_np, (1, 1, dilation, dilation))
            scale_np = np.random.uniform(size=scale_shape).astype(dtype)
            shift_np = np.random.uniform(size=shift_shape).astype(dtype)
            # correctness with scipy
            depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nchw(
                input_np, dilated_filter_np, stride, padding)
            scale_shift_scipy = np.zeros(shape=scale_shift_shape)
            for c in range(in_channel * channel_multiplier):
                scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c]
                relu_scipy = np.maximum(scale_shift_scipy, 0)
            return (input_np, filter_np, scale_np, shift_np,
                    depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy)
        # Get the test data
        (input_np, filter_np, scale_np, shift_np,
         depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy) = get_ref_data()

        input_tvm = tvm.nd.array(input_np, ctx)
        filter_tvm = tvm.nd.array(filter_np, ctx)
        scale_tvm = tvm.nd.array(scale_np, ctx)
        shift_tvm = tvm.nd.array(shift_np, ctx)
        depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), ctx)
        scale_shift_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), ctx)
        relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx)
        # launch kernel 1 (depthwise_conv2d)
        timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=1)
        tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean
        # launch kernel 2 (depthwise_conv2d + scale_shift)
        timer_2 = f2.time_evaluator(f2.entry_name, ctx, number=1)
        tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm).mean
        # launch kernel 3 (depthwise_conv2d + scale_shift + relu)
        timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1)
        tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean
        tvm.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5)
        tvm.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5)
        tvm.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)

    for device in get_all_backend():
        with autotvm.tophub.context(device):  # load tophub pre-tuned parameters
def deformable_conv2d_nchw_python(a_np, offset_np, w_np, stride, padding, dilation,
                                  deformable_groups, groups):
    """Deformable convolution operator in NCHW layout.

    a_np : numpy.ndarray
        4-D with shape [batch, in_channel, in_height, in_width]

    offset_np : numpy.ndarray
        4-D with shape [batch, deformable_groups * filter_height * filter_width * 2,
                        out_height, out_width]

    w_np : numpy.ndarray
        4-D with shape [num_filter, in_channel, filter_height, filter_width]

    stride : int or a list/tuple of two ints
        Stride size, or [stride_height, stride_width]

    padding : int or str or a list/tuple of 2 or 4 ints
        Padding size, or ['VALID', 'SAME'], or
        [pad_height, pad_width] for 2 ints, or
        [pad_top, pad_left, pad_bottom, pad_right] for 2 ints

    dilation : int or a list/tuple of two ints
        Dilation size, or [dilate_height, dilate_width]

    deformable_groups : int
        Number of deformable groups

    groups : int
        Number of groups

    b_np : np.ndarray
        4-D with shape [batch, out_channel, out_height, out_width]
    batch, in_channel, in_height, in_width = a_np.shape
    out_channel, _, kernel_h, kernel_w = w_np.shape
    out_height, out_width = offset_np.shape[-2:]
    dtype = a_np.dtype
    ic_per_dgroup = in_channel // deformable_groups
    assert groups == 1, "deformable_conv2d_nchw_python does not support groups > 1"

    if isinstance(stride, int):
        stride_h = stride_w = stride
        stride_h, stride_w = stride

    pad_top, pad_left, _, _ = get_pad_tuple(padding, (kernel_h, kernel_w))

    if isinstance(dilation, int):
        dilation_h = dilation_w = dilation
        dilation_h, dilation_w = dilation

    def _bilinear(n, c, h, w):
        low_h, low_w = int(h), int(w)
        high_h = min(low_h + 1, in_height - 1)
        high_w = min(low_w + 1, in_width - 1)
        y_lerp = h - low_h
        x_lerp = w - low_w

        bottom = (1 - x_lerp) * a_np[n, c, low_h, low_w] + x_lerp * a_np[n, c, low_h, high_w]
        top = (1 - x_lerp) * a_np[n, c, high_h, low_w] + x_lerp * a_np[n, c, high_h, high_w]
        return (1 - y_lerp) * bottom + y_lerp * top

    a_deform = np.zeros((batch, in_channel, out_height, out_width, kernel_h, kernel_w), dtype=dtype)
    for n, h, w in itertools.product(range(batch), range(out_height), range(out_width)):
        offset = offset_np[n, :, h, w].reshape(deformable_groups, kernel_h, kernel_w, 2)
        in_h = h * stride_h - pad_top
        in_w = w * stride_w - pad_left

        index_h_base, index_w_base = np.meshgrid(
            np.arange(in_h, in_h + kernel_h * dilation_h, dilation_h, dtype=offset_np.dtype),
            np.arange(in_w, in_w + kernel_w * dilation_w, dilation_w, dtype=offset_np.dtype),

        for c, kh, kw in itertools.product(range(in_channel), range(kernel_h), range(kernel_w)):
            dg = c // ic_per_dgroup
            index_h = index_h_base + offset[dg, ..., 0]
            index_w = index_w_base + offset[dg, ..., 1]

            y, x = index_h[kh, kw], index_w[kh, kw]
            if y < 0 or y >= in_height or x < 0 or x >= in_width:
            a_deform[n, c, h, w, kh, kw] = _bilinear(n, c, y, x)

    b_np = np.zeros((batch, out_channel, out_height, out_width), dtype=dtype)
    for n, c, f, h, w in itertools.product(range(batch), range(in_channel), range(out_channel),
                                           range(out_height), range(out_width)):
        b_np[n, f, h, w] += np.tensordot(a_deform[n, c, h, w], w_np[f, c])

    return b_np
def _depthwise_spatial_pack(args, data, kernel, strides, padding, dilation,
    """depthwise_conv2d_arm_cpu's inner implement"""
    is_var, u_vh, u_vw, u_vc = args
    out_dtype = out_dtype or data.dtype

    u_n, u_c, ih, iw = data.shape if is_var else get_const_tuple(data.shape)

    if isinstance(dilation, int):
        dilation_h = dilation_w = dilation
        dilation_h, dilation_w = dilation

    if len(kernel.shape) == 4:
        pre_packed = False
        u_c, um, ukh, ukw = kernel.shape if is_var else get_const_tuple(
    else:  # kernel tensor is pre packed
        pre_packed = True
        u_c, um, ukh, ukw, u_vc = kernel.shape if is_var else get_const_tuple(
        u_c = u_c * u_vc

    dilated_kernel_h = (ukh - 1) * dilation_h + 1
    dilated_kernel_w = (ukw - 1) * dilation_w + 1

    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
        padding, (dilated_kernel_h, dilated_kernel_w))
    hstr, wstr = strides if isinstance(strides,
                                       (tuple, list)) else (strides, strides)
    u_oh = (ih + pad_top + pad_down - dilated_kernel_h) // hstr + 1
    u_ow = (iw + pad_left + pad_right - dilated_kernel_w) // wstr + 1
    # pack data
    hpad = pad_top + pad_down
    wpad = pad_left + pad_right
    dopad = hpad != 0 or wpad != 0
    if dopad:
        data_pad = pad(
            (0, 0, pad_top, pad_left),
            (0, 0, pad_down, pad_right),
        data_pad = data

    oh_div = u_oh // u_vh
    ow_div = u_ow // u_vw
    kvshape = (u_c // u_vc, um, ukh, ukw, u_vc)
    ovshape = (u_n, u_c * um // u_vc, oh_div, u_ow // u_vw, u_vh, u_vw, u_vc)
    oshape = (u_n, u_c * um, oh_div * u_vh, ow_div * u_vw)

    if dilation_h != 1 or dilation_w != 1:
        # undilate input data
        dvshape = (u_n, oh_div, ow_div, u_c, ukh, ukw, u_vh, u_vw)
        data_vec = tvm.compute(
            lambda n, h, w, c, kh, kw, vh, vw: data_pad[n][c][
                (h * u_vh + vh) * hstr + kh * dilation_h][
                    (w * u_vw + vw) * wstr + kw * dilation_w],
        dvshape = (u_n, oh_div, ow_div, u_c, u_vh * hstr + ukh - 1,
                   u_vw * wstr + ukw - 1)
        data_vec = tvm.compute(
            lambda n, h, w, c, vh, vw: data_pad[n][c][h * u_vh * hstr + vh][
                w * u_vw * wstr + vw],

    if pre_packed:
        kernel_vec = kernel
        kernel_vec = tvm.compute(
            lambda co, m, kh, kw, vc: kernel[co * u_vc + vc][m][kh][kw],

    kh = tvm.reduce_axis((0, ukh), name="kh")
    kw = tvm.reduce_axis((0, ukw), name="kw")

    if dilation_h != 1 or dilation_w != 1:
        conv = tvm.compute(
            lambda n, co, h, w, vh, vw, vc: tvm.sum(
                data_vec[n, h, w, (co * u_vc + vc) // um, kh, kw, vh, vw].
                astype(out_dtype) * kernel_vec[co // um, co % um, kh, kw, vc
                axis=[kh, kw],
        conv = tvm.compute(
            lambda n, co, h, w, vh, vw, vc: tvm.sum(
                data_vec[n, h, w, (co * u_vc + vc) // um, vh * hstr + kh, vw *
                         wstr + kw].astype(out_dtype) * kernel_vec[
                             co // um, co % um, kh, kw, vc].astype(out_dtype),
                axis=[kh, kw],

    output = tvm.compute(
        lambda n, co, h, w: conv[n][co // u_vc][h // u_vh][w // u_vw][h % u_vh]
        [w % u_vw][co % u_vc],
    return output
def verify_conv2d_NHWC_gemm_int8(batch,
    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(
        padding, (kernel, kernel))
    padding_sum = pad_top + pad_left + pad_bottom + pad_right
    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" %
          (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum,

    in_height = in_width = in_size

    A = te.placeholder((batch, in_height, in_width, in_channel),
    W = te.placeholder((kernel, kernel, in_channel, num_filter),
    bias = te.placeholder((num_filter, ), name='bias', dtype='int8')

    a_shape = get_const_tuple(A.shape)
    w_shape = get_const_tuple(W.shape)
    bias_shape = get_const_tuple(bias.shape)
    dtype = A.dtype

    def get_ref_data():
        a_np = np.random.randint(low=-128, high=127,
        w_np = np.random.randint(low=-128, high=128,
        b_np = np.random.uniform(size=bias_shape).astype(dtype)
        dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1))
        c_np = topi.testing.conv2d_nhwc_python(a_np, dw_np, stride,

        if add_bias:
            b_np = np.random.uniform(size=bias_shape).astype(dtype)
            c_np += b_np
        if add_relu:
            c_np = np.maximum(c_np, 0)

        return a_np, w_np, b_np, c_np

    a_np, w_np, b_np, c_np = get_ref_data()

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
        print("Running on target: %s" % device)
            C = topi.arm_cpu.compute_conv2d_NHWC_quantized(
                A, W, (stride, stride), padding, (dilation, dilation), dtype)
            if add_bias:
                C = topi.add(C, bias)
            if add_relu:
                C = topi.nn.relu(C)
            s = topi.arm_cpu.schedule_conv2d_NHWC_quantized([C])

        a = tvm.nd.array(a_np, ctx)
        w = tvm.nd.array(w_np, ctx)
        b = tvm.nd.array(b_np, ctx)
        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype),
        if add_bias:
  , [A, W, bias, C],
                      name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
                      (batch, in_channel, in_size, num_filter, kernel, stride,
                       padding_sum, dilation))
            func =, [A, W, bias, C],
                             name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
                             (batch, in_channel, in_size, num_filter, kernel,
                              stride, padding_sum, dilation))
            func(a, w, b, c)
            func =, [A, W, C],
                             name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
                             (batch, in_channel, in_size, num_filter, kernel,
                              stride, padding_sum, dilation))
            func(a, w, c)
        tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)

def verify_conv2d_nchw(
        devices=['cuda', 'llvm -device=arm_cpu', 'opencl -device=mali']):
    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(
        padding, (kernel, kernel))
    padding_sum = pad_top + pad_left + pad_bottom + pad_right
    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" %
          (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum,

    in_height = in_width = in_size

    A = te.placeholder((batch, in_channel, in_height, in_width), name='A')
    W = te.placeholder((num_filter, in_channel, kernel, kernel), name='W')
    bias = te.placeholder((num_filter, 1, 1), name='bias')

    a_shape = get_const_tuple(A.shape)
    w_shape = get_const_tuple(W.shape)
    bias_shape = get_const_tuple(bias.shape)
    dtype = A.dtype

    def get_ref_data():
        a_np = np.random.uniform(size=a_shape).astype(dtype)
        w_np = np.random.uniform(size=w_shape).astype(dtype)
        b_np = np.random.uniform(size=bias_shape).astype(dtype)
        dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
        c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding)
        if add_bias:
            b_np = np.random.uniform(size=bias_shape).astype(dtype)
            c_np += b_np
        if add_relu:
            c_np = np.maximum(c_np, 0)
        return a_np, w_np, b_np, c_np

    a_np, w_np, b_np, c_np = get_ref_data()

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
        print("Running on target: %s" % device)
            fcompute, fschedule = topi.testing.dispatch(
                device, _conv2d_nchw_winograd_implement)
            C = fcompute(A, W, stride, padding, dilation, dtype)
            if add_bias:
                C = topi.add(C, bias)
            if add_relu:
                C = topi.nn.relu(C)
            s = fschedule([C])

        a = tvm.nd.array(a_np, ctx)
        w = tvm.nd.array(w_np, ctx)
        b = tvm.nd.array(b_np, ctx)
        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype),
        if add_bias:
            func =, [A, W, bias, C],
                             name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
                             (batch, in_channel, in_size, num_filter, kernel,
                              stride, padding_sum, dilation))
            func(a, w, b, c)
            func =, [A, W, C],
                             name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
                             (batch, in_channel, in_size, num_filter, kernel,
                              stride, padding_sum, dilation))
            func(a, w, c)

        rtol = 1e-3
        tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=rtol)

    for device in devices:
def conv2d_direct_simd_compute(cfg, data, kernel, strides, padding, dilation, out_dtype):
    """Compute function for Cortex-M7 SIMD implementation of conv2d."""
    assert isinstance(strides, int) or len(strides) == 2
    assert isinstance(dilation, int) or len(dilation) == 2

    if isinstance(strides, int):
        stride_h = stride_w = strides
        stride_h, stride_w = strides

    if isinstance(dilation, int):
        dilation_h = dilation_w = dilation
        dilation_h, dilation_w = dilation

    batch_size, in_height, in_width, in_channels = data.shape
    kernel_h, kernel_w, out_channels, _ = kernel.shape

    # compute the output shape
    dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
    dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
        padding, (dilated_kernel_h, dilated_kernel_w))
    out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
    out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)

    pad_before = [0, pad_top, pad_left, 0]
    pad_after = [0, pad_down, pad_right, 0]
    padded_data = pad(data, pad_before, pad_after, name='padded_data')

    rc = te.reduce_axis((0, in_channels), name='rc')
    ry = te.reduce_axis((0, kernel_h), name='ry')
    rx = te.reduce_axis((0, kernel_w), name='rx')

    conv = te.compute(
        (batch_size, out_height, out_width, out_channels),
        lambda nn, yy, xx, ff: te.sum(
            padded_data[nn, yy * stride_h + ry * dilation_h,
                        xx * stride_w + rx * dilation_w, rc].astype(out_dtype) *
            kernel[ry, rx, ff, rc].astype(out_dtype), axis=[ry, rx, rc]),
        name='conv2d', tag='conv2d_nhwc')

    # Config Space Definition #
    n, oh, ow, co = (cfg.axis(batch_size.value),
    kh, kw, ci = (cfg.reduce_axis(kernel_h.value),

    assert in_channels.value % 4 == 0
    owo, owi = cfg.define_split('tile_ow', ow, policy='factors', num_outputs=2)
    cio, cii = cfg.define_split('tile_ci', ci, policy='factors', num_outputs=2,
                                filter=lambda x: x.size[-1] % 4 == 0)
    coo, coi = cfg.define_split('tile_co', co, policy='factors', num_outputs=2)

                       [n, oh, owo, owi, coo, coi, kh, kw, cio, cii],
                       policy='candidate', candidate=[
                           [n, oh, kh, kw, owo, coo, cio, owi, coi, cii],
                           [n, oh, kh, kw, coo, owo, cio, owi, coi, cii],
                           [n, kh, kw, oh, owo, coo, cio, owi, coi, cii],
                           [n, kh, kw, oh, coo, owo, cio, owi, coi, cii]])

    cfg.define_knob('auto_unroll_max_step', [0, 2, 4, 8, 16, 32])
    cfg.define_knob('unroll_explicit', [0, 1])

    return conv
def depthwise_conv2d_with_workload_NCHWc(batch, in_channel, in_height, channel_multiplier, filter_height, stride, padding, dilation=1):
    in_width = in_height
    filter_channel = in_channel
    filter_width = filter_height
    stride_h = stride_w = stride

    assert dilation == 1, "depthwise_conv2d_NCHWc currently does not support dilation."
    pad_h, pad_w, _, _ = get_pad_tuple(padding, (filter_height, filter_width))
    padding_args = (pad_h, pad_w)

    out_channel = filter_channel * channel_multiplier
    # for testing functionality,
    # we choose arbitrary block size that can divide the channel,
    # regardless of the performance.
    oc_block = 1
    for bn in range(16, 0, -1):
        if out_channel % bn == 0:
            oc_block = bn

    ic_block = 1
    for bn in range(oc_block, 0, -1):
        if in_channel % bn == 0:
            ic_block = bn

    # placeholder
    Input = tvm.placeholder((batch, in_channel//ic_block, in_height, in_width, ic_block), name='Input')
    Filter = tvm.placeholder((out_channel//oc_block, 1, filter_height, filter_width, 1, oc_block), name='Filter')
    in_layout = "NCHW%dc" % ic_block
    out_layout = "NCHW%dc" % oc_block
    dtype = 'float32'

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
        print("Running on target: %s" % device)
            # declare
            DepthwiseConv2d = topi.nn.depthwise_conv2d_NCHWc(Input, Filter,
                                                             (stride_h, stride_w),
                                                             (dilation, dilation),
                                                             out_layout, dtype)
            # TODO: add scale_shift implement for NCHWc and add test here
            Relu = topi.nn.relu(DepthwiseConv2d)
            # schedule
            s1 = topi.generic.schedule_depthwise_conv2d_nchw(DepthwiseConv2d)
            s2 = topi.generic.schedule_depthwise_conv2d_nchw(Relu)
        # build the kernels
        f1 =, [Input, Filter, DepthwiseConv2d], device)
        f2 =, [Input, Filter, Relu], device)

        # Prepare pod type for test data closure
        input_shape = (batch, in_channel, in_height, in_width)
        filter_shape = (filter_channel, channel_multiplier, filter_height, filter_width)

        # Use memoize, pickle the test data for next time use.
        def get_ref_data():
            input_np = np.random.uniform(size=input_shape).astype(dtype)
            filter_np = np.random.uniform(size=filter_shape).astype(dtype)
            # correctness with scipy
            depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nchw(
                input_np, filter_np, stride, padding)
            relu_scipy = np.maximum(depthwise_conv2d_scipy, 0)
            return (_transform_data(input_np, ic_block),
                    _transform_kernel(filter_np, oc_block),
                    _transform_data(depthwise_conv2d_scipy, oc_block),
                    _transform_data(relu_scipy, oc_block))

        # Get the test data
        (input_np, filter_np, depthwise_conv2d_scipy, relu_scipy) = get_ref_data()

        input_tvm = tvm.nd.array(input_np, ctx)
        filter_tvm = tvm.nd.array(filter_np, ctx)

        depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape),
                                                     dtype=DepthwiseConv2d.dtype), ctx)
        relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx)
        # launch kernel 1 (depthwise_conv2d)
        f1(input_tvm, filter_tvm, depthwise_conv2d_tvm)
        # launch kernel 2 (depthwise_conv2d + relu)
        f2(input_tvm, filter_tvm, relu_tvm)
        tvm.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5)
        tvm.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)

    # test llvm only for now since depthwise_conv2d_NCHWc implement is missing in other backend.
    for device in ["llvm"]:
        with autotvm.tophub.context(device):  # load tophub pre-tuned parameters
def _depthwise_conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None):
    """Depthwise convolution nchw forward operator.

    Input : tvm.Tensor
        4-D with shape [batch, in_channel, in_height, in_width]

    Filter : tvm.Tensor
        4-D with shape [in_channel, channel_multiplier, filter_height, filter_width]

    stride : tuple of two ints
        The spatial stride along height and width

    padding : int or str
        Padding size, or ['VALID', 'SAME']

    dilation: int or a list/tuple of two ints
        dilation size, or [dilation_height, dilation_width]

    out_dtype: str, optional
        Output data type

    Output : tvm.Tensor
        4-D with shape [batch, out_channel, out_height, out_width]
    out_dtype = Input.dtype if out_dtype is None else out_dtype

    if isinstance(stride, int):
        stride_h = stride_w = stride
        stride_h, stride_w = stride

    if isinstance(dilation, int):
        dilation_h = dilation_w = dilation
        dilation_h, dilation_w = dilation

    batch, in_channel, in_height, in_width = Input.shape
    # shape of dilated kernel
    filter_channel, channel_multiplier, filter_height, filter_width = Filter.shape

    dilated_kernel_h = (filter_height - 1) * dilation_h + 1
    dilated_kernel_w = (filter_width - 1) * dilation_w + 1
    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
        padding, (dilated_kernel_h, dilated_kernel_w))
    out_channel = simplify(in_channel * channel_multiplier)
    out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
    out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)

    # padding stage
    pad_before = [0, 0, pad_top, pad_left]
    pad_after = [0, 0, pad_down, pad_right]
    PaddedInput = topi.nn.pad(Input, pad_before, pad_after, name="PaddedInput")
    # depthconv stage
    di = tvm.te.reduce_axis((0, filter_height), name='di')
    dj = tvm.te.reduce_axis((0, filter_width), name='dj')
    Output = tvm.te.compute(
        (batch, out_channel, out_height, out_width),
        lambda b, c, i, j: tvm.te.sum(
            (PaddedInput[b, c/channel_multiplier, i*stride_h+di*dilation_h,
                         j*stride_w+dj*dilation_w].astype(out_dtype) *
             Filter[c/channel_multiplier, c%channel_multiplier, di, dj].astype(out_dtype)),
            axis=[di, dj]),
        name='DepthwiseConv2d', tag="depthwise_conv2d_nchw")
    return Output
def decl_winograd(cfg,
    # return _baseline_winograd(cfg, data, kernel, strides, padding, layout, out_dtype)
    N, CI, IH, IW = get_const_tuple(data.shape)
    CO, _, KH, KW = get_const_tuple(kernel.shape)
    HSTR, WSTR = strides if isinstance(strides,
                                       (tuple, list)) else (strides, strides)
    HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)

    assert layout == 'NCHW'
    assert KH == 3 and KW == 3 and HPAD == 1 and WPAD == 1 and HSTR == 1 and WSTR == 1
    data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")

    A_data = np.array(
        [[1, 1, 1, 1, 1, 32, 32, 0], [0, 1, -1, 2, -2, 16, -16, 0],
         [0, 1, 1, 4, 4, 8, 8, 0], [0, 1, -1, 8, -8, 4, -4, 0],
         [0, 1, 1, 16, 16, 2, 2, 0], [0, 1, -1, 32, -32, 1, -1, 1]],
    G_data = np.array(
        [[1, 0, 0], [-2 / 9, -2 / 9, -2 / 9], [-2 / 9, 2 / 9, -2 / 9],
         [1 / 90, 1 / 45, 2 / 45], [1 / 90, -1 / 45, 2 / 45],
         [1 / 45, 1 / 90, 1 / 180], [1 / 45, -1 / 90, 1 / 180], [0, 0, 1]],
    B_data = np.array([[1, 0, -21 / 4, 0, 21 / 4, 0, -1, 0],
                       [0, 1, 1, -17 / 4, -17 / 4, 1, 1, 0],
                       [0, -1, 1, 17 / 4, -17 / 4, -1, 1, 0],
                       [0, 1 / 2, 1 / 4, -5 / 2, -5 / 4, 2, 1, 0],
                       [0, -1 / 2, 1 / 4, 5 / 2, -5 / 4, -2, 1, 0],
                       [0, 2, 4, -5 / 2, -5, 1 / 2, 1, 0],
                       [0, -2, 4, 5 / 2, -5, -1 / 2, 1, 0],
                       [0, -1, 0, 21 / 4, 0, -21 / 4, 0, 1]],

    m = A_data.shape[1]
    r = 3
    alpha = m + r - 1

    C = CI

    H = (IH + 2 * HPAD - 3) // HSTR + 1
    W = (IW + 2 * WPAD - 3) // WSTR + 1
    nH, nW = (H + m - 1) // m, (W + m - 1) // m

    def round_up(a, b):
        return ((a + b - 1) // b) * b

    K = round_up(CO, VK)
    P = round_up(N * nH * nW, VP)

    assert K % VK == 0
    assert P % VP == 0

    G = const_matrix(G_data, 'G')
    r_kh = tvm.reduce_axis((0, KH), 'r_kh')
    r_kw = tvm.reduce_axis((0, KW), 'r_kw')
    assert K >= CO
    if K > CO:
        kernel_pad = pad(kernel, (0, 0, 0, 0), (K - CO, 0, 0, 0),
        kernel_pad = kernel
    input_tile = tvm.placeholder(shape=(P // VP, C, alpha, alpha, VP),
    U = tvm.placeholder(shape=(K // VK, alpha, alpha, C, VK),

    #U = tvm.compute(
    #    (K // VK, alpha, alpha, C, VK), lambda k, eps, nu, c, kk:
    #    tvm.sum(kernel_pad[k * VK + kk][c][r_kh][r_kw].astype(out_dtype) *
    #            G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]), name='U')

    ## pack input tile
    #input_tile = tvm.compute((P // VP, C, alpha, alpha, VP),
    #                         lambda b, c, eps, nu, bb:
    #                         data_pad[(b*VP+bb) // (nH*nW)][c][(b*VP+bb) // nW % nH * m + eps]
    #                         [(b*VP+bb) % nW * m + nu],
    #                         name='d')

    def compute_B_T_dot_X(b, c, eps, nu, bb):
        temp_expr = {}
        for j in range(alpha):
            wd0 = input_tile[b][c][0][j][bb] - input_tile[b][c][6][j][bb]
            d4_sub_d2 = input_tile[b][c][4][j][bb] - input_tile[b][c][2][j][bb]
            wd7 = input_tile[b][c][7][j][bb] - input_tile[b][c][1][j][bb]
            d3_sub_d5 = input_tile[b][c][3][j][bb] - input_tile[b][c][5][j][bb]
            wd1 = input_tile[b][c][2][j][bb] + input_tile[b][c][6][j][bb]
            wd2 = input_tile[b][c][1][j][bb] + input_tile[b][c][5][j][bb]
            wd4 = input_tile[b][c][5][j][bb] + input_tile[b][c][1][j][bb] * 0.25
            wd5 = input_tile[b][c][6][j][bb] - input_tile[b][c][4][j][bb] * 5
            wd3 = input_tile[b][c][6][j][bb] + input_tile[b][c][2][j][bb] * 0.25
            wd6 = input_tile[b][c][1][j][bb] + input_tile[b][c][5][j][bb] * 0.25

            wd0 = wd0 + d4_sub_d2 * 5.25
            wd7 = wd7 + d3_sub_d5 * 5.25

            wd1 = wd1 - input_tile[b][c][4][j][bb] * 4.25
            wd2 = wd2 - input_tile[b][c][3][j][bb] * 4.25

            wd3 = wd3 - input_tile[b][c][4][j][bb] * 1.25
            wd5 = wd5 + input_tile[b][c][2][j][bb] * 4
            wd4 = wd4 - input_tile[b][c][3][j][bb] * 1.25
            wd6 = wd6 - input_tile[b][c][3][j][bb] * 1.25

            temp_expr[(0, j)] = wd0
            temp_expr[(1, j)] = wd1 + wd2
            temp_expr[(2, j)] = wd1 - wd2
            temp_expr[(3, j)] = wd3 + wd4 * 2
            temp_expr[(4, j)] = wd3 - wd4 * 2
            temp_expr[(5, j)] = wd5 + wd6 * 2
            temp_expr[(6, j)] = wd5 - wd6 * 2
            temp_expr[(7, j)] = wd7

        now = tvm.const(0.0, "float32")
        for ii in range(alpha):
            for jj in range(alpha):
                now = == ii, nu == jj),
                                 temp_expr[(ii, jj)], now)
        return now

    B_T_dot_X = tvm.compute((P // VP, C, alpha, alpha, VP),

    def compute_X_dot_B(b, eps, nu, c, bb):
        temp_expr = {}

        for i in range(alpha):
            wd0 = B_T_dot_X[b][c][i][0][bb] - B_T_dot_X[b][c][i][6][bb]
            d4_sub_d2 = B_T_dot_X[b][c][i][4][bb] - B_T_dot_X[b][c][i][2][bb]
            wd7 = B_T_dot_X[b][c][i][7][bb] - B_T_dot_X[b][c][i][1][bb]
            d3_sub_d5 = B_T_dot_X[b][c][i][3][bb] - B_T_dot_X[b][c][i][5][bb]
            wd1 = B_T_dot_X[b][c][i][2][bb] + B_T_dot_X[b][c][i][6][bb]
            wd2 = B_T_dot_X[b][c][i][1][bb] + B_T_dot_X[b][c][i][5][bb]
            wd4 = B_T_dot_X[b][c][i][5][bb] + B_T_dot_X[b][c][i][1][bb] * 0.25
            wd5 = B_T_dot_X[b][c][i][6][bb] - B_T_dot_X[b][c][i][4][bb] * 5
            wd3 = B_T_dot_X[b][c][i][6][bb] + B_T_dot_X[b][c][i][2][bb] * 0.25
            wd6 = B_T_dot_X[b][c][i][1][bb] + B_T_dot_X[b][c][i][5][bb] * 0.25

            wd0 = wd0 + d4_sub_d2 * 5.25
            wd7 = wd7 + d3_sub_d5 * 5.25

            wd1 = wd1 - B_T_dot_X[b][c][i][4][bb] * 4.25
            wd2 = wd2 - B_T_dot_X[b][c][i][3][bb] * 4.25

            wd3 = wd3 - B_T_dot_X[b][c][i][4][bb] * 1.25
            wd5 = wd5 + B_T_dot_X[b][c][i][2][bb] * 4
            wd4 = wd4 - B_T_dot_X[b][c][i][3][bb] * 1.25
            wd6 = wd6 - B_T_dot_X[b][c][i][3][bb] * 1.25

            temp_expr[(i, 0)] = wd0
            temp_expr[(i, 1)] = wd1 + wd2
            temp_expr[(i, 2)] = wd1 - wd2
            temp_expr[(i, 3)] = wd3 + wd4 * 2
            temp_expr[(i, 4)] = wd3 - wd4 * 2
            temp_expr[(i, 5)] = wd5 + wd6 * 2
            temp_expr[(i, 6)] = wd5 - wd6 * 2
            temp_expr[(i, 7)] = wd7

        now = tvm.const(0.0, "float32")
        for ii in range(alpha):
            for jj in range(alpha):
                now = == ii, nu == jj),
                                 temp_expr[(ii, jj)], now)
        return now

    V = tvm.compute((P // VP, alpha, alpha, C, VP), compute_X_dot_B, name="V")

    # batch gemm
    c = tvm.reduce_axis((0, C), name='c')
    M = tvm.compute((K // VK, P // VP, alpha, alpha, VK, VP),
                    lambda k, b, eps, nu, kk, bb: tvm.sum(
                        U[k][eps][nu][c][kk] * V[b][eps][nu][c][bb], axis=c),

    def compute_A_T_dot_M(k, b, eps, nu, kk, bb):
        temp_expr = {}

        for j in range(alpha):
            m1_add_m2 = M[k][b][1][j][kk][bb] + M[k][b][2][j][kk][bb]
            m1_sub_m2 = M[k][b][1][j][kk][bb] - M[k][b][2][j][kk][bb]
            m3_add_m4 = M[k][b][3][j][kk][bb] + M[k][b][4][j][kk][bb]
            m3_sub_m4 = M[k][b][3][j][kk][bb] - M[k][b][4][j][kk][bb]
            m5_add_m6 = M[k][b][5][j][kk][bb] + M[k][b][6][j][kk][bb]
            m5_sub_m6 = M[k][b][5][j][kk][bb] - M[k][b][6][j][kk][bb]
            s0 = M[k][b][0][j][kk][bb] + m1_add_m2
            s5 = M[k][b][7][j][kk][bb] + m1_sub_m2
            s1 = m1_sub_m2 + m5_sub_m6 * 16
            s4 = m1_add_m2 + m3_add_m4 * 16
            s2 = m1_add_m2 + 8 * m5_add_m6
            s3 = m1_sub_m2 + 8 * m3_sub_m4
            s0 = s0 + m5_add_m6 * 32
            s5 = s5 + m3_sub_m4 * 32
            s1 = s1 + m3_sub_m4 * 2
            s4 = s4 + m5_add_m6 * 2
            s0 = s0 + m3_add_m4
            s5 = s5 + m5_sub_m6
            s2 = s2 + m3_add_m4 * 4
            s3 = s3 + m5_sub_m6 * 4
            temp_expr[(0, j)] = s0
            temp_expr[(1, j)] = s1
            temp_expr[(2, j)] = s2
            temp_expr[(3, j)] = s3
            temp_expr[(4, j)] = s4
            temp_expr[(5, j)] = s5
        now = tvm.const(0.0, "float32")
        for ii in range(m):
            for jj in range(alpha):
                now = == ii, nu == jj),
                                 temp_expr[(ii, jj)], now)
        return now

    A_T_dot_M = tvm.compute((K // VK, P // VP, m, alpha, VK, VP),

    def compute_X_dot_A(k, b, eps, nu, kk, bb):
        temp_expr = {}

        for i in range(m):
            m1_add_m2 = A_T_dot_M[k][b][i][1][kk][bb] + A_T_dot_M[k][b][i][2][
            m1_sub_m2 = A_T_dot_M[k][b][i][1][kk][bb] - A_T_dot_M[k][b][i][2][
            m3_add_m4 = A_T_dot_M[k][b][i][3][kk][bb] + A_T_dot_M[k][b][i][4][
            m3_sub_m4 = A_T_dot_M[k][b][i][3][kk][bb] - A_T_dot_M[k][b][i][4][
            m5_add_m6 = A_T_dot_M[k][b][i][5][kk][bb] + A_T_dot_M[k][b][i][6][
            m5_sub_m6 = A_T_dot_M[k][b][i][5][kk][bb] - A_T_dot_M[k][b][i][6][
            s0 = A_T_dot_M[k][b][i][0][kk][bb] + m1_add_m2
            s5 = A_T_dot_M[k][b][i][7][kk][bb] + m1_sub_m2
            s1 = m1_sub_m2 + m5_sub_m6 * 16
            s4 = m1_add_m2 + m3_add_m4 * 16
            s2 = m1_add_m2 + 8 * m5_add_m6
            s3 = m1_sub_m2 + 8 * m3_sub_m4
            s0 = s0 + m5_add_m6 * 32
            s5 = s5 + m3_sub_m4 * 32
            s1 = s1 + m3_sub_m4 * 2
            s4 = s4 + m5_add_m6 * 2
            s0 = s0 + m3_add_m4
            s5 = s5 + m5_sub_m6
            s2 = s2 + m3_add_m4 * 4
            s3 = s3 + m5_sub_m6 * 4
            temp_expr[(i, 0)] = s0
            temp_expr[(i, 1)] = s1
            temp_expr[(i, 2)] = s2
            temp_expr[(i, 3)] = s3
            temp_expr[(i, 4)] = s4
            temp_expr[(i, 5)] = s5
        now = tvm.const(0.0, "float32")
        for ii in range(m):
            for jj in range(m):
                now = == ii, nu == jj),
                                 temp_expr[(ii, jj)], now)
        return now

    Y = tvm.compute((K // VK, P // VP, m, m, VK, VP),

    # unpack output
    def _output(n, k_, h, w):
        b_idx = n * nH * nW + (h // m) * nW + w // m
        b = b_idx // VP
        bb = b_idx % VP
        k = k_ // VK
        kk = k_ % VK
        return Y[k][b][h % m][w % m][kk][bb]

    output = tvm.compute((N, CO, H, W),

    if cfg:
        cfg.add_flop(2 * N * K * H * W * KH * KW * C)

    return Y, input_tile, U, output
def depth_1by1_fused(Input,
    """Fused depthwise convolution + 1x1 convolution forward operator (NCHW & NHWC).

    Input : tvm.Tensor
        4-D with shape [batch, in_channel, in_height, in_width] (NCHW)
                    or [batch, in_height, in_width, in_channel] (NHWC)

    Filter_d : tvm.Tensor
        4-D with shape [in_channel, in_channel * channel_multiplier, filter_height, filter_width]
                    or [filter_height, filter_width, in_channel, in_channel * channel_multiplier]

    Filter_1 : tvm.Tensor
        4-D with shape [out_channel, in_channel * channel_multiplier, 0, 0]
                    or [0, 0, out_channel, in_channel * channel_multiplier]

    stride_d : tuple of two ints
        The spatial stride along height and width

    padding_d : int or str
        Padding size, or ['VALID', 'SAME']

    dilation_d: int or a list/tuple of two ints
        dilation size, or [dilation_height, dilation_width]

    out_dtype: str, optional
        Output data type

    output : tvm.Tensor
        4-D with shape [batch, out_height, out_width, out_channel]

    assert layout in ["NCHW", "NHWC"]

    out_dtype = Input.dtype if out_dtype is None else out_dtype

    if isinstance(stride_d, int):
        stride_h_d = stride_w_d = stride_d
        stride_h_d, stride_w_d = stride_d

    if isinstance(dilation_d, int):
        dilation_h_d = dilation_w_d = dilation_d
        dilation_h_d, dilation_w_d = dilation_d

    if layout == "NCHW":
        if dilation_h_d != 1 or dilation_w_d != 1:
            Filter_d = dilate(Filter_d, (1, 1, dilation_h_d, dilation_w_d))
        batch, in_channel_d, in_height_d, in_width_d = Input.shape
        filter_channel, _, filter_height, filter_width = Filter_d.shape
        num_filter, channel, _, _ = Filter_1.shape
    else:  # NHWC
        if dilation_h_d != 1 or dilation_w_d != 1:
            Filter_d = dilate(Filter_d, (dilation_h_d, dilation_w_d, 1, 1))
        batch, in_height_d, in_width_d, in_channel_d = Input.shape
        filter_height, filter_width, filter_channel, _ = Filter_d.shape
        _, _, num_filter, channel = Filter_1.shape

    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
        padding_d, (filter_height, filter_width))
    out_channel = simplify(in_channel_d)
    out_height = simplify((in_height_d - filter_height + pad_top + pad_down) //
                          stride_h_d + 1)
    out_width = simplify((in_width_d - filter_width + pad_left + pad_right) //
                         stride_w_d + 1)
    out_channel = num_filter

    # padding stage
    if layout == "NCHW":
        pad_before = [0, 0, pad_top, pad_left]
        pad_after = [0, 0, pad_down, pad_right]
    else:  # NHWC
        pad_before = [0, pad_top, pad_left, 0]
        pad_after = [0, pad_down, pad_right, 0]

    PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")

    # depthconv stage
    di = tvm.reduce_axis((0, filter_height), name='di')
    dj = tvm.reduce_axis((0, filter_width), name='dj')
    # 1by1 stage
    c = tvm.reduce_axis((0, out_channel), name='c')

    if layout == "NCHW":
        Output = tvm.compute(
            (batch, out_channel, out_height, out_width),
            lambda b, f, i, j: tvm.sum(
                (PaddedInput[b, c, i * stride_h_d + di, j * stride_w_d + dj].
                 astype(out_dtype) * Filter_d[c, 0, di, dj].astype(
                     out_dtype) * Filter_1[f, c, 0, 0].astype(out_dtype)),
                axis=[di, dj, c]),
    else:  # NHWC
        Output = tvm.compute(
            (batch, out_height, out_width, out_channel),
            lambda b, i, j, f: tvm.sum(
                (PaddedInput[b, i * stride_h_d + di, j * stride_w_d + dj, c].
                 astype(out_dtype) * Filter_d[di, dj, c, 0].astype(
                     out_dtype) * Filter_1[0, 0, c, f].astype(out_dtype)),
                axis=[di, dj, c]),
    return Output
def _conv_spatial_pack_asm(args, data, kernel, strides, padding, dilation,
    is_var, vh_, vw_, vc_ = args

    # create workload according to raw arguments
    out_dtype = out_dtype or data.dtype
    n_, ci_, ih_, iw_ = data.shape if is_var else get_const_tuple(data.shape)

    if isinstance(dilation, int):
        dilation_h = dilation_w = dilation
        dilation_h, dilation_w = dilation

    if len(kernel.shape) == 4:
        pre_packed = False
        co_, _, kh_, kw_ = kernel.shape if is_var else get_const_tuple(
    else:  # kernel tensor is pre packed
        pre_packed = True
        co_, _, kh_, kw_, vc_ = kernel.shape if is_var else get_const_tuple(
        co_ = co_ * vc_

    dilated_kernel_h = (kh_ - 1) * dilation_h + 1
    dilated_kernel_w = (kw_ - 1) * dilation_w + 1
    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(
        padding, (dilated_kernel_h, dilated_kernel_w))
    hstr, wstr = strides if isinstance(strides,
                                       (tuple, list)) else (strides, strides)
    oh_ = (ih_ + pad_top + pad_bottom - dilated_kernel_h) // hstr + 1
    ow_ = (iw_ + pad_left + pad_right - dilated_kernel_w) // wstr + 1
    data_pad = pad(data, [0, 0, pad_top, pad_left],
                   [0, 0, pad_bottom, pad_right])

    oh_div = oh_ // vh_
    ow_div = ow_ // vw_
    kvshape = (co_ // vc_, ci_, kh_, kw_, vc_)
    ovshape = (n_, co_ // vc_, oh_div, ow_div, vh_, vw_, vc_)
    oshape = (n_, co_, oh_div * vh_, ow_div * vw_)

    if dilation_h != 1 or dilation_w != 1:
        # undilate input data
        dvshape = (n_, oh_ // vh_, ow_ // vw_, kh_, kw_, vh_, vw_, ci_)
        data_vec = tvm.compute(
            lambda n, h, w, kh, kw, vh, vw, ci: data_pad[n][ci][
                (h * vh_ + vh) * hstr + kh * dilation_h][
                    (w * vw_ + vw) * wstr + kw * dilation_w],
        dvshape = (
            oh_ // vh_,
            ow_ // vw_,
            (vh_ - 1) * hstr + kh_,
            (vw_ - 1) * wstr + kw_,
        data_vec = tvm.compute(
            lambda n, h, w, vh, vw, ci: data_pad[n][ci][h * vh_ * hstr + vh][
                w * vw_ * wstr + vw],

    if pre_packed:
        kernel_vec = kernel
        kernel_vec = tvm.compute(
            lambda co, ci, kh, kw, vc: kernel[co * vc_ + vc][ci][kh][kw],

    ci = tvm.reduce_axis((0, ci_), name="ci")
    kh = tvm.reduce_axis((0, kh_), name="kh")
    kw = tvm.reduce_axis((0, kw_), name="kw")

    # asm begin----
    type_map = {
        "int8": "int32",
        "uint8": "uint32",
        "float32": "float32",
        "float16": "float16",
    acum_dtype = type_map[data.dtype]
    attrs = {
        "SH": hstr,
        "SW": wstr,
        "PH": pad_top,
        "PW": pad_left,
        "DILA_H": dilation_h,
        "DILA_W": dilation_w,
        "VH": vh_,
        "VW": vw_,
        "VC": vc_,
        "ACUM_DTYPE": acum_dtype,
    # asm end----

    if dilation_h != 1 or dilation_w != 1:
        conv = tvm.compute(
            lambda n, co, h, w, vh, vw, vc: tvm.sum(
                data_vec[n, h, w, kh, kw, vh, vw, ci].astype(out_dtype) *
                kernel_vec[co, ci, kh, kw, vc].astype(out_dtype),
                axis=[ci, kh, kw],
        conv = tvm.compute(
            lambda n, co, h, w, vh, vw, vc: tvm.sum(
                data_vec[n, h, w, vh * hstr + kh, vw * wstr + kw, ci].astype(
                    out_dtype) * kernel_vec[co, ci, kh, kw, vc].astype(
                axis=[ci, kh, kw],

    output = tvm.compute(
        lambda n, co, h, w: conv[n][co // vc_][h // vh_][w // vw_][h % vh_][
            w % vw_][co % vc_],

    return output
def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_multiplier, filter_height, stride, padding, dilation=1):
    in_width = in_height
    filter_channel = in_channel
    filter_width = filter_height
    stride_h = stride_w = stride

    if dilation == 1:
        # here we transform the padding argument from 'str' to  'tuple' ,
        # because we need this to match the "workload" tuple to the records in TopHub
        pad_h, pad_w, _, _ = get_pad_tuple(padding, (filter_height, filter_width))
        padding_args = (pad_h, pad_w)
        padding_args = padding

    # placeholder
    Input = tvm.placeholder((batch, in_channel, in_height, in_width), name='Input')
    Filter = tvm.placeholder((filter_channel, channel_multiplier, filter_height, filter_width), name='Filter')
    Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale')
    Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift')

    dtype = 'float32'

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
        print("Running on target: %s" % device)
            # declare
            DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, Filter,
                (stride_h, stride_w), padding_args, dilation, dtype)
            ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift)
            Relu = topi.nn.relu(ScaleShift)
            # schedule
            s1 = topi.generic.schedule_depthwise_conv2d_nchw(DepthwiseConv2d)
            s2 = topi.generic.schedule_depthwise_conv2d_nchw(ScaleShift)
            s3 = topi.generic.schedule_depthwise_conv2d_nchw(Relu)
        # build the kernels
        f1 =, [Input, Filter, DepthwiseConv2d], device)
        f2 =, [Input, Filter, Scale, Shift, ScaleShift], device)
        f3 =, [Input, Filter, Scale, Shift, Relu], device)

        # Prepare pod type for test data closure
        input_shape = get_const_tuple(Input.shape)
        filter_shape = get_const_tuple(Filter.shape)
        scale_shape = get_const_tuple(Scale.shape)
        shift_shape = get_const_tuple(Shift.shape)
        scale_shift_shape = get_const_tuple(ScaleShift.shape)

        # Use memoize, pickle the test data for next time use.
        def get_ref_data():
            input_np = np.random.uniform(size=input_shape).astype(dtype)
            filter_np = np.random.uniform(size=filter_shape).astype(dtype)
            dilated_filter_np = topi.testing.dilate_python(filter_np, (1, 1, dilation, dilation))
            scale_np = np.random.uniform(size=scale_shape).astype(dtype)
            shift_np = np.random.uniform(size=shift_shape).astype(dtype)
            # correctness with scipy
            depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nchw(
                input_np, dilated_filter_np, stride, padding)
            scale_shift_scipy = np.zeros(shape=scale_shift_shape)
            for c in range(in_channel * channel_multiplier):
                scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c]
                relu_scipy = np.maximum(scale_shift_scipy, 0)
            return (input_np, filter_np, scale_np, shift_np,
                    depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy)
        # Get the test data
        (input_np, filter_np, scale_np, shift_np,
         depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy) = get_ref_data()

        input_tvm = tvm.nd.array(input_np, ctx)
        filter_tvm = tvm.nd.array(filter_np, ctx)
        scale_tvm = tvm.nd.array(scale_np, ctx)
        shift_tvm = tvm.nd.array(shift_np, ctx)
        depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), ctx)
        scale_shift_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), ctx)
        relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx)
        # launch kernel 1 (depthwise_conv2d)
        timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=1)
        tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean
        # launch kernel 2 (depthwise_conv2d + scale_shift)
        timer_2 = f2.time_evaluator(f2.entry_name, ctx, number=1)
        tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm).mean
        # launch kernel 3 (depthwise_conv2d + scale_shift + relu)
        timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1)
        tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean
        tvm.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5)
        tvm.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5)
        tvm.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)

    for device in get_all_backend():
        with autotvm.tophub.context(device):  # load tophub pre-tuned parameters
def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride,
                       padding, dilation=1, add_bias=False, add_relu=False,
                       devices='cuda', bgemm="direct"):
    """Test the conv2d with winograd for nhwc layout"""
    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
    padding_sum = pad_top + pad_left + pad_bottom + pad_right
    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (
        batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))

    in_height = in_width = in_size

    A = te.placeholder((batch, in_height, in_width, in_channel), name='A')
    W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W')
    bias = te.placeholder((1, 1, 1, num_filter), name='bias')

    a_shape = get_const_tuple(A.shape)
    w_shape = get_const_tuple(W.shape)
    bias_shape = get_const_tuple(bias.shape)
    dtype = A.dtype

    def get_ref_data():
        a_np = np.random.uniform(size=a_shape).astype(dtype)
        w_np = np.random.uniform(size=w_shape).astype(dtype)
        b_np = np.random.uniform(size=bias_shape).astype(dtype)
        dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1))
        c_np = topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding)
        if add_bias:
            b_np = np.random.uniform(size=bias_shape).astype(dtype)
            c_np += b_np
        if add_relu:
            c_np = np.maximum(c_np, 0)
        return a_np, w_np, b_np, c_np

    a_np, w_np, b_np, c_np = get_ref_data()

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
        print("Running on target: %s" % device)
            if bgemm == "direct":
                fcompute, fschedule = topi.testing.dispatch(device,
            elif bgemm == "tensorcore":
                fcompute, fschedule = topi.testing.dispatch(device,
            C = fcompute(A, W, stride, padding, dilation, 'float32')
            if add_bias:
                C = topi.add(C, bias)
            if add_relu:
                C = topi.nn.relu(C)
            s = fschedule([C])

        a = tvm.nd.array(a_np, ctx)
        w = tvm.nd.array(w_np, ctx)
        b = tvm.nd.array(b_np, ctx)
        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
        if add_bias:
            func =, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (
                batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
            func(a, w, b, c)
            func =, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (
                batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
            func(a, w, c)

        tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=2e-3)

def verify_conv2d_NCHWc(batch,
    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(
        padding, (kernel, kernel))
    padding_sum = pad_top + pad_left + pad_bottom + pad_right
    in_height = in_width = in_size
        "Workload: (%d, %d, %d, %d, %d, %d, %d)" %
        (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum))

    # for testing functionality,
    # we choose arbitrary block size that can divide the channel,
    # regardless of the performance.
    oc_block = 1
    for bn in range(16, 0, -1):
        if num_filter % bn == 0:
            oc_block = bn

    ic_block = 1
    for bn in range(oc_block, 0, -1):
        if in_channel % bn == 0:
            ic_block = bn

    A = tvm.placeholder(
        (batch, in_channel // ic_block, in_height, in_width, ic_block),
    W = tvm.placeholder((num_filter // oc_block, in_channel // ic_block,
                         kernel, kernel, ic_block, oc_block),
    bias = tvm.placeholder((num_filter // oc_block, 1, 1, oc_block),

    def get_ref_data():
        a_np = np.random.uniform(size=(batch, in_channel, in_height,
        w_np = np.random.uniform(size=(num_filter, in_channel, kernel,
        b_np = np.random.uniform(size=(num_filter, 1, 1)).astype(dtype)
        dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
        c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding)
        if add_bias:
            c_np += b_np
        if add_relu:
            c_np = np.maximum(c_np, 0)
        return _transform_data(a_np, ic_block), _transform_kernel(w_np, ic_block, oc_block), \
               _transform_bias(b_np, oc_block), _transform_data(c_np, oc_block)

    a_np, w_np, b_np, c_np = get_ref_data()

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
        print("Running on target: %s" % device)
            C = topi.x86.conv2d_NCHWc(A, W, (stride, stride), padding,
                                      (dilation, dilation),
                                      'NCHW%dc' % ic_block,
                                      "NCHW%dc" % oc_block, dtype)
            if add_bias:
                C = topi.add(C, bias)
            if add_relu:
                C = topi.nn.relu(C)
            s = topi.x86.schedule_conv2d_NCHWc([C])

        a = tvm.nd.array(a_np, ctx)
        w = tvm.nd.array(w_np, ctx)
        b = tvm.nd.array(b_np, ctx)
        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype),
        if add_bias:
            func =, [A, W, bias, C],
                             name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
                             (batch, in_channel, in_size, num_filter, kernel,
                              stride, padding_sum, dilation))
            func(a, w, b, c)
            func =, [A, W, C],
                             name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
                             (batch, in_channel, in_size, num_filter, kernel,
                              stride, padding_sum, dilation))
            func(a, w, c)
        tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-3)

    # test llvm only for now since conv2d_NCHWc implement is missing in other backend.
    for device in ["llvm"]:
        with autotvm.tophub.context(
                device):  # load tophub pre-tuned parameters