Пример #1
0
def argminmax_tiling_strategy(out_shape, axis):
    """Custom tiling strategy for argminmax op."""
    strategy = list()
    # when reduce axis is one, we do not need any strategy
    if out_shape[axis] == 1:
        return strategy

    # if reduce first axis, it will transpose to last axis
    # so here we adapt to this change
    if axis == 0:
        temp = out_shape[0]
        out_shape = out_shape[1:]
        out_shape.append(temp)
        axis = len(out_shape) - 1

    # eliminate single axis, which will automatically disappear in halide ir
    # and adjust axis if it is influenced
    shrink = list()
    for i, shp in enumerate(out_shape):
        if shp == 1:
            if i < axis:
                axis -= 1
        else:
            shrink.append(shp)

    for i, _ in enumerate(shrink):
        if i == axis:
            strategy.append(
                ct_util.create_constraint_on_axis(
                    values="FULL",
                    constraints=ct_util.TileConstraint.MAX,
                    axis=i)[0])
        else:
            strategy.append(
                ct_util.create_constraint_on_axis(
                    values=1,
                    constraints=ct_util.TileConstraint.FACTOR,
                    axis=i)[0])
    return strategy
Пример #2
0
def quantized_avgpool_tiling_strategy(data, kernel, stride, pad, quant_algo):
    """Custom tiling for quantized avgpool."""
    batch, c_1, fm_h, fm_w, c_0 = get_shape(data)
    _, [out_h, out_w] = \
        cal_pad_shapes_by_strategy(get_shape(data), kernel, stride, pad)

    strategy = list()
    if c_0 == 16:
        h_cut = out_h
        if fm_h >= 50 and fm_w >= 50:
            h_cut = 3
        dim_ind = 0
        tiling_params = list()
        if batch > 1:
            tiling_params.append([1, ct_util.TileConstraint.FACTOR, dim_ind])
            dim_ind = dim_ind + 1
        if c_1 > 1:
            tiling_params.append([1, ct_util.TileConstraint.FACTOR, dim_ind])
            dim_ind = dim_ind + 1
        tiling_params.append([h_cut, ct_util.TileConstraint.FACTOR, dim_ind])
        tiling_params.append(
            ["H", ct_util.TileConstraint.SET_AXIS_INFO, dim_ind])
        tiling_params.append(
            [out_w, ct_util.TileConstraint.FACTOR, dim_ind + 1])

        if quant_algo is not None:
            tiling_params.append(
                [kernel[0], ct_util.TileConstraint.FACTOR, dim_ind + 2])
            tiling_params.append(
                [kernel[1], ct_util.TileConstraint.FACTOR, dim_ind + 3])
            tiling_params.append(
                [16, ct_util.TileConstraint.FACTOR, dim_ind + 4])
        else:
            tiling_params.append(
                [kernel[0], ct_util.TileConstraint.FACTOR, dim_ind + 3])
            tiling_params.append(
                [kernel[1], ct_util.TileConstraint.FACTOR, dim_ind + 4])
            tiling_params.append(
                [16, ct_util.TileConstraint.FACTOR, dim_ind + 2])

        for para in tiling_params:
            strategy += ct_util.create_constraint_on_axis(values=para[0],
                                                          constraints=para[1],
                                                          axis=para[2])

    return strategy
Пример #3
0
def maxpool_with_argmax_tiling_strategy(data, kernel, stride, pad):
    """Custom tiling for maxpool with argmax version."""
    batch, c1, fm_h, fm_w, c0 = data.shape
    _, [out_h, _] = \
        cal_pad_shapes_by_strategy(get_shape(data), kernel, stride, pad)
    strategy = list()
    if data.ndim == 5 and c0.value == 16:
        h_cut = out_h
        if isinstance(fm_h, akg.tvm.expr.Var) or (fm_h.value >= 50
                                                  and fm_w.value >= 50):
            h_cut = 3
        dim_ind = 0
        if isinstance(batch, akg.tvm.expr.Var) or batch.value > 1:
            strategy += ct_util.create_constraint_on_axis(
                values=1,
                constraints=ct_util.TileConstraint.FACTOR,
                axis=dim_ind)
            dim_ind = dim_ind + 1
        if isinstance(c1, akg.tvm.expr.Var) or c1.value > 1:
            strategy += ct_util.create_constraint_on_axis(
                values=1,
                constraints=ct_util.TileConstraint.FACTOR,
                axis=dim_ind)
            dim_ind = dim_ind + 1
        strategy += ct_util.create_constraint_on_axis(
            values=h_cut,
            constraints=ct_util.TileConstraint.FACTOR,
            axis=dim_ind)
        strategy += ct_util.create_constraint_on_axis(
            values="H",
            constraints=ct_util.TileConstraint.SET_AXIS_INFO,
            axis=dim_ind)
        strategy += ct_util.create_constraint_on_axis(
            values="FULL",
            constraints=ct_util.TileConstraint.MAX,
            axis=dim_ind + 1)
        strategy += ct_util.create_constraint_on_axis(
            values=5,
            constraints=ct_util.TileConstraint.FACTOR,
            axis=dim_ind + 2)
        strategy += ct_util.create_constraint_on_axis(
            values=16,
            constraints=ct_util.TileConstraint.FACTOR,
            axis=dim_ind + 3)
    return strategy
Пример #4
0
def maxpool_with_argmax_custom_tiling_strategy(data):
    """Custom tiling for maxpool with argmax version."""
    batch, c1, _, _, c0 = data.shape
    strategy = list()
    if data.ndim == 5 and c0.value == 16:
        band = 1
        dim_ind = 0
        if isinstance(batch, akg.tvm.expr.Var) or batch.value > 1:
            strategy += ct_util.create_constraint_on_axis(
                values=1,
                constraints=ct_util.TileConstraint.FACTOR,
                band=band,
                axis=dim_ind)
            dim_ind = dim_ind + 1
        if isinstance(c1, akg.tvm.expr.Var) or c1.value > 1:
            strategy += ct_util.create_constraint_on_axis(
                values=1,
                constraints=ct_util.TileConstraint.FACTOR,
                band=band,
                axis=dim_ind)
            dim_ind = dim_ind + 1
        strategy += ct_util.create_constraint_on_axis(
            values=1,
            constraints=ct_util.TileConstraint.FACTOR,
            band=band,
            axis=dim_ind)
        dim_ind = dim_ind + 1
        strategy += ct_util.create_constraint_on_axis(
            values="FULL",
            constraints=ct_util.TileConstraint.MAX,
            band=band,
            axis=dim_ind)
        dim_ind = dim_ind + 1
        strategy += ct_util.create_constraint_on_axis(
            values="FULL",
            constraints=ct_util.TileConstraint.MAX,
            band=band,
            axis=dim_ind)
        dim_ind = dim_ind + 1
        strategy += ct_util.create_constraint_on_axis(
            values="FULL",
            constraints=ct_util.TileConstraint.MAX,
            band=band,
            axis=dim_ind)
        dim_ind = dim_ind + 1
        strategy += ct_util.create_constraint_on_axis(
            values="FULL",
            constraints=ct_util.TileConstraint.MAX,
            band=band,
            axis=dim_ind)
        band = 0
        dim_ind = 0
        strategy += ct_util.create_constraint_on_axis(
            values=1,
            constraints=ct_util.TileConstraint.FACTOR,
            band=band,
            axis=dim_ind)
        dim_ind = dim_ind + 1
        strategy += ct_util.create_constraint_on_axis(
            values=1,
            constraints=ct_util.TileConstraint.FACTOR,
            band=band,
            axis=dim_ind)
        dim_ind = dim_ind + 1

        strategy += ct_util.create_constraint_on_axis(
            values="FULL",
            constraints=ct_util.TileConstraint.MAX,
            band=band,
            axis=dim_ind)
        dim_ind = dim_ind + 1

        strategy += ct_util.create_constraint_on_axis(
            values="FULL",
            constraints=ct_util.TileConstraint.MAX,
            band=band,
            axis=dim_ind)
        dim_ind = dim_ind + 1

        strategy += ct_util.create_constraint_on_axis(
            values=1,
            constraints=ct_util.TileConstraint.FACTOR,
            band=band,
            axis=dim_ind)
        dim_ind = dim_ind + 1

        strategy += ct_util.create_constraint_on_axis(
            values="FULL",
            constraints=ct_util.TileConstraint.MAX,
            band=band,
            axis=dim_ind)
        dim_ind = dim_ind + 1

        strategy += ct_util.create_constraint_on_axis(
            values="FULL",
            constraints=ct_util.TileConstraint.MAX,
            band=band,
            axis=dim_ind)
        dim_ind = dim_ind + 1
        strategy += ct_util.create_constraint_on_axis(
            values="FULL",
            constraints=ct_util.TileConstraint.MAX,
            band=band,
            axis=dim_ind)
    return strategy