def test_weight_standardization_forward_backward(rng, w_shape, channel_axis, output_stat): input = np.array(rng.randn(*w_shape).astype(np.float32)) eps = 1e-05 x = nn.Variable.from_numpy_array(input) output = F.weight_standardization(x, channel_axis, eps, output_stat) ref = ref_weight_standardization(input, channel_axis, eps, output_stat) if output_stat: tmp = F.sink(*output) tmp.forward() tmp.backward() for o, r in zip(output, ref): assert o.shape == r.shape assert np.allclose(o.d, r, atol=1e-2, rtol=1e-5) else: output.forward() output.backward() assert np.allclose(output.d, ref, atol=1e-2, rtol=1e-5)
def ws_callback(w): return F.weight_standardization(w, channel_axis, eps=eps, output_stat=output_stat)
def callback(x): return F.weight_standardization(x, channel_axis=dim)
def ws_callback(w): return F.weight_standardization( w, channel_axis, eps=eps, output_stat=output_stat) y = function(x, apply_w=ws_callback, **kwargs)