예제 #1
0
    def __init__(self, shell, headless, workflow_cmdline_args, project_creation_args, *args, **kwargs):

        # Create a graph to be shared by all operators
        graph = Graph()
        super(DataConversionWorkflow, self).__init__(
            shell, headless, workflow_cmdline_args, project_creation_args, graph=graph, *args, **kwargs
        )
        self._applets = []

        # Instantiate DataSelection applet
        self.dataSelectionApplet = DataSelectionApplet(
            self, "Input Data", "Input Data", supportIlastik05Import=True, forceAxisOrder=None
        )

        # Configure global DataSelection settings
        role_names = ["Input Data"]
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        opDataSelection.DatasetRoles.setValue(role_names)

        # Instantiate DataExport applet
        self.dataExportApplet = DataExportApplet(self, "Data Export")

        # Configure global DataExport settings
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.WorkingDirectory.connect(opDataSelection.WorkingDirectory)
        opDataExport.SelectionNames.setValue(["Input"])

        # No special data pre/post processing necessary in this workflow,
        #   but this is where we'd hook it up if we needed it.
        #
        # self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        # self.dataExportApplet.prepare_lane_for_export = self.prepare_lane_for_export
        # self.dataExportApplet.post_process_lane_export = self.post_process_lane_export
        # self.dataExportApplet.post_process_entire_export = self.post_process_entire_export

        # Instantiate BatchProcessing applet
        self.batchProcessingApplet = BatchProcessingApplet(
            self, "Batch Processing", self.dataSelectionApplet, self.dataExportApplet
        )

        # Expose our applets in a list (for the shell to use)
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.dataExportApplet)
        self._applets.append(self.batchProcessingApplet)

        # Parse command-line arguments
        # Command-line args are applied in onProjectLoaded(), below.
        if workflow_cmdline_args:
            self._data_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args(workflow_cmdline_args)
            self._batch_input_args, unused_args = self.dataSelectionApplet.parse_known_cmdline_args(
                unused_args, role_names
            )
        else:
            unused_args = None
            self._batch_input_args = None
            self._data_export_args = None

        if unused_args:
            logger.warning("Unused command-line args: {}".format(unused_args))
    def __init__(self, shell, headless, workflow_cmdline_args, project_creation_args, *args, **kwargs):

        # Create a graph to be shared by all operators
        graph = Graph()
        super(NNClassificationWorkflow, self).__init__(shell, headless, workflow_cmdline_args, project_creation_args, graph=graph, *args, **kwargs)
        self._applets = []
        self._workflow_cmdline_args = workflow_cmdline_args
        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        # parser.add_argument('--print-labels-by-slice', help="Print the number of labels for each Z-slice of each image.", action="store_true")

        # Parse the creation args: These were saved to the project file when this project was first created.
        parsed_creation_args, unused_args = parser.parse_known_args(project_creation_args)

        # Parse the cmdline args for the current session.
        parsed_args, unused_args = parser.parse_known_args(workflow_cmdline_args)
        # self.print_labels_by_slice = parsed_args.print_labels_by_slice

        data_instructions = "Select your input data using the 'Raw Data' tab shown on the right.\n\n"\
                            "Power users: Optionally use the 'Prediction Mask' tab to supply a binary image that tells ilastik where it should avoid computations you don't need."

        # Applets for training (interactive) workflow
        self.dataSelectionApplet = self.createDataSelectionApplet()
        opDataSelection = self.dataSelectionApplet.topLevelOperator

        # see role constants, above
        opDataSelection.DatasetRoles.setValue(NNClassificationWorkflow.ROLE_NAMES)

        self.nnClassificationApplet = NNClassApplet(self, "NNClassApplet")

        self.dataExportApplet = NNClassificationDataExportApplet(self, 'Data Export')

        # Configure global DataExport settings
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.WorkingDirectory.connect(opDataSelection.WorkingDirectory)
        opDataExport.SelectionNames.setValue(self.EXPORT_NAMES)

        self.batchProcessingApplet = BatchProcessingApplet(self,
                                                           "Batch Processing",
                                                           self.dataSelectionApplet,
                                                           self.dataExportApplet)

        # Expose for shell
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.nnClassificationApplet)
        self._applets.append(self.dataExportApplet)
        self._applets.append(self.batchProcessingApplet)

        if unused_args:
            # We parse the export setting args first.  All remaining args are considered input files by the input applet.
            self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args(unused_args)
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args(unused_args)
        else:
            self._batch_input_args = None
            self._batch_export_args = None

        if unused_args:
            logger.warn("Unused command-line args: {}".format(unused_args))
예제 #3
0
    def __init__(self, shell, headless, workflow_cmdline_args, project_creation_args, *args, **kwargs):

        
        # Create a graph to be shared by all operators
        graph = Graph()
        super(DataConversionWorkflow, self).__init__(shell, headless, workflow_cmdline_args, project_creation_args, graph=graph, *args, **kwargs)
        self._applets = []

        # Instantiate DataSelection applet
        self.dataSelectionApplet = DataSelectionApplet(
            self,
            "Input Data",
            "Input Data",
            supportIlastik05Import=True,
            forceAxisOrder=None)

        # Configure global DataSelection settings
        role_names = ["Input Data"]
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        opDataSelection.DatasetRoles.setValue( role_names )

        # Instantiate DataExport applet
        self.dataExportApplet = DataExportApplet(self, "Data Export")

        # Configure global DataExport settings
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.WorkingDirectory.connect( opDataSelection.WorkingDirectory )
        opDataExport.SelectionNames.setValue( ["Input"] )        

        # No special data pre/post processing necessary in this workflow, 
        #   but this is where we'd hook it up if we needed it.
        #
        #self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        #self.dataExportApplet.prepare_lane_for_export = self.prepare_lane_for_export
        #self.dataExportApplet.post_process_lane_export = self.post_process_lane_export
        #self.dataExportApplet.post_process_entire_export = self.post_process_entire_export

        # Instantiate BatchProcessing applet
        self.batchProcessingApplet = BatchProcessingApplet(self, 
                                                           "Batch Processing", 
                                                           self.dataSelectionApplet, 
                                                           self.dataExportApplet)

        # Expose our applets in a list (for the shell to use)
        self._applets.append( self.dataSelectionApplet )
        self._applets.append( self.dataExportApplet )
        self._applets.append(self.batchProcessingApplet)

        # Parse command-line arguments
        # Command-line args are applied in onProjectLoaded(), below.
        if workflow_cmdline_args:
            self._data_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args( workflow_cmdline_args )
            self._batch_input_args, unused_args = self.dataSelectionApplet.parse_known_cmdline_args( unused_args, role_names )
        else:
            unused_args = None
            self._batch_input_args = None
            self._data_export_args = None

        if unused_args:
            logger.warning("Unused command-line args: {}".format( unused_args ))
예제 #4
0
    def __init__(self, shell, headless, workflow_cmdline_args, project_creation_workflow, *args, **kwargs):
        # Create a graph to be shared by all operators
        graph = Graph()

        super(WsdtWorkflow, self).__init__( shell, headless, workflow_cmdline_args, project_creation_workflow, graph=graph, *args, **kwargs)
        self._applets = []

        # -- DataSelection applet
        #
        self.dataSelectionApplet = DataSelectionApplet(self, "Input Data", "Input Data")

        # Dataset inputs
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        opDataSelection.DatasetRoles.setValue( self.ROLE_NAMES )

        # -- Wsdt applet
        #
        self.wsdtApplet = WsdtApplet(self, "Watershed", "Wsdt Watershed")

        # -- DataExport applet
        #
        self.dataExportApplet = DataExportApplet(self, "Data Export")

        # Configure global DataExport settings
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.WorkingDirectory.connect( opDataSelection.WorkingDirectory )
        opDataExport.SelectionNames.setValue( self.EXPORT_NAMES )

        # -- BatchProcessing applet
        #
        self.batchProcessingApplet = BatchProcessingApplet(self,
                                                           "Batch Processing",
                                                           self.dataSelectionApplet,
                                                           self.dataExportApplet)

        # -- Expose applets to shell
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.wsdtApplet)
        self._applets.append(self.dataExportApplet)
        self._applets.append(self.batchProcessingApplet)

        # -- Parse command-line arguments
        #    (Command-line args are applied in onProjectLoaded(), below.)
        if workflow_cmdline_args:
            self._data_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args( workflow_cmdline_args )
            self._batch_input_args, unused_args = self.dataSelectionApplet.parse_known_cmdline_args( unused_args, role_names )
        else:
            unused_args = None
            self._batch_input_args = None
            self._data_export_args = None

        if unused_args:
            logger.warning("Unused command-line args: {}".format( unused_args ))
예제 #5
0
class WsdtWorkflow(Workflow):
    workflowName = "Watershed Over Distance Transform"
    workflowDescription = "A bare-bones workflow for using the WSDT applet"
    defaultAppletIndex = 0 # show DataSelection by default

    DATA_ROLE_RAW = 0
    DATA_ROLE_PROBABILITIES = 1
    ROLE_NAMES = ['Raw Data', 'Probabilities']
    EXPORT_NAMES = ['Watershed']

    @property
    def applets(self):
        return self._applets

    @property
    def imageNameListSlot(self):
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def __init__(self, shell, headless, workflow_cmdline_args, project_creation_workflow, *args, **kwargs):
        # Create a graph to be shared by all operators
        graph = Graph()

        super(WsdtWorkflow, self).__init__( shell, headless, workflow_cmdline_args, project_creation_workflow, graph=graph, *args, **kwargs)
        self._applets = []

        # -- DataSelection applet
        #
        self.dataSelectionApplet = DataSelectionApplet(self, "Input Data", "Input Data")

        # Dataset inputs
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        opDataSelection.DatasetRoles.setValue( self.ROLE_NAMES )

        # -- Wsdt applet
        #
        self.wsdtApplet = WsdtApplet(self, "Watershed", "Wsdt Watershed")

        # -- DataExport applet
        #
        self.dataExportApplet = DataExportApplet(self, "Data Export")

        # Configure global DataExport settings
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.WorkingDirectory.connect( opDataSelection.WorkingDirectory )
        opDataExport.SelectionNames.setValue( self.EXPORT_NAMES )

        # -- BatchProcessing applet
        #
        self.batchProcessingApplet = BatchProcessingApplet(self,
                                                           "Batch Processing",
                                                           self.dataSelectionApplet,
                                                           self.dataExportApplet)

        # -- Expose applets to shell
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.wsdtApplet)
        self._applets.append(self.dataExportApplet)
        self._applets.append(self.batchProcessingApplet)

        # -- Parse command-line arguments
        #    (Command-line args are applied in onProjectLoaded(), below.)
        if workflow_cmdline_args:
            self._data_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args( workflow_cmdline_args )
            self._batch_input_args, unused_args = self.dataSelectionApplet.parse_known_cmdline_args( unused_args, role_names )
        else:
            unused_args = None
            self._batch_input_args = None
            self._data_export_args = None

        if unused_args:
            logger.warning("Unused command-line args: {}".format( unused_args ))

    def connectLane(self, laneIndex):
        """
        Override from base class.
        """
        opDataSelection = self.dataSelectionApplet.topLevelOperator.getLane(laneIndex)
        opWsdt = self.wsdtApplet.topLevelOperator.getLane(laneIndex)
        opDataExport = self.dataExportApplet.topLevelOperator.getLane(laneIndex)

        # watershed inputs
        opWsdt.RawData.connect( opDataSelection.ImageGroup[self.DATA_ROLE_RAW] )
        opWsdt.Input.connect( opDataSelection.ImageGroup[self.DATA_ROLE_PROBABILITIES] )

        # DataExport inputs
        opDataExport.RawData.connect( opDataSelection.ImageGroup[self.DATA_ROLE_RAW] )
        opDataExport.RawDatasetInfo.connect( opDataSelection.DatasetGroup[self.DATA_ROLE_RAW] )        
        opDataExport.Inputs.resize( len(self.EXPORT_NAMES) )
        opDataExport.Inputs[0].connect( opWsdt.Superpixels )
        for slot in opDataExport.Inputs:
            assert slot.partner is not None
        
    def onProjectLoaded(self, projectManager):
        """
        Overridden from Workflow base class.  Called by the Project Manager.
        
        If the user provided command-line arguments, use them to configure 
        the workflow inputs and output settings.
        """
        # Configure the data export operator.
        if self._data_export_args:
            self.dataExportApplet.configure_operator_with_parsed_args( self._data_export_args )

        if self._headless and self._batch_input_args and self._data_export_args:
            logger.info("Beginning Batch Processing")
            self.batchProcessingApplet.run_export_from_parsed_args(self._batch_input_args)
            logger.info("Completed Batch Processing")

    def handleAppletStateUpdateRequested(self):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.appletStateUpdateRequested`
        """
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        opDataExport = self.dataExportApplet.topLevelOperator
        opWsdt = self.wsdtApplet.topLevelOperator

        # If no data, nothing else is ready.
        input_ready = len(opDataSelection.ImageGroup) > 0 and not self.dataSelectionApplet.busy

        # The user isn't allowed to touch anything while batch processing is running.
        batch_processing_busy = self.batchProcessingApplet.busy

        self._shell.setAppletEnabled( self.dataSelectionApplet,   not batch_processing_busy )
        self._shell.setAppletEnabled( self.wsdtApplet,            not batch_processing_busy and input_ready )
        self._shell.setAppletEnabled( self.dataExportApplet,      not batch_processing_busy and input_ready and opWsdt.Superpixels.ready())
        self._shell.setAppletEnabled( self.batchProcessingApplet, not batch_processing_busy and input_ready )

        # Lastly, check for certain "busy" conditions, during which we
        #  should prevent the shell from closing the project.
        busy = False
        busy |= self.dataSelectionApplet.busy
        busy |= self.wsdtApplet.busy
        busy |= self.dataExportApplet.busy
        busy |= self.batchProcessingApplet.busy
        self._shell.enableProjectChanges( not busy )
예제 #6
0
class EdgeTrainingWithMulticutWorkflow(Workflow):
    workflowName = "Edge Training With Multicut"
    workflowDisplayName = "(BETA) Edge Training With Multicut"

    workflowDescription = "A workflow based around training a classifier for merging superpixels and joining them via multicut."
    defaultAppletIndex = 0  # show DataSelection by default

    DATA_ROLE_RAW = 0
    DATA_ROLE_PROBABILITIES = 1
    DATA_ROLE_SUPERPIXELS = 2
    DATA_ROLE_GROUNDTRUTH = 3
    ROLE_NAMES = ['Raw Data', 'Probabilities', 'Superpixels', 'Groundtruth']
    EXPORT_NAMES = ['Multicut Segmentation']

    @property
    def applets(self):
        return self._applets

    @property
    def imageNameListSlot(self):
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def __init__(self, shell, headless, workflow_cmdline_args,
                 project_creation_workflow, *args, **kwargs):
        self.stored_classifier = None

        # Create a graph to be shared by all operators
        graph = Graph()

        super(EdgeTrainingWithMulticutWorkflow,
              self).__init__(shell,
                             headless,
                             workflow_cmdline_args,
                             project_creation_workflow,
                             graph=graph,
                             *args,
                             **kwargs)
        self._applets = []

        # -- DataSelection applet
        #
        self.dataSelectionApplet = DataSelectionApplet(self, "Input Data",
                                                       "Input Data")

        # Dataset inputs
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        opDataSelection.DatasetRoles.setValue(self.ROLE_NAMES)

        # -- Watershed applet
        #
        self.wsdtApplet = WsdtApplet(self, "DT Watershed", "DT Watershed")

        # -- Edge training AND Multicut applet
        #
        self.edgeTrainingWithMulticutApplet = EdgeTrainingWithMulticutApplet(
            self, "Training and Multicut", "Training and Multicut")
        opEdgeTrainingWithMulticut = self.edgeTrainingWithMulticutApplet.topLevelOperator
        DEFAULT_FEATURES = {
            self.ROLE_NAMES[self.DATA_ROLE_RAW]: ['standard_edge_mean']
        }
        opEdgeTrainingWithMulticut.FeatureNames.setValue(DEFAULT_FEATURES)

        # -- DataExport applet
        #
        self.dataExportApplet = DataExportApplet(self, "Data Export")
        self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        self.dataExportApplet.post_process_entire_export = self.post_process_entire_export

        # Configure global DataExport settings
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.WorkingDirectory.connect(opDataSelection.WorkingDirectory)
        opDataExport.SelectionNames.setValue(self.EXPORT_NAMES)

        # -- BatchProcessing applet
        #
        self.batchProcessingApplet = BatchProcessingApplet(
            self, "Batch Processing", self.dataSelectionApplet,
            self.dataExportApplet)

        # -- Expose applets to shell
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.wsdtApplet)
        self._applets.append(self.edgeTrainingWithMulticutApplet)
        self._applets.append(self.dataExportApplet)
        self._applets.append(self.batchProcessingApplet)

        # -- Parse command-line arguments
        #    (Command-line args are applied in onProjectLoaded(), below.)
        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument(
            '--retrain',
            help=
            "Re-train the classifier based on labels stored in the project file, and re-save.",
            action="store_true")
        self.parsed_workflow_args, unused_args = parser.parse_known_args(
            workflow_cmdline_args)
        if unused_args:
            # Parse batch export/input args.
            self._data_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args(
                unused_args)
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args(
                unused_args)
        else:
            unused_args = None
            self._batch_input_args = None
            self._data_export_args = None

        if unused_args:
            logger.warn("Unused command-line args: {}".format(unused_args))

        if not self._headless:
            shell.currentAppletChanged.connect(self.handle_applet_changed)

    def prepareForNewLane(self, laneIndex):
        """
        Overridden from Workflow base class.
        Called immediately before a new lane is added to the workflow.
        """
        opEdgeTrainingWithMulticut = self.edgeTrainingWithMulticutApplet.topLevelOperator
        opClassifierCache = opEdgeTrainingWithMulticut.opEdgeTraining.opClassifierCache

        # When the new lane is added, dirty notifications will propagate throughout the entire graph.
        # This means the classifier will be marked 'dirty' even though it is still usable.
        # Before that happens, let's store the classifier, so we can restore it in handleNewLanesAdded(), below.
        if opClassifierCache.Output.ready() and \
           not opClassifierCache._dirty:
            self.stored_classifier = opClassifierCache.Output.value
        else:
            self.stored_classifier = None

    def handleNewLanesAdded(self):
        """
        Overridden from Workflow base class.
        Called immediately after a new lane is added to the workflow and initialized.
        """
        opEdgeTrainingWithMulticut = self.edgeTrainingWithMulticutApplet.topLevelOperator
        opClassifierCache = opEdgeTrainingWithMulticut.opEdgeTraining.opClassifierCache

        # Restore classifier we saved in prepareForNewLane() (if any)
        if self.stored_classifier:
            opClassifierCache.forceValue(self.stored_classifier)
            # Release reference
            self.stored_classifier = None

    def connectLane(self, laneIndex):
        """
        Override from base class.
        """
        opDataSelection = self.dataSelectionApplet.topLevelOperator.getLane(
            laneIndex)
        opWsdt = self.wsdtApplet.topLevelOperator.getLane(laneIndex)
        opEdgeTrainingWithMulticut = self.edgeTrainingWithMulticutApplet.topLevelOperator.getLane(
            laneIndex)
        opDataExport = self.dataExportApplet.topLevelOperator.getLane(
            laneIndex)

        # RAW DATA: Convert to float32
        opConvertRaw = OpConvertDtype(parent=self)
        opConvertRaw.ConversionDtype.setValue(np.float32)
        opConvertRaw.Input.connect(
            opDataSelection.ImageGroup[self.DATA_ROLE_RAW])

        # PROBABILITIES: Convert to float32
        opConvertProbabilities = OpConvertDtype(parent=self)
        opConvertProbabilities.ConversionDtype.setValue(np.float32)
        opConvertProbabilities.Input.connect(
            opDataSelection.ImageGroup[self.DATA_ROLE_PROBABILITIES])

        # GROUNDTRUTH: Convert to uint32, relabel, and cache
        opConvertGroundtruth = OpConvertDtype(parent=self)
        opConvertGroundtruth.ConversionDtype.setValue(np.uint32)
        opConvertGroundtruth.Input.connect(
            opDataSelection.ImageGroup[self.DATA_ROLE_GROUNDTRUTH])

        opRelabelGroundtruth = OpRelabelConsecutive(parent=self)
        opRelabelGroundtruth.Input.connect(opConvertGroundtruth.Output)

        opGroundtruthCache = OpBlockedArrayCache(parent=self)
        opGroundtruthCache.CompressionEnabled.setValue(True)
        opGroundtruthCache.Input.connect(opRelabelGroundtruth.Output)

        # watershed inputs
        opWsdt.RawData.connect(opDataSelection.ImageGroup[self.DATA_ROLE_RAW])
        opWsdt.Input.connect(
            opDataSelection.ImageGroup[self.DATA_ROLE_PROBABILITIES])

        # Actual computation is done with both RawData and Probabilities
        opStackRawAndVoxels = OpSimpleStacker(parent=self)
        opStackRawAndVoxels.Images.resize(2)
        opStackRawAndVoxels.Images[0].connect(opConvertRaw.Output)
        opStackRawAndVoxels.Images[1].connect(opConvertProbabilities.Output)
        opStackRawAndVoxels.AxisFlag.setValue('c')

        # If superpixels are available from a file, use it.
        opSuperpixelsSelect = OpPrecomputedInput(ignore_dirty_input=True,
                                                 parent=self)
        opSuperpixelsSelect.PrecomputedInput.connect(
            opDataSelection.ImageGroup[self.DATA_ROLE_SUPERPIXELS])
        opSuperpixelsSelect.SlowInput.connect(opWsdt.Superpixels)

        # If the superpixel file changes, then we have to remove the training labels from the image
        opEdgeTraining = opEdgeTrainingWithMulticut.opEdgeTraining

        def handle_new_superpixels(*args):
            opEdgeTraining.handle_dirty_superpixels(opEdgeTraining.Superpixels)

        opDataSelection.ImageGroup[self.DATA_ROLE_SUPERPIXELS].notifyReady(
            handle_new_superpixels)
        opDataSelection.ImageGroup[self.DATA_ROLE_SUPERPIXELS].notifyUnready(
            handle_new_superpixels)

        # edge training inputs
        opEdgeTrainingWithMulticut.RawData.connect(opDataSelection.ImageGroup[
            self.DATA_ROLE_RAW])  # Used for visualization only
        opEdgeTrainingWithMulticut.VoxelData.connect(
            opStackRawAndVoxels.Output)
        opEdgeTrainingWithMulticut.Superpixels.connect(
            opSuperpixelsSelect.Output)
        opEdgeTrainingWithMulticut.GroundtruthSegmentation.connect(
            opGroundtruthCache.Output)

        # DataExport inputs
        opDataExport.RawData.connect(
            opDataSelection.ImageGroup[self.DATA_ROLE_RAW])
        opDataExport.RawDatasetInfo.connect(
            opDataSelection.DatasetGroup[self.DATA_ROLE_RAW])
        opDataExport.Inputs.resize(len(self.EXPORT_NAMES))
        opDataExport.Inputs[0].connect(opEdgeTrainingWithMulticut.Output)
        for slot in opDataExport.Inputs:
            assert slot.partner is not None

    def onProjectLoaded(self, projectManager):
        """
        Overridden from Workflow base class.  Called by the Project Manager.
        
        If the user provided command-line arguments, use them to configure 
        the workflow inputs and output settings.
        """
        # Configure the data export operator.
        if self._data_export_args:
            self.dataExportApplet.configure_operator_with_parsed_args(
                self._data_export_args)

        # Retrain the classifier?
        if self.parsed_workflow_args.retrain:
            self._force_retrain_classifier(projectManager)

        if self._headless and self._batch_input_args and self._data_export_args:
            # Make sure the watershed can be computed if necessary.
            opWsdt = self.wsdtApplet.topLevelOperator
            opWsdt.FreezeCache.setValue(False)

            # Error checks
            if (self._batch_input_args.raw_data
                    and len(self._batch_input_args.probabilities) != len(
                        self._batch_input_args.raw_data)):
                msg = "Error: Your input file lists are malformed.\n"
                msg += "Usage: run_ilastik.sh --headless --raw_data <file1> <file2>... --probabilities <file1> <file2>..."
                sys.exit(msg)

            if (self._batch_input_args.superpixels
                    and (not self._batch_input_args.raw_data
                         or len(self._batch_input_args.superpixels) != len(
                             self._batch_input_args.raw_data))):
                msg = "Error: Wrong number of superpixel file inputs."
                sys.exit(msg)

            logger.info("Beginning Batch Processing")
            self.batchProcessingApplet.run_export_from_parsed_args(
                self._batch_input_args)
            logger.info("Completed Batch Processing")

    def _force_retrain_classifier(self, projectManager):
        logger.info("Retraining edge classifier...")
        op = self.edgeTrainingWithMulticutApplet.topLevelOperator

        # Cause the classifier to be dirty so it is forced to retrain.
        # (useful if the stored labels or features were changed outside ilastik)
        op.FeatureNames.setDirty()

        # Request the classifier, which forces training
        new_classifier = op.opEdgeTraining.opClassifierCache.Output.value
        if new_classifier is None:
            raise RuntimeError(
                "Classifier could not be trained! Check your labels and features."
            )

        # store new classifier to project file
        projectManager.saveProject(force_all_save=False)

    def prepare_for_entire_export(self):
        """
        Assigned to DataExportApplet.prepare_for_entire_export
        (See above.)
        """
        # While exporting results, the segmentation cache should not be "frozen"
        self.freeze_status = self.edgeTrainingWithMulticutApplet.topLevelOperator.FreezeCache.value
        self.edgeTrainingWithMulticutApplet.topLevelOperator.FreezeCache.setValue(
            False)

    def post_process_entire_export(self):
        """
        Assigned to DataExportApplet.post_process_entire_export
        (See above.)
        """
        # After export is finished, re-freeze the segmentation cache.
        self.edgeTrainingWithMulticutApplet.topLevelOperator.FreezeCache.setValue(
            self.freeze_status)

    def handleAppletStateUpdateRequested(self):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.appletStateUpdateRequested`
        """
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        opWsdt = self.wsdtApplet.topLevelOperator
        opEdgeTrainingWithMulticut = self.edgeTrainingWithMulticutApplet.topLevelOperator
        opDataExport = self.dataExportApplet.topLevelOperator

        # If no data, nothing else is ready.
        input_ready = len(opDataSelection.ImageGroup
                          ) > 0 and not self.dataSelectionApplet.busy

        superpixels_available_from_file = False
        lane_index = self._shell.currentImageIndex
        if lane_index != -1:
            superpixels_available_from_file = opDataSelection.ImageGroup[
                lane_index][self.DATA_ROLE_SUPERPIXELS].ready()

        superpixels_ready = opWsdt.Superpixels.ready()

        # The user isn't allowed to touch anything while batch processing is running.
        batch_processing_busy = self.batchProcessingApplet.busy

        self._shell.setAppletEnabled(self.dataSelectionApplet,
                                     not batch_processing_busy)
        self._shell.setAppletEnabled(
            self.wsdtApplet, not batch_processing_busy and input_ready
            and not superpixels_available_from_file)
        self._shell.setAppletEnabled(
            self.edgeTrainingWithMulticutApplet, not batch_processing_busy
            and input_ready and superpixels_ready)
        self._shell.setAppletEnabled(
            self.dataExportApplet, not batch_processing_busy and input_ready
            and opEdgeTrainingWithMulticut.Output.ready())
        self._shell.setAppletEnabled(self.batchProcessingApplet,
                                     not batch_processing_busy and input_ready)

        # Lastly, check for certain "busy" conditions, during which we
        #  should prevent the shell from closing the project.
        busy = False
        busy |= self.dataSelectionApplet.busy
        busy |= self.wsdtApplet.busy
        busy |= self.edgeTrainingWithMulticutApplet.busy
        busy |= self.dataExportApplet.busy
        busy |= self.batchProcessingApplet.busy
        self._shell.enableProjectChanges(not busy)

    def handle_applet_changed(self, prev_index, current_index):
        if prev_index != current_index:
            # If the user is viewing an applet downstream of the WSDT applet,
            # make sure the superpixels are always up-to-date.
            opWsdt = self.wsdtApplet.topLevelOperator
            opWsdt.FreezeCache.setValue(self._shell.currentAppletIndex <=
                                        self.applets.index(self.wsdtApplet))

            # Same for the multicut segmentation
            opMulticut = self.edgeTrainingWithMulticutApplet.topLevelOperator
            opMulticut.FreezeCache.setValue(
                self._shell.currentAppletIndex <= self.applets.index(
                    self.edgeTrainingWithMulticutApplet))
    def __init__(self, shell, headless, workflow_cmdline_args,
                 project_creation_args, *args, **kwargs):
        graph = kwargs.pop("graph") if "graph" in kwargs else Graph()
        super().__init__(shell,
                         headless,
                         workflow_cmdline_args,
                         project_creation_args,
                         graph=graph,
                         *args,
                         **kwargs)
        self.stored_object_classifier = None

        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument(
            "--fillmissing",
            help="use 'fill missing' applet with chosen detection method",
            choices=["classic", "svm", "none"],
            default="none",
        )
        parser.add_argument("--nobatch",
                            help="do not append batch applets",
                            action="store_true",
                            default=False)

        parsed_creation_args, unused_args = parser.parse_known_args(
            project_creation_args)

        self.fillMissing = parsed_creation_args.fillmissing

        parsed_args, unused_args = parser.parse_known_args(
            workflow_cmdline_args)
        if parsed_args.fillmissing != "none" and parsed_creation_args.fillmissing != parsed_args.fillmissing:
            logger.error(
                "Ignoring --fillmissing cmdline arg.  Can't specify a different fillmissing setting after the project has already been created."
            )

        self.batch = not parsed_args.nobatch

        self._applets = []

        self.createInputApplets()

        if self.fillMissing != "none":
            self.fillMissingSlicesApplet = FillMissingSlicesApplet(
                self, "Fill Missing Slices", "Fill Missing Slices",
                self.fillMissing)
            self._applets.append(self.fillMissingSlicesApplet)

        # our main applets
        self.objectExtractionApplet = ObjectExtractionApplet(
            workflow=self, name="Object Feature Selection")
        self.objectClassificationApplet = ObjectClassificationApplet(
            workflow=self)
        self._tableExporter = TableExporter(
            self.objectClassificationApplet.topLevelOperator)
        self.dataExportApplet = ObjectClassificationDataExportApplet(
            self,
            "Object Information Export",
            table_exporter=self._tableExporter)

        # Customization hooks
        self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        self.dataExportApplet.post_process_entire_export = self.post_process_entire_export

        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.WorkingDirectory.connect(
            self.dataSelectionApplet.topLevelOperator.WorkingDirectory)

        opDataExport.SelectionNames.setValue(
            self.ExportNames.asDisplayNameList())

        self._batch_export_args = None
        self._batch_input_args = None
        self._export_args = None
        self.batchProcessingApplet = None

        self._applets.append(self.objectExtractionApplet)
        self._applets.append(self.objectClassificationApplet)
        self._applets.append(self.dataExportApplet)

        if self.batch:
            self.batchProcessingApplet = BatchProcessingApplet(
                self, "Batch Processing", self.dataSelectionApplet,
                self.dataExportApplet)
            self._applets.append(self.batchProcessingApplet)

            if unused_args:
                exportsArgParser, _ = self.exportsArgParser
                self._export_args, unused_args = exportsArgParser.parse_known_args(
                    unused_args)

                # We parse the export setting args first.  All remaining args are considered input files by the input applet.
                self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args(
                    unused_args)
                self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args(
                    unused_args)

                # For backwards compatibility, translate these special args into the standard syntax
                self._batch_input_args.export_source = self._export_args.export_source

        self.blockwiseObjectClassificationApplet = BlockwiseObjectClassificationApplet(
            self, "Blockwise Object Classification",
            "Blockwise Object Classification")
        self._applets.append(self.blockwiseObjectClassificationApplet)

        if unused_args:
            logger.warning("Unused command-line args: {}".format(unused_args))
class ObjectClassificationWorkflow(Workflow):
    workflowName = "Object Classification Workflow Base"
    defaultAppletIndex = 1 # show DataSelection by default

    def __init__(self, shell, headless,
                 workflow_cmdline_args,
                 project_creation_args,
                 *args, **kwargs):
        graph = kwargs['graph'] if 'graph' in kwargs else Graph()
        if 'graph' in kwargs:
            del kwargs['graph']
        super(ObjectClassificationWorkflow, self).__init__(shell, headless, workflow_cmdline_args, project_creation_args, graph=graph, *args, **kwargs)
        self.stored_pixel_classifier = None
        self.stored_object_classifier = None

        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument('--fillmissing', help="use 'fill missing' applet with chosen detection method", choices=['classic', 'svm', 'none'], default='none')
        parser.add_argument('--filter', help="pixel feature filter implementation.", choices=['Original', 'Refactored', 'Interpolated'], default='Original')
        parser.add_argument('--nobatch', help="do not append batch applets", action='store_true', default=False)
        
        parsed_creation_args, unused_args = parser.parse_known_args(project_creation_args)

        self.fillMissing = parsed_creation_args.fillmissing
        self.filter_implementation = parsed_creation_args.filter

        parsed_args, unused_args = parser.parse_known_args(workflow_cmdline_args)
        if parsed_args.fillmissing != 'none' and parsed_creation_args.fillmissing != parsed_args.fillmissing:
            logger.error( "Ignoring --fillmissing cmdline arg.  Can't specify a different fillmissing setting after the project has already been created." )
        
        if parsed_args.filter != 'Original' and parsed_creation_args.filter != parsed_args.filter:
            logger.error( "Ignoring --filter cmdline arg.  Can't specify a different filter setting after the project has already been created." )

        self.batch = not parsed_args.nobatch

        self._applets = []

        self.pcApplet = None
        self.projectMetadataApplet = ProjectMetadataApplet()
        self._applets.append(self.projectMetadataApplet)

        self.setupInputs()
        
        if self.fillMissing != 'none':
            self.fillMissingSlicesApplet = FillMissingSlicesApplet(
                self, "Fill Missing Slices", "Fill Missing Slices", self.fillMissing)
            self._applets.append(self.fillMissingSlicesApplet)

        if isinstance(self, ObjectClassificationWorkflowPixel):
            self.input_types = 'raw'
        elif isinstance(self, ObjectClassificationWorkflowBinary):
            self.input_types = 'raw+binary'
        elif isinstance( self, ObjectClassificationWorkflowPrediction ):
            self.input_types = 'raw+pmaps'
        
        # our main applets
        self.objectExtractionApplet = ObjectExtractionApplet(workflow=self, name = "Object Feature Selection")
        self.objectClassificationApplet = ObjectClassificationApplet(workflow=self)
        self.dataExportApplet = ObjectClassificationDataExportApplet(self, "Object Information Export")
        self.dataExportApplet.set_exporting_operator(self.objectClassificationApplet.topLevelOperator)

        # Customization hooks
        self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        #self.dataExportApplet.prepare_lane_for_export = self.prepare_lane_for_export
        self.dataExportApplet.post_process_lane_export = self.post_process_lane_export
        self.dataExportApplet.post_process_entire_export = self.post_process_entire_export
        
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.WorkingDirectory.connect( self.dataSelectionApplet.topLevelOperator.WorkingDirectory )
        
        # See EXPORT_SELECTION_PREDICTIONS and EXPORT_SELECTION_PROBABILITIES, above
        export_selection_names = ['Object Predictions',
                                  'Object Probabilities',
                                  'Blockwise Object Predictions',
                                  'Blockwise Object Probabilities']
        if self.input_types == 'raw':
            # Re-configure to add the pixel probabilities option
            # See EXPORT_SELECTION_PIXEL_PROBABILITIES, above
            export_selection_names.append( 'Pixel Probabilities' )
        opDataExport.SelectionNames.setValue( export_selection_names )

        self._batch_export_args = None
        self._batch_input_args = None
        self._export_args = None
        self.batchProcessingApplet = None
        if self.batch:
            self.batchProcessingApplet = BatchProcessingApplet(self, 
                                                               "Batch Processing", 
                                                               self.dataSelectionApplet, 
                                                               self.dataExportApplet)
    
            if unused_args:
                # Additional export args (specific to the object classification workflow)
                export_arg_parser = argparse.ArgumentParser()
                export_arg_parser.add_argument( "--table_filename", help="The location to export the object feature/prediction CSV file.", required=False )
                export_arg_parser.add_argument( "--export_object_prediction_img", action="store_true" )
                export_arg_parser.add_argument( "--export_object_probability_img", action="store_true" )
                export_arg_parser.add_argument( "--export_pixel_probability_img", action="store_true" )
                
                # TODO: Support this, too, someday?
                #export_arg_parser.add_argument( "--export_object_label_img", action="store_true" )
                
                    
                self._export_args, unused_args = export_arg_parser.parse_known_args(unused_args)
                if self.input_types != 'raw' and self._export_args.export_pixel_probability_img:
                    raise RuntimeError("Invalid command-line argument: \n"\
                                       "--export_pixel_probability_img' can only be used with the combined "\
                                       "'Pixel Classification + Object Classification' workflow.")

                if sum([self._export_args.export_object_prediction_img,
                        self._export_args.export_object_probability_img,
                        self._export_args.export_pixel_probability_img]) > 1:
                    raise RuntimeError("Invalid command-line arguments: Only one type classification output can be exported at a time.")

                # We parse the export setting args first.  All remaining args are considered input files by the input applet.
                self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args( unused_args )
                self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args( unused_args )

                # For backwards compatibility, translate these special args into the standard syntax
                if self._export_args.export_object_prediction_img:
                    self._batch_input_args.export_source = "Object Predictions"
                if self._export_args.export_object_probability_img:
                    self._batch_input_args.export_source = "Object Probabilities"
                if self._export_args.export_pixel_probability_img:
                    self._batch_input_args.export_source = "Pixel Probabilities"


        self.blockwiseObjectClassificationApplet = BlockwiseObjectClassificationApplet(
            self, "Blockwise Object Classification", "Blockwise Object Classification")

        self._applets.append(self.objectExtractionApplet)
        self._applets.append(self.objectClassificationApplet)
        self._applets.append(self.dataExportApplet)
        if self.batchProcessingApplet:
            self._applets.append(self.batchProcessingApplet)
        self._applets.append(self.blockwiseObjectClassificationApplet)

        if unused_args:
            logger.warn("Unused command-line args: {}".format( unused_args ))

    @property
    def applets(self):
        return self._applets

    @property
    def imageNameListSlot(self):
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def prepareForNewLane(self, laneIndex):
        if self.pcApplet:
            opPixelClassification = self.pcApplet.topLevelOperator
            if opPixelClassification.classifier_cache.Output.ready() and \
               not opPixelClassification.classifier_cache._dirty:
                self.stored_pixel_classifier = opPixelClassification.classifier_cache.Output.value
            else:
                self.stored_pixel_classifier = None
        
        opObjectClassification = self.objectClassificationApplet.topLevelOperator
        if opObjectClassification.classifier_cache.Output.ready() and \
           not opObjectClassification.classifier_cache._dirty:
            self.stored_object_classifier = opObjectClassification.classifier_cache.Output.value
        else:
            self.stored_object_classifier = None

    def handleNewLanesAdded(self):
        """
        If new lanes were added, then we invalidated our classifiers unecessarily.
        Here, we can restore the classifier so it doesn't need to be retrained.
        """
        # If we have stored classifiers, restore them into the workflow now.
        if self.stored_pixel_classifier:
            opPixelClassification = self.pcApplet.topLevelOperator
            opPixelClassification.classifier_cache.forceValue(self.stored_pixel_classifier)
            # Release reference
            self.stored_pixel_classifier = None

        if self.stored_object_classifier:
            opObjectClassification = self.objectClassificationApplet.topLevelOperator
            opObjectClassification.classifier_cache.forceValue(self.stored_object_classifier)
            # Release reference
            self.stored_object_classifier = None

    def connectLane(self, laneIndex):
        rawslot, binaryslot = self.connectInputs(laneIndex)

        opData = self.dataSelectionApplet.topLevelOperator.getLane(laneIndex)

        opObjExtraction = self.objectExtractionApplet.topLevelOperator.getLane(laneIndex)
        opObjClassification = self.objectClassificationApplet.topLevelOperator.getLane(laneIndex)
        opDataExport = self.dataExportApplet.topLevelOperator.getLane(laneIndex)
        opBlockwiseObjectClassification = self.blockwiseObjectClassificationApplet.topLevelOperator.getLane(laneIndex)

        opObjExtraction.RawImage.connect(rawslot)
        opObjExtraction.BinaryImage.connect(binaryslot)

        opObjClassification.RawImages.connect(rawslot)
        opObjClassification.BinaryImages.connect(binaryslot)

        opObjClassification.SegmentationImages.connect(opObjExtraction.LabelImage)
        opObjClassification.ObjectFeatures.connect(opObjExtraction.RegionFeatures)
        opObjClassification.ComputedFeatureNames.connect(opObjExtraction.Features)

        # Data Export connections
        opDataExport.RawData.connect( opData.ImageGroup[0] )
        opDataExport.RawDatasetInfo.connect( opData.DatasetGroup[0] )
        opDataExport.Inputs.resize(4)
        opDataExport.Inputs[EXPORT_SELECTION_PREDICTIONS].connect( opObjClassification.UncachedPredictionImages )
        opDataExport.Inputs[EXPORT_SELECTION_PROBABILITIES].connect( opObjClassification.ProbabilityChannelImage )
        opDataExport.Inputs[EXPORT_SELECTION_BLOCKWISE_PREDICTIONS].connect( opBlockwiseObjectClassification.PredictionImage )
        opDataExport.Inputs[EXPORT_SELECTION_BLOCKWISE_PROBABILITIES].connect( opBlockwiseObjectClassification.ProbabilityChannelImage )
        
        if self.input_types == 'raw':
            # Append the prediction probabilities to the list of slots that can be exported.
            opDataExport.Inputs.resize(5)
            # Pull from this slot since the data has already been through the Op5 operator
            # (All data in the export operator must have matching spatial dimensions.)
            opThreshold = self.thresholdingApplet.topLevelOperator.getLane(laneIndex)
            opDataExport.Inputs[EXPORT_SELECTION_PIXEL_PROBABILITIES].connect( opThreshold.InputImage )

        opObjClassification = self.objectClassificationApplet.topLevelOperator.getLane(laneIndex)
        opBlockwiseObjectClassification = self.blockwiseObjectClassificationApplet.topLevelOperator.getLane(laneIndex)

        opBlockwiseObjectClassification.RawImage.connect(opObjClassification.RawImages)
        opBlockwiseObjectClassification.BinaryImage.connect(opObjClassification.BinaryImages)
        opBlockwiseObjectClassification.Classifier.connect(opObjClassification.Classifier)
        opBlockwiseObjectClassification.LabelsCount.connect(opObjClassification.NumLabels)
        opBlockwiseObjectClassification.SelectedFeatures.connect(opObjClassification.SelectedFeatures)
        
    def onProjectLoaded(self, projectManager):
        if not self._headless:
            return
        
        if not (self._batch_input_args and self._batch_export_args):
            raise RuntimeError("Currently, this workflow has no batch mode and headless mode support")
        
        # Check for problems: Is the project file ready to use?
        opObjClassification = self.objectClassificationApplet.topLevelOperator
        if not opObjClassification.Classifier.ready():
            logger.error( "Can't run batch prediction.\n"
                          "Couldn't obtain a classifier from your project file: {}.\n"
                          "Please make sure your project is fully configured with a trained classifier."
                          .format(projectManager.currentProjectPath) )
            return

        # Configure the data export operator.
        if self._batch_export_args:
            self.dataExportApplet.configure_operator_with_parsed_args( self._batch_export_args )

        if self._export_args:        
            csv_filename = self._export_args.table_filename
            if csv_filename:
                # The user wants to override the csv export location via 
                #  the command-line arguments. Apply the new setting to the operator.
                settings, selected_features = self.objectClassificationApplet.topLevelOperator.get_table_export_settings()
                if settings is None:
                    raise RuntimeError("You can't export the CSV object table unless you configure it in the GUI first.")
                assert 'file path' in settings, "Expected settings dict to contain a 'file path' key.  Did you rename that key?"
                settings['file path'] = csv_filename
                self.objectClassificationApplet.topLevelOperator.configure_table_export_settings( settings, selected_features )

        # Configure the batch data selection operator.
        if self._batch_input_args and self._batch_input_args.raw_data:
            logger.info("Beginning Batch Processing")
            self.batchProcessingApplet.run_export_from_parsed_args(self._batch_input_args)
            logger.info("Completed Batch Processing")

    def prepare_for_entire_export(self):
        # Un-freeze the workflow so we don't just get a bunch of zeros from the caches when we ask for results
        if self.pcApplet:
            self.pc_freeze_status = self.pcApplet.topLevelOperator.FreezePredictions.value
            self.pcApplet.topLevelOperator.FreezePredictions.setValue(False)
        self.oc_freeze_status = self.objectClassificationApplet.topLevelOperator.FreezePredictions.value
        self.objectClassificationApplet.topLevelOperator.FreezePredictions.setValue(False)

    def post_process_entire_export(self):
        # Unfreeze.
        if self.pcApplet:
            self.pcApplet.topLevelOperator.FreezePredictions.setValue(self.pc_freeze_status)
        self.objectClassificationApplet.topLevelOperator.FreezePredictions.setValue(self.oc_freeze_status)

    def post_process_lane_export(self, lane_index):
        # FIXME: This probably only works for the non-blockwise export slot.
        #        We should assert that the user isn't using the blockwise slot.
        settings, selected_features = self.objectClassificationApplet.topLevelOperator.get_table_export_settings()
        if settings:
            raw_dataset_info = self.dataSelectionApplet.topLevelOperator.DatasetGroup[lane_index][0].value
            if raw_dataset_info.location == DatasetInfo.Location.FileSystem:
                filename_suffix = raw_dataset_info.nickname
            else:
                filename_suffix = str(lane_index)
            req = self.objectClassificationApplet.topLevelOperator.export_object_data(
                        lane_index, 
                        # FIXME: Even in non-headless mode, we can't show the gui because we're running in a non-main thread.
                        #        That's not a huge deal, because there's still a progress bar for the overall export.
                        show_gui=False, 
                        filename_suffix=filename_suffix)
            req.wait()
         
    def getHeadlessOutputSlot(self, slotId):
        if slotId == "BatchPredictionImage":
            return self.opBatchClassify.PredictionImage
        raise Exception("Unknown headless output slot")

    def handleAppletStateUpdateRequested(self, upstream_ready=False):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.appletStateUpdateRequested`
        
        This method will be called by the child classes with the result of their
        own applet readyness findings as keyword argument.
        """

        # all workflows have these applets in common:

        # object feature selection
        # object classification
        # object prediction export
        # blockwise classification
        # batch input
        # batch prediction export

        self._shell.setAppletEnabled(self.dataSelectionApplet, not self.batchProcessingApplet.busy)

        cumulated_readyness = upstream_ready
        cumulated_readyness &= not self.batchProcessingApplet.busy # Nothing can be touched while batch mode is executing.

        self._shell.setAppletEnabled(self.objectExtractionApplet, cumulated_readyness)

        object_features_ready = ( self.objectExtractionApplet.topLevelOperator.Features.ready()
                                  and len(self.objectExtractionApplet.topLevelOperator.Features.value) > 0 )
        cumulated_readyness = cumulated_readyness and object_features_ready
        self._shell.setAppletEnabled(self.objectClassificationApplet, cumulated_readyness)

        opObjectClassification = self.objectClassificationApplet.topLevelOperator
        invalid_classifier = opObjectClassification.classifier_cache.fixAtCurrent.value and \
                             opObjectClassification.classifier_cache.Output.ready() and\
                             opObjectClassification.classifier_cache.Output.value is None

        invalid_classifier |= not opObjectClassification.NumLabels.ready() or \
                              opObjectClassification.NumLabels.value < 2

        object_classification_ready = object_features_ready and not invalid_classifier

        cumulated_readyness = cumulated_readyness and object_classification_ready
        self._shell.setAppletEnabled(self.dataExportApplet, cumulated_readyness)

        if self.batch:
            object_prediction_ready = True  # TODO is that so?
            cumulated_readyness = cumulated_readyness and object_prediction_ready

            self._shell.setAppletEnabled(self.blockwiseObjectClassificationApplet, cumulated_readyness)
            self._shell.setAppletEnabled(self.batchProcessingApplet, cumulated_readyness)

        # Lastly, check for certain "busy" conditions, during which we 
        # should prevent the shell from closing the project.
        #TODO implement
        busy = False
        self._shell.enableProjectChanges( not busy )

    def _inputReady(self, nRoles):
        slot = self.dataSelectionApplet.topLevelOperator.ImageGroup
        if len(slot) > 0:
            input_ready = True
            for sub in slot:
                input_ready = input_ready and \
                    all([sub[i].ready() for i in range(nRoles)])
        else:
            input_ready = False

        return input_ready

    def postprocessClusterSubResult(self, roi, result, blockwise_fileset):
        """
        This function is only used by special cluster scripts.
        
        When the batch-processing mechanism was rewritten, this function broke.
        It could probably be fixed with minor changes.
        """
        # TODO: Here, we hard-code to select from the first lane only.
        opBatchClassify = self.opBatchClassify[0]
        
        from lazyflow.utility.io_uti.blockwiseFileset import vectorized_pickle_dumps
        # Assume that roi always starts as a multiple of the blockshape
        block_shape = opBatchClassify.get_blockshape()
        assert all(block_shape == blockwise_fileset.description.sub_block_shape), "block shapes don't match"
        assert all((roi[0] % block_shape) == 0), "Sub-blocks must exactly correspond to the blockwise object classification blockshape"
        sub_block_index = roi[0] / blockwise_fileset.description.sub_block_shape

        sub_block_start = sub_block_index
        sub_block_stop = sub_block_start + 1
        sub_block_roi = (sub_block_start, sub_block_stop)
        
        # FIRST, remove all objects that lie outside the block (i.e. remove the ones in the halo)
        region_features = opBatchClassify.BlockwiseRegionFeatures( *sub_block_roi ).wait()
        region_features_dict = region_features.flat[0]
        region_centers = region_features_dict['Default features']['RegionCenter']

        opBlockPipeline = opBatchClassify._blockPipelines[ tuple(roi[0]) ]

        # Compute the block offset within the image coordinates
        halo_roi = opBlockPipeline._halo_roi

        translated_region_centers = region_centers + halo_roi[0][1:-1]

        # TODO: If this is too slow, vectorize this
        mask = numpy.zeros( region_centers.shape[0], dtype=numpy.bool_ )
        for index, translated_region_center in enumerate(translated_region_centers):
            # FIXME: Here we assume t=0 and c=0
            mask[index] = opBatchClassify.is_in_block( roi[0], (0,) + tuple(translated_region_center) + (0,) )
        
        # Always exclude the first object (it's the background??)
        mask[0] = False
        
        # Remove all 'negative' predictions, emit only 'positive' predictions
        # FIXME: Don't hardcode this?
        POSITIVE_LABEL = 2
        objectwise_predictions = opBlockPipeline.ObjectwisePredictions([]).wait()[0]
        assert objectwise_predictions.shape == mask.shape
        mask[objectwise_predictions != POSITIVE_LABEL] = False

        filtered_features = {}
        for feature_group, feature_dict in region_features_dict.items():
            filtered_group = filtered_features[feature_group] = {}
            for feature_name, feature_array in feature_dict.items():
                filtered_group[feature_name] = feature_array[mask]

        # SECOND, translate from block-local coordinates to global (file) coordinates.
        # Unfortunately, we've got multiple translations to perform here:
        # Coordinates in the region features are relative to their own block INCLUDING HALO,
        #  so we need to add the start of the block-with-halo as an offset.
        # BUT the image itself may be offset relative to the BlockwiseFileset coordinates
        #  (due to the view_origin setting), so we also need to add an offset for that, too

        # Get the image offset relative to the file coordinates
        image_offset = blockwise_fileset.description.view_origin
        
        total_offset_5d = halo_roi[0] + image_offset
        total_offset_3d = total_offset_5d[1:-1]

        filtered_features["Default features"]["RegionCenter"] += total_offset_3d
        filtered_features["Default features"]["Coord<Minimum>"] += total_offset_3d
        filtered_features["Default features"]["Coord<Maximum>"] += total_offset_3d

        # Finally, write the features to hdf5
        h5File = blockwise_fileset.getOpenHdf5FileForBlock( roi[0] )
        if 'pickled_region_features' in h5File:
            del h5File['pickled_region_features']

        # Must use str dtype
        dtype = h5py.new_vlen(str)
        dataset = h5File.create_dataset( 'pickled_region_features', shape=(1,), dtype=dtype )
        pickled_features = vectorized_pickle_dumps(numpy.array((filtered_features,)))
        dataset[0] = pickled_features

        object_centers_xyz = filtered_features["Default features"]["RegionCenter"].astype(int)
        object_min_coords_xyz = filtered_features["Default features"]["Coord<Minimum>"].astype(int)
        object_max_coords_xyz = filtered_features["Default features"]["Coord<Maximum>"].astype(int)
        object_sizes = filtered_features["Default features"]["Count"][:,0].astype(int)

        # Also, write out selected features as a 'point cloud' csv file.
        # (Store the csv file next to this block's h5 file.)
        dataset_directory = blockwise_fileset.getDatasetDirectory(roi[0])
        pointcloud_path = os.path.join( dataset_directory, "block-pointcloud.csv" )
        
        logger.info("Writing to csv: {}".format( pointcloud_path ))
        with open(pointcloud_path, "w") as fout:
            csv_writer = csv.DictWriter(fout, OUTPUT_COLUMNS, **CSV_FORMAT)
            csv_writer.writeheader()
        
            for obj_id in range(len(object_sizes)):
                fields = {}
                fields["x_px"], fields["y_px"], fields["z_px"], = object_centers_xyz[obj_id]
                fields["min_x_px"], fields["min_y_px"], fields["min_z_px"], = object_min_coords_xyz[obj_id]
                fields["max_x_px"], fields["max_y_px"], fields["max_z_px"], = object_max_coords_xyz[obj_id]
                fields["size_px"] = object_sizes[obj_id]

                csv_writer.writerow( fields )
                #fout.flush()
        
        logger.info("FINISHED csv export")
    def __init__(self, shell, headless, workflow_cmdline_args,
                 project_creation_args, *args, **kwargs):
        # Create a graph to be shared by all operators
        graph = Graph()
        super(PixelClassificationWorkflow,
              self).__init__(shell,
                             headless,
                             workflow_cmdline_args,
                             project_creation_args,
                             graph=graph,
                             *args,
                             **kwargs)
        self.stored_classifier = None
        self._applets = []
        self._workflow_cmdline_args = workflow_cmdline_args
        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument('--filter',
                            help="pixel feature filter implementation.",
                            choices=['Original', 'Refactored', 'Interpolated'],
                            default='Original')
        parser.add_argument(
            '--print-labels-by-slice',
            help="Print the number of labels for each Z-slice of each image.",
            action="store_true")
        parser.add_argument(
            '--label-search-value',
            help=
            "If provided, only this value is considered when using --print-labels-by-slice",
            default=0,
            type=int)
        parser.add_argument('--generate-random-labels',
                            help="Add random labels to the project file.",
                            action="store_true")
        parser.add_argument(
            '--random-label-value',
            help="The label value to use injecting random labels",
            default=1,
            type=int)
        parser.add_argument(
            '--random-label-count',
            help=
            "The number of random labels to inject via --generate-random-labels",
            default=2000,
            type=int)
        parser.add_argument(
            '--retrain',
            help=
            "Re-train the classifier based on labels stored in project file, and re-save.",
            action="store_true")
        parser.add_argument('--tree-count',
                            help='Number of trees for Vigra RF classifier.',
                            type=int)
        parser.add_argument('--variable-importance-path',
                            help='Location of variable-importance table.',
                            type=str)
        parser.add_argument(
            '--label-proportion',
            help='Proportion of feature-pixels used to train the classifier.',
            type=float)

        # Parse the creation args: These were saved to the project file when this project was first created.
        parsed_creation_args, unused_args = parser.parse_known_args(
            project_creation_args)
        self.filter_implementation = parsed_creation_args.filter

        # Parse the cmdline args for the current session.
        parsed_args, unused_args = parser.parse_known_args(
            workflow_cmdline_args)
        self.print_labels_by_slice = parsed_args.print_labels_by_slice
        self.label_search_value = parsed_args.label_search_value
        self.generate_random_labels = parsed_args.generate_random_labels
        self.random_label_value = parsed_args.random_label_value
        self.random_label_count = parsed_args.random_label_count
        self.retrain = parsed_args.retrain
        self.tree_count = parsed_args.tree_count
        self.variable_importance_path = parsed_args.variable_importance_path
        self.label_proportion = parsed_args.label_proportion

        if parsed_args.filter and parsed_args.filter != parsed_creation_args.filter:
            logger.error(
                "Ignoring new --filter setting.  Filter implementation cannot be changed after initial project creation."
            )

        data_instructions = "Select your input data using the 'Raw Data' tab shown on the right.\n\n"\
                            "Power users: Optionally use the 'Prediction Mask' tab to supply a binary image that tells ilastik where it should avoid computations you don't need."

        # Applets for training (interactive) workflow
        self.dataSelectionApplet = self.createDataSelectionApplet()
        opDataSelection = self.dataSelectionApplet.topLevelOperator

        # see role constants, above
        opDataSelection.DatasetRoles.setValue(
            PixelClassificationWorkflow.ROLE_NAMES)

        self.featureSelectionApplet = self.createFeatureSelectionApplet()

        self.pcApplet = self.createPixelClassificationApplet()
        opClassify = self.pcApplet.topLevelOperator

        self.dataExportApplet = PixelClassificationDataExportApplet(
            self, "Prediction Export")
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.PmapColors.connect(opClassify.PmapColors)
        opDataExport.LabelNames.connect(opClassify.LabelNames)
        opDataExport.WorkingDirectory.connect(opDataSelection.WorkingDirectory)
        opDataExport.SelectionNames.setValue(self.EXPORT_NAMES)

        # Expose for shell
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.featureSelectionApplet)
        self._applets.append(self.pcApplet)
        self._applets.append(self.dataExportApplet)

        self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        self.dataExportApplet.post_process_entire_export = self.post_process_entire_export

        self.batchProcessingApplet = BatchProcessingApplet(
            self, "Batch Processing", self.dataSelectionApplet,
            self.dataExportApplet)

        self._applets.append(self.batchProcessingApplet)
        if unused_args:
            # We parse the export setting args first.  All remaining args are considered input files by the input applet.
            self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args(
                unused_args)
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args(
                unused_args)
        else:
            self._batch_input_args = None
            self._batch_export_args = None

        if unused_args:
            logger.warn("Unused command-line args: {}".format(unused_args))
class ConservationTrackingWorkflowBase(Workflow):
    workflowName = "Automatic Tracking Workflow (Conservation Tracking) BASE"

    def __init__(self, shell, headless, workflow_cmdline_args,
                 project_creation_args, *args, **kwargs):
        graph = kwargs['graph'] if 'graph' in kwargs else Graph()
        if 'graph' in kwargs: del kwargs['graph']
        # if 'withOptTrans' in kwargs:
        #     self.withOptTrans = kwargs['withOptTrans']
        # if 'fromBinary' in kwargs:
        #     self.fromBinary = kwargs['fromBinary']
        super(ConservationTrackingWorkflowBase,
              self).__init__(shell,
                             headless,
                             workflow_cmdline_args,
                             project_creation_args,
                             graph=graph,
                             *args,
                             **kwargs)

        data_instructions = 'Use the "Raw Data" tab to load your intensity image(s).\n\n'
        if self.fromBinary:
            data_instructions += 'Use the "Binary Image" tab to load your segmentation image(s).'
        else:
            data_instructions += 'Use the "Prediction Maps" tab to load your pixel-wise probability image(s).'

        # Variables to store division and cell classifiers to prevent retraining every-time batch processing runs
        self.stored_division_classifier = None
        self.stored_cell_classifier = None

        ## Create applets
        self.dataSelectionApplet = DataSelectionApplet(
            self,
            "Input Data",
            "Input Data",
            forceAxisOrder=['txyzc'],
            instructionText=data_instructions,
            max_lanes=None)

        opDataSelection = self.dataSelectionApplet.topLevelOperator
        if self.fromBinary:
            opDataSelection.DatasetRoles.setValue(
                ['Raw Data', 'Segmentation Image'])
        else:
            opDataSelection.DatasetRoles.setValue(
                ['Raw Data', 'Prediction Maps'])

        if not self.fromBinary:
            self.thresholdTwoLevelsApplet = ThresholdTwoLevelsApplet(
                self, "Threshold and Size Filter", "ThresholdTwoLevels")
        if self.withOptTrans:
            self.opticalTranslationApplet = OpticalTranslationApplet(
                workflow=self)

        self.objectExtractionApplet = TrackingFeatureExtractionApplet(
            workflow=self,
            interactive=False,
            name="Object Feature Computation")

        opObjectExtraction = self.objectExtractionApplet.topLevelOperator
        opObjectExtraction.FeatureNamesVigra.setValue(
            configConservation.allFeaturesObjectCount)

        self.divisionDetectionApplet = self._createDivisionDetectionApplet(
            configConservation.selectedFeaturesDiv)  # Might be None

        if self.divisionDetectionApplet:
            feature_dict_division = {}
            feature_dict_division[config.features_division_name] = {
                name: {}
                for name in config.division_features
            }
            opObjectExtraction.FeatureNamesDivision.setValue(
                feature_dict_division)

            selected_features_div = {}
            for plugin_name in config.selected_features_division.keys():
                selected_features_div[plugin_name] = {
                    name: {}
                    for name in config.selected_features_division[plugin_name]
                }
            # FIXME: do not hard code this
            for name in [
                    'SquaredDistances_' + str(i)
                    for i in range(config.n_best_successors)
            ]:
                selected_features_div[config.features_division_name][name] = {}

            opDivisionDetection = self.divisionDetectionApplet.topLevelOperator
            opDivisionDetection.SelectedFeatures.setValue(
                configConservation.selectedFeaturesDiv)
            opDivisionDetection.LabelNames.setValue(
                ['Not Dividing', 'Dividing'])
            opDivisionDetection.AllowDeleteLabels.setValue(False)
            opDivisionDetection.AllowAddLabel.setValue(False)
            opDivisionDetection.EnableLabelTransfer.setValue(False)

        self.cellClassificationApplet = ObjectClassificationApplet(
            workflow=self,
            name="Object Count Classification",
            projectFileGroupName="CountClassification",
            selectedFeatures=configConservation.selectedFeaturesObjectCount)

        selected_features_objectcount = {}
        for plugin_name in config.selected_features_objectcount.keys():
            selected_features_objectcount[plugin_name] = {
                name: {}
                for name in config.selected_features_objectcount[plugin_name]
            }

        opCellClassification = self.cellClassificationApplet.topLevelOperator
        opCellClassification.SelectedFeatures.setValue(
            configConservation.selectedFeaturesObjectCount)
        opCellClassification.SuggestedLabelNames.setValue([
            'False Detection',
        ] + [str(1) + ' Object'] + [str(i) + ' Objects' for i in range(2, 10)])
        opCellClassification.AllowDeleteLastLabelOnly.setValue(True)
        opCellClassification.EnableLabelTransfer.setValue(False)

        self.trackingApplet = ConservationTrackingApplet(workflow=self)

        self.default_export_filename = '{dataset_dir}/{nickname}-exported_data.csv'
        self.dataExportApplet = TrackingBaseDataExportApplet(
            self,
            "Tracking Result Export",
            default_export_filename=self.default_export_filename)

        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.SelectionNames.setValue(
            ['Object-Identities', 'Tracking-Result', 'Merger-Result'])
        opDataExport.WorkingDirectory.connect(opDataSelection.WorkingDirectory)

        # Extra configuration for object export table (as CSV table or HDF5 table)
        opTracking = self.trackingApplet.topLevelOperator
        self.dataExportApplet.set_exporting_operator(opTracking)
        self.dataExportApplet.prepare_lane_for_export = self.prepare_lane_for_export
        self.dataExportApplet.post_process_lane_export = self.post_process_lane_export

        # table only export is just available for the pgmlink backend, hytra uses the CSV plugin instead
        try:
            import hytra
        except ImportError:
            self.dataExportApplet.includeTableOnlyOption(
            )  # Export table only, without volumes

        # configure export settings
        settings = {
            'file path': self.default_export_filename,
            'compression': {},
            'file type': 'csv'
        }
        selected_features = [
            'Count', 'RegionCenter', 'RegionRadii', 'RegionAxes'
        ]
        opTracking.ExportSettings.setValue((settings, selected_features))

        self._applets = []
        self._applets.append(self.dataSelectionApplet)
        if not self.fromBinary:
            self._applets.append(self.thresholdTwoLevelsApplet)
        if self.withOptTrans:
            self._applets.append(self.opticalTranslationApplet)
        self._applets.append(self.objectExtractionApplet)

        if self.divisionDetectionApplet:
            self._applets.append(self.divisionDetectionApplet)

        self.batchProcessingApplet = BatchProcessingApplet(
            self, "Batch Processing", self.dataSelectionApplet,
            self.dataExportApplet)

        self._applets.append(self.cellClassificationApplet)
        self._applets.append(self.trackingApplet)
        self._applets.append(self.dataExportApplet)
        self._applets.append(self.batchProcessingApplet)

        # Parse export and batch command-line arguments for headless mode
        if workflow_cmdline_args:
            self._data_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args(
                workflow_cmdline_args)
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args(
                workflow_cmdline_args)
        else:
            unused_args = None
            self._data_export_args = None
            self._batch_input_args = None

        if unused_args:
            logger.warn("Unused command-line args: {}".format(unused_args))

    @property
    def applets(self):
        return self._applets

    def _createDivisionDetectionApplet(self, selectedFeatures=dict()):
        return ObjectClassificationApplet(
            workflow=self,
            name="Division Detection (optional)",
            projectFileGroupName="DivisionDetection",
            selectedFeatures=selectedFeatures)

    @property
    def imageNameListSlot(self):
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def prepareForNewLane(self, laneIndex):
        # Store division and cell classifiers
        if self.divisionDetectionApplet:
            opDivisionClassification = self.divisionDetectionApplet.topLevelOperator
            if opDivisionClassification.classifier_cache.Output.ready() and \
               not opDivisionClassification.classifier_cache._dirty:
                self.stored_division_classifier = opDivisionClassification.classifier_cache.Output.value
            else:
                self.stored_division_classifier = None

        opCellClassification = self.cellClassificationApplet.topLevelOperator
        if opCellClassification.classifier_cache.Output.ready() and \
           not opCellClassification.classifier_cache._dirty:
            self.stored_cell_classifier = opCellClassification.classifier_cache.Output.value
        else:
            self.stored_cell_classifier = None

    def handleNewLanesAdded(self):
        """
        If new lanes were added, then we invalidated our classifiers unecessarily.
        Here, we can restore the classifier so it doesn't need to be retrained.
        """

        # If we have stored division and cell classifiers, restore them into the workflow now.
        if self.stored_division_classifier:
            opDivisionClassification = self.divisionDetectionApplet.topLevelOperator
            opDivisionClassification.classifier_cache.forceValue(
                self.stored_division_classifier)
            # Release reference
            self.stored_division_classifier = None

        # If we have stored division and cell classifiers, restore them into the workflow now.
        if self.stored_cell_classifier:
            opCellClassification = self.cellClassificationApplet.topLevelOperator
            opCellClassification.classifier_cache.forceValue(
                self.stored_cell_classifier)
            # Release reference
            self.stored_cell_classifier = None

    def connectLane(self, laneIndex):
        opData = self.dataSelectionApplet.topLevelOperator.getLane(laneIndex)
        if not self.fromBinary:
            opTwoLevelThreshold = self.thresholdTwoLevelsApplet.topLevelOperator.getLane(
                laneIndex)
        if self.withOptTrans:
            opOptTranslation = self.opticalTranslationApplet.topLevelOperator.getLane(
                laneIndex)
        opObjExtraction = self.objectExtractionApplet.topLevelOperator.getLane(
            laneIndex)

        if self.divisionDetectionApplet:
            opDivDetection = self.divisionDetectionApplet.topLevelOperator.getLane(
                laneIndex)

        opCellClassification = self.cellClassificationApplet.topLevelOperator.getLane(
            laneIndex)
        opTracking = self.trackingApplet.topLevelOperator.getLane(laneIndex)
        opDataExport = self.dataExportApplet.topLevelOperator.getLane(
            laneIndex)

        op5Raw = OpReorderAxes(parent=self)
        op5Raw.AxisOrder.setValue("txyzc")
        op5Raw.Input.connect(opData.ImageGroup[0])

        if not self.fromBinary:
            opTwoLevelThreshold.InputImage.connect(opData.ImageGroup[1])
            opTwoLevelThreshold.RawInput.connect(
                opData.ImageGroup[0])  # Used for display only
            # opTwoLevelThreshold.Channel.setValue(1)
            binarySrc = opTwoLevelThreshold.CachedOutput
        else:
            binarySrc = opData.ImageGroup[1]

        # Use Op5ifyers for both input datasets such that they are guaranteed to
        # have the same axis order after thresholding
        op5Binary = OpReorderAxes(parent=self)
        op5Binary.AxisOrder.setValue("txyzc")
        op5Binary.Input.connect(binarySrc)

        if self.withOptTrans:
            opOptTranslation.RawImage.connect(op5Raw.Output)
            opOptTranslation.BinaryImage.connect(op5Binary.Output)

        # # Connect operators ##
        opObjExtraction.RawImage.connect(op5Raw.Output)
        opObjExtraction.BinaryImage.connect(op5Binary.Output)
        if self.withOptTrans:
            opObjExtraction.TranslationVectors.connect(
                opOptTranslation.TranslationVectors)

        if self.divisionDetectionApplet:
            opDivDetection.BinaryImages.connect(op5Binary.Output)
            opDivDetection.RawImages.connect(op5Raw.Output)
            opDivDetection.SegmentationImages.connect(
                opObjExtraction.LabelImage)
            opDivDetection.ObjectFeatures.connect(
                opObjExtraction.RegionFeaturesAll)
            opDivDetection.ComputedFeatureNames.connect(
                opObjExtraction.ComputedFeatureNamesAll)

        opCellClassification.BinaryImages.connect(op5Binary.Output)
        opCellClassification.RawImages.connect(op5Raw.Output)
        opCellClassification.SegmentationImages.connect(
            opObjExtraction.LabelImage)
        opCellClassification.ObjectFeatures.connect(
            opObjExtraction.RegionFeaturesVigra)
        opCellClassification.ComputedFeatureNames.connect(
            opObjExtraction.FeatureNamesVigra)

        if self.divisionDetectionApplet:
            opTracking.ObjectFeaturesWithDivFeatures.connect(
                opObjExtraction.RegionFeaturesAll)
            opTracking.ComputedFeatureNamesWithDivFeatures.connect(
                opObjExtraction.ComputedFeatureNamesAll)
            opTracking.DivisionProbabilities.connect(
                opDivDetection.Probabilities)

        opTracking.RawImage.connect(op5Raw.Output)
        opTracking.LabelImage.connect(opObjExtraction.LabelImage)
        opTracking.ObjectFeatures.connect(opObjExtraction.RegionFeaturesVigra)
        opTracking.ComputedFeatureNames.connect(
            opObjExtraction.FeatureNamesVigra)
        opTracking.DetectionProbabilities.connect(
            opCellClassification.Probabilities)
        opTracking.NumLabels.connect(opCellClassification.NumLabels)

        opDataExport.Inputs.resize(3)
        opDataExport.Inputs[0].connect(opTracking.RelabeledImage)
        opDataExport.Inputs[1].connect(opTracking.Output)
        opDataExport.Inputs[2].connect(opTracking.MergerOutput)
        opDataExport.RawData.connect(op5Raw.Output)
        opDataExport.RawDatasetInfo.connect(opData.DatasetGroup[0])

    def prepare_lane_for_export(self, lane_index):
        # Bypass cache on headless mode and batch processing mode
        self.objectExtractionApplet.topLevelOperator[
            lane_index].BypassModeEnabled.setValue(True)

        # Get axes info
        maxt = self.trackingApplet.topLevelOperator[
            lane_index].RawImage.meta.shape[0]
        maxx = self.trackingApplet.topLevelOperator[
            lane_index].RawImage.meta.shape[1]
        maxy = self.trackingApplet.topLevelOperator[
            lane_index].RawImage.meta.shape[2]
        maxz = self.trackingApplet.topLevelOperator[
            lane_index].RawImage.meta.shape[3]
        time_enum = range(maxt)
        x_range = (0, maxx)
        y_range = (0, maxy)
        z_range = (0, maxz)

        ndim = 2
        if (z_range[1] - z_range[0]) > 1:
            ndim = 3

        parameters = self.trackingApplet.topLevelOperator.Parameters.value

        # Save state of axis ranges
        if 'time_range' in parameters:
            self.prev_time_range = parameters['time_range']
        else:
            self.prev_time_range = time_enum

        if 'x_range' in parameters:
            self.prev_x_range = parameters['x_range']
        else:
            self.prev_x_range = x_range

        if 'y_range' in parameters:
            self.prev_y_range = parameters['y_range']
        else:
            self.prev_y_range = y_range

        if 'z_range' in parameters:
            self.prev_z_range = parameters['z_range']
        else:
            self.prev_z_range = z_range

        self.trackingApplet.topLevelOperator[lane_index].track(
            time_range=time_enum,
            x_range=x_range,
            y_range=y_range,
            z_range=z_range,
            size_range=parameters['size_range'],
            x_scale=parameters['scales'][0],
            y_scale=parameters['scales'][1],
            z_scale=parameters['scales'][2],
            maxDist=parameters['maxDist'],
            maxObj=parameters['maxObj'],
            divThreshold=parameters['divThreshold'],
            avgSize=parameters['avgSize'],
            withTracklets=parameters['withTracklets'],
            sizeDependent=parameters['sizeDependent'],
            divWeight=parameters['divWeight'],
            transWeight=parameters['transWeight'],
            withDivisions=parameters['withDivisions'],
            withOpticalCorrection=parameters['withOpticalCorrection'],
            withClassifierPrior=parameters['withClassifierPrior'],
            ndim=ndim,
            withMergerResolution=parameters['withMergerResolution'],
            borderAwareWidth=parameters['borderAwareWidth'],
            withArmaCoordinates=parameters['withArmaCoordinates'],
            cplex_timeout=parameters['cplex_timeout'],
            appearance_cost=parameters['appearanceCost'],
            disappearance_cost=parameters['disappearanceCost'],
            max_nearest_neighbors=parameters['max_nearest_neighbors'],
            numFramesPerSplit=parameters['numFramesPerSplit'],
            force_build_hypotheses_graph=False,
            withBatchProcessing=True)

    def post_process_lane_export(self, lane_index, checkOverwriteFiles=False):
        # `time` parameter ensures we check only once for files that could be overwritten, pop up
        # the MessageBox and then don't export (time=0). For the next round we click the export button,
        # we really want it to export, so time=1. The default parameter is 1, so everything but not 0,
        # in order to ensure writing out even in headless mode.

        # FIXME: This probably only works for the non-blockwise export slot.
        #        We should assert that the user isn't using the blockwise slot.

        # Plugin export if selected
        logger.info(
            "Export source is: " +
            self.dataExportApplet.topLevelOperator.SelectedExportSource.value)

        if self.dataExportApplet.topLevelOperator.SelectedExportSource.value == OpTrackingBaseDataExport.PluginOnlyName:
            logger.info("Export source plugin selected!")
            selectedPlugin = self.dataExportApplet.topLevelOperator.SelectedPlugin.value

            exportPluginInfo = pluginManager.getPluginByName(
                selectedPlugin, category="TrackingExportFormats")
            if exportPluginInfo is None:
                logger.error("Could not find selected plugin %s" %
                             exportPluginInfo)
            else:
                exportPlugin = exportPluginInfo.plugin_object
                logger.info("Exporting tracking result using %s" %
                            selectedPlugin)
                name_format = self.dataExportApplet.topLevelOperator.getLane(
                    lane_index).OutputFilenameFormat.value
                partially_formatted_name = self.getPartiallyFormattedName(
                    lane_index, name_format)

                if exportPlugin.exportsToFile:
                    filename = partially_formatted_name
                    if os.path.basename(filename) == '':
                        filename = os.path.join(filename, 'pluginExport.txt')
                else:
                    filename = os.path.dirname(partially_formatted_name)

                if filename is None or len(str(filename)) == 0:
                    logger.error(
                        "Cannot export from plugin with empty output filename")
                    return

                exportStatus = self.trackingApplet.topLevelOperator.getLane(
                    lane_index).exportPlugin(filename, exportPlugin,
                                             checkOverwriteFiles)
                if not exportStatus:
                    return False
                logger.info("Export done")

            return

        # CSV Table export (only if plugin was not selected)
        settings, selected_features = self.trackingApplet.topLevelOperator.getLane(
            lane_index).get_table_export_settings()
        if settings:
            self.dataExportApplet.progressSignal.emit(0)
            name_format = settings['file path']
            partially_formatted_name = self.getPartiallyFormattedName(
                lane_index, name_format)
            settings['file path'] = partially_formatted_name

            req = self.trackingApplet.topLevelOperator.getLane(
                lane_index
            ).export_object_data(
                lane_index,
                # FIXME: Even in non-headless mode, we can't show the gui because we're running in a non-main thread.
                #        That's not a huge deal, because there's still a progress bar for the overall export.
                show_gui=False)

            req.wait()
            self.dataExportApplet.progressSignal.emit(100)

            # Restore option to bypass cache to false
            self.objectExtractionApplet.topLevelOperator[
                lane_index].BypassModeEnabled.setValue(False)

            # Restore state of axis ranges
            parameters = self.trackingApplet.topLevelOperator.Parameters.value
            parameters['time_range'] = self.prev_time_range
            parameters['x_range'] = self.prev_x_range
            parameters['y_range'] = self.prev_y_range
            parameters['z_range'] = self.prev_z_range

    def getPartiallyFormattedName(self, lane_index, path_format_string):
        ''' Takes the format string for the output file, fills in the most important placeholders, and returns it '''
        raw_dataset_info = self.dataSelectionApplet.topLevelOperator.DatasetGroup[
            lane_index][0].value
        project_path = self.shell.projectManager.currentProjectPath
        project_dir = os.path.dirname(project_path)
        dataset_dir = PathComponents(
            raw_dataset_info.filePath).externalDirectory
        abs_dataset_dir = make_absolute(dataset_dir, cwd=project_dir)
        known_keys = {}
        known_keys['dataset_dir'] = abs_dataset_dir
        nickname = raw_dataset_info.nickname.replace('*', '')
        if os.path.pathsep in nickname:
            nickname = PathComponents(nickname.split(
                os.path.pathsep)[0]).fileNameBase
        known_keys['nickname'] = nickname
        opDataExport = self.dataExportApplet.topLevelOperator.getLane(
            lane_index)
        known_keys[
            'result_type'] = self.dataExportApplet.topLevelOperator.SelectedPlugin._value
        # use partial formatting to fill in non-coordinate name fields
        partially_formatted_name = format_known_keys(path_format_string,
                                                     known_keys)
        return partially_formatted_name

    def _inputReady(self, nRoles):
        slot = self.dataSelectionApplet.topLevelOperator.ImageGroup
        if len(slot) > 0:
            input_ready = True
            for sub in slot:
                input_ready = input_ready and \
                    all([sub[i].ready() for i in range(nRoles)])
        else:
            input_ready = False

        return input_ready

    def onProjectLoaded(self, projectManager):
        """
        Overridden from Workflow base class.  Called by the Project Manager.
        
        If the user provided command-line arguments, use them to configure 
        the workflow inputs and output settings.
        """

        # Configure the data export operator.
        if self._data_export_args:
            self.dataExportApplet.configure_operator_with_parsed_args(
                self._data_export_args)

        # Configure headless mode.
        if self._headless and self._batch_input_args and self._data_export_args:
            logger.info("Beginning Batch Processing")
            self.batchProcessingApplet.run_export_from_parsed_args(
                self._batch_input_args)
            logger.info("Completed Batch Processing")

    def handleAppletStateUpdateRequested(self):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.statusUpdateSignal`
        """
        # If no data, nothing else is ready.
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        input_ready = self._inputReady(2) and not self.dataSelectionApplet.busy

        if not self.fromBinary:
            opThresholding = self.thresholdTwoLevelsApplet.topLevelOperator
            thresholdingOutput = opThresholding.CachedOutput
            thresholding_ready = input_ready and \
                           len(thresholdingOutput) > 0
        else:
            thresholding_ready = True and input_ready

        opObjectExtraction = self.objectExtractionApplet.topLevelOperator
        objectExtractionOutput = opObjectExtraction.ComputedFeatureNamesAll
        features_ready = thresholding_ready and \
                         len(objectExtractionOutput) > 0

        objectCountClassifier_ready = features_ready

        opTracking = self.trackingApplet.topLevelOperator
        tracking_ready = objectCountClassifier_ready

        busy = False
        busy |= self.dataSelectionApplet.busy
        busy |= self.trackingApplet.busy
        busy |= self.dataExportApplet.busy
        busy |= self.batchProcessingApplet.busy
        self._shell.enableProjectChanges(not busy)

        self._shell.setAppletEnabled(self.dataSelectionApplet, not busy)
        if not self.fromBinary:
            self._shell.setAppletEnabled(self.thresholdTwoLevelsApplet,
                                         input_ready and not busy)

        if self.divisionDetectionApplet:
            self._shell.setAppletEnabled(self.divisionDetectionApplet,
                                         features_ready and not busy)

        self._shell.setAppletEnabled(self.objectExtractionApplet,
                                     thresholding_ready and not busy)
        self._shell.setAppletEnabled(self.cellClassificationApplet,
                                     features_ready and not busy)
        self._shell.setAppletEnabled(self.trackingApplet,
                                     objectCountClassifier_ready and not busy)
        self._shell.setAppletEnabled(self.dataExportApplet, tracking_ready and not busy and \
                                    self.dataExportApplet.topLevelOperator.Inputs[0][0].ready() )
        self._shell.setAppletEnabled(self.batchProcessingApplet, tracking_ready and not busy and \
                                    self.dataExportApplet.topLevelOperator.Inputs[0][0].ready() )
예제 #11
0
class NNClassificationWorkflow(Workflow):
    """
    Workflow for the Neural Network Classification Applet
    """
    workflowName = "Neural Network Classification"
    workflowDescription = "This is obviously self-explanatory."
    defaultAppletIndex = 0 # show DataSelection by default

    DATA_ROLE_RAW = 0
    ROLE_NAMES = ['Raw Data']
    EXPORT_NAMES = ['Probabilities']

    @property
    def applets(self):
        """
        Return the list of applets that are owned by this workflow
        """
        return self._applets

    @property
    def imageNameListSlot(self):
        """
        Return the "image name list" slot, which lists the names of
        all image lanes (i.e. files) currently loaded by the workflow
        """
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def __init__(self, shell, headless, workflow_cmdline_args, project_creation_args, *args, **kwargs):

        # Create a graph to be shared by all operators
        graph = Graph()
        super(NNClassificationWorkflow, self).__init__(shell, headless, workflow_cmdline_args, project_creation_args, graph=graph, *args, **kwargs)
        self._applets = []
        self._workflow_cmdline_args = workflow_cmdline_args
        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        # parser.add_argument('--print-labels-by-slice', help="Print the number of labels for each Z-slice of each image.", action="store_true")

        # Parse the creation args: These were saved to the project file when this project was first created.
        parsed_creation_args, unused_args = parser.parse_known_args(project_creation_args)

        # Parse the cmdline args for the current session.
        parsed_args, unused_args = parser.parse_known_args(workflow_cmdline_args)
        # self.print_labels_by_slice = parsed_args.print_labels_by_slice

        data_instructions = "Select your input data using the 'Raw Data' tab shown on the right.\n\n"\
                            "Power users: Optionally use the 'Prediction Mask' tab to supply a binary image that tells ilastik where it should avoid computations you don't need."

        # Applets for training (interactive) workflow
        self.dataSelectionApplet = self.createDataSelectionApplet()
        opDataSelection = self.dataSelectionApplet.topLevelOperator

        # see role constants, above
        opDataSelection.DatasetRoles.setValue(NNClassificationWorkflow.ROLE_NAMES)

        self.nnClassificationApplet = NNClassApplet(self, "NNClassApplet")

        self.dataExportApplet = NNClassificationDataExportApplet(self, 'Data Export')

        # Configure global DataExport settings
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.WorkingDirectory.connect(opDataSelection.WorkingDirectory)
        opDataExport.SelectionNames.setValue(self.EXPORT_NAMES)

        self.batchProcessingApplet = BatchProcessingApplet(self,
                                                           "Batch Processing",
                                                           self.dataSelectionApplet,
                                                           self.dataExportApplet)

        # Expose for shell
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.nnClassificationApplet)
        self._applets.append(self.dataExportApplet)
        self._applets.append(self.batchProcessingApplet)

        if unused_args:
            # We parse the export setting args first.  All remaining args are considered input files by the input applet.
            self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args(unused_args)
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args(unused_args)
        else:
            self._batch_input_args = None
            self._batch_export_args = None

        if unused_args:
            logger.warn("Unused command-line args: {}".format(unused_args))

    def createDataSelectionApplet(self):
        """
        Can be overridden by subclasses, if they want to use
        special parameters to initialize the DataSelectionApplet.
        """
        data_instructions = "Select your input data using the 'Raw Data' tab shown on the right"
        return DataSelectionApplet(self,
                                   "Input Data",
                                   "Input Data",
                                   supportIlastik05Import=True,
                                   instructionText=data_instructions)


    def connectLane(self, laneIndex):
        """
        connects the operators for different lanes, each lane has a laneIndex starting at 0
        """
        opData = self.dataSelectionApplet.topLevelOperator.getLane(laneIndex)
        opNNclassify = self.nnClassificationApplet.topLevelOperator.getLane(laneIndex)
        opDataExport = self.dataExportApplet.topLevelOperator.getLane(laneIndex)

        # Input Image -> Feature Op
        #         and -> Classification Op (for display)
        opNNclassify.InputImage.connect(opData.Image)

        # Data Export connections
        opDataExport.RawData.connect(opData.ImageGroup[self.DATA_ROLE_RAW])
        opDataExport.RawDatasetInfo.connect(opData.DatasetGroup[self.DATA_ROLE_RAW])
        opDataExport.Inputs.resize(len(self.EXPORT_NAMES))
        opDataExport.Inputs[0].connect(opNNclassify.CachedPredictionProbabilities)
        # for slot in opDataExport.Inputs:
        #     assert slot.upstream_slot is not None

    def handleAppletStateUpdateRequested(self):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.appletStateUpdateRequested`
        """
        # If no data, nothing else is ready.
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        input_ready = len(opDataSelection.ImageGroup) > 0 and not self.dataSelectionApplet.busy

        opNNClassification = self.nnClassificationApplet.topLevelOperator

        opDataExport = self.dataExportApplet.topLevelOperator

        predictions_ready = input_ready and \
                            len(opDataExport.Inputs) > 0
                            # opDataExport.Inputs[0][0].ready()
                            # (TinyVector(opDataExport.Inputs[0][0].meta.shape) > 0).all()

        # Problems can occur if the features or input data are changed during live update mode.
        # Don't let the user do that.
        live_update_active = not opNNClassification.FreezePredictions.value

        # The user isn't allowed to touch anything while batch processing is running.
        batch_processing_busy = self.batchProcessingApplet.busy

        self._shell.setAppletEnabled(self.dataSelectionApplet, not batch_processing_busy)
        self._shell.setAppletEnabled(self.nnClassificationApplet, input_ready and not batch_processing_busy)
        self._shell.setAppletEnabled(self.dataExportApplet, predictions_ready and not batch_processing_busy and not live_update_active)

        if self.batchProcessingApplet is not None:
            self._shell.setAppletEnabled(self.batchProcessingApplet, predictions_ready and not batch_processing_busy)

        # Lastly, check for certain "busy" conditions, during which we
        #  should prevent the shell from closing the project.
        busy = False
        busy |= self.dataSelectionApplet.busy
        busy |= self.nnClassificationApplet.busy
        busy |= self.dataExportApplet.busy
        busy |= self.batchProcessingApplet.busy
        self._shell.enableProjectChanges(not busy)
예제 #12
0
    def __init__(self, shell, headless, workflow_cmdline_args, project_creation_args, appendBatchOperators=True, *args, **kwargs):
        graph = kwargs['graph'] if 'graph' in kwargs else Graph()
        if 'graph' in kwargs: del kwargs['graph']
        super( CountingWorkflow, self ).__init__( shell, headless, workflow_cmdline_args, project_creation_args, graph=graph, *args, **kwargs )
        self.stored_classifer = None

        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument("--csv-export-file", help="Instead of exporting prediction density images, export total counts to the given csv path.")
        self.parsed_counting_workflow_args, unused_args = parser.parse_known_args(workflow_cmdline_args)

        ######################
        # Interactive workflow
        ######################

        self.projectMetadataApplet = ProjectMetadataApplet()

        self.dataSelectionApplet = DataSelectionApplet(self,
                                                       "Input Data",
                                                       "Input Data" )
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        role_names = ['Raw Data']
        opDataSelection.DatasetRoles.setValue( role_names )

        self.featureSelectionApplet = FeatureSelectionApplet(self,
                                                             "Feature Selection",
                                                             "FeatureSelections")

        self.countingApplet = CountingApplet(workflow=self)
        opCounting = self.countingApplet.topLevelOperator

        self.dataExportApplet = CountingDataExportApplet(self, "Density Export", opCounting)
        
        # Customization hooks
        self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        self.dataExportApplet.post_process_lane_export = self.post_process_lane_export
        self.dataExportApplet.post_process_entire_export = self.post_process_entire_export
        
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.PmapColors.connect(opCounting.PmapColors)
        opDataExport.LabelNames.connect(opCounting.LabelNames)
        opDataExport.UpperBound.connect(opCounting.UpperBound)
        opDataExport.WorkingDirectory.connect(opDataSelection.WorkingDirectory)
        opDataExport.SelectionNames.setValue( ['Probabilities'] )        

        self._applets = []
        self._applets.append(self.projectMetadataApplet)
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.featureSelectionApplet)
        self._applets.append(self.countingApplet)
        self._applets.append(self.dataExportApplet)

        self._batch_input_args = None
        self._batch_export_args = None
        if appendBatchOperators:
            self.batchProcessingApplet = BatchProcessingApplet( self, 
                                                                "Batch Processing", 
                                                                self.dataSelectionApplet, 
                                                                self.dataExportApplet )
            self._applets.append(self.batchProcessingApplet)
            if unused_args:
                # We parse the export setting args first.  All remaining args are considered input files by the input applet.
                self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args( unused_args )
                self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args( unused_args )
    
        if unused_args:
            logger.warn("Unused command-line args: {}".format( unused_args ))
예제 #13
0
class CountingWorkflow(Workflow):
    workflowName = "Cell Density Counting"
    workflowDescription = "This is obviously self-explanatory."
    defaultAppletIndex = 1 # show DataSelection by default

    def __init__(self, shell, headless, workflow_cmdline_args, project_creation_args, appendBatchOperators=True, *args, **kwargs):
        graph = kwargs['graph'] if 'graph' in kwargs else Graph()
        if 'graph' in kwargs: del kwargs['graph']
        super( CountingWorkflow, self ).__init__( shell, headless, workflow_cmdline_args, project_creation_args, graph=graph, *args, **kwargs )
        self.stored_classifer = None

        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument("--csv-export-file", help="Instead of exporting prediction density images, export total counts to the given csv path.")
        self.parsed_counting_workflow_args, unused_args = parser.parse_known_args(workflow_cmdline_args)

        ######################
        # Interactive workflow
        ######################

        self.projectMetadataApplet = ProjectMetadataApplet()

        self.dataSelectionApplet = DataSelectionApplet(self,
                                                       "Input Data",
                                                       "Input Data" )
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        role_names = ['Raw Data']
        opDataSelection.DatasetRoles.setValue( role_names )

        self.featureSelectionApplet = FeatureSelectionApplet(self,
                                                             "Feature Selection",
                                                             "FeatureSelections")

        self.countingApplet = CountingApplet(workflow=self)
        opCounting = self.countingApplet.topLevelOperator

        self.dataExportApplet = CountingDataExportApplet(self, "Density Export", opCounting)
        
        # Customization hooks
        self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        self.dataExportApplet.post_process_lane_export = self.post_process_lane_export
        self.dataExportApplet.post_process_entire_export = self.post_process_entire_export
        
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.PmapColors.connect(opCounting.PmapColors)
        opDataExport.LabelNames.connect(opCounting.LabelNames)
        opDataExport.UpperBound.connect(opCounting.UpperBound)
        opDataExport.WorkingDirectory.connect(opDataSelection.WorkingDirectory)
        opDataExport.SelectionNames.setValue( ['Probabilities'] )        

        self._applets = []
        self._applets.append(self.projectMetadataApplet)
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.featureSelectionApplet)
        self._applets.append(self.countingApplet)
        self._applets.append(self.dataExportApplet)

        self._batch_input_args = None
        self._batch_export_args = None
        if appendBatchOperators:
            self.batchProcessingApplet = BatchProcessingApplet( self, 
                                                                "Batch Processing", 
                                                                self.dataSelectionApplet, 
                                                                self.dataExportApplet )
            self._applets.append(self.batchProcessingApplet)
            if unused_args:
                # We parse the export setting args first.  All remaining args are considered input files by the input applet.
                self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args( unused_args )
                self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args( unused_args )
    
        if unused_args:
            logger.warn("Unused command-line args: {}".format( unused_args ))


    @property
    def applets(self):
        return self._applets

    @property
    def imageNameListSlot(self):
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def prepareForNewLane(self, laneIndex):
        """
        Overridden from Workflow base class.
        Called immediately before a new lane is added to the workflow.
        """
        # When the new lane is added, dirty notifications will propagate throughout the entire graph.
        # This means the classifier will be marked 'dirty' even though it is still usable.
        # Before that happens, let's store the classifier, so we can restore it at the end of connectLane(), below.
        opCounting = self.countingApplet.topLevelOperator
        if opCounting.classifier_cache.Output.ready() and \
           not opCounting.classifier_cache._dirty:
            self.stored_classifer = opCounting.classifier_cache.Output.value
        else:
            self.stored_classifer = None

    def handleNewLanesAdded(self):
        """
        Overridden from Workflow base class.
        Called immediately after a new lane is added to the workflow and initialized.
        """
        # Restore classifier we saved in prepareForNewLane() (if any)
        if self.stored_classifer is not None:
            self.countingApplet.topLevelOperator.classifier_cache.forceValue(self.stored_classifer)
            # Release reference
            self.stored_classifer = None

    def connectLane(self, laneIndex):
        ## Access applet operators
        opData = self.dataSelectionApplet.topLevelOperator.getLane(laneIndex)
        opTrainingFeatures = self.featureSelectionApplet.topLevelOperator.getLane(laneIndex)
        opCounting = self.countingApplet.topLevelOperator.getLane(laneIndex)
        opDataExport = self.dataExportApplet.topLevelOperator.getLane(laneIndex)


        #### connect input image
        opTrainingFeatures.InputImage.connect(opData.Image)

        opCounting.InputImages.connect(opData.Image)
        opCounting.FeatureImages.connect(opTrainingFeatures.OutputImage)
        opCounting.LabelsAllowedFlags.connect(opData.AllowLabels)
        opCounting.CachedFeatureImages.connect( opTrainingFeatures.CachedOutputImage )
        #opCounting.UserLabels.connect(opClassify.LabelImages)
        #opCounting.ForegroundLabels.connect(opObjExtraction.LabelImage)
        opDataExport.Inputs.resize(1)
        opDataExport.Inputs[0].connect( opCounting.HeadlessPredictionProbabilities )
        opDataExport.RawData.connect( opData.ImageGroup[0] )
        opDataExport.RawDatasetInfo.connect( opData.DatasetGroup[0] )

    def onProjectLoaded(self, projectManager):
        """
        Overridden from Workflow base class.  Called by the Project Manager.
        
        If the user provided command-line arguments, use them to configure 
        the workflow for batch mode and export all results.
        (This workflow's headless mode supports only batch mode for now.)
        """
        # Headless batch mode.
        if self._headless and self._batch_input_args and self._batch_export_args:
            self.dataExportApplet.configure_operator_with_parsed_args( self._batch_export_args )

            # If the user provided a csv_path via the command line,
            # overwrite the setting in the counting export operator.
            csv_path = self.parsed_counting_workflow_args.csv_export_file
            if csv_path:
                self.dataExportApplet.topLevelOperator.CsvFilepath.setValue(csv_path)

            if self.countingApplet.topLevelOperator.classifier_cache._dirty:
                logger.warn("Your project file has no classifier.  "
                            "A new classifier will be trained for this run.")
                
            logger.info("Beginning Batch Processing")
            self.batchProcessingApplet.run_export_from_parsed_args(self._batch_input_args)
            logger.info("Completed Batch Processing")
    
    def prepare_for_entire_export(self):
        """
        Customization hook for data export (including batch mode).
        """
        self.freeze_status = self.countingApplet.topLevelOperator.FreezePredictions.value
        self.countingApplet.topLevelOperator.FreezePredictions.setValue(False)
        # Create a new CSV file to write object counts into.
        self.csv_export_file = None
        if self.dataExportApplet.topLevelOperator.CsvFilepath.ready():
            csv_path = self.dataExportApplet.topLevelOperator.CsvFilepath.value
            logger.info("Exporting object counts to CSV: " + csv_path)
            self.csv_export_file = open(csv_path, 'w')
    
    def post_process_lane_export(self, lane_index):
        """
        Customization hook for data export (including batch mode).
        """
        # Write the object counts for this lane as a line in the CSV file.
        if self.csv_export_file:
            self.dataExportApplet.write_csv_results(self.csv_export_file, lane_index)
        
    def post_process_entire_export(self):
        """
        Customization hook for data export (including batch mode).
        """
        self.countingApplet.topLevelOperator.FreezePredictions.setValue(self.freeze_status)
        if self.csv_export_file:
            self.csv_export_file.close()

    def handleAppletStateUpdateRequested(self):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.statusUpdateSignal`
        """
        # If no data, nothing else is ready.
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        input_ready = len(opDataSelection.ImageGroup) > 0 and not self.dataSelectionApplet.busy

        opFeatureSelection = self.featureSelectionApplet.topLevelOperator
        featureOutput = opFeatureSelection.OutputImage
        features_ready = input_ready and \
                         len(featureOutput) > 0 and  \
                         featureOutput[0].ready() and \
                         (TinyVector(featureOutput[0].meta.shape) > 0).all()

        opDataExport = self.dataExportApplet.topLevelOperator
        predictions_ready = features_ready and \
                            len(opDataExport.Inputs) > 0 and \
                            opDataExport.Inputs[0][0].ready() and \
                            (TinyVector(opDataExport.Inputs[0][0].meta.shape) > 0).all()

        self._shell.setAppletEnabled(self.featureSelectionApplet, input_ready)
        self._shell.setAppletEnabled(self.countingApplet, features_ready)
        self._shell.setAppletEnabled(self.dataExportApplet, predictions_ready and not self.dataExportApplet.busy)
        self._shell.setAppletEnabled(self.batchProcessingApplet, predictions_ready and not self.batchProcessingApplet.busy)
        
        # Lastly, check for certain "busy" conditions, during which we 
        #  should prevent the shell from closing the project.
        busy = False
        busy |= self.dataSelectionApplet.busy
        busy |= self.featureSelectionApplet.busy
        busy |= self.dataExportApplet.busy
        busy |= self.batchProcessingApplet.busy
        self._shell.enableProjectChanges( not busy )
예제 #14
0
    def __init__(self, shell, headless, workflow_cmdline_args, project_creation_args, n_stages, *args, **kwargs):
        """
        n_stages: How many iterations of feature selection and pixel classification should be inserted into the workflow.
        
        All other params are just as in PixelClassificationWorkflow
        """
        # Create a graph to be shared by all operators
        graph = Graph()
        super( NewAutocontextWorkflowBase, self ).__init__( shell, headless, workflow_cmdline_args, project_creation_args, graph=graph, *args, **kwargs )
        self.stored_classifers = []
        self._applets = []
        self._workflow_cmdline_args = workflow_cmdline_args

        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument('--retrain', help="Re-train the classifier based on labels stored in project file, and re-save.", action="store_true")

        # Parse the creation args: These were saved to the project file when this project was first created.
        parsed_creation_args, unused_args = parser.parse_known_args(project_creation_args)
        
        # Parse the cmdline args for the current session.
        parsed_args, unused_args = parser.parse_known_args(workflow_cmdline_args)
        self.retrain = parsed_args.retrain
        
        data_instructions = "Select your input data using the 'Raw Data' tab shown on the right.\n\n"\
                            "Power users: Optionally use the 'Prediction Mask' tab to supply a binary image that tells ilastik where it should avoid computations you don't need."

        self.dataSelectionApplet = self.createDataSelectionApplet()
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        
        # see role constants, above
        role_names = ['Raw Data', 'Prediction Mask']
        opDataSelection.DatasetRoles.setValue( role_names )

        self.featureSelectionApplets = []
        self.pcApplets = []
        for i in range(n_stages):
            self.featureSelectionApplets.append( self.createFeatureSelectionApplet(i) )
            self.pcApplets.append( self.createPixelClassificationApplet(i) )
        opFinalClassify = self.pcApplets[-1].topLevelOperator

        # If *any* stage enters 'live update' mode, make sure they all enter live update mode.
        def sync_freeze_predictions_settings( slot, *args ):
            freeze_predictions = slot.value
            for pcApplet in self.pcApplets:
                pcApplet.topLevelOperator.FreezePredictions.setValue( freeze_predictions )
        for pcApplet in self.pcApplets:
            pcApplet.topLevelOperator.FreezePredictions.notifyDirty( sync_freeze_predictions_settings )

        self.dataExportApplet = PixelClassificationDataExportApplet(self, "Prediction Export")
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.PmapColors.connect( opFinalClassify.PmapColors )
        opDataExport.LabelNames.connect( opFinalClassify.LabelNames )
        opDataExport.WorkingDirectory.connect( opDataSelection.WorkingDirectory )

        self.EXPORT_NAMES = []
        for stage_index in reversed(list(range(n_stages))):
            self.EXPORT_NAMES += ["{} Stage {}".format( name, stage_index+1 ) for name in self.EXPORT_NAMES_PER_STAGE]
        
        # And finally, one last item for *all* probabilities from all stages.
        self.EXPORT_NAMES += ["Probabilities All Stages"]
        opDataExport.SelectionNames.setValue( self.EXPORT_NAMES )

        # Expose for shell
        self._applets.append(self.dataSelectionApplet)
        self._applets += itertools.chain(*list(zip(self.featureSelectionApplets, self.pcApplets)))
        self._applets.append(self.dataExportApplet)
        
        self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        self.dataExportApplet.post_process_entire_export = self.post_process_entire_export

        self.batchProcessingApplet = BatchProcessingApplet(self, 
                                                           "Batch Processing", 
                                                           self.dataSelectionApplet, 
                                                           self.dataExportApplet)

        self._applets.append(self.batchProcessingApplet)
        if unused_args:
            # We parse the export setting args first.  All remaining args are considered input files by the input applet.
            self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args( unused_args )
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args( unused_args )
        else:
            self._batch_input_args = None
            self._batch_export_args = None

        if unused_args:
            logger.warning("Unused command-line args: {}".format( unused_args ))
예제 #15
0
class NewAutocontextWorkflowBase(Workflow):
    
    workflowName = "New Autocontext Base"
    defaultAppletIndex = 0 # show DataSelection by default
    
    DATA_ROLE_RAW = 0
    DATA_ROLE_PREDICTION_MASK = 1
    
    # First export names must match these for the export GUI, because we re-use the ordinary PC gui
    # (See PixelClassificationDataExportGui.)
    EXPORT_NAMES_PER_STAGE = ['Probabilities', 'Simple Segmentation', 'Uncertainty', 'Features', 'Labels', 'Input']
    
    @property
    def applets(self):
        return self._applets

    @property
    def imageNameListSlot(self):
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def __init__(self, shell, headless, workflow_cmdline_args, project_creation_args, n_stages, *args, **kwargs):
        """
        n_stages: How many iterations of feature selection and pixel classification should be inserted into the workflow.
        
        All other params are just as in PixelClassificationWorkflow
        """
        # Create a graph to be shared by all operators
        graph = Graph()
        super( NewAutocontextWorkflowBase, self ).__init__( shell, headless, workflow_cmdline_args, project_creation_args, graph=graph, *args, **kwargs )
        self.stored_classifers = []
        self._applets = []
        self._workflow_cmdline_args = workflow_cmdline_args

        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument('--retrain', help="Re-train the classifier based on labels stored in project file, and re-save.", action="store_true")

        # Parse the creation args: These were saved to the project file when this project was first created.
        parsed_creation_args, unused_args = parser.parse_known_args(project_creation_args)
        
        # Parse the cmdline args for the current session.
        parsed_args, unused_args = parser.parse_known_args(workflow_cmdline_args)
        self.retrain = parsed_args.retrain
        
        data_instructions = "Select your input data using the 'Raw Data' tab shown on the right.\n\n"\
                            "Power users: Optionally use the 'Prediction Mask' tab to supply a binary image that tells ilastik where it should avoid computations you don't need."

        self.dataSelectionApplet = self.createDataSelectionApplet()
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        
        # see role constants, above
        role_names = ['Raw Data', 'Prediction Mask']
        opDataSelection.DatasetRoles.setValue( role_names )

        self.featureSelectionApplets = []
        self.pcApplets = []
        for i in range(n_stages):
            self.featureSelectionApplets.append( self.createFeatureSelectionApplet(i) )
            self.pcApplets.append( self.createPixelClassificationApplet(i) )
        opFinalClassify = self.pcApplets[-1].topLevelOperator

        # If *any* stage enters 'live update' mode, make sure they all enter live update mode.
        def sync_freeze_predictions_settings( slot, *args ):
            freeze_predictions = slot.value
            for pcApplet in self.pcApplets:
                pcApplet.topLevelOperator.FreezePredictions.setValue( freeze_predictions )
        for pcApplet in self.pcApplets:
            pcApplet.topLevelOperator.FreezePredictions.notifyDirty( sync_freeze_predictions_settings )

        self.dataExportApplet = PixelClassificationDataExportApplet(self, "Prediction Export")
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.PmapColors.connect( opFinalClassify.PmapColors )
        opDataExport.LabelNames.connect( opFinalClassify.LabelNames )
        opDataExport.WorkingDirectory.connect( opDataSelection.WorkingDirectory )

        self.EXPORT_NAMES = []
        for stage_index in reversed(list(range(n_stages))):
            self.EXPORT_NAMES += ["{} Stage {}".format( name, stage_index+1 ) for name in self.EXPORT_NAMES_PER_STAGE]
        
        # And finally, one last item for *all* probabilities from all stages.
        self.EXPORT_NAMES += ["Probabilities All Stages"]
        opDataExport.SelectionNames.setValue( self.EXPORT_NAMES )

        # Expose for shell
        self._applets.append(self.dataSelectionApplet)
        self._applets += itertools.chain(*list(zip(self.featureSelectionApplets, self.pcApplets)))
        self._applets.append(self.dataExportApplet)
        
        self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        self.dataExportApplet.post_process_entire_export = self.post_process_entire_export

        self.batchProcessingApplet = BatchProcessingApplet(self, 
                                                           "Batch Processing", 
                                                           self.dataSelectionApplet, 
                                                           self.dataExportApplet)

        self._applets.append(self.batchProcessingApplet)
        if unused_args:
            # We parse the export setting args first.  All remaining args are considered input files by the input applet.
            self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args( unused_args )
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args( unused_args )
        else:
            self._batch_input_args = None
            self._batch_export_args = None

        if unused_args:
            logger.warning("Unused command-line args: {}".format( unused_args ))

    def createDataSelectionApplet(self):
        """
        Can be overridden by subclasses, if they want to use 
        special parameters to initialize the DataSelectionApplet.
        """
        data_instructions = "Select your input data using the 'Raw Data' tab shown on the right"

        c_at_end = ['yxc', 'xyc']
        for perm in itertools.permutations('tzyx', 3):
            c_at_end.append(''.join(perm) + 'c')
        for perm in itertools.permutations('tzyx', 4):
            c_at_end.append(''.join(perm) + 'c')

        return DataSelectionApplet( self,
                                    "Input Data",
                                    "Input Data",
                                    supportIlastik05Import=False,
                                    instructionText=data_instructions,
                                    forceAxisOrder=c_at_end)

    def createFeatureSelectionApplet(self, index):
        """
        Can be overridden by subclasses, if they want to return their own type of FeatureSelectionApplet.
        NOTE: The applet returned here must have the same interface as the regular FeatureSelectionApplet.
              (If it looks like a duck...)
        """
        # Make the first one compatible with the pixel classification workflow,
        # in case the user uses "Import Project"
        hdf5_group_name = 'FeatureSelections'
        if index > 0:
            hdf5_group_name = "FeatureSelections{index:02d}".format(index=index)
        applet = FeatureSelectionApplet(self, "Feature Selection", hdf5_group_name)
        applet.topLevelOperator.name += '{}'.format(index)
        return applet

    def createPixelClassificationApplet(self, index=0):
        """
        Can be overridden by subclasses, if they want to return their own type of PixelClassificationApplet.
        NOTE: The applet returned here must have the same interface as the regular PixelClassificationApplet.
              (If it looks like a duck...)
        """
        # Make the first one compatible with the pixel classification workflow,
        # in case the user uses "Import Project"
        hdf5_group_name = 'PixelClassification'
        if index > 0:
            hdf5_group_name = "PixelClassification{index:02d}".format(index=index)
        applet = PixelClassificationApplet( self, hdf5_group_name )
        applet.topLevelOperator.name += '{}'.format(index)
        return applet


    def prepareForNewLane(self, laneIndex):
        """
        Overridden from Workflow base class.
        Called immediately before a new lane is added to the workflow.
        """
        # When the new lane is added, dirty notifications will propagate throughout the entire graph.
        # This means the classifier will be marked 'dirty' even though it is still usable.
        # Before that happens, let's store the classifier, so we can restore it at the end of connectLane(), below.
        self.stored_classifers = []
        for pcApplet in self.pcApplets:
            opPixelClassification = pcApplet.topLevelOperator
            if opPixelClassification.classifier_cache.Output.ready() and \
               not opPixelClassification.classifier_cache._dirty:
                self.stored_classifers.append(opPixelClassification.classifier_cache.Output.value)
            else:
                self.stored_classifers = []
        
    def handleNewLanesAdded(self):
        """
        Overridden from Workflow base class.
        Called immediately after a new lane is added to the workflow and initialized.
        """
        # Restore classifier we saved in prepareForNewLane() (if any)
        if self.stored_classifers:
            for pcApplet, classifier in zip(self.pcApplets, self.stored_classifers):
                pcApplet.topLevelOperator.classifier_cache.forceValue(classifier)

            # Release references
            self.stored_classifers = []

    def connectLane(self, laneIndex):
        # Get a handle to each operator
        opData = self.dataSelectionApplet.topLevelOperator.getLane(laneIndex)
        opFirstFeatures = self.featureSelectionApplets[0].topLevelOperator.getLane(laneIndex)
        opFirstClassify = self.pcApplets[0].topLevelOperator.getLane(laneIndex)
        opFinalClassify = self.pcApplets[-1].topLevelOperator.getLane(laneIndex)
        opDataExport = self.dataExportApplet.topLevelOperator.getLane(laneIndex)
        
        def checkConstraints(*_):
            # if (opData.Image.meta.dtype in [np.uint8, np.uint16]) == False:
            #    msg = "The Autocontext Workflow only supports 8-bit images (UINT8 pixel type)\n"\
            #          "or 16-bit images (UINT16 pixel type)\n"\
            #          "Your image has a pixel type of {}.  Please convert your data to UINT8 and try again."\
            #          .format( str(np.dtype(opData.Image.meta.dtype)) )
            #    raise DatasetConstraintError( "Autocontext Workflow", msg, unfixable=True )
            pass

        opData.Image.notifyReady( checkConstraints )
        
        # Input Image -> Feature Op
        #         and -> Classification Op (for display)
        opFirstFeatures.InputImage.connect( opData.Image )
        opFirstClassify.InputImages.connect( opData.Image )

        # Feature Images -> Classification Op (for training, prediction)
        opFirstClassify.FeatureImages.connect( opFirstFeatures.OutputImage )
        opFirstClassify.CachedFeatureImages.connect( opFirstFeatures.CachedOutputImage )

        upstreamPcApplets = self.pcApplets[0:-1]
        downstreamFeatureApplets = self.featureSelectionApplets[1:]
        downstreamPcApplets = self.pcApplets[1:]

        for ( upstreamPcApplet,
              downstreamFeaturesApplet,
              downstreamPcApplet ) in zip( upstreamPcApplets, 
                                           downstreamFeatureApplets, 
                                           downstreamPcApplets ):
            
            opUpstreamClassify = upstreamPcApplet.topLevelOperator.getLane(laneIndex)
            opDownstreamFeatures = downstreamFeaturesApplet.topLevelOperator.getLane(laneIndex)
            opDownstreamClassify = downstreamPcApplet.topLevelOperator.getLane(laneIndex)

            # Connect label inputs (all are connected together).
            #opDownstreamClassify.LabelInputs.connect( opUpstreamClassify.LabelInputs )
            
            # Connect data path
            assert opData.Image.meta.dtype == opUpstreamClassify.PredictionProbabilitiesAutocontext.meta.dtype, (
                "Probability dtype needs to match up with input image dtype, got: "
                f"input: {opData.Image.meta.dtype} "
                f"probabilities: {opUpstreamClassify.PredictionProbabilitiesAutocontext.meta.dtype}"
            )
            opStacker = OpMultiArrayStacker(parent=self)
            opStacker.Images.resize(2)
            opStacker.Images[0].connect( opData.Image )
            opStacker.Images[1].connect( opUpstreamClassify.PredictionProbabilitiesAutocontext )
            opStacker.AxisFlag.setValue('c')
            
            opDownstreamFeatures.InputImage.connect( opStacker.Output )
            opDownstreamClassify.InputImages.connect( opStacker.Output )
            opDownstreamClassify.FeatureImages.connect( opDownstreamFeatures.OutputImage )
            opDownstreamClassify.CachedFeatureImages.connect( opDownstreamFeatures.CachedOutputImage )

        # Data Export connections
        opDataExport.RawData.connect( opData.ImageGroup[self.DATA_ROLE_RAW] )
        opDataExport.RawDatasetInfo.connect( opData.DatasetGroup[self.DATA_ROLE_RAW] )
        opDataExport.ConstraintDataset.connect( opData.ImageGroup[self.DATA_ROLE_RAW] )

        opDataExport.Inputs.resize( len(self.EXPORT_NAMES) )
        for reverse_stage_index, (stage_index, pcApplet) in enumerate(reversed(list(enumerate(self.pcApplets)))):
            opPc = pcApplet.topLevelOperator.getLane(laneIndex)
            num_items_per_stage = len(self.EXPORT_NAMES_PER_STAGE)
            opDataExport.Inputs[num_items_per_stage*reverse_stage_index+0].connect( opPc.HeadlessPredictionProbabilities )
            opDataExport.Inputs[num_items_per_stage*reverse_stage_index+1].connect( opPc.SimpleSegmentation )
            opDataExport.Inputs[num_items_per_stage*reverse_stage_index+2].connect( opPc.HeadlessUncertaintyEstimate )
            opDataExport.Inputs[num_items_per_stage*reverse_stage_index+3].connect( opPc.FeatureImages )
            opDataExport.Inputs[num_items_per_stage*reverse_stage_index+4].connect( opPc.LabelImages )
            opDataExport.Inputs[num_items_per_stage*reverse_stage_index+5].connect( opPc.InputImages ) # Input must come last due to an assumption in PixelClassificationDataExportGui

        # One last export slot for all probabilities, all stages
        opAllStageStacker = OpMultiArrayStacker(parent=self)
        opAllStageStacker.Images.resize( len(self.pcApplets) )
        for stage_index, pcApplet in enumerate(self.pcApplets):
            opPc = pcApplet.topLevelOperator.getLane(laneIndex)
            opAllStageStacker.Images[stage_index].connect(opPc.HeadlessPredictionProbabilities)
            opAllStageStacker.AxisFlag.setValue('c')

        # The ideal_blockshape metadata field will be bogus, so just eliminate it
        # (Otherwise, the channels could be split up in an unfortunate way.)
        opMetadataOverride = OpMetadataInjector(parent=self)
        opMetadataOverride.Input.connect( opAllStageStacker.Output )
        opMetadataOverride.Metadata.setValue( {'ideal_blockshape' : None } )
        
        opDataExport.Inputs[-1].connect( opMetadataOverride.Output )
        
        for slot in opDataExport.Inputs:
            assert slot.upstream_slot is not None

    def handleAppletStateUpdateRequested(self):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.appletStateUpdateRequested`
        """
        # If no data, nothing else is ready.
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        input_ready = len(opDataSelection.ImageGroup) > 0 and not self.dataSelectionApplet.busy

        # First, determine various 'ready' states for each pixel classification stage (features+prediction)
        StageFlags = collections.namedtuple("StageFlags", 'input_ready features_ready invalid_classifier predictions_ready live_update_active')
        stage_flags = []
        for stage_index, (featureSelectionApplet, pcApplet) in enumerate(zip(self.featureSelectionApplets, self.pcApplets)):
            if stage_index == 0:
                input_ready = len(opDataSelection.ImageGroup) > 0 and not self.dataSelectionApplet.busy
            else:
                input_ready = stage_flags[stage_index-1].predictions_ready

            opFeatureSelection = featureSelectionApplet.topLevelOperator
            featureOutput = opFeatureSelection.OutputImage
            features_ready = input_ready and \
                             len(featureOutput) > 0 and  \
                             featureOutput[0].ready() and \
                             (TinyVector(featureOutput[0].meta.shape) > 0).all()

            opPixelClassification = pcApplet.topLevelOperator
            invalid_classifier = opPixelClassification.classifier_cache.fixAtCurrent.value and \
                                 opPixelClassification.classifier_cache.Output.ready() and\
                                 opPixelClassification.classifier_cache.Output.value is None
    
            predictions_ready = features_ready and \
                                not invalid_classifier and \
                                len(opPixelClassification.HeadlessPredictionProbabilities) > 0 and \
                                opPixelClassification.HeadlessPredictionProbabilities[0].ready() and \
                                (TinyVector(opPixelClassification.HeadlessPredictionProbabilities[0].meta.shape) > 0).all()

            live_update_active = not opPixelClassification.FreezePredictions.value
            
            stage_flags += [ StageFlags(input_ready, features_ready, invalid_classifier, predictions_ready, live_update_active) ]



        opDataExport = self.dataExportApplet.topLevelOperator
        opPixelClassification = self.pcApplets[0].topLevelOperator

        # Problems can occur if the features or input data are changed during live update mode.
        # Don't let the user do that.
        any_live_update = any(flags.live_update_active for flags in stage_flags)
        
        # The user isn't allowed to touch anything while batch processing is running.
        batch_processing_busy = self.batchProcessingApplet.busy
        
        self._shell.setAppletEnabled(self.dataSelectionApplet, not any_live_update and not batch_processing_busy)

        for stage_index, (featureSelectionApplet, pcApplet) in enumerate(zip(self.featureSelectionApplets, self.pcApplets)):
            upstream_live_update = any(flags.live_update_active for flags in stage_flags[:stage_index])
            this_stage_live_update = stage_flags[stage_index].live_update_active
            downstream_live_update = any(flags.live_update_active for flags in stage_flags[stage_index+1:])
            
            self._shell.setAppletEnabled(featureSelectionApplet, stage_flags[stage_index].input_ready \
                                                                 and not this_stage_live_update \
                                                                 and not downstream_live_update \
                                                                 and not batch_processing_busy)
            
            self._shell.setAppletEnabled(pcApplet, stage_flags[stage_index].features_ready \
                                                   #and not downstream_live_update \ # Not necessary because live update modes are synced -- labels can't be added in live update.
                                                   and not batch_processing_busy)

        self._shell.setAppletEnabled(self.dataExportApplet, stage_flags[-1].predictions_ready and not batch_processing_busy)
        self._shell.setAppletEnabled(self.batchProcessingApplet, predictions_ready and not batch_processing_busy)
    
        # Lastly, check for certain "busy" conditions, during which we 
        #  should prevent the shell from closing the project.
        busy = False
        busy |= self.dataSelectionApplet.busy
        busy |= any(applet.busy for applet in self.featureSelectionApplets)
        busy |= self.dataExportApplet.busy
        busy |= self.batchProcessingApplet.busy
        self._shell.enableProjectChanges( not busy )

    def onProjectLoaded(self, projectManager):
        """
        Overridden from Workflow base class.  Called by the Project Manager.
        
        If the user provided command-line arguments, use them to configure 
        the workflow for batch mode and export all results.
        (This workflow's headless mode supports only batch mode for now.)
        """
        if self._headless:
            # In headless mode, let's see the messages from the training operator.
            logging.getLogger("lazyflow.operators.classifierOperators").setLevel(logging.DEBUG)

        if self.retrain:
            self._force_retrain_classifiers(projectManager)
        
        # Configure the data export operator.
        if self._batch_export_args:
            self.dataExportApplet.configure_operator_with_parsed_args( self._batch_export_args )

        if self._batch_input_args:
            for pcApplet in self.pcApplets:
                if pcApplet.topLevelOperator.classifier_cache._dirty:
                    logger.warning("At least one of your classifiers is not yet trained.  "
                                "A new classifier will be trained for this run.")
                    break

        if self._headless and self._batch_input_args and self._batch_export_args:
            logger.info("Beginning Batch Processing")
            self.batchProcessingApplet.run_export_from_parsed_args(self._batch_input_args)
            logger.info("Completed Batch Processing")

    def prepare_for_entire_export(self):
        # While exporting, we don't want to cache any data.
        export_selection_index = self.dataExportApplet.topLevelOperator.InputSelection.value
        export_selection_name = self.dataExportApplet.topLevelOperator.SelectionNames.value[ export_selection_index ]
        if 'all stages' in export_selection_name.lower():
            # UNLESS we're exporting from more than one stage at a time.
            # In that case, the caches help avoid unnecessary work (except for the last stage)
            self.featureSelectionApplets[-1].topLevelOperator.BypassCache.setValue(True)
        else:
            for featureSeletionApplet in self.featureSelectionApplets:
                featureSeletionApplet.topLevelOperator.BypassCache.setValue(True)
            
        # Unfreeze the classifier caches (ensure that we're exporting based on up-to-date labels)
        self.freeze_statuses = []
        for pcApplet in self.pcApplets:
            self.freeze_statuses.append(pcApplet.topLevelOperator.FreezePredictions.value)
            pcApplet.topLevelOperator.FreezePredictions.setValue(False)

    def post_process_entire_export(self):
        # While exporting, we disabled caches, but now we can enable them again.
        for featureSeletionApplet in self.featureSelectionApplets:
            featureSeletionApplet.topLevelOperator.BypassCache.setValue(False)

        # Re-freeze classifier caches (if necessary)
        for pcApplet, freeze_status in zip(self.pcApplets, self.freeze_statuses):
            pcApplet.topLevelOperator.FreezePredictions.setValue(freeze_status)

    def _force_retrain_classifiers(self, projectManager):
        # Cause the FIRST classifier to be dirty so it is forced to retrain.
        # (useful if the stored labels were changed outside ilastik)
        self.pcApplets[0].topLevelOperator.opTrain.ClassifierFactory.setDirty()
        
        # Unfreeze all classifier caches.
        for pcApplet in self.pcApplets:
            pcApplet.topLevelOperator.FreezePredictions.setValue(False)

        # Request the LAST classifier, which forces training
        _ = self.pcApplets[-1].topLevelOperator.Classifier.value

        # store new classifiers to project file
        projectManager.saveProject(force_all_save=False)

    def menus(self):
        """
        Overridden from Workflow base class
        """
        from PyQt5.QtWidgets import QMenu
        autocontext_menu = QMenu("Autocontext Utilities")
        distribute_action = autocontext_menu.addAction("Distribute Labels...")
        distribute_action.triggered.connect( self.distribute_labels_from_current_stage )

        self._autocontext_menu = autocontext_menu # Must retain here as a member or else reference disappears and the menu is deleted.
        return [self._autocontext_menu]
        
    def distribute_labels_from_current_stage(self):
        """
        Distrubute labels from the currently viewed stage across all other stages.
        """
        # Late import.
        # (Don't import PyQt in headless mode.)
        from PyQt5.QtWidgets import QMessageBox
        current_applet = self._applets[self.shell.currentAppletIndex]
        if current_applet not in self.pcApplets:
            QMessageBox.critical(self.shell, "Wrong page selected", "The currently active page isn't a Training page.")
            return
        
        current_stage_index = self.pcApplets.index(current_applet)
        destination_stage_indexes, partition = self.get_label_distribution_settings( current_stage_index,
                                                                                     num_stages=len(self.pcApplets))
        if destination_stage_indexes is None:
            return # User cancelled
        
        current_applet = self._applets[self.shell.currentAppletIndex]
        opCurrentPixelClassification = current_applet.topLevelOperator
        num_current_stage_classes = len(opCurrentPixelClassification.LabelNames.value)

        # Before we get started, make sure the destination stages have the necessary label classes
        for stage_index in destination_stage_indexes:
            # Get this stage's OpPixelClassification
            opPc = self.pcApplets[stage_index].topLevelOperator
    
            # Copy Label Colors
            current_stage_label_colors = opCurrentPixelClassification.LabelColors.value
            new_label_colors = list(opPc.LabelColors.value)
            new_label_colors[:num_current_stage_classes] = current_stage_label_colors[:num_current_stage_classes]
            opPc.LabelColors.setValue(new_label_colors)
            
            # Copy PMap colors
            current_stage_pmap_colors = opCurrentPixelClassification.PmapColors.value
            new_pmap_colors = list(opPc.PmapColors.value)
            new_pmap_colors[:num_current_stage_classes] = current_stage_pmap_colors[:num_current_stage_classes]
            opPc.PmapColors.setValue(new_pmap_colors)
    
            # Copy Label Names                    
            current_stage_label_names = opCurrentPixelClassification.LabelNames.value
            new_label_names = list(opPc.LabelNames.value)
            new_label_names[:num_current_stage_classes] = current_stage_label_names[:num_current_stage_classes]
            opPc.LabelNames.setValue(new_label_names)

        # For each lane, copy over the labels from the source stage to the destination stages 
        for lane_index in range(len(opCurrentPixelClassification.InputImages)):
            opPcLane = opCurrentPixelClassification.getLane(lane_index)

            # Gather all the labels for this lane
            blockwise_labels = {}
            nonzero_slicings = opPcLane.NonzeroLabelBlocks.value
            for block_slicing in nonzero_slicings:
                # Convert from slicing to roi-tuple so we can hash it in a dict key
                block_roi = sliceToRoi( block_slicing, opPcLane.InputImages.meta.shape )
                block_roi = tuple(map(tuple, block_roi))
                blockwise_labels[block_roi] = opPcLane.LabelImages[block_slicing].wait()

            if partition and current_stage_index in destination_stage_indexes:
                # Clear all labels in the current lane, since we'll be overwriting it with a subset of labels
                # FIXME: We could implement a fast function for this in OpCompressedUserLabelArray...
                for label_value in range(1,num_current_stage_classes+1,):
                    opPcLane.opLabelPipeline.opLabelArray.clearLabel(label_value)

            # Now redistribute those labels across all lanes
            for block_roi, block_labels in list(blockwise_labels.items()):
                nonzero_coords = block_labels.nonzero()

                if partition:
                    num_labels = len(nonzero_coords[0])
                    destination_stage_map = np.random.choice(destination_stage_indexes, (num_labels,))
                
                for stage_index in destination_stage_indexes:
                    if not partition:
                        this_stage_block_labels = block_labels
                    else:
                        # Divide into disjoint partitions
                        # Find the coordinates labels destined for this stage
                        this_stage_coords = np.transpose(nonzero_coords)[destination_stage_map == stage_index]
                        this_stage_coords = tuple(this_stage_coords.transpose())

                        # Extract only the labels destined for this stage
                        this_stage_block_labels = np.zeros_like(block_labels)
                        this_stage_block_labels[this_stage_coords] = block_labels[this_stage_coords]

                    # Get the current lane's view of this stage's OpPixelClassification
                    opPc = self.pcApplets[stage_index].topLevelOperator.getLane(lane_index)

                    # Inject
                    opPc.LabelInputs[roiToSlice(*block_roi)] = this_stage_block_labels
    
    @staticmethod
    def get_label_distribution_settings(source_stage_index, num_stages):
        # Late import.
        # (Don't import PyQt in headless mode.)
        from PyQt5.QtWidgets import QDialog, QVBoxLayout
        class LabelDistributionOptionsDlg( QDialog ):
            """
            A little dialog to let the user specify how the labels should be
            distributed from the current stages to the other stages.
            """
            def __init__(self, source_stage_index, num_stages, *args, **kwargs):
                super(LabelDistributionOptionsDlg, self).__init__(*args, **kwargs)

                from PyQt5.QtCore import Qt
                from PyQt5.QtWidgets import QGroupBox, QCheckBox, QRadioButton, QDialogButtonBox
            
                self.setWindowTitle("Distributing from Stage {}".format(source_stage_index+1))

                self.stage_checkboxes = []
                for stage_index in range(1, num_stages+1):
                    self.stage_checkboxes.append( QCheckBox("Stage {}".format( stage_index )) )
                
                # By default, send labels back into the current stage, at least.
                self.stage_checkboxes[source_stage_index].setChecked(True)
                
                stage_selection_layout = QVBoxLayout()
                for checkbox in self.stage_checkboxes:
                    stage_selection_layout.addWidget( checkbox )

                stage_selection_groupbox = QGroupBox("Send labels from Stage {} to:".format( source_stage_index+1 ), self)
                stage_selection_groupbox.setLayout(stage_selection_layout)
                
                self.copy_button = QRadioButton("Copy", self)
                self.partition_button = QRadioButton("Partition", self)
                self.partition_button.setChecked(True)
                distribution_mode_layout = QVBoxLayout()
                distribution_mode_layout.addWidget(self.copy_button)
                distribution_mode_layout.addWidget(self.partition_button)
                
                distribution_mode_group = QGroupBox("Distribution Mode", self)
                distribution_mode_group.setLayout(distribution_mode_layout)
                
                buttonbox = QDialogButtonBox( Qt.Horizontal, parent=self )
                buttonbox.setStandardButtons( QDialogButtonBox.Ok | QDialogButtonBox.Cancel )
                buttonbox.accepted.connect( self.accept )
                buttonbox.rejected.connect( self.reject )
                
                dlg_layout = QVBoxLayout()
                dlg_layout.addWidget(stage_selection_groupbox)
                dlg_layout.addWidget(distribution_mode_group)
                dlg_layout.addWidget(buttonbox)
                self.setLayout(dlg_layout)

            def distribution_mode(self):
                if self.copy_button.isChecked():
                    return "copy"
                if self.partition_button.isChecked():
                    return "partition"
                assert False, "Shouldn't get here."
            
            def destination_stages(self):
                """
                Return the list of stage_indexes (0-based) that the user checked.
                """
                return [ i for i,box in enumerate(self.stage_checkboxes) if box.isChecked() ]

        dlg = LabelDistributionOptionsDlg( source_stage_index, num_stages )
        if dlg.exec_() == QDialog.Rejected:
            return (None, None)
        
        destination_stage_indexes = dlg.destination_stages()
        partition = (dlg.distribution_mode() == "partition")
        return (destination_stage_indexes, partition)
class PixelClassificationWorkflow(Workflow):
    
    workflowName = "Pixel Classification"
    workflowDescription = "This is obviously self-explanatory."
    defaultAppletIndex = 1 # show DataSelection by default
    
    DATA_ROLE_RAW = 0
    DATA_ROLE_PREDICTION_MASK = 1
    ROLE_NAMES = ['Raw Data', 'Prediction Mask']
    EXPORT_NAMES = ['Probabilities', 'Simple Segmentation', 'Uncertainty', 'Features']
    
    @property
    def applets(self):
        return self._applets

    @property
    def imageNameListSlot(self):
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def __init__(self, shell, headless, workflow_cmdline_args, project_creation_args, *args, **kwargs):
        # Create a graph to be shared by all operators
        graph = Graph()
        super( PixelClassificationWorkflow, self ).__init__( shell, headless, workflow_cmdline_args, project_creation_args, graph=graph, *args, **kwargs )
        self.stored_classifer = None
        self._applets = []
        self._workflow_cmdline_args = workflow_cmdline_args
        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument('--filter', help="pixel feature filter implementation.", choices=['Original', 'Refactored', 'Interpolated'], default='Original')
        parser.add_argument('--print-labels-by-slice', help="Print the number of labels for each Z-slice of each image.", action="store_true")
        parser.add_argument('--label-search-value', help="If provided, only this value is considered when using --print-labels-by-slice", default=0, type=int)
        parser.add_argument('--generate-random-labels', help="Add random labels to the project file.", action="store_true")
        parser.add_argument('--random-label-value', help="The label value to use injecting random labels", default=1, type=int)
        parser.add_argument('--random-label-count', help="The number of random labels to inject via --generate-random-labels", default=2000, type=int)
        parser.add_argument('--retrain', help="Re-train the classifier based on labels stored in project file, and re-save.", action="store_true")
        parser.add_argument('--tree-count', help='Number of trees for Vigra RF classifier.', type=int)
        parser.add_argument('--variable-importance-path', help='Location of variable-importance table.', type=str)
        parser.add_argument('--label-proportion', help='Proportion of feature-pixels used to train the classifier.', type=float)

        # Parse the creation args: These were saved to the project file when this project was first created.
        parsed_creation_args, unused_args = parser.parse_known_args(project_creation_args)
        self.filter_implementation = parsed_creation_args.filter
        
        # Parse the cmdline args for the current session.
        parsed_args, unused_args = parser.parse_known_args(workflow_cmdline_args)
        self.print_labels_by_slice = parsed_args.print_labels_by_slice
        self.label_search_value = parsed_args.label_search_value
        self.generate_random_labels = parsed_args.generate_random_labels
        self.random_label_value = parsed_args.random_label_value
        self.random_label_count = parsed_args.random_label_count
        self.retrain = parsed_args.retrain
        self.tree_count = parsed_args.tree_count
        self.variable_importance_path = parsed_args.variable_importance_path
        self.label_proportion = parsed_args.label_proportion

        if parsed_args.filter and parsed_args.filter != parsed_creation_args.filter:
            logger.error("Ignoring new --filter setting.  Filter implementation cannot be changed after initial project creation.")
        
        data_instructions = "Select your input data using the 'Raw Data' tab shown on the right.\n\n"\
                            "Power users: Optionally use the 'Prediction Mask' tab to supply a binary image that tells ilastik where it should avoid computations you don't need."

        # Applets for training (interactive) workflow 
        self.projectMetadataApplet = ProjectMetadataApplet()
        
        self.dataSelectionApplet = self.createDataSelectionApplet()
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        
        # see role constants, above
        opDataSelection.DatasetRoles.setValue( PixelClassificationWorkflow.ROLE_NAMES )

        self.featureSelectionApplet = self.createFeatureSelectionApplet()

        self.pcApplet = self.createPixelClassificationApplet()
        opClassify = self.pcApplet.topLevelOperator

        self.dataExportApplet = PixelClassificationDataExportApplet(self, "Prediction Export")
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.PmapColors.connect( opClassify.PmapColors )
        opDataExport.LabelNames.connect( opClassify.LabelNames )
        opDataExport.WorkingDirectory.connect( opDataSelection.WorkingDirectory )
        opDataExport.SelectionNames.setValue( self.EXPORT_NAMES )        

        # Expose for shell
        self._applets.append(self.projectMetadataApplet)
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.featureSelectionApplet)
        self._applets.append(self.pcApplet)
        self._applets.append(self.dataExportApplet)
        
        self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        self.dataExportApplet.post_process_entire_export = self.post_process_entire_export

        self.batchProcessingApplet = BatchProcessingApplet(self, 
                                                           "Batch Processing", 
                                                           self.dataSelectionApplet, 
                                                           self.dataExportApplet)

        self._applets.append(self.batchProcessingApplet)
        if unused_args:
            # We parse the export setting args first.  All remaining args are considered input files by the input applet.
            self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args( unused_args )
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args( unused_args )
        else:
            self._batch_input_args = None
            self._batch_export_args = None

        if unused_args:
            logger.warn("Unused command-line args: {}".format( unused_args ))

    def createDataSelectionApplet(self):
        """
        Can be overridden by subclasses, if they want to use 
        special parameters to initialize the DataSelectionApplet.
        """
        data_instructions = "Select your input data using the 'Raw Data' tab shown on the right"
        return DataSelectionApplet( self,
                                    "Input Data",
                                    "Input Data",
                                    supportIlastik05Import=True,
                                    instructionText=data_instructions )


    def createFeatureSelectionApplet(self):
        """
        Can be overridden by subclasses, if they want to return their own type of FeatureSelectionApplet.
        NOTE: The applet returned here must have the same interface as the regular FeatureSelectionApplet.
              (If it looks like a duck...)
        """
        return FeatureSelectionApplet(self, "Feature Selection", "FeatureSelections", self.filter_implementation)

    def createPixelClassificationApplet(self):
        """
        Can be overridden by subclasses, if they want to return their own type of PixelClassificationApplet.
        NOTE: The applet returned here must have the same interface as the regular PixelClassificationApplet.
              (If it looks like a duck...)
        """
        return PixelClassificationApplet( self, "PixelClassification" )

    def prepareForNewLane(self, laneIndex):
        """
        Overridden from Workflow base class.
        Called immediately before a new lane is added to the workflow.
        """
        # When the new lane is added, dirty notifications will propagate throughout the entire graph.
        # This means the classifier will be marked 'dirty' even though it is still usable.
        # Before that happens, let's store the classifier, so we can restore it at the end of connectLane(), below.
        opPixelClassification = self.pcApplet.topLevelOperator
        if opPixelClassification.classifier_cache.Output.ready() and \
           not opPixelClassification.classifier_cache._dirty:
            self.stored_classifer = opPixelClassification.classifier_cache.Output.value
        else:
            self.stored_classifer = None
        
    def handleNewLanesAdded(self):
        """
        Overridden from Workflow base class.
        Called immediately after a new lane is added to the workflow and initialized.
        """
        # Restore classifier we saved in prepareForNewLane() (if any)
        if self.stored_classifer:
            self.pcApplet.topLevelOperator.classifier_cache.forceValue(self.stored_classifer)
            # Release reference
            self.stored_classifer = None

    def connectLane(self, laneIndex):
        # Get a handle to each operator
        opData = self.dataSelectionApplet.topLevelOperator.getLane(laneIndex)
        opTrainingFeatures = self.featureSelectionApplet.topLevelOperator.getLane(laneIndex)
        opClassify = self.pcApplet.topLevelOperator.getLane(laneIndex)
        opDataExport = self.dataExportApplet.topLevelOperator.getLane(laneIndex)
        
        # Input Image -> Feature Op
        #         and -> Classification Op (for display)
        opTrainingFeatures.InputImage.connect( opData.Image )
        opClassify.InputImages.connect( opData.Image )
        
        if ilastik_config.getboolean('ilastik', 'debug'):
            opClassify.PredictionMasks.connect( opData.ImageGroup[self.DATA_ROLE_PREDICTION_MASK] )
        
        # Feature Images -> Classification Op (for training, prediction)
        opClassify.FeatureImages.connect( opTrainingFeatures.OutputImage )
        opClassify.CachedFeatureImages.connect( opTrainingFeatures.CachedOutputImage )
        
        # Training flags -> Classification Op (for GUI restrictions)
        opClassify.LabelsAllowedFlags.connect( opData.AllowLabels )

        # Data Export connections
        opDataExport.RawData.connect( opData.ImageGroup[self.DATA_ROLE_RAW] )
        opDataExport.RawDatasetInfo.connect( opData.DatasetGroup[self.DATA_ROLE_RAW] )
        opDataExport.ConstraintDataset.connect( opData.ImageGroup[self.DATA_ROLE_RAW] )
        opDataExport.Inputs.resize( len(self.EXPORT_NAMES) )
        opDataExport.Inputs[0].connect( opClassify.HeadlessPredictionProbabilities )
        opDataExport.Inputs[1].connect( opClassify.SimpleSegmentation )
        opDataExport.Inputs[2].connect( opClassify.HeadlessUncertaintyEstimate )
        opDataExport.Inputs[3].connect( opClassify.FeatureImages )
        for slot in opDataExport.Inputs:
            assert slot.partner is not None

    def handleAppletStateUpdateRequested(self):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.appletStateUpdateRequested`
        """
        # If no data, nothing else is ready.
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        input_ready = len(opDataSelection.ImageGroup) > 0 and not self.dataSelectionApplet.busy

        opFeatureSelection = self.featureSelectionApplet.topLevelOperator
        featureOutput = opFeatureSelection.OutputImage
        features_ready = input_ready and \
                         len(featureOutput) > 0 and  \
                         featureOutput[0].ready() and \
                         (TinyVector(featureOutput[0].meta.shape) > 0).all()

        opDataExport = self.dataExportApplet.topLevelOperator
        opPixelClassification = self.pcApplet.topLevelOperator

        invalid_classifier = opPixelClassification.classifier_cache.fixAtCurrent.value and \
                             opPixelClassification.classifier_cache.Output.ready() and\
                             opPixelClassification.classifier_cache.Output.value is None

        predictions_ready = features_ready and \
                            not invalid_classifier and \
                            len(opDataExport.Inputs) > 0 and \
                            opDataExport.Inputs[0][0].ready() and \
                            (TinyVector(opDataExport.Inputs[0][0].meta.shape) > 0).all()

        # Problems can occur if the features or input data are changed during live update mode.
        # Don't let the user do that.
        live_update_active = not opPixelClassification.FreezePredictions.value
        
        # The user isn't allowed to touch anything while batch processing is running.
        batch_processing_busy = self.batchProcessingApplet.busy
        
        self._shell.setAppletEnabled(self.dataSelectionApplet, not live_update_active and not batch_processing_busy)
        self._shell.setAppletEnabled(self.featureSelectionApplet, input_ready and not live_update_active and not batch_processing_busy)
        self._shell.setAppletEnabled(self.pcApplet, features_ready and not batch_processing_busy)
        self._shell.setAppletEnabled(self.dataExportApplet, predictions_ready and not batch_processing_busy)

        if self.batchProcessingApplet is not None:
            self._shell.setAppletEnabled(self.batchProcessingApplet, predictions_ready and not batch_processing_busy)
    
        # Lastly, check for certain "busy" conditions, during which we 
        #  should prevent the shell from closing the project.
        busy = False
        busy |= self.dataSelectionApplet.busy
        busy |= self.featureSelectionApplet.busy
        busy |= self.dataExportApplet.busy
        busy |= self.batchProcessingApplet.busy
        self._shell.enableProjectChanges( not busy )

    def onProjectLoaded(self, projectManager):
        """
        Overridden from Workflow base class.  Called by the Project Manager.
        
        If the user provided command-line arguments, use them to configure 
        the workflow for batch mode and export all results.
        (This workflow's headless mode supports only batch mode for now.)
        """
        if self.generate_random_labels:
            self._generate_random_labels(self.random_label_count, self.random_label_value)
            logger.info("Saving project...")
            self._shell.projectManager.saveProject()
            logger.info("Done.")
        
        if self.print_labels_by_slice:
            self._print_labels_by_slice( self.label_search_value )

        if self._headless:
            # In headless mode, let's see the messages from the training operator.
            logging.getLogger("lazyflow.operators.classifierOperators").setLevel(logging.DEBUG)

        if self.variable_importance_path: 
            classifier_factory = self.pcApplet.topLevelOperator.opTrain.ClassifierFactory.value
            classifier_factory.set_variable_importance_path( self.variable_importance_path )
            
        if self.tree_count:
            classifier_factory = self.pcApplet.topLevelOperator.opTrain.ClassifierFactory.value
            classifier_factory.set_num_trees( self.tree_count )
                        
        if self.label_proportion:
            classifier_factory = self.pcApplet.topLevelOperator.opTrain.ClassifierFactory.value
            classifier_factory.set_label_proportion( self.label_proportion )
            
        if self.tree_count or self.label_proportion:
            self.pcApplet.topLevelOperator.ClassifierFactory.setDirty()
            
        if self.retrain:
            # Cause the classifier to be dirty so it is forced to retrain.
            # (useful if the stored labels were changed outside ilastik)
            self.pcApplet.topLevelOperator.opTrain.ClassifierFactory.setDirty()
            
            # Request the classifier, which forces training
            self.pcApplet.topLevelOperator.FreezePredictions.setValue(False)
            _ = self.pcApplet.topLevelOperator.Classifier.value

            # store new classifier to project file
            projectManager.saveProject(force_all_save=False)

        # Configure the data export operator.
        if self._batch_export_args:
            self.dataExportApplet.configure_operator_with_parsed_args( self._batch_export_args )

        if self._batch_input_args and self.pcApplet.topLevelOperator.classifier_cache._dirty:
            logger.warn("Your project file has no classifier.  A new classifier will be trained for this run.")

        if self._headless and self._batch_input_args and self._batch_export_args:
            logger.info("Beginning Batch Processing")
            self.batchProcessingApplet.run_export_from_parsed_args(self._batch_input_args)
            logger.info("Completed Batch Processing")

    def prepare_for_entire_export(self):
        self.freeze_status = self.pcApplet.topLevelOperator.FreezePredictions.value
        self.pcApplet.topLevelOperator.FreezePredictions.setValue(False)

    def post_process_entire_export(self):
        self.pcApplet.topLevelOperator.FreezePredictions.setValue(self.freeze_status)

    def _print_labels_by_slice(self, search_value):
        """
        Iterate over each label image in the project and print the number of labels present on each Z-slice of the image.
        (This is a special feature requested by the FlyEM proofreaders.)
        """
        opTopLevelClassify = self.pcApplet.topLevelOperator
        project_label_count = 0
        for image_index, label_slot in enumerate(opTopLevelClassify.LabelImages):
            tagged_shape = label_slot.meta.getTaggedShape()
            if 'z' not in tagged_shape:
                logger.error("Can't print label counts by Z-slices.  Image #{} has no Z-dimension.".format(image_index))
            else:
                logger.info("Label counts in Z-slices of Image #{}:".format( image_index ))
                slicing = [slice(None)] * len(tagged_shape)
                blank_slices = []
                image_label_count = 0
                for z in range(tagged_shape['z']):
                    slicing[tagged_shape.keys().index('z')] = slice(z, z+1)
                    label_slice = label_slot[slicing].wait()
                    if search_value:                        
                        count = (label_slice == search_value).sum()
                    else:
                        count = (label_slice != 0).sum()
                    if count > 0:
                        logger.info("Z={}: {}".format( z, count ))
                        image_label_count += count
                    else:
                        blank_slices.append( z )
                project_label_count += image_label_count
                if len(blank_slices) > 20:
                    # Don't list the blank slices if there were a lot of them.
                    logger.info("Image #{} has {} blank slices.".format( image_index, len(blank_slices) ))
                elif len(blank_slices) > 0:
                    logger.info( "Image #{} has {} blank slices: {}".format( image_index, len(blank_slices), blank_slices ) )
                else:
                    logger.info( "Image #{} has no blank slices.".format( image_index ) )
                logger.info( "Total labels for Image #{}: {}".format( image_index, image_label_count ) )
        logger.info( "Total labels for project: {}".format( project_label_count ) )

    
    def _generate_random_labels(self, labels_per_image, label_value):
        """
        Inject random labels into the project file.
        (This is a special feature requested by the FlyEM proofreaders.)
        """
        logger.info( "Injecting {} labels of value {} into all images.".format( labels_per_image, label_value ) )
        opTopLevelClassify = self.pcApplet.topLevelOperator
        
        label_names = copy.copy(opTopLevelClassify.LabelNames.value)
        while len(label_names) < label_value:
            label_names.append( "Label {}".format( len(label_names)+1 ) )
        
        opTopLevelClassify.LabelNames.setValue( label_names )
        
        for image_index in range(len(opTopLevelClassify.LabelImages)):
            logger.info( "Injecting labels into image #{}".format( image_index ) )
            # For reproducibility of label generation
            SEED = 1
            numpy.random.seed([SEED, image_index])
        
            label_input_slot = opTopLevelClassify.LabelInputs[image_index]
            label_output_slot = opTopLevelClassify.LabelImages[image_index]
        
            shape = label_output_slot.meta.shape
            random_labels = numpy.zeros( shape=shape, dtype=numpy.uint8 )
            num_pixels = len(random_labels.flat)
            current_progress = -1
            for sample_index in range(labels_per_image):
                flat_index = numpy.random.randint(0,num_pixels)
                # Don't overwrite existing labels
                # Keep looking until we find a blank pixel
                while random_labels.flat[flat_index]:
                    flat_index = numpy.random.randint(0,num_pixels)
                random_labels.flat[flat_index] = label_value

                # Print progress every 10%
                progress = float(sample_index) / labels_per_image
                progress = 10 * (int(100*progress)/10)
                if progress != current_progress:
                    current_progress = progress
                    sys.stdout.write( "{}% ".format( current_progress ) )
                    sys.stdout.flush()

            sys.stdout.write( "100%\n" )
            # Write into the operator
            label_input_slot[fullSlicing(shape)] = random_labels
        
        logger.info( "Done injecting labels" )


    def getHeadlessOutputSlot(self, slotId):
        """
        Not used by the regular app.
        Only used for special cluster scripts.
        """
        # "Regular" (i.e. with the images that the user selected as input data)
        if slotId == "Predictions":
            return self.pcApplet.topLevelOperator.HeadlessPredictionProbabilities
        elif slotId == "PredictionsUint8":
            return self.pcApplet.topLevelOperator.HeadlessUint8PredictionProbabilities
        # "Batch" (i.e. with the images that the user selected as batch inputs).
        elif slotId == "BatchPredictions":
            return self.opBatchPredictionPipeline.HeadlessPredictionProbabilities
        if slotId == "BatchPredictionsUint8":
            return self.opBatchPredictionPipeline.HeadlessUint8PredictionProbabilities
        
        raise Exception("Unknown headless output slot")
class ObjectClassificationWorkflow(Workflow):
    workflowName = "Object Classification Workflow Base"
    defaultAppletIndex = 1 # show DataSelection by default

    def __init__(self, shell, headless,
                 workflow_cmdline_args,
                 project_creation_args,
                 *args, **kwargs):
        graph = kwargs['graph'] if 'graph' in kwargs else Graph()
        if 'graph' in kwargs:
            del kwargs['graph']
        super(ObjectClassificationWorkflow, self).__init__(shell, headless, workflow_cmdline_args, project_creation_args, graph=graph, *args, **kwargs)
        self.stored_pixel_classifier = None
        self.stored_object_classifier = None

        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument('--fillmissing', help="use 'fill missing' applet with chosen detection method", choices=['classic', 'svm', 'none'], default='none')
        parser.add_argument('--filter', help="pixel feature filter implementation.", choices=['Original', 'Refactored', 'Interpolated'], default='Original')
        parser.add_argument('--nobatch', help="do not append batch applets", action='store_true', default=False)
        
        parsed_creation_args, unused_args = parser.parse_known_args(project_creation_args)

        self.fillMissing = parsed_creation_args.fillmissing
        self.filter_implementation = parsed_creation_args.filter

        parsed_args, unused_args = parser.parse_known_args(workflow_cmdline_args)
        if parsed_args.fillmissing != 'none' and parsed_creation_args.fillmissing != parsed_args.fillmissing:
            logger.error( "Ignoring --fillmissing cmdline arg.  Can't specify a different fillmissing setting after the project has already been created." )
        
        if parsed_args.filter != 'Original' and parsed_creation_args.filter != parsed_args.filter:
            logger.error( "Ignoring --filter cmdline arg.  Can't specify a different filter setting after the project has already been created." )

        self.batch = not parsed_args.nobatch

        self._applets = []

        self.pcApplet = None
        self.projectMetadataApplet = ProjectMetadataApplet()
        self._applets.append(self.projectMetadataApplet)

        self.setupInputs()
        
        if self.fillMissing != 'none':
            self.fillMissingSlicesApplet = FillMissingSlicesApplet(
                self, "Fill Missing Slices", "Fill Missing Slices", self.fillMissing)
            self._applets.append(self.fillMissingSlicesApplet)

        if isinstance(self, ObjectClassificationWorkflowPixel):
            self.input_types = 'raw'
        elif isinstance(self, ObjectClassificationWorkflowBinary):
            self.input_types = 'raw+binary'
        elif isinstance( self, ObjectClassificationWorkflowPrediction ):
            self.input_types = 'raw+pmaps'
        
        # our main applets
        self.objectExtractionApplet = ObjectExtractionApplet(workflow=self, name = "Object Feature Selection")
        self.objectClassificationApplet = ObjectClassificationApplet(workflow=self)
        self.dataExportApplet = ObjectClassificationDataExportApplet(self, "Object Information Export")
        self.dataExportApplet.set_exporting_operator(self.objectClassificationApplet.topLevelOperator)

        # Customization hooks
        self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        #self.dataExportApplet.prepare_lane_for_export = self.prepare_lane_for_export
        self.dataExportApplet.post_process_lane_export = self.post_process_lane_export
        self.dataExportApplet.post_process_entire_export = self.post_process_entire_export
        
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.WorkingDirectory.connect( self.dataSelectionApplet.topLevelOperator.WorkingDirectory )
        
        # See EXPORT_SELECTION_PREDICTIONS and EXPORT_SELECTION_PROBABILITIES, above
        export_selection_names = ['Object Predictions',
                                  'Object Probabilities',
                                  'Blockwise Object Predictions',
                                  'Blockwise Object Probabilities']
        if self.input_types == 'raw':
            # Re-configure to add the pixel probabilities option
            # See EXPORT_SELECTION_PIXEL_PROBABILITIES, above
            export_selection_names.append( 'Pixel Probabilities' )
        opDataExport.SelectionNames.setValue( export_selection_names )

        self._batch_export_args = None
        self._batch_input_args = None
        self._export_args = None
        self.batchProcessingApplet = None
        if self.batch:
            self.batchProcessingApplet = BatchProcessingApplet(self, 
                                                               "Batch Processing", 
                                                               self.dataSelectionApplet, 
                                                               self.dataExportApplet)
    
            if unused_args:
                # Additional export args (specific to the object classification workflow)
                export_arg_parser = argparse.ArgumentParser()
                export_arg_parser.add_argument( "--table_filename", help="The location to export the object feature/prediction CSV file.", required=False )
                export_arg_parser.add_argument( "--export_object_prediction_img", action="store_true" )
                export_arg_parser.add_argument( "--export_object_probability_img", action="store_true" )
                
                # TODO: Support this, too, someday?
                #export_arg_parser.add_argument( "--export_object_label_img", action="store_true" )
                
                if self.input_types == 'raw':
                    export_arg_parser.add_argument( "--export_pixel_probability_img", action="store_true" )
                self._export_args, unused_args = export_arg_parser.parse_known_args(unused_args)
                self._export_args.export_pixel_probability_img = self._export_args.export_pixel_probability_img or None

                # We parse the export setting args first.  All remaining args are considered input files by the input applet.
                self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args( unused_args )
                self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args( unused_args )

                # For backwards compatibility, translate these special args into the standard syntax
                if self._export_args.export_object_prediction_img:
                    self._batch_input_args.export_source = "Object Predictions"
                if self._export_args.export_object_probability_img:
                    self._batch_input_args.export_source = "Object Probabilities"
                if self._export_args.export_pixel_probability_img:
                    self._batch_input_args.export_source = "Pixel Probabilities"


        self.blockwiseObjectClassificationApplet = BlockwiseObjectClassificationApplet(
            self, "Blockwise Object Classification", "Blockwise Object Classification")

        self._applets.append(self.objectExtractionApplet)
        self._applets.append(self.objectClassificationApplet)
        self._applets.append(self.dataExportApplet)
        if self.batchProcessingApplet:
            self._applets.append(self.batchProcessingApplet)
        self._applets.append(self.blockwiseObjectClassificationApplet)

        if unused_args:
            logger.warn("Unused command-line args: {}".format( unused_args ))

    @property
    def applets(self):
        return self._applets

    @property
    def imageNameListSlot(self):
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def prepareForNewLane(self, laneIndex):
        if self.pcApplet:
            opPixelClassification = self.pcApplet.topLevelOperator
            if opPixelClassification.classifier_cache.Output.ready() and \
               not opPixelClassification.classifier_cache._dirty:
                self.stored_pixel_classifer = opPixelClassification.classifier_cache.Output.value
            else:
                self.stored_pixel_classifer = None
        
        opObjectClassification = self.objectClassificationApplet.topLevelOperator
        if opObjectClassification.classifier_cache.Output.ready() and \
           not opObjectClassification.classifier_cache._dirty:
            self.stored_object_classifer = opObjectClassification.classifier_cache.Output.value
        else:
            self.stored_object_classifer = None

    def handleNewLanesAdded(self):
        """
        If new lanes were added, then we invalidated our classifiers unecessarily.
        Here, we can restore the classifer so it doesn't need to be retrained.
        """
        # If we have stored classifiers, restore them into the workflow now.
        if self.stored_pixel_classifer:
            opPixelClassification = self.pcApplet.topLevelOperator
            opPixelClassification.classifier_cache.forceValue(self.stored_pixel_classifer)
            # Release reference
            self.stored_pixel_classifer = None

        if self.stored_object_classifer:
            opObjectClassification = self.objectClassificationApplet.topLevelOperator
            opObjectClassification.classifier_cache.forceValue(self.stored_object_classifer)
            # Release reference
            self.stored_object_classifer = None

    def connectLane(self, laneIndex):
        rawslot, binaryslot = self.connectInputs(laneIndex)

        opData = self.dataSelectionApplet.topLevelOperator.getLane(laneIndex)

        opObjExtraction = self.objectExtractionApplet.topLevelOperator.getLane(laneIndex)
        opObjClassification = self.objectClassificationApplet.topLevelOperator.getLane(laneIndex)
        opDataExport = self.dataExportApplet.topLevelOperator.getLane(laneIndex)
        opBlockwiseObjectClassification = self.blockwiseObjectClassificationApplet.topLevelOperator.getLane(laneIndex)

        opObjExtraction.RawImage.connect(rawslot)
        opObjExtraction.BinaryImage.connect(binaryslot)

        opObjClassification.RawImages.connect(rawslot)
        opObjClassification.LabelsAllowedFlags.connect(opData.AllowLabels)
        opObjClassification.BinaryImages.connect(binaryslot)

        opObjClassification.SegmentationImages.connect(opObjExtraction.LabelImage)
        opObjClassification.ObjectFeatures.connect(opObjExtraction.RegionFeatures)
        opObjClassification.ComputedFeatureNames.connect(opObjExtraction.Features)

        # Data Export connections
        opDataExport.RawData.connect( opData.ImageGroup[0] )
        opDataExport.RawDatasetInfo.connect( opData.DatasetGroup[0] )
        opDataExport.Inputs.resize(4)
        opDataExport.Inputs[EXPORT_SELECTION_PREDICTIONS].connect( opObjClassification.UncachedPredictionImages )
        opDataExport.Inputs[EXPORT_SELECTION_PROBABILITIES].connect( opObjClassification.ProbabilityChannelImage )
        opDataExport.Inputs[EXPORT_SELECTION_BLOCKWISE_PREDICTIONS].connect( opBlockwiseObjectClassification.PredictionImage )
        opDataExport.Inputs[EXPORT_SELECTION_BLOCKWISE_PROBABILITIES].connect( opBlockwiseObjectClassification.ProbabilityChannelImage )
        
        if self.input_types == 'raw':
            # Append the prediction probabilities to the list of slots that can be exported.
            opDataExport.Inputs.resize(5)
            # Pull from this slot since the data has already been through the Op5 operator
            # (All data in the export operator must have matching spatial dimensions.)
            opThreshold = self.thresholdingApplet.topLevelOperator.getLane(laneIndex)
            opDataExport.Inputs[EXPORT_SELECTION_PIXEL_PROBABILITIES].connect( opThreshold.InputImage )

        opObjClassification = self.objectClassificationApplet.topLevelOperator.getLane(laneIndex)
        opBlockwiseObjectClassification = self.blockwiseObjectClassificationApplet.topLevelOperator.getLane(laneIndex)

        opBlockwiseObjectClassification.RawImage.connect(opObjClassification.RawImages)
        opBlockwiseObjectClassification.BinaryImage.connect(opObjClassification.BinaryImages)
        opBlockwiseObjectClassification.Classifier.connect(opObjClassification.Classifier)
        opBlockwiseObjectClassification.LabelsCount.connect(opObjClassification.NumLabels)
        opBlockwiseObjectClassification.SelectedFeatures.connect(opObjClassification.SelectedFeatures)
        
    def onProjectLoaded(self, projectManager):
        if not self._headless:
            return
        
        if not (self._batch_input_args and self._batch_export_args):
            raise RuntimeError("Currently, this workflow has no batch mode and headless mode support")
        
        # Check for problems: Is the project file ready to use?
        opObjClassification = self.objectClassificationApplet.topLevelOperator
        if not opObjClassification.Classifier.ready():
            logger.error( "Can't run batch prediction.\n"
                          "Couldn't obtain a classifier from your project file: {}.\n"
                          "Please make sure your project is fully configured with a trained classifier."
                          .format(projectManager.currentProjectPath) )
            return

        # Configure the data export operator.
        if self._batch_export_args:
            self.dataExportApplet.configure_operator_with_parsed_args( self._batch_export_args )

        if self._export_args:        
            csv_filename = self._export_args.table_filename
            if csv_filename:
                # The user wants to override the csv export location via 
                #  the command-line arguments. Apply the new setting to the operator.
                settings, selected_features = self.objectClassificationApplet.topLevelOperator.get_table_export_settings()
                if settings is None:
                    raise RuntimeError("You can't export the CSV object table unless you configure it in the GUI first.")
                assert 'file path' in settings, "Expected settings dict to contain a 'file path' key.  Did you rename that key?"
                settings['file path'] = csv_filename
                self.objectClassificationApplet.topLevelOperator.configure_table_export_settings( settings, selected_features )

        # Configure the batch data selection operator.
        if self._batch_input_args and self._batch_input_args.raw_data:
            logger.info("Beginning Batch Processing")
            self.batchProcessingApplet.run_export_from_parsed_args(self._batch_input_args)
            logger.info("Completed Batch Processing")

    def prepare_for_entire_export(self):
        # Un-freeze the workflow so we don't just get a bunch of zeros from the caches when we ask for results
        if self.pcApplet:
            self.pc_freeze_status = self.pcApplet.topLevelOperator.FreezePredictions.value
            self.pcApplet.topLevelOperator.FreezePredictions.setValue(False)
        self.oc_freeze_status = self.objectClassificationApplet.topLevelOperator.FreezePredictions.value
        self.objectClassificationApplet.topLevelOperator.FreezePredictions.setValue(False)

    def post_process_entire_export(self):
        # Unfreeze.
        if self.pcApplet:
            self.pcApplet.topLevelOperator.FreezePredictions.setValue(self.pc_freeze_status)
        self.objectClassificationApplet.topLevelOperator.FreezePredictions.setValue(self.oc_freeze_status)

    def post_process_lane_export(self, lane_index):
        # FIXME: This probably only works for the non-blockwise export slot.
        #        We should assert that the user isn't using the blockwise slot.
        settings, selected_features = self.objectClassificationApplet.topLevelOperator.get_table_export_settings()
        if settings:
            raw_dataset_info = self.dataSelectionApplet.topLevelOperator.DatasetGroup[lane_index][0].value
            if raw_dataset_info.location == DatasetInfo.Location.FileSystem:
                filename_suffix = raw_dataset_info.nickname
            else:
                filename_suffix = str(lane_index)
            req = self.objectClassificationApplet.topLevelOperator.export_object_data(
                        lane_index, 
                        # FIXME: Even in non-headless mode, we can't show the gui because we're running in a non-main thread.
                        #        That's not a huge deal, because there's still a progress bar for the overall export.
                        show_gui=False, 
                        filename_suffix=filename_suffix)
            req.wait()
         
    def getHeadlessOutputSlot(self, slotId):
        if slotId == "BatchPredictionImage":
            return self.opBatchClassify.PredictionImage
        raise Exception("Unknown headless output slot")

    def handleAppletStateUpdateRequested(self, upstream_ready=False):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.appletStateUpdateRequested`
        
        This method will be called by the child classes with the result of their
        own applet readyness findings as keyword argument.
        """

        # all workflows have these applets in common:

        # object feature selection
        # object classification
        # object prediction export
        # blockwise classification
        # batch input
        # batch prediction export

        self._shell.setAppletEnabled(self.dataSelectionApplet, not self.batchProcessingApplet.busy)

        cumulated_readyness = upstream_ready
        cumulated_readyness &= not self.batchProcessingApplet.busy # Nothing can be touched while batch mode is executing.

        self._shell.setAppletEnabled(self.objectExtractionApplet, cumulated_readyness)

        object_features_ready = ( self.objectExtractionApplet.topLevelOperator.Features.ready()
                                  and len(self.objectExtractionApplet.topLevelOperator.Features.value) > 0 )
        cumulated_readyness = cumulated_readyness and object_features_ready
        self._shell.setAppletEnabled(self.objectClassificationApplet, cumulated_readyness)

        opObjectClassification = self.objectClassificationApplet.topLevelOperator
        invalid_classifier = opObjectClassification.classifier_cache.fixAtCurrent.value and \
                             opObjectClassification.classifier_cache.Output.ready() and\
                             opObjectClassification.classifier_cache.Output.value is None

        invalid_classifier |= not opObjectClassification.NumLabels.ready() or \
                              opObjectClassification.NumLabels.value < 2

        object_classification_ready = object_features_ready and not invalid_classifier

        cumulated_readyness = cumulated_readyness and object_classification_ready
        self._shell.setAppletEnabled(self.dataExportApplet, cumulated_readyness)

        if self.batch:
            object_prediction_ready = True  # TODO is that so?
            cumulated_readyness = cumulated_readyness and object_prediction_ready

            self._shell.setAppletEnabled(self.blockwiseObjectClassificationApplet, cumulated_readyness)
            self._shell.setAppletEnabled(self.batchProcessingApplet, cumulated_readyness)

        # Lastly, check for certain "busy" conditions, during which we 
        # should prevent the shell from closing the project.
        #TODO implement
        busy = False
        self._shell.enableProjectChanges( not busy )

    def _inputReady(self, nRoles):
        slot = self.dataSelectionApplet.topLevelOperator.ImageGroup
        if len(slot) > 0:
            input_ready = True
            for sub in slot:
                input_ready = input_ready and \
                    all([sub[i].ready() for i in range(nRoles)])
        else:
            input_ready = False

        return input_ready

    def postprocessClusterSubResult(self, roi, result, blockwise_fileset):
        """
        This function is only used by special cluster scripts.
        
        When the batch-processing mechanism was rewritten, this function broke.
        It could probably be fixed with minor changes.
        """
        # TODO: Here, we hard-code to select from the first lane only.
        opBatchClassify = self.opBatchClassify[0]
        
        from lazyflow.utility.io.blockwiseFileset import vectorized_pickle_dumps
        # Assume that roi always starts as a multiple of the blockshape
        block_shape = opBatchClassify.get_blockshape()
        assert all(block_shape == blockwise_fileset.description.sub_block_shape), "block shapes don't match"
        assert all((roi[0] % block_shape) == 0), "Sub-blocks must exactly correspond to the blockwise object classification blockshape"
        sub_block_index = roi[0] / blockwise_fileset.description.sub_block_shape

        sub_block_start = sub_block_index
        sub_block_stop = sub_block_start + 1
        sub_block_roi = (sub_block_start, sub_block_stop)
        
        # FIRST, remove all objects that lie outside the block (i.e. remove the ones in the halo)
        region_features = opBatchClassify.BlockwiseRegionFeatures( *sub_block_roi ).wait()
        region_features_dict = region_features.flat[0]
        region_centers = region_features_dict['Default features']['RegionCenter']

        opBlockPipeline = opBatchClassify._blockPipelines[ tuple(roi[0]) ]

        # Compute the block offset within the image coordinates
        halo_roi = opBlockPipeline._halo_roi

        translated_region_centers = region_centers + halo_roi[0][1:-1]

        # TODO: If this is too slow, vectorize this
        mask = numpy.zeros( region_centers.shape[0], dtype=numpy.bool_ )
        for index, translated_region_center in enumerate(translated_region_centers):
            # FIXME: Here we assume t=0 and c=0
            mask[index] = opBatchClassify.is_in_block( roi[0], (0,) + tuple(translated_region_center) + (0,) )
        
        # Always exclude the first object (it's the background??)
        mask[0] = False
        
        # Remove all 'negative' predictions, emit only 'positive' predictions
        # FIXME: Don't hardcode this?
        POSITIVE_LABEL = 2
        objectwise_predictions = opBlockPipeline.ObjectwisePredictions([]).wait()[0]
        assert objectwise_predictions.shape == mask.shape
        mask[objectwise_predictions != POSITIVE_LABEL] = False

        filtered_features = {}
        for feature_group, feature_dict in region_features_dict.items():
            filtered_group = filtered_features[feature_group] = {}
            for feature_name, feature_array in feature_dict.items():
                filtered_group[feature_name] = feature_array[mask]

        # SECOND, translate from block-local coordinates to global (file) coordinates.
        # Unfortunately, we've got multiple translations to perform here:
        # Coordinates in the region features are relative to their own block INCLUDING HALO,
        #  so we need to add the start of the block-with-halo as an offset.
        # BUT the image itself may be offset relative to the BlockwiseFileset coordinates
        #  (due to the view_origin setting), so we also need to add an offset for that, too

        # Get the image offset relative to the file coordinates
        image_offset = blockwise_fileset.description.view_origin
        
        total_offset_5d = halo_roi[0] + image_offset
        total_offset_3d = total_offset_5d[1:-1]

        filtered_features["Default features"]["RegionCenter"] += total_offset_3d
        filtered_features["Default features"]["Coord<Minimum>"] += total_offset_3d
        filtered_features["Default features"]["Coord<Maximum>"] += total_offset_3d

        # Finally, write the features to hdf5
        h5File = blockwise_fileset.getOpenHdf5FileForBlock( roi[0] )
        if 'pickled_region_features' in h5File:
            del h5File['pickled_region_features']

        # Must use str dtype
        dtype = h5py.new_vlen(str)
        dataset = h5File.create_dataset( 'pickled_region_features', shape=(1,), dtype=dtype )
        pickled_features = vectorized_pickle_dumps(numpy.array((filtered_features,)))
        dataset[0] = pickled_features

        object_centers_xyz = filtered_features["Default features"]["RegionCenter"].astype(int)
        object_min_coords_xyz = filtered_features["Default features"]["Coord<Minimum>"].astype(int)
        object_max_coords_xyz = filtered_features["Default features"]["Coord<Maximum>"].astype(int)
        object_sizes = filtered_features["Default features"]["Count"][:,0].astype(int)

        # Also, write out selected features as a 'point cloud' csv file.
        # (Store the csv file next to this block's h5 file.)
        dataset_directory = blockwise_fileset.getDatasetDirectory(roi[0])
        pointcloud_path = os.path.join( dataset_directory, "block-pointcloud.csv" )
        
        logger.info("Writing to csv: {}".format( pointcloud_path ))
        with open(pointcloud_path, "w") as fout:
            csv_writer = csv.DictWriter(fout, OUTPUT_COLUMNS, **CSV_FORMAT)
            csv_writer.writeheader()
        
            for obj_id in range(len(object_sizes)):
                fields = {}
                fields["x_px"], fields["y_px"], fields["z_px"], = object_centers_xyz[obj_id]
                fields["min_x_px"], fields["min_y_px"], fields["min_z_px"], = object_min_coords_xyz[obj_id]
                fields["max_x_px"], fields["max_y_px"], fields["max_z_px"], = object_max_coords_xyz[obj_id]
                fields["size_px"] = object_sizes[obj_id]

                csv_writer.writerow( fields )
                #fout.flush()
        
        logger.info("FINISHED csv export")
예제 #18
0
class StructuredTrackingWorkflowBase( Workflow ):
    workflowName = "Structured Learning Tracking Workflow BASE"

    @property
    def applets(self):
        return self._applets

    @property
    def imageNameListSlot(self):
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def __init__( self, shell, headless, workflow_cmdline_args, project_creation_args, *args, **kwargs ):
        graph = kwargs['graph'] if 'graph' in kwargs else Graph()
        if 'graph' in kwargs: del kwargs['graph']

        super(StructuredTrackingWorkflowBase, self).__init__(shell, headless, workflow_cmdline_args, project_creation_args, graph=graph, *args, **kwargs)

        data_instructions = 'Use the "Raw Data" tab to load your intensity image(s).\n\n'
        if self.fromBinary:
            data_instructions += 'Use the "Binary Image" tab to load your segmentation image(s).'
        else:
            data_instructions += 'Use the "Prediction Maps" tab to load your pixel-wise probability image(s).'

        # Create applets
        self.dataSelectionApplet = DataSelectionApplet(self,
            "Input Data",
            "Input Data",
            batchDataGui=False,
            forceAxisOrder=['txyzc'],
            instructionText=data_instructions,
            max_lanes=1)

        opDataSelection = self.dataSelectionApplet.topLevelOperator
        if self.fromBinary:
            opDataSelection.DatasetRoles.setValue( ['Raw Data', 'Binary Image'] )
        else:
            opDataSelection.DatasetRoles.setValue( ['Raw Data', 'Prediction Maps'] )

        if not self.fromBinary:
            self.thresholdTwoLevelsApplet = ThresholdTwoLevelsApplet( self,"Threshold and Size Filter","ThresholdTwoLevels" )

        self.divisionDetectionApplet = ObjectClassificationApplet(workflow=self,
                                                                     name="Division Detection (optional)",
                                                                     projectFileGroupName="DivisionDetection",
                                                                     selectedFeatures=configConservation.selectedFeaturesDiv)

        self.cellClassificationApplet = ObjectClassificationApplet(workflow=self,
                                                                     name="Object Count Classification",
                                                                     projectFileGroupName="CountClassification",
                                                                     selectedFeatures=configConservation.selectedFeaturesObjectCount)

        self.trackingFeatureExtractionApplet = TrackingFeatureExtractionApplet(name="Object Feature Computation",workflow=self, interactive=False)

        self.objectExtractionApplet = ObjectExtractionApplet(name="Object Feature Computation",workflow=self, interactive=False)

        self.annotationsApplet = AnnotationsApplet( name="Training", workflow=self )
        opAnnotations = self.annotationsApplet.topLevelOperator

        self.trackingApplet = StructuredTrackingApplet( name="Tracking - Structured Learning", workflow=self )
        opStructuredTracking = self.trackingApplet.topLevelOperator

        if SOLVER=="CPLEX" or SOLVER=="GUROBI":
            self._solver="ILP"
        elif SOLVER=="DPCT":
            self._solver="Flow-based"
        else:
            self._solver=None
        opStructuredTracking._solver = self._solver

        self.default_tracking_export_filename = '{dataset_dir}/{nickname}-tracking_exported_data.csv'
        self.dataExportTrackingApplet = TrackingBaseDataExportApplet(
            self,
            "Tracking Result Export",
            default_export_filename=self.default_tracking_export_filename,
            pluginExportFunc=self._pluginExportFunc
        )
        opDataExportTracking = self.dataExportTrackingApplet.topLevelOperator
        opDataExportTracking.SelectionNames.setValue( ['Tracking-Result', 'Merger-Result', 'Object-Identities'] )
        opDataExportTracking.WorkingDirectory.connect( opDataSelection.WorkingDirectory )
        self.dataExportTrackingApplet.set_exporting_operator(opStructuredTracking)
        self.dataExportTrackingApplet.prepare_lane_for_export = self.prepare_lane_for_export

        # configure export settings
        settings = {'file path': self.default_tracking_export_filename, 'compression': {}, 'file type': 'h5'}
        selected_features = ['Count', 'RegionCenter', 'RegionRadii', 'RegionAxes']
        opStructuredTracking.ExportSettings.setValue( (settings, selected_features) )

        self._applets = []
        self._applets.append(self.dataSelectionApplet)
        if not self.fromBinary:
            self._applets.append(self.thresholdTwoLevelsApplet)
        self._applets.append(self.trackingFeatureExtractionApplet)
        self._applets.append(self.divisionDetectionApplet)

        self.batchProcessingApplet = BatchProcessingApplet(self, "Batch Processing", self.dataSelectionApplet, self.dataExportTrackingApplet)

        self._applets.append(self.cellClassificationApplet)
        self._applets.append(self.objectExtractionApplet)
        self._applets.append(self.annotationsApplet)
        self._applets.append(self.trackingApplet)
        self._applets.append(self.dataExportTrackingApplet)

        if self.divisionDetectionApplet:
            opDivDetection = self.divisionDetectionApplet.topLevelOperator
            opDivDetection.SelectedFeatures.setValue(configConservation.selectedFeaturesDiv)
            opDivDetection.LabelNames.setValue(['Not Dividing', 'Dividing'])
            opDivDetection.AllowDeleteLabels.setValue(False)
            opDivDetection.AllowAddLabel.setValue(False)
            opDivDetection.EnableLabelTransfer.setValue(False)

        opCellClassification = self.cellClassificationApplet.topLevelOperator
        opCellClassification.SelectedFeatures.setValue(configConservation.selectedFeaturesObjectCount )
        opCellClassification.SuggestedLabelNames.setValue( ['False Detection',] + [str(1) + ' Object'] + [str(i) + ' Objects' for i in range(2,10) ] )
        opCellClassification.AllowDeleteLastLabelOnly.setValue(True)
        opCellClassification.EnableLabelTransfer.setValue(False)

        if workflow_cmdline_args:

            if '--testFullAnnotations' in workflow_cmdline_args:
                self.testFullAnnotations = True
            else:
                self.testFullAnnotations = False

            self._data_export_args, unused_args = self.dataExportTrackingApplet.parse_known_cmdline_args( workflow_cmdline_args )
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args( workflow_cmdline_args )
        else:
            unused_args = None
            self._data_export_args = None
            self._batch_input_args = None
            self.testFullAnnotations = False

        if unused_args:
            logger.warning("Unused command-line args: {}".format( unused_args ))

    def _pluginExportFunc(self, lane_index, filename, exportPlugin, checkOverwriteFiles, plugArgsSlot) -> int:
        return (
            self.trackingApplet
            .topLevelOperator
            .getLane(lane_index)
            .exportPlugin(
                filename,
                exportPlugin,
                checkOverwriteFiles,
                additionalPluginArgumentsSlot
            )
        )

    def connectLane(self, laneIndex):
        opData = self.dataSelectionApplet.topLevelOperator.getLane(laneIndex)
        opObjExtraction = self.objectExtractionApplet.topLevelOperator.getLane(laneIndex)
        opTrackingFeatureExtraction = self.trackingFeatureExtractionApplet.topLevelOperator.getLane(laneIndex)

        opAnnotations = self.annotationsApplet.topLevelOperator.getLane(laneIndex)
        if not self.fromBinary:
            opTwoLevelThreshold = self.thresholdTwoLevelsApplet.topLevelOperator.getLane(laneIndex)

        opStructuredTracking = self.trackingApplet.topLevelOperator.getLane(laneIndex)
        opDataTrackingExport = self.dataExportTrackingApplet.topLevelOperator.getLane(laneIndex)

        ## Connect operators ##
        op5Raw = OpReorderAxes(parent=self)
        op5Raw.AxisOrder.setValue("txyzc")
        op5Raw.Input.connect(opData.ImageGroup[0])

        opDivDetection = self.divisionDetectionApplet.topLevelOperator.getLane(laneIndex)
        opCellClassification = self.cellClassificationApplet.topLevelOperator.getLane(laneIndex)

        if not self.fromBinary:
            opTwoLevelThreshold.InputImage.connect( opData.ImageGroup[1] )
            opTwoLevelThreshold.RawInput.connect( opData.ImageGroup[0] ) # Used for display only
            binarySrc = opTwoLevelThreshold.CachedOutput
        else:
            binarySrc = opData.ImageGroup[1]
        # Use Op5ifyers for both input datasets such that they are guaranteed to
        # have the same axis order after thresholding
        op5Binary = OpReorderAxes(parent=self)
        op5Binary.AxisOrder.setValue("txyzc")
        op5Binary.Input.connect(binarySrc)

        opObjExtraction.RawImage.connect( op5Raw.Output )
        opObjExtraction.BinaryImage.connect( op5Binary.Output )

        opTrackingFeatureExtraction.RawImage.connect( op5Raw.Output )
        opTrackingFeatureExtraction.BinaryImage.connect( op5Binary.Output )

        opTrackingFeatureExtraction.setDefaultFeatures(configConservation.allFeaturesObjectCount)
        opTrackingFeatureExtraction.FeatureNamesVigra.setValue(configConservation.allFeaturesObjectCount)
        feature_dict_division = {}
        feature_dict_division[config.features_division_name] = { name: {} for name in config.division_features }
        opTrackingFeatureExtraction.FeatureNamesDivision.setValue(feature_dict_division)

        if self.divisionDetectionApplet:
            opDivDetection.BinaryImages.connect( op5Binary.Output )
            opDivDetection.RawImages.connect( op5Raw.Output )
            opDivDetection.SegmentationImages.connect(opTrackingFeatureExtraction.LabelImage)
            opDivDetection.ObjectFeatures.connect(opTrackingFeatureExtraction.RegionFeaturesAll)
            opDivDetection.ComputedFeatureNames.connect(opTrackingFeatureExtraction.ComputedFeatureNamesAll)

        opCellClassification.BinaryImages.connect( op5Binary.Output )
        opCellClassification.RawImages.connect( op5Raw.Output )
        opCellClassification.SegmentationImages.connect(opTrackingFeatureExtraction.LabelImage)
        opCellClassification.ObjectFeatures.connect(opTrackingFeatureExtraction.RegionFeaturesAll)
        opCellClassification.ComputedFeatureNames.connect(opTrackingFeatureExtraction.ComputedFeatureNamesNoDivisions)

        opAnnotations.RawImage.connect( op5Raw.Output )
        opAnnotations.BinaryImage.connect( op5Binary.Output )
        opAnnotations.LabelImage.connect( opObjExtraction.LabelImage )
        opAnnotations.ObjectFeatures.connect( opObjExtraction.RegionFeatures )
        opAnnotations.ComputedFeatureNames.connect(opObjExtraction.Features)
        opAnnotations.DivisionProbabilities.connect( opDivDetection.Probabilities )
        opAnnotations.DetectionProbabilities.connect( opCellClassification.Probabilities )
        opAnnotations.MaxNumObj.connect (opCellClassification.MaxNumObj)

        opStructuredTracking.RawImage.connect( op5Raw.Output )
        opStructuredTracking.LabelImage.connect( opTrackingFeatureExtraction.LabelImage )
        opStructuredTracking.ObjectFeatures.connect( opTrackingFeatureExtraction.RegionFeaturesVigra )
        opStructuredTracking.ComputedFeatureNames.connect( opTrackingFeatureExtraction.FeatureNamesVigra )

        if self.divisionDetectionApplet:
            opStructuredTracking.ObjectFeaturesWithDivFeatures.connect( opTrackingFeatureExtraction.RegionFeaturesAll)
            opStructuredTracking.ComputedFeatureNamesWithDivFeatures.connect( opTrackingFeatureExtraction.ComputedFeatureNamesAll )
            opStructuredTracking.DivisionProbabilities.connect( opDivDetection.Probabilities )

        opStructuredTracking.DetectionProbabilities.connect( opCellClassification.Probabilities )
        opStructuredTracking.NumLabels.connect( opCellClassification.NumLabels )
        opStructuredTracking.Annotations.connect (opAnnotations.Annotations)
        opStructuredTracking.Labels.connect (opAnnotations.Labels)
        opStructuredTracking.Divisions.connect (opAnnotations.Divisions)
        opStructuredTracking.Appearances.connect (opAnnotations.Appearances)
        opStructuredTracking.Disappearances.connect (opAnnotations.Disappearances)
        opStructuredTracking.MaxNumObj.connect (opCellClassification.MaxNumObj)

        opDataTrackingExport.Inputs.resize(3)
        opDataTrackingExport.Inputs[0].connect( opStructuredTracking.RelabeledImage )
        opDataTrackingExport.Inputs[1].connect( opStructuredTracking.MergerOutput )
        opDataTrackingExport.Inputs[2].connect( opStructuredTracking.LabelImage )
        opDataTrackingExport.RawData.connect( op5Raw.Output )
        opDataTrackingExport.RawDatasetInfo.connect( opData.DatasetGroup[0] )

    def prepare_lane_for_export(self, lane_index):
        import logging
        logger = logging.getLogger(__name__)

        maxt = self.trackingApplet.topLevelOperator[lane_index].RawImage.meta.shape[0]
        maxx = self.trackingApplet.topLevelOperator[lane_index].RawImage.meta.shape[1]
        maxy = self.trackingApplet.topLevelOperator[lane_index].RawImage.meta.shape[2]
        maxz = self.trackingApplet.topLevelOperator[lane_index].RawImage.meta.shape[3]
        time_enum = list(range(maxt))
        x_range = (0, maxx)
        y_range = (0, maxy)
        z_range = (0, maxz)

        ndim = 2
        if ( z_range[1] - z_range[0] ) > 1:
            ndim = 3

        parameters = self.trackingApplet.topLevelOperator.Parameters.value
        # Save state of axis ranges
        if 'time_range' in parameters:
            self.prev_time_range = parameters['time_range']
        else:
            self.prev_time_range = time_enum

        if 'x_range' in parameters:
            self.prev_x_range = parameters['x_range']
        else:
            self.prev_x_range = x_range

        if 'y_range' in parameters:
            self.prev_y_range = parameters['y_range']
        else:
            self.prev_y_range = y_range

        if 'z_range' in parameters:
            self.prev_z_range = parameters['z_range']
        else:
            self.prev_z_range = z_range

        # batch processing starts a new lane, so training data needs to be copied from the lane that loaded the project
        loaded_project_lane_index=0
        self.annotationsApplet.topLevelOperator[lane_index].Annotations.setValue(
            self.trackingApplet.topLevelOperator[loaded_project_lane_index].Annotations.value)

        def runLearningAndTracking(withMergerResolution=True):
            if self.testFullAnnotations:
                logger.info("Test: Structured Learning")
                weights = self.trackingApplet.topLevelOperator[lane_index]._runStructuredLearning(
                    z_range,
                    parameters['maxObj'],
                    parameters['max_nearest_neighbors'],
                    parameters['maxDist'],
                    parameters['divThreshold'],
                    [parameters['scales'][0],parameters['scales'][1],parameters['scales'][2]],
                    parameters['size_range'],
                    parameters['withDivisions'],
                    parameters['borderAwareWidth'],
                    parameters['withClassifierPrior'],
                    withBatchProcessing=True)
                logger.info("weights: {}".format(weights))

            logger.info("Test: Tracking")
            result = self.trackingApplet.topLevelOperator[lane_index].track(
                time_range = time_enum,
                x_range = x_range,
                y_range = y_range,
                z_range = z_range,
                size_range = parameters['size_range'],
                x_scale = parameters['scales'][0],
                y_scale = parameters['scales'][1],
                z_scale = parameters['scales'][2],
                maxDist=parameters['maxDist'],
                maxObj = parameters['maxObj'],
                divThreshold=parameters['divThreshold'],
                avgSize=parameters['avgSize'],
                withTracklets=parameters['withTracklets'],
                sizeDependent=parameters['sizeDependent'],
                detWeight=parameters['detWeight'],
                divWeight=parameters['divWeight'],
                transWeight=parameters['transWeight'],
                withDivisions=parameters['withDivisions'],
                withOpticalCorrection=parameters['withOpticalCorrection'],
                withClassifierPrior=parameters['withClassifierPrior'],
                ndim=ndim,
                withMergerResolution=withMergerResolution,
                borderAwareWidth = parameters['borderAwareWidth'],
                withArmaCoordinates = parameters['withArmaCoordinates'],
                cplex_timeout = parameters['cplex_timeout'],
                appearance_cost = parameters['appearanceCost'],
                disappearance_cost = parameters['disappearanceCost'],
                force_build_hypotheses_graph = False,
                withBatchProcessing = True
            )

            return result

        if self.testFullAnnotations:

            self.result = runLearningAndTracking(withMergerResolution=False)

            hypothesesGraph = self.trackingApplet.topLevelOperator[lane_index].LearningHypothesesGraph.value
            hypothesesGraph.insertSolution(self.result)
            hypothesesGraph.computeLineage()
            solution = hypothesesGraph.getSolutionDictionary()
            annotations = self.trackingApplet.topLevelOperator[lane_index].Annotations.value

            self.trackingApplet.topLevelOperator[lane_index].insertAnnotationsToHypothesesGraph(hypothesesGraph,annotations,misdetectionLabel=-1)
            hypothesesGraph.computeLineage()
            solutionFromAnnotations = hypothesesGraph.getSolutionDictionary()

            for key in list(solution.keys()):
                if key == 'detectionResults':
                    detectionFlag = True
                    for i in range(len(solution[key])):
                        flag = False
                        for j in range(len(solutionFromAnnotations[key])):
                            if solution[key][i]['id'] == solutionFromAnnotations[key][j]['id'] and \
                                solution[key][i]['value'] == solutionFromAnnotations[key][j]['value']:
                                flag = True
                                break
                        detectionFlag &= flag
                elif key == 'divisionResults':
                    divisionFlag = True
                    for i in range(len(solution[key])):
                        flag = False
                        for j in range(len(solutionFromAnnotations[key])):
                            if solution[key][i]['id'] == solutionFromAnnotations[key][j]['id'] and \
                                solution[key][i]['value'] == solutionFromAnnotations[key][j]['value']:
                                flag = True
                                break
                        divisionFlag &= flag
                elif key == 'linkingResults':
                    linkingFlag = True
                    for i in range(len(solution[key])):
                        flag = False
                        for j in range(len(solutionFromAnnotations[key])):
                            if solution[key][i]['dest'] == solutionFromAnnotations[key][j]['dest'] and \
                                solution[key][i]['src'] == solutionFromAnnotations[key][j]['src']:
                                if solution[key][i]['gap'] == solutionFromAnnotations[key][j]['gap'] and \
                                    solution[key][i]['value'] == solutionFromAnnotations[key][j]['value']:
                                    flag = True
                                    break
                        linkingFlag &= flag

            assert detectionFlag, "Detection results are NOT correct. They differ from your annotated detections."
            logger.info("Detection results are correct.")
            assert divisionFlag, "Division results are NOT correct. They differ from your annotated divisions."
            logger.info("Division results are correct.")
            assert linkingFlag, "Transition results are NOT correct. They differ from your annotated transitions."
            logger.info("Transition results are correct.")
        self.result = runLearningAndTracking(withMergerResolution=parameters['withMergerResolution'])

    def _inputReady(self, nRoles):
        slot = self.dataSelectionApplet.topLevelOperator.ImageGroup
        if len(slot) > 0:
            input_ready = True
            for sub in slot:
                input_ready = input_ready and \
                    all([sub[i].ready() for i in range(nRoles)])
        else:
            input_ready = False
        return input_ready

    def onProjectLoaded(self, projectManager):
        """
        Overridden from Workflow base class.  Called by the Project Manager.

        If the user provided command-line arguments, use them to configure
        the workflow inputs and output settings.
        """

        # Configure the data export operator.
        if self._data_export_args:
            self.dataExportTrackingApplet.configure_operator_with_parsed_args( self._data_export_args )

        # Configure headless mode.
        if self._headless and self._batch_input_args and self._data_export_args:
            logger.info("Beginning Batch Processing")
            self.batchProcessingApplet.run_export_from_parsed_args(self._batch_input_args)
            logger.info("Completed Batch Processing")

    def handleAppletStateUpdateRequested(self):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.statusUpdateSignal`
        """
        # If no data, nothing else is ready.
        input_ready = self._inputReady(2) and not self.dataSelectionApplet.busy

        if not self.fromBinary:
            opThresholding = self.thresholdTwoLevelsApplet.topLevelOperator
            thresholdingOutput = opThresholding.CachedOutput
            thresholding_ready = input_ready and len(thresholdingOutput) > 0
        else:
            thresholding_ready = input_ready

        opTrackingFeatureExtraction = self.trackingFeatureExtractionApplet.topLevelOperator
        trackingFeatureExtractionOutput = opTrackingFeatureExtraction.ComputedFeatureNamesAll
        tracking_features_ready = thresholding_ready and len(trackingFeatureExtractionOutput) > 0

        objectCountClassifier_ready = tracking_features_ready

        opObjectExtraction = self.objectExtractionApplet.topLevelOperator
        objectExtractionOutput = opObjectExtraction.RegionFeatures
        features_ready = thresholding_ready and \
                         len(objectExtractionOutput) > 0

        opAnnotations = self.annotationsApplet.topLevelOperator
        annotations_ready = features_ready and \
                           len(opAnnotations.Labels) > 0 and \
                           opAnnotations.Labels.ready() and \
                           opAnnotations.TrackImage.ready()

        opStructuredTracking = self.trackingApplet.topLevelOperator
        structured_tracking_ready = objectCountClassifier_ready

        withIlpSolver = (self._solver=="ILP")

        busy = False
        busy |= self.dataSelectionApplet.busy
        busy |= self.annotationsApplet.busy
        # busy |= self.dataExportAnnotationsApplet.busy
        busy |= self.trackingApplet.busy
        busy |= self.dataExportTrackingApplet.busy

        self._shell.enableProjectChanges( not busy )

        self._shell.setAppletEnabled(self.dataSelectionApplet, not busy)
        if not self.fromBinary:
            self._shell.setAppletEnabled(self.thresholdTwoLevelsApplet, input_ready and not busy)
        self._shell.setAppletEnabled(self.trackingFeatureExtractionApplet, thresholding_ready and not busy)
        self._shell.setAppletEnabled(self.cellClassificationApplet, tracking_features_ready and not busy)
        self._shell.setAppletEnabled(self.divisionDetectionApplet, tracking_features_ready and not busy)
        self._shell.setAppletEnabled(self.objectExtractionApplet, not busy)
        self._shell.setAppletEnabled(self.annotationsApplet, features_ready and not busy) # and withIlpSolver)
        # self._shell.setAppletEnabled(self.dataExportAnnotationsApplet, annotations_ready and not busy and \
        #                                 self.dataExportAnnotationsApplet.topLevelOperator.Inputs[0][0].ready() )
        self._shell.setAppletEnabled(self.trackingApplet, objectCountClassifier_ready and not busy)
        self._shell.setAppletEnabled(self.dataExportTrackingApplet, structured_tracking_ready and not busy and \
                                    self.dataExportTrackingApplet.topLevelOperator.Inputs[0][0].ready() )
예제 #19
0
    def __init__(self, shell, headless, workflow_cmdline_args, project_creation_args, *args, **kwargs):
        graph = kwargs["graph"] if "graph" in kwargs else Graph()
        if "graph" in kwargs:
            del kwargs["graph"]
        super(CountingWorkflow, self).__init__(
            shell, headless, workflow_cmdline_args, project_creation_args, graph=graph, *args, **kwargs
        )
        self.stored_classifier = None

        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument(
            "--csv-export-file",
            help="Instead of exporting prediction density images, export total counts to the given csv path.",
        )
        self.parsed_counting_workflow_args, unused_args = parser.parse_known_args(workflow_cmdline_args)

        ######################
        # Interactive workflow
        ######################

        allowed_axis_orders = []
        for space in itertools.permutations("xyz", 2):
            allowed_axis_orders.append("".join(space) + "c")

        self.dataSelectionApplet = DataSelectionApplet(
            self, "Input Data", "Input Data", forceAxisOrder=allowed_axis_orders
        )
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        role_names = ["Raw Data"]
        opDataSelection.DatasetRoles.setValue(role_names)

        self.featureSelectionApplet = FeatureSelectionApplet(self, "Feature Selection", "FeatureSelections")

        self.countingApplet = CountingApplet(workflow=self)
        opCounting = self.countingApplet.topLevelOperator
        opCounting.WorkingDirectory.connect(opDataSelection.WorkingDirectory)

        self.dataExportApplet = CountingDataExportApplet(self, "Density Export", opCounting)

        # Customization hooks
        self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        self.dataExportApplet.post_process_lane_export = self.post_process_lane_export
        self.dataExportApplet.post_process_entire_export = self.post_process_entire_export

        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.PmapColors.connect(opCounting.PmapColors)
        opDataExport.LabelNames.connect(opCounting.LabelNames)
        opDataExport.UpperBound.connect(opCounting.UpperBound)
        opDataExport.WorkingDirectory.connect(opDataSelection.WorkingDirectory)
        opDataExport.SelectionNames.setValue(["Probabilities"])

        self._applets = []
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.featureSelectionApplet)
        self._applets.append(self.countingApplet)
        self._applets.append(self.dataExportApplet)

        self.batchProcessingApplet = BatchProcessingApplet(
            self, "Batch Processing", self.dataSelectionApplet, self.dataExportApplet
        )
        self._applets.append(self.batchProcessingApplet)
        if unused_args:
            # We parse the export setting args first.  All remaining args are considered input files by the input applet.
            self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args(unused_args)
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args(unused_args)
        else:
            self._batch_input_args = None
            self._batch_export_args = None

        if unused_args:
            logger.warning("Unused command-line args: {}".format(unused_args))
예제 #20
0
    def __init__( self, shell, headless, workflow_cmdline_args, project_creation_args, *args, **kwargs ):
        graph = kwargs['graph'] if 'graph' in kwargs else Graph()
        if 'graph' in kwargs: del kwargs['graph']

        super(StructuredTrackingWorkflowBase, self).__init__(shell, headless, workflow_cmdline_args, project_creation_args, graph=graph, *args, **kwargs)

        data_instructions = 'Use the "Raw Data" tab to load your intensity image(s).\n\n'
        if self.fromBinary:
            data_instructions += 'Use the "Binary Image" tab to load your segmentation image(s).'
        else:
            data_instructions += 'Use the "Prediction Maps" tab to load your pixel-wise probability image(s).'

        # Create applets
        self.dataSelectionApplet = DataSelectionApplet(self,
            "Input Data",
            "Input Data",
            batchDataGui=False,
            forceAxisOrder=['txyzc'],
            instructionText=data_instructions,
            max_lanes=1)

        opDataSelection = self.dataSelectionApplet.topLevelOperator
        if self.fromBinary:
            opDataSelection.DatasetRoles.setValue( ['Raw Data', 'Binary Image'] )
        else:
            opDataSelection.DatasetRoles.setValue( ['Raw Data', 'Prediction Maps'] )

        if not self.fromBinary:
            self.thresholdTwoLevelsApplet = ThresholdTwoLevelsApplet( self,"Threshold and Size Filter","ThresholdTwoLevels" )

        self.divisionDetectionApplet = ObjectClassificationApplet(workflow=self,
                                                                     name="Division Detection (optional)",
                                                                     projectFileGroupName="DivisionDetection",
                                                                     selectedFeatures=configConservation.selectedFeaturesDiv)

        self.cellClassificationApplet = ObjectClassificationApplet(workflow=self,
                                                                     name="Object Count Classification",
                                                                     projectFileGroupName="CountClassification",
                                                                     selectedFeatures=configConservation.selectedFeaturesObjectCount)

        self.trackingFeatureExtractionApplet = TrackingFeatureExtractionApplet(name="Object Feature Computation",workflow=self, interactive=False)

        self.objectExtractionApplet = ObjectExtractionApplet(name="Object Feature Computation",workflow=self, interactive=False)

        self.annotationsApplet = AnnotationsApplet( name="Training", workflow=self )
        opAnnotations = self.annotationsApplet.topLevelOperator

        self.trackingApplet = StructuredTrackingApplet( name="Tracking - Structured Learning", workflow=self )
        opStructuredTracking = self.trackingApplet.topLevelOperator

        if SOLVER=="CPLEX" or SOLVER=="GUROBI":
            self._solver="ILP"
        elif SOLVER=="DPCT":
            self._solver="Flow-based"
        else:
            self._solver=None
        opStructuredTracking._solver = self._solver

        self.default_tracking_export_filename = '{dataset_dir}/{nickname}-tracking_exported_data.csv'
        self.dataExportTrackingApplet = TrackingBaseDataExportApplet(
            self,
            "Tracking Result Export",
            default_export_filename=self.default_tracking_export_filename,
            pluginExportFunc=self._pluginExportFunc
        )
        opDataExportTracking = self.dataExportTrackingApplet.topLevelOperator
        opDataExportTracking.SelectionNames.setValue( ['Tracking-Result', 'Merger-Result', 'Object-Identities'] )
        opDataExportTracking.WorkingDirectory.connect( opDataSelection.WorkingDirectory )
        self.dataExportTrackingApplet.set_exporting_operator(opStructuredTracking)
        self.dataExportTrackingApplet.prepare_lane_for_export = self.prepare_lane_for_export

        # configure export settings
        settings = {'file path': self.default_tracking_export_filename, 'compression': {}, 'file type': 'h5'}
        selected_features = ['Count', 'RegionCenter', 'RegionRadii', 'RegionAxes']
        opStructuredTracking.ExportSettings.setValue( (settings, selected_features) )

        self._applets = []
        self._applets.append(self.dataSelectionApplet)
        if not self.fromBinary:
            self._applets.append(self.thresholdTwoLevelsApplet)
        self._applets.append(self.trackingFeatureExtractionApplet)
        self._applets.append(self.divisionDetectionApplet)

        self.batchProcessingApplet = BatchProcessingApplet(self, "Batch Processing", self.dataSelectionApplet, self.dataExportTrackingApplet)

        self._applets.append(self.cellClassificationApplet)
        self._applets.append(self.objectExtractionApplet)
        self._applets.append(self.annotationsApplet)
        self._applets.append(self.trackingApplet)
        self._applets.append(self.dataExportTrackingApplet)

        if self.divisionDetectionApplet:
            opDivDetection = self.divisionDetectionApplet.topLevelOperator
            opDivDetection.SelectedFeatures.setValue(configConservation.selectedFeaturesDiv)
            opDivDetection.LabelNames.setValue(['Not Dividing', 'Dividing'])
            opDivDetection.AllowDeleteLabels.setValue(False)
            opDivDetection.AllowAddLabel.setValue(False)
            opDivDetection.EnableLabelTransfer.setValue(False)

        opCellClassification = self.cellClassificationApplet.topLevelOperator
        opCellClassification.SelectedFeatures.setValue(configConservation.selectedFeaturesObjectCount )
        opCellClassification.SuggestedLabelNames.setValue( ['False Detection',] + [str(1) + ' Object'] + [str(i) + ' Objects' for i in range(2,10) ] )
        opCellClassification.AllowDeleteLastLabelOnly.setValue(True)
        opCellClassification.EnableLabelTransfer.setValue(False)

        if workflow_cmdline_args:

            if '--testFullAnnotations' in workflow_cmdline_args:
                self.testFullAnnotations = True
            else:
                self.testFullAnnotations = False

            self._data_export_args, unused_args = self.dataExportTrackingApplet.parse_known_cmdline_args( workflow_cmdline_args )
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args( workflow_cmdline_args )
        else:
            unused_args = None
            self._data_export_args = None
            self._batch_input_args = None
            self.testFullAnnotations = False

        if unused_args:
            logger.warning("Unused command-line args: {}".format( unused_args ))
예제 #21
0
    def __init__(self, shell, headless, workflow_cmdline_args,
                 project_creation_args, *args, **kwargs):
        graph = kwargs["graph"] if "graph" in kwargs else Graph()
        if "graph" in kwargs:
            del kwargs["graph"]
        # if 'withOptTrans' in kwargs:
        #     self.withOptTrans = kwargs['withOptTrans']
        # if 'fromBinary' in kwargs:
        #     self.fromBinary = kwargs['fromBinary']
        super(ConservationTrackingWorkflowBase,
              self).__init__(shell,
                             headless,
                             workflow_cmdline_args,
                             project_creation_args,
                             graph=graph,
                             *args,
                             **kwargs)

        data_instructions = 'Use the "Raw Data" tab to load your intensity image(s).\n\n'
        if self.fromBinary:
            data_instructions += 'Use the "Binary Image" tab to load your segmentation image(s).'
        else:
            data_instructions += 'Use the "Prediction Maps" tab to load your pixel-wise probability image(s).'

        # Variables to store division and cell classifiers to prevent retraining every-time batch processing runs
        self.stored_division_classifier = None
        self.stored_cell_classifier = None

        ## Create applets
        self.dataSelectionApplet = DataSelectionApplet(
            self,
            "Input Data",
            "Input Data",
            forceAxisOrder=["txyzc"],
            instructionText=data_instructions,
            max_lanes=None,
        )

        opDataSelection = self.dataSelectionApplet.topLevelOperator
        if self.fromBinary:
            opDataSelection.DatasetRoles.setValue(
                ["Raw Data", "Segmentation Image"])
        else:
            opDataSelection.DatasetRoles.setValue(
                ["Raw Data", "Prediction Maps"])

        if not self.fromBinary:
            self.thresholdTwoLevelsApplet = ThresholdTwoLevelsApplet(
                self, "Threshold and Size Filter", "ThresholdTwoLevels")

        self.objectExtractionApplet = TrackingFeatureExtractionApplet(
            workflow=self,
            interactive=False,
            name="Object Feature Computation")

        opObjectExtraction = self.objectExtractionApplet.topLevelOperator

        self.divisionDetectionApplet = self._createDivisionDetectionApplet(
            configConservation.selectedFeaturesDiv)  # Might be None

        if self.divisionDetectionApplet:
            feature_dict_division = {}
            feature_dict_division[config.features_division_name] = {
                name: {}
                for name in config.division_features
            }
            opObjectExtraction.FeatureNamesDivision.setValue(
                feature_dict_division)

            selected_features_div = {}
            for plugin_name in list(config.selected_features_division.keys()):
                selected_features_div[plugin_name] = {
                    name: {}
                    for name in config.selected_features_division[plugin_name]
                }
            # FIXME: do not hard code this
            for name in [
                    "SquaredDistances_" + str(i)
                    for i in range(config.n_best_successors)
            ]:
                selected_features_div[config.features_division_name][name] = {}

            opDivisionDetection = self.divisionDetectionApplet.topLevelOperator
            opDivisionDetection.SelectedFeatures.setValue(
                configConservation.selectedFeaturesDiv)
            opDivisionDetection.LabelNames.setValue(
                ["Not Dividing", "Dividing"])
            opDivisionDetection.AllowDeleteLabels.setValue(False)
            opDivisionDetection.AllowAddLabel.setValue(False)
            opDivisionDetection.EnableLabelTransfer.setValue(False)

        self.cellClassificationApplet = ObjectClassificationApplet(
            workflow=self,
            name="Object Count Classification",
            projectFileGroupName="CountClassification",
            selectedFeatures=configConservation.selectedFeaturesObjectCount,
        )

        selected_features_objectcount = {}
        for plugin_name in list(config.selected_features_objectcount.keys()):
            selected_features_objectcount[plugin_name] = {
                name: {}
                for name in config.selected_features_objectcount[plugin_name]
            }

        opCellClassification = self.cellClassificationApplet.topLevelOperator
        opCellClassification.SelectedFeatures.setValue(
            configConservation.selectedFeaturesObjectCount)
        opCellClassification.SuggestedLabelNames.setValue(
            ["False Detection"] + [str(1) + " Object"] +
            [str(i) + " Objects" for i in range(2, 10)])
        opCellClassification.AllowDeleteLastLabelOnly.setValue(True)
        opCellClassification.EnableLabelTransfer.setValue(False)

        self.trackingApplet = ConservationTrackingApplet(workflow=self)

        self.default_export_filename = "{dataset_dir}/{nickname}-exported_data.csv"
        self.dataExportApplet = TrackingBaseDataExportApplet(
            self,
            "Tracking Result Export",
            default_export_filename=self.default_export_filename,
            pluginExportFunc=self._pluginExportFunc,
        )

        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.SelectionNames.setValue(
            ["Object-Identities", "Tracking-Result", "Merger-Result"])
        opDataExport.WorkingDirectory.connect(opDataSelection.WorkingDirectory)

        # Extra configuration for object export table (as CSV table or HDF5 table)
        opTracking = self.trackingApplet.topLevelOperator
        self.dataExportApplet.set_exporting_operator(opTracking)
        self.dataExportApplet.prepare_lane_for_export = self.prepare_lane_for_export

        # configure export settings
        # settings = {'file path': self.default_export_filename, 'compression': {}, 'file type': 'csv'}
        # selected_features = ['Count', 'RegionCenter', 'RegionRadii', 'RegionAxes']
        # opTracking.ExportSettings.setValue( (settings, selected_features) )

        self._applets = []
        self._applets.append(self.dataSelectionApplet)
        if not self.fromBinary:
            self._applets.append(self.thresholdTwoLevelsApplet)
        self._applets.append(self.objectExtractionApplet)

        if self.divisionDetectionApplet:
            self._applets.append(self.divisionDetectionApplet)

        self.batchProcessingApplet = BatchProcessingApplet(
            self, "Batch Processing", self.dataSelectionApplet,
            self.dataExportApplet)

        self._applets.append(self.cellClassificationApplet)
        self._applets.append(self.trackingApplet)
        self._applets.append(self.dataExportApplet)
        self._applets.append(self.batchProcessingApplet)

        # Parse export and batch command-line arguments for headless mode
        if workflow_cmdline_args:
            self._data_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args(
                workflow_cmdline_args)
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args(
                workflow_cmdline_args)

        else:
            unused_args = None
            self._data_export_args = None
            self._batch_input_args = None

        if unused_args:
            logger.warning("Unused command-line args: {}".format(unused_args))
예제 #22
0
class WsdtWorkflow(Workflow):
    workflowName = "Watershed Over Distance Transform"
    workflowDescription = "A bare-bones workflow for using the WSDT applet"
    defaultAppletIndex = 0  # show DataSelection by default

    DATA_ROLE_RAW = 0
    DATA_ROLE_PROBABILITIES = 1
    ROLE_NAMES = ['Raw Data', 'Probabilities']
    EXPORT_NAMES = ['Watershed']

    @property
    def applets(self):
        return self._applets

    @property
    def imageNameListSlot(self):
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def __init__(self, shell, headless, workflow_cmdline_args,
                 project_creation_workflow, *args, **kwargs):
        # Create a graph to be shared by all operators
        graph = Graph()

        super(WsdtWorkflow, self).__init__(shell,
                                           headless,
                                           workflow_cmdline_args,
                                           project_creation_workflow,
                                           graph=graph,
                                           *args,
                                           **kwargs)
        self._applets = []

        # -- DataSelection applet
        #
        self.dataSelectionApplet = DataSelectionApplet(self, "Input Data",
                                                       "Input Data")

        # Dataset inputs
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        opDataSelection.DatasetRoles.setValue(self.ROLE_NAMES)

        # -- Wsdt applet
        #
        self.wsdtApplet = WsdtApplet(self, "Watershed", "Wsdt Watershed")

        # -- DataExport applet
        #
        self.dataExportApplet = DataExportApplet(self, "Data Export")

        # Configure global DataExport settings
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.WorkingDirectory.connect(opDataSelection.WorkingDirectory)
        opDataExport.SelectionNames.setValue(self.EXPORT_NAMES)

        # -- BatchProcessing applet
        #
        self.batchProcessingApplet = BatchProcessingApplet(
            self, "Batch Processing", self.dataSelectionApplet,
            self.dataExportApplet)

        # -- Expose applets to shell
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.wsdtApplet)
        self._applets.append(self.dataExportApplet)
        self._applets.append(self.batchProcessingApplet)

        # -- Parse command-line arguments
        #    (Command-line args are applied in onProjectLoaded(), below.)
        if workflow_cmdline_args:
            self._data_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args(
                workflow_cmdline_args)
            self._batch_input_args, unused_args = self.dataSelectionApplet.parse_known_cmdline_args(
                unused_args, role_names)
        else:
            unused_args = None
            self._batch_input_args = None
            self._data_export_args = None

        if unused_args:
            logger.warning("Unused command-line args: {}".format(unused_args))

    def connectLane(self, laneIndex):
        """
        Override from base class.
        """
        opDataSelection = self.dataSelectionApplet.topLevelOperator.getLane(
            laneIndex)
        opWsdt = self.wsdtApplet.topLevelOperator.getLane(laneIndex)
        opDataExport = self.dataExportApplet.topLevelOperator.getLane(
            laneIndex)

        # watershed inputs
        opWsdt.RawData.connect(opDataSelection.ImageGroup[self.DATA_ROLE_RAW])
        opWsdt.Input.connect(
            opDataSelection.ImageGroup[self.DATA_ROLE_PROBABILITIES])

        # DataExport inputs
        opDataExport.RawData.connect(
            opDataSelection.ImageGroup[self.DATA_ROLE_RAW])
        opDataExport.RawDatasetInfo.connect(
            opDataSelection.DatasetGroup[self.DATA_ROLE_RAW])
        opDataExport.Inputs.resize(len(self.EXPORT_NAMES))
        opDataExport.Inputs[0].connect(opWsdt.Superpixels)
        for slot in opDataExport.Inputs:
            assert slot.upstream_slot is not None

    def onProjectLoaded(self, projectManager):
        """
        Overridden from Workflow base class.  Called by the Project Manager.
        
        If the user provided command-line arguments, use them to configure 
        the workflow inputs and output settings.
        """
        # Configure the data export operator.
        if self._data_export_args:
            self.dataExportApplet.configure_operator_with_parsed_args(
                self._data_export_args)

        if self._headless and self._batch_input_args and self._data_export_args:
            logger.info("Beginning Batch Processing")
            self.batchProcessingApplet.run_export_from_parsed_args(
                self._batch_input_args)
            logger.info("Completed Batch Processing")

    def handleAppletStateUpdateRequested(self):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.appletStateUpdateRequested`
        """
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        opDataExport = self.dataExportApplet.topLevelOperator
        opWsdt = self.wsdtApplet.topLevelOperator

        # If no data, nothing else is ready.
        input_ready = len(opDataSelection.ImageGroup
                          ) > 0 and not self.dataSelectionApplet.busy

        # The user isn't allowed to touch anything while batch processing is running.
        batch_processing_busy = self.batchProcessingApplet.busy

        self._shell.setAppletEnabled(self.dataSelectionApplet,
                                     not batch_processing_busy)
        self._shell.setAppletEnabled(self.wsdtApplet, not batch_processing_busy
                                     and input_ready)
        self._shell.setAppletEnabled(
            self.dataExportApplet, not batch_processing_busy and input_ready
            and opWsdt.Superpixels.ready())
        self._shell.setAppletEnabled(self.batchProcessingApplet,
                                     not batch_processing_busy and input_ready)

        # Lastly, check for certain "busy" conditions, during which we
        #  should prevent the shell from closing the project.
        busy = False
        busy |= self.dataSelectionApplet.busy
        busy |= self.wsdtApplet.busy
        busy |= self.dataExportApplet.busy
        busy |= self.batchProcessingApplet.busy
        self._shell.enableProjectChanges(not busy)
예제 #23
0
    def __init__( self, shell, headless, workflow_cmdline_args, project_creation_args, *args, **kwargs ):
        graph = kwargs['graph'] if 'graph' in kwargs else Graph()
        if 'graph' in kwargs: del kwargs['graph']

        super(StructuredTrackingWorkflowBase, self).__init__(shell, headless, workflow_cmdline_args, project_creation_args, graph=graph, *args, **kwargs)

        data_instructions = 'Use the "Raw Data" tab to load your intensity image(s).\n\n'
        if self.fromBinary:
            data_instructions += 'Use the "Binary Image" tab to load your segmentation image(s).'
        else:
            data_instructions += 'Use the "Prediction Maps" tab to load your pixel-wise probability image(s).'

        # Create applets
        self.dataSelectionApplet = DataSelectionApplet(self,
            "Input Data",
            "Input Data",
            batchDataGui=False,
            forceAxisOrder=['txyzc'],
            instructionText=data_instructions,
            max_lanes=1)

        opDataSelection = self.dataSelectionApplet.topLevelOperator
        if self.fromBinary:
            opDataSelection.DatasetRoles.setValue( ['Raw Data', 'Binary Image'] )
        else:
            opDataSelection.DatasetRoles.setValue( ['Raw Data', 'Prediction Maps'] )

        if not self.fromBinary:
            self.thresholdTwoLevelsApplet = ThresholdTwoLevelsApplet( self,"Threshold and Size Filter","ThresholdTwoLevels" )

        self.divisionDetectionApplet = ObjectClassificationApplet(workflow=self,
                                                                     name="Division Detection (optional)",
                                                                     projectFileGroupName="DivisionDetection",
                                                                     selectedFeatures=configConservation.selectedFeaturesDiv)

        self.cellClassificationApplet = ObjectClassificationApplet(workflow=self,
                                                                     name="Object Count Classification",
                                                                     projectFileGroupName="CountClassification",
                                                                     selectedFeatures=configConservation.selectedFeaturesObjectCount)

        self.trackingFeatureExtractionApplet = TrackingFeatureExtractionApplet(name="Object Feature Computation",workflow=self, interactive=False)

        self.objectExtractionApplet = ObjectExtractionApplet(name="Object Feature Computation",workflow=self, interactive=False)

        self.annotationsApplet = AnnotationsApplet( name="Training", workflow=self )
        opAnnotations = self.annotationsApplet.topLevelOperator

        self.trackingApplet = StructuredTrackingApplet( name="Tracking - Structured Learning", workflow=self )
        opStructuredTracking = self.trackingApplet.topLevelOperator

        if SOLVER=="CPLEX" or SOLVER=="GUROBI":
            self._solver="ILP"
        elif SOLVER=="DPCT":
            self._solver="Flow-based"
        else:
            self._solver=None
        opStructuredTracking._solver = self._solver

        self.default_tracking_export_filename = '{dataset_dir}/{nickname}-tracking_exported_data.csv'
        self.dataExportTrackingApplet = TrackingBaseDataExportApplet(self, "Tracking Result Export",default_export_filename=self.default_tracking_export_filename)
        opDataExportTracking = self.dataExportTrackingApplet.topLevelOperator
        opDataExportTracking.SelectionNames.setValue( ['Tracking-Result', 'Merger-Result', 'Object-Identities'] )
        opDataExportTracking.WorkingDirectory.connect( opDataSelection.WorkingDirectory )
        self.dataExportTrackingApplet.set_exporting_operator(opStructuredTracking)
        self.dataExportTrackingApplet.prepare_lane_for_export = self.prepare_lane_for_export
        self.dataExportTrackingApplet.post_process_lane_export = self.post_process_lane_export

        # configure export settings
        settings = {'file path': self.default_tracking_export_filename, 'compression': {}, 'file type': 'h5'}
        selected_features = ['Count', 'RegionCenter', 'RegionRadii', 'RegionAxes']
        opStructuredTracking.ExportSettings.setValue( (settings, selected_features) )

        self._applets = []
        self._applets.append(self.dataSelectionApplet)
        if not self.fromBinary:
            self._applets.append(self.thresholdTwoLevelsApplet)
        self._applets.append(self.trackingFeatureExtractionApplet)
        self._applets.append(self.divisionDetectionApplet)

        self.batchProcessingApplet = BatchProcessingApplet(self, "Batch Processing", self.dataSelectionApplet, self.dataExportTrackingApplet)

        self._applets.append(self.cellClassificationApplet)
        self._applets.append(self.objectExtractionApplet)
        self._applets.append(self.annotationsApplet)
        self._applets.append(self.trackingApplet)
        self._applets.append(self.dataExportTrackingApplet)

        if self.divisionDetectionApplet:
            opDivDetection = self.divisionDetectionApplet.topLevelOperator
            opDivDetection.SelectedFeatures.setValue(configConservation.selectedFeaturesDiv)
            opDivDetection.LabelNames.setValue(['Not Dividing', 'Dividing'])
            opDivDetection.AllowDeleteLabels.setValue(False)
            opDivDetection.AllowAddLabel.setValue(False)
            opDivDetection.EnableLabelTransfer.setValue(False)

        opCellClassification = self.cellClassificationApplet.topLevelOperator
        opCellClassification.SelectedFeatures.setValue(configConservation.selectedFeaturesObjectCount )
        opCellClassification.SuggestedLabelNames.setValue( ['False Detection',] + [str(1) + ' Object'] + [str(i) + ' Objects' for i in range(2,10) ] )
        opCellClassification.AllowDeleteLastLabelOnly.setValue(True)
        opCellClassification.EnableLabelTransfer.setValue(False)

        if workflow_cmdline_args:

            if '--testFullAnnotations' in workflow_cmdline_args:
                self.testFullAnnotations = True
            else:
                self.testFullAnnotations = False

            self._data_export_args, unused_args = self.dataExportTrackingApplet.parse_known_cmdline_args( workflow_cmdline_args )
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args( workflow_cmdline_args )
        else:
            unused_args = None
            self._data_export_args = None
            self._batch_input_args = None
            self.testFullAnnotations = False

        if unused_args:
            logger.warning("Unused command-line args: {}".format( unused_args ))
예제 #24
0
    def __init__(self, shell, headless, workflow_cmdline_args,
                 project_creation_workflow, *args, **kwargs):
        # Create a graph to be shared by all operators
        graph = Graph()

        super(WsdtWorkflow, self).__init__(shell,
                                           headless,
                                           workflow_cmdline_args,
                                           project_creation_workflow,
                                           graph=graph,
                                           *args,
                                           **kwargs)
        self._applets = []

        # -- DataSelection applet
        #
        self.dataSelectionApplet = DataSelectionApplet(self, "Input Data",
                                                       "Input Data")

        # Dataset inputs
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        opDataSelection.DatasetRoles.setValue(self.ROLE_NAMES)

        # -- Wsdt applet
        #
        self.wsdtApplet = WsdtApplet(self, "Watershed", "Wsdt Watershed")

        # -- DataExport applet
        #
        self.dataExportApplet = DataExportApplet(self, "Data Export")

        # Configure global DataExport settings
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.WorkingDirectory.connect(opDataSelection.WorkingDirectory)
        opDataExport.SelectionNames.setValue(self.EXPORT_NAMES)

        # -- BatchProcessing applet
        #
        self.batchProcessingApplet = BatchProcessingApplet(
            self, "Batch Processing", self.dataSelectionApplet,
            self.dataExportApplet)

        # -- Expose applets to shell
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.wsdtApplet)
        self._applets.append(self.dataExportApplet)
        self._applets.append(self.batchProcessingApplet)

        # -- Parse command-line arguments
        #    (Command-line args are applied in onProjectLoaded(), below.)
        if workflow_cmdline_args:
            self._data_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args(
                workflow_cmdline_args)
            self._batch_input_args, unused_args = self.dataSelectionApplet.parse_known_cmdline_args(
                unused_args, role_names)
        else:
            unused_args = None
            self._batch_input_args = None
            self._data_export_args = None

        if unused_args:
            logger.warning("Unused command-line args: {}".format(unused_args))
예제 #25
0
class StructuredTrackingWorkflowBase( Workflow ):
    workflowName = "Structured Learning Tracking Workflow BASE"

    @property
    def applets(self):
        return self._applets

    @property
    def imageNameListSlot(self):
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def __init__( self, shell, headless, workflow_cmdline_args, project_creation_args, *args, **kwargs ):
        graph = kwargs['graph'] if 'graph' in kwargs else Graph()
        if 'graph' in kwargs: del kwargs['graph']

        super(StructuredTrackingWorkflowBase, self).__init__(shell, headless, workflow_cmdline_args, project_creation_args, graph=graph, *args, **kwargs)

        data_instructions = 'Use the "Raw Data" tab to load your intensity image(s).\n\n'
        if self.fromBinary:
            data_instructions += 'Use the "Binary Image" tab to load your segmentation image(s).'
        else:
            data_instructions += 'Use the "Prediction Maps" tab to load your pixel-wise probability image(s).'

        # Create applets
        self.dataSelectionApplet = DataSelectionApplet(self,
            "Input Data",
            "Input Data",
            batchDataGui=False,
            forceAxisOrder=['txyzc'],
            instructionText=data_instructions,
            max_lanes=1)

        opDataSelection = self.dataSelectionApplet.topLevelOperator
        if self.fromBinary:
            opDataSelection.DatasetRoles.setValue( ['Raw Data', 'Binary Image'] )
        else:
            opDataSelection.DatasetRoles.setValue( ['Raw Data', 'Prediction Maps'] )

        if not self.fromBinary:
            self.thresholdTwoLevelsApplet = ThresholdTwoLevelsApplet( self,"Threshold and Size Filter","ThresholdTwoLevels" )

        self.divisionDetectionApplet = ObjectClassificationApplet(workflow=self,
                                                                     name="Division Detection (optional)",
                                                                     projectFileGroupName="DivisionDetection",
                                                                     selectedFeatures=configStructured.selectedFeaturesDiv)

        self.cellClassificationApplet = ObjectClassificationApplet(workflow=self,
                                                                     name="Object Count Classification",
                                                                     projectFileGroupName="CountClassification",
                                                                     selectedFeatures=configStructured.selectedFeaturesObjectCount)

        self.cropSelectionApplet = CropSelectionApplet(self,"Crop Selection","CropSelection")

        self.trackingFeatureExtractionApplet = TrackingFeatureExtractionApplet(name="Object Feature Computation",workflow=self, interactive=False)

        self.objectExtractionApplet = ObjectExtractionApplet(name="Object Feature Computation",workflow=self, interactive=False)

        self.annotationsApplet = AnnotationsApplet( name="Training", workflow=self )
        opAnnotations = self.annotationsApplet.topLevelOperator

        # self.default_training_export_filename = '{dataset_dir}/{nickname}-training_exported_data.csv'
        # self.dataExportAnnotationsApplet = TrackingBaseDataExportApplet(self, "Training Export",default_export_filename=self.default_training_export_filename)
        # opDataExportAnnotations = self.dataExportAnnotationsApplet.topLevelOperator
        # opDataExportAnnotations.SelectionNames.setValue( ['User Training for Tracking', 'Object Identities'] )
        # opDataExportAnnotations.WorkingDirectory.connect( opDataSelection.WorkingDirectory )
        # self.dataExportAnnotationsApplet.set_exporting_operator(opAnnotations)

        self.trackingApplet = StructuredTrackingApplet( name="Tracking - Structured Learning", workflow=self )
        opStructuredTracking = self.trackingApplet.topLevelOperator

        self.default_tracking_export_filename = '{dataset_dir}/{nickname}-tracking_exported_data.csv'
        self.dataExportTrackingApplet = TrackingBaseDataExportApplet(self, "Tracking Result Export",default_export_filename=self.default_tracking_export_filename)
        opDataExportTracking = self.dataExportTrackingApplet.topLevelOperator
        opDataExportTracking.SelectionNames.setValue( ['Tracking-Result', 'Merger-Result', 'Object-Identities'] )
        opDataExportTracking.WorkingDirectory.connect( opDataSelection.WorkingDirectory )
        self.dataExportTrackingApplet.set_exporting_operator(opStructuredTracking)
        self.dataExportTrackingApplet.prepare_lane_for_export = self.prepare_lane_for_export
        self.dataExportTrackingApplet.post_process_lane_export = self.post_process_lane_export

        # configure export settings
        settings = {'file path': self.default_tracking_export_filename, 'compression': {}, 'file type': 'h5'}
        selected_features = ['Count', 'RegionCenter', 'RegionRadii', 'RegionAxes']
        opStructuredTracking.ExportSettings.setValue( (settings, selected_features) )

        self._applets = []
        self._applets.append(self.dataSelectionApplet)
        if not self.fromBinary:
            self._applets.append(self.thresholdTwoLevelsApplet)
        self._applets.append(self.trackingFeatureExtractionApplet)
        self._applets.append(self.divisionDetectionApplet)

        self.batchProcessingApplet = BatchProcessingApplet(self, "Batch Processing", self.dataSelectionApplet, self.dataExportTrackingApplet)

        self._applets.append(self.cellClassificationApplet)
        self._applets.append(self.cropSelectionApplet)
        self._applets.append(self.objectExtractionApplet)
        self._applets.append(self.annotationsApplet)
        # self._applets.append(self.dataExportAnnotationsApplet)
        self._applets.append(self.trackingApplet)
        self._applets.append(self.dataExportTrackingApplet)

        if self.divisionDetectionApplet:
            opDivDetection = self.divisionDetectionApplet.topLevelOperator
            opDivDetection.SelectedFeatures.setValue(configConservation.selectedFeaturesDiv)
            opDivDetection.LabelNames.setValue(['Not Dividing', 'Dividing'])
            opDivDetection.AllowDeleteLabels.setValue(False)
            opDivDetection.AllowAddLabel.setValue(False)
            opDivDetection.EnableLabelTransfer.setValue(False)

        opCellClassification = self.cellClassificationApplet.topLevelOperator
        opCellClassification.SelectedFeatures.setValue(configConservation.selectedFeaturesObjectCount )
        opCellClassification.SuggestedLabelNames.setValue( ['False Detection',] + [str(1) + ' Object'] + [str(i) + ' Objects' for i in range(2,10) ] )
        opCellClassification.AllowDeleteLastLabelOnly.setValue(True)
        opCellClassification.EnableLabelTransfer.setValue(False)

        if workflow_cmdline_args:
            self._data_export_args, unused_args = self.dataExportTrackingApplet.parse_known_cmdline_args( workflow_cmdline_args )
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args( workflow_cmdline_args )
        else:
            unused_args = None
            self._data_export_args = None
            self._batch_input_args = None

        if unused_args:
            logger.warn("Unused command-line args: {}".format( unused_args ))

    def connectLane(self, laneIndex):
        opData = self.dataSelectionApplet.topLevelOperator.getLane(laneIndex)
        opObjExtraction = self.objectExtractionApplet.topLevelOperator.getLane(laneIndex)
        opTrackingFeatureExtraction = self.trackingFeatureExtractionApplet.topLevelOperator.getLane(laneIndex)

        opAnnotations = self.annotationsApplet.topLevelOperator.getLane(laneIndex)
        if not self.fromBinary:
            opTwoLevelThreshold = self.thresholdTwoLevelsApplet.topLevelOperator.getLane(laneIndex)
        # opDataAnnotationsExport = self.dataExportAnnotationsApplet.topLevelOperator.getLane(laneIndex)

        opCropSelection = self.cropSelectionApplet.topLevelOperator.getLane(laneIndex)
        opStructuredTracking = self.trackingApplet.topLevelOperator.getLane(laneIndex)
        opDataTrackingExport = self.dataExportTrackingApplet.topLevelOperator.getLane(laneIndex)

        ## Connect operators ##
        op5Raw = OpReorderAxes(parent=self)
        op5Raw.AxisOrder.setValue("txyzc")
        op5Raw.Input.connect(opData.ImageGroup[0])

        opDivDetection = self.divisionDetectionApplet.topLevelOperator.getLane(laneIndex)
        opCellClassification = self.cellClassificationApplet.topLevelOperator.getLane(laneIndex)

        if not self.fromBinary:
            opTwoLevelThreshold.InputImage.connect( opData.ImageGroup[1] )
            opTwoLevelThreshold.RawInput.connect( opData.ImageGroup[0] ) # Used for display only
            binarySrc = opTwoLevelThreshold.CachedOutput
        else:
            binarySrc = opData.ImageGroup[1]
        # Use Op5ifyers for both input datasets such that they are guaranteed to
        # have the same axis order after thresholding
        op5Binary = OpReorderAxes(parent=self)
        op5Binary.AxisOrder.setValue("txyzc")
        op5Binary.Input.connect(binarySrc)

        opCropSelection.InputImage.connect( opData.ImageGroup[0] )
        opCropSelection.PredictionImage.connect( opData.ImageGroup[1] )

        opObjExtraction.RawImage.connect( op5Raw.Output )
        opObjExtraction.BinaryImage.connect( op5Binary.Output )

        opTrackingFeatureExtraction.RawImage.connect( op5Raw.Output )
        opTrackingFeatureExtraction.BinaryImage.connect( op5Binary.Output )

        # vigra_features = list((set(config.vigra_features)).union(config.selected_features_objectcount[config.features_vigra_name]))
        # feature_names_vigra = {}
        # feature_names_vigra[config.features_vigra_name] = { name: {} for name in vigra_features }

        opTrackingFeatureExtraction.FeatureNamesVigra.setValue(configConservation.allFeaturesObjectCount)
        feature_dict_division = {}
        feature_dict_division[config.features_division_name] = { name: {} for name in config.division_features }
        opTrackingFeatureExtraction.FeatureNamesDivision.setValue(feature_dict_division)

        if self.divisionDetectionApplet:
            opDivDetection.BinaryImages.connect( op5Binary.Output )
            opDivDetection.RawImages.connect( op5Raw.Output )
            opDivDetection.SegmentationImages.connect(opTrackingFeatureExtraction.LabelImage)
            opDivDetection.ObjectFeatures.connect(opTrackingFeatureExtraction.RegionFeaturesAll)
            opDivDetection.ComputedFeatureNames.connect(opTrackingFeatureExtraction.ComputedFeatureNamesAll)

        opCellClassification.BinaryImages.connect( op5Binary.Output )
        opCellClassification.RawImages.connect( op5Raw.Output )
        opCellClassification.SegmentationImages.connect(opTrackingFeatureExtraction.LabelImage)
        opCellClassification.ObjectFeatures.connect(opTrackingFeatureExtraction.RegionFeaturesAll)
        opCellClassification.ComputedFeatureNames.connect(opTrackingFeatureExtraction.ComputedFeatureNamesNoDivisions)

        opAnnotations.RawImage.connect( op5Raw.Output )
        opAnnotations.BinaryImage.connect( op5Binary.Output )
        opAnnotations.LabelImage.connect( opObjExtraction.LabelImage )
        opAnnotations.ObjectFeatures.connect( opObjExtraction.RegionFeatures )
        opAnnotations.ComputedFeatureNames.connect(opObjExtraction.Features)
        opAnnotations.Crops.connect( opCropSelection.Crops)
        opAnnotations.DivisionProbabilities.connect( opDivDetection.Probabilities )
        opAnnotations.DetectionProbabilities.connect( opCellClassification.Probabilities )
        opAnnotations.MaxNumObj.connect (opCellClassification.MaxNumObj)

        # opDataAnnotationsExport.Inputs.resize(2)
        # opDataAnnotationsExport.Inputs[0].connect( opAnnotations.TrackImage )
        # opDataAnnotationsExport.Inputs[1].connect( opAnnotations.LabelImage )
        # opDataAnnotationsExport.RawData.connect( op5Raw.Output )
        # opDataAnnotationsExport.RawDatasetInfo.connect( opData.DatasetGroup[0] )

        opStructuredTracking.RawImage.connect( op5Raw.Output )
        opStructuredTracking.LabelImage.connect( opTrackingFeatureExtraction.LabelImage )
        opStructuredTracking.ObjectFeatures.connect( opTrackingFeatureExtraction.RegionFeaturesVigra )
        opStructuredTracking.ComputedFeatureNames.connect( opTrackingFeatureExtraction.FeatureNamesVigra )

        if self.divisionDetectionApplet:
            opStructuredTracking.ObjectFeaturesWithDivFeatures.connect( opTrackingFeatureExtraction.RegionFeaturesAll)
            opStructuredTracking.ComputedFeatureNamesWithDivFeatures.connect( opTrackingFeatureExtraction.ComputedFeatureNamesAll )
            opStructuredTracking.DivisionProbabilities.connect( opDivDetection.Probabilities )

        # configure tracking export settings
        settings = {'file path': self.default_tracking_export_filename, 'compression': {}, 'file type': 'csv'}
        selected_features = ['Count', 'RegionCenter']
        opStructuredTracking.configure_table_export_settings(settings, selected_features)

        opStructuredTracking.DetectionProbabilities.connect( opCellClassification.Probabilities )
        opStructuredTracking.NumLabels.connect( opCellClassification.NumLabels )
        opStructuredTracking.Crops.connect (opCropSelection.Crops)
        opStructuredTracking.Annotations.connect (opAnnotations.Annotations)
        opStructuredTracking.Labels.connect (opAnnotations.Labels)
        opStructuredTracking.Divisions.connect (opAnnotations.Divisions)
        opStructuredTracking.MaxNumObj.connect (opCellClassification.MaxNumObj)

        opDataTrackingExport.Inputs.resize(3)
        opDataTrackingExport.Inputs[0].connect( opStructuredTracking.RelabeledImage )
        opDataTrackingExport.Inputs[1].connect( opStructuredTracking.MergerOutput )
        opDataTrackingExport.Inputs[2].connect( opStructuredTracking.LabelImage )
        opDataTrackingExport.RawData.connect( op5Raw.Output )
        opDataTrackingExport.RawDatasetInfo.connect( opData.DatasetGroup[0] )

    def prepare_lane_for_export(self, lane_index):
        import logging
        logger = logging.getLogger(__name__)

        maxt = self.trackingApplet.topLevelOperator[lane_index].RawImage.meta.shape[0]
        maxx = self.trackingApplet.topLevelOperator[lane_index].RawImage.meta.shape[1]
        maxy = self.trackingApplet.topLevelOperator[lane_index].RawImage.meta.shape[2]
        maxz = self.trackingApplet.topLevelOperator[lane_index].RawImage.meta.shape[3]
        time_enum = range(maxt)
        x_range = (0, maxx)
        y_range = (0, maxy)
        z_range = (0, maxz)

        ndim = 2
        if ( z_range[1] - z_range[0] ) > 1:
            ndim = 3

        parameters = self.trackingApplet.topLevelOperator.Parameters.value
        # Save state of axis ranges
        if 'time_range' in parameters:
            self.prev_time_range = parameters['time_range']
        else:
            self.prev_time_range = time_enum

        if 'x_range' in parameters:
            self.prev_x_range = parameters['x_range']
        else:
            self.prev_x_range = x_range

        if 'y_range' in parameters:
            self.prev_y_range = parameters['y_range']
        else:
            self.prev_y_range = y_range

        if 'z_range' in parameters:
            self.prev_z_range = parameters['z_range']
        else:
            self.prev_z_range = z_range

        # batch processing starts a new lane, so training data needs to be copied from the lane that loaded the project
        loaded_project_lane_index=0
        self.annotationsApplet.topLevelOperator[lane_index].Annotations.setValue(
            self.trackingApplet.topLevelOperator[loaded_project_lane_index].Annotations.value)

        self.cropSelectionApplet.topLevelOperator[lane_index].Crops.setValue(
            self.trackingApplet.topLevelOperator[loaded_project_lane_index].Crops.value)

        logger.info("Test: Structured Learning")
        weights = self.trackingApplet.topLevelOperator[lane_index]._runStructuredLearning(
            z_range,
            parameters['maxObj'],
            parameters['max_nearest_neighbors'],
            parameters['maxDist'],
            parameters['divThreshold'],
            [parameters['scales'][0],parameters['scales'][1],parameters['scales'][2]],
            parameters['size_range'],
            parameters['withDivisions'],
            parameters['borderAwareWidth'],
            parameters['withClassifierPrior'],
            withBatchProcessing=True)
        logger.info("weights: {}".format(weights))

        logger.info("Test: Tracking")
        self.trackingApplet.topLevelOperator[lane_index].track(
            time_range = time_enum,
            x_range = x_range,
            y_range = y_range,
            z_range = z_range,
            size_range = parameters['size_range'],
            x_scale = parameters['scales'][0],
            y_scale = parameters['scales'][1],
            z_scale = parameters['scales'][2],
            maxDist=parameters['maxDist'],
            maxObj = parameters['maxObj'],
            divThreshold=parameters['divThreshold'],
            avgSize=parameters['avgSize'],
            withTracklets=parameters['withTracklets'],
            sizeDependent=parameters['sizeDependent'],
            detWeight=parameters['detWeight'],
            divWeight=parameters['divWeight'],
            transWeight=parameters['transWeight'],
            withDivisions=parameters['withDivisions'],
            withOpticalCorrection=parameters['withOpticalCorrection'],
            withClassifierPrior=parameters['withClassifierPrior'],
            ndim=ndim,
            withMergerResolution=parameters['withMergerResolution'],
            borderAwareWidth = parameters['borderAwareWidth'],
            withArmaCoordinates = parameters['withArmaCoordinates'],
            cplex_timeout = parameters['cplex_timeout'],
            appearance_cost = parameters['appearanceCost'],
            disappearance_cost = parameters['disappearanceCost'],
            force_build_hypotheses_graph = False,
            withBatchProcessing = True
        )

    def post_process_lane_export(self, lane_index):
        # FIXME: This probably only works for the non-blockwise export slot.
        #        We should assert that the user isn't using the blockwise slot.
        settings, selected_features = self.trackingApplet.topLevelOperator.getLane(lane_index).get_table_export_settings()
        from lazyflow.utility import PathComponents, make_absolute, format_known_keys

        if settings:
            self.dataExportTrackingApplet.progressSignal.emit(-1)
            raw_dataset_info = self.dataSelectionApplet.topLevelOperator.DatasetGroup[lane_index][0].value

            project_path = self.shell.projectManager.currentProjectPath
            project_dir = os.path.dirname(project_path)
            dataset_dir = PathComponents(raw_dataset_info.filePath).externalDirectory
            abs_dataset_dir = make_absolute(dataset_dir, cwd=project_dir)

            known_keys = {}
            known_keys['dataset_dir'] = abs_dataset_dir
            nickname = raw_dataset_info.nickname.replace('*', '')
            if os.path.pathsep in nickname:
                nickname = PathComponents(nickname.split(os.path.pathsep)[0]).fileNameBase
            known_keys['nickname'] = nickname

            # use partial formatting to fill in non-coordinate name fields
            name_format = settings['file path']
            partially_formatted_name = format_known_keys( name_format, known_keys )
            settings['file path'] = partially_formatted_name

            req = self.trackingApplet.topLevelOperator.getLane(lane_index).export_object_data(
                        lane_index,
                        # FIXME: Even in non-headless mode, we can't show the gui because we're running in a non-main thread.
                        #        That's not a huge deal, because there's still a progress bar for the overall export.
                        show_gui=False)

            req.wait()
            self.dataExportTrackingApplet.progressSignal.emit(100)

    def _inputReady(self, nRoles):
        slot = self.dataSelectionApplet.topLevelOperator.ImageGroup
        if len(slot) > 0:
            input_ready = True
            for sub in slot:
                input_ready = input_ready and \
                    all([sub[i].ready() for i in range(nRoles)])
        else:
            input_ready = False
        return input_ready

    def onProjectLoaded(self, projectManager):
        """
        Overridden from Workflow base class.  Called by the Project Manager.

        If the user provided command-line arguments, use them to configure
        the workflow inputs and output settings.
        """

        # Configure the data export operator.
        if self._data_export_args:
            self.dataExportTrackingApplet.configure_operator_with_parsed_args( self._data_export_args )

        # Configure headless mode.
        if self._headless and self._batch_input_args and self._data_export_args:
            logger.info("Beginning Batch Processing")
            self.batchProcessingApplet.run_export_from_parsed_args(self._batch_input_args)
            logger.info("Completed Batch Processing")

    def handleAppletStateUpdateRequested(self):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.statusUpdateSignal`
        """
        # If no data, nothing else is ready.
        input_ready = self._inputReady(2) and not self.dataSelectionApplet.busy

        if not self.fromBinary:
            opThresholding = self.thresholdTwoLevelsApplet.topLevelOperator
            thresholdingOutput = opThresholding.CachedOutput
            thresholding_ready = input_ready and len(thresholdingOutput) > 0
        else:
            thresholding_ready = input_ready

        opTrackingFeatureExtraction = self.trackingFeatureExtractionApplet.topLevelOperator
        trackingFeatureExtractionOutput = opTrackingFeatureExtraction.ComputedFeatureNamesAll
        tracking_features_ready = thresholding_ready and len(trackingFeatureExtractionOutput) > 0

        opCropSelection = self.cropSelectionApplet.topLevelOperator
        croppingOutput = opCropSelection.Crops
        cropping_ready = thresholding_ready and len(croppingOutput) > 0

        objectCountClassifier_ready = tracking_features_ready

        opObjectExtraction = self.objectExtractionApplet.topLevelOperator
        objectExtractionOutput = opObjectExtraction.RegionFeatures
        features_ready = thresholding_ready and \
                         len(objectExtractionOutput) > 0

        opAnnotations = self.annotationsApplet.topLevelOperator
        annotations_ready = features_ready and \
                           len(opAnnotations.Labels) > 0 and \
                           opAnnotations.Labels.ready() and \
                           opAnnotations.TrackImage.ready()

        opStructuredTracking = self.trackingApplet.topLevelOperator
        structured_tracking_ready = objectCountClassifier_ready and \
                           len(opStructuredTracking.EventsVector) > 0
        busy = False
        busy |= self.dataSelectionApplet.busy
        busy |= self.annotationsApplet.busy
        # busy |= self.dataExportAnnotationsApplet.busy
        busy |= self.trackingApplet.busy
        busy |= self.dataExportTrackingApplet.busy

        self._shell.enableProjectChanges( not busy )

        self._shell.setAppletEnabled(self.dataSelectionApplet, not busy)
        if not self.fromBinary:
            self._shell.setAppletEnabled(self.thresholdTwoLevelsApplet, input_ready and not busy)
        self._shell.setAppletEnabled(self.trackingFeatureExtractionApplet, thresholding_ready and not busy)
        self._shell.setAppletEnabled(self.cellClassificationApplet, tracking_features_ready and not busy)
        self._shell.setAppletEnabled(self.divisionDetectionApplet, tracking_features_ready and not busy)
        self._shell.setAppletEnabled(self.cropSelectionApplet, thresholding_ready and not busy)
        self._shell.setAppletEnabled(self.objectExtractionApplet, not busy)
        self._shell.setAppletEnabled(self.annotationsApplet, features_ready and not busy)
        # self._shell.setAppletEnabled(self.dataExportAnnotationsApplet, annotations_ready and not busy and \
        #                                 self.dataExportAnnotationsApplet.topLevelOperator.Inputs[0][0].ready() )
        self._shell.setAppletEnabled(self.trackingApplet, objectCountClassifier_ready and not busy)
        self._shell.setAppletEnabled(self.dataExportTrackingApplet, structured_tracking_ready and not busy and \
                                    self.dataExportTrackingApplet.topLevelOperator.Inputs[0][0].ready() )
예제 #26
0
class StructuredTrackingWorkflowBase(Workflow):
    workflowName = "Structured Learning Tracking Workflow BASE"

    @property
    def applets(self):
        return self._applets

    @property
    def imageNameListSlot(self):
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def __init__(self, shell, headless, workflow_cmdline_args,
                 project_creation_args, *args, **kwargs):
        graph = kwargs["graph"] if "graph" in kwargs else Graph()
        if "graph" in kwargs:
            del kwargs["graph"]

        super(StructuredTrackingWorkflowBase,
              self).__init__(shell,
                             headless,
                             workflow_cmdline_args,
                             project_creation_args,
                             graph=graph,
                             *args,
                             **kwargs)

        data_instructions = 'Use the "Raw Data" tab to load your intensity image(s).\n\n'
        if self.fromBinary:
            data_instructions += 'Use the "Binary Image" tab to load your segmentation image(s).'
        else:
            data_instructions += 'Use the "Prediction Maps" tab to load your pixel-wise probability image(s).'

        # Create applets
        self.dataSelectionApplet = DataSelectionApplet(
            self,
            "Input Data",
            "Input Data",
            batchDataGui=False,
            forceAxisOrder=["txyzc"],
            instructionText=data_instructions,
            max_lanes=1,
        )

        opDataSelection = self.dataSelectionApplet.topLevelOperator
        if self.fromBinary:
            opDataSelection.DatasetRoles.setValue(["Raw Data", "Binary Image"])
        else:
            opDataSelection.DatasetRoles.setValue(
                ["Raw Data", "Prediction Maps"])

        if not self.fromBinary:
            self.thresholdTwoLevelsApplet = ThresholdTwoLevelsApplet(
                self, "Threshold and Size Filter", "ThresholdTwoLevels")

        self.divisionDetectionApplet = ObjectClassificationApplet(
            workflow=self,
            name="Division Detection (optional)",
            projectFileGroupName="DivisionDetection",
            selectedFeatures=configConservation.selectedFeaturesDiv,
        )

        self.cellClassificationApplet = ObjectClassificationApplet(
            workflow=self,
            name="Object Count Classification",
            projectFileGroupName="CountClassification",
            selectedFeatures=configConservation.selectedFeaturesObjectCount,
        )

        self.trackingFeatureExtractionApplet = TrackingFeatureExtractionApplet(
            name="Object Feature Computation",
            workflow=self,
            interactive=False)

        self.objectExtractionApplet = ObjectExtractionApplet(
            name="Object Feature Computation",
            workflow=self,
            interactive=False)

        self.annotationsApplet = AnnotationsApplet(name="Training",
                                                   workflow=self)
        opAnnotations = self.annotationsApplet.topLevelOperator

        self.trackingApplet = StructuredTrackingApplet(
            name="Tracking - Structured Learning", workflow=self)
        opStructuredTracking = self.trackingApplet.topLevelOperator

        if SOLVER == "CPLEX" or SOLVER == "GUROBI":
            self._solver = "ILP"
        elif SOLVER == "DPCT":
            self._solver = "Flow-based"
        else:
            self._solver = None
        opStructuredTracking._solver = self._solver

        self.default_tracking_export_filename = "{dataset_dir}/{nickname}-tracking_exported_data.csv"
        self.dataExportTrackingApplet = TrackingBaseDataExportApplet(
            self,
            "Tracking Result Export",
            default_export_filename=self.default_tracking_export_filename,
            pluginExportFunc=self._pluginExportFunc,
        )
        opDataExportTracking = self.dataExportTrackingApplet.topLevelOperator
        opDataExportTracking.SelectionNames.setValue(
            ["Tracking-Result", "Merger-Result", "Object-Identities"])
        opDataExportTracking.WorkingDirectory.connect(
            opDataSelection.WorkingDirectory)
        self.dataExportTrackingApplet.set_exporting_operator(
            opStructuredTracking)
        self.dataExportTrackingApplet.prepare_lane_for_export = self.prepare_lane_for_export

        # configure export settings
        settings = {
            "file path": self.default_tracking_export_filename,
            "compression": {},
            "file type": "h5"
        }
        selected_features = [
            "Count", "RegionCenter", "RegionRadii", "RegionAxes"
        ]
        opStructuredTracking.ExportSettings.setValue(
            (settings, selected_features))

        self._applets = []
        self._applets.append(self.dataSelectionApplet)
        if not self.fromBinary:
            self._applets.append(self.thresholdTwoLevelsApplet)
        self._applets.append(self.trackingFeatureExtractionApplet)
        self._applets.append(self.divisionDetectionApplet)

        self.batchProcessingApplet = BatchProcessingApplet(
            self, "Batch Processing", self.dataSelectionApplet,
            self.dataExportTrackingApplet)

        self._applets.append(self.cellClassificationApplet)
        self._applets.append(self.objectExtractionApplet)
        self._applets.append(self.annotationsApplet)
        self._applets.append(self.trackingApplet)
        self._applets.append(self.dataExportTrackingApplet)

        if self.divisionDetectionApplet:
            opDivDetection = self.divisionDetectionApplet.topLevelOperator
            opDivDetection.SelectedFeatures.setValue(
                configConservation.selectedFeaturesDiv)
            opDivDetection.LabelNames.setValue(["Not Dividing", "Dividing"])
            opDivDetection.AllowDeleteLabels.setValue(False)
            opDivDetection.AllowAddLabel.setValue(False)
            opDivDetection.EnableLabelTransfer.setValue(False)

        opCellClassification = self.cellClassificationApplet.topLevelOperator
        opCellClassification.SelectedFeatures.setValue(
            configConservation.selectedFeaturesObjectCount)
        opCellClassification.SuggestedLabelNames.setValue(
            ["False Detection"] + [str(1) + " Object"] +
            [str(i) + " Objects" for i in range(2, 10)])
        opCellClassification.AllowDeleteLastLabelOnly.setValue(True)
        opCellClassification.EnableLabelTransfer.setValue(False)

        if workflow_cmdline_args:

            if "--testFullAnnotations" in workflow_cmdline_args:
                self.testFullAnnotations = True
            else:
                self.testFullAnnotations = False

            self._data_export_args, unused_args = self.dataExportTrackingApplet.parse_known_cmdline_args(
                workflow_cmdline_args)
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args(
                workflow_cmdline_args)
        else:
            unused_args = None
            self._data_export_args = None
            self._batch_input_args = None
            self.testFullAnnotations = False

        if unused_args:
            logger.warning("Unused command-line args: {}".format(unused_args))

    def _pluginExportFunc(self, lane_index, filename, exportPlugin,
                          checkOverwriteFiles, plugArgsSlot) -> int:
        return self.trackingApplet.topLevelOperator.getLane(
            lane_index).exportPlugin(filename, exportPlugin,
                                     checkOverwriteFiles, plugArgsSlot)

    def connectLane(self, laneIndex):
        opData = self.dataSelectionApplet.topLevelOperator.getLane(laneIndex)
        opObjExtraction = self.objectExtractionApplet.topLevelOperator.getLane(
            laneIndex)
        opTrackingFeatureExtraction = self.trackingFeatureExtractionApplet.topLevelOperator.getLane(
            laneIndex)

        opAnnotations = self.annotationsApplet.topLevelOperator.getLane(
            laneIndex)
        if not self.fromBinary:
            opTwoLevelThreshold = self.thresholdTwoLevelsApplet.topLevelOperator.getLane(
                laneIndex)

        opStructuredTracking = self.trackingApplet.topLevelOperator.getLane(
            laneIndex)
        opDataTrackingExport = self.dataExportTrackingApplet.topLevelOperator.getLane(
            laneIndex)

        ## Connect operators ##
        op5Raw = OpReorderAxes(parent=self)
        op5Raw.AxisOrder.setValue("txyzc")
        op5Raw.Input.connect(opData.ImageGroup[0])

        opDivDetection = self.divisionDetectionApplet.topLevelOperator.getLane(
            laneIndex)
        opCellClassification = self.cellClassificationApplet.topLevelOperator.getLane(
            laneIndex)

        if not self.fromBinary:
            opTwoLevelThreshold.InputImage.connect(opData.ImageGroup[1])
            opTwoLevelThreshold.RawInput.connect(
                opData.ImageGroup[0])  # Used for display only
            binarySrc = opTwoLevelThreshold.CachedOutput
        else:
            binarySrc = opData.ImageGroup[1]
        # Use Op5ifyers for both input datasets such that they are guaranteed to
        # have the same axis order after thresholding
        op5Binary = OpReorderAxes(parent=self)
        op5Binary.AxisOrder.setValue("txyzc")
        op5Binary.Input.connect(binarySrc)

        opObjExtraction.RawImage.connect(op5Raw.Output)
        opObjExtraction.BinaryImage.connect(op5Binary.Output)

        opTrackingFeatureExtraction.RawImage.connect(op5Raw.Output)
        opTrackingFeatureExtraction.BinaryImage.connect(op5Binary.Output)

        opTrackingFeatureExtraction.setDefaultFeatures(
            configConservation.allFeaturesObjectCount)
        opTrackingFeatureExtraction.FeatureNamesVigra.setValue(
            configConservation.allFeaturesObjectCount)
        feature_dict_division = {}
        feature_dict_division[config.features_division_name] = {
            name: {}
            for name in config.division_features
        }
        opTrackingFeatureExtraction.FeatureNamesDivision.setValue(
            feature_dict_division)

        if self.divisionDetectionApplet:
            opDivDetection.BinaryImages.connect(op5Binary.Output)
            opDivDetection.RawImages.connect(op5Raw.Output)
            opDivDetection.SegmentationImages.connect(
                opTrackingFeatureExtraction.LabelImage)
            opDivDetection.ObjectFeatures.connect(
                opTrackingFeatureExtraction.RegionFeaturesAll)
            opDivDetection.ComputedFeatureNames.connect(
                opTrackingFeatureExtraction.ComputedFeatureNamesAll)

        opCellClassification.BinaryImages.connect(op5Binary.Output)
        opCellClassification.RawImages.connect(op5Raw.Output)
        opCellClassification.SegmentationImages.connect(
            opTrackingFeatureExtraction.LabelImage)
        opCellClassification.ObjectFeatures.connect(
            opTrackingFeatureExtraction.RegionFeaturesAll)
        opCellClassification.ComputedFeatureNames.connect(
            opTrackingFeatureExtraction.ComputedFeatureNamesNoDivisions)

        opAnnotations.RawImage.connect(op5Raw.Output)
        opAnnotations.BinaryImage.connect(op5Binary.Output)
        opAnnotations.LabelImage.connect(opObjExtraction.LabelImage)
        opAnnotations.ObjectFeatures.connect(opObjExtraction.RegionFeatures)
        opAnnotations.ComputedFeatureNames.connect(opObjExtraction.Features)
        opAnnotations.DivisionProbabilities.connect(
            opDivDetection.Probabilities)
        opAnnotations.DetectionProbabilities.connect(
            opCellClassification.Probabilities)
        opAnnotations.MaxNumObj.connect(opCellClassification.MaxNumObj)

        opStructuredTracking.RawImage.connect(op5Raw.Output)
        opStructuredTracking.LabelImage.connect(
            opTrackingFeatureExtraction.LabelImage)
        opStructuredTracking.ObjectFeatures.connect(
            opTrackingFeatureExtraction.RegionFeaturesVigra)
        opStructuredTracking.ComputedFeatureNames.connect(
            opTrackingFeatureExtraction.FeatureNamesVigra)

        if self.divisionDetectionApplet:
            opStructuredTracking.ObjectFeaturesWithDivFeatures.connect(
                opTrackingFeatureExtraction.RegionFeaturesAll)
            opStructuredTracking.ComputedFeatureNamesWithDivFeatures.connect(
                opTrackingFeatureExtraction.ComputedFeatureNamesAll)
            opStructuredTracking.DivisionProbabilities.connect(
                opDivDetection.Probabilities)

        opStructuredTracking.DetectionProbabilities.connect(
            opCellClassification.Probabilities)
        opStructuredTracking.NumLabels.connect(opCellClassification.NumLabels)
        opStructuredTracking.Annotations.connect(opAnnotations.Annotations)
        opStructuredTracking.Labels.connect(opAnnotations.Labels)
        opStructuredTracking.Divisions.connect(opAnnotations.Divisions)
        opStructuredTracking.Appearances.connect(opAnnotations.Appearances)
        opStructuredTracking.Disappearances.connect(
            opAnnotations.Disappearances)
        opStructuredTracking.MaxNumObj.connect(opCellClassification.MaxNumObj)

        opDataTrackingExport.Inputs.resize(3)
        opDataTrackingExport.Inputs[0].connect(opStructuredTracking.Output)
        opDataTrackingExport.Inputs[1].connect(
            opStructuredTracking.MergerOutput)
        opDataTrackingExport.Inputs[2].connect(
            opStructuredTracking.RelabeledImage)
        opDataTrackingExport.RawData.connect(op5Raw.Output)
        opDataTrackingExport.RawDatasetInfo.connect(opData.DatasetGroup[0])

    def prepare_lane_for_export(self, lane_index):
        import logging

        logger = logging.getLogger(__name__)

        maxt = self.trackingApplet.topLevelOperator[
            lane_index].RawImage.meta.shape[0]
        maxx = self.trackingApplet.topLevelOperator[
            lane_index].RawImage.meta.shape[1]
        maxy = self.trackingApplet.topLevelOperator[
            lane_index].RawImage.meta.shape[2]
        maxz = self.trackingApplet.topLevelOperator[
            lane_index].RawImage.meta.shape[3]
        time_enum = list(range(maxt))
        x_range = (0, maxx)
        y_range = (0, maxy)
        z_range = (0, maxz)

        ndim = 2
        if (z_range[1] - z_range[0]) > 1:
            ndim = 3

        parameters = self.trackingApplet.topLevelOperator.Parameters.value
        # Save state of axis ranges
        if "time_range" in parameters:
            self.prev_time_range = parameters["time_range"]
        else:
            self.prev_time_range = time_enum

        if "x_range" in parameters:
            self.prev_x_range = parameters["x_range"]
        else:
            self.prev_x_range = x_range

        if "y_range" in parameters:
            self.prev_y_range = parameters["y_range"]
        else:
            self.prev_y_range = y_range

        if "z_range" in parameters:
            self.prev_z_range = parameters["z_range"]
        else:
            self.prev_z_range = z_range

        # batch processing starts a new lane, so training data needs to be copied from the lane that loaded the project
        loaded_project_lane_index = 0
        self.annotationsApplet.topLevelOperator[
            lane_index].Annotations.setValue(
                self.trackingApplet.
                topLevelOperator[loaded_project_lane_index].Annotations.value)

        def runLearningAndTracking(withMergerResolution=True):
            if self.testFullAnnotations:
                logger.info("Test: Structured Learning")
                weights = self.trackingApplet.topLevelOperator[
                    lane_index]._runStructuredLearning(
                        z_range,
                        parameters["maxObj"],
                        parameters["max_nearest_neighbors"],
                        parameters["maxDist"],
                        parameters["divThreshold"],
                        [
                            parameters["scales"][0], parameters["scales"][1],
                            parameters["scales"][2]
                        ],
                        parameters["size_range"],
                        parameters["withDivisions"],
                        parameters["borderAwareWidth"],
                        parameters["withClassifierPrior"],
                        withBatchProcessing=True,
                    )
                logger.info("weights: {}".format(weights))

            logger.info("Test: Tracking")
            result = self.trackingApplet.topLevelOperator[lane_index].track(
                time_range=time_enum,
                x_range=x_range,
                y_range=y_range,
                z_range=z_range,
                size_range=parameters["size_range"],
                x_scale=parameters["scales"][0],
                y_scale=parameters["scales"][1],
                z_scale=parameters["scales"][2],
                maxDist=parameters["maxDist"],
                maxObj=parameters["maxObj"],
                divThreshold=parameters["divThreshold"],
                avgSize=parameters["avgSize"],
                withTracklets=parameters["withTracklets"],
                sizeDependent=parameters["sizeDependent"],
                detWeight=parameters["detWeight"],
                divWeight=parameters["divWeight"],
                transWeight=parameters["transWeight"],
                withDivisions=parameters["withDivisions"],
                withOpticalCorrection=parameters["withOpticalCorrection"],
                withClassifierPrior=parameters["withClassifierPrior"],
                ndim=ndim,
                withMergerResolution=withMergerResolution,
                borderAwareWidth=parameters["borderAwareWidth"],
                withArmaCoordinates=parameters["withArmaCoordinates"],
                cplex_timeout=parameters["cplex_timeout"],
                appearance_cost=parameters["appearanceCost"],
                disappearance_cost=parameters["disappearanceCost"],
                force_build_hypotheses_graph=False,
                withBatchProcessing=True,
            )

            return result

        if self.testFullAnnotations:

            self.result = runLearningAndTracking(withMergerResolution=False)

            hypothesesGraph = self.trackingApplet.topLevelOperator[
                lane_index].LearningHypothesesGraph.value
            hypothesesGraph.insertSolution(self.result)
            hypothesesGraph.computeLineage()
            solution = hypothesesGraph.getSolutionDictionary()
            annotations = self.trackingApplet.topLevelOperator[
                lane_index].Annotations.value

            self.trackingApplet.topLevelOperator[
                lane_index].insertAnnotationsToHypothesesGraph(
                    hypothesesGraph, annotations, misdetectionLabel=-1)
            hypothesesGraph.computeLineage()
            solutionFromAnnotations = hypothesesGraph.getSolutionDictionary()

            for key in list(solution.keys()):
                if key == "detectionResults":
                    detectionFlag = True
                    for i in range(len(solution[key])):
                        flag = False
                        for j in range(len(solutionFromAnnotations[key])):
                            if (solution[key][i]["id"]
                                    == solutionFromAnnotations[key][j]["id"]
                                    and solution[key][i]["value"] ==
                                    solutionFromAnnotations[key][j]["value"]):
                                flag = True
                                break
                        detectionFlag &= flag
                elif key == "divisionResults":
                    divisionFlag = True
                    for i in range(len(solution[key])):
                        flag = False
                        for j in range(len(solutionFromAnnotations[key])):
                            if (solution[key][i]["id"]
                                    == solutionFromAnnotations[key][j]["id"]
                                    and solution[key][i]["value"] ==
                                    solutionFromAnnotations[key][j]["value"]):
                                flag = True
                                break
                        divisionFlag &= flag
                elif key == "linkingResults":
                    linkingFlag = True
                    for i in range(len(solution[key])):
                        flag = False
                        for j in range(len(solutionFromAnnotations[key])):
                            if (solution[key][i]["dest"]
                                    == solutionFromAnnotations[key][j]["dest"]
                                    and solution[key][i]["src"]
                                    == solutionFromAnnotations[key][j]["src"]):
                                if (solution[key][i]["gap"] ==
                                        solutionFromAnnotations[key][j]["gap"]
                                        and solution[key][i]["value"]
                                        == solutionFromAnnotations[key][j]
                                    ["value"]):
                                    flag = True
                                    break
                        linkingFlag &= flag

            assert detectionFlag, "Detection results are NOT correct. They differ from your annotated detections."
            logger.info("Detection results are correct.")
            assert divisionFlag, "Division results are NOT correct. They differ from your annotated divisions."
            logger.info("Division results are correct.")
            assert linkingFlag, "Transition results are NOT correct. They differ from your annotated transitions."
            logger.info("Transition results are correct.")
        self.result = runLearningAndTracking(
            withMergerResolution=parameters["withMergerResolution"])

    def _inputReady(self, nRoles):
        slot = self.dataSelectionApplet.topLevelOperator.ImageGroup
        if len(slot) > 0:
            input_ready = True
            for sub in slot:
                input_ready = input_ready and all(
                    [sub[i].ready() for i in range(nRoles)])
        else:
            input_ready = False
        return input_ready

    def onProjectLoaded(self, projectManager):
        """
        Overridden from Workflow base class.  Called by the Project Manager.

        If the user provided command-line arguments, use them to configure
        the workflow inputs and output settings.
        """

        # Configure the data export operator.
        if self._data_export_args:
            self.dataExportTrackingApplet.configure_operator_with_parsed_args(
                self._data_export_args)

        # Configure headless mode.
        if self._headless and self._batch_input_args and self._data_export_args:
            logger.info("Beginning Batch Processing")
            self.batchProcessingApplet.run_export_from_parsed_args(
                self._batch_input_args)
            logger.info("Completed Batch Processing")

    def handleAppletStateUpdateRequested(self):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.statusUpdateSignal`
        """
        # If no data, nothing else is ready.
        input_ready = self._inputReady(2) and not self.dataSelectionApplet.busy

        if not self.fromBinary:
            opThresholding = self.thresholdTwoLevelsApplet.topLevelOperator
            thresholdingOutput = opThresholding.CachedOutput
            thresholding_ready = input_ready and len(thresholdingOutput) > 0
        else:
            thresholding_ready = input_ready

        opTrackingFeatureExtraction = self.trackingFeatureExtractionApplet.topLevelOperator
        trackingFeatureExtractionOutput = opTrackingFeatureExtraction.ComputedFeatureNamesAll
        tracking_features_ready = thresholding_ready and len(
            trackingFeatureExtractionOutput) > 0

        objectCountClassifier_ready = tracking_features_ready

        opObjectExtraction = self.objectExtractionApplet.topLevelOperator
        objectExtractionOutput = opObjectExtraction.RegionFeatures
        features_ready = thresholding_ready and len(objectExtractionOutput) > 0

        opAnnotations = self.annotationsApplet.topLevelOperator
        annotations_ready = (features_ready and len(opAnnotations.Labels) > 0
                             and opAnnotations.Labels.ready()
                             and opAnnotations.TrackImage.ready())

        opStructuredTracking = self.trackingApplet.topLevelOperator
        structured_tracking_ready = objectCountClassifier_ready

        withIlpSolver = self._solver == "ILP"

        busy = False
        busy |= self.dataSelectionApplet.busy
        busy |= self.annotationsApplet.busy
        # busy |= self.dataExportAnnotationsApplet.busy
        busy |= self.trackingApplet.busy
        busy |= self.dataExportTrackingApplet.busy

        self._shell.enableProjectChanges(not busy)

        self._shell.setAppletEnabled(self.dataSelectionApplet, not busy)
        if not self.fromBinary:
            self._shell.setAppletEnabled(self.thresholdTwoLevelsApplet,
                                         input_ready and not busy)
        self._shell.setAppletEnabled(self.trackingFeatureExtractionApplet,
                                     thresholding_ready and not busy)
        self._shell.setAppletEnabled(self.cellClassificationApplet,
                                     tracking_features_ready and not busy)
        self._shell.setAppletEnabled(self.divisionDetectionApplet,
                                     tracking_features_ready and not busy)
        self._shell.setAppletEnabled(self.objectExtractionApplet, not busy)
        self._shell.setAppletEnabled(self.annotationsApplet, features_ready
                                     and not busy)  # and withIlpSolver)
        # self._shell.setAppletEnabled(self.dataExportAnnotationsApplet, annotations_ready and not busy and \
        #                                 self.dataExportAnnotationsApplet.topLevelOperator.Inputs[0][0].ready() )
        self._shell.setAppletEnabled(self.trackingApplet,
                                     objectCountClassifier_ready and not busy)
        self._shell.setAppletEnabled(
            self.dataExportTrackingApplet,
            structured_tracking_ready and not busy
            and self.dataExportTrackingApplet.topLevelOperator.Inputs[0]
            [0].ready(),
        )
class ObjectClassificationWorkflow(Workflow):
    workflowName = "Object Classification Workflow Base"
    defaultAppletIndex = 0  # show DataSelection by default

    @property
    def ExportNames(self):
        @enum.unique
        class ExportNames(SlotNameEnum):
            OBJECT_PREDICTIONS = enum.auto()
            OBJECT_PROBABILITIES = enum.auto()
            BLOCKWISE_OBJECT_PREDICTIONS = enum.auto()
            BLOCKWISE_OBJECT_PROBABILITIES = enum.auto()
            OBJECT_IDENTITIES = enum.auto()

        return ExportNames

    class InputImageRoles(SlotNameEnum):
        RAW_DATA = enum.auto()
        ATLAS = enum.auto()

    @property
    def data_instructions(self):
        return (
            f'Use the "{self.InputImageRoles.RAW_DATA.displayName}" tab to load your intensity image(s).\n\n'
            f'Use the (optional) "{self.InputImageRoles.ATLAS.displayName}" tab if you want to map your objects to colors in an Atlas image.\n\n'
        )

    def __init__(self, shell, headless, workflow_cmdline_args,
                 project_creation_args, *args, **kwargs):
        graph = kwargs.pop("graph") if "graph" in kwargs else Graph()
        super().__init__(shell,
                         headless,
                         workflow_cmdline_args,
                         project_creation_args,
                         graph=graph,
                         *args,
                         **kwargs)
        self.stored_object_classifier = None

        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument(
            "--fillmissing",
            help="use 'fill missing' applet with chosen detection method",
            choices=["classic", "svm", "none"],
            default="none",
        )
        parser.add_argument("--nobatch",
                            help="do not append batch applets",
                            action="store_true",
                            default=False)

        parsed_creation_args, unused_args = parser.parse_known_args(
            project_creation_args)

        self.fillMissing = parsed_creation_args.fillmissing

        parsed_args, unused_args = parser.parse_known_args(
            workflow_cmdline_args)
        if parsed_args.fillmissing != "none" and parsed_creation_args.fillmissing != parsed_args.fillmissing:
            logger.error(
                "Ignoring --fillmissing cmdline arg.  Can't specify a different fillmissing setting after the project has already been created."
            )

        self.batch = not parsed_args.nobatch

        self._applets = []

        self.createInputApplets()

        if self.fillMissing != "none":
            self.fillMissingSlicesApplet = FillMissingSlicesApplet(
                self, "Fill Missing Slices", "Fill Missing Slices",
                self.fillMissing)
            self._applets.append(self.fillMissingSlicesApplet)

        # our main applets
        self.objectExtractionApplet = ObjectExtractionApplet(
            workflow=self, name="Object Feature Selection")
        self.objectClassificationApplet = ObjectClassificationApplet(
            workflow=self)
        self._tableExporter = TableExporter(
            self.objectClassificationApplet.topLevelOperator)
        self.dataExportApplet = ObjectClassificationDataExportApplet(
            self,
            "Object Information Export",
            table_exporter=self._tableExporter)

        # Customization hooks
        self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        self.dataExportApplet.post_process_entire_export = self.post_process_entire_export

        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.WorkingDirectory.connect(
            self.dataSelectionApplet.topLevelOperator.WorkingDirectory)

        opDataExport.SelectionNames.setValue(
            self.ExportNames.asDisplayNameList())

        self._batch_export_args = None
        self._batch_input_args = None
        self._export_args = None
        self.batchProcessingApplet = None

        self._applets.append(self.objectExtractionApplet)
        self._applets.append(self.objectClassificationApplet)
        self._applets.append(self.dataExportApplet)

        if self.batch:
            self.batchProcessingApplet = BatchProcessingApplet(
                self, "Batch Processing", self.dataSelectionApplet,
                self.dataExportApplet)
            self._applets.append(self.batchProcessingApplet)

            if unused_args:
                exportsArgParser, _ = self.exportsArgParser
                self._export_args, unused_args = exportsArgParser.parse_known_args(
                    unused_args)

                # We parse the export setting args first.  All remaining args are considered input files by the input applet.
                self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args(
                    unused_args)
                self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args(
                    unused_args)

                # For backwards compatibility, translate these special args into the standard syntax
                self._batch_input_args.export_source = self._export_args.export_source

        self.blockwiseObjectClassificationApplet = BlockwiseObjectClassificationApplet(
            self, "Blockwise Object Classification",
            "Blockwise Object Classification")
        self._applets.append(self.blockwiseObjectClassificationApplet)

        if unused_args:
            logger.warning("Unused command-line args: {}".format(unused_args))

    def createInputApplets(self):
        self.dataSelectionApplet = DataSelectionApplet(
            self,
            "Input Data",
            "Input Data",
            batchDataGui=False,
            forceAxisOrder=None,
            instructionText=self.data_instructions,
        )

        opData = self.dataSelectionApplet.topLevelOperator
        opData.DatasetRoles.setValue(self.InputImageRoles.asDisplayNameList())
        self._applets.append(self.dataSelectionApplet)

    @property
    def exportsArgParser(self):
        parser = argparse.ArgumentParser()
        parser.add_argument(
            "--table_filename",
            help=
            "The location to export the object feature/prediction CSV file.")
        exportImageArgGroup = parser.add_mutually_exclusive_group()
        exportImageArgGroup.add_argument(
            "--export_object_prediction_img",
            dest="export_source",
            action=_DeprecatedStoreConstAction,
            const=self.ExportNames.OBJECT_PREDICTIONS.displayName,
        )
        exportImageArgGroup.add_argument(
            "--export_object_probability_img",
            dest="export_source",
            action=_DeprecatedStoreConstAction,
            const=self.ExportNames.OBJECT_PROBABILITIES.displayName,
        )
        return parser, exportImageArgGroup

    @property
    def applets(self):
        return self._applets

    @property
    def imageNameListSlot(self):
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def prepareForNewLane(self, laneIndex):
        opObjectClassification = self.objectClassificationApplet.topLevelOperator
        if (opObjectClassification.classifier_cache.Output.ready()
                and not opObjectClassification.classifier_cache._dirty):
            self.stored_object_classifier = opObjectClassification.classifier_cache.Output.value
        else:
            self.stored_object_classifier = None

    def handleNewLanesAdded(self):
        """
        If new lanes were added, then we invalidated our classifiers unecessarily.
        Here, we can restore the classifier so it doesn't need to be retrained.
        """
        if self.stored_object_classifier:
            opObjectClassification = self.objectClassificationApplet.topLevelOperator
            opObjectClassification.classifier_cache.forceValue(
                self.stored_object_classifier, set_dirty=False)
            # Release reference
            self.stored_object_classifier = None

    def getImageSlot(self, input_role, laneIndex) -> OutputSlot:
        opData = self.dataSelectionApplet.topLevelOperator.getLane(laneIndex)
        return opData.ImageGroup[input_role]

    def toDefaultAxisOrder(self, slot):
        return OpReorderAxes(parent=self, AxisOrder="txyzc", Input=slot).Output

    def createRawDataSourceSlot(self, laneIndex, canonicalOrder=True):
        rawslot = self.getImageSlot(self.InputImageRoles.RAW_DATA, laneIndex)
        if self.fillMissing != "none":
            opFillMissingSlices = self.fillMissingSlicesApplet.topLevelOperator.getLane(
                laneIndex)
            opFillMissingSlices.Input.connect(rawslot)
            rawslot = opFillMissingSlices.Output

        if canonicalOrder:
            rawslot = self.toDefaultAxisOrder(rawslot)

        return rawslot

    def createAtlasSourceSlot(self, laneIndex):
        rawAtlasSlot = self.getImageSlot(self.InputImageRoles.ATLAS, laneIndex)
        return self.toDefaultAxisOrder(rawAtlasSlot)

    @abstractmethod
    def connectInputs(self, laneIndex):
        pass

    def connectLane(self, laneIndex):
        rawslot, binaryslot = self.connectInputs(laneIndex)
        atlas_slot = self.createAtlasSourceSlot(laneIndex)

        opData = self.dataSelectionApplet.topLevelOperator.getLane(laneIndex)

        opObjExtraction = self.objectExtractionApplet.topLevelOperator.getLane(
            laneIndex)
        opObjClassification = self.objectClassificationApplet.topLevelOperator.getLane(
            laneIndex)
        opDataExport = self.dataExportApplet.topLevelOperator.getLane(
            laneIndex)
        opBlockwiseObjectClassification = self.blockwiseObjectClassificationApplet.topLevelOperator.getLane(
            laneIndex)

        opObjExtraction.RawImage.connect(rawslot)
        opObjExtraction.BinaryImage.connect(binaryslot)
        opObjExtraction.Atlas.connect(atlas_slot)

        opObjClassification.RawImages.connect(rawslot)
        opObjClassification.BinaryImages.connect(binaryslot)
        opObjClassification.Atlas.connect(atlas_slot)

        opObjClassification.SegmentationImages.connect(
            opObjExtraction.LabelImage)
        opObjClassification.ObjectFeatures.connect(
            opObjExtraction.RegionFeatures)
        opObjClassification.ComputedFeatureNames.connect(
            opObjExtraction.Features)

        # Data Export connections
        opDataExport.RawData.connect(
            opData.ImageGroup[self.InputImageRoles.RAW_DATA])
        opDataExport.RawDatasetInfo.connect(
            opData.DatasetGroup[self.InputImageRoles.RAW_DATA])
        opDataExport.Inputs.resize(len(self.ExportNames))
        opDataExport.Inputs[self.ExportNames.OBJECT_PREDICTIONS].connect(
            opObjClassification.UncachedPredictionImages)
        opDataExport.Inputs[self.ExportNames.OBJECT_PROBABILITIES].connect(
            opObjClassification.ProbabilityChannelImage)
        opDataExport.Inputs[
            self.ExportNames.BLOCKWISE_OBJECT_PREDICTIONS].connect(
                opBlockwiseObjectClassification.PredictionImage)
        opDataExport.Inputs[
            self.ExportNames.BLOCKWISE_OBJECT_PROBABILITIES].connect(
                opBlockwiseObjectClassification.ProbabilityChannelImage)
        opDataExport.Inputs[self.ExportNames.OBJECT_IDENTITIES].connect(
            opObjClassification.SegmentationImagesOut)

        opObjClassification = self.objectClassificationApplet.topLevelOperator.getLane(
            laneIndex)
        opBlockwiseObjectClassification = self.blockwiseObjectClassificationApplet.topLevelOperator.getLane(
            laneIndex)

        opBlockwiseObjectClassification.RawImage.connect(
            opObjClassification.RawImages)
        opBlockwiseObjectClassification.BinaryImage.connect(
            opObjClassification.BinaryImages)
        opBlockwiseObjectClassification.Classifier.connect(
            opObjClassification.Classifier)
        opBlockwiseObjectClassification.LabelsCount.connect(
            opObjClassification.NumLabels)
        opBlockwiseObjectClassification.SelectedFeatures.connect(
            opObjClassification.SelectedFeatures)

    def onProjectLoaded(self, projectManager):
        if not self._headless:
            return

        if not (self._batch_input_args and self._batch_export_args):
            logger.warning(
                "Was not able to understand the batch mode command-line arguments."
            )

        # Check for problems: Is the project file ready to use?
        opObjClassification = self.objectClassificationApplet.topLevelOperator
        if not opObjClassification.Classifier.ready():
            logger.error(
                "Can't run batch prediction.\n"
                "Couldn't obtain a classifier from your project file: {}.\n"
                "Please make sure your project is fully configured with a trained classifier."
                .format(projectManager.currentProjectPath))
            return

        # Configure the data export operator.
        if self._batch_export_args:
            self.dataExportApplet.configure_operator_with_parsed_args(
                self._batch_export_args)

        if self._export_args:
            csv_filename = self._export_args.table_filename
            if csv_filename:
                # The user wants to override the csv export location via
                #  the command-line arguments. Apply the new setting to the operator.
                self._tableExporter.override_file_path(csv_filename)

        # Configure the batch data selection operator.
        if self._batch_input_args and self._batch_input_args.raw_data:
            logger.info("Beginning Batch Processing")
            self.batchProcessingApplet.run_export_from_parsed_args(
                self._batch_input_args)
            logger.info("Completed Batch Processing")

    def prepare_for_entire_export(self):
        # Un-freeze the workflow so we don't just get a bunch of zeros from the caches when we ask for results
        self.oc_freeze_status = self.objectClassificationApplet.topLevelOperator.FreezePredictions.value
        self.objectClassificationApplet.topLevelOperator.FreezePredictions.setValue(
            False)

    def post_process_entire_export(self):
        # Unfreeze.
        self.objectClassificationApplet.topLevelOperator.FreezePredictions.setValue(
            self.oc_freeze_status)

    def getHeadlessOutputSlot(self, slotId):
        if slotId == "BatchPredictionImage":
            return self.opBatchClassify.PredictionImage
        raise Exception("Unknown headless output slot")

    def handleAppletStateUpdateRequested(self, upstream_ready=False):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.appletStateUpdateRequested`

        This method will be called by the child classes with the result of their
        own applet readyness findings as keyword argument.
        """

        # all workflows have these applets in common:

        # object feature selection
        # object classification
        # object prediction export
        # blockwise classification
        # batch input
        # batch prediction export

        self._shell.setAppletEnabled(self.dataSelectionApplet,
                                     not self.batchProcessingApplet.busy)

        cumulated_readyness = upstream_ready
        cumulated_readyness &= (
            not self.batchProcessingApplet.busy
        )  # Nothing can be touched while batch mode is executing.

        self._shell.setAppletEnabled(self.objectExtractionApplet,
                                     cumulated_readyness)

        object_features_ready = (
            self.objectExtractionApplet.topLevelOperator.Features.ready() and
            len(self.objectExtractionApplet.topLevelOperator.Features.value) >
            0)
        cumulated_readyness = cumulated_readyness and object_features_ready
        self._shell.setAppletEnabled(self.objectClassificationApplet,
                                     cumulated_readyness)

        opObjectClassification = self.objectClassificationApplet.topLevelOperator
        invalid_classifier = (
            opObjectClassification.classifier_cache.fixAtCurrent.value
            and opObjectClassification.classifier_cache.Output.ready()
            and opObjectClassification.classifier_cache.Output.value is None)

        invalid_classifier |= not opObjectClassification.NumLabels.ready(
        ) or opObjectClassification.NumLabels.value < 2

        object_classification_ready = object_features_ready and not invalid_classifier

        cumulated_readyness = cumulated_readyness and object_classification_ready
        self._shell.setAppletEnabled(self.dataExportApplet,
                                     cumulated_readyness)

        if self.batch:
            object_prediction_ready = True  # TODO is that so?
            cumulated_readyness = cumulated_readyness and object_prediction_ready

            self._shell.setAppletEnabled(
                self.blockwiseObjectClassificationApplet, cumulated_readyness)
            self._shell.setAppletEnabled(self.batchProcessingApplet,
                                         cumulated_readyness)

        # Lastly, check for certain "busy" conditions, during which we
        # should prevent the shell from closing the project.
        # TODO implement
        busy = False
        self._shell.enableProjectChanges(not busy)

    def _inputReady(self):
        image_group_slot = self.dataSelectionApplet.topLevelOperator.ImageGroup
        for input_lane_slot in image_group_slot:
            for role in self.InputImageRoles:
                if role == self.InputImageRoles.ATLAS:
                    continue
                if not input_lane_slot[role].ready():
                    return False
        return bool(len(image_group_slot))

    def postprocessClusterSubResult(self, roi, result, blockwise_fileset):
        """
        This function is only used by special cluster scripts.

        When the batch-processing mechanism was rewritten, this function broke.
        It could probably be fixed with minor changes.
        """
        assert sys.version_info.major == 2, (
            "Alert! This function has not been "
            "tested under python 3. Please remove this assertion, and be wary of any "
            "strange behavior you encounter")

        # TODO: Here, we hard-code to select from the first lane only.
        opBatchClassify = self.opBatchClassify[0]

        from lazyflow.utility.io_uti.blockwiseFileset import vectorized_pickle_dumps

        # Assume that roi always starts as a multiple of the blockshape
        block_shape = opBatchClassify.get_blockshape()
        assert all(block_shape == blockwise_fileset.description.sub_block_shape
                   ), "block shapes don't match"
        assert all(
            (roi[0] % block_shape) == 0
        ), "Sub-blocks must exactly correspond to the blockwise object classification blockshape"
        sub_block_index = roi[
            0] // blockwise_fileset.description.sub_block_shape

        sub_block_start = sub_block_index
        sub_block_stop = sub_block_start + 1
        sub_block_roi = (sub_block_start, sub_block_stop)

        # FIRST, remove all objects that lie outside the block (i.e. remove the ones in the halo)
        region_features = opBatchClassify.BlockwiseRegionFeatures(
            *sub_block_roi).wait()
        region_features_dict = region_features.flat[0]
        region_centers = region_features_dict[default_features_key][
            "RegionCenter"]

        opBlockPipeline = opBatchClassify._blockPipelines[tuple(roi[0])]

        # Compute the block offset within the image coordinates
        halo_roi = opBlockPipeline._halo_roi

        translated_region_centers = region_centers + halo_roi[0][1:-1]

        # TODO: If this is too slow, vectorize this
        mask = numpy.zeros(region_centers.shape[0], dtype=numpy.bool_)
        for index, translated_region_center in enumerate(
                translated_region_centers):
            # FIXME: Here we assume t=0 and c=0
            mask[index] = opBatchClassify.is_in_block(
                roi[0], (0, ) + tuple(translated_region_center) + (0, ))

        # Always exclude the first object (it's the background??)
        mask[0] = False

        # Remove all 'negative' predictions, emit only 'positive' predictions
        # FIXME: Don't hardcode this?
        POSITIVE_LABEL = 2
        objectwise_predictions = opBlockPipeline.ObjectwisePredictions(
            []).wait()[0]
        assert objectwise_predictions.shape == mask.shape
        mask[objectwise_predictions != POSITIVE_LABEL] = False

        filtered_features = {}
        for feature_group, feature_dict in list(region_features_dict.items()):
            filtered_group = filtered_features[feature_group] = {}
            for feature_name, feature_array in list(feature_dict.items()):
                filtered_group[feature_name] = feature_array[mask]

        # SECOND, translate from block-local coordinates to global (file) coordinates.
        # Unfortunately, we've got multiple translations to perform here:
        # Coordinates in the region features are relative to their own block INCLUDING HALO,
        #  so we need to add the start of the block-with-halo as an offset.
        # BUT the image itself may be offset relative to the BlockwiseFileset coordinates
        #  (due to the view_origin setting), so we also need to add an offset for that, too

        # Get the image offset relative to the file coordinates
        image_offset = blockwise_fileset.description.view_origin

        total_offset_5d = halo_roi[0] + image_offset
        total_offset_3d = total_offset_5d[1:-1]

        filtered_features[default_features_key][
            "RegionCenter"] += total_offset_3d
        filtered_features[default_features_key][
            "Coord<Minimum>"] += total_offset_3d
        filtered_features[default_features_key][
            "Coord<Maximum>"] += total_offset_3d

        # Finally, write the features to hdf5
        h5File = blockwise_fileset.getOpenHdf5FileForBlock(roi[0])
        if "pickled_region_features" in h5File:
            del h5File["pickled_region_features"]

        # Must use str dtype
        dtype = h5py.new_vlen(str)
        dataset = h5File.create_dataset("pickled_region_features",
                                        shape=(1, ),
                                        dtype=dtype)
        pickled_features = vectorized_pickle_dumps(
            numpy.array((filtered_features, )))
        dataset[0] = pickled_features

        object_centers_xyz = filtered_features[default_features_key][
            "RegionCenter"].astype(int)
        object_min_coords_xyz = filtered_features[default_features_key][
            "Coord<Minimum>"].astype(int)
        object_max_coords_xyz = filtered_features[default_features_key][
            "Coord<Maximum>"].astype(int)
        object_sizes = filtered_features[default_features_key][
            "Count"][:, 0].astype(int)

        # Also, write out selected features as a 'point cloud' csv file.
        # (Store the csv file next to this block's h5 file.)
        dataset_directory = blockwise_fileset.getDatasetDirectory(roi[0])
        pointcloud_path = os.path.join(dataset_directory,
                                       "block-pointcloud.csv")

        logger.info("Writing to csv: {}".format(pointcloud_path))
        with open(pointcloud_path, "w") as fout:
            csv_writer = csv.DictWriter(fout, OUTPUT_COLUMNS, **CSV_FORMAT)
            csv_writer.writeheader()

            for obj_id in range(len(object_sizes)):
                fields = {}
                (
                    fields["x_px"],
                    fields["y_px"],
                    fields["z_px"],
                ) = object_centers_xyz[obj_id]
                (
                    fields["min_x_px"],
                    fields["min_y_px"],
                    fields["min_z_px"],
                ) = object_min_coords_xyz[obj_id]
                (
                    fields["max_x_px"],
                    fields["max_y_px"],
                    fields["max_z_px"],
                ) = object_max_coords_xyz[obj_id]
                fields["size_px"] = object_sizes[obj_id]

                csv_writer.writerow(fields)
                # fout.flush()

        logger.info("FINISHED csv export")
class EdgeTrainingWithMulticutWorkflow(Workflow):
    workflowName = "Edge Training With Multicut"
    workflowDisplayName = "(BETA) Edge Training With Multicut"

    workflowDescription = "A workflow based around training a classifier for merging superpixels and joining them via multicut."
    defaultAppletIndex = 0 # show DataSelection by default

    DATA_ROLE_RAW = 0
    DATA_ROLE_PROBABILITIES = 1
    DATA_ROLE_SUPERPIXELS = 2
    DATA_ROLE_GROUNDTRUTH = 3
    ROLE_NAMES = ['Raw Data', 'Probabilities', 'Superpixels', 'Groundtruth']
    EXPORT_NAMES = ['Multicut Segmentation']

    @property
    def applets(self):
        return self._applets

    @property
    def imageNameListSlot(self):
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def __init__(self, shell, headless, workflow_cmdline_args, project_creation_workflow, *args, **kwargs):
        self.stored_classifier = None

        # Create a graph to be shared by all operators
        graph = Graph()

        super(EdgeTrainingWithMulticutWorkflow, self).__init__( shell, headless, workflow_cmdline_args, project_creation_workflow, graph=graph, *args, **kwargs)
        self._applets = []

        # -- DataSelection applet
        #
        self.dataSelectionApplet = DataSelectionApplet(self, "Input Data", "Input Data", forceAxisOrder=['zyxc', 'yxc'])

        # Dataset inputs
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        opDataSelection.DatasetRoles.setValue( self.ROLE_NAMES )

        # -- Watershed applet
        #
        self.wsdtApplet = WsdtApplet(self, "DT Watershed", "DT Watershed")

        # -- Edge training AND Multicut applet
        # 
        self.edgeTrainingWithMulticutApplet = EdgeTrainingWithMulticutApplet(self, "Training and Multicut", "Training and Multicut")
        opEdgeTrainingWithMulticut = self.edgeTrainingWithMulticutApplet.topLevelOperator
        DEFAULT_FEATURES = { self.ROLE_NAMES[self.DATA_ROLE_RAW]: ['standard_edge_mean'] }
        opEdgeTrainingWithMulticut.FeatureNames.setValue( DEFAULT_FEATURES )

        # -- DataExport applet
        #
        self.dataExportApplet = DataExportApplet(self, "Data Export")
        self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        self.dataExportApplet.post_process_entire_export = self.post_process_entire_export

        # Configure global DataExport settings
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.WorkingDirectory.connect( opDataSelection.WorkingDirectory )
        opDataExport.SelectionNames.setValue( self.EXPORT_NAMES )

        # -- BatchProcessing applet
        #
        self.batchProcessingApplet = BatchProcessingApplet(self,
                                                           "Batch Processing",
                                                           self.dataSelectionApplet,
                                                           self.dataExportApplet)

        # -- Expose applets to shell
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.wsdtApplet)
        self._applets.append(self.edgeTrainingWithMulticutApplet)
        self._applets.append(self.dataExportApplet)
        self._applets.append(self.batchProcessingApplet)

        # -- Parse command-line arguments
        #    (Command-line args are applied in onProjectLoaded(), below.)
        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument('--retrain', help="Re-train the classifier based on labels stored in the project file, and re-save.", action="store_true")
        self.parsed_workflow_args, unused_args = parser.parse_known_args(workflow_cmdline_args)
        if unused_args:
            # Parse batch export/input args.
            self._data_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args( unused_args )
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args( unused_args )
        else:
            unused_args = None
            self._batch_input_args = None
            self._data_export_args = None

        if unused_args:
            logger.warn("Unused command-line args: {}".format( unused_args ))
        
        if not self._headless:
            shell.currentAppletChanged.connect( self.handle_applet_changed )

    def prepareForNewLane(self, laneIndex):
        """
        Overridden from Workflow base class.
        Called immediately before a new lane is added to the workflow.
        """
        opEdgeTrainingWithMulticut = self.edgeTrainingWithMulticutApplet.topLevelOperator
        opClassifierCache = opEdgeTrainingWithMulticut.opEdgeTraining.opClassifierCache

        # When the new lane is added, dirty notifications will propagate throughout the entire graph.
        # This means the classifier will be marked 'dirty' even though it is still usable.
        # Before that happens, let's store the classifier, so we can restore it in handleNewLanesAdded(), below.
        if opClassifierCache.Output.ready() and \
           not opClassifierCache._dirty:
            self.stored_classifier = opClassifierCache.Output.value
        else:
            self.stored_classifier = None
        
    def handleNewLanesAdded(self):
        """
        Overridden from Workflow base class.
        Called immediately after a new lane is added to the workflow and initialized.
        """
        opEdgeTrainingWithMulticut = self.edgeTrainingWithMulticutApplet.topLevelOperator
        opClassifierCache = opEdgeTrainingWithMulticut.opEdgeTraining.opClassifierCache

        # Restore classifier we saved in prepareForNewLane() (if any)
        if self.stored_classifier:
            opClassifierCache.forceValue(self.stored_classifier)
            # Release reference
            self.stored_classifier = None

    def connectLane(self, laneIndex):
        """
        Override from base class.
        """
        opDataSelection = self.dataSelectionApplet.topLevelOperator.getLane(laneIndex)
        opWsdt = self.wsdtApplet.topLevelOperator.getLane(laneIndex)
        opEdgeTrainingWithMulticut = self.edgeTrainingWithMulticutApplet.topLevelOperator.getLane(laneIndex)
        opDataExport = self.dataExportApplet.topLevelOperator.getLane(laneIndex)

        # RAW DATA: Convert to float32
        opConvertRaw = OpConvertDtype( parent=self )
        opConvertRaw.ConversionDtype.setValue( np.float32 )
        opConvertRaw.Input.connect( opDataSelection.ImageGroup[self.DATA_ROLE_RAW] )

        # PROBABILITIES: Convert to float32
        opConvertProbabilities = OpConvertDtype( parent=self )
        opConvertProbabilities.ConversionDtype.setValue( np.float32 )
        opConvertProbabilities.Input.connect( opDataSelection.ImageGroup[self.DATA_ROLE_PROBABILITIES] )

        # PROBABILITIES: Normalize drange to [0.0, 1.0]
        opNormalizeProbabilities = OpPixelOperator( parent=self )
        def normalize_inplace(a):
            drange = opNormalizeProbabilities.Input.meta.drange
            if drange is None or (drange[0] == 0.0 and drange[1] == 1.0):
                return a
            a[:] -= drange[0]
            a[:] /= ( drange[1] - drange[0] )
            return a
        opNormalizeProbabilities.Input.connect( opConvertProbabilities.Output )
        opNormalizeProbabilities.Function.setValue( normalize_inplace )

        # GROUNDTRUTH: Convert to uint32, relabel, and cache
        opConvertGroundtruth = OpConvertDtype( parent=self )
        opConvertGroundtruth.ConversionDtype.setValue( np.uint32 )
        opConvertGroundtruth.Input.connect( opDataSelection.ImageGroup[self.DATA_ROLE_GROUNDTRUTH] )

        opRelabelGroundtruth = OpRelabelConsecutive( parent=self )
        opRelabelGroundtruth.Input.connect( opConvertGroundtruth.Output )
        
        opGroundtruthCache = OpBlockedArrayCache( parent=self )
        opGroundtruthCache.CompressionEnabled.setValue(True)
        opGroundtruthCache.Input.connect( opRelabelGroundtruth.Output )

        # watershed inputs
        opWsdt.RawData.connect( opDataSelection.ImageGroup[self.DATA_ROLE_RAW] )
        opWsdt.Input.connect( opNormalizeProbabilities.Output )

        # Actual computation is done with both RawData and Probabilities
        opStackRawAndVoxels = OpSimpleStacker( parent=self )
        opStackRawAndVoxels.Images.resize(2)
        opStackRawAndVoxels.Images[0].connect( opConvertRaw.Output )
        opStackRawAndVoxels.Images[1].connect( opNormalizeProbabilities.Output )
        opStackRawAndVoxels.AxisFlag.setValue('c')

        # If superpixels are available from a file, use it.
        opSuperpixelsSelect = OpPrecomputedInput( ignore_dirty_input=True, parent=self )
        opSuperpixelsSelect.PrecomputedInput.connect( opDataSelection.ImageGroup[self.DATA_ROLE_SUPERPIXELS] )
        opSuperpixelsSelect.SlowInput.connect( opWsdt.Superpixels )

        # If the superpixel file changes, then we have to remove the training labels from the image
        opEdgeTraining = opEdgeTrainingWithMulticut.opEdgeTraining
        def handle_new_superpixels( *args ):
            opEdgeTraining.handle_dirty_superpixels( opEdgeTraining.Superpixels )
        opDataSelection.ImageGroup[self.DATA_ROLE_SUPERPIXELS].notifyReady( handle_new_superpixels )
        opDataSelection.ImageGroup[self.DATA_ROLE_SUPERPIXELS].notifyUnready( handle_new_superpixels )

        # edge training inputs
        opEdgeTrainingWithMulticut.RawData.connect( opDataSelection.ImageGroup[self.DATA_ROLE_RAW] ) # Used for visualization only
        opEdgeTrainingWithMulticut.VoxelData.connect( opStackRawAndVoxels.Output )
        opEdgeTrainingWithMulticut.Superpixels.connect( opSuperpixelsSelect.Output )
        opEdgeTrainingWithMulticut.GroundtruthSegmentation.connect( opGroundtruthCache.Output )

        # DataExport inputs
        opDataExport.RawData.connect( opDataSelection.ImageGroup[self.DATA_ROLE_RAW] )
        opDataExport.RawDatasetInfo.connect( opDataSelection.DatasetGroup[self.DATA_ROLE_RAW] )        
        opDataExport.Inputs.resize( len(self.EXPORT_NAMES) )
        opDataExport.Inputs[0].connect( opEdgeTrainingWithMulticut.Output )
        for slot in opDataExport.Inputs:
            assert slot.partner is not None
        
    def onProjectLoaded(self, projectManager):
        """
        Overridden from Workflow base class.  Called by the Project Manager.
        
        If the user provided command-line arguments, use them to configure 
        the workflow inputs and output settings.
        """
        # Configure the data export operator.
        if self._data_export_args:
            self.dataExportApplet.configure_operator_with_parsed_args( self._data_export_args )

        # Retrain the classifier?
        if self.parsed_workflow_args.retrain:
            self._force_retrain_classifier(projectManager)

        if self._headless and self._batch_input_args and self._data_export_args:
            # Make sure the watershed can be computed if necessary.
            opWsdt = self.wsdtApplet.topLevelOperator
            opWsdt.FreezeCache.setValue( False )

            # Error checks
            if (self._batch_input_args.raw_data
            and len(self._batch_input_args.probabilities) != len(self._batch_input_args.raw_data) ):
                msg = "Error: Your input file lists are malformed.\n"
                msg += "Usage: run_ilastik.sh --headless --raw_data <file1> <file2>... --probabilities <file1> <file2>..."
                sys.exit(msg)

            if  (self._batch_input_args.superpixels
            and (not self._batch_input_args.raw_data or len(self._batch_input_args.superpixels) != len(self._batch_input_args.raw_data) ) ):
                msg = "Error: Wrong number of superpixel file inputs."
                sys.exit(msg)

            logger.info("Beginning Batch Processing")
            self.batchProcessingApplet.run_export_from_parsed_args(self._batch_input_args)
            logger.info("Completed Batch Processing")

    def _force_retrain_classifier(self, projectManager):
        logger.info("Retraining edge classifier...")
        op = self.edgeTrainingWithMulticutApplet.topLevelOperator

        # Cause the classifier to be dirty so it is forced to retrain.
        # (useful if the stored labels or features were changed outside ilastik)
        op.FeatureNames.setDirty()
        
        # Request the classifier, which forces training
        new_classifier = op.opEdgeTraining.opClassifierCache.Output.value
        if new_classifier is None:
            raise RuntimeError("Classifier could not be trained! Check your labels and features.")

        # store new classifier to project file
        projectManager.saveProject(force_all_save=False)

    def prepare_for_entire_export(self):
        """
        Assigned to DataExportApplet.prepare_for_entire_export
        (See above.)
        """
        # While exporting results, the segmentation cache should not be "frozen"
        self.freeze_status = self.edgeTrainingWithMulticutApplet.topLevelOperator.FreezeCache.value
        self.edgeTrainingWithMulticutApplet.topLevelOperator.FreezeCache.setValue(False)

    def post_process_entire_export(self):
        """
        Assigned to DataExportApplet.post_process_entire_export
        (See above.)
        """
        # After export is finished, re-freeze the segmentation cache.
        self.edgeTrainingWithMulticutApplet.topLevelOperator.FreezeCache.setValue(self.freeze_status)


    def handleAppletStateUpdateRequested(self):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.appletStateUpdateRequested`
        """
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        opWsdt = self.wsdtApplet.topLevelOperator
        opEdgeTrainingWithMulticut = self.edgeTrainingWithMulticutApplet.topLevelOperator
        opDataExport = self.dataExportApplet.topLevelOperator

        # If no data, nothing else is ready.
        input_ready = len(opDataSelection.ImageGroup) > 0 and not self.dataSelectionApplet.busy

        superpixels_available_from_file = False
        lane_index = self._shell.currentImageIndex
        if lane_index != -1:
            superpixels_available_from_file = opDataSelection.ImageGroup[lane_index][self.DATA_ROLE_SUPERPIXELS].ready()

        superpixels_ready = opWsdt.Superpixels.ready()

        # The user isn't allowed to touch anything while batch processing is running.
        batch_processing_busy = self.batchProcessingApplet.busy

        self._shell.setAppletEnabled( self.dataSelectionApplet,             not batch_processing_busy )
        self._shell.setAppletEnabled( self.wsdtApplet,                      not batch_processing_busy and input_ready and not superpixels_available_from_file )
        self._shell.setAppletEnabled( self.edgeTrainingWithMulticutApplet,  not batch_processing_busy and input_ready and superpixels_ready )
        self._shell.setAppletEnabled( self.dataExportApplet,                not batch_processing_busy and input_ready and opEdgeTrainingWithMulticut.Output.ready())
        self._shell.setAppletEnabled( self.batchProcessingApplet,           not batch_processing_busy and input_ready )

        # Lastly, check for certain "busy" conditions, during which we
        #  should prevent the shell from closing the project.
        busy = False
        busy |= self.dataSelectionApplet.busy
        busy |= self.wsdtApplet.busy
        busy |= self.edgeTrainingWithMulticutApplet.busy
        busy |= self.dataExportApplet.busy
        busy |= self.batchProcessingApplet.busy
        self._shell.enableProjectChanges( not busy )

    def handle_applet_changed(self, prev_index, current_index):
        if prev_index != current_index:
            # If the user is viewing an applet downstream of the WSDT applet,
            # make sure the superpixels are always up-to-date.
            opWsdt = self.wsdtApplet.topLevelOperator
            opWsdt.FreezeCache.setValue( self._shell.currentAppletIndex <= self.applets.index( self.wsdtApplet ) )

            # Same for the multicut segmentation
            opMulticut = self.edgeTrainingWithMulticutApplet.topLevelOperator
            opMulticut.FreezeCache.setValue( self._shell.currentAppletIndex <= self.applets.index( self.edgeTrainingWithMulticutApplet ) )
            
예제 #29
0
    def __init__(self, shell, headless, workflow_cmdline_args,
                 project_creation_workflow, *args, **kwargs):
        self.stored_classifier = None

        # Create a graph to be shared by all operators
        graph = Graph()

        super(EdgeTrainingWithMulticutWorkflow,
              self).__init__(shell,
                             headless,
                             workflow_cmdline_args,
                             project_creation_workflow,
                             graph=graph,
                             *args,
                             **kwargs)
        self._applets = []

        # -- DataSelection applet
        #
        self.dataSelectionApplet = DataSelectionApplet(self, "Input Data",
                                                       "Input Data")

        # Dataset inputs
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        opDataSelection.DatasetRoles.setValue(self.ROLE_NAMES)

        # -- Watershed applet
        #
        self.wsdtApplet = WsdtApplet(self, "DT Watershed", "DT Watershed")

        # -- Edge training AND Multicut applet
        #
        self.edgeTrainingWithMulticutApplet = EdgeTrainingWithMulticutApplet(
            self, "Training and Multicut", "Training and Multicut")
        opEdgeTrainingWithMulticut = self.edgeTrainingWithMulticutApplet.topLevelOperator
        DEFAULT_FEATURES = {
            self.ROLE_NAMES[self.DATA_ROLE_RAW]: ['standard_edge_mean']
        }
        opEdgeTrainingWithMulticut.FeatureNames.setValue(DEFAULT_FEATURES)

        # -- DataExport applet
        #
        self.dataExportApplet = DataExportApplet(self, "Data Export")
        self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        self.dataExportApplet.post_process_entire_export = self.post_process_entire_export

        # Configure global DataExport settings
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.WorkingDirectory.connect(opDataSelection.WorkingDirectory)
        opDataExport.SelectionNames.setValue(self.EXPORT_NAMES)

        # -- BatchProcessing applet
        #
        self.batchProcessingApplet = BatchProcessingApplet(
            self, "Batch Processing", self.dataSelectionApplet,
            self.dataExportApplet)

        # -- Expose applets to shell
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.wsdtApplet)
        self._applets.append(self.edgeTrainingWithMulticutApplet)
        self._applets.append(self.dataExportApplet)
        self._applets.append(self.batchProcessingApplet)

        # -- Parse command-line arguments
        #    (Command-line args are applied in onProjectLoaded(), below.)
        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument(
            '--retrain',
            help=
            "Re-train the classifier based on labels stored in the project file, and re-save.",
            action="store_true")
        self.parsed_workflow_args, unused_args = parser.parse_known_args(
            workflow_cmdline_args)
        if unused_args:
            # Parse batch export/input args.
            self._data_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args(
                unused_args)
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args(
                unused_args)
        else:
            unused_args = None
            self._batch_input_args = None
            self._data_export_args = None

        if unused_args:
            logger.warn("Unused command-line args: {}".format(unused_args))

        if not self._headless:
            shell.currentAppletChanged.connect(self.handle_applet_changed)
    def __init__(self, shell, headless, workflow_cmdline_args, project_creation_workflow, *args, **kwargs):
        self.stored_classifier = None

        # Create a graph to be shared by all operators
        graph = Graph()

        super(EdgeTrainingWithMulticutWorkflow, self).__init__( shell, headless, workflow_cmdline_args, project_creation_workflow, graph=graph, *args, **kwargs)
        self._applets = []

        # -- DataSelection applet
        #
        self.dataSelectionApplet = DataSelectionApplet(self, "Input Data", "Input Data", forceAxisOrder=['zyxc', 'yxc'])

        # Dataset inputs
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        opDataSelection.DatasetRoles.setValue( self.ROLE_NAMES )

        # -- Watershed applet
        #
        self.wsdtApplet = WsdtApplet(self, "DT Watershed", "DT Watershed")

        # -- Edge training AND Multicut applet
        # 
        self.edgeTrainingWithMulticutApplet = EdgeTrainingWithMulticutApplet(self, "Training and Multicut", "Training and Multicut")
        opEdgeTrainingWithMulticut = self.edgeTrainingWithMulticutApplet.topLevelOperator
        DEFAULT_FEATURES = { self.ROLE_NAMES[self.DATA_ROLE_RAW]: ['standard_edge_mean'] }
        opEdgeTrainingWithMulticut.FeatureNames.setValue( DEFAULT_FEATURES )

        # -- DataExport applet
        #
        self.dataExportApplet = DataExportApplet(self, "Data Export")
        self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        self.dataExportApplet.post_process_entire_export = self.post_process_entire_export

        # Configure global DataExport settings
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.WorkingDirectory.connect( opDataSelection.WorkingDirectory )
        opDataExport.SelectionNames.setValue( self.EXPORT_NAMES )

        # -- BatchProcessing applet
        #
        self.batchProcessingApplet = BatchProcessingApplet(self,
                                                           "Batch Processing",
                                                           self.dataSelectionApplet,
                                                           self.dataExportApplet)

        # -- Expose applets to shell
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.wsdtApplet)
        self._applets.append(self.edgeTrainingWithMulticutApplet)
        self._applets.append(self.dataExportApplet)
        self._applets.append(self.batchProcessingApplet)

        # -- Parse command-line arguments
        #    (Command-line args are applied in onProjectLoaded(), below.)
        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument('--retrain', help="Re-train the classifier based on labels stored in the project file, and re-save.", action="store_true")
        self.parsed_workflow_args, unused_args = parser.parse_known_args(workflow_cmdline_args)
        if unused_args:
            # Parse batch export/input args.
            self._data_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args( unused_args )
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args( unused_args )
        else:
            unused_args = None
            self._batch_input_args = None
            self._data_export_args = None

        if unused_args:
            logger.warn("Unused command-line args: {}".format( unused_args ))
        
        if not self._headless:
            shell.currentAppletChanged.connect( self.handle_applet_changed )
예제 #31
0
class DataConversionWorkflow(Workflow):
    """
    Simple workflow for converting data between formats.
    Has only two 'interactive' applets (Data Selection and Data Export), plus the BatchProcessing applet.    

    Supports headless mode. For example:
    
    .. code-block::

        python ilastik.py --headless 
                          --new_project=NewTemporaryProject.ilp
                          --workflow=DataConversionWorkflow
                          --output_format="png sequence"
                          ~/input1.h5
                          ~/input2.h5

    .. note:: Beware of issues related to absolute vs. relative paths.
              Relative links are stored relative to the project file.

              To avoid this issue entirely, either 
                 (1) use only absolute filepaths
              or (2) cd into your project file's directory before launching ilastik.
    
    """
    def __init__(self, shell, headless, workflow_cmdline_args, project_creation_args, *args, **kwargs):

        
        # Create a graph to be shared by all operators
        graph = Graph()
        super(DataConversionWorkflow, self).__init__(shell, headless, workflow_cmdline_args, project_creation_args, graph=graph, *args, **kwargs)
        self._applets = []

        # Instantiate DataSelection applet
        self.dataSelectionApplet = DataSelectionApplet(self, 
                                                       "Input Data", 
                                                       "Input Data", 
                                                       supportIlastik05Import=True)

        # Configure global DataSelection settings
        role_names = ["Input Data"]
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        opDataSelection.DatasetRoles.setValue( role_names )

        # Instantiate DataExport applet
        self.dataExportApplet = DataExportApplet(self, "Data Export")

        # Configure global DataExport settings
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.WorkingDirectory.connect( opDataSelection.WorkingDirectory )
        opDataExport.SelectionNames.setValue( ["Input"] )        

        # No special data pre/post processing necessary in this workflow, 
        #   but this is where we'd hook it up if we needed it.
        #
        #self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        #self.dataExportApplet.prepare_lane_for_export = self.prepare_lane_for_export
        #self.dataExportApplet.post_process_lane_export = self.post_process_lane_export
        #self.dataExportApplet.post_process_entire_export = self.post_process_entire_export

        # Instantiate BatchProcessing applet
        self.batchProcessingApplet = BatchProcessingApplet(self, 
                                                           "Batch Processing", 
                                                           self.dataSelectionApplet, 
                                                           self.dataExportApplet)

        # Expose our applets in a list (for the shell to use)
        self._applets.append( self.dataSelectionApplet )
        self._applets.append( self.dataExportApplet )
        self._applets.append(self.batchProcessingApplet)

        # Parse command-line arguments
        # Command-line args are applied in onProjectLoaded(), below.
        if workflow_cmdline_args:
            self._data_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args( workflow_cmdline_args )
            self._batch_input_args, unused_args = self.dataSelectionApplet.parse_known_cmdline_args( unused_args, role_names )
        else:
            unused_args = None
            self._batch_input_args = None
            self._data_export_args = None

        if unused_args:
            logger.warn("Unused command-line args: {}".format( unused_args ))

    @property
    def applets(self):
        """
        Overridden from Workflow base class.
        """
        return self._applets

    @property
    def imageNameListSlot(self):
        """
        Overridden from Workflow base class.
        """
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def prepareForNewLane(self, laneIndex):
        """
        Overridden from Workflow base class.
        Called immediately before connectLane()
        """
        # No preparation necessary.
        pass

    def connectLane(self, laneIndex):
        """
        Overridden from Workflow base class.
        """
        # Get a *view* of each top-level operator, specific to the current lane.
        opDataSelectionView = self.dataSelectionApplet.topLevelOperator.getLane(laneIndex)
        opDataExportView = self.dataExportApplet.topLevelOperator.getLane(laneIndex)

        # Now connect the operators together for this lane.
        # Most workflows would have more to do here, but this workflow is super simple:
        # We just connect input to export
        opDataExportView.RawDatasetInfo.connect( opDataSelectionView.DatasetGroup[RAW_DATA_ROLE_INDEX] )        
        opDataExportView.Inputs.resize( 1 )
        opDataExportView.Inputs[RAW_DATA_ROLE_INDEX].connect( opDataSelectionView.ImageGroup[RAW_DATA_ROLE_INDEX] )

        # There is no special "raw" display layer in this workflow.
        #opDataExportView.RawData.connect( opDataSelectionView.ImageGroup[0] )

    def handleNewLanesAdded(self):
        """
        Overridden from Workflow base class.
        Called immediately AFTER connectLane() and the dataset is loaded into the workflow.
        """
        # No special handling required.
        pass

    def onProjectLoaded(self, projectManager):
        """
        Overridden from Workflow base class.  Called by the Project Manager.
        
        If the user provided command-line arguments, use them to configure 
        the workflow inputs and output settings.
        """
        # Configure the data export operator.
        if self._data_export_args:
            self.dataExportApplet.configure_operator_with_parsed_args( self._data_export_args )

        if self._headless and self._batch_input_args and self._data_export_args:
            logger.info("Beginning Batch Processing")
            self.batchProcessingApplet.run_export_from_parsed_args(self._batch_input_args)
            logger.info("Completed Batch Processing")

    def handleAppletStateUpdateRequested(self):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.statusUpdateSignal`
        """
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        input_ready = len(opDataSelection.ImageGroup) > 0

        opDataExport = self.dataExportApplet.topLevelOperator
        export_data_ready = input_ready and \
                            len(opDataExport.Inputs[0]) > 0 and \
                            opDataExport.Inputs[0][0].ready() and \
                            (TinyVector(opDataExport.Inputs[0][0].meta.shape) > 0).all()

        self._shell.setAppletEnabled(self.dataSelectionApplet, not self.batchProcessingApplet.busy)
        self._shell.setAppletEnabled(self.dataExportApplet, export_data_ready and not self.batchProcessingApplet.busy)
        self._shell.setAppletEnabled(self.batchProcessingApplet, export_data_ready)
        
        # Lastly, check for certain "busy" conditions, during which we 
        #  should prevent the shell from closing the project.
        busy = False
        busy |= self.dataSelectionApplet.busy
        busy |= self.dataExportApplet.busy
        busy |= self.batchProcessingApplet.busy
        self._shell.enableProjectChanges( not busy )
예제 #32
0
class _NNWorkflowBase(Workflow):
    """
    This class provides workflow for a remote tiktorch server
    It has special server configuration applets allowing user to
    connect to remotely running tiktorch server managed by user
    """

    auto_register = False
    workflowName = "Neural Network Classification (BASE)"
    workflowDescription = "Base class for NN Classification workflows"

    DATA_ROLE_RAW = 0
    ROLE_NAMES = ["Raw Data"]
    EXPORT_NAMES = ["Probabilities", "Labels"]

    @property
    def applets(self):
        return self._applets

    @property
    def imageNameListSlot(self):
        """
        Return the "image name list" slot, which lists the names of
        all image lanes (i.e. files) currently loaded by the workflow
        """
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def __init__(self, shell, headless, workflow_cmdline_args,
                 project_creation_args, *args, **kwargs):
        graph = Graph()
        super().__init__(shell,
                         headless,
                         workflow_cmdline_args,
                         project_creation_args,
                         graph=graph,
                         *args,
                         **kwargs)

        self._applets = []
        self._workflow_cmdline_args = workflow_cmdline_args

        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument("--batch-size",
                            help="choose the preferred batch size",
                            type=int)
        parser.add_argument("--model-path",
                            help="the neural network model for prediction")

        # Parse the creation args: These were saved to the project file when this project was first created.
        parsed_creation_args, unused_args = parser.parse_known_args(
            project_creation_args)

        # Parse the cmdline args for the current session.
        self.parsed_args, unused_args = parser.parse_known_args(
            workflow_cmdline_args)

        # Functions are supposed to expose applets to shell (add to self._applets)
        self._createInputAndConfigApplets()
        self._createClassifierApplet()

        self.dataExportApplet = NNClassificationDataExportApplet(
            self, "Data Export")

        # Configure global DataExport settings
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        opClassify = self.nnClassificationApplet.topLevelOperator
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.WorkingDirectory.connect(opDataSelection.WorkingDirectory)
        opDataExport.SelectionNames.setValue(self.EXPORT_NAMES)
        opDataExport.PmapColors.connect(opClassify.PmapColors)
        opDataExport.LabelNames.connect(opClassify.LabelNames)

        self.batchProcessingApplet = BatchProcessingApplet(
            self, "Batch Processing", self.dataSelectionApplet,
            self.dataExportApplet)

        # Expose for shell
        self._applets.append(self.dataExportApplet)
        self._applets.append(self.batchProcessingApplet)

        if unused_args:
            # We parse the export setting args first.  All remaining args are considered input files by the input applet.
            self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args(
                unused_args)
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args(
                unused_args)
        else:
            self._batch_input_args = None
            self._batch_export_args = None

        if unused_args:
            logger.warning("Unused command-line args: {}".format(unused_args))

    def _createClassifierApplet(self):
        # Override in child class
        raise NotImplemented

    def _createInputAndConfigApplets(self):
        data_instructions = "Select your input data using the 'Raw Data' tab shown on the right"
        self.dataSelectionApplet = DataSelectionApplet(
            self,
            "Input Data",
            "Input Data",
            instructionText=data_instructions)
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        opDataSelection.DatasetRoles.setValue(self.ROLE_NAMES)
        self._applets.append(self.dataSelectionApplet)

    def connectLane(self, laneIndex):
        """
        connects the operators for different lanes, each lane has a laneIndex starting at 0
        """
        opData = self.dataSelectionApplet.topLevelOperator.getLane(laneIndex)
        opNNclassify = self.nnClassificationApplet.topLevelOperator.getLane(
            laneIndex)
        opDataExport = self.dataExportApplet.topLevelOperator.getLane(
            laneIndex)

        # Input Image ->  Classification Op (for display)
        opNNclassify.InputImages.connect(opData.Image)
        # Data Export connections
        opDataExport.RawData.connect(opData.ImageGroup[self.DATA_ROLE_RAW])
        opDataExport.RawDatasetInfo.connect(
            opData.DatasetGroup[self.DATA_ROLE_RAW])
        opDataExport.Inputs.resize(len(self.EXPORT_NAMES))
        opDataExport.Inputs[0].connect(opNNclassify.PredictionProbabilities)
        opDataExport.Inputs[1].connect(opNNclassify.LabelImages)

    def handleAppletStateUpdateRequested(self, upstream_ready=True):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.appletStateUpdateRequested`
        """
        # If no data, nothing else is ready.
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        input_ready = len(opDataSelection.ImageGroup
                          ) > 0 and not self.dataSelectionApplet.busy

        opNNClassification = self.nnClassificationApplet.topLevelOperator

        opDataExport = self.dataExportApplet.topLevelOperator

        predictions_ready = input_ready and len(opDataExport.Inputs) > 0

        # Problems can occur if the features or input data are changed during live update mode.
        # Don't let the user do that.
        live_update_active = not opNNClassification.FreezePredictions.value

        # The user isn't allowed to touch anything while batch processing is running.
        batch_processing_busy = self.batchProcessingApplet.busy

        self._shell.setAppletEnabled(
            self.dataSelectionApplet, not batch_processing_busy
            and upstream_ready)

        self._shell.setAppletEnabled(
            self.nnClassificationApplet, input_ready
            and not batch_processing_busy and upstream_ready)
        self._shell.setAppletEnabled(
            self.dataExportApplet,
            predictions_ready and not batch_processing_busy
            and not live_update_active and upstream_ready,
        )

        if self.batchProcessingApplet is not None:
            self._shell.setAppletEnabled(
                self.batchProcessingApplet, predictions_ready
                and not batch_processing_busy and upstream_ready)

        # Lastly, check for certain "busy" conditions, during which we
        #  should prevent the shell from closing the project.
        busy = False
        busy |= self.dataSelectionApplet.busy
        busy |= self.nnClassificationApplet.busy
        busy |= self.dataExportApplet.busy
        busy |= self.batchProcessingApplet.busy
        self._shell.enableProjectChanges(not busy)

    def onProjectLoaded(self, projectManager):
        """
        Overridden from Workflow base class.  Called by the Project Manager.

        If the user provided command-line arguments, use them to configure
        the workflow for batch mode and export all results.
        (This workflow's headless mode supports only batch mode for now.)
        """
        # Headless batch mode.
        if self._headless and self._batch_input_args and self._batch_export_args:
            raise NotImplementedError(
                "headless networkclassification not implemented yet!")
            self.dataExportApplet.configure_operator_with_parsed_args(
                self._batch_export_args)

            batch_size = self.parsed_args.batch_size
            halo_size = self.parsed_args.halo_size
            model_path = self.parsed_args.model_path

            if batch_size and model_path:

                model = TikTorchLazyflowClassifier(None, model_path, halo_size,
                                                   batch_size)

                input_shape = self.getBlockShape(model, halo_size)

                self.nnClassificationApplet.topLevelOperator.BlockShape.setValue(
                    input_shape)
                self.nnClassificationApplet.topLevelOperator.NumClasses.setValue(
                    model._tiktorch_net.get("num_output_channels"))

                self.nnClassificationApplet.topLevelOperator.Classifier.setValue(
                    model)

            logger.info("Beginning Batch Processing")
            self.batchProcessingApplet.run_export_from_parsed_args(
                self._batch_input_args)
            logger.info("Completed Batch Processing")

    def getBlockShape(self, model, halo_size):
        """
        calculates the input Block shape
        """
        expected_input_shape = model._tiktorch_net.expected_input_shape
        input_shape = numpy.array(expected_input_shape)

        if not halo_size:
            if "output_size" in model._tiktorch_net._configuration:
                # if the ouputsize of the model is smaller as the expected input shape
                # the halo needs to be changed
                output_shape = model._tiktorch_net.get("output_size")
                if output_shape != input_shape:
                    self.halo_size = int(
                        (input_shape[1] - output_shape[1]) / 2)
                    model.HALO_SIZE = self.halo_size
                    print(self.halo_size)

        if len(model._tiktorch_net.get("window_size")) == 2:
            input_shape = numpy.append(input_shape, None)
        else:

            input_shape = input_shape[1:]
            input_shape = numpy.append(input_shape, None)

        input_shape[1:3] -= 2 * self.halo_size

        return input_shape
    def __init__(self, shell, headless, workflow_cmdline_args, project_creation_args, *args, **kwargs):
        # Create a graph to be shared by all operators
        graph = Graph()
        super( PixelClassificationWorkflow, self ).__init__( shell, headless, workflow_cmdline_args, project_creation_args, graph=graph, *args, **kwargs )
        self.stored_classifer = None
        self._applets = []
        self._workflow_cmdline_args = workflow_cmdline_args
        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument('--filter', help="pixel feature filter implementation.", choices=['Original', 'Refactored', 'Interpolated'], default='Original')
        parser.add_argument('--print-labels-by-slice', help="Print the number of labels for each Z-slice of each image.", action="store_true")
        parser.add_argument('--label-search-value', help="If provided, only this value is considered when using --print-labels-by-slice", default=0, type=int)
        parser.add_argument('--generate-random-labels', help="Add random labels to the project file.", action="store_true")
        parser.add_argument('--random-label-value', help="The label value to use injecting random labels", default=1, type=int)
        parser.add_argument('--random-label-count', help="The number of random labels to inject via --generate-random-labels", default=2000, type=int)
        parser.add_argument('--retrain', help="Re-train the classifier based on labels stored in project file, and re-save.", action="store_true")
        parser.add_argument('--tree-count', help='Number of trees for Vigra RF classifier.', type=int)
        parser.add_argument('--variable-importance-path', help='Location of variable-importance table.', type=str)
        parser.add_argument('--label-proportion', help='Proportion of feature-pixels used to train the classifier.', type=float)

        # Parse the creation args: These were saved to the project file when this project was first created.
        parsed_creation_args, unused_args = parser.parse_known_args(project_creation_args)
        self.filter_implementation = parsed_creation_args.filter
        
        # Parse the cmdline args for the current session.
        parsed_args, unused_args = parser.parse_known_args(workflow_cmdline_args)
        self.print_labels_by_slice = parsed_args.print_labels_by_slice
        self.label_search_value = parsed_args.label_search_value
        self.generate_random_labels = parsed_args.generate_random_labels
        self.random_label_value = parsed_args.random_label_value
        self.random_label_count = parsed_args.random_label_count
        self.retrain = parsed_args.retrain
        self.tree_count = parsed_args.tree_count
        self.variable_importance_path = parsed_args.variable_importance_path
        self.label_proportion = parsed_args.label_proportion

        if parsed_args.filter and parsed_args.filter != parsed_creation_args.filter:
            logger.error("Ignoring new --filter setting.  Filter implementation cannot be changed after initial project creation.")
        
        data_instructions = "Select your input data using the 'Raw Data' tab shown on the right.\n\n"\
                            "Power users: Optionally use the 'Prediction Mask' tab to supply a binary image that tells ilastik where it should avoid computations you don't need."

        # Applets for training (interactive) workflow 
        self.projectMetadataApplet = ProjectMetadataApplet()
        
        self.dataSelectionApplet = self.createDataSelectionApplet()
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        
        # see role constants, above
        opDataSelection.DatasetRoles.setValue( PixelClassificationWorkflow.ROLE_NAMES )

        self.featureSelectionApplet = self.createFeatureSelectionApplet()

        self.pcApplet = self.createPixelClassificationApplet()
        opClassify = self.pcApplet.topLevelOperator

        self.dataExportApplet = PixelClassificationDataExportApplet(self, "Prediction Export")
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.PmapColors.connect( opClassify.PmapColors )
        opDataExport.LabelNames.connect( opClassify.LabelNames )
        opDataExport.WorkingDirectory.connect( opDataSelection.WorkingDirectory )
        opDataExport.SelectionNames.setValue( self.EXPORT_NAMES )        

        # Expose for shell
        self._applets.append(self.projectMetadataApplet)
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.featureSelectionApplet)
        self._applets.append(self.pcApplet)
        self._applets.append(self.dataExportApplet)
        
        self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        self.dataExportApplet.post_process_entire_export = self.post_process_entire_export

        self.batchProcessingApplet = BatchProcessingApplet(self, 
                                                           "Batch Processing", 
                                                           self.dataSelectionApplet, 
                                                           self.dataExportApplet)

        self._applets.append(self.batchProcessingApplet)
        if unused_args:
            # We parse the export setting args first.  All remaining args are considered input files by the input applet.
            self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args( unused_args )
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args( unused_args )
        else:
            self._batch_input_args = None
            self._batch_export_args = None

        if unused_args:
            logger.warn("Unused command-line args: {}".format( unused_args ))
예제 #34
0
    def __init__(self, shell, headless, workflow_cmdline_args,
                 project_creation_args, *args, **kwargs):
        graph = Graph()
        super().__init__(shell,
                         headless,
                         workflow_cmdline_args,
                         project_creation_args,
                         graph=graph,
                         *args,
                         **kwargs)

        self._applets = []
        self._workflow_cmdline_args = workflow_cmdline_args

        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument("--batch-size",
                            help="choose the preferred batch size",
                            type=int)
        parser.add_argument("--model-path",
                            help="the neural network model for prediction")

        # Parse the creation args: These were saved to the project file when this project was first created.
        parsed_creation_args, unused_args = parser.parse_known_args(
            project_creation_args)

        # Parse the cmdline args for the current session.
        self.parsed_args, unused_args = parser.parse_known_args(
            workflow_cmdline_args)

        # Functions are supposed to expose applets to shell (add to self._applets)
        self._createInputAndConfigApplets()
        self._createClassifierApplet()

        self.dataExportApplet = NNClassificationDataExportApplet(
            self, "Data Export")

        # Configure global DataExport settings
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        opClassify = self.nnClassificationApplet.topLevelOperator
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.WorkingDirectory.connect(opDataSelection.WorkingDirectory)
        opDataExport.SelectionNames.setValue(self.EXPORT_NAMES)
        opDataExport.PmapColors.connect(opClassify.PmapColors)
        opDataExport.LabelNames.connect(opClassify.LabelNames)

        self.batchProcessingApplet = BatchProcessingApplet(
            self, "Batch Processing", self.dataSelectionApplet,
            self.dataExportApplet)

        # Expose for shell
        self._applets.append(self.dataExportApplet)
        self._applets.append(self.batchProcessingApplet)

        if unused_args:
            # We parse the export setting args first.  All remaining args are considered input files by the input applet.
            self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args(
                unused_args)
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args(
                unused_args)
        else:
            self._batch_input_args = None
            self._batch_export_args = None

        if unused_args:
            logger.warning("Unused command-line args: {}".format(unused_args))
    def __init__(self, shell, headless,
                 workflow_cmdline_args,
                 project_creation_args,
                 *args, **kwargs):
        graph = kwargs['graph'] if 'graph' in kwargs else Graph()
        if 'graph' in kwargs:
            del kwargs['graph']
        super(ObjectClassificationWorkflow, self).__init__(shell, headless, workflow_cmdline_args, project_creation_args, graph=graph, *args, **kwargs)
        self.stored_pixel_classifier = None
        self.stored_object_classifier = None

        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument('--fillmissing', help="use 'fill missing' applet with chosen detection method", choices=['classic', 'svm', 'none'], default='none')
        parser.add_argument('--filter', help="pixel feature filter implementation.", choices=['Original', 'Refactored', 'Interpolated'], default='Original')
        parser.add_argument('--nobatch', help="do not append batch applets", action='store_true', default=False)
        
        parsed_creation_args, unused_args = parser.parse_known_args(project_creation_args)

        self.fillMissing = parsed_creation_args.fillmissing
        self.filter_implementation = parsed_creation_args.filter

        parsed_args, unused_args = parser.parse_known_args(workflow_cmdline_args)
        if parsed_args.fillmissing != 'none' and parsed_creation_args.fillmissing != parsed_args.fillmissing:
            logger.error( "Ignoring --fillmissing cmdline arg.  Can't specify a different fillmissing setting after the project has already been created." )
        
        if parsed_args.filter != 'Original' and parsed_creation_args.filter != parsed_args.filter:
            logger.error( "Ignoring --filter cmdline arg.  Can't specify a different filter setting after the project has already been created." )

        self.batch = not parsed_args.nobatch

        self._applets = []

        self.pcApplet = None
        self.projectMetadataApplet = ProjectMetadataApplet()
        self._applets.append(self.projectMetadataApplet)

        self.setupInputs()
        
        if self.fillMissing != 'none':
            self.fillMissingSlicesApplet = FillMissingSlicesApplet(
                self, "Fill Missing Slices", "Fill Missing Slices", self.fillMissing)
            self._applets.append(self.fillMissingSlicesApplet)

        if isinstance(self, ObjectClassificationWorkflowPixel):
            self.input_types = 'raw'
        elif isinstance(self, ObjectClassificationWorkflowBinary):
            self.input_types = 'raw+binary'
        elif isinstance( self, ObjectClassificationWorkflowPrediction ):
            self.input_types = 'raw+pmaps'
        
        # our main applets
        self.objectExtractionApplet = ObjectExtractionApplet(workflow=self, name = "Object Feature Selection")
        self.objectClassificationApplet = ObjectClassificationApplet(workflow=self)
        self.dataExportApplet = ObjectClassificationDataExportApplet(self, "Object Information Export")
        self.dataExportApplet.set_exporting_operator(self.objectClassificationApplet.topLevelOperator)

        # Customization hooks
        self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        #self.dataExportApplet.prepare_lane_for_export = self.prepare_lane_for_export
        self.dataExportApplet.post_process_lane_export = self.post_process_lane_export
        self.dataExportApplet.post_process_entire_export = self.post_process_entire_export
        
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.WorkingDirectory.connect( self.dataSelectionApplet.topLevelOperator.WorkingDirectory )
        
        # See EXPORT_SELECTION_PREDICTIONS and EXPORT_SELECTION_PROBABILITIES, above
        export_selection_names = ['Object Predictions',
                                  'Object Probabilities',
                                  'Blockwise Object Predictions',
                                  'Blockwise Object Probabilities']
        if self.input_types == 'raw':
            # Re-configure to add the pixel probabilities option
            # See EXPORT_SELECTION_PIXEL_PROBABILITIES, above
            export_selection_names.append( 'Pixel Probabilities' )
        opDataExport.SelectionNames.setValue( export_selection_names )

        self._batch_export_args = None
        self._batch_input_args = None
        self._export_args = None
        self.batchProcessingApplet = None
        if self.batch:
            self.batchProcessingApplet = BatchProcessingApplet(self, 
                                                               "Batch Processing", 
                                                               self.dataSelectionApplet, 
                                                               self.dataExportApplet)
    
            if unused_args:
                # Additional export args (specific to the object classification workflow)
                export_arg_parser = argparse.ArgumentParser()
                export_arg_parser.add_argument( "--table_filename", help="The location to export the object feature/prediction CSV file.", required=False )
                export_arg_parser.add_argument( "--export_object_prediction_img", action="store_true" )
                export_arg_parser.add_argument( "--export_object_probability_img", action="store_true" )
                
                # TODO: Support this, too, someday?
                #export_arg_parser.add_argument( "--export_object_label_img", action="store_true" )
                
                if self.input_types == 'raw':
                    export_arg_parser.add_argument( "--export_pixel_probability_img", action="store_true" )
                self._export_args, unused_args = export_arg_parser.parse_known_args(unused_args)
                self._export_args.export_pixel_probability_img = self._export_args.export_pixel_probability_img or None

                # We parse the export setting args first.  All remaining args are considered input files by the input applet.
                self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args( unused_args )
                self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args( unused_args )

                # For backwards compatibility, translate these special args into the standard syntax
                if self._export_args.export_object_prediction_img:
                    self._batch_input_args.export_source = "Object Predictions"
                if self._export_args.export_object_probability_img:
                    self._batch_input_args.export_source = "Object Probabilities"
                if self._export_args.export_pixel_probability_img:
                    self._batch_input_args.export_source = "Pixel Probabilities"


        self.blockwiseObjectClassificationApplet = BlockwiseObjectClassificationApplet(
            self, "Blockwise Object Classification", "Blockwise Object Classification")

        self._applets.append(self.objectExtractionApplet)
        self._applets.append(self.objectClassificationApplet)
        self._applets.append(self.dataExportApplet)
        if self.batchProcessingApplet:
            self._applets.append(self.batchProcessingApplet)
        self._applets.append(self.blockwiseObjectClassificationApplet)

        if unused_args:
            logger.warn("Unused command-line args: {}".format( unused_args ))
예제 #36
0
class NNClassificationWorkflow(Workflow):
    """
    Workflow for the Neural Network Classification Applet
    """

    workflowName = "Neural Network Classification (Beta)"
    workflowDescription = "This is obviously self-explanatory."
    defaultAppletIndex = 0  # show DataSelection by default

    DATA_ROLE_RAW = 0
    ROLE_NAMES = ["Raw Data"]
    EXPORT_NAMES = ["Probabilities", "Labels"]

    @property
    def applets(self):
        """
        Return the list of applets that are owned by this workflow
        """
        return self._applets

    @property
    def imageNameListSlot(self):
        """
        Return the "image name list" slot, which lists the names of
        all image lanes (i.e. files) currently loaded by the workflow
        """
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def __init__(self, shell, headless, workflow_cmdline_args,
                 project_creation_args, *args, **kwargs):

        # Create a graph to be shared by all operators
        graph = Graph()
        super(NNClassificationWorkflow, self).__init__(shell,
                                                       headless,
                                                       workflow_cmdline_args,
                                                       project_creation_args,
                                                       graph=graph,
                                                       *args,
                                                       **kwargs)
        self._applets = []
        self._workflow_cmdline_args = workflow_cmdline_args

        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument("--batch-size",
                            help="choose the prefered batch size",
                            type=int)
        parser.add_argument("--halo-size",
                            help="choose the prefered halo size",
                            type=int)
        parser.add_argument("--model-path",
                            help="the neural network model for prediction")

        # Parse the creation args: These were saved to the project file when this project was first created.
        parsed_creation_args, unused_args = parser.parse_known_args(
            project_creation_args)

        # Parse the cmdline args for the current session.
        self.parsed_args, unused_args = parser.parse_known_args(
            workflow_cmdline_args)

        ######################
        # Interactive workflow
        ######################

        data_instructions = (
            "Select your input data using the 'Raw Data' tab shown on the right.\n\n"
            "Power users: Optionally use the 'Prediction Mask' tab to supply a binary image that tells ilastik where it should avoid computations you don't need."
        )

        # Applets for training (interactive) workflow
        self.dataSelectionApplet = self.createDataSelectionApplet()
        opDataSelection = self.dataSelectionApplet.topLevelOperator

        # see role constants, above
        opDataSelection.DatasetRoles.setValue(
            NNClassificationWorkflow.ROLE_NAMES)

        connFactory = tiktorch.TiktorchConnectionFactory()

        self.serverConfigApplet = ServerConfigApplet(
            self, connectionFactory=connFactory)
        self.nnClassificationApplet = NNClassApplet(
            self,
            "NNClassApplet",
            connectionFactory=self.serverConfigApplet.connectionFactory)

        opClassify = self.nnClassificationApplet.topLevelOperator

        self.dataExportApplet = NNClassificationDataExportApplet(
            self, "Data Export")

        # Configure global DataExport settings
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.WorkingDirectory.connect(opDataSelection.WorkingDirectory)
        opDataExport.SelectionNames.setValue(self.EXPORT_NAMES)
        opDataExport.PmapColors.connect(opClassify.PmapColors)
        opDataExport.LabelNames.connect(opClassify.LabelNames)

        # self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        # self.dataExportApplet.post_process_entire_export = self.post_process_entire_export

        self.batchProcessingApplet = BatchProcessingApplet(
            self, "Batch Processing", self.dataSelectionApplet,
            self.dataExportApplet)

        # Expose for shell
        self._applets.append(self.serverConfigApplet)
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.nnClassificationApplet)
        self._applets.append(self.dataExportApplet)
        self._applets.append(self.batchProcessingApplet)

        if unused_args:
            # We parse the export setting args first.  All remaining args are considered input files by the input applet.
            self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args(
                unused_args)
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args(
                unused_args)
        else:
            self._batch_input_args = None
            self._batch_export_args = None

        if unused_args:
            logger.warn("Unused command-line args: {}".format(unused_args))

    def createDataSelectionApplet(self):
        """
        Can be overridden by subclasses, if they want to use
        special parameters to initialize the DataSelectionApplet.
        """
        data_instructions = "Select your input data using the 'Raw Data' tab shown on the right"
        return DataSelectionApplet(self,
                                   "Input Data",
                                   "Input Data",
                                   supportIlastik05Import=True,
                                   instructionText=data_instructions)

    def connectLane(self, laneIndex):
        """
        connects the operators for different lanes, each lane has a laneIndex starting at 0
        """
        opData = self.dataSelectionApplet.topLevelOperator.getLane(laneIndex)
        opServerConfig = self.serverConfigApplet.topLevelOperator.getLane(
            laneIndex)
        opNNclassify = self.nnClassificationApplet.topLevelOperator.getLane(
            laneIndex)
        opDataExport = self.dataExportApplet.topLevelOperator.getLane(
            laneIndex)

        # Input Image ->  Classification Op (for display)
        opNNclassify.InputImages.connect(opData.Image)
        opNNclassify.ServerConfig.connect(opServerConfig.ServerConfig)

        # Data Export connections
        opDataExport.RawData.connect(opData.ImageGroup[self.DATA_ROLE_RAW])
        opDataExport.RawDatasetInfo.connect(
            opData.DatasetGroup[self.DATA_ROLE_RAW])
        opDataExport.Inputs.resize(len(self.EXPORT_NAMES))
        # opDataExport.Inputs[0].connect(op5Pred.Output)
        opDataExport.Inputs[0].connect(opNNclassify.PredictionProbabilities)
        opDataExport.Inputs[1].connect(opNNclassify.LabelImages)
        # for slot in opDataExport.Inputs:
        #     assert slot.upstream_slot is not None

    def handleAppletStateUpdateRequested(self):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.appletStateUpdateRequested`
        """
        # If no data, nothing else is ready.
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        input_ready = len(opDataSelection.ImageGroup
                          ) > 0 and not self.dataSelectionApplet.busy

        opNNClassification = self.nnClassificationApplet.topLevelOperator
        serverConfig_finished = self.serverConfigApplet.topLevelOperator.ServerConfig.ready(
        )

        opDataExport = self.dataExportApplet.topLevelOperator

        predictions_ready = input_ready and len(opDataExport.Inputs) > 0
        # opDataExport.Inputs[0][0].ready()
        # (TinyVector(opDataExport.Inputs[0][0].meta.shape) > 0).all()

        # Problems can occur if the features or input data are changed during live update mode.
        # Don't let the user do that.
        live_update_active = not opNNClassification.FreezePredictions.value

        # The user isn't allowed to touch anything while batch processing is running.
        batch_processing_busy = self.batchProcessingApplet.busy

        self._shell.setAppletEnabled(
            self.serverConfigApplet, not batch_processing_busy
            and not live_update_active)
        self._shell.setAppletEnabled(
            self.dataSelectionApplet, serverConfig_finished
            and not batch_processing_busy)

        self._shell.setAppletEnabled(
            self.nnClassificationApplet, input_ready and serverConfig_finished
            and not batch_processing_busy)
        self._shell.setAppletEnabled(
            self.dataExportApplet,
            serverConfig_finished and predictions_ready
            and not batch_processing_busy and not live_update_active,
        )

        if self.batchProcessingApplet is not None:
            self._shell.setAppletEnabled(
                self.batchProcessingApplet, serverConfig_finished
                and predictions_ready and not batch_processing_busy)

        # Lastly, check for certain "busy" conditions, during which we
        #  should prevent the shell from closing the project.
        busy = False
        busy |= self.dataSelectionApplet.busy
        busy |= self.nnClassificationApplet.busy
        busy |= self.dataExportApplet.busy
        busy |= self.batchProcessingApplet.busy
        self._shell.enableProjectChanges(not busy)

    def onProjectLoaded(self, projectManager):
        """
        Overridden from Workflow base class.  Called by the Project Manager.

        If the user provided command-line arguments, use them to configure
        the workflow for batch mode and export all results.
        (This workflow's headless mode supports only batch mode for now.)
        """
        # Headless batch mode.
        if self._headless and self._batch_input_args and self._batch_export_args:
            raise NotImplementedError(
                "headless networkclassification not implemented yet!")
            self.dataExportApplet.configure_operator_with_parsed_args(
                self._batch_export_args)

            batch_size = self.parsed_args.batch_size
            halo_size = self.parsed_args.halo_size
            model_path = self.parsed_args.model_path

            if batch_size and model_path:

                model = TikTorchLazyflowClassifier(None, model_path, halo_size,
                                                   batch_size)

                input_shape = self.getBlockShape(model, halo_size)

                self.nnClassificationApplet.topLevelOperator.BlockShape.setValue(
                    input_shape)
                self.nnClassificationApplet.topLevelOperator.NumClasses.setValue(
                    model._tiktorch_net.get("num_output_channels"))

                self.nnClassificationApplet.topLevelOperator.Classifier.setValue(
                    model)

            logger.info("Beginning Batch Processing")
            self.batchProcessingApplet.run_export_from_parsed_args(
                self._batch_input_args)
            logger.info("Completed Batch Processing")

    def getBlockShape(self, model, halo_size):
        """
        calculates the input Block shape
        """
        expected_input_shape = model._tiktorch_net.expected_input_shape
        input_shape = numpy.array(expected_input_shape)

        if not halo_size:
            if "output_size" in model._tiktorch_net._configuration:
                # if the ouputsize of the model is smaller as the expected input shape
                # the halo needs to be changed
                output_shape = model._tiktorch_net.get("output_size")
                if output_shape != input_shape:
                    self.halo_size = int(
                        (input_shape[1] - output_shape[1]) / 2)
                    model.HALO_SIZE = self.halo_size
                    print(self.halo_size)

        if len(model._tiktorch_net.get("window_size")) == 2:
            input_shape = numpy.append(input_shape, None)
        else:

            input_shape = input_shape[1:]
            input_shape = numpy.append(input_shape, None)

        input_shape[1:3] -= 2 * self.halo_size

        return input_shape

    # def getBlockShape(self, model, halo_size):
    #     """
    #     calculates the input Block shape
    #     """
    #     expected_input_shape = model._tiktorch_net.expected_input_shape
    #     input_shape = numpy.array(expected_input_shape)
    #
    #     if not halo_size:
    #         if 'output_size' in model._tiktorch_net._configuration:
    #             # if the ouputsize of the model is smaller as the expected input shape
    #             # the halo needs to be changed
    #             output_shape = model._tiktorch_net.get('output_size')
    #             if output_shape != input_shape:
    #                 self.halo_size = int((input_shape[1] - output_shape[1]) / 2)
    #                 model.HALO_SIZE = self.halo_size
    #                 print(self.halo_size)
    #
    #     if len(model._tiktorch_net.get('window_size')) == 2:
    #         input_shape = numpy.append(input_shape, None)
    #     else:
    #
    #         input_shape = input_shape[1:]
    #         input_shape = numpy.append(input_shape, None)
    #
    #     input_shape[1:3] -= 2 * self.halo_size
    #
    #     return input_shape

    def cleanUp(self):
        self.nnClassificationApplet.cleanUp()
예제 #37
0
class CountingWorkflow(Workflow):
    workflowName = "Cell Density Counting"
    workflowDescription = "This is obviously self-explanatory."
    defaultAppletIndex = 0  # show DataSelection by default

    def __init__(self, shell, headless, workflow_cmdline_args, project_creation_args, *args, **kwargs):
        graph = kwargs["graph"] if "graph" in kwargs else Graph()
        if "graph" in kwargs:
            del kwargs["graph"]
        super(CountingWorkflow, self).__init__(
            shell, headless, workflow_cmdline_args, project_creation_args, graph=graph, *args, **kwargs
        )
        self.stored_classifier = None

        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument(
            "--csv-export-file",
            help="Instead of exporting prediction density images, export total counts to the given csv path.",
        )
        self.parsed_counting_workflow_args, unused_args = parser.parse_known_args(workflow_cmdline_args)

        ######################
        # Interactive workflow
        ######################

        allowed_axis_orders = []
        for space in itertools.permutations("xyz", 2):
            allowed_axis_orders.append("".join(space) + "c")

        self.dataSelectionApplet = DataSelectionApplet(
            self, "Input Data", "Input Data", forceAxisOrder=allowed_axis_orders
        )
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        role_names = ["Raw Data"]
        opDataSelection.DatasetRoles.setValue(role_names)

        self.featureSelectionApplet = FeatureSelectionApplet(self, "Feature Selection", "FeatureSelections")

        self.countingApplet = CountingApplet(workflow=self)
        opCounting = self.countingApplet.topLevelOperator
        opCounting.WorkingDirectory.connect(opDataSelection.WorkingDirectory)

        self.dataExportApplet = CountingDataExportApplet(self, "Density Export", opCounting)

        # Customization hooks
        self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        self.dataExportApplet.post_process_lane_export = self.post_process_lane_export
        self.dataExportApplet.post_process_entire_export = self.post_process_entire_export

        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.PmapColors.connect(opCounting.PmapColors)
        opDataExport.LabelNames.connect(opCounting.LabelNames)
        opDataExport.UpperBound.connect(opCounting.UpperBound)
        opDataExport.WorkingDirectory.connect(opDataSelection.WorkingDirectory)
        opDataExport.SelectionNames.setValue(["Probabilities"])

        self._applets = []
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.featureSelectionApplet)
        self._applets.append(self.countingApplet)
        self._applets.append(self.dataExportApplet)

        self.batchProcessingApplet = BatchProcessingApplet(
            self, "Batch Processing", self.dataSelectionApplet, self.dataExportApplet
        )
        self._applets.append(self.batchProcessingApplet)
        if unused_args:
            # We parse the export setting args first.  All remaining args are considered input files by the input applet.
            self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args(unused_args)
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args(unused_args)
        else:
            self._batch_input_args = None
            self._batch_export_args = None

        if unused_args:
            logger.warning("Unused command-line args: {}".format(unused_args))

    @property
    def applets(self):
        return self._applets

    @property
    def imageNameListSlot(self):
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def prepareForNewLane(self, laneIndex):
        """
        Overridden from Workflow base class.
        Called immediately before a new lane is added to the workflow.
        """
        # When the new lane is added, dirty notifications will propagate throughout the entire graph.
        # This means the classifier will be marked 'dirty' even though it is still usable.
        # Before that happens, let's store the classifier, so we can restore it at the end of connectLane(), below.
        opCounting = self.countingApplet.topLevelOperator
        if opCounting.classifier_cache.Output.ready() and not opCounting.classifier_cache._dirty:
            self.stored_classifier = opCounting.classifier_cache.Output.value
        else:
            self.stored_classifier = None

    def handleNewLanesAdded(self):
        """
        Overridden from Workflow base class.
        Called immediately after a new lane is added to the workflow and initialized.
        """
        # Restore classifier we saved in prepareForNewLane() (if any)
        if self.stored_classifier is not None:
            self.countingApplet.topLevelOperator.classifier_cache.forceValue(self.stored_classifier)
            # Release reference
            self.stored_classifier = None

    def connectLane(self, laneIndex):
        ## Access applet operators
        opData = self.dataSelectionApplet.topLevelOperator.getLane(laneIndex)
        opTrainingFeatures = self.featureSelectionApplet.topLevelOperator.getLane(laneIndex)
        opCounting = self.countingApplet.topLevelOperator.getLane(laneIndex)
        opDataExport = self.dataExportApplet.topLevelOperator.getLane(laneIndex)

        #### connect input image
        opTrainingFeatures.InputImage.connect(opData.Image)

        opCounting.InputImages.connect(opData.Image)
        opCounting.FeatureImages.connect(opTrainingFeatures.OutputImage)
        opCounting.CachedFeatureImages.connect(opTrainingFeatures.CachedOutputImage)
        # opCounting.UserLabels.connect(opClassify.LabelImages)
        # opCounting.ForegroundLabels.connect(opObjExtraction.LabelImage)
        opDataExport.Inputs.resize(1)
        opDataExport.Inputs[0].connect(opCounting.HeadlessPredictionProbabilities)
        opDataExport.RawData.connect(opData.ImageGroup[0])
        opDataExport.RawDatasetInfo.connect(opData.DatasetGroup[0])

    def onProjectLoaded(self, projectManager):
        """
        Overridden from Workflow base class.  Called by the Project Manager.

        If the user provided command-line arguments, use them to configure
        the workflow for batch mode and export all results.
        (This workflow's headless mode supports only batch mode for now.)
        """
        # Headless batch mode.
        if self._headless and self._batch_input_args and self._batch_export_args:
            self.dataExportApplet.configure_operator_with_parsed_args(self._batch_export_args)

            # If the user provided a csv_path via the command line,
            # overwrite the setting in the counting export operator.
            csv_path = self.parsed_counting_workflow_args.csv_export_file
            if csv_path:
                self.dataExportApplet.topLevelOperator.CsvFilepath.setValue(csv_path)

            if self.countingApplet.topLevelOperator.classifier_cache._dirty:
                logger.warning(
                    "Your project file has no classifier.  " "A new classifier will be trained for this run."
                )

            logger.info("Beginning Batch Processing")
            self.batchProcessingApplet.run_export_from_parsed_args(self._batch_input_args)
            logger.info("Completed Batch Processing")

    def prepare_for_entire_export(self):
        """
        Customization hook for data export (including batch mode).
        """
        self.freeze_status = self.countingApplet.topLevelOperator.FreezePredictions.value
        self.countingApplet.topLevelOperator.FreezePredictions.setValue(False)
        # Create a new CSV file to write object counts into.
        self.csv_export_file = None
        if self.dataExportApplet.topLevelOperator.CsvFilepath.ready():
            csv_path = self.dataExportApplet.topLevelOperator.CsvFilepath.value
            logger.info("Exporting object counts to CSV: " + csv_path)
            self.csv_export_file = open(csv_path, "w")

    def post_process_lane_export(self, lane_index):
        """
        Customization hook for data export (including batch mode).
        """
        # Write the object counts for this lane as a line in the CSV file.
        if self.csv_export_file:
            self.dataExportApplet.write_csv_results(self.csv_export_file, lane_index)

    def post_process_entire_export(self):
        """
        Customization hook for data export (including batch mode).
        """
        self.countingApplet.topLevelOperator.FreezePredictions.setValue(self.freeze_status)
        if self.csv_export_file:
            self.csv_export_file.close()

    def handleAppletStateUpdateRequested(self):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.statusUpdateSignal`
        """
        # If no data, nothing else is ready.
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        input_ready = len(opDataSelection.ImageGroup) > 0 and not self.dataSelectionApplet.busy

        opFeatureSelection = self.featureSelectionApplet.topLevelOperator
        featureOutput = opFeatureSelection.OutputImage
        features_ready = (
            input_ready
            and len(featureOutput) > 0
            and featureOutput[0].ready()
            and (TinyVector(featureOutput[0].meta.shape) > 0).all()
        )

        opDataExport = self.dataExportApplet.topLevelOperator
        predictions_ready = (
            features_ready
            and len(opDataExport.Inputs) > 0
            and opDataExport.Inputs[0][0].ready()
            and (TinyVector(opDataExport.Inputs[0][0].meta.shape) > 0).all()
        )

        self._shell.setAppletEnabled(self.featureSelectionApplet, input_ready)
        self._shell.setAppletEnabled(self.countingApplet, features_ready)
        self._shell.setAppletEnabled(self.dataExportApplet, predictions_ready and not self.dataExportApplet.busy)
        self._shell.setAppletEnabled(
            self.batchProcessingApplet, predictions_ready and not self.batchProcessingApplet.busy
        )

        # Lastly, check for certain "busy" conditions, during which we
        #  should prevent the shell from closing the project.
        busy = False
        busy |= self.dataSelectionApplet.busy
        busy |= self.featureSelectionApplet.busy
        busy |= self.dataExportApplet.busy
        busy |= self.batchProcessingApplet.busy
        self._shell.enableProjectChanges(not busy)
예제 #38
0
    def __init__(self, shell, headless, workflow_cmdline_args,
                 project_creation_args, *args, **kwargs):

        # Create a graph to be shared by all operators
        graph = Graph()
        super(NNClassificationWorkflow, self).__init__(shell,
                                                       headless,
                                                       workflow_cmdline_args,
                                                       project_creation_args,
                                                       graph=graph,
                                                       *args,
                                                       **kwargs)
        self._applets = []
        self._workflow_cmdline_args = workflow_cmdline_args

        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument("--batch-size",
                            help="choose the prefered batch size",
                            type=int)
        parser.add_argument("--halo-size",
                            help="choose the prefered halo size",
                            type=int)
        parser.add_argument("--model-path",
                            help="the neural network model for prediction")

        # Parse the creation args: These were saved to the project file when this project was first created.
        parsed_creation_args, unused_args = parser.parse_known_args(
            project_creation_args)

        # Parse the cmdline args for the current session.
        self.parsed_args, unused_args = parser.parse_known_args(
            workflow_cmdline_args)

        ######################
        # Interactive workflow
        ######################

        data_instructions = (
            "Select your input data using the 'Raw Data' tab shown on the right.\n\n"
            "Power users: Optionally use the 'Prediction Mask' tab to supply a binary image that tells ilastik where it should avoid computations you don't need."
        )

        # Applets for training (interactive) workflow
        self.dataSelectionApplet = self.createDataSelectionApplet()
        opDataSelection = self.dataSelectionApplet.topLevelOperator

        # see role constants, above
        opDataSelection.DatasetRoles.setValue(
            NNClassificationWorkflow.ROLE_NAMES)

        connFactory = tiktorch.TiktorchConnectionFactory()

        self.serverConfigApplet = ServerConfigApplet(
            self, connectionFactory=connFactory)
        self.nnClassificationApplet = NNClassApplet(
            self,
            "NNClassApplet",
            connectionFactory=self.serverConfigApplet.connectionFactory)

        opClassify = self.nnClassificationApplet.topLevelOperator

        self.dataExportApplet = NNClassificationDataExportApplet(
            self, "Data Export")

        # Configure global DataExport settings
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.WorkingDirectory.connect(opDataSelection.WorkingDirectory)
        opDataExport.SelectionNames.setValue(self.EXPORT_NAMES)
        opDataExport.PmapColors.connect(opClassify.PmapColors)
        opDataExport.LabelNames.connect(opClassify.LabelNames)

        # self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        # self.dataExportApplet.post_process_entire_export = self.post_process_entire_export

        self.batchProcessingApplet = BatchProcessingApplet(
            self, "Batch Processing", self.dataSelectionApplet,
            self.dataExportApplet)

        # Expose for shell
        self._applets.append(self.serverConfigApplet)
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.nnClassificationApplet)
        self._applets.append(self.dataExportApplet)
        self._applets.append(self.batchProcessingApplet)

        if unused_args:
            # We parse the export setting args first.  All remaining args are considered input files by the input applet.
            self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args(
                unused_args)
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args(
                unused_args)
        else:
            self._batch_input_args = None
            self._batch_export_args = None

        if unused_args:
            logger.warn("Unused command-line args: {}".format(unused_args))
예제 #39
0
class DataConversionWorkflow(Workflow):
    """
    Simple workflow for converting data between formats.
    Has only two 'interactive' applets (Data Selection and Data Export), plus the BatchProcessing applet.

    Supports headless mode. For example:

    .. code-block::

        python ilastik.py --headless
                          --new_project=NewTemporaryProject.ilp
                          --workflow=DataConversionWorkflow
                          --output_format="png sequence"
                          ~/input1.h5
                          ~/input2.h5

    .. note:: Beware of issues related to absolute vs. relative paths.
              Relative links are stored relative to the project file.

              To avoid this issue entirely, either
                 (1) use only absolute filepaths
              or (2) cd into your project file's directory before launching ilastik.

    """

    def __init__(self, shell, headless, workflow_cmdline_args, project_creation_args, *args, **kwargs):

        # Create a graph to be shared by all operators
        graph = Graph()
        super(DataConversionWorkflow, self).__init__(
            shell, headless, workflow_cmdline_args, project_creation_args, graph=graph, *args, **kwargs
        )
        self._applets = []

        # Instantiate DataSelection applet
        self.dataSelectionApplet = DataSelectionApplet(
            self, "Input Data", "Input Data", supportIlastik05Import=True, forceAxisOrder=None
        )

        # Configure global DataSelection settings
        role_names = ["Input Data"]
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        opDataSelection.DatasetRoles.setValue(role_names)

        # Instantiate DataExport applet
        self.dataExportApplet = DataExportApplet(self, "Data Export")

        # Configure global DataExport settings
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.WorkingDirectory.connect(opDataSelection.WorkingDirectory)
        opDataExport.SelectionNames.setValue(["Input"])

        # No special data pre/post processing necessary in this workflow,
        #   but this is where we'd hook it up if we needed it.
        #
        # self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        # self.dataExportApplet.prepare_lane_for_export = self.prepare_lane_for_export
        # self.dataExportApplet.post_process_lane_export = self.post_process_lane_export
        # self.dataExportApplet.post_process_entire_export = self.post_process_entire_export

        # Instantiate BatchProcessing applet
        self.batchProcessingApplet = BatchProcessingApplet(
            self, "Batch Processing", self.dataSelectionApplet, self.dataExportApplet
        )

        # Expose our applets in a list (for the shell to use)
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.dataExportApplet)
        self._applets.append(self.batchProcessingApplet)

        # Parse command-line arguments
        # Command-line args are applied in onProjectLoaded(), below.
        if workflow_cmdline_args:
            self._data_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args(workflow_cmdline_args)
            self._batch_input_args, unused_args = self.dataSelectionApplet.parse_known_cmdline_args(
                unused_args, role_names
            )
        else:
            unused_args = None
            self._batch_input_args = None
            self._data_export_args = None

        if unused_args:
            logger.warning("Unused command-line args: {}".format(unused_args))

    @property
    def applets(self):
        """
        Overridden from Workflow base class.
        """
        return self._applets

    @property
    def imageNameListSlot(self):
        """
        Overridden from Workflow base class.
        """
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def prepareForNewLane(self, laneIndex):
        """
        Overridden from Workflow base class.
        Called immediately before connectLane()
        """
        # No preparation necessary.
        pass

    def connectLane(self, laneIndex):
        """
        Overridden from Workflow base class.
        """
        # Get a *view* of each top-level operator, specific to the current lane.
        opDataSelectionView = self.dataSelectionApplet.topLevelOperator.getLane(laneIndex)
        opDataExportView = self.dataExportApplet.topLevelOperator.getLane(laneIndex)

        # Now connect the operators together for this lane.
        # Most workflows would have more to do here, but this workflow is super simple:
        # We just connect input to export
        opDataExportView.RawDatasetInfo.connect(opDataSelectionView.DatasetGroup[RAW_DATA_ROLE_INDEX])
        opDataExportView.Inputs.resize(1)
        opDataExportView.Inputs[RAW_DATA_ROLE_INDEX].connect(opDataSelectionView.ImageGroup[RAW_DATA_ROLE_INDEX])

        # There is no special "raw" display layer in this workflow.
        # opDataExportView.RawData.connect( opDataSelectionView.ImageGroup[0] )

    def handleNewLanesAdded(self):
        """
        Overridden from Workflow base class.
        Called immediately AFTER connectLane() and the dataset is loaded into the workflow.
        """
        # No special handling required.
        pass

    def onProjectLoaded(self, projectManager):
        """
        Overridden from Workflow base class.  Called by the Project Manager.

        If the user provided command-line arguments, use them to configure
        the workflow inputs and output settings.
        """
        # Configure the data export operator.
        if self._data_export_args:
            self.dataExportApplet.configure_operator_with_parsed_args(self._data_export_args)

        if self._headless and self._batch_input_args and self._data_export_args:
            logger.info("Beginning Batch Processing")
            self.batchProcessingApplet.run_export_from_parsed_args(self._batch_input_args)
            logger.info("Completed Batch Processing")

    def handleAppletStateUpdateRequested(self):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.statusUpdateSignal`
        """
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        input_ready = len(opDataSelection.ImageGroup) > 0

        opDataExport = self.dataExportApplet.topLevelOperator
        export_data_ready = (
            input_ready
            and len(opDataExport.Inputs[0]) > 0
            and opDataExport.Inputs[0][0].ready()
            and (TinyVector(opDataExport.Inputs[0][0].meta.shape) > 0).all()
        )

        self._shell.setAppletEnabled(self.dataSelectionApplet, not self.batchProcessingApplet.busy)
        self._shell.setAppletEnabled(self.dataExportApplet, export_data_ready and not self.batchProcessingApplet.busy)
        self._shell.setAppletEnabled(self.batchProcessingApplet, export_data_ready)

        # Lastly, check for certain "busy" conditions, during which we
        #  should prevent the shell from closing the project.
        busy = False
        busy |= self.dataSelectionApplet.busy
        busy |= self.dataExportApplet.busy
        busy |= self.batchProcessingApplet.busy
        self._shell.enableProjectChanges(not busy)
class ConservationTrackingWorkflowBase( Workflow ):
    workflowName = "Automatic Tracking Workflow (Conservation Tracking) BASE"

    def __init__( self, shell, headless, workflow_cmdline_args, project_creation_args, *args, **kwargs ):
        graph = kwargs['graph'] if 'graph' in kwargs else Graph()
        if 'graph' in kwargs: del kwargs['graph']
        # if 'withOptTrans' in kwargs:
        #     self.withOptTrans = kwargs['withOptTrans']
        # if 'fromBinary' in kwargs:
        #     self.fromBinary = kwargs['fromBinary']
        super(ConservationTrackingWorkflowBase, self).__init__(shell, headless, workflow_cmdline_args, project_creation_args, graph=graph, *args, **kwargs)

        data_instructions = 'Use the "Raw Data" tab to load your intensity image(s).\n\n'
        if self.fromBinary:
            data_instructions += 'Use the "Binary Image" tab to load your segmentation image(s).'
        else:
            data_instructions += 'Use the "Prediction Maps" tab to load your pixel-wise probability image(s).'

        # Variables to store division and cell classifiers to prevent retraining every-time batch processing runs
        self.stored_division_classifier = None
        self.stored_cell_classifier = None

        ## Create applets 
        self.dataSelectionApplet = DataSelectionApplet(self, 
                                                       "Input Data", 
                                                       "Input Data", 
                                                       forceAxisOrder=['txyzc'],
                                                       instructionText=data_instructions,
                                                       max_lanes=None
                                                       )
        
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        if self.fromBinary:
            opDataSelection.DatasetRoles.setValue( ['Raw Data', 'Segmentation Image'] )
        else:
            opDataSelection.DatasetRoles.setValue( ['Raw Data', 'Prediction Maps'] )
                
        if not self.fromBinary:
            self.thresholdTwoLevelsApplet = ThresholdTwoLevelsApplet( self, 
                                                                  "Threshold and Size Filter", 
                                                                  "ThresholdTwoLevels" )
                                                                   
        self.objectExtractionApplet = TrackingFeatureExtractionApplet(workflow=self, interactive=False,
                                                                      name="Object Feature Computation")                                                                     
        
        opObjectExtraction = self.objectExtractionApplet.topLevelOperator

        self.divisionDetectionApplet = self._createDivisionDetectionApplet(configConservation.selectedFeaturesDiv) # Might be None

        if self.divisionDetectionApplet:
            feature_dict_division = {}
            feature_dict_division[config.features_division_name] = { name: {} for name in config.division_features }
            opObjectExtraction.FeatureNamesDivision.setValue(feature_dict_division)
               
            selected_features_div = {}
            for plugin_name in list(config.selected_features_division.keys()):
                selected_features_div[plugin_name] = { name: {} for name in config.selected_features_division[plugin_name] }
            # FIXME: do not hard code this
            for name in [ 'SquaredDistances_' + str(i) for i in range(config.n_best_successors) ]:
                selected_features_div[config.features_division_name][name] = {}

            opDivisionDetection = self.divisionDetectionApplet.topLevelOperator
            opDivisionDetection.SelectedFeatures.setValue(configConservation.selectedFeaturesDiv)
            opDivisionDetection.LabelNames.setValue(['Not Dividing', 'Dividing'])        
            opDivisionDetection.AllowDeleteLabels.setValue(False)
            opDivisionDetection.AllowAddLabel.setValue(False)
            opDivisionDetection.EnableLabelTransfer.setValue(False)
                
        self.cellClassificationApplet = ObjectClassificationApplet(workflow=self,
                                                                     name="Object Count Classification",
                                                                     projectFileGroupName="CountClassification",
                                                                     selectedFeatures=configConservation.selectedFeaturesObjectCount)

        selected_features_objectcount = {}
        for plugin_name in list(config.selected_features_objectcount.keys()):
            selected_features_objectcount[plugin_name] = { name: {} for name in config.selected_features_objectcount[plugin_name] }

        opCellClassification = self.cellClassificationApplet.topLevelOperator 
        opCellClassification.SelectedFeatures.setValue(configConservation.selectedFeaturesObjectCount)
        opCellClassification.SuggestedLabelNames.setValue( ['False Detection',] + [str(1) + ' Object'] + [str(i) + ' Objects' for i in range(2,10) ] )
        opCellClassification.AllowDeleteLastLabelOnly.setValue(True)
        opCellClassification.EnableLabelTransfer.setValue(False)
                
        self.trackingApplet = ConservationTrackingApplet( workflow=self )

        self.default_export_filename = '{dataset_dir}/{nickname}-exported_data.csv'
        self.dataExportApplet = TrackingBaseDataExportApplet(
            self,
            "Tracking Result Export",
            default_export_filename=self.default_export_filename,
            pluginExportFunc=self._pluginExportFunc,
        )

        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.SelectionNames.setValue( ['Object-Identities', 'Tracking-Result', 'Merger-Result'] )
        opDataExport.WorkingDirectory.connect( opDataSelection.WorkingDirectory )

        # Extra configuration for object export table (as CSV table or HDF5 table)
        opTracking = self.trackingApplet.topLevelOperator
        self.dataExportApplet.set_exporting_operator(opTracking)
        self.dataExportApplet.prepare_lane_for_export = self.prepare_lane_for_export

        # configure export settings
        # settings = {'file path': self.default_export_filename, 'compression': {}, 'file type': 'csv'}
        # selected_features = ['Count', 'RegionCenter', 'RegionRadii', 'RegionAxes']                  
        # opTracking.ExportSettings.setValue( (settings, selected_features) )
        
        self._applets = []                
        self._applets.append(self.dataSelectionApplet)
        if not self.fromBinary:
            self._applets.append(self.thresholdTwoLevelsApplet)
        self._applets.append(self.objectExtractionApplet)

        if self.divisionDetectionApplet:
            self._applets.append(self.divisionDetectionApplet)
        
        self.batchProcessingApplet = BatchProcessingApplet(self, "Batch Processing", self.dataSelectionApplet, self.dataExportApplet)
            
        self._applets.append(self.cellClassificationApplet)
        self._applets.append(self.trackingApplet)
        self._applets.append(self.dataExportApplet)
        self._applets.append(self.batchProcessingApplet)
        
        # Parse export and batch command-line arguments for headless mode
        if workflow_cmdline_args:
            self._data_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args( workflow_cmdline_args )
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args( workflow_cmdline_args )

        else:
            unused_args = None
            self._data_export_args = None
            self._batch_input_args = None

        if unused_args:
            logger.warning("Unused command-line args: {}".format( unused_args ))
        
    @property
    def applets(self):
        return self._applets

    def _createDivisionDetectionApplet(self,selectedFeatures=dict()):
        return ObjectClassificationApplet(workflow=self,
                                          name="Division Detection (optional)",
                                          projectFileGroupName="DivisionDetection",
                                          selectedFeatures=selectedFeatures)
    
    @property
    def imageNameListSlot(self):
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def prepareForNewLane(self, laneIndex):
        # Store division and cell classifiers
        if self.divisionDetectionApplet:
            opDivisionClassification = self.divisionDetectionApplet.topLevelOperator
            if opDivisionClassification.classifier_cache.Output.ready() and \
               not opDivisionClassification.classifier_cache._dirty:
                self.stored_division_classifier = opDivisionClassification.classifier_cache.Output.value
            else:
                self.stored_division_classifier = None
                
        opCellClassification = self.cellClassificationApplet.topLevelOperator
        if opCellClassification.classifier_cache.Output.ready() and \
           not opCellClassification.classifier_cache._dirty:
            self.stored_cell_classifier = opCellClassification.classifier_cache.Output.value
        else:
            self.stored_cell_classifier = None

    def handleNewLanesAdded(self):
        """
        If new lanes were added, then we invalidated our classifiers unecessarily.
        Here, we can restore the classifier so it doesn't need to be retrained.
        """
        
        # If we have stored division and cell classifiers, restore them into the workflow now.
        if self.stored_division_classifier:
            opDivisionClassification = self.divisionDetectionApplet.topLevelOperator
            opDivisionClassification.classifier_cache.forceValue(self.stored_division_classifier)
            # Release reference
            self.stored_division_classifier = None
        
        # If we have stored division and cell classifiers, restore them into the workflow now.
        if self.stored_cell_classifier:
            opCellClassification = self.cellClassificationApplet.topLevelOperator
            opCellClassification.classifier_cache.forceValue(self.stored_cell_classifier)
            # Release reference
            self.stored_cell_classifier = None
    
    def connectLane(self, laneIndex):
        opData = self.dataSelectionApplet.topLevelOperator.getLane(laneIndex)
        if not self.fromBinary:
            opTwoLevelThreshold = self.thresholdTwoLevelsApplet.topLevelOperator.getLane(laneIndex)
        opObjExtraction = self.objectExtractionApplet.topLevelOperator.getLane(laneIndex)
        opObjExtraction.setDefaultFeatures(configConservation.allFeaturesObjectCount)

        if self.divisionDetectionApplet:
                opDivDetection = self.divisionDetectionApplet.topLevelOperator.getLane(laneIndex)
            
        opCellClassification = self.cellClassificationApplet.topLevelOperator.getLane(laneIndex)
        opTracking = self.trackingApplet.topLevelOperator.getLane(laneIndex)
        opDataExport = self.dataExportApplet.topLevelOperator.getLane(laneIndex)
        
        op5Raw = OpReorderAxes(parent=self)
        op5Raw.AxisOrder.setValue("txyzc")
        op5Raw.Input.connect(opData.ImageGroup[0])
        
        if not self.fromBinary:
            opTwoLevelThreshold.InputImage.connect(opData.ImageGroup[1])
            opTwoLevelThreshold.RawInput.connect(opData.ImageGroup[0])  # Used for display only
            # opTwoLevelThreshold.Channel.setValue(1)
            binarySrc = opTwoLevelThreshold.CachedOutput
        else:
            binarySrc = opData.ImageGroup[1]
        
        # Use Op5ifyers for both input datasets such that they are guaranteed to 
        # have the same axis order after thresholding
        op5Binary = OpReorderAxes(parent=self)         
        op5Binary.AxisOrder.setValue("txyzc")
        op5Binary.Input.connect(binarySrc)

        # # Connect operators ##       
        opObjExtraction.RawImage.connect(op5Raw.Output)
        opObjExtraction.BinaryImage.connect(op5Binary.Output)

        if self.divisionDetectionApplet:            
            opDivDetection.BinaryImages.connect( op5Binary.Output )
            opDivDetection.RawImages.connect( op5Raw.Output )        
            opDivDetection.SegmentationImages.connect(opObjExtraction.LabelImage)
            opDivDetection.ObjectFeatures.connect(opObjExtraction.RegionFeaturesAll)
            opDivDetection.ComputedFeatureNames.connect(opObjExtraction.ComputedFeatureNamesAll)
        
        opCellClassification.BinaryImages.connect( op5Binary.Output )
        opCellClassification.RawImages.connect( op5Raw.Output )
        opCellClassification.SegmentationImages.connect(opObjExtraction.LabelImage)
        opCellClassification.ObjectFeatures.connect(opObjExtraction.RegionFeaturesVigra)
        opCellClassification.ComputedFeatureNames.connect(opObjExtraction.FeatureNamesVigra)
        
        if self.divisionDetectionApplet: 
            opTracking.ObjectFeaturesWithDivFeatures.connect( opObjExtraction.RegionFeaturesAll)
            opTracking.ComputedFeatureNamesWithDivFeatures.connect( opObjExtraction.ComputedFeatureNamesAll )
            opTracking.DivisionProbabilities.connect( opDivDetection.Probabilities ) 

        opTracking.RawImage.connect( op5Raw.Output )
        opTracking.LabelImage.connect( opObjExtraction.LabelImage )
        opTracking.ObjectFeatures.connect( opObjExtraction.RegionFeaturesVigra )
        opTracking.ComputedFeatureNames.connect( opObjExtraction.FeatureNamesVigra)
        opTracking.DetectionProbabilities.connect( opCellClassification.Probabilities )
        opTracking.NumLabels.connect( opCellClassification.NumLabels )
    
        opDataExport.Inputs.resize(3)
        opDataExport.Inputs[0].connect( opTracking.RelabeledImage )
        opDataExport.Inputs[1].connect( opTracking.Output )
        opDataExport.Inputs[2].connect( opTracking.MergerOutput )
        opDataExport.RawData.connect( op5Raw.Output )
        opDataExport.RawDatasetInfo.connect( opData.DatasetGroup[0] )
         
    def prepare_lane_for_export(self, lane_index):
        # Bypass cache on headless mode and batch processing mode
        self.objectExtractionApplet.topLevelOperator[lane_index].BypassModeEnabled.setValue(True)
        
        if not self.fromBinary:
            self.thresholdTwoLevelsApplet.topLevelOperator[lane_index].opCache.BypassModeEnabled.setValue(True)
            self.thresholdTwoLevelsApplet.topLevelOperator[lane_index].opSmootherCache.BypassModeEnabled.setValue(True)
         
        # Get axes info  
        maxt = self.trackingApplet.topLevelOperator[lane_index].RawImage.meta.shape[0] 
        maxx = self.trackingApplet.topLevelOperator[lane_index].RawImage.meta.shape[1] 
        maxy = self.trackingApplet.topLevelOperator[lane_index].RawImage.meta.shape[2] 
        maxz = self.trackingApplet.topLevelOperator[lane_index].RawImage.meta.shape[3] 
        time_enum = list(range(maxt))
        x_range = (0, maxx)
        y_range = (0, maxy)
        z_range = (0, maxz)

        ndim = 2
        if ( z_range[1] - z_range[0] ) > 1:
            ndim = 3
        
        parameters = self.trackingApplet.topLevelOperator.Parameters.value
        
        # Save state of axis ranges
        if 'time_range' in parameters:
            self.prev_time_range = parameters['time_range']
        else:
            self.prev_time_range = time_enum
            
        if 'x_range' in parameters:
            self.prev_x_range = parameters['x_range']
        else:
            self.prev_x_range = x_range
        
        if 'y_range' in parameters:
            self.prev_y_range = parameters['y_range']
        else:
            self.prev_y_range = y_range
            
        if 'z_range' in parameters:
            self.prev_z_range = parameters['z_range']
        else:
            self.prev_z_range = z_range

        if 'numFramesPerSplit' in parameters:
            numFramesPerSplit = parameters['numFramesPerSplit']
        else:
            numFramesPerSplit = 0

        self.trackingApplet.topLevelOperator[lane_index].track(
            time_range = time_enum,
            x_range = x_range,
            y_range = y_range,
            z_range = z_range,
            size_range = parameters['size_range'],
            x_scale = parameters['scales'][0],
            y_scale = parameters['scales'][1],
            z_scale = parameters['scales'][2],
            maxDist=parameters['maxDist'],         
            maxObj = parameters['maxObj'],               
            divThreshold=parameters['divThreshold'],
            avgSize=parameters['avgSize'],                
            withTracklets=parameters['withTracklets'], 
            sizeDependent=parameters['sizeDependent'],
            divWeight=parameters['divWeight'],
            transWeight=parameters['transWeight'],
            withDivisions=parameters['withDivisions'],
            withOpticalCorrection=parameters['withOpticalCorrection'],
            withClassifierPrior=parameters['withClassifierPrior'],
            ndim=ndim,
            withMergerResolution=parameters['withMergerResolution'],
            borderAwareWidth = parameters['borderAwareWidth'],
            withArmaCoordinates = parameters['withArmaCoordinates'],
            cplex_timeout = parameters['cplex_timeout'],
            appearance_cost = parameters['appearanceCost'],
            disappearance_cost = parameters['disappearanceCost'],
            max_nearest_neighbors = parameters['max_nearest_neighbors'],
            numFramesPerSplit = numFramesPerSplit,
            force_build_hypotheses_graph = False,
            withBatchProcessing = True
        )

    def _pluginExportFunc(self, lane_index, filename, exportPlugin, checkOverwriteFiles, plugArgsSlot) -> int:
        return (
            self.trackingApplet
            .topLevelOperator
            .getLane(lane_index)
            .exportPlugin(
                filename,
                exportPlugin,
                checkOverwriteFiles,
                plugArgsSlot
            )
        )

    def _inputReady(self, nRoles):
        slot = self.dataSelectionApplet.topLevelOperator.ImageGroup
        if len(slot) > 0:
            input_ready = True
            for sub in slot:
                input_ready = input_ready and \
                    all([sub[i].ready() for i in range(nRoles)])
        else:
            input_ready = False

        return input_ready

    def onProjectLoaded(self, projectManager):
        """
        Overridden from Workflow base class.  Called by the Project Manager.
        
        If the user provided command-line arguments, use them to configure 
        the workflow inputs and output settings.
        """

        # Configure the data export operator.
        if self._data_export_args:
            self.dataExportApplet.configure_operator_with_parsed_args( self._data_export_args )

        # Configure headless mode.
        if self._headless and self._batch_input_args and self._data_export_args:
            logger.info("Beginning Batch Processing")
            self.batchProcessingApplet.run_export_from_parsed_args(self._batch_input_args)
            logger.info("Completed Batch Processing")

    def handleAppletStateUpdateRequested(self):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.statusUpdateSignal`
        """
        # If no data, nothing else is ready.
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        input_ready = self._inputReady(2) and not self.dataSelectionApplet.busy

        if not self.fromBinary:
            opThresholding = self.thresholdTwoLevelsApplet.topLevelOperator
            thresholdingOutput = opThresholding.CachedOutput
            thresholding_ready = input_ready and \
                           len(thresholdingOutput) > 0
        else:
            thresholding_ready = True and input_ready

        opObjectExtraction = self.objectExtractionApplet.topLevelOperator
        objectExtractionOutput = opObjectExtraction.ComputedFeatureNamesAll
        features_ready = thresholding_ready and \
                         len(objectExtractionOutput) > 0

        objectCountClassifier_ready = features_ready

        opTracking = self.trackingApplet.topLevelOperator
        tracking_ready = objectCountClassifier_ready                          

        busy = False
        busy |= self.dataSelectionApplet.busy
        busy |= self.trackingApplet.busy
        busy |= self.dataExportApplet.busy
        busy |= self.batchProcessingApplet.busy
        self._shell.enableProjectChanges( not busy )

        self._shell.setAppletEnabled(self.dataSelectionApplet, not busy)
        if not self.fromBinary:
            self._shell.setAppletEnabled(self.thresholdTwoLevelsApplet, input_ready and not busy)
            
        if self.divisionDetectionApplet:    
            self._shell.setAppletEnabled(self.divisionDetectionApplet, features_ready and not busy)
        
        self._shell.setAppletEnabled(self.objectExtractionApplet, thresholding_ready and not busy)
        self._shell.setAppletEnabled(self.cellClassificationApplet, features_ready and not busy)
        self._shell.setAppletEnabled(self.trackingApplet, objectCountClassifier_ready and not busy)
        self._shell.setAppletEnabled(self.dataExportApplet, tracking_ready and not busy and \
                                    self.dataExportApplet.topLevelOperator.Inputs[0][0].ready() )
        self._shell.setAppletEnabled(self.batchProcessingApplet, tracking_ready and not busy and \
                                    self.dataExportApplet.topLevelOperator.Inputs[0][0].ready() )
예제 #41
0
class ConservationTrackingWorkflowBase(Workflow):
    workflowName = "Automatic Tracking Workflow (Conservation Tracking) BASE"

    def __init__(self, shell, headless, workflow_cmdline_args,
                 project_creation_args, *args, **kwargs):
        graph = kwargs["graph"] if "graph" in kwargs else Graph()
        if "graph" in kwargs:
            del kwargs["graph"]
        # if 'withOptTrans' in kwargs:
        #     self.withOptTrans = kwargs['withOptTrans']
        # if 'fromBinary' in kwargs:
        #     self.fromBinary = kwargs['fromBinary']
        super(ConservationTrackingWorkflowBase,
              self).__init__(shell,
                             headless,
                             workflow_cmdline_args,
                             project_creation_args,
                             graph=graph,
                             *args,
                             **kwargs)

        data_instructions = 'Use the "Raw Data" tab to load your intensity image(s).\n\n'
        if self.fromBinary:
            data_instructions += 'Use the "Binary Image" tab to load your segmentation image(s).'
        else:
            data_instructions += 'Use the "Prediction Maps" tab to load your pixel-wise probability image(s).'

        # Variables to store division and cell classifiers to prevent retraining every-time batch processing runs
        self.stored_division_classifier = None
        self.stored_cell_classifier = None

        ## Create applets
        self.dataSelectionApplet = DataSelectionApplet(
            self,
            "Input Data",
            "Input Data",
            forceAxisOrder=["txyzc"],
            instructionText=data_instructions,
            max_lanes=None,
        )

        opDataSelection = self.dataSelectionApplet.topLevelOperator
        if self.fromBinary:
            opDataSelection.DatasetRoles.setValue(
                ["Raw Data", "Segmentation Image"])
        else:
            opDataSelection.DatasetRoles.setValue(
                ["Raw Data", "Prediction Maps"])

        if not self.fromBinary:
            self.thresholdTwoLevelsApplet = ThresholdTwoLevelsApplet(
                self, "Threshold and Size Filter", "ThresholdTwoLevels")

        self.objectExtractionApplet = TrackingFeatureExtractionApplet(
            workflow=self,
            interactive=False,
            name="Object Feature Computation")

        opObjectExtraction = self.objectExtractionApplet.topLevelOperator

        self.divisionDetectionApplet = self._createDivisionDetectionApplet(
            configConservation.selectedFeaturesDiv)  # Might be None

        if self.divisionDetectionApplet:
            feature_dict_division = {}
            feature_dict_division[config.features_division_name] = {
                name: {}
                for name in config.division_features
            }
            opObjectExtraction.FeatureNamesDivision.setValue(
                feature_dict_division)

            selected_features_div = {}
            for plugin_name in list(config.selected_features_division.keys()):
                selected_features_div[plugin_name] = {
                    name: {}
                    for name in config.selected_features_division[plugin_name]
                }
            # FIXME: do not hard code this
            for name in [
                    "SquaredDistances_" + str(i)
                    for i in range(config.n_best_successors)
            ]:
                selected_features_div[config.features_division_name][name] = {}

            opDivisionDetection = self.divisionDetectionApplet.topLevelOperator
            opDivisionDetection.SelectedFeatures.setValue(
                configConservation.selectedFeaturesDiv)
            opDivisionDetection.LabelNames.setValue(
                ["Not Dividing", "Dividing"])
            opDivisionDetection.AllowDeleteLabels.setValue(False)
            opDivisionDetection.AllowAddLabel.setValue(False)
            opDivisionDetection.EnableLabelTransfer.setValue(False)

        self.cellClassificationApplet = ObjectClassificationApplet(
            workflow=self,
            name="Object Count Classification",
            projectFileGroupName="CountClassification",
            selectedFeatures=configConservation.selectedFeaturesObjectCount,
        )

        selected_features_objectcount = {}
        for plugin_name in list(config.selected_features_objectcount.keys()):
            selected_features_objectcount[plugin_name] = {
                name: {}
                for name in config.selected_features_objectcount[plugin_name]
            }

        opCellClassification = self.cellClassificationApplet.topLevelOperator
        opCellClassification.SelectedFeatures.setValue(
            configConservation.selectedFeaturesObjectCount)
        opCellClassification.SuggestedLabelNames.setValue(
            ["False Detection"] + [str(1) + " Object"] +
            [str(i) + " Objects" for i in range(2, 10)])
        opCellClassification.AllowDeleteLastLabelOnly.setValue(True)
        opCellClassification.EnableLabelTransfer.setValue(False)

        self.trackingApplet = ConservationTrackingApplet(workflow=self)

        self.default_export_filename = "{dataset_dir}/{nickname}-exported_data.csv"
        self.dataExportApplet = TrackingBaseDataExportApplet(
            self,
            "Tracking Result Export",
            default_export_filename=self.default_export_filename,
            pluginExportFunc=self._pluginExportFunc,
        )

        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.SelectionNames.setValue(
            ["Object-Identities", "Tracking-Result", "Merger-Result"])
        opDataExport.WorkingDirectory.connect(opDataSelection.WorkingDirectory)

        # Extra configuration for object export table (as CSV table or HDF5 table)
        opTracking = self.trackingApplet.topLevelOperator
        self.dataExportApplet.set_exporting_operator(opTracking)
        self.dataExportApplet.prepare_lane_for_export = self.prepare_lane_for_export

        # configure export settings
        # settings = {'file path': self.default_export_filename, 'compression': {}, 'file type': 'csv'}
        # selected_features = ['Count', 'RegionCenter', 'RegionRadii', 'RegionAxes']
        # opTracking.ExportSettings.setValue( (settings, selected_features) )

        self._applets = []
        self._applets.append(self.dataSelectionApplet)
        if not self.fromBinary:
            self._applets.append(self.thresholdTwoLevelsApplet)
        self._applets.append(self.objectExtractionApplet)

        if self.divisionDetectionApplet:
            self._applets.append(self.divisionDetectionApplet)

        self.batchProcessingApplet = BatchProcessingApplet(
            self, "Batch Processing", self.dataSelectionApplet,
            self.dataExportApplet)

        self._applets.append(self.cellClassificationApplet)
        self._applets.append(self.trackingApplet)
        self._applets.append(self.dataExportApplet)
        self._applets.append(self.batchProcessingApplet)

        # Parse export and batch command-line arguments for headless mode
        if workflow_cmdline_args:
            self._data_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args(
                workflow_cmdline_args)
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args(
                workflow_cmdline_args)

        else:
            unused_args = None
            self._data_export_args = None
            self._batch_input_args = None

        if unused_args:
            logger.warning("Unused command-line args: {}".format(unused_args))

    @property
    def applets(self):
        return self._applets

    def _createDivisionDetectionApplet(self, selectedFeatures=dict()):
        return ObjectClassificationApplet(
            workflow=self,
            name="Division Detection (optional)",
            projectFileGroupName="DivisionDetection",
            selectedFeatures=selectedFeatures,
        )

    @property
    def imageNameListSlot(self):
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def prepareForNewLane(self, laneIndex):
        # Store division and cell classifiers
        if self.divisionDetectionApplet:
            opDivisionClassification = self.divisionDetectionApplet.topLevelOperator
            if (opDivisionClassification.classifier_cache.Output.ready()
                    and not opDivisionClassification.classifier_cache._dirty):
                self.stored_division_classifier = opDivisionClassification.classifier_cache.Output.value
            else:
                self.stored_division_classifier = None

        opCellClassification = self.cellClassificationApplet.topLevelOperator
        if opCellClassification.classifier_cache.Output.ready(
        ) and not opCellClassification.classifier_cache._dirty:
            self.stored_cell_classifier = opCellClassification.classifier_cache.Output.value
        else:
            self.stored_cell_classifier = None

    def handleNewLanesAdded(self):
        """
        If new lanes were added, then we invalidated our classifiers unecessarily.
        Here, we can restore the classifier so it doesn't need to be retrained.
        """

        # If we have stored division and cell classifiers, restore them into the workflow now.
        if self.stored_division_classifier:
            opDivisionClassification = self.divisionDetectionApplet.topLevelOperator
            opDivisionClassification.classifier_cache.forceValue(
                self.stored_division_classifier)
            # Release reference
            self.stored_division_classifier = None

        # If we have stored division and cell classifiers, restore them into the workflow now.
        if self.stored_cell_classifier:
            opCellClassification = self.cellClassificationApplet.topLevelOperator
            opCellClassification.classifier_cache.forceValue(
                self.stored_cell_classifier)
            # Release reference
            self.stored_cell_classifier = None

    def connectLane(self, laneIndex):
        opData = self.dataSelectionApplet.topLevelOperator.getLane(laneIndex)
        if not self.fromBinary:
            opTwoLevelThreshold = self.thresholdTwoLevelsApplet.topLevelOperator.getLane(
                laneIndex)
        opObjExtraction = self.objectExtractionApplet.topLevelOperator.getLane(
            laneIndex)
        opObjExtraction.setDefaultFeatures(
            configConservation.allFeaturesObjectCount)

        if self.divisionDetectionApplet:
            opDivDetection = self.divisionDetectionApplet.topLevelOperator.getLane(
                laneIndex)

        opCellClassification = self.cellClassificationApplet.topLevelOperator.getLane(
            laneIndex)
        opTracking = self.trackingApplet.topLevelOperator.getLane(laneIndex)
        opDataExport = self.dataExportApplet.topLevelOperator.getLane(
            laneIndex)

        op5Raw = OpReorderAxes(parent=self)
        op5Raw.AxisOrder.setValue("txyzc")
        op5Raw.Input.connect(opData.ImageGroup[0])

        if not self.fromBinary:
            opTwoLevelThreshold.InputImage.connect(opData.ImageGroup[1])
            opTwoLevelThreshold.RawInput.connect(
                opData.ImageGroup[0])  # Used for display only
            # opTwoLevelThreshold.Channel.setValue(1)
            binarySrc = opTwoLevelThreshold.CachedOutput
        else:
            binarySrc = opData.ImageGroup[1]

        # Use Op5ifyers for both input datasets such that they are guaranteed to
        # have the same axis order after thresholding
        op5Binary = OpReorderAxes(parent=self)
        op5Binary.AxisOrder.setValue("txyzc")
        op5Binary.Input.connect(binarySrc)

        # # Connect operators ##
        opObjExtraction.RawImage.connect(op5Raw.Output)
        opObjExtraction.BinaryImage.connect(op5Binary.Output)

        if self.divisionDetectionApplet:
            opDivDetection.BinaryImages.connect(op5Binary.Output)
            opDivDetection.RawImages.connect(op5Raw.Output)
            opDivDetection.SegmentationImages.connect(
                opObjExtraction.LabelImage)
            opDivDetection.ObjectFeatures.connect(
                opObjExtraction.RegionFeaturesAll)
            opDivDetection.ComputedFeatureNames.connect(
                opObjExtraction.ComputedFeatureNamesAll)

        opCellClassification.BinaryImages.connect(op5Binary.Output)
        opCellClassification.RawImages.connect(op5Raw.Output)
        opCellClassification.SegmentationImages.connect(
            opObjExtraction.LabelImage)
        opCellClassification.ObjectFeatures.connect(
            opObjExtraction.RegionFeaturesVigra)
        opCellClassification.ComputedFeatureNames.connect(
            opObjExtraction.FeatureNamesVigra)

        if self.divisionDetectionApplet:
            opTracking.ObjectFeaturesWithDivFeatures.connect(
                opObjExtraction.RegionFeaturesAll)
            opTracking.ComputedFeatureNamesWithDivFeatures.connect(
                opObjExtraction.ComputedFeatureNamesAll)
            opTracking.DivisionProbabilities.connect(
                opDivDetection.Probabilities)

        opTracking.RawImage.connect(op5Raw.Output)
        opTracking.LabelImage.connect(opObjExtraction.LabelImage)
        opTracking.ObjectFeatures.connect(opObjExtraction.RegionFeaturesVigra)
        opTracking.ComputedFeatureNames.connect(
            opObjExtraction.FeatureNamesVigra)
        opTracking.DetectionProbabilities.connect(
            opCellClassification.Probabilities)
        opTracking.NumLabels.connect(opCellClassification.NumLabels)

        opDataExport.Inputs.resize(3)
        opDataExport.Inputs[0].connect(opTracking.RelabeledImage)
        opDataExport.Inputs[1].connect(opTracking.Output)
        opDataExport.Inputs[2].connect(opTracking.MergerOutput)
        opDataExport.RawData.connect(op5Raw.Output)
        opDataExport.RawDatasetInfo.connect(opData.DatasetGroup[0])

    def prepare_lane_for_export(self, lane_index):
        # Bypass cache on headless mode and batch processing mode
        self.objectExtractionApplet.topLevelOperator[
            lane_index].BypassModeEnabled.setValue(True)

        if not self.fromBinary:
            self.thresholdTwoLevelsApplet.topLevelOperator[
                lane_index].opCache.BypassModeEnabled.setValue(True)
            self.thresholdTwoLevelsApplet.topLevelOperator[
                lane_index].opSmootherCache.BypassModeEnabled.setValue(True)

        # Get axes info
        maxt = self.trackingApplet.topLevelOperator[
            lane_index].RawImage.meta.shape[0]
        maxx = self.trackingApplet.topLevelOperator[
            lane_index].RawImage.meta.shape[1]
        maxy = self.trackingApplet.topLevelOperator[
            lane_index].RawImage.meta.shape[2]
        maxz = self.trackingApplet.topLevelOperator[
            lane_index].RawImage.meta.shape[3]
        time_enum = list(range(maxt))
        x_range = (0, maxx)
        y_range = (0, maxy)
        z_range = (0, maxz)

        ndim = 2
        if (z_range[1] - z_range[0]) > 1:
            ndim = 3

        parameters = self.trackingApplet.topLevelOperator.Parameters.value

        # Save state of axis ranges
        if "time_range" in parameters:
            self.prev_time_range = parameters["time_range"]
        else:
            self.prev_time_range = time_enum

        if "x_range" in parameters:
            self.prev_x_range = parameters["x_range"]
        else:
            self.prev_x_range = x_range

        if "y_range" in parameters:
            self.prev_y_range = parameters["y_range"]
        else:
            self.prev_y_range = y_range

        if "z_range" in parameters:
            self.prev_z_range = parameters["z_range"]
        else:
            self.prev_z_range = z_range

        if "numFramesPerSplit" in parameters:
            numFramesPerSplit = parameters["numFramesPerSplit"]
        else:
            numFramesPerSplit = 0

        self.trackingApplet.topLevelOperator[lane_index].track(
            time_range=time_enum,
            x_range=x_range,
            y_range=y_range,
            z_range=z_range,
            size_range=parameters["size_range"],
            x_scale=parameters["scales"][0],
            y_scale=parameters["scales"][1],
            z_scale=parameters["scales"][2],
            maxDist=parameters["maxDist"],
            maxObj=parameters["maxObj"],
            divThreshold=parameters["divThreshold"],
            avgSize=parameters["avgSize"],
            withTracklets=parameters["withTracklets"],
            sizeDependent=parameters["sizeDependent"],
            divWeight=parameters["divWeight"],
            transWeight=parameters["transWeight"],
            withDivisions=parameters["withDivisions"],
            withOpticalCorrection=parameters["withOpticalCorrection"],
            withClassifierPrior=parameters["withClassifierPrior"],
            ndim=ndim,
            withMergerResolution=parameters["withMergerResolution"],
            borderAwareWidth=parameters["borderAwareWidth"],
            withArmaCoordinates=parameters["withArmaCoordinates"],
            cplex_timeout=parameters["cplex_timeout"],
            appearance_cost=parameters["appearanceCost"],
            disappearance_cost=parameters["disappearanceCost"],
            max_nearest_neighbors=parameters["max_nearest_neighbors"],
            numFramesPerSplit=numFramesPerSplit,
            force_build_hypotheses_graph=False,
            withBatchProcessing=True,
        )

    def _pluginExportFunc(self, lane_index, filename, exportPlugin,
                          checkOverwriteFiles, plugArgsSlot) -> int:
        return self.trackingApplet.topLevelOperator.getLane(
            lane_index).exportPlugin(filename, exportPlugin,
                                     checkOverwriteFiles, plugArgsSlot)

    def _inputReady(self, nRoles):
        slot = self.dataSelectionApplet.topLevelOperator.ImageGroup
        if len(slot) > 0:
            input_ready = True
            for sub in slot:
                input_ready = input_ready and all(
                    [sub[i].ready() for i in range(nRoles)])
        else:
            input_ready = False

        return input_ready

    def onProjectLoaded(self, projectManager):
        """
        Overridden from Workflow base class.  Called by the Project Manager.

        If the user provided command-line arguments, use them to configure
        the workflow inputs and output settings.
        """

        # Configure the data export operator.
        if self._data_export_args:
            self.dataExportApplet.configure_operator_with_parsed_args(
                self._data_export_args)

        # Configure headless mode.
        if self._headless and self._batch_input_args and self._data_export_args:
            logger.info("Beginning Batch Processing")
            self.batchProcessingApplet.run_export_from_parsed_args(
                self._batch_input_args)
            logger.info("Completed Batch Processing")

    def handleAppletStateUpdateRequested(self):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.statusUpdateSignal`
        """
        # If no data, nothing else is ready.
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        input_ready = self._inputReady(2) and not self.dataSelectionApplet.busy

        if not self.fromBinary:
            opThresholding = self.thresholdTwoLevelsApplet.topLevelOperator
            thresholdingOutput = opThresholding.CachedOutput
            thresholding_ready = input_ready and len(thresholdingOutput) > 0
        else:
            thresholding_ready = True and input_ready

        opObjectExtraction = self.objectExtractionApplet.topLevelOperator
        objectExtractionOutput = opObjectExtraction.ComputedFeatureNamesAll
        features_ready = thresholding_ready and len(objectExtractionOutput) > 0

        objectCountClassifier_ready = features_ready

        opTracking = self.trackingApplet.topLevelOperator
        tracking_ready = objectCountClassifier_ready

        busy = False
        busy |= self.dataSelectionApplet.busy
        busy |= self.trackingApplet.busy
        busy |= self.dataExportApplet.busy
        busy |= self.batchProcessingApplet.busy
        self._shell.enableProjectChanges(not busy)

        self._shell.setAppletEnabled(self.dataSelectionApplet, not busy)
        if not self.fromBinary:
            self._shell.setAppletEnabled(self.thresholdTwoLevelsApplet,
                                         input_ready and not busy)

        if self.divisionDetectionApplet:
            self._shell.setAppletEnabled(self.divisionDetectionApplet,
                                         features_ready and not busy)

        self._shell.setAppletEnabled(self.objectExtractionApplet,
                                     thresholding_ready and not busy)
        self._shell.setAppletEnabled(self.cellClassificationApplet,
                                     features_ready and not busy)
        self._shell.setAppletEnabled(self.trackingApplet,
                                     objectCountClassifier_ready and not busy)
        self._shell.setAppletEnabled(
            self.dataExportApplet,
            tracking_ready and not busy
            and self.dataExportApplet.topLevelOperator.Inputs[0][0].ready(),
        )
        self._shell.setAppletEnabled(
            self.batchProcessingApplet,
            tracking_ready and not busy
            and self.dataExportApplet.topLevelOperator.Inputs[0][0].ready(),
        )
    def __init__( self, shell, headless, workflow_cmdline_args, project_creation_args, *args, **kwargs ):
        graph = kwargs['graph'] if 'graph' in kwargs else Graph()
        if 'graph' in kwargs: del kwargs['graph']
        # if 'withOptTrans' in kwargs:
        #     self.withOptTrans = kwargs['withOptTrans']
        # if 'fromBinary' in kwargs:
        #     self.fromBinary = kwargs['fromBinary']
        super(ConservationTrackingWorkflowBase, self).__init__(shell, headless, workflow_cmdline_args, project_creation_args, graph=graph, *args, **kwargs)

        data_instructions = 'Use the "Raw Data" tab to load your intensity image(s).\n\n'
        if self.fromBinary:
            data_instructions += 'Use the "Binary Image" tab to load your segmentation image(s).'
        else:
            data_instructions += 'Use the "Prediction Maps" tab to load your pixel-wise probability image(s).'

        # Variables to store division and cell classifiers to prevent retraining every-time batch processing runs
        self.stored_division_classifier = None
        self.stored_cell_classifier = None

        ## Create applets 
        self.dataSelectionApplet = DataSelectionApplet(self, 
                                                       "Input Data", 
                                                       "Input Data", 
                                                       forceAxisOrder=['txyzc'],
                                                       instructionText=data_instructions,
                                                       max_lanes=None
                                                       )
        
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        if self.fromBinary:
            opDataSelection.DatasetRoles.setValue( ['Raw Data', 'Segmentation Image'] )
        else:
            opDataSelection.DatasetRoles.setValue( ['Raw Data', 'Prediction Maps'] )
                
        if not self.fromBinary:
            self.thresholdTwoLevelsApplet = ThresholdTwoLevelsApplet( self, 
                                                                  "Threshold and Size Filter", 
                                                                  "ThresholdTwoLevels" )
                                                                   
        self.objectExtractionApplet = TrackingFeatureExtractionApplet(workflow=self, interactive=False,
                                                                      name="Object Feature Computation")                                                                     
        
        opObjectExtraction = self.objectExtractionApplet.topLevelOperator

        self.divisionDetectionApplet = self._createDivisionDetectionApplet(configConservation.selectedFeaturesDiv) # Might be None

        if self.divisionDetectionApplet:
            feature_dict_division = {}
            feature_dict_division[config.features_division_name] = { name: {} for name in config.division_features }
            opObjectExtraction.FeatureNamesDivision.setValue(feature_dict_division)
               
            selected_features_div = {}
            for plugin_name in list(config.selected_features_division.keys()):
                selected_features_div[plugin_name] = { name: {} for name in config.selected_features_division[plugin_name] }
            # FIXME: do not hard code this
            for name in [ 'SquaredDistances_' + str(i) for i in range(config.n_best_successors) ]:
                selected_features_div[config.features_division_name][name] = {}

            opDivisionDetection = self.divisionDetectionApplet.topLevelOperator
            opDivisionDetection.SelectedFeatures.setValue(configConservation.selectedFeaturesDiv)
            opDivisionDetection.LabelNames.setValue(['Not Dividing', 'Dividing'])        
            opDivisionDetection.AllowDeleteLabels.setValue(False)
            opDivisionDetection.AllowAddLabel.setValue(False)
            opDivisionDetection.EnableLabelTransfer.setValue(False)
                
        self.cellClassificationApplet = ObjectClassificationApplet(workflow=self,
                                                                     name="Object Count Classification",
                                                                     projectFileGroupName="CountClassification",
                                                                     selectedFeatures=configConservation.selectedFeaturesObjectCount)

        selected_features_objectcount = {}
        for plugin_name in list(config.selected_features_objectcount.keys()):
            selected_features_objectcount[plugin_name] = { name: {} for name in config.selected_features_objectcount[plugin_name] }

        opCellClassification = self.cellClassificationApplet.topLevelOperator 
        opCellClassification.SelectedFeatures.setValue(configConservation.selectedFeaturesObjectCount)
        opCellClassification.SuggestedLabelNames.setValue( ['False Detection',] + [str(1) + ' Object'] + [str(i) + ' Objects' for i in range(2,10) ] )
        opCellClassification.AllowDeleteLastLabelOnly.setValue(True)
        opCellClassification.EnableLabelTransfer.setValue(False)
                
        self.trackingApplet = ConservationTrackingApplet( workflow=self )

        self.default_export_filename = '{dataset_dir}/{nickname}-exported_data.csv'
        self.dataExportApplet = TrackingBaseDataExportApplet(
            self,
            "Tracking Result Export",
            default_export_filename=self.default_export_filename,
            pluginExportFunc=self._pluginExportFunc,
        )

        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.SelectionNames.setValue( ['Object-Identities', 'Tracking-Result', 'Merger-Result'] )
        opDataExport.WorkingDirectory.connect( opDataSelection.WorkingDirectory )

        # Extra configuration for object export table (as CSV table or HDF5 table)
        opTracking = self.trackingApplet.topLevelOperator
        self.dataExportApplet.set_exporting_operator(opTracking)
        self.dataExportApplet.prepare_lane_for_export = self.prepare_lane_for_export

        # configure export settings
        # settings = {'file path': self.default_export_filename, 'compression': {}, 'file type': 'csv'}
        # selected_features = ['Count', 'RegionCenter', 'RegionRadii', 'RegionAxes']                  
        # opTracking.ExportSettings.setValue( (settings, selected_features) )
        
        self._applets = []                
        self._applets.append(self.dataSelectionApplet)
        if not self.fromBinary:
            self._applets.append(self.thresholdTwoLevelsApplet)
        self._applets.append(self.objectExtractionApplet)

        if self.divisionDetectionApplet:
            self._applets.append(self.divisionDetectionApplet)
        
        self.batchProcessingApplet = BatchProcessingApplet(self, "Batch Processing", self.dataSelectionApplet, self.dataExportApplet)
            
        self._applets.append(self.cellClassificationApplet)
        self._applets.append(self.trackingApplet)
        self._applets.append(self.dataExportApplet)
        self._applets.append(self.batchProcessingApplet)
        
        # Parse export and batch command-line arguments for headless mode
        if workflow_cmdline_args:
            self._data_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args( workflow_cmdline_args )
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args( workflow_cmdline_args )

        else:
            unused_args = None
            self._data_export_args = None
            self._batch_input_args = None

        if unused_args:
            logger.warning("Unused command-line args: {}".format( unused_args ))
class PixelClassificationWorkflow(Workflow):

    workflowName = "Pixel Classification"
    workflowDescription = "This is obviously self-explanatory."
    defaultAppletIndex = 0  # show DataSelection by default

    DATA_ROLE_RAW = 0
    DATA_ROLE_PREDICTION_MASK = 1
    ROLE_NAMES = ['Raw Data', 'Prediction Mask']
    EXPORT_NAMES = [
        'Probabilities', 'Simple Segmentation', 'Uncertainty', 'Features',
        'Labels'
    ]

    @property
    def applets(self):
        return self._applets

    @property
    def imageNameListSlot(self):
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def __init__(self, shell, headless, workflow_cmdline_args,
                 project_creation_args, *args, **kwargs):
        # Create a graph to be shared by all operators
        graph = Graph()
        super(PixelClassificationWorkflow,
              self).__init__(shell,
                             headless,
                             workflow_cmdline_args,
                             project_creation_args,
                             graph=graph,
                             *args,
                             **kwargs)
        self.stored_classifier = None
        self._applets = []
        self._workflow_cmdline_args = workflow_cmdline_args
        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument('--filter',
                            help="pixel feature filter implementation.",
                            choices=['Original', 'Refactored', 'Interpolated'],
                            default='Original')
        parser.add_argument(
            '--print-labels-by-slice',
            help="Print the number of labels for each Z-slice of each image.",
            action="store_true")
        parser.add_argument(
            '--label-search-value',
            help=
            "If provided, only this value is considered when using --print-labels-by-slice",
            default=0,
            type=int)
        parser.add_argument('--generate-random-labels',
                            help="Add random labels to the project file.",
                            action="store_true")
        parser.add_argument(
            '--random-label-value',
            help="The label value to use injecting random labels",
            default=1,
            type=int)
        parser.add_argument(
            '--random-label-count',
            help=
            "The number of random labels to inject via --generate-random-labels",
            default=2000,
            type=int)
        parser.add_argument(
            '--retrain',
            help=
            "Re-train the classifier based on labels stored in project file, and re-save.",
            action="store_true")
        parser.add_argument('--tree-count',
                            help='Number of trees for Vigra RF classifier.',
                            type=int)
        parser.add_argument('--variable-importance-path',
                            help='Location of variable-importance table.',
                            type=str)
        parser.add_argument(
            '--label-proportion',
            help='Proportion of feature-pixels used to train the classifier.',
            type=float)

        # Parse the creation args: These were saved to the project file when this project was first created.
        parsed_creation_args, unused_args = parser.parse_known_args(
            project_creation_args)
        self.filter_implementation = parsed_creation_args.filter

        # Parse the cmdline args for the current session.
        parsed_args, unused_args = parser.parse_known_args(
            workflow_cmdline_args)
        self.print_labels_by_slice = parsed_args.print_labels_by_slice
        self.label_search_value = parsed_args.label_search_value
        self.generate_random_labels = parsed_args.generate_random_labels
        self.random_label_value = parsed_args.random_label_value
        self.random_label_count = parsed_args.random_label_count
        self.retrain = parsed_args.retrain
        self.tree_count = parsed_args.tree_count
        self.variable_importance_path = parsed_args.variable_importance_path
        self.label_proportion = parsed_args.label_proportion

        if parsed_args.filter and parsed_args.filter != parsed_creation_args.filter:
            logger.error(
                "Ignoring new --filter setting.  Filter implementation cannot be changed after initial project creation."
            )

        data_instructions = "Select your input data using the 'Raw Data' tab shown on the right.\n\n"\
                            "Power users: Optionally use the 'Prediction Mask' tab to supply a binary image that tells ilastik where it should avoid computations you don't need."

        # Applets for training (interactive) workflow
        self.dataSelectionApplet = self.createDataSelectionApplet()
        opDataSelection = self.dataSelectionApplet.topLevelOperator

        # see role constants, above
        opDataSelection.DatasetRoles.setValue(
            PixelClassificationWorkflow.ROLE_NAMES)

        self.featureSelectionApplet = self.createFeatureSelectionApplet()

        self.pcApplet = self.createPixelClassificationApplet()
        opClassify = self.pcApplet.topLevelOperator

        self.dataExportApplet = PixelClassificationDataExportApplet(
            self, "Prediction Export")
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.PmapColors.connect(opClassify.PmapColors)
        opDataExport.LabelNames.connect(opClassify.LabelNames)
        opDataExport.WorkingDirectory.connect(opDataSelection.WorkingDirectory)
        opDataExport.SelectionNames.setValue(self.EXPORT_NAMES)

        # Expose for shell
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.featureSelectionApplet)
        self._applets.append(self.pcApplet)
        self._applets.append(self.dataExportApplet)

        self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        self.dataExportApplet.post_process_entire_export = self.post_process_entire_export

        self.batchProcessingApplet = BatchProcessingApplet(
            self, "Batch Processing", self.dataSelectionApplet,
            self.dataExportApplet)

        self._applets.append(self.batchProcessingApplet)
        if unused_args:
            # We parse the export setting args first.  All remaining args are considered input files by the input applet.
            self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args(
                unused_args)
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args(
                unused_args)
        else:
            self._batch_input_args = None
            self._batch_export_args = None

        if unused_args:
            logger.warn("Unused command-line args: {}".format(unused_args))

    def createDataSelectionApplet(self):
        """
        Can be overridden by subclasses, if they want to use 
        special parameters to initialize the DataSelectionApplet.
        """
        data_instructions = "Select your input data using the 'Raw Data' tab shown on the right"
        return DataSelectionApplet(self,
                                   "Input Data",
                                   "Input Data",
                                   supportIlastik05Import=True,
                                   instructionText=data_instructions)

    def createFeatureSelectionApplet(self):
        """
        Can be overridden by subclasses, if they want to return their own type of FeatureSelectionApplet.
        NOTE: The applet returned here must have the same interface as the regular FeatureSelectionApplet.
              (If it looks like a duck...)
        """
        return FeatureSelectionApplet(self, "Feature Selection",
                                      "FeatureSelections",
                                      self.filter_implementation)

    def createPixelClassificationApplet(self):
        """
        Can be overridden by subclasses, if they want to return their own type of PixelClassificationApplet.
        NOTE: The applet returned here must have the same interface as the regular PixelClassificationApplet.
              (If it looks like a duck...)
        """
        return PixelClassificationApplet(self, "PixelClassification")

    def prepareForNewLane(self, laneIndex):
        """
        Overridden from Workflow base class.
        Called immediately before a new lane is added to the workflow.
        """
        # When the new lane is added, dirty notifications will propagate throughout the entire graph.
        # This means the classifier will be marked 'dirty' even though it is still usable.
        # Before that happens, let's store the classifier, so we can restore it in handleNewLanesAdded(), below.
        opPixelClassification = self.pcApplet.topLevelOperator
        if opPixelClassification.classifier_cache.Output.ready() and \
           not opPixelClassification.classifier_cache._dirty:
            self.stored_classifier = opPixelClassification.classifier_cache.Output.value
        else:
            self.stored_classifier = None

    def handleNewLanesAdded(self):
        """
        Overridden from Workflow base class.
        Called immediately after a new lane is added to the workflow and initialized.
        """
        # Restore classifier we saved in prepareForNewLane() (if any)
        if self.stored_classifier:
            self.pcApplet.topLevelOperator.classifier_cache.forceValue(
                self.stored_classifier)
            # Release reference
            self.stored_classifier = None

    def connectLane(self, laneIndex):
        # Get a handle to each operator
        opData = self.dataSelectionApplet.topLevelOperator.getLane(laneIndex)
        opTrainingFeatures = self.featureSelectionApplet.topLevelOperator.getLane(
            laneIndex)
        opClassify = self.pcApplet.topLevelOperator.getLane(laneIndex)
        opDataExport = self.dataExportApplet.topLevelOperator.getLane(
            laneIndex)

        # Input Image -> Feature Op
        #         and -> Classification Op (for display)
        opTrainingFeatures.InputImage.connect(opData.Image)
        opClassify.InputImages.connect(opData.Image)

        if ilastik_config.getboolean('ilastik', 'debug'):
            opClassify.PredictionMasks.connect(
                opData.ImageGroup[self.DATA_ROLE_PREDICTION_MASK])

        # Feature Images -> Classification Op (for training, prediction)
        opClassify.FeatureImages.connect(opTrainingFeatures.OutputImage)
        opClassify.CachedFeatureImages.connect(
            opTrainingFeatures.CachedOutputImage)

        # Data Export connections
        opDataExport.RawData.connect(opData.ImageGroup[self.DATA_ROLE_RAW])
        opDataExport.RawDatasetInfo.connect(
            opData.DatasetGroup[self.DATA_ROLE_RAW])
        opDataExport.ConstraintDataset.connect(
            opData.ImageGroup[self.DATA_ROLE_RAW])
        opDataExport.Inputs.resize(len(self.EXPORT_NAMES))
        opDataExport.Inputs[0].connect(
            opClassify.HeadlessPredictionProbabilities)
        opDataExport.Inputs[1].connect(opClassify.SimpleSegmentation)
        opDataExport.Inputs[2].connect(opClassify.HeadlessUncertaintyEstimate)
        opDataExport.Inputs[3].connect(opClassify.FeatureImages)
        opDataExport.Inputs[4].connect(opClassify.LabelImages)
        for slot in opDataExport.Inputs:
            assert slot.partner is not None

    def handleAppletStateUpdateRequested(self):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.appletStateUpdateRequested`
        """
        # If no data, nothing else is ready.
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        input_ready = len(opDataSelection.ImageGroup
                          ) > 0 and not self.dataSelectionApplet.busy

        opFeatureSelection = self.featureSelectionApplet.topLevelOperator
        featureOutput = opFeatureSelection.OutputImage
        features_ready = input_ready and \
                         len(featureOutput) > 0 and  \
                         featureOutput[0].ready() and \
                         (TinyVector(featureOutput[0].meta.shape) > 0).all()

        opDataExport = self.dataExportApplet.topLevelOperator
        opPixelClassification = self.pcApplet.topLevelOperator

        invalid_classifier = opPixelClassification.classifier_cache.fixAtCurrent.value and \
                             opPixelClassification.classifier_cache.Output.ready() and\
                             opPixelClassification.classifier_cache.Output.value is None

        predictions_ready = features_ready and \
                            not invalid_classifier and \
                            len(opDataExport.Inputs) > 0 and \
                            opDataExport.Inputs[0][0].ready() and \
                            (TinyVector(opDataExport.Inputs[0][0].meta.shape) > 0).all()

        # Problems can occur if the features or input data are changed during live update mode.
        # Don't let the user do that.
        live_update_active = not opPixelClassification.FreezePredictions.value

        # The user isn't allowed to touch anything while batch processing is running.
        batch_processing_busy = self.batchProcessingApplet.busy

        self._shell.setAppletEnabled(
            self.dataSelectionApplet, not live_update_active
            and not batch_processing_busy)
        self._shell.setAppletEnabled(
            self.featureSelectionApplet, input_ready and not live_update_active
            and not batch_processing_busy)
        self._shell.setAppletEnabled(
            self.pcApplet, features_ready and not batch_processing_busy)
        self._shell.setAppletEnabled(
            self.dataExportApplet, predictions_ready
            and not batch_processing_busy)

        if self.batchProcessingApplet is not None:
            self._shell.setAppletEnabled(
                self.batchProcessingApplet, predictions_ready
                and not batch_processing_busy)

        # Lastly, check for certain "busy" conditions, during which we
        #  should prevent the shell from closing the project.
        busy = False
        busy |= self.dataSelectionApplet.busy
        busy |= self.featureSelectionApplet.busy
        busy |= self.dataExportApplet.busy
        busy |= self.batchProcessingApplet.busy
        self._shell.enableProjectChanges(not busy)

    def onProjectLoaded(self, projectManager):
        """
        Overridden from Workflow base class.  Called by the Project Manager.
        
        If the user provided command-line arguments, use them to configure 
        the workflow for batch mode and export all results.
        (This workflow's headless mode supports only batch mode for now.)
        """
        if self.generate_random_labels:
            self._generate_random_labels(self.random_label_count,
                                         self.random_label_value)
            logger.info("Saving project...")
            self._shell.projectManager.saveProject()
            logger.info("Done.")

        if self.print_labels_by_slice:
            self._print_labels_by_slice(self.label_search_value)

        if self._headless:
            # In headless mode, let's see the messages from the training operator.
            logging.getLogger(
                "lazyflow.operators.classifierOperators").setLevel(
                    logging.DEBUG)

        if self.variable_importance_path:
            classifier_factory = self.pcApplet.topLevelOperator.opTrain.ClassifierFactory.value
            classifier_factory.set_variable_importance_path(
                self.variable_importance_path)

        if self.tree_count:
            classifier_factory = self.pcApplet.topLevelOperator.opTrain.ClassifierFactory.value
            classifier_factory.set_num_trees(self.tree_count)

        if self.label_proportion:
            classifier_factory = self.pcApplet.topLevelOperator.opTrain.ClassifierFactory.value
            classifier_factory.set_label_proportion(self.label_proportion)

        if self.tree_count or self.label_proportion:
            self.pcApplet.topLevelOperator.ClassifierFactory.setDirty()

        if self.retrain:
            self._force_retrain_classifier(projectManager)

        # Configure the data export operator.
        if self._batch_export_args:
            self.dataExportApplet.configure_operator_with_parsed_args(
                self._batch_export_args)

        if self._batch_input_args and self.pcApplet.topLevelOperator.classifier_cache._dirty:
            logger.warn(
                "Your project file has no classifier.  A new classifier will be trained for this run."
            )

        if self._headless and self._batch_input_args and self._batch_export_args:
            logger.info("Beginning Batch Processing")
            self.batchProcessingApplet.run_export_from_parsed_args(
                self._batch_input_args)
            logger.info("Completed Batch Processing")

    def prepare_for_entire_export(self):
        """
        Assigned to DataExportApplet.prepare_for_entire_export
        (See above.)
        """
        self.freeze_status = self.pcApplet.topLevelOperator.FreezePredictions.value
        self.pcApplet.topLevelOperator.FreezePredictions.setValue(False)

    def post_process_entire_export(self):
        """
        Assigned to DataExportApplet.post_process_entire_export
        (See above.)
        """
        self.pcApplet.topLevelOperator.FreezePredictions.setValue(
            self.freeze_status)

    def _force_retrain_classifier(self, projectManager):
        # Cause the classifier to be dirty so it is forced to retrain.
        # (useful if the stored labels were changed outside ilastik)
        self.pcApplet.topLevelOperator.opTrain.ClassifierFactory.setDirty()

        # Request the classifier, which forces training
        self.pcApplet.topLevelOperator.FreezePredictions.setValue(False)
        _ = self.pcApplet.topLevelOperator.Classifier.value

        # store new classifier to project file
        projectManager.saveProject(force_all_save=False)

    def _print_labels_by_slice(self, search_value):
        """
        Iterate over each label image in the project and print the number of labels present on each Z-slice of the image.
        (This is a special feature requested by the FlyEM proofreaders.)
        """
        opTopLevelClassify = self.pcApplet.topLevelOperator
        project_label_count = 0
        for image_index, label_slot in enumerate(
                opTopLevelClassify.LabelImages):
            tagged_shape = label_slot.meta.getTaggedShape()
            if 'z' not in tagged_shape:
                logger.error(
                    "Can't print label counts by Z-slices.  Image #{} has no Z-dimension."
                    .format(image_index))
            else:
                logger.info("Label counts in Z-slices of Image #{}:".format(
                    image_index))
                slicing = [slice(None)] * len(tagged_shape)
                blank_slices = []
                image_label_count = 0
                for z in range(tagged_shape['z']):
                    slicing[tagged_shape.keys().index('z')] = slice(z, z + 1)
                    label_slice = label_slot[slicing].wait()
                    if search_value:
                        count = (label_slice == search_value).sum()
                    else:
                        count = (label_slice != 0).sum()
                    if count > 0:
                        logger.info("Z={}: {}".format(z, count))
                        image_label_count += count
                    else:
                        blank_slices.append(z)
                project_label_count += image_label_count
                if len(blank_slices) > 20:
                    # Don't list the blank slices if there were a lot of them.
                    logger.info("Image #{} has {} blank slices.".format(
                        image_index, len(blank_slices)))
                elif len(blank_slices) > 0:
                    logger.info("Image #{} has {} blank slices: {}".format(
                        image_index, len(blank_slices), blank_slices))
                else:
                    logger.info(
                        "Image #{} has no blank slices.".format(image_index))
                logger.info("Total labels for Image #{}: {}".format(
                    image_index, image_label_count))
        logger.info("Total labels for project: {}".format(project_label_count))

    def _generate_random_labels(self, labels_per_image, label_value):
        """
        Inject random labels into the project file.
        (This is a special feature requested by the FlyEM proofreaders.)
        """
        logger.info("Injecting {} labels of value {} into all images.".format(
            labels_per_image, label_value))
        opTopLevelClassify = self.pcApplet.topLevelOperator

        label_names = copy.copy(opTopLevelClassify.LabelNames.value)
        while len(label_names) < label_value:
            label_names.append("Label {}".format(len(label_names) + 1))

        opTopLevelClassify.LabelNames.setValue(label_names)

        for image_index in range(len(opTopLevelClassify.LabelImages)):
            logger.info("Injecting labels into image #{}".format(image_index))
            # For reproducibility of label generation
            SEED = 1
            numpy.random.seed([SEED, image_index])

            label_input_slot = opTopLevelClassify.LabelInputs[image_index]
            label_output_slot = opTopLevelClassify.LabelImages[image_index]

            shape = label_output_slot.meta.shape
            random_labels = numpy.zeros(shape=shape, dtype=numpy.uint8)
            num_pixels = len(random_labels.flat)
            current_progress = -1
            for sample_index in range(labels_per_image):
                flat_index = numpy.random.randint(0, num_pixels)
                # Don't overwrite existing labels
                # Keep looking until we find a blank pixel
                while random_labels.flat[flat_index]:
                    flat_index = numpy.random.randint(0, num_pixels)
                random_labels.flat[flat_index] = label_value

                # Print progress every 10%
                progress = float(sample_index) // labels_per_image
                progress = 10 * (int(100 * progress) // 10)
                if progress != current_progress:
                    current_progress = progress
                    sys.stdout.write("{}% ".format(current_progress))
                    sys.stdout.flush()

            sys.stdout.write("100%\n")
            # Write into the operator
            label_input_slot[fullSlicing(shape)] = random_labels

        logger.info("Done injecting labels")

    def getHeadlessOutputSlot(self, slotId):
        """
        Not used by the regular app.
        Only used for special cluster scripts.
        """
        # "Regular" (i.e. with the images that the user selected as input data)
        if slotId == "Predictions":
            return self.pcApplet.topLevelOperator.HeadlessPredictionProbabilities
        elif slotId == "PredictionsUint8":
            return self.pcApplet.topLevelOperator.HeadlessUint8PredictionProbabilities
        # "Batch" (i.e. with the images that the user selected as batch inputs).
        elif slotId == "BatchPredictions":
            return self.opBatchPredictionPipeline.HeadlessPredictionProbabilities
        if slotId == "BatchPredictionsUint8":
            return self.opBatchPredictionPipeline.HeadlessUint8PredictionProbabilities

        raise Exception("Unknown headless output slot")
예제 #44
0
class NewAutocontextWorkflowBase(Workflow):

    workflowName = "New Autocontext Base"
    defaultAppletIndex = 0  # show DataSelection by default

    DATA_ROLE_RAW = 0
    DATA_ROLE_PREDICTION_MASK = 1

    # First export names must match these for the export GUI, because we re-use the ordinary PC gui
    # (See PixelClassificationDataExportGui.)
    EXPORT_NAMES_PER_STAGE = [
        "Probabilities", "Simple Segmentation", "Uncertainty", "Features",
        "Labels", "Input"
    ]

    @property
    def applets(self):
        return self._applets

    @property
    def imageNameListSlot(self):
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def __init__(self, shell, headless, workflow_cmdline_args,
                 project_creation_args, n_stages, *args, **kwargs):
        """
        n_stages: How many iterations of feature selection and pixel classification should be inserted into the workflow.

        All other params are just as in PixelClassificationWorkflow
        """
        # Create a graph to be shared by all operators
        graph = Graph()
        super(NewAutocontextWorkflowBase, self).__init__(shell,
                                                         headless,
                                                         workflow_cmdline_args,
                                                         project_creation_args,
                                                         graph=graph,
                                                         *args,
                                                         **kwargs)
        self.stored_classifers = []
        self._applets = []
        self._workflow_cmdline_args = workflow_cmdline_args

        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument(
            "--retrain",
            help=
            "Re-train the classifier based on labels stored in project file, and re-save.",
            action="store_true",
        )

        # Parse the creation args: These were saved to the project file when this project was first created.
        parsed_creation_args, unused_args = parser.parse_known_args(
            project_creation_args)

        # Parse the cmdline args for the current session.
        parsed_args, unused_args = parser.parse_known_args(
            workflow_cmdline_args)
        self.retrain = parsed_args.retrain

        data_instructions = (
            "Select your input data using the 'Raw Data' tab shown on the right.\n\n"
            "Power users: Optionally use the 'Prediction Mask' tab to supply a binary image that tells ilastik where it should avoid computations you don't need."
        )

        self.dataSelectionApplet = self.createDataSelectionApplet()
        opDataSelection = self.dataSelectionApplet.topLevelOperator

        # see role constants, above
        role_names = ["Raw Data", "Prediction Mask"]
        opDataSelection.DatasetRoles.setValue(role_names)

        self.featureSelectionApplets = []
        self.pcApplets = []
        for i in range(n_stages):
            self.featureSelectionApplets.append(
                self.createFeatureSelectionApplet(i))
            self.pcApplets.append(self.createPixelClassificationApplet(i))
        opFinalClassify = self.pcApplets[-1].topLevelOperator

        # If *any* stage enters 'live update' mode, make sure they all enter live update mode.
        def sync_freeze_predictions_settings(slot, *args):
            freeze_predictions = slot.value
            for pcApplet in self.pcApplets:
                pcApplet.topLevelOperator.FreezePredictions.setValue(
                    freeze_predictions)

        for pcApplet in self.pcApplets:
            pcApplet.topLevelOperator.FreezePredictions.notifyDirty(
                sync_freeze_predictions_settings)

        self.dataExportApplet = PixelClassificationDataExportApplet(
            self, "Prediction Export")
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.PmapColors.connect(opFinalClassify.PmapColors)
        opDataExport.LabelNames.connect(opFinalClassify.LabelNames)
        opDataExport.WorkingDirectory.connect(opDataSelection.WorkingDirectory)

        self.EXPORT_NAMES = []
        for stage_index in reversed(list(range(n_stages))):
            self.EXPORT_NAMES += [
                "{} Stage {}".format(name, stage_index + 1)
                for name in self.EXPORT_NAMES_PER_STAGE
            ]

        # And finally, one last item for *all* probabilities from all stages.
        self.EXPORT_NAMES += ["Probabilities All Stages"]
        opDataExport.SelectionNames.setValue(self.EXPORT_NAMES)

        # Expose for shell
        self._applets.append(self.dataSelectionApplet)
        self._applets += itertools.chain(
            *list(zip(self.featureSelectionApplets, self.pcApplets)))
        self._applets.append(self.dataExportApplet)

        self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        self.dataExportApplet.post_process_entire_export = self.post_process_entire_export

        self.batchProcessingApplet = BatchProcessingApplet(
            self, "Batch Processing", self.dataSelectionApplet,
            self.dataExportApplet)

        self._applets.append(self.batchProcessingApplet)
        if unused_args:
            # We parse the export setting args first.  All remaining args are considered input files by the input applet.
            self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args(
                unused_args)
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args(
                unused_args)
        else:
            self._batch_input_args = None
            self._batch_export_args = None

        if unused_args:
            logger.warning("Unused command-line args: {}".format(unused_args))

    def createDataSelectionApplet(self):
        """
        Can be overridden by subclasses, if they want to use
        special parameters to initialize the DataSelectionApplet.
        """
        data_instructions = "Select your input data using the 'Raw Data' tab shown on the right"

        c_at_end = ["yxc", "xyc"]
        for perm in itertools.permutations("tzyx", 3):
            c_at_end.append("".join(perm) + "c")
        for perm in itertools.permutations("tzyx", 4):
            c_at_end.append("".join(perm) + "c")

        return DataSelectionApplet(
            self,
            "Input Data",
            "Input Data",
            supportIlastik05Import=False,
            instructionText=data_instructions,
            forceAxisOrder=c_at_end,
        )

    def createFeatureSelectionApplet(self, index):
        """
        Can be overridden by subclasses, if they want to return their own type of FeatureSelectionApplet.
        NOTE: The applet returned here must have the same interface as the regular FeatureSelectionApplet.
              (If it looks like a duck...)
        """
        # Make the first one compatible with the pixel classification workflow,
        # in case the user uses "Import Project"
        hdf5_group_name = "FeatureSelections"
        if index > 0:
            hdf5_group_name = "FeatureSelections{index:02d}".format(
                index=index)
        applet = FeatureSelectionApplet(self, "Feature Selection",
                                        hdf5_group_name)
        applet.topLevelOperator.name += "{}".format(index)
        return applet

    def createPixelClassificationApplet(self, index=0):
        """
        Can be overridden by subclasses, if they want to return their own type of PixelClassificationApplet.
        NOTE: The applet returned here must have the same interface as the regular PixelClassificationApplet.
              (If it looks like a duck...)
        """
        # Make the first one compatible with the pixel classification workflow,
        # in case the user uses "Import Project"
        hdf5_group_name = "PixelClassification"
        if index > 0:
            hdf5_group_name = "PixelClassification{index:02d}".format(
                index=index)
        applet = PixelClassificationApplet(self, hdf5_group_name)
        applet.topLevelOperator.name += "{}".format(index)
        return applet

    def prepareForNewLane(self, laneIndex):
        """
        Overridden from Workflow base class.
        Called immediately before a new lane is added to the workflow.
        """
        # When the new lane is added, dirty notifications will propagate throughout the entire graph.
        # This means the classifier will be marked 'dirty' even though it is still usable.
        # Before that happens, let's store the classifier, so we can restore it at the end of connectLane(), below.
        self.stored_classifers = []
        for pcApplet in self.pcApplets:
            opPixelClassification = pcApplet.topLevelOperator
            if (opPixelClassification.classifier_cache.Output.ready()
                    and not opPixelClassification.classifier_cache._dirty):
                self.stored_classifers.append(
                    opPixelClassification.classifier_cache.Output.value)
            else:
                self.stored_classifers = []

    def handleNewLanesAdded(self):
        """
        Overridden from Workflow base class.
        Called immediately after a new lane is added to the workflow and initialized.
        """
        # Restore classifier we saved in prepareForNewLane() (if any)
        if self.stored_classifers:
            for pcApplet, classifier in zip(self.pcApplets,
                                            self.stored_classifers):
                pcApplet.topLevelOperator.classifier_cache.forceValue(
                    classifier)

            # Release references
            self.stored_classifers = []

    def connectLane(self, laneIndex):
        # Get a handle to each operator
        opData = self.dataSelectionApplet.topLevelOperator.getLane(laneIndex)
        opFirstFeatures = self.featureSelectionApplets[
            0].topLevelOperator.getLane(laneIndex)
        opFirstClassify = self.pcApplets[0].topLevelOperator.getLane(laneIndex)
        opFinalClassify = self.pcApplets[-1].topLevelOperator.getLane(
            laneIndex)
        opDataExport = self.dataExportApplet.topLevelOperator.getLane(
            laneIndex)

        def checkConstraints(*_):
            # if (opData.Image.meta.dtype in [np.uint8, np.uint16]) == False:
            #    msg = "The Autocontext Workflow only supports 8-bit images (UINT8 pixel type)\n"\
            #          "or 16-bit images (UINT16 pixel type)\n"\
            #          "Your image has a pixel type of {}.  Please convert your data to UINT8 and try again."\
            #          .format( str(np.dtype(opData.Image.meta.dtype)) )
            #    raise DatasetConstraintError( "Autocontext Workflow", msg, unfixable=True )
            pass

        opData.Image.notifyReady(checkConstraints)

        # Input Image -> Feature Op
        #         and -> Classification Op (for display)
        opFirstFeatures.InputImage.connect(opData.Image)
        opFirstClassify.InputImages.connect(opData.Image)

        # Feature Images -> Classification Op (for training, prediction)
        opFirstClassify.FeatureImages.connect(opFirstFeatures.OutputImage)
        opFirstClassify.CachedFeatureImages.connect(
            opFirstFeatures.CachedOutputImage)

        upstreamPcApplets = self.pcApplets[0:-1]
        downstreamFeatureApplets = self.featureSelectionApplets[1:]
        downstreamPcApplets = self.pcApplets[1:]

        for (upstreamPcApplet, downstreamFeaturesApplet,
             downstreamPcApplet) in zip(upstreamPcApplets,
                                        downstreamFeatureApplets,
                                        downstreamPcApplets):

            opUpstreamClassify = upstreamPcApplet.topLevelOperator.getLane(
                laneIndex)
            opDownstreamFeatures = downstreamFeaturesApplet.topLevelOperator.getLane(
                laneIndex)
            opDownstreamClassify = downstreamPcApplet.topLevelOperator.getLane(
                laneIndex)

            # Connect label inputs (all are connected together).
            # opDownstreamClassify.LabelInputs.connect( opUpstreamClassify.LabelInputs )

            # Connect data path
            assert opData.Image.meta.dtype == opUpstreamClassify.PredictionProbabilitiesAutocontext.meta.dtype, (
                "Probability dtype needs to match up with input image dtype, got: "
                f"input: {opData.Image.meta.dtype} "
                f"probabilities: {opUpstreamClassify.PredictionProbabilitiesAutocontext.meta.dtype}"
            )
            opStacker = OpMultiArrayStacker(parent=self)
            opStacker.Images.resize(2)
            opStacker.Images[0].connect(opData.Image)
            opStacker.Images[1].connect(
                opUpstreamClassify.PredictionProbabilitiesAutocontext)
            opStacker.AxisFlag.setValue("c")

            opDownstreamFeatures.InputImage.connect(opStacker.Output)
            opDownstreamClassify.InputImages.connect(opStacker.Output)
            opDownstreamClassify.FeatureImages.connect(
                opDownstreamFeatures.OutputImage)
            opDownstreamClassify.CachedFeatureImages.connect(
                opDownstreamFeatures.CachedOutputImage)

        # Data Export connections
        opDataExport.RawData.connect(opData.ImageGroup[self.DATA_ROLE_RAW])
        opDataExport.RawDatasetInfo.connect(
            opData.DatasetGroup[self.DATA_ROLE_RAW])
        opDataExport.ConstraintDataset.connect(
            opData.ImageGroup[self.DATA_ROLE_RAW])

        opDataExport.Inputs.resize(len(self.EXPORT_NAMES))
        for reverse_stage_index, (stage_index, pcApplet) in enumerate(
                reversed(list(enumerate(self.pcApplets)))):
            opPc = pcApplet.topLevelOperator.getLane(laneIndex)
            num_items_per_stage = len(self.EXPORT_NAMES_PER_STAGE)
            opDataExport.Inputs[num_items_per_stage * reverse_stage_index +
                                0].connect(
                                    opPc.HeadlessPredictionProbabilities)
            opDataExport.Inputs[num_items_per_stage * reverse_stage_index +
                                1].connect(opPc.SimpleSegmentation)
            opDataExport.Inputs[num_items_per_stage * reverse_stage_index +
                                2].connect(opPc.HeadlessUncertaintyEstimate)
            opDataExport.Inputs[num_items_per_stage * reverse_stage_index +
                                3].connect(opPc.FeatureImages)
            opDataExport.Inputs[num_items_per_stage * reverse_stage_index +
                                4].connect(opPc.LabelImages)
            opDataExport.Inputs[
                num_items_per_stage * reverse_stage_index + 5].connect(
                    opPc.InputImages
                )  # Input must come last due to an assumption in PixelClassificationDataExportGui

        # One last export slot for all probabilities, all stages
        opAllStageStacker = OpMultiArrayStacker(parent=self)
        opAllStageStacker.Images.resize(len(self.pcApplets))
        for stage_index, pcApplet in enumerate(self.pcApplets):
            opPc = pcApplet.topLevelOperator.getLane(laneIndex)
            opAllStageStacker.Images[stage_index].connect(
                opPc.HeadlessPredictionProbabilities)
            opAllStageStacker.AxisFlag.setValue("c")

        # The ideal_blockshape metadata field will be bogus, so just eliminate it
        # (Otherwise, the channels could be split up in an unfortunate way.)
        opMetadataOverride = OpMetadataInjector(parent=self)
        opMetadataOverride.Input.connect(opAllStageStacker.Output)
        opMetadataOverride.Metadata.setValue({"ideal_blockshape": None})

        opDataExport.Inputs[-1].connect(opMetadataOverride.Output)

        for slot in opDataExport.Inputs:
            assert slot.upstream_slot is not None

    def handleAppletStateUpdateRequested(self):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.appletStateUpdateRequested`
        """
        # If no data, nothing else is ready.
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        input_ready = len(opDataSelection.ImageGroup
                          ) > 0 and not self.dataSelectionApplet.busy

        # First, determine various 'ready' states for each pixel classification stage (features+prediction)
        StageFlags = collections.namedtuple(
            "StageFlags",
            "input_ready features_ready invalid_classifier predictions_ready live_update_active"
        )
        stage_flags = []
        for stage_index, (featureSelectionApplet, pcApplet) in enumerate(
                zip(self.featureSelectionApplets, self.pcApplets)):
            if stage_index == 0:
                input_ready = len(opDataSelection.ImageGroup
                                  ) > 0 and not self.dataSelectionApplet.busy
            else:
                input_ready = stage_flags[stage_index - 1].predictions_ready

            opFeatureSelection = featureSelectionApplet.topLevelOperator
            featureOutput = opFeatureSelection.OutputImage
            features_ready = (
                input_ready and len(featureOutput) > 0
                and featureOutput[0].ready()
                and (TinyVector(featureOutput[0].meta.shape) > 0).all())

            opPixelClassification = pcApplet.topLevelOperator
            invalid_classifier = (
                opPixelClassification.classifier_cache.fixAtCurrent.value
                and opPixelClassification.classifier_cache.Output.ready() and
                opPixelClassification.classifier_cache.Output.value is None)

            predictions_ready = (
                features_ready and not invalid_classifier and
                len(opPixelClassification.HeadlessPredictionProbabilities) > 0
                and opPixelClassification.HeadlessPredictionProbabilities[0].
                ready() and (TinyVector(
                    opPixelClassification.HeadlessPredictionProbabilities[0].
                    meta.shape) > 0).all())

            live_update_active = not opPixelClassification.FreezePredictions.value

            stage_flags += [
                StageFlags(input_ready, features_ready, invalid_classifier,
                           predictions_ready, live_update_active)
            ]

        opDataExport = self.dataExportApplet.topLevelOperator
        opPixelClassification = self.pcApplets[0].topLevelOperator

        # Problems can occur if the features or input data are changed during live update mode.
        # Don't let the user do that.
        any_live_update = any(flags.live_update_active
                              for flags in stage_flags)

        # The user isn't allowed to touch anything while batch processing is running.
        batch_processing_busy = self.batchProcessingApplet.busy

        self._shell.setAppletEnabled(
            self.dataSelectionApplet, not any_live_update
            and not batch_processing_busy)

        for stage_index, (featureSelectionApplet, pcApplet) in enumerate(
                zip(self.featureSelectionApplets, self.pcApplets)):
            upstream_live_update = any(flags.live_update_active
                                       for flags in stage_flags[:stage_index])
            this_stage_live_update = stage_flags[
                stage_index].live_update_active
            downstream_live_update = any(flags.live_update_active
                                         for flags in stage_flags[stage_index +
                                                                  1:])

            self._shell.setAppletEnabled(
                featureSelectionApplet,
                stage_flags[stage_index].input_ready
                and not this_stage_live_update and not downstream_live_update
                and not batch_processing_busy,
            )

            self._shell.setAppletEnabled(
                pcApplet,
                stage_flags[stage_index].features_ready
                # and not downstream_live_update \ # Not necessary because live update modes are synced -- labels can't be added in live update.
                and not batch_processing_busy,
            )

        self._shell.setAppletEnabled(
            self.dataExportApplet, stage_flags[-1].predictions_ready
            and not batch_processing_busy)
        self._shell.setAppletEnabled(
            self.batchProcessingApplet, predictions_ready
            and not batch_processing_busy)

        # Lastly, check for certain "busy" conditions, during which we
        #  should prevent the shell from closing the project.
        busy = False
        busy |= self.dataSelectionApplet.busy
        busy |= any(applet.busy for applet in self.featureSelectionApplets)
        busy |= self.dataExportApplet.busy
        busy |= self.batchProcessingApplet.busy
        self._shell.enableProjectChanges(not busy)

    def onProjectLoaded(self, projectManager):
        """
        Overridden from Workflow base class.  Called by the Project Manager.

        If the user provided command-line arguments, use them to configure
        the workflow for batch mode and export all results.
        (This workflow's headless mode supports only batch mode for now.)
        """
        if self._headless:
            # In headless mode, let's see the messages from the training operator.
            logging.getLogger(
                "lazyflow.operators.classifierOperators").setLevel(
                    logging.DEBUG)

        if self.retrain:
            self._force_retrain_classifiers(projectManager)

        # Configure the data export operator.
        if self._batch_export_args:
            self.dataExportApplet.configure_operator_with_parsed_args(
                self._batch_export_args)

        if self._batch_input_args:
            for pcApplet in self.pcApplets:
                if pcApplet.topLevelOperator.classifier_cache._dirty:
                    logger.warning(
                        "At least one of your classifiers is not yet trained.  "
                        "A new classifier will be trained for this run.")
                    break

        if self._headless and self._batch_input_args and self._batch_export_args:
            logger.info("Beginning Batch Processing")
            self.batchProcessingApplet.run_export_from_parsed_args(
                self._batch_input_args)
            logger.info("Completed Batch Processing")

    def prepare_for_entire_export(self):
        # While exporting, we don't want to cache any data.
        export_selection_index = self.dataExportApplet.topLevelOperator.InputSelection.value
        export_selection_name = self.dataExportApplet.topLevelOperator.SelectionNames.value[
            export_selection_index]
        if "all stages" in export_selection_name.lower():
            # UNLESS we're exporting from more than one stage at a time.
            # In that case, the caches help avoid unnecessary work (except for the last stage)
            self.featureSelectionApplets[
                -1].topLevelOperator.BypassCache.setValue(True)
        else:
            for featureSeletionApplet in self.featureSelectionApplets:
                featureSeletionApplet.topLevelOperator.BypassCache.setValue(
                    True)

        # Unfreeze the classifier caches (ensure that we're exporting based on up-to-date labels)
        self.freeze_statuses = []
        for pcApplet in self.pcApplets:
            self.freeze_statuses.append(
                pcApplet.topLevelOperator.FreezePredictions.value)
            pcApplet.topLevelOperator.FreezePredictions.setValue(False)

    def post_process_entire_export(self):
        # While exporting, we disabled caches, but now we can enable them again.
        for featureSeletionApplet in self.featureSelectionApplets:
            featureSeletionApplet.topLevelOperator.BypassCache.setValue(False)

        # Re-freeze classifier caches (if necessary)
        for pcApplet, freeze_status in zip(self.pcApplets,
                                           self.freeze_statuses):
            pcApplet.topLevelOperator.FreezePredictions.setValue(freeze_status)

    def _force_retrain_classifiers(self, projectManager):
        # Cause the FIRST classifier to be dirty so it is forced to retrain.
        # (useful if the stored labels were changed outside ilastik)
        self.pcApplets[0].topLevelOperator.opTrain.ClassifierFactory.setDirty()

        # Unfreeze all classifier caches.
        for pcApplet in self.pcApplets:
            pcApplet.topLevelOperator.FreezePredictions.setValue(False)

        # Request the LAST classifier, which forces training
        _ = self.pcApplets[-1].topLevelOperator.Classifier.value

        # store new classifiers to project file
        projectManager.saveProject(force_all_save=False)

    def menus(self):
        """
        Overridden from Workflow base class
        """
        from PyQt5.QtWidgets import QMenu

        autocontext_menu = QMenu("Autocontext Utilities")
        distribute_action = autocontext_menu.addAction("Distribute Labels...")
        distribute_action.triggered.connect(
            self.distribute_labels_from_current_stage)

        self._autocontext_menu = (
            autocontext_menu
        )  # Must retain here as a member or else reference disappears and the menu is deleted.
        return [self._autocontext_menu]

    def distribute_labels_from_current_stage(self):
        """
        Distrubute labels from the currently viewed stage across all other stages.
        """
        # Late import.
        # (Don't import PyQt in headless mode.)
        from PyQt5.QtWidgets import QMessageBox

        current_applet = self._applets[self.shell.currentAppletIndex]
        if current_applet not in self.pcApplets:
            QMessageBox.critical(
                self.shell, "Wrong page selected",
                "The currently active page isn't a Training page.")
            return

        current_stage_index = self.pcApplets.index(current_applet)
        destination_stage_indexes, partition = self.get_label_distribution_settings(
            current_stage_index, num_stages=len(self.pcApplets))
        if destination_stage_indexes is None:
            return  # User cancelled

        current_applet = self._applets[self.shell.currentAppletIndex]
        opCurrentPixelClassification = current_applet.topLevelOperator
        num_current_stage_classes = len(
            opCurrentPixelClassification.LabelNames.value)

        # Before we get started, make sure the destination stages have the necessary label classes
        for stage_index in destination_stage_indexes:
            # Get this stage's OpPixelClassification
            opPc = self.pcApplets[stage_index].topLevelOperator

            # Copy Label Colors
            current_stage_label_colors = opCurrentPixelClassification.LabelColors.value
            new_label_colors = list(opPc.LabelColors.value)
            new_label_colors[:
                             num_current_stage_classes] = current_stage_label_colors[:
                                                                                     num_current_stage_classes]
            opPc.LabelColors.setValue(new_label_colors)

            # Copy PMap colors
            current_stage_pmap_colors = opCurrentPixelClassification.PmapColors.value
            new_pmap_colors = list(opPc.PmapColors.value)
            new_pmap_colors[:
                            num_current_stage_classes] = current_stage_pmap_colors[:
                                                                                   num_current_stage_classes]
            opPc.PmapColors.setValue(new_pmap_colors)

            # Copy Label Names
            current_stage_label_names = opCurrentPixelClassification.LabelNames.value
            new_label_names = list(opPc.LabelNames.value)
            new_label_names[:
                            num_current_stage_classes] = current_stage_label_names[:
                                                                                   num_current_stage_classes]
            opPc.LabelNames.setValue(new_label_names)

        # For each lane, copy over the labels from the source stage to the destination stages
        for lane_index in range(len(opCurrentPixelClassification.InputImages)):
            opPcLane = opCurrentPixelClassification.getLane(lane_index)

            # Gather all the labels for this lane
            blockwise_labels = {}
            nonzero_slicings = opPcLane.NonzeroLabelBlocks.value
            for block_slicing in nonzero_slicings:
                # Convert from slicing to roi-tuple so we can hash it in a dict key
                block_roi = sliceToRoi(block_slicing,
                                       opPcLane.InputImages.meta.shape)
                block_roi = tuple(map(tuple, block_roi))
                blockwise_labels[block_roi] = opPcLane.LabelImages[
                    block_slicing].wait()

            if partition and current_stage_index in destination_stage_indexes:
                # Clear all labels in the current lane, since we'll be overwriting it with a subset of labels
                # FIXME: We could implement a fast function for this in OpCompressedUserLabelArray...
                for label_value in range(1, num_current_stage_classes + 1):
                    opPcLane.opLabelPipeline.opLabelArray.clearLabel(
                        label_value)

            # Now redistribute those labels across all lanes
            for block_roi, block_labels in list(blockwise_labels.items()):
                nonzero_coords = block_labels.nonzero()

                if partition:
                    num_labels = len(nonzero_coords[0])
                    destination_stage_map = np.random.choice(
                        destination_stage_indexes, (num_labels, ))

                for stage_index in destination_stage_indexes:
                    if not partition:
                        this_stage_block_labels = block_labels
                    else:
                        # Divide into disjoint partitions
                        # Find the coordinates labels destined for this stage
                        this_stage_coords = np.transpose(nonzero_coords)[
                            destination_stage_map == stage_index]
                        this_stage_coords = tuple(
                            this_stage_coords.transpose())

                        # Extract only the labels destined for this stage
                        this_stage_block_labels = np.zeros_like(block_labels)
                        this_stage_block_labels[
                            this_stage_coords] = block_labels[
                                this_stage_coords]

                    # Get the current lane's view of this stage's OpPixelClassification
                    opPc = self.pcApplets[
                        stage_index].topLevelOperator.getLane(lane_index)

                    # Inject
                    opPc.LabelInputs[roiToSlice(
                        *block_roi)] = this_stage_block_labels

    @staticmethod
    def get_label_distribution_settings(source_stage_index, num_stages):
        # Late import.
        # (Don't import PyQt in headless mode.)
        from PyQt5.QtWidgets import QDialog, QVBoxLayout

        class LabelDistributionOptionsDlg(QDialog):
            """
            A little dialog to let the user specify how the labels should be
            distributed from the current stages to the other stages.
            """
            def __init__(self, source_stage_index, num_stages, *args,
                         **kwargs):
                super(LabelDistributionOptionsDlg,
                      self).__init__(*args, **kwargs)

                from PyQt5.QtCore import Qt
                from PyQt5.QtWidgets import QGroupBox, QCheckBox, QRadioButton, QDialogButtonBox

                self.setWindowTitle(
                    "Distributing from Stage {}".format(source_stage_index +
                                                        1))

                self.stage_checkboxes = []
                for stage_index in range(1, num_stages + 1):
                    self.stage_checkboxes.append(
                        QCheckBox("Stage {}".format(stage_index)))

                # By default, send labels back into the current stage, at least.
                self.stage_checkboxes[source_stage_index].setChecked(True)

                stage_selection_layout = QVBoxLayout()
                for checkbox in self.stage_checkboxes:
                    stage_selection_layout.addWidget(checkbox)

                stage_selection_groupbox = QGroupBox(
                    "Send labels from Stage {} to:".format(source_stage_index +
                                                           1), self)
                stage_selection_groupbox.setLayout(stage_selection_layout)

                self.copy_button = QRadioButton("Copy", self)
                self.partition_button = QRadioButton("Partition", self)
                self.partition_button.setChecked(True)
                distribution_mode_layout = QVBoxLayout()
                distribution_mode_layout.addWidget(self.copy_button)
                distribution_mode_layout.addWidget(self.partition_button)

                distribution_mode_group = QGroupBox("Distribution Mode", self)
                distribution_mode_group.setLayout(distribution_mode_layout)

                buttonbox = QDialogButtonBox(Qt.Horizontal, parent=self)
                buttonbox.setStandardButtons(QDialogButtonBox.Ok
                                             | QDialogButtonBox.Cancel)
                buttonbox.accepted.connect(self.accept)
                buttonbox.rejected.connect(self.reject)

                dlg_layout = QVBoxLayout()
                dlg_layout.addWidget(stage_selection_groupbox)
                dlg_layout.addWidget(distribution_mode_group)
                dlg_layout.addWidget(buttonbox)
                self.setLayout(dlg_layout)

            def distribution_mode(self):
                if self.copy_button.isChecked():
                    return "copy"
                if self.partition_button.isChecked():
                    return "partition"
                assert False, "Shouldn't get here."

            def destination_stages(self):
                """
                Return the list of stage_indexes (0-based) that the user checked.
                """
                return [
                    i for i, box in enumerate(self.stage_checkboxes)
                    if box.isChecked()
                ]

        dlg = LabelDistributionOptionsDlg(source_stage_index, num_stages)
        if dlg.exec_() == QDialog.Rejected:
            return (None, None)

        destination_stage_indexes = dlg.destination_stages()
        partition = dlg.distribution_mode() == "partition"
        return (destination_stage_indexes, partition)
예제 #45
0
class StructuredTrackingWorkflowBase( Workflow ):
    workflowName = "Structured Learning Tracking Workflow BASE"

    @property
    def applets(self):
        return self._applets

    @property
    def imageNameListSlot(self):
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def __init__( self, shell, headless, workflow_cmdline_args, project_creation_args, *args, **kwargs ):
        graph = kwargs['graph'] if 'graph' in kwargs else Graph()
        if 'graph' in kwargs: del kwargs['graph']

        super(StructuredTrackingWorkflowBase, self).__init__(shell, headless, workflow_cmdline_args, project_creation_args, graph=graph, *args, **kwargs)

        data_instructions = 'Use the "Raw Data" tab to load your intensity image(s).\n\n'
        if self.fromBinary:
            data_instructions += 'Use the "Binary Image" tab to load your segmentation image(s).'
        else:
            data_instructions += 'Use the "Prediction Maps" tab to load your pixel-wise probability image(s).'

        # Create applets
        self.dataSelectionApplet = DataSelectionApplet(self,
            "Input Data",
            "Input Data",
            batchDataGui=False,
            forceAxisOrder=['txyzc'],
            instructionText=data_instructions,
            max_lanes=1)

        opDataSelection = self.dataSelectionApplet.topLevelOperator
        if self.fromBinary:
            opDataSelection.DatasetRoles.setValue( ['Raw Data', 'Binary Image'] )
        else:
            opDataSelection.DatasetRoles.setValue( ['Raw Data', 'Prediction Maps'] )

        if not self.fromBinary:
            self.thresholdTwoLevelsApplet = ThresholdTwoLevelsApplet( self,"Threshold and Size Filter","ThresholdTwoLevels" )

        self.divisionDetectionApplet = ObjectClassificationApplet(workflow=self,
                                                                     name="Division Detection (optional)",
                                                                     projectFileGroupName="DivisionDetection",
                                                                     selectedFeatures=configConservation.selectedFeaturesDiv)

        self.cellClassificationApplet = ObjectClassificationApplet(workflow=self,
                                                                     name="Object Count Classification",
                                                                     projectFileGroupName="CountClassification",
                                                                     selectedFeatures=configConservation.selectedFeaturesObjectCount)

        self.trackingFeatureExtractionApplet = TrackingFeatureExtractionApplet(name="Object Feature Computation",workflow=self, interactive=False)

        self.objectExtractionApplet = ObjectExtractionApplet(name="Object Feature Computation",workflow=self, interactive=False)

        self.annotationsApplet = AnnotationsApplet( name="Training", workflow=self )
        opAnnotations = self.annotationsApplet.topLevelOperator

        self.trackingApplet = StructuredTrackingApplet( name="Tracking - Structured Learning", workflow=self )
        opStructuredTracking = self.trackingApplet.topLevelOperator

        if SOLVER=="CPLEX" or SOLVER=="GUROBI":
            self._solver="ILP"
        elif SOLVER=="DPCT":
            self._solver="Flow-based"
        else:
            self._solver=None
        opStructuredTracking._solver = self._solver

        self.default_tracking_export_filename = '{dataset_dir}/{nickname}-tracking_exported_data.csv'
        self.dataExportTrackingApplet = TrackingBaseDataExportApplet(self, "Tracking Result Export",default_export_filename=self.default_tracking_export_filename)
        opDataExportTracking = self.dataExportTrackingApplet.topLevelOperator
        opDataExportTracking.SelectionNames.setValue( ['Tracking-Result', 'Merger-Result', 'Object-Identities'] )
        opDataExportTracking.WorkingDirectory.connect( opDataSelection.WorkingDirectory )
        self.dataExportTrackingApplet.set_exporting_operator(opStructuredTracking)
        self.dataExportTrackingApplet.prepare_lane_for_export = self.prepare_lane_for_export
        self.dataExportTrackingApplet.post_process_lane_export = self.post_process_lane_export

        # configure export settings
        settings = {'file path': self.default_tracking_export_filename, 'compression': {}, 'file type': 'h5'}
        selected_features = ['Count', 'RegionCenter', 'RegionRadii', 'RegionAxes']
        opStructuredTracking.ExportSettings.setValue( (settings, selected_features) )

        self._applets = []
        self._applets.append(self.dataSelectionApplet)
        if not self.fromBinary:
            self._applets.append(self.thresholdTwoLevelsApplet)
        self._applets.append(self.trackingFeatureExtractionApplet)
        self._applets.append(self.divisionDetectionApplet)

        self.batchProcessingApplet = BatchProcessingApplet(self, "Batch Processing", self.dataSelectionApplet, self.dataExportTrackingApplet)

        self._applets.append(self.cellClassificationApplet)
        self._applets.append(self.objectExtractionApplet)
        self._applets.append(self.annotationsApplet)
        self._applets.append(self.trackingApplet)
        self._applets.append(self.dataExportTrackingApplet)

        if self.divisionDetectionApplet:
            opDivDetection = self.divisionDetectionApplet.topLevelOperator
            opDivDetection.SelectedFeatures.setValue(configConservation.selectedFeaturesDiv)
            opDivDetection.LabelNames.setValue(['Not Dividing', 'Dividing'])
            opDivDetection.AllowDeleteLabels.setValue(False)
            opDivDetection.AllowAddLabel.setValue(False)
            opDivDetection.EnableLabelTransfer.setValue(False)

        opCellClassification = self.cellClassificationApplet.topLevelOperator
        opCellClassification.SelectedFeatures.setValue(configConservation.selectedFeaturesObjectCount )
        opCellClassification.SuggestedLabelNames.setValue( ['False Detection',] + [str(1) + ' Object'] + [str(i) + ' Objects' for i in range(2,10) ] )
        opCellClassification.AllowDeleteLastLabelOnly.setValue(True)
        opCellClassification.EnableLabelTransfer.setValue(False)

        if workflow_cmdline_args:

            if '--testFullAnnotations' in workflow_cmdline_args:
                self.testFullAnnotations = True
            else:
                self.testFullAnnotations = False

            self._data_export_args, unused_args = self.dataExportTrackingApplet.parse_known_cmdline_args( workflow_cmdline_args )
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args( workflow_cmdline_args )
        else:
            unused_args = None
            self._data_export_args = None
            self._batch_input_args = None
            self.testFullAnnotations = False

        if unused_args:
            logger.warning("Unused command-line args: {}".format( unused_args ))

    def connectLane(self, laneIndex):
        opData = self.dataSelectionApplet.topLevelOperator.getLane(laneIndex)
        opObjExtraction = self.objectExtractionApplet.topLevelOperator.getLane(laneIndex)
        opTrackingFeatureExtraction = self.trackingFeatureExtractionApplet.topLevelOperator.getLane(laneIndex)

        opAnnotations = self.annotationsApplet.topLevelOperator.getLane(laneIndex)
        if not self.fromBinary:
            opTwoLevelThreshold = self.thresholdTwoLevelsApplet.topLevelOperator.getLane(laneIndex)

        opStructuredTracking = self.trackingApplet.topLevelOperator.getLane(laneIndex)
        opDataTrackingExport = self.dataExportTrackingApplet.topLevelOperator.getLane(laneIndex)

        ## Connect operators ##
        op5Raw = OpReorderAxes(parent=self)
        op5Raw.AxisOrder.setValue("txyzc")
        op5Raw.Input.connect(opData.ImageGroup[0])

        opDivDetection = self.divisionDetectionApplet.topLevelOperator.getLane(laneIndex)
        opCellClassification = self.cellClassificationApplet.topLevelOperator.getLane(laneIndex)

        if not self.fromBinary:
            opTwoLevelThreshold.InputImage.connect( opData.ImageGroup[1] )
            opTwoLevelThreshold.RawInput.connect( opData.ImageGroup[0] ) # Used for display only
            binarySrc = opTwoLevelThreshold.CachedOutput
        else:
            binarySrc = opData.ImageGroup[1]
        # Use Op5ifyers for both input datasets such that they are guaranteed to
        # have the same axis order after thresholding
        op5Binary = OpReorderAxes(parent=self)
        op5Binary.AxisOrder.setValue("txyzc")
        op5Binary.Input.connect(binarySrc)

        opObjExtraction.RawImage.connect( op5Raw.Output )
        opObjExtraction.BinaryImage.connect( op5Binary.Output )

        opTrackingFeatureExtraction.RawImage.connect( op5Raw.Output )
        opTrackingFeatureExtraction.BinaryImage.connect( op5Binary.Output )

        opTrackingFeatureExtraction.setDefaultFeatures(configConservation.allFeaturesObjectCount)
        opTrackingFeatureExtraction.FeatureNamesVigra.setValue(configConservation.allFeaturesObjectCount)
        feature_dict_division = {}
        feature_dict_division[config.features_division_name] = { name: {} for name in config.division_features }
        opTrackingFeatureExtraction.FeatureNamesDivision.setValue(feature_dict_division)

        if self.divisionDetectionApplet:
            opDivDetection.BinaryImages.connect( op5Binary.Output )
            opDivDetection.RawImages.connect( op5Raw.Output )
            opDivDetection.SegmentationImages.connect(opTrackingFeatureExtraction.LabelImage)
            opDivDetection.ObjectFeatures.connect(opTrackingFeatureExtraction.RegionFeaturesAll)
            opDivDetection.ComputedFeatureNames.connect(opTrackingFeatureExtraction.ComputedFeatureNamesAll)

        opCellClassification.BinaryImages.connect( op5Binary.Output )
        opCellClassification.RawImages.connect( op5Raw.Output )
        opCellClassification.SegmentationImages.connect(opTrackingFeatureExtraction.LabelImage)
        opCellClassification.ObjectFeatures.connect(opTrackingFeatureExtraction.RegionFeaturesAll)
        opCellClassification.ComputedFeatureNames.connect(opTrackingFeatureExtraction.ComputedFeatureNamesNoDivisions)

        opAnnotations.RawImage.connect( op5Raw.Output )
        opAnnotations.BinaryImage.connect( op5Binary.Output )
        opAnnotations.LabelImage.connect( opObjExtraction.LabelImage )
        opAnnotations.ObjectFeatures.connect( opObjExtraction.RegionFeatures )
        opAnnotations.ComputedFeatureNames.connect(opObjExtraction.Features)
        opAnnotations.DivisionProbabilities.connect( opDivDetection.Probabilities )
        opAnnotations.DetectionProbabilities.connect( opCellClassification.Probabilities )
        opAnnotations.MaxNumObj.connect (opCellClassification.MaxNumObj)

        opStructuredTracking.RawImage.connect( op5Raw.Output )
        opStructuredTracking.LabelImage.connect( opTrackingFeatureExtraction.LabelImage )
        opStructuredTracking.ObjectFeatures.connect( opTrackingFeatureExtraction.RegionFeaturesVigra )
        opStructuredTracking.ComputedFeatureNames.connect( opTrackingFeatureExtraction.FeatureNamesVigra )

        if self.divisionDetectionApplet:
            opStructuredTracking.ObjectFeaturesWithDivFeatures.connect( opTrackingFeatureExtraction.RegionFeaturesAll)
            opStructuredTracking.ComputedFeatureNamesWithDivFeatures.connect( opTrackingFeatureExtraction.ComputedFeatureNamesAll )
            opStructuredTracking.DivisionProbabilities.connect( opDivDetection.Probabilities )

        opStructuredTracking.DetectionProbabilities.connect( opCellClassification.Probabilities )
        opStructuredTracking.NumLabels.connect( opCellClassification.NumLabels )
        opStructuredTracking.Annotations.connect (opAnnotations.Annotations)
        opStructuredTracking.Labels.connect (opAnnotations.Labels)
        opStructuredTracking.Divisions.connect (opAnnotations.Divisions)
        opStructuredTracking.Appearances.connect (opAnnotations.Appearances)
        opStructuredTracking.Disappearances.connect (opAnnotations.Disappearances)
        opStructuredTracking.MaxNumObj.connect (opCellClassification.MaxNumObj)

        opDataTrackingExport.Inputs.resize(3)
        opDataTrackingExport.Inputs[0].connect( opStructuredTracking.RelabeledImage )
        opDataTrackingExport.Inputs[1].connect( opStructuredTracking.MergerOutput )
        opDataTrackingExport.Inputs[2].connect( opStructuredTracking.LabelImage )
        opDataTrackingExport.RawData.connect( op5Raw.Output )
        opDataTrackingExport.RawDatasetInfo.connect( opData.DatasetGroup[0] )

    def prepare_lane_for_export(self, lane_index):
        import logging
        logger = logging.getLogger(__name__)

        maxt = self.trackingApplet.topLevelOperator[lane_index].RawImage.meta.shape[0]
        maxx = self.trackingApplet.topLevelOperator[lane_index].RawImage.meta.shape[1]
        maxy = self.trackingApplet.topLevelOperator[lane_index].RawImage.meta.shape[2]
        maxz = self.trackingApplet.topLevelOperator[lane_index].RawImage.meta.shape[3]
        time_enum = list(range(maxt))
        x_range = (0, maxx)
        y_range = (0, maxy)
        z_range = (0, maxz)

        ndim = 2
        if ( z_range[1] - z_range[0] ) > 1:
            ndim = 3

        parameters = self.trackingApplet.topLevelOperator.Parameters.value
        # Save state of axis ranges
        if 'time_range' in parameters:
            self.prev_time_range = parameters['time_range']
        else:
            self.prev_time_range = time_enum

        if 'x_range' in parameters:
            self.prev_x_range = parameters['x_range']
        else:
            self.prev_x_range = x_range

        if 'y_range' in parameters:
            self.prev_y_range = parameters['y_range']
        else:
            self.prev_y_range = y_range

        if 'z_range' in parameters:
            self.prev_z_range = parameters['z_range']
        else:
            self.prev_z_range = z_range

        # batch processing starts a new lane, so training data needs to be copied from the lane that loaded the project
        loaded_project_lane_index=0
        self.annotationsApplet.topLevelOperator[lane_index].Annotations.setValue(
            self.trackingApplet.topLevelOperator[loaded_project_lane_index].Annotations.value)

        def runLearningAndTracking(withMergerResolution=True):
            if self.testFullAnnotations:
                logger.info("Test: Structured Learning")
                weights = self.trackingApplet.topLevelOperator[lane_index]._runStructuredLearning(
                    z_range,
                    parameters['maxObj'],
                    parameters['max_nearest_neighbors'],
                    parameters['maxDist'],
                    parameters['divThreshold'],
                    [parameters['scales'][0],parameters['scales'][1],parameters['scales'][2]],
                    parameters['size_range'],
                    parameters['withDivisions'],
                    parameters['borderAwareWidth'],
                    parameters['withClassifierPrior'],
                    withBatchProcessing=True)
                logger.info("weights: {}".format(weights))

            logger.info("Test: Tracking")
            result = self.trackingApplet.topLevelOperator[lane_index].track(
                time_range = time_enum,
                x_range = x_range,
                y_range = y_range,
                z_range = z_range,
                size_range = parameters['size_range'],
                x_scale = parameters['scales'][0],
                y_scale = parameters['scales'][1],
                z_scale = parameters['scales'][2],
                maxDist=parameters['maxDist'],
                maxObj = parameters['maxObj'],
                divThreshold=parameters['divThreshold'],
                avgSize=parameters['avgSize'],
                withTracklets=parameters['withTracklets'],
                sizeDependent=parameters['sizeDependent'],
                detWeight=parameters['detWeight'],
                divWeight=parameters['divWeight'],
                transWeight=parameters['transWeight'],
                withDivisions=parameters['withDivisions'],
                withOpticalCorrection=parameters['withOpticalCorrection'],
                withClassifierPrior=parameters['withClassifierPrior'],
                ndim=ndim,
                withMergerResolution=withMergerResolution,
                borderAwareWidth = parameters['borderAwareWidth'],
                withArmaCoordinates = parameters['withArmaCoordinates'],
                cplex_timeout = parameters['cplex_timeout'],
                appearance_cost = parameters['appearanceCost'],
                disappearance_cost = parameters['disappearanceCost'],
                force_build_hypotheses_graph = False,
                withBatchProcessing = True
            )

            return result

        if self.testFullAnnotations:

            self.result = runLearningAndTracking(withMergerResolution=False)

            hypothesesGraph = self.trackingApplet.topLevelOperator[lane_index].LearningHypothesesGraph.value
            hypothesesGraph.insertSolution(self.result)
            hypothesesGraph.computeLineage()
            solution = hypothesesGraph.getSolutionDictionary()
            annotations = self.trackingApplet.topLevelOperator[lane_index].Annotations.value

            self.trackingApplet.topLevelOperator[lane_index].insertAnnotationsToHypothesesGraph(hypothesesGraph,annotations,misdetectionLabel=-1)
            hypothesesGraph.computeLineage()
            solutionFromAnnotations = hypothesesGraph.getSolutionDictionary()

            for key in list(solution.keys()):
                if key == 'detectionResults':
                    detectionFlag = True
                    for i in range(len(solution[key])):
                        flag = False
                        for j in range(len(solutionFromAnnotations[key])):
                            if solution[key][i]['id'] == solutionFromAnnotations[key][j]['id'] and \
                                solution[key][i]['value'] == solutionFromAnnotations[key][j]['value']:
                                flag = True
                                break
                        detectionFlag &= flag
                elif key == 'divisionResults':
                    divisionFlag = True
                    for i in range(len(solution[key])):
                        flag = False
                        for j in range(len(solutionFromAnnotations[key])):
                            if solution[key][i]['id'] == solutionFromAnnotations[key][j]['id'] and \
                                solution[key][i]['value'] == solutionFromAnnotations[key][j]['value']:
                                flag = True
                                break
                        divisionFlag &= flag
                elif key == 'linkingResults':
                    linkingFlag = True
                    for i in range(len(solution[key])):
                        flag = False
                        for j in range(len(solutionFromAnnotations[key])):
                            if solution[key][i]['dest'] == solutionFromAnnotations[key][j]['dest'] and \
                                solution[key][i]['src'] == solutionFromAnnotations[key][j]['src']:
                                if solution[key][i]['gap'] == solutionFromAnnotations[key][j]['gap'] and \
                                    solution[key][i]['value'] == solutionFromAnnotations[key][j]['value']:
                                    flag = True
                                    break
                        linkingFlag &= flag

            assert detectionFlag, "Detection results are NOT correct. They differ from your annotated detections."
            logger.info("Detection results are correct.")
            assert divisionFlag, "Division results are NOT correct. They differ from your annotated divisions."
            logger.info("Division results are correct.")
            assert linkingFlag, "Transition results are NOT correct. They differ from your annotated transitions."
            logger.info("Transition results are correct.")
        self.result = runLearningAndTracking(withMergerResolution=parameters['withMergerResolution'])

    def post_process_lane_export(self, lane_index, checkOverwriteFiles=False):
        # Plugin export if selected
        logger.info("Export source is: " + self.dataExportTrackingApplet.topLevelOperator.SelectedExportSource.value)

        print("in post_process_lane_export")
        if self.dataExportTrackingApplet.topLevelOperator.SelectedExportSource.value == OpTrackingBaseDataExport.PluginOnlyName:
            logger.info("Export source plugin selected!")
            selectedPlugin = self.dataExportTrackingApplet.topLevelOperator.SelectedPlugin.value
            additionalPluginArgumentsSlot = self.dataExportTrackingApplet.topLevelOperator.AdditionalPluginArguments

            exportPluginInfo = pluginManager.getPluginByName(selectedPlugin, category="TrackingExportFormats")
            if exportPluginInfo is None:
                logger.error("Could not find selected plugin %s" % exportPluginInfo)
            else:
                exportPlugin = exportPluginInfo.plugin_object
                logger.info("Exporting tracking result using %s" % selectedPlugin)
                name_format = self.dataExportTrackingApplet.topLevelOperator.getLane(lane_index).OutputFilenameFormat.value
                partially_formatted_name = self.getPartiallyFormattedName(lane_index, name_format)

                if exportPlugin.exportsToFile:
                    filename = partially_formatted_name
                    if os.path.basename(filename) == '':
                        filename = os.path.join(filename, 'pluginExport.txt')
                else:
                    filename = partially_formatted_name

                if filename is None or len(str(filename)) == 0:
                    logger.error("Cannot export from plugin with empty output filename")
                    return True

                self.dataExportTrackingApplet.progressSignal(-1)
                exportStatus = self.trackingApplet.topLevelOperator.getLane(lane_index).exportPlugin(
                    filename, exportPlugin, checkOverwriteFiles, additionalPluginArgumentsSlot)
                self.dataExportTrackingApplet.progressSignal(100)

                if not exportStatus:
                    return False
                logger.info("Export done")

            return True

        return True

    def getPartiallyFormattedName(self, lane_index, path_format_string):
        ''' Takes the format string for the output file, fills in the most important placeholders, and returns it '''
        raw_dataset_info = self.dataSelectionApplet.topLevelOperator.DatasetGroup[lane_index][0].value
        project_path = self.shell.projectManager.currentProjectPath
        project_dir = os.path.dirname(project_path)
        dataset_dir = PathComponents(raw_dataset_info.filePath).externalDirectory
        abs_dataset_dir = make_absolute(dataset_dir, cwd=project_dir)
        known_keys = {}
        known_keys['dataset_dir'] = abs_dataset_dir
        nickname = raw_dataset_info.nickname.replace('*', '')
        if os.path.pathsep in nickname:
            nickname = PathComponents(nickname.split(os.path.pathsep)[0]).fileNameBase
        known_keys['nickname'] = nickname
        known_keys['result_type'] = self.dataExportTrackingApplet.topLevelOperator.SelectedPlugin._value
        # use partial formatting to fill in non-coordinate name fields
        partially_formatted_name = format_known_keys(path_format_string, known_keys)
        return partially_formatted_name

    def _inputReady(self, nRoles):
        slot = self.dataSelectionApplet.topLevelOperator.ImageGroup
        if len(slot) > 0:
            input_ready = True
            for sub in slot:
                input_ready = input_ready and \
                    all([sub[i].ready() for i in range(nRoles)])
        else:
            input_ready = False
        return input_ready

    def onProjectLoaded(self, projectManager):
        """
        Overridden from Workflow base class.  Called by the Project Manager.

        If the user provided command-line arguments, use them to configure
        the workflow inputs and output settings.
        """

        # Configure the data export operator.
        if self._data_export_args:
            self.dataExportTrackingApplet.configure_operator_with_parsed_args( self._data_export_args )

        # Configure headless mode.
        if self._headless and self._batch_input_args and self._data_export_args:
            logger.info("Beginning Batch Processing")
            self.batchProcessingApplet.run_export_from_parsed_args(self._batch_input_args)
            logger.info("Completed Batch Processing")

    def handleAppletStateUpdateRequested(self):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.statusUpdateSignal`
        """
        # If no data, nothing else is ready.
        input_ready = self._inputReady(2) and not self.dataSelectionApplet.busy

        if not self.fromBinary:
            opThresholding = self.thresholdTwoLevelsApplet.topLevelOperator
            thresholdingOutput = opThresholding.CachedOutput
            thresholding_ready = input_ready and len(thresholdingOutput) > 0
        else:
            thresholding_ready = input_ready

        opTrackingFeatureExtraction = self.trackingFeatureExtractionApplet.topLevelOperator
        trackingFeatureExtractionOutput = opTrackingFeatureExtraction.ComputedFeatureNamesAll
        tracking_features_ready = thresholding_ready and len(trackingFeatureExtractionOutput) > 0

        objectCountClassifier_ready = tracking_features_ready

        opObjectExtraction = self.objectExtractionApplet.topLevelOperator
        objectExtractionOutput = opObjectExtraction.RegionFeatures
        features_ready = thresholding_ready and \
                         len(objectExtractionOutput) > 0

        opAnnotations = self.annotationsApplet.topLevelOperator
        annotations_ready = features_ready and \
                           len(opAnnotations.Labels) > 0 and \
                           opAnnotations.Labels.ready() and \
                           opAnnotations.TrackImage.ready()

        opStructuredTracking = self.trackingApplet.topLevelOperator
        structured_tracking_ready = objectCountClassifier_ready

        withIlpSolver = (self._solver=="ILP")

        busy = False
        busy |= self.dataSelectionApplet.busy
        busy |= self.annotationsApplet.busy
        # busy |= self.dataExportAnnotationsApplet.busy
        busy |= self.trackingApplet.busy
        busy |= self.dataExportTrackingApplet.busy

        self._shell.enableProjectChanges( not busy )

        self._shell.setAppletEnabled(self.dataSelectionApplet, not busy)
        if not self.fromBinary:
            self._shell.setAppletEnabled(self.thresholdTwoLevelsApplet, input_ready and not busy)
        self._shell.setAppletEnabled(self.trackingFeatureExtractionApplet, thresholding_ready and not busy)
        self._shell.setAppletEnabled(self.cellClassificationApplet, tracking_features_ready and not busy)
        self._shell.setAppletEnabled(self.divisionDetectionApplet, tracking_features_ready and not busy)
        self._shell.setAppletEnabled(self.objectExtractionApplet, not busy)
        self._shell.setAppletEnabled(self.annotationsApplet, features_ready and not busy) # and withIlpSolver)
        # self._shell.setAppletEnabled(self.dataExportAnnotationsApplet, annotations_ready and not busy and \
        #                                 self.dataExportAnnotationsApplet.topLevelOperator.Inputs[0][0].ready() )
        self._shell.setAppletEnabled(self.trackingApplet, objectCountClassifier_ready and not busy)
        self._shell.setAppletEnabled(self.dataExportTrackingApplet, structured_tracking_ready and not busy and \
                                    self.dataExportTrackingApplet.topLevelOperator.Inputs[0][0].ready() )
예제 #46
0
    def __init__(self, shell, headless, workflow_cmdline_args,
                 project_creation_args, n_stages, *args, **kwargs):
        """
        n_stages: How many iterations of feature selection and pixel classification should be inserted into the workflow.

        All other params are just as in PixelClassificationWorkflow
        """
        # Create a graph to be shared by all operators
        graph = Graph()
        super(NewAutocontextWorkflowBase, self).__init__(shell,
                                                         headless,
                                                         workflow_cmdline_args,
                                                         project_creation_args,
                                                         graph=graph,
                                                         *args,
                                                         **kwargs)
        self.stored_classifers = []
        self._applets = []
        self._workflow_cmdline_args = workflow_cmdline_args

        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument(
            "--retrain",
            help=
            "Re-train the classifier based on labels stored in project file, and re-save.",
            action="store_true",
        )

        # Parse the creation args: These were saved to the project file when this project was first created.
        parsed_creation_args, unused_args = parser.parse_known_args(
            project_creation_args)

        # Parse the cmdline args for the current session.
        parsed_args, unused_args = parser.parse_known_args(
            workflow_cmdline_args)
        self.retrain = parsed_args.retrain

        data_instructions = (
            "Select your input data using the 'Raw Data' tab shown on the right.\n\n"
            "Power users: Optionally use the 'Prediction Mask' tab to supply a binary image that tells ilastik where it should avoid computations you don't need."
        )

        self.dataSelectionApplet = self.createDataSelectionApplet()
        opDataSelection = self.dataSelectionApplet.topLevelOperator

        # see role constants, above
        role_names = ["Raw Data", "Prediction Mask"]
        opDataSelection.DatasetRoles.setValue(role_names)

        self.featureSelectionApplets = []
        self.pcApplets = []
        for i in range(n_stages):
            self.featureSelectionApplets.append(
                self.createFeatureSelectionApplet(i))
            self.pcApplets.append(self.createPixelClassificationApplet(i))
        opFinalClassify = self.pcApplets[-1].topLevelOperator

        # If *any* stage enters 'live update' mode, make sure they all enter live update mode.
        def sync_freeze_predictions_settings(slot, *args):
            freeze_predictions = slot.value
            for pcApplet in self.pcApplets:
                pcApplet.topLevelOperator.FreezePredictions.setValue(
                    freeze_predictions)

        for pcApplet in self.pcApplets:
            pcApplet.topLevelOperator.FreezePredictions.notifyDirty(
                sync_freeze_predictions_settings)

        self.dataExportApplet = PixelClassificationDataExportApplet(
            self, "Prediction Export")
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.PmapColors.connect(opFinalClassify.PmapColors)
        opDataExport.LabelNames.connect(opFinalClassify.LabelNames)
        opDataExport.WorkingDirectory.connect(opDataSelection.WorkingDirectory)

        self.EXPORT_NAMES = []
        for stage_index in reversed(list(range(n_stages))):
            self.EXPORT_NAMES += [
                "{} Stage {}".format(name, stage_index + 1)
                for name in self.EXPORT_NAMES_PER_STAGE
            ]

        # And finally, one last item for *all* probabilities from all stages.
        self.EXPORT_NAMES += ["Probabilities All Stages"]
        opDataExport.SelectionNames.setValue(self.EXPORT_NAMES)

        # Expose for shell
        self._applets.append(self.dataSelectionApplet)
        self._applets += itertools.chain(
            *list(zip(self.featureSelectionApplets, self.pcApplets)))
        self._applets.append(self.dataExportApplet)

        self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        self.dataExportApplet.post_process_entire_export = self.post_process_entire_export

        self.batchProcessingApplet = BatchProcessingApplet(
            self, "Batch Processing", self.dataSelectionApplet,
            self.dataExportApplet)

        self._applets.append(self.batchProcessingApplet)
        if unused_args:
            # We parse the export setting args first.  All remaining args are considered input files by the input applet.
            self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args(
                unused_args)
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args(
                unused_args)
        else:
            self._batch_input_args = None
            self._batch_export_args = None

        if unused_args:
            logger.warning("Unused command-line args: {}".format(unused_args))
예제 #47
0
class NNClassificationWorkflow(Workflow):
    """
    Workflow for the Neural Network Classification Applet
    """

    workflowName = "Neural Network Classification"
    workflowDescription = "This is obviously self-explanatory."
    defaultAppletIndex = 0  # show DataSelection by default

    DATA_ROLE_RAW = 0
    ROLE_NAMES = ["Raw Data"]
    EXPORT_NAMES = ["Probabilities"]

    @property
    def applets(self):
        """
        Return the list of applets that are owned by this workflow
        """
        return self._applets

    @property
    def imageNameListSlot(self):
        """
        Return the "image name list" slot, which lists the names of
        all image lanes (i.e. files) currently loaded by the workflow
        """
        return self.dataSelectionApplet.topLevelOperator.ImageName

    def __init__(self, shell, headless, workflow_cmdline_args,
                 project_creation_args, *args, **kwargs):

        # Create a graph to be shared by all operators
        graph = Graph()
        super(NNClassificationWorkflow, self).__init__(shell,
                                                       headless,
                                                       workflow_cmdline_args,
                                                       project_creation_args,
                                                       graph=graph,
                                                       *args,
                                                       **kwargs)
        self._applets = []
        self._workflow_cmdline_args = workflow_cmdline_args
        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        # parser.add_argument('--print-labels-by-slice', help="Print the number of labels for each Z-slice of each image.", action="store_true")

        # Parse the creation args: These were saved to the project file when this project was first created.
        parsed_creation_args, unused_args = parser.parse_known_args(
            project_creation_args)

        # Parse the cmdline args for the current session.
        parsed_args, unused_args = parser.parse_known_args(
            workflow_cmdline_args)
        # self.print_labels_by_slice = parsed_args.print_labels_by_slice

        data_instructions = (
            "Select your input data using the 'Raw Data' tab shown on the right.\n\n"
            "Power users: Optionally use the 'Prediction Mask' tab to supply a binary image that tells ilastik where it should avoid computations you don't need."
        )

        # Applets for training (interactive) workflow
        self.dataSelectionApplet = self.createDataSelectionApplet()
        opDataSelection = self.dataSelectionApplet.topLevelOperator

        # see role constants, above
        opDataSelection.DatasetRoles.setValue(
            NNClassificationWorkflow.ROLE_NAMES)

        self.nnClassificationApplet = NNClassApplet(self, "NNClassApplet")

        self.dataExportApplet = NNClassificationDataExportApplet(
            self, "Data Export")

        # Configure global DataExport settings
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.WorkingDirectory.connect(opDataSelection.WorkingDirectory)
        opDataExport.SelectionNames.setValue(self.EXPORT_NAMES)

        self.batchProcessingApplet = BatchProcessingApplet(
            self, "Batch Processing", self.dataSelectionApplet,
            self.dataExportApplet)

        # Expose for shell
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.nnClassificationApplet)
        self._applets.append(self.dataExportApplet)
        self._applets.append(self.batchProcessingApplet)

        if unused_args:
            # We parse the export setting args first.  All remaining args are considered input files by the input applet.
            self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args(
                unused_args)
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args(
                unused_args)
        else:
            self._batch_input_args = None
            self._batch_export_args = None

        if unused_args:
            logger.warn("Unused command-line args: {}".format(unused_args))

    def createDataSelectionApplet(self):
        """
        Can be overridden by subclasses, if they want to use
        special parameters to initialize the DataSelectionApplet.
        """
        data_instructions = "Select your input data using the 'Raw Data' tab shown on the right"
        return DataSelectionApplet(self,
                                   "Input Data",
                                   "Input Data",
                                   supportIlastik05Import=True,
                                   instructionText=data_instructions)

    def connectLane(self, laneIndex):
        """
        connects the operators for different lanes, each lane has a laneIndex starting at 0
        """
        opData = self.dataSelectionApplet.topLevelOperator.getLane(laneIndex)
        opNNclassify = self.nnClassificationApplet.topLevelOperator.getLane(
            laneIndex)
        opDataExport = self.dataExportApplet.topLevelOperator.getLane(
            laneIndex)

        # Input Image -> Feature Op
        #         and -> Classification Op (for display)
        opNNclassify.InputImage.connect(opData.Image)

        # Data Export connections
        opDataExport.RawData.connect(opData.ImageGroup[self.DATA_ROLE_RAW])
        opDataExport.RawDatasetInfo.connect(
            opData.DatasetGroup[self.DATA_ROLE_RAW])
        opDataExport.Inputs.resize(len(self.EXPORT_NAMES))
        opDataExport.Inputs[0].connect(
            opNNclassify.CachedPredictionProbabilities)
        # for slot in opDataExport.Inputs:
        #     assert slot.upstream_slot is not None

    def handleAppletStateUpdateRequested(self):
        """
        Overridden from Workflow base class
        Called when an applet has fired the :py:attr:`Applet.appletStateUpdateRequested`
        """
        # If no data, nothing else is ready.
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        input_ready = len(opDataSelection.ImageGroup
                          ) > 0 and not self.dataSelectionApplet.busy

        opNNClassification = self.nnClassificationApplet.topLevelOperator

        opDataExport = self.dataExportApplet.topLevelOperator

        predictions_ready = input_ready and len(opDataExport.Inputs) > 0
        # opDataExport.Inputs[0][0].ready()
        # (TinyVector(opDataExport.Inputs[0][0].meta.shape) > 0).all()

        # Problems can occur if the features or input data are changed during live update mode.
        # Don't let the user do that.
        live_update_active = not opNNClassification.FreezePredictions.value

        # The user isn't allowed to touch anything while batch processing is running.
        batch_processing_busy = self.batchProcessingApplet.busy

        self._shell.setAppletEnabled(self.dataSelectionApplet,
                                     not batch_processing_busy)
        self._shell.setAppletEnabled(self.nnClassificationApplet, input_ready
                                     and not batch_processing_busy)
        self._shell.setAppletEnabled(
            self.dataExportApplet, predictions_ready
            and not batch_processing_busy and not live_update_active)

        if self.batchProcessingApplet is not None:
            self._shell.setAppletEnabled(
                self.batchProcessingApplet, predictions_ready
                and not batch_processing_busy)

        # Lastly, check for certain "busy" conditions, during which we
        #  should prevent the shell from closing the project.
        busy = False
        busy |= self.dataSelectionApplet.busy
        busy |= self.nnClassificationApplet.busy
        busy |= self.dataExportApplet.busy
        busy |= self.batchProcessingApplet.busy
        self._shell.enableProjectChanges(not busy)
예제 #48
0
    def __init__(self, shell, headless, workflow_cmdline_args,
                 project_creation_workflow, *args, **kwargs):
        self.stored_classifier = None

        # Create a graph to be shared by all operators
        graph = Graph()

        super(MulticutWorkflow, self).__init__(shell,
                                               headless,
                                               workflow_cmdline_args,
                                               project_creation_workflow,
                                               graph=graph,
                                               *args,
                                               **kwargs)
        self._applets = []

        # -- DataSelection applet
        #
        self.dataSelectionApplet = DataSelectionApplet(
            self, "Input Data", "Input Data", forceAxisOrder=['zyxc', 'yxc'])

        # Dataset inputs
        opDataSelection = self.dataSelectionApplet.topLevelOperator
        opDataSelection.DatasetRoles.setValue(self.ROLE_NAMES)

        # -- Watershed applet
        #
        self.wsdtApplet = WsdtApplet(self, "DT Watershed", "DT Watershed")

        # -- Edge training applet
        #
        self.edgeTrainingApplet = EdgeTrainingApplet(self, "Edge Training",
                                                     "Edge Training")
        opEdgeTraining = self.edgeTrainingApplet.topLevelOperator
        DEFAULT_FEATURES = {
            self.ROLE_NAMES[self.DATA_ROLE_RAW]: ['standard_edge_mean']
        }
        opEdgeTraining.FeatureNames.setValue(DEFAULT_FEATURES)

        # -- Multicut applet
        #
        self.multicutApplet = MulticutApplet(self, "Multicut Segmentation",
                                             "Multicut Segmentation")

        # -- DataExport applet
        #
        self.dataExportApplet = DataExportApplet(self, "Data Export")
        self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        self.dataExportApplet.post_process_entire_export = self.post_process_entire_export

        # Configure global DataExport settings
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.WorkingDirectory.connect(opDataSelection.WorkingDirectory)
        opDataExport.SelectionNames.setValue(self.EXPORT_NAMES)

        # -- BatchProcessing applet
        #
        self.batchProcessingApplet = BatchProcessingApplet(
            self, "Batch Processing", self.dataSelectionApplet,
            self.dataExportApplet)

        # -- Expose applets to shell
        self._applets.append(self.dataSelectionApplet)
        self._applets.append(self.wsdtApplet)
        self._applets.append(self.edgeTrainingApplet)
        self._applets.append(self.multicutApplet)
        self._applets.append(self.dataExportApplet)
        self._applets.append(self.batchProcessingApplet)

        # -- Parse command-line arguments
        #    (Command-line args are applied in onProjectLoaded(), below.)
        if workflow_cmdline_args:
            self._data_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args(
                workflow_cmdline_args)
            self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args(
                unused_args)
        else:
            unused_args = None
            self._batch_input_args = None
            self._data_export_args = None

        if unused_args:
            logger.warning("Unused command-line args: {}".format(unused_args))

        if not self._headless:
            shell.currentAppletChanged.connect(self.handle_applet_changed)
    def __init__(self, shell, headless,
                 workflow_cmdline_args,
                 project_creation_args,
                 *args, **kwargs):
        graph = kwargs['graph'] if 'graph' in kwargs else Graph()
        if 'graph' in kwargs:
            del kwargs['graph']
        super(ObjectClassificationWorkflow, self).__init__(shell, headless, workflow_cmdline_args, project_creation_args, graph=graph, *args, **kwargs)
        self.stored_pixel_classifier = None
        self.stored_object_classifier = None

        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument('--fillmissing', help="use 'fill missing' applet with chosen detection method", choices=['classic', 'svm', 'none'], default='none')
        parser.add_argument('--filter', help="pixel feature filter implementation.", choices=['Original', 'Refactored', 'Interpolated'], default='Original')
        parser.add_argument('--nobatch', help="do not append batch applets", action='store_true', default=False)
        
        parsed_creation_args, unused_args = parser.parse_known_args(project_creation_args)

        self.fillMissing = parsed_creation_args.fillmissing
        self.filter_implementation = parsed_creation_args.filter

        parsed_args, unused_args = parser.parse_known_args(workflow_cmdline_args)
        if parsed_args.fillmissing != 'none' and parsed_creation_args.fillmissing != parsed_args.fillmissing:
            logger.error( "Ignoring --fillmissing cmdline arg.  Can't specify a different fillmissing setting after the project has already been created." )
        
        if parsed_args.filter != 'Original' and parsed_creation_args.filter != parsed_args.filter:
            logger.error( "Ignoring --filter cmdline arg.  Can't specify a different filter setting after the project has already been created." )

        self.batch = not parsed_args.nobatch

        self._applets = []

        self.pcApplet = None
        self.projectMetadataApplet = ProjectMetadataApplet()
        self._applets.append(self.projectMetadataApplet)

        self.setupInputs()
        
        if self.fillMissing != 'none':
            self.fillMissingSlicesApplet = FillMissingSlicesApplet(
                self, "Fill Missing Slices", "Fill Missing Slices", self.fillMissing)
            self._applets.append(self.fillMissingSlicesApplet)

        if isinstance(self, ObjectClassificationWorkflowPixel):
            self.input_types = 'raw'
        elif isinstance(self, ObjectClassificationWorkflowBinary):
            self.input_types = 'raw+binary'
        elif isinstance( self, ObjectClassificationWorkflowPrediction ):
            self.input_types = 'raw+pmaps'
        
        # our main applets
        self.objectExtractionApplet = ObjectExtractionApplet(workflow=self, name = "Object Feature Selection")
        self.objectClassificationApplet = ObjectClassificationApplet(workflow=self)
        self.dataExportApplet = ObjectClassificationDataExportApplet(self, "Object Information Export")
        self.dataExportApplet.set_exporting_operator(self.objectClassificationApplet.topLevelOperator)

        # Customization hooks
        self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        #self.dataExportApplet.prepare_lane_for_export = self.prepare_lane_for_export
        self.dataExportApplet.post_process_lane_export = self.post_process_lane_export
        self.dataExportApplet.post_process_entire_export = self.post_process_entire_export
        
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.WorkingDirectory.connect( self.dataSelectionApplet.topLevelOperator.WorkingDirectory )
        
        # See EXPORT_SELECTION_PREDICTIONS and EXPORT_SELECTION_PROBABILITIES, above
        export_selection_names = ['Object Predictions',
                                  'Object Probabilities',
                                  'Blockwise Object Predictions',
                                  'Blockwise Object Probabilities']
        if self.input_types == 'raw':
            # Re-configure to add the pixel probabilities option
            # See EXPORT_SELECTION_PIXEL_PROBABILITIES, above
            export_selection_names.append( 'Pixel Probabilities' )
        opDataExport.SelectionNames.setValue( export_selection_names )

        self._batch_export_args = None
        self._batch_input_args = None
        self._export_args = None
        self.batchProcessingApplet = None
        if self.batch:
            self.batchProcessingApplet = BatchProcessingApplet(self, 
                                                               "Batch Processing", 
                                                               self.dataSelectionApplet, 
                                                               self.dataExportApplet)
    
            if unused_args:
                # Additional export args (specific to the object classification workflow)
                export_arg_parser = argparse.ArgumentParser()
                export_arg_parser.add_argument( "--table_filename", help="The location to export the object feature/prediction CSV file.", required=False )
                export_arg_parser.add_argument( "--export_object_prediction_img", action="store_true" )
                export_arg_parser.add_argument( "--export_object_probability_img", action="store_true" )
                export_arg_parser.add_argument( "--export_pixel_probability_img", action="store_true" )
                
                # TODO: Support this, too, someday?
                #export_arg_parser.add_argument( "--export_object_label_img", action="store_true" )
                
                    
                self._export_args, unused_args = export_arg_parser.parse_known_args(unused_args)
                if self.input_types != 'raw' and self._export_args.export_pixel_probability_img:
                    raise RuntimeError("Invalid command-line argument: \n"\
                                       "--export_pixel_probability_img' can only be used with the combined "\
                                       "'Pixel Classification + Object Classification' workflow.")

                if sum([self._export_args.export_object_prediction_img,
                        self._export_args.export_object_probability_img,
                        self._export_args.export_pixel_probability_img]) > 1:
                    raise RuntimeError("Invalid command-line arguments: Only one type classification output can be exported at a time.")

                # We parse the export setting args first.  All remaining args are considered input files by the input applet.
                self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args( unused_args )
                self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args( unused_args )

                # For backwards compatibility, translate these special args into the standard syntax
                if self._export_args.export_object_prediction_img:
                    self._batch_input_args.export_source = "Object Predictions"
                if self._export_args.export_object_probability_img:
                    self._batch_input_args.export_source = "Object Probabilities"
                if self._export_args.export_pixel_probability_img:
                    self._batch_input_args.export_source = "Pixel Probabilities"


        self.blockwiseObjectClassificationApplet = BlockwiseObjectClassificationApplet(
            self, "Blockwise Object Classification", "Blockwise Object Classification")

        self._applets.append(self.objectExtractionApplet)
        self._applets.append(self.objectClassificationApplet)
        self._applets.append(self.dataExportApplet)
        if self.batchProcessingApplet:
            self._applets.append(self.batchProcessingApplet)
        self._applets.append(self.blockwiseObjectClassificationApplet)

        if unused_args:
            logger.warn("Unused command-line args: {}".format( unused_args ))
    def __init__(self, shell, headless,
                 workflow_cmdline_args,
                 project_creation_args,
                 *args, **kwargs):
        graph = kwargs.pop('graph') if 'graph' in kwargs else Graph()
        super().__init__(shell, headless, workflow_cmdline_args, project_creation_args, graph=graph, *args, **kwargs)
        self.stored_object_classifier = None

        # Parse workflow-specific command-line args
        parser = argparse.ArgumentParser()
        parser.add_argument('--fillmissing', help="use 'fill missing' applet with chosen detection method", choices=['classic', 'svm', 'none'], default='none')
        parser.add_argument('--nobatch', help="do not append batch applets", action='store_true', default=False)

        parsed_creation_args, unused_args = parser.parse_known_args(project_creation_args)

        self.fillMissing = parsed_creation_args.fillmissing

        parsed_args, unused_args = parser.parse_known_args(workflow_cmdline_args)
        if parsed_args.fillmissing != 'none' and parsed_creation_args.fillmissing != parsed_args.fillmissing:
            logger.error( "Ignoring --fillmissing cmdline arg.  Can't specify a different fillmissing setting after the project has already been created." )

        self.batch = not parsed_args.nobatch

        self._applets = []

        self.createInputApplets()

        
        if self.fillMissing != 'none':
            self.fillMissingSlicesApplet = FillMissingSlicesApplet(
                self, "Fill Missing Slices", "Fill Missing Slices", self.fillMissing)
            self._applets.append(self.fillMissingSlicesApplet)

        # our main applets
        self.objectExtractionApplet = ObjectExtractionApplet(workflow=self, name = "Object Feature Selection")
        self.objectClassificationApplet = ObjectClassificationApplet(workflow=self)
        self.dataExportApplet = ObjectClassificationDataExportApplet(self, "Object Information Export")
        self.dataExportApplet.set_exporting_operator(self.objectClassificationApplet.topLevelOperator)

        # Customization hooks
        self.dataExportApplet.prepare_for_entire_export = self.prepare_for_entire_export
        #self.dataExportApplet.prepare_lane_for_export = self.prepare_lane_for_export
        self.dataExportApplet.post_process_lane_export = self.post_process_lane_export
        self.dataExportApplet.post_process_entire_export = self.post_process_entire_export
        
        opDataExport = self.dataExportApplet.topLevelOperator
        opDataExport.WorkingDirectory.connect( self.dataSelectionApplet.topLevelOperator.WorkingDirectory )

        opDataExport.SelectionNames.setValue( self.ExportNames.asDisplayNameList() )

        self._batch_export_args = None
        self._batch_input_args = None
        self._export_args = None
        self.batchProcessingApplet = None


        self._applets.append(self.objectExtractionApplet)
        self._applets.append(self.objectClassificationApplet)
        self._applets.append(self.dataExportApplet)

        if self.batch:
            self.batchProcessingApplet = BatchProcessingApplet(self, 
                                                               "Batch Processing", 
                                                               self.dataSelectionApplet, 
                                                               self.dataExportApplet)
            self._applets.append(self.batchProcessingApplet)

            if unused_args:
                exportsArgParser, _ = self.exportsArgParser
                self._export_args, unused_args = exportsArgParser.parse_known_args(unused_args)

                # We parse the export setting args first.  All remaining args are considered input files by the input applet.
                self._batch_export_args, unused_args = self.dataExportApplet.parse_known_cmdline_args( unused_args )
                self._batch_input_args, unused_args = self.batchProcessingApplet.parse_known_cmdline_args( unused_args )

                # For backwards compatibility, translate these special args into the standard syntax
                self._batch_input_args.export_source = self._export_args.export_source

        self.blockwiseObjectClassificationApplet = BlockwiseObjectClassificationApplet(
            self, "Blockwise Object Classification", "Blockwise Object Classification")
        self._applets.append(self.blockwiseObjectClassificationApplet)

        if unused_args:
            logger.warning("Unused command-line args: {}".format( unused_args ))