def test2dGatherAndScatterInvertibility(self): """2d gather and scatter invertibility test.""" batch_size = 2 num_heads = 2 height = 4 width = 6 depth = 8 query_shape = (2, 3) x = np.random.rand(batch_size, num_heads, height, width, depth) x_indices = common_attention.gather_indices_2d( x, query_shape, query_shape) gathered_x = common_attention.gather_blocks_2d(x, x_indices) x_shape = tf.constant([batch_size, num_heads, height, width, depth]) scattered_x = common_attention.scatter_blocks_2d( gathered_x, x_indices, x_shape) res = self.evaluate(scattered_x) self.assertAllClose(x, res)
def test2dGatherAndScatterInvertibility(self): """2d gather and scatter invertibility test.""" batch_size = 2 num_heads = 2 height = 4 width = 6 depth = 8 query_shape = (2, 3) x = np.random.rand(batch_size, num_heads, height, width, depth) with self.test_session() as session: x_indices = common_attention.gather_indices_2d( x, query_shape, query_shape) gathered_x = common_attention.gather_blocks_2d(x, x_indices) x_shape = tf.constant([batch_size, num_heads, height, width, depth]) scattered_x = common_attention.scatter_blocks_2d( gathered_x, x_indices, x_shape) session.run(tf.global_variables_initializer()) res = session.run(scattered_x) self.assertAllClose(x, res)
def test2dGather(self): """Testing 2d index gather and block gather functions.""" batch_size = 2 num_heads = 2 height = 4 width = 6 depth = 8 query_shape = (2, 3) x = np.random.rand(batch_size, num_heads, height, width, depth) y = np.reshape(x, (batch_size, num_heads, -1, depth)) correct_indices = [[0, 1, 2, 6, 7, 8], [3, 4, 5, 9, 10, 11], [12, 13, 14, 18, 19, 20], [15, 16, 17, 21, 22, 23]] correct_gathered_x = [ [[ y[0, 0, correct_indices[0]], y[0, 0, correct_indices[1]], y[0, 0, correct_indices[2]], y[0, 0, correct_indices[3]] ], [ y[0, 1, correct_indices[0]], y[0, 1, correct_indices[1]], y[0, 1, correct_indices[2]], y[0, 1, correct_indices[3]] ]], [[ y[1, 0, correct_indices[0]], y[1, 0, correct_indices[1]], y[1, 0, correct_indices[2]], y[1, 0, correct_indices[3]] ], [ y[1, 1, correct_indices[0]], y[1, 1, correct_indices[1]], y[1, 1, correct_indices[2]], y[1, 1, correct_indices[3]] ]] ] with self.test_session() as session: x_indices = common_attention.gather_indices_2d( x, query_shape, query_shape) gathered_x = common_attention.gather_blocks_2d(x, x_indices) x_indices, gathered_x = session.run([x_indices, gathered_x]) self.assertAllEqual(correct_indices, x_indices) self.assertAllClose(correct_gathered_x, gathered_x)
def test2dGather(self): """Testing 2d index gather and block gather functions.""" batch_size = 2 num_heads = 2 height = 4 width = 6 depth = 8 query_shape = (2, 3) x = np.random.rand(batch_size, num_heads, height, width, depth) y = np.reshape(x, (batch_size, num_heads, -1, depth)) correct_indices = [[0, 1, 2, 6, 7, 8], [3, 4, 5, 9, 10, 11], [12, 13, 14, 18, 19, 20], [15, 16, 17, 21, 22, 23]] correct_gathered_x = [[[y[0, 0, correct_indices[0]], y[0, 0, correct_indices[1]], y[0, 0, correct_indices[2]], y[0, 0, correct_indices[3]]], [y[0, 1, correct_indices[0]], y[0, 1, correct_indices[1]], y[0, 1, correct_indices[2]], y[0, 1, correct_indices[3]]]], [[y[1, 0, correct_indices[0]], y[1, 0, correct_indices[1]], y[1, 0, correct_indices[2]], y[1, 0, correct_indices[3]]], [y[1, 1, correct_indices[0]], y[1, 1, correct_indices[1]], y[1, 1, correct_indices[2]], y[1, 1, correct_indices[3]]]]] with self.test_session() as session: x_indices = common_attention.gather_indices_2d( x, query_shape, query_shape) gathered_x = common_attention.gather_blocks_2d(x, x_indices) x_indices, gathered_x = session.run([x_indices, gathered_x]) self.assertAllEqual(correct_indices, x_indices) self.assertAllClose(correct_gathered_x, gathered_x)