def test_Rnn_shape(): inputs = np.zeros((2, 5, 4)) rnn = Rnn(*GRUCell(3, zeros)) params = rnn.init_parameters(PRNGKey(0), inputs) assert len(params) == 1 assert len(params.gru_cell) == 3 assert np.array_equal(np.zeros((7, 3)), params.gru_cell.update_kernel) assert np.array_equal(np.zeros((7, 3)), params.gru_cell.reset_kernel) assert np.array_equal(np.zeros((7, 3)), params.gru_cell.compute_kernel) out = rnn.apply(params, inputs) assert np.array_equal(np.zeros((2, 5, 3)), out)
def rnn(): return Rnn(*GRUCell(carry_size, zeros)) net = Sequential(
def rnn(): return Rnn(*GRUCell(carry_size=carry_size, param_init=lambda key, shape: random.normal(key, shape) * 0.01))
def rnn(): return Rnn(*GRUCell(carry_size, zeros))