def test_connection_and_shapes(self): batch_size = 4 x = make_sequence([batch_size, 3]) # [B, F] core = recurrent.DeepRNN([ recurrent.VanillaRNN(hidden_size=3), basic.Linear(2), jax.nn.relu, recurrent.VanillaRNN(hidden_size=5), jax.nn.relu, ]) initial_state = core.initial_state(x.shape[0]) out, next_state = core(x, initial_state) self.assertEqual(out.shape, (batch_size, 5)) # Verifies that at least last layer of relu is applied. self.assertTrue(np.all(out >= np.zeros([batch_size, 5]))) self.assertLen(next_state, 2) self.assertEqual(initial_state[0].shape, (batch_size, 3)) self.assertEqual(initial_state[1].shape, (batch_size, 5)) self.assertLen(initial_state, 2) np.testing.assert_allclose(initial_state[0], jnp.zeros([batch_size, 3])) np.testing.assert_allclose(initial_state[1], jnp.zeros([batch_size, 5]))
def test_only_callables(self): x = make_sequence([4, 3]) # [B, F] core = recurrent.DeepRNN([jnp.tanh, jnp.square]) initial_state = core.initial_state(x.shape[0]) out, next_state = core(x, initial_state) np.testing.assert_allclose(out, np.square(np.tanh(x)), rtol=1e-4) self.assertEmpty(next_state) self.assertEmpty(initial_state)
def net(x): # x is [B, F]. core = recurrent.DeepRNN([jnp.tanh, jnp.square]) initial_state = core.initial_state(x.shape[0]) out, next_state = core(x, initial_state) return dict(out=out, next_state=next_state, initial_state=initial_state)
def net(x): # x is [B, F]. core = recurrent.DeepRNN([ recurrent.VanillaRNN(hidden_size=3), basic.Linear(2), jax.nn.relu, recurrent.VanillaRNN(hidden_size=5), jax.nn.relu, ]) initial_state = core.initial_state(x.shape[0]) out, next_state = core(x, initial_state) return dict(out=out, next_state=next_state, initial_state=initial_state)