def batch_norm_fwd(N, C, H, W, dtype="float32"): dshape = (N, C, H, W) oshape = (C, ) bshape = (1, C, 1, 1) sshape = (1, ) data = te.placeholder(dshape, name="data", dtype=dtype) scale = te.placeholder(oshape, name="scale", dtype=dtype) bias = te.placeholder(oshape, name="bias", dtype=dtype) running_mean = te.placeholder(oshape, name="running_mean", dtype=dtype) running_var = te.placeholder(oshape, name="running_var", dtype=dtype) eps = te.placeholder(sshape, name="eps", dtype=dtype) momentum = te.placeholder(sshape, name="momentum", dtype=dtype) axis = (0, 2, 3) num_ele = dshape[0] * dshape[2] * dshape[3] frac_num_ele = 1.0 / num_ele # compute batch mean mean_sum = topi.sum(data, axis, keepdims=True) saved_mean = topi.multiply(mean_sum, frac_num_ele) # compute batch rvars var_sub = topi.subtract(data, saved_mean) var_mul = topi.multiply(var_sub, var_sub) var_sum = topi.sum(var_mul, axis, keepdims=True) var = topi.multiply(var_sum, frac_num_ele) output_add = topi.add(var, eps) saved_rvars = topi.sqrt(output_add) # # compute output output_sub = topi.subtract(data, saved_mean) output_norm = topi.divide(output_sub, saved_rvars) scale_board = topi.reshape(scale, bshape) bias_board = topi.reshape(bias, bshape) output = topi.add(topi.multiply(output_norm, scale_board), bias_board) # reshape saved_rvars saved_rvars = topi.reshape(saved_rvars, oshape) # update running mean running_mean_mul1 = topi.multiply(running_mean, topi.subtract(1.0, momentum)) running_mean_mul2 = topi.multiply(topi.reshape(saved_mean, oshape), momentum) running_mean_out = topi.add(running_mean_mul1, running_mean_mul2) # update running var saved_var_mul1 = topi.multiply(running_var, topi.subtract(1.0, momentum)) saved_var_mul2 = topi.multiply(topi.reshape(var, oshape), momentum) running_var_out = topi.add(saved_var_mul1, saved_var_mul2) # reshape saved_mean saved_mean = topi.reshape(saved_mean, oshape) return [ data, scale, bias, running_mean, running_var, momentum, eps, output, saved_mean, saved_rvars, running_mean_out, running_var_out ]
def test_reduce_map(target, dev, ref_data, in_shape, axis, keepdims, reduce_type, dtype): target = tvm.target.Target(target) if target.kind.name == "vulkan" and reduce_type in ["sum", "any", "all"]: pytest.xfail(f"Vulkan backend has known errors on {reduce_type}") in_npy, in_npy_map, out_npy = ref_data # Build the logic and compile the function A = te.placeholder(shape=in_shape, name="A", dtype=dtype) A1 = topi.sqrt(topi.exp(A)) out_dtype = dtype if reduce_type == "sum": B = topi.sum(A1, axis=axis, keepdims=keepdims) elif reduce_type == "all": B = topi.all(A, axis=axis, keepdims=keepdims) elif reduce_type == "any": B = topi.any(A, axis=axis, keepdims=keepdims) elif reduce_type == "max": B = topi.max(A1, axis=axis, keepdims=keepdims) elif reduce_type == "min": B = topi.min(A1, axis=axis, keepdims=keepdims) elif reduce_type == "argmax": B = topi.argmax(A1, axis=axis, keepdims=keepdims) out_dtype = "int32" elif reduce_type == "argmin": B = topi.argmin(A1, axis=axis, keepdims=keepdims) out_dtype = "int32" else: raise NotImplementedError with tvm.target.Target(target): s = tvm.topi.testing.get_reduce_schedule(target)(B) foo = tvm.build(s, [A, B], target, name=reduce_type) data_tvm = tvm.nd.array(in_npy, device=dev) out_tvm = tvm.nd.empty(shape=out_npy.shape, device=dev, dtype=out_dtype) foo(data_tvm, out_tvm) if reduce_type == "argmax" or reduce_type == "argmin": out_tvm_indices = out_tvm.numpy() if keepdims: out_tvm_indices = np.take(out_tvm_indices, indices=0, axis=axis) if axis is None: out_tvm_val = in_npy_map.ravel()[out_tvm_indices] else: other_indices = tuple( np.indices(in_shape[0:axis] + in_shape[(axis + 1):])) sel_indices = other_indices[0:axis] + ( out_tvm_indices, ) + other_indices[axis:] out_tvm_val = in_npy_map[sel_indices] if reduce_type == "argmax": tvm.testing.assert_allclose(out_tvm_val, in_npy_map.max(axis=axis), 1e-3, 1e-3) elif reduce_type == "argmin": tvm.testing.assert_allclose(out_tvm_val, in_npy_map.min(axis=axis), 1e-3, 1e-3) else: tvm.testing.assert_allclose(out_tvm.numpy(), out_npy, 1e-3, 1e-3)
def check(device, dtype, m=32, n=32): ctx = tvm.context(device, 0) if not ctx.exist or not tvm.runtime.enabled(device): print("skip because", device, "is not enabled..") return if dtype == "float16" and not have_fp16(ctx.compute_version): print("Skip because gpu does not have fp16 support") return a = te.placeholder((m, n), name="a", dtype=dtype) b = te.placeholder((m, n), name="b", dtype=dtype) c = a + b d = a * b e = topi.elemwise_sum([c, d]) g = topi.sum(e) with tvm.target.create(device): sg = topi.cuda.schedule_reduce(g) func = tvm.build(sg, [a, b, g], device) a_np = np.random.uniform(size=(m, n)).astype(a.dtype) b_np = np.random.uniform(size=(m, n)).astype(b.dtype) g_np = np.sum(np.add(a_np * b_np, a_np + b_np)) a_nd = tvm.nd.array(a_np, ctx) b_nd = tvm.nd.array(b_np, ctx) g_nd = tvm.nd.array(np.zeros(g_np.shape, dtype=g_np.dtype), ctx) func(a_nd, b_nd, g_nd) tvm.testing.assert_allclose(g_nd.asnumpy(), g_np, rtol=1e-3)
def conv_bwd(N, CI, HI, WI, CO, HO, WO, KSIZE, stride, padding, dtype): strides = (stride, stride) shape_data = (N, CI, HI, WI) shape_weight = (CO, CI, KSIZE, KSIZE) shape_grad_output = (N, CO, HO, WO) # given tensor data = te.placeholder(shape_data, name="data", dtype=dtype) weight = te.placeholder(shape_weight, name="weight", dtype=dtype) grad_output = te.placeholder(shape_grad_output, name="grad_output", dtype=dtype) # grad_data out_h = (HO - 1) * strides[0] - 2 * padding + KSIZE out_w = (WO - 1) * strides[1] - 2 * padding + KSIZE output_padding = (HI - out_h, WI - out_w) grad_data = topi.nn.conv2d_transpose_nchw(grad_output, weight, strides, padding, dtype, output_padding) # grad_weight dilation_h, dilation_w = (1, 1) batch, in_channel, in_h, in_w = shape_data out_channel, _, filter_h, filter_w = shape_weight grad_output_tmp = topi.tile(grad_output, [1, in_channel, 1, 1]) grad_output_tmp = topi.reshape( grad_output_tmp, [batch * in_channel * out_channel, 1, HO, WO]) data_tmp = topi.reshape(data, [1, in_channel * batch, HI, WI]) grad_weight = topi.nn.group_conv2d_nchw(data_tmp, grad_output_tmp, stride=(dilation_h, dilation_w), padding=padding, dilation=strides, groups=in_channel * batch, out_dtype=dtype) # infer shape of grad_weight _, _, grad_h, grad_w = shape_grad_output fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple( padding, (filter_h, filter_w)) padded_weight_grad_h = (in_h - (grad_h - 1) * strides[0] - 1 + fpad_top + fpad_bottom) // dilation_h + 1 padded_weight_grad_w = (in_w - (grad_w - 1) * strides[1] - 1 + fpad_left + fpad_right) // dilation_w + 1 grad_weight = topi.reshape(grad_weight, [ batch, in_channel, out_channel, padded_weight_grad_h, padded_weight_grad_w ]) grad_weight = topi.sum(grad_weight, axis=0) grad_weight = topi.transpose(grad_weight, [1, 0, 2, 3]) if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w: grad_weight = topi.strided_slice( grad_weight, begin=[0, 0, 0, 0], end=[out_channel, in_channel, filter_h, filter_w]) return [data, weight, grad_output, grad_data, grad_weight] return [data, weight, grad_output, grad_data, grad_weight]
def test_reduce_map(in_shape, axis, keepdims, type="sum", test_id=0): global TASK # Build the logic and compile the function A = te.placeholder(shape=in_shape, name="A") if type == "sum": TASK = "sum_map_id%d" % test_id B = topi.sum(A, axis=axis, keepdims=keepdims) elif type == "max": TASK = "max_map_id%d" % test_id B = topi.max(A, axis=axis, keepdims=keepdims) elif type == "min": TASK = "min_map_id%d" % test_id B = topi.min(A, axis=axis, keepdims=keepdims) else: raise NotImplementedError s = topi.cuda.schedule_reduce(B) with tvm.transform.PassContext( config={"tir.UnrollLoop": { "auto_max_step": 16, }}): fcuda = tvm.build(s, [A, B], "cuda", name="sum") # Test in_npy = np.random.normal(size=in_shape).astype(np.float32) if type == "sum": out_npy = in_npy.sum(axis=axis, keepdims=keepdims) elif type == "max": out_npy = in_npy.max(axis=axis, keepdims=keepdims) elif type == "min": out_npy = in_npy.min(axis=axis, keepdims=keepdims) else: raise NotImplementedError data_tvm = tvm.nd.array(in_npy, device=tvm.cuda()) out_tvm = tvm.nd.empty(shape=out_npy.shape, device=tvm.cuda()) for _ in range(2): fcuda(data_tvm, out_tvm) tvm.testing.assert_allclose(out_tvm.asnumpy(), out_npy, rtol=4e-4, atol=4e-4)
def check(device, dtype, m=32, n=32): if not tvm.testing.device_enabled(device): print("Skipping", device) return dev = tvm.device(device, 0) if dtype == "float16" and not have_fp16(dev.compute_version): print("Skip because gpu does not have fp16 support") return a = tvm.te.placeholder((m, n), name="a", dtype=dtype) b = topi.sum(a) with tvm.target.Target(device): sb = tvm.te.create_schedule(b.op) i, _ = b.op.reduce_axis sb[b].bind(i, tvm.te.thread_axis("threadIdx.x")) func = tvm.build(sb, [a, b], device) a_np = np.random.uniform(size=(m, n)).astype(a.dtype) b_np = np.sum(a_np) a_nd = tvm.nd.array(a_np, dev) b_nd = tvm.nd.array(np.zeros(b_np.shape, dtype=b_np.dtype), dev) func(a_nd, b_nd) tvm.testing.assert_allclose(b_nd.asnumpy(), b_np, rtol=1e-3)
def check(device, dtype, m=32, n=32): if not tvm.testing.device_enabled(device): print("Skipping", device) return dev = tvm.device(device, 0) a = te.placeholder((m, n), name="a", dtype=dtype) b = te.placeholder((m, n), name="b", dtype=dtype) c = a + b d = a * b e = topi.elemwise_sum([c, d]) g = topi.sum(e) with tvm.target.Target(device): sg = topi.cuda.schedule_reduce(g) func = tvm.build(sg, [a, b, g], device) a_np = np.random.uniform(size=(m, n)).astype(a.dtype) b_np = np.random.uniform(size=(m, n)).astype(b.dtype) g_np = np.sum(np.add(a_np * b_np, a_np + b_np)) a_nd = tvm.nd.array(a_np, dev) b_nd = tvm.nd.array(b_np, dev) g_nd = tvm.nd.array(np.zeros(g_np.shape, dtype=g_np.dtype), dev) func(a_nd, b_nd, g_nd) tvm.testing.assert_allclose(g_nd.asnumpy(), g_np, rtol=1e-3)
def test_complex_reduce(): in_shape = (2, 3) dtype = "float32" axis = 0 keepdims = False A = te.placeholder(shape=in_shape, name="A", dtype=dtype) B = topi.sum(A, axis=axis, keepdims=keepdims) C = topi.add(B, B) D = topi.multiply(B, B) E = topi.add(C, D) for device, ctx in tvm.testing.enabled_targets(): print("Running on target: %s" % device) with tvm.target.Target(device): s = tvm.topi.testing.get_reduce_schedule(device)(E) foo = tvm.build(s, [A, E], device, name="sum") in_npy = np.random.uniform(-1, 1, size=in_shape).astype(dtype) sum_npy = in_npy.sum(axis=axis, keepdims=keepdims) out_npy = sum_npy * 2 + sum_npy * sum_npy data_tvm = tvm.nd.array(in_npy, ctx=ctx) out_tvm = tvm.nd.empty(shape=out_npy.shape, ctx=ctx, dtype=dtype) foo(data_tvm, out_tvm) tvm.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1e-3, 1e-3)
def batch_norm_bwd(N, C, H, W, dtype="float32"): dshape = (N, C, H, W) oshape = (C, ) bshape = (1, C, 1, 1) sshape = (1, ) data = te.placeholder(dshape, name="data", dtype=dtype) scale = te.placeholder(oshape, name="scale", dtype=dtype) saved_mean = te.placeholder(oshape, name="saved_mean", dtype=dtype) saved_var = te.placeholder(oshape, name="saved_var", dtype=dtype) eps = te.placeholder(sshape, name="eps", dtype=dtype) grad_output = te.placeholder(dshape, name="data", dtype=dtype) axis = (0, 2, 3) num_ele = dshape[0] * dshape[2] * dshape[3] frac_num_ele = 1.0 / num_ele # compute grad_input mean_sum = topi.sum(data, axis, True) mean = topi.multiply(mean_sum, frac_num_ele) var_sub = topi.subtract(data, mean) var_mul = topi.multiply(var_sub, var_sub) var_sum = topi.sum(var_mul, axis, True) var = topi.multiply(var_sum, frac_num_ele) var_eps = topi.add(var, eps) output_sqrt = topi.sqrt(var_eps) x_norm = topi.subtract(data, mean) x_hat = topi.divide(x_norm, output_sqrt) dx_hat = topi.multiply(grad_output, topi.reshape(scale, bshape)) grad_input_sum1 = topi.sum(dx_hat * x_hat, axis, True) grad_input_sum2 = topi.sum(dx_hat, axis, True) grad_input_left = topi.divide(frac_num_ele, topi.sqrt(var_eps)) grad_input_right1 = topi.subtract(topi.multiply(dx_hat, num_ele), grad_input_sum2) grad_input_right2 = topi.multiply(x_hat, grad_input_sum1) grad_input = topi.multiply( grad_input_left, topi.subtract(grad_input_right1, grad_input_right2)) # compute grad_scale and grad_bias grad_scale = topi.sum(grad_output * x_hat, axis) grad_bias = topi.sum(grad_output, axis) return [ data, scale, saved_mean, saved_var, eps, grad_output, grad_input, grad_scale, grad_bias ]
A = te.placeholder((n, m), name="A") k = te.reduce_axis((0, m), "k") B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B") s = te.create_schedule(B.op) ###################################################################### # and to examine the IR code in human readable format, we can do # print(tvm.lower(s, [A], simple_mode=True)) ###################################################################### # However, for such a common operation we had to define the reduce axis ourselves as well as explicit computation with # :code:`te.compute`. Imagine for more complicated operations how much details we need to provide. # Fortunately, we can replace those two lines with simple :code:`topi.sum` much like :code:`numpy.sum` # C = topi.sum(A, axis=1) ts = te.create_schedule(C.op) print(tvm.lower(ts, [A], simple_mode=True)) ###################################################################### # Numpy-style operator overloading # -------------------------------- # We can add two tensors using :code:`topi.broadcast_add` that have correct (broadcastable with specific) shapes. # Even shorter, TOPI provides operator overloading for such common operations. For example, # x, y = 100, 10 a = te.placeholder((x, y, y), name="a") b = te.placeholder((y, y), name="b") c = a + b # same as topi.broadcast_add d = a * b # same as topi.broadcast_mul
def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32"): # Build the logic and compile the function A = te.placeholder(shape=in_shape, name="A", dtype=dtype) A1 = topi.sqrt(topi.exp(A)) out_dtype = dtype if type == "sum": B = topi.sum(A1, axis=axis, keepdims=keepdims) elif type == "all": B = topi.all(A, axis=axis, keepdims=keepdims) elif type == "any": B = topi.any(A, axis=axis, keepdims=keepdims) elif type == "max": B = topi.max(A1, axis=axis, keepdims=keepdims) elif type == "min": B = topi.min(A1, axis=axis, keepdims=keepdims) elif type == "argmax": B = topi.argmax(A1, axis=axis, keepdims=keepdims) out_dtype = "int32" elif type == "argmin": B = topi.argmin(A1, axis=axis, keepdims=keepdims) out_dtype = "int32" else: raise NotImplementedError def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.Target(device): s = tvm.topi.testing.get_reduce_schedule(device)(B) foo = tvm.build(s, [A, B], device, name=type) # Test if dtype == "bool": in_npy_map = in_npy = np.random.choice([True, False], size=in_shape) else: in_npy = np.random.uniform(-1, 1, size=in_shape).astype(dtype) in_npy_map = np.sqrt(np.exp(in_npy)).astype(dtype) if type == "sum": out_npy = in_npy_map.sum(axis=axis, keepdims=keepdims) elif type == "all" and dtype == "bool": out_npy = in_npy_map.all(axis=axis, keepdims=keepdims) elif type == "any" and dtype == "bool": out_npy = in_npy_map.any(axis=axis, keepdims=keepdims) elif type == "max": out_npy = in_npy_map.max(axis=axis, keepdims=keepdims) elif type == "min": out_npy = in_npy_map.min(axis=axis, keepdims=keepdims) elif type == "argmax": out_npy = _my_npy_argmax(in_npy_map, axis=axis, keepdims=keepdims) elif type == "argmin": out_npy = _my_npy_argmin(in_npy_map, axis=axis, keepdims=keepdims) else: raise NotImplementedError data_tvm = tvm.nd.array(in_npy, ctx=ctx) out_tvm = tvm.nd.empty(shape=out_npy.shape, ctx=ctx, dtype=out_dtype) for _ in range(1): foo(data_tvm, out_tvm) if type == "argmax" or type == "argmin": out_tvm_indices = out_tvm.asnumpy() if keepdims: out_tvm_indices = np.take(out_tvm_indices, indices=0, axis=axis) if axis is None: out_tvm_val = in_npy_map.ravel()[out_tvm_indices] else: other_indices = tuple( np.indices(in_shape[0:axis] + in_shape[(axis + 1):])) sel_indices = other_indices[0:axis] + ( out_tvm_indices, ) + other_indices[axis:] out_tvm_val = in_npy_map[sel_indices] if type == "argmax": tvm.testing.assert_allclose(out_tvm_val, in_npy_map.max(axis=axis), 1e-3, 1e-3) elif type == "argmin": tvm.testing.assert_allclose(out_tvm_val, in_npy_map.min(axis=axis), 1e-3, 1e-3) else: tvm.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1e-3, 1e-3) for device, ctx in tvm.testing.enabled_targets(): check_device(device, ctx)
def test_topi(): X = te.placeholder((1, 2, 4, 4), name="X") W = te.placeholder((5, 2, 3, 3), name="W") W1 = te.placeholder((2, 5, 3, 3), name="W1") W2 = te.placeholder((1, ), name="W2") R = topi.nn.conv2d(X, W, 1, 1, 1) check_grad(R, [X, W]) R1 = topi.nn.conv2d(topi.nn.relu(R), W1, 1, 0, 1) check_grad(R1, [X, W, W1]) R = topi.broadcast_to(W2, (5, 2, 3, 3)) check_grad(R, [W2]) R = topi.nn.conv2d(X, topi.broadcast_to(W2, (5, 2, 3, 3)), 1, 1, 1) check_grad(R, [X, W2]) R = topi.nn.pool(X, [2, 2], [2, 2], [0, 0, 0, 0], "avg") check_grad(R, X) R = topi.nn.pool(X, [2, 2], [2, 2], [0, 0, 0, 0], "max") check_grad(R, X) X = te.placeholder((1, 2, 5, 5), name="X") R = topi.reshape(X, (1, 32)) check_grad(R, [X]) X = te.placeholder((1, 2, 5, 5), name="X") W = te.placeholder((2, 2, 3, 3), name="W") S = topi.reshape(X, (1, 50)) check_grad(S, [X]) R = X + topi.nn.conv2d(X + topi.nn.conv2d(X, W, 1, 1, 1), W, 1, 1, 1) check_grad(R, [X, W]) S = topi.nn.softmax(topi.reshape(R, (1, 50))) check_grad(S, [X, W]) S = topi.sigmoid(topi.reshape(R, (1, 50))) check_grad(S, [X, W]) S = topi.tanh(topi.reshape(R, (1, 50))) check_grad(S, [X, W]) S = topi.nn.log_softmax(topi.reshape(R, (1, 50))) check_grad(S, [X, W]) check_grad(S, [W], [X]) X = te.placeholder((1, 2, 3, 5), name="X") Y = te.placeholder((1, 2, 7, 5), name="Y") S = topi.concatenate((X, Y), 2) check_grad(S, [X, Y]) X = te.placeholder((1, 2, 6, 5), name="X") (S, R) = topi.split(X, 2, 2) check_grad(S, [X]) check_grad(R, [X]) R1 = topi.concatenate((S, R), 2) check_grad(R1, [X]) R2 = topi.concatenate((R, S), 2) check_grad(R2, [X]) X = te.placeholder((4, 5), name="X") I = te.placeholder((100, ), name="I", dtype="int32") R = topi.take(X, topi.abs(I)) check_grad(R, [X], [I]) W = te.placeholder((5, 5), name="W") exps = topi.exp(topi.nn.dense(X, W)) sumexps = topi.sum(exps, axis=-1, keepdims=True) R = exps / sumexps check_grad(R, [X, W], data_range=(-1, 1))
def compute_cross_entropy_with_logits(attrs, inputs, out_dtype): x, y = inputs return [-topi.sum(x * y) / x.shape[0]]
def compute_cross_entropy(attrs, inputs, out_dtype): x, y = inputs return [-topi.sum(topi.log(x) * y) / x.shape[0]]
def test_reduce_map(hexagon_session: Session, ref_data, in_shape, axis, keepdims, reduce_type, dtype): in_npy, in_npy_map, out_npy = ref_data # Build the logic and compile the function A = te.placeholder(shape=in_shape, name="A", dtype=dtype) A1 = topi.sqrt(topi.exp(A)) out_dtype = dtype if reduce_type == "sum": B = topi.sum(A1, axis=axis, keepdims=keepdims) elif reduce_type == "all": B = topi.all(A, axis=axis, keepdims=keepdims) elif reduce_type == "any": B = topi.any(A, axis=axis, keepdims=keepdims) elif reduce_type == "max": B = topi.max(A1, axis=axis, keepdims=keepdims) elif reduce_type == "min": B = topi.min(A1, axis=axis, keepdims=keepdims) elif reduce_type == "argmax": B = topi.argmax(A1, axis=axis, keepdims=keepdims) out_dtype = "int32" elif reduce_type == "argmin": B = topi.argmin(A1, axis=axis, keepdims=keepdims) out_dtype = "int32" else: raise NotImplementedError target_hexagon = tvm.target.hexagon("v68") with tvm.target.Target(target_hexagon): fschedule = topi.hexagon.schedule_reduce s = fschedule(B) func = tvm.build(s, [A, B], tvm.target.Target(target_hexagon, host=target_hexagon), name=reduce_type) mod = hexagon_session.load_module(func) dev = hexagon_session.device data_tvm = tvm.nd.array(in_npy, device=dev) out_tvm = tvm.nd.empty(shape=out_npy.shape, device=dev, dtype=out_dtype) mod[reduce_type](data_tvm, out_tvm) if reduce_type == "argmax" or reduce_type == "argmin": out_tvm_indices = out_tvm.numpy() if keepdims: out_tvm_indices = np.take(out_tvm_indices, indices=0, axis=axis) if axis is None: out_tvm_val = in_npy_map.ravel()[out_tvm_indices] else: other_indices = tuple( np.indices(in_shape[0:axis] + in_shape[(axis + 1):])) sel_indices = other_indices[0:axis] + ( out_tvm_indices, ) + other_indices[axis:] out_tvm_val = in_npy_map[sel_indices] if reduce_type == "argmax": tvm.testing.assert_allclose(out_tvm_val, in_npy_map.max(axis=axis), 1e-3, 1e-3) elif reduce_type == "argmin": tvm.testing.assert_allclose(out_tvm_val, in_npy_map.min(axis=axis), 1e-3, 1e-3) else: tvm.testing.assert_allclose(out_tvm.numpy(), out_npy, 1e-3, 1e-3)