def _declaration_conv(data, kernel, stride, padding, layout, out_dtype): print("Run in pure nChwc common decl") assert layout == 'NCHW', "only support NCHW convolution for AVX" wkl = get_workload(data, kernel, stride, padding, out_dtype) sch = _get_schedule(wkl) HPAD, WPAD = wkl.hpad, wkl.wpad HSTR, WSTR = wkl.hstride, wkl.wstride batch_size, in_channel, in_height, in_width = get_const_tuple(data.shape) num_filter, _, kernel_height, kernel_width, _, co = get_const_tuple( kernel.shape) num_filter *= co pad_height = in_height + 2 * HPAD pad_width = in_width + 2 * WPAD out_height = (in_height + 2 * HPAD - kernel_height) // HSTR + 1 out_width = (in_width + 2 * WPAD - kernel_width) // WSTR + 1 # pack data DOPAD = (HPAD != 0 and WPAD != 0) if DOPAD: data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") else: data_pad = data shape = (batch_size, in_channel // sch.ic_bn, pad_height, pad_width, sch.ic_bn) data_vec = tvm.compute( shape, lambda n, C, h, w, c: data_pad[n, C * sch.ic_bn + c, h, w], name='data_vec') kernel_vec = kernel # convolution oshape = (batch_size, num_filter // sch.oc_bn, out_height, out_width, sch.oc_bn) unpack_shape = (batch_size, num_filter, out_height, out_width) ic = tvm.reduce_axis((0, in_channel), name='ic') kh = tvm.reduce_axis((0, kernel_height), name='kh') kw = tvm.reduce_axis((0, kernel_width), name='kw') conv = tvm.compute( oshape, lambda n, oc_chunk, oh, ow, oc_block: tvm.sum( data_vec[n, ic // sch.ic_bn, oh * HSTR + kh, ow * WSTR + kw, ic % sch.ic_bn] * kernel_vec[oc_chunk, ic // sch.ic_bn, kh, kw, ic % sch.ic_bn, oc_block], axis=[ic, kh, kw]), name='conv') unpack = tvm.compute( unpack_shape, lambda n, c, h, w: conv[n, c // sch.oc_bn, h, w, c % sch.oc_bn], name='output_unpack', tag='conv2d_nchw') return unpack
def _declaration_conv(wkl, data, kernel): sch = _get_schedule(wkl) out_dtype = wkl.out_dtype HPAD, WPAD = wkl.hpad, wkl.wpad HSTR, WSTR = wkl.hstride, wkl.wstride batch_size = data.shape[0] out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1 out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 DOPAD = (HPAD != 0 and WPAD != 0) if DOPAD: data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad") else: data_pad = data oshape = (batch_size, wkl.out_filter // sch.oc_bn, out_height, out_width, sch.oc_bn) ic = tvm.reduce_axis((0, wkl.in_filter), name='ic') conv = tvm.compute( oshape, lambda n, oc_chunk, oh, ow, oc_block: tvm.sum(data_pad[ n, ic // sch.ic_bn, oh * HSTR, ow * WSTR, ic % sch.ic_bn].astype( out_dtype) * kernel[oc_chunk, ic // sch.ic_bn, ic % sch.ic_bn, oc_block, 0, 0], axis=[ic]), name='conv2d_NCHWc', tag='conv2d_NCHWc') return conv
def traverse(op): """Traverse operators from computation graph""" # inline all one-to-one-mapping operators except the last stage (output) if tag.is_broadcast(op.tag): if op not in s.outputs: s[op].compute_inline() for tensor in op.input_tensors: if tensor.op.input_tensors: traverse(tensor.op) if 'conv2d_nchw' in op.tag: # print('Run in x86-rasp schedule') output = op.output(0) conv_out = op.input_tensors[0] kernel_vec = conv_out.op.input_tensors[1] kernel = kernel_vec.op.input_tensors[0] data_vec = conv_out.op.input_tensors[0] data = data_vec.op.input_tensors[0] data_pad = None if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: data_pad = data data = data_pad.op.input_tensors[0] padding = infer_pad(data, data_pad) if data_pad is None: stride = infer_stride(data, kernel, output) else: stride = infer_stride(data_pad, kernel, output) wkl = _get_workload(data, kernel, stride, padding, output.dtype) sch = _get_schedule(wkl) return _SCH_TO_SCH_FUNC[type(sch)](s, data, data_pad, data_vec, kernel, kernel_vec, conv_out, output, outs[0])
def traverse(op): """Traverse operators from computation graph""" # inline all one-to-one-mapping operators except the last stage (output) if tag.is_broadcast(op.tag): if op not in s.outputs: s[op].compute_inline() for tensor in op.input_tensors: if tensor.op.input_tensors: traverse(tensor.op) if 'conv2d_nChwc' in op.tag: print('Got conv2d_nChwc tag: ' + str(op.tag)) output = op.output(0) # conv_out = op.input_tensors[0] conv_out = output kernel = conv_out.op.input_tensors[1] # kernel = kernel_vec.op.input_tensors[0] data_vec = conv_out.op.input_tensors[0] data = data_vec.op.input_tensors[0] \ if isinstance(data_vec.op, tvm.tensor.ComputeOp) and len(data_vec.op.input_tensors) > 0 and "pad" not in data_vec.op.tag \ else data_vec data_pad = None if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: data_pad = data data = data_pad.op.input_tensors[0] n, ic_chunk, h, w, ic_block = [x.value for x in data.shape] ic = ic_chunk * ic_block original_data = tvm.placeholder((n, ic, h, w), dtype=output.dtype) if data_pad is not None: n, _, pad_h, pad_w, _ = [x.value for x in data_pad.shape] original_data_pad = tvm.placeholder((n, ic, pad_h, pad_w), dtype=output.dtype) padding = infer_pad(original_data, original_data_pad) else: padding = (0, 0) oc, kh, kw = kernel_size original_kernel = tvm.placeholder((oc, ic, kh, kw), dtype=output.dtype) n, oc_chunk, oh, ow, oc_block = [x.value for x in output.shape] original_output = tvm.placeholder((n, oc_chunk * oc_block, oh, ow), dtype=output.dtype) if data_pad is None: stride = infer_stride(original_data, original_kernel, original_output) else: stride = infer_stride(original_data_pad, original_kernel, original_output) wkl = _get_workload(original_data, original_kernel, stride, padding, output.dtype) sch = _get_schedule(wkl) _SCH_TO_SCH_FUNC[type(sch)](s, data, data_pad, data_vec, kernel, conv_out, output, outs[0])
def _declaration_conv(data, kernel, stride, padding, layout, out_dtype): assert layout == 'NCHW', "only support NCHW convolution on rasp" assert data.shape[ 0].value == 1, "only support batch size=1 convolution on rasp" wkl = _get_workload(data, kernel, stride, padding, out_dtype) sch = _get_schedule(wkl) return _SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding, layout, out_dtype)
def _declaration_conv(data, kernel, kernel_size, stride, padding, layout, out_dtype): assert layout == 'NCHW', "only support NCHW convolution on avx" assert data.shape[0].value == 1, "only support batch size=1 convolution on avx" _, ic, _, _ = [x.value for x in data.shape] oc, kh, kw = kernel_size wkl = _get_workload(data, tvm.placeholder((oc, ic, kh, kw)), stride, padding, out_dtype) sch = _get_schedule(wkl) return _SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding, layout, out_dtype)
def _schedule_conv(s, data, data_pad, data_vec, kernel, conv_out, output, last): # no stride and padding info here padding = infer_pad(data, data_pad) if data_pad is None: stride = infer_stride(data, kernel, output) else: stride = infer_stride(data_pad, kernel, output) wkl = get_workload(data, kernel, stride, padding, output.dtype) sch = _get_schedule(wkl) # A, W = data, kernel_pack A0, A1 = data_pad, data_vec # schedule data if A0 is not None: s[A0].compute_inline() batch, ic_chunk, ih, ic_block, iw = s[A1].op.axis parallel_axis = s[A1].fuse(ic_chunk, ih) s[A1].parallel(parallel_axis) C, O0, O = conv_out, output, last CC = s.cache_write(C, 'global') batch, oc_chunk, oh, ow, oc_block = s[C].op.axis oh_outer, oh_inner = s[C].split(oh, factor=sch.oh_factor) s[C].vectorize(oc_block) s[CC].compute_at(s[C], oh_outer) _, oc_chunk, oh, ow, oc_block = s[CC].op.axis ic, = s[CC].op.reduce_axis ic_chunk, ic_block = s[CC].split(ic, factor=sch.ic_bn) oh_outer, oh_inner = s[CC].split(oh, factor=sch.oh_factor) ow_outer, ow_inner = s[CC].split(ow, factor=sch.ow_factor) s[CC].reorder(oc_chunk, oh_outer, ow_outer, ic_chunk, ic_block, oh_inner, ow_inner, oc_block) s[CC].vectorize(oc_block) s[CC].unroll(ow_inner) s[CC].unroll(oh_inner) if O0 != O: s[O0].compute_inline() batch, oc, oh, ow = s[O].op.axis oc_chunk, oc_block = s[O].split(oc, factor=sch.oc_bn) oh_outer, oh_inner = s[O].split(oh, factor=sch.oh_factor) ow_outer, ow_inner = s[O].split(ow, factor=sch.ow_factor) s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) parallel_axis = s[O].fuse(oc_chunk, oh_outer) s[C].compute_at(s[O], parallel_axis) s[O].vectorize(oc_block) s[O].parallel(parallel_axis) return s
def _declaration_conv(data, kernel, stride, padding, layout, out_dtype): assert layout == 'NCHW', "only support NCHW convolution on rasp" assert data.shape[ 0].value == 1, "only support batch size=1 convolution on rasp" wkl = _get_workload(data, kernel, stride, padding, out_dtype) sch = _get_schedule(wkl) HPAD, WPAD = wkl.hpad, wkl.wpad HSTR, WSTR = wkl.hstride, wkl.wstride batch_size, in_channel, in_height, in_width = get_const_tuple(data.shape) num_filter, _, kernel_height, kernel_width = get_const_tuple(kernel.shape) pad_height = in_height + 2 * HPAD pad_width = in_width + 2 * WPAD out_height = (in_height + 2 * HPAD - kernel_height) // HSTR + 1 out_width = (in_width + 2 * WPAD - kernel_width) // WSTR + 1 # input: c, h, w DOPAD = (HPAD != 0 and WPAD != 0) if DOPAD: data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") else: data_pad = data shape = (batch_size, in_channel // sch.ic_bn, pad_height, pad_width, sch.ic_bn) data_vec = tvm.compute( shape, lambda n, C, h, w, c: data_pad[n, C * sch.ic_bn + c, h, w]) shape = (num_filter // sch.oc_bn, in_channel // sch.ic_bn, sch.ic_bn, sch.oc_bn, 1, 1) kernel_pack = tvm.compute( shape, lambda CO, CI, ci, co, h, w: kernel[CO * sch.oc_bn + co, CI * sch.ic_bn + ci, h, w]) oshape = (batch_size, num_filter // sch.oc_bn, out_height, out_width, sch.oc_bn) ic = tvm.reduce_axis((0, in_channel), name='ic') conv = tvm.compute( oshape, lambda n, oc_chunk, oh, ow, oc_block: tvm.sum(data_vec[ n, ic // sch.ic_bn, oh * HSTR, ow * WSTR, ic % sch.ic_bn].astype( out_dtype) * kernel_pack[oc_chunk, ic // sch.ic_bn, ic % sch. ic_bn, oc_block, 0, 0], axis=[ic]), name='conv') oshape = (batch_size, num_filter, out_height, out_width) unpack = tvm.compute( oshape, lambda n, oc, oh, ow: conv[n, oc // sch.oc_bn, oh, ow, oc % sch.oc_bn], tag='conv2d_nchw') return unpack
def _schedule_conv(s, wkl, data, kernel, conv_out, last): sch = _get_schedule(wkl) # schedule data A = data if isinstance(s[A].op, tvm.tensor.ComputeOp): batch, ic_chunk, ih, iw, ic_block = s[A].op.axis parallel_axis = s[A].fuse(ic_chunk, ih) s[A].parallel(parallel_axis) C, O = conv_out, last CC = s.cache_write(C, 'global') batch, oc_chunk, oh, ow, oc_block = s[C].op.axis oh_outer, oh_inner = s[C].split(oh, factor=sch.oh_factor) ow_outer, ow_inner = s[C].split(ow, factor=sch.ow_factor) s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) s[C].vectorize(oc_block) parallel_axis = s[C].fuse(oc_chunk, oh_outer) s[CC].compute_at(s[C], parallel_axis) if C == O: s[C].parallel(parallel_axis) _, oc_chunk, oh, ow, oc_block = s[CC].op.axis ic, = s[CC].op.reduce_axis ic_chunk, ic_block = s[CC].split(ic, factor=sch.ic_bn) oh_outer, oh_inner = s[CC].split(oh, factor=sch.oh_factor) ow_outer, ow_inner = s[CC].split(ow, factor=sch.ow_factor) s[CC].reorder(oc_chunk, oh_outer, ow_outer, ic_chunk, ic_block, oh_inner, ow_inner, oc_block) s[CC].fuse(oc_chunk, oh_outer) s[CC].vectorize(oc_block) s[CC].unroll(ow_inner) s[CC].unroll(oh_inner) if C != O: batch, oc_chunk, oh, ow, oc_block = s[O].op.axis oh_outer, oh_inner = s[O].split(oh, factor=sch.oh_factor) ow_outer, ow_inner = s[O].split(ow, factor=sch.ow_factor) s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) parallel_axis = s[O].fuse(oc_chunk, oh_outer) s[C].compute_at(s[O], parallel_axis) s[O].vectorize(oc_block) s[O].parallel(parallel_axis) return s
def _declaration_conv(data, kernel, num_filter, kernel_size, stride, padding, out_dtype): assert data.shape[ 0].value == 1, "only support batch size=1 convolution on avx" n, ic_chunk, h, w, ic_block = [x.value for x in data.shape] ic = ic_chunk * ic_block oc = num_filter kh, kw = kernel_size wkl = _get_workload(tvm.placeholder((n, ic, h, w), dtype=out_dtype), tvm.placeholder((oc, ic, kh, kw), dtype=out_dtype), stride, padding, out_dtype) sch = _get_schedule(wkl) return _SCH_TO_DECL_FUNC[type(sch)](wkl, data, kernel)
def _schedule_conv(s, wkl, data, kernel, conv_out, last): sch = _get_schedule(wkl) A = data # schedule data if isinstance(s[A].op, tvm.tensor.ComputeOp): batch, ic_chunk, ih, iw, ic_block = s[A].op.axis parallel_axis = s[A].fuse(ic_chunk, ih) s[A].parallel(parallel_axis) # schedule 5-D conv C, O = conv_out, last CC = s.cache_write(C, 'global') _, oc_chunk, oh, ow, oc_block = s[C].op.axis ow_chunk, ow_block = s[C].split(ow, factor=sch.reg_n) s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) parallel_axis = s[C].fuse(oc_chunk, oh) s[C].vectorize(oc_block) if C == O: s[C].parallel(parallel_axis) s[CC].compute_at(s[C], ow_chunk) _, oc_chunk, oh, ow, oc_block = s[CC].op.axis ic, kh, kw = s[CC].op.reduce_axis ow_chunk, ow_block = s[CC].split(ow, factor=sch.reg_n) ic_chunk, ic_block = s[CC].split(ic, factor=sch.ic_bn) if sch.unroll_kw: s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kh, ic_block, kw, ow_block, oc_block) s[CC].unroll(kw) else: s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kh, kw, ic_block, ow_block, oc_block) s[CC].vectorize(oc_block) s[CC].unroll(ow_block) if C != O: batch, oc_chunk, oh, ow, oc_block = s[O].op.axis ow_chunk, ow_block = s[O].split(ow, factor=sch.reg_n) s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) parallel_axis = s[O].fuse(oc_chunk, oh) s[C].compute_at(s[O], parallel_axis) s[O].vectorize(oc_block) s[O].parallel(parallel_axis) return s
def _declaration_conv(data, kernel, stride, padding, layout, out_dtype): assert layout == 'NCHWc', "only support NCHW convolution on rasp" assert data.shape[0].value == 1, "only support batch size=1 convolution on rasp" wkl = get_workload(data, kernel, stride, padding, out_dtype) sch = _get_schedule(wkl) HPAD, WPAD = wkl.hpad, wkl.wpad HSTR, WSTR = wkl.hstride, wkl.wstride batch_size, in_channel_chunk, in_height, in_width, in_channel_block = get_const_tuple(data.shape) num_filter, _, _, co, kernel_height, kernel_width = get_const_tuple(kernel.shape) num_filter *= co pad_height = in_height + 2 * HPAD pad_width = in_width + 2 * WPAD out_height = (in_height + 2 * HPAD - kernel_height) // HSTR + 1 out_width = (in_width + 2 * WPAD - kernel_width) // WSTR + 1 # input: c, h, w DOPAD = (HPAD != 0 and WPAD != 0) if DOPAD: data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad") else: data_pad = data in_channel = in_channel_block * in_channel_chunk if in_channel_block != sch.ic_bn: print('WARNING!!! (1x1) in_channel_block=%d vs sch.ic_bn=%d' % (in_channel_block, sch.ic_bn)) shape = (batch_size, in_channel // sch.ic_bn, pad_height, pad_width, sch.ic_bn) data_vec = tvm.compute(shape, lambda n, C, h, w, c: data_pad[n, (C * sch.ic_bn + c) // in_channel_block, h, w, (C * sch.ic_bn + c) % in_channel_block], tag='conv2d_data_pack') else: data_vec = data_pad kernel_pack = kernel oshape = (batch_size, num_filter // sch.oc_bn, out_height, out_width, sch.oc_bn) ic = tvm.reduce_axis((0, in_channel), name='ic') conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: tvm.sum(data_vec[n, ic // sch.ic_bn, oh * HSTR, ow * WSTR, ic % sch.ic_bn].astype(out_dtype) * kernel_pack[oc_chunk, ic // sch.ic_bn, ic % sch.ic_bn, oc_block, 0, 0], axis=[ic]), name='conv2d_nChwc', tag='conv2d_nChwc') return conv
def traverse(op): """Traverse operators from computation graph""" # inline all one-to-one-mapping operators except the last stage (output) if tag.is_broadcast(op.tag): if op not in s.outputs: s[op].compute_inline() for tensor in op.input_tensors: if tensor.op.input_tensors: traverse(tensor.op) if 'conv2d_nChwc' in op.tag: output = op.output(0) # conv_out = op.input_tensors[0] conv_out = op.input_tensors[0] if 'conv2d_nChwc_unpack' in op.tag else output kernel = conv_out.op.input_tensors[1] # kernel = kernel_vec.op.input_tensors[0] data_vec = conv_out.op.input_tensors[0] data = data_vec.op.input_tensors[0] \ if isinstance(data_vec.op, tvm.tensor.ComputeOp) and len(data_vec.op.input_tensors) > 0 and "pad" not in data_vec.op.tag \ else data_vec data_pad = None if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: data_pad = data data = data_pad.op.input_tensors[0] ndim_input = len(data.shape) if ndim_input == 5: n, ic_chunk, h, w, ic_block = [x.value for x in data.shape] ic = ic_chunk * ic_block else: n, ic, h, w = [x.value for x in data.shape] original_data = tvm.placeholder((n, ic, h, w), dtype=output.dtype) oc = num_filter kh, kw = kernel_size original_kernel = tvm.placeholder((oc, ic, kh, kw), dtype=output.dtype) wkl = _get_workload(original_data, original_kernel, stride, padding, output.dtype) sch = _get_schedule(wkl) _SCH_TO_SCH_FUNC[type(sch)](s, wkl, data, data_pad, data_vec, kernel, conv_out, output, outs[0])
def traverse(op): """Traverse operators from computation graph""" # inline all one-to-one-mapping operators except the last stage (output) if tag.is_broadcast(op.tag): if op not in s.outputs: s[op].compute_inline() for tensor in op.input_tensors: if tensor.op.input_tensors: traverse(tensor.op) if 'conv2d_nchw' in op.tag: output = op.output(0) conv_out = op.input_tensors[0] kernel = conv_out.op.input_tensors[1] # kernel = kernel_vec.op.input_tensors[0] data_vec = conv_out.op.input_tensors[0] data = data_vec.op.input_tensors[0] data_pad = None if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: data_pad = data data = data_pad.op.input_tensors[0] padding = infer_pad(data, data_pad) _, ic, _, _ = [x.value for x in data.shape] oc = num_filter kh, kw = kernel_size original_kernel = tvm.placeholder((oc, ic, kh, kw)) if data_pad is None: stride = infer_stride(data, original_kernel, output) else: stride = infer_stride(data_pad, original_kernel, output) wkl = _get_workload(data, original_kernel, stride, padding, output.dtype) sch = _get_schedule(wkl) _SCH_TO_SCH_FUNC[type(sch)](s, data, data_pad, data_vec, kernel, conv_out, output, outs[0])
def _schedule_conv(s, data, data_pad, data_vec, kernel, kernel_pack, conv_out, output, last): # print('Run in avx512_conv_common sch') # no stride and padding info here """ C, O0, O = conv_out, output, last batch, oc, oh, ow = s[O].op.axis s[O].parallel(batch) return s """ padding = infer_pad(data, data_pad) if data_pad is None: stride = infer_stride(data, kernel, output) else: stride = infer_stride(data_pad, kernel, output) wkl = _get_workload(data, kernel, stride, padding, output.dtype) sch = _get_schedule(wkl) HPAD, WPAD = wkl.hpad, wkl.wpad DOPAD = (HPAD != 0 and WPAD != 0) A, W = data, kernel_pack A0, A1 = data_pad, data_vec # schedule data if DOPAD: s[A0].compute_inline() batch, ic_chunk, ih, ic_block, iw = s[A1].op.axis parallel_axis = s[A1].fuse(ic_chunk, ih) s[A1].parallel(parallel_axis) s[A1].pragma(batch, "parallel_launch_point") s[A1].pragma(parallel_axis, "parallel_stride_pattern") s[A1].pragma(batch, "parallel_barrier_when_finish") # schedule kernel pack if False: oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[W].op.axis s[W].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block) if sch.oc_bn > 1: s[W].vectorize(oc_block) parallel_axis = s[W].fuse(oc_chunk, oh) s[W].parallel(parallel_axis) s[W].pragma(parallel_axis, "parallel_launch_point") s[W].pragma(parallel_axis, "parallel_stride_pattern") s[W].pragma(parallel_axis, "parallel_barrier_when_finish") # schedule conv C, O0, O = conv_out, output, last CC = s.cache_write(C, 'global') _, oc_chunk, oh, ow, oc_block = s[C].op.axis ow_chunk, ow_block = s[C].split(ow, factor=sch.ur_w) s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) s[C].fuse(oc_chunk, oh) s[C].vectorize(oc_block) s[CC].compute_at(s[C], ow_chunk) _, oc_chunk, oh, ow, oc_block = s[CC].op.axis ic, kh, kw = s[CC].op.reduce_axis ow_chunk, ow_block = s[CC].split(ow, factor=sch.ur_w) ic_chunk, ic_block = s[CC].split(ic, factor=sch.ic_bn) if sch.unroll_kw: s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kh, ic_block, kw, ow_block, oc_block) s[CC].unroll(kw) else: s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kh, kw, ic_block, ow_block, oc_block) s[CC].fuse(oc_chunk, oh) s[CC].vectorize(oc_block) s[CC].unroll(ow_block) if O0 != O: s[O0].compute_inline() batch, oc, oh, ow = s[O].op.axis ow_chunk, ow_block = s[O].split(ow, factor=sch.ur_w) oc_chunk, oc_block = s[O].split(oc, factor=sch.oc_bn) s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) parallel_axis = s[O].fuse(oc_chunk, oh) s[C].compute_at(s[O], parallel_axis) s[O].vectorize(oc_block) s[O].parallel(parallel_axis) s[O].pragma(batch, "parallel_launch_point") s[O].pragma(parallel_axis, "parallel_stride_pattern") s[O].pragma(batch, "parallel_barrier_when_finish") return s
def _schedule_conv(s, data, data_pad, data_vec, kernel, kernel_pack, conv_out, output, last): # print('Run in avx512_conv_1x1 sch') # no stride and padding info here padding = infer_pad(data, data_pad) if data_pad is None: stride = infer_stride(data, kernel, output) else: stride = infer_stride(data_pad, kernel, output) wkl = _get_workload(data, kernel, stride, padding, output.dtype) sch = _get_schedule(wkl) A, W = data, kernel_pack A0, A1 = data_pad, data_vec # schedule data if A0 is not None: s[A0].compute_inline() batch, ic_chunk, ih, ic_block, iw = s[A1].op.axis parallel_axis = s[A1].fuse(ic_chunk, ih) s[A1].parallel(parallel_axis) s[A1].pragma(batch, "parallel_launch_point") s[A1].pragma(parallel_axis, "parallel_stride_pattern") s[A1].pragma(batch, "parallel_barrier_when_finish") # schedule kernel pack oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[W].op.axis s[W].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block) if sch.oc_bn > 1: s[W].vectorize(oc_block) parallel_axis = s[W].fuse(oc_chunk, oh) s[W].parallel(parallel_axis) s[W].pragma(parallel_axis, "parallel_launch_point") s[W].pragma(parallel_axis, "parallel_stride_pattern") s[W].pragma(parallel_axis, "parallel_barrier_when_finish") C, O0, O = conv_out, output, last CC = s.cache_write(C, 'global') batch, oc_chunk, oh, ow, oc_block = s[C].op.axis oh_outer, oh_inner = s[C].split(oh, factor=sch.oh_factor) s[C].vectorize(oc_block) s[CC].compute_at(s[C], oh_outer) _, oc_chunk, oh, ow, oc_block = s[CC].op.axis ic, = s[CC].op.reduce_axis ic_chunk, ic_block = s[CC].split(ic, factor=sch.ic_bn) oh_outer, oh_inner = s[CC].split(oh, factor=sch.oh_factor) ow_outer, ow_inner = s[CC].split(ow, factor=sch.ow_factor) s[CC].reorder(oc_chunk, oh_outer, ow_outer, ic_chunk, ic_block, oh_inner, ow_inner, oc_block) s[CC].vectorize(oc_block) s[CC].unroll(ow_inner) s[CC].unroll(oh_inner) if O0 != O: s[O0].compute_inline() batch, oc, oh, ow = s[O].op.axis oc_chunk, oc_block = s[O].split(oc, factor=sch.oc_bn) oh_outer, oh_inner = s[O].split(oh, factor=sch.oh_factor) ow_outer, ow_inner = s[O].split(ow, factor=sch.ow_factor) s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) parallel_axis = s[O].fuse(oc_chunk, oh_outer) s[C].compute_at(s[O], parallel_axis) s[O].vectorize(oc_block) s[O].parallel(parallel_axis) s[O].pragma(batch, "parallel_launch_point") s[O].pragma(parallel_axis, "parallel_stride_pattern") s[O].pragma(batch, "parallel_barrier_when_finish") return s
def _declaration_conv(wkl, data, kernel): sch = _get_schedule(wkl) HPAD, WPAD = wkl.hpad, wkl.wpad HSTR, WSTR = wkl.hstride, wkl.wstride ndim_input = len(data.shape) if ndim_input == 5: batch_size, in_channel_chunk, in_height, in_width, in_channel_block = get_const_tuple( data.shape) in_channel = in_channel_block * in_channel_chunk else: assert ndim_input == 4 in_channel_block = 0 batch_size, in_channel, in_height, in_width = get_const_tuple( data.shape) num_filter, _, kernel_height, kernel_width, _, co = get_const_tuple( kernel.shape) num_filter *= co pad_height = in_height + 2 * HPAD pad_width = in_width + 2 * WPAD out_height = (in_height + 2 * HPAD - kernel_height) // HSTR + 1 out_width = (in_width + 2 * WPAD - kernel_width) // WSTR + 1 # pack data DOPAD = (HPAD != 0 and WPAD != 0) if DOPAD: if ndim_input == 5: data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad") else: assert ndim_input == 4 data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") else: data_pad = data if in_channel_block != sch.ic_bn: print('WARNING!!! (common) in_channel_block=%d vs sch.ic_bn=%d' % (in_channel_block, sch.ic_bn)) shape = (batch_size, in_channel // sch.ic_bn, pad_height, pad_width, sch.ic_bn) if ndim_input == 5: data_vec = tvm.compute( shape, lambda n, C, h, w, c: data_pad[ n, (C * sch.ic_bn + c) // in_channel_block, h, w, (C * sch.ic_bn + c) % in_channel_block], name='data_vec', tag="conv2d_data_pack") else: assert ndim_input == 4 data_vec = tvm.compute( shape, lambda n, C, h, w, c: data_pad[n, (C * sch.ic_bn + c), h, w], name='data_vec', tag="conv2d_data_pack") # data_pad = data_vec else: data_vec = data_pad kernel_vec = kernel # convolution oshape = (batch_size, num_filter // sch.oc_bn, out_height, out_width, sch.oc_bn) ic = tvm.reduce_axis((0, in_channel), name='ic') kh = tvm.reduce_axis((0, kernel_height), name='kh') kw = tvm.reduce_axis((0, kernel_width), name='kw') import re unpack_channel_block = re.findall(r'\d+', sch.layout_out) if len(unpack_channel_block) == 0: conv = tvm.compute( oshape, lambda n, oc_chunk, oh, ow, oc_block: tvm.sum(data_vec[ n, ic // sch.ic_bn, oh * HSTR + kh, ow * WSTR + kw, ic % sch. ic_bn] * kernel_vec[oc_chunk, ic // sch.ic_bn, kh, kw, ic % sch .ic_bn, oc_block], axis=[ic, kh, kw]), name='conv2d') # , tag="conv2d_nChwc") unpack_shape = (batch_size, num_filter, out_height, out_width) unpack = tvm.compute( unpack_shape, lambda n, c, h, w: conv[n, c // sch.oc_bn, h, w, c % sch.oc_bn], name='output_unpack', tag='conv2d_nChwc_unpack') else: assert len(unpack_channel_block) == 1 unpack_channel_block = int(unpack_channel_block[0]) if unpack_channel_block == sch.oc_bn: return tvm.compute( oshape, lambda n, oc_chunk, oh, ow, oc_block: tvm.sum(data_vec[ n, ic // sch.ic_bn, oh * HSTR + kh, ow * WSTR + kw, ic % sch.ic_bn] * kernel_vec[oc_chunk, ic // sch.ic_bn, kh, kw, ic % sch.ic_bn, oc_block], axis= [ic, kh, kw]), name='conv2d', tag="conv2d_nChwc") else: conv = tvm.compute( oshape, lambda n, oc_chunk, oh, ow, oc_block: tvm.sum(data_vec[ n, ic // sch.ic_bn, oh * HSTR + kh, ow * WSTR + kw, ic % sch.ic_bn] * kernel_vec[oc_chunk, ic // sch.ic_bn, kh, kw, ic % sch.ic_bn, oc_block], axis= [ic, kh, kw]), name='conv2d') unpack_shape = (batch_size, num_filter // unpack_channel_block, out_height, out_width, unpack_channel_block) unpack = tvm.compute( unpack_shape, lambda n, C, h, w, c: conv[ n, (C * unpack_channel_block + c) // sch.oc_bn, h, w, (C * unpack_channel_block + c) % sch.oc_bn], name='output_unpack', tag='conv2d_nChwc_unpack') return unpack
def _schedule_conv(s, wkl, data, data_pad, data_vec, kernel, conv_out, output, last): sch = _get_schedule(wkl) HPAD, WPAD = wkl.hpad, wkl.wpad DOPAD = (HPAD != 0 and WPAD != 0) # A, W = data, kernel_vec A0, A1 = data_pad, data_vec # schedule data if DOPAD and "conv2d_data_pack" in s[A1].op.tag: s[A0].compute_inline() if isinstance( s[A1].op, tvm.tensor.ComputeOp): #and "conv2d_data_pack" in s[A1].op.tag: batch, ic_chunk, ih, iw, ic_block = s[A1].op.axis parallel_axis = s[A1].fuse(ic_chunk, ih) s[A1].parallel(parallel_axis) # schedule conv C, O0, O = conv_out, output, last CC = s.cache_write(C, 'global') _, oc_chunk, oh, ow, oc_block = s[C].op.axis ow_chunk, ow_block = s[C].split(ow, factor=sch.reg_n) s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) parallel_axis = s[C].fuse(oc_chunk, oh) s[C].vectorize(oc_block) if C == O: s[C].parallel(parallel_axis) s[CC].compute_at(s[C], ow_chunk) _, oc_chunk, oh, ow, oc_block = s[CC].op.axis ic, kh, kw = s[CC].op.reduce_axis ow_chunk, ow_block = s[CC].split(ow, factor=sch.reg_n) ic_chunk, ic_block = s[CC].split(ic, factor=sch.ic_bn) if sch.unroll_kw: s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kh, ic_block, kw, ow_block, oc_block) s[CC].unroll(kw) else: s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kh, kw, ic_block, ow_block, oc_block) s[CC].vectorize(oc_block) s[CC].unroll(ow_block) if O0 != O: s[O0].compute_inline() if C != O: if len(s[O].op.axis) == 5: batch, oc_chunk, oh, ow, oc_block = s[O].op.axis ow_chunk, ow_block = s[O].split(ow, factor=sch.reg_n) s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) parallel_axis = s[O].fuse(oc_chunk, oh) s[C].compute_at(s[O], parallel_axis) _, oc_block = s[O].split(oc_block, factor=sch.oc_bn) s[O].vectorize(oc_block) s[O].parallel(parallel_axis) else: assert len(s[O].op.axis) == 4 batch, oc, oh, ow = s[O].op.axis ow_chunk, ow_block = s[O].split(ow, factor=sch.reg_n) oc_chunk, oc_block = s[O].split(oc, factor=sch.oc_bn) s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) parallel_axis = s[O].fuse(oc_chunk, oh) s[C].compute_at(s[O], parallel_axis) s[O].vectorize(oc_block) s[O].parallel(parallel_axis) return s
def _schedule_conv(s, data, data_pad, data_vec, kernel, conv_out, output, last): print("Run in prepack common sch") # no stride and padding info here padding = infer_pad(data, data_pad) if data_pad is None: stride = infer_stride(data, kernel, output) else: stride = infer_stride(data_pad, kernel, output) wkl = get_workload(data, kernel, stride, padding, output.dtype) sch = _get_schedule(wkl) HPAD, WPAD = wkl.hpad, wkl.wpad DOPAD = (HPAD != 0 and WPAD != 0) # A, W = data, kernel_vec A0, A1 = data_pad, data_vec # schedule data if DOPAD: s[A0].compute_inline() batch, ic_chunk, ih, ic_block, iw = s[A1].op.axis parallel_axis = s[A1].fuse(ic_chunk, ih) s[A1].parallel(parallel_axis) # schedule conv C, O0, O = conv_out, output, last CC = s.cache_write(C, 'global') _, oc_chunk, oh, ow, oc_block = s[C].op.axis ow_chunk, ow_block = s[C].split(ow, factor=sch.reg_n) s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) s[C].fuse(oc_chunk, oh) s[C].vectorize(oc_block) s[CC].compute_at(s[C], ow_chunk) _, oc_chunk, oh, ow, oc_block = s[CC].op.axis ic, kh, kw = s[CC].op.reduce_axis ow_chunk, ow_block = s[CC].split(ow, factor=sch.reg_n) ic_chunk, ic_block = s[CC].split(ic, factor=sch.ic_bn) if sch.unroll_kw: s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kh, ic_block, kw, ow_block, oc_block) s[CC].unroll(kw) else: s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kh, kw, ic_block, ow_block, oc_block) s[CC].fuse(oc_chunk, oh) s[CC].vectorize(oc_block) s[CC].unroll(ow_block) if O0 != O: s[O0].compute_inline() batch, oc, oh, ow = s[O].op.axis ow_chunk, ow_block = s[O].split(ow, factor=sch.reg_n) oc_chunk, oc_block = s[O].split(oc, factor=sch.oc_bn) s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) parallel_axis = s[O].fuse(oc_chunk, oh) s[C].compute_at(s[O], parallel_axis) s[O].vectorize(oc_block) s[O].parallel(parallel_axis) return s
def _schedule_spatial_conv2d(s, data, data_pad, data_vec, kernel, kernel_vec, conv_out, output, last): # no stride and padding info here padding = infer_pad(data, data_pad) if data_pad is None: stride = infer_stride(data, kernel, output) else: stride = infer_stride(data_pad, kernel, output) wkl = _get_workload(data, kernel, stride, padding, output.dtype) sch = _get_schedule(wkl) H, W = wkl.height, wkl.width CI, CO = wkl.in_filter, wkl.out_filter HK, WK = wkl.hkernel, wkl.wkernel HPAD, WPAD = wkl.hpad, wkl.wpad HSTR, WSTR = wkl.hstride, wkl.wstride HCAT, WCAT = HK - 1, WK - 1 DOPAD = (HPAD != 0 and WPAD != 0) VH = sch.vh VW = sch.vw VC = sch.vc UNROLL = sch.unroll A, B, C = data, kernel, last A0, A1 = data_pad, data_vec B0 = kernel_vec C0, C1 = conv_out, output CC = s.cache_write(C0, "global") _, co, oh, ow, vh, vw, vc = s[C0].op.axis if UNROLL: s[C0].unroll(vw) s[C0].vectorize(vc) s[CC].compute_at(s[C0], ow) _, co, oh, ow, vh, vw, vc = s[CC].op.axis ci, dh, dw = s[CC].op.reduce_axis s[CC].reorder(ci, dh, vh, dw, vw, vc) if UNROLL: s[CC].unroll(vw) s[CC].vectorize(vc) ##### Schedule A if DOPAD: s[A0].compute_inline() _, h, _, _, _, _ = s[A1].op.axis if sch.ba == 1: oaxis = h paxis = h else: oh, ih = s[A1].split(h, sch.ba) oaxis = oh paxis = ih s[A1].parallel(paxis) s[A1].pragma(oaxis, "parallel_launch_point") s[A1].pragma(paxis, "parallel_stride_pattern") s[A1].pragma(oaxis, "parallel_barrier_when_finish") ##### Schedule B co, _, _, _, _ = s[B0].op.axis if sch.bc == 1: oaxis = co paxis = co else: oco, ico = s[B0].split(co, sch.bc) oaxis = oco paxis = ico s[B0].parallel(paxis) s[B0].pragma(oaxis, "parallel_launch_point") s[B0].pragma(paxis, "parallel_stride_pattern") s[B0].pragma(oaxis, "parallel_barrier_when_finish") ##### Schedule C n, co, h, w = s[C].op.axis co, vc = s[C].split(co, VC) oh, ow, vh, vw = s[C].tile(h, w, VH, VW) s[C].reorder(n, co, oh, ow, vh, vw, vc) if C != C1: s[C1].compute_inline() s[C0].compute_at(s[C], ow) if sch.bc == 1: oaxis = co paxis = co else: oco, ico = s[C].split(co, sch.bc) oaxis = oco paxis = ico s[C].parallel(paxis) s[C].pragma(oaxis, "parallel_launch_point") s[C].pragma(paxis, "parallel_stride_pattern") s[C].pragma(oaxis, "parallel_barrier_when_finish") return s
def _declaration_conv(data, kernel, stride, padding, layout, out_dtype): # print('Run in avx512_conv_common decl') assert layout == 'NCHW', "only support NCHW convolution on rasp" assert data.shape[ 0].value == 1, "only support batch size=1 convolution on rasp" wkl = _get_workload(data, kernel, stride, padding, out_dtype) sch = _get_schedule(wkl) HPAD, WPAD = wkl.hpad, wkl.wpad HSTR, WSTR = wkl.hstride, wkl.wstride batch_size, in_channel, in_height, in_width = get_const_tuple(data.shape) if len(kernel.shape) == 4: num_filter, _, kernel_height, kernel_width = get_const_tuple( kernel.shape) else: num_filter, _, kernel_height, kernel_width, ic, oc = get_const_tuple( kernel.shape) num_filter *= oc pad_height = in_height + 2 * HPAD pad_width = in_width + 2 * WPAD out_height = (in_height + 2 * HPAD - kernel_height) // HSTR + 1 out_width = (in_width + 2 * WPAD - kernel_width) // WSTR + 1 # pack data # input: c, h, w shape = (batch_size, in_channel, pad_height, pad_width) DOPAD = (HPAD != 0 and WPAD != 0) if DOPAD: data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") else: data_pad = data # data_pad = tvm.compute(shape, lambda n, c, h, w: tvm.select( # tvm.all(h >= HPAD, h < pad_height - HPAD, w >= WPAD, w < pad_width - WPAD), # data[n, c, h - HPAD, w - WPAD], 0.0 # ), name='data_pad') shape = (batch_size, in_channel // sch.ic_bn, pad_height, sch.ic_bn, pad_width) data_vec = tvm.compute( shape, lambda n, C, h, c, w: data_pad[n, C * sch.ic_bn + c, h, w], name='data_vec') # pack kernel # input: co, ci, h, w # output: gOIhw16i16o if False: shape = (num_filter // sch.oc_bn, in_channel // sch.ic_bn, kernel_height, kernel_width, sch.ic_bn, sch.oc_bn) kernel_pack = tvm.compute( shape, lambda CO, CI, h, w, ci, co: kernel[CO * sch.oc_bn + co, CI * sch. ic_bn + ci, h, w], name='kernel_pack') else: kernel_pack = kernel # convolution oshape = (batch_size, num_filter // sch.oc_bn, out_height, out_width, sch.oc_bn) ovshape = (batch_size, num_filter // sch.oc_bn, out_height, sch.oc_bn, out_width) unpack_shape = (batch_size, num_filter, out_height, out_width) ic = tvm.reduce_axis((0, in_channel), name='ic') kh = tvm.reduce_axis((0, kernel_height), name='kh') kw = tvm.reduce_axis((0, kernel_width), name='kw') conv = tvm.compute( oshape, lambda n, oc_chunk, oh, ow, oc_block: tvm.sum(data_vec[ n, ic // sch.ic_bn, oh * HSTR + kh, ic % sch.ic_bn, ow * WSTR + kw ].astype(out_dtype) * kernel_pack[oc_chunk, ic // sch.ic_bn, kh, kw, ic % sch.ic_bn, oc_block], axis=[ic, kh, kw]), name='conv') unpack = tvm.compute( unpack_shape, lambda n, c, h, w: conv[n, c // sch.oc_bn, h, w, c % sch.oc_bn], name='output_unpack', tag='conv2d_nchw') return unpack
def _declaration_conv(data, kernel, stride, padding, layout, out_dtype): assert layout == 'NCHWc', "only support NCHWc convolution for AVX" wkl = get_workload(data, kernel, stride, padding, out_dtype) sch = _get_schedule(wkl) HPAD, WPAD = wkl.hpad, wkl.wpad HSTR, WSTR = wkl.hstride, wkl.wstride batch_size, in_channel_chunk, in_height, in_width, in_channel_block = get_const_tuple( data.shape) num_filter, _, kernel_height, kernel_width, _, co = get_const_tuple( kernel.shape) num_filter *= co pad_height = in_height + 2 * HPAD pad_width = in_width + 2 * WPAD out_height = (in_height + 2 * HPAD - kernel_height) // HSTR + 1 out_width = (in_width + 2 * WPAD - kernel_width) // WSTR + 1 # pack data DOPAD = (HPAD != 0 and WPAD != 0) if DOPAD: data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad") else: data_pad = data in_channel = in_channel_block * in_channel_chunk if in_channel_block != sch.ic_bn: print('WARNING!!! (common) in_channel_block=%d vs sch.ic_bn=%d' % (in_channel_block, sch.ic_bn)) shape = (batch_size, in_channel // sch.ic_bn, pad_height, pad_width, sch.ic_bn) data_vec = tvm.compute( shape, lambda n, C, h, w, c: data_pad[ n, (C * sch.ic_bn + c) // in_channel_block, h, w, (C * sch.ic_bn + c) % in_channel_block], name='data_vec', tag="conv2d_data_pack") else: data_vec = data_pad kernel_vec = kernel # convolution oshape = (batch_size, num_filter // sch.oc_bn, out_height, out_width, sch.oc_bn) ic = tvm.reduce_axis((0, in_channel), name='ic') kh = tvm.reduce_axis((0, kernel_height), name='kh') kw = tvm.reduce_axis((0, kernel_width), name='kw') conv = tvm.compute( oshape, lambda n, oc_chunk, oh, ow, oc_block: tvm.sum( data_vec[n, ic // sch.ic_bn, oh * HSTR + kh, ow * WSTR + kw, ic % sch.ic_bn] * kernel_vec[oc_chunk, ic // sch.ic_bn, kh, kw, ic % sch.ic_bn, oc_block], axis=[ic, kh, kw]), name='conv2d_nChwc', tag="conv2d_nChwc") return conv
def _schedule_conv(s, data, data_pad, data_vec, kernel, conv_out, output, last): print("Run in prepack 1x1 sch") # no stride and padding info here padding = infer_pad(data, data_pad) if data_pad is None: stride = infer_stride(data, kernel, output) else: stride = infer_stride(data_pad, kernel, output) wkl = get_workload(data, kernel, stride, padding, output.dtype) sch = _get_schedule(wkl) HPAD, WPAD = wkl.hpad, wkl.wpad DOPAD = (HPAD != 0 and WPAD != 0) # A, W = data, kernel_pack A0, A1 = data_pad, data_vec # schedule data if isinstance(s[A1].op, tvm.tensor.ComputeOp): # and "conv2d_data_pack" in s[A1].op.tag: if DOPAD and "conv2d_data_pack" in s[A1].op.tag: s[A0].compute_inline() batch, ic_chunk, ih, iw, ic_block = s[A1].op.axis parallel_axis = s[A1].fuse(ic_chunk, ih) s[A1].parallel(parallel_axis) C, O0, O = conv_out, output, last CC = s.cache_write(C, 'global') batch, oc_chunk, oh, ow, oc_block = s[C].op.axis oh_outer, oh_inner = s[C].split(oh, factor=sch.oh_factor) s[C].vectorize(oc_block) s[CC].compute_at(s[C], oh_outer) _, oc_chunk, oh, ow, oc_block = s[CC].op.axis ic, = s[CC].op.reduce_axis ic_chunk, ic_block = s[CC].split(ic, factor=sch.ic_bn) oh_outer, oh_inner = s[CC].split(oh, factor=sch.oh_factor) ow_outer, ow_inner = s[CC].split(ow, factor=sch.ow_factor) s[CC].reorder(oc_chunk, oh_outer, ow_outer, ic_chunk, ic_block, oh_inner, ow_inner, oc_block) s[CC].vectorize(oc_block) s[CC].unroll(ow_inner) s[CC].unroll(oh_inner) if O0 != O: s[O0].compute_inline() batch, oc_chunk, oh, ow, oc_block = s[O].op.axis # oc_chunk, oc_block = s[O].split(oc, factor=sch.oc_bn) oh_outer, oh_inner = s[O].split(oh, factor=sch.oh_factor) ow_outer, ow_inner = s[O].split(ow, factor=sch.ow_factor) s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) parallel_axis = s[O].fuse(oc_chunk, oh_outer) s[C].compute_at(s[O], parallel_axis) s[O].vectorize(oc_block) s[O].parallel(parallel_axis) return s