def test_optimized_lstm_cell_matches_regular(self): # Create regular LSTMCell. rng = random.PRNGKey(0) key1, key2 = random.split(rng) x = random.normal(key1, (2, 3)) c0, h0 = nn.LSTMCell.initialize_carry(rng, (2, ), 4) self.assertEqual(c0.shape, (2, 4)) self.assertEqual(h0.shape, (2, 4)) lstm = nn.LSTMCell() (_, y), lstm_params = lstm.init_with_output(key2, (c0, h0), x) # Create OptimizedLSTMCell. rng = random.PRNGKey(0) key1, key2 = random.split(rng) x = random.normal(key1, (2, 3)) c0, h0 = nn.OptimizedLSTMCell.initialize_carry(rng, (2, ), 4) self.assertEqual(c0.shape, (2, 4)) self.assertEqual(h0.shape, (2, 4)) lstm_opt = nn.OptimizedLSTMCell() (_, y_opt), lstm_opt_params = lstm_opt.init_with_output( key2, (c0, h0), x) np.testing.assert_allclose(y, y_opt, rtol=1e-6) jtu.check_eq(lstm_params, lstm_opt_params)
def __call__(self, carry, x): return nn.OptimizedLSTMCell()(carry, x)
def __call__(self, carry, x): for _ in range(self.n_layers): carry, x = nn.OptimizedLSTMCell(activation_fn=self.activation_fn)( carry, x) return carry, x