示例#1
0
    def test_toplevel_submodule_adoption_pytree_transform(self):
        class A(nn.Module):
            @nn.compact
            def __call__(self, c, x):
                counter = self.variable('counter', 'i', jnp.zeros, ())
                counter.value += 1
                x = nn.Dense(1)(x)
                return c, x

        class B(nn.Module):
            A: Any

            @nn.compact
            def __call__(self, c, x):
                return self.A['foo'](*self.A['bar'](c, x))

        a = A()
        As = {'foo': A(), 'bar': A()}
        b = nn.scan(B,
                    in_axes=0,
                    variable_carry='counter',
                    variable_broadcast='params',
                    split_rngs={'params': False})(As)

        key = random.PRNGKey(0)
        x = jnp.ones((10, 2))

        p = B(As).init(key, x, x)
        y, cntrs = b.apply(p, x, x, mutable='counter')
        ref_cntrs = freeze({
            'counter': {
                'A_bar': {
                    'i': jnp.array(11.0),
                },
                'A_foo': {
                    'i': jnp.array(11.0),
                },
            },
        })
        self.assertTrue(
            jax.tree_util.tree_all(
                jax.tree_multimap(
                    lambda x, y: np.testing.assert_allclose(x, y, atol=1e-7),
                    cntrs, ref_cntrs)))
示例#2
0
 def __call__(self, c, xs):
     LSTM = nn.scan(nn.LSTMCell,
                    variable_broadcast='params',
                    split_rngs={'params': False})
     return LSTM(name="lstm_cell")(c, xs)
示例#3
0
 def __call__(self, c, xs):
     LSTM = nn.scan(nn.LSTMCell,
                    variable_in_axes={'param': nn.broadcast},
                    variable_out_axes={'param': nn.broadcast},
                    split_rngs={'param': False})
     return LSTM(name="lstm_cell")(c, xs)