def max_pool(data, kernel, stride, padding=[[0, 0], [0, 0]], name="max_pool"): assert len(data.shape) == 4, "only support 4-dim pooling" assert len(stride) == 2, "only support 2-dim stride" kernel_height, kernel_width = kernel stride_height, stride_width = stride batch, channel, height, width = data.shape [pad_top, pad_left], [pad_down, pad_right] = padding pad_before = [0, 0, pad_top, pad_left] pad_after = [0, 0, pad_down, pad_right] if padding != [[0, 0], [0, 0]]: data = pad(data, pad_before, pad_after, pad_value=tvm.min_value("float32")) out_height = simplify((height - kernel_height + pad_top + pad_down) // stride_height + 1) out_width = simplify((width - kernel_width + pad_left + pad_right) // stride_width + 1) dheight = hcl.reduce_axis(0, kernel_height) dwidth = hcl.reduce_axis(0, kernel_width) return hcl.compute( (batch, channel, out_height, out_width), lambda i, c, h, w: max(data[i, c, h * stride_height + dheight, w * stride_width + dwidth], axis=[dheight, dwidth]), name=name, attrs=OrderedDict([('out_img_w', out_width), ('out_img_h', out_height), ('in_num', channel), ('kernel_h', kernel[1]), ('kernel_w', kernel[0]), ('stride_h', stride[1]), ('stride_w', stride[0]), ('app_name', tvm.make.StringImm('max_pool'))]))
def max_pool2d_nchw(data, pooling, stride, padding, name='max_pool2d'): assert len(data.shape) == 4, "only support 4-dim pooling" assert len(stride) == 2, "only support 2-dim stride" pooling_h, pooling_w = pooling stride_h, stride_w = stride batch, channel, height, width = data.shape pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple( padding, (pooling_h, pooling_w)) pad_before = [0, 0, pad_top, pad_left] pad_after = [0, 0, pad_bottom, pad_right] if padding != [0, 0]: data = pad(data, pad_before, pad_after, pad_value=tvm.min_value("float32")) out_height = simplify((height - pooling_h + pad_top + pad_bottom) // stride_h + 1) out_width = simplify((width - pooling_w + pad_left + pad_right) // stride_w + 1) dheight = hcl.reduce_axis(0, pooling_h) dwidth = hcl.reduce_axis(0, pooling_w) return hcl.compute( (batch, channel, out_height, out_width), lambda i, c, h, w: max(data[i, c, h * stride_h + dheight, w * stride_w + dwidth], axis=[dheight, dwidth]), name=name, attrs=OrderedDict([('out_img_w', out_width), ('out_img_h', out_height), ('in_num', channel), ('kernel_h', pooling[1]), ('kernel_w', pooling[0]), ('stride_h', stride[1]), ('stride_w', stride[0]), ('app_name', tvm.make.StringImm('max_pool'))]))
def max_pool2d_nhwc(data, pooling, stride=[1, 1], padding=[0, 0], name='max_pool2d'): assert len(data.shape) == 4, "only support 4-dim pooling" assert len(stride) == 2, "only support 2-dim stride" max = hcl.reducer(tvm.min_value(data.dtype), lambda x, y: tvm.make.Max(x, y), data.dtype) pooling_h, pooling_w = pooling stride_h, stride_w = stride batch, height, width, channel = data.shape if len(padding) == 4: pad_top = padding[0] pad_left = padding[1] pad_bottom = padding[2] pad_right = padding[3] else: pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple( padding, (pooling_h, pooling_w)) pad_before = [0, pad_top, pad_left, 0] pad_after = [0, pad_bottom, pad_right, 0] data = pad(data, pad_before, pad_after, pad_value=tvm.min_value(data.dtype)) out_height = simplify((height - pooling_h + pad_top + pad_bottom) // stride_h + 1) out_width = simplify((width - pooling_w + pad_left + pad_right) // stride_w + 1) dheight = hcl.reduce_axis(0, pooling_h) dwidth = hcl.reduce_axis(0, pooling_w) return hcl.compute( (batch, out_height, out_width, channel), lambda i, h, w, c: max(data[i, h * stride_h + dheight, w * stride_w + dwidth, c], axis=[dheight, dwidth]), name=name, attrs=OrderedDict([('out_img_w', out_width), ('out_img_h', out_height), ('in_num', channel), ('kernel_h', pooling[1]), ('kernel_w', pooling[0]), ('stride_h', stride[1]), ('stride_w', stride[0]), ('app_name', tvm.make.StringImm('max_pool'))]))