class SegmentationPrefetcher: """ SegmentationPrefetcher will prefetch a bunch of segmentation images using a multiprocessing pool, so you do not have to wait around while the files get opened and decoded. Just request batches of images and segmentations calling fetch_batch(). """ def __init__( self, segmentation, split=None, randomize=False, segmentation_shape=None, categories=None, once=False, start=None, end=None, batch_size=4, ahead=4, thread=False, ): """ Constructor arguments: segmentation: The AbstractSegmentation to load. split: None for no filtering, or 'train' or 'val' etc. randomize: True to randomly shuffle order, or a random seed. categories: a list of categories to include in each batch. batch_size: number of data items for each batch. ahead: the number of data items to prefetch ahead. """ self.segmentation = segmentation self.split = split self.randomize = randomize self.random = random.Random() if randomize is not True: self.random.seed(randomize) self.categories = categories self.once = once self.batch_size = batch_size self.ahead = ahead # Initialize the multiprocessing pool n_procs = cpu_count() if thread: self.pool = ThreadPool(processes=n_procs) else: original_sigint_handler = setup_sigint() self.pool = Pool(processes=n_procs, initializer=setup_sigint) restore_sigint(original_sigint_handler) # Prefilter the image indexes of interest if start is None: start = 0 if end is None: end = segmentation.size() self.indexes = range(start, end) if split: self.indexes = [ i for i in self.indexes if segmentation.split(i) == split ] if self.randomize: self.random.shuffle(self.indexes) self.index = 0 self.result_queue = [] self.segmentation_shape = segmentation_shape # Get dense catmaps self.catmaps = [ segmentation.category_index_map(cat) if cat != "image" else None for cat in categories ] def next_job(self): # Gets the seg thing (I think) if self.index < 0: return None j = self.indexes[self.index] result = ( j, self.segmentation.__class__, self.segmentation.metadata(j), self.segmentation.filename(j), self.categories, self.segmentation_shape, ) self.index += 1 if self.index >= len(self.indexes): if self.once: self.index = -1 else: self.index = 0 if self.randomize: # Reshuffle every time through self.random.shuffle(self.indexes) return result def batches(self): """Iterator for all batches""" while True: batch = self.fetch_batch() if batch is None: return yield batch def fetch_batch(self): """Returns a single batch as an array of dictionaries.""" try: self.refill_tasks() if len(self.result_queue) == 0: return None result = self.result_queue.pop(0) return result.get(31536000) except KeyboardInterrupt: print("Caught KeyboardInterrupt, terminating workers") self.pool.terminate() raise def fetch_tensor_batch(self, bgr_mean=None, global_labels=False): """Iterator for batches as arrays of tensors.""" batch = self.fetch_batch() return self.form_caffe_tensors(batch, bgr_mean, global_labels) def tensor_batches(self, bgr_mean=None, global_labels=False): """Returns a single batch as an array of tensors, one per category.""" while True: batch = self.fetch_tensor_batch(bgr_mean=bgr_mean, global_labels=global_labels) if batch is None: return yield batch def form_caffe_tensors(self, batch, bgr_mean=None, global_labels=False): # Assemble a batch in [{'cat': data,..},..] format into # an array of batch tensors, the first for the image, and the # remaining for each category in self.categories, in order. # This also applies a random flip if needed if batch is None: return None cats = [*self.categories, "scene"] batches = [[] for c in cats] for record in batch: default_shape = (1, record["sh"], record["sw"]) for c, cat in enumerate(cats): if cat == "image": # Normalize image with right RGB order and mean batches[c].append(normalize_image(record[cat], bgr_mean)) elif global_labels: if cat == "scene": if not record[cat]: batches[c].append(np.array([-1])) elif len(record[cat]) > 1: print( f"Multiple scenes: {record['fn']} {record[cat]}" ) batches[c].append(np.array(record[cat][0])) else: batches[c].append(np.array(record[cat])) else: batches[c].append( normalize_label(record[cat], default_shape, flatten=True)) else: catmap = self.catmaps[c] batches[c].append(catmap[normalize_label(record[cat], default_shape, flatten=True)]) return [ numpy.concatenate(tuple(m[numpy.newaxis] for m in b)) for b in batches ] def refill_tasks(self): # It will call the sequencer to ask for a sequence # of batch_size jobs (indexes with categories) # Then it will call pool.map_async while len(self.result_queue) < self.ahead: data = [] while len(data) < self.batch_size: job = self.next_job() if job is None: break data.append(job) if len(data) == 0: return self.result_queue.append(self.pool.map_async( prefetch_worker, data)) def close(self): while len(self.result_queue): result = self.result_queue.pop(0) if result is not None: result.wait(0.001) self.pool.close() self.pool.cancel_join_thread()