def test2dGatherAndScatterInvertibility(self): """2d gather and scatter invertibility test.""" batch_size = 2 num_heads = 2 height = 4 width = 6 depth = 8 query_shape = (2, 3) x = np.random.rand(batch_size, num_heads, height, width, depth) x_indices = common_attention.gather_indices_2d( x, query_shape, query_shape) gathered_x = common_attention.gather_blocks_2d(x, x_indices) x_shape = tf.constant([batch_size, num_heads, height, width, depth]) scattered_x = common_attention.scatter_blocks_2d( gathered_x, x_indices, x_shape) res = self.evaluate(scattered_x) self.assertAllClose(x, res)
def test2dGatherAndScatterInvertibility(self): """2d gather and scatter invertibility test.""" batch_size = 2 num_heads = 2 height = 4 width = 6 depth = 8 query_shape = (2, 3) x = np.random.rand(batch_size, num_heads, height, width, depth) with self.test_session() as session: x_indices = common_attention.gather_indices_2d( x, query_shape, query_shape) gathered_x = common_attention.gather_blocks_2d(x, x_indices) x_shape = tf.constant([batch_size, num_heads, height, width, depth]) scattered_x = common_attention.scatter_blocks_2d( gathered_x, x_indices, x_shape) session.run(tf.global_variables_initializer()) res = session.run(scattered_x) self.assertAllClose(x, res)
def test2dGatherAndScatterInvertibility(self): """2d gather and scatter invertibility test.""" batch_size = 2 num_heads = 2 height = 4 width = 6 depth = 8 query_shape = (2, 3) x = np.random.rand(batch_size, num_heads, height, width, depth) with self.test_session() as session: x_indices = common_attention.gather_indices_2d( x, query_shape, query_shape) gathered_x = common_attention.gather_blocks_2d(x, x_indices) x_shape = tf.constant([batch_size, num_heads, height, width, depth]) scattered_x = common_attention.scatter_blocks_2d( gathered_x, x_indices, x_shape) session.run(tf.global_variables_initializer()) res = session.run(scattered_x) self.assertAllClose(x, res)
def test2dGather(self): """Testing 2d index gather and block gather functions.""" batch_size = 2 num_heads = 2 height = 4 width = 6 depth = 8 query_shape = (2, 3) x = np.random.rand(batch_size, num_heads, height, width, depth) y = np.reshape(x, (batch_size, num_heads, -1, depth)) correct_indices = [[0, 1, 2, 6, 7, 8], [3, 4, 5, 9, 10, 11], [12, 13, 14, 18, 19, 20], [15, 16, 17, 21, 22, 23]] correct_gathered_x = [ [[ y[0, 0, correct_indices[0]], y[0, 0, correct_indices[1]], y[0, 0, correct_indices[2]], y[0, 0, correct_indices[3]] ], [ y[0, 1, correct_indices[0]], y[0, 1, correct_indices[1]], y[0, 1, correct_indices[2]], y[0, 1, correct_indices[3]] ]], [[ y[1, 0, correct_indices[0]], y[1, 0, correct_indices[1]], y[1, 0, correct_indices[2]], y[1, 0, correct_indices[3]] ], [ y[1, 1, correct_indices[0]], y[1, 1, correct_indices[1]], y[1, 1, correct_indices[2]], y[1, 1, correct_indices[3]] ]] ] with self.test_session() as session: x_indices = common_attention.gather_indices_2d( x, query_shape, query_shape) gathered_x = common_attention.gather_blocks_2d(x, x_indices) x_indices, gathered_x = session.run([x_indices, gathered_x]) self.assertAllEqual(correct_indices, x_indices) self.assertAllClose(correct_gathered_x, gathered_x)
def test2dGather(self): """Testing 2d index gather and block gather functions.""" batch_size = 2 num_heads = 2 height = 4 width = 6 depth = 8 query_shape = (2, 3) x = np.random.rand(batch_size, num_heads, height, width, depth) y = np.reshape(x, (batch_size, num_heads, -1, depth)) correct_indices = [[0, 1, 2, 6, 7, 8], [3, 4, 5, 9, 10, 11], [12, 13, 14, 18, 19, 20], [15, 16, 17, 21, 22, 23]] correct_gathered_x = [[[y[0, 0, correct_indices[0]], y[0, 0, correct_indices[1]], y[0, 0, correct_indices[2]], y[0, 0, correct_indices[3]]], [y[0, 1, correct_indices[0]], y[0, 1, correct_indices[1]], y[0, 1, correct_indices[2]], y[0, 1, correct_indices[3]]]], [[y[1, 0, correct_indices[0]], y[1, 0, correct_indices[1]], y[1, 0, correct_indices[2]], y[1, 0, correct_indices[3]]], [y[1, 1, correct_indices[0]], y[1, 1, correct_indices[1]], y[1, 1, correct_indices[2]], y[1, 1, correct_indices[3]]]]] with self.test_session() as session: x_indices = common_attention.gather_indices_2d( x, query_shape, query_shape) gathered_x = common_attention.gather_blocks_2d(x, x_indices) x_indices, gathered_x = session.run([x_indices, gathered_x]) self.assertAllEqual(correct_indices, x_indices) self.assertAllClose(correct_gathered_x, gathered_x)
def testGetShiftedCenterBlocks(self): """Testing the function that gathers the flanged memory region.""" np.set_printoptions(threshold=np.inf) batch_size = 2 num_heads = 2 height = 4 width = 6 depth = 3 query_shape = (2, 3) x = np.random.rand(batch_size, num_heads, height, width, depth) y = np.reshape(x, (batch_size, num_heads, -1, depth)) zeros = np.zeros((depth), dtype=np.float32) zeros = np.array([zeros]) correct_gathered_x = [[[ np.concatenate((zeros, y[0, 0, [0, 1, 2, 6, 7]]), axis=0), np.concatenate((zeros, y[0, 0, [3, 4, 5, 9, 10]]), axis=0), np.concatenate((zeros, y[0, 0, [12, 13, 14, 18, 19]]), axis=0), np.concatenate((zeros, y[0, 0, [15, 16, 17, 21, 22]]), axis=0) ], [ np.concatenate( (zeros, y[0, 1, [0, 1, 2, 6, 7]]), axis=0), np.concatenate( (zeros, y[0, 1, [3, 4, 5, 9, 10]]), axis=0), np.concatenate( (zeros, y[0, 1, [12, 13, 14, 18, 19]]), axis=0), np.concatenate( (zeros, y[0, 1, [15, 16, 17, 21, 22]]), axis=0) ]], [[ np.concatenate( (zeros, y[1, 0, [0, 1, 2, 6, 7]]), axis=0), np.concatenate( (zeros, y[1, 0, [3, 4, 5, 9, 10]]), axis=0), np.concatenate( (zeros, y[1, 0, [12, 13, 14, 18, 19]]), axis=0), np.concatenate( (zeros, y[1, 0, [15, 16, 17, 21, 22]]), axis=0) ], [ np.concatenate( (zeros, y[1, 1, [0, 1, 2, 6, 7]]), axis=0), np.concatenate( (zeros, y[1, 1, [3, 4, 5, 9, 10]]), axis=0), np.concatenate( (zeros, y[1, 1, [12, 13, 14, 18, 19]]), axis=0), np.concatenate( (zeros, y[1, 1, [15, 16, 17, 21, 22]]), axis=0) ]]] correct_gathered_x = np.array(correct_gathered_x) with self.test_session() as session: x_indices = common_attention.gather_indices_2d( x, query_shape, query_shape) gathered_x = common_attention.get_shifted_center_blocks( tf.constant(x, dtype=tf.float32), x_indices) session.run(tf.global_variables_initializer()) x_indices, gathered_x = session.run([x_indices, gathered_x]) self.assertAllClose(correct_gathered_x, gathered_x)
def testGetMemoryRegion(self): """Testing the function that gathers the flanged memory region.""" np.set_printoptions(threshold=np.inf) batch_size = 2 num_heads = 2 height = 4 width = 6 depth = 3 query_shape = (2, 3) memory_flange = (1, 1) x = np.random.rand(batch_size, num_heads, height, width, depth) y = np.reshape(x, (batch_size, num_heads, -1, depth)) zeros = np.zeros((depth), dtype=np.float32) five_zeros = np.array([zeros] * 5) seven_zeros = np.array([zeros] * 7) two_zeros = np.array([zeros] * 2) zeros = np.array([zeros]) correct_x_flange = [ [[ seven_zeros, np.concatenate((five_zeros, y[0, 0, [2, 8]]), axis=0), np.concatenate((zeros, y[0, 0, [6, 7, 8, 9]], two_zeros), axis=0), np.concatenate( (y[0, 0, [8, 9, 10, 11]], zeros, y[0, 0, [14, 20]]), axis=0) ], [ seven_zeros, np.concatenate((five_zeros, y[0, 1, [2, 8]]), axis=0), np.concatenate((zeros, y[0, 1, [6, 7, 8, 9]], two_zeros), axis=0), np.concatenate( (y[0, 1, [8, 9, 10, 11]], zeros, y[0, 1, [14, 20]]), axis=0) ]], [[ seven_zeros, np.concatenate((five_zeros, y[1, 0, [2, 8]]), axis=0), np.concatenate((zeros, y[1, 0, [6, 7, 8, 9]], two_zeros), axis=0), np.concatenate( (y[1, 0, [8, 9, 10, 11]], zeros, y[1, 0, [14, 20]]), axis=0) ], [ seven_zeros, np.concatenate((five_zeros, y[1, 1, [2, 8]]), axis=0), np.concatenate((zeros, y[1, 1, [6, 7, 8, 9]], two_zeros), axis=0), np.concatenate( (y[1, 1, [8, 9, 10, 11]], zeros, y[1, 1, [14, 20]]), axis=0) ]] ] correct_x_flange = np.array(correct_x_flange) correct_x_center = [[[ y[0, 0, [0, 1, 2, 6, 7, 8]], y[0, 0, [3, 4, 5, 9, 10, 11]], y[0, 0, [12, 13, 14, 18, 19, 20]], y[0, 0, [15, 16, 17, 21, 22, 23]] ], [ y[0, 1, [0, 1, 2, 6, 7, 8]], y[0, 1, [3, 4, 5, 9, 10, 11]], y[0, 1, [12, 13, 14, 18, 19, 20]], y[0, 1, [15, 16, 17, 21, 22, 23]] ]], [[ y[1, 0, [0, 1, 2, 6, 7, 8]], y[1, 0, [3, 4, 5, 9, 10, 11]], y[1, 0, [12, 13, 14, 18, 19, 20]], y[1, 0, [15, 16, 17, 21, 22, 23]] ], [ y[1, 1, [0, 1, 2, 6, 7, 8]], y[1, 1, [3, 4, 5, 9, 10, 11]], y[1, 1, [12, 13, 14, 18, 19, 20]], y[1, 1, [15, 16, 17, 21, 22, 23]] ]]] correct_x_center = np.array(correct_x_center) with self.test_session() as session: x_indices = common_attention.gather_indices_2d( x, query_shape, query_shape) x_flange, x_center = common_attention.get_memory_region( tf.constant(x, dtype=tf.float32), query_shape, memory_flange, x_indices) session.run(tf.global_variables_initializer()) [x_flange, x_center] = session.run([x_flange, x_center]) self.assertAllClose(correct_x_flange, x_flange) self.assertAllClose(correct_x_center, x_center)
def testGetShiftedCenterBlocks(self): """Testing the function that gathers the flanged memory region.""" np.set_printoptions(threshold=np.inf) batch_size = 2 num_heads = 2 height = 4 width = 6 depth = 3 query_shape = (2, 3) x = np.random.rand(batch_size, num_heads, height, width, depth) y = np.reshape(x, (batch_size, num_heads, -1, depth)) zeros = np.zeros((depth), dtype=np.float32) zeros = np.array([zeros]) correct_gathered_x = [[[np.concatenate((zeros, y[0, 0, [0, 1, 2, 6, 7]]), axis=0), np.concatenate((zeros, y[0, 0, [3, 4, 5, 9, 10]]), axis=0), np.concatenate((zeros, y[0, 0, [12, 13, 14, 18, 19]]), axis=0), np.concatenate((zeros, y[0, 0, [15, 16, 17, 21, 22]]), axis=0)], [np.concatenate((zeros, y[0, 1, [0, 1, 2, 6, 7]]), axis=0), np.concatenate((zeros, y[0, 1, [3, 4, 5, 9, 10]]), axis=0), np.concatenate((zeros, y[0, 1, [12, 13, 14, 18, 19]]), axis=0), np.concatenate((zeros, y[0, 1, [15, 16, 17, 21, 22]]), axis=0)]], [[np.concatenate((zeros, y[1, 0, [0, 1, 2, 6, 7]]), axis=0), np.concatenate((zeros, y[1, 0, [3, 4, 5, 9, 10]]), axis=0), np.concatenate((zeros, y[1, 0, [12, 13, 14, 18, 19]]), axis=0), np.concatenate((zeros, y[1, 0, [15, 16, 17, 21, 22]]), axis=0)], [np.concatenate((zeros, y[1, 1, [0, 1, 2, 6, 7]]), axis=0), np.concatenate((zeros, y[1, 1, [3, 4, 5, 9, 10]]), axis=0), np.concatenate((zeros, y[1, 1, [12, 13, 14, 18, 19]]), axis=0), np.concatenate((zeros, y[1, 1, [15, 16, 17, 21, 22]]), axis=0)]]] correct_gathered_x = np.array(correct_gathered_x) with self.test_session() as session: x_indices = common_attention.gather_indices_2d( x, query_shape, query_shape) gathered_x = common_attention.get_shifted_center_blocks( tf.constant(x, dtype=tf.float32), x_indices) session.run(tf.global_variables_initializer()) x_indices, gathered_x = session.run([x_indices, gathered_x]) self.assertAllClose(correct_gathered_x, gathered_x)
def testGetMemoryRegion(self): """Testing the function that gathers the flanged memory region.""" np.set_printoptions(threshold=np.inf) batch_size = 2 num_heads = 2 height = 4 width = 6 depth = 3 query_shape = (2, 3) memory_flange = (1, 1) x = np.random.rand(batch_size, num_heads, height, width, depth) y = np.reshape(x, (batch_size, num_heads, -1, depth)) zeros = np.zeros((depth), dtype=np.float32) five_zeros = np.array([zeros]*5) seven_zeros = np.array([zeros]*7) two_zeros = np.array([zeros]*2) zeros = np.array([zeros]) correct_x_flange = [[[seven_zeros, np.concatenate((five_zeros, y[0, 0, [2, 8]]), axis=0), np.concatenate((zeros, y[0, 0, [6, 7, 8, 9]], two_zeros), axis=0), np.concatenate((y[0, 0, [8, 9, 10, 11]], zeros, y[0, 0, [14, 20]]), axis=0)], [seven_zeros, np.concatenate((five_zeros, y[0, 1, [2, 8]]), axis=0), np.concatenate((zeros, y[0, 1, [6, 7, 8, 9]], two_zeros), axis=0), np.concatenate((y[0, 1, [8, 9, 10, 11]], zeros, y[0, 1, [14, 20]]), axis=0)]], [[seven_zeros, np.concatenate((five_zeros, y[1, 0, [2, 8]]), axis=0), np.concatenate((zeros, y[1, 0, [6, 7, 8, 9]], two_zeros), axis=0), np.concatenate((y[1, 0, [8, 9, 10, 11]], zeros, y[1, 0, [14, 20]]), axis=0)], [seven_zeros, np.concatenate((five_zeros, y[1, 1, [2, 8]]), axis=0), np.concatenate((zeros, y[1, 1, [6, 7, 8, 9]], two_zeros), axis=0), np.concatenate((y[1, 1, [8, 9, 10, 11]], zeros, y[1, 1, [14, 20]]), axis=0)]]] correct_x_flange = np.array(correct_x_flange) correct_x_center = [[[y[0, 0, [0, 1, 2, 6, 7, 8]], y[0, 0, [3, 4, 5, 9, 10, 11]], y[0, 0, [12, 13, 14, 18, 19, 20]], y[0, 0, [15, 16, 17, 21, 22, 23]]], [y[0, 1, [0, 1, 2, 6, 7, 8]], y[0, 1, [3, 4, 5, 9, 10, 11]], y[0, 1, [12, 13, 14, 18, 19, 20]], y[0, 1, [15, 16, 17, 21, 22, 23]]]], [[y[1, 0, [0, 1, 2, 6, 7, 8]], y[1, 0, [3, 4, 5, 9, 10, 11]], y[1, 0, [12, 13, 14, 18, 19, 20]], y[1, 0, [15, 16, 17, 21, 22, 23]]], [y[1, 1, [0, 1, 2, 6, 7, 8]], y[1, 1, [3, 4, 5, 9, 10, 11]], y[1, 1, [12, 13, 14, 18, 19, 20]], y[1, 1, [15, 16, 17, 21, 22, 23]]]]] correct_x_center = np.array(correct_x_center) with self.test_session() as session: x_indices = common_attention.gather_indices_2d( x, query_shape, query_shape) x_flange, x_center = common_attention.get_memory_region( tf.constant(x, dtype=tf.float32), query_shape, memory_flange, x_indices) session.run(tf.global_variables_initializer()) [x_flange, x_center] = session.run([x_flange, x_center]) self.assertAllClose(correct_x_flange, x_flange) self.assertAllClose(correct_x_center, x_center)