def argminmax_tiling_strategy(out_shape, axis): """Custom tiling strategy for argminmax op.""" strategy = list() # when reduce axis is one, we do not need any strategy if out_shape[axis] == 1: return strategy # if reduce first axis, it will transpose to last axis # so here we adapt to this change if axis == 0: temp = out_shape[0] out_shape = out_shape[1:] out_shape.append(temp) axis = len(out_shape) - 1 # eliminate single axis, which will automatically disappear in halide ir # and adjust axis if it is influenced shrink = list() for i, shp in enumerate(out_shape): if shp == 1: if i < axis: axis -= 1 else: shrink.append(shp) for i, _ in enumerate(shrink): if i == axis: strategy.append( ct_util.create_constraint_on_axis( values="FULL", constraints=ct_util.TileConstraint.MAX, axis=i)[0]) else: strategy.append( ct_util.create_constraint_on_axis( values=1, constraints=ct_util.TileConstraint.FACTOR, axis=i)[0]) return strategy
def quantized_avgpool_tiling_strategy(data, kernel, stride, pad, quant_algo): """Custom tiling for quantized avgpool.""" batch, c_1, fm_h, fm_w, c_0 = get_shape(data) _, [out_h, out_w] = \ cal_pad_shapes_by_strategy(get_shape(data), kernel, stride, pad) strategy = list() if c_0 == 16: h_cut = out_h if fm_h >= 50 and fm_w >= 50: h_cut = 3 dim_ind = 0 tiling_params = list() if batch > 1: tiling_params.append([1, ct_util.TileConstraint.FACTOR, dim_ind]) dim_ind = dim_ind + 1 if c_1 > 1: tiling_params.append([1, ct_util.TileConstraint.FACTOR, dim_ind]) dim_ind = dim_ind + 1 tiling_params.append([h_cut, ct_util.TileConstraint.FACTOR, dim_ind]) tiling_params.append( ["H", ct_util.TileConstraint.SET_AXIS_INFO, dim_ind]) tiling_params.append( [out_w, ct_util.TileConstraint.FACTOR, dim_ind + 1]) if quant_algo is not None: tiling_params.append( [kernel[0], ct_util.TileConstraint.FACTOR, dim_ind + 2]) tiling_params.append( [kernel[1], ct_util.TileConstraint.FACTOR, dim_ind + 3]) tiling_params.append( [16, ct_util.TileConstraint.FACTOR, dim_ind + 4]) else: tiling_params.append( [kernel[0], ct_util.TileConstraint.FACTOR, dim_ind + 3]) tiling_params.append( [kernel[1], ct_util.TileConstraint.FACTOR, dim_ind + 4]) tiling_params.append( [16, ct_util.TileConstraint.FACTOR, dim_ind + 2]) for para in tiling_params: strategy += ct_util.create_constraint_on_axis(values=para[0], constraints=para[1], axis=para[2]) return strategy
def maxpool_with_argmax_tiling_strategy(data, kernel, stride, pad): """Custom tiling for maxpool with argmax version.""" batch, c1, fm_h, fm_w, c0 = data.shape _, [out_h, _] = \ cal_pad_shapes_by_strategy(get_shape(data), kernel, stride, pad) strategy = list() if data.ndim == 5 and c0.value == 16: h_cut = out_h if isinstance(fm_h, akg.tvm.expr.Var) or (fm_h.value >= 50 and fm_w.value >= 50): h_cut = 3 dim_ind = 0 if isinstance(batch, akg.tvm.expr.Var) or batch.value > 1: strategy += ct_util.create_constraint_on_axis( values=1, constraints=ct_util.TileConstraint.FACTOR, axis=dim_ind) dim_ind = dim_ind + 1 if isinstance(c1, akg.tvm.expr.Var) or c1.value > 1: strategy += ct_util.create_constraint_on_axis( values=1, constraints=ct_util.TileConstraint.FACTOR, axis=dim_ind) dim_ind = dim_ind + 1 strategy += ct_util.create_constraint_on_axis( values=h_cut, constraints=ct_util.TileConstraint.FACTOR, axis=dim_ind) strategy += ct_util.create_constraint_on_axis( values="H", constraints=ct_util.TileConstraint.SET_AXIS_INFO, axis=dim_ind) strategy += ct_util.create_constraint_on_axis( values="FULL", constraints=ct_util.TileConstraint.MAX, axis=dim_ind + 1) strategy += ct_util.create_constraint_on_axis( values=5, constraints=ct_util.TileConstraint.FACTOR, axis=dim_ind + 2) strategy += ct_util.create_constraint_on_axis( values=16, constraints=ct_util.TileConstraint.FACTOR, axis=dim_ind + 3) return strategy
def maxpool_with_argmax_custom_tiling_strategy(data): """Custom tiling for maxpool with argmax version.""" batch, c1, _, _, c0 = data.shape strategy = list() if data.ndim == 5 and c0.value == 16: band = 1 dim_ind = 0 if isinstance(batch, akg.tvm.expr.Var) or batch.value > 1: strategy += ct_util.create_constraint_on_axis( values=1, constraints=ct_util.TileConstraint.FACTOR, band=band, axis=dim_ind) dim_ind = dim_ind + 1 if isinstance(c1, akg.tvm.expr.Var) or c1.value > 1: strategy += ct_util.create_constraint_on_axis( values=1, constraints=ct_util.TileConstraint.FACTOR, band=band, axis=dim_ind) dim_ind = dim_ind + 1 strategy += ct_util.create_constraint_on_axis( values=1, constraints=ct_util.TileConstraint.FACTOR, band=band, axis=dim_ind) dim_ind = dim_ind + 1 strategy += ct_util.create_constraint_on_axis( values="FULL", constraints=ct_util.TileConstraint.MAX, band=band, axis=dim_ind) dim_ind = dim_ind + 1 strategy += ct_util.create_constraint_on_axis( values="FULL", constraints=ct_util.TileConstraint.MAX, band=band, axis=dim_ind) dim_ind = dim_ind + 1 strategy += ct_util.create_constraint_on_axis( values="FULL", constraints=ct_util.TileConstraint.MAX, band=band, axis=dim_ind) dim_ind = dim_ind + 1 strategy += ct_util.create_constraint_on_axis( values="FULL", constraints=ct_util.TileConstraint.MAX, band=band, axis=dim_ind) band = 0 dim_ind = 0 strategy += ct_util.create_constraint_on_axis( values=1, constraints=ct_util.TileConstraint.FACTOR, band=band, axis=dim_ind) dim_ind = dim_ind + 1 strategy += ct_util.create_constraint_on_axis( values=1, constraints=ct_util.TileConstraint.FACTOR, band=band, axis=dim_ind) dim_ind = dim_ind + 1 strategy += ct_util.create_constraint_on_axis( values="FULL", constraints=ct_util.TileConstraint.MAX, band=band, axis=dim_ind) dim_ind = dim_ind + 1 strategy += ct_util.create_constraint_on_axis( values="FULL", constraints=ct_util.TileConstraint.MAX, band=band, axis=dim_ind) dim_ind = dim_ind + 1 strategy += ct_util.create_constraint_on_axis( values=1, constraints=ct_util.TileConstraint.FACTOR, band=band, axis=dim_ind) dim_ind = dim_ind + 1 strategy += ct_util.create_constraint_on_axis( values="FULL", constraints=ct_util.TileConstraint.MAX, band=band, axis=dim_ind) dim_ind = dim_ind + 1 strategy += ct_util.create_constraint_on_axis( values="FULL", constraints=ct_util.TileConstraint.MAX, band=band, axis=dim_ind) dim_ind = dim_ind + 1 strategy += ct_util.create_constraint_on_axis( values="FULL", constraints=ct_util.TileConstraint.MAX, band=band, axis=dim_ind) return strategy