def get_eliminated(): a = te.placeholder((100,), dtype="bfloat16") b = te.placeholder((100,), dtype="bfloat16") c = te.compute( (100,), lambda i: to16( topi.add( to32( to16( topi.add( to32(a[i]), to32(b[i]), ) ) ), to32( to16( topi.add( to32(a[i]), to32(b[i]), ) ) ), ) ), ) s = te.create_schedule(c.op) stmt = lower_stmt(s, [a, b, c], tvm.tir.transform.BF16CastElimination) return stmt
def batch_norm_fwd(N, C, H, W, dtype="float32"): dshape = (N, C, H, W) oshape = (C, ) bshape = (1, C, 1, 1) sshape = (1, ) data = te.placeholder(dshape, name="data", dtype=dtype) scale = te.placeholder(oshape, name="scale", dtype=dtype) bias = te.placeholder(oshape, name="bias", dtype=dtype) running_mean = te.placeholder(oshape, name="running_mean", dtype=dtype) running_var = te.placeholder(oshape, name="running_var", dtype=dtype) eps = te.placeholder(sshape, name="eps", dtype=dtype) momentum = te.placeholder(sshape, name="momentum", dtype=dtype) axis = (0, 2, 3) num_ele = dshape[0] * dshape[2] * dshape[3] frac_num_ele = 1.0 / num_ele # compute batch mean mean_sum = topi.sum(data, axis, keepdims=True) saved_mean = topi.multiply(mean_sum, frac_num_ele) # compute batch rvars var_sub = topi.subtract(data, saved_mean) var_mul = topi.multiply(var_sub, var_sub) var_sum = topi.sum(var_mul, axis, keepdims=True) var = topi.multiply(var_sum, frac_num_ele) output_add = topi.add(var, eps) saved_rvars = topi.sqrt(output_add) # # compute output output_sub = topi.subtract(data, saved_mean) output_norm = topi.divide(output_sub, saved_rvars) scale_board = topi.reshape(scale, bshape) bias_board = topi.reshape(bias, bshape) output = topi.add(topi.multiply(output_norm, scale_board), bias_board) # reshape saved_rvars saved_rvars = topi.reshape(saved_rvars, oshape) # update running mean running_mean_mul1 = topi.multiply(running_mean, topi.subtract(1.0, momentum)) running_mean_mul2 = topi.multiply(topi.reshape(saved_mean, oshape), momentum) running_mean_out = topi.add(running_mean_mul1, running_mean_mul2) # update running var saved_var_mul1 = topi.multiply(running_var, topi.subtract(1.0, momentum)) saved_var_mul2 = topi.multiply(topi.reshape(var, oshape), momentum) running_var_out = topi.add(saved_var_mul1, saved_var_mul2) # reshape saved_mean saved_mean = topi.reshape(saved_mean, oshape) return [ data, scale, bias, running_mean, running_var, momentum, eps, output, saved_mean, saved_rvars, running_mean_out, running_var_out ]
def get_target(): a = te.placeholder((100, ), dtype='bfloat16') b = te.placeholder((100, ), dtype='bfloat16') c = te.compute((100, ), lambda i: to16( topi.add(topi.add( to32(a[i]), to32(b[i]), ), topi.add( to32(a[i]), to32(b[i]), )))) s = te.create_schedule(c.op) func = tvm.driver.build_module.form_irmodule(s, [a, b, c], "main", None)["main"] return func.body
def check_target(target): dev = tvm.device(target, 0) if not tvm.testing.device_enabled(target): print("Skip because %s is not enabled" % target) return print("Running on target: %s" % target) with tvm.target.Target(target): fcompute, fschedule = tvm.topi.testing.dispatch( target, _conv2d_hwcn_implement) t_conv = fcompute(A, W, stride, padding, dilation) t_bias = topi.add(t_conv, B) t_relu = topi.nn.relu(t_bias) s1 = fschedule([t_conv]) s2 = fschedule([t_bias]) s3 = fschedule([t_relu]) a = tvm.nd.array(a_np, dev) w = tvm.nd.array(w_np, dev) b = tvm.nd.array(b_np, dev) conv_out = tvm.nd.array( np.zeros(get_const_tuple(t_conv.shape), dtype=t_conv.dtype), dev) bias_out = tvm.nd.array( np.zeros(get_const_tuple(t_bias.shape), dtype=t_bias.dtype), dev) relu_out = tvm.nd.array( np.zeros(get_const_tuple(t_relu.shape), dtype=t_relu.dtype), dev) func1 = tvm.build(s1, [A, W, t_conv], target) func2 = tvm.build(s2, [A, W, B, t_bias], target) func3 = tvm.build(s3, [A, W, B, t_relu], target) func1(a, w, conv_out) func2(a, w, b, bias_out) func3(a, w, b, relu_out) tvm.testing.assert_allclose(conv_out.asnumpy(), c1_np, rtol=1e-5) tvm.testing.assert_allclose(bias_out.asnumpy(), c2_np, rtol=1e-5) tvm.testing.assert_allclose(relu_out.asnumpy(), c3_np, rtol=1e-5)
def group_conv2d(N, CI, H, W, CO, KH, KW, strides, padding, dilation, group): CI_G = CI // groups data_shape = (N // env.BATCH, CI // env.BLOCK_IN, H, W, env.BATCH, env.BLOCK_IN) kernel_shape = (CO // env.BLOCK_OUT, CI_G // env.BLOCK_IN, KH, KW, env.BLOCK_OUT, env.BLOCK_IN) bias_shape = (N // env.BATCH, CO // env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT) data = te.placeholder(data_shape, name="data", dtype=env.inp_dtype) kernel = te.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype) bias = te.placeholder(bias_shape, name="bias", dtype=env.acc_dtype) with tvm.target.vta(): res = topi.nn.group_conv2d_nchw(data, kernel, strides, padding, dilation, groups, env.acc_dtype) res = topi.right_shift(res, env.WGT_WIDTH) res = topi.add(res, bias) res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1) res = topi.cast(res, env.out_dtype) if tvm.target.Target.current().device_name == "vta": s = topi.generic.schedule_group_conv2d_nchw([res]) else: s = te.create_schedule([res.op]) return s, [data, kernel, bias, res]
def check_device(device, ctx): print("Running on target: %s" % device) fcompute, fschedule = tvm.topi.testing.dispatch( device, _conv3d_ncdhw_implement) with tvm.target.Target(device): C = fcompute(A, W, (stride, stride, stride), padding, (dilation, dilation, dilation), dtype) if add_bias: C = topi.add(C, bias) if add_relu: C = topi.nn.relu(C) s = fschedule([C]) a = tvm.nd.array(a_np, ctx) w = tvm.nd.array(w_np, ctx) b = tvm.nd.array(b_np, ctx) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) if add_bias: func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) func(a, w, b, c) else: func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) func(a, w, c) tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4)
def check_device(device): ctx = tvm.context(device, 0) if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return if not nvcc.have_tensorcore(ctx.compute_version): print("skip because gpu does not support Tensor Cores") return print("Running on target: %s" % device) with tvm.target.create(device): fcompute, fschedule = tvm.topi.testing.dispatch(device, _conv2d_nhwc_tensorcore_implement) C = fcompute(A, W, stride, padding, dilation, 'float32') if add_bias: C = topi.add(C, bias) if add_relu: C = topi.nn.relu(C) s = fschedule([C]) a = tvm.nd.array(a_np, ctx) w = tvm.nd.array(w_np, ctx) b = tvm.nd.array(b_np, ctx) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) if add_bias: func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % ( batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) func(a, w, b, c) else: func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % ( batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) func(a, w, c) rtol = 1e-3 tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=rtol)
def check_device(device): ctx = tvm.context(device, 0) if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return if device == "cuda" and not tvm.contrib.nvcc.have_int8( ctx.compute_version): print("Skip because int8 intrinsics are not available") return print("Running on target: %s" % device) with tvm.target.Target(device): C = topi.cuda.group_conv2d_NCHWc_int8(A, W, stride, padding, dilation, groups, dtype) if add_bias: C = topi.add(C, bias) if add_relu: C = topi.nn.relu(C) s = topi.cuda.schedule_group_conv2d_NCHWc_int8([C]) a = tvm.nd.array(a_np, ctx) w = tvm.nd.array(w_np, ctx) b = tvm.nd.array(b_np, ctx) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) if add_bias: func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d" %\ (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups)) func(a, w, b, c) else: func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d" % \ (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups)) func(a, w, c) tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, dilation): data_shape = (N // env.BATCH, CI // env.BLOCK_IN, H, W, env.BATCH, env.BLOCK_IN) kernel_shape = (CO // env.BLOCK_OUT, CI // env.BLOCK_IN, KH, KW, env.BLOCK_OUT, env.BLOCK_IN) bias_shape = (N // env.BATCH, CO // env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT) data = te.placeholder(data_shape, name="data", dtype=env.inp_dtype) kernel = te.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype) bias = te.placeholder(bias_shape, name="bias", dtype=env.acc_dtype) with tvm.target.vta(): res = topi.nn.conv2d(input=data, filter=kernel, padding=padding, strides=strides, dilation=dilation, layout='NCHW%dn%dc' % (env.BATCH, env.BLOCK_IN), out_dtype=env.acc_dtype) res = topi.right_shift(res, env.WGT_WIDTH) res = topi.add(res, bias) res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1) res = topi.cast(res, env.out_dtype) if tvm.target.Target.current().device_name == 'vta': s = topi.generic.schedule_conv2d_nchw([res]) else: s = te.create_schedule([res.op]) return s, [data, kernel, bias, res]
def check_device(device): ctx = tvm.context(device, 0) if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) with tvm.target.Target(device): fcompute, fschedule = tvm.topi.testing.dispatch( device, _group_conv2d_nchw_implement) C = fcompute(A, W, stride, padding, dilation, groups, dtype) if add_bias: C = topi.add(C, bias) if add_relu: C = topi.nn.relu(C) s = fschedule([C]) a = tvm.nd.array(a_np, ctx) w = tvm.nd.array(w_np, ctx) b = tvm.nd.array(b_np, ctx) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) if add_bias: func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d" %\ (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups)) func(a, w, b, c) else: func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d" % \ (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups)) func(a, w, c) tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
def compile_conv2d_NHWC_gemm_int8_arm(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False): pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel)) padding_sum = pad_top + pad_left + pad_bottom + pad_right print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) in_height = in_width = in_size A = te.placeholder((batch, in_height, in_width, in_channel), name='A', dtype='int8') W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W', dtype='int8') bias = te.placeholder((num_filter,), name='bias', dtype='int8') dtype = 'int32' device = "llvm --device arm_cpu --mtriple aarch64-linux-gnu" ctx = tvm.context(device, 0) if not ctx.exist: print("Skip because %s is not enabled" % device) return print("Compiling on arm AArch64 target: %s" % device) with tvm.target.create(device): assert is_aarch64_arm(), "AArch64 target not recognized" C = topi.arm_cpu.compute_conv2d_NHWC_quantized(A, W, (stride, stride), padding, (dilation, dilation), dtype) if add_bias: C = topi.add(C, bias) if add_relu: C = topi.nn.relu(C) s = topi.arm_cpu.schedule_conv2d_NHWC_quantized([C]) if add_bias: tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) else: func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
def check_target(target): dev = tvm.device(target, 0) if not tvm.testing.device_enabled(target): print("Skip because %s is not enabled" % target) return print("Running on target: %s" % target) if "cudnn" in target: fcompute, fschedule = topi.cuda.conv2d_cudnn, topi.cuda.schedule_conv2d_cudnn else: fcompute, fschedule = tvm.topi.testing.get_conv2d_nchw_implement( target) with tvm.target.Target(target): if "cudnn" in target: C = fcompute(A, W, (stride, stride), padding, (dilation, dilation), 1, "NCHW", dtype) else: C = fcompute(A, W, (stride, stride), padding, (dilation, dilation), dtype) if add_bias: C = topi.add(C, bias) if add_relu: C = topi.nn.relu(C) s = fschedule([C]) if "llvm" in target: verify_workload_padding() a = tvm.nd.array(a_np, dev) w = tvm.nd.array(w_np, dev) b = tvm.nd.array(b_np, dev) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev) if add_bias: func = tvm.build( s, [A, W, bias, C], target, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation), ) func(a, w, b, c) else: func = tvm.build( s, [A, W, C], target, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation), ) func(a, w, c) tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-4)
def check_target(target): dev = tvm.device(target, 0) if not tvm.testing.device_enabled(target): print("Skip because %s is not enabled" % target) return if target == "cuda" and not tvm.contrib.nvcc.have_int8( dev.compute_version): print("Skip because int8 intrinsics are not available") return print("Running on target: %s" % target) with tvm.target.Target(target): C = topi.cuda.conv2d_NCHWc_int8(A, W, (stride, stride), padding, (dilation, dilation), "NCHW", dtype) if add_bias: C = topi.add(C, bias) if add_relu: C = topi.nn.relu(C) s = topi.cuda.schedule_conv2d_NCHWc_int8([C]) a = tvm.nd.array(a_np, dev) w = tvm.nd.array(w_np, dev) b = tvm.nd.array(b_np, dev) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev) if add_bias: tvm.build( s, [A, W, bias, C], target, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation), ) func = tvm.build( s, [A, W, bias, C], target, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation), ) func(a, w, b, c) else: func = tvm.build( s, [A, W, C], target, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation), ) func(a, w, c) tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
def test_schedule_cache_reads(): a = te.placeholder((12, 12), dtype="uint8", name="a") b = te.placeholder((12, 12), dtype="uint8", name="b") add = topi.add(a, b) sch = te.create_schedule([add.op]) cr = sch.cache_read(b, "global", [add]) schedule_cache_reads(sch) assert len(sch.stages) == 4 assert len(sch[cr].leaf_iter_vars) == 1 iv = sch[cr].leaf_iter_vars[0] assert list(sch[cr].iter_var_attrs[iv].pragma_keys) == ["op"] assert list(sch[cr].iter_var_attrs[iv].pragma_values) == ["ethosu_copy"]
def test_complex_reduce(): in_shape = (2, 3) dtype = "float32" axis = 0 keepdims = False A = te.placeholder(shape=in_shape, name="A", dtype=dtype) B = topi.sum(A, axis=axis, keepdims=keepdims) C = topi.add(B, B) D = topi.multiply(B, B) E = topi.add(C, D) for device, ctx in tvm.testing.enabled_targets(): print("Running on target: %s" % device) with tvm.target.Target(device): s = tvm.topi.testing.get_reduce_schedule(device)(E) foo = tvm.build(s, [A, E], device, name="sum") in_npy = np.random.uniform(-1, 1, size=in_shape).astype(dtype) sum_npy = in_npy.sum(axis=axis, keepdims=keepdims) out_npy = sum_npy * 2 + sum_npy * sum_npy data_tvm = tvm.nd.array(in_npy, ctx=ctx) out_tvm = tvm.nd.empty(shape=out_npy.shape, ctx=ctx, dtype=dtype) foo(data_tvm, out_tvm) tvm.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1e-3, 1e-3)
def check_device(device): ctx = tvm.context(device, 0) if not ctx.exist: print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) with tvm.target.create(device): C = topi.arm_cpu.compute_conv2d_NHWC_quantized(A, W, (stride, stride), padding, (dilation, dilation), dtype) if add_bias: C = topi.add(C, bias) if add_relu: C = topi.nn.relu(C) s = topi.arm_cpu.schedule_conv2d_NHWC_quantized([C]) a = tvm.nd.array(a_np, ctx) w = tvm.nd.array(w_np, ctx) b = tvm.nd.array(b_np, ctx) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) if add_bias: tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) func(a, w, b, c) else: func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) func(a, w, c) tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
def check_device(device): dev = tvm.device(device, 0) if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) with tvm.target.Target(device): C = topi.x86.conv2d_NCHWc( A, W, (stride, stride), padding, (dilation, dilation), "NCHW%dc" % ic_block, "NCHW%dc" % oc_block, dtype, ) if add_bias: C = topi.add(C, bias) if add_relu: C = topi.nn.relu(C) s = topi.x86.schedule_conv2d_NCHWc([C]) a = tvm.nd.array(a_np, dev) w = tvm.nd.array(w_np, dev) b = tvm.nd.array(b_np, dev) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev) if add_bias: func = tvm.build( s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation), ) func(a, w, b, c) else: func = tvm.build( s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation), ) func(a, w, c) tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
def check_device(device): dev = tvm.device(device, 0) if not tvm.testing.device_enabled(device): print("Skipping %s becuase it is not enabled" % device) print("Running on target: %s" % device) with tvm.target.Target(device): result_c = topi.nn.conv2d( placholder_a, placeholder_w, stride, padding, dilation, data_layout="NCHW", out_dtype=dtype, ) if add_bias: result_c = topi.add(result_c, bias) if add_relu: result_c = topi.nn.relu(result_c) schedule = topi.generic.schedule_conv2d_nchw([result_c]) buff_a = tvm.nd.array(a_np, dev) buff_w = tvm.nd.array(w_np, dev) buff_b = tvm.nd.array(b_np, dev) buff_c = tvm.nd.array( np.zeros(get_const_tuple(result_c.shape), dtype=result_c.dtype), dev) if add_bias: func = tvm.build( schedule, [placholder_a, placeholder_w, bias, result_c], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation), ) func(buff_a, buff_w, buff_b, buff_c) else: func = tvm.build( schedule, [placholder_a, placeholder_w, result_c], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation), ) func(buff_a, buff_w, buff_c) tvm.testing.assert_allclose(buff_c.numpy(), c_np, rtol=1e-4)
def check_device(device, ctx): with tvm.target.Target(device): print("Running on target: %s" % device) conv2d_compute, conv2d_schedule = tvm.topi.testing.get_conv2d_nchw_implement(device) data = te.placeholder((2, 1, 2, 4), "int8", "data") w = te.placeholder((3, 1, 2, 2), "int8", "w") conv1 = conv2d_compute(data, w, 1, 0, 1, "int32") zeros = topi.full((2, 3, 1, 3), "int32", tvm.tir.const(0, dtype="int32")) gt = topi.greater_equal(conv1, zeros) one = topi.full((2, 3, 1, 3), "int32", tvm.tir.const(1, dtype="int32")) two = topi.full((2, 3, 1, 3), "int32", tvm.tir.const(2, dtype="int32")) where = topi.where(gt, one, two) add = topi.add(conv1, where) outs = [add] s = conv2d_schedule(outs) tvm.build(s, [data, w, add], target=backend)
def check_device(device): dev = tvm.device(device, 0) if not tvm.testing.device_enabled(device): print("Skipping %s becuase it is not enabled" % device) print("Running on target: %s" % device) with tvm.target.Target(device): C = topi.nn.conv2d(A, W, stride, padding, dilation, layout="NCHW", out_dtype=dtype) if add_bias: C = topi.add(C, bias) if add_relu: C = topi.nn.relu(C) s = topi.generic.schedule_conv2d_nchw([C]) a = tvm.nd.array(a_np, dev) w = tvm.nd.array(w_np, dev) b = tvm.nd.array(b_np, dev) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev) if add_bias: func = tvm.build( s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation), ) func(a, w, b, c) else: func = tvm.build( s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation), ) func(a, w, c) tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4)
def test_tiled_added_poolings(tile_shape): A = te.placeholder((1, 12, 12, 16), name="A", dtype="int8") B = te.placeholder((1, 14, 14, 16), name="A", dtype="int8") pool_a = topi.nn.pool2d(A, (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC") pool_b = topi.nn.pool2d(B, (5, 5), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC") add = topi.add(pool_a, pool_b) pool_c = topi.nn.pool2d(add, (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), "max", layout="NHWC") sch = tvm.te.create_schedule([pool_c.op]) oi, ii = _tile_nd(sch, pool_c, tile_shape) sch[add].compute_at(sch[pool_c], oi[-1]) sch[add].rolling_buffer() sch[pool_b].compute_at(sch[pool_c], oi[-1]) sch[pool_b].rolling_buffer() sch[pool_a].compute_at(sch[pool_c], oi[-1]) sch[pool_a].rolling_buffer() _verify_schedule(sch, [A, B], pool_c)
def test_bfloat_add_and_cast_FloatImm(): X = te.placeholder((3, ), name="X") Z = topi.cast(topi.add(topi.cast(X, dtype="custom[bfloat]16"), tvm.tir.FloatImm("custom[bfloat]16", 1.5)), dtype="float") s = te.create_schedule([Z.op]) built_cast = lower_datatypes_and_build(s, [X, Z]) ctx = tvm.context(tgt, 0) x = tvm.nd.array(np.array([0.0, 1.0, 1.5]).astype("float32"), ctx=ctx) z_expected = np.array([1.5, 2.5, 3.0]).astype("float32") z = tvm.nd.empty(Z.shape, dtype=Z.dtype, ctx=ctx) built_cast(x, z) assert np.array_equal(z_expected, z.asnumpy())
def check_device(device): ctx = tvm.context(device, 0) print("Running on target: %s" % device) with tvm.target.Target(device): if bgemm == "direct": fcompute, fschedule = tvm.topi.testing.dispatch( device, _conv2d_nhwc_winograd_direct) elif bgemm == "tensorcore": fcompute, fschedule = tvm.topi.testing.dispatch( device, _conv2d_nhwc_winograd_tensorcore) C = fcompute(A, W, stride, padding, dilation, "float32") if add_bias: C = topi.add(C, bias) if add_relu: C = topi.nn.relu(C) s = fschedule([C]) a = tvm.nd.array(a_np, ctx) w = tvm.nd.array(w_np, ctx) b = tvm.nd.array(b_np, ctx) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) if add_bias: func = tvm.build( s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation), ) func(a, w, b, c) else: func = tvm.build( s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation), ) func(a, w, c) tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=2e-3)
def check_device(device): dev = tvm.device(device, 0) print("Running on target: %s" % device) with tvm.target.Target(device): fcompute, fschedule = tvm.topi.testing.dispatch( device, _conv3d_ndhwc_tensorcore_implement) C = fcompute(A, W, stride, padding, dilation, 1, "float16") if add_bias: C = topi.add(C, bias) if add_relu: C = topi.nn.relu(C) s = fschedule([C]) a = tvm.nd.array(a_np, dev) w = tvm.nd.array(w_np, dev) b = tvm.nd.array(b_np, dev) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev) if add_bias: func = tvm.build( s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation), ) func(a, w, b, c) else: func = tvm.build( s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation), ) func(a, w, c) # Tensorcores are very inaccurate, with large shapes, the accumulation # error is high especially away from 1. We disable atol as it is very # large for these numbers that are far away from 1. tvm.testing.assert_allclose(c.numpy(), c_np, atol=1e200, rtol=0.01)
def batch_norm_bwd(N, C, H, W, dtype="float32"): dshape = (N, C, H, W) oshape = (C, ) bshape = (1, C, 1, 1) sshape = (1, ) data = te.placeholder(dshape, name="data", dtype=dtype) scale = te.placeholder(oshape, name="scale", dtype=dtype) saved_mean = te.placeholder(oshape, name="saved_mean", dtype=dtype) saved_var = te.placeholder(oshape, name="saved_var", dtype=dtype) eps = te.placeholder(sshape, name="eps", dtype=dtype) grad_output = te.placeholder(dshape, name="data", dtype=dtype) axis = (0, 2, 3) num_ele = dshape[0] * dshape[2] * dshape[3] frac_num_ele = 1.0 / num_ele # compute grad_input mean_sum = topi.sum(data, axis, True) mean = topi.multiply(mean_sum, frac_num_ele) var_sub = topi.subtract(data, mean) var_mul = topi.multiply(var_sub, var_sub) var_sum = topi.sum(var_mul, axis, True) var = topi.multiply(var_sum, frac_num_ele) var_eps = topi.add(var, eps) output_sqrt = topi.sqrt(var_eps) x_norm = topi.subtract(data, mean) x_hat = topi.divide(x_norm, output_sqrt) dx_hat = topi.multiply(grad_output, topi.reshape(scale, bshape)) grad_input_sum1 = topi.sum(dx_hat * x_hat, axis, True) grad_input_sum2 = topi.sum(dx_hat, axis, True) grad_input_left = topi.divide(frac_num_ele, topi.sqrt(var_eps)) grad_input_right1 = topi.subtract(topi.multiply(dx_hat, num_ele), grad_input_sum2) grad_input_right2 = topi.multiply(x_hat, grad_input_sum1) grad_input = topi.multiply( grad_input_left, topi.subtract(grad_input_right1, grad_input_right2)) # compute grad_scale and grad_bias grad_scale = topi.sum(grad_output * x_hat, axis) grad_bias = topi.sum(grad_output, axis) return [ data, scale, saved_mean, saved_var, eps, grad_output, grad_input, grad_scale, grad_bias ]
def check_device(device): with tvm.target.create(device): ctx = tvm.context(device, 0) if not ctx.exist: print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) conv2d_compute, conv2d_schedule = tvm.topi.testing.get_conv2d_nchw_implement(device) data = te.placeholder((2, 1, 2, 4), 'int8', 'data') w = te.placeholder((3, 1, 2, 2), 'int8', 'w') conv1 = conv2d_compute(data, w, 1, 0, 1, 'int32') zeros = topi.full((2, 3, 1, 3), 'int32', tvm.tir.const(0, dtype='int32')) gt = topi.greater_equal(conv1, zeros) one = topi.full((2, 3, 1, 3), 'int32', tvm.tir.const(1, dtype='int32')) two = topi.full((2, 3, 1, 3), 'int32', tvm.tir.const(2, dtype='int32')) where = topi.where(gt, one, two) add = topi.add(conv1, where) outs = [add] s = conv2d_schedule(outs) tvm.build(s, [data, w, add], target=backend)
def check_device(device): dev = tvm.device(device, 0) print("Running on target: %s" % device) with tvm.target.Target(device): fcompute, fschedule = tvm.topi.testing.dispatch( device, _conv3d_ndhwc_tensorcore_implement) C = fcompute(A, W, stride, padding, dilation, "float32") if add_bias: C = topi.add(C, bias) if add_relu: C = topi.nn.relu(C) s = fschedule([C]) a = tvm.nd.array(a_np, dev) w = tvm.nd.array(w_np, dev) b = tvm.nd.array(b_np, dev) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev) if add_bias: func = tvm.build( s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation), ) func(a, w, b, c) else: func = tvm.build( s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation), ) func(a, w, c) rtol = 1e-3 tvm.testing.assert_allclose(c.numpy(), c_np, rtol=rtol)
batch = 1 in_channel = 256 in_size = 32 num_filter = 256 kernel = 3 stride = 1 padding = "SAME" dilation = 1 A = te.placeholder((in_size, in_size, in_channel, batch), name="A") W = te.placeholder((kernel, kernel, in_channel, num_filter), name="W") B = te.placeholder((1, num_filter, 1), name="bias") with tvm.target.Target("llvm"): t_conv = topi.nn.conv2d_hwcn(A, W, stride, padding, dilation) t_bias = topi.add(t_conv, B) t_relu = topi.nn.relu(t_bias) s = topi.generic.schedule_conv2d_hwcn([t_relu]) ###################################################################### # Render Graphs with TEDD # ----------------------- # We render graphs to see the computation # and how it is scheduled. # If you run the tutorial in a Jupyter notebook, you can use the following commented lines # to render SVG figures showing in notebook directly. # tedd.viz_dataflow_graph(s, dot_file_path="/tmp/dfg.dot") # tedd.viz_dataflow_graph(s, show_svg = True)
def run_group_conv2d(env, remote, wl, target, check_correctness=True, print_ir=False, samples=4): # Workload assertions assert wl.hpad == wl.wpad # Perform packing only if we are targeting the accelerator if "arm_cpu" in target.keys: data_pack = False layout = "NCHW" fcompute = topi.nn.group_conv2d_nchw fschedule = topi.generic.schedule_group_conv2d_nchw elif "vta" in target.keys: data_pack = True layout = "NCHW%dn%dc" % (env.BATCH, env.BLOCK_IN) fcompute = vta.top.group_conv2d_packed fschedule = vta.top.schedule_group_conv2d_packed # Derive shapes depending upon packing CI_G = wl.in_filter // wl.groups a_shape = (wl.batch, wl.in_filter, wl.height, wl.width) w_shape = (wl.out_filter, CI_G, wl.hkernel, wl.wkernel) b_shape = (wl.batch, wl.out_filter, 1, 1) if data_pack: data_shape = (wl.batch // env.BATCH, wl.in_filter // env.BLOCK_IN, wl.height, wl.width, env.BATCH, env.BLOCK_IN) kernel_shape = (wl.out_filter // env.BLOCK_OUT, CI_G // env.BLOCK_IN, wl.hkernel, wl.wkernel, env.BLOCK_OUT, env.BLOCK_IN) bias_shape = (wl.batch // env.BATCH, wl.out_filter // env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT) else: data_shape = a_shape kernel_shape = w_shape bias_shape = b_shape data = te.placeholder(data_shape, name="data", dtype=env.inp_dtype) kernel = te.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype) bias = te.placeholder(bias_shape, name="bias", dtype=env.acc_dtype) padding = relay.nn.get_pad_tuple2d((wl.hpad, wl.wpad)) # Define base computation schedule with target: res = fcompute(data, kernel, (wl.hstride, wl.wstride), padding, (1, 1), wl.groups, env.acc_dtype) res = topi.right_shift(res, 8) res = topi.add(res, bias) res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1) res = topi.cast(res, env.out_dtype) # Derive base schedule s = fschedule([res]) if print_ir: print(vta.lower(s, [data, kernel, bias, res], simple_mode=True)) # Derive number of ops fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1 fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1 num_ops = 2 * wl.batch * fout_height * fout_width * wl.hkernel * wl.wkernel * \ wl.out_filter * wl.in_filter // wl.groups def get_ref_data(): # derive min max for act, wgt, and bias types (max non inclusive) a_min, a_max = 0 - (1 << (env.INP_WIDTH - 1)), (1 << (env.INP_WIDTH - 1)) w_min, w_max = 0 - (1 << (env.WGT_WIDTH - 1)), (1 << (env.WGT_WIDTH - 1)) b_min, b_max = 0 - 1 << (env.INP_WIDTH + env.WGT_WIDTH - 2), 1 << (env.INP_WIDTH + env.WGT_WIDTH - 2) a_np = np.random.randint(a_min, a_max, size=a_shape).astype(data.dtype) w_np = np.random.randint(w_min, w_max, size=w_shape).astype(kernel.dtype) b_np = np.random.randint(b_min, b_max, size=b_shape).astype(env.acc_dtype) r_np = tvm.topi.testing.conv2d_nchw_python( a_np.astype(env.acc_dtype), w_np.astype(env.acc_dtype), (wl.hstride, wl.wstride), wl.hpad, wl.groups).astype(env.acc_dtype) return a_np, w_np, b_np, r_np # Data in original format data_np, kernel_np, bias_np, res_ref = get_ref_data() if data_pack: data_np = data_np.reshape(wl.batch // env.BATCH, env.BATCH, wl.in_filter // env.BLOCK_IN, env.BLOCK_IN, wl.height, wl.width).transpose( (0, 2, 4, 5, 1, 3)) kernel_np = kernel_np.reshape(wl.out_filter // env.BLOCK_OUT, env.BLOCK_OUT, CI_G // env.BLOCK_IN, env.BLOCK_IN, wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3)) bias_np = bias_np.reshape(wl.batch // env.BATCH, wl.out_filter // env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT) # Build if "vta" in target.keys: mod = vta.build(s, [data, kernel, bias, res], target=target, target_host=env.target_host, name="conv2d") else: mod = tvm.build(s, [data, kernel, bias, res], target=target, target_host=env.target_host, name="conv2d") temp = util.tempdir() mod.save(temp.relpath("conv2d.o")) remote.upload(temp.relpath("conv2d.o")) f = remote.load_module("conv2d.o") ctx = remote.context(str(target)) res_np = np.zeros(topi.util.get_const_tuple(res.shape)).astype(res.dtype) data_arr = tvm.nd.array(data_np, ctx) kernel_arr = tvm.nd.array(kernel_np, ctx) bias_arr = tvm.nd.array(bias_np, ctx) res_arr = tvm.nd.array(res_np, ctx) time_f = f.time_evaluator("conv2d", ctx, number=samples) # In vta sim mode, collect simulator runtime statistics stats = {} cost = None if env.TARGET in ["sim", "tsim"]: # Check if we're in local RPC mode (allows us to rebuild the # runtime on the fly when varying the VTA designs) local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0")) if local_rpc: if env.TARGET == "sim": remote.get_function("vta.simulator.profiler_clear")() else: remote.get_function("vta.tsim.profiler_clear")() cost = time_f(data_arr, kernel_arr, bias_arr, res_arr) if env.TARGET == "sim": stats = json.loads( remote.get_function("vta.simulator.profiler_status")()) else: stats = json.loads( remote.get_function("vta.tsim.profiler_status")()) else: simulator.clear_stats() cost = time_f(data_arr, kernel_arr, bias_arr, res_arr) stats = simulator.stats() else: cost = time_f(data_arr, kernel_arr, bias_arr, res_arr) # Check correctness correct = False if check_correctness: res_orig = res_arr.asnumpy() if data_pack: res_orig = res_orig.transpose( (0, 4, 1, 5, 2, 3)).reshape(wl.batch, wl.out_filter, fout_height, fout_width) bias_np = bias_np.transpose( (0, 4, 1, 5, 2, 3)).reshape(wl.batch, wl.out_filter, 1, 1) res_ref = res_ref >> env.WGT_WIDTH res_ref += bias_np res_ref = np.clip(res_ref, 0, (1 << env.OUT_WIDTH - 1) - 1) res_ref = res_ref.astype(env.out_dtype) correct = np.allclose(res_orig, res_ref) gops = (num_ops / cost.mean) / float(10**9) status = "PASSED" if correct else "FAILED" if "arm_cpu" in target.keys: device = "CPU" elif "vta" in target.keys: device = "VTA" print("%s GROUP CONV2D TEST %s: Time cost = %g sec/op, %g GOPS" % (device, status, cost.mean, gops)) return correct, cost, stats
def test_cache_read_write(): N, H, W, CO, CI, KH, KW, strides, padding = 4, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1) data = te.placeholder((N, CI, H, W), name="Data") kernel_data = te.placeholder((CO, CI, KH, KW), name="Kernel_data") k0, k1 = te.compute( kernel_data.shape, lambda *i: (kernel_data(*i) + 1, kernel_data(*i) / 2), name="Kernel_split", ) kernel = te.compute(kernel_data.shape, lambda *i: k0(*i) + k1(*i), name="Kernel") conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation=1) relu = topi.nn.relu(conv) add = topi.add(data, relu) dag = auto_scheduler.ComputeDAG([data, kernel_data, add]) s0 = dag.get_init_state() pad_temp = s0.stage_ops[1] kernel_split = s0.stage_ops[3] # 0: init state ori_its = s0[add].iters its = s0.split(add, s0[add].iters[0], [2]) s0.reorder(add, [its[0], ori_its[1], its[1], ori_its[2], ori_its[3]]) s0.compute_inline(relu) # 1: simple cache_write with compute_at conv_global = s0.cache_write(conv, "global") s0.compute_at(conv_global, conv, s0[conv].iters[3]) # 2: simple cache_read with compute_at kernel_global = s0.cache_read(kernel, "global", [conv_global]) s0.compute_at(kernel_global, conv_global, s0[conv_global].iters[4]) """ Placeholder: Data, Kernel_data for i0 (0,4) for i1 (0,512) for i2 (0,9) for i3 (0,9) pad_temp = ... for i0 (0,512) for i1 (0,512) for i2 (0,3) for i3 (0,3) Kernel_split = ... for i0 (0,512) for i1 (0,512) for i2 (0,3) for i3 (0,3) Kernel = ... for nn (0,4) for ff (0,512) for yy (0,7) for xx (0,7) for nn_c (None) for ff_c (None) for yy_c (None) for xx_c (None) for rc (None) for ax0 (None) for ax1 (None) for ax2 (None) for ax3 (None) Kernel.global = ... for ry (None) for rx (None) compute.global = ... compute = ... for ax0.0 (0,2) for ax1 (0,512) for ax0.1 (0,2) for ax2 (0,7) for ax3 (0,7) T_add = ... """ s1 = dag.infer_bound_from_state(s0) assert s1[conv].iters[0].range.extent == 4 assert s1[conv].iters[1].range.extent == 512 assert s1[conv].iters[2].range.extent == 7 assert s1[conv].iters[3].range.extent == 7 assert s1[kernel_global].iters[0].range.extent == 1 assert s1[kernel_global].iters[1].range.extent == 1 assert s1[kernel_global].iters[2].range.extent == 3 assert s1[kernel_global].iters[3].range.extent == 3 assert s1[conv_global].iters[0].range.extent == 1 assert s1[conv_global].iters[1].range.extent == 1 assert s1[conv_global].iters[2].range.extent == 1 assert s1[conv_global].iters[3].range.extent == 1 assert s1[conv_global].iters[4].range.extent == 512 assert s1[conv_global].iters[5].range.extent == 3 assert s1[conv_global].iters[6].range.extent == 3 # 3: two level cache_read with compute_at # preparing for GPU's shared memory & local memory pad_temp_global = s0.cache_read(pad_temp, "global", [conv_global]) pad_temp_shared = s0.cache_read(pad_temp_global, "shared", [conv_global]) s0.compute_at(pad_temp_global, conv_global, s0[conv_global].iters[2]) s0.compute_at(pad_temp_shared, conv_global, s0[conv_global].iters[4]) # 4: cache_read with multi readers # This stage cannot be compute at to its consumer s0.cache_read(data, "global", [pad_temp, add]) """ Placeholder: Data, Kernel_data for ax0 (0,4) for ax1 (0,512) for ax2 (0,7) for ax3 (0,7) Data.global = ... for i0 (0,4) for i1 (0,512) for i2 (0,9) for i3 (0,9) pad_temp = ... for i0 (0,512) for i1 (0,512) for i2 (0,3) for i3 (0,3) Kernel_split = ... for i0 (0,512) for i1 (0,512) for i2 (0,3) for i3 (0,3) Kernel = ... for nn (0,4) for ff (0,512) for yy (0,7) for xx (0,7) for nn_c (None) for ff_c (None) for yy_c (None) for ax0 (None) for ax1 (None) for ax2 (None) for ax3 (None) pad_temp.global = ... for xx_c (None) for rc (None) for ax0 (None) for ax1 (None) for ax2 (None) for ax3 (None) Kernel.global = ... for ax0 (None) for ax1 (None) for ax2 (None) for ax3 (None) pad_temp.global.shared = ... for ry (None) for rx (None) compute.global = ... compute = ... for ax0.0 (0,2) for ax1 (0,512) for ax0.1 (0,2) for ax2 (0,7) for ax3 (0,7) T_add = ... """ s1 = dag.infer_bound_from_state(s0) assert s1[conv].iters[0].range.extent == 4 assert s1[conv].iters[1].range.extent == 512 assert s1[conv].iters[2].range.extent == 7 assert s1[conv].iters[3].range.extent == 7 assert s1[kernel_global].iters[0].range.extent == 1 assert s1[kernel_global].iters[1].range.extent == 1 assert s1[kernel_global].iters[2].range.extent == 3 assert s1[kernel_global].iters[3].range.extent == 3 assert s1[conv_global].iters[0].range.extent == 1 assert s1[conv_global].iters[1].range.extent == 1 assert s1[conv_global].iters[2].range.extent == 1 assert s1[conv_global].iters[3].range.extent == 1 assert s1[conv_global].iters[4].range.extent == 512 assert s1[conv_global].iters[5].range.extent == 3 assert s1[conv_global].iters[6].range.extent == 3 assert s1[pad_temp_global].iters[0].range.extent == 1 assert s1[pad_temp_global].iters[1].range.extent == 512 assert s1[pad_temp_global].iters[2].range.extent == 3 assert s1[pad_temp_global].iters[3].range.extent == 3 assert s1[pad_temp_shared].iters[0].range.extent == 1 assert s1[pad_temp_shared].iters[1].range.extent == 1 assert s1[pad_temp_shared].iters[2].range.extent == 3 assert s1[pad_temp_shared].iters[3].range.extent == 3 # 5: cache_write with multi outputs # TVM's cache_write actually has a bug with this case: # # After schedule.cache_write, TVM generate one new stage: # From: kernel_data -> kernel_split -> kernel # To: kernel_data -> kernel_split_global -> kernel_split -> kernel # # But with topo sort analyse, we get: # // kernel_data -> kernel_split_global -> kernel_split -> kernel # \ / # ----------------> kernel_split ----------------> # # TODO(jcf94): Seems there's bug with the input/output tensor. Such multi outputs case # should be unusual, so we make some hack on DoCacheWrite. This should be fixed later. kernel_split_global = s0.cache_write(kernel_split, "global") """ Placeholder: Data, Kernel_data for ax0 (0,4) for ax1 (0,512) for ax2 (0,7) for ax3 (0,7) Data.global = ... for i0 (0,4) for i1 (0,512) for i2 (0,9) for i3 (0,9) pad_temp = ... for i0_c (0,512) for i1_c (0,512) for i2_c (0,3) for i3_c (0,3) Kernel_split.global = ... for i0 (0,512) for i1 (0,512) for i2 (0,3) for i3 (0,3) Kernel_split = ... (******* Bug here, there should not be two kernel_split stage *******) for i0 (0,512) for i1 (0,512) for i2 (0,3) for i3 (0,3) Kernel_split = ... (******* Bug here, there should not be two kernel_split stage *******) for i0 (0,512) for i1 (0,512) for i2 (0,3) for i3 (0,3) Kernel = ... for nn (0,4) for ff (0,512) for yy (0,7) for xx (0,7) for nn_c (None) for ff_c (None) for yy_c (None) for ax0 (None) for ax1 (None) for ax2 (None) for ax3 (None) pad_temp.global = ... for xx_c (None) for rc (None) for ax0 (None) for ax1 (None) for ax2 (None) for ax3 (None) Kernel.global = ... for ax0 (None) for ax1 (None) for ax2 (None) for ax3 (None) pad_temp.global.shared = ... for ry (None) for rx (None) compute.global = ... compute = ... for ax0.0 (0,2) for ax1 (0,512) for ax0.1 (0,2) for ax2 (0,7) for ax3 (0,7) T_add = ... """ assert len(s0[kernel_split].iters) == len(s0[kernel_split_global].iters) for it0, it1 in zip(s0[kernel_split].iters, s0[kernel_split_global].iters): assert it0.range == it1.range