def fused_relu_grad_bn_double_update_grad(data_1, data_2, data_3, data_4, data_5, data_6, data_7, layout='NHWC'): transform_list = [data_2, data_4, data_5, data_6, data_7] for i in transform_list: if layout == "NCHW": i = topi.transpose(i, axes=(0, 2, 3, 1)) elif layout != "NHWC": raise NotImplementedError( 'Layout not supported {} '.format(layout)) data_tmp1 = topi.full_like(data_7, 0.0) data_tmp2 = topi.greater(data_7, data_tmp1) data_tmp3 = topi.add(data_5, data_6) data_tmp4 = topi.where(data_tmp2, data_tmp3, data_tmp1) data_tmp5 = topi.cast(data_tmp4, 'float32') data_tmp7 = topi.sum(data_tmp5, axis=(0, 1, 2)) n, h, w, c = data_7.shape data_tmp8 = topi.cast(data_2, 'float32') data_tmp9 = topi.full_like(data_tmp7, 1.0/(n*h*w)) data_tmp10 = topi.multiply(data_1, data_tmp9) data_tmp11 = topi.broadcast_to(data_tmp10, data_tmp8.shape) data_tmp12 = topi.subtract(data_tmp8, data_tmp11) data_tmp13 = topi.multiply(data_tmp5, data_tmp12) data_tmp15 = topi.sum(data_tmp13, axis=(0, 1, 2)) data_tmp16 = topi.cast(data_4, 'float32') data_tmp17 = topi.multiply(data_3, data_tmp9) data_tmp18 = topi.broadcast_to(data_tmp17, data_tmp16.shape) data_tmp19 = topi.subtract(data_tmp16, data_tmp18) data_tmp20 = topi.multiply(data_tmp5, data_tmp19) data_tmp22 = topi.sum(data_tmp20, axis=(0, 1, 2)) return [data_tmp7, data_tmp15, data_tmp22]
def fused_relu_grad_bn_reduce_grad(data_1, data_2, data_3, data_4, data_5, data_6, data_7, data_8, data_9, layout='NHWC', target=utils.CUDA): """ fused_relu_grad_bn_reduce_grad. Args: data_1~data_9: tvm.tensor.Tensor. layout: input layout, only 'NCHW', 'NHWC' supported Returns: tvm.tensor.Tensor. """ transform_list = [data_7, data_8, data_9] for i in transform_list: if layout == "NCHW": i = topi.transpose(i, axes=(0, 2, 3, 1)) elif layout != "NHWC": raise NotImplementedError( 'Layout not supported {} '.format(layout)) data_tmp1 = topi.multiply(data_4, data_5) n, h, w, c = data_9.shape data_tmp2 = topi.full_like(data_tmp1, 1.0 / (n * h * w)) data_tmp3 = topi.multiply(data_tmp1, data_tmp2) data_tmp5 = topi.full_like(data_9, 0.0) data_tmp6 = topi.greater(data_9, data_tmp5) data_tmp7 = topi.where(data_tmp6, data_8, data_tmp5) data_tmp8 = topi.cast(data_tmp7, 'float32') data_tmp9 = topi.full_like(data_tmp8, n * h * w) data_tmp10 = topi.multiply(data_tmp8, data_tmp9) data_tmp12 = topi.subtract(data_tmp10, data_3) data_tmp14 = topi.cast(data_7, 'float32') data_tmp15 = topi.multiply(data_6, data_tmp2) data_tmp17 = topi.subtract(data_tmp14, data_tmp15) data_tmp18 = topi.multiply(data_2, data_tmp17) data_tmp20 = topi.divide(data_tmp18, data_1) data_tmp21 = topi.subtract(data_tmp12, data_tmp20) data_tmp22 = topi.multiply(data_tmp3, data_tmp21) data_out = topi.cast(data_tmp22, 'float16') return data_out
def fused_relu_grad(input1, input2, input3, c1, target=utils.CUDA): """ fused_relu_grad. Args: input1 ~ input3: tvm.tensor.Tensor. c1: const. Returns: Three output (list of tvm.tensor.Tensor). """ data_zero = topi.full_like(input3, c1) cmp_zero = topi.greater(input3, data_zero) data_add = topi.add(input1, input2) return topi.where(cmp_zero, data_add, data_zero)