def testGetMemoryRegion(self):
        """Testing the function that gathers the flanged memory region."""
        np.set_printoptions(threshold=np.inf)
        batch_size = 2
        num_heads = 2
        height = 4
        width = 6
        depth = 3
        query_shape = (2, 3)
        memory_flange = (1, 1)

        x = np.random.rand(batch_size, num_heads, height, width, depth)
        y = np.reshape(x, (batch_size, num_heads, -1, depth))
        zeros = np.zeros((depth), dtype=np.float32)
        five_zeros = np.array([zeros] * 5)
        seven_zeros = np.array([zeros] * 7)
        two_zeros = np.array([zeros] * 2)
        zeros = np.array([zeros])

        correct_x_flange = [
            [[
                seven_zeros,
                np.concatenate((five_zeros, y[0, 0, [2, 8]]), axis=0),
                np.concatenate((zeros, y[0, 0, [6, 7, 8, 9]], two_zeros),
                               axis=0),
                np.concatenate(
                    (y[0, 0, [8, 9, 10, 11]], zeros, y[0, 0, [14, 20]]),
                    axis=0)
            ],
             [
                 seven_zeros,
                 np.concatenate((five_zeros, y[0, 1, [2, 8]]), axis=0),
                 np.concatenate((zeros, y[0, 1, [6, 7, 8, 9]], two_zeros),
                                axis=0),
                 np.concatenate(
                     (y[0, 1, [8, 9, 10, 11]], zeros, y[0, 1, [14, 20]]),
                     axis=0)
             ]],
            [[
                seven_zeros,
                np.concatenate((five_zeros, y[1, 0, [2, 8]]), axis=0),
                np.concatenate((zeros, y[1, 0, [6, 7, 8, 9]], two_zeros),
                               axis=0),
                np.concatenate(
                    (y[1, 0, [8, 9, 10, 11]], zeros, y[1, 0, [14, 20]]),
                    axis=0)
            ],
             [
                 seven_zeros,
                 np.concatenate((five_zeros, y[1, 1, [2, 8]]), axis=0),
                 np.concatenate((zeros, y[1, 1, [6, 7, 8, 9]], two_zeros),
                                axis=0),
                 np.concatenate(
                     (y[1, 1, [8, 9, 10, 11]], zeros, y[1, 1, [14, 20]]),
                     axis=0)
             ]]
        ]
        correct_x_flange = np.array(correct_x_flange)
        correct_x_center = [[[
            y[0, 0, [0, 1, 2, 6, 7, 8]], y[0, 0, [3, 4, 5, 9, 10, 11]],
            y[0, 0, [12, 13, 14, 18, 19, 20]], y[0, 0,
                                                 [15, 16, 17, 21, 22, 23]]
        ],
                             [
                                 y[0, 1, [0, 1, 2, 6, 7, 8]],
                                 y[0, 1, [3, 4, 5, 9, 10, 11]],
                                 y[0, 1, [12, 13, 14, 18, 19, 20]],
                                 y[0, 1, [15, 16, 17, 21, 22, 23]]
                             ]],
                            [[
                                y[1, 0, [0, 1, 2, 6, 7, 8]],
                                y[1, 0, [3, 4, 5, 9, 10, 11]],
                                y[1, 0, [12, 13, 14, 18, 19, 20]],
                                y[1, 0, [15, 16, 17, 21, 22, 23]]
                            ],
                             [
                                 y[1, 1, [0, 1, 2, 6, 7, 8]],
                                 y[1, 1, [3, 4, 5, 9, 10, 11]],
                                 y[1, 1, [12, 13, 14, 18, 19, 20]],
                                 y[1, 1, [15, 16, 17, 21, 22, 23]]
                             ]]]
        correct_x_center = np.array(correct_x_center)
        with self.test_session() as session:
            x_indices = common_attention.gather_indices_2d(
                x, query_shape, query_shape)
            x_flange, x_center = common_attention.get_memory_region(
                tf.constant(x, dtype=tf.float32), query_shape, memory_flange,
                x_indices)
            session.run(tf.global_variables_initializer())
            [x_flange, x_center] = session.run([x_flange, x_center])
        self.assertAllClose(correct_x_flange, x_flange)
        self.assertAllClose(correct_x_center, x_center)
  def testGetMemoryRegion(self):
    """Testing the function that gathers the flanged memory region."""
    np.set_printoptions(threshold=np.inf)
    batch_size = 2
    num_heads = 2
    height = 4
    width = 6
    depth = 3
    query_shape = (2, 3)
    memory_flange = (1, 1)

    x = np.random.rand(batch_size, num_heads, height, width, depth)
    y = np.reshape(x, (batch_size, num_heads, -1, depth))
    zeros = np.zeros((depth), dtype=np.float32)
    five_zeros = np.array([zeros]*5)
    seven_zeros = np.array([zeros]*7)
    two_zeros = np.array([zeros]*2)
    zeros = np.array([zeros])

    correct_x_flange = [[[seven_zeros,
                          np.concatenate((five_zeros, y[0, 0, [2, 8]]),
                                         axis=0),
                          np.concatenate((zeros, y[0, 0, [6, 7, 8, 9]],
                                          two_zeros), axis=0),
                          np.concatenate((y[0, 0, [8, 9, 10, 11]], zeros,
                                          y[0, 0, [14, 20]]), axis=0)],
                         [seven_zeros,
                          np.concatenate((five_zeros, y[0, 1, [2, 8]]),
                                         axis=0),
                          np.concatenate((zeros, y[0, 1, [6, 7, 8, 9]],
                                          two_zeros), axis=0),
                          np.concatenate((y[0, 1, [8, 9, 10, 11]], zeros,
                                          y[0, 1, [14, 20]]), axis=0)]],
                        [[seven_zeros,
                          np.concatenate((five_zeros, y[1, 0, [2, 8]]),
                                         axis=0),
                          np.concatenate((zeros, y[1, 0, [6, 7, 8, 9]],
                                          two_zeros), axis=0),
                          np.concatenate((y[1, 0, [8, 9, 10, 11]], zeros,
                                          y[1, 0, [14, 20]]), axis=0)],
                         [seven_zeros,
                          np.concatenate((five_zeros, y[1, 1, [2, 8]]),
                                         axis=0),
                          np.concatenate((zeros, y[1, 1, [6, 7, 8, 9]],
                                          two_zeros), axis=0),
                          np.concatenate((y[1, 1, [8, 9, 10, 11]], zeros,
                                          y[1, 1, [14, 20]]), axis=0)]]]
    correct_x_flange = np.array(correct_x_flange)
    correct_x_center = [[[y[0, 0, [0, 1, 2, 6, 7, 8]],
                          y[0, 0, [3, 4, 5, 9, 10, 11]],
                          y[0, 0, [12, 13, 14, 18, 19, 20]],
                          y[0, 0, [15, 16, 17, 21, 22, 23]]],
                         [y[0, 1, [0, 1, 2, 6, 7, 8]],
                          y[0, 1, [3, 4, 5, 9, 10, 11]],
                          y[0, 1, [12, 13, 14, 18, 19, 20]],
                          y[0, 1, [15, 16, 17, 21, 22, 23]]]],
                        [[y[1, 0, [0, 1, 2, 6, 7, 8]],
                          y[1, 0, [3, 4, 5, 9, 10, 11]],
                          y[1, 0, [12, 13, 14, 18, 19, 20]],
                          y[1, 0, [15, 16, 17, 21, 22, 23]]],
                         [y[1, 1, [0, 1, 2, 6, 7, 8]],
                          y[1, 1, [3, 4, 5, 9, 10, 11]],
                          y[1, 1, [12, 13, 14, 18, 19, 20]],
                          y[1, 1, [15, 16, 17, 21, 22, 23]]]]]
    correct_x_center = np.array(correct_x_center)
    with self.test_session() as session:
      x_indices = common_attention.gather_indices_2d(
          x, query_shape, query_shape)
      x_flange, x_center = common_attention.get_memory_region(
          tf.constant(x, dtype=tf.float32),
          query_shape,
          memory_flange,
          x_indices)
      session.run(tf.global_variables_initializer())
      [x_flange, x_center] = session.run([x_flange, x_center])
    self.assertAllClose(correct_x_flange, x_flange)
    self.assertAllClose(correct_x_center, x_center)