def _scalar_dequant_v100_v2(x_l0c, deq_ub, align_shape, x_shape, relu_flag, sqrt_mode): """ dequant for scale in v100 """ res = tvm.compute( align_shape, lambda i, j, k, l: (x_l0c(i, j, k, l).astype("float16") * deq_ub(0, 0, 0, 0)), name='dequant_to_fp16') if sqrt_mode: res = tvm.compute(x_shape, lambda i, j, k, l: (res(i, j, k, l) * deq_ub(0, 0, 0, 0)), name='dequant_sqrt') if relu_flag: res = tvm.compute(x_shape, lambda *indices: tvm.relu(res(*indices)), name="dequant_relu") res = tvm.compute(x_shape, lambda *indice: res(*indice), name="res", tag='dequant_res', attrs={ 'sqrt_mode': sqrt_mode, 'relu_mode': relu_flag, 'is_scalar': 1 }) return res
def _scalar_dequant_v100(x, x_shape, align_shape, deq_scale, relu_flag, sqrt_mode): """ dequant for scale in v100 """ res_f16 = tvm.compute( align_shape, lambda i, j, k, l: (x(i, j, k, l).astype("float16") * deq_scale(0, 0, 0, 0, 0)), name='dequant1', tag="dequant1_scale") res = tvm.compute(x_shape, lambda *indice: res_f16(*indice), name='dequant_remove_pad', tag="dequant_remove_pad") if relu_flag: res = tvm.compute(x_shape, lambda *indices: tvm.relu(res(*indices)), name="dequant_relu", tag="dequant_relu") if sqrt_mode: res = tvm.compute( x_shape, lambda i, j, k, l: (res(i, j, k, l) * deq_scale(0, 0, 0, 0, 0)), name='dequant2', tag='dequant2_scale', ) return res
def _s16_to_s8_normal_compute(x, x1, req_scale, x_shape, align_shape, c1_index, tensor_flag, relu_flag): """ generate s16_to_s8 compute """ if x1 is not None: if relu_flag: res_s16 = tvm.compute( x_shape, lambda *indices: tvm.relu(x(*indices) + x1(*indices)), name="res_s16", tag="requant_s16_vaddrelu") else: res_s16 = tvm.compute(x_shape, lambda *indices: x(*indices) + x1(*indices), name="res_s16", tag="requant_s16_vadd") else: if relu_flag: res_s16 = tvm.compute(x_shape, lambda *indices: tvm.relu(x(*indices)), name="res_s16", tag="requant_s16_relu") else: res_s16 = tvm.compute(x_shape, lambda *indices: x(*indices), name="res_s16", tag="requant_s16") x_shape_list = te.lang.cce.util.shape_to_list(x_shape) if tensor_flag: res_ub = tvm.compute(align_shape, _deq_cast_compute(res_s16, req_scale, align_shape, c1_index, tensor_flag, x_shape_list), name='s16_to_s8', tag="requant_s16_vector") else: res_ub = tvm.compute(align_shape, _deq_cast_compute(res_s16, req_scale, align_shape, c1_index, tensor_flag, x_shape_list), name='s16_to_s8', tag="requant_s16_scale") return res_s16, res_ub
def _vector_depthwise_fused_v100(x, x_shape, align_shape, deq_scale, relu_flag, sqrt_mode): """ dequant for vector in v100 """ if relu_flag: res_f16 = tvm.compute(align_shape, lambda i, j, a, k, l: tvm.relu( x(i, j // 2, j % 2, k, l).astype("float16") * deq_scale(0, j, 0, 0, l)), name='dequant1', tag="dequant1_vector", attrs={"relu_flag": 1}) else: res_f16 = tvm.compute(align_shape, lambda i, j, a, k, l: x(i, j // 2, j % 2, k, l). astype("float16") * deq_scale(0, j, a, 0, l), name='dequant1', tag="dequant1_vector", attrs={"relu_flag": 0}) align_shape[3] = x_shape[3].value if not sqrt_mode: res = tvm.compute(align_shape, lambda *indice: res_f16(*indice), name='dequant_remove_pad', tag="dequant_remove_pad", attrs={"sqrt_flag": 0}) else: res_sqrt = tvm.compute( align_shape, lambda i, j, a, k, l: (res_f16(i, j, a, k, l) * deq_scale(0, j, a, 0, l)), name='dequant2', tag='dequant2_vector') res = tvm.compute(align_shape, lambda *indice: res_sqrt(*indice), name='dequant2_remove_pad', tag="dequant2_remove_pad", attrs={"sqrt_flag": 1}) return res
def _vector_dequant_v100(x, x_shape, align_shape, deq_scale, relu_flag, sqrt_mode): """ dequant for vector in v100 """ if relu_flag: res_f16 = tvm.compute( align_shape, lambda i, j, k, l: tvm.relu( x(i, j, k, l).astype("float16") * deq_scale(0, j, 0, 0, l)), name='dequant1', tag="dequant1_vector", attrs={"relu_flag": 1}) else: res_f16 = tvm.compute(align_shape, lambda i, j, k, l: x(i, j, k, l).astype( "float16") * deq_scale(0, j, 0, 0, l), name='dequant1', tag="dequant1_vector", attrs={"relu_flag": 0}) res = tvm.compute(x_shape, lambda *indice: res_f16(*indice), name='dequant_remove_pad', tag="dequant_remove_pad") if sqrt_mode: res = tvm.compute(x_shape, lambda i, j, k, l: (res(i, j, k, l) * deq_scale(0, j, 0, 0, l)), name='dequant2', tag='dequant2_vector') return res
def _matmul_compute(x, x_shape, deq_scale, sqrt_mode, relu_flag, shape_matmul_origin, c1_index, tensor_flag): """ dequant for matmul """ if _is_support_v200_instruction(): if tensor_flag: res_f16 = tvm.compute(x_shape, _matmul_vdeq_cast_compute( x, deq_scale, x_shape, c1_index, tensor_flag, relu_flag, True), name='dequant', tag="dequant_vector") else: res_f16 = tvm.compute(x_shape, _matmul_vdeq_cast_compute( x, deq_scale, x_shape, c1_index, tensor_flag, relu_flag, True), name='dequant', tag="dequant_scale") else: if tensor_flag: res_f16 = tvm.compute( x_shape, _matmul_vdeq_cast_compute(x, deq_scale, x_shape, c1_index, tensor_flag, relu_flag, False), name='dequant', tag="dequant_vector", ) else: res_f16 = tvm.compute( x_shape, _matmul_vdeq_cast_compute(x, deq_scale, x_shape, c1_index, tensor_flag, relu_flag, False), name='dequant', tag="dequant", ) if sqrt_mode: if tensor_flag: res_f16 = tvm.compute(x_shape, _matmul_vdeq_cast_compute( res_f16, deq_scale, x_shape, c1_index, tensor_flag, relu_flag, False), name='dequant_sqrt', tag="dequant_vector_sqrt") else: res_f16 = tvm.compute(x_shape, _matmul_vdeq_cast_compute( res_f16, deq_scale, x_shape, c1_index, tensor_flag, relu_flag, False), name='dequant_sqrt', tag="dequant_sqrt") if relu_flag: res_f16 = tvm.compute(x_shape, lambda *indices: tvm.relu(res_f16[indices]), name="dequant_relu", tag="dequant_relu") if not _is_nz_format(x): # convert fractal_z to ND res_out = tvm.compute( shape_matmul_origin, lambda i, j: res_f16[j // 16, i // 16, i % 16, j % 16], name='dequant_ND', tag='dequant_ND', attrs={'format': 'NC1HWC0'}) else: # nz format res_out = tvm.compute(x_shape, lambda *i: res_f16[i], name='dequant_NZ', tag='dequant_NZ', attrs={'format': 'FRACTAL_NZ'}) return res_out