def test_push_pop(self):
    """Test pushing a popping from a NeuralStackCell.
    """
    input_values = np.array([[[[1.0, 0.0, 0.0]],
                              [[0.0, 1.0, 0.0]],
                              [[0.0, 0.0, 1.0]]]])

    expected_values = np.array([[[1.0, 0.0, 0.0],
                                 [0.0, 1.0, 0.0],
                                 [0.0, 0.0, 1.0],
                                 [0.0, 0.0, 0.0],
                                 [0.0, 0.0, 0.0],
                                 [0.0, 0.0, 0.0]]])
    expected_read_strengths = np.array([
        [[[1.0], [0.0], [0.0], [0.0], [0.0], [0.0]]]])
    expected_write_strengths = np.array([
        [[[0.0], [0.0], [0.], [1.0], [0.0], [0.0]]]])
    expected_top = np.array([[[1.0, 0.0, 0.0]]])

    stack = neural_stack.NeuralStackCell(8, 6, 3)
    stack_input = tf.constant(input_values, dtype=tf.float32)
    (outputs, state) = tf.nn.dynamic_rnn(cell=stack,
                                         inputs=stack_input,
                                         time_major=False,
                                         dtype=tf.float32)

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      _, state_vals = sess.run([outputs, state])
      (_, stack_top, values, read_strengths, write_strengths) = state_vals

      self.assertAllClose(expected_top, stack_top)
      self.assertAllClose(expected_values, values)
      self.assertAllClose(expected_read_strengths, read_strengths)
      self.assertAllClose(expected_write_strengths, write_strengths)
  def test_push_pop(self):
    """Test pushing a popping from a NeuralStackCell.

    The sequence of operations is:
      push([1.0, 0.0, 0.0])
      push([0.0, 1.0, 0.0])
      pop()
    """
    input_values = np.array([[[[1.0, 0.0, 0.0]],
                              [[0.0, 1.0, 0.0]],
                              [[0.0, 0.0, 1.0]]]])

    expected_values = np.array([[[1.0, 0.0, 0.0],
                                 [0.0, 1.0, 0.0],
                                 [0.0, 0.0, 1.0],
                                 [0.0, 0.0, 0.0],
                                 [0.0, 0.0, 0.0],
                                 [0.0, 0.0, 0.0]]])
    expected_read_strengths = np.array([
        [[[1.0], [0.0], [0.0], [0.0], [0.0], [0.0]]]])
    expected_write_strengths = np.array([
        [[[0.0], [0.0], [0.], [1.0], [0.0], [0.0]]]])
    expected_top = np.array([[[1.0, 0.0, 0.0]]])

    batch_size = 1
    embedding_size = 3
    memory_size = 6
    num_units = 8

    stack = neural_stack.NeuralStackCell(num_units, memory_size, embedding_size)
    stack_input = tf.constant(input_values, dtype=tf.float32)

    stack_zero_state = tf.zeros([batch_size, num_units])
    controller_outputs = stack.call_controller(None, None, stack_zero_state,
                                               batch_size)
    assert_controller_shapes(self, controller_outputs,
                             stack.get_controller_shape(batch_size))

    (outputs, state) = tf.nn.dynamic_rnn(cell=stack,
                                         inputs=stack_input,
                                         time_major=False,
                                         dtype=tf.float32)

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      _, state_vals = sess.run([outputs, state])
      (_, stack_top, values, read_strengths, write_strengths) = state_vals

      self.assertAllClose(expected_values, values)
      self.assertAllClose(expected_write_strengths, write_strengths)
      self.assertAllClose(expected_read_strengths, read_strengths)
      self.assertAllClose(expected_top, stack_top)
  def test_cell_shapes(self):
    """Check that all the NeuralStackCell tensor shapes are correct.
    """
    batch_size = 5
    embedding_size = 3
    memory_size = 6
    num_units = 8

    stack = neural_stack.NeuralStackCell(num_units, memory_size, embedding_size)
    stack.build(None)

    self.assertEqual([1, 1, memory_size, memory_size],
                     stack.get_read_mask(0).shape)

    stack_input = tf.zeros([batch_size, 1, embedding_size], dtype=tf.float32)
    zero_state = stack.zero_state(batch_size, tf.float32)
    (outputs, (stack_next_state)) = stack.call(stack_input, zero_state)

    # Make sure that stack output shapes match stack input shapes
    self.assertEqual(outputs.shape, stack_input.shape)

    assert_cell_shapes(self, stack_next_state, zero_state)
  def test_controller_shapes(self):
    """Check that all the NeuralStackCell tensor shapes are correct.
    """

    batch_size = 5
    embedding_size = 3
    memory_size = 6
    num_units = 8

    stack = neural_stack.NeuralStackCell(num_units, memory_size, embedding_size)

    stack.build(None)

    self.assertEqual([1, embedding_size], stack.output_size)
    self.assertEqual([1, memory_size, memory_size], stack.read_mask.shape)
    self.assertEqual([3, 3, 1, 1], stack.write_shift_convolution.shape)

    stack_input = tf.zeros([batch_size, 1, embedding_size], dtype=tf.float32)

    zero_state = stack.zero_state(batch_size, tf.float32)

    (controller_state,
     previous_values,
     memory_values,
     read_strengths,
     write_strengths) = zero_state

    self.assertEqual([batch_size, num_units], controller_state.shape)
    self.assertEqual([batch_size, 1, embedding_size], previous_values.shape)
    self.assertEqual([batch_size, memory_size, embedding_size],
                     memory_values.shape)
    self.assertEqual([batch_size, 1, memory_size, 1], read_strengths.shape)
    self.assertEqual([batch_size, 1, memory_size, 1], write_strengths.shape)

    rnn_input = tf.concat([
        tf.reshape(
            previous_values,
            shape=[batch_size, embedding_size]),
        tf.reshape(
            stack_input,
            shape=[batch_size, embedding_size])
    ], axis=1)
    self.assertEqual([batch_size, 2 * embedding_size], rnn_input.shape)

    (push_strengths,
     pop_strengths,
     new_values,
     outputs,
     controller_next_state) = stack.call_controller(rnn_input,
                                                    controller_state,
                                                    batch_size)

    self.assertEqual([batch_size, 1, 1, 1], push_strengths.shape)
    self.assertEqual([batch_size, 1, 1, 1], pop_strengths.shape)
    self.assertEqual([batch_size, 1, embedding_size], new_values.shape)
    self.assertEqual([batch_size, 1, embedding_size], outputs.shape)
    self.assertEqual([batch_size, num_units], controller_next_state.shape)

    (outputs, (controller_next_state,
               read_values,
               next_memory_values,
               next_read_strengths,
               next_write_strengths)) = stack.call(stack_input, zero_state)

    self.assertEqual([batch_size, 1, embedding_size], outputs.shape)
    self.assertEqual([batch_size, num_units], controller_next_state.shape)
    self.assertEqual([batch_size, 1, embedding_size], read_values.shape)
    self.assertEqual([batch_size, memory_size, embedding_size],
                     next_memory_values.shape)
    self.assertEqual([batch_size, 1, memory_size, 1], next_read_strengths.shape)
    self.assertEqual([batch_size, 1, memory_size, 1],
                     next_write_strengths.shape)

    # Make sure that stack output shapes match stack input shapes
    self.assertEqual(controller_next_state.shape, controller_state.shape)
    self.assertEqual(read_values.shape, previous_values.shape)
    self.assertEqual(next_memory_values.shape, memory_values.shape)
    self.assertEqual(next_read_strengths.shape, read_strengths.shape)
    self.assertEqual(next_write_strengths.shape, write_strengths.shape)