Пример #1
0
 def net(x):
     # x is [B, F].
     core = recurrent.deep_rnn_with_skip_connections([
         recurrent.VanillaRNN(hidden_size=3),
         recurrent.VanillaRNN(hidden_size=5),
     ])
     initial_state = core.initial_state(x.shape[0])
     out, _ = core(x, initial_state)
     return out
Пример #2
0
 def test_skip_connections(self):
     batch_size = 4
     x = make_sequence([batch_size, 3])  # [B, F]
     core = recurrent.deep_rnn_with_skip_connections([
         recurrent.VanillaRNN(hidden_size=3),
         recurrent.VanillaRNN(hidden_size=5),
     ])
     initial_state = core.initial_state(x.shape[0])
     out, _ = core(x, initial_state)
     self.assertEqual(out.shape, (batch_size, 8))
Пример #3
0
 def test_skip_validation(self):
     with self.assertRaisesRegex(ValueError, "skip_connections requires"):
         recurrent.deep_rnn_with_skip_connections([jax.nn.relu])