示例#1
0
    def test_optimized_lstm_cell_matches_regular(self):

        # Create regular LSTMCell.
        rng = random.PRNGKey(0)
        key1, key2 = random.split(rng)
        x = random.normal(key1, (2, 3))
        c0, h0 = nn.LSTMCell.initialize_carry(rng, (2, ), 4)
        self.assertEqual(c0.shape, (2, 4))
        self.assertEqual(h0.shape, (2, 4))
        lstm = nn.LSTMCell()
        (_, y), lstm_params = lstm.init_with_output(key2, (c0, h0), x)

        # Create OptimizedLSTMCell.
        rng = random.PRNGKey(0)
        key1, key2 = random.split(rng)
        x = random.normal(key1, (2, 3))
        c0, h0 = nn.OptimizedLSTMCell.initialize_carry(rng, (2, ), 4)
        self.assertEqual(c0.shape, (2, 4))
        self.assertEqual(h0.shape, (2, 4))
        lstm_opt = nn.OptimizedLSTMCell()
        (_, y_opt), lstm_opt_params = lstm_opt.init_with_output(
            key2, (c0, h0), x)

        np.testing.assert_allclose(y, y_opt, rtol=1e-6)
        jtu.check_eq(lstm_params, lstm_opt_params)
示例#2
0
 def __call__(self, carry, x):
     return nn.OptimizedLSTMCell()(carry, x)
示例#3
0
文件: lstm.py 项目: google/deluca
 def __call__(self, carry, x):
     for _ in range(self.n_layers):
         carry, x = nn.OptimizedLSTMCell(activation_fn=self.activation_fn)(
             carry, x)
     return carry, x