def test_push_pop(self): """Test pushing a popping from a NeuralStackCell. """ input_values = np.array([[[[1.0, 0.0, 0.0]], [[0.0, 1.0, 0.0]], [[0.0, 0.0, 1.0]]]]) expected_values = np.array([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]) expected_read_strengths = np.array([ [[[1.0], [0.0], [0.0], [0.0], [0.0], [0.0]]]]) expected_write_strengths = np.array([ [[[0.0], [0.0], [0.], [1.0], [0.0], [0.0]]]]) expected_top = np.array([[[1.0, 0.0, 0.0]]]) stack = neural_stack.NeuralStackCell(8, 6, 3) stack_input = tf.constant(input_values, dtype=tf.float32) (outputs, state) = tf.nn.dynamic_rnn(cell=stack, inputs=stack_input, time_major=False, dtype=tf.float32) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) _, state_vals = sess.run([outputs, state]) (_, stack_top, values, read_strengths, write_strengths) = state_vals self.assertAllClose(expected_top, stack_top) self.assertAllClose(expected_values, values) self.assertAllClose(expected_read_strengths, read_strengths) self.assertAllClose(expected_write_strengths, write_strengths)
def test_push_pop(self): """Test pushing a popping from a NeuralStackCell. The sequence of operations is: push([1.0, 0.0, 0.0]) push([0.0, 1.0, 0.0]) pop() """ input_values = np.array([[[[1.0, 0.0, 0.0]], [[0.0, 1.0, 0.0]], [[0.0, 0.0, 1.0]]]]) expected_values = np.array([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]) expected_read_strengths = np.array([ [[[1.0], [0.0], [0.0], [0.0], [0.0], [0.0]]]]) expected_write_strengths = np.array([ [[[0.0], [0.0], [0.], [1.0], [0.0], [0.0]]]]) expected_top = np.array([[[1.0, 0.0, 0.0]]]) batch_size = 1 embedding_size = 3 memory_size = 6 num_units = 8 stack = neural_stack.NeuralStackCell(num_units, memory_size, embedding_size) stack_input = tf.constant(input_values, dtype=tf.float32) stack_zero_state = tf.zeros([batch_size, num_units]) controller_outputs = stack.call_controller(None, None, stack_zero_state, batch_size) assert_controller_shapes(self, controller_outputs, stack.get_controller_shape(batch_size)) (outputs, state) = tf.nn.dynamic_rnn(cell=stack, inputs=stack_input, time_major=False, dtype=tf.float32) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) _, state_vals = sess.run([outputs, state]) (_, stack_top, values, read_strengths, write_strengths) = state_vals self.assertAllClose(expected_values, values) self.assertAllClose(expected_write_strengths, write_strengths) self.assertAllClose(expected_read_strengths, read_strengths) self.assertAllClose(expected_top, stack_top)
def test_cell_shapes(self): """Check that all the NeuralStackCell tensor shapes are correct. """ batch_size = 5 embedding_size = 3 memory_size = 6 num_units = 8 stack = neural_stack.NeuralStackCell(num_units, memory_size, embedding_size) stack.build(None) self.assertEqual([1, 1, memory_size, memory_size], stack.get_read_mask(0).shape) stack_input = tf.zeros([batch_size, 1, embedding_size], dtype=tf.float32) zero_state = stack.zero_state(batch_size, tf.float32) (outputs, (stack_next_state)) = stack.call(stack_input, zero_state) # Make sure that stack output shapes match stack input shapes self.assertEqual(outputs.shape, stack_input.shape) assert_cell_shapes(self, stack_next_state, zero_state)
def test_controller_shapes(self): """Check that all the NeuralStackCell tensor shapes are correct. """ batch_size = 5 embedding_size = 3 memory_size = 6 num_units = 8 stack = neural_stack.NeuralStackCell(num_units, memory_size, embedding_size) stack.build(None) self.assertEqual([1, embedding_size], stack.output_size) self.assertEqual([1, memory_size, memory_size], stack.read_mask.shape) self.assertEqual([3, 3, 1, 1], stack.write_shift_convolution.shape) stack_input = tf.zeros([batch_size, 1, embedding_size], dtype=tf.float32) zero_state = stack.zero_state(batch_size, tf.float32) (controller_state, previous_values, memory_values, read_strengths, write_strengths) = zero_state self.assertEqual([batch_size, num_units], controller_state.shape) self.assertEqual([batch_size, 1, embedding_size], previous_values.shape) self.assertEqual([batch_size, memory_size, embedding_size], memory_values.shape) self.assertEqual([batch_size, 1, memory_size, 1], read_strengths.shape) self.assertEqual([batch_size, 1, memory_size, 1], write_strengths.shape) rnn_input = tf.concat([ tf.reshape( previous_values, shape=[batch_size, embedding_size]), tf.reshape( stack_input, shape=[batch_size, embedding_size]) ], axis=1) self.assertEqual([batch_size, 2 * embedding_size], rnn_input.shape) (push_strengths, pop_strengths, new_values, outputs, controller_next_state) = stack.call_controller(rnn_input, controller_state, batch_size) self.assertEqual([batch_size, 1, 1, 1], push_strengths.shape) self.assertEqual([batch_size, 1, 1, 1], pop_strengths.shape) self.assertEqual([batch_size, 1, embedding_size], new_values.shape) self.assertEqual([batch_size, 1, embedding_size], outputs.shape) self.assertEqual([batch_size, num_units], controller_next_state.shape) (outputs, (controller_next_state, read_values, next_memory_values, next_read_strengths, next_write_strengths)) = stack.call(stack_input, zero_state) self.assertEqual([batch_size, 1, embedding_size], outputs.shape) self.assertEqual([batch_size, num_units], controller_next_state.shape) self.assertEqual([batch_size, 1, embedding_size], read_values.shape) self.assertEqual([batch_size, memory_size, embedding_size], next_memory_values.shape) self.assertEqual([batch_size, 1, memory_size, 1], next_read_strengths.shape) self.assertEqual([batch_size, 1, memory_size, 1], next_write_strengths.shape) # Make sure that stack output shapes match stack input shapes self.assertEqual(controller_next_state.shape, controller_state.shape) self.assertEqual(read_values.shape, previous_values.shape) self.assertEqual(next_memory_values.shape, memory_values.shape) self.assertEqual(next_read_strengths.shape, read_strengths.shape) self.assertEqual(next_write_strengths.shape, write_strengths.shape)