class OpSingleBlockObjectPrediction(Operator): RawImage = InputSlot() BinaryImage = InputSlot() SelectedFeatures = InputSlot(rtype=List, stype=Opaque) Classifier = InputSlot() LabelsCount = InputSlot() ObjectwisePredictions = OutputSlot(stype=Opaque, rtype=List) PredictionImage = OutputSlot() ProbabilityChannelImage = OutputSlot() BlockwiseRegionFeatures = OutputSlot() # Indexed by (t,c) # Schematic: # # RawImage -----> opRawSubRegion ------ _______________________ # \ / \ # BinaryImage --> opBinarySubRegion --> opExtract --(features)--> opPredict --(map)--> opPredictionImage --via execute()--> PredictionImage # / \ / / # SelectedFeatures----- \ Classifier / # \ / # (labels)---------------------------> opProbabilityChannelsToImage # +----------------------------------------------------------------+ # | input_shape = RawImage.meta.shape | # | | # | | # | | # | | # | | # | | # | halo_shape = blockshape + 2*halo_padding | # | +------------------------+ | # | | halo_roi | | # | | (for internal pipeline)| | # | | | | # | | +------------------+ | | # | | | block_roi | | | # | | | (output shape) | | | # | | | | | | # | | | | | | # | | | | | | # | | +------------------+ | | # | | | | # | | | | # | | | | # | +------------------------+ | # | | # | | # | | # | | # | | # | | # | | # +----------------------------------------------------------------+ def __init__(self, block_roi, halo_padding, *args, **kwargs): super(self.__class__, self).__init__(*args, **kwargs) self.block_roi = block_roi # In global coordinates self._halo_padding = halo_padding self._opBinarySubRegion = OpSubRegion(parent=self) self._opBinarySubRegion.Input.connect(self.BinaryImage) self._opRawSubRegion = OpSubRegion(parent=self) self._opRawSubRegion.Input.connect(self.RawImage) self._opExtract = OpObjectExtraction(parent=self) self._opExtract.BinaryImage.connect(self._opBinarySubRegion.Output) self._opExtract.RawImage.connect(self._opRawSubRegion.Output) self._opExtract.Features.connect(self.SelectedFeatures) self.BlockwiseRegionFeatures.connect( self._opExtract.BlockwiseRegionFeatures) self._opExtract._opRegFeats._opCache.name = "blockwise-regionfeats-cache" self._opPredict = OpObjectPredict(parent=self) self._opPredict.Features.connect(self._opExtract.RegionFeatures) self._opPredict.SelectedFeatures.connect(self.SelectedFeatures) self._opPredict.Classifier.connect(self.Classifier) self._opPredict.LabelsCount.connect(self.LabelsCount) self.ObjectwisePredictions.connect(self._opPredict.Predictions) self._opPredictionImage = OpRelabelSegmentation(parent=self) self._opPredictionImage.Image.connect(self._opExtract.LabelImage) self._opPredictionImage.Features.connect( self._opExtract.RegionFeatures) self._opPredictionImage.ObjectMap.connect(self._opPredict.Predictions) self._opPredictionCache = OpArrayCache(parent=self) self._opPredictionCache.Input.connect(self._opPredictionImage.Output) self._opProbabilityChannelsToImage = OpMultiRelabelSegmentation( parent=self) self._opProbabilityChannelsToImage.Image.connect( self._opExtract.LabelImage) self._opProbabilityChannelsToImage.ObjectMaps.connect( self._opPredict.ProbabilityChannels) self._opProbabilityChannelsToImage.Features.connect( self._opExtract.RegionFeatures) self._opProbabilityChannelStacker = OpMultiArrayStacker(parent=self) self._opProbabilityChannelStacker.Images.connect( self._opProbabilityChannelsToImage.Output) self._opProbabilityChannelStacker.AxisFlag.setValue('c') self._opProbabilityCache = OpArrayCache(parent=self) self._opProbabilityCache.Input.connect( self._opProbabilityChannelStacker.Output) def setupOutputs(self): tagged_input_shape = self.RawImage.meta.getTaggedShape() self._halo_roi = self.computeHaloRoi( tagged_input_shape, self._halo_padding, self.block_roi) # In global coordinates # Output roi in our own coordinates (i.e. relative to the halo start) self._output_roi = self.block_roi - self._halo_roi[0] halo_start, halo_stop = map(tuple, self._halo_roi) self._opRawSubRegion.Roi.setValue((halo_start, halo_stop)) # Binary image has only 1 channel. Adjust halo subregion. assert self.BinaryImage.meta.getTaggedShape()['c'] == 1 c_index = self.BinaryImage.meta.axistags.channelIndex binary_halo_roi = numpy.array(self._halo_roi) binary_halo_roi[:, c_index] = (0, 1) # Binary has only 1 channel. binary_halo_start, binary_halo_stop = map(tuple, binary_halo_roi) self._opBinarySubRegion.Roi.setValue( (binary_halo_start, binary_halo_stop)) self.PredictionImage.meta.assignFrom( self._opPredictionImage.Output.meta) self.PredictionImage.meta.shape = tuple( numpy.subtract(self.block_roi[1], self.block_roi[0])) self.ProbabilityChannelImage.meta.assignFrom( self._opProbabilityChannelStacker.Output.meta) probability_shape = numpy.subtract(self.block_roi[1], self.block_roi[0]) probability_shape[ -1] = self._opProbabilityChannelStacker.Output.meta.shape[-1] self.ProbabilityChannelImage.meta.shape = tuple(probability_shape) # Cache the entire block self._opPredictionCache.blockShape.setValue( self._opPredictionCache.Input.meta.shape) self._opProbabilityCache.blockShape.setValue( self._opProbabilityCache.Input.meta.shape) # Forward dirty regions to our own output self._opPredictionImage.Output.notifyDirty(self._handleDirtyPrediction) def execute(self, slot, subindex, roi, destination): assert slot is self.PredictionImage or slot is self.ProbabilityChannelImage, "Unknown input slot" assert (numpy.array(roi.stop) <= slot.meta.shape).all(), "Roi is out-of-bounds" # Extract from the output (discard halo) halo_offset = numpy.subtract(self.block_roi[0], self._halo_roi[0]) adjusted_roi = (halo_offset + roi.start, halo_offset + roi.stop) if slot is self.PredictionImage: return self._opPredictionCache.Output( *adjusted_roi).writeInto(destination).wait() elif slot is self.ProbabilityChannelImage: return self._opProbabilityCache.Output( *adjusted_roi).writeInto(destination).wait() def propagateDirty(self, slot, subindex, roi): """ Nothing to do here because dirty notifications are propagated through our internal pipeline and forwarded to our output via our notifyDirty handler. """ pass def _handleDirtyPrediction(self, slot, roi): """ Foward dirty notifications from our internal output slot to the external one, but first discard the halo and offset the roi to compensate for the halo. """ # Discard halo. dirtyRoi is in internal coordinates (i.e. relative to halo start) dirtyRoi = getIntersection((roi.start, roi.stop), self._output_roi, assertIntersect=False) if dirtyRoi is not None: halo_offset = numpy.subtract(self.block_roi[0], self._halo_roi[0]) adjusted_roi = dirtyRoi - halo_offset # adjusted_roi is in output coordinates (relative to output block start) self.PredictionImage.setDirty(*adjusted_roi) # Expand to all channels and set channel image dirty adjusted_roi[:, -1] = (0, self.ProbabilityChannelImage.meta.shape[-1]) self.ProbabilityChannelImage.setDirty(*adjusted_roi) @classmethod def computeHaloRoi(cls, tagged_dataset_shape, halo_padding, block_roi): block_roi = numpy.array(block_roi) block_start, block_stop = block_roi channel_index = tagged_dataset_shape.keys().index('c') block_start[channel_index] = 0 block_stop[channel_index] = tagged_dataset_shape['c'] # Compute halo and clip to dataset bounds halo_start = block_start - halo_padding halo_start = numpy.maximum(halo_start, (0, ) * len(halo_start)) halo_stop = block_stop + halo_padding halo_stop = numpy.minimum(halo_stop, tagged_dataset_shape.values()) halo_roi = (halo_start, halo_stop) return halo_roi
class OpMaskedWatershed(Operator): """ Performs a seeded watershed within a masked region. The masking is achieved using using vigra's terminate=StopAtThreshold feature. """ Input = InputSlot( optional=True ) # If no input is given, output is voronoi within the masked region. Mask = InputSlot( ) # Watershed will only be computed for pixels where mask=True Seeds = InputSlot() Output = OutputSlot() def __init__(self, *args, **kwargs): super(OpMaskedWatershed, self).__init__(*args, **kwargs) # Use an internal operator to prepare the data, # for easy caching/parallelization. self._opPrepInput = _OpPrepWatershedInput(parent=self) self._opPrepInput.Input.connect(self.Input) self._opPrepInput.Mask.connect(self.Mask) self._opPreppedInputCache = OpArrayCache(parent=self) self._opPreppedInputCache.Input.connect(self._opPrepInput.Output) def setupOutputs(self): if self.Input.ready(): assert self.Input.meta.drange is not None, "Masked watershed requires input drange to be specified" # Cache the prepared input in 8 blocks blockshape = numpy.array(self._opPrepInput.Output.meta.shape) / 2 blockshape = numpy.maximum(1, blockshape) self._opPreppedInputCache.blockShape.setValue(tuple(blockshape)) self.Output.meta.assignFrom(self.Mask.meta) self.Output.meta.dtype = numpy.uint32 def execute(self, slot, subindex, roi, result): # The input preparation involves converting to uint8 and combining # the mask so we can use the StopAtThreshold mechanism with Timer() as prep_timer: input_data = self._opPreppedInputCache.Output(roi.start, roi.stop).wait() logger.debug("Input prep took {} seconds".format(prep_timer.seconds())) input_axistags = self._opPrepInput.Output.meta.axistags max_input_value = self._opPrepInput.Output.meta.drange[1] seeds = self.Seeds(roi.start, roi.stop).wait() # The input_data has max value outside the mask area. # Discard seeds outside the mask seeds[input_data == max_input_value] = 0 # Reduce to 3-D (keep order of xyz axes) tags = input_axistags axes3d = "".join([tag.key for tag in tags if tag.key in 'xyz']) input_view = vigra.taggedView(input_data, input_axistags) input_view = input_view.withAxes(*axes3d) input_view = vigra.taggedView(input_view, axes3d) seeds_view = vigra.taggedView(seeds, self.Seeds.meta.axistags) seeds_view = seeds_view.withAxes(*axes3d) seeds_view = seeds_view.astype(numpy.uint32) result_view = vigra.taggedView(result, self.Output.meta.axistags) result_view = result_view.withAxes(*axes3d) result_view = vigra.taggedView(result_view, axes3d) with Timer() as watershed_timer: # The 'watershedsNew' function is faster and supports StopAtThreshold even in turbo mode _, maxLabel = vigra.analysis.watershedsNew( input_view, seeds=seeds_view, out=result_view, method='turbo', terminate=vigra.analysis.SRGType.StopAtThreshold, max_cost=max_input_value - 1) logger.debug("vigra.watershedsNew() took {} seconds ({} seeds)".format( watershed_timer.seconds(), maxLabel)) return result def propagateDirty(self, slot, subindex, roi): self.Output.setDirty()