def net(x): # x is [B, F]. core = recurrent.deep_rnn_with_skip_connections([ recurrent.VanillaRNN(hidden_size=3), recurrent.VanillaRNN(hidden_size=5), ]) initial_state = core.initial_state(x.shape[0]) out, _ = core(x, initial_state) return out
def test_skip_connections(self): batch_size = 4 x = make_sequence([batch_size, 3]) # [B, F] core = recurrent.deep_rnn_with_skip_connections([ recurrent.VanillaRNN(hidden_size=3), recurrent.VanillaRNN(hidden_size=5), ]) initial_state = core.initial_state(x.shape[0]) out, _ = core(x, initial_state) self.assertEqual(out.shape, (batch_size, 8))
def test_skip_validation(self): with self.assertRaisesRegex(ValueError, "skip_connections requires"): recurrent.deep_rnn_with_skip_connections([jax.nn.relu])