Пример #1
0
  def test_bind_stateful(self):
    class Foo(nn.Module):
      def setup(self):
        self.a = nn.Dense(3)
        self.bn = nn.BatchNorm()
        self.b = nn.Dense(1)

    def f(foo, x):
      x = foo.a(x)
      x = foo.bn(x, use_running_average=False)
      return foo.b(x)
      
    foo = Foo()
    x = jnp.ones((4,))
    f_init = nn.init_with_output(f, foo)
    y1, variables = f_init(random.PRNGKey(0), x)
    foo_b = foo.bind(variables, mutable='batch_stats')
    y2 = f(foo_b, x)
    y3, new_state = nn.apply(f, foo, mutable='batch_stats')(variables, x)
    self.assertEqual(y1, y2)
    self.assertEqual(y2, y3)
    bs_1 = new_state['batch_stats']
    bs_2 = foo_b.variables['batch_stats']
    for x, y in zip(jax.tree_leaves(bs_1), jax.tree_leaves(bs_2)):
      np.testing.assert_allclose(x, y)
Пример #2
0
  def test_bind(self):
    class Foo(nn.Module):
      def setup(self):
        self.a = nn.Dense(3)
        self.b = nn.Dense(1)

    def f(foo, x):
      x = foo.a(x)
      return foo.b(x)
      
    foo = Foo()
    x = jnp.ones((4,))
    f_init = nn.init_with_output(f, foo)
    y1, variables = f_init(random.PRNGKey(0), x)
    y2 = f(foo.bind(variables), x)
    self.assertEqual(y1, y2)