Exemplo n.º 1
0
def test_Rnn_shape():
    inputs = np.zeros((2, 5, 4))
    rnn = Rnn(*GRUCell(3, zeros))
    params = rnn.init_parameters(PRNGKey(0), inputs)

    assert len(params) == 1
    assert len(params.gru_cell) == 3
    assert np.array_equal(np.zeros((7, 3)), params.gru_cell.update_kernel)
    assert np.array_equal(np.zeros((7, 3)), params.gru_cell.reset_kernel)
    assert np.array_equal(np.zeros((7, 3)), params.gru_cell.compute_kernel)

    out = rnn.apply(params, inputs)
    assert np.array_equal(np.zeros((2, 5, 3)), out)
Exemplo n.º 2
0
    def rnn(): return Rnn(*GRUCell(carry_size, zeros))

    net = Sequential(
Exemplo n.º 3
0
 def rnn():
     return Rnn(*GRUCell(carry_size=carry_size,
                         param_init=lambda key, shape: random.normal(key, shape) * 0.01))
Exemplo n.º 4
0
 def rnn():
     return Rnn(*GRUCell(carry_size, zeros))