def fused_relu_grad_bn_double_update_grad(data_1, data_2, data_3, data_4, data_5, data_6, data_7, layout='NHWC'): transform_list = [data_2, data_4, data_5, data_6, data_7] for i in transform_list: if layout == "NCHW": i = topi.transpose(i, axes=(0, 2, 3, 1)) elif layout != "NHWC": raise NotImplementedError( 'Layout not supported {} '.format(layout)) data_tmp1 = topi.full_like(data_7, 0.0) data_tmp2 = topi.greater(data_7, data_tmp1) data_tmp3 = topi.add(data_5, data_6) data_tmp4 = topi.where(data_tmp2, data_tmp3, data_tmp1) data_tmp5 = topi.cast(data_tmp4, 'float32') data_tmp7 = topi.sum(data_tmp5, axis=(0, 1, 2)) n, h, w, c = data_7.shape data_tmp8 = topi.cast(data_2, 'float32') data_tmp9 = topi.full_like(data_tmp7, 1.0/(n*h*w)) data_tmp10 = topi.multiply(data_1, data_tmp9) data_tmp11 = topi.broadcast_to(data_tmp10, data_tmp8.shape) data_tmp12 = topi.subtract(data_tmp8, data_tmp11) data_tmp13 = topi.multiply(data_tmp5, data_tmp12) data_tmp15 = topi.sum(data_tmp13, axis=(0, 1, 2)) data_tmp16 = topi.cast(data_4, 'float32') data_tmp17 = topi.multiply(data_3, data_tmp9) data_tmp18 = topi.broadcast_to(data_tmp17, data_tmp16.shape) data_tmp19 = topi.subtract(data_tmp16, data_tmp18) data_tmp20 = topi.multiply(data_tmp5, data_tmp19) data_tmp22 = topi.sum(data_tmp20, axis=(0, 1, 2)) return [data_tmp7, data_tmp15, data_tmp22]
def fused_relu_grad_bn_reduce_grad(data_1, data_2, data_3, data_4, data_5, data_6, data_7, data_8, data_9, layout='NHWC', target=utils.CUDA): """ fused_relu_grad_bn_reduce_grad. Args: data_1~data_9: tvm.tensor.Tensor. layout: input layout, only 'NCHW', 'NHWC' supported Returns: tvm.tensor.Tensor. """ transform_list = [data_7, data_8, data_9] for i in transform_list: if layout == "NCHW": i = topi.transpose(i, axes=(0, 2, 3, 1)) elif layout != "NHWC": raise NotImplementedError( 'Layout not supported {} '.format(layout)) data_tmp1 = topi.multiply(data_4, data_5) n, h, w, c = data_9.shape data_tmp2 = topi.full_like(data_tmp1, 1.0 / (n * h * w)) data_tmp3 = topi.multiply(data_tmp1, data_tmp2) data_tmp5 = topi.full_like(data_9, 0.0) data_tmp6 = topi.greater(data_9, data_tmp5) data_tmp7 = topi.where(data_tmp6, data_8, data_tmp5) data_tmp8 = topi.cast(data_tmp7, 'float32') data_tmp9 = topi.full_like(data_tmp8, n * h * w) data_tmp10 = topi.multiply(data_tmp8, data_tmp9) data_tmp12 = topi.subtract(data_tmp10, data_3) data_tmp14 = topi.cast(data_7, 'float32') data_tmp15 = topi.multiply(data_6, data_tmp2) data_tmp17 = topi.subtract(data_tmp14, data_tmp15) data_tmp18 = topi.multiply(data_2, data_tmp17) data_tmp20 = topi.divide(data_tmp18, data_1) data_tmp21 = topi.subtract(data_tmp12, data_tmp20) data_tmp22 = topi.multiply(data_tmp3, data_tmp21) data_out = topi.cast(data_tmp22, 'float16') return data_out
def select_compute(condition, x1, x2): """select compute implementation""" shape = get_shape(x1) con_shape = get_shape(condition) num_dtype = x1.dtype bool_dtype = condition.dtype if num_dtype in ("int8", "uint8"): x1_dtype = "float32" ones = akg.lang.cce.broadcast(tvm.const(VALUE_ONE, dtype="float32"), shape, output_dtype="float32") x1 = akg.lang.cce.cast_to(x1, "float32") x2 = akg.lang.cce.cast_to(x2, "float32") else: x1_dtype = num_dtype ones = akg.lang.cce.broadcast(tvm.const(VALUE_ONE, dtype=num_dtype), shape, output_dtype=num_dtype) if bool_dtype == "int8": if x1_dtype == "int32": condition_dtype = akg.lang.cce.ceil(condition) else: condition_dtype = akg.lang.cce.cast_to(condition, x1_dtype) else: if x1_dtype == "int32": condition_dtype = condition else: condition_dtype = akg.lang.cce.cast_to(condition, x1_dtype) if list(con_shape) != list(shape): condition_dtype = akg.lang.cce.broadcast(condition_dtype, shape) vinsn_support_dtype = ("float16", "float32") if utils.product_is_mini(): vinsn_support_dtype = ("float16", ) if num_dtype in vinsn_support_dtype: res = topi.where(condition_dtype, x1, x2) else: # For data types that are not supported by the vector instruction (vcmp and vsel), # if the `topi.where` is directly used, the related instructions generated in the .cce file # are scalar instructions such as `cond ? x1 : x2`, which is very inefficient. # Therefore, other equivalent calculation methods are adopted. condition_opp = akg.lang.cce.vsub(ones, condition_dtype) temp_x = akg.lang.cce.vmul(x1, condition_dtype) temp_y = akg.lang.cce.vmul(x2, condition_opp) res = akg.lang.cce.vadd(temp_x, temp_y) if num_dtype in ("int8", "uint8"): res = akg.lang.cce.cast_to(res, num_dtype) return res
def fused_relu_grad(input1, input2, input3, c1, target=utils.CUDA): """ fused_relu_grad. Args: input1 ~ input3: tvm.tensor.Tensor. c1: const. Returns: Three output (list of tvm.tensor.Tensor). """ data_zero = topi.full_like(input3, c1) cmp_zero = topi.greater(input3, data_zero) data_add = topi.add(input1, input2) return topi.where(cmp_zero, data_add, data_zero)