def net(seqs, should_reset): # seqs is [T, B, F]. core = recurrent.LSTM(hidden_size=4) reset_core = recurrent.ResetCore(core) batch_size = seqs.shape[1] # Statically unroll, collecting states. core_outs, core_states = static_unroll_with_states( core, seqs, core.initial_state(batch_size)) reset_outs, reset_states = static_unroll_with_states( reset_core, (seqs, should_reset), reset_core.initial_state(batch_size)) # Unroll without access to intermediate states. dynamic_core_outs, dynamic_core_state = unroll( core, seqs, core.initial_state(batch_size)) dynamic_reset_outs, dynamic_reset_state = unroll( reset_core, (seqs, should_reset), reset_core.initial_state(batch_size)) return dict( core_outs=core_outs, core_states=core_states, reset_outs=reset_outs, reset_states=reset_states, dynamic_core_outs=dynamic_core_outs, dynamic_core_state=dynamic_core_state, dynamic_reset_outs=dynamic_reset_outs, dynamic_reset_state=dynamic_reset_state, )
def test_invalid_input(self): core = recurrent.LSTM(hidden_size=4) reset_core = recurrent.ResetCore(core) with self.assertRaisesRegex(ValueError, "should_reset must have rank-1 of state."): reset_core((jnp.array([1, 2, 3]), jnp.array([2, 3, 4])), jnp.array([2, 3, 4]))
def test_lstm_raises(self): core = recurrent.LSTM(4) with self.assertRaisesRegex(ValueError, "rank-1 or rank-2"): core(jnp.zeros([]), core.initial_state(None)) with self.assertRaisesRegex(ValueError, "rank-1 or rank-2"): expanded_state = tree.map_structure( lambda x: jnp.expand_dims(x, 0), core.initial_state(1)) core(jnp.zeros([1, 1, 1]), expanded_state)
def test_batch_major(self, unroll): core = recurrent.LSTM(4) sequence_len, batch_size = 10, 5 inputs = np.random.randn(sequence_len, batch_size, 2) batch_major_inputs = jnp.swapaxes(inputs, 0, 1) initial_state = core.initial_state(batch_size) time_major_outputs, time_major_unroll_state_out = unroll( core, inputs, initial_state, time_major=True) batch_major_outputs, batch_major_unroll_state_out = unroll( core, batch_major_inputs, initial_state, time_major=False) jax.tree_multimap(np.testing.assert_array_equal, time_major_unroll_state_out, batch_major_unroll_state_out) jax.tree_multimap( lambda x, y: np.testing.assert_array_equal(x, jnp.swapaxes(y, 0, 1)), time_major_outputs, batch_major_outputs)