示例#1
0
    def test_connection_and_shapes(self):
        batch_size = 4
        x = make_sequence([batch_size, 3])  # [B, F]
        core = recurrent.DeepRNN([
            recurrent.VanillaRNN(hidden_size=3),
            basic.Linear(2),
            jax.nn.relu,
            recurrent.VanillaRNN(hidden_size=5),
            jax.nn.relu,
        ])
        initial_state = core.initial_state(x.shape[0])
        out, next_state = core(x, initial_state)

        self.assertEqual(out.shape, (batch_size, 5))
        # Verifies that at least last layer of relu is applied.
        self.assertTrue(np.all(out >= np.zeros([batch_size, 5])))

        self.assertLen(next_state, 2)
        self.assertEqual(initial_state[0].shape, (batch_size, 3))
        self.assertEqual(initial_state[1].shape, (batch_size, 5))

        self.assertLen(initial_state, 2)
        np.testing.assert_allclose(initial_state[0], jnp.zeros([batch_size,
                                                                3]))
        np.testing.assert_allclose(initial_state[1], jnp.zeros([batch_size,
                                                                5]))
示例#2
0
 def test_only_callables(self):
     x = make_sequence([4, 3])  # [B, F]
     core = recurrent.DeepRNN([jnp.tanh, jnp.square])
     initial_state = core.initial_state(x.shape[0])
     out, next_state = core(x, initial_state)
     np.testing.assert_allclose(out, np.square(np.tanh(x)), rtol=1e-4)
     self.assertEmpty(next_state)
     self.assertEmpty(initial_state)
示例#3
0
 def net(x):
     # x is [B, F].
     core = recurrent.DeepRNN([jnp.tanh, jnp.square])
     initial_state = core.initial_state(x.shape[0])
     out, next_state = core(x, initial_state)
     return dict(out=out,
                 next_state=next_state,
                 initial_state=initial_state)
示例#4
0
        def net(x):
            # x is [B, F].
            core = recurrent.DeepRNN([
                recurrent.VanillaRNN(hidden_size=3),
                basic.Linear(2),
                jax.nn.relu,
                recurrent.VanillaRNN(hidden_size=5),
                jax.nn.relu,
            ])
            initial_state = core.initial_state(x.shape[0])
            out, next_state = core(x, initial_state)

            return dict(out=out,
                        next_state=next_state,
                        initial_state=initial_state)