コード例 #1
0
ファイル: recurrent_test.py プロジェクト: tirkarthi/dm-haiku
    def test_connection_and_shapes(self):
        batch_size = 4
        x = make_sequence([batch_size, 3])  # [B, F]
        core = recurrent.DeepRNN([
            recurrent.VanillaRNN(hidden_size=3),
            basic.Linear(2),
            jax.nn.relu,
            recurrent.VanillaRNN(hidden_size=5),
            jax.nn.relu,
        ])
        initial_state = core.initial_state(x.shape[0])
        out, next_state = core(x, initial_state)

        self.assertEqual(out.shape, (batch_size, 5))
        # Verifies that at least last layer of relu is applied.
        self.assertTrue(np.all(out >= np.zeros([batch_size, 5])))

        self.assertLen(next_state, 2)
        self.assertEqual(initial_state[0].shape, (batch_size, 3))
        self.assertEqual(initial_state[1].shape, (batch_size, 5))

        self.assertLen(initial_state, 2)
        np.testing.assert_allclose(initial_state[0], jnp.zeros([batch_size,
                                                                3]))
        np.testing.assert_allclose(initial_state[1], jnp.zeros([batch_size,
                                                                5]))
コード例 #2
0
 def net(x):
     # x is [B, F].
     core = recurrent.deep_rnn_with_skip_connections([
         recurrent.VanillaRNN(hidden_size=3),
         recurrent.VanillaRNN(hidden_size=5),
     ])
     initial_state = core.initial_state(x.shape[0])
     out, _ = core(x, initial_state)
     return out
コード例 #3
0
ファイル: recurrent_test.py プロジェクト: tirkarthi/dm-haiku
 def test_skip_connections(self):
     batch_size = 4
     x = make_sequence([batch_size, 3])  # [B, F]
     core = recurrent.deep_rnn_with_skip_connections([
         recurrent.VanillaRNN(hidden_size=3),
         recurrent.VanillaRNN(hidden_size=5),
     ])
     initial_state = core.initial_state(x.shape[0])
     out, _ = core(x, initial_state)
     self.assertEqual(out.shape, (batch_size, 8))
コード例 #4
0
ファイル: recurrent_test.py プロジェクト: vinid/dm-haiku
    def test_double_bias_length_parameters(self):
        double_bias = recurrent.VanillaRNN(1, double_bias=True)
        double_bias(jnp.zeros([1]), double_bias.initial_state(None))
        double_bias_params = jax.tree_leaves(double_bias.params_dict())

        vanilla = recurrent.VanillaRNN(1, double_bias=False)
        vanilla(jnp.zeros([1]), vanilla.initial_state(None))
        vanilla_params = jax.tree_leaves(vanilla.params_dict())

        self.assertLen(double_bias_params, len(vanilla_params) + 1)
コード例 #5
0
        def net(x):
            # x is [B, F].
            core = recurrent.DeepRNN([
                recurrent.VanillaRNN(hidden_size=3),
                basic.Linear(2),
                jax.nn.relu,
                recurrent.VanillaRNN(hidden_size=5),
                jax.nn.relu,
            ])
            initial_state = core.initial_state(x.shape[0])
            out, next_state = core(x, initial_state)

            return dict(out=out,
                        next_state=next_state,
                        initial_state=initial_state)
コード例 #6
0
ファイル: recurrent_test.py プロジェクト: tirkarthi/dm-haiku
 def test_core_unroll_nested(self, unroll):
     seqs = make_sequence([4, 8, 1])
     batch_size = seqs.shape[1]
     core = DuplicateCore(recurrent.VanillaRNN(hidden_size=4))
     outs, _ = unroll(core, seqs, core.initial_state(batch_size))
     self.assertLen(outs, 2)
     for out in outs:
         self.assertEqual(out.shape, (4, 8, 4))