Exemple #1
0
class OpSingleBlockObjectPrediction(Operator):
    RawImage = InputSlot()
    BinaryImage = InputSlot()

    SelectedFeatures = InputSlot(rtype=List, stype=Opaque)

    Classifier = InputSlot()
    LabelsCount = InputSlot()

    ObjectwisePredictions = OutputSlot(stype=Opaque, rtype=List)
    PredictionImage = OutputSlot()
    ProbabilityChannelImage = OutputSlot()
    BlockwiseRegionFeatures = OutputSlot()  # Indexed by (t,c)

    # Schematic:
    #
    # RawImage -----> opRawSubRegion ------                        _______________________
    #                                      \                      /                       \
    # BinaryImage --> opBinarySubRegion --> opExtract --(features)--> opPredict --(map)--> opPredictionImage --via execute()--> PredictionImage
    #                                      /         \               /                    /
    #                 SelectedFeatures-----           \   Classifier                     /
    #                                                  \                                /
    #                                                   (labels)---------------------------> opProbabilityChannelsToImage

    # +----------------------------------------------------------------+
    # | input_shape = RawImage.meta.shape                              |
    # |                                                                |
    # |                                                                |
    # |                                                                |
    # |                                                                |
    # |                                                                |
    # |                                                                |
    # |                    halo_shape = blockshape + 2*halo_padding    |
    # |                    +------------------------+                  |
    # |                    | halo_roi               |                  |
    # |                    | (for internal pipeline)|                  |
    # |                    |                        |                  |
    # |                    |  +------------------+  |                  |
    # |                    |  | block_roi        |  |                  |
    # |                    |  | (output shape)   |  |                  |
    # |                    |  |                  |  |                  |
    # |                    |  |                  |  |                  |
    # |                    |  |                  |  |                  |
    # |                    |  +------------------+  |                  |
    # |                    |                        |                  |
    # |                    |                        |                  |
    # |                    |                        |                  |
    # |                    +------------------------+                  |
    # |                                                                |
    # |                                                                |
    # |                                                                |
    # |                                                                |
    # |                                                                |
    # |                                                                |
    # |                                                                |
    # +----------------------------------------------------------------+

    def __init__(self, block_roi, halo_padding, *args, **kwargs):
        super(self.__class__, self).__init__(*args, **kwargs)

        self.block_roi = block_roi  # In global coordinates
        self._halo_padding = halo_padding

        self._opBinarySubRegion = OpSubRegion(parent=self)
        self._opBinarySubRegion.Input.connect(self.BinaryImage)

        self._opRawSubRegion = OpSubRegion(parent=self)
        self._opRawSubRegion.Input.connect(self.RawImage)

        self._opExtract = OpObjectExtraction(parent=self)
        self._opExtract.BinaryImage.connect(self._opBinarySubRegion.Output)
        self._opExtract.RawImage.connect(self._opRawSubRegion.Output)
        self._opExtract.Features.connect(self.SelectedFeatures)
        self.BlockwiseRegionFeatures.connect(
            self._opExtract.BlockwiseRegionFeatures)

        self._opExtract._opRegFeats._opCache.name = "blockwise-regionfeats-cache"

        self._opPredict = OpObjectPredict(parent=self)
        self._opPredict.Features.connect(self._opExtract.RegionFeatures)
        self._opPredict.SelectedFeatures.connect(self.SelectedFeatures)
        self._opPredict.Classifier.connect(self.Classifier)
        self._opPredict.LabelsCount.connect(self.LabelsCount)
        self.ObjectwisePredictions.connect(self._opPredict.Predictions)

        self._opPredictionImage = OpRelabelSegmentation(parent=self)
        self._opPredictionImage.Image.connect(self._opExtract.LabelImage)
        self._opPredictionImage.Features.connect(
            self._opExtract.RegionFeatures)
        self._opPredictionImage.ObjectMap.connect(self._opPredict.Predictions)

        self._opPredictionCache = OpArrayCache(parent=self)
        self._opPredictionCache.Input.connect(self._opPredictionImage.Output)

        self._opProbabilityChannelsToImage = OpMultiRelabelSegmentation(
            parent=self)
        self._opProbabilityChannelsToImage.Image.connect(
            self._opExtract.LabelImage)
        self._opProbabilityChannelsToImage.ObjectMaps.connect(
            self._opPredict.ProbabilityChannels)
        self._opProbabilityChannelsToImage.Features.connect(
            self._opExtract.RegionFeatures)

        self._opProbabilityChannelStacker = OpMultiArrayStacker(parent=self)
        self._opProbabilityChannelStacker.Images.connect(
            self._opProbabilityChannelsToImage.Output)
        self._opProbabilityChannelStacker.AxisFlag.setValue('c')

        self._opProbabilityCache = OpArrayCache(parent=self)
        self._opProbabilityCache.Input.connect(
            self._opProbabilityChannelStacker.Output)

    def setupOutputs(self):
        tagged_input_shape = self.RawImage.meta.getTaggedShape()
        self._halo_roi = self.computeHaloRoi(
            tagged_input_shape, self._halo_padding,
            self.block_roi)  # In global coordinates

        # Output roi in our own coordinates (i.e. relative to the halo start)
        self._output_roi = self.block_roi - self._halo_roi[0]

        halo_start, halo_stop = map(tuple, self._halo_roi)

        self._opRawSubRegion.Roi.setValue((halo_start, halo_stop))

        # Binary image has only 1 channel.  Adjust halo subregion.
        assert self.BinaryImage.meta.getTaggedShape()['c'] == 1
        c_index = self.BinaryImage.meta.axistags.channelIndex
        binary_halo_roi = numpy.array(self._halo_roi)
        binary_halo_roi[:, c_index] = (0, 1)  # Binary has only 1 channel.
        binary_halo_start, binary_halo_stop = map(tuple, binary_halo_roi)

        self._opBinarySubRegion.Roi.setValue(
            (binary_halo_start, binary_halo_stop))

        self.PredictionImage.meta.assignFrom(
            self._opPredictionImage.Output.meta)
        self.PredictionImage.meta.shape = tuple(
            numpy.subtract(self.block_roi[1], self.block_roi[0]))

        self.ProbabilityChannelImage.meta.assignFrom(
            self._opProbabilityChannelStacker.Output.meta)
        probability_shape = numpy.subtract(self.block_roi[1],
                                           self.block_roi[0])
        probability_shape[
            -1] = self._opProbabilityChannelStacker.Output.meta.shape[-1]
        self.ProbabilityChannelImage.meta.shape = tuple(probability_shape)

        # Cache the entire block
        self._opPredictionCache.blockShape.setValue(
            self._opPredictionCache.Input.meta.shape)
        self._opProbabilityCache.blockShape.setValue(
            self._opProbabilityCache.Input.meta.shape)

        # Forward dirty regions to our own output
        self._opPredictionImage.Output.notifyDirty(self._handleDirtyPrediction)

    def execute(self, slot, subindex, roi, destination):
        assert slot is self.PredictionImage or slot is self.ProbabilityChannelImage, "Unknown input slot"
        assert (numpy.array(roi.stop) <=
                slot.meta.shape).all(), "Roi is out-of-bounds"

        # Extract from the output (discard halo)
        halo_offset = numpy.subtract(self.block_roi[0], self._halo_roi[0])
        adjusted_roi = (halo_offset + roi.start, halo_offset + roi.stop)
        if slot is self.PredictionImage:
            return self._opPredictionCache.Output(
                *adjusted_roi).writeInto(destination).wait()
        elif slot is self.ProbabilityChannelImage:
            return self._opProbabilityCache.Output(
                *adjusted_roi).writeInto(destination).wait()

    def propagateDirty(self, slot, subindex, roi):
        """
        Nothing to do here because dirty notifications are propagated 
        through our internal pipeline and forwarded to our output via 
        our notifyDirty handler.
        """
        pass

    def _handleDirtyPrediction(self, slot, roi):
        """
        Foward dirty notifications from our internal output slot to the external one,
        but first discard the halo and offset the roi to compensate for the halo.
        """
        # Discard halo.  dirtyRoi is in internal coordinates (i.e. relative to halo start)
        dirtyRoi = getIntersection((roi.start, roi.stop),
                                   self._output_roi,
                                   assertIntersect=False)
        if dirtyRoi is not None:
            halo_offset = numpy.subtract(self.block_roi[0], self._halo_roi[0])
            adjusted_roi = dirtyRoi - halo_offset  # adjusted_roi is in output coordinates (relative to output block start)
            self.PredictionImage.setDirty(*adjusted_roi)

            # Expand to all channels and set channel image dirty
            adjusted_roi[:,
                         -1] = (0, self.ProbabilityChannelImage.meta.shape[-1])
            self.ProbabilityChannelImage.setDirty(*adjusted_roi)

    @classmethod
    def computeHaloRoi(cls, tagged_dataset_shape, halo_padding, block_roi):
        block_roi = numpy.array(block_roi)
        block_start, block_stop = block_roi

        channel_index = tagged_dataset_shape.keys().index('c')
        block_start[channel_index] = 0
        block_stop[channel_index] = tagged_dataset_shape['c']

        # Compute halo and clip to dataset bounds
        halo_start = block_start - halo_padding
        halo_start = numpy.maximum(halo_start, (0, ) * len(halo_start))

        halo_stop = block_stop + halo_padding
        halo_stop = numpy.minimum(halo_stop, tagged_dataset_shape.values())

        halo_roi = (halo_start, halo_stop)
        return halo_roi
class OpMaskedWatershed(Operator):
    """
    Performs a seeded watershed within a masked region.
    The masking is achieved using using vigra's terminate=StopAtThreshold feature.
    """
    Input = InputSlot(
        optional=True
    )  # If no input is given, output is voronoi within the masked region.
    Mask = InputSlot(
    )  # Watershed will only be computed for pixels where mask=True
    Seeds = InputSlot()

    Output = OutputSlot()

    def __init__(self, *args, **kwargs):
        super(OpMaskedWatershed, self).__init__(*args, **kwargs)

        # Use an internal operator to prepare the data,
        #  for easy caching/parallelization.
        self._opPrepInput = _OpPrepWatershedInput(parent=self)
        self._opPrepInput.Input.connect(self.Input)
        self._opPrepInput.Mask.connect(self.Mask)

        self._opPreppedInputCache = OpArrayCache(parent=self)
        self._opPreppedInputCache.Input.connect(self._opPrepInput.Output)

    def setupOutputs(self):
        if self.Input.ready():
            assert self.Input.meta.drange is not None, "Masked watershed requires input drange to be specified"

        # Cache the prepared input in 8 blocks
        blockshape = numpy.array(self._opPrepInput.Output.meta.shape) / 2
        blockshape = numpy.maximum(1, blockshape)
        self._opPreppedInputCache.blockShape.setValue(tuple(blockshape))

        self.Output.meta.assignFrom(self.Mask.meta)
        self.Output.meta.dtype = numpy.uint32

    def execute(self, slot, subindex, roi, result):
        # The input preparation involves converting to uint8 and combining
        #  the mask so we can use the StopAtThreshold mechanism
        with Timer() as prep_timer:
            input_data = self._opPreppedInputCache.Output(roi.start,
                                                          roi.stop).wait()
        logger.debug("Input prep took {} seconds".format(prep_timer.seconds()))

        input_axistags = self._opPrepInput.Output.meta.axistags
        max_input_value = self._opPrepInput.Output.meta.drange[1]

        seeds = self.Seeds(roi.start, roi.stop).wait()

        # The input_data has max value outside the mask area.
        # Discard seeds outside the mask
        seeds[input_data == max_input_value] = 0

        # Reduce to 3-D (keep order of xyz axes)
        tags = input_axistags
        axes3d = "".join([tag.key for tag in tags if tag.key in 'xyz'])

        input_view = vigra.taggedView(input_data, input_axistags)
        input_view = input_view.withAxes(*axes3d)
        input_view = vigra.taggedView(input_view, axes3d)

        seeds_view = vigra.taggedView(seeds, self.Seeds.meta.axistags)
        seeds_view = seeds_view.withAxes(*axes3d)
        seeds_view = seeds_view.astype(numpy.uint32)

        result_view = vigra.taggedView(result, self.Output.meta.axistags)
        result_view = result_view.withAxes(*axes3d)
        result_view = vigra.taggedView(result_view, axes3d)

        with Timer() as watershed_timer:
            # The 'watershedsNew' function is faster and supports StopAtThreshold even in turbo mode
            _, maxLabel = vigra.analysis.watershedsNew(
                input_view,
                seeds=seeds_view,
                out=result_view,
                method='turbo',
                terminate=vigra.analysis.SRGType.StopAtThreshold,
                max_cost=max_input_value - 1)

        logger.debug("vigra.watershedsNew() took {} seconds ({} seeds)".format(
            watershed_timer.seconds(), maxLabel))
        return result

    def propagateDirty(self, slot, subindex, roi):
        self.Output.setDirty()