def test_update_in_combinator(self): def template(x, init_key=None): def increment(x, init_key=None): return Counter(jnp.zeros(()))(x, init_key=init_key, name='counter') return nn.Serial([increment, increment])(x, init_key=init_key, name='increment') net = state.init(template)(self._seed, jnp.ones(())) self.assertEqual(net(jnp.ones(())), 1.) net = state.update(net, jnp.ones(())) self.assertEqual(net(jnp.ones(())), 3.)
def test_update(self): def template(x, init_key=None): return Counter(jnp.zeros(()))(x, init_key=init_key, name='counter') net = state.init(template)(self._seed, jnp.ones(())) self.assertEqual(net(jnp.ones(())), 1.) net2 = state.update(net, jnp.ones(())) self.assertEqual(net2(jnp.ones(())), 2.) net2 = net.update(jnp.ones(())) self.assertEqual(net2(jnp.ones(())), 2.)