def argmin(data, axis=None, keepdims=False): """Returns the indices of the minimum values along an axis. Parameters ---------- data : tvm.Tensor The input tvm tensor axis : None or int or tuple of int Axis or axes along which a argmin operation is performed. The default, axis=None, will find the indices of minimum element all of the elements of the input array. If axis is negative it counts from the last to the first axis. keepdims : bool If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array. Returns ------- ret : tvm.Tensor """ _argmin = tvm.comm_reducer(fcombine=_argmin_comp, fidentity=_argmin_init, name='argmin') return comm_reduce(data, axis=axis, keepdims=keepdims, func=_argmin, is_idx_reduce=True)
def test_tensor_comm_reducer(): m = tvm.var('m') n = tvm.var('n') A = tvm.placeholder((m, n), name='A') k = tvm.reduce_axis((0, n), "k") mysum = tvm.comm_reducer(lambda x, y: x+y, lambda t: tvm.const(0, dtype=t)) C = tvm.compute((m,), lambda i: mysum(A[i, k], axis=k))
def f(n): rv = tvm.reduce_axis((0, n)) init = lambda dtype: tvm.expr.Select(n > 1, tvm.const(0, dtype), n.astype(dtype)) sum = tvm.comm_reducer( lambda x, y: tvm.max(x + y, n.astype('float32')), init, name='sum') return sum(X[rv], axis=rv)
def test_rfactor_argmax(): def fcombine(x, y): lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0]) rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1]) return lhs, rhs def fidentity(t0, t1): return tvm.const(-1, t0), tvm.min_value(t1) argmax = tvm.comm_reducer(fcombine, fidentity, name='argmax') nn = 1027 mm = 10 n = tvm.convert(nn) m = tvm.convert(mm) A0 = tvm.placeholder((m, n), name='A0', dtype='int32') A1 = tvm.placeholder((m, n), name='A1', dtype='float32') k = tvm.reduce_axis((0, n)) B0, B1 = tvm.compute((m,), lambda i: argmax((A0[i, k], A1[i, k]), axis=k), name='B') # schedule s = tvm.create_schedule(B0.op) nthread = 16 ko, kf = s[B0].split(k, factor=nthread) BF0, BF1 = s.rfactor(B0, kf) bx, ty = s[B0].split(s[B0].op.axis[0], factor=nthread) s[B0].bind(bx, tvm.thread_axis("blockIdx.x")) s[B0].bind(ty, tvm.thread_axis("threadIdx.y")) tx = s[B0].op.reduce_axis[0] thread_x = tvm.thread_axis("threadIdx.x") s[B0].bind(tx, thread_x) s[BF0.op].compute_at(s[B0], tx) s[B0].set_store_predicate(thread_x.var.equal(0)) def check_target(device): ctx = tvm.context(device, 0) if not ctx.exist: print("skip because %s is not enabled.." % device) return fapi = tvm.lower(s, args=[A0, A1, B0, B1]) fargmax = tvm.build(fapi, target=device, name="argmax") np_idx = np.repeat(np.arange(nn, dtype='int32').reshape(1, nn), mm, axis=0) np_val = np.random.uniform(size=(mm, nn)).astype('float32') np_res = np.argmax(np_val, axis=1) nd_idx = tvm.nd.array(np_idx, ctx) nd_val = tvm.nd.array(np_val, ctx) nd_res0 = tvm.nd.array(np.zeros(mm, dtype='int32'), ctx) nd_res1 = tvm.nd.array(np.zeros(mm, dtype='float32'), ctx) fargmax(nd_idx, nd_val, nd_res0, nd_res1) tvm.testing.assert_allclose(np_res, nd_res0.asnumpy()) check_target("cuda") check_target("vulkan")
def test_rfactor_argmax(): def fcombine(x, y): lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0]) rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1]) return lhs, rhs def fidentity(t0, t1): return tvm.const(-1, t0), tvm.min_value(t1) argmax = tvm.comm_reducer(fcombine, fidentity, name='argmax') nn = 1027 mm = 10 n = tvm.convert(nn) m = tvm.convert(mm) A0 = tvm.placeholder((m, n), name='A0', dtype='int32') A1 = tvm.placeholder((m, n), name='A1', dtype='float32') k = tvm.reduce_axis((0, n)) B0, B1 = tvm.compute((m, ), lambda i: argmax((A0[i, k], A1[i, k]), axis=k), name='B') # schedule s = tvm.create_schedule(B0.op) nthread = 16 ko, kf = s[B0].split(k, factor=nthread) BF0, BF1 = s.rfactor(B0, kf) bx, ty = s[B0].split(s[B0].op.axis[0], factor=nthread) s[B0].bind(bx, tvm.thread_axis("blockIdx.x")) s[B0].bind(ty, tvm.thread_axis("threadIdx.y")) tx = s[B0].op.reduce_axis[0] thread_x = tvm.thread_axis("threadIdx.x") s[B0].bind(tx, thread_x) s[BF0.op].compute_at(s[B0], tx) s[B0].set_store_predicate(thread_x.var.equal(0)) def check_target(device): ctx = tvm.context(device, 0) if not ctx.exist: print("skip because %s is not enabled.." % device) return fapi = tvm.lower(s, args=[A0, A1, B0, B1]) fargmax = tvm.build(fapi, target=device, name="argmax") np_idx = np.repeat(np.arange(nn, dtype='int32').reshape(1, nn), mm, axis=0) np_val = np.random.uniform(size=(mm, nn)).astype('float32') np_res = np.argmax(np_val, axis=1) nd_idx = tvm.nd.array(np_idx, ctx) nd_val = tvm.nd.array(np_val, ctx) nd_res0 = tvm.nd.array(np.zeros(mm, dtype='int32'), ctx) nd_res1 = tvm.nd.array(np.zeros(mm, dtype='float32'), ctx) fargmax(nd_idx, nd_val, nd_res0, nd_res1) tvm.testing.assert_allclose(np_res, nd_res0.asnumpy()) check_target("cuda") check_target("vulkan")
def common_reduce(name, args=(0, )): if not isinstance(args, tuple) and not isinstance(args, list): args = (args, ) def reduce_op(x, y): assert x.dtype == y.dtype, "Reduing elements that don't have same data type: %s v.s. %s" % ( x.dtype, y.dtype) return tvm.call_pure_extern(x.dtype, name, x, y, *args[1:]) return tvm.comm_reducer(reduce_op, lambda t: tvm.const(args[0], dtype=t), name=name)
def compute_backward_vadd(dtype, ndim, reduce1st, req): # The backward of broadcast op is basically a reduction on broadcast axes. # We label the reduce axes as 1 and other axes as 0, and they form a bit string. # Each bit string correponds to a kernel, so the number of kernels is as many as `2^n` # To reduce it, the bit string is compressed by combining consecutive 0s or 1s. # In this way, the number of bit string (the number of kernels) is reduced to `2 * n` # They compressed bit string is stored in `axes`. And `reduce1st` represents the first bit # of the compressed bit string. Credit to @junrushao1994 and @yzhliu. axes = ([reduce1st, 1 - reduce1st] * ndim)[:ndim] X = tvm.placeholder([tvm.var() for _ in range(ndim)], name='X', dtype=dtype) reducer = tvm.comm_reducer(lambda x, y: x + y, lambda t: tvm.const(0, dtype=t), name="sum") ret = reduce_axes(X, axes, reducer) in_grad_a, in_grad = assign_by_req(ret, req) s = tvm.create_schedule(in_grad.op) return s, X, in_grad_a, in_grad, [ret, in_grad]
def test_argmax(): def fcombine(x, y): lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0]) rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1]) return lhs, rhs def fidentity(t0, t1): return tvm.const(-1, t0), tvm.min_value(t1) argmax = tvm.comm_reducer(fcombine, fidentity, name='argmax') m = tvm.var('m') n = tvm.var('n') idx = tvm.placeholder((m, n), name='idx', dtype='int32') val = tvm.placeholder((m, n), name='val', dtype='float32') k = tvm.reduce_axis((0, n), 'k') T0, T1 = tvm.compute((m, ), lambda i: argmax((idx[i, k], val[i, k]), axis=k), name='T') s = tvm.create_schedule(T0.op) def check_target(): device = 'cpu' if not tvm.module.enabled(device): print("skip because %s is not enabled.." % device) return ctx = tvm.context(device, 0) fapi = tvm.lower(s, args=[idx, val, T0, T1]) fargmax = tvm.build(fapi, target='llvm', name="argmax") mm = 12 nn = 16 np_idx = np.repeat(np.arange(nn, dtype='int32').reshape(1, nn), mm, axis=0) np_val = np.random.uniform(size=(mm, nn)).astype('float32') np_res = np.argmax(np_val, axis=1) nd_idx = tvm.nd.array(np_idx, ctx) nd_val = tvm.nd.array(np_val, ctx) nd_res0 = tvm.nd.array(np.zeros(mm, dtype='int32'), ctx) nd_res1 = tvm.nd.array(np.zeros(mm, dtype='float32'), ctx) fargmax(nd_idx, nd_val, nd_res0, nd_res1) np.testing.assert_allclose(np_res, nd_res0.asnumpy()) check_target()
def test_argmax(): def fcombine(x, y): lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0]) rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1]) return lhs, rhs def fidentity(t0, t1): return tvm.const(-1, t0), tvm.min_value(t1) argmax = tvm.comm_reducer(fcombine, fidentity, name='argmax') m = tvm.var('m') n = tvm.var('n') idx = tvm.placeholder((m, n), name='idx', dtype='int32') val = tvm.placeholder((m, n), name='val', dtype='float32') k = tvm.reduce_axis((0, n), 'k') T0, T1 = tvm.compute((m,), lambda i: argmax((idx[i,k], val[i,k]), axis=k), name='T') s = tvm.create_schedule(T0.op) def check_target(): device = 'cpu' if not tvm.module.enabled(device): print("skip because %s is not enabled.." % device) return ctx = tvm.context(device, 0) fapi = tvm.lower(s, args=[idx, val, T0, T1]) fargmax = tvm.build(fapi, target='llvm', name="argmax") mm = 12 nn = 16 np_idx = np.repeat(np.arange(nn, dtype='int32').reshape(1, nn), mm, axis=0) np_val = np.random.uniform(size=(mm, nn)).astype('float32') np_res = np.argmax(np_val, axis=1) nd_idx = tvm.nd.array(np_idx, ctx) nd_val = tvm.nd.array(np_val, ctx) nd_res0 = tvm.nd.array(np.zeros(mm, dtype='int32'), ctx) nd_res1 = tvm.nd.array(np.zeros(mm, dtype='float32'), ctx) fargmax(nd_idx, nd_val, nd_res0, nd_res1) tvm.testing.assert_allclose(np_res, nd_res0.asnumpy()) check_target()
def _sample(i, c, ph, pw): roi = rois[i] batch_index = roi[0].astype('int32') roi_start_w = roi[1] * spatial_scale roi_start_h = roi[2] * spatial_scale roi_end_w = roi[3] * spatial_scale roi_end_h = roi[4] * spatial_scale roi_h = roi_end_h - roi_start_h roi_w = roi_end_w - roi_start_w roi_h = roi_h roi_w = roi_w bin_h = roi_h / pooled_size_h bin_w = roi_w / pooled_size_w hstart = ph * bin_h wstart = pw * bin_w hend = (ph + 1) * bin_h wend = (pw + 1) * bin_w hstart = tvm.min(tvm.max(hstart + roi_start_h, 0), height - 1) wstart = tvm.min(tvm.max(wstart + roi_start_w, 0), width - 1) hend = tvm.min(tvm.max(hend + roi_start_h, 0), height - 1) wend = tvm.min(tvm.max(wend + roi_start_w, 0), width - 1) non_empty = tvm.all(hstart < hend, wstart < wend) def min_value(dtype): return tvm.expr.Select(non_empty, tvm.min_value(dtype), tvm.const(0.0, dtype)) stride_h = (hend - hstart) / 3.0 stride_w = (wend - wstart) / 3.0 hstart += stride_h wstart += stride_w stride_h = tvm.max(0.01, stride_h) stride_w = tvm.max(0.01, stride_w) _max = tvm.comm_reducer(lambda x, y: tvm.make._OpMax(x, y), min_value, name='max') rh = tvm.reduce_axis((0, tvm.expr.Select(non_empty, 2, 0)), 'rh') rw = tvm.reduce_axis((0, tvm.expr.Select(non_empty, 2, 0)), 'rw') return _max(_bilinear(batch_index, c, hstart + rh * stride_h, wstart + rw * stride_w), axis=[rh, rw])
def test_inline_multi_reduce(): def argmax_comp(x, y): idx = tvm.select((x[1] >= y[1]), x[0], y[0]) val = tvm.select((x[1] >= y[1]), x[1], y[1]) return idx, val def argmax_init(idx_typ, val_typ): return tvm.const(-1, idx_typ), tvm.min_value(val_typ) argmax = tvm.comm_reducer(argmax_comp, argmax_init, name='argmax') m = tvm.var('m') n = tvm.var('n') val = tvm.placeholder((m, n), name='val', dtype='float32') val1 = tvm.compute((m, n), lambda i, j: val[i, j]+1, name='val1') val2 = tvm.compute((m, n), lambda i, j: tvm.exp(val1[i, j]), name='val2') k = tvm.reduce_axis((0, n), 'k') T_idx, T_val = tvm.compute((m, ), lambda i: argmax((k.var, val2[i, k]), axis=k), name='T') s = tvm.create_schedule(T_idx.op) s[val1].compute_inline() s = s.normalize() bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_inline_multi_reduce(): def argmax_comp(x, y): idx = tvm.expr.Select((x[1] >= y[1]), x[0], y[0]) val = tvm.expr.Select((x[1] >= y[1]), x[1], y[1]) return idx, val def argmax_init(idx_typ, val_typ): return tvm.const(-1, idx_typ), tvm.min_value(val_typ) argmax = tvm.comm_reducer(argmax_comp, argmax_init, name='argmax') m = tvm.var('m') n = tvm.var('n') val = tvm.placeholder((m, n), name='val', dtype='float32') val1 = tvm.compute((m, n), lambda i, j: val[i, j]+1, name='val1') val2 = tvm.compute((m, n), lambda i, j: tvm.exp(val1[i, j]), name='val2') k = tvm.reduce_axis((0, n), 'k') T_idx, T_val = tvm.compute((m, ), lambda i: argmax((k.var, val2[i, k]), axis=k), name='T') s = tvm.create_schedule(T_idx.op) s[val1].compute_inline() s = s.normalize() bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds)
def _pool(i, c, ph, pw): roi = rois[i] batch_index = roi[0].astype('int32') roi_start_w, roi_start_h, roi_end_w, roi_end_h = roi[1], roi[2], roi[ 3], roi[4] roi_start_h = tvm.round(roi_start_h * spatial_scale).astype('int32') roi_start_w = tvm.round(roi_start_w * spatial_scale).astype('int32') roi_end_h = tvm.round(roi_end_h * spatial_scale).astype('int32') roi_end_w = tvm.round(roi_end_w * spatial_scale).astype('int32') # force malformed ROIs to be 1x1 roi_h = tvm.max(roi_end_h - roi_start_h + 1, tvm.const(1, 'int32')) roi_w = tvm.max(roi_end_w - roi_start_w + 1, tvm.const(1, 'int32')) bin_h = roi_h.astype(dtype) / pooled_size_h bin_w = roi_w.astype(dtype) / pooled_size_w # use epsilon to prevent floating point precision loss in floor/ceil epsilon = tvm.const(0.00001, dtype) hstart = tvm.floor(ph * bin_h + epsilon).astype('int32') wstart = tvm.floor(pw * bin_w + epsilon).astype('int32') hend = tvm.ceil((ph + 1) * bin_h - epsilon).astype('int32') wend = tvm.ceil((pw + 1) * bin_w - epsilon).astype('int32') hstart = tvm.min(tvm.max(hstart + roi_start_h, 0), height) wstart = tvm.min(tvm.max(wstart + roi_start_w, 0), width) hend = tvm.min(tvm.max(hend + roi_start_h, 0), height) wend = tvm.min(tvm.max(wend + roi_start_w, 0), width) non_empty = tvm.all(hstart < hend, wstart < wend) min_value = lambda dtype: tvm.if_then_else( non_empty, tvm.min_value(dtype), tvm.const(0.0, dtype)) # pylint: disable=unnecessary-lambda _max = tvm.comm_reducer(lambda x, y: tvm.make._OpMax(x, y), min_value, name='max') rh = tvm.reduce_axis((0, hend - hstart), 'rh') rw = tvm.reduce_axis((0, wend - wstart), 'rw') return _max(data[batch_index, c, hstart + rh, wstart + rw], axis=[rh, rw])
def test_tensor_comm_reducer_overload(): m = tvm.var('m') n = tvm.var('n') mysum = tvm.comm_reducer(lambda x, y: x+y, lambda t: tvm.const(0, dtype=t)) sum_res = mysum(m, n)
(data[n, rv] - center[c, rv]), axis=rv), name='dis') # Determine the center def argmin_combine(x, y): lhs = tvm.select((x[1] <= y[1]), x[0], y[0]) rhs = tvm.select((x[1] <= y[1]), x[1], y[1]) return lhs, rhs def argmin_identity(t0, t1): return tvm.const(-1, t0), tvm.max_value(t1) argmin = tvm.comm_reducer(argmin_combine, argmin_identity, name='argmin') rc = tvm.reduce_axis((0, C), name='rc') dummy_idx = tvm.compute((C, ), lambda c: c, name='dummy_idx') idx, mdis = tvm.compute((N, ), lambda i: argmin((dummy_idx[rc], dis[i, rc]), axis=rc), name='idx_w_dis') # Update the center rn2 = tvm.reduce_axis((0, N), name='rn2') center_cnt = tvm.compute((C, ), lambda c: tvm.sum(1, rn2, idx[rn2] == c), name='center_cnt') rn1 = tvm.reduce_axis((0, N), name='rn1') new_center = tvm.compute( (C, V),
print(tvm.lower(s, [Input, Filter, Output], simple_mode=True)) ###################################################################### # .. _general-reduction: # # Define General Commutative Reduction Operation # ---------------------------------------------- # Besides the built-in reduction operations like :any:`tvm.sum`, # :any:`tvm.min` and :any:`tvm.max`, you can also define your # commutative reduction operation by :any:`tvm.comm_reducer`. # n = tvm.var('n') m = tvm.var('m') product = tvm.comm_reducer(lambda x, y: x * y, lambda t: tvm.const(1, dtype=t), name="product") A = tvm.placeholder((n, m), name='A') k = tvm.reduce_axis((0, m), name='k') B = tvm.compute((n, ), lambda i: product(A[i, k], axis=k), name='B') ###################################################################### # .. note:: # # Sometimes we would like to perform reduction that involves multiple # values like :code:`argmax`, which can be done by tuple inputs. # See :ref:`reduction-with-tuple-inputs` for more detail. ###################################################################### # Summary # -------
def test_simplify_combiner(): dummy = tvm.var('dummy') prod = comm_reducer(lambda x, y: x * y, lambda t0: tvm.const(1, t0)) sum_or_prod = comm_reducer( lambda x, y: tvm.expr.Select(dummy < 0, x + y, x * y), lambda t0: tvm.expr.Select(dummy < 0, tvm.const(0, t0), tvm.const( 1, t0))) sum_and_prod = comm_reducer( lambda x, y: (x[0] + y[0], x[1] * y[1]), lambda t0, t1: (tvm.const(0, t0), tvm.const(5, t0) - tvm.const(4, t0))) sum_and_prod2 = comm_reducer( lambda x, y: (x[0] + y[0], x[1] * y[1] + 0 * x[0] + y[0] - y[0]), lambda t0, t1: (tvm.const(5, t0) - tvm.const(5, t0), tvm.const(1, t1))) some_reducer1 = comm_reducer( lambda x, y: (x[0] + y[0], x[0] + y[0] + x[1] + y[1], x[0] * y[2] + y[ 0] * x[2], x[1] + y[2], 4.0), lambda t0, t1, t2, t3, t4: (tvm.const(0, t0), tvm.const(1, t1), tvm.const(2, t2), tvm.const( 3, t3), tvm.const(4, t4))) k = tvm.reduce_axis((0, 10), name="k") A = tvm.placeholder((10, ), name='A') # Test that SimplifyCombiner makes use of vranges vrange = {dummy: tvm.Range(-10, -5)} assert Equal(Simplify(sum_or_prod(A[k], k), vrange), tvm.sum(A[k], k)) vrange = {dummy: tvm.Range(5, 10)} assert Equal(Simplify(sum_or_prod(A[k], k), vrange), prod(A[k], k)) assert Equal(Simplify(sum_and_prod((A[k], A[10 - k]), k)[0]), tvm.sum(A[k], k)) assert Equal(Simplify(sum_and_prod((A[k], A[10 - k]), k)[1]), prod(A[10 - k], k)) assert Equal(Simplify(sum_and_prod2((A[k], A[10 - k]), k)[0]), tvm.sum(A[k], k)) assert Equal(Simplify(sum_and_prod2((A[k], A[10 - k]), k)[1]), prod(A[10 - k], k)) reference_simplified_sources = [[A[0]], [A[0], A[1]], [A[0], A[2]], [A[0], A[1], A[2], A[3]], [A[4]]] for j in range(5): # Here we use the j-th component of the result, so only it and the components it # depends on are left. simplified = Simplify( some_reducer1((A[0], A[1], A[2], A[3], A[4]), k)[j]) # Check that the remaining components are the expected ones. for lhs, rhs in zip(simplified.source, reference_simplified_sources[j]): assert Equal(lhs, rhs) # Test that components with side effects are not removed side_effect = lambda *xs: tvm.make.Call("int32", "dummy", xs, tvm.expr.Call .Intrinsic, None, 0) assert Equal(Simplify(sum_and_prod((A[k], side_effect(A[10 - k])), k)[0]), sum_and_prod((A[k], side_effect(A[10 - k])), k)[0]) assert Equal(Simplify(sum_and_prod((side_effect(A[k]), A[10 - k]), k)[0]), tvm.sum(side_effect(A[k]), k))
# x and y are the operands of reduction, both of them is a tuple of index # and value. def fcombine(x, y): lhs = tvm.select((x[1] >= y[1]), x[0], y[0]) rhs = tvm.select((x[1] >= y[1]), x[1], y[1]) return lhs, rhs # our identity element also need to be a tuple, so `fidentity` accepts # two types as inputs. def fidentity(t0, t1): return tvm.const(-1, t0), tvm.min_value(t1) argmax = tvm.comm_reducer(fcombine, fidentity, name='argmax') # describe the reduction computation m = tvm.var('m') n = tvm.var('n') idx = tvm.placeholder((m, n), name='idx', dtype='int32') val = tvm.placeholder((m, n), name='val', dtype='int32') k = tvm.reduce_axis((0, n), 'k') T0, T1 = tvm.compute((m, ), lambda i: argmax((idx[i, k], val[i, k]), axis=k), name='T') # the generated IR code would be: s = tvm.create_schedule(T0.op) print(tvm.lower(s, [idx, val, T0, T1], simple_mode=True))
s = tvm.create_schedule(Output.op) print(tvm.lower(s, [Input, Filter, Output], simple_mode=True)) ###################################################################### # .. _general-reduction: # # Define General Commutative Reduction Operation # ---------------------------------------------- # Besides the built-in reduction operations like :any:`tvm.sum`, # :any:`tvm.min` and :any:`tvm.max`, you can also define your # commutative reduction operation by :any:`tvm.comm_reducer`. # n = tvm.var('n') m = tvm.var('m') product = tvm.comm_reducer(lambda x, y: x*y, lambda t: tvm.const(1, dtype=t), name="product") A = tvm.placeholder((n, m), name='A') k = tvm.reduce_axis((0, m), name='k') B = tvm.compute((n,), lambda i: product(A[i, k], axis=k), name='B') ###################################################################### # .. note:: # # Sometimes we would like to perform reduction that involves multiple # values like :code:`argmax`, which can be done by tuple inputs. # See :ref:`reduction-with-tuple-inputs` for more detail. ###################################################################### # Summary # ------- # This tutorial provides a walk through of reduction schedule.
def measure_bandwidth_sum(total_item, item_per_thread, stride, base_type, bits, lanes, target, target_host, remote, ctx, n_times): """ measure memory bandwidth of gpu by product reduction for a given type The IR for measurement is for each thread for i in 1..num_per_thread: y[global_id] = y[global_id] * x[base + i * stride] Parameters ---------- total_item: int number of elements in input array item_per_thread: int number of elements each thread accumulates stride: int stride in memory access base_type: str can be "int", "float" bits: int can be 16, 32 lanes: int lane of the vector type, can be 1, 2, 4, 8, 16 target: :any:`tvm.target.Target` the target and option of the compilation. target_host : str or :any:`tvm.target.Target` host compilation target ctx: TVMcontext the context of array remote: tvm.rpc.RPCSession remote rpc session n_times: int number of runs for taking mean Returns ------- GBPS: float gigabyte per second """ n, m = total_item, item_per_thread n //= lanes base_type = str(base_type) + str(bits) dtype = base_type if lanes == 1 else base_type + "x" + str(lanes) k = tvm.reduce_axis((0, m), name="k") x = tvm.placeholder((n, ), dtype=dtype, name="x") op = tvm.comm_reducer(lambda x, y: x * y, lambda t: tvm.const(1, dtype=t), name="sum") y = tvm.compute((n // m, ), lambda i: op( x[i // stride * stride * m + i % stride + k * stride], axis=k)) s = tvm.create_schedule(y.op) yo, yi = s[y].split(y.op.axis[0], target.max_num_threads) s[y].bind(yo, tvm.thread_axis("blockIdx.x")) s[y].bind(yi, tvm.thread_axis("threadIdx.x")) s[y].unroll(k) try: func = tvm.build(s, [x, y], target, target_host=target_host) x = tvm.nd.empty((n, ), dtype=dtype, ctx=ctx) y = tvm.nd.empty((n // m, ), dtype=dtype, ctx=ctx) func = _convert_to_remote(func, remote) time_f = func.time_evaluator(func.entry_name, ctx, number=n_times) time = time_f(x, y).mean except tvm._ffi.base.TVMError: # build error (occur when device does not support half) return -1 return 1.0 * (total_item * bits / 8) / 1e9 / time
def measure_bandwidth_sum(total_item, item_per_thread, stride, base_type, bits, lanes, target, target_host, remote, ctx, n_times): """ measure memory bandwidth of gpu by product reduction for a given type The IR for measurement is for each thread for i in 1..num_per_thread: y[global_id] = y[global_id] * x[base + i * stride] Parameters ---------- total_item: int number of elements in input array item_per_thread: int number of elements each thread accumulates stride: int stride in memory access base_type: str can be "int", "float" bits: int can be 16, 32 lanes: int lane of the vector type, can be 1, 2, 4, 8, 16 target: :any:`tvm.target.Target` the target and option of the compilation. target_host : str or :any:`tvm.target.Target` host compilation target ctx: TVMcontext the context of array remote: tvm.rpc.RPCSession remote rpc session n_times: int number of runs for taking mean Returns ------- GBPS: float gigabyte per second """ n, m = total_item, item_per_thread n //= lanes base_type = str(base_type) + str(bits) dtype = base_type if lanes == 1 else base_type + "x" + str(lanes) k = tvm.reduce_axis((0, m), name="k") x = tvm.placeholder((n,), dtype=dtype, name="x") op = tvm.comm_reducer(lambda x, y: x*y, lambda t: tvm.const(1, dtype=t), name="sum") y = tvm.compute((n // m,), lambda i: op(x[i // stride * stride * m + i % stride + k * stride], axis=k)) s = tvm.create_schedule(y.op) yo, yi = s[y].split(y.op.axis[0], target.max_num_threads) s[y].bind(yo, tvm.thread_axis("blockIdx.x")) s[y].bind(yi, tvm.thread_axis("threadIdx.x")) s[y].unroll(k) try: func = tvm.build(s, [x, y], target, target_host=target_host) x = tvm.nd.empty((n,), dtype=dtype, ctx=ctx) y = tvm.nd.empty((n // m,), dtype=dtype, ctx=ctx) func = _convert_to_remote(func, remote) time_f = func.time_evaluator(func.entry_name, ctx, number=n_times) time = time_f(x, y).mean except tvm._ffi.base.TVMError: # build error (occur when device does not support half) return -1 return 1.0 * (total_item * bits / 8) / 1e9 / time
dot = tvm.compute((N, L), lambda n, l: tvm.sum(weight[l, rd] * data_expand[n, rd], axis=rd), name='dot') factor = tvm.compute((N, L), lambda n, l: 1 / (1 + tvm.exp(-dot[n, l])), name='factor') def argmax_combine(x, y): lhs = tvm.select((x[1] > y[1]), x[0], y[0]) rhs = tvm.select((x[1] > y[1]), x[1], y[1]) return lhs, rhs def argmax_identity(t0, t1): return tvm.const(-1, t0), tvm.min_value(t1) argmax = tvm.comm_reducer(argmax_combine, argmax_identity, name='argmax') dummy_idx = tvm.compute((L, ), lambda l: l, name='dummy_idx') rl = tvm.reduce_axis((0, L), name='rl') pred_idx,mdis = tvm.compute((N, ), lambda n: argmax((dummy_idx[rl], factor[n, rl]), axis=rl), name='pred_idx') rn = tvm.reduce_axis((0, N), name='rn') err = tvm.compute((1, ), lambda i: tvm.sum(1, rn, label[rn, pred_idx[rn]] < 0.5), name='err') # === End computation # Scheduling s = tvm.create_schedule([pred_idx.op, err.op])
# operands, also need to keep the index of operand. It can be expressed # with :any:`comm_reducer` as below: # x and y are the operands of reduction, both of them is a tuple of index # and value. def fcombine(x, y): lhs = tvm.select((x[1] >= y[1]), x[0], y[0]) rhs = tvm.select((x[1] >= y[1]), x[1], y[1]) return lhs, rhs # our identity element also need to be a tuple, so `fidentity` accepts # two types as inputs. def fidentity(t0, t1): return tvm.const(-1, t0), tvm.min_value(t1) argmax = tvm.comm_reducer(fcombine, fidentity, name='argmax') # describe the reduction computation m = tvm.var('m') n = tvm.var('n') idx = tvm.placeholder((m, n), name='idx', dtype='int32') val = tvm.placeholder((m, n), name='val', dtype='int32') k = tvm.reduce_axis((0, n), 'k') T0, T1 = tvm.compute((m, ), lambda i: argmax((idx[i, k], val[i, k]), axis=k), name='T') # the generated IR code would be: s = tvm.create_schedule(T0.op) print(tvm.lower(s, [idx, val, T0, T1], simple_mode=True)) ###################################################################### # .. note::
i = tvm.reduce_axis((0, n), name='i') B = tvm.compute((), lambda: tvm.sum(A[i, j], axis=(i, j)), name='b') s = tvm.create_schedule(B.op) print(tvm.lower(s, [A, B], simple_mode=True)) mod = tvm.build(s, [A, B]) c = tvm.nd.array(np.empty((), dtype='float32')) mod(tvm.nd.array(a), c) np.testing.assert_allclose(a.sum(), c.asnumpy(), atol=1e-5) # Commutative Reduction # f(a, b) = f(b, a) # prod(axis=1) comp = lambda a, b: a * b init = lambda dtype: tvm.const(1, dtype=dtype) product = tvm.comm_reducer(comp, init) n, m = tvm.var('n'), tvm.var('m') A = tvm.placeholder((n, m), name='a') k = tvm.reduce_axis((0, m), name='k') B = tvm.compute((n, ), lambda i: product(A[i, k], axis=k), name='b') s = tvm.create_schedule(B.op) print(tvm.lower(s, [A, B], simple_mode=True)) mod = tvm.build(s, [A, B]) b = tvm.nd.array(np.empty((3, ), dtype='float32')) mod(tvm.nd.array(a), b) np.testing.assert_allclose(a.prod(axis=1), b.asnumpy(), atol=1e-5)