def test_connection_and_shapes(self): batch_size = 4 x = make_sequence([batch_size, 3]) # [B, F] core = recurrent.DeepRNN([ recurrent.VanillaRNN(hidden_size=3), basic.Linear(2), jax.nn.relu, recurrent.VanillaRNN(hidden_size=5), jax.nn.relu, ]) initial_state = core.initial_state(x.shape[0]) out, next_state = core(x, initial_state) self.assertEqual(out.shape, (batch_size, 5)) # Verifies that at least last layer of relu is applied. self.assertTrue(np.all(out >= np.zeros([batch_size, 5]))) self.assertLen(next_state, 2) self.assertEqual(initial_state[0].shape, (batch_size, 3)) self.assertEqual(initial_state[1].shape, (batch_size, 5)) self.assertLen(initial_state, 2) np.testing.assert_allclose(initial_state[0], jnp.zeros([batch_size, 3])) np.testing.assert_allclose(initial_state[1], jnp.zeros([batch_size, 5]))
def net(x): # x is [B, F]. core = recurrent.deep_rnn_with_skip_connections([ recurrent.VanillaRNN(hidden_size=3), recurrent.VanillaRNN(hidden_size=5), ]) initial_state = core.initial_state(x.shape[0]) out, _ = core(x, initial_state) return out
def test_skip_connections(self): batch_size = 4 x = make_sequence([batch_size, 3]) # [B, F] core = recurrent.deep_rnn_with_skip_connections([ recurrent.VanillaRNN(hidden_size=3), recurrent.VanillaRNN(hidden_size=5), ]) initial_state = core.initial_state(x.shape[0]) out, _ = core(x, initial_state) self.assertEqual(out.shape, (batch_size, 8))
def test_double_bias_length_parameters(self): double_bias = recurrent.VanillaRNN(1, double_bias=True) double_bias(jnp.zeros([1]), double_bias.initial_state(None)) double_bias_params = jax.tree_leaves(double_bias.params_dict()) vanilla = recurrent.VanillaRNN(1, double_bias=False) vanilla(jnp.zeros([1]), vanilla.initial_state(None)) vanilla_params = jax.tree_leaves(vanilla.params_dict()) self.assertLen(double_bias_params, len(vanilla_params) + 1)
def net(x): # x is [B, F]. core = recurrent.DeepRNN([ recurrent.VanillaRNN(hidden_size=3), basic.Linear(2), jax.nn.relu, recurrent.VanillaRNN(hidden_size=5), jax.nn.relu, ]) initial_state = core.initial_state(x.shape[0]) out, next_state = core(x, initial_state) return dict(out=out, next_state=next_state, initial_state=initial_state)
def test_core_unroll_nested(self, unroll): seqs = make_sequence([4, 8, 1]) batch_size = seqs.shape[1] core = DuplicateCore(recurrent.VanillaRNN(hidden_size=4)) outs, _ = unroll(core, seqs, core.initial_state(batch_size)) self.assertLen(outs, 2) for out in outs: self.assertEqual(out.shape, (4, 8, 4))