def testGrid2LSTMCellLegacy(self): """Test for legacy case (when state_is_tuple=False).""" with self.test_session() as sess: with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 3]) m = array_ops.zeros([1, 8]) cell = grid_rnn_cell.Grid2LSTMCell(2, use_peepholes=True, state_is_tuple=False, output_is_tuple=False) self.assertEqual(cell.state_size, 8) g, s = cell(x, m) self.assertEqual(g.get_shape(), (1, 2)) self.assertEqual(s.get_shape(), (1, 8)) sess.run([variables.global_variables_initializer()]) res = sess.run( [g, s], { x: np.array([[1., 1., 1.]]), m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]]) }) self.assertEqual(res[0].shape, (1, 2)) self.assertEqual(res[1].shape, (1, 8)) self.assertAllClose(res[0], [[0.95686918, 0.95686918]]) self.assertAllClose(res[1], [[ 2.41515064, 2.41515064, 0.95686918, 0.95686918, 1.38917875, 1.49043763, 0.83884692, 0.86036491 ]])
def testGrid2LSTMCellWithRelu(self): with self.test_session() as sess: with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 3]) m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])), ) cell = grid_rnn_cell.Grid2LSTMCell( 2, use_peepholes=True, non_recurrent_fn=nn_ops.relu) self.assertEqual(cell.state_size, ((2, 2), )) g, s = cell(x, m) self.assertEqual(g[0].get_shape(), (1, 2)) self.assertEqual(s[0].c.get_shape(), (1, 2)) self.assertEqual(s[0].h.get_shape(), (1, 2)) sess.run([variables.global_variables_initializer()]) res_g, res_s = sess.run( [g, s], { x: np.array([[1., 1., 1.]]), m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])), ) }) self.assertEqual(res_g[0].shape, (1, 2)) self.assertAllClose(res_g[0], [[2.1831727, 2.1831727]]) self.assertAllClose(res_s, (([[0.92270052, 1.02325559] ], [[0.66159075, 0.70475441]]), ))
def testGrid2LSTMCellTied(self): with self.test_session(use_gpu=False) as sess: with variable_scope.variable_scope('root', initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 3]) m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])), (array_ops.zeros([1, 2]), array_ops.zeros([1, 2]))) cell = grid_rnn_cell.Grid2LSTMCell(2, tied=True, use_peepholes=True) self.assertEqual(cell.state_size, ((2, 2), (2, 2))) g, s = cell(x, m) self.assertEqual(g[0].get_shape(), (1, 2)) self.assertEqual(s[0].c.get_shape(), (1, 2)) self.assertEqual(s[0].h.get_shape(), (1, 2)) self.assertEqual(s[1].c.get_shape(), (1, 2)) self.assertEqual(s[1].h.get_shape(), (1, 2)) sess.run([variables.global_variables_initializer()]) res_g, res_s = sess.run( [g, s], {x: np.array([[1., 1., 1.]]), m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])), (np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))}) self.assertEqual(res_g[0].shape, (1, 2)) self.assertEqual(res_s[0].c.shape, (1, 2)) self.assertEqual(res_s[0].h.shape, (1, 2)) self.assertEqual(res_s[1].c.shape, (1, 2)) self.assertEqual(res_s[1].h.shape, (1, 2)) self.assertAllClose(res_g[0], [[0.95686918, 0.95686918]]) self.assertAllClose(res_s, (([[2.41515064, 2.41515064]], [[0.95686918, 0.95686918]]), ([[1.38917875, 1.49043763]], [[0.83884692, 0.86036491]])))
def testGrid2LSTMCellReLUWithRNN(self): batch_size = 3 input_size = 5 max_length = 6 # unrolled up to this length num_units = 2 with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)): cell = grid_rnn_cell.Grid2LSTMCell( num_units=num_units, non_recurrent_fn=nn_ops.relu) inputs = max_length * [ array_ops.placeholder( dtypes.float32, shape=(batch_size, input_size)) ] outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32) self.assertEqual(len(outputs), len(inputs)) self.assertEqual(state.get_shape(), (batch_size, 4)) for out, inp in zip(outputs, inputs): self.assertEqual(out.get_shape()[0], inp.get_shape()[0]) self.assertEqual(out.get_shape()[1], num_units) self.assertEqual(out.dtype, inp.dtype) with self.test_session() as sess: sess.run(variables.global_variables_initializer()) input_value = np.ones((batch_size, input_size)) values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value}) for v in values: self.assertTrue(np.all(np.isfinite(v)))
def testGrid2LSTMCellWithRNNAndDynamicBatchSize(self): """Test for #4296 """ input_size = 5 max_length = 6 # unrolled up to this length num_units = 2 with variable_scope.variable_scope('root', initializer=init_ops.constant_initializer(0.5)): cell = grid_rnn_cell.Grid2LSTMCell(num_units=num_units) inputs = max_length * [ array_ops.placeholder( dtypes.float32, shape=(None, input_size)) ] outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32) self.assertEqual(len(outputs), len(inputs)) for out, inp in zip(outputs, inputs): self.assertEqual(len(out), 1) self.assertTrue(out[0].get_shape()[0].value is None) self.assertEqual(out[0].get_shape()[1], num_units) self.assertEqual(out[0].dtype, inp.dtype) with self.test_session() as sess: sess.run(variables.global_variables_initializer()) input_value = np.ones((3, input_size)) values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value}) for tp in values[:-1]: for v in tp: self.assertTrue(np.all(np.isfinite(v))) for tp in values[-1]: for st in tp: for v in st: self.assertTrue(np.all(np.isfinite(v)))