def test_update_in_combinator(self):
   def template(x, init_key=None):
     def increment(x, init_key=None):
       return Counter(jnp.zeros(()))(x, init_key=init_key, name='counter')
     return nn.Serial([increment, increment])(x, init_key=init_key,
                                              name='increment')
   net = state.init(template)(self._seed, jnp.ones(()))
   self.assertEqual(net(jnp.ones(())), 1.)
   net = state.update(net, jnp.ones(()))
   self.assertEqual(net(jnp.ones(())), 3.)
  def test_update(self):
    def template(x, init_key=None):
      return Counter(jnp.zeros(()))(x, init_key=init_key, name='counter')
    net = state.init(template)(self._seed, jnp.ones(()))
    self.assertEqual(net(jnp.ones(())), 1.)

    net2 = state.update(net, jnp.ones(()))
    self.assertEqual(net2(jnp.ones(())), 2.)

    net2 = net.update(jnp.ones(()))
    self.assertEqual(net2(jnp.ones(())), 2.)