def test_rnn_var_node(self): x = np.array([[1, 2, 1], [-1, 0, -.5]]).T x = x.reshape((3, 1, 2)) x0_var = rnn.RnnVarNode(0, x) x1_var = rnn.RnnVarNode(1, x) np.testing.assert_equal(x0_var.value(), x[:, :, 0]) np.testing.assert_equal(x1_var.value(), x[:, :, 1]) debug("[SimpleRnnCellTests.test_rnn_var_node()] x0_var.value() = np.{}".format(repr(x0_var.value())))
def test_2_seq_rnn(self): x = np.array([[1, 2, 1], [-1, 0, -.5]]).T x = x.reshape((3, 1, 2)) x0_var = rnn.RnnVarNode(0, x) x1_var = rnn.RnnVarNode(1, x) cell1 = rnn.RnnCell(x0_var, None,self.w_param, self.wb_param, self.u_param, self.ub_param, self.h ) cell2 = rnn.RnnCell(x1_var, cell1,self.w_param, self.wb_param, self.u_param, self.ub_param ) x0_var.forward(self.var_map) x1_var.forward(self.var_map) y,h = cell2.value() debug("[SimpleRnnCellTests.test_2_seq_rnn()] y = np.{}".format(repr(y))) debug("[SimpleRnnCellTests.test_2_seq_rnn()] h = np.{}".format(repr(h))) dely, delh = y * .1, h * .1 cell2.backward((dely, None), self, var_map=self.var_map) wgrad = self.w_param._total_incoming_gradient() debug("[SimpleRnnCellTests.test_2_seq_rnn()] wgrad = np.{}".format(repr(wgrad)))