def __init__(self, *args, **kwargs):
        super(OpMockPixelClassifier, self).__init__(*args, **kwargs)

        self.LabelNames.setValue( ["Membrane", "Cytoplasm"] )
        self.LabelColors.setValue( [(255,0,0), (0,255,0)] ) # Red, Green
        self.PmapColors.setValue( [(255,0,0), (0,255,0)] ) # Red, Green
        
        self._data = []
        self.dataShape = (1,10,100,100,1)
        self.prediction_shape = self.dataShape[:-1] + (2,) # Hard-coded to provide 2 classes
        
        self.FreezePredictions.setValue(False)
        
        self.opClassifier = OpTrainClassifierBlocked(graph=self.graph, parent=self)
        self.opClassifier.ClassifierFactory.connect( self.ClassifierFactory )
        self.opClassifier.Labels.connect(self.LabelImages)
        self.opClassifier.nonzeroLabelBlocks.connect(self.NonzeroLabelBlocks)
        self.opClassifier.MaxLabel.setValue(2)
        
        self.classifier_cache = OpValueCache(graph=self.graph, parent=self)
        self.classifier_cache.Input.connect( self.opClassifier.Classifier )
        
        p1 = numpy.indices(self.dataShape).sum(0) / 207.0
        p2 = 1 - p1

        self.predictionData = numpy.concatenate((p1,p2), axis=4)
예제 #2
0
    def __init__(self, *args, **kwargs):
        """
        Instantiate all internal operators and connect them together.
        """
        super(OpPixelClassification, self).__init__(*args, **kwargs)

        # Default values for some input slots
        self.FreezePredictions.setValue(True)
        self.LabelNames.setValue([])
        self.LabelColors.setValue([])
        self.PmapColors.setValue([])

        # SPECIAL connection: The LabelInputs slot doesn't get it's data
        #  from the InputImages slot, but it's shape must match.
        self.LabelInputs.connect(self.InputImages)

        # Hook up Labeling Pipeline
        self.opLabelPipeline = OpMultiLaneWrapper(
            OpLabelPipeline,
            parent=self,
            broadcastingSlotNames=['DeleteLabel'])
        self.opLabelPipeline.RawImage.connect(self.InputImages)
        self.opLabelPipeline.LabelInput.connect(self.LabelInputs)
        self.opLabelPipeline.DeleteLabel.setValue(-1)
        self.LabelImages.connect(self.opLabelPipeline.Output)
        self.NonzeroLabelBlocks.connect(self.opLabelPipeline.nonzeroBlocks)

        # Hook up the Training operator
        self.opTrain = OpTrainClassifierBlocked(parent=self)
        self.opTrain.ClassifierFactory.connect(self.ClassifierFactory)
        self.opTrain.Labels.connect(self.opLabelPipeline.Output)
        self.opTrain.Images.connect(self.FeatureImages)
        self.opTrain.nonzeroLabelBlocks.connect(
            self.opLabelPipeline.nonzeroBlocks)

        # Hook up the Classifier Cache
        # The classifier is cached here to allow serializers to force in
        #   a pre-calculated classifier (loaded from disk)
        self.classifier_cache = OpValueCache(parent=self)
        self.classifier_cache.name = "OpPixelClassification.classifier_cache"
        self.classifier_cache.inputs["Input"].connect(
            self.opTrain.outputs['Classifier'])
        self.classifier_cache.inputs["fixAtCurrent"].connect(
            self.FreezePredictions)
        self.Classifier.connect(self.classifier_cache.Output)

        # Hook up the prediction pipeline inputs
        self.opPredictionPipeline = OpMultiLaneWrapper(OpPredictionPipeline,
                                                       parent=self)
        self.opPredictionPipeline.FeatureImages.connect(self.FeatureImages)
        self.opPredictionPipeline.CachedFeatureImages.connect(
            self.CachedFeatureImages)
        self.opPredictionPipeline.Classifier.connect(
            self.classifier_cache.Output)
        self.opPredictionPipeline.FreezePredictions.connect(
            self.FreezePredictions)
        self.opPredictionPipeline.PredictionsFromDisk.connect(
            self.PredictionsFromDisk)
        self.opPredictionPipeline.PredictionMask.connect(self.PredictionMasks)

        # Feature Selection Stuff
        self.opFeatureMatrixCaches = OpMultiLaneWrapper(OpFeatureMatrixCache,
                                                        parent=self)
        self.opFeatureMatrixCaches.LabelImage.connect(
            self.opLabelPipeline.Output)
        self.opFeatureMatrixCaches.FeatureImage.connect(self.FeatureImages)
        self.opFeatureMatrixCaches.LabelImage.setDirty(
        )  # do I still need this?

        def _updateNumClasses(*args):
            """
            When the number of labels changes, we MUST make sure that the prediction image changes its shape (the number of channels).
            Since setupOutputs is not called for mere dirty notifications, but is called in response to setValue(),
            we use this function to call setValue().
            """
            numClasses = len(self.LabelNames.value)
            self.opTrain.MaxLabel.setValue(numClasses)
            self.opPredictionPipeline.NumClasses.setValue(numClasses)
            self.NumClasses.setValue(numClasses)

        self.LabelNames.notifyDirty(_updateNumClasses)

        # Prediction pipeline outputs -> Top-level outputs
        self.PredictionProbabilities.connect(
            self.opPredictionPipeline.PredictionProbabilities)
        self.PredictionProbabilitiesUint8.connect(
            self.opPredictionPipeline.PredictionProbabilitiesUint8)
        self.CachedPredictionProbabilities.connect(
            self.opPredictionPipeline.CachedPredictionProbabilities)
        self.HeadlessPredictionProbabilities.connect(
            self.opPredictionPipeline.HeadlessPredictionProbabilities)
        self.HeadlessUint8PredictionProbabilities.connect(
            self.opPredictionPipeline.HeadlessUint8PredictionProbabilities)
        self.PredictionProbabilityChannels.connect(
            self.opPredictionPipeline.PredictionProbabilityChannels)
        self.SegmentationChannels.connect(
            self.opPredictionPipeline.SegmentationChannels)
        self.UncertaintyEstimate.connect(
            self.opPredictionPipeline.UncertaintyEstimate)
        self.SimpleSegmentation.connect(
            self.opPredictionPipeline.SimpleSegmentation)
        self.HeadlessUncertaintyEstimate.connect(
            self.opPredictionPipeline.HeadlessUncertaintyEstimate)

        def inputResizeHandler(slot, oldsize, newsize):
            if (newsize == 0):
                self.LabelImages.resize(0)
                self.NonzeroLabelBlocks.resize(0)
                self.PredictionProbabilities.resize(0)
                self.CachedPredictionProbabilities.resize(0)

        self.InputImages.notifyResized(inputResizeHandler)

        # Debug assertions: Check to make sure the non-wrapped operators stayed that way.
        assert self.opTrain.Images.operator == self.opTrain

        def handleNewInputImage(multislot, index, *args):
            def handleInputReady(slot):
                self._checkConstraints(index)
                self.setupCaches(multislot.index(slot))

            multislot[index].notifyReady(handleInputReady)

        self.InputImages.notifyInserted(handleNewInputImage)

        # If any feature image changes shape, we need to verify that the
        #  channels are consistent with the currently cached classifier
        # Otherwise, delete the currently cached classifier.
        def handleNewFeatureImage(multislot, index, *args):
            def handleFeatureImageReady(slot):
                def handleFeatureMetaChanged(slot):
                    if (self.classifier_cache.fixAtCurrent.value
                            and self.classifier_cache.Output.ready()
                            and slot.meta.shape is not None):
                        classifier = self.classifier_cache.Output.value
                        channel_names = slot.meta.channel_names
                        if classifier and classifier.feature_names != channel_names:
                            self.classifier_cache.resetValue()

                slot.notifyMetaChanged(handleFeatureMetaChanged)

            multislot[index].notifyReady(handleFeatureImageReady)

        self.FeatureImages.notifyInserted(handleNewFeatureImage)

        def handleNewMaskImage(multislot, index, *args):
            def handleInputReady(slot):
                self._checkConstraints(index)

            multislot[index].notifyReady(handleInputReady)

        self.PredictionMasks.notifyInserted(handleNewMaskImage)

        # All input multi-slots should be kept in sync
        # Output multi-slots will auto-sync via the graph
        multiInputs = filter(lambda s: s.level >= 1, self.inputs.values())
        for s1 in multiInputs:
            for s2 in multiInputs:
                if s1 != s2:

                    def insertSlot(a, b, position, finalsize):
                        a.insertSlot(position, finalsize)

                    s1.notifyInserted(partial(insertSlot, s2))

                    def removeSlot(a, b, position, finalsize):
                        a.removeSlot(position, finalsize)

                    s1.notifyRemoved(partial(removeSlot, s2))