Exemple #1
0
class SegmentationPrefetcher:
    """
    SegmentationPrefetcher will prefetch a bunch of segmentation
    images using a multiprocessing pool, so you do not have to wait
    around while the files get opened and decoded.  Just request
    batches of images and segmentations calling fetch_batch().
    """
    def __init__(
        self,
        segmentation,
        split=None,
        randomize=False,
        segmentation_shape=None,
        categories=None,
        once=False,
        start=None,
        end=None,
        batch_size=4,
        ahead=4,
        thread=False,
    ):
        """
        Constructor arguments:
        segmentation: The AbstractSegmentation to load.
        split: None for no filtering, or 'train' or 'val' etc.
        randomize: True to randomly shuffle order, or a random seed.
        categories: a list of categories to include in each batch.
        batch_size: number of data items for each batch.
        ahead: the number of data items to prefetch ahead.
        """
        self.segmentation = segmentation
        self.split = split
        self.randomize = randomize
        self.random = random.Random()
        if randomize is not True:
            self.random.seed(randomize)
        self.categories = categories
        self.once = once
        self.batch_size = batch_size
        self.ahead = ahead
        # Initialize the multiprocessing pool
        n_procs = cpu_count()
        if thread:
            self.pool = ThreadPool(processes=n_procs)
        else:
            original_sigint_handler = setup_sigint()
            self.pool = Pool(processes=n_procs, initializer=setup_sigint)
            restore_sigint(original_sigint_handler)
        # Prefilter the image indexes of interest
        if start is None:
            start = 0
        if end is None:
            end = segmentation.size()
        self.indexes = range(start, end)
        if split:
            self.indexes = [
                i for i in self.indexes if segmentation.split(i) == split
            ]
        if self.randomize:
            self.random.shuffle(self.indexes)
        self.index = 0
        self.result_queue = []
        self.segmentation_shape = segmentation_shape
        # Get dense catmaps
        self.catmaps = [
            segmentation.category_index_map(cat) if cat != "image" else None
            for cat in categories
        ]

    def next_job(self):
        # Gets the seg thing (I think)
        if self.index < 0:
            return None
        j = self.indexes[self.index]
        result = (
            j,
            self.segmentation.__class__,
            self.segmentation.metadata(j),
            self.segmentation.filename(j),
            self.categories,
            self.segmentation_shape,
        )
        self.index += 1
        if self.index >= len(self.indexes):
            if self.once:
                self.index = -1
            else:
                self.index = 0
                if self.randomize:
                    # Reshuffle every time through
                    self.random.shuffle(self.indexes)
        return result

    def batches(self):
        """Iterator for all batches"""
        while True:
            batch = self.fetch_batch()
            if batch is None:
                return
            yield batch

    def fetch_batch(self):
        """Returns a single batch as an array of dictionaries."""
        try:
            self.refill_tasks()
            if len(self.result_queue) == 0:
                return None
            result = self.result_queue.pop(0)
            return result.get(31536000)
        except KeyboardInterrupt:
            print("Caught KeyboardInterrupt, terminating workers")
            self.pool.terminate()
            raise

    def fetch_tensor_batch(self, bgr_mean=None, global_labels=False):
        """Iterator for batches as arrays of tensors."""
        batch = self.fetch_batch()
        return self.form_caffe_tensors(batch, bgr_mean, global_labels)

    def tensor_batches(self, bgr_mean=None, global_labels=False):
        """Returns a single batch as an array of tensors, one per category."""
        while True:
            batch = self.fetch_tensor_batch(bgr_mean=bgr_mean,
                                            global_labels=global_labels)
            if batch is None:
                return
            yield batch

    def form_caffe_tensors(self, batch, bgr_mean=None, global_labels=False):
        # Assemble a batch in [{'cat': data,..},..] format into
        # an array of batch tensors, the first for the image, and the
        # remaining for each category in self.categories, in order.
        # This also applies a random flip if needed
        if batch is None:
            return None
        cats = [*self.categories, "scene"]
        batches = [[] for c in cats]
        for record in batch:
            default_shape = (1, record["sh"], record["sw"])
            for c, cat in enumerate(cats):
                if cat == "image":
                    # Normalize image with right RGB order and mean
                    batches[c].append(normalize_image(record[cat], bgr_mean))
                elif global_labels:
                    if cat == "scene":
                        if not record[cat]:
                            batches[c].append(np.array([-1]))
                        elif len(record[cat]) > 1:
                            print(
                                f"Multiple scenes: {record['fn']} {record[cat]}"
                            )
                            batches[c].append(np.array(record[cat][0]))
                        else:
                            batches[c].append(np.array(record[cat]))
                    else:
                        batches[c].append(
                            normalize_label(record[cat],
                                            default_shape,
                                            flatten=True))
                else:
                    catmap = self.catmaps[c]
                    batches[c].append(catmap[normalize_label(record[cat],
                                                             default_shape,
                                                             flatten=True)])
        return [
            numpy.concatenate(tuple(m[numpy.newaxis] for m in b))
            for b in batches
        ]

    def refill_tasks(self):
        # It will call the sequencer to ask for a sequence
        # of batch_size jobs (indexes with categories)
        # Then it will call pool.map_async
        while len(self.result_queue) < self.ahead:
            data = []
            while len(data) < self.batch_size:
                job = self.next_job()
                if job is None:
                    break
                data.append(job)
            if len(data) == 0:
                return
            self.result_queue.append(self.pool.map_async(
                prefetch_worker, data))

    def close(self):
        while len(self.result_queue):
            result = self.result_queue.pop(0)
            if result is not None:
                result.wait(0.001)
        self.pool.close()
        self.pool.cancel_join_thread()