コード例 #1
0
def test_weight_standardization_forward_backward(rng, w_shape, channel_axis, output_stat):
    input = np.array(rng.randn(*w_shape).astype(np.float32))
    eps = 1e-05

    x = nn.Variable.from_numpy_array(input)
    output = F.weight_standardization(x, channel_axis, eps, output_stat)
    ref = ref_weight_standardization(input, channel_axis, eps, output_stat)

    if output_stat:
        tmp = F.sink(*output)
        tmp.forward()
        tmp.backward()

        for o, r in zip(output, ref):
            assert o.shape == r.shape
            assert np.allclose(o.d, r, atol=1e-2, rtol=1e-5)

    else:
        output.forward()
        output.backward()

        assert np.allclose(output.d, ref, atol=1e-2, rtol=1e-5)
コード例 #2
0
 def ws_callback(w):
     return F.weight_standardization(w,
                                     channel_axis,
                                     eps=eps,
                                     output_stat=output_stat)
コード例 #3
0
 def callback(x):
     return F.weight_standardization(x, channel_axis=dim)
コード例 #4
0
 def ws_callback(w): return F.weight_standardization(
     w, channel_axis, eps=eps, output_stat=output_stat)
 y = function(x, apply_w=ws_callback, **kwargs)