Example #1
0
        def net(seqs, should_reset):
            # seqs is [T, B, F].
            core = recurrent.LSTM(hidden_size=4)
            reset_core = recurrent.ResetCore(core)
            batch_size = seqs.shape[1]

            # Statically unroll, collecting states.
            core_outs, core_states = static_unroll_with_states(
                core, seqs, core.initial_state(batch_size))
            reset_outs, reset_states = static_unroll_with_states(
                reset_core, (seqs, should_reset),
                reset_core.initial_state(batch_size))

            # Unroll without access to intermediate states.
            dynamic_core_outs, dynamic_core_state = unroll(
                core, seqs, core.initial_state(batch_size))
            dynamic_reset_outs, dynamic_reset_state = unroll(
                reset_core, (seqs, should_reset),
                reset_core.initial_state(batch_size))

            return dict(
                core_outs=core_outs,
                core_states=core_states,
                reset_outs=reset_outs,
                reset_states=reset_states,
                dynamic_core_outs=dynamic_core_outs,
                dynamic_core_state=dynamic_core_state,
                dynamic_reset_outs=dynamic_reset_outs,
                dynamic_reset_state=dynamic_reset_state,
            )
Example #2
0
 def test_invalid_input(self):
     core = recurrent.LSTM(hidden_size=4)
     reset_core = recurrent.ResetCore(core)
     with self.assertRaisesRegex(ValueError,
                                 "should_reset must have rank-1 of state."):
         reset_core((jnp.array([1, 2, 3]), jnp.array([2, 3, 4])),
                    jnp.array([2, 3, 4]))
Example #3
0
 def test_allow_batched_only_cores(self):
   # Ensures batched-only cores can be wrapped with ResetCore.
   core = recurrent.ResetCore(_BatchedOnlyCore())
   batch_size = 5
   inputs = jnp.ones((batch_size, 4))
   prev_state = core.initial_state(batch_size)
   should_reset = 0 * prev_state
   core((inputs, should_reset), prev_state)
Example #4
0
  def test_unbatched(self, unroll):
    reset_time = 2
    seq_len = 5
    state_size = 4

    core = recurrent.ResetCore(_IncrementByOneCore(state_size=state_size))
    inputs = jnp.arange(0, seq_len)
    batch_size = None  # Unbatched.
    should_reset = inputs == reset_time
    initial_state = core.initial_state(batch_size)
    result, _ = unroll(core, (inputs, should_reset), initial_state)

    expected_result = np.array([  # seq_len x state_size
        [1.0, 1.0, 1.0, 1.0],
        [2.0, 2.0, 2.0, 2.0],
        [1.0, 1.0, 1.0, 1.0],  # reset_time = 2.
        [2.0, 2.0, 2.0, 2.0],
        [3.0, 3.0, 3.0, 3.0]
    ])
    np.testing.assert_allclose(result, expected_result, rtol=1e-6, atol=1e-6)
Example #5
0
    def test_reversed_dynamic_unroll(self, batch_size):
        reset_time = 2
        seq_len = 7
        state_size = 4

        core = recurrent.ResetCore(_IncrementByOneCore(state_size=state_size))
        initial_state = core.initial_state(batch_size)

        inputs = jnp.arange(0, seq_len)  # seq_len
        if batch_size is not None:
            # seq_len x batch_size
            inputs = jnp.stack([inputs] * batch_size, axis=1)

        should_reset = inputs == reset_time
        fwd_result, _ = recurrent.dynamic_unroll(
            core, (inputs[::-1], should_reset[::-1]),
            initial_state,
            reverse=False)
        rev_result, _ = recurrent.dynamic_unroll(core, (inputs, should_reset),
                                                 initial_state,
                                                 reverse=True)
        np.testing.assert_allclose(fwd_result[::-1], rev_result)
Example #6
0
 def test_input_conform_fails(self, reset, state):
     core = recurrent.ResetCore(core=_DummyCore(state=state))
     with self.assertRaises(ValueError):
         core((state, reset), state)
Example #7
0
 def test_input_conform(self, reset, state):
     core = recurrent.ResetCore(core=_DummyCore(state=state))
     core((state, reset), state)