def fused_bn_update(input1, input2, input3, input4, dtype, c1, c2, c3, c4): """ fused operator. Args: input1 ~ input4: tvm.tensor.Tensor. dtype: dtype of Tensor. c1 ~ c4: const. Returns: Three output (list of tvm.tensor.Tensor). """ const1 = tvm.const(c1, dtype) mul0 = topi.multiply(input2, const1) mul1 = topi.multiply(input1, const1) mul2 = topi.multiply(mul1, mul1) sigma2 = topi.subtract(mul0, mul2) const2 = tvm.const(c2, dtype) rsqrt_val = topi.rsqrt(topi.add(sigma2, const2)) const3 = tvm.const(c3, dtype) mul3 = topi.multiply(sigma2, const3) sub1 = topi.subtract(input3, mul3) const4 = tvm.const(c4, dtype) data1 = topi.multiply(const4, sub1) sub2 = topi.subtract(input4, mul1) data2 = topi.multiply(const4, sub2) return (rsqrt_val, data1, data2)
def fused_mul_div_rsqrt_mul_isfinite_red(input1, input2, out_dtype): """ fused operator. Args: input1: tvm.tensor.Tensor. input2: tvm.tensor.Tensor. dtype: dtype of Tensor. Returns: list of tvm.tensor.Tensor. """ mul_param1 = topi.multiply(input2, input2) divide_val = topi.divide(1, mul_param1) rsqrt_val = topi.rsqrt(divide_val) mul_param0 = topi.multiply(input1, rsqrt_val) isfinite = topi.isfinite(mul_param0) reduce_and = topi.all(isfinite) if mul_param0.dtype != out_dtype: mul_param0 = topi.cast(mul_param0, out_dtype) rsqrt_val = topi.cast(rsqrt_val, out_dtype) divide_val = topi.cast(divide_val, out_dtype) return [reduce_and, mul_param0, rsqrt_val, divide_val]
def _sqrt(data): """Calculate sqrt by using three times newton iteration(Mini) or vsqrt(Cloud).""" if utils.product_is_mini(): data_sqrt = topi.rsqrt(data) data_sqrt = _newton_iter(data, data_sqrt) data_sqrt = _newton_iter(data, data_sqrt) data_sqrt = _newton_iter(data, data_sqrt) return topi.multiply(data, data_sqrt) else: return topi.sqrt(data)
def _vrsqrt_newton(num_to_vrsqrt): """Calculate vrsqrt(num_to_vrsqrt) with newton's method.""" start_value = topi.rsqrt(num_to_vrsqrt) # use newton's method 3 times newton_res_1 = _newton(start_value, num_to_vrsqrt) newton_res_2 = _newton(newton_res_1, num_to_vrsqrt) newton_res_3 = _newton(newton_res_2, num_to_vrsqrt) return newton_res_3
def sqrt_mini_newton_iter_impl(x): """sqrt compute on mini with the Newton's Iteration""" # mini supports the rsqrt instruction, but not the sqrt instruction x_rsqrt = topi.rsqrt(x) x_sqrt = topi.divide(1, x_rsqrt) # newton_iter: x(n+1) = 1/2 *(x(n) + a/x(n)) steps = 3 half = tvm.const(0.5, x.dtype) shape = x.shape for i in range(steps): x_sqrt = tvm.compute(shape, lambda *indice: half * (x_sqrt(*indice) + x(*indice) / x_sqrt(*indice)), name="x_sqrt_%s" % i) return x_sqrt