Example #1
0
def _scalar_dequant_v100_v2(x_l0c, deq_ub, align_shape, x_shape, relu_flag,
                            sqrt_mode):
    """
    dequant for scale in v100

    """
    res = tvm.compute(
        align_shape,
        lambda i, j, k, l:
        (x_l0c(i, j, k, l).astype("float16") * deq_ub(0, 0, 0, 0)),
        name='dequant_to_fp16')

    if sqrt_mode:
        res = tvm.compute(x_shape,
                          lambda i, j, k, l:
                          (res(i, j, k, l) * deq_ub(0, 0, 0, 0)),
                          name='dequant_sqrt')

    if relu_flag:
        res = tvm.compute(x_shape,
                          lambda *indices: tvm.relu(res(*indices)),
                          name="dequant_relu")

    res = tvm.compute(x_shape,
                      lambda *indice: res(*indice),
                      name="res",
                      tag='dequant_res',
                      attrs={
                          'sqrt_mode': sqrt_mode,
                          'relu_mode': relu_flag,
                          'is_scalar': 1
                      })
    return res
Example #2
0
def _scalar_dequant_v100(x, x_shape, align_shape, deq_scale, relu_flag,
                         sqrt_mode):
    """
    dequant for scale in v100

    """
    res_f16 = tvm.compute(
        align_shape,
        lambda i, j, k, l:
        (x(i, j, k, l).astype("float16") * deq_scale(0, 0, 0, 0, 0)),
        name='dequant1',
        tag="dequant1_scale")

    res = tvm.compute(x_shape,
                      lambda *indice: res_f16(*indice),
                      name='dequant_remove_pad',
                      tag="dequant_remove_pad")

    if relu_flag:
        res = tvm.compute(x_shape,
                          lambda *indices: tvm.relu(res(*indices)),
                          name="dequant_relu",
                          tag="dequant_relu")
    if sqrt_mode:
        res = tvm.compute(
            x_shape,
            lambda i, j, k, l: (res(i, j, k, l) * deq_scale(0, 0, 0, 0, 0)),
            name='dequant2',
            tag='dequant2_scale',
        )

    return res
Example #3
0
def _s16_to_s8_normal_compute(x, x1, req_scale, x_shape, align_shape, c1_index,
                              tensor_flag, relu_flag):
    """
    generate s16_to_s8 compute
    """
    if x1 is not None:
        if relu_flag:
            res_s16 = tvm.compute(
                x_shape,
                lambda *indices: tvm.relu(x(*indices) + x1(*indices)),
                name="res_s16",
                tag="requant_s16_vaddrelu")
        else:
            res_s16 = tvm.compute(x_shape,
                                  lambda *indices: x(*indices) + x1(*indices),
                                  name="res_s16",
                                  tag="requant_s16_vadd")
    else:
        if relu_flag:
            res_s16 = tvm.compute(x_shape,
                                  lambda *indices: tvm.relu(x(*indices)),
                                  name="res_s16",
                                  tag="requant_s16_relu")
        else:
            res_s16 = tvm.compute(x_shape,
                                  lambda *indices: x(*indices),
                                  name="res_s16",
                                  tag="requant_s16")
    x_shape_list = te.lang.cce.util.shape_to_list(x_shape)
    if tensor_flag:
        res_ub = tvm.compute(align_shape,
                             _deq_cast_compute(res_s16, req_scale, align_shape,
                                               c1_index, tensor_flag,
                                               x_shape_list),
                             name='s16_to_s8',
                             tag="requant_s16_vector")
    else:
        res_ub = tvm.compute(align_shape,
                             _deq_cast_compute(res_s16, req_scale, align_shape,
                                               c1_index, tensor_flag,
                                               x_shape_list),
                             name='s16_to_s8',
                             tag="requant_s16_scale")
    return res_s16, res_ub
Example #4
0
def _vector_depthwise_fused_v100(x, x_shape, align_shape, deq_scale, relu_flag,
                                 sqrt_mode):
    """
    dequant for vector in v100

    """

    if relu_flag:
        res_f16 = tvm.compute(align_shape,
                              lambda i, j, a, k, l: tvm.relu(
                                  x(i, j // 2, j % 2, k, l).astype("float16") *
                                  deq_scale(0, j, 0, 0, l)),
                              name='dequant1',
                              tag="dequant1_vector",
                              attrs={"relu_flag": 1})

    else:
        res_f16 = tvm.compute(align_shape,
                              lambda i, j, a, k, l: x(i, j // 2, j % 2, k, l).
                              astype("float16") * deq_scale(0, j, a, 0, l),
                              name='dequant1',
                              tag="dequant1_vector",
                              attrs={"relu_flag": 0})

    align_shape[3] = x_shape[3].value

    if not sqrt_mode:
        res = tvm.compute(align_shape,
                          lambda *indice: res_f16(*indice),
                          name='dequant_remove_pad',
                          tag="dequant_remove_pad",
                          attrs={"sqrt_flag": 0})
    else:
        res_sqrt = tvm.compute(
            align_shape,
            lambda i, j, a, k, l:
            (res_f16(i, j, a, k, l) * deq_scale(0, j, a, 0, l)),
            name='dequant2',
            tag='dequant2_vector')

        res = tvm.compute(align_shape,
                          lambda *indice: res_sqrt(*indice),
                          name='dequant2_remove_pad',
                          tag="dequant2_remove_pad",
                          attrs={"sqrt_flag": 1})
    return res
Example #5
0
def _vector_dequant_v100(x, x_shape, align_shape, deq_scale, relu_flag,
                         sqrt_mode):
    """
    dequant for vector in v100

    """
    if relu_flag:
        res_f16 = tvm.compute(
            align_shape,
            lambda i, j, k, l: tvm.relu(
                x(i, j, k, l).astype("float16") * deq_scale(0, j, 0, 0, l)),
            name='dequant1',
            tag="dequant1_vector",
            attrs={"relu_flag": 1})

    else:
        res_f16 = tvm.compute(align_shape,
                              lambda i, j, k, l: x(i, j, k, l).astype(
                                  "float16") * deq_scale(0, j, 0, 0, l),
                              name='dequant1',
                              tag="dequant1_vector",
                              attrs={"relu_flag": 0})

    res = tvm.compute(x_shape,
                      lambda *indice: res_f16(*indice),
                      name='dequant_remove_pad',
                      tag="dequant_remove_pad")

    if sqrt_mode:
        res = tvm.compute(x_shape,
                          lambda i, j, k, l:
                          (res(i, j, k, l) * deq_scale(0, j, 0, 0, l)),
                          name='dequant2',
                          tag='dequant2_vector')

    return res
Example #6
0
def _matmul_compute(x, x_shape, deq_scale, sqrt_mode, relu_flag,
                    shape_matmul_origin, c1_index, tensor_flag):
    """
    dequant for matmul
    """
    if _is_support_v200_instruction():
        if tensor_flag:
            res_f16 = tvm.compute(x_shape,
                                  _matmul_vdeq_cast_compute(
                                      x, deq_scale, x_shape, c1_index,
                                      tensor_flag, relu_flag, True),
                                  name='dequant',
                                  tag="dequant_vector")
        else:
            res_f16 = tvm.compute(x_shape,
                                  _matmul_vdeq_cast_compute(
                                      x, deq_scale, x_shape, c1_index,
                                      tensor_flag, relu_flag, True),
                                  name='dequant',
                                  tag="dequant_scale")
    else:
        if tensor_flag:
            res_f16 = tvm.compute(
                x_shape,
                _matmul_vdeq_cast_compute(x, deq_scale, x_shape, c1_index,
                                          tensor_flag, relu_flag, False),
                name='dequant',
                tag="dequant_vector",
            )
        else:
            res_f16 = tvm.compute(
                x_shape,
                _matmul_vdeq_cast_compute(x, deq_scale, x_shape, c1_index,
                                          tensor_flag, relu_flag, False),
                name='dequant',
                tag="dequant",
            )
        if sqrt_mode:
            if tensor_flag:
                res_f16 = tvm.compute(x_shape,
                                      _matmul_vdeq_cast_compute(
                                          res_f16, deq_scale, x_shape,
                                          c1_index, tensor_flag, relu_flag,
                                          False),
                                      name='dequant_sqrt',
                                      tag="dequant_vector_sqrt")
            else:
                res_f16 = tvm.compute(x_shape,
                                      _matmul_vdeq_cast_compute(
                                          res_f16, deq_scale, x_shape,
                                          c1_index, tensor_flag, relu_flag,
                                          False),
                                      name='dequant_sqrt',
                                      tag="dequant_sqrt")

        if relu_flag:
            res_f16 = tvm.compute(x_shape,
                                  lambda *indices: tvm.relu(res_f16[indices]),
                                  name="dequant_relu",
                                  tag="dequant_relu")
    if not _is_nz_format(x):
        # convert fractal_z to ND
        res_out = tvm.compute(
            shape_matmul_origin,
            lambda i, j: res_f16[j // 16, i // 16, i % 16, j % 16],
            name='dequant_ND',
            tag='dequant_ND',
            attrs={'format': 'NC1HWC0'})
    else:
        # nz format
        res_out = tvm.compute(x_shape,
                              lambda *i: res_f16[i],
                              name='dequant_NZ',
                              tag='dequant_NZ',
                              attrs={'format': 'FRACTAL_NZ'})
    return res_out