def test_vanilla_rnn_forward(self): print('Checking Vanilla RNN: forward') N, T, D, H = 2, 3, 4, 5 x = np.linspace(-0.1, 0.3, num=N * T * D).reshape(N, T, D) h0 = np.linspace(-0.3, 0.1, num=N * H).reshape(N, H) Wx = np.linspace(-0.2, 0.4, num=D * H).reshape(D, H) Wh = np.linspace(-0.4, 0.1, num=H * H).reshape(H, H) b = np.linspace(-0.7, 0.1, num=H) h, _ = rnn_forward(x, h0, Wx, Wh, b) expected_h = np.asarray( [[ [ -0.42070749, -0.27279261, -0.11074945, 0.05740409, 0.22236251 ], [-0.39525808, -0.22554661, -0.0409454, 0.14649412, 0.32397316], [ -0.42305111, -0.24223728, -0.04287027, 0.15997045, 0.35014525 ], ], [[-0.55857474, -0.39065825, -0.19198182, 0.02378408, 0.23735671], [-0.27150199, -0.07088804, 0.13562939, 0.33099728, 0.50158768], [-0.51014825, -0.30524429, -0.06755202, 0.17806392, 0.40333043]]]) self.assertTrue( self.rel_error(expected_h, h) < 1e-5, "rnn_forward relative error should be less than 1e-5")
def test_vanilla_rnn_backward(self): print('Checking Vanilla RNN: backward') np.random.seed(231) N, D, T, H = 2, 3, 10, 5 x = np.random.randn(N, T, D) h0 = np.random.randn(N, H) Wx = np.random.randn(D, H) Wh = np.random.randn(H, H) b = np.random.randn(H) out, cache = rnn_forward(x, h0, Wx, Wh, b) dout = np.random.randn(*out.shape) dx, dh0, dWx, dWh, db = rnn_backward(dout, cache) def fx(x): return rnn_forward(x, h0, Wx, Wh, b)[0] def fh0(h0): return rnn_forward(x, h0, Wx, Wh, b)[0] def fWx(Wx): return rnn_forward(x, h0, Wx, Wh, b)[0] def fWh(Wh): return rnn_forward(x, h0, Wx, Wh, b)[0] def fb(b): return rnn_forward(x, h0, Wx, Wh, b)[0] dx_num = self.eval_numerical_gradient_array(fx, x, dout) dh0_num = self.eval_numerical_gradient_array(fh0, h0, dout) dWx_num = self.eval_numerical_gradient_array(fWx, Wx, dout) dWh_num = self.eval_numerical_gradient_array(fWh, Wh, dout) db_num = self.eval_numerical_gradient_array(fb, b, dout) self.assertTrue( self.rel_error(dx_num, dx) < 1e-5, "dx relative error should be less than 1e-5") self.assertTrue( self.rel_error(dh0_num, dh0) < 1e-5, "dh0 relative error should be less than 1e-5") self.assertTrue( self.rel_error(dWx_num, dWx) < 1e-5, "dWx relative error should be less than 1e-5") self.assertTrue( self.rel_error(dWh_num, dWh) < 1e-5, "dWh relative error should be less than 1e-5") self.assertTrue( self.rel_error(db_num, db) < 1e-5, "db relative error should be less than 1e-5")
def fb(b): return rnn_forward(x, h0, Wx, Wh, b)[0]
def fWh(Wh): return rnn_forward(x, h0, Wx, Wh, b)[0]
def fWx(Wx): return rnn_forward(x, h0, Wx, Wh, b)[0]
def fh0(h0): return rnn_forward(x, h0, Wx, Wh, b)[0]
def fx(x): return rnn_forward(x, h0, Wx, Wh, b)[0]