def __init__(self, *args, **kwargs): super(OpMockPixelClassifier, self).__init__(*args, **kwargs) self.LabelNames.setValue( ["Membrane", "Cytoplasm"] ) self.LabelColors.setValue( [(255,0,0), (0,255,0)] ) # Red, Green self.PmapColors.setValue( [(255,0,0), (0,255,0)] ) # Red, Green self._data = [] self.dataShape = (1,10,100,100,1) self.prediction_shape = self.dataShape[:-1] + (2,) # Hard-coded to provide 2 classes self.FreezePredictions.setValue(False) self.opClassifier = OpTrainClassifierBlocked(graph=self.graph, parent=self) self.opClassifier.ClassifierFactory.connect( self.ClassifierFactory ) self.opClassifier.Labels.connect(self.LabelImages) self.opClassifier.nonzeroLabelBlocks.connect(self.NonzeroLabelBlocks) self.opClassifier.MaxLabel.setValue(2) self.classifier_cache = OpValueCache(graph=self.graph, parent=self) self.classifier_cache.Input.connect( self.opClassifier.Classifier ) p1 = numpy.indices(self.dataShape).sum(0) / 207.0 p2 = 1 - p1 self.predictionData = numpy.concatenate((p1,p2), axis=4)
def __init__(self, *args, **kwargs): """ Instantiate all internal operators and connect them together. """ super(OpPixelClassification, self).__init__(*args, **kwargs) # Default values for some input slots self.FreezePredictions.setValue(True) self.LabelNames.setValue([]) self.LabelColors.setValue([]) self.PmapColors.setValue([]) # SPECIAL connection: The LabelInputs slot doesn't get it's data # from the InputImages slot, but it's shape must match. self.LabelInputs.connect(self.InputImages) # Hook up Labeling Pipeline self.opLabelPipeline = OpMultiLaneWrapper( OpLabelPipeline, parent=self, broadcastingSlotNames=['DeleteLabel']) self.opLabelPipeline.RawImage.connect(self.InputImages) self.opLabelPipeline.LabelInput.connect(self.LabelInputs) self.opLabelPipeline.DeleteLabel.setValue(-1) self.LabelImages.connect(self.opLabelPipeline.Output) self.NonzeroLabelBlocks.connect(self.opLabelPipeline.nonzeroBlocks) # Hook up the Training operator self.opTrain = OpTrainClassifierBlocked(parent=self) self.opTrain.ClassifierFactory.connect(self.ClassifierFactory) self.opTrain.Labels.connect(self.opLabelPipeline.Output) self.opTrain.Images.connect(self.FeatureImages) self.opTrain.nonzeroLabelBlocks.connect( self.opLabelPipeline.nonzeroBlocks) # Hook up the Classifier Cache # The classifier is cached here to allow serializers to force in # a pre-calculated classifier (loaded from disk) self.classifier_cache = OpValueCache(parent=self) self.classifier_cache.name = "OpPixelClassification.classifier_cache" self.classifier_cache.inputs["Input"].connect( self.opTrain.outputs['Classifier']) self.classifier_cache.inputs["fixAtCurrent"].connect( self.FreezePredictions) self.Classifier.connect(self.classifier_cache.Output) # Hook up the prediction pipeline inputs self.opPredictionPipeline = OpMultiLaneWrapper(OpPredictionPipeline, parent=self) self.opPredictionPipeline.FeatureImages.connect(self.FeatureImages) self.opPredictionPipeline.CachedFeatureImages.connect( self.CachedFeatureImages) self.opPredictionPipeline.Classifier.connect( self.classifier_cache.Output) self.opPredictionPipeline.FreezePredictions.connect( self.FreezePredictions) self.opPredictionPipeline.PredictionsFromDisk.connect( self.PredictionsFromDisk) self.opPredictionPipeline.PredictionMask.connect(self.PredictionMasks) # Feature Selection Stuff self.opFeatureMatrixCaches = OpMultiLaneWrapper(OpFeatureMatrixCache, parent=self) self.opFeatureMatrixCaches.LabelImage.connect( self.opLabelPipeline.Output) self.opFeatureMatrixCaches.FeatureImage.connect(self.FeatureImages) self.opFeatureMatrixCaches.LabelImage.setDirty( ) # do I still need this? def _updateNumClasses(*args): """ When the number of labels changes, we MUST make sure that the prediction image changes its shape (the number of channels). Since setupOutputs is not called for mere dirty notifications, but is called in response to setValue(), we use this function to call setValue(). """ numClasses = len(self.LabelNames.value) self.opTrain.MaxLabel.setValue(numClasses) self.opPredictionPipeline.NumClasses.setValue(numClasses) self.NumClasses.setValue(numClasses) self.LabelNames.notifyDirty(_updateNumClasses) # Prediction pipeline outputs -> Top-level outputs self.PredictionProbabilities.connect( self.opPredictionPipeline.PredictionProbabilities) self.PredictionProbabilitiesUint8.connect( self.opPredictionPipeline.PredictionProbabilitiesUint8) self.CachedPredictionProbabilities.connect( self.opPredictionPipeline.CachedPredictionProbabilities) self.HeadlessPredictionProbabilities.connect( self.opPredictionPipeline.HeadlessPredictionProbabilities) self.HeadlessUint8PredictionProbabilities.connect( self.opPredictionPipeline.HeadlessUint8PredictionProbabilities) self.PredictionProbabilityChannels.connect( self.opPredictionPipeline.PredictionProbabilityChannels) self.SegmentationChannels.connect( self.opPredictionPipeline.SegmentationChannels) self.UncertaintyEstimate.connect( self.opPredictionPipeline.UncertaintyEstimate) self.SimpleSegmentation.connect( self.opPredictionPipeline.SimpleSegmentation) self.HeadlessUncertaintyEstimate.connect( self.opPredictionPipeline.HeadlessUncertaintyEstimate) def inputResizeHandler(slot, oldsize, newsize): if (newsize == 0): self.LabelImages.resize(0) self.NonzeroLabelBlocks.resize(0) self.PredictionProbabilities.resize(0) self.CachedPredictionProbabilities.resize(0) self.InputImages.notifyResized(inputResizeHandler) # Debug assertions: Check to make sure the non-wrapped operators stayed that way. assert self.opTrain.Images.operator == self.opTrain def handleNewInputImage(multislot, index, *args): def handleInputReady(slot): self._checkConstraints(index) self.setupCaches(multislot.index(slot)) multislot[index].notifyReady(handleInputReady) self.InputImages.notifyInserted(handleNewInputImage) # If any feature image changes shape, we need to verify that the # channels are consistent with the currently cached classifier # Otherwise, delete the currently cached classifier. def handleNewFeatureImage(multislot, index, *args): def handleFeatureImageReady(slot): def handleFeatureMetaChanged(slot): if (self.classifier_cache.fixAtCurrent.value and self.classifier_cache.Output.ready() and slot.meta.shape is not None): classifier = self.classifier_cache.Output.value channel_names = slot.meta.channel_names if classifier and classifier.feature_names != channel_names: self.classifier_cache.resetValue() slot.notifyMetaChanged(handleFeatureMetaChanged) multislot[index].notifyReady(handleFeatureImageReady) self.FeatureImages.notifyInserted(handleNewFeatureImage) def handleNewMaskImage(multislot, index, *args): def handleInputReady(slot): self._checkConstraints(index) multislot[index].notifyReady(handleInputReady) self.PredictionMasks.notifyInserted(handleNewMaskImage) # All input multi-slots should be kept in sync # Output multi-slots will auto-sync via the graph multiInputs = filter(lambda s: s.level >= 1, self.inputs.values()) for s1 in multiInputs: for s2 in multiInputs: if s1 != s2: def insertSlot(a, b, position, finalsize): a.insertSlot(position, finalsize) s1.notifyInserted(partial(insertSlot, s2)) def removeSlot(a, b, position, finalsize): a.removeSlot(position, finalsize) s1.notifyRemoved(partial(removeSlot, s2))