示例#1
0
 def test2dGatherAndScatterInvertibility(self):
   """2d gather and scatter invertibility test."""
   batch_size = 2
   num_heads = 2
   height = 4
   width = 6
   depth = 8
   query_shape = (2, 3)
   x = np.random.rand(batch_size, num_heads, height, width, depth)
   x_indices = common_attention.gather_indices_2d(
       x, query_shape, query_shape)
   gathered_x = common_attention.gather_blocks_2d(x, x_indices)
   x_shape = tf.constant([batch_size, num_heads, height, width, depth])
   scattered_x = common_attention.scatter_blocks_2d(
       gathered_x, x_indices, x_shape)
   res = self.evaluate(scattered_x)
   self.assertAllClose(x, res)
 def test2dGatherAndScatterInvertibility(self):
   """2d gather and scatter invertibility test."""
   batch_size = 2
   num_heads = 2
   height = 4
   width = 6
   depth = 8
   query_shape = (2, 3)
   x = np.random.rand(batch_size, num_heads, height, width, depth)
   with self.test_session() as session:
     x_indices = common_attention.gather_indices_2d(
         x, query_shape, query_shape)
     gathered_x = common_attention.gather_blocks_2d(x, x_indices)
     x_shape = tf.constant([batch_size, num_heads, height, width, depth])
     scattered_x = common_attention.scatter_blocks_2d(
         gathered_x, x_indices, x_shape)
     session.run(tf.global_variables_initializer())
     res = session.run(scattered_x)
   self.assertAllClose(x, res)
 def test2dGatherAndScatterInvertibility(self):
   """2d gather and scatter invertibility test."""
   batch_size = 2
   num_heads = 2
   height = 4
   width = 6
   depth = 8
   query_shape = (2, 3)
   x = np.random.rand(batch_size, num_heads, height, width, depth)
   with self.test_session() as session:
     x_indices = common_attention.gather_indices_2d(
         x, query_shape, query_shape)
     gathered_x = common_attention.gather_blocks_2d(x, x_indices)
     x_shape = tf.constant([batch_size, num_heads, height, width, depth])
     scattered_x = common_attention.scatter_blocks_2d(
         gathered_x, x_indices, x_shape)
     session.run(tf.global_variables_initializer())
     res = session.run(scattered_x)
   self.assertAllClose(x, res)
    def test2dGather(self):
        """Testing 2d index gather and block gather functions."""
        batch_size = 2
        num_heads = 2
        height = 4
        width = 6
        depth = 8
        query_shape = (2, 3)
        x = np.random.rand(batch_size, num_heads, height, width, depth)
        y = np.reshape(x, (batch_size, num_heads, -1, depth))
        correct_indices = [[0, 1, 2, 6, 7, 8], [3, 4, 5, 9, 10, 11],
                           [12, 13, 14, 18, 19, 20], [15, 16, 17, 21, 22, 23]]
        correct_gathered_x = [
            [[
                y[0, 0, correct_indices[0]], y[0, 0, correct_indices[1]],
                y[0, 0, correct_indices[2]], y[0, 0, correct_indices[3]]
            ],
             [
                 y[0, 1, correct_indices[0]], y[0, 1, correct_indices[1]],
                 y[0, 1, correct_indices[2]], y[0, 1, correct_indices[3]]
             ]],
            [[
                y[1, 0, correct_indices[0]], y[1, 0, correct_indices[1]],
                y[1, 0, correct_indices[2]], y[1, 0, correct_indices[3]]
            ],
             [
                 y[1, 1, correct_indices[0]], y[1, 1, correct_indices[1]],
                 y[1, 1, correct_indices[2]], y[1, 1, correct_indices[3]]
             ]]
        ]

        with self.test_session() as session:
            x_indices = common_attention.gather_indices_2d(
                x, query_shape, query_shape)
            gathered_x = common_attention.gather_blocks_2d(x, x_indices)
            x_indices, gathered_x = session.run([x_indices, gathered_x])
        self.assertAllEqual(correct_indices, x_indices)
        self.assertAllClose(correct_gathered_x, gathered_x)
  def test2dGather(self):
    """Testing 2d index gather and block gather functions."""
    batch_size = 2
    num_heads = 2
    height = 4
    width = 6
    depth = 8
    query_shape = (2, 3)
    x = np.random.rand(batch_size, num_heads, height, width, depth)
    y = np.reshape(x, (batch_size, num_heads, -1, depth))
    correct_indices = [[0, 1, 2, 6, 7, 8],
                       [3, 4, 5, 9, 10, 11],
                       [12, 13, 14, 18, 19, 20],
                       [15, 16, 17, 21, 22, 23]]
    correct_gathered_x = [[[y[0, 0, correct_indices[0]],
                            y[0, 0, correct_indices[1]],
                            y[0, 0, correct_indices[2]],
                            y[0, 0, correct_indices[3]]],
                           [y[0, 1, correct_indices[0]],
                            y[0, 1, correct_indices[1]],
                            y[0, 1, correct_indices[2]],
                            y[0, 1, correct_indices[3]]]],
                          [[y[1, 0, correct_indices[0]],
                            y[1, 0, correct_indices[1]],
                            y[1, 0, correct_indices[2]],
                            y[1, 0, correct_indices[3]]],
                           [y[1, 1, correct_indices[0]],
                            y[1, 1, correct_indices[1]],
                            y[1, 1, correct_indices[2]],
                            y[1, 1, correct_indices[3]]]]]

    with self.test_session() as session:
      x_indices = common_attention.gather_indices_2d(
          x, query_shape, query_shape)
      gathered_x = common_attention.gather_blocks_2d(x, x_indices)
      x_indices, gathered_x = session.run([x_indices, gathered_x])
    self.assertAllEqual(correct_indices, x_indices)
    self.assertAllClose(correct_gathered_x, gathered_x)