Ejemplo n.º 1
0
  def test_bind_stateful(self):
    class Foo(nn.Module):
      def setup(self):
        self.a = nn.Dense(3)
        self.bn = nn.BatchNorm()
        self.b = nn.Dense(1)

    def f(foo, x):
      x = foo.a(x)
      x = foo.bn(x, use_running_average=False)
      return foo.b(x)
      
    foo = Foo()
    x = jnp.ones((4,))
    f_init = nn.init_with_output(f, foo)
    y1, variables = f_init(random.PRNGKey(0), x)
    foo_b = foo.bind(variables, mutable='batch_stats')
    y2 = f(foo_b, x)
    y3, new_state = nn.apply(f, foo, mutable='batch_stats')(variables, x)
    self.assertEqual(y1, y2)
    self.assertEqual(y2, y3)
    bs_1 = new_state['batch_stats']
    bs_2 = foo_b.variables['batch_stats']
    for x, y in zip(jax.tree_leaves(bs_1), jax.tree_leaves(bs_2)):
      np.testing.assert_allclose(x, y)
Ejemplo n.º 2
0
        def eval(params, images, z, z_rng):
            def eval_model(vae):
                recon_images, mean, logvar = vae(images, z_rng)
                comparison = jnp.concatenate([images[:8].reshape(-1, self.image_size, self.image_size, 3),
                                            recon_images[:8].reshape(-1, self.image_size, self.image_size, 3)])
                generate_images = vae.generate(z)
                generate_images = generate_images.reshape(-1, self.image_size, self.image_size, 3)
                metrics = self.compute_metrics(recon_images, images, mean, logvar)
                return metrics, comparison, generate_images

            return nn.apply(eval_model, self.model())({'params': params})
Ejemplo n.º 3
0
  def test_functional_apply(self):
    class Foo(nn.Module):
      def setup(self):
        self.a = nn.Dense(3)
        self.b = nn.Dense(1)

    def f(foo, x):
      x = foo.a(x)
      return foo.b(x)
      
    foo = Foo()
    x = jnp.ones((4,))
    f_init = nn.init_with_output(f, foo)
    f_apply = nn.apply(f, foo)
    y1, variables = f_init(random.PRNGKey(0), x)
    y2 = f_apply(variables, x)
    self.assertEqual(y1, y2)