Exemplo n.º 1
0
    def __init__(self, shell, headless, workflow_cmdline_args, project_creation_args, appendBatchOperators=True, *args, **kwargs):
        # Create a graph to be shared by all operators
        graph = Graph()
        super( PixelClassificationWorkflow, self ).__init__( shell, headless, workflow_cmdline_args, project_creation_args, graph=graph, *args, **kwargs )
        self._applets = []
        self._workflow_cmdline_args = workflow_cmdline_args

        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument('--filter', help="pixel feature filter implementation.", choices=['Original', 'Refactored', 'Interpolated'], default='Original')
        parser.add_argument('--print-labels-by-slice', help="Print the number of labels for each Z-slice of each image.", action="store_true")
        parser.add_argument('--label-search-value', help="If provided, only this value is considered when using --print-labels-by-slice", default=0, type=int)
        parser.add_argument('--generate-random-labels', help="Add random labels to the project file.", action="store_true")
        parser.add_argument('--random-label-value', help="The label value to use injecting random labels", default=1, type=int)
        parser.add_argument('--random-label-count', help="The number of random labels to inject via --generate-random-labels", default=2000, type=int)
        parser.add_argument('--retrain', help="Re-train the classifier based on labels stored in project file, and re-save.", action="store_true")

        # Parse the creation args: These were saved to the project file when this project was first created.
        parsed_creation_args, unused_args = parser.parse_known_args(project_creation_args)
        self.filter_implementation = parsed_creation_args.filter
        
        # Parse the cmdline args for the current session.
        parsed_args, unused_args = parser.parse_known_args(workflow_cmdline_args)
        self.print_labels_by_slice = parsed_args.print_labels_by_slice
        self.label_search_value = parsed_args.label_search_value
        self.generate_random_labels = parsed_args.generate_random_labels
        self.random_label_value = parsed_args.random_label_value
        self.random_label_count = parsed_args.random_label_count
        self.retrain = parsed_args.retrain

        if parsed_args.filter and parsed_args.filter != parsed_creation_args.filter:
            logger.error("Ignoring new --filter setting.  Filter implementation cannot be changed after initial project creation.")
        
        data_instructions = "Select your input data using the 'Raw Data' tab shown on the right.\n\n"\
                            "Power users: Optionally use the 'Prediction Mask' tab to supply a binary image that tells ilastik where it should avoid computations you don't need."

        # Applets for training (interactive) workflow 
        self.projectMetadataApplet = ProjectMetadataApplet()
        self.dataSelectionApplet = DataSelectionApplet( self,
                                                        "Input Data",
                                                        "Input Data",
                                                        supportIlastik05Import=True,
                                                        batchDataGui=False,
                                                        instructionText=data_instructions )
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        
        # see role constants, above
        opDataSelection.DatasetRoles.setValue( ['Raw Data', 'Prediction Mask'] )

        self.featureSelectionApplet = FeatureSelectionApplet(self, "Feature Selection", "FeatureSelections", self.filter_implementation)

        self.pcApplet = PixelClassificationApplet( self, "PixelClassification" )
        opClassify = self.pcApplet.topLevelOperator

        self.dataExportApplet = PixelClassificationDataExportApplet(self, "Prediction Export")
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.PmapColors.connect( opClassify.PmapColors )
        opDataExport.LabelNames.connect( opClassify.LabelNames )
        opDataExport.WorkingDirectory.connect( opDataSelection.WorkingDirectory )
        opDataExport.SelectionNames.setValue( self.EXPORT_NAMES )        

        # Expose for shell
        self._applets.append(self.projectMetadataApplet)
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.featureSelectionApplet)
        self._applets.append(self.pcApplet)
        self._applets.append(self.dataExportApplet)

        self._batch_input_args = None
        self._batch_export_args = None

        self.batchInputApplet = None
        self.batchResultsApplet = None
        if appendBatchOperators:
            # Create applets for batch workflow
            self.batchInputApplet = DataSelectionApplet(self, "Batch Prediction Input Selections", "Batch Inputs", supportIlastik05Import=False, batchDataGui=True)
            self.batchResultsApplet = PixelClassificationDataExportApplet(self, "Batch Prediction Output Locations", isBatch=True)
    
            # Expose in shell        
            self._applets.append(self.batchInputApplet)
            self._applets.append(self.batchResultsApplet)
    
            # Connect batch workflow (NOT lane-based)
            self._initBatchWorkflow()

            if unused_args:
                # We parse the export setting args first.  All remaining args are considered input files by the input applet.
                self._batch_export_args, unused_args = self.batchResultsApplet.parse_known_cmdline_args( unused_args )
                self._batch_input_args, unused_args = self.batchInputApplet.parse_known_cmdline_args( unused_args )
    
        if unused_args:
            logger.warn("Unused command-line args: {}".format( unused_args ))
    def __init__(self, shell, headless, workflow_cmdline_args,
                 project_creation_args, *args, **kwargs):
        # Create a graph to be shared by all operators
        graph = Graph()
        super(PixelClassificationWorkflow,
              self).__init__(shell,
                             headless,
                             workflow_cmdline_args,
                             project_creation_args,
                             graph=graph,
                             *args,
                             **kwargs)
        self.stored_classifier = None
        self._applets = []
        self._workflow_cmdline_args = workflow_cmdline_args
        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument('--filter',
                            help="pixel feature filter implementation.",
                            choices=['Original', 'Refactored', 'Interpolated'],
                            default='Original')
        parser.add_argument(
            '--print-labels-by-slice',
            help="Print the number of labels for each Z-slice of each image.",
            action="store_true")
        parser.add_argument(
            '--label-search-value',
            help=
            "If provided, only this value is considered when using --print-labels-by-slice",
            default=0,
            type=int)
        parser.add_argument('--generate-random-labels',
                            help="Add random labels to the project file.",
                            action="store_true")
        parser.add_argument(
            '--random-label-value',
            help="The label value to use injecting random labels",
            default=1,
            type=int)
        parser.add_argument(
            '--random-label-count',
            help=
            "The number of random labels to inject via --generate-random-labels",
            default=2000,
            type=int)
        parser.add_argument(
            '--retrain',
            help=
            "Re-train the classifier based on labels stored in project file, and re-save.",
            action="store_true")
        parser.add_argument('--tree-count',
                            help='Number of trees for Vigra RF classifier.',
                            type=int)
        parser.add_argument('--variable-importance-path',
                            help='Location of variable-importance table.',
                            type=str)
        parser.add_argument(
            '--label-proportion',
            help='Proportion of feature-pixels used to train the classifier.',
            type=float)

        # Parse the creation args: These were saved to the project file when this project was first created.
        parsed_creation_args, unused_args = parser.parse_known_args(
            project_creation_args)
        self.filter_implementation = parsed_creation_args.filter

        # Parse the cmdline args for the current session.
        parsed_args, unused_args = parser.parse_known_args(
            workflow_cmdline_args)
        self.print_labels_by_slice = parsed_args.print_labels_by_slice
        self.label_search_value = parsed_args.label_search_value
        self.generate_random_labels = parsed_args.generate_random_labels
        self.random_label_value = parsed_args.random_label_value
        self.random_label_count = parsed_args.random_label_count
        self.retrain = parsed_args.retrain
        self.tree_count = parsed_args.tree_count
        self.variable_importance_path = parsed_args.variable_importance_path
        self.label_proportion = parsed_args.label_proportion

        if parsed_args.filter and parsed_args.filter != parsed_creation_args.filter:
            logger.error(
                "Ignoring new --filter setting.  Filter implementation cannot be changed after initial project creation."
            )

        data_instructions = "Select your input data using the 'Raw Data' tab shown on the right.\n\n"\
                            "Power users: Optionally use the 'Prediction Mask' tab to supply a binary image that tells ilastik where it should avoid computations you don't need."

        # Applets for training (interactive) workflow
        self.dataSelectionApplet = self.createDataSelectionApplet()
        opDataSelection = self.dataSelectionApplet.topLevelOperator

        # see role constants, above
        opDataSelection.DatasetRoles.setValue(
            PixelClassificationWorkflow.ROLE_NAMES)

        self.featureSelectionApplet = self.createFeatureSelectionApplet()

        self.pcApplet = self.createPixelClassificationApplet()
        opClassify = self.pcApplet.topLevelOperator

        self.dataExportApplet = PixelClassificationDataExportApplet(
            self, "Prediction Export")
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.PmapColors.connect(opClassify.PmapColors)
        opDataExport.LabelNames.connect(opClassify.LabelNames)
        opDataExport.WorkingDirectory.connect(opDataSelection.WorkingDirectory)
        opDataExport.SelectionNames.setValue(self.EXPORT_NAMES)

        # Expose for shell
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.featureSelectionApplet)
        self._applets.append(self.pcApplet)
        self._applets.append(self.dataExportApplet)

        self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        self.dataExportApplet.post_process_entire_export = self.post_process_entire_export

        self.batchProcessingApplet = BatchProcessingApplet(
            self, "Batch Processing", self.dataSelectionApplet,
            self.dataExportApplet)

        self._applets.append(self.batchProcessingApplet)
        if unused_args:
            # We parse the export setting args first.  All remaining args are considered input files by the input applet.
            self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args(
                unused_args)
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args(
                unused_args)
        else:
            self._batch_input_args = None
            self._batch_export_args = None

        if unused_args:
            logger.warn("Unused command-line args: {}".format(unused_args))
Exemplo n.º 3
0
class PixelClassificationWorkflow(Workflow):
    
    workflowName = "Pixel Classification"
    workflowDescription = "This is obviously self-explanatory."
    defaultAppletIndex = 1 # show DataSelection by default
    
    DATA_ROLE_RAW = 0
    DATA_ROLE_PREDICTION_MASK = 1
    
    EXPORT_NAMES = ['Probabilities', 'Simple Segmentation', 'Uncertainty', 'Features']
    
    @property
    def applets(self):
        return self._applets

    @property
    def imageNameListSlot(self):
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def __init__(self, shell, headless, workflow_cmdline_args, project_creation_args, appendBatchOperators=True, *args, **kwargs):
        # Create a graph to be shared by all operators
        graph = Graph()
        super( PixelClassificationWorkflow, self ).__init__( shell, headless, workflow_cmdline_args, project_creation_args, graph=graph, *args, **kwargs )
        self._applets = []
        self._workflow_cmdline_args = workflow_cmdline_args

        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument('--filter', help="pixel feature filter implementation.", choices=['Original', 'Refactored', 'Interpolated'], default='Original')
        parser.add_argument('--print-labels-by-slice', help="Print the number of labels for each Z-slice of each image.", action="store_true")
        parser.add_argument('--label-search-value', help="If provided, only this value is considered when using --print-labels-by-slice", default=0, type=int)
        parser.add_argument('--generate-random-labels', help="Add random labels to the project file.", action="store_true")
        parser.add_argument('--random-label-value', help="The label value to use injecting random labels", default=1, type=int)
        parser.add_argument('--random-label-count', help="The number of random labels to inject via --generate-random-labels", default=2000, type=int)
        parser.add_argument('--retrain', help="Re-train the classifier based on labels stored in project file, and re-save.", action="store_true")

        # Parse the creation args: These were saved to the project file when this project was first created.
        parsed_creation_args, unused_args = parser.parse_known_args(project_creation_args)
        self.filter_implementation = parsed_creation_args.filter
        
        # Parse the cmdline args for the current session.
        parsed_args, unused_args = parser.parse_known_args(workflow_cmdline_args)
        self.print_labels_by_slice = parsed_args.print_labels_by_slice
        self.label_search_value = parsed_args.label_search_value
        self.generate_random_labels = parsed_args.generate_random_labels
        self.random_label_value = parsed_args.random_label_value
        self.random_label_count = parsed_args.random_label_count
        self.retrain = parsed_args.retrain

        if parsed_args.filter and parsed_args.filter != parsed_creation_args.filter:
            logger.error("Ignoring new --filter setting.  Filter implementation cannot be changed after initial project creation.")
        
        data_instructions = "Select your input data using the 'Raw Data' tab shown on the right.\n\n"\
                            "Power users: Optionally use the 'Prediction Mask' tab to supply a binary image that tells ilastik where it should avoid computations you don't need."

        # Applets for training (interactive) workflow 
        self.projectMetadataApplet = ProjectMetadataApplet()
        self.dataSelectionApplet = DataSelectionApplet( self,
                                                        "Input Data",
                                                        "Input Data",
                                                        supportIlastik05Import=True,
                                                        batchDataGui=False,
                                                        instructionText=data_instructions )
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        
        # see role constants, above
        opDataSelection.DatasetRoles.setValue( ['Raw Data', 'Prediction Mask'] )

        self.featureSelectionApplet = FeatureSelectionApplet(self, "Feature Selection", "FeatureSelections", self.filter_implementation)

        self.pcApplet = PixelClassificationApplet( self, "PixelClassification" )
        opClassify = self.pcApplet.topLevelOperator

        self.dataExportApplet = PixelClassificationDataExportApplet(self, "Prediction Export")
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.PmapColors.connect( opClassify.PmapColors )
        opDataExport.LabelNames.connect( opClassify.LabelNames )
        opDataExport.WorkingDirectory.connect( opDataSelection.WorkingDirectory )
        opDataExport.SelectionNames.setValue( self.EXPORT_NAMES )        

        # Expose for shell
        self._applets.append(self.projectMetadataApplet)
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.featureSelectionApplet)
        self._applets.append(self.pcApplet)
        self._applets.append(self.dataExportApplet)

        self._batch_input_args = None
        self._batch_export_args = None

        self.batchInputApplet = None
        self.batchResultsApplet = None
        if appendBatchOperators:
            # Create applets for batch workflow
            self.batchInputApplet = DataSelectionApplet(self, "Batch Prediction Input Selections", "Batch Inputs", supportIlastik05Import=False, batchDataGui=True)
            self.batchResultsApplet = PixelClassificationDataExportApplet(self, "Batch Prediction Output Locations", isBatch=True)
    
            # Expose in shell        
            self._applets.append(self.batchInputApplet)
            self._applets.append(self.batchResultsApplet)
    
            # Connect batch workflow (NOT lane-based)
            self._initBatchWorkflow()

            if unused_args:
                # We parse the export setting args first.  All remaining args are considered input files by the input applet.
                self._batch_export_args, unused_args = self.batchResultsApplet.parse_known_cmdline_args( unused_args )
                self._batch_input_args, unused_args = self.batchInputApplet.parse_known_cmdline_args( unused_args )
    
        if unused_args:
            logger.warn("Unused command-line args: {}".format( unused_args ))

    def connectLane(self, laneIndex):
        # Get a handle to each operator
        opData = self.dataSelectionApplet.topLevelOperator.getLane(laneIndex)
        opTrainingFeatures = self.featureSelectionApplet.topLevelOperator.getLane(laneIndex)
        opClassify = self.pcApplet.topLevelOperator.getLane(laneIndex)
        opDataExport = self.dataExportApplet.topLevelOperator.getLane(laneIndex)
        
        # Input Image -> Feature Op
        #         and -> Classification Op (for display)
        opTrainingFeatures.InputImage.connect( opData.Image )
        opClassify.InputImages.connect( opData.Image )
        
        if ilastik_config.getboolean('ilastik', 'debug'):
            opClassify.PredictionMasks.connect( opData.ImageGroup[self.DATA_ROLE_PREDICTION_MASK] )
        
        # Feature Images -> Classification Op (for training, prediction)
        opClassify.FeatureImages.connect( opTrainingFeatures.OutputImage )
        opClassify.CachedFeatureImages.connect( opTrainingFeatures.CachedOutputImage )
        
        # Training flags -> Classification Op (for GUI restrictions)
        opClassify.LabelsAllowedFlags.connect( opData.AllowLabels )

        # Data Export connections
        opDataExport.RawData.connect( opData.ImageGroup[self.DATA_ROLE_RAW] )
        opDataExport.RawDatasetInfo.connect( opData.DatasetGroup[self.DATA_ROLE_RAW] )
        opDataExport.ConstraintDataset.connect( opData.ImageGroup[self.DATA_ROLE_RAW] )
        opDataExport.Inputs.resize( len(self.EXPORT_NAMES) )
        opDataExport.Inputs[0].connect( opClassify.HeadlessPredictionProbabilities )
        opDataExport.Inputs[1].connect( opClassify.SimpleSegmentation )
        opDataExport.Inputs[2].connect( opClassify.HeadlessUncertaintyEstimate )
        opDataExport.Inputs[3].connect( opClassify.FeatureImages )
        for slot in opDataExport.Inputs:
            assert slot.partner is not None

    def _initBatchWorkflow(self):
        """
        Connect the batch-mode top-level operators to the training workflow and to each other.
        """
        # Access applet operators from the training workflow
        opTrainingDataSelection = self.dataSelectionApplet.topLevelOperator
        opTrainingFeatures = self.featureSelectionApplet.topLevelOperator
        opClassify = self.pcApplet.topLevelOperator
        
        # Access the batch operators
        opBatchInputs = self.batchInputApplet.topLevelOperator
        opBatchResults = self.batchResultsApplet.topLevelOperator
        
        opBatchInputs.DatasetRoles.connect( opTrainingDataSelection.DatasetRoles )
        
        opSelectFirstLane = OperatorWrapper( OpSelectSubslot, parent=self )
        opSelectFirstLane.Inputs.connect( opTrainingDataSelection.ImageGroup )
        opSelectFirstLane.SubslotIndex.setValue(0)
        
        opSelectFirstRole = OpSelectSubslot( parent=self )
        opSelectFirstRole.Inputs.connect( opSelectFirstLane.Output )
        opSelectFirstRole.SubslotIndex.setValue(self.DATA_ROLE_RAW)
        
        opBatchResults.ConstraintDataset.connect( opSelectFirstRole.Output )
        
        ## Create additional batch workflow operators
        opBatchFeatures = OperatorWrapper( OpFeatureSelectionNoCache, operator_kwargs={'filter_implementation': self.filter_implementation}, parent=self, promotedSlotNames=['InputImage'] )
        opBatchPredictionPipeline = OperatorWrapper( OpPredictionPipelineNoCache, parent=self )
        
        ## Connect Operators ##
        opTranspose = OpTransposeSlots( parent=self )
        opTranspose.OutputLength.setValue(2) # There are 2 roles
        opTranspose.Inputs.connect( opBatchInputs.DatasetGroup )
        opTranspose.name = "batchTransposeInputs"
        
        # Provide dataset paths from data selection applet to the batch export applet
        opBatchResults.RawDatasetInfo.connect( opTranspose.Outputs[self.DATA_ROLE_RAW] )
        opBatchResults.WorkingDirectory.connect( opBatchInputs.WorkingDirectory )
        
        # Connect (clone) the feature operator inputs from 
        #  the interactive workflow's features operator (which gets them from the GUI)
        opBatchFeatures.Scales.connect( opTrainingFeatures.Scales )
        opBatchFeatures.FeatureIds.connect( opTrainingFeatures.FeatureIds )
        opBatchFeatures.SelectionMatrix.connect( opTrainingFeatures.SelectionMatrix )
        
        # Classifier and NumClasses are provided by the interactive workflow
        opBatchPredictionPipeline.Classifier.connect( opClassify.Classifier )
        opBatchPredictionPipeline.NumClasses.connect( opClassify.NumClasses )
        
        # Provide these for the gui
        opBatchResults.RawData.connect( opBatchInputs.Image )
        opBatchResults.PmapColors.connect( opClassify.PmapColors )
        opBatchResults.LabelNames.connect( opClassify.LabelNames )
        
        # Connect Image pathway:
        # Input Image -> Features Op -> Prediction Op -> Export
        opBatchFeatures.InputImage.connect( opBatchInputs.Image )
        opBatchPredictionPipeline.PredictionMask.connect( opBatchInputs.Image1 )
        opBatchPredictionPipeline.FeatureImages.connect( opBatchFeatures.OutputImage )

        opBatchResults.SelectionNames.setValue( self.EXPORT_NAMES )        
        # opBatchResults.Inputs is indexed by [lane][selection],
        # Use OpTranspose to allow connection.
        opTransposeBatchInputs = OpTransposeSlots( parent=self )
        opTransposeBatchInputs.name = "opTransposeBatchInputs"
        opTransposeBatchInputs.OutputLength.setValue(0)
        opTransposeBatchInputs.Inputs.resize( len(self.EXPORT_NAMES) )
        opTransposeBatchInputs.Inputs[0].connect( opBatchPredictionPipeline.HeadlessPredictionProbabilities ) # selection 0
        opTransposeBatchInputs.Inputs[1].connect( opBatchPredictionPipeline.SimpleSegmentation ) # selection 1
        opTransposeBatchInputs.Inputs[2].connect( opBatchPredictionPipeline.HeadlessUncertaintyEstimate ) # selection 2
        opTransposeBatchInputs.Inputs[3].connect( opBatchPredictionPipeline.FeatureImages ) # selection 3
        for slot in opTransposeBatchInputs.Inputs:
            assert slot.partner is not None
        
        # Now opTransposeBatchInputs.Outputs is level-2 indexed by [lane][selection]
        opBatchResults.Inputs.connect( opTransposeBatchInputs.Outputs )

        # We don't actually need the cached path in the batch pipeline.
        # Just connect the uncached features here to satisfy the operator.
        #opBatchPredictionPipeline.CachedFeatureImages.connect( opBatchFeatures.OutputImage )

        self.opBatchFeatures = opBatchFeatures
        self.opBatchPredictionPipeline = opBatchPredictionPipeline

    def handleAppletStateUpdateRequested(self):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.appletStateUpdateRequested`
        """
        # If no data, nothing else is ready.
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        input_ready = len(opDataSelection.ImageGroup) > 0 and not self.dataSelectionApplet.busy

        opFeatureSelection = self.featureSelectionApplet.topLevelOperator
        featureOutput = opFeatureSelection.OutputImage
        features_ready = input_ready and \
                         len(featureOutput) > 0 and  \
                         featureOutput[0].ready() and \
                         (TinyVector(featureOutput[0].meta.shape) > 0).all()

        opDataExport = self.dataExportApplet.topLevelOperator
        opPixelClassification = self.pcApplet.topLevelOperator

        invalid_classifier = opPixelClassification.classifier_cache.fixAtCurrent.value and \
                             opPixelClassification.classifier_cache.Output.ready() and\
                             opPixelClassification.classifier_cache.Output.value is None

        predictions_ready = features_ready and \
                            not invalid_classifier and \
                            len(opDataExport.Inputs) > 0 and \
                            opDataExport.Inputs[0][0].ready() and \
                            (TinyVector(opDataExport.Inputs[0][0].meta.shape) > 0).all()

        # Problems can occur if the features or input data are changed during live update mode.
        # Don't let the user do that.
        live_update_active = not opPixelClassification.FreezePredictions.value
        
        self._shell.setAppletEnabled(self.dataSelectionApplet, not live_update_active)
        self._shell.setAppletEnabled(self.featureSelectionApplet, input_ready and not live_update_active)
        self._shell.setAppletEnabled(self.pcApplet, features_ready)
        self._shell.setAppletEnabled(self.dataExportApplet, predictions_ready)

        if self.batchInputApplet is not None:
            # Training workflow must be fully configured before batch can be used
            self._shell.setAppletEnabled(self.batchInputApplet, predictions_ready)
    
            opBatchDataSelection = self.batchInputApplet.topLevelOperator
            batch_input_ready = predictions_ready and \
                                len(opBatchDataSelection.ImageGroup) > 0
            self._shell.setAppletEnabled(self.batchResultsApplet, batch_input_ready)
            
        # Lastly, check for certain "busy" conditions, during which we 
        #  should prevent the shell from closing the project.
        busy = False
        busy |= self.dataSelectionApplet.busy
        busy |= self.featureSelectionApplet.busy
        busy |= self.dataExportApplet.busy
        self._shell.enableProjectChanges( not busy )

    def getHeadlessOutputSlot(self, slotId):
        # "Regular" (i.e. with the images that the user selected as input data)
        if slotId == "Predictions":
            return self.pcApplet.topLevelOperator.HeadlessPredictionProbabilities
        elif slotId == "PredictionsUint8":
            return self.pcApplet.topLevelOperator.HeadlessUint8PredictionProbabilities
        # "Batch" (i.e. with the images that the user selected as batch inputs).
        elif slotId == "BatchPredictions":
            return self.opBatchPredictionPipeline.HeadlessPredictionProbabilities
        if slotId == "BatchPredictionsUint8":
            return self.opBatchPredictionPipeline.HeadlessUint8PredictionProbabilities
        
        raise Exception("Unknown headless output slot")


    def onProjectLoaded(self, projectManager):
        """
        Overridden from Workflow base class.  Called by the Project Manager.
        
        If the user provided command-line arguments, use them to configure 
        the workflow for batch mode and export all results.
        (This workflow's headless mode supports only batch mode for now.)
        """
        if self.generate_random_labels:
            self._generate_random_labels(self.random_label_count, self.random_label_value)
            logger.info("Saving project...")
            self._shell.projectManager.saveProject()
            logger.info("Done.")
        
        if self.print_labels_by_slice:
            self._print_labels_by_slice( self.label_search_value )

        # Configure the batch data selection operator.
        if self._batch_input_args and (self._batch_input_args.input_files or self._batch_input_args.raw_data):
            self.batchInputApplet.configure_operator_with_parsed_args( self._batch_input_args )
        
        # Configure the data export operator.
        if self._batch_export_args:
            self.batchResultsApplet.configure_operator_with_parsed_args( self._batch_export_args )

        if self._batch_input_args and self.pcApplet.topLevelOperator.classifier_cache._dirty:
            logger.warn("Your project file has no classifier.  A new classifier will be trained for this run.")

        if self._headless:
            # In headless mode, let's see the messages from the training operator.
            logging.getLogger("lazyflow.operators.classifierOperators").setLevel(logging.DEBUG)
        
        if self.retrain:
            # Cause the classifier to be dirty so it is forced to retrain.
            # (useful if the stored labels were changed outside ilastik)
            self.pcApplet.topLevelOperator.opTrain.ClassifierFactory.setDirty()
            
            # Request the classifier, which forces training
            self.pcApplet.topLevelOperator.FreezePredictions.setValue(False)
            _ = self.pcApplet.topLevelOperator.Classifier.value

            # store new classifier to project file
            projectManager.saveProject(force_all_save=False)

        if self._headless and self._batch_input_args and self._batch_export_args:
            # Make sure we're using the up-to-date classifier.
            self.pcApplet.topLevelOperator.FreezePredictions.setValue(False)
        
            # Now run the batch export and report progress....
            opBatchDataExport = self.batchResultsApplet.topLevelOperator
            for i, opExportDataLaneView in enumerate(opBatchDataExport):
                logger.info( "Exporting result {} to {}".format(i, opExportDataLaneView.ExportPath.value) )
    
                sys.stdout.write( "Result {}/{} Progress: ".format( i, len( opBatchDataExport ) ) )
                sys.stdout.flush()
                def print_progress( progress ):
                    sys.stdout.write( "{} ".format( progress ) )
                    sys.stdout.flush()
    
                # If the operator provides a progress signal, use it.
                slotProgressSignal = opExportDataLaneView.progressSignal
                slotProgressSignal.subscribe( print_progress )
                opExportDataLaneView.run_export()
                
                # Finished.
                sys.stdout.write("\n")


    def _print_labels_by_slice(self, search_value):
        """
        Iterate over each label image in the project and print the number of labels present on each Z-slice of the image.
        (This is a special feature requested by the FlyEM proofreaders.)
        """
        opTopLevelClassify = self.pcApplet.topLevelOperator
        project_label_count = 0
        for image_index, label_slot in enumerate(opTopLevelClassify.LabelImages):
            tagged_shape = label_slot.meta.getTaggedShape()
            if 'z' not in tagged_shape:
                logger.error("Can't print label counts by Z-slices.  Image #{} has no Z-dimension.".format(image_index))
            else:
                logger.info("Label counts in Z-slices of Image #{}:".format( image_index ))
                slicing = [slice(None)] * len(tagged_shape)
                blank_slices = []
                image_label_count = 0
                for z in range(tagged_shape['z']):
                    slicing[tagged_shape.keys().index('z')] = slice(z, z+1)
                    label_slice = label_slot[slicing].wait()
                    if search_value:                        
                        count = (label_slice == search_value).sum()
                    else:
                        count = (label_slice != 0).sum()
                    if count > 0:
                        logger.info("Z={}: {}".format( z, count ))
                        image_label_count += count
                    else:
                        blank_slices.append( z )
                project_label_count += image_label_count
                if len(blank_slices) > 20:
                    # Don't list the blank slices if there were a lot of them.
                    logger.info("Image #{} has {} blank slices.".format( image_index, len(blank_slices) ))
                elif len(blank_slices) > 0:
                    logger.info( "Image #{} has {} blank slices: {}".format( image_index, len(blank_slices), blank_slices ) )
                else:
                    logger.info( "Image #{} has no blank slices.".format( image_index ) )
                logger.info( "Total labels for Image #{}: {}".format( image_index, image_label_count ) )
        logger.info( "Total labels for project: {}".format( project_label_count ) )

    
    def _generate_random_labels(self, labels_per_image, label_value):
        """
        Inject random labels into the project file.
        (This is a special feature requested by the FlyEM proofreaders.)
        """
        logger.info( "Injecting {} labels of value {} into all images.".format( labels_per_image, label_value ) )
        opTopLevelClassify = self.pcApplet.topLevelOperator
        
        label_names = copy.copy(opTopLevelClassify.LabelNames.value)
        while len(label_names) < label_value:
            label_names.append( "Label {}".format( len(label_names)+1 ) )
        
        opTopLevelClassify.LabelNames.setValue( label_names )
        
        for image_index in range(len(opTopLevelClassify.LabelImages)):
            logger.info( "Injecting labels into image #{}".format( image_index ) )
            # For reproducibility of label generation
            SEED = 1
            numpy.random.seed([SEED, image_index])
        
            label_input_slot = opTopLevelClassify.LabelInputs[image_index]
            label_output_slot = opTopLevelClassify.LabelImages[image_index]
        
            shape = label_output_slot.meta.shape
            random_labels = numpy.zeros( shape=shape, dtype=numpy.uint8 )
            num_pixels = len(random_labels.flat)
            current_progress = -1
            for sample_index in range(labels_per_image):
                flat_index = numpy.random.randint(0,num_pixels)
                # Don't overwrite existing labels
                # Keep looking until we find a blank pixel
                while random_labels.flat[flat_index]:
                    flat_index = numpy.random.randint(0,num_pixels)
                random_labels.flat[flat_index] = label_value

                # Print progress every 10%
                progress = float(sample_index) / labels_per_image
                progress = 10 * (int(100*progress)/10)
                if progress != current_progress:
                    current_progress = progress
                    sys.stdout.write( "{}% ".format( current_progress ) )
                    sys.stdout.flush()

            sys.stdout.write( "100%\n" )
            # Write into the operator
            label_input_slot[fullSlicing(shape)] = random_labels
        
        logger.info( "Done injecting labels" )
class PixelClassificationWorkflow(Workflow):
    
    workflowName = "Pixel Classification"
    workflowDescription = "This is obviously self-explanatory."
    defaultAppletIndex = 1 # show DataSelection by default
    
    DATA_ROLE_RAW = 0
    DATA_ROLE_PREDICTION_MASK = 1
    ROLE_NAMES = ['Raw Data', 'Prediction Mask']
    EXPORT_NAMES = ['Probabilities', 'Simple Segmentation', 'Uncertainty', 'Features']
    
    @property
    def applets(self):
        return self._applets

    @property
    def imageNameListSlot(self):
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def __init__(self, shell, headless, workflow_cmdline_args, project_creation_args, *args, **kwargs):
        # Create a graph to be shared by all operators
        graph = Graph()
        super( PixelClassificationWorkflow, self ).__init__( shell, headless, workflow_cmdline_args, project_creation_args, graph=graph, *args, **kwargs )
        self.stored_classifer = None
        self._applets = []
        self._workflow_cmdline_args = workflow_cmdline_args
        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument('--filter', help="pixel feature filter implementation.", choices=['Original', 'Refactored', 'Interpolated'], default='Original')
        parser.add_argument('--print-labels-by-slice', help="Print the number of labels for each Z-slice of each image.", action="store_true")
        parser.add_argument('--label-search-value', help="If provided, only this value is considered when using --print-labels-by-slice", default=0, type=int)
        parser.add_argument('--generate-random-labels', help="Add random labels to the project file.", action="store_true")
        parser.add_argument('--random-label-value', help="The label value to use injecting random labels", default=1, type=int)
        parser.add_argument('--random-label-count', help="The number of random labels to inject via --generate-random-labels", default=2000, type=int)
        parser.add_argument('--retrain', help="Re-train the classifier based on labels stored in project file, and re-save.", action="store_true")
        parser.add_argument('--tree-count', help='Number of trees for Vigra RF classifier.', type=int)
        parser.add_argument('--variable-importance-path', help='Location of variable-importance table.', type=str)
        parser.add_argument('--label-proportion', help='Proportion of feature-pixels used to train the classifier.', type=float)

        # Parse the creation args: These were saved to the project file when this project was first created.
        parsed_creation_args, unused_args = parser.parse_known_args(project_creation_args)
        self.filter_implementation = parsed_creation_args.filter
        
        # Parse the cmdline args for the current session.
        parsed_args, unused_args = parser.parse_known_args(workflow_cmdline_args)
        self.print_labels_by_slice = parsed_args.print_labels_by_slice
        self.label_search_value = parsed_args.label_search_value
        self.generate_random_labels = parsed_args.generate_random_labels
        self.random_label_value = parsed_args.random_label_value
        self.random_label_count = parsed_args.random_label_count
        self.retrain = parsed_args.retrain
        self.tree_count = parsed_args.tree_count
        self.variable_importance_path = parsed_args.variable_importance_path
        self.label_proportion = parsed_args.label_proportion

        if parsed_args.filter and parsed_args.filter != parsed_creation_args.filter:
            logger.error("Ignoring new --filter setting.  Filter implementation cannot be changed after initial project creation.")
        
        data_instructions = "Select your input data using the 'Raw Data' tab shown on the right.\n\n"\
                            "Power users: Optionally use the 'Prediction Mask' tab to supply a binary image that tells ilastik where it should avoid computations you don't need."

        # Applets for training (interactive) workflow 
        self.projectMetadataApplet = ProjectMetadataApplet()
        
        self.dataSelectionApplet = self.createDataSelectionApplet()
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        
        # see role constants, above
        opDataSelection.DatasetRoles.setValue( PixelClassificationWorkflow.ROLE_NAMES )

        self.featureSelectionApplet = self.createFeatureSelectionApplet()

        self.pcApplet = self.createPixelClassificationApplet()
        opClassify = self.pcApplet.topLevelOperator

        self.dataExportApplet = PixelClassificationDataExportApplet(self, "Prediction Export")
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.PmapColors.connect( opClassify.PmapColors )
        opDataExport.LabelNames.connect( opClassify.LabelNames )
        opDataExport.WorkingDirectory.connect( opDataSelection.WorkingDirectory )
        opDataExport.SelectionNames.setValue( self.EXPORT_NAMES )        

        # Expose for shell
        self._applets.append(self.projectMetadataApplet)
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.featureSelectionApplet)
        self._applets.append(self.pcApplet)
        self._applets.append(self.dataExportApplet)
        
        self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        self.dataExportApplet.post_process_entire_export = self.post_process_entire_export

        self.batchProcessingApplet = BatchProcessingApplet(self, 
                                                           "Batch Processing", 
                                                           self.dataSelectionApplet, 
                                                           self.dataExportApplet)

        self._applets.append(self.batchProcessingApplet)
        if unused_args:
            # We parse the export setting args first.  All remaining args are considered input files by the input applet.
            self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args( unused_args )
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args( unused_args )
        else:
            self._batch_input_args = None
            self._batch_export_args = None

        if unused_args:
            logger.warn("Unused command-line args: {}".format( unused_args ))

    def createDataSelectionApplet(self):
        """
        Can be overridden by subclasses, if they want to use 
        special parameters to initialize the DataSelectionApplet.
        """
        data_instructions = "Select your input data using the 'Raw Data' tab shown on the right"
        return DataSelectionApplet( self,
                                    "Input Data",
                                    "Input Data",
                                    supportIlastik05Import=True,
                                    instructionText=data_instructions )


    def createFeatureSelectionApplet(self):
        """
        Can be overridden by subclasses, if they want to return their own type of FeatureSelectionApplet.
        NOTE: The applet returned here must have the same interface as the regular FeatureSelectionApplet.
              (If it looks like a duck...)
        """
        return FeatureSelectionApplet(self, "Feature Selection", "FeatureSelections", self.filter_implementation)

    def createPixelClassificationApplet(self):
        """
        Can be overridden by subclasses, if they want to return their own type of PixelClassificationApplet.
        NOTE: The applet returned here must have the same interface as the regular PixelClassificationApplet.
              (If it looks like a duck...)
        """
        return PixelClassificationApplet( self, "PixelClassification" )

    def prepareForNewLane(self, laneIndex):
        """
        Overridden from Workflow base class.
        Called immediately before a new lane is added to the workflow.
        """
        # When the new lane is added, dirty notifications will propagate throughout the entire graph.
        # This means the classifier will be marked 'dirty' even though it is still usable.
        # Before that happens, let's store the classifier, so we can restore it at the end of connectLane(), below.
        opPixelClassification = self.pcApplet.topLevelOperator
        if opPixelClassification.classifier_cache.Output.ready() and \
           not opPixelClassification.classifier_cache._dirty:
            self.stored_classifer = opPixelClassification.classifier_cache.Output.value
        else:
            self.stored_classifer = None
        
    def handleNewLanesAdded(self):
        """
        Overridden from Workflow base class.
        Called immediately after a new lane is added to the workflow and initialized.
        """
        # Restore classifier we saved in prepareForNewLane() (if any)
        if self.stored_classifer:
            self.pcApplet.topLevelOperator.classifier_cache.forceValue(self.stored_classifer)
            # Release reference
            self.stored_classifer = None

    def connectLane(self, laneIndex):
        # Get a handle to each operator
        opData = self.dataSelectionApplet.topLevelOperator.getLane(laneIndex)
        opTrainingFeatures = self.featureSelectionApplet.topLevelOperator.getLane(laneIndex)
        opClassify = self.pcApplet.topLevelOperator.getLane(laneIndex)
        opDataExport = self.dataExportApplet.topLevelOperator.getLane(laneIndex)
        
        # Input Image -> Feature Op
        #         and -> Classification Op (for display)
        opTrainingFeatures.InputImage.connect( opData.Image )
        opClassify.InputImages.connect( opData.Image )
        
        if ilastik_config.getboolean('ilastik', 'debug'):
            opClassify.PredictionMasks.connect( opData.ImageGroup[self.DATA_ROLE_PREDICTION_MASK] )
        
        # Feature Images -> Classification Op (for training, prediction)
        opClassify.FeatureImages.connect( opTrainingFeatures.OutputImage )
        opClassify.CachedFeatureImages.connect( opTrainingFeatures.CachedOutputImage )
        
        # Training flags -> Classification Op (for GUI restrictions)
        opClassify.LabelsAllowedFlags.connect( opData.AllowLabels )

        # Data Export connections
        opDataExport.RawData.connect( opData.ImageGroup[self.DATA_ROLE_RAW] )
        opDataExport.RawDatasetInfo.connect( opData.DatasetGroup[self.DATA_ROLE_RAW] )
        opDataExport.ConstraintDataset.connect( opData.ImageGroup[self.DATA_ROLE_RAW] )
        opDataExport.Inputs.resize( len(self.EXPORT_NAMES) )
        opDataExport.Inputs[0].connect( opClassify.HeadlessPredictionProbabilities )
        opDataExport.Inputs[1].connect( opClassify.SimpleSegmentation )
        opDataExport.Inputs[2].connect( opClassify.HeadlessUncertaintyEstimate )
        opDataExport.Inputs[3].connect( opClassify.FeatureImages )
        for slot in opDataExport.Inputs:
            assert slot.partner is not None

    def handleAppletStateUpdateRequested(self):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.appletStateUpdateRequested`
        """
        # If no data, nothing else is ready.
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        input_ready = len(opDataSelection.ImageGroup) > 0 and not self.dataSelectionApplet.busy

        opFeatureSelection = self.featureSelectionApplet.topLevelOperator
        featureOutput = opFeatureSelection.OutputImage
        features_ready = input_ready and \
                         len(featureOutput) > 0 and  \
                         featureOutput[0].ready() and \
                         (TinyVector(featureOutput[0].meta.shape) > 0).all()

        opDataExport = self.dataExportApplet.topLevelOperator
        opPixelClassification = self.pcApplet.topLevelOperator

        invalid_classifier = opPixelClassification.classifier_cache.fixAtCurrent.value and \
                             opPixelClassification.classifier_cache.Output.ready() and\
                             opPixelClassification.classifier_cache.Output.value is None

        predictions_ready = features_ready and \
                            not invalid_classifier and \
                            len(opDataExport.Inputs) > 0 and \
                            opDataExport.Inputs[0][0].ready() and \
                            (TinyVector(opDataExport.Inputs[0][0].meta.shape) > 0).all()

        # Problems can occur if the features or input data are changed during live update mode.
        # Don't let the user do that.
        live_update_active = not opPixelClassification.FreezePredictions.value
        
        # The user isn't allowed to touch anything while batch processing is running.
        batch_processing_busy = self.batchProcessingApplet.busy
        
        self._shell.setAppletEnabled(self.dataSelectionApplet, not live_update_active and not batch_processing_busy)
        self._shell.setAppletEnabled(self.featureSelectionApplet, input_ready and not live_update_active and not batch_processing_busy)
        self._shell.setAppletEnabled(self.pcApplet, features_ready and not batch_processing_busy)
        self._shell.setAppletEnabled(self.dataExportApplet, predictions_ready and not batch_processing_busy)

        if self.batchProcessingApplet is not None:
            self._shell.setAppletEnabled(self.batchProcessingApplet, predictions_ready and not batch_processing_busy)
    
        # Lastly, check for certain "busy" conditions, during which we 
        #  should prevent the shell from closing the project.
        busy = False
        busy |= self.dataSelectionApplet.busy
        busy |= self.featureSelectionApplet.busy
        busy |= self.dataExportApplet.busy
        busy |= self.batchProcessingApplet.busy
        self._shell.enableProjectChanges( not busy )

    def onProjectLoaded(self, projectManager):
        """
        Overridden from Workflow base class.  Called by the Project Manager.
        
        If the user provided command-line arguments, use them to configure 
        the workflow for batch mode and export all results.
        (This workflow's headless mode supports only batch mode for now.)
        """
        if self.generate_random_labels:
            self._generate_random_labels(self.random_label_count, self.random_label_value)
            logger.info("Saving project...")
            self._shell.projectManager.saveProject()
            logger.info("Done.")
        
        if self.print_labels_by_slice:
            self._print_labels_by_slice( self.label_search_value )

        if self._headless:
            # In headless mode, let's see the messages from the training operator.
            logging.getLogger("lazyflow.operators.classifierOperators").setLevel(logging.DEBUG)

        if self.variable_importance_path: 
            classifier_factory = self.pcApplet.topLevelOperator.opTrain.ClassifierFactory.value
            classifier_factory.set_variable_importance_path( self.variable_importance_path )
            
        if self.tree_count:
            classifier_factory = self.pcApplet.topLevelOperator.opTrain.ClassifierFactory.value
            classifier_factory.set_num_trees( self.tree_count )
                        
        if self.label_proportion:
            classifier_factory = self.pcApplet.topLevelOperator.opTrain.ClassifierFactory.value
            classifier_factory.set_label_proportion( self.label_proportion )
            
        if self.tree_count or self.label_proportion:
            self.pcApplet.topLevelOperator.ClassifierFactory.setDirty()
            
        if self.retrain:
            # Cause the classifier to be dirty so it is forced to retrain.
            # (useful if the stored labels were changed outside ilastik)
            self.pcApplet.topLevelOperator.opTrain.ClassifierFactory.setDirty()
            
            # Request the classifier, which forces training
            self.pcApplet.topLevelOperator.FreezePredictions.setValue(False)
            _ = self.pcApplet.topLevelOperator.Classifier.value

            # store new classifier to project file
            projectManager.saveProject(force_all_save=False)

        # Configure the data export operator.
        if self._batch_export_args:
            self.dataExportApplet.configure_operator_with_parsed_args( self._batch_export_args )

        if self._batch_input_args and self.pcApplet.topLevelOperator.classifier_cache._dirty:
            logger.warn("Your project file has no classifier.  A new classifier will be trained for this run.")

        if self._headless and self._batch_input_args and self._batch_export_args:
            logger.info("Beginning Batch Processing")
            self.batchProcessingApplet.run_export_from_parsed_args(self._batch_input_args)
            logger.info("Completed Batch Processing")

    def prepare_for_entire_export(self):
        self.freeze_status = self.pcApplet.topLevelOperator.FreezePredictions.value
        self.pcApplet.topLevelOperator.FreezePredictions.setValue(False)

    def post_process_entire_export(self):
        self.pcApplet.topLevelOperator.FreezePredictions.setValue(self.freeze_status)

    def _print_labels_by_slice(self, search_value):
        """
        Iterate over each label image in the project and print the number of labels present on each Z-slice of the image.
        (This is a special feature requested by the FlyEM proofreaders.)
        """
        opTopLevelClassify = self.pcApplet.topLevelOperator
        project_label_count = 0
        for image_index, label_slot in enumerate(opTopLevelClassify.LabelImages):
            tagged_shape = label_slot.meta.getTaggedShape()
            if 'z' not in tagged_shape:
                logger.error("Can't print label counts by Z-slices.  Image #{} has no Z-dimension.".format(image_index))
            else:
                logger.info("Label counts in Z-slices of Image #{}:".format( image_index ))
                slicing = [slice(None)] * len(tagged_shape)
                blank_slices = []
                image_label_count = 0
                for z in range(tagged_shape['z']):
                    slicing[tagged_shape.keys().index('z')] = slice(z, z+1)
                    label_slice = label_slot[slicing].wait()
                    if search_value:                        
                        count = (label_slice == search_value).sum()
                    else:
                        count = (label_slice != 0).sum()
                    if count > 0:
                        logger.info("Z={}: {}".format( z, count ))
                        image_label_count += count
                    else:
                        blank_slices.append( z )
                project_label_count += image_label_count
                if len(blank_slices) > 20:
                    # Don't list the blank slices if there were a lot of them.
                    logger.info("Image #{} has {} blank slices.".format( image_index, len(blank_slices) ))
                elif len(blank_slices) > 0:
                    logger.info( "Image #{} has {} blank slices: {}".format( image_index, len(blank_slices), blank_slices ) )
                else:
                    logger.info( "Image #{} has no blank slices.".format( image_index ) )
                logger.info( "Total labels for Image #{}: {}".format( image_index, image_label_count ) )
        logger.info( "Total labels for project: {}".format( project_label_count ) )

    
    def _generate_random_labels(self, labels_per_image, label_value):
        """
        Inject random labels into the project file.
        (This is a special feature requested by the FlyEM proofreaders.)
        """
        logger.info( "Injecting {} labels of value {} into all images.".format( labels_per_image, label_value ) )
        opTopLevelClassify = self.pcApplet.topLevelOperator
        
        label_names = copy.copy(opTopLevelClassify.LabelNames.value)
        while len(label_names) < label_value:
            label_names.append( "Label {}".format( len(label_names)+1 ) )
        
        opTopLevelClassify.LabelNames.setValue( label_names )
        
        for image_index in range(len(opTopLevelClassify.LabelImages)):
            logger.info( "Injecting labels into image #{}".format( image_index ) )
            # For reproducibility of label generation
            SEED = 1
            numpy.random.seed([SEED, image_index])
        
            label_input_slot = opTopLevelClassify.LabelInputs[image_index]
            label_output_slot = opTopLevelClassify.LabelImages[image_index]
        
            shape = label_output_slot.meta.shape
            random_labels = numpy.zeros( shape=shape, dtype=numpy.uint8 )
            num_pixels = len(random_labels.flat)
            current_progress = -1
            for sample_index in range(labels_per_image):
                flat_index = numpy.random.randint(0,num_pixels)
                # Don't overwrite existing labels
                # Keep looking until we find a blank pixel
                while random_labels.flat[flat_index]:
                    flat_index = numpy.random.randint(0,num_pixels)
                random_labels.flat[flat_index] = label_value

                # Print progress every 10%
                progress = float(sample_index) / labels_per_image
                progress = 10 * (int(100*progress)/10)
                if progress != current_progress:
                    current_progress = progress
                    sys.stdout.write( "{}% ".format( current_progress ) )
                    sys.stdout.flush()

            sys.stdout.write( "100%\n" )
            # Write into the operator
            label_input_slot[fullSlicing(shape)] = random_labels
        
        logger.info( "Done injecting labels" )


    def getHeadlessOutputSlot(self, slotId):
        """
        Not used by the regular app.
        Only used for special cluster scripts.
        """
        # "Regular" (i.e. with the images that the user selected as input data)
        if slotId == "Predictions":
            return self.pcApplet.topLevelOperator.HeadlessPredictionProbabilities
        elif slotId == "PredictionsUint8":
            return self.pcApplet.topLevelOperator.HeadlessUint8PredictionProbabilities
        # "Batch" (i.e. with the images that the user selected as batch inputs).
        elif slotId == "BatchPredictions":
            return self.opBatchPredictionPipeline.HeadlessPredictionProbabilities
        if slotId == "BatchPredictionsUint8":
            return self.opBatchPredictionPipeline.HeadlessUint8PredictionProbabilities
        
        raise Exception("Unknown headless output slot")
class PixelClassificationWorkflow(Workflow):

    workflowName = "Pixel Classification"
    workflowDescription = "This is obviously self-explanatory."
    defaultAppletIndex = 0  # show DataSelection by default

    DATA_ROLE_RAW = 0
    DATA_ROLE_PREDICTION_MASK = 1
    ROLE_NAMES = ['Raw Data', 'Prediction Mask']
    EXPORT_NAMES = [
        'Probabilities', 'Simple Segmentation', 'Uncertainty', 'Features',
        'Labels'
    ]

    @property
    def applets(self):
        return self._applets

    @property
    def imageNameListSlot(self):
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def __init__(self, shell, headless, workflow_cmdline_args,
                 project_creation_args, *args, **kwargs):
        # Create a graph to be shared by all operators
        graph = Graph()
        super(PixelClassificationWorkflow,
              self).__init__(shell,
                             headless,
                             workflow_cmdline_args,
                             project_creation_args,
                             graph=graph,
                             *args,
                             **kwargs)
        self.stored_classifier = None
        self._applets = []
        self._workflow_cmdline_args = workflow_cmdline_args
        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument('--filter',
                            help="pixel feature filter implementation.",
                            choices=['Original', 'Refactored', 'Interpolated'],
                            default='Original')
        parser.add_argument(
            '--print-labels-by-slice',
            help="Print the number of labels for each Z-slice of each image.",
            action="store_true")
        parser.add_argument(
            '--label-search-value',
            help=
            "If provided, only this value is considered when using --print-labels-by-slice",
            default=0,
            type=int)
        parser.add_argument('--generate-random-labels',
                            help="Add random labels to the project file.",
                            action="store_true")
        parser.add_argument(
            '--random-label-value',
            help="The label value to use injecting random labels",
            default=1,
            type=int)
        parser.add_argument(
            '--random-label-count',
            help=
            "The number of random labels to inject via --generate-random-labels",
            default=2000,
            type=int)
        parser.add_argument(
            '--retrain',
            help=
            "Re-train the classifier based on labels stored in project file, and re-save.",
            action="store_true")
        parser.add_argument('--tree-count',
                            help='Number of trees for Vigra RF classifier.',
                            type=int)
        parser.add_argument('--variable-importance-path',
                            help='Location of variable-importance table.',
                            type=str)
        parser.add_argument(
            '--label-proportion',
            help='Proportion of feature-pixels used to train the classifier.',
            type=float)

        # Parse the creation args: These were saved to the project file when this project was first created.
        parsed_creation_args, unused_args = parser.parse_known_args(
            project_creation_args)
        self.filter_implementation = parsed_creation_args.filter

        # Parse the cmdline args for the current session.
        parsed_args, unused_args = parser.parse_known_args(
            workflow_cmdline_args)
        self.print_labels_by_slice = parsed_args.print_labels_by_slice
        self.label_search_value = parsed_args.label_search_value
        self.generate_random_labels = parsed_args.generate_random_labels
        self.random_label_value = parsed_args.random_label_value
        self.random_label_count = parsed_args.random_label_count
        self.retrain = parsed_args.retrain
        self.tree_count = parsed_args.tree_count
        self.variable_importance_path = parsed_args.variable_importance_path
        self.label_proportion = parsed_args.label_proportion

        if parsed_args.filter and parsed_args.filter != parsed_creation_args.filter:
            logger.error(
                "Ignoring new --filter setting.  Filter implementation cannot be changed after initial project creation."
            )

        data_instructions = "Select your input data using the 'Raw Data' tab shown on the right.\n\n"\
                            "Power users: Optionally use the 'Prediction Mask' tab to supply a binary image that tells ilastik where it should avoid computations you don't need."

        # Applets for training (interactive) workflow
        self.dataSelectionApplet = self.createDataSelectionApplet()
        opDataSelection = self.dataSelectionApplet.topLevelOperator

        # see role constants, above
        opDataSelection.DatasetRoles.setValue(
            PixelClassificationWorkflow.ROLE_NAMES)

        self.featureSelectionApplet = self.createFeatureSelectionApplet()

        self.pcApplet = self.createPixelClassificationApplet()
        opClassify = self.pcApplet.topLevelOperator

        self.dataExportApplet = PixelClassificationDataExportApplet(
            self, "Prediction Export")
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.PmapColors.connect(opClassify.PmapColors)
        opDataExport.LabelNames.connect(opClassify.LabelNames)
        opDataExport.WorkingDirectory.connect(opDataSelection.WorkingDirectory)
        opDataExport.SelectionNames.setValue(self.EXPORT_NAMES)

        # Expose for shell
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.featureSelectionApplet)
        self._applets.append(self.pcApplet)
        self._applets.append(self.dataExportApplet)

        self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        self.dataExportApplet.post_process_entire_export = self.post_process_entire_export

        self.batchProcessingApplet = BatchProcessingApplet(
            self, "Batch Processing", self.dataSelectionApplet,
            self.dataExportApplet)

        self._applets.append(self.batchProcessingApplet)
        if unused_args:
            # We parse the export setting args first.  All remaining args are considered input files by the input applet.
            self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args(
                unused_args)
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args(
                unused_args)
        else:
            self._batch_input_args = None
            self._batch_export_args = None

        if unused_args:
            logger.warn("Unused command-line args: {}".format(unused_args))

    def createDataSelectionApplet(self):
        """
        Can be overridden by subclasses, if they want to use 
        special parameters to initialize the DataSelectionApplet.
        """
        data_instructions = "Select your input data using the 'Raw Data' tab shown on the right"
        return DataSelectionApplet(self,
                                   "Input Data",
                                   "Input Data",
                                   supportIlastik05Import=True,
                                   instructionText=data_instructions)

    def createFeatureSelectionApplet(self):
        """
        Can be overridden by subclasses, if they want to return their own type of FeatureSelectionApplet.
        NOTE: The applet returned here must have the same interface as the regular FeatureSelectionApplet.
              (If it looks like a duck...)
        """
        return FeatureSelectionApplet(self, "Feature Selection",
                                      "FeatureSelections",
                                      self.filter_implementation)

    def createPixelClassificationApplet(self):
        """
        Can be overridden by subclasses, if they want to return their own type of PixelClassificationApplet.
        NOTE: The applet returned here must have the same interface as the regular PixelClassificationApplet.
              (If it looks like a duck...)
        """
        return PixelClassificationApplet(self, "PixelClassification")

    def prepareForNewLane(self, laneIndex):
        """
        Overridden from Workflow base class.
        Called immediately before a new lane is added to the workflow.
        """
        # When the new lane is added, dirty notifications will propagate throughout the entire graph.
        # This means the classifier will be marked 'dirty' even though it is still usable.
        # Before that happens, let's store the classifier, so we can restore it in handleNewLanesAdded(), below.
        opPixelClassification = self.pcApplet.topLevelOperator
        if opPixelClassification.classifier_cache.Output.ready() and \
           not opPixelClassification.classifier_cache._dirty:
            self.stored_classifier = opPixelClassification.classifier_cache.Output.value
        else:
            self.stored_classifier = None

    def handleNewLanesAdded(self):
        """
        Overridden from Workflow base class.
        Called immediately after a new lane is added to the workflow and initialized.
        """
        # Restore classifier we saved in prepareForNewLane() (if any)
        if self.stored_classifier:
            self.pcApplet.topLevelOperator.classifier_cache.forceValue(
                self.stored_classifier)
            # Release reference
            self.stored_classifier = None

    def connectLane(self, laneIndex):
        # Get a handle to each operator
        opData = self.dataSelectionApplet.topLevelOperator.getLane(laneIndex)
        opTrainingFeatures = self.featureSelectionApplet.topLevelOperator.getLane(
            laneIndex)
        opClassify = self.pcApplet.topLevelOperator.getLane(laneIndex)
        opDataExport = self.dataExportApplet.topLevelOperator.getLane(
            laneIndex)

        # Input Image -> Feature Op
        #         and -> Classification Op (for display)
        opTrainingFeatures.InputImage.connect(opData.Image)
        opClassify.InputImages.connect(opData.Image)

        if ilastik_config.getboolean('ilastik', 'debug'):
            opClassify.PredictionMasks.connect(
                opData.ImageGroup[self.DATA_ROLE_PREDICTION_MASK])

        # Feature Images -> Classification Op (for training, prediction)
        opClassify.FeatureImages.connect(opTrainingFeatures.OutputImage)
        opClassify.CachedFeatureImages.connect(
            opTrainingFeatures.CachedOutputImage)

        # Data Export connections
        opDataExport.RawData.connect(opData.ImageGroup[self.DATA_ROLE_RAW])
        opDataExport.RawDatasetInfo.connect(
            opData.DatasetGroup[self.DATA_ROLE_RAW])
        opDataExport.ConstraintDataset.connect(
            opData.ImageGroup[self.DATA_ROLE_RAW])
        opDataExport.Inputs.resize(len(self.EXPORT_NAMES))
        opDataExport.Inputs[0].connect(
            opClassify.HeadlessPredictionProbabilities)
        opDataExport.Inputs[1].connect(opClassify.SimpleSegmentation)
        opDataExport.Inputs[2].connect(opClassify.HeadlessUncertaintyEstimate)
        opDataExport.Inputs[3].connect(opClassify.FeatureImages)
        opDataExport.Inputs[4].connect(opClassify.LabelImages)
        for slot in opDataExport.Inputs:
            assert slot.partner is not None

    def handleAppletStateUpdateRequested(self):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.appletStateUpdateRequested`
        """
        # If no data, nothing else is ready.
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        input_ready = len(opDataSelection.ImageGroup
                          ) > 0 and not self.dataSelectionApplet.busy

        opFeatureSelection = self.featureSelectionApplet.topLevelOperator
        featureOutput = opFeatureSelection.OutputImage
        features_ready = input_ready and \
                         len(featureOutput) > 0 and  \
                         featureOutput[0].ready() and \
                         (TinyVector(featureOutput[0].meta.shape) > 0).all()

        opDataExport = self.dataExportApplet.topLevelOperator
        opPixelClassification = self.pcApplet.topLevelOperator

        invalid_classifier = opPixelClassification.classifier_cache.fixAtCurrent.value and \
                             opPixelClassification.classifier_cache.Output.ready() and\
                             opPixelClassification.classifier_cache.Output.value is None

        predictions_ready = features_ready and \
                            not invalid_classifier and \
                            len(opDataExport.Inputs) > 0 and \
                            opDataExport.Inputs[0][0].ready() and \
                            (TinyVector(opDataExport.Inputs[0][0].meta.shape) > 0).all()

        # Problems can occur if the features or input data are changed during live update mode.
        # Don't let the user do that.
        live_update_active = not opPixelClassification.FreezePredictions.value

        # The user isn't allowed to touch anything while batch processing is running.
        batch_processing_busy = self.batchProcessingApplet.busy

        self._shell.setAppletEnabled(
            self.dataSelectionApplet, not live_update_active
            and not batch_processing_busy)
        self._shell.setAppletEnabled(
            self.featureSelectionApplet, input_ready and not live_update_active
            and not batch_processing_busy)
        self._shell.setAppletEnabled(
            self.pcApplet, features_ready and not batch_processing_busy)
        self._shell.setAppletEnabled(
            self.dataExportApplet, predictions_ready
            and not batch_processing_busy)

        if self.batchProcessingApplet is not None:
            self._shell.setAppletEnabled(
                self.batchProcessingApplet, predictions_ready
                and not batch_processing_busy)

        # Lastly, check for certain "busy" conditions, during which we
        #  should prevent the shell from closing the project.
        busy = False
        busy |= self.dataSelectionApplet.busy
        busy |= self.featureSelectionApplet.busy
        busy |= self.dataExportApplet.busy
        busy |= self.batchProcessingApplet.busy
        self._shell.enableProjectChanges(not busy)

    def onProjectLoaded(self, projectManager):
        """
        Overridden from Workflow base class.  Called by the Project Manager.
        
        If the user provided command-line arguments, use them to configure 
        the workflow for batch mode and export all results.
        (This workflow's headless mode supports only batch mode for now.)
        """
        if self.generate_random_labels:
            self._generate_random_labels(self.random_label_count,
                                         self.random_label_value)
            logger.info("Saving project...")
            self._shell.projectManager.saveProject()
            logger.info("Done.")

        if self.print_labels_by_slice:
            self._print_labels_by_slice(self.label_search_value)

        if self._headless:
            # In headless mode, let's see the messages from the training operator.
            logging.getLogger(
                "lazyflow.operators.classifierOperators").setLevel(
                    logging.DEBUG)

        if self.variable_importance_path:
            classifier_factory = self.pcApplet.topLevelOperator.opTrain.ClassifierFactory.value
            classifier_factory.set_variable_importance_path(
                self.variable_importance_path)

        if self.tree_count:
            classifier_factory = self.pcApplet.topLevelOperator.opTrain.ClassifierFactory.value
            classifier_factory.set_num_trees(self.tree_count)

        if self.label_proportion:
            classifier_factory = self.pcApplet.topLevelOperator.opTrain.ClassifierFactory.value
            classifier_factory.set_label_proportion(self.label_proportion)

        if self.tree_count or self.label_proportion:
            self.pcApplet.topLevelOperator.ClassifierFactory.setDirty()

        if self.retrain:
            self._force_retrain_classifier(projectManager)

        # Configure the data export operator.
        if self._batch_export_args:
            self.dataExportApplet.configure_operator_with_parsed_args(
                self._batch_export_args)

        if self._batch_input_args and self.pcApplet.topLevelOperator.classifier_cache._dirty:
            logger.warn(
                "Your project file has no classifier.  A new classifier will be trained for this run."
            )

        if self._headless and self._batch_input_args and self._batch_export_args:
            logger.info("Beginning Batch Processing")
            self.batchProcessingApplet.run_export_from_parsed_args(
                self._batch_input_args)
            logger.info("Completed Batch Processing")

    def prepare_for_entire_export(self):
        """
        Assigned to DataExportApplet.prepare_for_entire_export
        (See above.)
        """
        self.freeze_status = self.pcApplet.topLevelOperator.FreezePredictions.value
        self.pcApplet.topLevelOperator.FreezePredictions.setValue(False)

    def post_process_entire_export(self):
        """
        Assigned to DataExportApplet.post_process_entire_export
        (See above.)
        """
        self.pcApplet.topLevelOperator.FreezePredictions.setValue(
            self.freeze_status)

    def _force_retrain_classifier(self, projectManager):
        # Cause the classifier to be dirty so it is forced to retrain.
        # (useful if the stored labels were changed outside ilastik)
        self.pcApplet.topLevelOperator.opTrain.ClassifierFactory.setDirty()

        # Request the classifier, which forces training
        self.pcApplet.topLevelOperator.FreezePredictions.setValue(False)
        _ = self.pcApplet.topLevelOperator.Classifier.value

        # store new classifier to project file
        projectManager.saveProject(force_all_save=False)

    def _print_labels_by_slice(self, search_value):
        """
        Iterate over each label image in the project and print the number of labels present on each Z-slice of the image.
        (This is a special feature requested by the FlyEM proofreaders.)
        """
        opTopLevelClassify = self.pcApplet.topLevelOperator
        project_label_count = 0
        for image_index, label_slot in enumerate(
                opTopLevelClassify.LabelImages):
            tagged_shape = label_slot.meta.getTaggedShape()
            if 'z' not in tagged_shape:
                logger.error(
                    "Can't print label counts by Z-slices.  Image #{} has no Z-dimension."
                    .format(image_index))
            else:
                logger.info("Label counts in Z-slices of Image #{}:".format(
                    image_index))
                slicing = [slice(None)] * len(tagged_shape)
                blank_slices = []
                image_label_count = 0
                for z in range(tagged_shape['z']):
                    slicing[tagged_shape.keys().index('z')] = slice(z, z + 1)
                    label_slice = label_slot[slicing].wait()
                    if search_value:
                        count = (label_slice == search_value).sum()
                    else:
                        count = (label_slice != 0).sum()
                    if count > 0:
                        logger.info("Z={}: {}".format(z, count))
                        image_label_count += count
                    else:
                        blank_slices.append(z)
                project_label_count += image_label_count
                if len(blank_slices) > 20:
                    # Don't list the blank slices if there were a lot of them.
                    logger.info("Image #{} has {} blank slices.".format(
                        image_index, len(blank_slices)))
                elif len(blank_slices) > 0:
                    logger.info("Image #{} has {} blank slices: {}".format(
                        image_index, len(blank_slices), blank_slices))
                else:
                    logger.info(
                        "Image #{} has no blank slices.".format(image_index))
                logger.info("Total labels for Image #{}: {}".format(
                    image_index, image_label_count))
        logger.info("Total labels for project: {}".format(project_label_count))

    def _generate_random_labels(self, labels_per_image, label_value):
        """
        Inject random labels into the project file.
        (This is a special feature requested by the FlyEM proofreaders.)
        """
        logger.info("Injecting {} labels of value {} into all images.".format(
            labels_per_image, label_value))
        opTopLevelClassify = self.pcApplet.topLevelOperator

        label_names = copy.copy(opTopLevelClassify.LabelNames.value)
        while len(label_names) < label_value:
            label_names.append("Label {}".format(len(label_names) + 1))

        opTopLevelClassify.LabelNames.setValue(label_names)

        for image_index in range(len(opTopLevelClassify.LabelImages)):
            logger.info("Injecting labels into image #{}".format(image_index))
            # For reproducibility of label generation
            SEED = 1
            numpy.random.seed([SEED, image_index])

            label_input_slot = opTopLevelClassify.LabelInputs[image_index]
            label_output_slot = opTopLevelClassify.LabelImages[image_index]

            shape = label_output_slot.meta.shape
            random_labels = numpy.zeros(shape=shape, dtype=numpy.uint8)
            num_pixels = len(random_labels.flat)
            current_progress = -1
            for sample_index in range(labels_per_image):
                flat_index = numpy.random.randint(0, num_pixels)
                # Don't overwrite existing labels
                # Keep looking until we find a blank pixel
                while random_labels.flat[flat_index]:
                    flat_index = numpy.random.randint(0, num_pixels)
                random_labels.flat[flat_index] = label_value

                # Print progress every 10%
                progress = float(sample_index) // labels_per_image
                progress = 10 * (int(100 * progress) // 10)
                if progress != current_progress:
                    current_progress = progress
                    sys.stdout.write("{}% ".format(current_progress))
                    sys.stdout.flush()

            sys.stdout.write("100%\n")
            # Write into the operator
            label_input_slot[fullSlicing(shape)] = random_labels

        logger.info("Done injecting labels")

    def getHeadlessOutputSlot(self, slotId):
        """
        Not used by the regular app.
        Only used for special cluster scripts.
        """
        # "Regular" (i.e. with the images that the user selected as input data)
        if slotId == "Predictions":
            return self.pcApplet.topLevelOperator.HeadlessPredictionProbabilities
        elif slotId == "PredictionsUint8":
            return self.pcApplet.topLevelOperator.HeadlessUint8PredictionProbabilities
        # "Batch" (i.e. with the images that the user selected as batch inputs).
        elif slotId == "BatchPredictions":
            return self.opBatchPredictionPipeline.HeadlessPredictionProbabilities
        if slotId == "BatchPredictionsUint8":
            return self.opBatchPredictionPipeline.HeadlessUint8PredictionProbabilities

        raise Exception("Unknown headless output slot")
Exemplo n.º 6
0
class NewAutocontextWorkflowBase(Workflow):

    workflowName = "New Autocontext Base"
    defaultAppletIndex = 0  # show DataSelection by default

    DATA_ROLE_RAW = 0
    DATA_ROLE_PREDICTION_MASK = 1

    # First export names must match these for the export GUI, because we re-use the ordinary PC gui
    # (See PixelClassificationDataExportGui.)
    EXPORT_NAMES_PER_STAGE = [
        "Probabilities", "Simple Segmentation", "Uncertainty", "Features",
        "Labels", "Input"
    ]

    @property
    def applets(self):
        return self._applets

    @property
    def imageNameListSlot(self):
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def __init__(self, shell, headless, workflow_cmdline_args,
                 project_creation_args, n_stages, *args, **kwargs):
        """
        n_stages: How many iterations of feature selection and pixel classification should be inserted into the workflow.

        All other params are just as in PixelClassificationWorkflow
        """
        # Create a graph to be shared by all operators
        graph = Graph()
        super(NewAutocontextWorkflowBase, self).__init__(shell,
                                                         headless,
                                                         workflow_cmdline_args,
                                                         project_creation_args,
                                                         graph=graph,
                                                         *args,
                                                         **kwargs)
        self.stored_classifers = []
        self._applets = []
        self._workflow_cmdline_args = workflow_cmdline_args

        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument(
            "--retrain",
            help=
            "Re-train the classifier based on labels stored in project file, and re-save.",
            action="store_true",
        )

        # Parse the creation args: These were saved to the project file when this project was first created.
        parsed_creation_args, unused_args = parser.parse_known_args(
            project_creation_args)

        # Parse the cmdline args for the current session.
        parsed_args, unused_args = parser.parse_known_args(
            workflow_cmdline_args)
        self.retrain = parsed_args.retrain

        data_instructions = (
            "Select your input data using the 'Raw Data' tab shown on the right.\n\n"
            "Power users: Optionally use the 'Prediction Mask' tab to supply a binary image that tells ilastik where it should avoid computations you don't need."
        )

        self.dataSelectionApplet = self.createDataSelectionApplet()
        opDataSelection = self.dataSelectionApplet.topLevelOperator

        # see role constants, above
        role_names = ["Raw Data", "Prediction Mask"]
        opDataSelection.DatasetRoles.setValue(role_names)

        self.featureSelectionApplets = []
        self.pcApplets = []
        for i in range(n_stages):
            self.featureSelectionApplets.append(
                self.createFeatureSelectionApplet(i))
            self.pcApplets.append(self.createPixelClassificationApplet(i))
        opFinalClassify = self.pcApplets[-1].topLevelOperator

        # If *any* stage enters 'live update' mode, make sure they all enter live update mode.
        def sync_freeze_predictions_settings(slot, *args):
            freeze_predictions = slot.value
            for pcApplet in self.pcApplets:
                pcApplet.topLevelOperator.FreezePredictions.setValue(
                    freeze_predictions)

        for pcApplet in self.pcApplets:
            pcApplet.topLevelOperator.FreezePredictions.notifyDirty(
                sync_freeze_predictions_settings)

        self.dataExportApplet = PixelClassificationDataExportApplet(
            self, "Prediction Export")
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.PmapColors.connect(opFinalClassify.PmapColors)
        opDataExport.LabelNames.connect(opFinalClassify.LabelNames)
        opDataExport.WorkingDirectory.connect(opDataSelection.WorkingDirectory)

        self.EXPORT_NAMES = []
        for stage_index in reversed(list(range(n_stages))):
            self.EXPORT_NAMES += [
                "{} Stage {}".format(name, stage_index + 1)
                for name in self.EXPORT_NAMES_PER_STAGE
            ]

        # And finally, one last item for *all* probabilities from all stages.
        self.EXPORT_NAMES += ["Probabilities All Stages"]
        opDataExport.SelectionNames.setValue(self.EXPORT_NAMES)

        # Expose for shell
        self._applets.append(self.dataSelectionApplet)
        self._applets += itertools.chain(
            *list(zip(self.featureSelectionApplets, self.pcApplets)))
        self._applets.append(self.dataExportApplet)

        self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        self.dataExportApplet.post_process_entire_export = self.post_process_entire_export

        self.batchProcessingApplet = BatchProcessingApplet(
            self, "Batch Processing", self.dataSelectionApplet,
            self.dataExportApplet)

        self._applets.append(self.batchProcessingApplet)
        if unused_args:
            # We parse the export setting args first.  All remaining args are considered input files by the input applet.
            self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args(
                unused_args)
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args(
                unused_args)
        else:
            self._batch_input_args = None
            self._batch_export_args = None

        if unused_args:
            logger.warning("Unused command-line args: {}".format(unused_args))

    def createDataSelectionApplet(self):
        """
        Can be overridden by subclasses, if they want to use
        special parameters to initialize the DataSelectionApplet.
        """
        data_instructions = "Select your input data using the 'Raw Data' tab shown on the right"

        c_at_end = ["yxc", "xyc"]
        for perm in itertools.permutations("tzyx", 3):
            c_at_end.append("".join(perm) + "c")
        for perm in itertools.permutations("tzyx", 4):
            c_at_end.append("".join(perm) + "c")

        return DataSelectionApplet(
            self,
            "Input Data",
            "Input Data",
            supportIlastik05Import=False,
            instructionText=data_instructions,
            forceAxisOrder=c_at_end,
        )

    def createFeatureSelectionApplet(self, index):
        """
        Can be overridden by subclasses, if they want to return their own type of FeatureSelectionApplet.
        NOTE: The applet returned here must have the same interface as the regular FeatureSelectionApplet.
              (If it looks like a duck...)
        """
        # Make the first one compatible with the pixel classification workflow,
        # in case the user uses "Import Project"
        hdf5_group_name = "FeatureSelections"
        if index > 0:
            hdf5_group_name = "FeatureSelections{index:02d}".format(
                index=index)
        applet = FeatureSelectionApplet(self, "Feature Selection",
                                        hdf5_group_name)
        applet.topLevelOperator.name += "{}".format(index)
        return applet

    def createPixelClassificationApplet(self, index=0):
        """
        Can be overridden by subclasses, if they want to return their own type of PixelClassificationApplet.
        NOTE: The applet returned here must have the same interface as the regular PixelClassificationApplet.
              (If it looks like a duck...)
        """
        # Make the first one compatible with the pixel classification workflow,
        # in case the user uses "Import Project"
        hdf5_group_name = "PixelClassification"
        if index > 0:
            hdf5_group_name = "PixelClassification{index:02d}".format(
                index=index)
        applet = PixelClassificationApplet(self, hdf5_group_name)
        applet.topLevelOperator.name += "{}".format(index)
        return applet

    def prepareForNewLane(self, laneIndex):
        """
        Overridden from Workflow base class.
        Called immediately before a new lane is added to the workflow.
        """
        # When the new lane is added, dirty notifications will propagate throughout the entire graph.
        # This means the classifier will be marked 'dirty' even though it is still usable.
        # Before that happens, let's store the classifier, so we can restore it at the end of connectLane(), below.
        self.stored_classifers = []
        for pcApplet in self.pcApplets:
            opPixelClassification = pcApplet.topLevelOperator
            if (opPixelClassification.classifier_cache.Output.ready()
                    and not opPixelClassification.classifier_cache._dirty):
                self.stored_classifers.append(
                    opPixelClassification.classifier_cache.Output.value)
            else:
                self.stored_classifers = []

    def handleNewLanesAdded(self):
        """
        Overridden from Workflow base class.
        Called immediately after a new lane is added to the workflow and initialized.
        """
        # Restore classifier we saved in prepareForNewLane() (if any)
        if self.stored_classifers:
            for pcApplet, classifier in zip(self.pcApplets,
                                            self.stored_classifers):
                pcApplet.topLevelOperator.classifier_cache.forceValue(
                    classifier)

            # Release references
            self.stored_classifers = []

    def connectLane(self, laneIndex):
        # Get a handle to each operator
        opData = self.dataSelectionApplet.topLevelOperator.getLane(laneIndex)
        opFirstFeatures = self.featureSelectionApplets[
            0].topLevelOperator.getLane(laneIndex)
        opFirstClassify = self.pcApplets[0].topLevelOperator.getLane(laneIndex)
        opFinalClassify = self.pcApplets[-1].topLevelOperator.getLane(
            laneIndex)
        opDataExport = self.dataExportApplet.topLevelOperator.getLane(
            laneIndex)

        def checkConstraints(*_):
            # if (opData.Image.meta.dtype in [np.uint8, np.uint16]) == False:
            #    msg = "The Autocontext Workflow only supports 8-bit images (UINT8 pixel type)\n"\
            #          "or 16-bit images (UINT16 pixel type)\n"\
            #          "Your image has a pixel type of {}.  Please convert your data to UINT8 and try again."\
            #          .format( str(np.dtype(opData.Image.meta.dtype)) )
            #    raise DatasetConstraintError( "Autocontext Workflow", msg, unfixable=True )
            pass

        opData.Image.notifyReady(checkConstraints)

        # Input Image -> Feature Op
        #         and -> Classification Op (for display)
        opFirstFeatures.InputImage.connect(opData.Image)
        opFirstClassify.InputImages.connect(opData.Image)

        # Feature Images -> Classification Op (for training, prediction)
        opFirstClassify.FeatureImages.connect(opFirstFeatures.OutputImage)
        opFirstClassify.CachedFeatureImages.connect(
            opFirstFeatures.CachedOutputImage)

        upstreamPcApplets = self.pcApplets[0:-1]
        downstreamFeatureApplets = self.featureSelectionApplets[1:]
        downstreamPcApplets = self.pcApplets[1:]

        for (upstreamPcApplet, downstreamFeaturesApplet,
             downstreamPcApplet) in zip(upstreamPcApplets,
                                        downstreamFeatureApplets,
                                        downstreamPcApplets):

            opUpstreamClassify = upstreamPcApplet.topLevelOperator.getLane(
                laneIndex)
            opDownstreamFeatures = downstreamFeaturesApplet.topLevelOperator.getLane(
                laneIndex)
            opDownstreamClassify = downstreamPcApplet.topLevelOperator.getLane(
                laneIndex)

            # Connect label inputs (all are connected together).
            # opDownstreamClassify.LabelInputs.connect( opUpstreamClassify.LabelInputs )

            # Connect data path
            assert opData.Image.meta.dtype == opUpstreamClassify.PredictionProbabilitiesAutocontext.meta.dtype, (
                "Probability dtype needs to match up with input image dtype, got: "
                f"input: {opData.Image.meta.dtype} "
                f"probabilities: {opUpstreamClassify.PredictionProbabilitiesAutocontext.meta.dtype}"
            )
            opStacker = OpMultiArrayStacker(parent=self)
            opStacker.Images.resize(2)
            opStacker.Images[0].connect(opData.Image)
            opStacker.Images[1].connect(
                opUpstreamClassify.PredictionProbabilitiesAutocontext)
            opStacker.AxisFlag.setValue("c")

            opDownstreamFeatures.InputImage.connect(opStacker.Output)
            opDownstreamClassify.InputImages.connect(opStacker.Output)
            opDownstreamClassify.FeatureImages.connect(
                opDownstreamFeatures.OutputImage)
            opDownstreamClassify.CachedFeatureImages.connect(
                opDownstreamFeatures.CachedOutputImage)

        # Data Export connections
        opDataExport.RawData.connect(opData.ImageGroup[self.DATA_ROLE_RAW])
        opDataExport.RawDatasetInfo.connect(
            opData.DatasetGroup[self.DATA_ROLE_RAW])
        opDataExport.ConstraintDataset.connect(
            opData.ImageGroup[self.DATA_ROLE_RAW])

        opDataExport.Inputs.resize(len(self.EXPORT_NAMES))
        for reverse_stage_index, (stage_index, pcApplet) in enumerate(
                reversed(list(enumerate(self.pcApplets)))):
            opPc = pcApplet.topLevelOperator.getLane(laneIndex)
            num_items_per_stage = len(self.EXPORT_NAMES_PER_STAGE)
            opDataExport.Inputs[num_items_per_stage * reverse_stage_index +
                                0].connect(
                                    opPc.HeadlessPredictionProbabilities)
            opDataExport.Inputs[num_items_per_stage * reverse_stage_index +
                                1].connect(opPc.SimpleSegmentation)
            opDataExport.Inputs[num_items_per_stage * reverse_stage_index +
                                2].connect(opPc.HeadlessUncertaintyEstimate)
            opDataExport.Inputs[num_items_per_stage * reverse_stage_index +
                                3].connect(opPc.FeatureImages)
            opDataExport.Inputs[num_items_per_stage * reverse_stage_index +
                                4].connect(opPc.LabelImages)
            opDataExport.Inputs[
                num_items_per_stage * reverse_stage_index + 5].connect(
                    opPc.InputImages
                )  # Input must come last due to an assumption in PixelClassificationDataExportGui

        # One last export slot for all probabilities, all stages
        opAllStageStacker = OpMultiArrayStacker(parent=self)
        opAllStageStacker.Images.resize(len(self.pcApplets))
        for stage_index, pcApplet in enumerate(self.pcApplets):
            opPc = pcApplet.topLevelOperator.getLane(laneIndex)
            opAllStageStacker.Images[stage_index].connect(
                opPc.HeadlessPredictionProbabilities)
            opAllStageStacker.AxisFlag.setValue("c")

        # The ideal_blockshape metadata field will be bogus, so just eliminate it
        # (Otherwise, the channels could be split up in an unfortunate way.)
        opMetadataOverride = OpMetadataInjector(parent=self)
        opMetadataOverride.Input.connect(opAllStageStacker.Output)
        opMetadataOverride.Metadata.setValue({"ideal_blockshape": None})

        opDataExport.Inputs[-1].connect(opMetadataOverride.Output)

        for slot in opDataExport.Inputs:
            assert slot.upstream_slot is not None

    def handleAppletStateUpdateRequested(self):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.appletStateUpdateRequested`
        """
        # If no data, nothing else is ready.
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        input_ready = len(opDataSelection.ImageGroup
                          ) > 0 and not self.dataSelectionApplet.busy

        # First, determine various 'ready' states for each pixel classification stage (features+prediction)
        StageFlags = collections.namedtuple(
            "StageFlags",
            "input_ready features_ready invalid_classifier predictions_ready live_update_active"
        )
        stage_flags = []
        for stage_index, (featureSelectionApplet, pcApplet) in enumerate(
                zip(self.featureSelectionApplets, self.pcApplets)):
            if stage_index == 0:
                input_ready = len(opDataSelection.ImageGroup
                                  ) > 0 and not self.dataSelectionApplet.busy
            else:
                input_ready = stage_flags[stage_index - 1].predictions_ready

            opFeatureSelection = featureSelectionApplet.topLevelOperator
            featureOutput = opFeatureSelection.OutputImage
            features_ready = (
                input_ready and len(featureOutput) > 0
                and featureOutput[0].ready()
                and (TinyVector(featureOutput[0].meta.shape) > 0).all())

            opPixelClassification = pcApplet.topLevelOperator
            invalid_classifier = (
                opPixelClassification.classifier_cache.fixAtCurrent.value
                and opPixelClassification.classifier_cache.Output.ready() and
                opPixelClassification.classifier_cache.Output.value is None)

            predictions_ready = (
                features_ready and not invalid_classifier and
                len(opPixelClassification.HeadlessPredictionProbabilities) > 0
                and opPixelClassification.HeadlessPredictionProbabilities[0].
                ready() and (TinyVector(
                    opPixelClassification.HeadlessPredictionProbabilities[0].
                    meta.shape) > 0).all())

            live_update_active = not opPixelClassification.FreezePredictions.value

            stage_flags += [
                StageFlags(input_ready, features_ready, invalid_classifier,
                           predictions_ready, live_update_active)
            ]

        opDataExport = self.dataExportApplet.topLevelOperator
        opPixelClassification = self.pcApplets[0].topLevelOperator

        # Problems can occur if the features or input data are changed during live update mode.
        # Don't let the user do that.
        any_live_update = any(flags.live_update_active
                              for flags in stage_flags)

        # The user isn't allowed to touch anything while batch processing is running.
        batch_processing_busy = self.batchProcessingApplet.busy

        self._shell.setAppletEnabled(
            self.dataSelectionApplet, not any_live_update
            and not batch_processing_busy)

        for stage_index, (featureSelectionApplet, pcApplet) in enumerate(
                zip(self.featureSelectionApplets, self.pcApplets)):
            upstream_live_update = any(flags.live_update_active
                                       for flags in stage_flags[:stage_index])
            this_stage_live_update = stage_flags[
                stage_index].live_update_active
            downstream_live_update = any(flags.live_update_active
                                         for flags in stage_flags[stage_index +
                                                                  1:])

            self._shell.setAppletEnabled(
                featureSelectionApplet,
                stage_flags[stage_index].input_ready
                and not this_stage_live_update and not downstream_live_update
                and not batch_processing_busy,
            )

            self._shell.setAppletEnabled(
                pcApplet,
                stage_flags[stage_index].features_ready
                # and not downstream_live_update \ # Not necessary because live update modes are synced -- labels can't be added in live update.
                and not batch_processing_busy,
            )

        self._shell.setAppletEnabled(
            self.dataExportApplet, stage_flags[-1].predictions_ready
            and not batch_processing_busy)
        self._shell.setAppletEnabled(
            self.batchProcessingApplet, predictions_ready
            and not batch_processing_busy)

        # Lastly, check for certain "busy" conditions, during which we
        #  should prevent the shell from closing the project.
        busy = False
        busy |= self.dataSelectionApplet.busy
        busy |= any(applet.busy for applet in self.featureSelectionApplets)
        busy |= self.dataExportApplet.busy
        busy |= self.batchProcessingApplet.busy
        self._shell.enableProjectChanges(not busy)

    def onProjectLoaded(self, projectManager):
        """
        Overridden from Workflow base class.  Called by the Project Manager.

        If the user provided command-line arguments, use them to configure
        the workflow for batch mode and export all results.
        (This workflow's headless mode supports only batch mode for now.)
        """
        if self._headless:
            # In headless mode, let's see the messages from the training operator.
            logging.getLogger(
                "lazyflow.operators.classifierOperators").setLevel(
                    logging.DEBUG)

        if self.retrain:
            self._force_retrain_classifiers(projectManager)

        # Configure the data export operator.
        if self._batch_export_args:
            self.dataExportApplet.configure_operator_with_parsed_args(
                self._batch_export_args)

        if self._batch_input_args:
            for pcApplet in self.pcApplets:
                if pcApplet.topLevelOperator.classifier_cache._dirty:
                    logger.warning(
                        "At least one of your classifiers is not yet trained.  "
                        "A new classifier will be trained for this run.")
                    break

        if self._headless and self._batch_input_args and self._batch_export_args:
            logger.info("Beginning Batch Processing")
            self.batchProcessingApplet.run_export_from_parsed_args(
                self._batch_input_args)
            logger.info("Completed Batch Processing")

    def prepare_for_entire_export(self):
        # While exporting, we don't want to cache any data.
        export_selection_index = self.dataExportApplet.topLevelOperator.InputSelection.value
        export_selection_name = self.dataExportApplet.topLevelOperator.SelectionNames.value[
            export_selection_index]
        if "all stages" in export_selection_name.lower():
            # UNLESS we're exporting from more than one stage at a time.
            # In that case, the caches help avoid unnecessary work (except for the last stage)
            self.featureSelectionApplets[
                -1].topLevelOperator.BypassCache.setValue(True)
        else:
            for featureSeletionApplet in self.featureSelectionApplets:
                featureSeletionApplet.topLevelOperator.BypassCache.setValue(
                    True)

        # Unfreeze the classifier caches (ensure that we're exporting based on up-to-date labels)
        self.freeze_statuses = []
        for pcApplet in self.pcApplets:
            self.freeze_statuses.append(
                pcApplet.topLevelOperator.FreezePredictions.value)
            pcApplet.topLevelOperator.FreezePredictions.setValue(False)

    def post_process_entire_export(self):
        # While exporting, we disabled caches, but now we can enable them again.
        for featureSeletionApplet in self.featureSelectionApplets:
            featureSeletionApplet.topLevelOperator.BypassCache.setValue(False)

        # Re-freeze classifier caches (if necessary)
        for pcApplet, freeze_status in zip(self.pcApplets,
                                           self.freeze_statuses):
            pcApplet.topLevelOperator.FreezePredictions.setValue(freeze_status)

    def _force_retrain_classifiers(self, projectManager):
        # Cause the FIRST classifier to be dirty so it is forced to retrain.
        # (useful if the stored labels were changed outside ilastik)
        self.pcApplets[0].topLevelOperator.opTrain.ClassifierFactory.setDirty()

        # Unfreeze all classifier caches.
        for pcApplet in self.pcApplets:
            pcApplet.topLevelOperator.FreezePredictions.setValue(False)

        # Request the LAST classifier, which forces training
        _ = self.pcApplets[-1].topLevelOperator.Classifier.value

        # store new classifiers to project file
        projectManager.saveProject(force_all_save=False)

    def menus(self):
        """
        Overridden from Workflow base class
        """
        from PyQt5.QtWidgets import QMenu

        autocontext_menu = QMenu("Autocontext Utilities")
        distribute_action = autocontext_menu.addAction("Distribute Labels...")
        distribute_action.triggered.connect(
            self.distribute_labels_from_current_stage)

        self._autocontext_menu = (
            autocontext_menu
        )  # Must retain here as a member or else reference disappears and the menu is deleted.
        return [self._autocontext_menu]

    def distribute_labels_from_current_stage(self):
        """
        Distrubute labels from the currently viewed stage across all other stages.
        """
        # Late import.
        # (Don't import PyQt in headless mode.)
        from PyQt5.QtWidgets import QMessageBox

        current_applet = self._applets[self.shell.currentAppletIndex]
        if current_applet not in self.pcApplets:
            QMessageBox.critical(
                self.shell, "Wrong page selected",
                "The currently active page isn't a Training page.")
            return

        current_stage_index = self.pcApplets.index(current_applet)
        destination_stage_indexes, partition = self.get_label_distribution_settings(
            current_stage_index, num_stages=len(self.pcApplets))
        if destination_stage_indexes is None:
            return  # User cancelled

        current_applet = self._applets[self.shell.currentAppletIndex]
        opCurrentPixelClassification = current_applet.topLevelOperator
        num_current_stage_classes = len(
            opCurrentPixelClassification.LabelNames.value)

        # Before we get started, make sure the destination stages have the necessary label classes
        for stage_index in destination_stage_indexes:
            # Get this stage's OpPixelClassification
            opPc = self.pcApplets[stage_index].topLevelOperator

            # Copy Label Colors
            current_stage_label_colors = opCurrentPixelClassification.LabelColors.value
            new_label_colors = list(opPc.LabelColors.value)
            new_label_colors[:
                             num_current_stage_classes] = current_stage_label_colors[:
                                                                                     num_current_stage_classes]
            opPc.LabelColors.setValue(new_label_colors)

            # Copy PMap colors
            current_stage_pmap_colors = opCurrentPixelClassification.PmapColors.value
            new_pmap_colors = list(opPc.PmapColors.value)
            new_pmap_colors[:
                            num_current_stage_classes] = current_stage_pmap_colors[:
                                                                                   num_current_stage_classes]
            opPc.PmapColors.setValue(new_pmap_colors)

            # Copy Label Names
            current_stage_label_names = opCurrentPixelClassification.LabelNames.value
            new_label_names = list(opPc.LabelNames.value)
            new_label_names[:
                            num_current_stage_classes] = current_stage_label_names[:
                                                                                   num_current_stage_classes]
            opPc.LabelNames.setValue(new_label_names)

        # For each lane, copy over the labels from the source stage to the destination stages
        for lane_index in range(len(opCurrentPixelClassification.InputImages)):
            opPcLane = opCurrentPixelClassification.getLane(lane_index)

            # Gather all the labels for this lane
            blockwise_labels = {}
            nonzero_slicings = opPcLane.NonzeroLabelBlocks.value
            for block_slicing in nonzero_slicings:
                # Convert from slicing to roi-tuple so we can hash it in a dict key
                block_roi = sliceToRoi(block_slicing,
                                       opPcLane.InputImages.meta.shape)
                block_roi = tuple(map(tuple, block_roi))
                blockwise_labels[block_roi] = opPcLane.LabelImages[
                    block_slicing].wait()

            if partition and current_stage_index in destination_stage_indexes:
                # Clear all labels in the current lane, since we'll be overwriting it with a subset of labels
                # FIXME: We could implement a fast function for this in OpCompressedUserLabelArray...
                for label_value in range(1, num_current_stage_classes + 1):
                    opPcLane.opLabelPipeline.opLabelArray.clearLabel(
                        label_value)

            # Now redistribute those labels across all lanes
            for block_roi, block_labels in list(blockwise_labels.items()):
                nonzero_coords = block_labels.nonzero()

                if partition:
                    num_labels = len(nonzero_coords[0])
                    destination_stage_map = np.random.choice(
                        destination_stage_indexes, (num_labels, ))

                for stage_index in destination_stage_indexes:
                    if not partition:
                        this_stage_block_labels = block_labels
                    else:
                        # Divide into disjoint partitions
                        # Find the coordinates labels destined for this stage
                        this_stage_coords = np.transpose(nonzero_coords)[
                            destination_stage_map == stage_index]
                        this_stage_coords = tuple(
                            this_stage_coords.transpose())

                        # Extract only the labels destined for this stage
                        this_stage_block_labels = np.zeros_like(block_labels)
                        this_stage_block_labels[
                            this_stage_coords] = block_labels[
                                this_stage_coords]

                    # Get the current lane's view of this stage's OpPixelClassification
                    opPc = self.pcApplets[
                        stage_index].topLevelOperator.getLane(lane_index)

                    # Inject
                    opPc.LabelInputs[roiToSlice(
                        *block_roi)] = this_stage_block_labels

    @staticmethod
    def get_label_distribution_settings(source_stage_index, num_stages):
        # Late import.
        # (Don't import PyQt in headless mode.)
        from PyQt5.QtWidgets import QDialog, QVBoxLayout

        class LabelDistributionOptionsDlg(QDialog):
            """
            A little dialog to let the user specify how the labels should be
            distributed from the current stages to the other stages.
            """
            def __init__(self, source_stage_index, num_stages, *args,
                         **kwargs):
                super(LabelDistributionOptionsDlg,
                      self).__init__(*args, **kwargs)

                from PyQt5.QtCore import Qt
                from PyQt5.QtWidgets import QGroupBox, QCheckBox, QRadioButton, QDialogButtonBox

                self.setWindowTitle(
                    "Distributing from Stage {}".format(source_stage_index +
                                                        1))

                self.stage_checkboxes = []
                for stage_index in range(1, num_stages + 1):
                    self.stage_checkboxes.append(
                        QCheckBox("Stage {}".format(stage_index)))

                # By default, send labels back into the current stage, at least.
                self.stage_checkboxes[source_stage_index].setChecked(True)

                stage_selection_layout = QVBoxLayout()
                for checkbox in self.stage_checkboxes:
                    stage_selection_layout.addWidget(checkbox)

                stage_selection_groupbox = QGroupBox(
                    "Send labels from Stage {} to:".format(source_stage_index +
                                                           1), self)
                stage_selection_groupbox.setLayout(stage_selection_layout)

                self.copy_button = QRadioButton("Copy", self)
                self.partition_button = QRadioButton("Partition", self)
                self.partition_button.setChecked(True)
                distribution_mode_layout = QVBoxLayout()
                distribution_mode_layout.addWidget(self.copy_button)
                distribution_mode_layout.addWidget(self.partition_button)

                distribution_mode_group = QGroupBox("Distribution Mode", self)
                distribution_mode_group.setLayout(distribution_mode_layout)

                buttonbox = QDialogButtonBox(Qt.Horizontal, parent=self)
                buttonbox.setStandardButtons(QDialogButtonBox.Ok
                                             | QDialogButtonBox.Cancel)
                buttonbox.accepted.connect(self.accept)
                buttonbox.rejected.connect(self.reject)

                dlg_layout = QVBoxLayout()
                dlg_layout.addWidget(stage_selection_groupbox)
                dlg_layout.addWidget(distribution_mode_group)
                dlg_layout.addWidget(buttonbox)
                self.setLayout(dlg_layout)

            def distribution_mode(self):
                if self.copy_button.isChecked():
                    return "copy"
                if self.partition_button.isChecked():
                    return "partition"
                assert False, "Shouldn't get here."

            def destination_stages(self):
                """
                Return the list of stage_indexes (0-based) that the user checked.
                """
                return [
                    i for i, box in enumerate(self.stage_checkboxes)
                    if box.isChecked()
                ]

        dlg = LabelDistributionOptionsDlg(source_stage_index, num_stages)
        if dlg.exec_() == QDialog.Rejected:
            return (None, None)

        destination_stage_indexes = dlg.destination_stages()
        partition = dlg.distribution_mode() == "partition"
        return (destination_stage_indexes, partition)
Exemplo n.º 7
0
    def __init__(self, shell, headless, workflow_cmdline_args,
                 project_creation_args, n_stages, *args, **kwargs):
        """
        n_stages: How many iterations of feature selection and pixel classification should be inserted into the workflow.

        All other params are just as in PixelClassificationWorkflow
        """
        # Create a graph to be shared by all operators
        graph = Graph()
        super(NewAutocontextWorkflowBase, self).__init__(shell,
                                                         headless,
                                                         workflow_cmdline_args,
                                                         project_creation_args,
                                                         graph=graph,
                                                         *args,
                                                         **kwargs)
        self.stored_classifers = []
        self._applets = []
        self._workflow_cmdline_args = workflow_cmdline_args

        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument(
            "--retrain",
            help=
            "Re-train the classifier based on labels stored in project file, and re-save.",
            action="store_true",
        )

        # Parse the creation args: These were saved to the project file when this project was first created.
        parsed_creation_args, unused_args = parser.parse_known_args(
            project_creation_args)

        # Parse the cmdline args for the current session.
        parsed_args, unused_args = parser.parse_known_args(
            workflow_cmdline_args)
        self.retrain = parsed_args.retrain

        data_instructions = (
            "Select your input data using the 'Raw Data' tab shown on the right.\n\n"
            "Power users: Optionally use the 'Prediction Mask' tab to supply a binary image that tells ilastik where it should avoid computations you don't need."
        )

        self.dataSelectionApplet = self.createDataSelectionApplet()
        opDataSelection = self.dataSelectionApplet.topLevelOperator

        # see role constants, above
        role_names = ["Raw Data", "Prediction Mask"]
        opDataSelection.DatasetRoles.setValue(role_names)

        self.featureSelectionApplets = []
        self.pcApplets = []
        for i in range(n_stages):
            self.featureSelectionApplets.append(
                self.createFeatureSelectionApplet(i))
            self.pcApplets.append(self.createPixelClassificationApplet(i))
        opFinalClassify = self.pcApplets[-1].topLevelOperator

        # If *any* stage enters 'live update' mode, make sure they all enter live update mode.
        def sync_freeze_predictions_settings(slot, *args):
            freeze_predictions = slot.value
            for pcApplet in self.pcApplets:
                pcApplet.topLevelOperator.FreezePredictions.setValue(
                    freeze_predictions)

        for pcApplet in self.pcApplets:
            pcApplet.topLevelOperator.FreezePredictions.notifyDirty(
                sync_freeze_predictions_settings)

        self.dataExportApplet = PixelClassificationDataExportApplet(
            self, "Prediction Export")
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.PmapColors.connect(opFinalClassify.PmapColors)
        opDataExport.LabelNames.connect(opFinalClassify.LabelNames)
        opDataExport.WorkingDirectory.connect(opDataSelection.WorkingDirectory)

        self.EXPORT_NAMES = []
        for stage_index in reversed(list(range(n_stages))):
            self.EXPORT_NAMES += [
                "{} Stage {}".format(name, stage_index + 1)
                for name in self.EXPORT_NAMES_PER_STAGE
            ]

        # And finally, one last item for *all* probabilities from all stages.
        self.EXPORT_NAMES += ["Probabilities All Stages"]
        opDataExport.SelectionNames.setValue(self.EXPORT_NAMES)

        # Expose for shell
        self._applets.append(self.dataSelectionApplet)
        self._applets += itertools.chain(
            *list(zip(self.featureSelectionApplets, self.pcApplets)))
        self._applets.append(self.dataExportApplet)

        self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        self.dataExportApplet.post_process_entire_export = self.post_process_entire_export

        self.batchProcessingApplet = BatchProcessingApplet(
            self, "Batch Processing", self.dataSelectionApplet,
            self.dataExportApplet)

        self._applets.append(self.batchProcessingApplet)
        if unused_args:
            # We parse the export setting args first.  All remaining args are considered input files by the input applet.
            self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args(
                unused_args)
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args(
                unused_args)
        else:
            self._batch_input_args = None
            self._batch_export_args = None

        if unused_args:
            logger.warning("Unused command-line args: {}".format(unused_args))
class PixelClassificationWorkflow(Workflow):

    workflowName = "Pixel Classification"
    workflowDescription = "This is obviously self-explanatory."
    defaultAppletIndex = 1  # show DataSelection by default

    DATA_ROLE_RAW = 0
    DATA_ROLE_PREDICTION_MASK = 1

    EXPORT_NAMES = [
        'Probabilities', 'Simple Segmentation', 'Uncertainty', 'Features'
    ]

    @property
    def applets(self):
        return self._applets

    @property
    def imageNameListSlot(self):
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def __init__(self,
                 shell,
                 headless,
                 workflow_cmdline_args,
                 project_creation_args,
                 appendBatchOperators=True,
                 *args,
                 **kwargs):
        # Create a graph to be shared by all operators
        graph = Graph()
        super(PixelClassificationWorkflow,
              self).__init__(shell,
                             headless,
                             workflow_cmdline_args,
                             project_creation_args,
                             graph=graph,
                             *args,
                             **kwargs)
        self._applets = []
        self._workflow_cmdline_args = workflow_cmdline_args

        data_instructions = "Select your input data using the 'Raw Data' tab shown on the right"

        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument('--filter',
                            help="pixel feature filter implementation.",
                            choices=['Original', 'Refactored', 'Interpolated'],
                            default='Original')
        parser.add_argument(
            '--print-labels-by-slice',
            help="Print the number of labels for each Z-slice of each image.",
            action="store_true")
        parser.add_argument(
            '--label-search-value',
            help=
            "If provided, only this value is considered when using --print-labels-by-slice",
            default=0,
            type=int)
        parser.add_argument('--generate-random-labels',
                            help="Add random labels to the project file.",
                            action="store_true")
        parser.add_argument(
            '--random-label-value',
            help="The label value to use injecting random labels",
            default=1,
            type=int)
        parser.add_argument(
            '--random-label-count',
            help=
            "The number of random labels to inject via --generate-random-labels",
            default=2000,
            type=int)
        parser.add_argument(
            '--retrain',
            help=
            "Re-train the classifier based on labels stored in project file, and re-save.",
            action="store_true")

        # Parse the creation args: These were saved to the project file when this project was first created.
        parsed_creation_args, unused_args = parser.parse_known_args(
            project_creation_args)
        self.filter_implementation = parsed_creation_args.filter

        # Parse the cmdline args for the current session.
        parsed_args, unused_args = parser.parse_known_args(
            workflow_cmdline_args)
        self.print_labels_by_slice = parsed_args.print_labels_by_slice
        self.label_search_value = parsed_args.label_search_value
        self.generate_random_labels = parsed_args.generate_random_labels
        self.random_label_value = parsed_args.random_label_value
        self.random_label_count = parsed_args.random_label_count
        self.retrain = parsed_args.retrain

        if parsed_args.filter and parsed_args.filter != parsed_creation_args.filter:
            logger.error(
                "Ignoring new --filter setting.  Filter implementation cannot be changed after initial project creation."
            )

        # Applets for training (interactive) workflow
        self.projectMetadataApplet = ProjectMetadataApplet()
        self.dataSelectionApplet = DataSelectionApplet(
            self,
            "Input Data",
            "Input Data",
            supportIlastik05Import=True,
            batchDataGui=False,
            instructionText=data_instructions)
        opDataSelection = self.dataSelectionApplet.topLevelOperator

        if ilastik_config.getboolean('ilastik', 'debug'):
            # see role constants, above
            role_names = ['Raw Data', 'Prediction Mask']
            opDataSelection.DatasetRoles.setValue(role_names)
        else:
            role_names = ['Raw Data']
            opDataSelection.DatasetRoles.setValue(role_names)

        self.featureSelectionApplet = FeatureSelectionApplet(
            self, "Feature Selection", "FeatureSelections",
            self.filter_implementation)

        self.pcApplet = PixelClassificationApplet(self, "PixelClassification")
        opClassify = self.pcApplet.topLevelOperator

        self.dataExportApplet = PixelClassificationDataExportApplet(
            self, "Prediction Export")
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.PmapColors.connect(opClassify.PmapColors)
        opDataExport.LabelNames.connect(opClassify.LabelNames)
        opDataExport.WorkingDirectory.connect(opDataSelection.WorkingDirectory)
        opDataExport.SelectionNames.setValue(self.EXPORT_NAMES)

        # Expose for shell
        self._applets.append(self.projectMetadataApplet)
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.featureSelectionApplet)
        self._applets.append(self.pcApplet)
        self._applets.append(self.dataExportApplet)

        self._batch_input_args = None
        self._batch_export_args = None

        self.batchInputApplet = None
        self.batchResultsApplet = None
        if appendBatchOperators:
            # Create applets for batch workflow
            self.batchInputApplet = DataSelectionApplet(
                self,
                "Batch Prediction Input Selections",
                "Batch Inputs",
                supportIlastik05Import=False,
                batchDataGui=True)
            self.batchResultsApplet = PixelClassificationDataExportApplet(
                self, "Batch Prediction Output Locations", isBatch=True)

            # Expose in shell
            self._applets.append(self.batchInputApplet)
            self._applets.append(self.batchResultsApplet)

            # Connect batch workflow (NOT lane-based)
            self._initBatchWorkflow()

            if unused_args:
                # We parse the export setting args first.  All remaining args are considered input files by the input applet.
                self._batch_export_args, unused_args = self.batchResultsApplet.parse_known_cmdline_args(
                    unused_args)
                self._batch_input_args, unused_args = self.batchInputApplet.parse_known_cmdline_args(
                    unused_args)

        if unused_args:
            logger.warn("Unused command-line args: {}".format(unused_args))

    def connectLane(self, laneIndex):
        # Get a handle to each operator
        opData = self.dataSelectionApplet.topLevelOperator.getLane(laneIndex)
        opTrainingFeatures = self.featureSelectionApplet.topLevelOperator.getLane(
            laneIndex)
        opClassify = self.pcApplet.topLevelOperator.getLane(laneIndex)
        opDataExport = self.dataExportApplet.topLevelOperator.getLane(
            laneIndex)

        # Input Image -> Feature Op
        #         and -> Classification Op (for display)
        opTrainingFeatures.InputImage.connect(opData.Image)
        opClassify.InputImages.connect(opData.Image)

        if ilastik_config.getboolean('ilastik', 'debug'):
            opClassify.PredictionMasks.connect(
                opData.ImageGroup[self.DATA_ROLE_PREDICTION_MASK])

        # Feature Images -> Classification Op (for training, prediction)
        opClassify.FeatureImages.connect(opTrainingFeatures.OutputImage)
        opClassify.CachedFeatureImages.connect(
            opTrainingFeatures.CachedOutputImage)

        # Training flags -> Classification Op (for GUI restrictions)
        opClassify.LabelsAllowedFlags.connect(opData.AllowLabels)

        # Data Export connections
        opDataExport.RawData.connect(opData.ImageGroup[self.DATA_ROLE_RAW])
        opDataExport.RawDatasetInfo.connect(
            opData.DatasetGroup[self.DATA_ROLE_RAW])
        opDataExport.ConstraintDataset.connect(
            opData.ImageGroup[self.DATA_ROLE_RAW])
        opDataExport.Inputs.resize(len(self.EXPORT_NAMES))
        opDataExport.Inputs[0].connect(
            opClassify.HeadlessPredictionProbabilities)
        opDataExport.Inputs[1].connect(opClassify.SimpleSegmentation)
        opDataExport.Inputs[2].connect(opClassify.HeadlessUncertaintyEstimate)
        opDataExport.Inputs[3].connect(opClassify.FeatureImages)
        for slot in opDataExport.Inputs:
            assert slot.partner is not None

    def _initBatchWorkflow(self):
        """
        Connect the batch-mode top-level operators to the training workflow and to each other.
        """
        # Access applet operators from the training workflow
        opTrainingDataSelection = self.dataSelectionApplet.topLevelOperator
        opTrainingFeatures = self.featureSelectionApplet.topLevelOperator
        opClassify = self.pcApplet.topLevelOperator

        # Access the batch operators
        opBatchInputs = self.batchInputApplet.topLevelOperator
        opBatchResults = self.batchResultsApplet.topLevelOperator

        opBatchInputs.DatasetRoles.connect(
            opTrainingDataSelection.DatasetRoles)

        opSelectFirstLane = OperatorWrapper(OpSelectSubslot, parent=self)
        opSelectFirstLane.Inputs.connect(opTrainingDataSelection.ImageGroup)
        opSelectFirstLane.SubslotIndex.setValue(0)

        opSelectFirstRole = OpSelectSubslot(parent=self)
        opSelectFirstRole.Inputs.connect(opSelectFirstLane.Output)
        opSelectFirstRole.SubslotIndex.setValue(self.DATA_ROLE_RAW)

        opBatchResults.ConstraintDataset.connect(opSelectFirstRole.Output)

        ## Create additional batch workflow operators
        opBatchFeatures = OperatorWrapper(OpFeatureSelectionNoCache,
                                          operator_kwargs={
                                              'filter_implementation':
                                              self.filter_implementation
                                          },
                                          parent=self,
                                          promotedSlotNames=['InputImage'])
        opBatchPredictionPipeline = OperatorWrapper(
            OpPredictionPipelineNoCache, parent=self)

        ## Connect Operators ##
        opTranspose = OpTransposeSlots(parent=self)
        opTranspose.OutputLength.setValue(2)  # There are 2 roles
        opTranspose.Inputs.connect(opBatchInputs.DatasetGroup)
        opTranspose.name = "batchTransposeInputs"

        # Provide dataset paths from data selection applet to the batch export applet
        opBatchResults.RawDatasetInfo.connect(
            opTranspose.Outputs[self.DATA_ROLE_RAW])
        opBatchResults.WorkingDirectory.connect(opBatchInputs.WorkingDirectory)

        # Connect (clone) the feature operator inputs from
        #  the interactive workflow's features operator (which gets them from the GUI)
        opBatchFeatures.Scales.connect(opTrainingFeatures.Scales)
        opBatchFeatures.FeatureIds.connect(opTrainingFeatures.FeatureIds)
        opBatchFeatures.SelectionMatrix.connect(
            opTrainingFeatures.SelectionMatrix)

        # Classifier and NumClasses are provided by the interactive workflow
        opBatchPredictionPipeline.Classifier.connect(opClassify.Classifier)
        opBatchPredictionPipeline.NumClasses.connect(opClassify.NumClasses)

        # Provide these for the gui
        opBatchResults.RawData.connect(opBatchInputs.Image)
        opBatchResults.PmapColors.connect(opClassify.PmapColors)
        opBatchResults.LabelNames.connect(opClassify.LabelNames)

        # Connect Image pathway:
        # Input Image -> Features Op -> Prediction Op -> Export
        opBatchFeatures.InputImage.connect(opBatchInputs.Image)
        opBatchPredictionPipeline.PredictionMask.connect(opBatchInputs.Image1)
        opBatchPredictionPipeline.FeatureImages.connect(
            opBatchFeatures.OutputImage)

        opBatchResults.SelectionNames.setValue(self.EXPORT_NAMES)
        # opBatchResults.Inputs is indexed by [lane][selection],
        # Use OpTranspose to allow connection.
        opTransposeBatchInputs = OpTransposeSlots(parent=self)
        opTransposeBatchInputs.name = "opTransposeBatchInputs"
        opTransposeBatchInputs.OutputLength.setValue(0)
        opTransposeBatchInputs.Inputs.resize(len(self.EXPORT_NAMES))
        opTransposeBatchInputs.Inputs[0].connect(
            opBatchPredictionPipeline.HeadlessPredictionProbabilities
        )  # selection 0
        opTransposeBatchInputs.Inputs[1].connect(
            opBatchPredictionPipeline.SimpleSegmentation)  # selection 1
        opTransposeBatchInputs.Inputs[2].connect(
            opBatchPredictionPipeline.HeadlessUncertaintyEstimate
        )  # selection 2
        opTransposeBatchInputs.Inputs[3].connect(
            opBatchPredictionPipeline.FeatureImages)  # selection 3
        for slot in opTransposeBatchInputs.Inputs:
            assert slot.partner is not None

        # Now opTransposeBatchInputs.Outputs is level-2 indexed by [lane][selection]
        opBatchResults.Inputs.connect(opTransposeBatchInputs.Outputs)

        # We don't actually need the cached path in the batch pipeline.
        # Just connect the uncached features here to satisfy the operator.
        #opBatchPredictionPipeline.CachedFeatureImages.connect( opBatchFeatures.OutputImage )

        self.opBatchFeatures = opBatchFeatures
        self.opBatchPredictionPipeline = opBatchPredictionPipeline

    def handleAppletStateUpdateRequested(self):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.appletStateUpdateRequested`
        """
        # If no data, nothing else is ready.
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        input_ready = len(opDataSelection.ImageGroup
                          ) > 0 and not self.dataSelectionApplet.busy

        opFeatureSelection = self.featureSelectionApplet.topLevelOperator
        featureOutput = opFeatureSelection.OutputImage
        features_ready = input_ready and \
                         len(featureOutput) > 0 and  \
                         featureOutput[0].ready() and \
                         (TinyVector(featureOutput[0].meta.shape) > 0).all()

        opDataExport = self.dataExportApplet.topLevelOperator
        opPixelClassification = self.pcApplet.topLevelOperator

        invalid_classifier = opPixelClassification.classifier_cache.fixAtCurrent.value and \
                             opPixelClassification.classifier_cache.Output.ready() and\
                             opPixelClassification.classifier_cache.Output.value is None

        predictions_ready = features_ready and \
                            not invalid_classifier and \
                            len(opDataExport.Inputs) > 0 and \
                            opDataExport.Inputs[0][0].ready() and \
                            (TinyVector(opDataExport.Inputs[0][0].meta.shape) > 0).all()

        # Problems can occur if the features or input data are changed during live update mode.
        # Don't let the user do that.
        live_update_active = not opPixelClassification.FreezePredictions.value

        self._shell.setAppletEnabled(self.dataSelectionApplet,
                                     not live_update_active)
        self._shell.setAppletEnabled(self.featureSelectionApplet, input_ready
                                     and not live_update_active)
        self._shell.setAppletEnabled(self.pcApplet, features_ready)
        self._shell.setAppletEnabled(self.dataExportApplet, predictions_ready)

        if self.batchInputApplet is not None:
            # Training workflow must be fully configured before batch can be used
            self._shell.setAppletEnabled(self.batchInputApplet,
                                         predictions_ready)

            opBatchDataSelection = self.batchInputApplet.topLevelOperator
            batch_input_ready = predictions_ready and \
                                len(opBatchDataSelection.ImageGroup) > 0
            self._shell.setAppletEnabled(self.batchResultsApplet,
                                         batch_input_ready)

        # Lastly, check for certain "busy" conditions, during which we
        #  should prevent the shell from closing the project.
        busy = False
        busy |= self.dataSelectionApplet.busy
        busy |= self.featureSelectionApplet.busy
        busy |= self.dataExportApplet.busy
        self._shell.enableProjectChanges(not busy)

    def getHeadlessOutputSlot(self, slotId):
        # "Regular" (i.e. with the images that the user selected as input data)
        if slotId == "Predictions":
            return self.pcApplet.topLevelOperator.HeadlessPredictionProbabilities
        elif slotId == "PredictionsUint8":
            return self.pcApplet.topLevelOperator.HeadlessUint8PredictionProbabilities
        # "Batch" (i.e. with the images that the user selected as batch inputs).
        elif slotId == "BatchPredictions":
            return self.opBatchPredictionPipeline.HeadlessPredictionProbabilities
        if slotId == "BatchPredictionsUint8":
            return self.opBatchPredictionPipeline.HeadlessUint8PredictionProbabilities

        raise Exception("Unknown headless output slot")

    def onProjectLoaded(self, projectManager):
        """
        Overridden from Workflow base class.  Called by the Project Manager.
        
        If the user provided command-line arguments, use them to configure 
        the workflow for batch mode and export all results.
        (This workflow's headless mode supports only batch mode for now.)
        """
        if self.generate_random_labels:
            self._generate_random_labels(self.random_label_count,
                                         self.random_label_value)
            logger.info("Saving project...")
            self._shell.projectManager.saveProject()
            logger.info("Done.")

        if self.print_labels_by_slice:
            self._print_labels_by_slice(self.label_search_value)

        # Configure the batch data selection operator.
        if self._batch_input_args and (self._batch_input_args.input_files
                                       or self._batch_input_args.raw_data):
            self.batchInputApplet.configure_operator_with_parsed_args(
                self._batch_input_args)

        # Configure the data export operator.
        if self._batch_export_args:
            self.batchResultsApplet.configure_operator_with_parsed_args(
                self._batch_export_args)

        if self._batch_input_args and self.pcApplet.topLevelOperator.classifier_cache._dirty:
            logger.warn(
                "Your project file has no classifier.  A new classifier will be trained for this run."
            )

        # Let's see the messages from the training operator.
        logging.getLogger("lazyflow.operators.classifierOperators").setLevel(
            logging.DEBUG)

        if self.retrain:
            # Cause the classifier to be dirty so it is forced to retrain.
            # (useful if the stored labels were changed outside ilastik)
            self.pcApplet.topLevelOperator.opTrain.ClassifierFactory.setDirty()

            # Request the classifier, which forces training
            self.pcApplet.topLevelOperator.FreezePredictions.setValue(False)
            _ = self.pcApplet.topLevelOperator.Classifier.value

            # store new classifier to project file
            projectManager.saveProject(force_all_save=False)

        if self._headless and self._batch_input_args and self._batch_export_args:
            # Make sure we're using the up-to-date classifier.
            self.pcApplet.topLevelOperator.FreezePredictions.setValue(False)

            # Now run the batch export and report progress....
            opBatchDataExport = self.batchResultsApplet.topLevelOperator
            for i, opExportDataLaneView in enumerate(opBatchDataExport):
                logger.info("Exporting result {} to {}".format(
                    i, opExportDataLaneView.ExportPath.value))

                sys.stdout.write("Result {}/{} Progress: ".format(
                    i, len(opBatchDataExport)))
                sys.stdout.flush()

                def print_progress(progress):
                    sys.stdout.write("{} ".format(progress))
                    sys.stdout.flush()

                # If the operator provides a progress signal, use it.
                slotProgressSignal = opExportDataLaneView.progressSignal
                slotProgressSignal.subscribe(print_progress)
                opExportDataLaneView.run_export()

                # Finished.
                sys.stdout.write("\n")

    def _print_labels_by_slice(self, search_value):
        """
        Iterate over each label image in the project and print the number of labels present on each Z-slice of the image.
        (This is a special feature requested by the FlyEM proofreaders.)
        """
        opTopLevelClassify = self.pcApplet.topLevelOperator
        project_label_count = 0
        for image_index, label_slot in enumerate(
                opTopLevelClassify.LabelImages):
            tagged_shape = label_slot.meta.getTaggedShape()
            if 'z' not in tagged_shape:
                logger.error(
                    "Can't print label counts by Z-slices.  Image #{} has no Z-dimension."
                    .format(image_index))
            else:
                logger.info("Label counts in Z-slices of Image #{}:".format(
                    image_index))
                slicing = [slice(None)] * len(tagged_shape)
                blank_slices = []
                image_label_count = 0
                for z in range(tagged_shape['z']):
                    slicing[tagged_shape.keys().index('z')] = slice(z, z + 1)
                    label_slice = label_slot[slicing].wait()
                    if search_value:
                        count = (label_slice == search_value).sum()
                    else:
                        count = (label_slice != 0).sum()
                    if count > 0:
                        logger.info("Z={}: {}".format(z, count))
                        image_label_count += count
                    else:
                        blank_slices.append(z)
                project_label_count += image_label_count
                if len(blank_slices) > 20:
                    # Don't list the blank slices if there were a lot of them.
                    logger.info("Image #{} has {} blank slices.".format(
                        image_index, len(blank_slices)))
                elif len(blank_slices) > 0:
                    logger.info("Image #{} has {} blank slices: {}".format(
                        image_index, len(blank_slices), blank_slices))
                else:
                    logger.info(
                        "Image #{} has no blank slices.".format(image_index))
                logger.info("Total labels for Image #{}: {}".format(
                    image_index, image_label_count))
        logger.info("Total labels for project: {}".format(project_label_count))

    def _generate_random_labels(self, labels_per_image, label_value):
        """
        Inject random labels into the project file.
        (This is a special feature requested by the FlyEM proofreaders.)
        """
        logger.info("Injecting {} labels of value {} into all images.".format(
            labels_per_image, label_value))
        opTopLevelClassify = self.pcApplet.topLevelOperator

        label_names = copy.copy(opTopLevelClassify.LabelNames.value)
        while len(label_names) < label_value:
            label_names.append("Label {}".format(len(label_names) + 1))

        opTopLevelClassify.LabelNames.setValue(label_names)

        for image_index in range(len(opTopLevelClassify.LabelImages)):
            logger.info("Injecting labels into image #{}".format(image_index))
            # For reproducibility of label generation
            SEED = 1
            numpy.random.seed([SEED, image_index])

            label_input_slot = opTopLevelClassify.LabelInputs[image_index]
            label_output_slot = opTopLevelClassify.LabelImages[image_index]

            shape = label_output_slot.meta.shape
            random_labels = numpy.zeros(shape=shape, dtype=numpy.uint8)
            num_pixels = len(random_labels.flat)
            current_progress = -1
            for sample_index in range(labels_per_image):
                flat_index = numpy.random.randint(0, num_pixels)
                # Don't overwrite existing labels
                # Keep looking until we find a blank pixel
                while random_labels.flat[flat_index]:
                    flat_index = numpy.random.randint(0, num_pixels)
                random_labels.flat[flat_index] = label_value

                # Print progress every 10%
                progress = float(sample_index) / labels_per_image
                progress = 10 * (int(100 * progress) / 10)
                if progress != current_progress:
                    current_progress = progress
                    sys.stdout.write("{}% ".format(current_progress))
                    sys.stdout.flush()

            sys.stdout.write("100%\n")
            # Write into the operator
            label_input_slot[fullSlicing(shape)] = random_labels

        logger.info("Done injecting labels")
    def __init__(self,
                 shell,
                 headless,
                 workflow_cmdline_args,
                 project_creation_args,
                 appendBatchOperators=True,
                 *args,
                 **kwargs):
        # Create a graph to be shared by all operators
        graph = Graph()
        super(PixelClassificationWorkflow,
              self).__init__(shell,
                             headless,
                             workflow_cmdline_args,
                             project_creation_args,
                             graph=graph,
                             *args,
                             **kwargs)
        self._applets = []
        self._workflow_cmdline_args = workflow_cmdline_args

        data_instructions = "Select your input data using the 'Raw Data' tab shown on the right"

        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument('--filter',
                            help="pixel feature filter implementation.",
                            choices=['Original', 'Refactored', 'Interpolated'],
                            default='Original')
        parser.add_argument(
            '--print-labels-by-slice',
            help="Print the number of labels for each Z-slice of each image.",
            action="store_true")
        parser.add_argument(
            '--label-search-value',
            help=
            "If provided, only this value is considered when using --print-labels-by-slice",
            default=0,
            type=int)
        parser.add_argument('--generate-random-labels',
                            help="Add random labels to the project file.",
                            action="store_true")
        parser.add_argument(
            '--random-label-value',
            help="The label value to use injecting random labels",
            default=1,
            type=int)
        parser.add_argument(
            '--random-label-count',
            help=
            "The number of random labels to inject via --generate-random-labels",
            default=2000,
            type=int)
        parser.add_argument(
            '--retrain',
            help=
            "Re-train the classifier based on labels stored in project file, and re-save.",
            action="store_true")

        # Parse the creation args: These were saved to the project file when this project was first created.
        parsed_creation_args, unused_args = parser.parse_known_args(
            project_creation_args)
        self.filter_implementation = parsed_creation_args.filter

        # Parse the cmdline args for the current session.
        parsed_args, unused_args = parser.parse_known_args(
            workflow_cmdline_args)
        self.print_labels_by_slice = parsed_args.print_labels_by_slice
        self.label_search_value = parsed_args.label_search_value
        self.generate_random_labels = parsed_args.generate_random_labels
        self.random_label_value = parsed_args.random_label_value
        self.random_label_count = parsed_args.random_label_count
        self.retrain = parsed_args.retrain

        if parsed_args.filter and parsed_args.filter != parsed_creation_args.filter:
            logger.error(
                "Ignoring new --filter setting.  Filter implementation cannot be changed after initial project creation."
            )

        # Applets for training (interactive) workflow
        self.projectMetadataApplet = ProjectMetadataApplet()
        self.dataSelectionApplet = DataSelectionApplet(
            self,
            "Input Data",
            "Input Data",
            supportIlastik05Import=True,
            batchDataGui=False,
            instructionText=data_instructions)
        opDataSelection = self.dataSelectionApplet.topLevelOperator

        if ilastik_config.getboolean('ilastik', 'debug'):
            # see role constants, above
            role_names = ['Raw Data', 'Prediction Mask']
            opDataSelection.DatasetRoles.setValue(role_names)
        else:
            role_names = ['Raw Data']
            opDataSelection.DatasetRoles.setValue(role_names)

        self.featureSelectionApplet = FeatureSelectionApplet(
            self, "Feature Selection", "FeatureSelections",
            self.filter_implementation)

        self.pcApplet = PixelClassificationApplet(self, "PixelClassification")
        opClassify = self.pcApplet.topLevelOperator

        self.dataExportApplet = PixelClassificationDataExportApplet(
            self, "Prediction Export")
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.PmapColors.connect(opClassify.PmapColors)
        opDataExport.LabelNames.connect(opClassify.LabelNames)
        opDataExport.WorkingDirectory.connect(opDataSelection.WorkingDirectory)
        opDataExport.SelectionNames.setValue(self.EXPORT_NAMES)

        # Expose for shell
        self._applets.append(self.projectMetadataApplet)
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.featureSelectionApplet)
        self._applets.append(self.pcApplet)
        self._applets.append(self.dataExportApplet)

        self._batch_input_args = None
        self._batch_export_args = None

        self.batchInputApplet = None
        self.batchResultsApplet = None
        if appendBatchOperators:
            # Create applets for batch workflow
            self.batchInputApplet = DataSelectionApplet(
                self,
                "Batch Prediction Input Selections",
                "Batch Inputs",
                supportIlastik05Import=False,
                batchDataGui=True)
            self.batchResultsApplet = PixelClassificationDataExportApplet(
                self, "Batch Prediction Output Locations", isBatch=True)

            # Expose in shell
            self._applets.append(self.batchInputApplet)
            self._applets.append(self.batchResultsApplet)

            # Connect batch workflow (NOT lane-based)
            self._initBatchWorkflow()

            if unused_args:
                # We parse the export setting args first.  All remaining args are considered input files by the input applet.
                self._batch_export_args, unused_args = self.batchResultsApplet.parse_known_cmdline_args(
                    unused_args)
                self._batch_input_args, unused_args = self.batchInputApplet.parse_known_cmdline_args(
                    unused_args)

        if unused_args:
            logger.warn("Unused command-line args: {}".format(unused_args))
Exemplo n.º 10
0
    def __init__(self, shell, headless, workflow_cmdline_args, project_creation_args, *args, **kwargs):
        # Create a graph to be shared by all operators
        graph = Graph()
        super( PixelClassificationWorkflow, self ).__init__( shell, headless, workflow_cmdline_args, project_creation_args, graph=graph, *args, **kwargs )
        self.stored_classifier = None
        self._applets = []
        self._workflow_cmdline_args = workflow_cmdline_args
        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument('--print-labels-by-slice', help="Print the number of labels for each Z-slice of each image.", action="store_true")
        parser.add_argument('--label-search-value', help="If provided, only this value is considered when using --print-labels-by-slice", default=0, type=int)
        parser.add_argument('--generate-random-labels', help="Add random labels to the project file.", action="store_true")
        parser.add_argument('--random-label-value', help="The label value to use injecting random labels", default=1, type=int)
        parser.add_argument('--random-label-count', help="The number of random labels to inject via --generate-random-labels", default=2000, type=int)
        parser.add_argument('--retrain', help="Re-train the classifier based on labels stored in project file, and re-save.", action="store_true")
        parser.add_argument('--tree-count', help='Number of trees for Vigra RF classifier.', type=int)
        parser.add_argument('--variable-importance-path', help='Location of variable-importance table.', type=str)
        parser.add_argument('--label-proportion', help='Proportion of feature-pixels used to train the classifier.', type=float)

        # Parse the creation args: These were saved to the project file when this project was first created.
        parsed_creation_args, unused_args = parser.parse_known_args(project_creation_args)

        # Parse the cmdline args for the current session.
        parsed_args, unused_args = parser.parse_known_args(workflow_cmdline_args)
        self.print_labels_by_slice = parsed_args.print_labels_by_slice
        self.label_search_value = parsed_args.label_search_value
        self.generate_random_labels = parsed_args.generate_random_labels
        self.random_label_value = parsed_args.random_label_value
        self.random_label_count = parsed_args.random_label_count
        self.retrain = parsed_args.retrain
        self.tree_count = parsed_args.tree_count
        self.variable_importance_path = parsed_args.variable_importance_path
        self.label_proportion = parsed_args.label_proportion

        data_instructions = "Select your input data using the 'Raw Data' tab shown on the right.\n\n"\
                            "Power users: Optionally use the 'Prediction Mask' tab to supply a binary image that tells ilastik where it should avoid computations you don't need."

        # Applets for training (interactive) workflow
        self.dataSelectionApplet = self.createDataSelectionApplet()
        opDataSelection = self.dataSelectionApplet.topLevelOperator

        self.featureSelectionApplet = self.createFeatureSelectionApplet()

        self.pcApplet = self.createPixelClassificationApplet()
        opClassify = self.pcApplet.topLevelOperator

        self.dataExportApplet = PixelClassificationDataExportApplet(self, "Prediction Export")
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.PmapColors.connect( opClassify.PmapColors )
        opDataExport.LabelNames.connect( opClassify.LabelNames )
        opDataExport.WorkingDirectory.connect( opDataSelection.WorkingDirectory )
        opDataExport.SelectionNames.setValue( self.ExportNames.asDisplayNameList() )

        # Expose for shell
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.featureSelectionApplet)
        self._applets.append(self.pcApplet)
        self._applets.append(self.dataExportApplet)

        self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        self.dataExportApplet.post_process_entire_export = self.post_process_entire_export

        self.batchProcessingApplet = BatchProcessingApplet(self,
                                                           "Batch Processing",
                                                           self.dataSelectionApplet,
                                                           self.dataExportApplet)

        self._applets.append(self.batchProcessingApplet)
        if unused_args:
            # We parse the export setting args first.  All remaining args are considered input files by the input applet.
            self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args( unused_args )
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args( unused_args )
        else:
            self._batch_input_args = None
            self._batch_export_args = None

        if unused_args:
            logger.warning("Unused command-line args: {}".format( unused_args ))
Exemplo n.º 11
0
    def __init__(self, shell, headless, workflow_cmdline_args, appendBatchOperators=True, *args, **kwargs):
        # Create a graph to be shared by all operators
        graph = Graph()
        super( PixelClassificationWorkflow, self ).__init__( shell, headless, graph=graph, *args, **kwargs )
        self._applets = []
        self._workflow_cmdline_args = workflow_cmdline_args

        data_instructions = "Select your input data using the 'Raw Data' tab shown on the right"

        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument('--filter', help="pixel feature filter implementation.", choices=['Original', 'Refactored', 'Interpolated'], default='Original')
        parsed_args, unused_args = parser.parse_known_args(workflow_cmdline_args)
        self.filter_implementation = parsed_args.filter
        
        # Applets for training (interactive) workflow 
        self.projectMetadataApplet = ProjectMetadataApplet()
        self.dataSelectionApplet = DataSelectionApplet( self,
                                                        "Input Data",
                                                        "Input Data",
                                                        supportIlastik05Import=True,
                                                        batchDataGui=False,
                                                        instructionText=data_instructions )
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        opDataSelection.DatasetRoles.setValue( ['Raw Data'] )

        self.featureSelectionApplet = FeatureSelectionApplet(self, "Feature Selection", "FeatureSelections", self.filter_implementation)

        self.pcApplet = PixelClassificationApplet(self, "PixelClassification")
        opClassify = self.pcApplet.topLevelOperator

        self.dataExportApplet = PixelClassificationDataExportApplet(self, "Prediction Export")
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.PmapColors.connect( opClassify.PmapColors )
        opDataExport.LabelNames.connect( opClassify.LabelNames )
        opDataExport.WorkingDirectory.connect( opDataSelection.WorkingDirectory )

        # Expose for shell
        self._applets.append(self.projectMetadataApplet)
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.featureSelectionApplet)
        self._applets.append(self.pcApplet)
        self._applets.append(self.dataExportApplet)

        self._batch_input_args = None
        self._batch_export_args = None

        self.batchInputApplet = None
        self.batchResultsApplet = None
        if appendBatchOperators:
            # Create applets for batch workflow
            self.batchInputApplet = DataSelectionApplet(self, "Batch Prediction Input Selections", "Batch Inputs", supportIlastik05Import=False, batchDataGui=True)
            self.batchResultsApplet = PixelClassificationDataExportApplet(self, "Batch Prediction Output Locations", isBatch=True)
    
            # Expose in shell        
            self._applets.append(self.batchInputApplet)
            self._applets.append(self.batchResultsApplet)
    
            # Connect batch workflow (NOT lane-based)
            self._initBatchWorkflow()

            if unused_args:
                # We parse the export setting args first.  All remaining args are considered input files by the input applet.
                self._batch_export_args, unused_args = self.batchResultsApplet.parse_known_cmdline_args( unused_args )
                self._batch_input_args, unused_args = self.batchInputApplet.parse_known_cmdline_args( unused_args )
    
        if unused_args:
            logger.warn("Unused command-line args: {}".format( unused_args ))
Exemplo n.º 12
0
class PixelClassificationWorkflow(Workflow):
    
    workflowName = "Pixel Classification"
    workflowDescription = "This is obviously self-explanoratory."
    defaultAppletIndex = 1 # show DataSelection by default
    
    @property
    def applets(self):
        return self._applets

    @property
    def imageNameListSlot(self):
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def __init__(self, shell, headless, workflow_cmdline_args, appendBatchOperators=True, *args, **kwargs):
        # Create a graph to be shared by all operators
        graph = Graph()
        super( PixelClassificationWorkflow, self ).__init__( shell, headless, graph=graph, *args, **kwargs )
        self._applets = []
        self._workflow_cmdline_args = workflow_cmdline_args

        data_instructions = "Select your input data using the 'Raw Data' tab shown on the right"

        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument('--filter', help="pixel feature filter implementation.", choices=['Original', 'Refactored', 'Interpolated'], default='Original')
        parsed_args, unused_args = parser.parse_known_args(workflow_cmdline_args)
        self.filter_implementation = parsed_args.filter
        
        # Applets for training (interactive) workflow 
        self.projectMetadataApplet = ProjectMetadataApplet()
        self.dataSelectionApplet = DataSelectionApplet( self,
                                                        "Input Data",
                                                        "Input Data",
                                                        supportIlastik05Import=True,
                                                        batchDataGui=False,
                                                        instructionText=data_instructions )
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        opDataSelection.DatasetRoles.setValue( ['Raw Data'] )

        self.featureSelectionApplet = FeatureSelectionApplet(self, "Feature Selection", "FeatureSelections", self.filter_implementation)

        self.pcApplet = PixelClassificationApplet(self, "PixelClassification")
        opClassify = self.pcApplet.topLevelOperator

        self.dataExportApplet = PixelClassificationDataExportApplet(self, "Prediction Export")
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.PmapColors.connect( opClassify.PmapColors )
        opDataExport.LabelNames.connect( opClassify.LabelNames )
        opDataExport.WorkingDirectory.connect( opDataSelection.WorkingDirectory )

        # Expose for shell
        self._applets.append(self.projectMetadataApplet)
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.featureSelectionApplet)
        self._applets.append(self.pcApplet)
        self._applets.append(self.dataExportApplet)

        self._batch_input_args = None
        self._batch_export_args = None

        self.batchInputApplet = None
        self.batchResultsApplet = None
        if appendBatchOperators:
            # Create applets for batch workflow
            self.batchInputApplet = DataSelectionApplet(self, "Batch Prediction Input Selections", "Batch Inputs", supportIlastik05Import=False, batchDataGui=True)
            self.batchResultsApplet = PixelClassificationDataExportApplet(self, "Batch Prediction Output Locations", isBatch=True)
    
            # Expose in shell        
            self._applets.append(self.batchInputApplet)
            self._applets.append(self.batchResultsApplet)
    
            # Connect batch workflow (NOT lane-based)
            self._initBatchWorkflow()

            if unused_args:
                # We parse the export setting args first.  All remaining args are considered input files by the input applet.
                self._batch_export_args, unused_args = self.batchResultsApplet.parse_known_cmdline_args( unused_args )
                self._batch_input_args, unused_args = self.batchInputApplet.parse_known_cmdline_args( unused_args )
    
        if unused_args:
            logger.warn("Unused command-line args: {}".format( unused_args ))

    def connectLane(self, laneIndex):
        # Get a handle to each operator
        opData = self.dataSelectionApplet.topLevelOperator.getLane(laneIndex)
        opTrainingFeatures = self.featureSelectionApplet.topLevelOperator.getLane(laneIndex)
        opClassify = self.pcApplet.topLevelOperator.getLane(laneIndex)
        opDataExport = self.dataExportApplet.topLevelOperator.getLane(laneIndex)
        
        # Input Image -> Feature Op
        #         and -> Classification Op (for display)
        opTrainingFeatures.InputImage.connect( opData.Image )
        opClassify.InputImages.connect( opData.Image )
        
        # Feature Images -> Classification Op (for training, prediction)
        opClassify.FeatureImages.connect( opTrainingFeatures.OutputImage )
        opClassify.CachedFeatureImages.connect( opTrainingFeatures.CachedOutputImage )
        
        # Training flags -> Classification Op (for GUI restrictions)
        opClassify.LabelsAllowedFlags.connect( opData.AllowLabels )

        # Data Export connections
        opDataExport.RawData.connect( opData.ImageGroup[0] )
        opDataExport.Input.connect( opClassify.HeadlessPredictionProbabilities )
        opDataExport.RawDatasetInfo.connect( opData.DatasetGroup[0] )
        opDataExport.ConstraintDataset.connect( opData.ImageGroup[0] )

    def _initBatchWorkflow(self):
        """
        Connect the batch-mode top-level operators to the training workflow and to each other.
        """
        # Access applet operators from the training workflow
        opTrainingDataSelection = self.dataSelectionApplet.topLevelOperator
        opTrainingFeatures = self.featureSelectionApplet.topLevelOperator
        opClassify = self.pcApplet.topLevelOperator
        
        # Access the batch operators
        opBatchInputs = self.batchInputApplet.topLevelOperator
        opBatchResults = self.batchResultsApplet.topLevelOperator
        
        opBatchInputs.DatasetRoles.connect( opTrainingDataSelection.DatasetRoles )
        
        opSelectFirstLane = OperatorWrapper( OpSelectSubslot, parent=self )
        opSelectFirstLane.Inputs.connect( opTrainingDataSelection.ImageGroup )
        opSelectFirstLane.SubslotIndex.setValue(0)
        
        opSelectFirstRole = OpSelectSubslot( parent=self )
        opSelectFirstRole.Inputs.connect( opSelectFirstLane.Output )
        opSelectFirstRole.SubslotIndex.setValue(0)
        
        opBatchResults.ConstraintDataset.connect( opSelectFirstRole.Output )
        
        ## Create additional batch workflow operators
        opBatchFeatures = OperatorWrapper( OpFeatureSelectionNoCache, operator_kwargs={'filter_implementation': self.filter_implementation}, parent=self, promotedSlotNames=['InputImage'] )
        opBatchPredictionPipeline = OperatorWrapper( OpPredictionPipelineNoCache, parent=self )
        
        ## Connect Operators ##
        opTranspose = OpTransposeSlots( parent=self )
        opTranspose.OutputLength.setValue(1)
        opTranspose.Inputs.connect( opBatchInputs.DatasetGroup )
        
        # Provide dataset paths from data selection applet to the batch export applet
        opBatchResults.RawDatasetInfo.connect( opTranspose.Outputs[0] )
        opBatchResults.WorkingDirectory.connect( opBatchInputs.WorkingDirectory )
        
        # Connect (clone) the feature operator inputs from 
        #  the interactive workflow's features operator (which gets them from the GUI)
        opBatchFeatures.Scales.connect( opTrainingFeatures.Scales )
        opBatchFeatures.FeatureIds.connect( opTrainingFeatures.FeatureIds )
        opBatchFeatures.SelectionMatrix.connect( opTrainingFeatures.SelectionMatrix )
        
        # Classifier and NumClasses are provided by the interactive workflow
        opBatchPredictionPipeline.Classifier.connect( opClassify.Classifier )
        opBatchPredictionPipeline.FreezePredictions.setValue( False )
        opBatchPredictionPipeline.NumClasses.connect( opClassify.NumClasses )
        
        # Provide these for the gui
        opBatchResults.RawData.connect( opBatchInputs.Image )
        opBatchResults.PmapColors.connect( opClassify.PmapColors )
        opBatchResults.LabelNames.connect( opClassify.LabelNames )
        
        # Connect Image pathway:
        # Input Image -> Features Op -> Prediction Op -> Export
        opBatchFeatures.InputImage.connect( opBatchInputs.Image )
        opBatchPredictionPipeline.FeatureImages.connect( opBatchFeatures.OutputImage )
        opBatchResults.Input.connect( opBatchPredictionPipeline.HeadlessPredictionProbabilities )

        # We don't actually need the cached path in the batch pipeline.
        # Just connect the uncached features here to satisfy the operator.
        #opBatchPredictionPipeline.CachedFeatureImages.connect( opBatchFeatures.OutputImage )

        self.opBatchPredictionPipeline = opBatchPredictionPipeline

    def handleAppletStateUpdateRequested(self):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.appletStateUpdateRequested`
        """
        # If no data, nothing else is ready.
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        input_ready = len(opDataSelection.ImageGroup) > 0 and not self.dataSelectionApplet.busy

        opFeatureSelection = self.featureSelectionApplet.topLevelOperator
        featureOutput = opFeatureSelection.OutputImage
        features_ready = input_ready and \
                         len(featureOutput) > 0 and  \
                         featureOutput[0].ready() and \
                         (TinyVector(featureOutput[0].meta.shape) > 0).all()

        opDataExport = self.dataExportApplet.topLevelOperator
        predictions_ready = features_ready and \
                            len(opDataExport.Input) > 0 and \
                            opDataExport.Input[0].ready() and \
                            (TinyVector(opDataExport.Input[0].meta.shape) > 0).all()

        # Problems can occur if the features or input data are changed during live update mode.
        # Don't let the user do that.
        opPixelClassification = self.pcApplet.topLevelOperator
        live_update_active = not opPixelClassification.FreezePredictions.value

        self._shell.setAppletEnabled(self.dataSelectionApplet, not live_update_active)
        self._shell.setAppletEnabled(self.featureSelectionApplet, input_ready and not live_update_active)
        self._shell.setAppletEnabled(self.pcApplet, features_ready)
        self._shell.setAppletEnabled(self.dataExportApplet, predictions_ready)

        if self.batchInputApplet is not None:
            # Training workflow must be fully configured before batch can be used
            self._shell.setAppletEnabled(self.batchInputApplet, predictions_ready)
    
            opBatchDataSelection = self.batchInputApplet.topLevelOperator
            batch_input_ready = predictions_ready and \
                                len(opBatchDataSelection.ImageGroup) > 0
            self._shell.setAppletEnabled(self.batchResultsApplet, batch_input_ready)
            
        # Lastly, check for certain "busy" conditions, during which we 
        #  should prevent the shell from closing the project.
        busy = False
        busy |= self.dataSelectionApplet.busy
        busy |= self.featureSelectionApplet.busy
        busy |= self.dataExportApplet.busy
        self._shell.enableProjectChanges( not busy )

    def getHeadlessOutputSlot(self, slotId):
        # "Regular" (i.e. with the images that the user selected as input data)
        if slotId == "Predictions":
            return self.pcApplet.topLevelOperator.HeadlessPredictionProbabilities
        elif slotId == "PredictionsUint8":
            return self.pcApplet.topLevelOperator.HeadlessUint8PredictionProbabilities
        # "Batch" (i.e. with the images that the user selected as batch inputs).
        elif slotId == "BatchPredictions":
            return self.opBatchPredictionPipeline.HeadlessPredictionProbabilities
        if slotId == "BatchPredictionsUint8":
            return self.opBatchPredictionPipeline.HeadlessUint8PredictionProbabilities
        
        raise Exception("Unknown headless output slot")
    
    def onProjectLoaded(self, projectManager):
        """
        Overridden from Workflow base class.  Called by the Project Manager.
        
        If the user provided command-line arguments, use them to configure 
        the workflow for batch mode and export all results.
        (This workflow's headless mode supports only batch mode for now.)
        """
        # Configure the batch data selection operator.
        if self._batch_input_args and self._batch_input_args.input_files: 
            self.batchInputApplet.configure_operator_with_parsed_args( self._batch_input_args )
        
        # Configure the data export operator.
        if self._batch_export_args:
            self.batchResultsApplet.configure_operator_with_parsed_args( self._batch_export_args )

        if self._headless and self._batch_input_args and self._batch_export_args:
            
            # Make sure we're using the up-to-date classifier.
            self.pcApplet.topLevelOperator.FreezePredictions.setValue(False)
        
            # Now run the batch export and report progress....
            opBatchDataExport = self.batchResultsApplet.topLevelOperator
            for i, opExportDataLaneView in enumerate(opBatchDataExport):
                print "Exporting result {} to {}".format(i, opExportDataLaneView.ExportPath.value)
    
                sys.stdout.write( "Result {}/{} Progress: ".format( i, len( opBatchDataExport ) ) )
                def print_progress( progress ):
                    sys.stdout.write( "{} ".format( progress ) )
    
                # If the operator provides a progress signal, use it.
                slotProgressSignal = opExportDataLaneView.progressSignal
                slotProgressSignal.subscribe( print_progress )
                opExportDataLaneView.run_export()
                
                # Finished.
                sys.stdout.write("\n")
Exemplo n.º 13
0
    def __init__(self, shell, headless, workflow_cmdline_args, project_creation_args, n_stages, *args, **kwargs):
        """
        n_stages: How many iterations of feature selection and pixel classification should be inserted into the workflow.
        
        All other params are just as in PixelClassificationWorkflow
        """
        # Create a graph to be shared by all operators
        graph = Graph()
        super( NewAutocontextWorkflowBase, self ).__init__( shell, headless, workflow_cmdline_args, project_creation_args, graph=graph, *args, **kwargs )
        self.stored_classifers = []
        self._applets = []
        self._workflow_cmdline_args = workflow_cmdline_args

        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument('--retrain', help="Re-train the classifier based on labels stored in project file, and re-save.", action="store_true")

        # Parse the creation args: These were saved to the project file when this project was first created.
        parsed_creation_args, unused_args = parser.parse_known_args(project_creation_args)
        
        # Parse the cmdline args for the current session.
        parsed_args, unused_args = parser.parse_known_args(workflow_cmdline_args)
        self.retrain = parsed_args.retrain
        
        data_instructions = "Select your input data using the 'Raw Data' tab shown on the right.\n\n"\
                            "Power users: Optionally use the 'Prediction Mask' tab to supply a binary image that tells ilastik where it should avoid computations you don't need."

        self.dataSelectionApplet = self.createDataSelectionApplet()
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        
        # see role constants, above
        role_names = ['Raw Data', 'Prediction Mask']
        opDataSelection.DatasetRoles.setValue( role_names )

        self.featureSelectionApplets = []
        self.pcApplets = []
        for i in range(n_stages):
            self.featureSelectionApplets.append( self.createFeatureSelectionApplet(i) )
            self.pcApplets.append( self.createPixelClassificationApplet(i) )
        opFinalClassify = self.pcApplets[-1].topLevelOperator

        # If *any* stage enters 'live update' mode, make sure they all enter live update mode.
        def sync_freeze_predictions_settings( slot, *args ):
            freeze_predictions = slot.value
            for pcApplet in self.pcApplets:
                pcApplet.topLevelOperator.FreezePredictions.setValue( freeze_predictions )
        for pcApplet in self.pcApplets:
            pcApplet.topLevelOperator.FreezePredictions.notifyDirty( sync_freeze_predictions_settings )

        self.dataExportApplet = PixelClassificationDataExportApplet(self, "Prediction Export")
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.PmapColors.connect( opFinalClassify.PmapColors )
        opDataExport.LabelNames.connect( opFinalClassify.LabelNames )
        opDataExport.WorkingDirectory.connect( opDataSelection.WorkingDirectory )

        self.EXPORT_NAMES = []
        for stage_index in reversed(list(range(n_stages))):
            self.EXPORT_NAMES += ["{} Stage {}".format( name, stage_index+1 ) for name in self.EXPORT_NAMES_PER_STAGE]
        
        # And finally, one last item for *all* probabilities from all stages.
        self.EXPORT_NAMES += ["Probabilities All Stages"]
        opDataExport.SelectionNames.setValue( self.EXPORT_NAMES )

        # Expose for shell
        self._applets.append(self.dataSelectionApplet)
        self._applets += itertools.chain(*list(zip(self.featureSelectionApplets, self.pcApplets)))
        self._applets.append(self.dataExportApplet)
        
        self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        self.dataExportApplet.post_process_entire_export = self.post_process_entire_export

        self.batchProcessingApplet = BatchProcessingApplet(self, 
                                                           "Batch Processing", 
                                                           self.dataSelectionApplet, 
                                                           self.dataExportApplet)

        self._applets.append(self.batchProcessingApplet)
        if unused_args:
            # We parse the export setting args first.  All remaining args are considered input files by the input applet.
            self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args( unused_args )
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args( unused_args )
        else:
            self._batch_input_args = None
            self._batch_export_args = None

        if unused_args:
            logger.warning("Unused command-line args: {}".format( unused_args ))
Exemplo n.º 14
0
class NewAutocontextWorkflowBase(Workflow):
    
    workflowName = "New Autocontext Base"
    defaultAppletIndex = 0 # show DataSelection by default
    
    DATA_ROLE_RAW = 0
    DATA_ROLE_PREDICTION_MASK = 1
    
    # First export names must match these for the export GUI, because we re-use the ordinary PC gui
    # (See PixelClassificationDataExportGui.)
    EXPORT_NAMES_PER_STAGE = ['Probabilities', 'Simple Segmentation', 'Uncertainty', 'Features', 'Labels', 'Input']
    
    @property
    def applets(self):
        return self._applets

    @property
    def imageNameListSlot(self):
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def __init__(self, shell, headless, workflow_cmdline_args, project_creation_args, n_stages, *args, **kwargs):
        """
        n_stages: How many iterations of feature selection and pixel classification should be inserted into the workflow.
        
        All other params are just as in PixelClassificationWorkflow
        """
        # Create a graph to be shared by all operators
        graph = Graph()
        super( NewAutocontextWorkflowBase, self ).__init__( shell, headless, workflow_cmdline_args, project_creation_args, graph=graph, *args, **kwargs )
        self.stored_classifers = []
        self._applets = []
        self._workflow_cmdline_args = workflow_cmdline_args

        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument('--retrain', help="Re-train the classifier based on labels stored in project file, and re-save.", action="store_true")

        # Parse the creation args: These were saved to the project file when this project was first created.
        parsed_creation_args, unused_args = parser.parse_known_args(project_creation_args)
        
        # Parse the cmdline args for the current session.
        parsed_args, unused_args = parser.parse_known_args(workflow_cmdline_args)
        self.retrain = parsed_args.retrain
        
        data_instructions = "Select your input data using the 'Raw Data' tab shown on the right.\n\n"\
                            "Power users: Optionally use the 'Prediction Mask' tab to supply a binary image that tells ilastik where it should avoid computations you don't need."

        self.dataSelectionApplet = self.createDataSelectionApplet()
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        
        # see role constants, above
        role_names = ['Raw Data', 'Prediction Mask']
        opDataSelection.DatasetRoles.setValue( role_names )

        self.featureSelectionApplets = []
        self.pcApplets = []
        for i in range(n_stages):
            self.featureSelectionApplets.append( self.createFeatureSelectionApplet(i) )
            self.pcApplets.append( self.createPixelClassificationApplet(i) )
        opFinalClassify = self.pcApplets[-1].topLevelOperator

        # If *any* stage enters 'live update' mode, make sure they all enter live update mode.
        def sync_freeze_predictions_settings( slot, *args ):
            freeze_predictions = slot.value
            for pcApplet in self.pcApplets:
                pcApplet.topLevelOperator.FreezePredictions.setValue( freeze_predictions )
        for pcApplet in self.pcApplets:
            pcApplet.topLevelOperator.FreezePredictions.notifyDirty( sync_freeze_predictions_settings )

        self.dataExportApplet = PixelClassificationDataExportApplet(self, "Prediction Export")
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.PmapColors.connect( opFinalClassify.PmapColors )
        opDataExport.LabelNames.connect( opFinalClassify.LabelNames )
        opDataExport.WorkingDirectory.connect( opDataSelection.WorkingDirectory )

        self.EXPORT_NAMES = []
        for stage_index in reversed(list(range(n_stages))):
            self.EXPORT_NAMES += ["{} Stage {}".format( name, stage_index+1 ) for name in self.EXPORT_NAMES_PER_STAGE]
        
        # And finally, one last item for *all* probabilities from all stages.
        self.EXPORT_NAMES += ["Probabilities All Stages"]
        opDataExport.SelectionNames.setValue( self.EXPORT_NAMES )

        # Expose for shell
        self._applets.append(self.dataSelectionApplet)
        self._applets += itertools.chain(*list(zip(self.featureSelectionApplets, self.pcApplets)))
        self._applets.append(self.dataExportApplet)
        
        self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        self.dataExportApplet.post_process_entire_export = self.post_process_entire_export

        self.batchProcessingApplet = BatchProcessingApplet(self, 
                                                           "Batch Processing", 
                                                           self.dataSelectionApplet, 
                                                           self.dataExportApplet)

        self._applets.append(self.batchProcessingApplet)
        if unused_args:
            # We parse the export setting args first.  All remaining args are considered input files by the input applet.
            self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args( unused_args )
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args( unused_args )
        else:
            self._batch_input_args = None
            self._batch_export_args = None

        if unused_args:
            logger.warning("Unused command-line args: {}".format( unused_args ))

    def createDataSelectionApplet(self):
        """
        Can be overridden by subclasses, if they want to use 
        special parameters to initialize the DataSelectionApplet.
        """
        data_instructions = "Select your input data using the 'Raw Data' tab shown on the right"

        c_at_end = ['yxc', 'xyc']
        for perm in itertools.permutations('tzyx', 3):
            c_at_end.append(''.join(perm) + 'c')
        for perm in itertools.permutations('tzyx', 4):
            c_at_end.append(''.join(perm) + 'c')

        return DataSelectionApplet( self,
                                    "Input Data",
                                    "Input Data",
                                    supportIlastik05Import=False,
                                    instructionText=data_instructions,
                                    forceAxisOrder=c_at_end)

    def createFeatureSelectionApplet(self, index):
        """
        Can be overridden by subclasses, if they want to return their own type of FeatureSelectionApplet.
        NOTE: The applet returned here must have the same interface as the regular FeatureSelectionApplet.
              (If it looks like a duck...)
        """
        # Make the first one compatible with the pixel classification workflow,
        # in case the user uses "Import Project"
        hdf5_group_name = 'FeatureSelections'
        if index > 0:
            hdf5_group_name = "FeatureSelections{index:02d}".format(index=index)
        applet = FeatureSelectionApplet(self, "Feature Selection", hdf5_group_name)
        applet.topLevelOperator.name += '{}'.format(index)
        return applet

    def createPixelClassificationApplet(self, index=0):
        """
        Can be overridden by subclasses, if they want to return their own type of PixelClassificationApplet.
        NOTE: The applet returned here must have the same interface as the regular PixelClassificationApplet.
              (If it looks like a duck...)
        """
        # Make the first one compatible with the pixel classification workflow,
        # in case the user uses "Import Project"
        hdf5_group_name = 'PixelClassification'
        if index > 0:
            hdf5_group_name = "PixelClassification{index:02d}".format(index=index)
        applet = PixelClassificationApplet( self, hdf5_group_name )
        applet.topLevelOperator.name += '{}'.format(index)
        return applet


    def prepareForNewLane(self, laneIndex):
        """
        Overridden from Workflow base class.
        Called immediately before a new lane is added to the workflow.
        """
        # When the new lane is added, dirty notifications will propagate throughout the entire graph.
        # This means the classifier will be marked 'dirty' even though it is still usable.
        # Before that happens, let's store the classifier, so we can restore it at the end of connectLane(), below.
        self.stored_classifers = []
        for pcApplet in self.pcApplets:
            opPixelClassification = pcApplet.topLevelOperator
            if opPixelClassification.classifier_cache.Output.ready() and \
               not opPixelClassification.classifier_cache._dirty:
                self.stored_classifers.append(opPixelClassification.classifier_cache.Output.value)
            else:
                self.stored_classifers = []
        
    def handleNewLanesAdded(self):
        """
        Overridden from Workflow base class.
        Called immediately after a new lane is added to the workflow and initialized.
        """
        # Restore classifier we saved in prepareForNewLane() (if any)
        if self.stored_classifers:
            for pcApplet, classifier in zip(self.pcApplets, self.stored_classifers):
                pcApplet.topLevelOperator.classifier_cache.forceValue(classifier)

            # Release references
            self.stored_classifers = []

    def connectLane(self, laneIndex):
        # Get a handle to each operator
        opData = self.dataSelectionApplet.topLevelOperator.getLane(laneIndex)
        opFirstFeatures = self.featureSelectionApplets[0].topLevelOperator.getLane(laneIndex)
        opFirstClassify = self.pcApplets[0].topLevelOperator.getLane(laneIndex)
        opFinalClassify = self.pcApplets[-1].topLevelOperator.getLane(laneIndex)
        opDataExport = self.dataExportApplet.topLevelOperator.getLane(laneIndex)
        
        def checkConstraints(*_):
            # if (opData.Image.meta.dtype in [np.uint8, np.uint16]) == False:
            #    msg = "The Autocontext Workflow only supports 8-bit images (UINT8 pixel type)\n"\
            #          "or 16-bit images (UINT16 pixel type)\n"\
            #          "Your image has a pixel type of {}.  Please convert your data to UINT8 and try again."\
            #          .format( str(np.dtype(opData.Image.meta.dtype)) )
            #    raise DatasetConstraintError( "Autocontext Workflow", msg, unfixable=True )
            pass

        opData.Image.notifyReady( checkConstraints )
        
        # Input Image -> Feature Op
        #         and -> Classification Op (for display)
        opFirstFeatures.InputImage.connect( opData.Image )
        opFirstClassify.InputImages.connect( opData.Image )

        # Feature Images -> Classification Op (for training, prediction)
        opFirstClassify.FeatureImages.connect( opFirstFeatures.OutputImage )
        opFirstClassify.CachedFeatureImages.connect( opFirstFeatures.CachedOutputImage )

        upstreamPcApplets = self.pcApplets[0:-1]
        downstreamFeatureApplets = self.featureSelectionApplets[1:]
        downstreamPcApplets = self.pcApplets[1:]

        for ( upstreamPcApplet,
              downstreamFeaturesApplet,
              downstreamPcApplet ) in zip( upstreamPcApplets, 
                                           downstreamFeatureApplets, 
                                           downstreamPcApplets ):
            
            opUpstreamClassify = upstreamPcApplet.topLevelOperator.getLane(laneIndex)
            opDownstreamFeatures = downstreamFeaturesApplet.topLevelOperator.getLane(laneIndex)
            opDownstreamClassify = downstreamPcApplet.topLevelOperator.getLane(laneIndex)

            # Connect label inputs (all are connected together).
            #opDownstreamClassify.LabelInputs.connect( opUpstreamClassify.LabelInputs )
            
            # Connect data path
            assert opData.Image.meta.dtype == opUpstreamClassify.PredictionProbabilitiesAutocontext.meta.dtype, (
                "Probability dtype needs to match up with input image dtype, got: "
                f"input: {opData.Image.meta.dtype} "
                f"probabilities: {opUpstreamClassify.PredictionProbabilitiesAutocontext.meta.dtype}"
            )
            opStacker = OpMultiArrayStacker(parent=self)
            opStacker.Images.resize(2)
            opStacker.Images[0].connect( opData.Image )
            opStacker.Images[1].connect( opUpstreamClassify.PredictionProbabilitiesAutocontext )
            opStacker.AxisFlag.setValue('c')
            
            opDownstreamFeatures.InputImage.connect( opStacker.Output )
            opDownstreamClassify.InputImages.connect( opStacker.Output )
            opDownstreamClassify.FeatureImages.connect( opDownstreamFeatures.OutputImage )
            opDownstreamClassify.CachedFeatureImages.connect( opDownstreamFeatures.CachedOutputImage )

        # Data Export connections
        opDataExport.RawData.connect( opData.ImageGroup[self.DATA_ROLE_RAW] )
        opDataExport.RawDatasetInfo.connect( opData.DatasetGroup[self.DATA_ROLE_RAW] )
        opDataExport.ConstraintDataset.connect( opData.ImageGroup[self.DATA_ROLE_RAW] )

        opDataExport.Inputs.resize( len(self.EXPORT_NAMES) )
        for reverse_stage_index, (stage_index, pcApplet) in enumerate(reversed(list(enumerate(self.pcApplets)))):
            opPc = pcApplet.topLevelOperator.getLane(laneIndex)
            num_items_per_stage = len(self.EXPORT_NAMES_PER_STAGE)
            opDataExport.Inputs[num_items_per_stage*reverse_stage_index+0].connect( opPc.HeadlessPredictionProbabilities )
            opDataExport.Inputs[num_items_per_stage*reverse_stage_index+1].connect( opPc.SimpleSegmentation )
            opDataExport.Inputs[num_items_per_stage*reverse_stage_index+2].connect( opPc.HeadlessUncertaintyEstimate )
            opDataExport.Inputs[num_items_per_stage*reverse_stage_index+3].connect( opPc.FeatureImages )
            opDataExport.Inputs[num_items_per_stage*reverse_stage_index+4].connect( opPc.LabelImages )
            opDataExport.Inputs[num_items_per_stage*reverse_stage_index+5].connect( opPc.InputImages ) # Input must come last due to an assumption in PixelClassificationDataExportGui

        # One last export slot for all probabilities, all stages
        opAllStageStacker = OpMultiArrayStacker(parent=self)
        opAllStageStacker.Images.resize( len(self.pcApplets) )
        for stage_index, pcApplet in enumerate(self.pcApplets):
            opPc = pcApplet.topLevelOperator.getLane(laneIndex)
            opAllStageStacker.Images[stage_index].connect(opPc.HeadlessPredictionProbabilities)
            opAllStageStacker.AxisFlag.setValue('c')

        # The ideal_blockshape metadata field will be bogus, so just eliminate it
        # (Otherwise, the channels could be split up in an unfortunate way.)
        opMetadataOverride = OpMetadataInjector(parent=self)
        opMetadataOverride.Input.connect( opAllStageStacker.Output )
        opMetadataOverride.Metadata.setValue( {'ideal_blockshape' : None } )
        
        opDataExport.Inputs[-1].connect( opMetadataOverride.Output )
        
        for slot in opDataExport.Inputs:
            assert slot.upstream_slot is not None

    def handleAppletStateUpdateRequested(self):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.appletStateUpdateRequested`
        """
        # If no data, nothing else is ready.
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        input_ready = len(opDataSelection.ImageGroup) > 0 and not self.dataSelectionApplet.busy

        # First, determine various 'ready' states for each pixel classification stage (features+prediction)
        StageFlags = collections.namedtuple("StageFlags", 'input_ready features_ready invalid_classifier predictions_ready live_update_active')
        stage_flags = []
        for stage_index, (featureSelectionApplet, pcApplet) in enumerate(zip(self.featureSelectionApplets, self.pcApplets)):
            if stage_index == 0:
                input_ready = len(opDataSelection.ImageGroup) > 0 and not self.dataSelectionApplet.busy
            else:
                input_ready = stage_flags[stage_index-1].predictions_ready

            opFeatureSelection = featureSelectionApplet.topLevelOperator
            featureOutput = opFeatureSelection.OutputImage
            features_ready = input_ready and \
                             len(featureOutput) > 0 and  \
                             featureOutput[0].ready() and \
                             (TinyVector(featureOutput[0].meta.shape) > 0).all()

            opPixelClassification = pcApplet.topLevelOperator
            invalid_classifier = opPixelClassification.classifier_cache.fixAtCurrent.value and \
                                 opPixelClassification.classifier_cache.Output.ready() and\
                                 opPixelClassification.classifier_cache.Output.value is None
    
            predictions_ready = features_ready and \
                                not invalid_classifier and \
                                len(opPixelClassification.HeadlessPredictionProbabilities) > 0 and \
                                opPixelClassification.HeadlessPredictionProbabilities[0].ready() and \
                                (TinyVector(opPixelClassification.HeadlessPredictionProbabilities[0].meta.shape) > 0).all()

            live_update_active = not opPixelClassification.FreezePredictions.value
            
            stage_flags += [ StageFlags(input_ready, features_ready, invalid_classifier, predictions_ready, live_update_active) ]



        opDataExport = self.dataExportApplet.topLevelOperator
        opPixelClassification = self.pcApplets[0].topLevelOperator

        # Problems can occur if the features or input data are changed during live update mode.
        # Don't let the user do that.
        any_live_update = any(flags.live_update_active for flags in stage_flags)
        
        # The user isn't allowed to touch anything while batch processing is running.
        batch_processing_busy = self.batchProcessingApplet.busy
        
        self._shell.setAppletEnabled(self.dataSelectionApplet, not any_live_update and not batch_processing_busy)

        for stage_index, (featureSelectionApplet, pcApplet) in enumerate(zip(self.featureSelectionApplets, self.pcApplets)):
            upstream_live_update = any(flags.live_update_active for flags in stage_flags[:stage_index])
            this_stage_live_update = stage_flags[stage_index].live_update_active
            downstream_live_update = any(flags.live_update_active for flags in stage_flags[stage_index+1:])
            
            self._shell.setAppletEnabled(featureSelectionApplet, stage_flags[stage_index].input_ready \
                                                                 and not this_stage_live_update \
                                                                 and not downstream_live_update \
                                                                 and not batch_processing_busy)
            
            self._shell.setAppletEnabled(pcApplet, stage_flags[stage_index].features_ready \
                                                   #and not downstream_live_update \ # Not necessary because live update modes are synced -- labels can't be added in live update.
                                                   and not batch_processing_busy)

        self._shell.setAppletEnabled(self.dataExportApplet, stage_flags[-1].predictions_ready and not batch_processing_busy)
        self._shell.setAppletEnabled(self.batchProcessingApplet, predictions_ready and not batch_processing_busy)
    
        # Lastly, check for certain "busy" conditions, during which we 
        #  should prevent the shell from closing the project.
        busy = False
        busy |= self.dataSelectionApplet.busy
        busy |= any(applet.busy for applet in self.featureSelectionApplets)
        busy |= self.dataExportApplet.busy
        busy |= self.batchProcessingApplet.busy
        self._shell.enableProjectChanges( not busy )

    def onProjectLoaded(self, projectManager):
        """
        Overridden from Workflow base class.  Called by the Project Manager.
        
        If the user provided command-line arguments, use them to configure 
        the workflow for batch mode and export all results.
        (This workflow's headless mode supports only batch mode for now.)
        """
        if self._headless:
            # In headless mode, let's see the messages from the training operator.
            logging.getLogger("lazyflow.operators.classifierOperators").setLevel(logging.DEBUG)

        if self.retrain:
            self._force_retrain_classifiers(projectManager)
        
        # Configure the data export operator.
        if self._batch_export_args:
            self.dataExportApplet.configure_operator_with_parsed_args( self._batch_export_args )

        if self._batch_input_args:
            for pcApplet in self.pcApplets:
                if pcApplet.topLevelOperator.classifier_cache._dirty:
                    logger.warning("At least one of your classifiers is not yet trained.  "
                                "A new classifier will be trained for this run.")
                    break

        if self._headless and self._batch_input_args and self._batch_export_args:
            logger.info("Beginning Batch Processing")
            self.batchProcessingApplet.run_export_from_parsed_args(self._batch_input_args)
            logger.info("Completed Batch Processing")

    def prepare_for_entire_export(self):
        # While exporting, we don't want to cache any data.
        export_selection_index = self.dataExportApplet.topLevelOperator.InputSelection.value
        export_selection_name = self.dataExportApplet.topLevelOperator.SelectionNames.value[ export_selection_index ]
        if 'all stages' in export_selection_name.lower():
            # UNLESS we're exporting from more than one stage at a time.
            # In that case, the caches help avoid unnecessary work (except for the last stage)
            self.featureSelectionApplets[-1].topLevelOperator.BypassCache.setValue(True)
        else:
            for featureSeletionApplet in self.featureSelectionApplets:
                featureSeletionApplet.topLevelOperator.BypassCache.setValue(True)
            
        # Unfreeze the classifier caches (ensure that we're exporting based on up-to-date labels)
        self.freeze_statuses = []
        for pcApplet in self.pcApplets:
            self.freeze_statuses.append(pcApplet.topLevelOperator.FreezePredictions.value)
            pcApplet.topLevelOperator.FreezePredictions.setValue(False)

    def post_process_entire_export(self):
        # While exporting, we disabled caches, but now we can enable them again.
        for featureSeletionApplet in self.featureSelectionApplets:
            featureSeletionApplet.topLevelOperator.BypassCache.setValue(False)

        # Re-freeze classifier caches (if necessary)
        for pcApplet, freeze_status in zip(self.pcApplets, self.freeze_statuses):
            pcApplet.topLevelOperator.FreezePredictions.setValue(freeze_status)

    def _force_retrain_classifiers(self, projectManager):
        # Cause the FIRST classifier to be dirty so it is forced to retrain.
        # (useful if the stored labels were changed outside ilastik)
        self.pcApplets[0].topLevelOperator.opTrain.ClassifierFactory.setDirty()
        
        # Unfreeze all classifier caches.
        for pcApplet in self.pcApplets:
            pcApplet.topLevelOperator.FreezePredictions.setValue(False)

        # Request the LAST classifier, which forces training
        _ = self.pcApplets[-1].topLevelOperator.Classifier.value

        # store new classifiers to project file
        projectManager.saveProject(force_all_save=False)

    def menus(self):
        """
        Overridden from Workflow base class
        """
        from PyQt5.QtWidgets import QMenu
        autocontext_menu = QMenu("Autocontext Utilities")
        distribute_action = autocontext_menu.addAction("Distribute Labels...")
        distribute_action.triggered.connect( self.distribute_labels_from_current_stage )

        self._autocontext_menu = autocontext_menu # Must retain here as a member or else reference disappears and the menu is deleted.
        return [self._autocontext_menu]
        
    def distribute_labels_from_current_stage(self):
        """
        Distrubute labels from the currently viewed stage across all other stages.
        """
        # Late import.
        # (Don't import PyQt in headless mode.)
        from PyQt5.QtWidgets import QMessageBox
        current_applet = self._applets[self.shell.currentAppletIndex]
        if current_applet not in self.pcApplets:
            QMessageBox.critical(self.shell, "Wrong page selected", "The currently active page isn't a Training page.")
            return
        
        current_stage_index = self.pcApplets.index(current_applet)
        destination_stage_indexes, partition = self.get_label_distribution_settings( current_stage_index,
                                                                                     num_stages=len(self.pcApplets))
        if destination_stage_indexes is None:
            return # User cancelled
        
        current_applet = self._applets[self.shell.currentAppletIndex]
        opCurrentPixelClassification = current_applet.topLevelOperator
        num_current_stage_classes = len(opCurrentPixelClassification.LabelNames.value)

        # Before we get started, make sure the destination stages have the necessary label classes
        for stage_index in destination_stage_indexes:
            # Get this stage's OpPixelClassification
            opPc = self.pcApplets[stage_index].topLevelOperator
    
            # Copy Label Colors
            current_stage_label_colors = opCurrentPixelClassification.LabelColors.value
            new_label_colors = list(opPc.LabelColors.value)
            new_label_colors[:num_current_stage_classes] = current_stage_label_colors[:num_current_stage_classes]
            opPc.LabelColors.setValue(new_label_colors)
            
            # Copy PMap colors
            current_stage_pmap_colors = opCurrentPixelClassification.PmapColors.value
            new_pmap_colors = list(opPc.PmapColors.value)
            new_pmap_colors[:num_current_stage_classes] = current_stage_pmap_colors[:num_current_stage_classes]
            opPc.PmapColors.setValue(new_pmap_colors)
    
            # Copy Label Names                    
            current_stage_label_names = opCurrentPixelClassification.LabelNames.value
            new_label_names = list(opPc.LabelNames.value)
            new_label_names[:num_current_stage_classes] = current_stage_label_names[:num_current_stage_classes]
            opPc.LabelNames.setValue(new_label_names)

        # For each lane, copy over the labels from the source stage to the destination stages 
        for lane_index in range(len(opCurrentPixelClassification.InputImages)):
            opPcLane = opCurrentPixelClassification.getLane(lane_index)

            # Gather all the labels for this lane
            blockwise_labels = {}
            nonzero_slicings = opPcLane.NonzeroLabelBlocks.value
            for block_slicing in nonzero_slicings:
                # Convert from slicing to roi-tuple so we can hash it in a dict key
                block_roi = sliceToRoi( block_slicing, opPcLane.InputImages.meta.shape )
                block_roi = tuple(map(tuple, block_roi))
                blockwise_labels[block_roi] = opPcLane.LabelImages[block_slicing].wait()

            if partition and current_stage_index in destination_stage_indexes:
                # Clear all labels in the current lane, since we'll be overwriting it with a subset of labels
                # FIXME: We could implement a fast function for this in OpCompressedUserLabelArray...
                for label_value in range(1,num_current_stage_classes+1,):
                    opPcLane.opLabelPipeline.opLabelArray.clearLabel(label_value)

            # Now redistribute those labels across all lanes
            for block_roi, block_labels in list(blockwise_labels.items()):
                nonzero_coords = block_labels.nonzero()

                if partition:
                    num_labels = len(nonzero_coords[0])
                    destination_stage_map = np.random.choice(destination_stage_indexes, (num_labels,))
                
                for stage_index in destination_stage_indexes:
                    if not partition:
                        this_stage_block_labels = block_labels
                    else:
                        # Divide into disjoint partitions
                        # Find the coordinates labels destined for this stage
                        this_stage_coords = np.transpose(nonzero_coords)[destination_stage_map == stage_index]
                        this_stage_coords = tuple(this_stage_coords.transpose())

                        # Extract only the labels destined for this stage
                        this_stage_block_labels = np.zeros_like(block_labels)
                        this_stage_block_labels[this_stage_coords] = block_labels[this_stage_coords]

                    # Get the current lane's view of this stage's OpPixelClassification
                    opPc = self.pcApplets[stage_index].topLevelOperator.getLane(lane_index)

                    # Inject
                    opPc.LabelInputs[roiToSlice(*block_roi)] = this_stage_block_labels
    
    @staticmethod
    def get_label_distribution_settings(source_stage_index, num_stages):
        # Late import.
        # (Don't import PyQt in headless mode.)
        from PyQt5.QtWidgets import QDialog, QVBoxLayout
        class LabelDistributionOptionsDlg( QDialog ):
            """
            A little dialog to let the user specify how the labels should be
            distributed from the current stages to the other stages.
            """
            def __init__(self, source_stage_index, num_stages, *args, **kwargs):
                super(LabelDistributionOptionsDlg, self).__init__(*args, **kwargs)

                from PyQt5.QtCore import Qt
                from PyQt5.QtWidgets import QGroupBox, QCheckBox, QRadioButton, QDialogButtonBox
            
                self.setWindowTitle("Distributing from Stage {}".format(source_stage_index+1))

                self.stage_checkboxes = []
                for stage_index in range(1, num_stages+1):
                    self.stage_checkboxes.append( QCheckBox("Stage {}".format( stage_index )) )
                
                # By default, send labels back into the current stage, at least.
                self.stage_checkboxes[source_stage_index].setChecked(True)
                
                stage_selection_layout = QVBoxLayout()
                for checkbox in self.stage_checkboxes:
                    stage_selection_layout.addWidget( checkbox )

                stage_selection_groupbox = QGroupBox("Send labels from Stage {} to:".format( source_stage_index+1 ), self)
                stage_selection_groupbox.setLayout(stage_selection_layout)
                
                self.copy_button = QRadioButton("Copy", self)
                self.partition_button = QRadioButton("Partition", self)
                self.partition_button.setChecked(True)
                distribution_mode_layout = QVBoxLayout()
                distribution_mode_layout.addWidget(self.copy_button)
                distribution_mode_layout.addWidget(self.partition_button)
                
                distribution_mode_group = QGroupBox("Distribution Mode", self)
                distribution_mode_group.setLayout(distribution_mode_layout)
                
                buttonbox = QDialogButtonBox( Qt.Horizontal, parent=self )
                buttonbox.setStandardButtons( QDialogButtonBox.Ok | QDialogButtonBox.Cancel )
                buttonbox.accepted.connect( self.accept )
                buttonbox.rejected.connect( self.reject )
                
                dlg_layout = QVBoxLayout()
                dlg_layout.addWidget(stage_selection_groupbox)
                dlg_layout.addWidget(distribution_mode_group)
                dlg_layout.addWidget(buttonbox)
                self.setLayout(dlg_layout)

            def distribution_mode(self):
                if self.copy_button.isChecked():
                    return "copy"
                if self.partition_button.isChecked():
                    return "partition"
                assert False, "Shouldn't get here."
            
            def destination_stages(self):
                """
                Return the list of stage_indexes (0-based) that the user checked.
                """
                return [ i for i,box in enumerate(self.stage_checkboxes) if box.isChecked() ]

        dlg = LabelDistributionOptionsDlg( source_stage_index, num_stages )
        if dlg.exec_() == QDialog.Rejected:
            return (None, None)
        
        destination_stage_indexes = dlg.destination_stages()
        partition = (dlg.distribution_mode() == "partition")
        return (destination_stage_indexes, partition)