예제 #1
0
        def net(seqs, should_reset):
            # seqs is [T, B, F].
            core = recurrent.LSTM(hidden_size=4)
            reset_core = recurrent.ResetCore(core)
            batch_size = seqs.shape[1]

            # Statically unroll, collecting states.
            core_outs, core_states = static_unroll_with_states(
                core, seqs, core.initial_state(batch_size))
            reset_outs, reset_states = static_unroll_with_states(
                reset_core, (seqs, should_reset),
                reset_core.initial_state(batch_size))

            # Unroll without access to intermediate states.
            dynamic_core_outs, dynamic_core_state = unroll(
                core, seqs, core.initial_state(batch_size))
            dynamic_reset_outs, dynamic_reset_state = unroll(
                reset_core, (seqs, should_reset),
                reset_core.initial_state(batch_size))

            return dict(
                core_outs=core_outs,
                core_states=core_states,
                reset_outs=reset_outs,
                reset_states=reset_states,
                dynamic_core_outs=dynamic_core_outs,
                dynamic_core_state=dynamic_core_state,
                dynamic_reset_outs=dynamic_reset_outs,
                dynamic_reset_state=dynamic_reset_state,
            )
예제 #2
0
 def test_invalid_input(self):
     core = recurrent.LSTM(hidden_size=4)
     reset_core = recurrent.ResetCore(core)
     with self.assertRaisesRegex(ValueError,
                                 "should_reset must have rank-1 of state."):
         reset_core((jnp.array([1, 2, 3]), jnp.array([2, 3, 4])),
                    jnp.array([2, 3, 4]))
예제 #3
0
    def test_lstm_raises(self):
        core = recurrent.LSTM(4)
        with self.assertRaisesRegex(ValueError, "rank-1 or rank-2"):
            core(jnp.zeros([]), core.initial_state(None))

        with self.assertRaisesRegex(ValueError, "rank-1 or rank-2"):
            expanded_state = tree.map_structure(
                lambda x: jnp.expand_dims(x, 0), core.initial_state(1))
            core(jnp.zeros([1, 1, 1]), expanded_state)
예제 #4
0
  def test_batch_major(self, unroll):
    core = recurrent.LSTM(4)
    sequence_len, batch_size = 10, 5

    inputs = np.random.randn(sequence_len, batch_size, 2)
    batch_major_inputs = jnp.swapaxes(inputs, 0, 1)

    initial_state = core.initial_state(batch_size)
    time_major_outputs, time_major_unroll_state_out = unroll(
        core, inputs, initial_state, time_major=True)
    batch_major_outputs, batch_major_unroll_state_out = unroll(
        core, batch_major_inputs, initial_state, time_major=False)

    jax.tree_multimap(np.testing.assert_array_equal,
                      time_major_unroll_state_out, batch_major_unroll_state_out)
    jax.tree_multimap(
        lambda x, y: np.testing.assert_array_equal(x, jnp.swapaxes(y, 0, 1)),
        time_major_outputs, batch_major_outputs)