def compute_depthwise_conv2d_NHWC_HWOI(Input, Filter, stride, padding, dilation, out_dtype=None, args={}): """Depthwise convolution operator in NCHWc layout. """ if out_dtype is None: out_dtype = Input.dtype assert isinstance(stride, int) or len(stride) == 2 assert isinstance(dilation, int) or len(dilation) == 2 if isinstance(stride, int): stride_h = stride_w = stride else: stride_h, stride_w = stride if isinstance(dilation, int): dilation_h = dilation_w = dilation else: dilation_h, dilation_w = dilation batch, in_height, in_width, channels = Input.shape kernel_h, kernel_w, _, _ = Filter.shape # compute the output shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 pad_top, pad_left, pad_down, pad_right = nn.get_pad_tuple( padding, (dilated_kernel_h, dilated_kernel_w)) out_height_orig = out_height = simplify( (in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) out_width_orig = out_width = simplify( (in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) channel_block = 4 channel_chunk = channels // channel_block num_filter_chunk = 1 # compute: Input = te.compute( [batch, in_height, in_width, channel_chunk, channel_block], lambda nn, yy, xx, icc, icb: Input[nn, yy, xx, icc * 4 + icb], name="input_pack", tag="input_pack", ) Filter = te.compute( [kernel_h, kernel_w, channel_chunk, num_filter_chunk, channel_block], lambda kh, kw, ifc, nfc, cb: Filter[kh, kw, ifc * 4 + cb, nfc], name="filter_pack", tag="filter_pack", ) # can output shape be divded by 2 or even 4? # if it cannot be divided, need to extend for further help with split # theortically there should be addition padding for inputs, but it will be optimized by # cache_read InferBound. We must proceed pad here exactly to produce tensor which is # required for calculation of original out size, not more! In other case intermediate # tensor might be allcoated with less sizes while compute will try to fill the expanded # one - data discrepancy as a result # And in case of textures it is not a problem if we provide texture of less size because # 1. It is not important which valuses would be for extra calc - these calculations are # required only for better utilizatin of GPU fit to working groups # 2. When we request pixel out opf bound, texture will handle this correctly. As mentioned # above, the value itself is not important if out_height % 2 != 0: out_height += 1 if out_width % 2 != 0: out_width += 1 if out_height % 4 != 0: out_height += 2 if out_width % 4 != 0: out_width += 2 # compute graph pad_before = [0, pad_top, pad_left, 0, 0] pad_after = [0, pad_down, pad_right, 0, 0] # calculation of real used input size: input_latest_w = (out_width_orig - 1) * stride_w + (kernel_w - 1) * dilation_w + 1 input_latest_h = (out_height_orig - 1) * stride_h + (kernel_h - 1) * dilation_h + 1 if input_latest_w < in_width + pad_before[3] + pad_after[3]: pad_after[ 3] -= in_width + pad_before[3] + pad_after[3] - input_latest_w if input_latest_h < in_height + pad_before[2] + pad_after[2]: pad_after[ 2] -= in_height + pad_before[2] + pad_after[2] - input_latest_h temp = nn.pad(Input, pad_before, pad_after, name="pad_temp") ry = te.reduce_axis((0, kernel_h), name="ry") rx = te.reduce_axis((0, kernel_w), name="rx") conv = te.compute( (batch, out_height, out_width, channel_chunk, channel_block), lambda nn, yy, xx, ffc, ffb: te.sum( (temp[nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, ffc, ffb] * Filter[ry, rx, ffc, 0, ffb]).astype( args["accumulator"]), axis=[ry, rx], ), tag="depthwise_conv2d_nhwc", ) dummy_cast = te.compute( (batch, out_height_orig, out_width_orig, channel_chunk, channel_block), lambda n, y, x, fc, fb: conv[n, y, x, fc, fb].astype(out_dtype), tag="dummy_cast") return te.compute((batch, out_height_orig, out_width_orig, channels), lambda n, y, x, c: dummy_cast[n, y, x, c // 4, c % 4], tag="cast_from_acc" + args["accumulator"][-2:])
def compute_conv2d_gemm_without_weight_transform(cfg, data, B_interleaved_t, strides, padding, dilation, out_dtype, kernel_size, output_channels): """Compute conv2d by transforming the input, executing GEMM and transforming the output back""" batches, IH, IW, IC = get_const_tuple(data.shape) KH, KW = get_const_tuple(kernel_size) OC = get_const_int(output_channels) K_AREA = KH * KW if isinstance(dilation, int): dilation_h = dilation_w = dilation else: dilation_h, dilation_w = get_const_tuple(dilation) dilated_kernel_h = (KH - 1) * dilation_h + 1 dilated_kernel_w = (KW - 1) * dilation_w + 1 pad_top, pad_left, pad_down, pad_right = get_pad_tuple( padding, (dilated_kernel_h, dilated_kernel_w)) HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1 OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1 if pad_top or pad_left: data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0], name="data_pad") else: data_pad = data # --- Im2col M = OH * OW K = IC * K_AREA N = OC A_shape = (batches, M, K) if K_AREA == 1: A = te.compute( A_shape, lambda n, x, y: data_pad[n, HSTR * (x // OW), WSTR * (x % OW), y], name="data_flatten", ) else: A = te.compute( A_shape, lambda n, x, y: data_pad[n, HSTR * (x // OW) + dilation_h * ( (y // IC) // KW), WSTR * (x % OW) + dilation_w * ((y // IC) % KW), y % IC, ], name="data_im2col", ) N_transformed = B_interleaved_t.shape[0] # --- Pad if necessary idxm = tvm.tir.indexmod pad_m = 0 pad_k = 0 if M % 4 != 0: pad_m = 4 - (M % 4) if K % 16 != 0: pad_k = 16 - (K % 16) M_padded = M + pad_m K_padded = K + pad_k pad_before = (0, 0, 0) pad_after = (0, pad_m, pad_k) if pad_m != 0 or pad_k != 0: A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded") # --- GEMM: A*B' k = te.reduce_axis((0, K_padded), "k") A_interleaved = te.compute( (batches, M_padded // 4, K_padded // 16, 4, 16), lambda b, x, y, z, w: A[b, z + 4 * x, w + 16 * y], name="A_interleaved", ) C_interleaved = te.compute( (batches, M_padded // 4, N_transformed, 4, 4), lambda b, x, y, w, z: te.sum( A_interleaved[b, x, k // 16, w, idxm(k, 16)].astype(out_dtype) * B_interleaved_t[y, k // 16, z, idxm(k, 16)].astype(out_dtype), axis=k, ), name="C_interleaved", ) # --- Unpack C C = te.compute( (batches, M, N), lambda b, x, y: C_interleaved[b, x // 4, y // 4, idxm(x, 4), idxm(y, 4)], name="C", ) # --- Produce the conv output out_shape = (batches, OH, OW, OC) out = te.compute(out_shape, lambda b, x, y, z: C(b, y + OW * x, z), name="conv2d_gemm_output") # Configuration space x, y = cfg.axis(M_padded // 4), cfg.axis(K_padded // 16) cfg.define_reorder("reorder_gemm", [x, y], policy="candidate", candidate=[[x, y], [y, x]]) outer_loop, inner_loop = cfg.axis(4), cfg.axis(16) cfg.define_annotate("A_interleaved_unroll_vec", [outer_loop, inner_loop], policy="try_unroll_vec") cfg.define_knob("gemm_quantized_unroll", [True, False]) cfg.define_knob("gemm_quantized_interleave", [True, False]) # Fallback configuration if cfg.is_fallback: cfg["reorder_gemm"] = ReorderEntity([0, 1]) cfg["A_interleaved_unroll_vec"] = AnnotateEntity(["unroll", "vec"]) cfg["gemm_quantized_unroll"] = OtherOptionEntity(False) cfg["gemm_quantized_interleave"] = OtherOptionEntity(True) return out
def compute_depthwise_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilation, out_dtype=None, args={}): """Depthwise convolution operator in NCHWc layout. """ if out_dtype is None: out_dtype = Input.dtype assert isinstance(stride, int) or len(stride) == 2 assert isinstance(dilation, int) or len(dilation) == 2 if isinstance(stride, int): stride_h = stride_w = stride else: stride_h, stride_w = stride if isinstance(dilation, int): dilation_h = dilation_w = dilation else: dilation_h, dilation_w = dilation batch, channel_chunk, in_height, in_width, channel_block = Input.shape _, channel_multiplier, kernel_h, kernel_w, _ = Filter.shape # compute the output shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 pad_top, pad_left, pad_down, pad_right = nn.get_pad_tuple( padding, (dilated_kernel_h, dilated_kernel_w)) out_channel_chunk = simplify(channel_chunk * channel_multiplier) out_height = simplify( (in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) out_width = simplify( (in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) # compute graph pad_before = [0, 0, pad_top, pad_left, 0] pad_after = [0, 0, pad_down, pad_right, 0] temp = nn.pad(Input, pad_before, pad_after, name="pad_temp") ry = te.reduce_axis((0, kernel_h), name="ry") rx = te.reduce_axis((0, kernel_w), name="rx") if autotvm.GLOBAL_SCOPE.in_tuning: # NCHWc x CMRSc = [N,(C//4)M,OH,OW, 4c] # NCHWc x CMRS # texture: NCH|W|c # texture: C|MRS|c Filter_tx = te.compute( (channel_chunk, channel_multiplier * kernel_h * kernel_w, channel_block), lambda ffc, mrs, ffb: Filter[ffc, mrs // (kernel_h * kernel_w), ( mrs // kernel_w) % kernel_h, mrs % kernel_w, ffb], name="packed_filter") conv = te.compute( (batch, out_channel_chunk, out_height, out_width, channel_block), lambda nn, ffc, yy, xx, ffb: te.sum( (temp[nn, ffc // channel_multiplier, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, ffb] * Filter_tx[ffc // channel_multiplier, ( (ffc % channel_multiplier) * kernel_h + ry) * kernel_w + rx, ffb]).astype(args["accumulator"]), axis=[ry, rx], ), tag="depthwise_conv2d_nchwc_kcrsk_texture", ) else: conv = te.compute( (batch, out_channel_chunk, out_height, out_width, channel_block), lambda nn, ffc, yy, xx, ffb: te.sum( (temp[nn, ffc // channel_multiplier, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, ffb] * Filter[ffc // channel_multiplier, ffc % channel_multiplier, ry, rx, ffb]).astype(args["accumulator"]), axis=[ry, rx], ), tag="depthwise_conv2d_nchwc_kcrsk", ) return te.compute( conv.shape, lambda n, ffc, y, x, ffb: conv[n, ffc, y, x, ffb].astype(out_dtype), tag="cast_from_acc" + args["accumulator"][-2:])
def compute_conv2d_gemm_without_weight_transform(cfg, data, B_interleaved_t, strides, padding, dilation, out_dtype, kernel_size, output_channels): """Compute conv2d by transforming the input, executing GEMM and transforming the output back""" batches, IH, IW, IC = get_const_tuple(data.shape) KH, KW = kernel_size OC = output_channels K_AREA = KH * KW if isinstance(dilation, int): dilation_h = dilation_w = dilation else: dilation_h, dilation_w = dilation dilated_kernel_h = (KH - 1) * dilation_h + 1 dilated_kernel_w = (KW - 1) * dilation_w + 1 pad_top, pad_left, pad_down, pad_right = \ get_pad_tuple(padding, (dilated_kernel_h, dilated_kernel_w)) HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1 OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1 if pad_top or pad_left: data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0], name="data_pad") else: data_pad = data # --- Im2col M = OH * OW K = IC * K_AREA N = OC A_shape = (batches, M, K) if K_AREA == 1: A = te.compute(A_shape, lambda n, x, y: data_pad[n, HSTR * (x // OW), WSTR * (x % OW), y], name='data_flatten') else: A = te.compute(A_shape, lambda n, x, y: data_pad[n, HSTR * (x // OW) + dilation_h * ( (y // IC) // KW), WSTR * (x % OW) + dilation_w * ( (y // IC) % KW), y % IC], name='data_im2col') N_transformed = B_interleaved_t.shape[0] # --- Pad if necessary idxm = tvm.tir.indexmod pad_m = 0 pad_k = 0 if M % 4 != 0: pad_m = 4 - (M % 4) if K % 16 != 0: pad_k = 16 - (K % 16) M_padded = M + pad_m K_padded = K + pad_k pad_before = (0, 0, 0) pad_after = (0, pad_m, pad_k) if pad_m != 0 or pad_k != 0: A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded") # --- GEMM: A*B' k = te.reduce_axis((0, K_padded), "k") A_interleaved = te.compute( (batches, M_padded // 4, K_padded // 16, 4, 16), lambda b, x, y, z, w: A[b, z + 4 * x, w + 16 * y], name='A_interleaved') C_interleaved = te.compute( (batches, M_padded // 4, N_transformed, 4, 4), lambda b, x, y, w, z: te.sum(A_interleaved[ b, x, k // 16, w, idxm(k, 16)].astype(out_dtype) * B_interleaved_t[ y, k // 16, z, idxm(k, 16)].astype(out_dtype), axis=k), name='C_interleaved') # --- Unpack C C = te.compute((batches, M, N), lambda b, x, y: C_interleaved[b, x // 4, y // 4, idxm(x, 4), idxm(y, 4)], name="C") # --- Produce the conv output out_shape = (batches, OH, OW, OC) out = te.compute(out_shape, lambda b, x, y, z: C(b, y + OW * x, z), name='conv2d_gemm_output') return out
def compute_conv2d_NHWC_HWIO(Input, Filter, stride, padding, dilation, out_dtype=None, args={}): """Convolution operator in NHWC layout. """ if out_dtype is None: out_dtype = Input.dtype assert isinstance(stride, int) or len(stride) == 2 assert isinstance(dilation, int) or len(dilation) == 2 if isinstance(stride, int): stride_h = stride_w = stride else: stride_h, stride_w = stride if isinstance(dilation, int): dilation_h = dilation_w = dilation else: dilation_h, dilation_w = dilation batch, in_height, in_width, in_channel = Input.shape kernel_h, kernel_w, _, out_channels = Filter.shape # compute the output shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 pad_top, pad_left, pad_down, pad_right = nn.get_pad_tuple( padding, (dilated_kernel_h, dilated_kernel_w)) out_height_orig = out_height = simplify( (in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) out_width_orig = out_width = simplify( (in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) in_channel_block = 4 in_channel_tail = in_channel % in_channel_block in_channel_chunk = in_channel // in_channel_block num_filter_block = 4 num_filter_tail = out_channels % num_filter_block num_filter_chunk = out_channels // num_filter_block pad_value = tvm.tir.const(0, Input.dtype) # compute: if in_channel_tail == 0: Input = te.compute( [batch, in_height, in_width, in_channel_chunk, in_channel_block], lambda nn, yy, xx, icc, icb: Input[nn, yy, xx, icc * in_channel_block + icb], name="input_pack", tag="input_pack", ) else: in_channel_chunk += 1 def _reorder_data(*indices): condition = [] condition.append(indices[3] == in_channel_chunk - 1) condition.append(indices[4] >= in_channel_tail) condition = tvm.tir.all(*condition) return tvm.tir.if_then_else( condition, pad_value, Input[indices[0], indices[1], indices[2], indices[3] * in_channel_block + indices[4]]) Input = te.compute( [batch, in_height, in_width, in_channel_chunk, in_channel_block], _reorder_data, name="input_pack", tag="input_pack_expanded", ) if num_filter_tail == 0 and in_channel_tail == 0: Filter = te.compute( [ kernel_h, kernel_w, in_channel, num_filter_chunk, num_filter_block ], lambda kh, kw, ic, nfc, nfb: Filter[kh, kw, ic, nfc * num_filter_block + nfb], name="filter_pack", tag="filter_pack", ) else: num_filter_chunk += 1 # HWIO def _reorder_weights(*indices): conditionA = [] conditionA.append(indices[3] == num_filter_chunk - 1) conditionA.append(indices[4] >= num_filter_block) conditionAT = tvm.tir.all(*conditionA) conditionO = [] conditionO.append(conditionAT) conditionO.append( indices[2] >= in_channel_chunk * in_channel_block + in_channel_tail) conditionOT = tvm.tir.any(*conditionO) return tvm.tir.if_then_else( conditionOT, pad_value, Filter[indices[0], indices[1], indices[2], indices[3] * num_filter_block + indices[4]]) Filter = te.compute( [ kernel_h, kernel_w, in_channel, num_filter_chunk, num_filter_block ], _reorder_weights, name="filter_pack", tag="filter_pack_expanded", ) # can output shape be divded by 2 or even 4? # if it cannot be divided, need to extend for further help with split # theortically there should be addition padding for inputs, but it will be optimized by # cache_read InferBound. We must proceed pad here exactly to produce tensor which is # required for calculation of original out size, not more! In other case intermediate # tensor might be allcoated with less sizes while compute will try to fill the expanded # one - data discrepancy as a result # And in case of textures it is not a problem if we provide texture of less size because # 1. It is not important which valuses would be for extra calc - these calculations are # required only for better utilizatin of GPU fit to working groups # 2. When we request pixel out opf bound, texture will handle this correctly. As mentioned # above, the value itself is not important if out_height % 2 != 0: out_height += 1 if out_width % 2 != 0: out_width += 1 if out_height % 4 != 0: out_height += 2 if out_width % 4 != 0: out_width += 2 # compute graph pad_before = [0, pad_top, pad_left, 0, 0] pad_after = [0, pad_down, pad_right, 0, 0] # calculation of real used input size: input_latest_w = (out_width_orig - 1) * stride_w + (kernel_w - 1) * dilation_w + 1 input_latest_h = (out_height_orig - 1) * stride_h + (kernel_h - 1) * dilation_h + 1 if input_latest_w < in_width + pad_before[3] + pad_after[3]: pad_after[ 3] -= in_width + pad_before[3] + pad_after[3] - input_latest_w if input_latest_h < in_height + pad_before[2] + pad_after[2]: pad_after[ 2] -= in_height + pad_before[2] + pad_after[2] - input_latest_h temp = nn.pad(Input, pad_before, pad_after, name="pad_temp") rcc = te.reduce_axis((0, in_channel_chunk), name="rc") rcb = te.reduce_axis((0, in_channel_block), name="rc") ry = te.reduce_axis((0, kernel_h), name="ry") rx = te.reduce_axis((0, kernel_w), name="rx") conv = te.compute( (batch, out_height, out_width, num_filter_chunk, num_filter_block), lambda nn, yy, xx, fc, fb: te.sum( (temp[nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rcc, rcb] * Filter[ry, rx, rcc * in_channel_block + rcb, fc, fb]).astype(args[ "accumulator"]), axis=[ry, rx, rcc, rcb], ), tag="conv2d_nhwc", ) dummy_cast = te.compute( (batch, out_height_orig, out_width_orig, num_filter_chunk, num_filter_block), lambda n, y, x, fc, fb: conv[n, y, x, fc, fb].astype(out_dtype), tag="dummy_cast") return te.compute((batch, out_height_orig, out_width_orig, out_channels), lambda n, y, x, c: dummy_cast[n, y, x, c // 4, c % 4], tag="cast_from_acc" + args["accumulator"][-2:])
def compute_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilation, out_dtype=None, args={}): """Convolution operator in NCHWc layout. """ if out_dtype is None: out_dtype = Input.dtype assert isinstance(stride, int) or len(stride) == 2 assert isinstance(dilation, int) or len(dilation) == 2 if isinstance(stride, int): stride_h = stride_w = stride else: stride_h, stride_w = stride if isinstance(dilation, int): dilation_h = dilation_w = dilation else: dilation_h, dilation_w = dilation batch, in_channel_chunk, in_height, in_width, in_channel_block = Input.shape num_filter_chunk, channel, kernel_h, kernel_w, num_filter_block = Filter.shape # compute the output shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 pad_top, pad_left, pad_down, pad_right = nn.get_pad_tuple( padding, (dilated_kernel_h, dilated_kernel_w)) out_height_orig = out_height = simplify( (in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) out_width_orig = out_width = simplify( (in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) # can output shape be divded by 2 or even 4? # if it cannot be divided, need to extend for further help with split # theortically there should be addition padding for inputs, but it will be optimized by # cache_read InferBound. We must proceed pad here exactly to produce tensor which is # required for calculation of original out size, not more! In other case intermediate # tensor might be allcoated with less sizes while compute will try to fill the expanded # one - data discrepancy as a result # And in case of textures it is not a problem if we provide texture of less size because # 1. It is not important which valuses would be for extra calc - these calculations are # required only for better utilizatin of GPU fit to working groups # 2. When we request pixel out opf bound, texture will handle this correctly. As mentioned # above, the value itself is not important if out_height % 2 != 0: out_height += 1 if out_width % 2 != 0: out_width += 1 if out_height % 4 != 0: out_height += 2 if out_width % 4 != 0: out_width += 2 # compute graph pad_before = [0, 0, pad_top, pad_left, 0] pad_after = [0, 0, pad_down, pad_right, 0] # calculation of real used input size: input_latest_w = (out_width_orig - 1) * stride_w + (kernel_w - 1) * dilation_w + 1 input_latest_h = (out_height_orig - 1) * stride_h + (kernel_h - 1) * dilation_h + 1 if input_latest_w < in_width + pad_before[3] + pad_after[3]: pad_after[ 3] -= in_width + pad_before[3] + pad_after[3] - input_latest_w if input_latest_h < in_height + pad_before[2] + pad_after[2]: pad_after[ 2] -= in_height + pad_before[2] + pad_after[2] - input_latest_h temp = nn.pad(Input, pad_before, pad_after, name="pad_temp") rcc = te.reduce_axis((0, in_channel_chunk), name="rc") rcb = te.reduce_axis((0, in_channel_block), name="rc") ry = te.reduce_axis((0, kernel_h), name="ry") rx = te.reduce_axis((0, kernel_w), name="rx") # When tuning, insert a cache_read("texture") stage to properly test # performance of kernels that utlize texture inputs. The cache_read # is not needed when using the graph_runtime which supports passing # in external texture buffers. This can be removed once AutoTVM tuning # supports capturing this runtime information during task extraction # or once texture lowering in tir.TextureFlatten supports cache_read # cancellation when padding is utilized. if autotvm.GLOBAL_SCOPE.in_tuning: # NCHWc x KCRSk # texture: NCH|W|c # texture: K|CRS|k Filter_tx = te.compute( (num_filter_chunk, channel * kernel_h * kernel_w, num_filter_block), lambda ffc, crs, ffb: Filter[ffc, crs // (kernel_h * kernel_w), ( crs // kernel_w) % kernel_h, crs % kernel_w, ffb], name="packed_filter") conv = te.compute( (batch, num_filter_chunk, out_height, out_width, num_filter_block), lambda nn, ffc, yy, xx, ffb: te.sum( (temp[nn, rcc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rcb] * Filter_tx[ffc, ( (rcc * in_channel_block + rcb) * kernel_h + ry ) * kernel_w + rx, ffb]).astype(args["accumulator"]), axis=[rcc, rcb, ry, rx], ), tag="conv2d_nchwc", ) else: conv = te.compute( (batch, num_filter_chunk, out_height, out_width, num_filter_block), lambda nn, ffc, yy, xx, ffb: te.sum( (temp[nn, rcc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rcb] * Filter[ ffc, rcc * in_channel_block + rcb, ry, rx, ffb]). astype(args["accumulator"]), axis=[rcc, rcb, ry, rx], ), tag="conv2d_nchwc", ) return te.compute( (batch, num_filter_chunk, out_height_orig, out_width_orig, num_filter_block), lambda n, fc, y, x, fb: conv[n, fc, y, x, fb].astype(out_dtype), tag="cast_from_acc" + args["accumulator"][-2:])
def add_pad( data, layout, out_height, out_width, kernel_h, kernel_w, dilation_h, dilation_w, padding, stride_h, stride_w, ): """Computes required padding values by the parameters of conv2d and adds compute for extending of original tensor Parameters ---------- data: tvm.te.Tensor 5d tensor, the layout of spatial dimensions are defined as separate argument layout: string Layout of origin 4d tensor out_height: int Height of the output feature map out_width: int Width of the output feature map kernel_h: int Height of the conv2d kernel kernel_w: int Width of the conv2d kernel dilation_h: int Height dilation value from conv2d attributes dilation_w: int Width dilation value from conv2d attributes padding: list / tuple of n ints Padding values from conv2d attributes stride_h: int Height stride value from conv2d attributes stride_w: int Width stride value from conv2d attributes Returns ------- Output : tvm.te.Tensor n-D, the same layout as Input. """ dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 pad_top, pad_left, pad_down, pad_right = nn.get_pad_tuple( padding, (dilated_kernel_h, dilated_kernel_w)) # compute graph if layout == "NCHW": y_axis = 2 x_axis = 3 if len(data.shape) == 4: _, _, in_height, in_width = data.shape else: _, _, in_height, in_width, _ = data.shape elif layout == "NHWC": y_axis = 1 x_axis = 2 if len(data.shape) == 4: _, in_height, in_width, _ = data.shape else: _, in_height, in_width, _, _ = data.shape else: assert False, "not supported layout in adreno util add_pad" pad_before = [0, 0, 0, 0, 0] pad_after = [0, 0, 0, 0, 0] pad_before[y_axis] = pad_top pad_before[x_axis] = pad_left pad_after[y_axis] = pad_down pad_after[x_axis] = pad_right # calculation of real used input size: input_latest_w = (out_width - 1) * stride_w + (kernel_w - 1) * dilation_w + 1 input_latest_h = (out_height - 1) * stride_h + (kernel_h - 1) * dilation_h + 1 if input_latest_w < in_width + pad_before[x_axis] + pad_after[x_axis]: pad_after[x_axis] -= in_width + pad_before[x_axis] + pad_after[ x_axis] - input_latest_w if input_latest_h < in_height + pad_before[y_axis] + pad_after[y_axis]: pad_after[y_axis] -= in_height + pad_before[y_axis] + pad_after[ y_axis] - input_latest_h return nn.pad(data, pad_before, pad_after, name="pad_temp")
def compute_conv2d_NCHWc_tpack(Input, Filter, stride, padding, dilation, out_dtype=None, args={}): """Convolution operator in NCHWc layout. """ if out_dtype is None: out_dtype = Input.dtype assert isinstance(stride, int) or len(stride) == 2 assert isinstance(dilation, int) or len(dilation) == 2 if isinstance(stride, int): stride_h = stride_w = stride else: stride_h, stride_w = stride if isinstance(dilation, int): dilation_h = dilation_w = dilation else: dilation_h, dilation_w = dilation batch, in_channels, in_height, in_width = Input.shape out_channles, _, kernel_h, kernel_w = Filter.shape in_channel_tail = in_channels % 4 in_channel_chunk = in_channels // 4 if in_channel_tail == 0: in_channel_tail = 4 else: in_channel_chunk += 1 num_filter_block = out_channles % 4 num_filter_chunk = out_channles // 4 if num_filter_block == 0: num_filter_block = 4 else: num_filter_chunk += 1 pad_value = tvm.tir.const(0, Input.dtype) def _reorder_data(*indices): condition = [] condition.append(indices[1] == in_channel_chunk - 1) condition.append(indices[4] >= in_channel_tail) condition = tvm.tir.all(*condition) return tvm.tir.if_then_else( condition, pad_value, Input[indices[0],indices[1] * 4 + indices[4], indices[2], indices[3]]) # compute: reordered_data = te.compute( [batch, in_channel_chunk, in_height, in_width, 4], _reorder_data, name="input_pack", tag="input_pack", ) def _reorder_weights(*indices): conditionA = [] conditionA.append(indices[0] == num_filter_chunk - 1) conditionA.append(indices[4] >= num_filter_block) conditionAT = tvm.tir.all(*conditionA) conditionO = [] conditionO.append(conditionAT) conditionO.append(indices[1] >= in_channel_chunk * 4 + in_channel_tail) conditionOT = tvm.tir.any(*conditionO) return tvm.tir.if_then_else( conditionOT, pad_value, Filter[indices[0] * 4 + indices[4], indices[1], indices[2], indices[3]]) reordered_filter = te.compute( [num_filter_chunk, in_channel_chunk * 4, kernel_h, kernel_w, 4], _reorder_weights, name="filter_pack", tag="filter_pack", ) # batch, in_channel_chunk, in_height, in_width, in_channel_block = Input.shape # num_filter_chunk, channel, kernel_h, kernel_w, num_filter_block = Filter.shape # compute the output shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 pad_top, pad_left, pad_down, pad_right = nn.get_pad_tuple( padding, (dilated_kernel_h, dilated_kernel_w) ) out_height_orig = out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) out_width_orig = out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) # can output shape be divded by 2 or even 4? # if it cannot be divided, need to extend for further help with split # theortically there should be addition padding for inputs, but it will be optimized by # cache_read InferBound. We must proceed pad here exactly to produce tensor which is # required for calculation of original out size, not more! In other case intermediate # tensor might be allcoated with less sizes while compute will try to fill the expanded # one - data discrepancy as a result # And in case of textures it is not a problem if we provide texture of less size because # 1. It is not important which valuses would be for extra calc - these calculations are # required only for better utilizatin of GPU fit to working groups # 2. When we request pixel out opf bound, texture will handle this correctly. As mentioned # above, the value itself is not important if out_height % 2 != 0: out_height += 1 if out_width % 2 != 0: out_width += 1 if out_height % 4 != 0: out_height += 2 if out_width % 4 != 0: out_width += 2 # compute graph pad_before = [0, 0, pad_top, pad_left, 0] pad_after = [0, 0, pad_down, pad_right, 0] # calculation of real used input size: input_latest_w = (out_width_orig - 1) * stride_w + (kernel_w - 1) * dilation_w + 1 input_latest_h = (out_height_orig - 1) * stride_h + (kernel_h - 1) * dilation_h + 1 if input_latest_w < in_width + pad_before[3] + pad_after[3]: pad_after[3] -= in_width + pad_before[3] + pad_after[3] - input_latest_w if input_latest_h < in_height + pad_before[2] + pad_after[2]: pad_after[2] -= in_height + pad_before[2] + pad_after[2] - input_latest_h temp = nn.pad(reordered_data, pad_before, pad_after, name="pad_temp") rcc = te.reduce_axis((0, in_channel_chunk), name="rcc") rcb = te.reduce_axis((0, 4), name="rcb") ry = te.reduce_axis((0, kernel_h), name="ry") rx = te.reduce_axis((0, kernel_w), name="rx") conv = te.compute( (batch, num_filter_chunk, out_height, out_width, 4), lambda nn, ffc, yy, xx, ffb: te.sum( (temp[nn, rcc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rcb] * reordered_filter[ffc, rcc * 4 + rcb, ry, rx, ffb]).astype(args["accumulator"]), axis=[rcc, rcb, ry, rx], ), tag="conv2d_nchwc_tpack", ) # conv = s.cache_write(conv, "local") does not work properly, it does not create # intermediate buffer, continues to read/write from global tensor as accumulator and # leads to the crash in runtime # due to this reason we had to use such dummy cast and compute_at to create such intermediate # accumulator with local scope dummy_cast = te.compute((batch, num_filter_chunk, out_height_orig, out_width_orig, 4), lambda n,fc,y,x,fb: conv[n,fc,y,x,fb].astype(out_dtype), tag="dummy_cast") return te.compute((batch, out_channles, out_height_orig, out_width_orig), lambda n,c,y,x: dummy_cast[n,c // 4,y,x,c % 4], tag="cast_from_acc" + args["accumulator"][-2:])
def compute_conv2d_gemm_without_weight_transform( cfg, data, B_interleaved_t, strides, padding, dilation, out_dtype, kernel_size, output_channels, interleave_A, ): """Compute conv2d by transforming the input, executing GEMM and transforming the output back""" batches, IH, IW, IC = get_const_tuple(data.shape) KH, KW = get_const_tuple(kernel_size) OC = get_const_int(output_channels) kernel_area = KH * KW if isinstance(dilation, int): dilation_h = dilation_w = dilation else: dilation_h, dilation_w = get_const_tuple(dilation) dilated_kernel_h = (KH - 1) * dilation_h + 1 dilated_kernel_w = (KW - 1) * dilation_w + 1 pad_top, pad_left, pad_down, pad_right = get_pad_tuple( padding, (dilated_kernel_h, dilated_kernel_w)) HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1 OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1 if pad_top or pad_left: data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0], name="data_pad") else: data_pad = data # Im2col M = OH * OW K = IC * kernel_area N = OC A_shape = (batches, M, K) if kernel_area == 1: A = tvm.topi.reshape(data_pad, A_shape) else: A = te.compute( A_shape, lambda n, x, y: data_pad[n, HSTR * (x // OW) + dilation_h * ( (y // IC) // KW), WSTR * (x % OW) + dilation_w * ((y // IC) % KW), y % IC, ], name="data_im2col", ) # Pad if necessary N_transformed = B_interleaved_t.shape[0] tile_rows_B = B_interleaved_t.shape[2] tile_cols_B = B_interleaved_t.shape[3] # Select the tiling strategy for A. # The tiling information is chosen to maximize register usage during # the tile computation. # # Please refer to: # - https://discuss.tvm.apache.org/t/rfc-accelerate-quantized-convolution-through-dot-product # - Conv2DGemmWeightTransformRel in src/relay/op/nn/convolution.h # In order to have more information # if is_dotprod_available() and interleave_A: # If dot product has been enabled, and we are interleaving A # tile size should be 8x4 tile_rows_A = 8 tile_cols_A = 4 else: # If either there is no dot product or if we are using a native strategy # tile size should be 4x16 tile_rows_A = 4 tile_cols_A = 16 pad_M = 0 pad_K = 0 if M % tile_rows_A != 0: pad_M = tile_rows_A - (M % tile_rows_A) if K % tile_cols_A != 0: pad_K = tile_cols_A - (K % tile_cols_A) M_padded = M + pad_M K_padded = K + pad_K N_padded = N_transformed * tile_rows_B pad_before = (0, 0, 0) pad_after = (0, pad_M, pad_K) if pad_M != 0 or pad_K != 0: A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded") idxm = tvm.tir.indexmod k = te.reduce_axis((0, K_padded), "k") if interleave_A: # Configuration space configure_knobs(cfg, M_padded, K_padded) # Pack the input data A_interleaved = te.compute( (batches, M_padded // tile_rows_A, K_padded // tile_cols_A, tile_rows_A, tile_cols_A), lambda b, x, y, z, w: A[b, z + tile_rows_A * x, w + tile_cols_A * y ], name="A_interleaved", ) # Execute GEMM C_interleaved = te.compute( (batches, M_padded // tile_rows_A, N_transformed, tile_rows_A, tile_rows_B), lambda b, x, y, w, z: te.sum( A_interleaved[b, x, k // tile_cols_A, w, idxm(k, tile_cols_A)].astype("int32") * B_interleaved_t[y, k // tile_cols_B, z, idxm(k, tile_cols_B)].astype("int32"), axis=k, ), name="C_interleaved", ) # Unpack the result C = te.compute( (batches, M, N), lambda b, x, y: C_interleaved[ b, x // tile_rows_A, y // tile_rows_B, idxm(x, tile_rows_A), idxm(y, tile_rows_B)].astype(out_dtype), name="C", ) zero = tvm.tir.const(0) else: # No need to pack/unpack, execute GEMM directly C = te.compute( (batches, M_padded, N_padded), lambda b, x, y: te.sum( A[b, x, k].astype("int32") * B_interleaved_t[ y // tile_rows_B, k // tile_cols_B, idxm(y, tile_rows_B), idxm(k, tile_cols_B)].astype("int32"), axis=k, ), name="C", ) # We need to ensure that infer bound pass does not remove the padding # which is necessary for the tensorizations to work. So we need to # add a dummy reference to the padding area of the result zero = (tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1] - tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1]) # Reshape the result into a convolution output out_shape = (batches, OH, OW, OC) out = te.compute( out_shape, lambda b, x, y, z: (C(b, y + OW * x, z) + zero).astype(out_dtype), name="conv2d_gemm_output", ) return out
def conv2d_winograd_comp( cfg, data, kernel, strides, padding, dilation, out_dtype, args, pre_computed, layout ): """Compute declaration for winograd Parameters ---------- cfg: ConfigEntity The config for this template data: tvm.te.Tensor 4-D or 5-D Data tensor with shape NCHW or NCHW4c kernel: tvm.te.Tensor 4-D or 5-D tensor with shape OIHW or OIHW4o strides: int or a list/tuple of two ints stride size, or [stride_height, stride_width] padding: int or a list/tuple of 2 or 4 ints padding size, or [pad_height, pad_width] for 2 ints, or [pad_top, pad_left, pad_bottom, pad_right] for 4 ints dilation: int or a list/tuple of two ints dilation size, or [dilation_height, dilation_width] out_dtype: str The output type. This is used for mixed precision. args: dict Dictionary with additional arguments, e.g. accumulator type pre_computed: bool Flag if weights were pre computed if true or the weights should be computed in runtime layout: str NHWC or NCHW values are accepted Returns ------- output: tvm.te.Tensor 4-D or 5-D with shape NCHW or NCHW4c """ assert layout in ("NCHW", "NHWC") tile_size = infer_tile_size(data, layout) if isinstance(dilation, int): dilation_h = dilation_w = dilation else: dilation_h, dilation_w = dilation HSTR, WSTR = (strides, strides) if isinstance(strides, int) else strides convert_from4d = False if len(data.shape) == 4: convert_from4d = True if layout == "NCHW": N, DCI, H, W = get_const_tuple(data.shape) else: N, H, W, DCI = get_const_tuple(data.shape) if not pre_computed: if layout == "NCHW": out_channels, CI, KH, KW = get_const_tuple(kernel.shape) else: KH, KW, CI, out_channels = get_const_tuple(kernel.shape) else: alpha, _, CI, out_channels = get_const_tuple(kernel.shape) KH = KW = alpha + 1 - tile_size in_channel_chunks, in_channel_block, in_channel_tail = split_to_chunks(CI, 4) out_channel_chunks, out_channel_block, out_channel_tail = split_to_chunks(out_channels, 4) if autotvm.GLOBAL_SCOPE.in_tuning is True: if layout == "NCHW": dshape = (N, in_channel_chunks, H, W, in_channel_block) else: dshape = (N, H, W, in_channel_chunks, in_channel_block) if not pre_computed: # kernel tensor is raw tensor, do strict check if layout == "NCHW": kshape = (out_channel_chunks, CI, KH, KW, out_channel_block) else: kshape = (KH, KW, CI, out_channel_chunks, out_channel_block) else: kshape = (alpha, alpha, CI, out_channel_chunks, out_channel_block) data = tvm.te.placeholder(dshape, data.dtype, name="data_placeholder") kernel = tvm.te.placeholder(kshape, kernel.dtype, name="kernel_placeholder") else: data = pack_input( data, layout, N, in_channel_chunks, in_channel_block, in_channel_tail, H, W ) kernel_layout = "OIHW" if layout == "NCHW" else "HWIO" if not pre_computed: # kernel tensor is raw tensor, do strict check kernel = pack_filter( kernel, kernel_layout, out_channel_chunks, out_channel_block, out_channel_tail, CI, in_channel_chunks, in_channel_block, in_channel_tail, KH, KW, ) else: kernel = pack_filter( kernel, "HWIO", out_channel_chunks, out_channel_block, out_channel_tail, CI, in_channel_chunks, in_channel_block, in_channel_tail, alpha, alpha, ) if layout == "NCHW": N, DCI, H, W, CB = get_const_tuple(data.shape) else: N, H, W, DCI, CB = get_const_tuple(data.shape) if not pre_computed: # kernel tensor is raw tensor, do strict check if layout == "NCHW": CO, CI, KH, KW, COB = get_const_tuple(kernel.shape) else: KH, KW, CI, CO, COB = get_const_tuple(kernel.shape) alpha = KW + tile_size - 1 assert HSTR == 1 and WSTR == 1 and KH == KW else: alpha, _, CI, CO, COB = get_const_tuple(kernel.shape) KH = KW = alpha + 1 - tile_size assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1 if isinstance(N, tvm.tir.Any): N = tvm.te.size_var("n") if not isinstance(H, int) or not isinstance(W, int): raise RuntimeError( "adreno winograd conv2d doesn't support dynamic input\ height or width." ) pt, pl, pb, pr = nn.get_pad_tuple(padding, (KH, KW)) if layout == "NCHW": data_pad = nn.pad(data, (0, 0, pt, pl, 0), (0, 0, pb, pr, 0), name="data_pad") else: data_pad = nn.pad(data, (0, pt, pl, 0, 0), (0, pb, pr, 0, 0), name="data_pad") r = KW m = tile_size A, B, G = winograd_transform_matrices(m, r, out_dtype) H = (H + pt + pb - KH) // HSTR + 1 W = (W + pl + pr - KW) // WSTR + 1 nH, nW = (H + m - 1) // m, (W + m - 1) // m P = N * nH * nW if isinstance(N, int) else nH * nW # transform kernel if not pre_computed: r_kh = te.reduce_axis((0, KH), name="r_kh") r_kw = te.reduce_axis((0, KW), name="r_kw") if layout == "NCHW": kernel_pack = te.compute( (alpha, alpha, CI, CO, COB), lambda eps, nu, ci, co, cob: te.sum( kernel[co][ci][r_kh][r_kw][cob] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw] ), name="kernel_pack", ) else: kernel_pack = te.compute( (alpha, alpha, CI, CO, COB), lambda eps, nu, ci, co, cob: te.sum( kernel[r_kh][r_kw][ci][co][cob] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw] ), name="kernel_pack", ) else: kernel_pack = kernel idxdiv = tvm.tir.indexdiv idxmod = tvm.tir.indexmod if layout == "NCHW": N, CI, _, _, CB = get_const_tuple(data.shape) else: N, _, _, CI, CB = get_const_tuple(data.shape) # pack input tile if layout == "NCHW": input_tile = te.compute( (alpha, alpha, CI, P, CB), lambda eps, nu, c, p, cb: data_pad[idxdiv(p, (nH * nW))][c][ idxmod(idxdiv(p, nW), nH) * m + eps ][idxmod(p, nW) * m + nu][cb], name="d", ) else: input_tile = te.compute( (alpha, alpha, CI, P, CB), lambda eps, nu, c, p, cb: data_pad[idxdiv(p, (nH * nW))][ idxmod(idxdiv(p, nW), nH) * m + eps ][idxmod(p, nW) * m + nu][c][cb], name="d", ) # transform data r_a = te.reduce_axis((0, alpha), "r_a") r_b = te.reduce_axis((0, alpha), "r_a") data_pack = te.compute( (P, CI, alpha, alpha, CB), lambda p, ci, eps, nu, cb: te.sum( input_tile[r_a][r_b][ci][p][cb] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b] ), name="data_pack", ) # repack transformed data data_pack_trans = te.compute( (alpha, alpha, CI, P, CB), lambda eps, nu, c, p, cb: data_pack[p][c][eps][nu][cb], name="data_pack_trans", ) # do batch gemm ci = te.reduce_axis((0, CI), name="ci") cb = te.reduce_axis((0, CB), name="cb") bgemm = te.compute( (alpha, alpha, CO, P, COB), lambda eps, nu, co, p, cob: te.sum( ( kernel_pack[eps][nu][ci * CB + cb][co][cob] * data_pack_trans[eps][nu][ci][p][cb] ).astype(args["accumulator"]), axis=[ci, cb], ), name="bgemm", ) # inverse transform r_a = te.reduce_axis((0, alpha), "r_a") r_b = te.reduce_axis((0, alpha), "r_a") inverse = te.compute( (CO, P, m, m, COB), lambda co, p, vh, vw, cob: te.sum( bgemm[r_a][r_b][co][p][cob] * (A[r_a][vh] * A[r_b][vw]).astype(args["accumulator"]), axis=[r_a, r_b], ), name="inverse", ) # output if layout == "NCHW": if convert_from4d and autotvm.GLOBAL_SCOPE.in_tuning is False: output = te.compute( (N, out_channels, H, W), lambda n, c, h, w: inverse[c // CB][n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m)][ idxmod(h, m) ][idxmod(w, m)][c % CB].astype(out_dtype), name="output", tag="cast_from_acc" + args["accumulator"][-2:], ) else: output = te.compute( (N, CO, H, W, COB), lambda n, co, h, w, cob: inverse[co][ n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m) ][idxmod(h, m)][idxmod(w, m)][cob].astype(out_dtype), name="output", tag="cast_from_acc" + args["accumulator"][-2:], ) else: if convert_from4d and autotvm.GLOBAL_SCOPE.in_tuning is False: output = te.compute( (N, H, W, out_channels), lambda n, h, w, c: inverse[c // CB][n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m)][ idxmod(h, m) ][idxmod(w, m)][c % CB].astype(out_dtype), name="output", tag="cast_from_acc" + args["accumulator"][-2:], ) else: output = te.compute( (N, H, W, CO, COB), lambda n, h, w, co, cob: inverse[co][ n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m) ][idxmod(h, m)][idxmod(w, m)][cob].astype(out_dtype), name="output", tag="cast_from_acc" + args["accumulator"][-2:], ) if isinstance(N, int): cfg.add_flop(2 * N * CO * COB * H * W * CI * CB * KH * KW) return output
def compute_conv2d_NCHWc_KCRSk_acc32(Input, Filter, stride, padding, dilation, out_dtype=None): """Convolution operator in NCHWc layout.""" if out_dtype is None: out_dtype = Input.dtype assert isinstance(stride, int) or len(stride) == 2 assert isinstance(dilation, int) or len(dilation) == 2 if isinstance(stride, int): stride_h = stride_w = stride else: stride_h, stride_w = stride if isinstance(dilation, int): dilation_h = dilation_w = dilation else: dilation_h, dilation_w = dilation batch, in_channel_chunk, in_height, in_width, in_channel_block = Input.shape num_filter_chunk, channel, kernel_h, kernel_w, num_filter_block = Filter.shape # compute the output shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 pad_top, pad_left, pad_down, pad_right = nn.get_pad_tuple( padding, (dilated_kernel_h, dilated_kernel_w)) out_height = simplify( (in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) out_width = simplify( (in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) # compute graph pad_before = [0, 0, pad_top, pad_left, 0] pad_after = [0, 0, pad_down, pad_right, 0] temp = nn.pad(Input, pad_before, pad_after, name="pad_temp") rcc = te.reduce_axis((0, in_channel_chunk), name="rc") rcb = te.reduce_axis((0, in_channel_block), name="rc") ry = te.reduce_axis((0, kernel_h), name="ry") rx = te.reduce_axis((0, kernel_w), name="rx") # NCHWc x KCRSk # texture: NCH|W|c # texture: K|CRS|k # c = crs//RS # rs = crs % RS # r = rs // W == (crs // S) % R # s = rs % W == crs % S Filter = te.compute( (num_filter_chunk, channel * kernel_h * kernel_w, num_filter_block), lambda ffc, crs, ffb: Filter[ffc, crs // (kernel_h * kernel_w), ( crs // kernel_w) % kernel_h, crs % kernel_w, ffb], name="packed_filter", ) conv = te.compute( (batch, num_filter_chunk, out_height, out_width, num_filter_block), lambda nn, ffc, yy, xx, ffb: te.sum( (temp[nn, rcc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rcb] * Filter[ffc, ( (rcc * in_channel_block + rcb) * kernel_h + ry ) * kernel_w + rx, ffb]).astype(out_dtype), axis=[rcc, rcb, ry, rx], ), tag="conv2d_nchwc_kcrsk_texture", ) output = te.compute( conv.shape, lambda n, fc, y, x, fb: conv[n, fc, y, x, fb].astype("float32")) return output