def __init__(self, *args, **kwargs):
            super(OpMulticut, self).__init__(*args, **kwargs)

            self.opMulticutAgglomerator = OpMulticutAgglomerator(parent=self)
            self.opMulticutAgglomerator.Beta.connect(self.Beta)
            self.opMulticutAgglomerator.SolverName.connect(self.SolverName)
            self.opMulticutAgglomerator.Rag.connect(self.Rag)
            self.opMulticutAgglomerator.EdgeProbabilities.connect(
                self.EdgeProbabilities)

            self.opNodeLabelsCache = OpValueCache(parent=self)
            self.opNodeLabelsCache.fixAtCurrent.connect(self.FreezeCache)
            self.opNodeLabelsCache.Input.connect(
                self.opMulticutAgglomerator.NodeLabels)
            self.opNodeLabelsCache.name = 'opNodeLabelCache'

            self.opRelabel = OpProjectNodeLabeling(parent=self)
            self.opRelabel.Superpixels.connect(self.Superpixels)
            self.opRelabel.NodeLabels.connect(self.opNodeLabelsCache.Output)

            self.opDisagreement = OpEdgeLabelDisagreementDict(parent=self)
            self.opDisagreement.Rag.connect(self.Rag)
            self.opDisagreement.NodeLabels.connect(self.opNodeLabelsCache.Output)
            self.opDisagreement.EdgeProbabilities.connect(self.EdgeProbabilities)
            self.EdgeLabelDisagreementDict.connect(
                self.opDisagreement.EdgeLabelDisagreementDict)

            self.opSegmentationCache = OpBlockedArrayCache(parent=self)
            self.opSegmentationCache.fixAtCurrent.connect(self.FreezeCache)
            self.opSegmentationCache.Input.connect(self.opRelabel.Output)
            self.Output.connect(self.opSegmentationCache.Output)
    def __init__(self, *args, **kwargs):
        super(OpMockPixelClassifier, self).__init__(*args, **kwargs)

        self.LabelNames.setValue( ["Membrane", "Cytoplasm"] )
        self.LabelColors.setValue( [(255,0,0), (0,255,0)] ) # Red, Green
        self.PmapColors.setValue( [(255,0,0), (0,255,0)] ) # Red, Green
        
        self._data = []
        self.dataShape = (1,10,100,100,1)
        self.prediction_shape = self.dataShape[:-1] + (2,) # Hard-coded to provide 2 classes
        
        self.FreezePredictions.setValue(False)
        
        self.opClassifier = OpTrainClassifierBlocked(graph=self.graph, parent=self)
        self.opClassifier.ClassifierFactory.connect( self.ClassifierFactory )
        self.opClassifier.Labels.connect(self.LabelImages)
        self.opClassifier.nonzeroLabelBlocks.connect(self.NonzeroLabelBlocks)
        self.opClassifier.MaxLabel.setValue(2)
        
        self.classifier_cache = OpValueCache(graph=self.graph, parent=self)
        self.classifier_cache.Input.connect( self.opClassifier.Classifier )
        
        p1 = numpy.indices(self.dataShape).sum(0) / 207.0
        p2 = 1 - p1

        self.predictionData = numpy.concatenate((p1,p2), axis=4)
Example #3
0
    def test_basic(self):
        graph = lazyflow.graph.Graph()
        op = OpValueCache(graph=graph)
        assert not op._dirty
        op.Input.setValue('Hello')
        assert op._dirty
        assert op.Output.value == 'Hello'

        outputDirtyCount = [0]
        def handleOutputDirty(slot, roi):
            outputDirtyCount[0] += 1
        op.Output.notifyDirty(handleOutputDirty)
        
        op.forceValue('Goodbye')
        # The cache itself isn't dirty (won't ask input for value)
        assert not op._dirty
        assert op.Output.value == 'Goodbye'
        
        # But the cache notified downstream slots that his value changed
        assert outputDirtyCount[0] == 1
    def __init__(self, *args, **kwargs):
        super(OpMockPixelClassifier, self).__init__(*args, **kwargs)
        self._data = []
        self.dataShape = (1,10,100,100,1)
        self.prediction_shape = self.dataShape[:-1] + (2,) # Hard-coded to provide 2 classes
        
        self.FreezePredictions.setValue(False)
        
        self.opClassifier = OpTrainRandomForestBlocked(graph=self.graph, parent=self)
        self.opClassifier.Labels.connect(self.LabelImages)
        self.opClassifier.nonzeroLabelBlocks.connect(self.NonzeroLabelBlocks)
        self.opClassifier.fixClassifier.setValue(False)
        
        self.classifier_cache = OpValueCache(graph=self.graph, parent=self)
        self.classifier_cache.Input.connect( self.opClassifier.Classifier )
        
        p1 = numpy.indices(self.dataShape).sum(0) / 207.0
        p2 = 1 - p1

        self.predictionData = numpy.concatenate((p1,p2), axis=4)
    def __init__( self, *args, **kwargs ):
        """
        Instantiate all internal operators and connect them together.
        """
        super(OpPixelClassification, self).__init__(*args, **kwargs)
        
        # Default values for some input slots
        self.FreezePredictions.setValue(True)
        self.LabelNames.setValue( [] )
        self.LabelColors.setValue( [] )
        self.PmapColors.setValue( [] )

        # SPECIAL connection: The LabelInputs slot doesn't get it's data  
        #  from the InputImages slot, but it's shape must match.
        self.LabelInputs.connect( self.InputImages )

        # Hook up Labeling Pipeline
        self.opLabelPipeline = OpMultiLaneWrapper( OpLabelPipeline, parent=self, broadcastingSlotNames=['DeleteLabel'] )
        self.opLabelPipeline.RawImage.connect( self.InputImages )
        self.opLabelPipeline.LabelInput.connect( self.LabelInputs )
        self.opLabelPipeline.DeleteLabel.setValue( -1 )
        self.LabelImages.connect( self.opLabelPipeline.Output )
        self.NonzeroLabelBlocks.connect( self.opLabelPipeline.nonzeroBlocks )

        # Hook up the Training operator
        self.opTrain = OpTrainClassifierBlocked( parent=self )
        self.opTrain.ClassifierFactory.connect( self.ClassifierFactory )
        self.opTrain.Labels.connect( self.opLabelPipeline.Output )
        self.opTrain.Images.connect( self.CachedFeatureImages )
        self.opTrain.nonzeroLabelBlocks.connect( self.opLabelPipeline.nonzeroBlocks )

        # Hook up the Classifier Cache
        # The classifier is cached here to allow serializers to force in
        #   a pre-calculated classifier (loaded from disk)
        self.classifier_cache = OpValueCache( parent=self )
        self.classifier_cache.name = "OpPixelClassification.classifier_cache"
        self.classifier_cache.inputs["Input"].connect(self.opTrain.outputs['Classifier'])
        self.classifier_cache.inputs["fixAtCurrent"].connect( self.FreezePredictions )
        self.Classifier.connect( self.classifier_cache.Output )

        # Hook up the prediction pipeline inputs
        self.opPredictionPipeline = OpMultiLaneWrapper( OpPredictionPipeline, parent=self )
        self.opPredictionPipeline.FeatureImages.connect( self.FeatureImages )
        self.opPredictionPipeline.CachedFeatureImages.connect( self.CachedFeatureImages )
        self.opPredictionPipeline.Classifier.connect( self.classifier_cache.Output )
        self.opPredictionPipeline.FreezePredictions.connect( self.FreezePredictions )
        self.opPredictionPipeline.PredictionsFromDisk.connect( self.PredictionsFromDisk )
        self.opPredictionPipeline.PredictionMask.connect( self.PredictionMasks )
        
        def _updateNumClasses(*args):
            """
            When the number of labels changes, we MUST make sure that the prediction image changes its shape (the number of channels).
            Since setupOutputs is not called for mere dirty notifications, but is called in response to setValue(),
            we use this function to call setValue().
            """
            numClasses = len(self.LabelNames.value)
            self.opTrain.MaxLabel.setValue( numClasses )
            self.opPredictionPipeline.NumClasses.setValue( numClasses )
            self.NumClasses.setValue( numClasses )
        self.LabelNames.notifyDirty( _updateNumClasses )

        # Prediction pipeline outputs -> Top-level outputs
        self.PredictionProbabilities.connect( self.opPredictionPipeline.PredictionProbabilities )
        self.CachedPredictionProbabilities.connect( self.opPredictionPipeline.CachedPredictionProbabilities )
        self.HeadlessPredictionProbabilities.connect( self.opPredictionPipeline.HeadlessPredictionProbabilities )
        self.HeadlessUint8PredictionProbabilities.connect( self.opPredictionPipeline.HeadlessUint8PredictionProbabilities )
        self.PredictionProbabilityChannels.connect( self.opPredictionPipeline.PredictionProbabilityChannels )
        self.SegmentationChannels.connect( self.opPredictionPipeline.SegmentationChannels )
        self.UncertaintyEstimate.connect( self.opPredictionPipeline.UncertaintyEstimate )
        self.SimpleSegmentation.connect( self.opPredictionPipeline.SimpleSegmentation )
        self.HeadlessUncertaintyEstimate.connect( self.opPredictionPipeline.HeadlessUncertaintyEstimate )

        def inputResizeHandler( slot, oldsize, newsize ):
            if ( newsize == 0 ):
                self.LabelImages.resize(0)
                self.NonzeroLabelBlocks.resize(0)
                self.PredictionProbabilities.resize(0)
                self.CachedPredictionProbabilities.resize(0)
        self.InputImages.notifyResized( inputResizeHandler )

        # Debug assertions: Check to make sure the non-wrapped operators stayed that way.
        assert self.opTrain.Images.operator == self.opTrain

        def handleNewInputImage( multislot, index, *args ):
            def handleInputReady(slot):
                self._checkConstraints( index )
                self.setupCaches( multislot.index(slot) )
            multislot[index].notifyReady(handleInputReady)
                
        self.InputImages.notifyInserted( handleNewInputImage )

        # If any feature image changes shape, we need to verify that the 
        #  channels are consistent with the currently cached classifier
        # Otherwise, delete the currently cached classifier.
        def handleNewFeatureImage( multislot, index, *args ):
            def handleFeatureImageReady(slot):
                def handleFeatureMetaChanged(slot):
                    if ( self.classifier_cache.fixAtCurrent.value and
                         self.classifier_cache.Output.ready() and 
                         slot.meta.shape is not None ):
                        classifier = self.classifier_cache.Output.value
                        channel_names = slot.meta.channel_names
                        if classifier and classifier.feature_names != channel_names:
                            self.classifier_cache.resetValue()
                slot.notifyMetaChanged(handleFeatureMetaChanged)
            multislot[index].notifyReady(handleFeatureImageReady)
                
        self.FeatureImages.notifyInserted( handleNewFeatureImage )

        def handleNewMaskImage( multislot, index, *args ):
            def handleInputReady(slot):
                self._checkConstraints( index )
            multislot[index].notifyReady(handleInputReady)        
        self.PredictionMasks.notifyInserted( handleNewMaskImage )

        # All input multi-slots should be kept in sync
        # Output multi-slots will auto-sync via the graph
        multiInputs = filter( lambda s: s.level >= 1, self.inputs.values() )
        for s1 in multiInputs:
            for s2 in multiInputs:
                if s1 != s2:
                    def insertSlot( a, b, position, finalsize ):
                        a.insertSlot(position, finalsize)
                    s1.notifyInserted( partial(insertSlot, s2 ) )
                    
                    def removeSlot( a, b, position, finalsize ):
                        a.removeSlot(position, finalsize)
                    s1.notifyRemoved( partial(removeSlot, s2 ) )
class OpPixelClassification( Operator ):
    """
    Top-level operator for pixel classification
    """
    name="OpPixelClassification"
    category = "Top-level"
    
    # Graph inputs
    
    InputImages = InputSlot(level=1) # Original input data.  Used for display only.
    PredictionMasks = InputSlot(level=1, optional=True) # Routed to OpClassifierPredict.PredictionMask.  See there for details.

    LabelInputs = InputSlot(optional = True, level=1) # Input for providing label data from an external source
    LabelsAllowedFlags = InputSlot(stype='bool', level=1) # Specifies which images are permitted to be labeled 
    
    FeatureImages = InputSlot(level=1) # Computed feature images (each channel is a different feature)
    CachedFeatureImages = InputSlot(level=1) # Cached feature data.

    FreezePredictions = InputSlot(stype='bool')
    ClassifierFactory = InputSlot(value=ParallelVigraRfLazyflowClassifierFactory(100))

    PredictionsFromDisk = InputSlot(optional=True, level=1)

    PredictionProbabilities = OutputSlot(level=1) # Classification predictions (via feature cache for interactive speed)

    PredictionProbabilityChannels = OutputSlot(level=2) # Classification predictions, enumerated by channel
    SegmentationChannels = OutputSlot(level=2) # Binary image of the final selections.
    
    LabelImages = OutputSlot(level=1) # Labels from the user
    NonzeroLabelBlocks = OutputSlot(level=1) # A list if slices that contain non-zero label values
    Classifier = OutputSlot() # We provide the classifier as an external output for other applets to use

    CachedPredictionProbabilities = OutputSlot(level=1) # Classification predictions (via feature cache AND prediction cache)

    HeadlessPredictionProbabilities = OutputSlot(level=1) # Classification predictions ( via no image caches (except for the classifier itself )
    HeadlessUint8PredictionProbabilities = OutputSlot(level=1) # Same as above, but 0-255 uint8 instead of 0.0-1.0 float32
    HeadlessUncertaintyEstimate = OutputSlot(level=1) # Same as uncertaintly estimate, but does not rely on cached data.

    UncertaintyEstimate = OutputSlot(level=1)
    
    SimpleSegmentation = OutputSlot(level=1) # For debug, for now

    # GUI-only (not part of the pipeline, but saved to the project)
    LabelNames = OutputSlot()
    LabelColors = OutputSlot()
    PmapColors = OutputSlot()

    NumClasses = OutputSlot()
    
    def setupOutputs(self):
        self.LabelNames.meta.dtype = object
        self.LabelNames.meta.shape = (1,)
        self.LabelColors.meta.dtype = object
        self.LabelColors.meta.shape = (1,)
        self.PmapColors.meta.dtype = object
        self.PmapColors.meta.shape = (1,)

    def __init__( self, *args, **kwargs ):
        """
        Instantiate all internal operators and connect them together.
        """
        super(OpPixelClassification, self).__init__(*args, **kwargs)
        
        # Default values for some input slots
        self.FreezePredictions.setValue(True)
        self.LabelNames.setValue( [] )
        self.LabelColors.setValue( [] )
        self.PmapColors.setValue( [] )

        # SPECIAL connection: The LabelInputs slot doesn't get it's data  
        #  from the InputImages slot, but it's shape must match.
        self.LabelInputs.connect( self.InputImages )

        # Hook up Labeling Pipeline
        self.opLabelPipeline = OpMultiLaneWrapper( OpLabelPipeline, parent=self, broadcastingSlotNames=['DeleteLabel'] )
        self.opLabelPipeline.RawImage.connect( self.InputImages )
        self.opLabelPipeline.LabelInput.connect( self.LabelInputs )
        self.opLabelPipeline.DeleteLabel.setValue( -1 )
        self.LabelImages.connect( self.opLabelPipeline.Output )
        self.NonzeroLabelBlocks.connect( self.opLabelPipeline.nonzeroBlocks )

        # Hook up the Training operator
        self.opTrain = OpTrainClassifierBlocked( parent=self )
        self.opTrain.ClassifierFactory.connect( self.ClassifierFactory )
        self.opTrain.Labels.connect( self.opLabelPipeline.Output )
        self.opTrain.Images.connect( self.CachedFeatureImages )
        self.opTrain.nonzeroLabelBlocks.connect( self.opLabelPipeline.nonzeroBlocks )

        # Hook up the Classifier Cache
        # The classifier is cached here to allow serializers to force in
        #   a pre-calculated classifier (loaded from disk)
        self.classifier_cache = OpValueCache( parent=self )
        self.classifier_cache.name = "OpPixelClassification.classifier_cache"
        self.classifier_cache.inputs["Input"].connect(self.opTrain.outputs['Classifier'])
        self.classifier_cache.inputs["fixAtCurrent"].connect( self.FreezePredictions )
        self.Classifier.connect( self.classifier_cache.Output )

        # Hook up the prediction pipeline inputs
        self.opPredictionPipeline = OpMultiLaneWrapper( OpPredictionPipeline, parent=self )
        self.opPredictionPipeline.FeatureImages.connect( self.FeatureImages )
        self.opPredictionPipeline.CachedFeatureImages.connect( self.CachedFeatureImages )
        self.opPredictionPipeline.Classifier.connect( self.classifier_cache.Output )
        self.opPredictionPipeline.FreezePredictions.connect( self.FreezePredictions )
        self.opPredictionPipeline.PredictionsFromDisk.connect( self.PredictionsFromDisk )
        self.opPredictionPipeline.PredictionMask.connect( self.PredictionMasks )
        
        def _updateNumClasses(*args):
            """
            When the number of labels changes, we MUST make sure that the prediction image changes its shape (the number of channels).
            Since setupOutputs is not called for mere dirty notifications, but is called in response to setValue(),
            we use this function to call setValue().
            """
            numClasses = len(self.LabelNames.value)
            self.opTrain.MaxLabel.setValue( numClasses )
            self.opPredictionPipeline.NumClasses.setValue( numClasses )
            self.NumClasses.setValue( numClasses )
        self.LabelNames.notifyDirty( _updateNumClasses )

        # Prediction pipeline outputs -> Top-level outputs
        self.PredictionProbabilities.connect( self.opPredictionPipeline.PredictionProbabilities )
        self.CachedPredictionProbabilities.connect( self.opPredictionPipeline.CachedPredictionProbabilities )
        self.HeadlessPredictionProbabilities.connect( self.opPredictionPipeline.HeadlessPredictionProbabilities )
        self.HeadlessUint8PredictionProbabilities.connect( self.opPredictionPipeline.HeadlessUint8PredictionProbabilities )
        self.PredictionProbabilityChannels.connect( self.opPredictionPipeline.PredictionProbabilityChannels )
        self.SegmentationChannels.connect( self.opPredictionPipeline.SegmentationChannels )
        self.UncertaintyEstimate.connect( self.opPredictionPipeline.UncertaintyEstimate )
        self.SimpleSegmentation.connect( self.opPredictionPipeline.SimpleSegmentation )
        self.HeadlessUncertaintyEstimate.connect( self.opPredictionPipeline.HeadlessUncertaintyEstimate )

        def inputResizeHandler( slot, oldsize, newsize ):
            if ( newsize == 0 ):
                self.LabelImages.resize(0)
                self.NonzeroLabelBlocks.resize(0)
                self.PredictionProbabilities.resize(0)
                self.CachedPredictionProbabilities.resize(0)
        self.InputImages.notifyResized( inputResizeHandler )

        # Debug assertions: Check to make sure the non-wrapped operators stayed that way.
        assert self.opTrain.Images.operator == self.opTrain

        def handleNewInputImage( multislot, index, *args ):
            def handleInputReady(slot):
                self._checkConstraints( index )
                self.setupCaches( multislot.index(slot) )
            multislot[index].notifyReady(handleInputReady)
                
        self.InputImages.notifyInserted( handleNewInputImage )

        # If any feature image changes shape, we need to verify that the 
        #  channels are consistent with the currently cached classifier
        # Otherwise, delete the currently cached classifier.
        def handleNewFeatureImage( multislot, index, *args ):
            def handleFeatureImageReady(slot):
                def handleFeatureMetaChanged(slot):
                    if ( self.classifier_cache.fixAtCurrent.value and
                         self.classifier_cache.Output.ready() and 
                         slot.meta.shape is not None ):
                        classifier = self.classifier_cache.Output.value
                        channel_names = slot.meta.channel_names
                        if classifier and classifier.feature_names != channel_names:
                            self.classifier_cache.resetValue()
                slot.notifyMetaChanged(handleFeatureMetaChanged)
            multislot[index].notifyReady(handleFeatureImageReady)
                
        self.FeatureImages.notifyInserted( handleNewFeatureImage )

        def handleNewMaskImage( multislot, index, *args ):
            def handleInputReady(slot):
                self._checkConstraints( index )
            multislot[index].notifyReady(handleInputReady)        
        self.PredictionMasks.notifyInserted( handleNewMaskImage )

        # All input multi-slots should be kept in sync
        # Output multi-slots will auto-sync via the graph
        multiInputs = filter( lambda s: s.level >= 1, self.inputs.values() )
        for s1 in multiInputs:
            for s2 in multiInputs:
                if s1 != s2:
                    def insertSlot( a, b, position, finalsize ):
                        a.insertSlot(position, finalsize)
                    s1.notifyInserted( partial(insertSlot, s2 ) )
                    
                    def removeSlot( a, b, position, finalsize ):
                        a.removeSlot(position, finalsize)
                    s1.notifyRemoved( partial(removeSlot, s2 ) )

    def setupCaches(self, imageIndex):
        numImages = len(self.InputImages)
        inputSlot = self.InputImages[imageIndex]
#        # Can't setup if all inputs haven't been set yet.
#        if numImages != len(self.FeatureImages) or \
#           numImages != len(self.CachedFeatureImages):
#            return
#        
#        self.LabelImages.resize(numImages)
        self.LabelInputs.resize(numImages)

        # Special case: We have to set up the shape of our label *input* according to our image input shape
        shapeList = list(self.InputImages[imageIndex].meta.shape)
        try:
            channelIndex = self.InputImages[imageIndex].meta.axistags.index('c')
            shapeList[channelIndex] = 1
        except:
            pass
        self.LabelInputs[imageIndex].meta.shape = tuple(shapeList)
        self.LabelInputs[imageIndex].meta.axistags = inputSlot.meta.axistags

    def _checkConstraints(self, laneIndex):
        """
        Ensure that all input images have the same number of channels.
        """
        if not self.InputImages[laneIndex].ready():
            return

        thisLaneTaggedShape = self.InputImages[laneIndex].meta.getTaggedShape()

        # Find a different lane and use it for comparison
        validShape = thisLaneTaggedShape
        for i, slot in enumerate(self.InputImages):
            if slot.ready() and i != laneIndex:
                validShape = slot.meta.getTaggedShape()
                break

        if validShape['c'] != thisLaneTaggedShape['c']:
            raise DatasetConstraintError(
                 "Pixel Classification",
                 "All input images must have the same number of channels.  "\
                 "Your new image has {} channel(s), but your other images have {} channel(s)."\
                 .format( thisLaneTaggedShape['c'], validShape['c'] ) )
            
        if len(validShape) != len(thisLaneTaggedShape):
            raise DatasetConstraintError(
                 "Pixel Classification",
                 "All input images must have the same dimensionality.  "\
                 "Your new image has {} dimensions (including channel), but your other images have {} dimensions."\
                 .format( len(thisLaneTaggedShape), len(validShape) ) )
        
        mask_slot = self.PredictionMasks[laneIndex]
        input_shape = tuple(thisLaneTaggedShape.values())
        if mask_slot.ready() and mask_slot.meta.shape[:-1] != input_shape[:-1]:
            raise DatasetConstraintError(
                 "Pixel Classification",
                 "If you supply a prediction mask, it must have the same shape as the input image."\
                 "Your input image has shape {}, but your mask has shape {}."\
                 .format( input_shape, mask_slot.meta.shape ) )
    
    def setInSlot(self, slot, subindex, roi, value):
        # Nothing to do here: All inputs that support __setitem__
        #   are directly connected to internal operators.
        pass

    def propagateDirty(self, slot, subindex, roi):
        # Nothing to do here: All outputs are directly connected to 
        #  internal operators that handle their own dirty propagation.
        pass

    def addLane(self, laneIndex):
        numLanes = len(self.InputImages)
        assert numLanes == laneIndex, "Image lanes must be appended."        
        self.InputImages.resize(numLanes+1)
        
    def removeLane(self, laneIndex, finalLength):
        self.InputImages.removeSlot(laneIndex, finalLength)

    def getLane(self, laneIndex):
        return OperatorSubView(self, laneIndex)

    def importLabels(self, laneIndex, slot):
        # Load the data into the cache
        new_max = self.getLane( laneIndex ).opLabelPipeline.opLabelArray.ingestData( slot )

        # Add to the list of label names if there's a new max label
        old_names = self.LabelNames.value
        old_max = len(old_names)
        if new_max > old_max:
            new_names = old_names + map( lambda x: "Label {}".format(x), 
                                         range(old_max+1, new_max+1) )
            self.LabelNames.setValue(new_names)

            # Make some default colors, too
            default_colors = [(255,0,0),
                              (0,255,0),
                              (0,0,255),
                              (255,255,0),
                              (255,0,255),
                              (0,255,255),
                              (128,128,128),
                              (255, 105, 180),
                              (255, 165, 0),
                              (240, 230, 140) ]
            label_colors = self.LabelColors.value
            pmap_colors = self.PmapColors.value
            
            self.LabelColors.setValue( label_colors + default_colors[old_max:new_max] )
            self.PmapColors.setValue( pmap_colors + default_colors[old_max:new_max] )

    def mergeLabels(self, from_label, into_label):
        for laneIndex in range(len(self.InputImages)):
            self.getLane( laneIndex ).opLabelPipeline.opLabelArray.mergeLabels(from_label, into_label)
Example #7
0
    def __init__(self, *args, **kwargs):
        super(OpEdgeTraining, self).__init__(*args, **kwargs)

        self.opCreateRag = OpMultiLaneWrapper(OpCreateRag, parent=self)
        self.opCreateRag.Superpixels.connect(self.Superpixels)

        self.opRagCache = OpMultiLaneWrapper(
            OpValueCache, parent=self, broadcastingSlotNames=['fixAtCurrent'])
        self.opRagCache.Input.connect(self.opCreateRag.Rag)
        self.opRagCache.name = 'opRagCache'

        self.opComputeEdgeFeatures = OpMultiLaneWrapper(
            OpComputeEdgeFeatures,
            parent=self,
            broadcastingSlotNames=['FeatureNames'])
        self.opComputeEdgeFeatures.FeatureNames.connect(self.FeatureNames)
        self.opComputeEdgeFeatures.VoxelData.connect(self.VoxelData)
        self.opComputeEdgeFeatures.Rag.connect(self.opRagCache.Output)

        self.opEdgeFeaturesCache = OpMultiLaneWrapper(
            OpValueCache, parent=self, broadcastingSlotNames=['fixAtCurrent'])
        self.opEdgeFeaturesCache.Input.connect(
            self.opComputeEdgeFeatures.EdgeFeaturesDataFrame)
        self.opEdgeFeaturesCache.name = 'opEdgeFeaturesCache'

        self.opTrainEdgeClassifier = OpTrainEdgeClassifier(parent=self)
        self.opTrainEdgeClassifier.EdgeLabelsDict.connect(self.EdgeLabelsDict)
        self.opTrainEdgeClassifier.EdgeFeaturesDataFrame.connect(
            self.opEdgeFeaturesCache.Output)

        # classifier cache input is set after training.
        self.opClassifierCache = OpValueCache(parent=self)
        self.opClassifierCache.Input.connect(
            self.opTrainEdgeClassifier.EdgeClassifier)
        self.opClassifierCache.fixAtCurrent.connect(self.FreezeClassifier)
        self.opClassifierCache.name = 'opClassifierCache'

        self.opPredictEdgeProbabilities = OpMultiLaneWrapper(
            OpPredictEdgeProbabilities,
            parent=self,
            broadcastingSlotNames=['EdgeClassifier'])
        self.opPredictEdgeProbabilities.EdgeClassifier.connect(
            self.opClassifierCache.Output)
        self.opPredictEdgeProbabilities.EdgeFeaturesDataFrame.connect(
            self.opEdgeFeaturesCache.Output)

        self.opEdgeProbabilitiesCache = OpMultiLaneWrapper(
            OpValueCache, parent=self, broadcastingSlotNames=['fixAtCurrent'])
        self.opEdgeProbabilitiesCache.Input.connect(
            self.opPredictEdgeProbabilities.EdgeProbabilities)
        self.opEdgeProbabilitiesCache.name = 'opEdgeProbabilitiesCache'
        self.opEdgeProbabilitiesCache.fixAtCurrent.connect(
            self.FreezeClassifier)

        self.opEdgeProbabilitiesDict = OpMultiLaneWrapper(
            OpEdgeProbabilitiesDict, parent=self)
        self.opEdgeProbabilitiesDict.Rag.connect(self.opRagCache.Output)
        self.opEdgeProbabilitiesDict.EdgeProbabilities.connect(
            self.opEdgeProbabilitiesCache.Output)

        self.opEdgeProbabilitiesDictCache = OpMultiLaneWrapper(
            OpValueCache, parent=self, broadcastingSlotNames=['fixAtCurrent'])
        self.opEdgeProbabilitiesDictCache.Input.connect(
            self.opEdgeProbabilitiesDict.EdgeProbabilitiesDict)
        self.opEdgeProbabilitiesDictCache.name = 'opEdgeProbabilitiesDictCache'

        self.opNaiveSegmentation = OpMultiLaneWrapper(OpNaiveSegmentation,
                                                      parent=self)
        self.opNaiveSegmentation.Superpixels.connect(self.Superpixels)
        self.opNaiveSegmentation.Rag.connect(self.opRagCache.Output)
        self.opNaiveSegmentation.EdgeProbabilities.connect(
            self.opEdgeProbabilitiesCache.Output)

        self.opNaiveSegmentationCache = OpMultiLaneWrapper(
            OpBlockedArrayCache,
            parent=self,
            broadcastingSlotNames=[
                'CompressionEnabled', 'fixAtCurrent', 'BypassModeEnabled'
            ])
        self.opNaiveSegmentationCache.CompressionEnabled.setValue(True)
        self.opNaiveSegmentationCache.Input.connect(
            self.opNaiveSegmentation.Output)
        self.opNaiveSegmentationCache.name = 'opNaiveSegmentationCache'

        self.Rag.connect(self.opRagCache.Output)
        self.EdgeProbabilities.connect(self.opEdgeProbabilitiesCache.Output)
        self.EdgeProbabilitiesDict.connect(
            self.opEdgeProbabilitiesDictCache.Output)
        self.NaiveSegmentation.connect(self.opNaiveSegmentationCache.Output)

        # All input multi-slots should be kept in sync
        # Output multi-slots will auto-sync via the graph
        multiInputs = filter(lambda s: s.level >= 1, self.inputs.values())
        for s1 in multiInputs:
            for s2 in multiInputs:
                if s1 != s2:

                    def insertSlot(a, b, position, finalsize):
                        a.insertSlot(position, finalsize)

                    s1.notifyInserted(partial(insertSlot, s2))

                    def removeSlot(a, b, position, finalsize):
                        a.removeSlot(position, finalsize)

                    s1.notifyRemoved(partial(removeSlot, s2))

        # If superpixels change, we have to delete our edge labels.
        # Since we're dealing with multi-lane slot, setting up dirty handlers is a two-stage process.
        # (1) React to lane insertion by subscribing to dirty signals for the new lane.
        # (2) React to each lane's dirty signal by deleting the labels for that lane.

        def subscribe_to_dirty_sp(slot, position, finalsize):
            # A new lane was added.  Subscribe to it's dirty signal.
            assert slot is self.Superpixels
            self.Superpixels[position].notifyDirty(
                self.handle_dirty_superpixels)
            self.Superpixels[position].notifyReady(
                self.handle_dirty_superpixels)
            self.Superpixels[position].notifyUnready(
                self.handle_dirty_superpixels)

        # When a new lane is added, set up the listener for dirtyness.
        self.Superpixels.notifyInserted(subscribe_to_dirty_sp)
Example #8
0
    def __init__(self, *args, **kwargs):
        """
        Instantiate all internal operators and connect them together.
        """
        super(OpCounting, self).__init__(*args, **kwargs)

        # Default values for some input slots
        self.FreezePredictions.setValue(True)
        self.LabelNames.setValue(["Foreground", "Background"])
        self.LabelColors.setValue([(255, 0, 0), (0, 255, 0)])
        self.PmapColors.setValue([(255, 0, 0), (0, 255, 0)])

        # SPECIAL connection: The LabelInputs slot doesn't get it's data
        #  from the InputImages slot, but it's shape must match.
        self.LabelInputs.connect(self.InputImages)
        self.BoxLabelInputs.connect(self.InputImages)

        # Hook up Labeling Pipeline
        self.opLabelPipeline = OpMultiLaneWrapper(OpLabelPipeline, parent=self)
        self.opLabelPipeline.RawImage.connect(self.InputImages)
        self.opLabelPipeline.LabelInput.connect(self.LabelInputs)
        self.opLabelPipeline.BoxLabelInput.connect(self.BoxLabelInputs)
        self.LabelImages.connect(self.opLabelPipeline.Output)
        self.NonzeroLabelBlocks.connect(self.opLabelPipeline.nonzeroBlocks)

        self.BoxLabelImages.connect(self.opLabelPipeline.BoxOutput)

        self.GetFore = OpMultiLaneWrapper(OpPixelOperator, parent=self)

        def conv(arr):
            numpy.place(arr, arr == 2, 0)
            return arr.astype(numpy.float)

        self.GetFore.Function.setValue(conv)
        self.GetFore.Input.connect(self.opLabelPipeline.Output)

        self.LabelPreviewer = OpMultiLaneWrapper(OpLabelPreviewer, parent=self)
        self.LabelPreviewer.Input.connect(self.GetFore.Output)

        self.LabelPreview.connect(self.LabelPreviewer.Output)

        # Hook up the Training operator
        self.opUpperBound = OpUpperBound(parent=self, graph=self.graph)
        self.UpperBound.connect(self.opUpperBound.UpperBound)

        self.boxViewer = OpBoxViewer(parent=self, graph=self.graph)

        self.opTrain = OpTrainCounter(parent=self, graph=self.graph)
        self.opTrain.inputs['ForegroundLabels'].connect(self.GetFore.Output)
        self.opTrain.inputs['BackgroundLabels'].connect(
            self.opLabelPipeline.Output)
        self.opTrain.inputs['Images'].connect(self.CachedFeatureImages)
        self.opTrain.inputs["nonzeroLabelBlocks"].connect(
            self.opLabelPipeline.nonzeroBlocks)
        self.opTrain.inputs['fixClassifier'].setValue(True)
        self.opTrain.inputs["UpperBound"].connect(self.opUpperBound.UpperBound)

        # Hook up the Classifier Cache
        # The classifier is cached here to allow serializers to force in
        #   a pre-calculated classifier (loaded from disk)
        self.classifier_cache = OpValueCache(parent=self, graph=self.graph)
        self.classifier_cache.inputs["Input"].connect(
            self.opTrain.outputs['Classifier'])
        self.Classifier.connect(self.classifier_cache.Output)

        # Hook up the prediction pipeline inputs
        self.opPredictionPipeline = OpMultiLaneWrapper(OpPredictionPipeline,
                                                       parent=self)
        self.opPredictionPipeline.FeatureImages.connect(self.FeatureImages)
        self.opPredictionPipeline.CachedFeatureImages.connect(
            self.CachedFeatureImages)
        self.opPredictionPipeline.MaxLabel.setValue(2)
        self.opPredictionPipeline.Classifier.connect(
            self.classifier_cache.Output)
        self.opPredictionPipeline.FreezePredictions.connect(
            self.FreezePredictions)
        self.opPredictionPipeline.PredictionsFromDisk.connect(
            self.PredictionsFromDisk)

        # Prediction pipeline outputs -> Top-level outputs
        self.PredictionProbabilities.connect(
            self.opPredictionPipeline.PredictionProbabilities)
        self.CachedPredictionProbabilities.connect(
            self.opPredictionPipeline.CachedPredictionProbabilities)
        self.HeadlessPredictionProbabilities.connect(
            self.opPredictionPipeline.HeadlessPredictionProbabilities)
        #self.HeadlessUint8PredictionProbabilities.connect( self.opPredictionPipeline.HeadlessUint8PredictionProbabilities )
        #self.PredictionProbabilityChannels.connect( self.opPredictionPipeline.PredictionProbabilityChannels )
        #self.SegmentationChannels.connect( self.opPredictionPipeline.SegmentationChannels )
        self.UncertaintyEstimate.connect(
            self.opPredictionPipeline.UncertaintyEstimate)
        self.Density.connect(
            self.opPredictionPipeline.CachedPredictionProbabilities)
        self.OutputSum.connect(self.opPredictionPipeline.OutputSum)

        def inputResizeHandler(slot, oldsize, newsize):
            if (newsize == 0):
                self.LabelImages.resize(0)
                self.NonzeroLabelBlocks.resize(0)
                self.PredictionProbabilities.resize(0)
                self.CachedPredictionProbabilities.resize(0)

        self.InputImages.notifyResized(inputResizeHandler)

        # Debug assertions: Check to make sure the non-wrapped operators stayed that way.
        assert self.opTrain.Images.operator == self.opTrain

        def handleNewInputImage(multislot, index, *args):
            def handleInputReady(slot):
                self._checkConstraints(index)
                self.setupCaches(multislot.index(slot))

            multislot[index].notifyReady(handleInputReady)

        self.InputImages.notifyInserted(handleNewInputImage)

        # All input multi-slots should be kept in sync
        # Output multi-slots will auto-sync via the graph
        multiInputs = filter(lambda s: s.level >= 1, self.inputs.values())
        for s1 in multiInputs:
            for s2 in multiInputs:
                if s1 != s2:

                    def insertSlot(a, b, position, finalsize):
                        a.insertSlot(position, finalsize)

                    s1.notifyInserted(partial(insertSlot, s2))

                    def removeSlot(a, b, position, finalsize):
                        a.removeSlot(position, finalsize)

                    s1.notifyRemoved(partial(removeSlot, s2))

        self.options = self.opTrain.options
    def __init__(self, *args, **kwargs):
        """
        Instantiate all internal operators and connect them together.
        """
        super(OpPixelClassification, self).__init__(*args, **kwargs)

        # Default values for some input slots
        self.FreezePredictions.setValue(True)
        self.LabelNames.setValue([])
        self.LabelColors.setValue([])
        self.PmapColors.setValue([])

        # SPECIAL connection: The LabelInputs slot doesn't get it's data
        #  from the InputImages slot, but it's shape must match.
        self.LabelInputs.connect(self.InputImages)

        # Hook up Labeling Pipeline
        self.opLabelPipeline = OpMultiLaneWrapper(OpLabelPipeline, parent=self)
        self.opLabelPipeline.RawImage.connect(self.InputImages)
        self.opLabelPipeline.LabelInput.connect(self.LabelInputs)
        self.LabelImages.connect(self.opLabelPipeline.Output)
        self.NonzeroLabelBlocks.connect(self.opLabelPipeline.nonzeroBlocks)

        # Find the highest label in all the label images
        self.opMaxLabel = OpMaxValue(parent=self)
        self.opMaxLabel.Inputs.connect(self.opLabelPipeline.MaxLabel)
        self.MaxLabelValue.connect(self.opMaxLabel.Output)

        # Hook up the Training operator
        self.opTrain = OpTrainRandomForestBlocked(parent=self)
        self.opTrain.inputs['Labels'].connect(self.opLabelPipeline.Output)
        self.opTrain.inputs['Images'].connect(self.CachedFeatureImages)
        self.opTrain.inputs['MaxLabel'].connect(self.opMaxLabel.Output)
        self.opTrain.inputs["nonzeroLabelBlocks"].connect(
            self.opLabelPipeline.nonzeroBlocks)
        self.opTrain.inputs['fixClassifier'].setValue(False)

        # Hook up the Classifier Cache
        # The classifier is cached here to allow serializers to force in
        #   a pre-calculated classifier (loaded from disk)
        self.classifier_cache = OpValueCache(parent=self)
        self.classifier_cache.inputs["Input"].connect(
            self.opTrain.outputs['Classifier'])
        self.Classifier.connect(self.classifier_cache.Output)

        # Hook up the prediction pipeline inputs
        self.opPredictionPipeline = OpMultiLaneWrapper(OpPredictionPipeline,
                                                       parent=self)
        self.opPredictionPipeline.FeatureImages.connect(self.FeatureImages)
        self.opPredictionPipeline.CachedFeatureImages.connect(
            self.CachedFeatureImages)
        self.opPredictionPipeline.MaxLabel.connect(self.opMaxLabel.Output)
        self.opPredictionPipeline.Classifier.connect(
            self.classifier_cache.Output)
        self.opPredictionPipeline.FreezePredictions.connect(
            self.FreezePredictions)
        self.opPredictionPipeline.PredictionsFromDisk.connect(
            self.PredictionsFromDisk)

        # Prediction pipeline outputs -> Top-level outputs
        self.PredictionProbabilities.connect(
            self.opPredictionPipeline.PredictionProbabilities)
        self.CachedPredictionProbabilities.connect(
            self.opPredictionPipeline.CachedPredictionProbabilities)
        self.HeadlessPredictionProbabilities.connect(
            self.opPredictionPipeline.HeadlessPredictionProbabilities)
        self.HeadlessUint8PredictionProbabilities.connect(
            self.opPredictionPipeline.HeadlessUint8PredictionProbabilities)
        self.PredictionProbabilityChannels.connect(
            self.opPredictionPipeline.PredictionProbabilityChannels)
        self.SegmentationChannels.connect(
            self.opPredictionPipeline.SegmentationChannels)
        self.UncertaintyEstimate.connect(
            self.opPredictionPipeline.UncertaintyEstimate)

        def inputResizeHandler(slot, oldsize, newsize):
            if (newsize == 0):
                self.LabelImages.resize(0)
                self.NonzeroLabelBlocks.resize(0)
                self.PredictionProbabilities.resize(0)
                self.CachedPredictionProbabilities.resize(0)

        self.InputImages.notifyResized(inputResizeHandler)

        # Debug assertions: Check to make sure the non-wrapped operators stayed that way.
        assert self.opMaxLabel.Inputs.operator == self.opMaxLabel
        assert self.opTrain.Images.operator == self.opTrain

        def handleNewInputImage(multislot, index, *args):
            def handleInputReady(slot):
                self._checkConstraints(index)
                self.setupCaches(multislot.index(slot))

            multislot[index].notifyReady(handleInputReady)

        self.InputImages.notifyInserted(handleNewInputImage)

        # All input multi-slots should be kept in sync
        # Output multi-slots will auto-sync via the graph
        multiInputs = filter(lambda s: s.level >= 1, self.inputs.values())
        for s1 in multiInputs:
            for s2 in multiInputs:
                if s1 != s2:

                    def insertSlot(a, b, position, finalsize):
                        a.insertSlot(position, finalsize)

                    s1.notifyInserted(partial(insertSlot, s2))

                    def removeSlot(a, b, position, finalsize):
                        a.removeSlot(position, finalsize)

                    s1.notifyRemoved(partial(removeSlot, s2))
Example #10
0
class OpEdgeTraining(Operator):
    # Shared across lanes
    DEFAULT_FEATURES = {"Grayscale": ["standard_edge_mean"]}
    FeatureNames = InputSlot(value=DEFAULT_FEATURES)
    FreezeClassifier = InputSlot(value=True)
    TrainRandomForest = InputSlot(value=False)

    # Lane-wise
    WatershedSelectedInput = InputSlot(level=1)
    EdgeLabelsDict = InputSlot(level=1, value={})
    VoxelData = InputSlot(level=1)  # stacked input with edge probabilities
    Superpixels = InputSlot(level=1)
    GroundtruthSegmentation = InputSlot(level=1, optional=True)
    RawData = InputSlot(level=1, optional=True)  # Used by the GUI for display only

    Rag = OutputSlot(level=1)
    EdgeProbabilities = OutputSlot(level=1)
    EdgeProbabilitiesDict = OutputSlot(level=1)  # A dict of id_pair -> probabilities
    NaiveSegmentation = OutputSlot(level=1)

    def __init__(self, *args, **kwargs):
        super(OpEdgeTraining, self).__init__(*args, **kwargs)

        self.opCreateRag = OpMultiLaneWrapper(OpCreateRag, parent=self)
        self.opCreateRag.Superpixels.connect(self.Superpixels)

        self.opRagCache = OpMultiLaneWrapper(OpValueCache, parent=self, broadcastingSlotNames=["fixAtCurrent"])
        self.opRagCache.Input.connect(self.opCreateRag.Rag)
        self.opRagCache.name = "opRagCache"

        self.opComputeEdgeFeatures = OpMultiLaneWrapper(
            OpComputeEdgeFeatures, parent=self, broadcastingSlotNames=["FeatureNames", "TrainRandomForest"]
        )
        self.opComputeEdgeFeatures.FeatureNames.connect(self.FeatureNames)
        self.opComputeEdgeFeatures.VoxelData.connect(self.VoxelData)
        self.opComputeEdgeFeatures.Rag.connect(self.opRagCache.Output)
        self.opComputeEdgeFeatures.TrainRandomForest.connect(self.TrainRandomForest)
        self.opComputeEdgeFeatures.WatershedSelectedInput.connect(self.WatershedSelectedInput)

        self.opEdgeFeaturesCache = OpMultiLaneWrapper(OpValueCache, parent=self, broadcastingSlotNames=["fixAtCurrent"])
        self.opEdgeFeaturesCache.Input.connect(self.opComputeEdgeFeatures.EdgeFeaturesDataFrame)
        self.opEdgeFeaturesCache.name = "opEdgeFeaturesCache"

        self.opTrainEdgeClassifier = OpTrainEdgeClassifier(parent=self)
        self.opTrainEdgeClassifier.EdgeLabelsDict.connect(self.EdgeLabelsDict)
        self.opTrainEdgeClassifier.EdgeFeaturesDataFrame.connect(self.opEdgeFeaturesCache.Output)

        # classifier cache input is set after training.
        self.opClassifierCache = OpValueCache(parent=self)
        self.opClassifierCache.Input.connect(self.opTrainEdgeClassifier.EdgeClassifier)
        self.opClassifierCache.fixAtCurrent.connect(self.FreezeClassifier)
        self.opClassifierCache.name = "opClassifierCache"

        self.opPredictEdgeProbabilities = OpMultiLaneWrapper(
            OpPredictEdgeProbabilities, parent=self, broadcastingSlotNames=["EdgeClassifier", "TrainRandomForest"]
        )
        self.opPredictEdgeProbabilities.EdgeClassifier.connect(self.opClassifierCache.Output)
        self.opPredictEdgeProbabilities.EdgeFeaturesDataFrame.connect(self.opEdgeFeaturesCache.Output)
        self.opPredictEdgeProbabilities.TrainRandomForest.connect(self.TrainRandomForest)

        self.opEdgeProbabilitiesCache = OpMultiLaneWrapper(
            OpValueCache, parent=self, broadcastingSlotNames=["fixAtCurrent"]
        )
        self.opEdgeProbabilitiesCache.Input.connect(self.opPredictEdgeProbabilities.EdgeProbabilities)
        self.opEdgeProbabilitiesCache.name = "opEdgeProbabilitiesCache"
        self.opEdgeProbabilitiesCache.fixAtCurrent.connect(self.FreezeClassifier)

        self.opEdgeProbabilitiesDict = OpMultiLaneWrapper(OpEdgeProbabilitiesDict, parent=self)
        self.opEdgeProbabilitiesDict.Rag.connect(self.opRagCache.Output)
        self.opEdgeProbabilitiesDict.EdgeProbabilities.connect(self.opEdgeProbabilitiesCache.Output)

        self.opEdgeProbabilitiesDictCache = OpMultiLaneWrapper(
            OpValueCache, parent=self, broadcastingSlotNames=["fixAtCurrent"]
        )
        self.opEdgeProbabilitiesDictCache.Input.connect(self.opEdgeProbabilitiesDict.EdgeProbabilitiesDict)
        self.opEdgeProbabilitiesDictCache.name = "opEdgeProbabilitiesDictCache"

        self.opNaiveSegmentation = OpMultiLaneWrapper(OpNaiveSegmentation, parent=self)
        self.opNaiveSegmentation.Superpixels.connect(self.Superpixels)
        self.opNaiveSegmentation.Rag.connect(self.opRagCache.Output)
        self.opNaiveSegmentation.EdgeProbabilities.connect(self.opEdgeProbabilitiesCache.Output)

        self.opNaiveSegmentationCache = OpMultiLaneWrapper(
            OpBlockedArrayCache,
            parent=self,
            broadcastingSlotNames=["CompressionEnabled", "fixAtCurrent", "BypassModeEnabled"],
        )
        self.opNaiveSegmentationCache.CompressionEnabled.setValue(True)
        self.opNaiveSegmentationCache.Input.connect(self.opNaiveSegmentation.Output)
        self.opNaiveSegmentationCache.name = "opNaiveSegmentationCache"

        self.Rag.connect(self.opRagCache.Output)
        self.EdgeProbabilities.connect(self.opEdgeProbabilitiesCache.Output)
        self.EdgeProbabilitiesDict.connect(self.opEdgeProbabilitiesDictCache.Output)
        self.NaiveSegmentation.connect(self.opNaiveSegmentationCache.Output)

        # All input multi-slots should be kept in sync
        # Output multi-slots will auto-sync via the graph
        multiInputs = [s for s in list(self.inputs.values()) if s.level >= 1]
        for s1 in multiInputs:
            for s2 in multiInputs:
                if s1 != s2:

                    def insertSlot(a, b, position, finalsize):
                        a.insertSlot(position, finalsize)

                    s1.notifyInserted(partial(insertSlot, s2))

                    def removeSlot(a, b, position, finalsize):
                        a.removeSlot(position, finalsize)

                    s1.notifyRemoved(partial(removeSlot, s2))

        # If superpixels change, we have to delete our edge labels.
        # Since we're dealing with multi-lane slot, setting up dirty handlers is a two-stage process.
        # (1) React to lane insertion by subscribing to dirty signals for the new lane.
        # (2) React to each lane's dirty signal by deleting the labels for that lane.

        def subscribe_to_dirty_sp(slot, position, finalsize):
            # A new lane was added.  Subscribe to it's dirty signal.
            assert slot is self.Superpixels
            self.Superpixels[position].notifyDirty(self.handle_dirty_superpixels)
            self.Superpixels[position].notifyReady(self.handle_dirty_superpixels)
            self.Superpixels[position].notifyUnready(self.handle_dirty_superpixels)

        # When a new lane is added, set up the listener for dirtyness.
        self.Superpixels.notifyInserted(subscribe_to_dirty_sp)

    def handle_dirty_superpixels(self, subslot, *args):
        """
        Discards the labels for a given lane.
        NOTE: In addition to callers in this file, this function is also called from multicutWorkflow.py
        """
        # Determine which lane triggered this and delete it's labels
        lane_index = self.Superpixels.index(subslot)
        old_labels = self.EdgeLabelsDict[lane_index].value
        if old_labels:
            logger.warning("Superpixels changed.  Deleting all labels in lane {}.".format(lane_index))
            logger.info("Old labels were: {}".format(old_labels))
            self.EdgeLabelsDict[lane_index].setValue({})

    def setupOutputs(self):
        for sp_slot, seg_cache_blockshape_slot in zip(self.Superpixels, self.opNaiveSegmentationCache.BlockShape):
            assert sp_slot.meta.dtype == np.uint32
            assert sp_slot.meta.getAxisKeys()[-1] == "c"
            seg_cache_blockshape_slot.setValue(sp_slot.meta.shape)

    def execute(self, slot, subindex, roi, result):
        assert False, "Shouldn't get here, but requesting slot: {}".format(slot)

    def propagateDirty(self, slot, subindex, roi):
        pass

    def setEdgeLabelsFromGroundtruth(self, lane_index):
        """
        For the given lane, read the ground truth volume and
        automatically determine edge label values.
        """
        op_view = self.getLane(lane_index)

        if not op_view.GroundtruthSegmentation.ready():
            raise RuntimeError("There is no Ground Truth data available for lane: {}".format(lane_index))

        logger.info("Loading groundtruth for lane {}...".format(lane_index))
        gt_vol = op_view.GroundtruthSegmentation[:].wait()
        gt_vol = vigra.taggedView(gt_vol, op_view.GroundtruthSegmentation.meta.axistags)
        gt_vol = gt_vol.withAxes("".join(tag.key for tag in op_view.Superpixels.meta.axistags))
        gt_vol = gt_vol.dropChannelAxis()

        rag = op_view.opRagCache.Output.value

        logger.info("Computing edge decisions from groundtruth...")
        decisions = rag.edge_decisions_from_groundtruth(gt_vol, asdict=False)
        edge_labels = decisions.view(np.uint8) + 1
        edge_ids = list(map(tuple, rag.edge_ids))
        edge_labels_dict = dict(list(zip(edge_ids, edge_labels)))
        op_view.EdgeLabelsDict.setValue(edge_labels_dict)

    def addLane(self, laneIndex):
        numLanes = len(self.VoxelData)
        assert numLanes == laneIndex, "Image lanes must be appended."
        self.VoxelData.resize(numLanes + 1)

    def removeLane(self, laneIndex, finalLength):
        self.VoxelData.removeSlot(laneIndex, finalLength)

    def getLane(self, laneIndex):
        return OperatorSubView(self, laneIndex)

    def clear_caches(self, lane_index):
        self.opClassifierCache.resetValue()
        for cache in [
            self.opRagCache,
            self.opEdgeProbabilitiesCache,
            self.opEdgeProbabilitiesDictCache,
            self.opEdgeFeaturesCache,
        ]:
            c = cache.getLane(lane_index)
            c.resetValue()
Example #11
0
class OpPixelClassification(Operator):
    """
    Top-level operator for pixel classification
    """
    name = "OpPixelClassification"
    category = "Top-level"

    # Graph inputs

    InputImages = InputSlot(
        level=1)  # Original input data.  Used for display only.
    PredictionMasks = InputSlot(
        level=1, optional=True
    )  # Routed to OpClassifierPredict.PredictionMask.  See there for details.

    LabelInputs = InputSlot(
        optional=True,
        level=1)  # Input for providing label data from an external source

    FeatureImages = InputSlot(
        level=1
    )  # Computed feature images (each channel is a different feature)
    CachedFeatureImages = InputSlot(level=1)  # Cached feature data.

    FreezePredictions = InputSlot(stype='bool')
    ClassifierFactory = InputSlot(
        value=ParallelVigraRfLazyflowClassifierFactory(100))

    PredictionsFromDisk = InputSlot(optional=True, level=1)

    PredictionProbabilities = OutputSlot(
        level=1
    )  # Classification predictions (via feature cache for interactive speed)
    PredictionProbabilitiesUint8 = OutputSlot(
        level=1)  # Same thing, but converted to uint8 first

    PredictionProbabilityChannels = OutputSlot(
        level=2)  # Classification predictions, enumerated by channel
    SegmentationChannels = OutputSlot(
        level=2)  # Binary image of the final selections.

    LabelImages = OutputSlot(level=1)  # Labels from the user
    NonzeroLabelBlocks = OutputSlot(
        level=1)  # A list if slices that contain non-zero label values
    Classifier = OutputSlot(
    )  # We provide the classifier as an external output for other applets to use

    CachedPredictionProbabilities = OutputSlot(
        level=1
    )  # Classification predictions (via feature cache AND prediction cache)

    HeadlessPredictionProbabilities = OutputSlot(
        level=1
    )  # Classification predictions ( via no image caches (except for the classifier itself )
    HeadlessUint8PredictionProbabilities = OutputSlot(
        level=1)  # Same as above, but 0-255 uint8 instead of 0.0-1.0 float32
    HeadlessUncertaintyEstimate = OutputSlot(
        level=1
    )  # Same as uncertaintly estimate, but does not rely on cached data.

    UncertaintyEstimate = OutputSlot(level=1)

    SimpleSegmentation = OutputSlot(level=1)  # For debug, for now

    # GUI-only (not part of the pipeline, but saved to the project)
    LabelNames = OutputSlot()
    LabelColors = OutputSlot()
    PmapColors = OutputSlot()

    NumClasses = OutputSlot()

    def setupOutputs(self):
        self.LabelNames.meta.dtype = object
        self.LabelNames.meta.shape = (1, )
        self.LabelColors.meta.dtype = object
        self.LabelColors.meta.shape = (1, )
        self.PmapColors.meta.dtype = object
        self.PmapColors.meta.shape = (1, )

    def __init__(self, *args, **kwargs):
        """
        Instantiate all internal operators and connect them together.
        """
        super(OpPixelClassification, self).__init__(*args, **kwargs)

        # Default values for some input slots
        self.FreezePredictions.setValue(True)
        self.LabelNames.setValue([])
        self.LabelColors.setValue([])
        self.PmapColors.setValue([])

        # SPECIAL connection: The LabelInputs slot doesn't get it's data
        #  from the InputImages slot, but it's shape must match.
        self.LabelInputs.connect(self.InputImages)

        # Hook up Labeling Pipeline
        self.opLabelPipeline = OpMultiLaneWrapper(
            OpLabelPipeline,
            parent=self,
            broadcastingSlotNames=['DeleteLabel'])
        self.opLabelPipeline.RawImage.connect(self.InputImages)
        self.opLabelPipeline.LabelInput.connect(self.LabelInputs)
        self.opLabelPipeline.DeleteLabel.setValue(-1)
        self.LabelImages.connect(self.opLabelPipeline.Output)
        self.NonzeroLabelBlocks.connect(self.opLabelPipeline.nonzeroBlocks)

        # Hook up the Training operator
        self.opTrain = OpTrainClassifierBlocked(parent=self)
        self.opTrain.ClassifierFactory.connect(self.ClassifierFactory)
        self.opTrain.Labels.connect(self.opLabelPipeline.Output)
        self.opTrain.Images.connect(self.FeatureImages)
        self.opTrain.nonzeroLabelBlocks.connect(
            self.opLabelPipeline.nonzeroBlocks)

        # Hook up the Classifier Cache
        # The classifier is cached here to allow serializers to force in
        #   a pre-calculated classifier (loaded from disk)
        self.classifier_cache = OpValueCache(parent=self)
        self.classifier_cache.name = "OpPixelClassification.classifier_cache"
        self.classifier_cache.inputs["Input"].connect(
            self.opTrain.outputs['Classifier'])
        self.classifier_cache.inputs["fixAtCurrent"].connect(
            self.FreezePredictions)
        self.Classifier.connect(self.classifier_cache.Output)

        # Hook up the prediction pipeline inputs
        self.opPredictionPipeline = OpMultiLaneWrapper(OpPredictionPipeline,
                                                       parent=self)
        self.opPredictionPipeline.FeatureImages.connect(self.FeatureImages)
        self.opPredictionPipeline.CachedFeatureImages.connect(
            self.CachedFeatureImages)
        self.opPredictionPipeline.Classifier.connect(
            self.classifier_cache.Output)
        self.opPredictionPipeline.FreezePredictions.connect(
            self.FreezePredictions)
        self.opPredictionPipeline.PredictionsFromDisk.connect(
            self.PredictionsFromDisk)
        self.opPredictionPipeline.PredictionMask.connect(self.PredictionMasks)

        # Feature Selection Stuff
        self.opFeatureMatrixCaches = OpMultiLaneWrapper(OpFeatureMatrixCache,
                                                        parent=self)
        self.opFeatureMatrixCaches.LabelImage.connect(
            self.opLabelPipeline.Output)
        self.opFeatureMatrixCaches.FeatureImage.connect(self.FeatureImages)
        self.opFeatureMatrixCaches.LabelImage.setDirty(
        )  # do I still need this?

        def _updateNumClasses(*args):
            """
            When the number of labels changes, we MUST make sure that the prediction image changes its shape (the number of channels).
            Since setupOutputs is not called for mere dirty notifications, but is called in response to setValue(),
            we use this function to call setValue().
            """
            numClasses = len(self.LabelNames.value)
            self.opTrain.MaxLabel.setValue(numClasses)
            self.opPredictionPipeline.NumClasses.setValue(numClasses)
            self.NumClasses.setValue(numClasses)

        self.LabelNames.notifyDirty(_updateNumClasses)

        # Prediction pipeline outputs -> Top-level outputs
        self.PredictionProbabilities.connect(
            self.opPredictionPipeline.PredictionProbabilities)
        self.PredictionProbabilitiesUint8.connect(
            self.opPredictionPipeline.PredictionProbabilitiesUint8)
        self.CachedPredictionProbabilities.connect(
            self.opPredictionPipeline.CachedPredictionProbabilities)
        self.HeadlessPredictionProbabilities.connect(
            self.opPredictionPipeline.HeadlessPredictionProbabilities)
        self.HeadlessUint8PredictionProbabilities.connect(
            self.opPredictionPipeline.HeadlessUint8PredictionProbabilities)
        self.PredictionProbabilityChannels.connect(
            self.opPredictionPipeline.PredictionProbabilityChannels)
        self.SegmentationChannels.connect(
            self.opPredictionPipeline.SegmentationChannels)
        self.UncertaintyEstimate.connect(
            self.opPredictionPipeline.UncertaintyEstimate)
        self.SimpleSegmentation.connect(
            self.opPredictionPipeline.SimpleSegmentation)
        self.HeadlessUncertaintyEstimate.connect(
            self.opPredictionPipeline.HeadlessUncertaintyEstimate)

        def inputResizeHandler(slot, oldsize, newsize):
            if (newsize == 0):
                self.LabelImages.resize(0)
                self.NonzeroLabelBlocks.resize(0)
                self.PredictionProbabilities.resize(0)
                self.CachedPredictionProbabilities.resize(0)

        self.InputImages.notifyResized(inputResizeHandler)

        # Debug assertions: Check to make sure the non-wrapped operators stayed that way.
        assert self.opTrain.Images.operator == self.opTrain

        def handleNewInputImage(multislot, index, *args):
            def handleInputReady(slot):
                self._checkConstraints(index)
                self.setupCaches(multislot.index(slot))

            multislot[index].notifyReady(handleInputReady)

        self.InputImages.notifyInserted(handleNewInputImage)

        # If any feature image changes shape, we need to verify that the
        #  channels are consistent with the currently cached classifier
        # Otherwise, delete the currently cached classifier.
        def handleNewFeatureImage(multislot, index, *args):
            def handleFeatureImageReady(slot):
                def handleFeatureMetaChanged(slot):
                    if (self.classifier_cache.fixAtCurrent.value
                            and self.classifier_cache.Output.ready()
                            and slot.meta.shape is not None):
                        classifier = self.classifier_cache.Output.value
                        channel_names = slot.meta.channel_names
                        if classifier and classifier.feature_names != channel_names:
                            self.classifier_cache.resetValue()

                slot.notifyMetaChanged(handleFeatureMetaChanged)

            multislot[index].notifyReady(handleFeatureImageReady)

        self.FeatureImages.notifyInserted(handleNewFeatureImage)

        def handleNewMaskImage(multislot, index, *args):
            def handleInputReady(slot):
                self._checkConstraints(index)

            multislot[index].notifyReady(handleInputReady)

        self.PredictionMasks.notifyInserted(handleNewMaskImage)

        # All input multi-slots should be kept in sync
        # Output multi-slots will auto-sync via the graph
        multiInputs = filter(lambda s: s.level >= 1, self.inputs.values())
        for s1 in multiInputs:
            for s2 in multiInputs:
                if s1 != s2:

                    def insertSlot(a, b, position, finalsize):
                        a.insertSlot(position, finalsize)

                    s1.notifyInserted(partial(insertSlot, s2))

                    def removeSlot(a, b, position, finalsize):
                        a.removeSlot(position, finalsize)

                    s1.notifyRemoved(partial(removeSlot, s2))

    def setupCaches(self, imageIndex):
        numImages = len(self.InputImages)
        inputSlot = self.InputImages[imageIndex]
        #        # Can't setup if all inputs haven't been set yet.
        #        if numImages != len(self.FeatureImages) or \
        #           numImages != len(self.CachedFeatureImages):
        #            return
        #
        #        self.LabelImages.resize(numImages)
        self.LabelInputs.resize(numImages)

        # Special case: We have to set up the shape of our label *input* according to our image input shape
        shapeList = list(self.InputImages[imageIndex].meta.shape)
        try:
            channelIndex = self.InputImages[imageIndex].meta.axistags.index(
                'c')
            shapeList[channelIndex] = 1
        except:
            pass
        self.LabelInputs[imageIndex].meta.shape = tuple(shapeList)
        self.LabelInputs[imageIndex].meta.axistags = inputSlot.meta.axistags

    def _checkConstraints(self, laneIndex):
        """
        Ensure that all input images have the same number of channels.
        """
        if not self.InputImages[laneIndex].ready():
            return

        thisLaneTaggedShape = self.InputImages[laneIndex].meta.getTaggedShape()

        # Find a different lane and use it for comparison
        validShape = thisLaneTaggedShape
        for i, slot in enumerate(self.InputImages):
            if slot.ready() and i != laneIndex:
                validShape = slot.meta.getTaggedShape()
                break

        if 't' in thisLaneTaggedShape:
            del thisLaneTaggedShape['t']
        if 't' in validShape:
            del validShape['t']

        if validShape['c'] != thisLaneTaggedShape['c']:
            raise DatasetConstraintError(
                 "Pixel Classification",
                 "All input images must have the same number of channels.  "\
                 "Your new image has {} channel(s), but your other images have {} channel(s)."\
                 .format( thisLaneTaggedShape['c'], validShape['c'] ) )

        if len(validShape) != len(thisLaneTaggedShape):
            raise DatasetConstraintError(
                 "Pixel Classification",
                 "All input images must have the same dimensionality.  "\
                 "Your new image has {} dimensions (including channel), but your other images have {} dimensions."\
                 .format( len(thisLaneTaggedShape), len(validShape) ) )

        mask_slot = self.PredictionMasks[laneIndex]
        input_shape = self.InputImages[laneIndex].meta.shape
        if mask_slot.ready() and mask_slot.meta.shape[:-1] != input_shape[:-1]:
            raise DatasetConstraintError(
                 "Pixel Classification",
                 "If you supply a prediction mask, it must have the same shape as the input image."\
                 "Your input image has shape {}, but your mask has shape {}."\
                 .format( input_shape, mask_slot.meta.shape ) )

    def setInSlot(self, slot, subindex, roi, value):
        # Nothing to do here: All inputs that support __setitem__
        #   are directly connected to internal operators.
        pass

    def propagateDirty(self, slot, subindex, roi):
        # Nothing to do here: All outputs are directly connected to
        #  internal operators that handle their own dirty propagation.
        pass

    def addLane(self, laneIndex):
        numLanes = len(self.InputImages)
        assert numLanes == laneIndex, "Image lanes must be appended."
        self.InputImages.resize(numLanes + 1)

    def removeLane(self, laneIndex, finalLength):
        self.InputImages.removeSlot(laneIndex, finalLength)

    def getLane(self, laneIndex):
        return OperatorSubView(self, laneIndex)

    def importLabels(self, laneIndex, slot):
        # Load the data into the cache
        new_max = self.getLane(
            laneIndex).opLabelPipeline.opLabelArray.ingestData(slot)

        # Add to the list of label names if there's a new max label
        old_names = self.LabelNames.value
        old_max = len(old_names)
        if new_max > old_max:
            new_names = old_names + map(lambda x: "Label {}".format(x),
                                        range(old_max + 1, new_max + 1))
            self.LabelNames.setValue(new_names)

            # Make some default colors, too
            default_colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255),
                              (255, 255, 0), (255, 0, 255), (0, 255, 255),
                              (128, 128, 128), (255, 105, 180), (255, 165, 0),
                              (240, 230, 140)]
            label_colors = self.LabelColors.value
            pmap_colors = self.PmapColors.value

            self.LabelColors.setValue(label_colors +
                                      default_colors[old_max:new_max])
            self.PmapColors.setValue(pmap_colors +
                                     default_colors[old_max:new_max])

    def mergeLabels(self, from_label, into_label):
        for laneIndex in range(len(self.InputImages)):
            self.getLane(laneIndex).opLabelPipeline.opLabelArray.mergeLabels(
                from_label, into_label)

    def clearLabel(self, label_value):
        for laneIndex in range(len(self.InputImages)):
            self.getLane(laneIndex).opLabelPipeline.opLabelArray.clearLabel(
                label_value)
Example #12
0
    def __init__(self, *args, **kwargs):
        """
        Instantiate all internal operators and connect them together.
        """
        super(OpPixelClassification, self).__init__(*args, **kwargs)

        # Default values for some input slots
        self.FreezePredictions.setValue(True)
        self.LabelNames.setValue([])
        self.LabelColors.setValue([])
        self.PmapColors.setValue([])

        # SPECIAL connection: The LabelInputs slot doesn't get it's data
        #  from the InputImages slot, but it's shape must match.
        self.LabelInputs.connect(self.InputImages)

        # Hook up Labeling Pipeline
        self.opLabelPipeline = OpMultiLaneWrapper(
            OpLabelPipeline,
            parent=self,
            broadcastingSlotNames=['DeleteLabel'])
        self.opLabelPipeline.RawImage.connect(self.InputImages)
        self.opLabelPipeline.LabelInput.connect(self.LabelInputs)
        self.opLabelPipeline.DeleteLabel.setValue(-1)
        self.LabelImages.connect(self.opLabelPipeline.Output)
        self.NonzeroLabelBlocks.connect(self.opLabelPipeline.nonzeroBlocks)

        # Hook up the Training operator
        self.opTrain = OpTrainClassifierBlocked(parent=self)
        self.opTrain.ClassifierFactory.connect(self.ClassifierFactory)
        self.opTrain.Labels.connect(self.opLabelPipeline.Output)
        self.opTrain.Images.connect(self.FeatureImages)
        self.opTrain.nonzeroLabelBlocks.connect(
            self.opLabelPipeline.nonzeroBlocks)

        # Hook up the Classifier Cache
        # The classifier is cached here to allow serializers to force in
        #   a pre-calculated classifier (loaded from disk)
        self.classifier_cache = OpValueCache(parent=self)
        self.classifier_cache.name = "OpPixelClassification.classifier_cache"
        self.classifier_cache.inputs["Input"].connect(
            self.opTrain.outputs['Classifier'])
        self.classifier_cache.inputs["fixAtCurrent"].connect(
            self.FreezePredictions)
        self.Classifier.connect(self.classifier_cache.Output)

        # Hook up the prediction pipeline inputs
        self.opPredictionPipeline = OpMultiLaneWrapper(OpPredictionPipeline,
                                                       parent=self)
        self.opPredictionPipeline.FeatureImages.connect(self.FeatureImages)
        self.opPredictionPipeline.CachedFeatureImages.connect(
            self.CachedFeatureImages)
        self.opPredictionPipeline.Classifier.connect(
            self.classifier_cache.Output)
        self.opPredictionPipeline.FreezePredictions.connect(
            self.FreezePredictions)
        self.opPredictionPipeline.PredictionsFromDisk.connect(
            self.PredictionsFromDisk)
        self.opPredictionPipeline.PredictionMask.connect(self.PredictionMasks)

        # Feature Selection Stuff
        self.opFeatureMatrixCaches = OpMultiLaneWrapper(OpFeatureMatrixCache,
                                                        parent=self)
        self.opFeatureMatrixCaches.LabelImage.connect(
            self.opLabelPipeline.Output)
        self.opFeatureMatrixCaches.FeatureImage.connect(self.FeatureImages)
        self.opFeatureMatrixCaches.LabelImage.setDirty(
        )  # do I still need this?

        def _updateNumClasses(*args):
            """
            When the number of labels changes, we MUST make sure that the prediction image changes its shape (the number of channels).
            Since setupOutputs is not called for mere dirty notifications, but is called in response to setValue(),
            we use this function to call setValue().
            """
            numClasses = len(self.LabelNames.value)
            self.opTrain.MaxLabel.setValue(numClasses)
            self.opPredictionPipeline.NumClasses.setValue(numClasses)
            self.NumClasses.setValue(numClasses)

        self.LabelNames.notifyDirty(_updateNumClasses)

        # Prediction pipeline outputs -> Top-level outputs
        self.PredictionProbabilities.connect(
            self.opPredictionPipeline.PredictionProbabilities)
        self.PredictionProbabilitiesUint8.connect(
            self.opPredictionPipeline.PredictionProbabilitiesUint8)
        self.CachedPredictionProbabilities.connect(
            self.opPredictionPipeline.CachedPredictionProbabilities)
        self.HeadlessPredictionProbabilities.connect(
            self.opPredictionPipeline.HeadlessPredictionProbabilities)
        self.HeadlessUint8PredictionProbabilities.connect(
            self.opPredictionPipeline.HeadlessUint8PredictionProbabilities)
        self.PredictionProbabilityChannels.connect(
            self.opPredictionPipeline.PredictionProbabilityChannels)
        self.SegmentationChannels.connect(
            self.opPredictionPipeline.SegmentationChannels)
        self.UncertaintyEstimate.connect(
            self.opPredictionPipeline.UncertaintyEstimate)
        self.SimpleSegmentation.connect(
            self.opPredictionPipeline.SimpleSegmentation)
        self.HeadlessUncertaintyEstimate.connect(
            self.opPredictionPipeline.HeadlessUncertaintyEstimate)

        def inputResizeHandler(slot, oldsize, newsize):
            if (newsize == 0):
                self.LabelImages.resize(0)
                self.NonzeroLabelBlocks.resize(0)
                self.PredictionProbabilities.resize(0)
                self.CachedPredictionProbabilities.resize(0)

        self.InputImages.notifyResized(inputResizeHandler)

        # Debug assertions: Check to make sure the non-wrapped operators stayed that way.
        assert self.opTrain.Images.operator == self.opTrain

        def handleNewInputImage(multislot, index, *args):
            def handleInputReady(slot):
                self._checkConstraints(index)
                self.setupCaches(multislot.index(slot))

            multislot[index].notifyReady(handleInputReady)

        self.InputImages.notifyInserted(handleNewInputImage)

        # If any feature image changes shape, we need to verify that the
        #  channels are consistent with the currently cached classifier
        # Otherwise, delete the currently cached classifier.
        def handleNewFeatureImage(multislot, index, *args):
            def handleFeatureImageReady(slot):
                def handleFeatureMetaChanged(slot):
                    if (self.classifier_cache.fixAtCurrent.value
                            and self.classifier_cache.Output.ready()
                            and slot.meta.shape is not None):
                        classifier = self.classifier_cache.Output.value
                        channel_names = slot.meta.channel_names
                        if classifier and classifier.feature_names != channel_names:
                            self.classifier_cache.resetValue()

                slot.notifyMetaChanged(handleFeatureMetaChanged)

            multislot[index].notifyReady(handleFeatureImageReady)

        self.FeatureImages.notifyInserted(handleNewFeatureImage)

        def handleNewMaskImage(multislot, index, *args):
            def handleInputReady(slot):
                self._checkConstraints(index)

            multislot[index].notifyReady(handleInputReady)

        self.PredictionMasks.notifyInserted(handleNewMaskImage)

        # All input multi-slots should be kept in sync
        # Output multi-slots will auto-sync via the graph
        multiInputs = filter(lambda s: s.level >= 1, self.inputs.values())
        for s1 in multiInputs:
            for s2 in multiInputs:
                if s1 != s2:

                    def insertSlot(a, b, position, finalsize):
                        a.insertSlot(position, finalsize)

                    s1.notifyInserted(partial(insertSlot, s2))

                    def removeSlot(a, b, position, finalsize):
                        a.removeSlot(position, finalsize)

                    s1.notifyRemoved(partial(removeSlot, s2))
Example #13
0
    def __init__(self, *args, connectionFactory, **kwargs):
        """
        Instantiate all internal operators and connect them together.
        """
        super(OpNNClassification, self).__init__(*args, **kwargs)
        self._connectionFactory = connectionFactory
        #
        # Default values for some input slots
        self.FreezePredictions.setValue(True)
        self.LabelNames.setValue([])
        self.LabelColors.setValue([])
        self.PmapColors.setValue([])

        self.Checkpoints.setValue([])
        self._binary_model = None

        # SPECIAL connection: the LabelInputs slot doesn't get it's data
        # from the InputImages slot, but it's shape must match.
        self.LabelInputs.connect(self.InputImages)

        self.opBlockShape = OpMultiLaneWrapper(OpBlockShape, parent=self)
        self.opBlockShape.RawImage.connect(self.InputImages)
        self.opBlockShape.ModelSession.connect(self.ModelSession)

        # self.opModel = OpModel(parent=self.parent, connectionFactory=connectionFactory)
        # self.opModel.ServerConfig.connect(self.ServerConfig)
        # self.opModel.ModelBinary.connect(self.ModelBinary)

        # self.ModelSession.connect(self.opModel.TiktorchModel)
        # self.NumClasses.connect(self.opModel.NumClasses)

        # Hook up Labeling Pipeline
        self.opLabelPipeline = OpMultiLaneWrapper(
            OpLabelPipeline,
            parent=self,
            broadcastingSlotNames=["DeleteLabel"])
        self.opLabelPipeline.RawImage.connect(self.InputImages)
        self.opLabelPipeline.LabelInput.connect(self.LabelInputs)
        self.opLabelPipeline.DeleteLabel.setValue(-1)
        self.LabelImages.connect(self.opLabelPipeline.Output)
        self.NonzeroLabelBlocks.connect(self.opLabelPipeline.nonzeroBlocks)

        # TRAINING OPERATOR
        self.opTrain = OpTikTorchTrainClassifierBlocked(parent=self)
        self.opTrain.ModelSession.connect(self.ModelSession)
        self.opTrain.Labels.connect(self.opLabelPipeline.Output)
        self.opTrain.Images.connect(self.InputImages)
        self.opTrain.BlockShape.connect(self.opBlockShape.BlockShapeTrain)
        self.opTrain.nonzeroLabelBlocks.connect(
            self.opLabelPipeline.nonzeroBlocks)
        self.opTrain.MaxLabel.connect(self.NumClasses)

        # CLASSIFIER CACHE
        # This cache stores exactly one object: the classifier itself.
        self.classifier_cache = OpValueCache(parent=self)
        self.classifier_cache.name = "OpNetworkClassification.classifier_cache"
        self.classifier_cache.inputs["Input"].connect(
            self.opTrain.UpdatedModelSession)
        self.classifier_cache.inputs["fixAtCurrent"].connect(
            self.FreezePredictions)
        self.Classifier.connect(self.classifier_cache.Output)

        # Hook up the prediction pipeline inputs
        self.opPredictionPipeline = OpMultiLaneWrapper(OpPredictionPipeline,
                                                       parent=self)
        self.opPredictionPipeline.RawImage.connect(self.InputImages)
        # self.opPredictionPipeline.Classifier.connect(self.classifier_cache.Output)
        self.opPredictionPipeline.Classifier.connect(self.ModelSession)
        self.opPredictionPipeline.NumClasses.connect(self.NumClasses)
        self.opPredictionPipeline.FreezePredictions.connect(
            self.FreezePredictions)

        self.PredictionProbabilities.connect(
            self.opPredictionPipeline.PredictionProbabilities)
        self.CachedPredictionProbabilities.connect(
            self.opPredictionPipeline.CachedPredictionProbabilities)
        self.PredictionProbabilityChannels.connect(
            self.opPredictionPipeline.PredictionProbabilityChannels)

        def inputResizeHandler(slot, oldsize, newsize):
            if newsize == 0:
                self.LabelImages.resize(0)
                self.NonzeroLabelBlocks.resize(0)
                self.PredictionProbabilities.resize(0)
                self.CachedPredictionProbabilities.resize(0)

        self.InputImages.notifyResized(inputResizeHandler)

        # Debug assertions: Check to make sure the non-wrapped operators stayed that way.
        assert self.opTrain.Images.operator == self.opTrain

        def handleNewInputImage(multislot, index, *args):
            def handleInputReady(slot):
                self._checkConstraints(index)
                self.setupCaches(multislot.index(slot))

            multislot[index].notifyReady(handleInputReady)

        self.InputImages.notifyInserted(handleNewInputImage)

        # All input multi-slots should be kept in sync
        # Output multi-slots will auto-sync via the graph
        multiInputs = [s for s in list(self.inputs.values()) if s.level >= 1]
        for s1 in multiInputs:
            for s2 in multiInputs:
                if s1 != s2:

                    def insertSlot(a, b, position, finalsize):
                        a.insertSlot(position, finalsize)

                    s1.notifyInserted(partial(insertSlot, s2))

                    def removeSlot(a, b, position, finalsize):
                        a.removeSlot(position, finalsize)

                    s1.notifyRemoved(partial(removeSlot, s2))
Example #14
0
    def __init__(self, graph):
        """
        Instantiate all internal operators and connect them together.
        """
        super(OpPixelClassification, self).__init__(graph=graph)

        self.FreezePredictions.setValue(True)  # Default

        # Create internal operators
        # Explicitly wrapped:
        self.opInputShapeReader = OperatorWrapper(OpShapeReader,
                                                  parent=self,
                                                  graph=self.graph)
        self.opLabelArray = OperatorWrapper(OpBlockedSparseLabelArray,
                                            parent=self,
                                            graph=self.graph)
        self.predict = OperatorWrapper(OpPredictRandomForest,
                                       parent=self,
                                       graph=self.graph)
        self.prediction_cache = OperatorWrapper(OpSlicedBlockedArrayCache,
                                                parent=self,
                                                graph=self.graph)
        assert len(self.prediction_cache.Input) == 0
        self.prediction_cache_gui = OperatorWrapper(OpSlicedBlockedArrayCache,
                                                    parent=self,
                                                    graph=self.graph)
        assert len(self.prediction_cache_gui.Input) == 0
        self.precomputed_predictions = OperatorWrapper(OpPrecomputedInput,
                                                       parent=self,
                                                       graph=self.graph)
        self.precomputed_predictions_gui = OperatorWrapper(OpPrecomputedInput,
                                                           parent=self,
                                                           graph=self.graph)

        # NOT wrapped
        self.opMaxLabel = OpMaxValue(parent=self, graph=self.graph)
        self.opTrain = OpTrainRandomForestBlocked(parent=self,
                                                  graph=self.graph)

        # Set up label cache shape input
        self.opInputShapeReader.Input.connect(self.InputImages)
        self.opLabelArray.inputs["shape"].connect(
            self.opInputShapeReader.OutputShape)

        # Set up other label cache inputs
        self.LabelInputs.connect(self.InputImages)
        self.opLabelArray.inputs["Input"].connect(self.LabelInputs)
        self.opLabelArray.inputs["eraser"].setValue(100)

        # Initialize the delete input to -1, which means "no label".
        # Now changing this input to a positive value will cause label deletions.
        # (The deleteLabel input is monitored for changes.)
        self.opLabelArray.inputs["deleteLabel"].setValue(-1)

        # Find the highest label in all the label images
        self.opMaxLabel.Inputs.connect(self.opLabelArray.outputs['maxLabel'])

        ##
        # training
        ##

        self.opTrain.inputs['Labels'].connect(
            self.opLabelArray.outputs["Output"])
        self.opTrain.inputs['Images'].connect(self.CachedFeatureImages)
        self.opTrain.inputs["nonzeroLabelBlocks"].connect(
            self.opLabelArray.outputs["nonzeroBlocks"])
        self.opTrain.inputs['fixClassifier'].setValue(False)

        # The classifier is cached here to allow serializers to force in a pre-calculated classifier...
        self.classifier_cache = OpValueCache(parent=self, graph=self.graph)
        self.classifier_cache.inputs["Input"].connect(
            self.opTrain.outputs['Classifier'])

        ##
        #
        ##
        self.predict.inputs['Classifier'].connect(
            self.classifier_cache.outputs['Output'])
        self.predict.inputs['Image'].connect(self.CachedFeatureImages)
        self.predict.inputs['LabelsCount'].connect(self.opMaxLabel.Output)

        # prediction cache for downstream operators (if they want it)
        self.prediction_cache.name = "PredictionCache"
        self.prediction_cache.inputs["fixAtCurrent"].setValue(False)
        self.prediction_cache.inputs["Input"].connect(self.predict.PMaps)

        # The serializer uses these operators to provide prediction data directly from the project file
        # if the predictions haven't become dirty since the project file was opened.
        self.precomputed_predictions.SlowInput.connect(
            self.prediction_cache.Output)
        self.precomputed_predictions.PrecomputedInput.connect(
            self.PredictionsFromDisk)

        # Prediction cache for the GUI
        self.prediction_cache_gui.name = "PredictionCache"
        self.prediction_cache_gui.inputs["fixAtCurrent"].connect(
            self.FreezePredictions)
        self.prediction_cache_gui.inputs["Input"].connect(self.predict.PMaps)

        self.precomputed_predictions_gui.SlowInput.connect(
            self.prediction_cache_gui.Output)
        self.precomputed_predictions_gui.PrecomputedInput.connect(
            self.PredictionsFromDisk)

        # Connect our internal outputs to our external outputs
        self.LabelImages.connect(self.opLabelArray.Output)
        self.MaxLabelValue.connect(self.opMaxLabel.Output)
        self.NonzeroLabelBlocks.connect(self.opLabelArray.nonzeroBlocks)
        self.PredictionProbabilities.connect(self.predict.PMaps)
        self.CachedPredictionProbabilities.connect(
            self.precomputed_predictions.Output)
        self.Classifier.connect(self.classifier_cache.Output)

        def inputResizeHandler(slot, oldsize, newsize):
            if (newsize == 0):
                self.LabelImages.resize(0)
                self.NonzeroLabelBlocks.resize(0)
                self.PredictionProbabilities.resize(0)
                self.CachedPredictionProbabilities.resize(0)

        self.InputImages.notifyResized(inputResizeHandler)

        # Check to make sure the non-wrapped operators stayed that way.
        assert self.opMaxLabel.Inputs.operator == self.opMaxLabel
        assert self.opTrain.Images.operator == self.opTrain

        # Also provide each prediction channel as a separate layer (for the GUI)
        self.opPredictionSlicer = OperatorWrapper(OpMultiArraySlicer2,
                                                  parent=self,
                                                  graph=self.graph)
        self.opPredictionSlicer.Input.connect(
            self.precomputed_predictions_gui.Output)
        self.opPredictionSlicer.AxisFlag.setValue('c')
        self.PredictionProbabilityChannels.connect(
            self.opPredictionSlicer.Slices)

        self.opSegementor = OperatorWrapper(OpPixelOperator,
                                            parent=self,
                                            graph=self.graph)
        self.opSegementor.Input.connect(
            self.precomputed_predictions_gui.Output)
        self.opSegementor.Function.setValue(
            lambda x: numpy.where(x < 0.5, 0, 1))

        self.opSegmentationSlicer = OperatorWrapper(OpMultiArraySlicer2,
                                                    parent=self,
                                                    graph=self.graph)
        self.opSegmentationSlicer.Input.connect(self.opSegementor.Output)
        self.opSegmentationSlicer.AxisFlag.setValue('c')
        self.SegmentationChannels.connect(self.opSegmentationSlicer.Slices)

        # Create a layer for uncertainty estimate
        self.opUncertaintyEstimator = OperatorWrapper(OpEnsembleMargin,
                                                      parent=self,
                                                      graph=self.graph)
        self.opUncertaintyEstimator.Input.connect(
            self.precomputed_predictions_gui.Output)

        # Cache the uncertainty so we get zeros for uncomputed points
        self.opUncertaintyCache = OperatorWrapper(OpSlicedBlockedArrayCache,
                                                  parent=self,
                                                  graph=self.graph)
        self.opUncertaintyCache.Input.connect(
            self.opUncertaintyEstimator.Output)
        self.opUncertaintyCache.fixAtCurrent.connect(self.FreezePredictions)

        self.UncertaintyEstimate.connect(self.opUncertaintyCache.Output)

        def handleNewInputImage(multislot, index, *args):
            def handleInputReady(slot):
                self.setupCaches(multislot.index(slot))

            multislot[index].notifyReady(handleInputReady)

        self.InputImages.notifyInserted(handleNewInputImage)
    def setupOperators(self, *args, **kwargs):
        self.FreezePredictions.setValue(True)  # Default

        # Create internal operators
        # Explicitly wrapped:
        self.opInputShapeReader = OperatorWrapper(OpShapeReader, parent=self)
        self.opLabelArray = OperatorWrapper(OpBlockedSparseLabelArray,
                                            parent=self)

        self.predictors = []
        self.prediction_caches = []
        self.prediction_caches_gui = []

        #FIXME: we should take it from the input slot
        niter = self.AutocontextIterations.value

        for i in range(niter):
            predict = OperatorWrapper(OpPredictRandomForest, parent=self)
            prediction_cache = OperatorWrapper(OpSlicedBlockedArrayCache,
                                               parent=self)
            prediction_cache_gui = OperatorWrapper(OpSlicedBlockedArrayCache,
                                                   parent=self)

            self.predictors.append(predict)
            self.prediction_caches.append(prediction_cache)
            self.prediction_caches_gui.append(prediction_cache_gui)

        #We only display the last prediction layer

        self.precomputed_predictions = OperatorWrapper(OpPrecomputedInput,
                                                       parent=self)
        self.precomputed_predictions_gui = OperatorWrapper(OpPrecomputedInput,
                                                           parent=self)

        #Display pixel-only predictions to compare
        self.precomputed_predictions_pixel = OperatorWrapper(
            OpPrecomputedInput, parent=self)
        self.precomputed_predictions_pixel_gui = OperatorWrapper(
            OpPrecomputedInput, parent=self)

        # NOT wrapped
        self.opMaxLabel = OpMaxValue(parent=self)
        self.trainers = []
        for i in range(niter):
            opTrain = OpTrainRandomForestBlocked(parent=self)
            self.trainers.append(opTrain)

        # Set up label cache shape input
        self.opInputShapeReader.Input.connect(self.InputImages)
        self.opLabelArray.inputs["shape"].connect(
            self.opInputShapeReader.OutputShape)

        # Set up other label cache inputs
        self.LabelInputs.connect(self.InputImages)
        self.opLabelArray.inputs["Input"].connect(self.LabelInputs)
        self.opLabelArray.inputs["eraser"].setValue(100)

        # Initialize the delete input to -1, which means "no label".
        # Now changing this input to a positive value will cause label deletions.
        # (The deleteLabel input is monitored for changes.)
        self.opLabelArray.inputs["deleteLabel"].setValue(-1)

        # Find the highest label in all the label images
        self.opMaxLabel.Inputs.connect(self.opLabelArray.outputs['maxLabel'])

        # Setup autocontext features
        self.autocontextFeatures = []
        self.autocontextFeaturesMulti = []
        self.autocontext_caches = []
        self.featureStackers = []

        for i in range(niter - 1):
            features = createAutocontextFeatureOperators(self, True)
            self.autocontextFeatures.append(features)
            opMulti = OperatorWrapper(Op50ToMulti, parent=self)
            self.autocontextFeaturesMulti.append(opMulti)
            opStacker = OperatorWrapper(OpMultiArrayStacker, parent=self)
            opStacker.inputs["AxisFlag"].setValue("c")
            opStacker.inputs["AxisIndex"].setValue(3)
            self.featureStackers.append(opStacker)
            autocontext_cache = OperatorWrapper(OpSlicedBlockedArrayCache,
                                                parent=self)
            self.autocontext_caches.append(autocontext_cache)

        # connect the features to predictors
        for i in range(niter - 1):
            for ifeat, feat in enumerate(self.autocontextFeatures[i]):
                feat.inputs['Input'].connect(self.prediction_caches[i].Output)
                print "Multi: Connecting an output", "Input%.2d" % (ifeat)
                self.autocontextFeaturesMulti[i].inputs[
                    "Input%.2d" % (ifeat)].connect(feat.outputs["Output"])
            # connect the pixel features to the same multislot
            print "Multi: Connecting an output", "Input%.2d" % (len(
                self.autocontextFeatures[i]))
            self.autocontextFeaturesMulti[i].inputs["Input%.2d" % (len(
                self.autocontextFeatures[i]))].connect(
                    self.CachedFeatureImages)
            # stack the autocontext features with pixel features
            self.featureStackers[i].inputs["Images"].connect(
                self.autocontextFeaturesMulti[i].outputs["Outputs"])
            # cache the stacks
            self.autocontext_caches[i].inputs["Input"].connect(
                self.featureStackers[i].outputs["Output"])
            self.autocontext_caches[i].inputs["fixAtCurrent"].setValue(False)

        ##
        # training
        ##

        for op in self.trainers:
            op.inputs['Labels'].connect(self.opLabelArray.outputs["Output"])
            op.inputs["nonzeroLabelBlocks"].connect(
                self.opLabelArray.outputs["nonzeroBlocks"])
            op.inputs['fixClassifier'].setValue(False)
        # Connect the first training operator - just pixel features
        self.trainers[0].inputs['Images'].connect(self.CachedFeatureImages)
        # Connect other training operators - stacked pixel and autocontext features
        for i in range(1, niter):
            self.trainers[i].inputs["Images"].connect(
                self.featureStackers[i - 1].outputs["Output"])

        ##
        # prediction
        ##

        # The classifier is cached here to allow serializers to force in a pre-calculated classifier...
        self.classifiers = []
        self.classifier_caches = []

        for i in range(niter):
            self.classifiers.append(self.trainers[i].outputs['Classifier'])
            cache = OpValueCache(parent=self)
            cache.inputs["Input"].connect(
                self.trainers[i].outputs['Classifier'])
            self.classifier_caches.append(cache)

        for i in range(niter):
            self.predictors[i].inputs['Classifier'].connect(
                self.classifier_caches[i].outputs["Output"])
            self.predictors[i].inputs['LabelsCount'].connect(
                self.opMaxLabel.Output)

            self.prediction_caches[i].inputs["fixAtCurrent"].setValue(False)
            self.prediction_caches[i].inputs["Input"].connect(
                self.predictors[i].PMaps)

            self.prediction_caches_gui[i].name = "PredictionCache"
            self.prediction_caches_gui[i].inputs["fixAtCurrent"].connect(
                self.FreezePredictions)
            self.prediction_caches_gui[i].inputs["Input"].connect(
                self.predictors[i].PMaps)

        self.predictors[0].inputs['Image'].connect(self.CachedFeatureImages)
        for i in range(1, niter):
            self.predictors[i].inputs['Image'].connect(
                self.autocontext_caches[i - 1].outputs["Output"])

        # The serializer uses these operators to provide prediction data directly from the project file
        # if the predictions haven't become dirty since the project file was opened.
        self.precomputed_predictions.SlowInput.connect(
            self.prediction_caches[-1].Output)
        self.precomputed_predictions.PrecomputedInput.connect(
            self.PredictionsFromDisk)

        self.precomputed_predictions_pixel.SlowInput.connect(
            self.prediction_caches[0].Output)
        self.precomputed_predictions_pixel.PrecomputedInput.connect(
            self.PredictionsFromDisk)

        # !!! here we can change which prediction step we show:
        self.precomputed_predictions_gui.SlowInput.connect(
            self.prediction_caches_gui[-1].Output)
        self.precomputed_predictions_gui.PrecomputedInput.connect(
            self.PredictionsFromDisk)
        self.precomputed_predictions_pixel_gui.SlowInput.connect(
            self.prediction_caches_gui[0].Output)
        self.precomputed_predictions_pixel_gui.PrecomputedInput.connect(
            self.PredictionsFromDisk)

        # Connect our internal outputs to our external outputs
        self.LabelImages.connect(self.opLabelArray.Output)
        self.MaxLabelValue.connect(self.opMaxLabel.Output)
        self.NonzeroLabelBlocks.connect(self.opLabelArray.nonzeroBlocks)
        self.PixelOnlyPredictions.connect(self.predictors[0].PMaps)
        self.PredictionProbabilities.connect(self.predictors[-1].PMaps)
        self.CachedPredictionProbabilities.connect(
            self.precomputed_predictions.Output)
        self.CachedPixelPredictionProbabilities.connect(
            self.precomputed_predictions_pixel.Output)

        self.multi = Op50ToMulti(parent=self)
        for i in range(niter):
            self.multi.inputs["Input%.2d" % i].connect(
                self.classifier_caches[i].outputs["Output"])

        self.Classifiers.connect(self.multi.outputs["Outputs"])

        def inputResizeHandler(slot, oldsize, newsize):
            if (newsize == 0):
                self.LabelImages.resize(0)
                self.NonzeroLabelBlocks.resize(0)
                self.PixelOnlyPredictions.resize(0)
                self.PredictionProbabilities.resize(0)
                self.CachedPredictionProbabilities.resize(0)
                self.CachedPixelPredictionProbabilities.resize(0)

        self.InputImages.notifyResized(inputResizeHandler)

        # Check to make sure the non-wrapped operators stayed that way.
        assert self.opMaxLabel.Inputs.operator == self.opMaxLabel
        for i in range(niter):
            assert self.trainers[0].Images.operator == self.trainers[0]
        #assert self.opTrain.Images.operator == self.opTrain

        # Also provide each prediction channel as a separate layer (for the GUI)
        self.opPredictionSlicer = OperatorWrapper(OpMultiArraySlicer2,
                                                  parent=self)
        self.opPredictionSlicer.Input.connect(
            self.precomputed_predictions_gui.Output)
        self.opPredictionSlicer.AxisFlag.setValue('c')
        self.PredictionProbabilityChannels.connect(
            self.opPredictionSlicer.Slices)

        self.opPixelPredictionSlicer = OperatorWrapper(OpMultiArraySlicer2,
                                                       parent=self)
        self.opPixelPredictionSlicer.Input.connect(
            self.precomputed_predictions_pixel_gui.Output)
        self.opPixelPredictionSlicer.AxisFlag.setValue('c')
        self.PixelOnlyPredictionChannels.connect(
            self.opPixelPredictionSlicer.Slices)

        def handleNewInputImage(multislot, index, *args):
            def handleInputReady(slot):
                self.setupCaches(multislot.index(slot))

            multislot[index].notifyReady(handleInputReady)

        self.InputImages.notifyInserted(handleNewInputImage)

        # All input multi-slots should be kept in sync
        # Output multi-slots will auto-sync via the graph
        multiInputs = filter(lambda s: s.level >= 1, self.inputs.values())
        for s1 in multiInputs:
            for s2 in multiInputs:
                if s1 != s2:

                    def insertSlot(a, b, position, finalsize):
                        a.insertSlot(position, finalsize)

                    s1.notifyInserted(partial(insertSlot, s2))

                    def removeSlot(a, b, position, finalsize):
                        a.removeSlot(position, finalsize)

                    s1.notifyRemoved(partial(removeSlot, s2))
    def __init__(self, *args, **kwargs):
        super(OpObjectClassification, self).__init__(*args, **kwargs)

        # internal operators
        opkwargs = dict(parent=self)
        self.opTrain = OpObjectTrain(parent=self)
        self.opPredict = OpMultiLaneWrapper(OpObjectPredict, **opkwargs)
        self.opLabelsToImage = OpMultiLaneWrapper(OpRelabelSegmentation,
                                                  **opkwargs)
        self.opPredictionsToImage = OpMultiLaneWrapper(OpRelabelSegmentation,
                                                       **opkwargs)
        self.opPredictionImageCache = OpMultiLaneWrapper(
            OpSlicedBlockedArrayCache, **opkwargs)

        self.opProbabilityChannelsToImage = OpMultiLaneWrapper(
            OpMultiRelabelSegmentation, **opkwargs)
        self.opBadObjectsToImage = OpMultiLaneWrapper(OpRelabelSegmentation,
                                                      **opkwargs)
        self.opBadObjectsToWarningMessage = OpBadObjectsToWarningMessage(
            parent=self)

        self.classifier_cache = OpValueCache(parent=self)

        # connect inputs
        self.opTrain.Features.connect(self.ObjectFeatures)
        self.opTrain.Labels.connect(self.LabelInputs)
        self.opTrain.FixClassifier.setValue(False)
        self.opTrain.SelectedFeatures.connect(self.SelectedFeatures)

        self.classifier_cache.Input.connect(self.opTrain.Classifier)

        # Find the highest label in all the label images
        self.opMaxLabel = OpMaxLabel(parent=self)
        self.opMaxLabel.Inputs.connect(self.LabelInputs)

        self.opPredict.Features.connect(self.ObjectFeatures)
        self.opPredict.Classifier.connect(self.classifier_cache.Output)
        self.opPredict.LabelsCount.connect(self.opMaxLabel.Output)
        self.opPredict.SelectedFeatures.connect(self.SelectedFeatures)

        self.opLabelsToImage.Image.connect(self.SegmentationImages)
        self.opLabelsToImage.ObjectMap.connect(self.LabelInputs)
        self.opLabelsToImage.Features.connect(self.ObjectFeatures)

        self.opPredictionsToImage.Image.connect(self.SegmentationImages)
        self.opPredictionsToImage.ObjectMap.connect(self.opPredict.Predictions)
        self.opPredictionsToImage.Features.connect(self.ObjectFeatures)

        #self.opPredictionImageCache.name = "prediction_image_cache"
        self.opPredictionImageCache.fixAtCurrent.connect(
            self.FreezePredictions)
        self.opPredictionImageCache.Input.connect(
            self.opPredictionsToImage.Output)

        self.opProbabilityChannelsToImage.Image.connect(
            self.SegmentationImages)
        self.opProbabilityChannelsToImage.ObjectMaps.connect(
            self.opPredict.ProbabilityChannels)
        self.opProbabilityChannelsToImage.Features.connect(self.ObjectFeatures)

        class OpWrappedCache(Operator):
            """
            This quick hack is necessary because there's not currently a way to wrap an OperatorWrapper.
            We need to double-wrap the cache, so we need this operator to provide the first level of wrapping.
            """
            Input = InputSlot(level=1)
            innerBlockShape = InputSlot()
            outerBlockShape = InputSlot()
            fixAtCurrent = InputSlot(value=False)

            Output = OutputSlot(level=1)

            def __init__(self, *args, **kwargs):
                super(OpWrappedCache, self).__init__(*args, **kwargs)
                self._innerOperator = OperatorWrapper(
                    OpSlicedBlockedArrayCache, parent=self)
                self._innerOperator.Input.connect(self.Input)
                self._innerOperator.fixAtCurrent.connect(self.fixAtCurrent)
                self._innerOperator.innerBlockShape.connect(
                    self.innerBlockShape)
                self._innerOperator.outerBlockShape.connect(
                    self.outerBlockShape)
                self.Output.connect(self._innerOperator.Output)

            def execute(self, slot, subindex, roi, destination):
                assert False, "Shouldn't get here."

            def propagateDirty(self, slot, subindex, roi):
                pass  # Nothing to do...

        # Wrap the cache for probability channels twice TWICE.
        self.opProbChannelsImageCache = OpMultiLaneWrapper(OpWrappedCache,
                                                           parent=self)
        self.opProbChannelsImageCache.Input.connect(
            self.opProbabilityChannelsToImage.Output)
        self.opProbChannelsImageCache.fixAtCurrent.connect(
            self.FreezePredictions)

        self.opBadObjectsToImage.Image.connect(self.SegmentationImages)
        self.opBadObjectsToImage.ObjectMap.connect(self.opPredict.BadObjects)
        self.opBadObjectsToImage.Features.connect(self.ObjectFeatures)

        self.opBadObjectsToWarningMessage.BadObjects.connect(
            self.opTrain.BadObjects)

        self.opPredict.InputProbabilities.connect(self.InputProbabilities)

        self.LabelNames.setValue([])
        self.LabelColors.setValue([])
        self.PmapColors.setValue([])

        # connect outputs
        self.NumLabels.connect(self.opMaxLabel.Output)
        self.LabelImages.connect(self.opLabelsToImage.Output)
        self.Predictions.connect(self.opPredict.Predictions)
        self.Probabilities.connect(self.opPredict.Probabilities)
        self.CachedProbabilities.connect(self.opPredict.CachedProbabilities)
        self.PredictionImages.connect(self.opPredictionImageCache.Output)
        self.UncachedPredictionImages.connect(self.opPredictionsToImage.Output)
        self.PredictionProbabilityChannels.connect(
            self.opProbChannelsImageCache.Output)
        self.BadObjects.connect(self.opPredict.BadObjects)
        self.BadObjectImages.connect(self.opBadObjectsToImage.Output)
        self.Warnings.connect(self.opBadObjectsToWarningMessage.WarningMessage)

        self.Classifier.connect(self.classifier_cache.Output)

        self.SegmentationImagesOut.connect(self.SegmentationImages)

        self.Eraser.setValue(100)
        self.DeleteLabel.setValue(-1)

        self._labelBBoxes = []
        self._ambiguousLabels = []
        self._needLabelTransfer = False

        def handleNewInputImage(multislot, index, *args):
            def handleInputReady(slot):
                self.setupCaches(multislot.index(slot))

            multislot[index].notifyReady(handleInputReady)

        self.SegmentationImages.notifyInserted(handleNewInputImage)