def get_eliminated():
     a = te.placeholder((100,), dtype="bfloat16")
     b = te.placeholder((100,), dtype="bfloat16")
     c = te.compute(
         (100,),
         lambda i: to16(
             topi.add(
                 to32(
                     to16(
                         topi.add(
                             to32(a[i]),
                             to32(b[i]),
                         )
                     )
                 ),
                 to32(
                     to16(
                         topi.add(
                             to32(a[i]),
                             to32(b[i]),
                         )
                     )
                 ),
             )
         ),
     )
     s = te.create_schedule(c.op)
     stmt = lower_stmt(s, [a, b, c], tvm.tir.transform.BF16CastElimination)
     return stmt
예제 #2
0
def batch_norm_fwd(N, C, H, W, dtype="float32"):
    dshape = (N, C, H, W)
    oshape = (C, )
    bshape = (1, C, 1, 1)
    sshape = (1, )
    data = te.placeholder(dshape, name="data", dtype=dtype)
    scale = te.placeholder(oshape, name="scale", dtype=dtype)
    bias = te.placeholder(oshape, name="bias", dtype=dtype)
    running_mean = te.placeholder(oshape, name="running_mean", dtype=dtype)
    running_var = te.placeholder(oshape, name="running_var", dtype=dtype)
    eps = te.placeholder(sshape, name="eps", dtype=dtype)
    momentum = te.placeholder(sshape, name="momentum", dtype=dtype)

    axis = (0, 2, 3)
    num_ele = dshape[0] * dshape[2] * dshape[3]
    frac_num_ele = 1.0 / num_ele
    # compute batch mean
    mean_sum = topi.sum(data, axis, keepdims=True)
    saved_mean = topi.multiply(mean_sum, frac_num_ele)
    # compute batch rvars
    var_sub = topi.subtract(data, saved_mean)
    var_mul = topi.multiply(var_sub, var_sub)
    var_sum = topi.sum(var_mul, axis, keepdims=True)
    var = topi.multiply(var_sum, frac_num_ele)
    output_add = topi.add(var, eps)
    saved_rvars = topi.sqrt(output_add)
    # # compute output
    output_sub = topi.subtract(data, saved_mean)
    output_norm = topi.divide(output_sub, saved_rvars)
    scale_board = topi.reshape(scale, bshape)
    bias_board = topi.reshape(bias, bshape)
    output = topi.add(topi.multiply(output_norm, scale_board), bias_board)
    # reshape saved_rvars
    saved_rvars = topi.reshape(saved_rvars, oshape)
    # update running mean
    running_mean_mul1 = topi.multiply(running_mean,
                                      topi.subtract(1.0, momentum))
    running_mean_mul2 = topi.multiply(topi.reshape(saved_mean, oshape),
                                      momentum)
    running_mean_out = topi.add(running_mean_mul1, running_mean_mul2)
    # update running var
    saved_var_mul1 = topi.multiply(running_var, topi.subtract(1.0, momentum))
    saved_var_mul2 = topi.multiply(topi.reshape(var, oshape), momentum)
    running_var_out = topi.add(saved_var_mul1, saved_var_mul2)
    # reshape saved_mean
    saved_mean = topi.reshape(saved_mean, oshape)

    return [
        data, scale, bias, running_mean, running_var, momentum, eps, output,
        saved_mean, saved_rvars, running_mean_out, running_var_out
    ]
예제 #3
0
 def get_target():
     a = te.placeholder((100, ), dtype='bfloat16')
     b = te.placeholder((100, ), dtype='bfloat16')
     c = te.compute((100, ), lambda i: to16(
         topi.add(topi.add(
             to32(a[i]),
             to32(b[i]),
         ), topi.add(
             to32(a[i]),
             to32(b[i]),
         ))))
     s = te.create_schedule(c.op)
     func = tvm.driver.build_module.form_irmodule(s, [a, b, c], "main",
                                                  None)["main"]
     return func.body
예제 #4
0
    def check_target(target):
        dev = tvm.device(target, 0)
        if not tvm.testing.device_enabled(target):
            print("Skip because %s is not enabled" % target)
            return
        print("Running on target: %s" % target)
        with tvm.target.Target(target):
            fcompute, fschedule = tvm.topi.testing.dispatch(
                target, _conv2d_hwcn_implement)
            t_conv = fcompute(A, W, stride, padding, dilation)
            t_bias = topi.add(t_conv, B)
            t_relu = topi.nn.relu(t_bias)
            s1 = fschedule([t_conv])
            s2 = fschedule([t_bias])
            s3 = fschedule([t_relu])
        a = tvm.nd.array(a_np, dev)
        w = tvm.nd.array(w_np, dev)
        b = tvm.nd.array(b_np, dev)

        conv_out = tvm.nd.array(
            np.zeros(get_const_tuple(t_conv.shape), dtype=t_conv.dtype), dev)
        bias_out = tvm.nd.array(
            np.zeros(get_const_tuple(t_bias.shape), dtype=t_bias.dtype), dev)
        relu_out = tvm.nd.array(
            np.zeros(get_const_tuple(t_relu.shape), dtype=t_relu.dtype), dev)
        func1 = tvm.build(s1, [A, W, t_conv], target)
        func2 = tvm.build(s2, [A, W, B, t_bias], target)
        func3 = tvm.build(s3, [A, W, B, t_relu], target)
        func1(a, w, conv_out)
        func2(a, w, b, bias_out)
        func3(a, w, b, relu_out)
        tvm.testing.assert_allclose(conv_out.asnumpy(), c1_np, rtol=1e-5)
        tvm.testing.assert_allclose(bias_out.asnumpy(), c2_np, rtol=1e-5)
        tvm.testing.assert_allclose(relu_out.asnumpy(), c3_np, rtol=1e-5)
예제 #5
0
def group_conv2d(N, CI, H, W, CO, KH, KW, strides, padding, dilation, group):

    CI_G = CI // groups
    data_shape = (N // env.BATCH, CI // env.BLOCK_IN, H, W, env.BATCH,
                  env.BLOCK_IN)
    kernel_shape = (CO // env.BLOCK_OUT, CI_G // env.BLOCK_IN, KH, KW,
                    env.BLOCK_OUT, env.BLOCK_IN)
    bias_shape = (N // env.BATCH, CO // env.BLOCK_OUT, 1, 1, env.BATCH,
                  env.BLOCK_OUT)

    data = te.placeholder(data_shape, name="data", dtype=env.inp_dtype)
    kernel = te.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
    bias = te.placeholder(bias_shape, name="bias", dtype=env.acc_dtype)

    with tvm.target.vta():
        res = topi.nn.group_conv2d_nchw(data, kernel, strides, padding,
                                        dilation, groups, env.acc_dtype)
        res = topi.right_shift(res, env.WGT_WIDTH)
        res = topi.add(res, bias)
        res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
        res = topi.cast(res, env.out_dtype)

    if tvm.target.Target.current().device_name == "vta":
        s = topi.generic.schedule_group_conv2d_nchw([res])
    else:
        s = te.create_schedule([res.op])

    return s, [data, kernel, bias, res]
예제 #6
0
    def check_device(device, ctx):
        print("Running on target: %s" % device)
        fcompute, fschedule = tvm.topi.testing.dispatch(
            device, _conv3d_ncdhw_implement)
        with tvm.target.Target(device):
            C = fcompute(A, W, (stride, stride, stride), padding,
                         (dilation, dilation, dilation), dtype)
            if add_bias:
                C = topi.add(C, bias)
            if add_relu:
                C = topi.nn.relu(C)
            s = fschedule([C])

        a = tvm.nd.array(a_np, ctx)
        w = tvm.nd.array(w_np, ctx)
        b = tvm.nd.array(b_np, ctx)
        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype),
                         ctx)
        if add_bias:
            func = tvm.build(s, [A, W, bias, C],
                             device,
                             name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
                             (batch, in_channel, in_size, num_filter, kernel,
                              stride, padding_sum, dilation))
            func(a, w, b, c)
        else:
            func = tvm.build(s, [A, W, C],
                             device,
                             name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
                             (batch, in_channel, in_size, num_filter, kernel,
                              stride, padding_sum, dilation))
            func(a, w, c)
        tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4)
    def check_device(device):
        ctx = tvm.context(device, 0)
        if not tvm.testing.device_enabled(device):
            print("Skip because %s is not enabled" % device)
            return
        if not nvcc.have_tensorcore(ctx.compute_version):
            print("skip because gpu does not support Tensor Cores")
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            fcompute, fschedule = tvm.topi.testing.dispatch(device, _conv2d_nhwc_tensorcore_implement)
            C = fcompute(A, W, stride, padding, dilation, 'float32')
            if add_bias:
                C = topi.add(C, bias)
            if add_relu:
                C = topi.nn.relu(C)
            s = fschedule([C])

        a = tvm.nd.array(a_np, ctx)
        w = tvm.nd.array(w_np, ctx)
        b = tvm.nd.array(b_np, ctx)
        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
        if add_bias:
            func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (
                batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
            func(a, w, b, c)
        else:
            func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (
                batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
            func(a, w, c)

        rtol = 1e-3
        tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=rtol)
예제 #8
0
    def check_device(device):
        ctx = tvm.context(device, 0)
        if not tvm.testing.device_enabled(device):
            print("Skip because %s is not enabled" % device)
            return
        if device == "cuda" and not tvm.contrib.nvcc.have_int8(
                ctx.compute_version):
            print("Skip because int8 intrinsics are not available")
            return

        print("Running on target: %s" % device)
        with tvm.target.Target(device):
            C = topi.cuda.group_conv2d_NCHWc_int8(A, W, stride, padding,
                                                  dilation, groups, dtype)
            if add_bias:
                C = topi.add(C, bias)
            if add_relu:
                C = topi.nn.relu(C)
            s = topi.cuda.schedule_group_conv2d_NCHWc_int8([C])

        a = tvm.nd.array(a_np, ctx)
        w = tvm.nd.array(w_np, ctx)
        b = tvm.nd.array(b_np, ctx)
        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype),
                         ctx)
        if add_bias:
            func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d" %\
                (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups))
            func(a, w, b, c)
        else:
            func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d" % \
            (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups))
            func(a, w, c)
        tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
예제 #9
0
def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, dilation):
    data_shape = (N // env.BATCH, CI // env.BLOCK_IN, H, W, env.BATCH,
                  env.BLOCK_IN)
    kernel_shape = (CO // env.BLOCK_OUT, CI // env.BLOCK_IN, KH, KW,
                    env.BLOCK_OUT, env.BLOCK_IN)
    bias_shape = (N // env.BATCH, CO // env.BLOCK_OUT, 1, 1, env.BATCH,
                  env.BLOCK_OUT)

    data = te.placeholder(data_shape, name="data", dtype=env.inp_dtype)
    kernel = te.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
    bias = te.placeholder(bias_shape, name="bias", dtype=env.acc_dtype)

    with tvm.target.vta():
        res = topi.nn.conv2d(input=data,
                             filter=kernel,
                             padding=padding,
                             strides=strides,
                             dilation=dilation,
                             layout='NCHW%dn%dc' % (env.BATCH, env.BLOCK_IN),
                             out_dtype=env.acc_dtype)
        res = topi.right_shift(res, env.WGT_WIDTH)
        res = topi.add(res, bias)
        res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
        res = topi.cast(res, env.out_dtype)

    if tvm.target.Target.current().device_name == 'vta':
        s = topi.generic.schedule_conv2d_nchw([res])
    else:
        s = te.create_schedule([res.op])

    return s, [data, kernel, bias, res]
예제 #10
0
    def check_device(device):
        ctx = tvm.context(device, 0)
        if not tvm.testing.device_enabled(device):
            print("Skip because %s is not enabled" % device)
            return

        print("Running on target: %s" % device)
        with tvm.target.Target(device):
            fcompute, fschedule = tvm.topi.testing.dispatch(
                device, _group_conv2d_nchw_implement)
            C = fcompute(A, W, stride, padding, dilation, groups, dtype)
            if add_bias:
                C = topi.add(C, bias)
            if add_relu:
                C = topi.nn.relu(C)
            s = fschedule([C])

        a = tvm.nd.array(a_np, ctx)
        w = tvm.nd.array(w_np, ctx)
        b = tvm.nd.array(b_np, ctx)
        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype),
                         ctx)
        if add_bias:
            func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d" %\
                (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups))
            func(a, w, b, c)
        else:
            func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d" % \
            (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups))
            func(a, w, c)
        tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
예제 #11
0
def compile_conv2d_NHWC_gemm_int8_arm(batch, in_channel, in_size, num_filter, kernel, stride, padding,
                                 dilation=1, add_bias=False, add_relu=False):
    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
    padding_sum = pad_top + pad_left + pad_bottom + pad_right
    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter,
                                                          kernel, stride, padding_sum, dilation))

    in_height = in_width = in_size
    A = te.placeholder((batch, in_height, in_width, in_channel), name='A', dtype='int8')
    W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W', dtype='int8')
    bias = te.placeholder((num_filter,), name='bias', dtype='int8')
    dtype = 'int32'
    device = "llvm --device arm_cpu --mtriple aarch64-linux-gnu"

    ctx = tvm.context(device, 0)
    if not ctx.exist:
        print("Skip because %s is not enabled" % device)
        return
    print("Compiling on arm AArch64 target: %s" % device)
    with tvm.target.create(device):
        assert is_aarch64_arm(), "AArch64 target not recognized"

        C = topi.arm_cpu.compute_conv2d_NHWC_quantized(A, W, (stride, stride), padding,
                                                       (dilation, dilation), dtype)
        if add_bias:
            C = topi.add(C, bias)
        if add_relu:
            C = topi.nn.relu(C)
        s = topi.arm_cpu.schedule_conv2d_NHWC_quantized([C])

    if add_bias:
        tvm.build(s, [A, W, bias, C], device,
                  name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch,
                                                         in_channel,
                                                         in_size,
                                                         num_filter,
                                                         kernel,
                                                         stride,
                                                         padding_sum,
                                                         dilation))
        func = tvm.build(s, [A, W, bias, C], device,
                         name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch,
                                                                in_channel,
                                                                in_size,
                                                                num_filter,
                                                                kernel,
                                                                stride,
                                                                padding_sum,
                                                                dilation))
    else:
        func = tvm.build(s, [A, W, C], device,
                         name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch,
                                                                in_channel,
                                                                in_size,
                                                                num_filter,
                                                                kernel,
                                                                stride,
                                                                padding_sum,
                                                                dilation))
예제 #12
0
    def check_target(target):
        dev = tvm.device(target, 0)
        if not tvm.testing.device_enabled(target):
            print("Skip because %s is not enabled" % target)
            return
        print("Running on target: %s" % target)

        if "cudnn" in target:
            fcompute, fschedule = topi.cuda.conv2d_cudnn, topi.cuda.schedule_conv2d_cudnn
        else:
            fcompute, fschedule = tvm.topi.testing.get_conv2d_nchw_implement(
                target)

        with tvm.target.Target(target):
            if "cudnn" in target:
                C = fcompute(A, W, (stride, stride), padding,
                             (dilation, dilation), 1, "NCHW", dtype)
            else:
                C = fcompute(A, W, (stride, stride), padding,
                             (dilation, dilation), dtype)
            if add_bias:
                C = topi.add(C, bias)
            if add_relu:
                C = topi.nn.relu(C)
            s = fschedule([C])

            if "llvm" in target:
                verify_workload_padding()

        a = tvm.nd.array(a_np, dev)
        w = tvm.nd.array(w_np, dev)
        b = tvm.nd.array(b_np, dev)

        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype),
                         dev)
        if add_bias:
            func = tvm.build(
                s,
                [A, W, bias, C],
                target,
                name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
                (batch, in_channel, in_size, num_filter, kernel, stride,
                 padding_sum, dilation),
            )
            func(a, w, b, c)
        else:
            func = tvm.build(
                s,
                [A, W, C],
                target,
                name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
                (batch, in_channel, in_size, num_filter, kernel, stride,
                 padding_sum, dilation),
            )
            func(a, w, c)
        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-4)
예제 #13
0
    def check_target(target):
        dev = tvm.device(target, 0)
        if not tvm.testing.device_enabled(target):
            print("Skip because %s is not enabled" % target)
            return
        if target == "cuda" and not tvm.contrib.nvcc.have_int8(
                dev.compute_version):
            print("Skip because int8 intrinsics are not available")
            return

        print("Running on target: %s" % target)
        with tvm.target.Target(target):
            C = topi.cuda.conv2d_NCHWc_int8(A, W, (stride, stride), padding,
                                            (dilation, dilation), "NCHW",
                                            dtype)
            if add_bias:
                C = topi.add(C, bias)
            if add_relu:
                C = topi.nn.relu(C)
            s = topi.cuda.schedule_conv2d_NCHWc_int8([C])

        a = tvm.nd.array(a_np, dev)
        w = tvm.nd.array(w_np, dev)
        b = tvm.nd.array(b_np, dev)
        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype),
                         dev)
        if add_bias:
            tvm.build(
                s,
                [A, W, bias, C],
                target,
                name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
                (batch, in_channel, in_size, num_filter, kernel, stride,
                 padding_sum, dilation),
            )
            func = tvm.build(
                s,
                [A, W, bias, C],
                target,
                name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
                (batch, in_channel, in_size, num_filter, kernel, stride,
                 padding_sum, dilation),
            )
            func(a, w, b, c)
        else:
            func = tvm.build(
                s,
                [A, W, C],
                target,
                name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
                (batch, in_channel, in_size, num_filter, kernel, stride,
                 padding_sum, dilation),
            )
            func(a, w, c)
        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
예제 #14
0
def test_schedule_cache_reads():
    a = te.placeholder((12, 12), dtype="uint8", name="a")
    b = te.placeholder((12, 12), dtype="uint8", name="b")
    add = topi.add(a, b)
    sch = te.create_schedule([add.op])
    cr = sch.cache_read(b, "global", [add])
    schedule_cache_reads(sch)
    assert len(sch.stages) == 4
    assert len(sch[cr].leaf_iter_vars) == 1
    iv = sch[cr].leaf_iter_vars[0]
    assert list(sch[cr].iter_var_attrs[iv].pragma_keys) == ["op"]
    assert list(sch[cr].iter_var_attrs[iv].pragma_values) == ["ethosu_copy"]
예제 #15
0
def test_complex_reduce():
    in_shape = (2, 3)
    dtype = "float32"
    axis = 0
    keepdims = False
    A = te.placeholder(shape=in_shape, name="A", dtype=dtype)
    B = topi.sum(A, axis=axis, keepdims=keepdims)
    C = topi.add(B, B)
    D = topi.multiply(B, B)
    E = topi.add(C, D)
    for device, ctx in tvm.testing.enabled_targets():
        print("Running on target: %s" % device)
        with tvm.target.Target(device):
            s = tvm.topi.testing.get_reduce_schedule(device)(E)
        foo = tvm.build(s, [A, E], device, name="sum")
        in_npy = np.random.uniform(-1, 1, size=in_shape).astype(dtype)
        sum_npy = in_npy.sum(axis=axis, keepdims=keepdims)
        out_npy = sum_npy * 2 + sum_npy * sum_npy
        data_tvm = tvm.nd.array(in_npy, ctx=ctx)
        out_tvm = tvm.nd.empty(shape=out_npy.shape, ctx=ctx, dtype=dtype)
        foo(data_tvm, out_tvm)
        tvm.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1e-3, 1e-3)
예제 #16
0
    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            C = topi.arm_cpu.compute_conv2d_NHWC_quantized(A, W, (stride, stride), padding,
                                                           (dilation, dilation), dtype)
            if add_bias:
                C = topi.add(C, bias)
            if add_relu:
                C = topi.nn.relu(C)
            s = topi.arm_cpu.schedule_conv2d_NHWC_quantized([C])

        a = tvm.nd.array(a_np, ctx)
        w = tvm.nd.array(w_np, ctx)
        b = tvm.nd.array(b_np, ctx)
        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
        if add_bias:
            tvm.build(s, [A, W, bias, C], device,
                      name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch,
                                                             in_channel,
                                                             in_size,
                                                             num_filter,
                                                             kernel,
                                                             stride,
                                                             padding_sum,
                                                             dilation))
            func = tvm.build(s, [A, W, bias, C], device,
                             name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch,
                                                                    in_channel,
                                                                    in_size,
                                                                    num_filter,
                                                                    kernel,
                                                                    stride,
                                                                    padding_sum,
                                                                    dilation))
            func(a, w, b, c)
        else:
            func = tvm.build(s, [A, W, C], device,
                             name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch,
                                                                    in_channel,
                                                                    in_size,
                                                                    num_filter,
                                                                    kernel,
                                                                    stride,
                                                                    padding_sum,
                                                                    dilation))
            func(a, w, c)
        tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
    def check_device(device):
        dev = tvm.device(device, 0)
        if not tvm.testing.device_enabled(device):
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.Target(device):
            C = topi.x86.conv2d_NCHWc(
                A,
                W,
                (stride, stride),
                padding,
                (dilation, dilation),
                "NCHW%dc" % ic_block,
                "NCHW%dc" % oc_block,
                dtype,
            )
            if add_bias:
                C = topi.add(C, bias)
            if add_relu:
                C = topi.nn.relu(C)
            s = topi.x86.schedule_conv2d_NCHWc([C])

        a = tvm.nd.array(a_np, dev)
        w = tvm.nd.array(w_np, dev)
        b = tvm.nd.array(b_np, dev)
        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype),
                         dev)
        if add_bias:
            func = tvm.build(
                s,
                [A, W, bias, C],
                device,
                name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
                (batch, in_channel, in_size, num_filter, kernel, stride,
                 padding_sum, dilation),
            )
            func(a, w, b, c)
        else:
            func = tvm.build(
                s,
                [A, W, C],
                device,
                name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
                (batch, in_channel, in_size, num_filter, kernel, stride,
                 padding_sum, dilation),
            )
            func(a, w, c)
        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
예제 #18
0
    def check_device(device):
        dev = tvm.device(device, 0)
        if not tvm.testing.device_enabled(device):
            print("Skipping %s becuase it is not enabled" % device)
        print("Running on target: %s" % device)
        with tvm.target.Target(device):
            result_c = topi.nn.conv2d(
                placholder_a,
                placeholder_w,
                stride,
                padding,
                dilation,
                data_layout="NCHW",
                out_dtype=dtype,
            )
            if add_bias:
                result_c = topi.add(result_c, bias)
            if add_relu:
                result_c = topi.nn.relu(result_c)
            schedule = topi.generic.schedule_conv2d_nchw([result_c])

        buff_a = tvm.nd.array(a_np, dev)
        buff_w = tvm.nd.array(w_np, dev)
        buff_b = tvm.nd.array(b_np, dev)
        buff_c = tvm.nd.array(
            np.zeros(get_const_tuple(result_c.shape), dtype=result_c.dtype),
            dev)
        if add_bias:
            func = tvm.build(
                schedule,
                [placholder_a, placeholder_w, bias, result_c],
                device,
                name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
                (batch, in_channel, in_size, num_filter, kernel, stride,
                 padding, dilation),
            )
            func(buff_a, buff_w, buff_b, buff_c)
        else:
            func = tvm.build(
                schedule,
                [placholder_a, placeholder_w, result_c],
                device,
                name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
                (batch, in_channel, in_size, num_filter, kernel, stride,
                 padding, dilation),
            )
            func(buff_a, buff_w, buff_c)
        tvm.testing.assert_allclose(buff_c.numpy(), c_np, rtol=1e-4)
예제 #19
0
 def check_device(device, ctx):
     with tvm.target.Target(device):
         print("Running on target: %s" % device)
         conv2d_compute, conv2d_schedule = tvm.topi.testing.get_conv2d_nchw_implement(device)
         data = te.placeholder((2, 1, 2, 4), "int8", "data")
         w = te.placeholder((3, 1, 2, 2), "int8", "w")
         conv1 = conv2d_compute(data, w, 1, 0, 1, "int32")
         zeros = topi.full((2, 3, 1, 3), "int32", tvm.tir.const(0, dtype="int32"))
         gt = topi.greater_equal(conv1, zeros)
         one = topi.full((2, 3, 1, 3), "int32", tvm.tir.const(1, dtype="int32"))
         two = topi.full((2, 3, 1, 3), "int32", tvm.tir.const(2, dtype="int32"))
         where = topi.where(gt, one, two)
         add = topi.add(conv1, where)
         outs = [add]
         s = conv2d_schedule(outs)
         tvm.build(s, [data, w, add], target=backend)
예제 #20
0
    def check_device(device):
        dev = tvm.device(device, 0)
        if not tvm.testing.device_enabled(device):
            print("Skipping %s becuase it is not enabled" % device)
        print("Running on target: %s" % device)
        with tvm.target.Target(device):
            C = topi.nn.conv2d(A,
                               W,
                               stride,
                               padding,
                               dilation,
                               layout="NCHW",
                               out_dtype=dtype)
            if add_bias:
                C = topi.add(C, bias)
            if add_relu:
                C = topi.nn.relu(C)
            s = topi.generic.schedule_conv2d_nchw([C])

        a = tvm.nd.array(a_np, dev)
        w = tvm.nd.array(w_np, dev)
        b = tvm.nd.array(b_np, dev)
        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype),
                         dev)
        if add_bias:
            func = tvm.build(
                s,
                [A, W, bias, C],
                device,
                name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
                (batch, in_channel, in_size, num_filter, kernel, stride,
                 padding, dilation),
            )
            func(a, w, b, c)
        else:
            func = tvm.build(
                s,
                [A, W, C],
                device,
                name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
                (batch, in_channel, in_size, num_filter, kernel, stride,
                 padding, dilation),
            )
            func(a, w, c)
        tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4)
예제 #21
0
def test_tiled_added_poolings(tile_shape):
    A = te.placeholder((1, 12, 12, 16), name="A", dtype="int8")
    B = te.placeholder((1, 14, 14, 16), name="A", dtype="int8")
    pool_a = topi.nn.pool2d(A, (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC")
    pool_b = topi.nn.pool2d(B, (5, 5), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC")
    add = topi.add(pool_a, pool_b)
    pool_c = topi.nn.pool2d(add, (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC")

    sch = tvm.te.create_schedule([pool_c.op])
    oi, ii = _tile_nd(sch, pool_c, tile_shape)
    sch[add].compute_at(sch[pool_c], oi[-1])
    sch[add].rolling_buffer()
    sch[pool_b].compute_at(sch[pool_c], oi[-1])
    sch[pool_b].rolling_buffer()
    sch[pool_a].compute_at(sch[pool_c], oi[-1])
    sch[pool_a].rolling_buffer()

    _verify_schedule(sch, [A, B], pool_c)
예제 #22
0
def test_bfloat_add_and_cast_FloatImm():
    X = te.placeholder((3, ), name="X")
    Z = topi.cast(topi.add(topi.cast(X, dtype="custom[bfloat]16"),
                           tvm.tir.FloatImm("custom[bfloat]16", 1.5)),
                  dtype="float")

    s = te.create_schedule([Z.op])
    built_cast = lower_datatypes_and_build(s, [X, Z])

    ctx = tvm.context(tgt, 0)

    x = tvm.nd.array(np.array([0.0, 1.0, 1.5]).astype("float32"), ctx=ctx)
    z_expected = np.array([1.5, 2.5, 3.0]).astype("float32")
    z = tvm.nd.empty(Z.shape, dtype=Z.dtype, ctx=ctx)

    built_cast(x, z)

    assert np.array_equal(z_expected, z.asnumpy())
예제 #23
0
    def check_device(device):
        ctx = tvm.context(device, 0)
        print("Running on target: %s" % device)
        with tvm.target.Target(device):
            if bgemm == "direct":
                fcompute, fschedule = tvm.topi.testing.dispatch(
                    device, _conv2d_nhwc_winograd_direct)
            elif bgemm == "tensorcore":
                fcompute, fschedule = tvm.topi.testing.dispatch(
                    device, _conv2d_nhwc_winograd_tensorcore)
            C = fcompute(A, W, stride, padding, dilation, "float32")
            if add_bias:
                C = topi.add(C, bias)
            if add_relu:
                C = topi.nn.relu(C)
            s = fschedule([C])

        a = tvm.nd.array(a_np, ctx)
        w = tvm.nd.array(w_np, ctx)
        b = tvm.nd.array(b_np, ctx)
        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype),
                         ctx)
        if add_bias:
            func = tvm.build(
                s,
                [A, W, bias, C],
                device,
                name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
                (batch, in_channel, in_size, num_filter, kernel, stride,
                 padding_sum, dilation),
            )
            func(a, w, b, c)
        else:
            func = tvm.build(
                s,
                [A, W, C],
                device,
                name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
                (batch, in_channel, in_size, num_filter, kernel, stride,
                 padding_sum, dilation),
            )
            func(a, w, c)

        tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=2e-3)
    def check_device(device):
        dev = tvm.device(device, 0)
        print("Running on target: %s" % device)
        with tvm.target.Target(device):
            fcompute, fschedule = tvm.topi.testing.dispatch(
                device, _conv3d_ndhwc_tensorcore_implement)
            C = fcompute(A, W, stride, padding, dilation, 1, "float16")
            if add_bias:
                C = topi.add(C, bias)
            if add_relu:
                C = topi.nn.relu(C)
            s = fschedule([C])

        a = tvm.nd.array(a_np, dev)
        w = tvm.nd.array(w_np, dev)
        b = tvm.nd.array(b_np, dev)
        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype),
                         dev)
        if add_bias:
            func = tvm.build(
                s,
                [A, W, bias, C],
                device,
                name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
                (batch, in_channel, in_size, num_filter, kernel, stride,
                 padding_sum, dilation),
            )
            func(a, w, b, c)
        else:
            func = tvm.build(
                s,
                [A, W, C],
                device,
                name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
                (batch, in_channel, in_size, num_filter, kernel, stride,
                 padding_sum, dilation),
            )
            func(a, w, c)

        # Tensorcores are very inaccurate, with large shapes, the accumulation
        # error is high especially away from 1. We disable atol as it is very
        # large for these numbers that are far away from 1.
        tvm.testing.assert_allclose(c.numpy(), c_np, atol=1e200, rtol=0.01)
예제 #25
0
def batch_norm_bwd(N, C, H, W, dtype="float32"):
    dshape = (N, C, H, W)
    oshape = (C, )
    bshape = (1, C, 1, 1)
    sshape = (1, )
    data = te.placeholder(dshape, name="data", dtype=dtype)
    scale = te.placeholder(oshape, name="scale", dtype=dtype)
    saved_mean = te.placeholder(oshape, name="saved_mean", dtype=dtype)
    saved_var = te.placeholder(oshape, name="saved_var", dtype=dtype)
    eps = te.placeholder(sshape, name="eps", dtype=dtype)
    grad_output = te.placeholder(dshape, name="data", dtype=dtype)

    axis = (0, 2, 3)
    num_ele = dshape[0] * dshape[2] * dshape[3]
    frac_num_ele = 1.0 / num_ele
    # compute grad_input
    mean_sum = topi.sum(data, axis, True)
    mean = topi.multiply(mean_sum, frac_num_ele)
    var_sub = topi.subtract(data, mean)
    var_mul = topi.multiply(var_sub, var_sub)
    var_sum = topi.sum(var_mul, axis, True)
    var = topi.multiply(var_sum, frac_num_ele)
    var_eps = topi.add(var, eps)
    output_sqrt = topi.sqrt(var_eps)
    x_norm = topi.subtract(data, mean)
    x_hat = topi.divide(x_norm, output_sqrt)
    dx_hat = topi.multiply(grad_output, topi.reshape(scale, bshape))
    grad_input_sum1 = topi.sum(dx_hat * x_hat, axis, True)
    grad_input_sum2 = topi.sum(dx_hat, axis, True)
    grad_input_left = topi.divide(frac_num_ele, topi.sqrt(var_eps))
    grad_input_right1 = topi.subtract(topi.multiply(dx_hat, num_ele),
                                      grad_input_sum2)
    grad_input_right2 = topi.multiply(x_hat, grad_input_sum1)
    grad_input = topi.multiply(
        grad_input_left, topi.subtract(grad_input_right1, grad_input_right2))
    # compute grad_scale and grad_bias
    grad_scale = topi.sum(grad_output * x_hat, axis)
    grad_bias = topi.sum(grad_output, axis)

    return [
        data, scale, saved_mean, saved_var, eps, grad_output, grad_input,
        grad_scale, grad_bias
    ]
예제 #26
0
 def check_device(device):
     with tvm.target.create(device):
         ctx = tvm.context(device, 0)
         if not ctx.exist:
             print("Skip because %s is not enabled" % device)
             return
         print("Running on target: %s" % device)
         conv2d_compute, conv2d_schedule = tvm.topi.testing.get_conv2d_nchw_implement(device)
         data = te.placeholder((2, 1, 2, 4), 'int8', 'data')
         w = te.placeholder((3, 1, 2, 2), 'int8', 'w')
         conv1 = conv2d_compute(data, w, 1, 0, 1, 'int32')
         zeros = topi.full((2, 3, 1, 3), 'int32', tvm.tir.const(0, dtype='int32'))
         gt = topi.greater_equal(conv1, zeros)
         one = topi.full((2, 3, 1, 3), 'int32', tvm.tir.const(1, dtype='int32'))
         two = topi.full((2, 3, 1, 3), 'int32', tvm.tir.const(2, dtype='int32'))
         where = topi.where(gt, one, two)
         add = topi.add(conv1, where)
         outs = [add]
         s = conv2d_schedule(outs)
         tvm.build(s, [data, w, add], target=backend)
    def check_device(device):
        dev = tvm.device(device, 0)
        print("Running on target: %s" % device)
        with tvm.target.Target(device):
            fcompute, fschedule = tvm.topi.testing.dispatch(
                device, _conv3d_ndhwc_tensorcore_implement)
            C = fcompute(A, W, stride, padding, dilation, "float32")
            if add_bias:
                C = topi.add(C, bias)
            if add_relu:
                C = topi.nn.relu(C)
            s = fschedule([C])

        a = tvm.nd.array(a_np, dev)
        w = tvm.nd.array(w_np, dev)
        b = tvm.nd.array(b_np, dev)
        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype),
                         dev)
        if add_bias:
            func = tvm.build(
                s,
                [A, W, bias, C],
                device,
                name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
                (batch, in_channel, in_size, num_filter, kernel, stride,
                 padding_sum, dilation),
            )
            func(a, w, b, c)
        else:
            func = tvm.build(
                s,
                [A, W, C],
                device,
                name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
                (batch, in_channel, in_size, num_filter, kernel, stride,
                 padding_sum, dilation),
            )
            func(a, w, c)

        rtol = 1e-3
        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=rtol)
예제 #28
0
batch = 1
in_channel = 256
in_size = 32
num_filter = 256
kernel = 3
stride = 1
padding = "SAME"
dilation = 1

A = te.placeholder((in_size, in_size, in_channel, batch), name="A")
W = te.placeholder((kernel, kernel, in_channel, num_filter), name="W")
B = te.placeholder((1, num_filter, 1), name="bias")

with tvm.target.Target("llvm"):
    t_conv = topi.nn.conv2d_hwcn(A, W, stride, padding, dilation)
    t_bias = topi.add(t_conv, B)
    t_relu = topi.nn.relu(t_bias)
    s = topi.generic.schedule_conv2d_hwcn([t_relu])

######################################################################
# Render Graphs with TEDD
# -----------------------
# We render graphs to see the computation
# and how it is scheduled.
# If you run the tutorial in a Jupyter notebook, you can use the following commented lines
# to render SVG figures showing in notebook directly.
#

tedd.viz_dataflow_graph(s, dot_file_path="/tmp/dfg.dot")
# tedd.viz_dataflow_graph(s, show_svg = True)
예제 #29
0
def run_group_conv2d(env,
                     remote,
                     wl,
                     target,
                     check_correctness=True,
                     print_ir=False,
                     samples=4):

    # Workload assertions
    assert wl.hpad == wl.wpad

    # Perform packing only if we are targeting the accelerator
    if "arm_cpu" in target.keys:
        data_pack = False
        layout = "NCHW"
        fcompute = topi.nn.group_conv2d_nchw
        fschedule = topi.generic.schedule_group_conv2d_nchw
    elif "vta" in target.keys:
        data_pack = True
        layout = "NCHW%dn%dc" % (env.BATCH, env.BLOCK_IN)
        fcompute = vta.top.group_conv2d_packed
        fschedule = vta.top.schedule_group_conv2d_packed

    # Derive shapes depending upon packing
    CI_G = wl.in_filter // wl.groups
    a_shape = (wl.batch, wl.in_filter, wl.height, wl.width)
    w_shape = (wl.out_filter, CI_G, wl.hkernel, wl.wkernel)
    b_shape = (wl.batch, wl.out_filter, 1, 1)
    if data_pack:
        data_shape = (wl.batch // env.BATCH, wl.in_filter // env.BLOCK_IN,
                      wl.height, wl.width, env.BATCH, env.BLOCK_IN)
        kernel_shape = (wl.out_filter // env.BLOCK_OUT, CI_G // env.BLOCK_IN,
                        wl.hkernel, wl.wkernel, env.BLOCK_OUT, env.BLOCK_IN)
        bias_shape = (wl.batch // env.BATCH, wl.out_filter // env.BLOCK_OUT, 1,
                      1, env.BATCH, env.BLOCK_OUT)
    else:
        data_shape = a_shape
        kernel_shape = w_shape
        bias_shape = b_shape
    data = te.placeholder(data_shape, name="data", dtype=env.inp_dtype)
    kernel = te.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
    bias = te.placeholder(bias_shape, name="bias", dtype=env.acc_dtype)
    padding = relay.nn.get_pad_tuple2d((wl.hpad, wl.wpad))

    # Define base computation schedule
    with target:
        res = fcompute(data, kernel, (wl.hstride, wl.wstride), padding, (1, 1),
                       wl.groups, env.acc_dtype)
        res = topi.right_shift(res, 8)
        res = topi.add(res, bias)
        res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
        res = topi.cast(res, env.out_dtype)
        # Derive base schedule
        s = fschedule([res])
        if print_ir:
            print(vta.lower(s, [data, kernel, bias, res], simple_mode=True))

    # Derive number of ops
    fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1
    fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1
    num_ops = 2 * wl.batch * fout_height * fout_width * wl.hkernel * wl.wkernel * \
        wl.out_filter * wl.in_filter // wl.groups

    def get_ref_data():
        # derive min max for act, wgt, and bias types (max non inclusive)
        a_min, a_max = 0 - (1 << (env.INP_WIDTH - 1)), (1 <<
                                                        (env.INP_WIDTH - 1))
        w_min, w_max = 0 - (1 << (env.WGT_WIDTH - 1)), (1 <<
                                                        (env.WGT_WIDTH - 1))
        b_min, b_max = 0 - 1 << (env.INP_WIDTH + env.WGT_WIDTH -
                                 2), 1 << (env.INP_WIDTH + env.WGT_WIDTH - 2)
        a_np = np.random.randint(a_min, a_max, size=a_shape).astype(data.dtype)
        w_np = np.random.randint(w_min, w_max,
                                 size=w_shape).astype(kernel.dtype)
        b_np = np.random.randint(b_min, b_max,
                                 size=b_shape).astype(env.acc_dtype)
        r_np = tvm.topi.testing.conv2d_nchw_python(
            a_np.astype(env.acc_dtype), w_np.astype(env.acc_dtype),
            (wl.hstride, wl.wstride), wl.hpad, wl.groups).astype(env.acc_dtype)
        return a_np, w_np, b_np, r_np

    # Data in original format
    data_np, kernel_np, bias_np, res_ref = get_ref_data()
    if data_pack:
        data_np = data_np.reshape(wl.batch // env.BATCH, env.BATCH,
                                  wl.in_filter // env.BLOCK_IN, env.BLOCK_IN,
                                  wl.height, wl.width).transpose(
                                      (0, 2, 4, 5, 1, 3))
        kernel_np = kernel_np.reshape(wl.out_filter // env.BLOCK_OUT,
                                      env.BLOCK_OUT, CI_G // env.BLOCK_IN,
                                      env.BLOCK_IN, wl.hkernel,
                                      wl.wkernel).transpose((0, 2, 4, 5, 1, 3))
        bias_np = bias_np.reshape(wl.batch // env.BATCH,
                                  wl.out_filter // env.BLOCK_OUT, 1, 1,
                                  env.BATCH, env.BLOCK_OUT)

    # Build
    if "vta" in target.keys:
        mod = vta.build(s, [data, kernel, bias, res],
                        target=target,
                        target_host=env.target_host,
                        name="conv2d")
    else:
        mod = tvm.build(s, [data, kernel, bias, res],
                        target=target,
                        target_host=env.target_host,
                        name="conv2d")
    temp = util.tempdir()
    mod.save(temp.relpath("conv2d.o"))
    remote.upload(temp.relpath("conv2d.o"))
    f = remote.load_module("conv2d.o")
    ctx = remote.context(str(target))

    res_np = np.zeros(topi.util.get_const_tuple(res.shape)).astype(res.dtype)
    data_arr = tvm.nd.array(data_np, ctx)
    kernel_arr = tvm.nd.array(kernel_np, ctx)
    bias_arr = tvm.nd.array(bias_np, ctx)
    res_arr = tvm.nd.array(res_np, ctx)
    time_f = f.time_evaluator("conv2d", ctx, number=samples)

    # In vta sim mode, collect simulator runtime statistics
    stats = {}
    cost = None
    if env.TARGET in ["sim", "tsim"]:
        # Check if we're in local RPC mode (allows us to rebuild the
        # runtime on the fly when varying the VTA designs)
        local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0"))
        if local_rpc:
            if env.TARGET == "sim":
                remote.get_function("vta.simulator.profiler_clear")()
            else:
                remote.get_function("vta.tsim.profiler_clear")()
            cost = time_f(data_arr, kernel_arr, bias_arr, res_arr)
            if env.TARGET == "sim":
                stats = json.loads(
                    remote.get_function("vta.simulator.profiler_status")())
            else:
                stats = json.loads(
                    remote.get_function("vta.tsim.profiler_status")())
        else:
            simulator.clear_stats()
            cost = time_f(data_arr, kernel_arr, bias_arr, res_arr)
            stats = simulator.stats()
    else:
        cost = time_f(data_arr, kernel_arr, bias_arr, res_arr)

    # Check correctness
    correct = False
    if check_correctness:
        res_orig = res_arr.asnumpy()
        if data_pack:
            res_orig = res_orig.transpose(
                (0, 4, 1, 5, 2, 3)).reshape(wl.batch, wl.out_filter,
                                            fout_height, fout_width)
            bias_np = bias_np.transpose(
                (0, 4, 1, 5, 2, 3)).reshape(wl.batch, wl.out_filter, 1, 1)
        res_ref = res_ref >> env.WGT_WIDTH
        res_ref += bias_np
        res_ref = np.clip(res_ref, 0, (1 << env.OUT_WIDTH - 1) - 1)
        res_ref = res_ref.astype(env.out_dtype)
        correct = np.allclose(res_orig, res_ref)

    gops = (num_ops / cost.mean) / float(10**9)
    status = "PASSED" if correct else "FAILED"
    if "arm_cpu" in target.keys:
        device = "CPU"
    elif "vta" in target.keys:
        device = "VTA"
    print("%s GROUP CONV2D TEST %s: Time cost = %g sec/op, %g GOPS" %
          (device, status, cost.mean, gops))

    return correct, cost, stats
def test_cache_read_write():
    N, H, W, CO, CI, KH, KW, strides, padding = 4, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)

    data = te.placeholder((N, CI, H, W), name="Data")
    kernel_data = te.placeholder((CO, CI, KH, KW), name="Kernel_data")
    k0, k1 = te.compute(
        kernel_data.shape,
        lambda *i: (kernel_data(*i) + 1, kernel_data(*i) / 2),
        name="Kernel_split",
    )
    kernel = te.compute(kernel_data.shape, lambda *i: k0(*i) + k1(*i), name="Kernel")
    conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation=1)
    relu = topi.nn.relu(conv)
    add = topi.add(data, relu)

    dag = auto_scheduler.ComputeDAG([data, kernel_data, add])
    s0 = dag.get_init_state()

    pad_temp = s0.stage_ops[1]
    kernel_split = s0.stage_ops[3]

    # 0: init state
    ori_its = s0[add].iters
    its = s0.split(add, s0[add].iters[0], [2])
    s0.reorder(add, [its[0], ori_its[1], its[1], ori_its[2], ori_its[3]])
    s0.compute_inline(relu)

    # 1: simple cache_write with compute_at
    conv_global = s0.cache_write(conv, "global")
    s0.compute_at(conv_global, conv, s0[conv].iters[3])

    # 2: simple cache_read with compute_at
    kernel_global = s0.cache_read(kernel, "global", [conv_global])
    s0.compute_at(kernel_global, conv_global, s0[conv_global].iters[4])
    """
        Placeholder: Data, Kernel_data
        for i0 (0,4)
          for i1 (0,512)
            for i2 (0,9)
              for i3 (0,9)
                pad_temp = ...
        for i0 (0,512)
          for i1 (0,512)
            for i2 (0,3)
              for i3 (0,3)
                Kernel_split = ...
        for i0 (0,512)
          for i1 (0,512)
            for i2 (0,3)
              for i3 (0,3)
                Kernel = ...
        for nn (0,4)
          for ff (0,512)
            for yy (0,7)
              for xx (0,7)
                for nn_c (None)
                  for ff_c (None)
                    for yy_c (None)
                      for xx_c (None)
                        for rc (None)
                          for ax0 (None)
                            for ax1 (None)
                              for ax2 (None)
                                for ax3 (None)
                                  Kernel.global = ...
                          for ry (None)
                            for rx (None)
                              compute.global = ...
                compute = ...
        for ax0.0 (0,2)
          for ax1 (0,512)
            for ax0.1 (0,2)
              for ax2 (0,7)
                for ax3 (0,7)
                  T_add = ...
    """
    s1 = dag.infer_bound_from_state(s0)
    assert s1[conv].iters[0].range.extent == 4
    assert s1[conv].iters[1].range.extent == 512
    assert s1[conv].iters[2].range.extent == 7
    assert s1[conv].iters[3].range.extent == 7
    assert s1[kernel_global].iters[0].range.extent == 1
    assert s1[kernel_global].iters[1].range.extent == 1
    assert s1[kernel_global].iters[2].range.extent == 3
    assert s1[kernel_global].iters[3].range.extent == 3
    assert s1[conv_global].iters[0].range.extent == 1
    assert s1[conv_global].iters[1].range.extent == 1
    assert s1[conv_global].iters[2].range.extent == 1
    assert s1[conv_global].iters[3].range.extent == 1
    assert s1[conv_global].iters[4].range.extent == 512
    assert s1[conv_global].iters[5].range.extent == 3
    assert s1[conv_global].iters[6].range.extent == 3

    # 3: two level cache_read with compute_at
    #    preparing for GPU's shared memory & local memory
    pad_temp_global = s0.cache_read(pad_temp, "global", [conv_global])
    pad_temp_shared = s0.cache_read(pad_temp_global, "shared", [conv_global])
    s0.compute_at(pad_temp_global, conv_global, s0[conv_global].iters[2])
    s0.compute_at(pad_temp_shared, conv_global, s0[conv_global].iters[4])

    # 4: cache_read with multi readers
    #    This stage cannot be compute at to its consumer
    s0.cache_read(data, "global", [pad_temp, add])
    """
        Placeholder: Data, Kernel_data
        for ax0 (0,4)
          for ax1 (0,512)
            for ax2 (0,7)
              for ax3 (0,7)
                Data.global = ...
        for i0 (0,4)
          for i1 (0,512)
            for i2 (0,9)
              for i3 (0,9)
                pad_temp = ...
        for i0 (0,512)
          for i1 (0,512)
            for i2 (0,3)
              for i3 (0,3)
                Kernel_split = ...
        for i0 (0,512)
          for i1 (0,512)
            for i2 (0,3)
              for i3 (0,3)
                Kernel = ...
        for nn (0,4)
          for ff (0,512)
            for yy (0,7)
              for xx (0,7)
                for nn_c (None)
                  for ff_c (None)
                    for yy_c (None)
                      for ax0 (None)
                        for ax1 (None)
                          for ax2 (None)
                            for ax3 (None)
                              pad_temp.global = ...
                      for xx_c (None)
                        for rc (None)
                          for ax0 (None)
                            for ax1 (None)
                              for ax2 (None)
                                for ax3 (None)
                                  Kernel.global = ...
                          for ax0 (None)
                            for ax1 (None)
                              for ax2 (None)
                                for ax3 (None)
                                  pad_temp.global.shared = ...
                          for ry (None)
                            for rx (None)
                              compute.global = ...
                compute = ...
        for ax0.0 (0,2)
          for ax1 (0,512)
            for ax0.1 (0,2)
              for ax2 (0,7)
                for ax3 (0,7)
                  T_add = ...
    """
    s1 = dag.infer_bound_from_state(s0)
    assert s1[conv].iters[0].range.extent == 4
    assert s1[conv].iters[1].range.extent == 512
    assert s1[conv].iters[2].range.extent == 7
    assert s1[conv].iters[3].range.extent == 7
    assert s1[kernel_global].iters[0].range.extent == 1
    assert s1[kernel_global].iters[1].range.extent == 1
    assert s1[kernel_global].iters[2].range.extent == 3
    assert s1[kernel_global].iters[3].range.extent == 3
    assert s1[conv_global].iters[0].range.extent == 1
    assert s1[conv_global].iters[1].range.extent == 1
    assert s1[conv_global].iters[2].range.extent == 1
    assert s1[conv_global].iters[3].range.extent == 1
    assert s1[conv_global].iters[4].range.extent == 512
    assert s1[conv_global].iters[5].range.extent == 3
    assert s1[conv_global].iters[6].range.extent == 3
    assert s1[pad_temp_global].iters[0].range.extent == 1
    assert s1[pad_temp_global].iters[1].range.extent == 512
    assert s1[pad_temp_global].iters[2].range.extent == 3
    assert s1[pad_temp_global].iters[3].range.extent == 3
    assert s1[pad_temp_shared].iters[0].range.extent == 1
    assert s1[pad_temp_shared].iters[1].range.extent == 1
    assert s1[pad_temp_shared].iters[2].range.extent == 3
    assert s1[pad_temp_shared].iters[3].range.extent == 3

    # 5: cache_write with multi outputs
    # TVM's cache_write actually has a bug with this case:
    #
    # After schedule.cache_write, TVM generate one new stage:
    #   From: kernel_data -> kernel_split -> kernel
    #   To:   kernel_data -> kernel_split_global -> kernel_split -> kernel
    #
    # But with topo sort analyse, we get:
    #  //   kernel_data -> kernel_split_global -> kernel_split -> kernel
    #         \                                                /
    #          ----------------> kernel_split ---------------->
    #
    # TODO(jcf94): Seems there's bug with the input/output tensor. Such multi outputs case
    # should be unusual, so we make some hack on DoCacheWrite. This should be fixed later.
    kernel_split_global = s0.cache_write(kernel_split, "global")
    """
        Placeholder: Data, Kernel_data
        for ax0 (0,4)
          for ax1 (0,512)
            for ax2 (0,7)
              for ax3 (0,7)
                Data.global = ...
        for i0 (0,4)
          for i1 (0,512)
            for i2 (0,9)
              for i3 (0,9)
                pad_temp = ...
        for i0_c (0,512)
          for i1_c (0,512)
            for i2_c (0,3)
              for i3_c (0,3)
                Kernel_split.global = ...
        for i0 (0,512)
          for i1 (0,512)
            for i2 (0,3)
              for i3 (0,3)
                Kernel_split = ...
        (******* Bug here, there should not be two kernel_split stage *******)
        for i0 (0,512)
          for i1 (0,512)
            for i2 (0,3)
              for i3 (0,3)
                Kernel_split = ...
        (******* Bug here, there should not be two kernel_split stage *******)
        for i0 (0,512)
          for i1 (0,512)
            for i2 (0,3)
              for i3 (0,3)
                Kernel = ...
        for nn (0,4)
          for ff (0,512)
            for yy (0,7)
              for xx (0,7)
                for nn_c (None)
                  for ff_c (None)
                    for yy_c (None)
                      for ax0 (None)
                        for ax1 (None)
                          for ax2 (None)
                            for ax3 (None)
                              pad_temp.global = ...
                      for xx_c (None)
                        for rc (None)
                          for ax0 (None)
                            for ax1 (None)
                              for ax2 (None)
                                for ax3 (None)
                                  Kernel.global = ...
                          for ax0 (None)
                            for ax1 (None)
                              for ax2 (None)
                                for ax3 (None)
                                  pad_temp.global.shared = ...
                          for ry (None)
                            for rx (None)
                              compute.global = ...
                compute = ...
        for ax0.0 (0,2)
          for ax1 (0,512)
            for ax0.1 (0,2)
              for ax2 (0,7)
                for ax3 (0,7)
                  T_add = ...
    """
    assert len(s0[kernel_split].iters) == len(s0[kernel_split_global].iters)
    for it0, it1 in zip(s0[kernel_split].iters, s0[kernel_split_global].iters):
        assert it0.range == it1.range