def testGrid2BasicRNNCellTied(self): with self.test_session() as sess: with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([2, 2]) m = (array_ops.zeros([2, 2]), array_ops.zeros([2, 2])) cell = grid_rnn_cell.Grid2BasicRNNCell(2, tied=True) self.assertEqual(cell.state_size, (2, 2)) g, s = cell(x, m) self.assertEqual(g[0].get_shape(), (2, 2)) self.assertEqual(s[0].get_shape(), (2, 2)) self.assertEqual(s[1].get_shape(), (2, 2)) sess.run([variables.global_variables_initializer()]) res_g, res_s = sess.run( [g, s], { x: np.array([[1., 1.], [2., 2.]]), m: (np.array([[0.1, 0.1], [0.2, 0.2] ]), np.array([[0.1, 0.1], [0.2, 0.2]])) }) self.assertEqual(res_g[0].shape, (2, 2)) self.assertEqual(res_s[0].shape, (2, 2)) self.assertEqual(res_s[1].shape, (2, 2)) self.assertAllClose( res_g, ([[0.94685763, 0.94685763], [0.99480951, 0.99480951]], )) self.assertAllClose( res_s, ([[0.94685763, 0.94685763], [0.99480951, 0.99480951] ], [[0.80049908, 0.80049908], [0.97574311, 0.97574311]]))
def testGrid2BasicRNNCellWithRelu(self): with self.test_session() as sess: with variable_scope.variable_scope('root', initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) m = (array_ops.zeros([1, 2]), ) cell = grid_rnn_cell.Grid2BasicRNNCell( 2, non_recurrent_fn=nn_ops.relu) self.assertEqual(cell.state_size, (2, )) g, s = cell(x, m) self.assertEqual(g[0].get_shape(), (1, 2)) self.assertEqual(s[0].get_shape(), (1, 2)) sess.run([variables.global_variables_initializer()]) res_g, res_s = sess.run([g, s], {x: np.array([[1., 1.]]), m: np.array([[0.1, 0.1]])}) self.assertEqual(res_g[0].shape, (1, 2)) self.assertEqual(res_s[0].shape, (1, 2)) self.assertAllClose(res_g, ([[1.80049896, 1.80049896]], )) self.assertAllClose(res_s, ([[0.80049896, 0.80049896]], ))