コード例 #1
0
ファイル: test_modules.py プロジェクト: QUELUCIFER/jaxnet
def test_Rnn_shape():
    inputs = np.zeros((2, 5, 4))
    rnn = Rnn(*GRUCell(3, zeros))
    params = rnn.init_parameters(PRNGKey(0), inputs)

    assert len(params) == 1
    assert len(params.gru_cell) == 3
    assert np.array_equal(np.zeros((7, 3)), params.gru_cell.update_kernel)
    assert np.array_equal(np.zeros((7, 3)), params.gru_cell.reset_kernel)
    assert np.array_equal(np.zeros((7, 3)), params.gru_cell.compute_kernel)

    out = rnn.apply(params, inputs)
    assert np.array_equal(np.zeros((2, 5, 3)), out)
コード例 #2
0
ファイル: test_examples.py プロジェクト: tom-bird/jaxnet
    def rnn(): return Rnn(*GRUCell(carry_size, zeros))

    net = Sequential(
コード例 #3
0
 def rnn():
     return Rnn(*GRUCell(carry_size=carry_size,
                         param_init=lambda key, shape: random.normal(key, shape) * 0.01))
コード例 #4
0
ファイル: test_examples.py プロジェクト: juliuskunze/jaxnet
 def rnn():
     return Rnn(*GRUCell(carry_size, zeros))