def test_bind_stateful(self): class Foo(nn.Module): def setup(self): self.a = nn.Dense(3) self.bn = nn.BatchNorm() self.b = nn.Dense(1) def f(foo, x): x = foo.a(x) x = foo.bn(x, use_running_average=False) return foo.b(x) foo = Foo() x = jnp.ones((4,)) f_init = nn.init_with_output(f, foo) y1, variables = f_init(random.PRNGKey(0), x) foo_b = foo.bind(variables, mutable='batch_stats') y2 = f(foo_b, x) y3, new_state = nn.apply(f, foo, mutable='batch_stats')(variables, x) self.assertEqual(y1, y2) self.assertEqual(y2, y3) bs_1 = new_state['batch_stats'] bs_2 = foo_b.variables['batch_stats'] for x, y in zip(jax.tree_leaves(bs_1), jax.tree_leaves(bs_2)): np.testing.assert_allclose(x, y)
def eval(params, images, z, z_rng): def eval_model(vae): recon_images, mean, logvar = vae(images, z_rng) comparison = jnp.concatenate([images[:8].reshape(-1, self.image_size, self.image_size, 3), recon_images[:8].reshape(-1, self.image_size, self.image_size, 3)]) generate_images = vae.generate(z) generate_images = generate_images.reshape(-1, self.image_size, self.image_size, 3) metrics = self.compute_metrics(recon_images, images, mean, logvar) return metrics, comparison, generate_images return nn.apply(eval_model, self.model())({'params': params})
def test_functional_apply(self): class Foo(nn.Module): def setup(self): self.a = nn.Dense(3) self.b = nn.Dense(1) def f(foo, x): x = foo.a(x) return foo.b(x) foo = Foo() x = jnp.ones((4,)) f_init = nn.init_with_output(f, foo) f_apply = nn.apply(f, foo) y1, variables = f_init(random.PRNGKey(0), x) y2 = f_apply(variables, x) self.assertEqual(y1, y2)