def batch_norm_fwd(N, C, H, W, dtype="float32"): dshape = (N, C, H, W) oshape = (C, ) bshape = (1, C, 1, 1) sshape = (1, ) data = te.placeholder(dshape, name="data", dtype=dtype) scale = te.placeholder(oshape, name="scale", dtype=dtype) bias = te.placeholder(oshape, name="bias", dtype=dtype) running_mean = te.placeholder(oshape, name="running_mean", dtype=dtype) running_var = te.placeholder(oshape, name="running_var", dtype=dtype) eps = te.placeholder(sshape, name="eps", dtype=dtype) momentum = te.placeholder(sshape, name="momentum", dtype=dtype) axis = (0, 2, 3) num_ele = dshape[0] * dshape[2] * dshape[3] frac_num_ele = 1.0 / num_ele # compute batch mean mean_sum = topi.sum(data, axis, keepdims=True) saved_mean = topi.multiply(mean_sum, frac_num_ele) # compute batch rvars var_sub = topi.subtract(data, saved_mean) var_mul = topi.multiply(var_sub, var_sub) var_sum = topi.sum(var_mul, axis, keepdims=True) var = topi.multiply(var_sum, frac_num_ele) output_add = topi.add(var, eps) saved_rvars = topi.sqrt(output_add) # # compute output output_sub = topi.subtract(data, saved_mean) output_norm = topi.divide(output_sub, saved_rvars) scale_board = topi.reshape(scale, bshape) bias_board = topi.reshape(bias, bshape) output = topi.add(topi.multiply(output_norm, scale_board), bias_board) # reshape saved_rvars saved_rvars = topi.reshape(saved_rvars, oshape) # update running mean running_mean_mul1 = topi.multiply(running_mean, topi.subtract(1.0, momentum)) running_mean_mul2 = topi.multiply(topi.reshape(saved_mean, oshape), momentum) running_mean_out = topi.add(running_mean_mul1, running_mean_mul2) # update running var saved_var_mul1 = topi.multiply(running_var, topi.subtract(1.0, momentum)) saved_var_mul2 = topi.multiply(topi.reshape(var, oshape), momentum) running_var_out = topi.add(saved_var_mul1, saved_var_mul2) # reshape saved_mean saved_mean = topi.reshape(saved_mean, oshape) return [ data, scale, bias, running_mean, running_var, momentum, eps, output, saved_mean, saved_rvars, running_mean_out, running_var_out ]
def simulated_quantize_compute(attrs, inputs, out_type): """Compiler for simulated_quantize.""" assert len(inputs) == 4 assert attrs.sign assert attrs.rounding == "round" data, scale, clip_min, clip_max = inputs if attrs.kind == QAnnotateKind.IDENTITY: return [topi.identity(data)] # simulate rounding error scaled_data = topi.divide(data, scale) clipped_data = topi.maximum(topi.minimum(scaled_data, clip_max), clip_min) round_data = topi.round(clipped_data) # recover data rdata = topi.multiply(round_data, scale) return [rdata]
def batch_norm_bwd(N, C, H, W, dtype="float32"): dshape = (N, C, H, W) oshape = (C, ) bshape = (1, C, 1, 1) sshape = (1, ) data = te.placeholder(dshape, name="data", dtype=dtype) scale = te.placeholder(oshape, name="scale", dtype=dtype) saved_mean = te.placeholder(oshape, name="saved_mean", dtype=dtype) saved_var = te.placeholder(oshape, name="saved_var", dtype=dtype) eps = te.placeholder(sshape, name="eps", dtype=dtype) grad_output = te.placeholder(dshape, name="data", dtype=dtype) axis = (0, 2, 3) num_ele = dshape[0] * dshape[2] * dshape[3] frac_num_ele = 1.0 / num_ele # compute grad_input mean_sum = topi.sum(data, axis, True) mean = topi.multiply(mean_sum, frac_num_ele) var_sub = topi.subtract(data, mean) var_mul = topi.multiply(var_sub, var_sub) var_sum = topi.sum(var_mul, axis, True) var = topi.multiply(var_sum, frac_num_ele) var_eps = topi.add(var, eps) output_sqrt = topi.sqrt(var_eps) x_norm = topi.subtract(data, mean) x_hat = topi.divide(x_norm, output_sqrt) dx_hat = topi.multiply(grad_output, topi.reshape(scale, bshape)) grad_input_sum1 = topi.sum(dx_hat * x_hat, axis, True) grad_input_sum2 = topi.sum(dx_hat, axis, True) grad_input_left = topi.divide(frac_num_ele, topi.sqrt(var_eps)) grad_input_right1 = topi.subtract(topi.multiply(dx_hat, num_ele), grad_input_sum2) grad_input_right2 = topi.multiply(x_hat, grad_input_sum1) grad_input = topi.multiply( grad_input_left, topi.subtract(grad_input_right1, grad_input_right2)) # compute grad_scale and grad_bias grad_scale = topi.sum(grad_output * x_hat, axis) grad_bias = topi.sum(grad_output, axis) return [ data, scale, saved_mean, saved_var, eps, grad_output, grad_input, grad_scale, grad_bias ]
def test_complex_reduce(): in_shape = (2, 3) dtype = "float32" axis = 0 keepdims = False A = te.placeholder(shape=in_shape, name="A", dtype=dtype) B = topi.sum(A, axis=axis, keepdims=keepdims) C = topi.add(B, B) D = topi.multiply(B, B) E = topi.add(C, D) for device, ctx in tvm.testing.enabled_targets(): print("Running on target: %s" % device) with tvm.target.Target(device): s = tvm.topi.testing.get_reduce_schedule(device)(E) foo = tvm.build(s, [A, E], device, name="sum") in_npy = np.random.uniform(-1, 1, size=in_shape).astype(dtype) sum_npy = in_npy.sum(axis=axis, keepdims=keepdims) out_npy = sum_npy * 2 + sum_npy * sum_npy data_tvm = tvm.nd.array(in_npy, ctx=ctx) out_tvm = tvm.nd.empty(shape=out_npy.shape, ctx=ctx, dtype=dtype) foo(data_tvm, out_tvm) tvm.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1e-3, 1e-3)
def multiply_broadcast_compute(input_a, input_b): """Call the multiply op from topi""" return topi.multiply(input_a, input_b)
def multiply_packed(cfg, lhs, rhs): return topi.multiply(lhs, rhs)