예제 #1
0
def test_module_api_hooks():
    net = MyModule()
    pre_hook_num = 0
    post_hook_num = 0
    hooks = []

    def pre_hook(module, inputs):
        nonlocal pre_hook_num
        pre_hook_num += 1
        modified_inputs = tuple(inp + 1 for inp in inputs)
        return modified_inputs

    def post_hook(module, inputs, outputs):
        nonlocal post_hook_num
        post_hook_num += 1
        outputs += 1
        return outputs

    net.apply(lambda module: hooks.append(
        module.register_forward_pre_hook(pre_hook)))
    net.apply(
        lambda module: hooks.append(module.register_forward_hook(post_hook)))

    shape = (1, 4, 1, 1)
    x = tensor(np.zeros(shape, dtype=np.float32))
    y = net(x)

    assert pre_hook_num == 4
    assert post_hook_num == 4
    mean1 = Parameter(np.zeros(shape), dtype=np.float32)
    bn1 = F.batch_norm(x + 3,
                       mean1,
                       Parameter(np.ones(shape), dtype=np.float32),
                       training=True)
    np.testing.assert_allclose(
        net.i.bn.running_mean.numpy(),
        mean1.numpy(),
    )
    mean2 = Parameter(np.zeros(shape), dtype=np.float32)
    bn2 = F.batch_norm(bn1 + 3,
                       mean2,
                       Parameter(np.ones(shape), dtype=np.float32),
                       training=True)
    np.testing.assert_allclose(
        net.bn.running_mean.numpy(),
        mean2.numpy(),
    )
    np.testing.assert_allclose((bn2 + 2).numpy(), y.numpy())

    assert len(hooks) == 8
    for handler in hooks:
        handler.remove()
    y = net(x)
    assert pre_hook_num == 4
    assert post_hook_num == 4
예제 #2
0
def test_set_value():
    v0 = np.random.random((2, 3)).astype(np.float32)
    param = Parameter(v0)
    v1 = np.random.random((2, 3)).astype(np.float32)
    param.set_value(v1)
    np.testing.assert_allclose(param.numpy(), v1, atol=5e-6)
    v2 = np.random.random((3, 3)).astype(np.float32)
    # TODO: add this
    # with pytest.raises(ValueError):
    #     param.set_value(v2)
    np.testing.assert_allclose(param.numpy(), v1, atol=5e-6)