def net(seqs, should_reset): # seqs is [T, B, F]. core = recurrent.LSTM(hidden_size=4) reset_core = recurrent.ResetCore(core) batch_size = seqs.shape[1] # Statically unroll, collecting states. core_outs, core_states = static_unroll_with_states( core, seqs, core.initial_state(batch_size)) reset_outs, reset_states = static_unroll_with_states( reset_core, (seqs, should_reset), reset_core.initial_state(batch_size)) # Unroll without access to intermediate states. dynamic_core_outs, dynamic_core_state = unroll( core, seqs, core.initial_state(batch_size)) dynamic_reset_outs, dynamic_reset_state = unroll( reset_core, (seqs, should_reset), reset_core.initial_state(batch_size)) return dict( core_outs=core_outs, core_states=core_states, reset_outs=reset_outs, reset_states=reset_states, dynamic_core_outs=dynamic_core_outs, dynamic_core_state=dynamic_core_state, dynamic_reset_outs=dynamic_reset_outs, dynamic_reset_state=dynamic_reset_state, )
def test_invalid_input(self): core = recurrent.LSTM(hidden_size=4) reset_core = recurrent.ResetCore(core) with self.assertRaisesRegex(ValueError, "should_reset must have rank-1 of state."): reset_core((jnp.array([1, 2, 3]), jnp.array([2, 3, 4])), jnp.array([2, 3, 4]))
def test_allow_batched_only_cores(self): # Ensures batched-only cores can be wrapped with ResetCore. core = recurrent.ResetCore(_BatchedOnlyCore()) batch_size = 5 inputs = jnp.ones((batch_size, 4)) prev_state = core.initial_state(batch_size) should_reset = 0 * prev_state core((inputs, should_reset), prev_state)
def test_unbatched(self, unroll): reset_time = 2 seq_len = 5 state_size = 4 core = recurrent.ResetCore(_IncrementByOneCore(state_size=state_size)) inputs = jnp.arange(0, seq_len) batch_size = None # Unbatched. should_reset = inputs == reset_time initial_state = core.initial_state(batch_size) result, _ = unroll(core, (inputs, should_reset), initial_state) expected_result = np.array([ # seq_len x state_size [1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0], [1.0, 1.0, 1.0, 1.0], # reset_time = 2. [2.0, 2.0, 2.0, 2.0], [3.0, 3.0, 3.0, 3.0] ]) np.testing.assert_allclose(result, expected_result, rtol=1e-6, atol=1e-6)
def test_reversed_dynamic_unroll(self, batch_size): reset_time = 2 seq_len = 7 state_size = 4 core = recurrent.ResetCore(_IncrementByOneCore(state_size=state_size)) initial_state = core.initial_state(batch_size) inputs = jnp.arange(0, seq_len) # seq_len if batch_size is not None: # seq_len x batch_size inputs = jnp.stack([inputs] * batch_size, axis=1) should_reset = inputs == reset_time fwd_result, _ = recurrent.dynamic_unroll( core, (inputs[::-1], should_reset[::-1]), initial_state, reverse=False) rev_result, _ = recurrent.dynamic_unroll(core, (inputs, should_reset), initial_state, reverse=True) np.testing.assert_allclose(fwd_result[::-1], rev_result)
def test_input_conform_fails(self, reset, state): core = recurrent.ResetCore(core=_DummyCore(state=state)) with self.assertRaises(ValueError): core((state, reset), state)
def test_input_conform(self, reset, state): core = recurrent.ResetCore(core=_DummyCore(state=state)) core((state, reset), state)