def test_toplevel_submodule_adoption_pytree_transform(self): class A(nn.Module): @nn.compact def __call__(self, c, x): counter = self.variable('counter', 'i', jnp.zeros, ()) counter.value += 1 x = nn.Dense(1)(x) return c, x class B(nn.Module): A: Any @nn.compact def __call__(self, c, x): return self.A['foo'](*self.A['bar'](c, x)) a = A() As = {'foo': A(), 'bar': A()} b = nn.scan(B, in_axes=0, variable_carry='counter', variable_broadcast='params', split_rngs={'params': False})(As) key = random.PRNGKey(0) x = jnp.ones((10, 2)) p = B(As).init(key, x, x) y, cntrs = b.apply(p, x, x, mutable='counter') ref_cntrs = freeze({ 'counter': { 'A_bar': { 'i': jnp.array(11.0), }, 'A_foo': { 'i': jnp.array(11.0), }, }, }) self.assertTrue( jax.tree_util.tree_all( jax.tree_multimap( lambda x, y: np.testing.assert_allclose(x, y, atol=1e-7), cntrs, ref_cntrs)))
def __call__(self, c, xs): LSTM = nn.scan(nn.LSTMCell, variable_broadcast='params', split_rngs={'params': False}) return LSTM(name="lstm_cell")(c, xs)
def __call__(self, c, xs): LSTM = nn.scan(nn.LSTMCell, variable_in_axes={'param': nn.broadcast}, variable_out_axes={'param': nn.broadcast}, split_rngs={'param': False}) return LSTM(name="lstm_cell")(c, xs)