def testGetMemoryRegion(self): """Testing the function that gathers the flanged memory region.""" np.set_printoptions(threshold=np.inf) batch_size = 2 num_heads = 2 height = 4 width = 6 depth = 3 query_shape = (2, 3) memory_flange = (1, 1) x = np.random.rand(batch_size, num_heads, height, width, depth) y = np.reshape(x, (batch_size, num_heads, -1, depth)) zeros = np.zeros((depth), dtype=np.float32) five_zeros = np.array([zeros] * 5) seven_zeros = np.array([zeros] * 7) two_zeros = np.array([zeros] * 2) zeros = np.array([zeros]) correct_x_flange = [ [[ seven_zeros, np.concatenate((five_zeros, y[0, 0, [2, 8]]), axis=0), np.concatenate((zeros, y[0, 0, [6, 7, 8, 9]], two_zeros), axis=0), np.concatenate( (y[0, 0, [8, 9, 10, 11]], zeros, y[0, 0, [14, 20]]), axis=0) ], [ seven_zeros, np.concatenate((five_zeros, y[0, 1, [2, 8]]), axis=0), np.concatenate((zeros, y[0, 1, [6, 7, 8, 9]], two_zeros), axis=0), np.concatenate( (y[0, 1, [8, 9, 10, 11]], zeros, y[0, 1, [14, 20]]), axis=0) ]], [[ seven_zeros, np.concatenate((five_zeros, y[1, 0, [2, 8]]), axis=0), np.concatenate((zeros, y[1, 0, [6, 7, 8, 9]], two_zeros), axis=0), np.concatenate( (y[1, 0, [8, 9, 10, 11]], zeros, y[1, 0, [14, 20]]), axis=0) ], [ seven_zeros, np.concatenate((five_zeros, y[1, 1, [2, 8]]), axis=0), np.concatenate((zeros, y[1, 1, [6, 7, 8, 9]], two_zeros), axis=0), np.concatenate( (y[1, 1, [8, 9, 10, 11]], zeros, y[1, 1, [14, 20]]), axis=0) ]] ] correct_x_flange = np.array(correct_x_flange) correct_x_center = [[[ y[0, 0, [0, 1, 2, 6, 7, 8]], y[0, 0, [3, 4, 5, 9, 10, 11]], y[0, 0, [12, 13, 14, 18, 19, 20]], y[0, 0, [15, 16, 17, 21, 22, 23]] ], [ y[0, 1, [0, 1, 2, 6, 7, 8]], y[0, 1, [3, 4, 5, 9, 10, 11]], y[0, 1, [12, 13, 14, 18, 19, 20]], y[0, 1, [15, 16, 17, 21, 22, 23]] ]], [[ y[1, 0, [0, 1, 2, 6, 7, 8]], y[1, 0, [3, 4, 5, 9, 10, 11]], y[1, 0, [12, 13, 14, 18, 19, 20]], y[1, 0, [15, 16, 17, 21, 22, 23]] ], [ y[1, 1, [0, 1, 2, 6, 7, 8]], y[1, 1, [3, 4, 5, 9, 10, 11]], y[1, 1, [12, 13, 14, 18, 19, 20]], y[1, 1, [15, 16, 17, 21, 22, 23]] ]]] correct_x_center = np.array(correct_x_center) with self.test_session() as session: x_indices = common_attention.gather_indices_2d( x, query_shape, query_shape) x_flange, x_center = common_attention.get_memory_region( tf.constant(x, dtype=tf.float32), query_shape, memory_flange, x_indices) session.run(tf.global_variables_initializer()) [x_flange, x_center] = session.run([x_flange, x_center]) self.assertAllClose(correct_x_flange, x_flange) self.assertAllClose(correct_x_center, x_center)
def testGetMemoryRegion(self): """Testing the function that gathers the flanged memory region.""" np.set_printoptions(threshold=np.inf) batch_size = 2 num_heads = 2 height = 4 width = 6 depth = 3 query_shape = (2, 3) memory_flange = (1, 1) x = np.random.rand(batch_size, num_heads, height, width, depth) y = np.reshape(x, (batch_size, num_heads, -1, depth)) zeros = np.zeros((depth), dtype=np.float32) five_zeros = np.array([zeros]*5) seven_zeros = np.array([zeros]*7) two_zeros = np.array([zeros]*2) zeros = np.array([zeros]) correct_x_flange = [[[seven_zeros, np.concatenate((five_zeros, y[0, 0, [2, 8]]), axis=0), np.concatenate((zeros, y[0, 0, [6, 7, 8, 9]], two_zeros), axis=0), np.concatenate((y[0, 0, [8, 9, 10, 11]], zeros, y[0, 0, [14, 20]]), axis=0)], [seven_zeros, np.concatenate((five_zeros, y[0, 1, [2, 8]]), axis=0), np.concatenate((zeros, y[0, 1, [6, 7, 8, 9]], two_zeros), axis=0), np.concatenate((y[0, 1, [8, 9, 10, 11]], zeros, y[0, 1, [14, 20]]), axis=0)]], [[seven_zeros, np.concatenate((five_zeros, y[1, 0, [2, 8]]), axis=0), np.concatenate((zeros, y[1, 0, [6, 7, 8, 9]], two_zeros), axis=0), np.concatenate((y[1, 0, [8, 9, 10, 11]], zeros, y[1, 0, [14, 20]]), axis=0)], [seven_zeros, np.concatenate((five_zeros, y[1, 1, [2, 8]]), axis=0), np.concatenate((zeros, y[1, 1, [6, 7, 8, 9]], two_zeros), axis=0), np.concatenate((y[1, 1, [8, 9, 10, 11]], zeros, y[1, 1, [14, 20]]), axis=0)]]] correct_x_flange = np.array(correct_x_flange) correct_x_center = [[[y[0, 0, [0, 1, 2, 6, 7, 8]], y[0, 0, [3, 4, 5, 9, 10, 11]], y[0, 0, [12, 13, 14, 18, 19, 20]], y[0, 0, [15, 16, 17, 21, 22, 23]]], [y[0, 1, [0, 1, 2, 6, 7, 8]], y[0, 1, [3, 4, 5, 9, 10, 11]], y[0, 1, [12, 13, 14, 18, 19, 20]], y[0, 1, [15, 16, 17, 21, 22, 23]]]], [[y[1, 0, [0, 1, 2, 6, 7, 8]], y[1, 0, [3, 4, 5, 9, 10, 11]], y[1, 0, [12, 13, 14, 18, 19, 20]], y[1, 0, [15, 16, 17, 21, 22, 23]]], [y[1, 1, [0, 1, 2, 6, 7, 8]], y[1, 1, [3, 4, 5, 9, 10, 11]], y[1, 1, [12, 13, 14, 18, 19, 20]], y[1, 1, [15, 16, 17, 21, 22, 23]]]]] correct_x_center = np.array(correct_x_center) with self.test_session() as session: x_indices = common_attention.gather_indices_2d( x, query_shape, query_shape) x_flange, x_center = common_attention.get_memory_region( tf.constant(x, dtype=tf.float32), query_shape, memory_flange, x_indices) session.run(tf.global_variables_initializer()) [x_flange, x_center] = session.run([x_flange, x_center]) self.assertAllClose(correct_x_flange, x_flange) self.assertAllClose(correct_x_center, x_center)