def test_cell_shapes(self):
        """Check that all the NeuralStackCell tensor shapes are correct.
    """
        batch_size = 5
        embedding_size = 4
        memory_size = 12
        num_units = 8

        deque = neural_stack.NeuralDequeCell(num_units, memory_size,
                                             embedding_size)
        deque.build(None)

        self.assertEqual([1, 1, memory_size, memory_size],
                         deque.get_read_mask(0).shape)
        self.assertEqual([1, 1, memory_size, memory_size],
                         deque.get_read_mask(1).shape)

        deque_input = tf.zeros([batch_size, 1, embedding_size],
                               dtype=tf.float32)
        zero_state = deque.zero_state(batch_size, tf.float32)
        (outputs, (deque_next_state)) = deque.call(deque_input, zero_state)

        # Make sure that deque output shapes match deque input shapes
        self.assertEqual(outputs.shape, deque_input.shape)

        assert_cell_shapes(self, deque_next_state, zero_state)
  def test_enqueue_dequeue(self):
    """Test enqueueing a dequeueing from a NeuralDequeCell.

    The sequence of operations is:
      enqueue_bottom([1.0, 0.0, 0.0, 0.0])
      enqueue_bottom([0.0, 1.0, 0.0, 0.0])
      enqueue_bottom([0.0, 0.0, 1.0, 0.0])
      enqueue_top([0.0, 0.0, 0.0, 1.0])
      dequeue_top()
      dequeue_top()
    """
    input_values = np.array([[[[1.0, 0.0, 0.0, 0.0]],
                              [[0.0, 1.0, 0.0, 0.0]],
                              [[0.0, 0.0, 1.0, 0.0]],
                              [[0.0, 0.0, 0.0, 1.0]],
                              [[0.0, 0.0, 0.0, 0.0]],
                              [[0.0, 0.0, 0.0, 0.0]]]])

    expected_values = np.array([[[0.0, 0.0, 0.0, 0.0],
                                 [0.0, 0.0, 0.0, 0.0],
                                 [0.0, 0.0, 0.0, 0.0],
                                 [0.0, 0.0, 1.0, 0.0],
                                 [0.0, 1.0, 0.0, 0.0],
                                 [1.0, 0.0, 0.0, 0.0],
                                 [0.0, 0.0, 0.0, 0.0],
                                 [0.0, 0.0, 0.0, 0.0],
                                 [0.0, 0.0, 0.0, 0.0],
                                 [0.0, 0.0, 0.0, 1.0],
                                 [0.0, 0.0, 0.0, 0.0],
                                 [0.0, 0.0, 0.0, 0.0]]])

    expected_read_strengths = np.array([[[[0.0], [0.0], [0.0], [1.0], [1.0],
                                          [0.0], [0.0], [0.0], [0.0], [0.0],
                                          [0.0], [0.0]]]])

    expected_write_strengths = np.array([[[[0.0], [0.0], [0.0], [0.0], [0.0],
                                           [0.0], [0.0], [0.0], [0.0], [0.0],
                                           [0.0], [1.0]],
                                          [[1.0], [0.0], [0.0], [0.0], [0.0],
                                           [0.0], [0.0], [0.0], [0.0], [0.0],
                                           [0.0], [0.0]]]])

    expected_read_values = np.array([[[0.0, 0.0, 1.0, 0.0],
                                      [0.0, 1.0, 0.0, 0.0]]])

    batch_size = input_values.shape[0]
    memory_size = input_values.shape[1] * 2
    embedding_size = input_values.shape[3]
    num_units = 8

    deque = neural_stack.NeuralDequeCell(num_units, memory_size, embedding_size)
    rnn_input = tf.constant(input_values, dtype=tf.float32)

    deque_zero_state = tf.zeros([batch_size, num_units])
    controller_outputs = deque.call_controller(None, None,
                                               deque_zero_state,
                                               batch_size)
    assert_controller_shapes(self, controller_outputs,
                             deque.get_controller_shape(batch_size))

    (outputs, state) = tf.nn.dynamic_rnn(cell=deque,
                                         inputs=rnn_input,
                                         time_major=False,
                                         dtype=tf.float32)

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      _, state_vals = sess.run([outputs, state])
      (_, read_values,
       memory_values,
       read_strengths,
       write_strengths) = state_vals

      print(read_values)
      self.assertAllClose(expected_values, memory_values)
      self.assertAllClose(expected_write_strengths, write_strengths)
      self.assertAllClose(expected_read_strengths, read_strengths)
      self.assertAllClose(expected_read_values, read_values)