Esempio n. 1
0
 def test2dGatherAndScatterInvertibility(self):
   """2d gather and scatter invertibility test."""
   batch_size = 2
   num_heads = 2
   height = 4
   width = 6
   depth = 8
   query_shape = (2, 3)
   x = np.random.rand(batch_size, num_heads, height, width, depth)
   x_indices = common_attention.gather_indices_2d(
       x, query_shape, query_shape)
   gathered_x = common_attention.gather_blocks_2d(x, x_indices)
   x_shape = tf.constant([batch_size, num_heads, height, width, depth])
   scattered_x = common_attention.scatter_blocks_2d(
       gathered_x, x_indices, x_shape)
   res = self.evaluate(scattered_x)
   self.assertAllClose(x, res)
 def test2dGatherAndScatterInvertibility(self):
   """2d gather and scatter invertibility test."""
   batch_size = 2
   num_heads = 2
   height = 4
   width = 6
   depth = 8
   query_shape = (2, 3)
   x = np.random.rand(batch_size, num_heads, height, width, depth)
   with self.test_session() as session:
     x_indices = common_attention.gather_indices_2d(
         x, query_shape, query_shape)
     gathered_x = common_attention.gather_blocks_2d(x, x_indices)
     x_shape = tf.constant([batch_size, num_heads, height, width, depth])
     scattered_x = common_attention.scatter_blocks_2d(
         gathered_x, x_indices, x_shape)
     session.run(tf.global_variables_initializer())
     res = session.run(scattered_x)
   self.assertAllClose(x, res)
 def test2dGatherAndScatterInvertibility(self):
   """2d gather and scatter invertibility test."""
   batch_size = 2
   num_heads = 2
   height = 4
   width = 6
   depth = 8
   query_shape = (2, 3)
   x = np.random.rand(batch_size, num_heads, height, width, depth)
   with self.test_session() as session:
     x_indices = common_attention.gather_indices_2d(
         x, query_shape, query_shape)
     gathered_x = common_attention.gather_blocks_2d(x, x_indices)
     x_shape = tf.constant([batch_size, num_heads, height, width, depth])
     scattered_x = common_attention.scatter_blocks_2d(
         gathered_x, x_indices, x_shape)
     session.run(tf.global_variables_initializer())
     res = session.run(scattered_x)
   self.assertAllClose(x, res)
    def test2dGather(self):
        """Testing 2d index gather and block gather functions."""
        batch_size = 2
        num_heads = 2
        height = 4
        width = 6
        depth = 8
        query_shape = (2, 3)
        x = np.random.rand(batch_size, num_heads, height, width, depth)
        y = np.reshape(x, (batch_size, num_heads, -1, depth))
        correct_indices = [[0, 1, 2, 6, 7, 8], [3, 4, 5, 9, 10, 11],
                           [12, 13, 14, 18, 19, 20], [15, 16, 17, 21, 22, 23]]
        correct_gathered_x = [
            [[
                y[0, 0, correct_indices[0]], y[0, 0, correct_indices[1]],
                y[0, 0, correct_indices[2]], y[0, 0, correct_indices[3]]
            ],
             [
                 y[0, 1, correct_indices[0]], y[0, 1, correct_indices[1]],
                 y[0, 1, correct_indices[2]], y[0, 1, correct_indices[3]]
             ]],
            [[
                y[1, 0, correct_indices[0]], y[1, 0, correct_indices[1]],
                y[1, 0, correct_indices[2]], y[1, 0, correct_indices[3]]
            ],
             [
                 y[1, 1, correct_indices[0]], y[1, 1, correct_indices[1]],
                 y[1, 1, correct_indices[2]], y[1, 1, correct_indices[3]]
             ]]
        ]

        with self.test_session() as session:
            x_indices = common_attention.gather_indices_2d(
                x, query_shape, query_shape)
            gathered_x = common_attention.gather_blocks_2d(x, x_indices)
            x_indices, gathered_x = session.run([x_indices, gathered_x])
        self.assertAllEqual(correct_indices, x_indices)
        self.assertAllClose(correct_gathered_x, gathered_x)
  def test2dGather(self):
    """Testing 2d index gather and block gather functions."""
    batch_size = 2
    num_heads = 2
    height = 4
    width = 6
    depth = 8
    query_shape = (2, 3)
    x = np.random.rand(batch_size, num_heads, height, width, depth)
    y = np.reshape(x, (batch_size, num_heads, -1, depth))
    correct_indices = [[0, 1, 2, 6, 7, 8],
                       [3, 4, 5, 9, 10, 11],
                       [12, 13, 14, 18, 19, 20],
                       [15, 16, 17, 21, 22, 23]]
    correct_gathered_x = [[[y[0, 0, correct_indices[0]],
                            y[0, 0, correct_indices[1]],
                            y[0, 0, correct_indices[2]],
                            y[0, 0, correct_indices[3]]],
                           [y[0, 1, correct_indices[0]],
                            y[0, 1, correct_indices[1]],
                            y[0, 1, correct_indices[2]],
                            y[0, 1, correct_indices[3]]]],
                          [[y[1, 0, correct_indices[0]],
                            y[1, 0, correct_indices[1]],
                            y[1, 0, correct_indices[2]],
                            y[1, 0, correct_indices[3]]],
                           [y[1, 1, correct_indices[0]],
                            y[1, 1, correct_indices[1]],
                            y[1, 1, correct_indices[2]],
                            y[1, 1, correct_indices[3]]]]]

    with self.test_session() as session:
      x_indices = common_attention.gather_indices_2d(
          x, query_shape, query_shape)
      gathered_x = common_attention.gather_blocks_2d(x, x_indices)
      x_indices, gathered_x = session.run([x_indices, gathered_x])
    self.assertAllEqual(correct_indices, x_indices)
    self.assertAllClose(correct_gathered_x, gathered_x)
    def testGetShiftedCenterBlocks(self):
        """Testing the function that gathers the flanged memory region."""
        np.set_printoptions(threshold=np.inf)
        batch_size = 2
        num_heads = 2
        height = 4
        width = 6
        depth = 3
        query_shape = (2, 3)

        x = np.random.rand(batch_size, num_heads, height, width, depth)
        y = np.reshape(x, (batch_size, num_heads, -1, depth))
        zeros = np.zeros((depth), dtype=np.float32)
        zeros = np.array([zeros])

        correct_gathered_x = [[[
            np.concatenate((zeros, y[0, 0, [0, 1, 2, 6, 7]]), axis=0),
            np.concatenate((zeros, y[0, 0, [3, 4, 5, 9, 10]]), axis=0),
            np.concatenate((zeros, y[0, 0, [12, 13, 14, 18, 19]]), axis=0),
            np.concatenate((zeros, y[0, 0, [15, 16, 17, 21, 22]]), axis=0)
        ],
                               [
                                   np.concatenate(
                                       (zeros, y[0, 1, [0, 1, 2, 6, 7]]),
                                       axis=0),
                                   np.concatenate(
                                       (zeros, y[0, 1, [3, 4, 5, 9, 10]]),
                                       axis=0),
                                   np.concatenate(
                                       (zeros, y[0, 1, [12, 13, 14, 18, 19]]),
                                       axis=0),
                                   np.concatenate(
                                       (zeros, y[0, 1, [15, 16, 17, 21, 22]]),
                                       axis=0)
                               ]],
                              [[
                                  np.concatenate(
                                      (zeros, y[1, 0, [0, 1, 2, 6, 7]]),
                                      axis=0),
                                  np.concatenate(
                                      (zeros, y[1, 0, [3, 4, 5, 9, 10]]),
                                      axis=0),
                                  np.concatenate(
                                      (zeros, y[1, 0, [12, 13, 14, 18, 19]]),
                                      axis=0),
                                  np.concatenate(
                                      (zeros, y[1, 0, [15, 16, 17, 21, 22]]),
                                      axis=0)
                              ],
                               [
                                   np.concatenate(
                                       (zeros, y[1, 1, [0, 1, 2, 6, 7]]),
                                       axis=0),
                                   np.concatenate(
                                       (zeros, y[1, 1, [3, 4, 5, 9, 10]]),
                                       axis=0),
                                   np.concatenate(
                                       (zeros, y[1, 1, [12, 13, 14, 18, 19]]),
                                       axis=0),
                                   np.concatenate(
                                       (zeros, y[1, 1, [15, 16, 17, 21, 22]]),
                                       axis=0)
                               ]]]
        correct_gathered_x = np.array(correct_gathered_x)
        with self.test_session() as session:
            x_indices = common_attention.gather_indices_2d(
                x, query_shape, query_shape)
            gathered_x = common_attention.get_shifted_center_blocks(
                tf.constant(x, dtype=tf.float32), x_indices)
            session.run(tf.global_variables_initializer())
            x_indices, gathered_x = session.run([x_indices, gathered_x])
        self.assertAllClose(correct_gathered_x, gathered_x)
    def testGetMemoryRegion(self):
        """Testing the function that gathers the flanged memory region."""
        np.set_printoptions(threshold=np.inf)
        batch_size = 2
        num_heads = 2
        height = 4
        width = 6
        depth = 3
        query_shape = (2, 3)
        memory_flange = (1, 1)

        x = np.random.rand(batch_size, num_heads, height, width, depth)
        y = np.reshape(x, (batch_size, num_heads, -1, depth))
        zeros = np.zeros((depth), dtype=np.float32)
        five_zeros = np.array([zeros] * 5)
        seven_zeros = np.array([zeros] * 7)
        two_zeros = np.array([zeros] * 2)
        zeros = np.array([zeros])

        correct_x_flange = [
            [[
                seven_zeros,
                np.concatenate((five_zeros, y[0, 0, [2, 8]]), axis=0),
                np.concatenate((zeros, y[0, 0, [6, 7, 8, 9]], two_zeros),
                               axis=0),
                np.concatenate(
                    (y[0, 0, [8, 9, 10, 11]], zeros, y[0, 0, [14, 20]]),
                    axis=0)
            ],
             [
                 seven_zeros,
                 np.concatenate((five_zeros, y[0, 1, [2, 8]]), axis=0),
                 np.concatenate((zeros, y[0, 1, [6, 7, 8, 9]], two_zeros),
                                axis=0),
                 np.concatenate(
                     (y[0, 1, [8, 9, 10, 11]], zeros, y[0, 1, [14, 20]]),
                     axis=0)
             ]],
            [[
                seven_zeros,
                np.concatenate((five_zeros, y[1, 0, [2, 8]]), axis=0),
                np.concatenate((zeros, y[1, 0, [6, 7, 8, 9]], two_zeros),
                               axis=0),
                np.concatenate(
                    (y[1, 0, [8, 9, 10, 11]], zeros, y[1, 0, [14, 20]]),
                    axis=0)
            ],
             [
                 seven_zeros,
                 np.concatenate((five_zeros, y[1, 1, [2, 8]]), axis=0),
                 np.concatenate((zeros, y[1, 1, [6, 7, 8, 9]], two_zeros),
                                axis=0),
                 np.concatenate(
                     (y[1, 1, [8, 9, 10, 11]], zeros, y[1, 1, [14, 20]]),
                     axis=0)
             ]]
        ]
        correct_x_flange = np.array(correct_x_flange)
        correct_x_center = [[[
            y[0, 0, [0, 1, 2, 6, 7, 8]], y[0, 0, [3, 4, 5, 9, 10, 11]],
            y[0, 0, [12, 13, 14, 18, 19, 20]], y[0, 0,
                                                 [15, 16, 17, 21, 22, 23]]
        ],
                             [
                                 y[0, 1, [0, 1, 2, 6, 7, 8]],
                                 y[0, 1, [3, 4, 5, 9, 10, 11]],
                                 y[0, 1, [12, 13, 14, 18, 19, 20]],
                                 y[0, 1, [15, 16, 17, 21, 22, 23]]
                             ]],
                            [[
                                y[1, 0, [0, 1, 2, 6, 7, 8]],
                                y[1, 0, [3, 4, 5, 9, 10, 11]],
                                y[1, 0, [12, 13, 14, 18, 19, 20]],
                                y[1, 0, [15, 16, 17, 21, 22, 23]]
                            ],
                             [
                                 y[1, 1, [0, 1, 2, 6, 7, 8]],
                                 y[1, 1, [3, 4, 5, 9, 10, 11]],
                                 y[1, 1, [12, 13, 14, 18, 19, 20]],
                                 y[1, 1, [15, 16, 17, 21, 22, 23]]
                             ]]]
        correct_x_center = np.array(correct_x_center)
        with self.test_session() as session:
            x_indices = common_attention.gather_indices_2d(
                x, query_shape, query_shape)
            x_flange, x_center = common_attention.get_memory_region(
                tf.constant(x, dtype=tf.float32), query_shape, memory_flange,
                x_indices)
            session.run(tf.global_variables_initializer())
            [x_flange, x_center] = session.run([x_flange, x_center])
        self.assertAllClose(correct_x_flange, x_flange)
        self.assertAllClose(correct_x_center, x_center)
  def testGetShiftedCenterBlocks(self):
    """Testing the function that gathers the flanged memory region."""
    np.set_printoptions(threshold=np.inf)
    batch_size = 2
    num_heads = 2
    height = 4
    width = 6
    depth = 3
    query_shape = (2, 3)

    x = np.random.rand(batch_size, num_heads, height, width, depth)
    y = np.reshape(x, (batch_size, num_heads, -1, depth))
    zeros = np.zeros((depth), dtype=np.float32)
    zeros = np.array([zeros])

    correct_gathered_x = [[[np.concatenate((zeros, y[0, 0, [0, 1, 2, 6, 7]]),
                                           axis=0),
                            np.concatenate((zeros, y[0, 0, [3, 4, 5, 9, 10]]),
                                           axis=0),
                            np.concatenate((zeros,
                                            y[0, 0, [12, 13, 14, 18, 19]]),
                                           axis=0),
                            np.concatenate((zeros,
                                            y[0, 0, [15, 16, 17, 21, 22]]),
                                           axis=0)],
                           [np.concatenate((zeros, y[0, 1, [0, 1, 2, 6, 7]]),
                                           axis=0),
                            np.concatenate((zeros, y[0, 1, [3, 4, 5, 9, 10]]),
                                           axis=0),
                            np.concatenate((zeros,
                                            y[0, 1, [12, 13, 14, 18, 19]]),
                                           axis=0),
                            np.concatenate((zeros,
                                            y[0, 1, [15, 16, 17, 21, 22]]),
                                           axis=0)]],
                          [[np.concatenate((zeros, y[1, 0, [0, 1, 2, 6, 7]]),
                                           axis=0),
                            np.concatenate((zeros, y[1, 0, [3, 4, 5, 9, 10]]),
                                           axis=0),
                            np.concatenate((zeros,
                                            y[1, 0, [12, 13, 14, 18, 19]]),
                                           axis=0),
                            np.concatenate((zeros,
                                            y[1, 0, [15, 16, 17, 21, 22]]),
                                           axis=0)],
                           [np.concatenate((zeros, y[1, 1, [0, 1, 2, 6, 7]]),
                                           axis=0),
                            np.concatenate((zeros, y[1, 1, [3, 4, 5, 9, 10]]),
                                           axis=0),
                            np.concatenate((zeros,
                                            y[1, 1, [12, 13, 14, 18, 19]]),
                                           axis=0),
                            np.concatenate((zeros,
                                            y[1, 1, [15, 16, 17, 21, 22]]),
                                           axis=0)]]]
    correct_gathered_x = np.array(correct_gathered_x)
    with self.test_session() as session:
      x_indices = common_attention.gather_indices_2d(
          x, query_shape, query_shape)
      gathered_x = common_attention.get_shifted_center_blocks(
          tf.constant(x, dtype=tf.float32),
          x_indices)
      session.run(tf.global_variables_initializer())
      x_indices, gathered_x = session.run([x_indices, gathered_x])
    self.assertAllClose(correct_gathered_x, gathered_x)
  def testGetMemoryRegion(self):
    """Testing the function that gathers the flanged memory region."""
    np.set_printoptions(threshold=np.inf)
    batch_size = 2
    num_heads = 2
    height = 4
    width = 6
    depth = 3
    query_shape = (2, 3)
    memory_flange = (1, 1)

    x = np.random.rand(batch_size, num_heads, height, width, depth)
    y = np.reshape(x, (batch_size, num_heads, -1, depth))
    zeros = np.zeros((depth), dtype=np.float32)
    five_zeros = np.array([zeros]*5)
    seven_zeros = np.array([zeros]*7)
    two_zeros = np.array([zeros]*2)
    zeros = np.array([zeros])

    correct_x_flange = [[[seven_zeros,
                          np.concatenate((five_zeros, y[0, 0, [2, 8]]),
                                         axis=0),
                          np.concatenate((zeros, y[0, 0, [6, 7, 8, 9]],
                                          two_zeros), axis=0),
                          np.concatenate((y[0, 0, [8, 9, 10, 11]], zeros,
                                          y[0, 0, [14, 20]]), axis=0)],
                         [seven_zeros,
                          np.concatenate((five_zeros, y[0, 1, [2, 8]]),
                                         axis=0),
                          np.concatenate((zeros, y[0, 1, [6, 7, 8, 9]],
                                          two_zeros), axis=0),
                          np.concatenate((y[0, 1, [8, 9, 10, 11]], zeros,
                                          y[0, 1, [14, 20]]), axis=0)]],
                        [[seven_zeros,
                          np.concatenate((five_zeros, y[1, 0, [2, 8]]),
                                         axis=0),
                          np.concatenate((zeros, y[1, 0, [6, 7, 8, 9]],
                                          two_zeros), axis=0),
                          np.concatenate((y[1, 0, [8, 9, 10, 11]], zeros,
                                          y[1, 0, [14, 20]]), axis=0)],
                         [seven_zeros,
                          np.concatenate((five_zeros, y[1, 1, [2, 8]]),
                                         axis=0),
                          np.concatenate((zeros, y[1, 1, [6, 7, 8, 9]],
                                          two_zeros), axis=0),
                          np.concatenate((y[1, 1, [8, 9, 10, 11]], zeros,
                                          y[1, 1, [14, 20]]), axis=0)]]]
    correct_x_flange = np.array(correct_x_flange)
    correct_x_center = [[[y[0, 0, [0, 1, 2, 6, 7, 8]],
                          y[0, 0, [3, 4, 5, 9, 10, 11]],
                          y[0, 0, [12, 13, 14, 18, 19, 20]],
                          y[0, 0, [15, 16, 17, 21, 22, 23]]],
                         [y[0, 1, [0, 1, 2, 6, 7, 8]],
                          y[0, 1, [3, 4, 5, 9, 10, 11]],
                          y[0, 1, [12, 13, 14, 18, 19, 20]],
                          y[0, 1, [15, 16, 17, 21, 22, 23]]]],
                        [[y[1, 0, [0, 1, 2, 6, 7, 8]],
                          y[1, 0, [3, 4, 5, 9, 10, 11]],
                          y[1, 0, [12, 13, 14, 18, 19, 20]],
                          y[1, 0, [15, 16, 17, 21, 22, 23]]],
                         [y[1, 1, [0, 1, 2, 6, 7, 8]],
                          y[1, 1, [3, 4, 5, 9, 10, 11]],
                          y[1, 1, [12, 13, 14, 18, 19, 20]],
                          y[1, 1, [15, 16, 17, 21, 22, 23]]]]]
    correct_x_center = np.array(correct_x_center)
    with self.test_session() as session:
      x_indices = common_attention.gather_indices_2d(
          x, query_shape, query_shape)
      x_flange, x_center = common_attention.get_memory_region(
          tf.constant(x, dtype=tf.float32),
          query_shape,
          memory_flange,
          x_indices)
      session.run(tf.global_variables_initializer())
      [x_flange, x_center] = session.run([x_flange, x_center])
    self.assertAllClose(correct_x_flange, x_flange)
    self.assertAllClose(correct_x_center, x_center)