def test_2_seq_rnn(self): x = np.array([[1, 2, 1], [-1, 0, -.5]]).T x = x.reshape((3, 1, 2)) x0_var = rnn.RnnVarNode(0, x) x1_var = rnn.RnnVarNode(1, x) cell1 = rnn.RnnCell(x0_var, None,self.w_param, self.wb_param, self.u_param, self.ub_param, self.h ) cell2 = rnn.RnnCell(x1_var, cell1,self.w_param, self.wb_param, self.u_param, self.ub_param ) x0_var.forward(self.var_map) x1_var.forward(self.var_map) y,h = cell2.value() debug("[SimpleRnnCellTests.test_2_seq_rnn()] y = np.{}".format(repr(y))) debug("[SimpleRnnCellTests.test_2_seq_rnn()] h = np.{}".format(repr(h))) dely, delh = y * .1, h * .1 cell2.backward((dely, None), self, var_map=self.var_map) wgrad = self.w_param._total_incoming_gradient() debug("[SimpleRnnCellTests.test_2_seq_rnn()] wgrad = np.{}".format(repr(wgrad)))
def test_forward(self): input_x_node = n.VarNode('x') rnn_cell = rnn.RnnCell(input_x_node, None, self.w_param, self.wb_param, self.u_param, self.ub_param, self.h) input_x_node.forward(self.var_map) y, h = rnn_cell.value() debug("[SimpleRnnCellTests.test_forward()] y = np.{}".format(repr(y))) debug("[SimpleRnnCellTests.test_forward()] h = np.{}".format(repr(h))) dely, delh = y * .1, h * .1 rnn_cell.backward((dely, delh), self, self.var_map) grad_x = input_x_node.total_incoming_gradient() debug("[SimpleRnnCellTests.test_forward()] grad_x = np.{}".format(repr(grad_x)))