def test_module_api_hooks(): net = MyModule() pre_hook_num = 0 post_hook_num = 0 hooks = [] def pre_hook(module, inputs): nonlocal pre_hook_num pre_hook_num += 1 modified_inputs = tuple(inp + 1 for inp in inputs) return modified_inputs def post_hook(module, inputs, outputs): nonlocal post_hook_num post_hook_num += 1 outputs += 1 return outputs net.apply(lambda module: hooks.append( module.register_forward_pre_hook(pre_hook))) net.apply( lambda module: hooks.append(module.register_forward_hook(post_hook))) shape = (1, 4, 1, 1) x = tensor(np.zeros(shape, dtype=np.float32)) y = net(x) assert pre_hook_num == 4 assert post_hook_num == 4 mean1 = Parameter(np.zeros(shape), dtype=np.float32) bn1 = F.batch_norm(x + 3, mean1, Parameter(np.ones(shape), dtype=np.float32), training=True) np.testing.assert_allclose( net.i.bn.running_mean.numpy(), mean1.numpy(), ) mean2 = Parameter(np.zeros(shape), dtype=np.float32) bn2 = F.batch_norm(bn1 + 3, mean2, Parameter(np.ones(shape), dtype=np.float32), training=True) np.testing.assert_allclose( net.bn.running_mean.numpy(), mean2.numpy(), ) np.testing.assert_allclose((bn2 + 2).numpy(), y.numpy()) assert len(hooks) == 8 for handler in hooks: handler.remove() y = net(x) assert pre_hook_num == 4 assert post_hook_num == 4
def test_set_value(): v0 = np.random.random((2, 3)).astype(np.float32) param = Parameter(v0) v1 = np.random.random((2, 3)).astype(np.float32) param.set_value(v1) np.testing.assert_allclose(param.numpy(), v1, atol=5e-6) v2 = np.random.random((3, 3)).astype(np.float32) # TODO: add this # with pytest.raises(ValueError): # param.set_value(v2) np.testing.assert_allclose(param.numpy(), v1, atol=5e-6)