def zero_pad2d(inputs, padding=0): """Zero padding for 2d tensor Args: ----------------------------- inputs : tvm.te.tensor.Tensor shape [batch, channel, height, width] padding: (optional:0) int or tuple expected: (h_pad_up, h_pad_down, w_pad_up, w_pad_down) ----------------------------- Returns: ----------------------------- tvm.te.tensor.Tensor shape [batch, channel, padded_height, padded_width] ----------------------------- """ padding = (padding, padding, padding, padding) if isinstance(padding, (int, tvm.tir.IntImm)) else padding assert_print(isinstance(padding, tuple), "type(padding)={}".format(type(padding))) if len(padding) == 2: padding = (padding[0], padding[0], padding[1], padding[1]) assert_print(len(padding) == 4) padding_zero = tvm.tir.expr.const(0, inputs.dtype) batch_size, in_channel, height, width = inputs.shape return tvm.te.compute( (batch_size, in_channel, height + padding[0] + padding[1], width + padding[2] + padding[3]), lambda b, c, h, w: tvm.te.if_then_else( tvm.te.all(h >= padding[0], h < height + padding[0], w >= padding[2], w < width + padding[2]), inputs[b, c, h - padding[0], w - padding[2]], padding_zero ), name='Padding', requires_grad=True )
def cross_entropy(inputs, targets): ''' https://pytorch.org/docs/stable/nn.html#torch.nn.CrossEntropyLoss loss(x,class) = −x[class]+log(\Sigma:j exp(x[j])) -x[class]: since targets is one-hot, we use "inner-dot" to compute x_class log(\Sigma:j exp(x[j])) may overflow when computing exp(x[j]) >>>>Trick: maxval + log(\Sigma: j exp(x[j] - maxval)) Finally, compute the average over batches ''' assert_print(inputs.shape[0].value == targets.shape[0].value) assert_print(inputs.shape[1].value == targets.shape[1].value) N, C = inputs.shape c = tvm.te.reduce_axis([0, C], "c") k1 = tvm.te.reduce_axis([0, C], name="k1") # First compute the maximum for each batch max_val = tvm.te.compute([N], lambda n: tvm.te.max(inputs[n, k1], axis=[k1]), name="max_val") # Use the log_softmax trick to avoid overflow sum_val = tvm.te.compute( [N], lambda i: tvm.te.sum(tvm.tir.exp(inputs[i, c] - max_val[i]), axis=[c]), "sum_val") rrn = tvm.te.reduce_axis([0, N], "rrn") rrc = tvm.te.reduce_axis([0, C], "rrc") x_class = tvm.te.compute( [N], lambda i: tvm.te.sum(inputs[i, rrc] * targets[i, rrc], axis=[rrc]), name="x_class") return tvm.te.compute([1], lambda i: tvm.te.sum( (tvm.tir.log(sum_val[i + rrn]) + max_val[i + rrn] - x_class[i + rrn]) / N, axis=[rrn]), name="cross_entropy")
def add_subspace(self, name, subspace, type_key, override=False): if name in self.subspaces and not override: raise RuntimeError("Same subspace name") assert_print(type_key in self.valid_type_keys) self.subspaces[name] = subspace self.types[type_key].append(name) self.dim += subspace.dim
def __init__(self, in_channel, out_channel, kernel_size, bias=False, padding=0, stride=1, dilation=1, groups=1): super(Conv2dLayer, self).__init__() if isinstance(kernel_size, (int, tvm.expr.IntImm)): kernel_size = (kernel_size, kernel_size) assert_print(isinstance(kernel_size, tuple) and len(kernel_size) == 2) self.weight = tvm.placeholder((out_channel, in_channel, *kernel_size), dtype="float32") self.params["weight"] = self.weight if bias: self.bias = tvm.placeholder((out_channel, ), dtype="float32") self.params["bias"] = self.bias else: self.bias = None def forward_func(inputs): return conv2d_nchw(inputs, self.weight, self.bias, stride, padding, dilation, groups) self.forward_func = forward_func
def forward(self, batch_size, policy="random"): assert_print(policy in ["random", "best"]) ret = dict() for name, walker in self.walkers.items(): if policy == "random": ret_entities, ret_p_values = walker.random_batch(batch_size) elif policy == "best": ret_entities, ret_p_values = walker.best_batch(batch_size) ret[name] = (ret_entities, ret_p_values) return ret
def __init__(self, input_len, width, depth, output_len): super(Judger, self).__init__() assert_print(isinstance(width, int) and width > 0) assert_print(isinstance(depth, int) and depth > 1) self.net = nn.Sequential() self.net.add_module("input", nn.Linear(input_len, width)) self.net.add_module("input_activate", nn.ReLU()) for count in range(depth - 2): name = "hidden_{}".format(count) self.net.add_module(name, nn.Linear(width, width)) self.net.add_module(name + "_activate", nn.ReLU()) self.net.add_module("output", nn.Linear(width, output_len))
def gemm(A, B, transposeA=False, transposeB=False): """Matrix multiplies matrix Args: ----------------------------- A: tvm.te.tensor.Tensor shape [height, width] B: tvm.te.tensor.Tensor shape [width, length] transposeA: (optional:False) bool transposeB: (optional:False) bool ----------------------------- Returns: ----------------------------- tvm.te.tensor.Tensor shape [height, length] ----------------------------- """ if transposeA and transposeB: k = tvm.te.reduce_axis((0, B.shape[1])) assert_print(A.shape[0].value == B.shape[1].value) return tvm.te.compute((A.shape[1], B.shape[0]), lambda i, j: tvm.te.sum(A[k, i] * B[j, k], axis=k), requires_grad=True) elif transposeA and not transposeB: k = tvm.te.reduce_axis((0, B.shape[0])) assert_print(A.shape[0].value == B.shape[0].value) return tvm.te.compute((A.shape[1], B.shape[1]), lambda i, j: tvm.te.sum(A[k, i] * B[k, j], axis=k), requires_grad=True) elif not transposeA and transposeB: k = tvm.te.reduce_axis((0, B.shape[1])) assert_print(A.shape[1].value == B.shape[1].value) return tvm.te.compute((A.shape[0], B.shape[0]), lambda i, j: tvm.te.sum(A[i, k] * B[j, k], axis=k), requires_grad=True) else: k = tvm.te.reduce_axis((0, B.shape[0])) assert_print(A.shape[1].value == B.shape[0].value) return tvm.te.compute((A.shape[0], B.shape[1]), lambda i, j: tvm.te.sum(A[i, k] * B[k, j], axis=k), requires_grad=True)
def lltm(input, targets, weight_for_classify, bias_for_classify, weight_for_gate, bias_for_gate, old_h, old_c): ''' input: [batch_size, 28*28] new/old_h & new/old_c: [batch_size, state_size] weight_for_classify: [10, state_size] bias_for_classify: [10] result: [batch_size, 10] targets: [batch_size, 10] one-hot ''' new_h, new_c = internel_lltm(input, weight_for_gate, bias_for_gate, old_h, old_c) assert_print(new_h.shape[1].value == weight_for_classify.shape[1].value) result = topi.nn.dense(new_h, weight_for_classify, bias_for_classify) loss = cross_entropy(result, targets) return loss, result, new_h, new_c
def conv2d_nchw(inputs, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): batch_size, in_channel, in_h, in_w = inputs.shape out_channel, channel_per_group, k_h, k_w = weight.shape assert_print((channel_per_group * groups).value == in_channel.value) out_channel_per_group = out_channel // groups assert_print((out_channel_per_group * groups).value == out_channel.value) stride = (stride, stride) if isinstance(stride, (int, tvm.tir.IntImm)) else stride padding = (padding, padding) if isinstance(padding, (int, tvm.tir.IntImm)) else padding dilation = (dilation, dilation) if isinstance(dilation, (int, tvm.tir.IntImm)) else dilation assert_print(isinstance(stride, tuple) and len(stride) == 2) assert_print(isinstance(padding, tuple) and len(padding) == 2) assert_print(isinstance(dilation, tuple) and len(dilation) == 2) out_h = (in_h + 2 * padding[0] - dilation[0] * (k_h - 1) - 1) // stride[0] + 1 out_w = (in_w + 2 * padding[1] - dilation[1] * (k_w - 1) - 1) // stride[1] + 1 rc = tvm.te.reduce_axis((0, channel_per_group), name="rc") rh = tvm.te.reduce_axis((0, k_h), name="rh") rw = tvm.te.reduce_axis((0, k_w), name="rw") padded = zero_pad2d(inputs, padding=padding) output = tvm.te.compute( (batch_size, out_channel, out_h, out_w), lambda b, c, h, w: tvm.te.sum( (padded[b, c // out_channel_per_group * channel_per_group + rc, h * stride[0] + rh * dilation[0], w * stride[ 1] + rw * dilation[1]] * weight[c, rc, rh, rw]), axis=[rc, rw, rh]), requires_grad=True) if bias is not None: output = tvm.te.compute( (batch_size, out_channel, out_h, out_w), lambda b, c, h, w: output[b, c, h, w] + bias[c]) return output
def zero_pad2d(inputs, padding=0): padding = (padding, padding, padding, padding) if isinstance(padding, (int, tvm.tir.IntImm)) else padding assert_print(isinstance(padding, tuple), "type(padding)={}".format(type(padding))) if len(padding) == 2: padding = (padding[0], padding[0], padding[1], padding[1]) assert_print(len(padding) == 4) padding_zero = tvm.tir.expr.const(0, inputs.dtype) batch_size, in_channel, height, width = inputs.shape return tvm.te.compute( (batch_size, in_channel, height + padding[0] + padding[1], width + padding[2] + padding[3]), lambda b, c, h, w: tvm.te.if_then_else( tvm.te.all(h >= padding[0], h < height + padding[0], w >= padding[2], w < width + padding[2]), inputs[b, c, h - padding[0], w - padding[2]], padding_zero ), name='Padding', requires_grad=True )
def next_entity(self, pos, d): # d is tuple if len(d) == 1: next_pos = (pos + d[0]) % self.size return next_pos elif len(d) == 2: asc_pos, dec_pos = d[0], d[1] assert_print(0 <= asc_pos < self.dim) assert_print(0 <= dec_pos < self.dim) assert_print(asc_pos != dec_pos) current = self.static_entities[pos] ret = current.copy() left = current[asc_pos] * current[dec_pos] canout = False next_pos = -1 while not canout: tmp = ret[asc_pos] + 1 while tmp <= left: if self.allow_non_divisible == 'continuous': break elif self.allow_non_divisible == 'power2' and is_power_of_x( 2, tmp): break elif left % tmp == 0: break tmp += 1 tmp = min(tmp, left) ret[asc_pos] = tmp ret[dec_pos] = math.ceil(left / tmp) try: next_pos = self.static_entities.index(ret) canout = True except ValueError: canout = False return next_pos else: raise RuntimeError( "Not support for direction more than two dims: {}".format(d))
def cpu_schedule_simple(config, s, op, op_state): # always cache write here # if op.num_outputs > 1: # raise RuntimeWarning("Too many outputs in one operation!") write_cache = s.cache_write(op.output(0), "global") # spatial split spatial_axes = s[op].op.axis splited_spatial_axes = [] if "spatial" in config and len(config["spatial"]) > 0: # to align each axis assert_print(len(config["spatial"]) == len(spatial_axes), "align failed") for axis, nparts in zip(spatial_axes, config["spatial"]): nfactors = [1] count = len(nparts) - 1 while count >= 0: nfactors.append(nparts[count] * nfactors[-1]) count -= 1 tmp_buffer = [] num_factors = len(nfactors) for i in range(num_factors - 2): factor = nfactors[num_factors - 2 - i] part = nparts[i] if factor == 1: tmp_buffer.append(axis) axis = None elif part == 1: tmp_buffer.append(None) else: outer, axis = s[op].split(axis, factor=factor) tmp_buffer.append(outer) tmp_buffer.append(axis) splited_spatial_axes.append(tmp_buffer) else: for axis in spatial_axes: splited_spatial_axes.append([axis]) assert_print(len(splited_spatial_axes) > 0, "empty spatial axes") # must be non-empty # always reorder and fuse here # this part actually suppose there is "spatial" in config # which is avoidable spatial_fuse_lsts = [] spatial_fuse_extents = [] reorder_lst = [] fused_spatial_axes = [] spatial_split_num_parts = len(splited_spatial_axes[0]) for count in range(spatial_split_num_parts): tmp_buffer = [x[count] for x in splited_spatial_axes] tmp_extent = reduce(lambda a, b: a * b, [x[count] for x in config["spatial"]]) spatial_fuse_lsts.append(tmp_buffer) spatial_fuse_extents.append(tmp_extent) reorder_lst.extend(tmp_buffer) reorder_lst_without_none = list(filter(lambda x: x is not None, reorder_lst)) # print("reorder op", reorder_lst_without_none) s[op].reorder(*reorder_lst_without_none) for fuse_lst in spatial_fuse_lsts[:1]: tmp_buffer = list(filter(lambda x: x is not None, fuse_lst)) # print("fuse op", tmp_buffer) fused = s[op].fuse(*tmp_buffer) fused_spatial_axes.append(fused) kernel_scope = fused_spatial_axes[0] if len(spatial_fuse_lsts) > 1: count = 0 while count < len(config["spatial"]) and config["spatial"][count][1] == 1: count += 1 if count == len(config["spatial"]): count -= 1 next_pos_for_comptue_at = spatial_fuse_lsts[1][count] else: next_pos_for_comptue_at = kernel_scope # always parallel here s[op].parallel(kernel_scope) # vectorize if len(spatial_fuse_lsts) == 2: count = len(spatial_fuse_lsts[1]) - 1 while count >= 1: if spatial_fuse_lsts[1][count] is not None and config["spatial"][1][count] > 1: # print("vectorize op", spatial_fuse_lsts[1][count]) s[op].vectorize(spatial_fuse_lsts[1][count]) break count -= 1 elif len(spatial_fuse_lsts) > 2: count = len(spatial_fuse_lsts[-1]) - 1 while count >= 0: if spatial_fuse_lsts[-1][count] is not None and config["spatial"][count][ -1] > 1: # print("vectorize op", spatial_fuse_lsts[-1][count]) s[op].vectorize(spatial_fuse_lsts[-1][count]) break count -= 1 # always compute at here # print("compute at", next_pos_for_comptue_at) s[write_cache].compute_at(s[op], next_pos_for_comptue_at) # spatial_split for write cache spatial_axes = s[write_cache].op.axis num_spatial_axes = len(spatial_axes) splited_spatial_axes = [] if "spatial" in config and len(config["spatial"]) > 0: # to align each axis assert_print(len(config["spatial"]) == len(spatial_axes), "align failed") for axis, nparts in zip(spatial_axes, config["spatial"]): nfactors = [1] count = len(nparts) - 1 while count >= 0: nfactors.append(nparts[count] * nfactors[-1]) count -= 1 tmp_buffer = [] num_factors = len(nfactors) for i in range(num_factors - 2): factor = nfactors[num_factors - 2 - i] part = nparts[i] if factor == 1: tmp_buffer.append(axis) axis = None elif part == 1: tmp_buffer.append(None) else: outer, axis = s[write_cache].split(axis, factor=factor) tmp_buffer.append(outer) tmp_buffer.append(axis) splited_spatial_axes.append(tmp_buffer) else: for axis in spatial_axes: splited_spatial_axes.append([axis]) assert_print(len(splited_spatial_axes) > 0, "empty spatial axes") # must be non-empty # reduce_split for write cache reduced_axes = s[write_cache].op.reduce_axis num_reduce_axes = len(reduced_axes) splited_reduced_axes = [] if "reduce" in config and len(config["reduce"]) > 0: # to align each axis assert_print(len(config["reduce"]) == len(reduced_axes), "align reduce failed") for axis, nparts in zip(reduced_axes, config["reduce"]): nfactors = [1] count = len(nparts) - 1 while count >= 0: nfactors.append(nparts[count] * nfactors[-1]) count -= 1 tmp_buffer = [] num_factors = len(nfactors) for i in range(num_factors - 2): factor = nfactors[num_factors - 2 - i] part = nparts[i] if factor == 1: tmp_buffer.append(axis) axis = None elif part == 1: tmp_buffer.append(None) else: outer, axis = s[write_cache].split(axis, factor=factor) tmp_buffer.append(outer) tmp_buffer.append(axis) splited_reduced_axes.append(tmp_buffer) else: for axis in reduced_axes: splited_reduced_axes.append([axis]) # for easy align # reduce_split_num_parts = len(splited_reduced_axes[0]) # assert reduce_split_num_parts == spatial_split_num_parts # reorder hybrid for spatial and reduce hybrid_axes = splited_spatial_axes + splited_reduced_axes hybrid_fuse_lsts = [] hybrid_reorder_lst = [] for count in range(spatial_split_num_parts): tmp_buffer = [x[count] for x in hybrid_axes] hybrid_fuse_lsts.append(tmp_buffer) hybrid_reorder_lst.extend(tmp_buffer) if len(hybrid_fuse_lsts) > 1: last_parts = hybrid_reorder_lst[-num_spatial_axes - num_reduce_axes:] hybrid_reorder_lst = hybrid_reorder_lst[:-num_spatial_axes - num_reduce_axes] tmp_buffer = last_parts[-num_reduce_axes:] tmp_buffer.extend(last_parts[:-num_reduce_axes]) hybrid_reorder_lst.extend(tmp_buffer) hybrid_reorder_lst_without_none = list( filter(lambda x: x is not None, hybrid_reorder_lst)) # print("reorder cache write", hybrid_reorder_lst_without_none) s[write_cache].reorder(*hybrid_reorder_lst_without_none) # fuse without reduce axes # assert len(hybrid_fuse_lsts) > 0 # s[write_cache].fuse(*hybrid_fuse_lsts[0][:-num_reduce_axes]) # unroll and vectorize without reduce axes if len(hybrid_fuse_lsts) > 1: rcount = num_spatial_axes - 1 while rcount >= 0 and config["spatial"][rcount][-1] == 1: rcount -= 1 if rcount >= 0: # print("vectorize cache write", hybrid_fuse_lsts[-1][rcount]) s[write_cache].vectorize(hybrid_fuse_lsts[-1][rcount]) for count in range(rcount): if config["spatial"][count][-1] > 1: # print("unroll cache write", hybrid_fuse_lsts[-1][count]) s[write_cache].unroll(hybrid_fuse_lsts[-1][count]) if len(hybrid_fuse_lsts) > 2: for count in range(num_spatial_axes): if config["spatial"][count][-2] > 1: # print("unroll cache write", hybrid_fuse_lsts[-2][count]) s[write_cache].unroll(hybrid_fuse_lsts[-2][count])
def cpu_schedule_split_reorder_fuse(config, s, op, op_state): # assert_print(op in s) loop_idx = [] loop_lst = [] # always cache write here # if op.num_outputs > 1: # raise RuntimeWarning("Too many outputs in one operation!") write_cache = s.cache_write(op.output(0), "local") # spatial split spatial_axes = [axis for axis in s[op].op.axis] assert len(spatial_axes) > 0, "empty spatial axes" # must be non-empty n = spatial_axes[0] kernel_scope, n = s[op].split(n, nparts=1) spatial_axes[0] = n splited_spatial_axes = [] splited_spatial_extents = [] if "spatial" in config and len(config["spatial"]) > 0: # to align each axis assert len(config["spatial"]) == len(spatial_axes), "align failed" for axis, nparts in zip(spatial_axes, config["spatial"]): tmp_buffer = [] tmp_extents = [] for count in range(len(nparts) - 1): outer, axis = s[op].split(axis, nparts=nparts[count]) tmp_buffer.append(outer) tmp_extents.append(nparts[count]) tmp_buffer.append(axis) tmp_extents.append(nparts[-1]) splited_spatial_axes.append(tmp_buffer) splited_spatial_extents.append(tmp_extents) else: for axis in spatial_axes: splited_spatial_axes.append([axis]) splited_spatial_extents.append([axis.dom.extent.value]) # always reorder here reorder_lst = [] reorder_parts = [] reorder_part_extents = [] for count in range(len(splited_spatial_axes[0])): tmp_buffer = [x[count] for x in splited_spatial_axes] tmp_extents = [x[count] for x in splited_spatial_extents] reorder_lst.extend(tmp_buffer) reorder_parts.append(tmp_buffer) reorder_part_extents.append(tmp_extents) s[op].reorder(*reorder_lst) # handle fuse request fused_parts = [] fused_part_extents = [] fused_part_idx = [] if "fuse" in config and len(config["fuse"]) > 0: base_id = 0 for part, extents in zip(reorder_parts, reorder_part_extents): tmp_part = [] tmp_extents = [] tmp_idx = [] idx = 0 beg = 0 for end in config["fuse"][0]: if end - beg > 1: fuse_lst = part[beg:end] fused = s[op].fuse(*fuse_lst) tmp_part.append(fused) extent = reduce(lambda x, y: x * y, extents[beg:end], 1) tmp_idx.extend([idx] * (end - beg)) idx += 1 tmp_extents.append(extent) elif end - beg == 1: tmp_part.append(part[beg]) tmp_extents.append(extents[beg]) tmp_idx.append(idx) idx += 1 beg = end fused_parts.append(tmp_part) fused_part_extents.append(tmp_extents) fused_part_idx.append(tmp_idx) # for op state loop_lst.extend(tmp_part) loop_idx.extend([x + base_id for x in tmp_idx]) base_id += len(tmp_part) else: fused_parts = reorder_parts fused_part_extents = reorder_part_extents fused_part_idx = [list(range(len(x))) for x in reorder_parts] # for op state loop_lst = reorder_lst loop_idx = list(range(len(reorder_lst))) # record op state op_state.loop_lst = loop_lst op_state.loop_idx = loop_idx # parallel fused = s[op].fuse(*fused_parts[0]) s[op].parallel(fused) # compute at num_parts = len(fused_parts) if num_parts == 1: local_pos = fused elif num_parts == 2: local_pos = fused_parts[num_parts - 1][0] else: local_pos = fused_parts[num_parts - 2][-1] if "local_pos" in config and len(config["local_pos"]) > 0: local_at_part = config["local_pos"][0][0] local_at_idx = config["local_pos"][0][1] # index changed because of fusion cur_idx = fused_part_idx[local_at_part][local_at_idx] local_pos = fused_parts[local_at_part][cur_idx] # always compute at here s[write_cache].compute_at(s[op], local_pos) # reduce_split reduced_axes = s[write_cache].op.reduce_axis splited_reduced_axes = [] if "reduce" in config and len(config["reduce"]) > 0: # to align each axis assert_print(len(config["reduce"]) == len(reduced_axes), "align reduce failed") for axis, nparts in zip(reduced_axes, config["reduce"]): tmp_buffer = [] for count in range(len(nparts) - 1): outer, axis = s[write_cache].split(axis, nparts=nparts[count]) tmp_buffer.append(outer) tmp_buffer.append(axis) splited_reduced_axes.append(tmp_buffer) else: for axis in reduced_axes: splited_reduced_axes.append([axis]) # if has reduce axes if len(splited_reduced_axes) > 0: # always reorder here reduced_nonfuse_lsts = [] reorder_lst = [] length = len(splited_reduced_axes[0]) # leave the last part for count in range(length - 1): tmp_buffer = [x[count] for x in splited_reduced_axes] reduced_nonfuse_lsts.append(tmp_buffer) reorder_lst.extend(tmp_buffer) # the last part last_part = [x[length - 1] for x in splited_reduced_axes] spatial_remainder = s[write_cache].op.axis # change the order of reduce axes and spatial axes if "reorder" in config and len(config["reorder"]) > 0: pos = config["reorder"][0][0] assert pos < len(spatial_remainder) tmp_buffer = [] count = len(spatial_remainder) - 1 while count > pos: tmp_buffer.append(spatial_remainder[count]) count -= 1 p = pos q = len(last_part) - 1 while p >= 0 and q >= 0: tmp_buffer.append(spatial_remainder[p]) tmp_buffer.append(last_part[q]) p -= 1 q -= 1 while p >= 0: tmp_buffer.append(spatial_remainder[p]) p -= 1 while q >= 0: tmp_buffer.append(last_part[q]) q -= 1 tmp_buffer = list(reversed(tmp_buffer)) reorder_lst.extend(tmp_buffer) else: reorder_lst.extend(last_part) reorder_lst.extend(spatial_remainder) s[write_cache].reorder(*reorder_lst) # unroll if "unroll" in config and len(config["unroll"]) > 0: step = config["unroll"][0][0] explicit = config["unroll"][0][1] s[op].pragma(kernel_scope, 'auto_unroll_max_step', step) s[op].pragma(kernel_scope, 'unroll_explicit', explicit)
def gemm(A, B): k = tvm.te.reduce_axis((0, B.shape[0])) assert_print(A.shape[1].value == B.shape[0].value) return tvm.te.compute([A.shape[0], B.shape[1]], lambda i, j: tvm.te.sum(A[i, k] * B[k, j], axis=k), requires_grad=True)
def cuda_schedule_split_fuse(config, s, op, op_state): # assert_print(op in s) # always cache write here # if op.num_outputs > 1: # raise RuntimeWarning("Too many outputs in one operation!") write_cache = s.cache_write(op.output(0), "local") # always cache read here read_cache_share_lst = [] read_cache_local_lst = [] for t in op.input_tensors: share = s.cache_read(t, "shared", [write_cache]) read_cache_share_lst.append(share) local = s.cache_read(share, "local", [write_cache]) read_cache_local_lst.append(local) # spatial split spatial_axes = s[op].op.axis splited_spatial_axes = [] if "spatial" in config and len(config["spatial"]) > 0: # to align each axis assert_print( len(config["spatial"]) == len(spatial_axes), "align failed") for axis, nparts in zip(spatial_axes, config["spatial"]): tmp_buffer = [] for count in range(len(nparts) - 1): outer, axis = s[op].split(axis, nparts=nparts[count]) tmp_buffer.append(outer) tmp_buffer.append(axis) splited_spatial_axes.append(tmp_buffer) else: for axis in spatial_axes: splited_spatial_axes.append([axis]) assert_print(len(splited_spatial_axes) > 0, "empty spatial axes") # must be non-empty # always reorder and fuse here spatial_fuse_lsts = [] spatial_fuse_extents = [] reorder_lst = [] fused_spatial_axes = [] for count in range(len(splited_spatial_axes[0])): tmp_buffer = [x[count] for x in splited_spatial_axes] tmp_extent = reduce(lambda a, b: a * b, [x[count] for x in config["spatial"]]) spatial_fuse_lsts.append(tmp_buffer) spatial_fuse_extents.append(tmp_extent) reorder_lst.extend(tmp_buffer) s[op].reorder(*reorder_lst) for fuse_lst in spatial_fuse_lsts: fused = s[op].fuse(*fuse_lst) fused_spatial_axes.append(fused) kernel_scope = fused_spatial_axes[0] # always bind here length = len(fused_spatial_axes) thread_extents = 1 assert_print(length > 1, "fused axes length <= 1") if 2 <= length <= 3: s[op].bind(fused_spatial_axes[0], te.thread_axis("blockIdx.x")) s[op].bind(fused_spatial_axes[1], te.thread_axis("threadIdx.x")) thread_pos = fused_spatial_axes[1] thread_extents = spatial_fuse_extents[1] else: s[op].bind(fused_spatial_axes[0], te.thread_axis("blockIdx.x")) s[op].bind(fused_spatial_axes[1], te.thread_axis("vthread")) s[op].bind(fused_spatial_axes[2], te.thread_axis("threadIdx.x")) thread_pos = fused_spatial_axes[2] thread_extents = spatial_fuse_extents[2] # always compute at here s[write_cache].compute_at(s[op], thread_pos) # reduce_split reduced_axes = s[write_cache].op.reduce_axis splited_reduced_axes = [] if "reduce" in config and len(config["reduce"]) > 0: # to align each axis assert_print( len(config["reduce"]) == len(reduced_axes), "align reduce failed") for axis, nparts in zip(reduced_axes, config["reduce"]): tmp_buffer = [] for count in range(len(nparts) - 1): outer, axis = s[write_cache].split(axis, nparts=nparts[count]) tmp_buffer.append(outer) tmp_buffer.append(axis) splited_reduced_axes.append(tmp_buffer) else: for axis in reduced_axes: splited_reduced_axes.append([axis]) share_pos = None local_pos = None # if has reduce axes if len(splited_reduced_axes) > 0: # always reorder here reduced_nonfuse_lsts = [] reorder_lst = [] length = len(splited_reduced_axes[0]) for count in range(length): tmp_buffer = [x[count] for x in splited_reduced_axes] reduced_nonfuse_lsts.append(tmp_buffer) reorder_lst.extend(tmp_buffer) # change the order of reduce axes and spatial axes reorder_lst.extend(s[write_cache].op.axis) s[write_cache].reorder(*reorder_lst) if length == 1: share_pos = reduced_nonfuse_lsts[-1][0] else: share_pos = reduced_nonfuse_lsts[-2][0] local_pos = reduced_nonfuse_lsts[-1][-1] # always cache read here if share_pos is not None: for share in read_cache_share_lst: s[share].compute_at(s[write_cache], share_pos) else: for share in read_cache_share_lst: s[share].compute_inline() if local_pos is not None: for local in read_cache_local_lst: s[local].compute_at(s[write_cache], local_pos) else: for local in read_cache_local_lst: s[local].compute_inline() # always cooperative fetching if share_pos is not None: for share in read_cache_share_lst: fuse_lst = s[share].op.axis fused = s[share].fuse(*fuse_lst) outer, inner = s[share].split(fused, nparts=thread_extents) s[share].bind(outer, te.thread_axis("threadIdx.x")) # unroll if "unroll" in config and len(config["unroll"]) > 0: step = config["unroll"][0][0] explicit = config["unroll"][0][1] s[op].pragma(kernel_scope, 'auto_unroll_max_step', step) s[op].pragma(kernel_scope, 'unroll_explicit', explicit)
def cuda_schedule_fuse_split(config, s, op, op_state): # assert_print(op in s) # always cache write here # if op.num_outputs > 1: # raise RuntimeWarning("Too many outputs in one operation!") write_cache = s.cache_write(op.output(0), "local") # always cache read here read_cache_share_lst = [] # read_cache_local_lst = [] for t in op.input_tensors: share = s.cache_read(t, "shared", [write_cache]) read_cache_share_lst.append(share) # local = s.cache_read(share, "local", [write_cache]) # read_cache_local_lst.append(local) # spatial fuse spatial_axes = s[op].op.axis fused_spatial_axes = [] if "fuse" in config and len(config["fuse"]) > 0: # fuse redundant axes beg = 0 for end in config["fuse"][0]: fuse_lst = spatial_axes[beg:end] beg = end if len(fuse_lst) > 0: fused = s[op].fuse(*fuse_lst) fused_spatial_axes.append(fused) else: fused_spatial_axes = spatial_axes # spatial split split_factor_lst = [] splited_spatial_axes = [] if "spatial" in config and len(config["spatial"]) > 0: # to align each axis assert len(config["spatial"]) == len(spatial_axes), "align failed" # compute split factors if "fuse" in config and len(config["fuse"]) > 0: beg = 0 for end in config["fuse"][0]: tmp_lst = [1] * len(config["spatial"][0]) for i in range(beg, end): for j in range(len(config["spatial"][i])): tmp_lst[j] *= config["spatial"][i][j] if beg < end: split_factor_lst.append(tmp_lst) beg = end else: split_factor_lst = config["spatial"] assert len(fused_spatial_axes) == len(split_factor_lst), "align failed" for axis, nparts in zip(fused_spatial_axes, split_factor_lst): tmp_buffer = [] for count in range(len(nparts) - 1): outer, axis = s[op].split(axis, nparts=nparts[count]) tmp_buffer.append(outer) tmp_buffer.append(axis) splited_spatial_axes.append(tmp_buffer) else: for axis in fused_spatial_axes: splited_spatial_axes.append([axis]) assert len( splited_spatial_axes) > 0, "empty spatial axes" # must be non-empty # always reorder here reorder_lst = [] for count in range(len(splited_spatial_axes[0])): tmp_buffer = [x[count] for x in splited_spatial_axes] reorder_lst.extend(tmp_buffer) s[op].reorder(*reorder_lst) # fix kernel scope kernel_scope = reorder_lst[0] # always bind here # - prepare thread axis bx = te.thread_axis("blockIdx.x") by = te.thread_axis("blockIdx.y") bz = te.thread_axis("blockIdx.z") vx = te.thread_axis("vthread") vy = te.thread_axis("vthread") vz = te.thread_axis("vthread") tx = te.thread_axis("threadIdx.x") ty = te.thread_axis("threadIdx.y") tz = te.thread_axis("threadIdx.z") blocks = [bz, by, bx] threads = [tz, ty, tx] vthreads = [vz, vy, vx] block_extents = [-1, -1, -1] # z, y, x virtual_extents = [-1, -1, -1] thread_extents = [-1, -1, -1] length = len(splited_spatial_axes) assert length >= 1 # - bind count = min(length, len(blocks)) - 1 while count >= 0: parts = len(splited_spatial_axes[count]) assert parts > 0 if parts == 1: s[op].bind(splited_spatial_axes[count][0], blocks[count]) block_extents[count] = split_factor_lst[count][0] elif parts == 2: s[op].bind(splited_spatial_axes[count][0], blocks[count]) block_extents[count] = split_factor_lst[count][0] s[op].bind(splited_spatial_axes[count][1], threads[count]) thread_extents[count] = split_factor_lst[count][1] else: s[op].bind(splited_spatial_axes[count][0], blocks[count]) block_extents[count] = split_factor_lst[count][0] s[op].bind(splited_spatial_axes[count][1], vthreads[count]) virtual_extents[count] = split_factor_lst[count][1] s[op].bind(splited_spatial_axes[count][2], threads[count]) thread_extents[count] = split_factor_lst[count][2] count -= 1 # - compute at pos count = min(length, len(blocks)) - 1 parts = len(splited_spatial_axes[count]) thread_pos = splited_spatial_axes[count][min(parts - 1, 2)] # always compute at here s[write_cache].compute_at(s[op], thread_pos) # reduce_split reduced_axes = s[write_cache].op.reduce_axis splited_reduced_axes = [] if "reduce" in config and len(config["reduce"]) > 0: # to align each axis assert_print( len(config["reduce"]) == len(reduced_axes), "align reduce failed") for axis, nparts in zip(reduced_axes, config["reduce"]): tmp_buffer = [] for count in range(len(nparts) - 1): outer, axis = s[write_cache].split(axis, nparts=nparts[count]) tmp_buffer.append(outer) tmp_buffer.append(axis) splited_reduced_axes.append(tmp_buffer) else: for axis in reduced_axes: splited_reduced_axes.append([axis]) share_pos = None # local_pos = None # if has reduce axes if len(splited_reduced_axes) > 0: # always reorder here reduced_nonfuse_lsts = [] reorder_lst = [] length = len(splited_reduced_axes[0]) # leave the last part for count in range(length - 1): tmp_buffer = [x[count] for x in splited_reduced_axes] reduced_nonfuse_lsts.append(tmp_buffer) reorder_lst.extend(tmp_buffer) # the last part last_part = [x[length - 1] for x in splited_reduced_axes] spatial_remainder = s[write_cache].op.axis # change the order of reduce axes and spatial axes if "reorder" in config and len(config["reorder"]) > 0: pos = config["reorder"][0][0] assert pos < len(spatial_remainder) tmp_buffer = [] count = len(spatial_remainder) - 1 while count > pos: tmp_buffer.append(spatial_remainder[count]) count -= 1 p = pos q = len(last_part) - 1 while p >= 0 and q >= 0: tmp_buffer.append(spatial_remainder[p]) tmp_buffer.append(last_part[q]) p -= 1 q -= 1 while p >= 0: tmp_buffer.append(spatial_remainder[p]) p -= 1 while q >= 0: tmp_buffer.append(last_part[q]) q -= 1 tmp_buffer = list(reversed(tmp_buffer)) reorder_lst.extend(tmp_buffer) else: reorder_lst.extend(last_part) reorder_lst.extend(spatial_remainder) s[write_cache].reorder(*reorder_lst) # decide where to compute at if length == 1: share_pos = last_part[-1] else: mid = math.ceil(length / 2.0) - 1 share_pos = reduced_nonfuse_lsts[mid][-1] # local_pos = last_part[-1] # always cache read here if share_pos is not None: for share in read_cache_share_lst: s[share].compute_at(s[write_cache], share_pos) else: for share in read_cache_share_lst: s[share].compute_inline() # if local_pos is not None: # for local in read_cache_local_lst: # s[local].compute_at(s[write_cache], local_pos) # else: # for local in read_cache_local_lst: # s[local].compute_inline() # always cooperative fetching if share_pos is not None: for share in read_cache_share_lst: fuse_lst = s[share].op.axis fused = s[share].fuse(*fuse_lst) count = 2 cur = 1 limit = 1024 while count >= 0: factor = thread_extents[count] if factor < 0: defined = False factor = 16 else: defined = True cur *= factor if not defined and cur > limit: break fused, inner = s[share].split(fused, factor=factor) s[share].bind(inner, threads[count]) count -= 1 # unroll if "unroll" in config and len(config["unroll"]) > 0: step = config["unroll"][0][0] explicit = config["unroll"][0][1] s[op].pragma(kernel_scope, 'auto_unroll_max_step', step) s[op].pragma(kernel_scope, 'unroll_explicit', explicit)
def cuda_schedule_split_reorder_fuse(config, s, op, op_state): # assert_print(op in s) loop_lst = [] loop_idx = [] # always cache write here # if op.num_outputs > 1: # raise RuntimeWarning("Too many outputs in one operation!") write_cache = s.cache_write(op.output(0), "local") # always cache read here read_cache_share_lst = [] # read_cache_local_lst = [] for t in op.input_tensors: share = s.cache_read(t, "shared", [write_cache]) read_cache_share_lst.append(share) # local = s.cache_read(share, "local", [write_cache]) # read_cache_local_lst.append(local) # spatial split spatial_axes = [axis for axis in s[op].op.axis] assert len(spatial_axes) > 0, "empty spatial axes" # must be non-empty n = spatial_axes[0] kernel_scope, n = s[op].split(n, nparts=1) spatial_axes[0] = n splited_spatial_axes = [] splited_spatial_extents = [] if "spatial" in config and len(config["spatial"]) > 0: # to align each axis assert len(config["spatial"]) == len(spatial_axes), "align failed" for axis, nparts in zip(spatial_axes, config["spatial"]): tmp_buffer = [] tmp_extents = [] for count in range(len(nparts) - 1): outer, axis = s[op].split(axis, nparts=nparts[count]) tmp_buffer.append(outer) tmp_extents.append(nparts[count]) tmp_buffer.append(axis) tmp_extents.append(nparts[-1]) splited_spatial_axes.append(tmp_buffer) splited_spatial_extents.append(tmp_extents) else: for axis in spatial_axes: splited_spatial_axes.append([axis]) splited_spatial_extents.append([axis.dom.extent.value]) # always reorder here reorder_lst = [] reorder_parts = [] reorder_part_extents = [] for count in range(len(splited_spatial_axes[0])): tmp_buffer = [x[count] for x in splited_spatial_axes] tmp_extents = [x[count] for x in splited_spatial_extents] reorder_lst.extend(tmp_buffer) reorder_parts.append(tmp_buffer) reorder_part_extents.append(tmp_extents) s[op].reorder(*reorder_lst) # handle fuse request fused_parts = [] fused_part_extents = [] fused_part_idx = [] if "fuse" in config and len(config["fuse"]) > 0: base_id = 0 for part, extents in zip(reorder_parts, reorder_part_extents): tmp_part = [] tmp_extents = [] tmp_idx = [] idx = 0 beg = 0 for end in config["fuse"][0]: if end - beg > 1: fuse_lst = part[beg:end] fused = s[op].fuse(*fuse_lst) tmp_part.append(fused) extent = reduce(lambda x, y: x * y, extents[beg:end], 1) tmp_idx.extend([idx] * (end - beg)) idx += 1 tmp_extents.append(extent) elif end - beg == 1: tmp_part.append(part[beg]) tmp_extents.append(extents[beg]) tmp_idx.append(idx) idx += 1 beg = end fused_parts.append(tmp_part) fused_part_extents.append(tmp_extents) fused_part_idx.append(tmp_idx) loop_lst.extend(tmp_part) loop_idx.extend([x + base_id for x in tmp_idx]) base_id += len(tmp_part) else: fused_parts = reorder_parts fused_part_extents = reorder_part_extents fused_part_idx = [list(range(len(x))) for x in reorder_parts] loop_lst = reorder_lst loop_idx = list(range(len(reorder_lst))) # record op state op_state.loop_lst = loop_lst op_state.loop_idx = loop_idx # always bind here # - prepare thread axis bx = te.thread_axis("blockIdx.x") by = te.thread_axis("blockIdx.y") bz = te.thread_axis("blockIdx.z") vx = te.thread_axis("vthread") vy = te.thread_axis("vthread") vz = te.thread_axis("vthread") tx = te.thread_axis("threadIdx.x") ty = te.thread_axis("threadIdx.y") tz = te.thread_axis("threadIdx.z") blocks = [bz, by, bx] threads = [tz, ty, tx] vthreads = [vz, vy, vx] block_extents = [-1, -1, -1] # z, y, x virtual_extents = [-1, -1, -1] thread_extents = [-1, -1, -1] bind_option = [None, None, None] bind_candidate = [blocks, vthreads, threads] candiate_extents = [block_extents, virtual_extents, thread_extents] # - bind num_parts = len(fused_parts) if num_parts == 1: bind_option[0] = (fused_parts[0], fused_part_extents[0]) local_pos = fused_parts[0][:len(bind_candidate[0])][-1] elif num_parts == 2: bind_option[0] = (fused_parts[0], fused_part_extents[0]) bind_option[2] = (fused_parts[1], fused_part_extents[1]) local_pos = fused_parts[1][:len(bind_candidate[2])][-1] else: bind_option[0] = (fused_parts[0], fused_part_extents[0]) bind_option[1] = (fused_parts[1], fused_part_extents[1]) bind_option[2] = (fused_parts[2], fused_part_extents[2]) local_pos = fused_parts[2][:len(bind_candidate[2])][-1] for option, candidate, extents in zip(bind_option, bind_candidate, candiate_extents): if option is not None: for i, axis in enumerate(option[0][:len(candidate)]): s[op].bind(axis, candidate[i]) extents[i] = option[1][i] # compute at if "local_pos" in config and len(config["local_pos"]) > 0: local_at_part = config["local_pos"][0][0] local_at_idx = config["local_pos"][0][1] # index changed because of fusion cur_idx = fused_part_idx[local_at_part][local_at_idx] local_pos = fused_parts[local_at_part][cur_idx] # always compute at here s[write_cache].compute_at(s[op], local_pos) # reduce_split reduced_axes = s[write_cache].op.reduce_axis splited_reduced_axes = [] if "reduce" in config and len(config["reduce"]) > 0: # to align each axis assert_print( len(config["reduce"]) == len(reduced_axes), "align reduce failed") for axis, nparts in zip(reduced_axes, config["reduce"]): tmp_buffer = [] for count in range(len(nparts) - 1): outer, axis = s[write_cache].split(axis, nparts=nparts[count]) tmp_buffer.append(outer) tmp_buffer.append(axis) splited_reduced_axes.append(tmp_buffer) else: for axis in reduced_axes: splited_reduced_axes.append([axis]) share_pos = None # local_pos = None # if has reduce axes if len(splited_reduced_axes) > 0: # always reorder here reduced_nonfuse_lsts = [] reorder_lst = [] length = len(splited_reduced_axes[0]) # leave the last part for count in range(length - 1): tmp_buffer = [x[count] for x in splited_reduced_axes] reduced_nonfuse_lsts.append(tmp_buffer) reorder_lst.extend(tmp_buffer) # the last part last_part = [x[length - 1] for x in splited_reduced_axes] spatial_remainder = s[write_cache].op.axis # change the order of reduce axes and spatial axes if "reorder" in config and len(config["reorder"]) > 0: pos = config["reorder"][0][0] assert pos < len(spatial_remainder) tmp_buffer = [] count = len(spatial_remainder) - 1 while count > pos: tmp_buffer.append(spatial_remainder[count]) count -= 1 p = pos q = len(last_part) - 1 while p >= 0 and q >= 0: tmp_buffer.append(spatial_remainder[p]) tmp_buffer.append(last_part[q]) p -= 1 q -= 1 while p >= 0: tmp_buffer.append(spatial_remainder[p]) p -= 1 while q >= 0: tmp_buffer.append(last_part[q]) q -= 1 tmp_buffer = list(reversed(tmp_buffer)) reorder_lst.extend(tmp_buffer) else: reorder_lst.extend(last_part) reorder_lst.extend(spatial_remainder) s[write_cache].reorder(*reorder_lst) # decide where to compute at if "share_pos" in config and len(config["share_pos"]) > 0: share_at = config["share_pos"][0][0] share_idx = config["share_pos"][0][1] reduced_nonfuse_lsts.append(last_part) share_pos = reduced_nonfuse_lsts[share_at][share_idx] else: if length == 1: share_pos = last_part[-1] else: mid = math.ceil(length / 2.0) - 1 share_pos = reduced_nonfuse_lsts[mid][-1] # local_pos = last_part[-1] # always cache read here if share_pos is not None: for share in read_cache_share_lst: s[share].compute_at(s[write_cache], share_pos) else: for share in read_cache_share_lst: s[share].compute_inline() # if local_pos is not None: # for local in read_cache_local_lst: # s[local].compute_at(s[write_cache], local_pos) # else: # for local in read_cache_local_lst: # s[local].compute_inline() # always cooperative fetching if share_pos is not None: for share in read_cache_share_lst: fuse_lst = s[share].op.axis fused = s[share].fuse(*fuse_lst) count = 2 cur = 1 limit = 1024 while count >= 0: factor = thread_extents[count] if factor < 0: defined = False factor = 16 else: defined = True cur *= factor if not defined and cur > limit: break fused, inner = s[share].split(fused, factor=factor) s[share].bind(inner, threads[count]) count -= 1 # unroll if "unroll" in config and len(config["unroll"]) > 0: step = config["unroll"][0][0] explicit = config["unroll"][0][1] s[op].pragma(kernel_scope, 'auto_unroll_max_step', step) s[op].pragma(kernel_scope, 'unroll_explicit', explicit)
def cpu_schedule_split_fuse(config, s, op, op_state): # always cache write here # if op.num_outputs > 1: # raise RuntimeWarning("Too many outputs in one operation!") write_cache = s.cache_write(op.output(0), "global") # spatial split spatial_axes = s[op].op.axis splited_spatial_axes = [] if "spatial" in config and len(config["spatial"]) > 0: # to align each axis assert_print(len(config["spatial"]) == len(spatial_axes), "align failed") for axis, nparts in zip(spatial_axes, config["spatial"]): tmp_buffer = [] for count in range(len(nparts) - 1): outer, axis = s[op].split(axis, nparts=nparts[count]) tmp_buffer.append(outer) tmp_buffer.append(axis) splited_spatial_axes.append(tmp_buffer) else: for axis in spatial_axes: splited_spatial_axes.append([axis]) assert_print(len(splited_spatial_axes) > 0, "empty spatial axes") # must be non-empty # always reorder and fuse here spatial_fuse_lsts = [] spatial_fuse_extents = [] reorder_lst = [] fused_spatial_axes = [] for count in range(len(splited_spatial_axes[0])): tmp_buffer = [x[count] for x in splited_spatial_axes] tmp_extent = reduce(lambda a, b: a * b, [x[count] for x in config["spatial"]]) spatial_fuse_lsts.append(tmp_buffer) spatial_fuse_extents.append(tmp_extent) reorder_lst.extend(tmp_buffer) s[op].reorder(*reorder_lst) for fuse_lst in spatial_fuse_lsts: fused = s[op].fuse(*fuse_lst) fused_spatial_axes.append(fused) kernel_scope = fused_spatial_axes[0] # always parallel here length = len(fused_spatial_axes) assert_print(length > 0, "empty spatial axes!") s[op].parallel(fused_spatial_axes[0]) if length == 1: thread_pos = fused_spatial_axes[0] if 2 <= length <= 3: thread_pos = fused_spatial_axes[1] else: thread_pos = fused_spatial_axes[2] # always compute at here s[write_cache].compute_at(s[op], thread_pos) # reduce_split reduced_axes = s[write_cache].op.reduce_axis splited_reduced_axes = [] if "reduce" in config and len(config["reduce"]) > 0: # to align each axis assert_print(len(config["reduce"]) == len(reduced_axes), "align reduce failed") for axis, nparts in zip(reduced_axes, config["reduce"]): tmp_buffer = [] for count in range(len(nparts) - 1): outer, axis = s[write_cache].split(axis, nparts=nparts[count]) tmp_buffer.append(outer) tmp_buffer.append(axis) splited_reduced_axes.append(tmp_buffer) else: for axis in reduced_axes: splited_reduced_axes.append([axis]) # if has reduce axes if len(splited_reduced_axes) > 0: # always reorder here reduced_nonfuse_lsts = [] reorder_lst = [] length = len(splited_reduced_axes[0]) for count in range(length): tmp_buffer = [x[count] for x in splited_reduced_axes] reduced_nonfuse_lsts.append(tmp_buffer) reorder_lst.extend(tmp_buffer) # change the order of reduce axes and spatial axes rlength = len(splited_reduced_axes) if rlength > 1: reorder_lst.extend(s[write_cache].op.axis) elif rlength == 1: # in this case, have to interleave otherwise the reorder is of no use tmp_order = [] p_spatial = len(s[write_cache].op.axis) - 1 p_reduce = len(reorder_lst) - 1 while p_spatial >= 0 and p_reduce >= 0: tmp_order.append(s[write_cache].op.axis[p_spatial]) tmp_order.append(reorder_lst[p_reduce]) p_spatial -= 1 p_reduce -= 1 while p_spatial >= 0: tmp_order.append(s[write_cache].op.axis[p_spatial]) p_spatial -= 1 while p_reduce >= 0: tmp_order.append(reorder_lst[p_reduce]) p_reduce -= 1 tmp_order = list(reversed(tmp_order)) reorder_lst = tmp_order s[write_cache].reorder(*reorder_lst) # unroll if "unroll" in config and len(config["unroll"]) > 0: step = config["unroll"][0][0] s[op].pragma(kernel_scope, 'auto_unroll_max_step', step) # always vectorize here s[write_cache].vectorize(s[write_cache].op.axis[-1])