class OpStreamingHdf5Reader(Operator):
    """
    The top-level operator for the data selection applet.
    """
    name = "OpStreamingHdf5Reader"
    category = "Reader"

    # The project hdf5 File object (already opened)
    Hdf5File = InputSlot(stype='hdf5File')

    # The internal path for project-local datasets
    InternalPath = InputSlot(stype='string')

    # Output data
    OutputImage = OutputSlot()

    H5EXTS = ['.h5', '.hdf5', '.ilp']

    class DatasetReadError(Exception):
        def __init__(self, internalPath):
            self.internalPath = internalPath
            self.msg = "Unable to open Hdf5 dataset: {}".format(internalPath)
            super(OpStreamingHdf5Reader.DatasetReadError,
                  self).__init__(self.msg)

    def __init__(self, *args, **kwargs):
        super(OpStreamingHdf5Reader, self).__init__(*args, **kwargs)
        self._hdf5File = None

    def setupOutputs(self):
        # Read the dataset meta-info from the HDF5 dataset
        self._hdf5File = self.Hdf5File.value
        internalPath = self.InternalPath.value

        if internalPath not in self._hdf5File:
            raise OpStreamingHdf5Reader.DatasetReadError(internalPath)

        dataset = self._hdf5File[internalPath]

        try:
            # Read the axistags property without actually importing the data
            axistagsJson = self._hdf5File[internalPath].attrs[
                'axistags']  # Throws KeyError if 'axistags' can't be found
            axistags = vigra.AxisTags.fromJSON(axistagsJson)
            axisorder = ''.join(tag.key for tag in axistags)
            if '?' in axisorder:
                raise KeyError('?')
        except KeyError:
            # No axistags found.
            axisorder = get_default_axisordering(dataset.shape)
            axistags = vigra.defaultAxistags(str(axisorder))

        assert len(axistags) == len( dataset.shape ),\
            "Mismatch between shape {} and axisorder {}".format( dataset.shape, axisorder )

        # Configure our slot meta-info
        self.OutputImage.meta.dtype = dataset.dtype.type
        self.OutputImage.meta.shape = dataset.shape
        self.OutputImage.meta.axistags = axistags

        # If the dataset specifies a datarange, add it to the slot metadata
        if 'drange' in self._hdf5File[internalPath].attrs:
            self.OutputImage.meta.drange = tuple(
                self._hdf5File[internalPath].attrs['drange'])

        # Same for display_mode
        if 'display_mode' in self._hdf5File[internalPath].attrs:
            self.OutputImage.meta.display_mode = str(
                self._hdf5File[internalPath].attrs['display_mode'])

        total_volume = numpy.prod(
            numpy.array(self._hdf5File[internalPath].shape))
        chunks = self._hdf5File[internalPath].chunks
        if not chunks and total_volume > 1e8:
            self.OutputImage.meta.inefficient_format = True
            logger.warning(
                "This dataset ({}{}) is NOT chunked.  "
                "Performance for 3D access patterns will be bad!".format(
                    self._hdf5File.filename, internalPath))
        if chunks:
            self.OutputImage.meta.ideal_blockshape = chunks

    def execute(self, slot, subindex, roi, result):
        t = time.time()
        assert self._hdf5File is not None
        # Read the desired data directly from the hdf5File
        key = roi.toSlice()
        hdf5File = self._hdf5File
        internalPath = self.InternalPath.value

        timer = None
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug("Reading HDF5 block: [{}, {}]".format(
                roi.start, roi.stop))
            timer = Timer()
            timer.unpause()

        if result.flags.c_contiguous:
            hdf5File[internalPath].read_direct(result[...], key)
        else:
            result[...] = hdf5File[internalPath][key]
        if logger.getEffectiveLevel() >= logging.DEBUG:
            t = 1000.0 * (time.time() - t)
            logger.debug("took %f msec." % t)

        if timer:
            timer.pause()
            logger.debug("Completed HDF5 read in {} seconds: [{}, {}]".format(
                timer.seconds(), roi.start, roi.stop))

    def propagateDirty(self, slot, subindex, roi):
        if slot == self.Hdf5File or slot == self.InternalPath:
            self.OutputImage.setDirty(slice(None))
Exemple #2
0
class OpClusterize(Operator):
    Input = InputSlot()
    OutputDatasetDescription = InputSlot()
    ProjectFilePath = InputSlot(stype="filestring")
    ConfigFilePath = InputSlot(stype="filestring")

    ReturnCode = OutputSlot()

    class TaskInfo(object):
        taskName = None
        command = None
        subregion = None

    def setupOutputs(self):
        self.ReturnCode.meta.dtype = bool
        self.ReturnCode.meta.shape = (1, )

    def execute(self, slot, subindex, roi, result):
        dtypeBytes = self._getDtypeBytes()
        totalBytes = dtypeBytes * numpy.prod(self.Input.meta.shape)
        totalMB = old_div(totalBytes, (1000 * 1000))
        logger.info(
            "Clusterizing computation of {} MB dataset, outputting according to {}"
            .format(totalMB, self.OutputDatasetDescription.value))

        configFilePath = self.ConfigFilePath.value
        self._config = parseClusterConfigFile(configFilePath)

        # Create the destination file if necessary
        blockwiseFileset, taskInfos = self._prepareDestination()

        try:
            # Figure out which work doesn't need to be recomputed (if any)
            unneeded_rois = []
            for roi in list(taskInfos.keys()):
                if blockwiseFileset.getBlockStatus(
                        roi[0]
                ) == BlockwiseFileset.BLOCK_AVAILABLE or blockwiseFileset.isBlockLocked(
                        roi[0]
                ):  # We don't attempt to process currently locked blocks.
                    unneeded_rois.append(roi)

            # Remove any tasks that we don't need to compute (they were finished in a previous run)
            for roi in unneeded_rois:
                logger.info("No need to run task: {} for roi: {}".format(
                    taskInfos[roi].taskName, roi))
                del taskInfos[roi]

            absWorkDir, _ = getPathVariants(
                self._config.server_working_directory,
                os.path.split(configFilePath)[0])
            if self._config.task_launch_server == "localhost":

                def localCommand(cmd):
                    cwd = os.getcwd()
                    os.chdir(absWorkDir)
                    subprocess.call(cmd, shell=True)
                    os.chdir(cwd)

                launchFunc = localCommand
            else:
                # We use fabric for executing remote tasks
                # Import it here because it isn't required that the nodes can use it.
                import fabric.api as fab

                @fab.hosts(self._config.task_launch_server)
                def remoteCommand(cmd):
                    with fab.cd(absWorkDir):
                        fab.run(cmd)

                launchFunc = functools.partial(fab.execute, remoteCommand)

            # Spawn each task
            for taskInfo in list(taskInfos.values()):
                logger.info("Launching node task: " + taskInfo.command)
                launchFunc(taskInfo.command)

            # Return immediately.  We do not attempt to monitor the task progress.
            result[0] = True
            return result
        finally:
            blockwiseFileset.close()

    def _prepareTaskInfos(self, roiList):
        # Divide up the workload into large pieces
        logger.info("Dividing into {} node jobs.".format(len(roiList)))

        taskInfos = collections.OrderedDict()
        for roiIndex, roi in enumerate(roiList):
            roi = (tuple(roi[0]), tuple(roi[1]))
            taskInfo = OpClusterize.TaskInfo()
            taskInfo.subregion = SubRegion(None, start=roi[0], stop=roi[1])

            taskName = "J{:02}".format(roiIndex)

            commandArgs = []
            commandArgs.append("--option_config_file=" +
                               self.ConfigFilePath.value)
            commandArgs.append("--project=" + self.ProjectFilePath.value)
            commandArgs.append('--_node_work_="' +
                               Roi.dumps(taskInfo.subregion) + '"')
            commandArgs.append("--process_name={}".format(taskName))
            commandArgs.append("--output_description_file={}".format(
                self.OutputDatasetDescription.value))

            # Check the command format string: We need to know where to put our args...
            commandFormat = self._config.command_format
            assert commandFormat.find("{task_args}") != -1

            # Output log directory might be a relative path (relative to config file)
            absLogDir, _ = getPathVariants(
                self._config.output_log_directory,
                os.path.split(self.ConfigFilePath.value)[0])
            if not os.path.exists(absLogDir):
                os.makedirs(absLogDir)
            taskOutputLogFilename = taskName + ".log"
            taskOutputLogPath = os.path.join(absLogDir, taskOutputLogFilename)

            allArgs = " " + " ".join(commandArgs) + " "
            taskInfo.taskName = taskName
            taskInfo.command = commandFormat.format(
                task_args=allArgs,
                task_name=taskName,
                task_output_file=taskOutputLogPath)
            taskInfos[roi] = taskInfo

        return taskInfos

    def _prepareDestination(self):
        """
        - If the result file doesn't exist yet, create it (and the dataset)
        - If the result file already exists, return a list of the rois that
        are NOT needed (their data already exists in the final output)
        """
        originalDescription = BlockwiseFileset.readDescription(
            self.OutputDatasetDescription.value)
        datasetDescription = copy.deepcopy(originalDescription)

        # Modify description fields as needed
        # -- axes
        datasetDescription.axes = "".join(
            list(self.Input.meta.getTaggedShape().keys()))
        assert set(originalDescription.axes) == set(datasetDescription.axes), (
            "Can't prepare destination dataset: original dataset description listed "
            "axes as {}, but actual output axes are {}".format(
                originalDescription.axes, datasetDescription.axes))

        # -- shape
        datasetDescription.view_shape = list(self.Input.meta.shape)
        # -- block_shape
        assert originalDescription.block_shape is not None
        originalBlockDims = collections.OrderedDict(
            list(zip(originalDescription.axes,
                     originalDescription.block_shape)))
        datasetDescription.block_shape = [
            originalBlockDims[a] for a in datasetDescription.axes
        ]
        datasetDescription.block_shape = list(
            map(
                min,
                list(zip(datasetDescription.block_shape,
                         self.Input.meta.shape))))
        # -- chunks
        if originalDescription.chunks is not None:
            originalChunkDims = collections.OrderedDict(
                list(zip(originalDescription.axes,
                         originalDescription.chunks)))
            datasetDescription.chunks = [
                originalChunkDims[a] for a in datasetDescription.axes
            ]
            datasetDescription.chunks = list(
                map(
                    min,
                    list(zip(datasetDescription.chunks,
                             self.Input.meta.shape))))
        # -- dtype
        if datasetDescription.dtype != self.Input.meta.dtype:
            dtype = self.Input.meta.dtype
            if type(dtype) is numpy.dtype:
                dtype = dtype.type
            datasetDescription.dtype = dtype().__class__.__name__

        # Create a unique hash for this blocking scheme.
        # If it changes, we can't use any previous data.
        sha = hashlib.sha1()
        sha.update(str(tuple(datasetDescription.block_shape)))
        sha.update(datasetDescription.axes)
        sha.update(datasetDescription.block_file_name_format)

        datasetDescription.hash_id = sha.hexdigest()

        if datasetDescription != originalDescription:
            descriptionFilePath = self.OutputDatasetDescription.value
            logger.info("Overwriting dataset description: {}".format(
                descriptionFilePath))
            BlockwiseFileset.writeDescription(descriptionFilePath,
                                              datasetDescription)
            with open(descriptionFilePath, "r") as f:
                logger.info(f.read())

        # Now open the dataset
        blockwiseFileset = BlockwiseFileset(
            self.OutputDatasetDescription.value)

        taskInfos = self._prepareTaskInfos(blockwiseFileset.getAllBlockRois())

        if blockwiseFileset.description.hash_id != originalDescription.hash_id:
            # Something about our blocking scheme changed.
            # Make sure all blocks are marked as NOT available.
            # (Just in case some were left over from a previous run.)
            for roi in list(taskInfos.keys()):
                blockwiseFileset.setBlockStatus(
                    roi[0], BlockwiseFileset.BLOCK_NOT_AVAILABLE)

        return blockwiseFileset, taskInfos

    def _determineCompletedBlocks(self, blockwiseFileset, taskInfos):
        finished_rois = []
        for roi in list(taskInfos.keys()):
            if blockwiseFileset.getBlockStatus(
                    roi[0]) == BlockwiseFileset.BLOCK_AVAILABLE:
                finished_rois.append(roi)
        return finished_rois

    def propagateDirty(self, slot, subindex, roi):
        self.ReturnCode.setDirty(slice(None))

    def _getDtypeBytes(self):
        """
        Return the size of the dataset dtype in bytes.
        """
        dtype = self.Input.meta.dtype
        if type(dtype) is numpy.dtype:
            # Make sure we're dealing with a type (e.g. numpy.float64),
            #  not a numpy.dtype
            dtype = dtype.type

        return dtype().nbytes
class OpExportSlot(Operator):
    """
    Export a slot 'as-is', i.e. no subregion, no dtype conversion, no normalization, no axis re-ordering, etc.
    For sequence export formats, the sequence is indexed by the axistags' FIRST axis.
    For example, txyzc produces a sequence of xyzc volumes.
    """
    Input = InputSlot()
    
    OutputFormat = InputSlot(value='hdf5') # string.  See formats, below
    OutputFilenameFormat = InputSlot() # A format string allowing {roi}, {t_start}, {t_stop}, etc (but not {nickname} or {dataset_dir})
    OutputInternalPath = InputSlot(value='exported_data')

    CoordinateOffset = InputSlot(optional=True) # Add an offset to the roi coordinates in the export path (useful if Input is a subregion of a larger dataset)

    ExportPath = OutputSlot()
    FormatSelectionErrorMsg = OutputSlot()

    _2d_exts = vigra.impex.listExtensions().split()    

    # List all supported formats
    _2d_formats = [FormatInfo(ext, ext, 2, 2) for ext in _2d_exts]
    _3d_sequence_formats = [FormatInfo(ext + ' sequence', ext, 3, 3) for ext in _2d_exts]
    _3d_volume_formats = [ FormatInfo('multipage tiff', 'tiff', 3, 3) ]
    _4d_sequence_formats = [ FormatInfo('multipage tiff sequence', 'tiff', 4, 4) ]
    nd_format_formats = [ FormatInfo('hdf5', 'h5', 0, 5),
                          FormatInfo('compressed hdf5', 'h5', 0, 5),
                          FormatInfo('numpy', 'npy', 0, 5),
                          FormatInfo('dvid', '', 2, 5),
                          FormatInfo('blockwise hdf5', 'json', 0, 5) ]
    
    ALL_FORMATS = _2d_formats + _3d_sequence_formats + _3d_volume_formats\
                + _4d_sequence_formats + nd_format_formats

    def __init__(self, *args, **kwargs):
        super( OpExportSlot, self ).__init__(*args, **kwargs)
        self.progressSignal = OrderedSignal()

        # Set up the impl function lookup dict
        export_impls = {}
        export_impls['hdf5'] = ('h5', self._export_hdf5)
        export_impls['compressed hdf5'] = ('h5', partial(self._export_hdf5, True))
        export_impls['numpy'] = ('npy', self._export_npy)
        export_impls['dvid'] = ('', self._export_dvid)
        export_impls['blockwise hdf5'] = ('json', self._export_blockwise_hdf5)
        
        for fmt in self._2d_formats:
            export_impls[fmt.name] = (fmt.extension, partial(self._export_2d, fmt.extension) )

        for fmt in self._3d_sequence_formats:
            export_impls[fmt.name] = (fmt.extension, partial(self._export_3d_sequence, fmt.extension) )

        export_impls['multipage tiff'] = ('tiff', self._export_multipage_tiff)
        export_impls['multipage tiff sequence'] = ('tiff', self._export_multipage_tiff_sequence)
        self._export_impls = export_impls

        self.Input.notifyMetaChanged( self._updateFormatSelectionErrorMsg )
    
    def setupOutputs(self):
        self.ExportPath.meta.shape = (1,)
        self.ExportPath.meta.dtype = object
        self.FormatSelectionErrorMsg.meta.shape = (1,)
        self.FormatSelectionErrorMsg.meta.dtype = object
        
        if self.OutputFormat.value in ('hdf5', 'compressed hdf5') and self.OutputInternalPath.value == "":
            self.ExportPath.meta.NOTREADY = True
    
    def execute(self, slot, subindex, roi, result):
        if slot == self.ExportPath:
            return self._executeExportPath(result)
        else:
            assert False, "Unknown output slot: {}".format( slot.name )

    def _executeExportPath(self, result):
        path_format = self.OutputFilenameFormat.value
        file_extension = self._export_impls[ self.OutputFormat.value ][0]
        
        # Remove existing extension (if present) and add the correct extension (if any)
        if file_extension:
            path_format = os.path.splitext(path_format)[0]
            path_format += '.' + file_extension

        # Provide the TOTAL path (including dataset name)
        if self.OutputFormat.value in ('hdf5', 'compressed hdf5'):
            path_format += '/' + self.OutputInternalPath.value

        roi = numpy.array( roiFromShape(self.Input.meta.shape) )
        
        # Intermediate state can cause coordinate offset and input shape to be mismatched.
        # Just don't use the offset if it looks wrong.
        # (The client will provide a valid offset later on.)
        if self.CoordinateOffset.ready() and len(self.CoordinateOffset.value) == len(roi[0]):
            offset = self.CoordinateOffset.value
            assert len(roi[0] == len(offset))
            roi += offset
        optional_replacements = {}
        optional_replacements['roi'] = list(map(tuple, roi))
        for key, (start, stop) in zip( self.Input.meta.getAxisKeys(), roi.transpose() ):
            optional_replacements[key + '_start'] = start
            optional_replacements[key + '_stop'] = stop
        formatted_path = format_known_keys( path_format, optional_replacements )
        result[0] = formatted_path
        return result

    def _updateFormatSelectionErrorMsg(self, *args):
        error_msg = self._get_format_selection_error_msg()
        self.FormatSelectionErrorMsg.setValue( error_msg )

    def _get_format_selection_error_msg(self, *args):
        """
        If the currently selected format does not support the input image format, 
        return an error message stating why. Otherwise, return an empty string.
        """
        if not self.Input.ready():
            return "Input not ready"
        output_format = self.OutputFormat.value

        # These cases support all combinations
        if output_format in ('hdf5', 'compressed hdf5' 'npy', 'blockwise hdf5'):
            return ""
        
        tagged_shape = self.Input.meta.getTaggedShape()
        axes = OpStackWriter.get_nonsingleton_axes_for_tagged_shape( tagged_shape )
        output_dtype = self.Input.meta.dtype

        if output_format == 'dvid':
            # dvid requires a channel axis, which must come last.
            # Internally, we transpose it before sending it over the wire
            if list(tagged_shape.keys())[-1] != 'c':
                return "DVID requires the last axis to be channel."

            # Make sure DVID supports this dtype/channel combo.
            from libdvid.voxels import VoxelsMetadata
            axiskeys = self.Input.meta.getAxisKeys()
            # We reverse the axiskeys because the export operator (see below) uses transpose_axes=True
            reverse_axiskeys = "".join(reversed( axiskeys ))
            reverse_shape = tuple(reversed(self.Input.meta.shape))
            metainfo = VoxelsMetadata.create_default_metadata( reverse_shape,
                                                               output_dtype,
                                                               reverse_axiskeys,
                                                               0.0,
                                                               'nanometers' )
            try:
                metainfo.determine_dvid_typename()
            except Exception as ex:
                return str(ex)
            else:
                return ""

        return FormatValidity.check(self.Input.meta.getTaggedShape(),
                                    self.Input.meta.dtype,
                                    output_format)

    def propagateDirty(self, slot, subindex, roi):
        if slot == self.OutputFormat or slot == self.OutputFilenameFormat:
            self.ExportPath.setDirty()
        if slot == self.OutputFormat:
            self._updateFormatSelectionErrorMsg()

    def run_export_to_array(self):
        """
        Export the slot data to an array, instead of to disk.
        The data is computed blockwise, as necessary.
        The result is returned.
        """
        self.progressSignal(0)
        opExport = OpExportToArray(parent=self)
        try:
            opExport.progressSignal.subscribe(self.progressSignal)
            opExport.Input.connect(self.Input)
            return opExport.run_export_to_array()
        finally:
            opExport.cleanUp()
            self.progressSignal(100)                
    
    def run_export(self):
        """
        Perform the export and WAIT for it to complete.
        If you want asynchronous execution, run this function in a request:
        
            req = Request( opExport.run_export )
            req.submit()
        """
        output_format = self.OutputFormat.value
        try:
            export_func = self._export_impls[output_format][1]
        except KeyError:
            raise Exception( "Unknown export format: {}".format( output_format ) )
        else:
            mkdir_p( PathComponents(self.ExportPath.value).externalDirectory )
            export_func()
    
    def _export_hdf5(self, compress=False):
        self.progressSignal( 0 )

        # Create and open the hdf5 file
        export_components = PathComponents(self.ExportPath.value)
        try:
            os.remove(export_components.externalPath)
        except OSError as ex:
            # It's okay if the file isn't there.
            if ex.errno != 2:
                raise
        try:
            with h5py.File(export_components.externalPath, 'w') as hdf5File:
                # Create a temporary operator to do the work for us
                opH5Writer = OpH5WriterBigDataset(parent=self)
                try:
                    opH5Writer.CompressionEnabled.setValue( compress )
                    opH5Writer.hdf5File.setValue( hdf5File )
                    opH5Writer.hdf5Path.setValue( export_components.internalPath )
                    opH5Writer.Image.connect( self.Input )
            
                    # The H5 Writer provides it's own progress signal, so just connect ours to it.
                    opH5Writer.progressSignal.subscribe( self.progressSignal )
    
                    # Perform the export and block for it in THIS THREAD.
                    opH5Writer.WriteImage[:].wait()
                finally:
                    opH5Writer.cleanUp()
                    self.progressSignal(100)
        except IOError as ex:
            import sys
            msg = "\nException raised when attempting to export to {}: {}\n"\
                  .format( export_components.externalPath, str(ex) )
            sys.stderr.write(msg)
            raise

    def _export_npy(self):
        self.progressSignal(0)
        export_path = self.ExportPath.value
        try:
            opWriter = OpNpyWriter( parent=self )
            opWriter.Filepath.setValue( export_path )
            opWriter.Input.connect( self.Input )
            
            # Run the export in this thread
            opWriter.write()
        finally:
            opWriter.cleanUp()
            self.progressSignal(100)
    
    def _export_dvid(self):
        self.progressSignal(0)
        export_path = self.ExportPath.value
        
        opExport = OpExportDvidVolume( transpose_axes=True, parent=self )
        try:
            opExport.Input.connect( self.Input )
            opExport.NodeDataUrl.setValue( export_path )
            
            # Run the export in this thread
            opExport.run_export()
        finally:
            opExport.cleanUp()
            self.progressSignal(100)
    
    def _export_blockwise_hdf5(self):
        raise NotImplementedError
    
    def _export_2d(self, fmt):
        self.progressSignal(0)
        export_path = self.ExportPath.value
        opExport = OpExport2DImage( parent=self )
        try:
            opExport.progressSignal.subscribe(self.progressSignal)
            opExport.Filepath.setValue( export_path )
            opExport.Input.connect( self.Input )
            
            # Run the export
            opExport.run_export()
        finally:
            opExport.cleanUp()
            self.progressSignal(100)
    
    def _export_3d_sequence(self, extension):
        self.progressSignal(0)
        export_path_base, export_path_ext = os.path.splitext( self.ExportPath.value )
        export_path_pattern = export_path_base + "." + extension
        
        try:
            opWriter = OpStackWriter( parent=self )
            opWriter.FilepathPattern.setValue( export_path_pattern )
            opWriter.Input.connect( self.Input )
            opWriter.progressSignal.subscribe( self.progressSignal )
            
            if self.CoordinateOffset.ready():
                step_axis = opWriter.get_nonsingleton_axes()[0]
                step_axis_index = self.Input.meta.getAxisKeys().index(step_axis)
                step_axis_offset = self.CoordinateOffset.value[step_axis_index]
                opWriter.SliceIndexOffset.setValue( step_axis_offset )

            # Run the export
            opWriter.run_export()
        finally:
            opWriter.cleanUp()
            self.progressSignal(100)
    
    def _export_multipage_tiff(self):
        self.progressSignal(0)
        export_path = self.ExportPath.value
        try:
            opExport = OpExportMultipageTiff( parent=self )
            opExport.Filepath.setValue( export_path )
            opExport.Input.connect( self.Input )
            opExport.progressSignal.subscribe( self.progressSignal )
            
            # Run the export
            opExport.run_export()
        finally:
            opExport.cleanUp()
            self.progressSignal(100)
        
    def _export_multipage_tiff_sequence(self):
        self.progressSignal(0)
        export_path_base, export_path_ext = os.path.splitext( self.ExportPath.value )
        export_path_pattern = export_path_base + ".tiff"
        
        try:
            opExport = OpExportMultipageTiffSequence( parent=self )
            opExport.FilepathPattern.setValue( export_path_pattern )
            opExport.Input.connect( self.Input )
            opExport.progressSignal.subscribe( self.progressSignal )
            
            if self.CoordinateOffset.ready():
                step_axis = opExport.get_nonsingleton_axes()[0]
                step_axis_index = self.Input.meta.getAxisKeys().index(step_axis)
                step_axis_offset = self.CoordinateOffset.value[step_axis_index]
                opExport.SliceIndexOffset.setValue( step_axis_offset )

            # Run the export
            opExport.run_export()
        finally:
            opExport.cleanUp()
            self.progressSignal(100)
Exemple #4
0
class OpMeanProjection(Operator):
    """
    Given an input image and max/min bounds,
    masks out (i.e. sets to zero) all pixels that fall outside the bounds.
    """
    name = "OpMeanProjection"
    category = "Pointwise"

    Input = InputSlot()

    Axis = InputSlot(value=0, stype="int")

    Output = OutputSlot()

    def __init__(self, *args, **kwargs):
        super(OpMeanProjection, self).__init__(*args, **kwargs)

        self._generation = {self.name: 0}

    def setupOutputs(self):
        # Copy the input metadata to both outputs
        self.Output.meta.assignFrom(self.Input.meta)

        self.Output.meta.axistags = vigra.AxisTags(*list(
            nanshe.util.iters.iter_with_skip_indices(self.Output.meta.axistags,
                                                     self.Axis.value)))

        self.Output.meta.shape = self.Output.meta.shape[:self.Axis.value] +\
                                 self.Output.meta.shape[self.Axis.value+1:]

        self.Output.meta.generation = self._generation

    def execute(self, slot, subindex, roi, result):
        axis = self.Axis.value

        assert (axis < len(self.Input.meta.shape))

        key = roi.toSlice()
        key = list(key)
        key = key[:axis] + [slice(None)] + key[axis:]
        key[axis] = nanshe.util.iters.reformat_slice(
            key[axis], self.Input.meta.shape[axis])
        key = tuple(key)

        raw = self.Input[key].wait()

        processed = raw.mean(axis=self.Axis.value)

        if slot.name == 'Output':
            result[...] = processed

    def setInSlot(self, slot, subindex, roi, value):
        pass

    def propagateDirty(self, slot, subindex, roi):
        if slot.name == "Input":
            self._generation[self.name] += 1

            axis = self.Axis.value

            slicing = roi.toSlice()
            slicing = list(slicing)
            slicing = slicing[:axis] + slicing[axis + 1:]
            slicing = tuple(slicing)

            self.Output.setDirty(slicing)
        elif slot.name == "Axis":
            self._generation[self.name] += 1
            self.Output.setDirty(slice(None))
        else:
            assert False, "Unknown dirty input slot"
class _OpVigraLabelVolume(Operator):
    """
    Operator that simply wraps vigra's labelVolume function.
    """
    name = "OpVigraLabelVolume"
    category = "Vigra"
    
    Input = InputSlot() 
    BackgroundValue = InputSlot(optional=True)
    
    Output = OutputSlot()
    
    def setupOutputs(self):
        inputShape = self.Input.meta.shape

        # Must have at most 1 time slice
        timeIndex = self.Input.meta.axistags.index('t')
        assert timeIndex == len(inputShape) or inputShape[timeIndex] == 1
        
        # Must have at most 1 channel
        channelIndex = self.Input.meta.axistags.channelIndex
        assert channelIndex == len(inputShape) or inputShape[channelIndex] == 1

        self.Output.meta.assignFrom(self.Input.meta)
        self.Output.meta.dtype = numpy.uint32
        
    def execute(self, slot, subindex, roi, destination):
        assert slot == self.Output
        
        resultView = destination.view( vigra.VigraArray )
        resultView.axistags = self.Input.meta.axistags
        
        inputData = self.Input(roi.start, roi.stop).wait()
        inputData = inputData.view(vigra.VigraArray)
        inputData.axistags = self.Input.meta.axistags

        # Drop the time axis, which vigra.labelVolume doesn't remove automatically
        axiskeys = [tag.key for tag in inputData.axistags]        
        if 't' in axiskeys:
            inputData = inputData.bindAxis('t', 0)
            resultView = resultView.bindAxis('t', 0)

        # Drop the channel axis, too.
        if 'c' in axiskeys:
            inputData = inputData.bindAxis('c', 0)
            resultView = resultView.bindAxis('c', 0)

        # I have no idea why, but vigra sometimes throws a precondition error if this line is present.
        # ...on the other hand, I can't remember why I added this line in the first place...
        # inputData = inputData.view(numpy.ndarray)

        if self.BackgroundValue.ready():
            bg = self.BackgroundValue.value
            if isinstance( bg, numpy.ndarray ):
                # If background value was given as a 1-element array, extract it.
                assert bg.size == 1
                bg = bg.squeeze()[()]
            if isinstance( bg, numpy.float ):
                bg = float(bg)
            else:
                bg = int(bg)
            if len(inputData.shape)==2:
                vigra.analysis.labelImageWithBackground(inputData, background_value=bg, out=resultView)
            else:
                vigra.analysis.labelVolumeWithBackground(inputData, background_value=bg, out=resultView)
        else:
            if len(inputData.shape)==2:
                vigra.analysis.labelImageWithBackground(inputData, out=resultView)
            else:
                vigra.analysis.labelVolumeWithBackground(inputData, out=resultView)
        
        return destination

    def propagateDirty(self, inputSlot, subindex, roi):
        if inputSlot == self.Input:
            # If anything changed, the whole image is now dirty 
            #  because a single pixel change can trigger a cascade of relabeling.
            self.Output.setDirty( slice(None) )
        elif inputSlot == self.BackgroundValue:
            self.Output.setDirty( slice(None) )
class OpTrainClassifierBlocked(Operator):
    """
    Owns two child training operators, for 'vectorwise' and 'pixelwise' classifier types.
    Chooses which one to use based on the type of ClassifierFactory provided as input.
    """
    Images = InputSlot(level=1)
    Labels = InputSlot(level=1)
    ClassifierFactory = InputSlot()
    nonzeroLabelBlocks = InputSlot(level=1)  # Used only in the pixelwise case.
    MaxLabel = InputSlot()

    Classifier = OutputSlot()

    def __init__(self, *args, **kwargs):
        super(OpTrainClassifierBlocked, self).__init__(*args, **kwargs)
        self.progressSignal = OrderedSignal()
        self._mode = None

        # Fully connect the vectorwise training operator
        self._opVectorwiseTrain = OpTrainVectorwiseClassifierBlocked(
            parent=self)
        self._opVectorwiseTrain.Images.connect(self.Images)
        self._opVectorwiseTrain.Labels.connect(self.Labels)
        self._opVectorwiseTrain.ClassifierFactory.connect(
            self.ClassifierFactory)
        self._opVectorwiseTrain.MaxLabel.connect(self.MaxLabel)
        self._opVectorwiseTrain.progressSignal.subscribe(self.progressSignal)

        # Fully connect the pixelwise training operator
        self._opPixelwiseTrain = OpTrainPixelwiseClassifierBlocked(parent=self)
        self._opPixelwiseTrain.Images.connect(self.Images)
        self._opPixelwiseTrain.Labels.connect(self.Labels)
        self._opPixelwiseTrain.ClassifierFactory.connect(
            self.ClassifierFactory)
        self._opPixelwiseTrain.nonzeroLabelBlocks.connect(
            self.nonzeroLabelBlocks)
        self._opPixelwiseTrain.MaxLabel.connect(self.MaxLabel)
        self._opPixelwiseTrain.progressSignal.subscribe(self.progressSignal)

    def setupOutputs(self):
        # Construct an inner operator depending on the type of classifier we'll be creating.
        classifier_factory = self.ClassifierFactory.value
        if issubclass(type(classifier_factory),
                      LazyflowVectorwiseClassifierFactoryABC):
            new_mode = 'vectorwise'
        elif issubclass(type(classifier_factory),
                        LazyflowPixelwiseClassifierFactoryABC):
            new_mode = 'pixelwise'
        else:
            raise Exception("Unknown classifier factory type: {}".format(
                type(classifier_factory)))

        if new_mode == self._mode:
            return

        self.Classifier.disconnect()
        self._mode = new_mode

        if self._mode == 'vectorwise':
            self.Classifier.connect(self._opVectorwiseTrain.Classifier)
        elif self._mode == 'pixelwise':
            self.Classifier.connect(self._opPixelwiseTrain.Classifier)

    def execute(self, slot, subindex, roi, result):
        assert False, "Shouldn't get here..."

    def propagateDirty(self, slot, subindex, roi):
        pass
class OpVectorwiseClassifierPredict(Operator):
    Image = InputSlot()
    LabelsCount = InputSlot()
    Classifier = InputSlot()

    # An entire prediction request is skipped if the mask is all zeros for the requested roi.
    # Otherwise, the request is serviced as usual and the mask is ignored.
    PredictionMask = InputSlot(optional=True)

    PMaps = OutputSlot()

    def __init__(self, *args, **kwargs):
        super(OpVectorwiseClassifierPredict, self).__init__(*args, **kwargs)

        # Make sure the entire image is dirty if the prediction mask is removed.
        self.PredictionMask.notifyUnready(lambda s: self.PMaps.setDirty())

    def setupOutputs(self):
        assert self.Image.meta.getAxisKeys()[-1] == 'c'

        nlabels = max(
            self.LabelsCount.value, 1
        )  #we'll have at least 2 labels once we actually predict something
        #not setting it to 0 here is friendlier to possible downstream
        #ilastik operators, setting it to 2 causes errors in pixel classification
        #(live prediction doesn't work when only two labels are present)

        self.PMaps.meta.assignFrom(self.Image.meta)
        self.PMaps.meta.dtype = numpy.float32
        self.PMaps.meta.shape = self.Image.meta.shape[:-1] + (
            nlabels, )  # FIXME: This assumes that channel is the last axis
        self.PMaps.meta.drange = (0.0, 1.0)

        ideal_blockshape = self.Image.meta.ideal_blockshape
        if ideal_blockshape is None:
            ideal_blockshape = (0, ) * len(self.Image.meta.shape)
        ideal_blockshape = list(ideal_blockshape)
        ideal_blockshape[-1] = self.PMaps.meta.shape[-1]
        self.PMaps.meta.ideal_blockshape = tuple(ideal_blockshape)

        output_channels = nlabels
        input_channels = self.Image.meta.shape[-1]
        # Temporarily consumed RAM includes the following:
        # >> result array: 4 * N output_channels
        # >> (times 2 due to temporary variable)
        # >> input data allocation
        classifier_factory = self.Classifier.meta.classifier_factory
        classifier_ram_per_pixelchannel = classifier_factory.estimated_ram_usage_per_requested_predictionchannel(
        )
        classifier_ram_per_pixel = classifier_ram_per_pixelchannel * output_channels
        feature_ram_per_pixel = max(self.Image.meta.dtype().nbytes,
                                    4) * input_channels
        self.PMaps.meta.ram_usage_per_requested_pixel = classifier_ram_per_pixel + feature_ram_per_pixel

    def execute(self, slot, subindex, roi, result):
        classifier = self.Classifier.value

        # Training operator may return 'None' if there was no data to train with
        skip_prediction = (classifier is None)

        # Shortcut: If the mask is totally zero, skip this request entirely
        if not skip_prediction and self.PredictionMask.ready():
            mask_roi = numpy.array((roi.start, roi.stop))
            mask_roi[:, -1:] = [[0], [1]]
            start, stop = map(tuple, mask_roi)
            mask = self.PredictionMask(start, stop).wait()
            skip_prediction = not numpy.any(mask)
            del mask

        if skip_prediction:
            result[:] = 0.0
            return result

        assert issubclass(type(classifier), LazyflowVectorwiseClassifierABC), \
            "Classifier is of type {}, which does not satisfy the LazyflowVectorwiseClassifierABC interface."\
            "".format( type(classifier) )

        key = roi.toSlice()
        newKey = key[:-1]
        newKey += (slice(0, self.Image.meta.shape[-1], None), )

        with Timer() as features_timer:
            input_data = self.Image[newKey].wait()

        input_data = numpy.asarray(input_data, numpy.float32)

        shape = input_data.shape
        prod = numpy.prod(shape[:-1])
        features = input_data.reshape((prod, shape[-1]))

        with Timer() as prediction_timer:
            probabilities = classifier.predict_probabilities(features)

        logger.debug( "Features took {} seconds, Prediction took {} seconds for roi: {} : {}"\
                      .format( features_timer.seconds(), prediction_timer.seconds(), roi.start, roi.stop ) )

        assert probabilities.shape[1] <= self.PMaps.meta.shape[-1], \
            "Error: Somehow the classifier has more label classes than expected:"\
            " Got {} classes, expected {} classes"\
            .format( probabilities.shape[1], self.PMaps.meta.shape[-1] )

        # We're expecting a channel for each label class.
        # If we didn't provide at least one sample for each label,
        #  we may get back fewer channels.
        if probabilities.shape[1] < self.PMaps.meta.shape[-1]:
            # Copy to an array of the correct shape
            # This is slow, but it's an unusual case
            assert probabilities.shape[-1] == len(classifier.known_classes)
            full_probabilities = numpy.zeros(probabilities.shape[:-1] +
                                             (self.PMaps.meta.shape[-1], ),
                                             dtype=numpy.float32)
            for i, label in enumerate(classifier.known_classes):
                full_probabilities[:, label - 1] = probabilities[:, i]

            probabilities = full_probabilities

        # Reshape to image
        probabilities.shape = shape[:-1] + (self.PMaps.meta.shape[-1], )

        # Copy only the prediction channels the client requested.
        result[...] = probabilities[..., roi.start[-1]:roi.stop[-1]]
        return result

    def propagateDirty(self, slot, subindex, roi):
        if slot == self.Classifier:
            self.logger.debug("classifier changed, setting dirty")
            self.PMaps.setDirty()
        elif slot == self.Image:
            self.PMaps.setDirty()
        elif slot == self.PredictionMask:
            self.PMaps.setDirty(roi.start, roi.stop)
class OpPixelClassification( Operator ):
    """
    Top-level operator for pixel classification
    """
    name="OpPixelClassification"
    category = "Top-level"
    
    # Graph inputs
    
    InputImages = InputSlot(level=1) # Original input data.  Used for display only.
    PredictionMasks = InputSlot(level=1, optional=True) # Routed to OpClassifierPredict.PredictionMask.  See there for details.

    LabelInputs = InputSlot(optional = True, level=1) # Input for providing label data from an external source
    
    FeatureImages = InputSlot(level=1) # Computed feature images (each channel is a different feature)
    CachedFeatureImages = InputSlot(level=1) # Cached feature data.

    FreezePredictions = InputSlot(stype='bool')
    ClassifierFactory = InputSlot(value=ParallelVigraRfLazyflowClassifierFactory(100))

    PredictionsFromDisk = InputSlot(optional=True, level=1)

    PredictionProbabilities = OutputSlot(level=1) # Classification predictions (via feature cache for interactive speed)
    PredictionProbabilitiesUint8 = OutputSlot(level=1) # Same thing, but converted to uint8 first

    PredictionProbabilityChannels = OutputSlot(level=2) # Classification predictions, enumerated by channel
    SegmentationChannels = OutputSlot(level=2) # Binary image of the final selections.
    
    LabelImages = OutputSlot(level=1) # Labels from the user
    NonzeroLabelBlocks = OutputSlot(level=1) # A list if slices that contain non-zero label values
    Classifier = OutputSlot() # We provide the classifier as an external output for other applets to use

    CachedPredictionProbabilities = OutputSlot(level=1) # Classification predictions (via feature cache AND prediction cache)

    HeadlessPredictionProbabilities = OutputSlot(level=1) # Classification predictions ( via no image caches (except for the classifier itself )
    HeadlessUint8PredictionProbabilities = OutputSlot(level=1) # Same as above, but 0-255 uint8 instead of 0.0-1.0 float32
    HeadlessUncertaintyEstimate = OutputSlot(level=1) # Same as uncertaintly estimate, but does not rely on cached data.

    UncertaintyEstimate = OutputSlot(level=1)
    
    SimpleSegmentation = OutputSlot(level=1) # For debug, for now

    # GUI-only (not part of the pipeline, but saved to the project)
    LabelNames = OutputSlot()
    LabelColors = OutputSlot()
    PmapColors = OutputSlot()
    Bookmarks = OutputSlot(level=1)

    NumClasses = OutputSlot()

    def setupOutputs(self):
        self.LabelNames.meta.dtype = object
        self.LabelNames.meta.shape = (1,)
        self.LabelColors.meta.dtype = object
        self.LabelColors.meta.shape = (1,)
        self.PmapColors.meta.dtype = object
        self.PmapColors.meta.shape = (1,)

    def __init__( self, *args, **kwargs ):
        """
        Instantiate all internal operators and connect them together.
        """
        super(OpPixelClassification, self).__init__(*args, **kwargs)
        
        # Default values for some input slots
        self.FreezePredictions.setValue(True)
        self.LabelNames.setValue( [] )
        self.LabelColors.setValue( [] )
        self.PmapColors.setValue( [] )

        # SPECIAL connection: The LabelInputs slot doesn't get it's data  
        #  from the InputImages slot, but it's shape must match.
        self.LabelInputs.connect( self.InputImages )

        # Hook up Labeling Pipeline
        self.opLabelPipeline = OpMultiLaneWrapper( OpLabelPipeline, parent=self, broadcastingSlotNames=['DeleteLabel'] )
        self.opLabelPipeline.RawImage.connect( self.InputImages )
        self.opLabelPipeline.LabelInput.connect( self.LabelInputs )
        self.opLabelPipeline.DeleteLabel.setValue( -1 )
        self.LabelImages.connect( self.opLabelPipeline.Output )
        self.NonzeroLabelBlocks.connect( self.opLabelPipeline.nonzeroBlocks )

        # Hook up the Training operator
        self.opTrain = OpTrainClassifierBlocked( parent=self )
        self.opTrain.ClassifierFactory.connect( self.ClassifierFactory )
        self.opTrain.Labels.connect( self.opLabelPipeline.Output )
        self.opTrain.Images.connect( self.FeatureImages )
        self.opTrain.nonzeroLabelBlocks.connect( self.opLabelPipeline.nonzeroBlocks )

        # Hook up the Classifier Cache
        # The classifier is cached here to allow serializers to force in
        #   a pre-calculated classifier (loaded from disk)
        self.classifier_cache = OpValueCache( parent=self )
        self.classifier_cache.name = "OpPixelClassification.classifier_cache"
        self.classifier_cache.inputs["Input"].connect(self.opTrain.outputs['Classifier'])
        self.classifier_cache.inputs["fixAtCurrent"].connect( self.FreezePredictions )
        self.Classifier.connect( self.classifier_cache.Output )

        # Hook up the prediction pipeline inputs
        self.opPredictionPipeline = OpMultiLaneWrapper( OpPredictionPipeline, parent=self )
        self.opPredictionPipeline.FeatureImages.connect( self.FeatureImages )
        self.opPredictionPipeline.CachedFeatureImages.connect( self.CachedFeatureImages )
        self.opPredictionPipeline.Classifier.connect( self.classifier_cache.Output )
        self.opPredictionPipeline.FreezePredictions.connect( self.FreezePredictions )
        self.opPredictionPipeline.PredictionsFromDisk.connect( self.PredictionsFromDisk )
        self.opPredictionPipeline.PredictionMask.connect( self.PredictionMasks )

        # Feature Selection Stuff
        self.opFeatureMatrixCaches = OpMultiLaneWrapper(OpFeatureMatrixCache, parent=self)
        self.opFeatureMatrixCaches.LabelImage.connect(self.opLabelPipeline.Output)
        self.opFeatureMatrixCaches.FeatureImage.connect(self.FeatureImages)
        self.opFeatureMatrixCaches.LabelImage.setDirty()  # do I still need this?

        
        def _updateNumClasses(*args):
            """
            When the number of labels changes, we MUST make sure that the prediction image changes its shape (the number of channels).
            Since setupOutputs is not called for mere dirty notifications, but is called in response to setValue(),
            we use this function to call setValue().
            """
            numClasses = len(self.LabelNames.value)
            self.opTrain.MaxLabel.setValue( numClasses )
            self.opPredictionPipeline.NumClasses.setValue( numClasses )
            self.NumClasses.setValue( numClasses )
        self.LabelNames.notifyDirty( _updateNumClasses )

        # Prediction pipeline outputs -> Top-level outputs
        self.PredictionProbabilities.connect( self.opPredictionPipeline.PredictionProbabilities )
        self.PredictionProbabilitiesUint8.connect( self.opPredictionPipeline.PredictionProbabilitiesUint8 )
        self.CachedPredictionProbabilities.connect( self.opPredictionPipeline.CachedPredictionProbabilities )
        self.HeadlessPredictionProbabilities.connect( self.opPredictionPipeline.HeadlessPredictionProbabilities )
        self.HeadlessUint8PredictionProbabilities.connect( self.opPredictionPipeline.HeadlessUint8PredictionProbabilities )
        self.PredictionProbabilityChannels.connect( self.opPredictionPipeline.PredictionProbabilityChannels )
        self.SegmentationChannels.connect( self.opPredictionPipeline.SegmentationChannels )
        self.UncertaintyEstimate.connect( self.opPredictionPipeline.UncertaintyEstimate )
        self.SimpleSegmentation.connect( self.opPredictionPipeline.SimpleSegmentation )
        self.HeadlessUncertaintyEstimate.connect( self.opPredictionPipeline.HeadlessUncertaintyEstimate )

        def inputResizeHandler( slot, oldsize, newsize ):
            if ( newsize == 0 ):
                self.Bookmarks.resize(0)
                self.LabelImages.resize(0)
                self.NonzeroLabelBlocks.resize(0)
                self.PredictionProbabilities.resize(0)
                self.CachedPredictionProbabilities.resize(0)
        self.InputImages.notifyResized( inputResizeHandler )

        # Debug assertions: Check to make sure the non-wrapped operators stayed that way.
        assert self.opTrain.Images.operator == self.opTrain

        def handleNewInputImage( multislot, index, *args ):
            def handleInputReady(slot):
                self._checkConstraints( index )
                self.setupCaches( multislot.index(slot) )
            multislot[index].notifyReady(handleInputReady)
                
        self.InputImages.notifyInserted( handleNewInputImage )

        # If any feature image changes shape, we need to verify that the 
        #  channels are consistent with the currently cached classifier
        # Otherwise, delete the currently cached classifier.
        def handleNewFeatureImage( multislot, index, *args ):
            def handleFeatureImageReady(slot):
                def handleFeatureMetaChanged(slot):
                    if ( self.classifier_cache.fixAtCurrent.value and
                         self.classifier_cache.Output.ready() and 
                         slot.meta.shape is not None ):
                        classifier = self.classifier_cache.Output.value
                        channel_names = slot.meta.channel_names
                        if classifier and classifier.feature_names != channel_names:
                            self.classifier_cache.resetValue()
                slot.notifyMetaChanged(handleFeatureMetaChanged)
            multislot[index].notifyReady(handleFeatureImageReady)
                
        self.FeatureImages.notifyInserted( handleNewFeatureImage )

        def handleNewMaskImage( multislot, index, *args ):
            def handleInputReady(slot):
                self._checkConstraints( index )
            multislot[index].notifyReady(handleInputReady)        
        self.PredictionMasks.notifyInserted( handleNewMaskImage )

        # All input multi-slots should be kept in sync
        # Output multi-slots will auto-sync via the graph
        multiInputs = [s for s in list(self.inputs.values()) if s.level >= 1]
        for s1 in multiInputs:
            for s2 in multiInputs:
                if s1 != s2:
                    def insertSlot( a, b, position, finalsize ):
                        a.insertSlot(position, finalsize)
                    s1.notifyInserted( partial(insertSlot, s2 ) )
                    
                    def removeSlot( a, b, position, finalsize ):
                        a.removeSlot(position, finalsize)
                    s1.notifyRemoved( partial(removeSlot, s2 ) )

    def setupCaches(self, imageIndex):
        numImages = len(self.InputImages)
        inputSlot = self.InputImages[imageIndex]
#        # Can't setup if all inputs haven't been set yet.
#        if numImages != len(self.FeatureImages) or \
#           numImages != len(self.CachedFeatureImages):
#            return
#        
#        self.LabelImages.resize(numImages)
        self.LabelInputs.resize(numImages)

        # Special case: We have to set up the shape of our label *input* according to our image input shape
        shapeList = list(self.InputImages[imageIndex].meta.shape)
        try:
            channelIndex = self.InputImages[imageIndex].meta.axistags.index('c')
            shapeList[channelIndex] = 1
        except:
            pass
        self.LabelInputs[imageIndex].meta.shape = tuple(shapeList)
        self.LabelInputs[imageIndex].meta.axistags = inputSlot.meta.axistags

    def _checkConstraints(self, laneIndex):
        """
        Ensure that all input images have the same number of channels.
        """
        if not self.InputImages[laneIndex].ready():
            return

        thisLaneTaggedShape = self.InputImages[laneIndex].meta.getTaggedShape()

        # Find a different lane and use it for comparison
        validShape = thisLaneTaggedShape
        for i, slot in enumerate(self.InputImages):
            if slot.ready() and i != laneIndex:
                validShape = slot.meta.getTaggedShape()
                break

        if 't' in thisLaneTaggedShape:
            del thisLaneTaggedShape['t']
        if 't' in validShape:
            del validShape['t']

        if validShape['c'] != thisLaneTaggedShape['c']:
            raise DatasetConstraintError(
                 "Pixel Classification",
                 "All input images must have the same number of channels.  "\
                 "Your new image has {} channel(s), but your other images have {} channel(s)."\
                 .format( thisLaneTaggedShape['c'], validShape['c'] ) )
            
        if len(validShape) != len(thisLaneTaggedShape):
            raise DatasetConstraintError(
                 "Pixel Classification",
                 "All input images must have the same dimensionality.  "\
                 "Your new image has {} dimensions (including channel), but your other images have {} dimensions."\
                 .format( len(thisLaneTaggedShape), len(validShape) ) )
        
        mask_slot = self.PredictionMasks[laneIndex]
        input_shape = self.InputImages[laneIndex].meta.shape
        if mask_slot.ready() and mask_slot.meta.shape[:-1] != input_shape[:-1]:
            raise DatasetConstraintError(
                 "Pixel Classification",
                 "If you supply a prediction mask, it must have the same shape as the input image."\
                 "Your input image has shape {}, but your mask has shape {}."\
                 .format( input_shape, mask_slot.meta.shape ) )
    
    def setInSlot(self, slot, subindex, roi, value):
        # Nothing to do here: All inputs that support __setitem__
        #   are directly connected to internal operators.
        pass

    def propagateDirty(self, slot, subindex, roi):
        # Nothing to do here: All outputs are directly connected to 
        #  internal operators that handle their own dirty propagation.
        pass

    def addLane(self, laneIndex):
        numLanes = len(self.InputImages)
        assert numLanes == laneIndex, "Image lanes must be appended."        
        self.InputImages.resize(numLanes+1)
        self.Bookmarks.resize(numLanes+1)
        self.Bookmarks[numLanes].setValue([]) # Default value
        
    def removeLane(self, laneIndex, finalLength):
        self.InputImages.removeSlot(laneIndex, finalLength)
        self.Bookmarks.removeSlot(laneIndex, finalLength)

    def getLane(self, laneIndex):
        return OperatorSubView(self, laneIndex)

    def importLabels(self, laneIndex, slot):
        # Load the data into the cache
        new_max = self.getLane( laneIndex ).opLabelPipeline.opLabelArray.ingestData( slot )

        # Add to the list of label names if there's a new max label
        old_names = self.LabelNames.value
        old_max = len(old_names)
        if new_max > old_max:
            new_names = old_names + ["Label {}".format(x) for x in range(old_max+1, new_max+1)]
            self.LabelNames.setValue(new_names)

            # Make some default colors, too
            default_colors = [(255,0,0),
                              (0,255,0),
                              (0,0,255),
                              (255,255,0),
                              (255,0,255),
                              (0,255,255),
                              (128,128,128),
                              (255, 105, 180),
                              (255, 165, 0),
                              (240, 230, 140) ]
            label_colors = self.LabelColors.value
            pmap_colors = self.PmapColors.value
            
            self.LabelColors.setValue( label_colors + default_colors[old_max:new_max] )
            self.PmapColors.setValue( pmap_colors + default_colors[old_max:new_max] )

    def mergeLabels(self, from_label, into_label):
        for laneIndex in range(len(self.InputImages)):
            self.getLane( laneIndex ).opLabelPipeline.opLabelArray.mergeLabels(from_label, into_label)

    def clearLabel(self, label_value):
        for laneIndex in range(len(self.InputImages)):
            self.getLane( laneIndex ).opLabelPipeline.opLabelArray.clearLabel(label_value)
class OpPredictionPipeline(OpPredictionPipelineNoCache):
    """
    This operator extends the cacheless prediction pipeline above with additional outputs for the GUI.
    (It uses caches for these outputs, and has an extra input for cached features.)
    """        
    FreezePredictions = InputSlot()
    CachedFeatureImages = InputSlot()

    PredictionProbabilities = OutputSlot()
    CachedPredictionProbabilities = OutputSlot()

    PredictionProbabilitiesUint8 = OutputSlot()
    
    PredictionProbabilityChannels = OutputSlot( level=1 )
    SegmentationChannels = OutputSlot( level=1 )
    UncertaintyEstimate = OutputSlot()

    def __init__(self, *args, **kwargs):
        super(OpPredictionPipeline, self).__init__( *args, **kwargs )

        # Random forest prediction using CACHED features.
        self.predict = OpClassifierPredict( parent=self )
        self.predict.name = "OpClassifierPredict"
        self.predict.Classifier.connect(self.Classifier) 
        self.predict.Image.connect(self.CachedFeatureImages)
        self.predict.PredictionMask.connect(self.PredictionMask)
        self.predict.LabelsCount.connect( self.NumClasses )
        self.PredictionProbabilities.connect( self.predict.PMaps )

        # Alternate headless output: uint8 instead of float.
        # Note that drange is automatically updated.        
        self.opConvertToUint8 = OpPixelOperator( parent=self )
        self.opConvertToUint8.Input.connect( self.predict.PMaps )
        self.opConvertToUint8.Function.setValue( lambda a: (255*a).astype(numpy.uint8) )
        self.PredictionProbabilitiesUint8.connect( self.opConvertToUint8.Output )

        # Prediction cache for the GUI
        self.prediction_cache_gui = OpSlicedBlockedArrayCache( parent=self )
        self.prediction_cache_gui.name = "prediction_cache_gui"
        self.prediction_cache_gui.inputs["fixAtCurrent"].connect( self.FreezePredictions )
        self.prediction_cache_gui.inputs["Input"].connect( self.predict.PMaps )
        self.CachedPredictionProbabilities.connect(self.prediction_cache_gui.Output )

        # Also provide each prediction channel as a separate layer (for the GUI)
        self.opPredictionSlicer = OpMultiArraySlicer2( parent=self )
        self.opPredictionSlicer.name = "opPredictionSlicer"
        self.opPredictionSlicer.Input.connect( self.prediction_cache_gui.Output )
        self.opPredictionSlicer.AxisFlag.setValue('c')
        self.PredictionProbabilityChannels.connect( self.opPredictionSlicer.Slices )
        
        self.opSegmentor = OpMaxChannelIndicatorOperator( parent=self )
        self.opSegmentor.Input.connect( self.prediction_cache_gui.Output )

        self.opSegmentationSlicer = OpMultiArraySlicer2( parent=self )
        self.opSegmentationSlicer.name = "opSegmentationSlicer"
        self.opSegmentationSlicer.Input.connect( self.opSegmentor.Output )
        self.opSegmentationSlicer.AxisFlag.setValue('c')
        self.SegmentationChannels.connect( self.opSegmentationSlicer.Slices )

        # Create a layer for uncertainty estimate
        self.opUncertaintyEstimator = OpEnsembleMargin( parent=self )
        self.opUncertaintyEstimator.Input.connect( self.prediction_cache_gui.Output )

        # Cache the uncertainty so we get zeros for uncomputed points
        self.opUncertaintyCache = OpSlicedBlockedArrayCache( parent=self )
        self.opUncertaintyCache.name = "opUncertaintyCache"
        self.opUncertaintyCache.Input.connect( self.opUncertaintyEstimator.Output )
        self.opUncertaintyCache.fixAtCurrent.connect( self.FreezePredictions )
        self.UncertaintyEstimate.connect( self.opUncertaintyCache.Output )

    def setupOutputs(self):
        # Set the blockshapes for each input image separately, depending on which axistags it has.
        axisOrder = [ tag.key for tag in self.FeatureImages.meta.axistags ]

        blockDimsX = { 't' : (1,1),
                       'z' : (256,256),
                       'y' : (256,256),
                       'x' : (1,1),
                       'c' : (100, 100) }

        blockDimsY = { 't' : (1,1),
                       'z' : (256,256),
                       'y' : (1,1),
                       'x' : (256,256),
                       'c' : (100,100) }

        blockDimsZ = { 't' : (1,1),
                       'z' : (1,1),
                       'y' : (256,256),
                       'x' : (256,256),
                       'c' : (100,100) }

        blockShapeX = tuple( blockDimsX[k][1] for k in axisOrder )
        blockShapeY = tuple( blockDimsY[k][1] for k in axisOrder )
        blockShapeZ = tuple( blockDimsZ[k][1] for k in axisOrder )

        self.prediction_cache_gui.BlockShape.setValue( (blockShapeX, blockShapeY, blockShapeZ) )
        self.opUncertaintyCache.BlockShape.setValue( (blockShapeX, blockShapeY, blockShapeZ) )

        assert self.opConvertToUint8.Output.meta.drange == (0,255)
class OpTrainCounter(Operator):
    name = "TrainCounter"
    description = "Train a random forest on multiple images"
    category = "Learning"

    # Definition of inputs:
    Images = InputSlot(level=1)
    ForegroundLabels = InputSlot(level=1)
    BackgroundLabels = InputSlot(level=1)
    nonzeroLabelBlocks = InputSlot(level=1)

    fixClassifier = InputSlot(stype="bool")
    Sigma = InputSlot(stype="float", value=2.0)
    Epsilon = InputSlot(stype="float")
    C = InputSlot(stype="float")
    SelectedOption = InputSlot(stype="object")
    Ntrees = InputSlot(stype="int")  # RF parameter
    MaxDepth = InputSlot(stype="object")  # RF parameter, None means grow until purity
    BoxConstraintRois = InputSlot(level=1, stype="list", value=[])
    BoxConstraintValues = InputSlot(level=1, stype="list", value=[])
    UpperBound = InputSlot()

    # Definition of the outputs:
    Classifier = OutputSlot()
    options = SVR.options
    availableOptions = [checkOption(option["req"]) for option in SVR.options]
    numRegressors = 4

    def __init__(self, *args, **kwargs):
        super(OpTrainCounter, self).__init__(*args, **kwargs)
        self.progressSignal = OrderedSignal()
        self._svr = SVR()
        params = self._svr.get_params()
        self.initInputs(params)
        self.Classifier.meta.dtype = object
        self.Classifier.meta.shape = (self.numRegressors,)

        # Normally, lane removal does not trigger a dirty notification.
        # But in this case, if the lane contained any label data whatsoever,
        #  the classifier needs to be marked dirty.
        # We know which slots contain (or contained) label data because they have
        # been 'touched' at some point (they became dirty at some point).
        self._touched_slots = set()
        def handle_new_lane( multislot, index, newlength ):
            def handle_dirty_lane( slot, roi ):
                self._touched_slots.add(slot)
            multislot[index].notifyDirty( handle_dirty_lane )
        self.ForegroundLabels.notifyInserted( handle_new_lane )
        self.BackgroundLabels.notifyInserted( handle_new_lane )

        def handle_remove_lane( multislot, index, newlength ):
            # If the lane we're removing contained
            # label data, then mark the downstream dirty
            if multislot[index] in self._touched_slots:
                self.Classifier.setDirty()
                self._touched_slots.remove(multislot[index])
        self.ForegroundLabels.notifyRemove( handle_remove_lane )
        self.BackgroundLabels.notifyRemove( handle_remove_lane )

    def initInputs(self, params):
        fix = False
        if self.fixClassifier.ready():
            fix = self.fixClassifier.value
        self.fixClassifier.setValue(True)
        self.Sigma.setValue(params["Sigma"])
        self.Epsilon.setValue(params["epsilon"])
        self.C.setValue(params["C"])
        self.Ntrees.setValue(params["ntrees"])
        self.MaxDepth.setValue(params["maxdepth"])
        self.SelectedOption.setValue(params["method"])

        self.fixClassifier.setValue(fix)

    def setupOutputs(self):
        if self.inputs["fixClassifier"].value == False:
            method = self.SelectedOption.value
            if type(method) is dict:
                method = method["method"]
            
            params = {"method" : method,
                      "Sigma": self.Sigma.value,
                      "epsilon" : self.Epsilon.value,
                      "C" : self.C.value,
                      "ntrees" : self.Ntrees.value,
                      "maxdepth" :self.MaxDepth.value
                     }
            self._svr.set_params(**params)
            #self.Classifier.setValue(self._svr)
            #self.outputs["Classifier"].meta.dtype = object
            #self.outputs["Classifier"].meta.shape = (self._forest_count,)



    #@traceLogged(logger, level=logging.INFO, msg="OpTrainCounter: Training Counting Regressor")
    def execute(self, slot, subindex, roi, result):

        progress = 0
        numImages = len(self.Images)
        self.progressSignal(progress)
        featMatrix=[]
        labelsMatrix=[]
        tagList = []

        
        #result[0] = self._svr

        for i,labels in enumerate(self.inputs["ForegroundLabels"]):
            if labels.meta.shape is not None:
                opGaussian = OpGaussianSmoothing(parent = self, graph = self.graph)
                opGaussian.sigma.setValue(self.Sigma.value)
                opGaussian.Input.connect(self.ForegroundLabels[i])
                blocks = self.inputs["nonzeroLabelBlocks"][i][0].wait()
                
                reqlistlabels = []
                reqlistbg = []
                reqlistfeat = []
                progress += 10 // numImages
                self.progressSignal(progress)
                
                for b in blocks[0]:
                    request = opGaussian.Output[b]
                    #request = labels[b]
                    featurekey = list(b)
                    featurekey[-1] = slice(None, None, None)
                    request2 = self.Images[i][featurekey]
                    request3 = self.inputs["BackgroundLabels"][i][b]
                    reqlistlabels.append(request)
                    reqlistfeat.append(request2)
                    reqlistbg.append(request3)

                traceLogger.debug("Requests prepared")

                numLabelBlocks = len(reqlistlabels)
                progress_outer = [progress]
                if numLabelBlocks > 0:
                    progressInc = (80 - 10)//(numLabelBlocks * numImages)

                def progressNotify(req):
                    progress_outer[0] += progressInc//2
                    self.progressSignal(progress_outer[0])

                for ir, req in enumerate(reqlistfeat):
                    req.notify_finished(progressNotify)
                    req.submit()

                for ir, req in enumerate(reqlistlabels):
                    req.notify_finished(progressNotify)
                    req.submit()

                for ir, req in enumerate(reqlistbg):
                    req.notify_finished(progressNotify)
                    req.submit()
                
                traceLogger.debug("Requests fired")
                

                #Fixme: Maybe later request only part of the region?

                #image=self.inputs["Images"][i][:].wait()
                for ir, req in enumerate(reqlistlabels):
                    
                    labblock = req.wait()
                    
                    image = reqlistfeat[ir].wait()
                    labbgblock = reqlistbg[ir].wait()
                    labblock = labblock.reshape((image.shape[:-1]))
                    image = image.reshape((-1, image.shape[-1]))
                    labbgindices = np.where(labbgblock == 2)            
                    labbgindices = np.ravel_multi_index(labbgindices, labbgblock.shape)
                    
                    newDot, mapping, tags = \
                    self._svr.prepareDataRefactored(labblock, labbgindices)
                    #self._svr.prepareData(labblock, smooth = True)

                    labels   = newDot[mapping]
                    features = image[mapping]

                    featMatrix.append(features)
                    labelsMatrix.append(labels)
                    tagList.append(tags)
                
                progress = progress_outer[0]

                traceLogger.debug("Requests processed")


        self.progressSignal(80 / numImages)
        if len(featMatrix) == 0 or len(labelsMatrix) == 0:
            result[:] = None

        else:
            posTags = [tag[0] for tag in tagList]
            negTags = [tag[1] for tag in tagList]
            numPosTags = np.sum(posTags)
            numTags = np.sum(posTags) + np.sum(negTags)
            fullFeatMatrix = np.ndarray((numTags, self.Images[0].meta.shape[-1]), dtype = np.float64)
            fullLabelsMatrix = np.ndarray((numTags), dtype = np.float64)
            fullFeatMatrix[:] = np.NAN
            fullLabelsMatrix[:] = np.NAN
            currPosCount = 0
            currNegCount = numPosTags
            for i, posCount in enumerate(posTags):
                fullFeatMatrix[currPosCount:currPosCount + posTags[i],:] = featMatrix[i][:posCount,:]
                fullLabelsMatrix[currPosCount:currPosCount + posTags[i]] = labelsMatrix[i][:posCount]
                fullFeatMatrix[currNegCount:currNegCount + negTags[i],:] = featMatrix[i][posCount:,:]
                fullLabelsMatrix[currNegCount:currNegCount + negTags[i]] = labelsMatrix[i][posCount:]
                currPosCount += posTags[i]
                currNegCount += negTags[i]


            assert(not np.isnan(np.sum(fullFeatMatrix)))

            fullTags = [np.sum(posTags), np.sum(negTags)]
            #pool = RequestPool()

            maxima = np.max(fullFeatMatrix, axis=0)
            minima = np.min(fullFeatMatrix, axis=0)
            normalizationFactors = (minima,maxima)
            



            boxConstraintList = []
            boxConstraints = None
            if self.BoxConstraintRois.ready() and self.BoxConstraintValues.ready():
                for i, slot in enumerate(zip(self.BoxConstraintRois,self.BoxConstraintValues)):
                    for constr, val in zip(slot[0].value, slot[1].value):
                        boxConstraintList.append((i, constr, val))
                if len(boxConstraintList) > 0:
                    boxConstraints = self.constructBoxConstraints(boxConstraintList)

            params = self._svr.get_params() 
            try:
                pool = RequestPool()
                def train_and_store(i):
                    result[i] = SVR(minmax = normalizationFactors, **params)
                    result[i].fitPrepared(fullFeatMatrix, fullLabelsMatrix, tags = fullTags, boxConstraints = boxConstraints, numRegressors
                         = self.numRegressors, trainAll = False)
                for i in range(self.numRegressors):
                    req = pool.request(partial(train_and_store, i))
                
                pool.wait()
                pool.clean()
            
            except:
                logger.error("ERROR: could not learn regressor")
                logger.error("fullFeatMatrix shape = {}, dtype = {}".format(fullFeatMatrix.shape, fullFeatMatrix.dtype) )
                logger.error("fullLabelsMatrix shape = {}, dtype = {}".format(fullLabelsMatrix.shape, fullLabelsMatrix.dtype) )
                raise
            finally:
                self.progressSignal(100) 

        return result

    def propagateDirty(self, slot, subindex, roi):
        if slot is not self.inputs["fixClassifier"] and self.inputs["fixClassifier"].value == False:
            self.outputs["Classifier"].setDirty((slice(None),))
    
    def constructBoxConstraints(self, constraints):
        
        try:
            shape = np.array([[stop - start for start, stop in zip(constr[0][1:-2], constr[1][1:-2])] for _, constr,_ in
                       constraints])
            taggedShape = self.Images[0].meta.getTaggedShape()
            numcols = taggedShape['c']
            shape = shape[:,0] * shape[:,1]
            shape = np.sum(shape,axis = 0)
            constraintmatrix = np.ndarray(shape = (shape, numcols))
            constraintindices = []
            constraintvalues =  []
            offset = 0
            for imagenumber, constr, value in constraints:
                    slicing = [slice(start,stop) for start, stop in zip(constr[0][1:-2], constr[1][1:-2])]
                    numrows = (slicing[0].stop - slicing[0].start) * (slicing[1].stop - slicing[1].start)
                    slicing.append(slice(None)) 
                    slicing = tuple(slicing)

                    constraintmatrix[offset:offset + numrows,:] = self.Images[imagenumber][slicing].wait().reshape((numrows,
                                                                                                          -1))
                    constraintindices.append(offset)
                    constraintvalues.append(value)
                    offset = offset + numrows
            constraintindices.append(offset)

            constraintvalues = np.array(constraintvalues, np.float64)
            constraintindices = np.array(constraintindices, np.int)

            boxConstraints = {"boxFeatures" : constraintmatrix, "boxValues" : constraintvalues, "boxIndices" :
                              constraintindices}
        except:
            boxConstraints = None
            logger.error("An error has occured with the box Constraints: {} ".format(constraints))
        
        return boxConstraints
class OpAnnotations(Operator):
    name = "Training"
    category = "other"

    BinaryImage = InputSlot()
    LabelImage = InputSlot()
    RawImage = InputSlot()
    ActiveTrack = InputSlot(stype='int', value=0)
    ObjectFeatures = InputSlot(stype=Opaque, rtype=List)
    DivisionProbabilities = InputSlot(stype=Opaque, rtype=List)
    DetectionProbabilities = InputSlot(stype=Opaque, rtype=List)
    MaxNumObj = InputSlot()
    ComputedFeatureNames = InputSlot(rtype=List, stype=Opaque)

    TrackImage = OutputSlot()
    Labels = OutputSlot(stype=Opaque, rtype=List)
    Divisions = OutputSlot(stype=Opaque, rtype=List)
    Appearances = OutputSlot(stype=Opaque)
    Disappearances = OutputSlot(stype=Opaque)
    UntrackedImage = OutputSlot()

    Annotations = OutputSlot(stype=Opaque)

    # Use a slot for storing the export settings in the project file.
    ExportSettings = OutputSlot()

    # Override functions ExportingOperator mixin
    def configure_table_export_settings(self, settings, selected_features):
        self.ExportSettings.setValue((settings, selected_features))

    def get_table_export_settings(self):
        if self.ExportSettings.ready():
            (settings, selected_features) = self.ExportSettings.value
            return (settings, selected_features)
        else:
            return None, None

    def __init__(self, parent=None, graph=None):
        super(OpAnnotations, self).__init__(parent=parent, graph=graph)
        self.labels = {}
        self.divisions = {}
        self.appearances = {}
        self.disappearances = {}

        self.Annotations.setValue(dict())
        self.Labels.setValue({})
        self.Divisions.setValue({})
        self.Appearances.setValue({})
        self.Disappearances.setValue({})

        self.RawImage.notifyReady(self._checkConstraints)
        self.BinaryImage.notifyReady(self._checkConstraints)

        self.export_progress_dialog = None
        self.ExportSettings.setValue((None, None))

    def setupOutputs(self):
        self.TrackImage.meta.assignFrom(self.LabelImage.meta)
        self.UntrackedImage.meta.assignFrom(self.LabelImage.meta)

        for t in range(self.LabelImage.meta.shape[0]):
            if t not in list(self.labels.keys()):
                self.labels[t] = {}

        self.Annotations.meta.dtype = object
        self.Annotations.meta.shape = (1, )

        self.Labels.meta.dtype = object
        self.Labels.meta.shape = (1, )

        self.Divisions.meta.dtype = object
        self.Divisions.meta.shape = (1, )

        self.Appearances.meta.dtype = object
        self.Appearances.meta.shape = (1, )

        self.Disappearances.meta.dtype = object
        self.Disappearances.meta.shape = (1, )

    def initOutputs(self):
        self.TrackImage.meta.assignFrom(self.LabelImage.meta)
        self.UntrackedImage.meta.assignFrom(self.LabelImage.meta)

        for t in range(self.LabelImage.meta.shape[0]):
            if t not in list(self.labels.keys()):
                self.labels[t] = {}
            if t not in list(self.appearances.keys()):
                self.appearances[t] = {}
            if t not in list(self.disappearances.keys()):
                self.disappearances[t] = {}

    def _checkConstraints(self, *args):
        if self.RawImage.ready():
            rawTaggedShape = self.RawImage.meta.getTaggedShape()
            if rawTaggedShape['t'] < 2:
                raise DatasetConstraintError(
                     "Tracking",
                     "For tracking, the dataset must have a time axis with at least 2 images.   "\
                     "Please load time-series data instead. See user documentation for details." )

        if self.LabelImage.ready():
            segmentationTaggedShape = self.LabelImage.meta.getTaggedShape()
            if segmentationTaggedShape['t'] < 2:
                raise DatasetConstraintError(
                     "Tracking",
                     "For tracking, the dataset must have a time axis with at least 2 images.   "\
                     "Please load time-series data instead. See user documentation for details." )

        if self.RawImage.ready() and self.LabelImage.ready():
            rawTaggedShape['c'] = None
            segmentationTaggedShape['c'] = None
            if dict(rawTaggedShape) != dict(segmentationTaggedShape):
                raise DatasetConstraintError("Tracking",
                     "For tracking, the raw data and the prediction maps must contain the same "\
                     "number of timesteps and the same shape.   "\
                     "Your raw image has a shape of (t, x, y, z, c) = {}, whereas your prediction image has a "\
                     "shape of (t, x, y, z, c) = {}"\
                     .format( self.RawImage.meta.shape, self.BinaryImage.meta.shape ) )

    def execute(self, slot, subindex, roi, result):
        key = roi.toSlice()
        if slot is self.Divisions:
            result = {}
            for trackid in list(self.divisions.keys()):
                (children, t_parent) = self.divisions[trackid]
                result[trackid] = (children, t_parent)
            return result

        if slot is self.Labels:
            result = {}
            for t in list(self.labels.keys()):
                result[t] = self.labels[t]

        elif slot is self.TrackImage:
            for t in range(roi.start[0], roi.stop[0]):
                if t not in list(self.labels.keys()):
                    result[t - roi.start[0], ...][:] = 0
                    return result

                result[t - roi.start[0],
                       ...] = self.LabelImage.get(roi).wait()[t - roi.start[0],
                                                              ...]
                result[t - roi.start[0], ...,
                       0] = self._relabel(result[t - roi.start[0], ..., 0],
                                          self.labels[t])

        elif slot is self.UntrackedImage:
            for t in range(roi.start[0], roi.stop[0]):
                result[t - roi.start[0],
                       ...] = self.LabelImage.get(roi).wait()[t - roi.start[0],
                                                              ...]
                labels_at = {}
                if t in list(self.labels.keys()):
                    labels_at = self.labels[t]
                result[t - roi.start[0], ..., 0] = self._relabelUntracked(
                    result[t - roi.start[0], ..., 0], labels_at)

        if slot.name == 'Annotations':
            annotations = self.Annotations[key].wait()
            result[...] = annotations
        elif slot.name == 'Appearances':
            appearances = self.Appearances[key].wait()
            result[...] = appearances
        elif slot.name == 'Disappearances':
            disappearances = self.Disappearances[key].wait()
            result[...] = disappearances

        return result

    def propagateDirty(self, slot, subindex, roi):
        if slot == self.LabelImage:
            self.labels = {}
            self.divisions = {}
            self.appearances = {}
            self.disappearances = {}
        elif slot.name == "Annotations":
            self.Annotations.setDirty(roi)
        elif slot.name == "Labels":
            self.Labels.setDirty(roi)
        elif slot.name == "Divisions":
            self.Divisions.setDirty(roi)
        elif slot.name == "Appearances":
            self.Appearances.setDirty(roi)
        elif slot.name == "Disappearances":
            self.Disappearances.setDirty(roi)
        # else:
        #     self.Labels.setDirty( slice(None) )
        #     self.Divisions.setDirty( slice(None) )
        #     self.Annotations.setDirty( slice(None) )

    def _relabel(self, volume, replace):
        mp = np.arange(0, np.amax(volume) + 1, dtype=volume.dtype)
        mp[1:] = 0
        labels = np.sort(vigra.analysis.unique(volume)).tolist()
        if 0 in labels:
            labels.remove(0)
        for label in labels:
            if label in replace and len(replace[label]) > 0:
                l = list(replace[label])[-1]
                if l == -1:
                    mp[label] = 2**16 - 1
                else:
                    mp[label] = l
        return mp[volume]

    def _relabelUntracked(self, volume, tracked_at):
        mp = np.arange(0, np.amax(volume) + 1, dtype=volume.dtype)
        mp[1:] = 1
        labels = np.sort(vigra.analysis.unique(volume)).tolist()
        if 0 in labels:
            labels.remove(0)
        for label in labels:
            if (label in list(
                    tracked_at.keys())) and (len(tracked_at[label]) > 0):
                mp[label] = 0
        return mp[volume]

    def _getObjects(self, trange, misdet_idx):
        filtered_labels = {}
        oid2tids = {}
        alltids = set()
        for t in range(trange[0], trange[1]):
            count = 0
            filtered_labels[t] = []
            oid2tids[t] = {}
            troi = SubRegion(self.LabelImage,
                             start=[
                                 t,
                             ] + [
                                 0,
                             ] * len(self.LabelImage.meta.shape[1:]),
                             stop=[
                                 t + 1,
                             ] + list(self.LabelImage.meta.shape[1:]))
            max_oid = np.max(self.LabelImage.get(troi).wait())
            for idx in range(max_oid + 1):
                oid = int(idx) + 1
                if t in list(self.labels.keys()) and oid in list(
                        self.labels[t].keys()):
                    if misdet_idx not in self.labels[t][oid]:
                        oid2tids[t][oid] = self.labels[t][oid]
                        for l in self.labels[t][oid]:
                            alltids.add(l)
                        count += 1

            logger.info("at timestep {}, {} traxels found".format(t, count))

        return oid2tids, alltids

    def save_export_progress_dialog(self, dialog):
        """
        Implements ExportOperator.save_export_progress_dialog
        Without this the progress dialog would be hidden after the export
        :param dialog: the ProgressDialog to save
        """
        self.export_progress_dialog = dialog

    @staticmethod
    def lookup_oid_for_tid(oid2tid, tid, t):
        mapping = oid2tid[t]
        for oid, tids in mapping.items():
            if tid in tids:
                return oid
        raise ValueError("TID {} at t={} not found!".format(tid, t))

    def do_export(self,
                  settings,
                  selected_features,
                  progress_slot,
                  lane_index,
                  filename_suffix=""):
        """
        Implements ExportOperator.do_export(settings, selected_features, progress_slot
        Most likely called from ExportOperator.export_object_data
        :param settings: the settings for the exporter, see
        :param selected_features:
        :param progress_slot:
        :param lane_index: Ignored. (This is a single-lane operator. It is the caller's responsibility to make sure he's calling the right lane.)
        :param filename_suffix: If provided, appended to the filename (before the extension).
        :return:
        """

        obj_count = list(objects_per_frame(self.LabelImage))  # slow
        divisions = self.divisions
        t_range = (0, self.LabelImage.meta.shape[
            self.LabelImage.meta.axistags.index("t")])
        oid2tid, _ = self._getObjects(t_range, None)  # slow
        tracks = [
            0 if list(map(len, list(i.values()))) == [] else max(
                list(map(len, list(i.values()))))
            for i in list(oid2tid.values())
        ]
        if tracks == []:
            max_tracks = 0
        else:
            max_tracks = max(tracks)
        ids = ilastik_ids(obj_count)

        file_path = settings["file path"]
        if filename_suffix:
            path, ext = os.path.splitext(file_path)
            file_path = path + "-" + filename_suffix + ext

        export_file = ExportFile(file_path)
        export_file.ExportProgress.subscribe(progress_slot)
        export_file.InsertionProgress.subscribe(progress_slot)

        export_file.add_columns("table", list(range(sum(obj_count))),
                                Mode.List, Default.KnimeId)
        export_file.add_columns("table", list(ids), Mode.List,
                                Default.IlastikId)
        export_file.add_columns(
            "table", oid2tid, Mode.IlastikTrackingTable, {
                "max": max_tracks,
                "counts": obj_count,
                "extra ids": {},
                "range": t_range
            })
        export_file.add_columns("table", self.ObjectFeatures,
                                Mode.IlastikFeatureTable,
                                {"selection": selected_features})

        if divisions:
            ott = partial(self.lookup_oid_for_tid, oid2tid)
            divs = [(value[1], ott(key, value[1]), key,
                     ott(value[0][0], value[1] + 1), value[0][0],
                     ott(value[0][1], value[1] + 1), value[0][1])
                    for key, value in sorted(iter(divisions.items()),
                                             key=itemgetter(0))]
            assert sum(Default.ManualDivMap) == len(divs[0])
            names = list(
                compress(Default.DivisionNames["names"], Default.ManualDivMap))
            export_file.add_columns("divisions",
                                    divs,
                                    Mode.List,
                                    extra={"names": names})

        if settings["file type"] == "h5":
            export_file.add_rois(Default.LabelRoiPath, self.LabelImage,
                                 "table", settings["margin"], "labeling")
            if settings["include raw"]:
                export_file.add_image(Default.RawPath, self.RawImage)
            else:
                export_file.add_rois(Default.RawRoiPath, self.RawImage,
                                     "table", settings["margin"])
        export_file.write_all(settings["file type"], settings["compression"])

        export_file.ExportProgress.unsubscribe(progress_slot)
        export_file.InsertionProgress.unsubscribe(progress_slot)
class OpLabelPreviewerRefactored(Operator):

    name = "LabelPreviewer"

    Images = InputSlot(level=1)
    Output = OutputSlot(level=1)
class OpPredictCounter(Operator):
    name = "PredictCounter"
    description = "Predict on multiple images"
    category = "Learning"

    # Definition of inputs:
    Image = InputSlot()
    Classifier = InputSlot()
    LabelsCount = InputSlot(stype='integer')

    # Definition of outputs:
    PMaps = OutputSlot()

    def setupOutputs(self):
        nlabels=self.inputs["LabelsCount"].value
        self.PMaps.meta.dtype = np.float32
        self.PMaps.meta.axistags = copy.copy(self.Image.meta.axistags)
        self.PMaps.meta.shape = self.Image.meta.shape[:-1] + (OpTrainCounter.numRegressors,) # FIXME: This assumes that channel is the last axis
        self.PMaps.meta.drange = (0.0, 1.0)

    def execute(self, slot, subindex, roi, result):
        t1 = time.time()
        key = roi.toSlice()
        nlabels=self.inputs["LabelsCount"].value

        traceLogger.debug("OpPredictRandomForest: Requesting classifier. roi={}".format(roi))
        forests=self.inputs["Classifier"][:].wait()

        if any(forest is None for forest in forests):
            # Training operator may return 'None' if there was no data to train with
            return np.zeros(np.subtract(roi.stop, roi.start), dtype=np.float32)[...]

        traceLogger.debug("OpPredictRandomForest: Got classifier")
        #assert RF.labelCount() == nlabels, "ERROR: OpPredictRandomForest, labelCount differs from true labelCount! %r vs. %r" % (RF.labelCount(), nlabels)

        newKey = key[:-1]
        newKey += (slice(0,self.inputs["Image"].meta.shape[-1],None),)

        res = self.inputs["Image"][newKey].wait()

        shape=res.shape
        prod = np.prod(shape[:-1])
        res.shape = (prod, shape[-1])
        features=res

        predictions = [0]*len(forests)

        t2 = time.time()

        pool = RequestPool()
        
        def predict_forest(i):
            predictions[i] = forests[i].predict(np.asarray(features, dtype = np.float32))
            predictions[i] = predictions[i].reshape(result.shape[:-1])


        for i,f in enumerate(forests):
            req = pool.request(partial(predict_forest,i))

        pool.wait()
        pool.clean()
        #predictions[0] = forests[0].predict(np.asarray(features, dtype = np.float32), normalize = False)
        #predictions[0] = predictions[0].reshape(result.shape)
        prediction=np.dstack(predictions)
        result[...] = prediction

        # If our LabelsCount is higher than the number of labels in the training set,
        # then our results aren't really valid.  FIXME !!!
        # Duplicate the last label's predictions
        #for c in range(result.shape[-1]):
        #    result[...,c] = prediction[...,min(c+key[-1].start, prediction.shape[-1]-1)]

        t3 = time.time()

        logger.debug("Predict took %fseconds, actual RF time was %fs, feature time was %fs" % (t3-t1, t3-t2, t2-t1))
        return result



    def propagateDirty(self, slot, subindex, roi):
        key = roi.toSlice()
        if slot == self.inputs["Classifier"]:
            logger.debug("OpPredictRandomForest: Classifier changed, setting dirty")
            if self.LabelsCount.ready() and self.LabelsCount.value > 0:
                self.outputs["PMaps"].setDirty(slice(None,None,None))
        elif slot == self.inputs["Image"]:
            nlabels=self.inputs["LabelsCount"].value
            if nlabels > 0:
                self.outputs["PMaps"].setDirty(key[:-1] + (slice(0,nlabels,None),))
        elif slot == self.inputs["LabelsCount"]:
            # When the labels count changes, we must resize the output
            if self.configured():
                # FIXME: It's ugly that we call the 'private' _setupOutputs() function here,
                #  but the output shape needs to change when this input becomes dirty,
                #  and the output change needs to be propagated to the rest of the graph.
                self._setupOutputs()
            self.outputs["PMaps"].setDirty(slice(None,None,None))
class OpSparseLabelArray(Operator, Cache):
    name = "Sparse Label Array"
    description = "simple cache for sparse label arrays"

    inputSlots = [
        InputSlot("Input", optional=True),
        InputSlot("shape"),
        InputSlot("eraser"),
        InputSlot("deleteLabel", optional=True)
    ]

    outputSlots = [
        OutputSlot("Output"),
        OutputSlot("nonzeroValues"),
        OutputSlot("nonzeroCoordinates"),
        OutputSlot("maxLabel")
    ]

    def __init__(self, *args, **kwargs):
        super(OpSparseLabelArray, self).__init__(*args, **kwargs)
        self.lock = threading.Lock()
        self._denseArray = None
        self._sparseNZ = None
        self._oldShape = (0, )
        self._maxLabel = 0

        # Now that we're initialized, it's safe to register with the memory manager
        self.registerWithMemoryManager()

    def usedMemory(self):
        if self._denseArray is not None:
            return self._denseArray.nbytes
        return 0

    def lastAccessTime(self):
        return 0
        #return self._last_access

    def generateReport(self, report):
        report.name = self.name
        #report.fractionOfUsedMemoryDirty = self.fractionOfUsedMemoryDirty()
        report.usedMemory = self.usedMemory()
        #report.lastAccessTime = self.lastAccessTime()
        report.dtype = self.Output.meta.dtype
        report.type = type(self)
        report.id = id(self)

    def setupOutputs(self):
        if (numpy.array(self._oldShape) != self.inputs["shape"].value).any():
            shape = self.inputs["shape"].value
            self._oldShape = shape
            self.outputs["Output"].meta.dtype = numpy.uint8
            self.outputs["Output"].meta.shape = shape

            # FIXME: Don't give arbitrary axistags.  Specify them correctly if you need them.
            #self.outputs["Output"].meta.axistags = vigra.defaultAxistags(len(shape))

            self.inputs["Input"].meta.shape = shape

            self.outputs["nonzeroValues"].meta.dtype = object
            self.outputs["nonzeroValues"].meta.shape = (1, )

            self.outputs["nonzeroCoordinates"].meta.dtype = object
            self.outputs["nonzeroCoordinates"].meta.shape = (1, )

            self._denseArray = numpy.zeros(shape, numpy.uint8)
            self._sparseNZ = blist.sorteddict()

        if self.inputs["deleteLabel"].ready(
        ) and self.inputs["deleteLabel"].value != -1:
            labelNr = self.inputs["deleteLabel"].value

            neutralElement = 0
            self.inputs["deleteLabel"].setValue(-1)  #reset state of inputslot
            self.lock.acquire()

            # Find the entries to remove
            updateNZ = numpy.nonzero(
                numpy.where(self._denseArray == labelNr, 1, 0))
            if len(updateNZ) > 0:
                # Convert to 1-D indexes for the raveled version
                updateNZRavel = numpy.ravel_multi_index(
                    updateNZ, self._denseArray.shape)
                # Zero out the entries we don't want any more
                self._denseArray.ravel()[updateNZRavel] = neutralElement
                # Remove the zeros from the sparse list
                for index in updateNZRavel:
                    self._sparseNZ.pop(index)
            # Labels are continuous values: Shift all higher label values down by 1.
            self._denseArray[:] = numpy.where(self._denseArray > labelNr,
                                              self._denseArray - 1,
                                              self._denseArray)
            self._maxLabel = self._denseArray.max()
            self.lock.release()
            self.outputs["nonzeroValues"].setDirty(slice(None))
            self.outputs["nonzeroCoordinates"].setDirty(slice(None))
            self.outputs["Output"].setDirty(slice(None))
            self.outputs["maxLabel"].setValue(self._maxLabel)

    def execute(self, slot, subindex, roi, result):
        key = roiToSlice(roi.start, roi.stop)

        self.lock.acquire()
        assert (
            self.inputs["eraser"].ready() == True
            and self.inputs["shape"].ready() == True
        ), "OpDenseSparseArray:  One of the neccessary input slots is not ready: shape: %r, eraser: %r" % (
            self.inputs["eraser"].ready(), self.inputs["shape"].ready())
        if slot.name == "Output":
            result[:] = self._denseArray[key]
        elif slot.name == "nonzeroValues":
            result[0] = numpy.array(self._sparseNZ.values())
        elif slot.name == "nonzeroCoordinates":
            result[0] = numpy.array(self._sparseNZ.keys())
        elif slot.name == "maxLabel":
            result[0] = self._maxLabel
        self.lock.release()
        return result

    def setInSlot(self, slot, subindex, roi, value):
        key = roi.toSlice()
        assert value.dtype == self._denseArray.dtype, "Labels must be {}".format(
            self._denseArray.dtype)
        assert isinstance(value, numpy.ndarray)
        if type(value) != numpy.ndarray:
            # vigra.VigraArray doesn't handle advanced indexing correctly,
            #   so convert to numpy.ndarray first
            value = value.view(numpy.ndarray)

        shape = self.inputs["shape"].value
        eraseLabel = self.inputs["eraser"].value
        neutralElement = 0

        self.lock.acquire()
        #fix slicing of single dimensions:
        start, stop = sliceToRoi(key, shape, extendSingleton=False)
        start = start.floor()._asint()
        stop = stop.floor()._asint()

        tempKey = roiToSlice(start - start, stop - start)  #, hardBind = True)

        stop += numpy.where(stop - start == 0, 1, 0)

        key = roiToSlice(start, stop)

        updateShape = tuple(stop - start)

        update = self._denseArray[key].copy()

        update[tempKey] = value

        startRavel = numpy.ravel_multi_index(numpy.array(start, numpy.int32),
                                             shape)

        #insert values into dict
        updateNZ = numpy.nonzero(numpy.where(update != neutralElement, 1, 0))
        updateNZRavelSmall = numpy.ravel_multi_index(updateNZ, updateShape)

        if isinstance(value, numpy.ndarray):
            valuesNZ = value.ravel()[updateNZRavelSmall]
        else:
            valuesNZ = value

        updateNZRavel = numpy.ravel_multi_index(updateNZ, shape)
        updateNZRavel += startRavel

        self._denseArray.ravel()[updateNZRavel] = valuesNZ

        valuesNZ = self._denseArray.ravel()[updateNZRavel]

        self._denseArray.ravel()[updateNZRavel] = valuesNZ

        td = blist.sorteddict(zip(updateNZRavel.tolist(), valuesNZ.tolist()))

        self._sparseNZ.update(td)

        #remove values to be deleted
        updateNZ = numpy.nonzero(numpy.where(update == eraseLabel, 1, 0))
        if len(updateNZ) > 0:
            updateNZRavel = numpy.ravel_multi_index(updateNZ, shape)
            updateNZRavel += startRavel
            self._denseArray.ravel()[updateNZRavel] = neutralElement
            for index in updateNZRavel:
                self._sparseNZ.pop(index)

        # Update our maxlabel
        self._maxLabel = self._denseArray.max()

        self.lock.release()

        # Set our max label dirty if necessary
        self.outputs["maxLabel"].setValue(self._maxLabel)
        self.outputs["Output"].setDirty(key)

    def propagateDirty(self, dirtySlot, subindex, roi):
        if dirtySlot == self.Input:
            self.Output.setDirty(roi)
        else:
            # All other inputs are single-value inputs that will trigger
            #  a new call to setupOutputs, which already sets the outputs dirty.
            # (See above.)
            pass
class OpTrainVectorwiseClassifierBlocked(Operator):
    Images = InputSlot(level=1)
    Labels = InputSlot(level=1)
    ClassifierFactory = InputSlot()
    MaxLabel = InputSlot()

    Classifier = OutputSlot()

    # Images[N] ---                                                                                         MaxLabel ------
    #              \                                                                                                       \
    # Labels[N] --> opFeatureMatrixCaches ---(FeatureImage[N])---> opConcatenateFeatureImages ---(label+feature matrix)---> OpTrainFromFeatures ---(Classifier)--->

    def __init__(self, *args, **kwargs):
        super(OpTrainVectorwiseClassifierBlocked,
              self).__init__(*args, **kwargs)
        self.progressSignal = OrderedSignal()

        self._opFeatureMatrixCaches = OperatorWrapper(OpFeatureMatrixCache,
                                                      parent=self)
        self._opFeatureMatrixCaches.LabelImage.connect(self.Labels)
        self._opFeatureMatrixCaches.FeatureImage.connect(self.Images)

        self._opConcatenateFeatureMatrices = OpConcatenateFeatureMatrices(
            parent=self)
        self._opConcatenateFeatureMatrices.FeatureMatrices.connect(
            self._opFeatureMatrixCaches.LabelAndFeatureMatrix)
        self._opConcatenateFeatureMatrices.ProgressSignals.connect(
            self._opFeatureMatrixCaches.ProgressSignal)

        self._opTrainFromFeatures = OpTrainClassifierFromFeatureVectors(
            parent=self)
        self._opTrainFromFeatures.ClassifierFactory.connect(
            self.ClassifierFactory)
        self._opTrainFromFeatures.LabelAndFeatureMatrix.connect(
            self._opConcatenateFeatureMatrices.ConcatenatedOutput)
        self._opTrainFromFeatures.MaxLabel.connect(self.MaxLabel)

        self.Classifier.connect(self._opTrainFromFeatures.Classifier)

        # Progress reporting
        def _handleFeatureProgress(progress):
            self.progressSignal(0.8 * progress)

        self._opConcatenateFeatureMatrices.progressSignal.subscribe(
            _handleFeatureProgress)

        def _handleTrainingComplete():
            self.progressSignal(100.0)

        self._opTrainFromFeatures.trainingCompleteSignal.subscribe(
            _handleTrainingComplete)

    def cleanUp(self):
        self.progressSignal.clean()
        self.Classifier.disconnect()
        super(OpTrainVectorwiseClassifierBlocked, self).cleanUp()

    def setupOutputs(self):
        pass  # Nothing to do; our output is connected to an internal operator.

    def execute(self, slot, subindex, roi, result):
        assert False, "Shouldn't get here..."

    def propagateDirty(self, slot, subindex, roi):
        pass
Exemple #16
0
class _OpThresholdOneLevel(Operator):
    name = "_OpThresholdOneLevel"

    InputImage = InputSlot()
    MinSize = InputSlot(stype='int', value=0)
    MaxSize = InputSlot(stype='int', value=1000000)
    Threshold = InputSlot(stype='float', value=0.5)

    Output = OutputSlot()

    #debug output
    BeforeSizeFilter = OutputSlot()

    def __init__(self, *args, **kwargs):
        super(_OpThresholdOneLevel, self).__init__(*args, **kwargs)

        self._opThresholder = OpPixelOperator(parent=self)
        self._opThresholder.Input.connect(self.InputImage)

        self._opLabeler = OpLabelVolume(parent=self)
        self._opLabeler.Method.setValue(_labeling_impl)
        self._opLabeler.Input.connect(self._opThresholder.Output)

        self.BeforeSizeFilter.connect(self._opLabeler.CachedOutput)

        self._opFilter = OpFilterLabels(parent=self)
        self._opFilter.Input.connect(self._opLabeler.CachedOutput)
        self._opFilter.MinLabelSize.connect(self.MinSize)
        self._opFilter.MaxLabelSize.connect(self.MaxSize)
        self._opFilter.BinaryOut.setValue(False)

        self.Output.connect(self._opFilter.Output)

    def setupOutputs(self):
        def thresholdToUint8(thresholdValue, a):
            drange = self.InputImage.meta.drange
            if drange is not None:
                assert drange[0] == 0,\
                    "Don't know how to threshold data with this drange."
                thresholdValue *= drange[1]
            if a.dtype == numpy.uint8:
                # In-place (numpy optimizes this!)
                a[:] = (a > thresholdValue)
                return a
            else:
                return (a > thresholdValue).astype(numpy.uint8)

        self._opThresholder.Function.setValue(
            partial(thresholdToUint8, self.Threshold.value))

        # self.Output already has metadata: it is directly connected to self._opFilter.Output

    def execute(self, slot, subindex, roi, result):
        assert False, "Shouldn't get here..."

    def propagateDirty(self, slot, subindex, roi):
        pass  # nothing to do here

    def setInSlot(self, slot, subindex, roi, value):
        # Nothing to do here.
        # Our Input slots are directly fed into the cache,
        #  so all calls to __setitem__ are forwarded automatically
        pass
class OpClassifierPredict(Operator):
    Image = InputSlot()
    LabelsCount = InputSlot()
    Classifier = InputSlot()

    # An entire prediction request is skipped if the mask is all zeros for the requested roi.
    # Otherwise, the request is serviced as usual and the mask is ignored.
    PredictionMask = InputSlot(optional=True)

    PMaps = OutputSlot()

    def __init__(self, *args, **kwargs):
        super(OpClassifierPredict, self).__init__(*args, **kwargs)
        self._mode = None
        self._prediction_op = None

    def setupOutputs(self):
        # Construct an inner operator depending on the type of classifier we'll be using.
        # We don't want to access the classifier directly here because that would trigger the full computation already.
        # Instead, we require the factory to be passed along with the classifier metadata.

        try:
            classifier_factory = self.Classifier.meta.classifier_factory
        except KeyError:
            raise Exception(
                "Classifier slot must include classifier factory as metadata.")

        if issubclass(classifier_factory.__class__,
                      LazyflowVectorwiseClassifierFactoryABC):
            new_mode = 'vectorwise'
        elif issubclass(classifier_factory.__class__,
                        LazyflowPixelwiseClassifierFactoryABC):
            new_mode = 'pixelwise'
        else:
            raise Exception("Unknown classifier factory type: {}".format(
                type(classifier_factory)))

        if new_mode == self._mode:
            return

        if self._mode is not None:
            self.PMaps.disconnect()
            self._prediction_op.cleanUp()
        self._mode = new_mode

        if self._mode == 'vectorwise':
            self._prediction_op = OpVectorwiseClassifierPredict(parent=self)
        elif self._mode == 'pixelwise':
            self._prediction_op = OpPixelwiseClassifierPredict(parent=self)

        self._prediction_op.PredictionMask.connect(self.PredictionMask)
        self._prediction_op.Image.connect(self.Image)
        self._prediction_op.LabelsCount.connect(self.LabelsCount)
        self._prediction_op.Classifier.connect(self.Classifier)
        self.PMaps.connect(self._prediction_op.PMaps)

    def execute(self, slot, subindex, roi, result):
        assert False, "Shouldn't get here..."

    def propagateDirty(self, slot, subindex, roi):
        if slot == self.Classifier:
            self.PMaps.setDirty()
Exemple #18
0
class _OpThresholdTwoLevels(Operator):
    name = "_OpThresholdTwoLevels"

    InputImage = InputSlot()
    MinSize = InputSlot(stype='int', value=0)
    MaxSize = InputSlot(stype='int', value=1000000)
    HighThreshold = InputSlot(stype='float', value=0.5)
    LowThreshold = InputSlot(stype='float', value=0.2)

    Output = OutputSlot()
    CachedOutput = OutputSlot()  # For the GUI (blockwise-access)

    # For serialization
    InputHdf5 = InputSlot(optional=True)
    OutputHdf5 = OutputSlot()
    CleanBlocks = OutputSlot()

    # Debug outputs
    BigRegions = OutputSlot()
    SmallRegions = OutputSlot()
    FilteredSmallLabels = OutputSlot()

    # Schematic:
    #
    #           HighThreshold                         MinSize,MaxSize                       --(cache)--> opColorize -> FilteredSmallLabels
    #                   \                                       \                     /
    #           opHighThresholder --> opHighLabeler --> opHighLabelSizeFilter                           Output
    #          /                   \          /                 \                                            \                         /
    # InputImage        --(cache)--> SmallRegions                    opSelectLabels -->opFinalLabelSizeFilter--> opCache --> CachedOutput
    #          \                                                              /                                           /       \
    #           opLowThresholder ----> opLowLabeler --------------------------                                       InputHdf5     --> OutputHdf5
    #                   /                \                                                                                        -> CleanBlocks
    #           LowThreshold            --(cache)--> BigRegions

    def __init__(self, *args, **kwargs):
        super(_OpThresholdTwoLevels, self).__init__(*args, **kwargs)

        self._opLowThresholder = OpPixelOperator(parent=self)
        self._opLowThresholder.Input.connect(self.InputImage)

        self._opHighThresholder = OpPixelOperator(parent=self)
        self._opHighThresholder.Input.connect(self.InputImage)

        self._opLowLabeler = OpLabelVolume(parent=self)
        self._opLowLabeler.Method.setValue(_labeling_impl)
        self._opLowLabeler.Input.connect(self._opLowThresholder.Output)

        self._opHighLabeler = OpLabelVolume(parent=self)
        self._opHighLabeler.Method.setValue(_labeling_impl)
        self._opHighLabeler.Input.connect(self._opHighThresholder.Output)

        self._opHighLabelSizeFilter = OpFilterLabels(parent=self)
        self._opHighLabelSizeFilter.Input.connect(
            self._opHighLabeler.CachedOutput)
        self._opHighLabelSizeFilter.MinLabelSize.connect(self.MinSize)
        self._opHighLabelSizeFilter.MaxLabelSize.connect(self.MaxSize)
        self._opHighLabelSizeFilter.BinaryOut.setValue(
            False)  # we do the binarization in opSelectLabels
        # this way, we get to display pretty colors

        self._opSelectLabels = OpSelectLabels(parent=self)
        self._opSelectLabels.BigLabels.connect(self._opLowLabeler.CachedOutput)
        self._opSelectLabels.SmallLabels.connect(
            self._opHighLabelSizeFilter.Output)

        # remove the remaining very large objects -
        # they might still be present in case a big object
        # was split into many small ones for the higher threshold
        # and they got reconnected again at lower threshold
        self._opFinalLabelSizeFilter = OpFilterLabels(parent=self)
        self._opFinalLabelSizeFilter.Input.connect(self._opSelectLabels.Output)
        self._opFinalLabelSizeFilter.MinLabelSize.connect(self.MinSize)
        self._opFinalLabelSizeFilter.MaxLabelSize.connect(self.MaxSize)
        self._opFinalLabelSizeFilter.BinaryOut.setValue(False)

        self._opCache = OpCompressedCache(parent=self)
        self._opCache.name = "_OpThresholdTwoLevels._opCache"
        self._opCache.InputHdf5.connect(self.InputHdf5)
        self._opCache.Input.connect(self._opFinalLabelSizeFilter.Output)

        # Connect our own outputs
        self.Output.connect(self._opFinalLabelSizeFilter.Output)
        self.CachedOutput.connect(self._opCache.Output)

        # Serialization outputs
        self.CleanBlocks.connect(self._opCache.CleanBlocks)
        self.OutputHdf5.connect(self._opCache.OutputHdf5)

        #self.InputChannel.connect( self._opChannelSelector.Output )

        # More debug outputs.  These all go through their own caches
        self._opBigRegionCache = OpCompressedCache(parent=self)
        self._opBigRegionCache.name = "_OpThresholdTwoLevels._opBigRegionCache"
        self._opBigRegionCache.Input.connect(self._opLowThresholder.Output)
        self.BigRegions.connect(self._opBigRegionCache.Output)

        self._opSmallRegionCache = OpCompressedCache(parent=self)
        self._opSmallRegionCache.name = "_OpThresholdTwoLevels._opSmallRegionCache"
        self._opSmallRegionCache.Input.connect(self._opHighThresholder.Output)
        self.SmallRegions.connect(self._opSmallRegionCache.Output)

        self._opFilteredSmallLabelsCache = OpCompressedCache(parent=self)
        self._opFilteredSmallLabelsCache.name = "_OpThresholdTwoLevels._opFilteredSmallLabelsCache"
        self._opFilteredSmallLabelsCache.Input.connect(
            self._opHighLabelSizeFilter.Output)
        self._opColorizeSmallLabels = OpColorizeLabels(parent=self)
        self._opColorizeSmallLabels.Input.connect(
            self._opFilteredSmallLabelsCache.Output)
        self.FilteredSmallLabels.connect(self._opColorizeSmallLabels.Output)

    def setupOutputs(self):
        def thresholdToUint8(thresholdValue, a):
            drange = self.InputImage.meta.drange
            if drange is not None:
                assert drange[0] == 0,\
                    "Don't know how to threshold data with this drange."
                thresholdValue *= drange[1]
            if a.dtype == numpy.uint8:
                # In-place (numpy optimizes this!)
                a[:] = (a > thresholdValue)
                return a
            else:
                return (a > thresholdValue).astype(numpy.uint8)

        self._opLowThresholder.Function.setValue(
            partial(thresholdToUint8, self.LowThreshold.value))
        self._opHighThresholder.Function.setValue(
            partial(thresholdToUint8, self.HighThreshold.value))

        # Output is already connected internally -- don't reassign new metadata
        # self.Output.meta.assignFrom(self.InputImage.meta)

        # Blockshape is the entire spatial volume (hysteresis thresholding is
        # a global operation)
        tagged_shape = self.Output.meta.getTaggedShape()
        tagged_shape['c'] = 1
        tagged_shape['t'] = 1
        self._opCache.BlockShape.setValue(tuple(tagged_shape.values()))
        self._opBigRegionCache.BlockShape.setValue(tuple(
            tagged_shape.values()))
        self._opSmallRegionCache.BlockShape.setValue(
            tuple(tagged_shape.values()))
        self._opFilteredSmallLabelsCache.BlockShape.setValue(
            tuple(tagged_shape.values()))

    def execute(self, slot, subindex, roi, result):
        assert False, "Shouldn't get here..."

    def propagateDirty(self, slot, subindex, roi):
        pass  # Nothing to do here

    def setInSlot(self, slot, subindex, roi, value):
        assert slot == self.InputHdf5,\
            "Invalid slot for setInSlot(): {}".format(slot.name)
class OpPixelwiseClassifierPredict(Operator):
    Image = InputSlot()
    LabelsCount = InputSlot()
    Classifier = InputSlot()

    # An entire prediction request is skipped if the mask is all zeros for the requested roi.
    # Otherwise, the request is serviced as usual and the mask is ignored.
    PredictionMask = InputSlot(optional=True)

    PMaps = OutputSlot()

    def __init__(self, *args, **kwargs):
        super(OpPixelwiseClassifierPredict, self).__init__(*args, **kwargs)

        # Make sure the entire image is dirty if the prediction mask is removed.
        self.PredictionMask.notifyUnready(lambda s: self.PMaps.setDirty())

    def setupOutputs(self):
        assert self.Image.meta.getAxisKeys()[-1] == 'c'

        nlabels = max(
            self.LabelsCount.value, 1
        )  #we'll have at least 2 labels once we actually predict something
        #not setting it to 0 here is friendlier to possible downstream
        #ilastik operators, setting it to 2 causes errors in pixel classification
        #(live prediction doesn't work when only two labels are present)

        self.PMaps.meta.dtype = numpy.float32
        self.PMaps.meta.axistags = copy.copy(self.Image.meta.axistags)
        self.PMaps.meta.shape = self.Image.meta.shape[:-1] + (
            nlabels, )  # FIXME: This assumes that channel is the last axis
        self.PMaps.meta.drange = (0.0, 1.0)

    def execute(self, slot, subindex, roi, result):
        classifier = self.Classifier.value

        # Training operator may return 'None' if there was no data to train with
        skip_prediction = (classifier is None)

        # Shortcut: If the mask is totally zero, skip this request entirely
        if not skip_prediction and self.PredictionMask.ready():
            mask_roi = numpy.array((roi.start, roi.stop))
            mask_roi[:, -1:] = [[0], [1]]
            start, stop = map(tuple, mask_roi)
            mask = self.PredictionMask(start, stop).wait()
            skip_prediction = not numpy.any(mask)

        if skip_prediction:
            result[:] = 0.0
            return result

        assert issubclass(type(classifier), LazyflowPixelwiseClassifierABC), \
            "Classifier is of type {}, which does not satisfy the LazyflowPixelwiseClassifierABC interface."\
            "".format( type(classifier) )

        upstream_roi = (roi.start, roi.stop)
        # Ask for the halo needed by the classifier
        axiskeys = self.Image.meta.getAxisKeys()
        halo_shape = classifier.get_halo_shape(axiskeys)
        assert len(halo_shape) == len(upstream_roi[0])
        assert halo_shape[
            -1] == 0, "Didn't expect a non-zero halo for channel dimension."

        # Expand block by halo, then clip to image bounds
        upstream_roi = numpy.array(upstream_roi)
        upstream_roi[0] -= halo_shape
        upstream_roi[1] += halo_shape
        upstream_roi = getIntersection(upstream_roi,
                                       roiFromShape(self.Image.meta.shape))
        upstream_roi = numpy.asarray(upstream_roi)

        # Determine how to extract the data from the result (without the halo)
        downstream_roi = numpy.array((roi.start, roi.stop))
        predictions_roi = downstream_roi[:, :-1] - upstream_roi[0, :-1]

        # Request all upstream channels
        input_channels = self.Image.meta.shape[-1]
        upstream_roi[:, -1] = [0, input_channels]

        # Request the data
        input_data = self.Image(*upstream_roi).wait()
        axistags = self.Image.meta.axistags
        probabilities = classifier.predict_probabilities_pixelwise(
            input_data, predictions_roi, axistags)

        # We're expecting a channel for each label class.
        # If we didn't provide at least one sample for each label,
        #  we may get back fewer channels.
        if probabilities.shape[-1] != self.PMaps.meta.shape[-1]:
            # Copy to an array of the correct shape
            # This is slow, but it's an unusual case
            assert probabilities.shape[-1] == len(classifier.known_classes)
            full_probabilities = numpy.zeros(probabilities.shape[:-1] +
                                             (self.PMaps.meta.shape[-1], ),
                                             dtype=numpy.float32)
            for i, label in enumerate(classifier.known_classes):
                full_probabilities[..., label - 1] = probabilities[..., i]

            probabilities = full_probabilities

        # Copy only the prediction channels the client requested.
        result[...] = probabilities[..., roi.start[-1]:roi.stop[-1]]
        return result

    def propagateDirty(self, slot, subindex, roi):
        if slot == self.Classifier:
            self.logger.debug("classifier changed, setting dirty")
            self.PMaps.setDirty()
        elif slot == self.Image:
            self.PMaps.setDirty()
        elif slot == self.PredictionMask:
            self.PMaps.setDirty(roi.start, roi.stop)
Exemple #20
0
class _OpCacheWrapper(Operator):
    name = "OpCacheWrapper"
    Input = InputSlot()

    Output = OutputSlot()

    InputHdf5 = InputSlot(optional=True)
    CleanBlocks = OutputSlot()
    OutputHdf5 = OutputSlot()

    def __init__(self, *args, **kwargs):
        super(_OpCacheWrapper, self).__init__(*args, **kwargs)
        op1 = OpReorderAxes(parent=self)
        op1.name = "op1"
        op2 = OpReorderAxes(parent=self)
        op2.name = "op2"

        op1.AxisOrder.setValue('xyzct')
        op2.AxisOrder.setValue('txyzc')

        op1.Input.connect(self.Input)
        self.Output.connect(op2.Output)

        self._op1 = op1
        self._op2 = op2
        self._cache = None

    def setupOutputs(self):
        self._disconnectInternals()

        # we need a new cache
        cache = OpCompressedCache(parent=self)
        cache.name = self.name + "WrappedCache"

        # connect cache outputs
        self.CleanBlocks.connect(cache.CleanBlocks)
        self.OutputHdf5.connect(cache.OutputHdf5)
        self._op2.Input.connect(cache.Output)

        # connect cache inputs
        cache.InputHdf5.connect(self.InputHdf5)
        cache.Input.connect(self._op1.Output)

        # set the cache block shape
        tagged_shape = self._op1.Output.meta.getTaggedShape()
        tagged_shape['t'] = 1
        tagged_shape['c'] = 1
        cacheshape = map(lambda k: tagged_shape[k], 'xyzct')
        if _labeling_impl == "lazy":
            #HACK hardcoded block shape
            blockshape = numpy.minimum(cacheshape, 256)
        else:
            # use full spatial volume if not lazy
            blockshape = cacheshape
        cache.BlockShape.setValue(tuple(blockshape))

        self._cache = cache

    def execute(self, slot, subindex, roi, result):
        assert False

    def propagateDirty(self, slot, subindex, roi):
        pass

    def setInSlot(self, slot, subindex, key, value):
        assert slot == self.InputHdf5,\
            "setInSlot not implemented for slot {}".format(slot.name)
        assert self._cache is not None,\
            "setInSlot called before input was configured"
        self._cache.setInSlot(self._cache.InputHdf5, subindex, key, value)

    def _disconnectInternals(self):
        self.CleanBlocks.disconnect()
        self.OutputHdf5.disconnect()
        self._op2.Input.disconnect()

        if self._cache is not None:
            self._cache.InputHdf5.disconnect()
            self._cache.Input.disconnect()
            del self._cache
Exemple #21
0
class OpAutocontextBatch(Operator):

    Classifiers = InputSlot(level=1)
    FeatureImage = InputSlot()
    MaxLabelValue = InputSlot()
    AutocontextIterations = InputSlot()

    PredictionProbabilities = OutputSlot()

    #PixelOnlyPredictions = OutputSlot()

    def __init__(self, *args, **kwargs):
        super(OpAutocontextBatch, self).__init__(*args, **kwargs)
        self.prediction_caches = None
        self.predictors = None
        #self.AutocontextIterations.notifyDirty(self.setupOperators)

    def setupOperators(self, *args, **kwargs):

        self.predictors = []
        self.prediction_caches = []

        #niter = len(self.Classifiers)
        niter = self.AutocontextIterations.value
        for i in range(niter):
            #predict = OperatorWrapper(OpPredictRandomForest, parent=self, parent=self)
            predict = OpPredictRandomForest(parent=self)
            self.predictors.append(predict)
            #prediction_cache = OperatorWrapper( OpSlicedBlockedArrayCache, parent=self, parent=self )
            prediction_cache = OpSlicedBlockedArrayCache(parent=self)
            self.prediction_caches.append(prediction_cache)

        # Setup autocontext features
        self.autocontextFeatures = []
        self.autocontextFeaturesMulti = []
        self.autocontext_caches = []
        self.featureStackers = []

        for i in range(niter - 1):
            features = createAutocontextFeatureOperators(self, False)
            self.autocontextFeatures.append(features)
            opMulti = Op50ToMulti(parent=self)
            self.autocontextFeaturesMulti.append(opMulti)
            opStacker = OpMultiArrayStacker(parent=self)
            opStacker.inputs["AxisFlag"].setValue("c")
            opStacker.inputs["AxisIndex"].setValue(3)
            self.featureStackers.append(opStacker)
            autocontext_cache = OpSlicedBlockedArrayCache(parent=self)
            self.autocontext_caches.append(autocontext_cache)

        # connect the features to predictors
        for i in range(niter - 1):
            for ifeat, feat in enumerate(self.autocontextFeatures[i]):
                feat.inputs['Input'].connect(self.prediction_caches[i].Output)
                print "Multi: Connecting an output", "Input%.2d" % (ifeat)
                self.autocontextFeaturesMulti[i].inputs[
                    "Input%.2d" % (ifeat)].connect(feat.outputs["Output"])
            # connect the pixel features to the same multislot
            print "Multi: Connecting an output", "Input%.2d" % (len(
                self.autocontextFeatures[i]))
            self.autocontextFeaturesMulti[i].inputs["Input%.2d" % (len(
                self.autocontextFeatures[i]))].connect(self.FeatureImage)
            # stack the autocontext features with pixel features
            self.featureStackers[i].inputs["Images"].connect(
                self.autocontextFeaturesMulti[i].outputs["Outputs"])
            # cache the stacks
            self.autocontext_caches[i].inputs["Input"].connect(
                self.featureStackers[i].outputs["Output"])
            self.autocontext_caches[i].inputs["fixAtCurrent"].setValue(False)

        for i in range(niter):

            self.predictors[i].inputs['Classifier'].connect(
                self.Classifiers[i])
            self.predictors[i].inputs['LabelsCount'].connect(
                self.MaxLabelValue)

            self.prediction_caches[i].inputs["fixAtCurrent"].setValue(False)
            self.prediction_caches[i].inputs["Input"].connect(
                self.predictors[i].PMaps)

        self.predictors[0].inputs['Image'].connect(self.FeatureImage)
        for i in range(1, niter):
            self.predictors[i].inputs['Image'].connect(
                self.autocontext_caches[i - 1].outputs["Output"])

        #self.PixelOnlyPredictions.connect(self.predictors[-1].PMaps)
        self.PredictionProbabilities.connect(self.predictors[0].PMaps)

    def setupOutputs(self):
        print "calling setupOutputs"

        if self.AutocontextIterations.ready() and self.predictors is None:
            self.setupOperators()

        # Set the blockshapes for each input image separately, depending on which axistags it has.
        axisOrder = [tag.key for tag in self.FeatureImage.meta.axistags]
        ## Pixel Cache blocks
        blockDimsX = {
            't': (1, 1),
            'z': (128, 256),
            'y': (128, 256),
            'x': (5, 5),
            'c': (1000, 1000)
        }

        blockDimsY = {
            't': (1, 1),
            'z': (128, 256),
            'y': (5, 5),
            'x': (128, 256),
            'c': (1000, 1000)
        }

        blockDimsZ = {
            't': (1, 1),
            'z': (5, 5),
            'y': (128, 256),
            'x': (128, 256),
            'c': (1000, 1000)
        }

        innerBlockShapeX = tuple(blockDimsX[k][0] for k in axisOrder)
        outerBlockShapeX = tuple(blockDimsX[k][1] for k in axisOrder)

        innerBlockShapeY = tuple(blockDimsY[k][0] for k in axisOrder)
        outerBlockShapeY = tuple(blockDimsY[k][1] for k in axisOrder)

        innerBlockShapeZ = tuple(blockDimsZ[k][0] for k in axisOrder)
        outerBlockShapeZ = tuple(blockDimsZ[k][1] for k in axisOrder)

        for cache in self.prediction_caches:
            cache.inputs["innerBlockShape"].setValue(
                (innerBlockShapeX, innerBlockShapeY, innerBlockShapeZ))
            cache.inputs["outerBlockShape"].setValue(
                (outerBlockShapeX, outerBlockShapeY, outerBlockShapeZ))

        for cache in self.autocontext_caches:
            cache.innerBlockShape.setValue(
                (innerBlockShapeX, innerBlockShapeY, innerBlockShapeZ))
            cache.outerBlockShape.setValue(
                (outerBlockShapeX, outerBlockShapeY, outerBlockShapeZ))

    '''
    def execute(self, slot, subindex, roi, result):
        if slot==self.PredictionProbabilities:
            #we shouldn't be here, it's for testing
            print "opBatchPredict, who is not ready?"
            print self.Classifiers.ready(), self.FeatureImage.ready(), self.AutocontextIterations.ready(), self.MaxLabelValue.ready()
            return
    '''

    def setInSlot(self, slot, subindex, roi, value):
        # Nothing to do here: All inputs that support __setitem__
        #   are directly connected to internal operators.
        pass

    def propagateDirty(self, inputSlot, subindex, key):
        # Nothing to do here: All outputs are directly connected to
        #  internal operators that handle their own dirty propagation.
        pass
Exemple #22
0
class OpThresholdTwoLevels(Operator):
    name = "OpThresholdTwoLevels"

    RawInput = InputSlot(optional=True)  # Display only

    InputImage = InputSlot()
    MinSize = InputSlot(stype='int', value=10)
    MaxSize = InputSlot(stype='int', value=1000000)
    HighThreshold = InputSlot(stype='float', value=0.5)
    LowThreshold = InputSlot(stype='float', value=0.2)
    SingleThreshold = InputSlot(stype='float', value=0.5)
    SmootherSigma = InputSlot(value={'x': 1.0, 'y': 1.0, 'z': 1.0})
    Channel = InputSlot(value=0)
    CurOperator = InputSlot(stype='int', value=0)

    ## Graph-Cut options ##

    SingleThresholdGC = InputSlot(stype='float', value=0.5)

    Beta = InputSlot(value=.2)

    # apply thresholding before graph-cut
    UsePreThreshold = InputSlot(stype='bool', value=True)

    # margin around single object (only graph-cut)
    Margin = InputSlot(value=numpy.asarray((20, 20, 20)))

    ## Output slots ##

    Output = OutputSlot()

    CachedOutput = OutputSlot()  # For the GUI (blockwise-access)

    # For serialization
    InputHdf5 = InputSlot(optional=True)

    CleanBlocks = OutputSlot()

    OutputHdf5 = OutputSlot()

    ## Debug outputs

    InputChannel = OutputSlot()
    Smoothed = OutputSlot()
    BigRegions = OutputSlot()
    SmallRegions = OutputSlot()
    FilteredSmallLabels = OutputSlot()
    BeforeSizeFilter = OutputSlot()

    def __init__(self, *args, **kwargs):
        super(OpThresholdTwoLevels, self).__init__(*args, **kwargs)

        self.InputImage.notifyReady(self.checkConstraints)

        self._opReorder1 = OpReorderAxes(parent=self)
        self._opReorder1.AxisOrder.setValue('txyzc')
        self._opReorder1.Input.connect(self.InputImage)

        self._opChannelSelector = OpSingleChannelSelector(parent=self)
        self._opChannelSelector.Input.connect(self._opReorder1.Output)
        self._opChannelSelector.Index.connect(self.Channel)

        # anisotropic gauss
        self._opSmoother = OpAnisotropicGaussianSmoothing5d(parent=self)
        self._opSmoother.Sigmas.connect(self.SmootherSigma)
        self._opSmoother.Input.connect(self._opChannelSelector.Output)

        # debug output
        self.Smoothed.connect(self._opSmoother.Output)

        # single threshold operator
        self.opThreshold1 = _OpThresholdOneLevel(parent=self)
        self.opThreshold1.Threshold.connect(self.SingleThreshold)
        self.opThreshold1.MinSize.connect(self.MinSize)
        self.opThreshold1.MaxSize.connect(self.MaxSize)

        # double threshold operator
        self.opThreshold2 = _OpThresholdTwoLevels(parent=self)
        self.opThreshold2.MinSize.connect(self.MinSize)
        self.opThreshold2.MaxSize.connect(self.MaxSize)
        self.opThreshold2.LowThreshold.connect(self.LowThreshold)
        self.opThreshold2.HighThreshold.connect(self.HighThreshold)

        if haveGraphCut():
            self.opThreshold1GC = _OpThresholdOneLevel(parent=self)
            self.opThreshold1GC.Threshold.connect(self.SingleThresholdGC)
            self.opThreshold1GC.MinSize.connect(self.MinSize)
            self.opThreshold1GC.MaxSize.connect(self.MaxSize)

            self.opObjectsGraphCut = OpObjectsSegment(parent=self)
            self.opObjectsGraphCut.Prediction.connect(self.Smoothed)
            self.opObjectsGraphCut.LabelImage.connect(
                self.opThreshold1GC.Output)
            self.opObjectsGraphCut.Beta.connect(self.Beta)
            self.opObjectsGraphCut.Margin.connect(self.Margin)

            self.opGraphCut = OpGraphCut(parent=self)
            self.opGraphCut.Prediction.connect(self.Smoothed)
            self.opGraphCut.Beta.connect(self.Beta)

        self._op5CacheOutput = OpReorderAxes(parent=self)

        self._opReorder2 = OpReorderAxes(parent=self)
        self.Output.connect(self._opReorder2.Output)

        #cache our own output, don't propagate from internal operator
        self._cache = _OpCacheWrapper(parent=self)
        self._cache.name = "OpThresholdTwoLevels.OpCacheWrapper"
        self._cache.Input.connect(self.Output)
        self.CachedOutput.connect(self._cache.Output)

        # Serialization slots
        self._cache.InputHdf5.connect(self.InputHdf5)
        self.CleanBlocks.connect(self._cache.CleanBlocks)
        self.OutputHdf5.connect(self._cache.OutputHdf5)

        #Debug outputs
        self.InputChannel.connect(self._opChannelSelector.Output)

    def setupOutputs(self):

        self._opReorder2.AxisOrder.setValue(self.InputImage.meta.getAxisKeys())

        # propagate drange
        self.opThreshold1.InputImage.meta.drange = self.InputImage.meta.drange
        if haveGraphCut():
            self.opThreshold1GC.InputImage.meta.drange = self.InputImage.meta.drange
        self.opThreshold2.InputImage.meta.drange = self.InputImage.meta.drange

        self._disconnectAll()

        curIndex = self.CurOperator.value

        if curIndex == 0:
            outputSlot = self._connectForSingleThreshold(self.opThreshold1)
        elif curIndex == 1:
            outputSlot = self._connectForTwoLevelThreshold()
        elif curIndex == 2:
            outputSlot = self._connectForGraphCut()
        else:
            raise ValueError(
                "Unknown index {} for current tab.".format(curIndex))

        self._opReorder2.Input.connect(outputSlot)
        # force the cache to emit a dirty signal
        self._cache.Input.setDirty(slice(None))

    def checkConstraints(self, *args):
        if self._opReorder1.Output.ready():
            numChannels = self._opReorder1.Output.meta.getTaggedShape()['c']
            if self.Channel.value >= numChannels:
                raise DatasetConstraintError(
                    "Two-Level Thresholding",
                    "Your project is configured to select data from channel"
                    " #{}, but your input data only has {} channels.".format(
                        self.Channel.value, numChannels))

    def _disconnectAll(self):
        # start from back
        for slot in [
                self.BigRegions, self.SmallRegions, self.FilteredSmallLabels,
                self.BeforeSizeFilter
        ]:
            slot.disconnect()
            slot.meta.NOTREADY = True
        self._opReorder2.Input.disconnect()
        if haveGraphCut():
            self.opThreshold1GC.InputImage.disconnect()
        self.opThreshold1.InputImage.disconnect()
        self.opThreshold2.InputImage.disconnect()

    def _connectForSingleThreshold(self, threshOp):
        # connect the operators for SingleThreshold
        self.BeforeSizeFilter.connect(threshOp.BeforeSizeFilter)
        self.BeforeSizeFilter.meta.NOTREADY = None
        threshOp.InputImage.connect(self.Smoothed)
        return threshOp.Output

    def _connectForTwoLevelThreshold(self):
        # connect the operators for TwoLevelThreshold
        self.BigRegions.connect(self.opThreshold2.BigRegions)
        self.SmallRegions.connect(self.opThreshold2.SmallRegions)
        self.FilteredSmallLabels.connect(self.opThreshold2.FilteredSmallLabels)
        for slot in [
                self.BigRegions, self.SmallRegions, self.FilteredSmallLabels
        ]:
            slot.meta.NOTREADY = None
        self.opThreshold2.InputImage.connect(self.Smoothed)

        return self.opThreshold2.Output

    def _connectForGraphCut(self):
        assert haveGraphCut(), "Module for graph cut is not available"
        if self.UsePreThreshold.value:
            self._connectForSingleThreshold(self.opThreshold1GC)
            return self.opObjectsGraphCut.Output
        else:
            return self.opGraphCut.Output

    # raise an error if setInSlot is called, we do not pre-cache input
    #def setInSlot(self, slot, subindex, roi, value):
    #pass

    def execute(self, slot, subindex, roi, destination):
        assert False, "Shouldn't get here."

    def propagateDirty(self, slot, subindex, roi):
        # dirtiness propagation is handled in the sub-operators
        pass

    def setInSlot(self, slot, subindex, roi, value):
        assert slot == self.InputHdf5,\
            "[{}] Wrong slot for setInSlot(): {}".format(self.name,
                                                         slot)
        pass
Exemple #23
0
class OpStructuredTracking(OpConservationTracking):
    Crops = InputSlot()
    Labels = InputSlot(stype=Opaque, rtype=List)
    Divisions = InputSlot(stype=Opaque, rtype=List)
    Annotations = InputSlot(stype=Opaque)
    MaxNumObj = InputSlot()

    DivisionWeight = OutputSlot()
    DetectionWeight = OutputSlot()
    TransitionWeight = OutputSlot()
    AppearanceWeight = OutputSlot()
    DisappearanceWeight = OutputSlot()
    MaxNumObjOut = OutputSlot()

    def __init__(self, parent=None, graph=None):
        super(OpStructuredTracking, self).__init__(parent=parent, graph=graph)

        self.labels = {}
        self.divisions = {}
        self.Annotations.setValue({})
        self._ndim = 3

        self._parent = parent

        self.DivisionWeight.setValue(0.6)
        self.DetectionWeight.setValue(0.6)
        self.TransitionWeight.setValue(0.01)
        self.AppearanceWeight.setValue(0.3)
        self.DisappearanceWeight.setValue(0.2)

        self.MaxNumObjOut.setValue(1)

        self.transition_parameter = 5
        self.detectionWeight = 1
        self.divisionWeight = 1
        self.transitionWeight = 1
        self.appearanceWeight = 1
        self.disappearanceWeight = 1

        self.Crops.notifyReady(bind(self._updateCropsFromOperator))
        self.Labels.notifyReady(bind(self._updateLabelsFromOperator))
        self.Divisions.notifyReady(bind(self._updateDivisionsFromOperator))

    def _updateLabelsFromOperator(self):
        self.labels = self.Labels.value

    def _updateDivisionsFromOperator(self):
        self.divisions = self.Divisions.value

    def setupOutputs(self):
        super(OpStructuredTracking, self).setupOutputs()
        self._ndim = 2 if self.LabelImage.meta.shape[3] == 1 else 3

        for t in range(self.LabelImage.meta.shape[0]):
            if t not in self.labels.keys():
                self.labels[t] = {}

    def execute(self, slot, subindex, roi, result):

        if slot is self.Labels:
            result = self.Labels.wait()

        elif slot is self.Divisions:
            result = self.Divisions.wait()

        else:
            super(OpStructuredTracking, self).execute(slot, subindex, roi,
                                                      result)

        return result

    def _updateCropsFromOperator(self):
        self._crops = self.Crops.value

    def _runStructuredLearning(self,
                               z_range,
                               maxObj,
                               maxNearestNeighbors,
                               maxDist,
                               divThreshold,
                               scales,
                               size_range,
                               withDivisions,
                               borderAwareWidth,
                               withClassifierPrior,
                               withBatchProcessing=False):

        if not withBatchProcessing:
            gui = self.parent.parent.trackingApplet._gui.currentGui()

        emptyAnnotations = False
        for crop in self.Annotations.value.keys():
            emptyCrop = self.Annotations.value[crop][
                "divisions"] == {} and self.Annotations.value[crop][
                    "labels"] == {}
            if emptyCrop and not withBatchProcessing:
                gui._criticalMessage("Error: Weights can not be calculated because training annotations for crop {} are missing. ".format(crop) +\
                                  "Go back to Training applet and Save your training for each crop.")
            emptyAnnotations = emptyAnnotations or emptyCrop

        if emptyAnnotations:
            return [
                self.DetectionWeight.value, self.DivisionWeight.value,
                self.TransitionWeight.value, self.AppearanceWeight.value,
                self.DisappearanceWeight.value
            ]

        self._updateCropsFromOperator()
        median_obj_size = [0]

        from_z = z_range[0]
        to_z = z_range[1]
        ndim = 3
        if (to_z - from_z == 0):
            ndim = 2

        time_range = [0, self.LabelImage.meta.shape[0] - 1]
        x_range = [0, self.LabelImage.meta.shape[1]]
        y_range = [0, self.LabelImage.meta.shape[2]]
        z_range = [0, self.LabelImage.meta.shape[3]]

        parameters = self.Parameters.value

        parameters['maxDist'] = maxDist
        parameters['maxObj'] = maxObj
        parameters['divThreshold'] = divThreshold
        parameters['withDivisions'] = withDivisions
        parameters['withClassifierPrior'] = withClassifierPrior
        parameters['borderAwareWidth'] = borderAwareWidth
        parameters['scales'] = scales
        parameters['time_range'] = [min(time_range), max(time_range)]
        parameters['x_range'] = x_range
        parameters['y_range'] = y_range
        parameters['z_range'] = z_range
        parameters['max_nearest_neighbors'] = maxNearestNeighbors

        # Set a size range with a minimum area equal to the max number of objects (since the GMM throws an error if we try to fit more gaussians than the number of pixels in the object)
        size_range = (max(maxObj, size_range[0]), size_range[1])
        parameters['size_range'] = size_range

        self.Parameters.setValue(parameters, check_changed=False)

        foundAllArcs = False
        new_max_nearest_neighbors = max([maxNearestNeighbors - 1, 1])
        maxObjOK = True
        parameters['max_nearest_neighbors'] = maxNearestNeighbors
        while not foundAllArcs and maxObjOK and new_max_nearest_neighbors < 10:
            new_max_nearest_neighbors += 1
            logger.info("new_max_nearest_neighbors: {}".format(
                new_max_nearest_neighbors))

            time_range = range(0, self.LabelImage.meta.shape[0])

            parameters['max_nearest_neighbors'] = new_max_nearest_neighbors
            self.Parameters.setValue(parameters, check_changed=False)

            hypothesesGraph = self._createHypothesesGraph()
            if hypothesesGraph.countNodes() == 0:
                raise DatasetConstraintError(
                    'Structured Learning',
                    'Can not track frames with 0 objects, abort.')

            hypothesesGraph.insertEnergies()
            # trackingGraph = hypothesesGraph.toTrackingGraph()
            # import pprint
            # pprint.pprint(trackingGraph.model)

            maxDist = 200
            sizeDependent = False
            divThreshold = float(0.5)

            logger.info(
                "Structured Learning: Adding Training Annotations to Hypotheses Graph"
            )

            mergeMsgStr = "Your tracking annotations contradict this model assumptions! All tracks must be continuous, tracks of length one are not allowed, and mergers may merge or split but all tracks in a merger appear/disappear together."
            foundAllArcs = True
            numAllAnnotatedDivisions = 0

            self.features = self.ObjectFeatures(
                range(0, self.LabelImage.meta.shape[0])).wait()

            for cropKey in self.Crops.value.keys():
                if foundAllArcs:

                    if not cropKey in self.Annotations.value.keys():
                        if not withBatchProcessing:
                            gui._criticalMessage("You have not trained or saved your training for " + str(cropKey) + \
                                              ". \nGo back to the Training applet and save all your training!")
                        return [
                            self.DetectionWeight.value,
                            self.DivisionWeight.value,
                            self.TransitionWeight.value,
                            self.AppearanceWeight.value,
                            self.DisappearanceWeight.value
                        ]

                    crop = self.Annotations.value[cropKey]
                    timeRange = self.Crops.value[cropKey]['time']

                    if "labels" in crop.keys():

                        labels = crop["labels"]

                        for time in labels.keys():
                            if time in range(timeRange[0], timeRange[1] + 1):

                                if not foundAllArcs:
                                    break

                                for label in labels[time].keys():

                                    if not foundAllArcs:
                                        break

                                    trackSet = labels[time][label]
                                    center = self.features[time][
                                        default_features_key]['RegionCenter'][
                                            label]
                                    trackCount = len(trackSet)

                                    if trackCount > maxObj:
                                        logger.info(
                                            "Your track count for object {} in time frame {} is {} =| {} |, which is greater than maximum object number {} defined by object count classifier!"
                                            .format(label, time, trackCount,
                                                    trackSet, maxObj))
                                        logger.info(
                                            "Either remove track(s) from this object or train the object count classifier with more labels!"
                                        )
                                        maxObjOK = False
                                        raise DatasetConstraintError('Structured Learning', "Your track count for object "+str(label)+" in time frame " +str(time)+ " equals "+str(trackCount)+"=|"+str(trackSet)+"|," + \
                                                " which is greater than the maximum object number "+str(maxObj)+" defined by object count classifier! " + \
                                                "Either remove track(s) from this object or train the object count classifier with more labels!")

                                    for track in trackSet:

                                        if not foundAllArcs:
                                            logger.info(
                                                "[structuredTrackingGui] Increasing max nearest neighbors!"
                                            )
                                            break

                                        # is this a FIRST, INTERMEDIATE, LAST, SINGLETON(FIRST_LAST) object of a track (or FALSE_DETECTION)
                                        type = self._type(
                                            cropKey, time, track
                                        )  # returns [type, previous_label] if type=="LAST" or "INTERMEDIATE" (else [type])
                                        if type == None:
                                            raise DatasetConstraintError(
                                                'Structured Learning',
                                                mergeMsgStr)

                                        elif type[0] in [
                                                "LAST", "INTERMEDIATE",
                                                "SINGLETON(FIRST_LAST)"
                                        ]:
                                            if type[0] == "SINGLETON(FIRST_LAST)":
                                                trackCountIntersection = len(
                                                    trackSet)
                                            else:
                                                previous_label = int(type[1])
                                                previousTrackSet = labels[
                                                    time - 1][previous_label]
                                                intersectionSet = trackSet.intersection(
                                                    previousTrackSet)
                                                trackCountIntersection = len(
                                                    intersectionSet)
                                            print "trackCountIntersection", trackCountIntersection

                                            if trackCountIntersection > maxObj:
                                                logger.info(
                                                    "Your track count for transition ( {},{} ) ---> ( {},{} ) is {} =| {} |, which is greater than maximum object number {} defined by object count classifier!"
                                                    .format(
                                                        previous_label,
                                                        time - 1, label, time,
                                                        trackCountIntersection,
                                                        intersectionSet,
                                                        maxObj))
                                                logger.info(
                                                    "Either remove track(s) from these objects or train the object count classifier with more labels!"
                                                )
                                                maxObjOK = False
                                                raise DatasetConstraintError('Structured Learning', "Your track count for transition ("+str(previous_label)+","+str(time-1)+") ---> ("+str(label)+","+str(time)+") is "+str(trackCountIntersection)+"=|"+str(intersectionSet)+"|, " + \
                                                        "which is greater than maximum object number "+str(maxObj)+" defined by object count classifier!" + \
                                                        "Either remove track(s) from these objects or train the object count classifier with more labels!")

                                            sink = (time, int(label))
                                            foundAllArcs = False
                                            for edge in hypothesesGraph._graph.in_edges(
                                                    sink
                                            ):  # an edge is a tuple of source and target nodes
                                                logger.info(
                                                    "Looking at in edge {} of node {}, searching for ({},{})"
                                                    .format(
                                                        edge, sink, time - 1,
                                                        previous_label))
                                                if edge[0][
                                                        0] == time - 1 and edge[
                                                            0][1] == int(
                                                                previous_label
                                                            ):  # every node 'id' is a tuple (timestep, label), so we need the in-edge coming from previous_label
                                                    foundAllArcs = True
                                                    hypothesesGraph._graph.edge[
                                                        edge[0]][edge[1]][
                                                            'value'] = int(
                                                                trackCountIntersection
                                                            )
                                                    print "[structuredTrackingGui] EDGE: ({},{})--->({},{})".format(
                                                        time - 1,
                                                        int(previous_label),
                                                        time, int(label))
                                                    break

                                            if not foundAllArcs:
                                                logger.info(
                                                    "[structuredTrackingGui] Increasing max nearest neighbors! LABELS/MERGERS {} {}"
                                                    .format(
                                                        time - 1,
                                                        int(previous_label)))
                                                logger.info(
                                                    "[structuredTrackingGui] Increasing max nearest neighbors! LABELS/MERGERS {} {}"
                                                    .format(time, int(label)))
                                                break

                                    if type == None:
                                        raise DatasetConstraintError(
                                            'Structured Learning', mergeMsgStr)

                                    elif type[0] in [
                                            "FIRST", "LAST", "INTERMEDIATE",
                                            "SINGLETON(FIRST_LAST)"
                                    ]:
                                        if (
                                                time, int(label)
                                        ) in hypothesesGraph._graph.node.keys(
                                        ):
                                            hypothesesGraph._graph.node[(
                                                time, int(label)
                                            )]['value'] = trackCount
                                            logger.info(
                                                "[structuredTrackingGui] NODE: {} {}"
                                                .format(time, int(label)))
                                            print "[structuredTrackingGui] NODE: {} {} {}".format(
                                                time, int(label),
                                                int(trackCount))
                                        else:
                                            logger.info(
                                                "[structuredTrackingGui] NODE: {} {} NOT found"
                                                .format(time, int(label)))

                                            foundAllArcs = False
                                            break

                    if foundAllArcs and "divisions" in crop.keys():
                        divisions = crop["divisions"]

                        numAllAnnotatedDivisions = numAllAnnotatedDivisions + len(
                            divisions)
                        for track in divisions.keys():
                            if not foundAllArcs:
                                break

                            division = divisions[track]
                            time = int(division[1])

                            parent = int(
                                self.getLabelInCrop(cropKey, time, track))

                            if parent >= 0:
                                children = [
                                    int(
                                        self.getLabelInCrop(
                                            cropKey, time + 1, division[0][i]))
                                    for i in [0, 1]
                                ]
                                parentNode = (time, parent)
                                hypothesesGraph._graph.node[parentNode][
                                    'divisionValue'] = 1
                                foundAllArcs = False
                                for child in children:
                                    for edge in hypothesesGraph._graph.out_edges(
                                            parentNode
                                    ):  # an edge is a tuple of source and target nodes
                                        if edge[1][0] == time + 1 and edge[1][
                                                1] == int(
                                                    child
                                                ):  # every node 'id' is a tuple (timestep, label), so we need the in-edge coming from previous_label
                                            foundAllArcs = True
                                            hypothesesGraph._graph.edge[
                                                edge[0]][edge[1]]['value'] = 1
                                            break
                                    if not foundAllArcs:
                                        break

                                if not foundAllArcs:
                                    logger.info(
                                        "[structuredTrackingGui] Increasing max nearest neighbors! DIVISION {} {}"
                                        .format(time, parent))
                                    break
        logger.info(
            "max nearest neighbors= {}".format(new_max_nearest_neighbors))

        if new_max_nearest_neighbors > maxNearestNeighbors:
            maxNearestNeighbors = new_max_nearest_neighbors
            parameters['maxNearestNeighbors'] = maxNearestNeighbors
            if not withBatchProcessing:
                gui._drawer.maxNearestNeighborsSpinBox.setValue(
                    maxNearestNeighbors)

        detectionWeight = self.DetectionWeight.value
        divisionWeight = self.DivisionWeight.value
        transitionWeight = self.TransitionWeight.value
        disappearanceWeight = self.DisappearanceWeight.value
        appearanceWeight = self.AppearanceWeight.value

        if not foundAllArcs:
            logger.info(
                "[structuredTracking] Increasing max nearest neighbors did not result in finding all training arcs!"
            )
            return [
                transitionWeight, detectionWeight, divisionWeight,
                appearanceWeight, disappearanceWeight
            ]

        hypothesesGraph.insertEnergies()

        # crops away everything (arcs and nodes) that doesn't have 'value' set
        prunedGraph = hypothesesGraph.pruneGraphToSolution(
            distanceToSolution=0
        )  # width of non-annotated border needed for negative training examples

        trackingGraph = prunedGraph.toTrackingGraph()

        # trackingGraph.convexifyCosts()
        model = trackingGraph.model
        model['settings']['optimizerEpGap'] = 0.005
        gt = prunedGraph.getSolutionDictionary()

        initialWeights = trackingGraph.weightsListToDict([
            transitionWeight, detectionWeight, divisionWeight,
            appearanceWeight, disappearanceWeight
        ])

        mht.trainWithWeightInitialization(model, gt, initialWeights)
        weightsDict = mht.train(model, gt)

        weights = trackingGraph.weightsDictToList(weightsDict)

        if not withBatchProcessing and withDivisions and numAllAnnotatedDivisions == 0 and not weights[
                2] == 0.0:
            gui._informationMessage("Divisible objects are checked, but you did not annotate any divisions in your tracking training. " + \
                                 "The resulting division weight might be arbitrarily and if there are divisions present in the dataset, " +\
                                 "they might not be present in the tracking solution.")

        norm = 0
        for i in range(len(weights)):
            norm += weights[i] * weights[i]
        norm = math.sqrt(norm)

        if norm > 0.0000001:
            self.TransitionWeight.setValue(weights[0] / norm)
            self.DetectionWeight.setValue(weights[1] / norm)
            self.DivisionWeight.setValue(weights[2] / norm)
            self.AppearanceWeight.setValue(weights[3] / norm)
            self.DisappearanceWeight.setValue(weights[4] / norm)

        if not withBatchProcessing:
            gui._drawer.detWeightBox.setValue(self.DetectionWeight.value)
            gui._drawer.divWeightBox.setValue(self.DivisionWeight.value)
            gui._drawer.transWeightBox.setValue(self.TransitionWeight.value)
            gui._drawer.appearanceBox.setValue(self.AppearanceWeight.value)
            gui._drawer.disappearanceBox.setValue(
                self.DisappearanceWeight.value)

        if not withBatchProcessing:
            if self.DetectionWeight.value < 0.0:
                gui._informationMessage ("Detection weight calculated was negative. Tracking solution will be re-calculated with non-negativity constraints for learning weights. " + \
                    "Furthermore, you should add more training and recalculate the learning weights in order to improve your tracking solution.")
            elif self.DivisionWeight.value < 0.0:
                gui._informationMessage ("Division weight calculated was negative. Tracking solution will be re-calculated with non-negativity constraints for learning weights. " + \
                    "Furthermore, you should add more division cells to your training and recalculate the learning weights in order to improve your tracking solution.")
            elif self.TransitionWeight.value < 0.0:
                gui._informationMessage ("Transition weight calculated was negative. Tracking solution will be re-calculated with non-negativity constraints for learning weights. " + \
                    "Furthermore, you should add more transitions to your training and recalculate the learning weights in order to improve your tracking solution.")
            elif self.AppearanceWeight.value < 0.0:
                gui._informationMessage ("Appearance weight calculated was negative. Tracking solution will be re-calculated with non-negativity constraints for learning weights. " + \
                    "Furthermore, you should add more appearances to your training and recalculate the learning weights in order to improve your tracking solution.")
            elif self.DisappearanceWeight.value < 0.0:
                gui._informationMessage ("Disappearance weight calculated was negative. Tracking solution will be re-calculated with non-negativity constraints for learning weights. " + \
                    "Furthermore, you should add more disappearances to your training and recalculate the learning weights in order to improve your tracking solution.")

        if self.DetectionWeight.value < 0.0 or self.DivisionWeight.value < 0.0 or self.TransitionWeight.value < 0.0 or \
            self.AppearanceWeight.value < 0.0 or self.DisappearanceWeight.value < 0.0:

            model['settings']['nonNegativeWeightsOnly'] = True
            weightsDict = mht.train(model, gt)

            weights = trackingGraph.weightsDictToList(weightsDict)

            norm = 0
            for i in range(len(weights)):
                norm += weights[i] * weights[i]
            norm = math.sqrt(norm)

            if norm > 0.0000001:
                self.TransitionWeight.setValue(weights[0] / norm)
                self.DetectionWeight.setValue(weights[1] / norm)
                self.DivisionWeight.setValue(weights[2] / norm)
                self.AppearanceWeight.setValue(weights[3] / norm)
                self.DisappearanceWeight.setValue(weights[4] / norm)

            if not withBatchProcessing:
                gui._drawer.detWeightBox.setValue(self.DetectionWeight.value)
                gui._drawer.divWeightBox.setValue(self.DivisionWeight.value)
                gui._drawer.transWeightBox.setValue(
                    self.TransitionWeight.value)
                gui._drawer.appearanceBox.setValue(self.AppearanceWeight.value)
                gui._drawer.disappearanceBox.setValue(
                    self.DisappearanceWeight.value)

        logger.info("Structured Learning Tracking Weights (normalized):")
        logger.info("   detection weight     = {}".format(
            self.DetectionWeight.value))
        logger.info("   division weight     = {}".format(
            self.DivisionWeight.value))
        logger.info("   transition weight     = {}".format(
            self.TransitionWeight.value))
        logger.info("   appearance weight     = {}".format(
            self.AppearanceWeight.value))
        logger.info("   disappearance weight     = {}".format(
            self.DisappearanceWeight.value))

        parameters['detWeight'] = self.DetectionWeight.value
        parameters['divWeight'] = self.DivisionWeight.value
        parameters['transWeight'] = self.TransitionWeight.value
        parameters['appearanceCost'] = self.AppearanceWeight.value
        parameters['disappearanceCost'] = self.DisappearanceWeight.value

        self.Parameters.setValue(parameters)

        return [
            self.DetectionWeight.value, self.DivisionWeight.value,
            self.TransitionWeight.value, self.AppearanceWeight.value,
            self.DisappearanceWeight.value
        ]

    def getLabelInCrop(self, cropKey, time, track):
        labels = self.Annotations.value[cropKey]["labels"][time]
        for label in labels.keys():
            if self.Annotations.value[cropKey]["labels"][time][label] == set(
                [track]):
                return label
        return -1

    def _type(self, cropKey, time, track):
        # returns [type, previous_label] (if type=="LAST" or "INTERMEDIATE" else [type])
        type = None
        if track == -1:
            return ["FALSE_DETECTION"]
        elif time == 0:
            type = "FIRST"

        labels = self.Annotations.value[cropKey]["labels"]
        crop = self._crops[cropKey]
        lastTime = -1
        lastLabel = -1
        for t in range(crop["time"][0], time):
            for label in labels[t]:
                if track in labels[t][label]:
                    lastTime = t
                    lastLabel = label
        if lastTime == -1:
            type = "FIRST"
        elif lastTime < time - 1:
            logger.info(
                "ERROR: Your annotations are not complete. See time frame {}.".
                format(time - 1))
        elif lastTime == time - 1:
            type = "INTERMEDIATE"

        firstTime = -1
        for t in range(crop["time"][1], time, -1):
            if t in labels.keys():
                for label in labels[t]:
                    if track in labels[t][label]:
                        firstTime = t
        if firstTime == -1:
            if type == "FIRST":
                return ["SINGLETON(FIRST_LAST)"]
            else:
                return ["LAST", lastLabel]
        elif firstTime > time + 1:
            logger.info(
                "ERROR: Your annotations are not complete. See time frame {}.".
                format(time + 1))
        elif firstTime == time + 1:
            if type == "INTERMEDIATE":
                return ["INTERMEDIATE", lastLabel]
            elif type != None:
                return [type]
Exemple #24
0
class OpNNClassification(Operator):
    """
    Top-level operator for pixel classification
    """

    NO_MODEL = _NO_MODEL

    name = "OpNNClassification"
    category = "Top-level"

    # Graph inputs
    InputImages = InputSlot(level=1)
    ServerConfig = InputSlot(stype=stype.Opaque, nonlane=True)
    Checkpoints = InputSlot()

    NumClasses = InputSlot()
    LabelInputs = InputSlot(optional=True, level=1)
    FreezePredictions = InputSlot(stype="bool", value=False, nonlane=True)
    ModelBinary = InputSlot(stype=stype.Opaque, nonlane=True)
    # Contains cached model info
    ModelInfo = InputSlot(stype=stype.Opaque, nonlane=True, optional=True)
    ModelSession = InputSlot()

    Classifier = OutputSlot()
    PredictionProbabilities = OutputSlot(
        level=1
    )  # Classification predictions (via feature cache for interactive speed)
    PredictionProbabilityChannels = OutputSlot(
        level=2)  # Classification predictions, enumerated by channel
    CachedPredictionProbabilities = OutputSlot(level=1)
    LabelImages = OutputSlot(level=1)
    NonzeroLabelBlocks = OutputSlot(level=1)

    Halo_Size = InputSlot(value=0)
    Batch_Size = InputSlot(value=1)

    # Gui only (not part of the pipeline)
    LabelNames = OutputSlot()
    LabelColors = OutputSlot()
    PmapColors = OutputSlot()

    def setupOutputs(self):
        numClasses = self.NumClasses.value

        self.LabelNames.meta.dtype = object
        self.LabelNames.meta.shape = (numClasses, )
        self.LabelColors.meta.dtype = object
        self.LabelColors.meta.shape = (numClasses, )
        self.PmapColors.meta.dtype = object
        self.PmapColors.meta.shape = (numClasses, )

        if self.opBlockShape.BlockShapeInference.ready():
            self.opPredictionPipeline.BlockShape.connect(
                self.opBlockShape.BlockShapeInference)

    def cleanUp(self):
        try:
            self.ModelSession.value.close()
        except Exception as e:
            logger.warning(e)

    def __init__(self, *args, connectionFactory, **kwargs):
        """
        Instantiate all internal operators and connect them together.
        """
        super(OpNNClassification, self).__init__(*args, **kwargs)
        self._connectionFactory = connectionFactory
        #
        # Default values for some input slots
        self.FreezePredictions.setValue(True)
        self.LabelNames.setValue([])
        self.LabelColors.setValue([])
        self.PmapColors.setValue([])

        self.Checkpoints.setValue([])
        self._binary_model = None

        # SPECIAL connection: the LabelInputs slot doesn't get it's data
        # from the InputImages slot, but it's shape must match.
        self.LabelInputs.connect(self.InputImages)

        self.opBlockShape = OpMultiLaneWrapper(OpBlockShape, parent=self)
        self.opBlockShape.RawImage.connect(self.InputImages)
        self.opBlockShape.ModelSession.connect(self.ModelSession)

        # self.opModel = OpModel(parent=self.parent, connectionFactory=connectionFactory)
        # self.opModel.ServerConfig.connect(self.ServerConfig)
        # self.opModel.ModelBinary.connect(self.ModelBinary)

        # self.ModelSession.connect(self.opModel.TiktorchModel)
        # self.NumClasses.connect(self.opModel.NumClasses)

        # Hook up Labeling Pipeline
        self.opLabelPipeline = OpMultiLaneWrapper(
            OpLabelPipeline,
            parent=self,
            broadcastingSlotNames=["DeleteLabel"])
        self.opLabelPipeline.RawImage.connect(self.InputImages)
        self.opLabelPipeline.LabelInput.connect(self.LabelInputs)
        self.opLabelPipeline.DeleteLabel.setValue(-1)
        self.LabelImages.connect(self.opLabelPipeline.Output)
        self.NonzeroLabelBlocks.connect(self.opLabelPipeline.nonzeroBlocks)

        # TRAINING OPERATOR
        self.opTrain = OpTikTorchTrainClassifierBlocked(parent=self)
        self.opTrain.ModelSession.connect(self.ModelSession)
        self.opTrain.Labels.connect(self.opLabelPipeline.Output)
        self.opTrain.Images.connect(self.InputImages)
        self.opTrain.BlockShape.connect(self.opBlockShape.BlockShapeTrain)
        self.opTrain.nonzeroLabelBlocks.connect(
            self.opLabelPipeline.nonzeroBlocks)
        self.opTrain.MaxLabel.connect(self.NumClasses)

        # CLASSIFIER CACHE
        # This cache stores exactly one object: the classifier itself.
        self.classifier_cache = OpValueCache(parent=self)
        self.classifier_cache.name = "OpNetworkClassification.classifier_cache"
        self.classifier_cache.inputs["Input"].connect(
            self.opTrain.UpdatedModelSession)
        self.classifier_cache.inputs["fixAtCurrent"].connect(
            self.FreezePredictions)
        self.Classifier.connect(self.classifier_cache.Output)

        # Hook up the prediction pipeline inputs
        self.opPredictionPipeline = OpMultiLaneWrapper(OpPredictionPipeline,
                                                       parent=self)
        self.opPredictionPipeline.RawImage.connect(self.InputImages)
        # self.opPredictionPipeline.Classifier.connect(self.classifier_cache.Output)
        self.opPredictionPipeline.Classifier.connect(self.ModelSession)
        self.opPredictionPipeline.NumClasses.connect(self.NumClasses)
        self.opPredictionPipeline.FreezePredictions.connect(
            self.FreezePredictions)

        self.PredictionProbabilities.connect(
            self.opPredictionPipeline.PredictionProbabilities)
        self.CachedPredictionProbabilities.connect(
            self.opPredictionPipeline.CachedPredictionProbabilities)
        self.PredictionProbabilityChannels.connect(
            self.opPredictionPipeline.PredictionProbabilityChannels)

        def inputResizeHandler(slot, oldsize, newsize):
            if newsize == 0:
                self.LabelImages.resize(0)
                self.NonzeroLabelBlocks.resize(0)
                self.PredictionProbabilities.resize(0)
                self.CachedPredictionProbabilities.resize(0)

        self.InputImages.notifyResized(inputResizeHandler)

        # Debug assertions: Check to make sure the non-wrapped operators stayed that way.
        assert self.opTrain.Images.operator == self.opTrain

        def handleNewInputImage(multislot, index, *args):
            def handleInputReady(slot):
                self._checkConstraints(index)
                self.setupCaches(multislot.index(slot))

            multislot[index].notifyReady(handleInputReady)

        self.InputImages.notifyInserted(handleNewInputImage)

        # All input multi-slots should be kept in sync
        # Output multi-slots will auto-sync via the graph
        multiInputs = [s for s in list(self.inputs.values()) if s.level >= 1]
        for s1 in multiInputs:
            for s2 in multiInputs:
                if s1 != s2:

                    def insertSlot(a, b, position, finalsize):
                        a.insertSlot(position, finalsize)

                    s1.notifyInserted(partial(insertSlot, s2))

                    def removeSlot(a, b, position, finalsize):
                        a.removeSlot(position, finalsize)

                    s1.notifyRemoved(partial(removeSlot, s2))

    def set_model(self, model_content: bytes) -> bool:
        self.ModelBinary.disconnect()
        self.ModelBinary.setValue(model_content)
        return self.opModel.TiktorchModel.ready()

    def update_config(self, partial_config: dict):
        self.ClassifierFactory.meta.hparams = partial_config

        def _send_hparams(slot):
            classifierFactory = self.ClassifierFactory[:].wait()[0]
            classifierFactory.update_config(
                self.ClassifierFactory.meta.hparams)

        if not self.ClassifierFactory.ready():
            self.ClassifierFactory.notifyReady(_send_hparams)
        else:
            classifierFactory = self.ClassifierFactory[:].wait()[0]
            classifierFactory.update_config(partial_config)

    def setupCaches(self, imageIndex):
        numImages = len(self.InputImages)
        inputSlot = self.InputImages[imageIndex]

        self.LabelInputs.resize(numImages)

        # Special case: We have to set up the shape of our label *input* according to our image input shape
        shapeList = list(self.InputImages[imageIndex].meta.shape)
        try:
            channelIndex = self.InputImages[imageIndex].meta.axistags.index(
                "c")
            shapeList[channelIndex] = 1
        except:
            pass
        self.LabelInputs[imageIndex].meta.shape = tuple(shapeList)
        self.LabelInputs[imageIndex].meta.axistags = inputSlot.meta.axistags

    def _checkConstraints(self, laneIndex):
        """
        Ensure that all input images have the same number of channels.
        """
        if not self.InputImages[laneIndex].ready():
            return

        thisLaneTaggedShape = self.InputImages[laneIndex].meta.getTaggedShape()

        # Find a different lane and use it for comparison
        validShape = thisLaneTaggedShape
        for i, slot in enumerate(self.InputImages):
            if slot.ready() and i != laneIndex:
                validShape = slot.meta.getTaggedShape()
                break

        if "t" in thisLaneTaggedShape:
            del thisLaneTaggedShape["t"]
        if "t" in validShape:
            del validShape["t"]

        if validShape["c"] != thisLaneTaggedShape["c"]:
            raise DatasetConstraintError(
                "Pixel Classification with CNNs",
                "All input images must have the same number of channels.  "
                "Your new image has {} channel(s), but your other images have {} channel(s)."
                .format(thisLaneTaggedShape["c"], validShape["c"]),
            )

        if len(validShape) != len(thisLaneTaggedShape):
            raise DatasetConstraintError(
                "Pixel Classification with CNNs",
                "All input images must have the same dimensionality.  "
                "Your new image has {} dimensions (including channel), but your other images have {} dimensions."
                .format(len(thisLaneTaggedShape), len(validShape)),
            )

    def setInSlot(self, slot, subindex, roi, value):
        # Nothing to do here: All inputs that support __setitem__
        #   are directly connected to internal operators.
        pass

    def propagateDirty(self, slot, subindex, roi):
        # Nothing to do here: All outputs are directly connected to
        #  internal operators that handle their own dirty propagation.
        self.PredictionProbabilityChannels.setDirty(slice(None))

    def addLane(self, laneIndex):
        numLanes = len(self.InputImages)
        assert numLanes == laneIndex, f"Image lanes must be appended. {numLanes}, {laneIndex})"
        self.InputImages.resize(numLanes + 1)

    def removeLane(self, laneIndex, finalLength):
        self.InputImages.removeSlot(laneIndex, finalLength)

    def getLane(self, laneIndex):
        return OperatorSubView(self, laneIndex)

    def importLabels(self, laneIndex, slot):
        # Load the data into the cache
        new_max = self.getLane(
            laneIndex).opLabelPipeline.opLabelArray.ingestData(slot)

        # Add to the list of label names if there's a new max label
        old_names = self.LabelNames.value
        old_max = len(old_names)
        if new_max > old_max:
            new_names = old_names + [
                "Label {}".format(x) for x in range(old_max + 1, new_max + 1)
            ]
            self.LabelNames.setValue(new_names)

            # Make some default colors, too
            # FIXME: take the colors from default16_new
            from volumina import colortables

            default_colors = colortables.default16_new

            label_colors = self.LabelColors.value
            pmap_colors = self.PmapColors.value

            self.LabelColors.setValue(label_colors +
                                      default_colors[old_max:new_max])
            self.PmapColors.setValue(pmap_colors +
                                     default_colors[old_max:new_max])

    def mergeLabels(self, from_label, into_label):
        for laneIndex in range(len(self.InputImages)):
            self.getLane(laneIndex).opLabelPipeline.opLabelArray.mergeLabels(
                from_label, into_label)

    def clearLabel(self, label_value):
        for laneIndex in range(len(self.InputImages)):
            self.getLane(laneIndex).opLabelPipeline.opLabelArray.clearLabel(
                label_value)
Exemple #25
0
class OpFilterLabels(Operator):
    """
    Given a labeled volume, discard labels that have too few pixels.
    Zero is used as the background label
    """
    name = "OpFilterLabels"
    category = "generic"

    Input = InputSlot()
    MinLabelSize = InputSlot(stype='int')
    MaxLabelSize = InputSlot(optional=True, stype='int')
    BinaryOut = InputSlot(optional=True, value=False, stype='bool')

    Output = OutputSlot()

    def setupOutputs(self):
        self.Output.meta.assignFrom(self.Input.meta)

    def execute(self, slot, subindex, roi, result):
        minSize = self.MinLabelSize.value
        maxSize = None
        if self.MaxLabelSize.ready():
            maxSize = self.MaxLabelSize.value
        req = self.Input.get(roi)
        req.writeInto(result)
        req.wait()

        self.remove_wrongly_sized_connected_components(result,
                                                       min_size=minSize,
                                                       max_size=maxSize,
                                                       in_place=True)
        return result

    def propagateDirty(self, inputSlot, subindex, roi):
        # Both input slots can affect the entire output
        assert inputSlot == self.Input or inputSlot == self.MinLabelSize or inputSlot == self.MaxLabelSize
        self.Output.setDirty(slice(None))

    def remove_wrongly_sized_connected_components(self, a, min_size, max_size,
                                                  in_place):
        """
        Adapted from http://github.com/jni/ray/blob/develop/ray/morpho.py
        (MIT License)
        """
        bin_out = self.BinaryOut.value

        original_dtype = a.dtype

        if not in_place:
            a = a.copy()
        if min_size == 0 and (max_size is None or max_size > numpy.prod(
                a.shape)):  # shortcut for efficiency
            if (bin_out):
                numpy.place(a, a, 1)
            return a

        try:
            component_sizes = numpy.bincount(a.ravel())
        except TypeError:
            # On 32-bit systems, must explicitly convert from uint32 to int
            # (This fix is just for VM testing.)
            component_sizes = numpy.bincount(
                numpy.asarray(a.ravel(), dtype=int))
        bad_sizes = component_sizes < min_size
        if max_size is not None:
            numpy.logical_or(bad_sizes,
                             component_sizes > max_size,
                             out=bad_sizes)

        bad_locations = bad_sizes[a]
        a[bad_locations] = 0
        if (bin_out):
            # Replace non-zero values with 1
            numpy.place(a, a, 1)
        return numpy.array(a, dtype=original_dtype)
Exemple #26
0
class OpDeviationFromMean(Operator):
    """
    Multi-image operator.
    Calculates the pixelwise mean of a set of images, and produces a set of corresponding images for the difference from the mean.
    Note: Inputs must all have the same shape.
    """
    ScalingFactor = InputSlot()  # Scale after subtraction
    Offset = InputSlot()  # Offset final results
    Input = InputSlot(level=1)  # Multi-image input

    Mean = OutputSlot()
    Output = OutputSlot(level=1)  # Multi-image output

    def setupOutputs(self):
        # Ensure all inputs have the same shape
        if len(self.Input) > 0:
            shape = self.Input[0].meta.shape
            for islot in self.Input:
                if islot.meta.shape != shape:
                    raise RuntimeError(
                        "Input images must have the same shape.")

        # Copy the meta info from each input to the corresponding output
        self.Output.resize(len(self.Input))
        for index, islot in enumerate(self.Input):
            self.Output[index].meta.assignFrom(islot.meta)

        self.Mean.meta.assignFrom(self.Input[0].meta)

        def markAllOutputsDirty(*args):
            self.propagateDirty(self.Input, (), slice(None))

        self.Input.notifyInserted(markAllOutputsDirty)
        self.Input.notifyRemoved(markAllOutputsDirty)

    def execute(self, slot, subindex, roi, result):
        """
        Compute.  This is a simple implementation, without optimizations.
        """
        # Compute average of *all* inputs
        result[:] = 0.0
        for s in self.Input:
            result[:] += s.get(roi).wait()
        result[:] = result / len(self.Input)

        # If the user wanted the mean, we're done.
        if slot == self.Mean:
            return result

        assert slot == self.Output

        # Subtract average from the particular image being requested
        result[:] = self.Input[subindex].get(roi).wait() - result

        # Scale
        result[:] *= self.ScalingFactor.value

        # Add constant offset
        result[:] += self.Offset.value

        return result

    def propagateDirty(self, slot, subindex, roi):
        # If the dirty slot is one of our two constants, then the entire image region is dirty
        if slot == self.Offset or slot == self.ScalingFactor:
            roi = slice(None)  # The whole image region

        # All inputs affect all outputs, so every image is dirty now
        for oslot in self.Output:
            oslot.setDirty(roi)

    #############################################
    ## Methods to satisfy MultiLaneOperatorABC ##
    #############################################

    def addLane(self, laneIndex):
        """
        Add an image lane to the top-level operator.
        """
        numLanes = len(self.Input)
        assert numLanes == laneIndex, "Image lanes must be appended."
        self.Input.resize(numLanes + 1)
        self.Output.resize(numLanes + 1)

    def removeLane(self, laneIndex, finalLength):
        """
        Remove the specified image lane from the top-level operator.
        """
        self.Input.removeSlot(laneIndex, finalLength)
        self.Output.removeSlot(laneIndex, finalLength)

    def getLane(self, laneIndex):
        return OperatorSubView(self, laneIndex)
Exemple #27
0
class OpTaskWorker(Operator):
    Input = InputSlot()
    RoiString = InputSlot(stype="string")
    TaskName = InputSlot(stype="string")
    ConfigFilePath = InputSlot(stype="filestring")
    OutputFilesetDescription = InputSlot(stype="filestring")

    ReturnCode = OutputSlot()

    def __init__(self, *args, **kwargs):
        super(OpTaskWorker, self).__init__(*args, **kwargs)
        self.progressSignal = OrderedSignal()
        self._primaryBlockwiseFileset = None

    def setupOutputs(self):
        self.ReturnCode.meta.dtype = bool
        self.ReturnCode.meta.shape = (1, )

        self._closeFiles()
        self._primaryBlockwiseFileset = BlockwiseFileset(
            self.OutputFilesetDescription.value, "a")

    def cleanUp(self):
        self._closeFiles()
        super(OpTaskWorker, self).cleanUp()

    def _closeFiles(self):
        if self._primaryBlockwiseFileset is not None:
            self._primaryBlockwiseFileset.close()
        self._primaryBlockwiseFileset = None

    def execute(self, slot, subindex, ignored_roi, result):
        configFilePath = self.ConfigFilePath.value
        config = parseClusterConfigFile(configFilePath)

        blockwiseFileset = self._primaryBlockwiseFileset

        # Check axis compatibility
        inputAxes = list(self.Input.meta.getTaggedShape().keys())
        outputAxes = list(blockwiseFileset.description.axes)
        assert set(inputAxes) == set(
            outputAxes
        ), "Output dataset has the wrong set of axes.  Input axes: {}, Output axes: {}".format(
            "".join(inputAxes), "".join(outputAxes))

        roiString = self.RoiString.value
        roi = Roi.loads(roiString)
        if len(roi.start) != len(self.Input.meta.shape):
            assert (
                False
            ), "Task roi: {} is not valid for this input.  Did the master launch this task correctly?".format(
                roiString)

        logger.info("Executing for roi: {}".format(roi))

        if config.use_node_local_scratch:
            assert False, "FIXME."

        assert (
            blockwiseFileset.getEntireBlockRoi(roi.start)[1] == roi.stop
        ).all(
        ), "Each task must execute exactly one full block.  ({},{}) is not a valid block roi.".format(
            roi.start, roi.stop)
        assert self.Input.ready()

        with Timer() as computeTimer:
            # Stream the data out to disk.
            request_blockshape = (
                self._primaryBlockwiseFileset.description.sub_block_shape
            )  # Could be None.  That's okay.
            streamer = BigRequestStreamer(self.Input, (roi.start, roi.stop),
                                          request_blockshape)
            streamer.progressSignal.subscribe(self.progressSignal)
            streamer.resultSignal.subscribe(self._handlePrimaryResultBlock)
            streamer.execute()

            # Now the block is ready.  Update the status.
            blockwiseFileset.setBlockStatus(roi.start,
                                            BlockwiseFileset.BLOCK_AVAILABLE)

        logger.info("Finished task in {} seconds".format(
            computeTimer.seconds()))
        result[0] = True
        return result

    def propagateDirty(self, slot, subindex, roi):
        self.ReturnCode.setDirty(slice(None))

    def _handlePrimaryResultBlock(self, roi, result):
        # First write the primary
        self._primaryBlockwiseFileset.writeData(roi, result)

        # Ask the workflow if there is any special post-processing to do...
        self.get_workflow().postprocessClusterSubResult(
            roi, result, self._primaryBlockwiseFileset)

    def get_workflow(self):
        op = self
        while not isinstance(op, Workflow):
            op = op.parent
        return op
class OpTrainPixelwiseClassifierBlocked(Operator):
    Images = InputSlot(level=1)
    Labels = InputSlot(level=1)
    ClassifierFactory = InputSlot()
    nonzeroLabelBlocks = InputSlot(level=1)
    MaxLabel = InputSlot()

    Classifier = OutputSlot()

    def __init__(self, *args, **kwargs):
        super(OpTrainPixelwiseClassifierBlocked,
              self).__init__(*args, **kwargs)
        self.progressSignal = OrderedSignal()

        # Normally, lane removal does not trigger a dirty notification.
        # But in this case, if the lane contained any label data whatsoever,
        #  the classifier needs to be marked dirty.
        # We know which slots contain (or contained) label data because they have
        # been 'touched' at some point (they became dirty at some point).
        self._touched_slots = set()

        def handle_new_lane(multislot, index, newlength):
            def handle_dirty_lane(slot, roi):
                self._touched_slots.add(slot)

            multislot[index].notifyDirty(handle_dirty_lane)

        self.Labels.notifyInserted(handle_new_lane)

        def handle_remove_lane(multislot, index, newlength):
            # If the lane we're removing contained
            # label data, then mark the downstream dirty
            if multislot[index] in self._touched_slots:
                self.Classifier.setDirty()
                self._touched_slots.remove(multislot[index])

        self.Labels.notifyRemove(handle_remove_lane)

    def setupOutputs(self):
        for slot in list(self.Images) + list(self.Labels):
            assert slot.meta.getAxisKeys()[-1] == 'c', \
                "This opearator assumes channel is the last axis."

        self.Classifier.meta.dtype = object
        self.Classifier.meta.shape = (1, )

        # Special metadata for downstream operators using the classifier
        self.Classifier.meta.classifier_factory = self.ClassifierFactory.value

    def cleanUp(self):
        self.progressSignal.clean()
        super(OpTrainPixelwiseClassifierBlocked, self).cleanUp()

    def execute(self, slot, subindex, roi, result):
        classifier_factory = self.ClassifierFactory.value
        assert issubclass(type(classifier_factory), LazyflowPixelwiseClassifierFactoryABC), \
            "Factory is of type {}, which does not satisfy the LazyflowPixelwiseClassifierFactoryABC interface."\
            "".format( type(classifier_factory) )

        # Accumulate all non-zero blocks of each image into lists
        label_data_blocks = []
        image_data_blocks = []
        for image_slot, label_slot, nonzero_block_slot in zip(
                self.Images, self.Labels, self.nonzeroLabelBlocks):
            block_slicings = nonzero_block_slot.value
            for block_slicing in block_slicings:
                # Get labels
                block_label_roi = sliceToRoi(block_slicing,
                                             label_slot.meta.shape)
                block_label_data = label_slot(*block_label_roi).wait()

                # Shrink roi to bounding box of actual label pixels
                bb_roi_within_block = nonzero_bounding_box(block_label_data)
                block_label_bb_roi = bb_roi_within_block + block_label_roi[0]

                # Double-check that there is at least 1 non-zero label in the block.
                if (block_label_bb_roi[1] > block_label_bb_roi[0]).all():
                    # Ask for the halo needed by the classifier
                    axiskeys = image_slot.meta.getAxisKeys()
                    halo_shape = classifier_factory.get_halo_shape(axiskeys)
                    assert len(halo_shape) == len(block_label_roi[0])
                    assert halo_shape[
                        -1] == 0, "Didn't expect a non-zero halo for channel dimension."

                    # Expand block by halo, but keep clipped to image bounds
                    padded_label_roi, bb_roi_within_padded = enlargeRoiForHalo(
                        *block_label_bb_roi,
                        shape=label_slot.meta.shape,
                        sigma=halo_shape,
                        window=1,
                        return_result_roi=True)

                    # Copy labels to new array, which has size == bounding-box + halo
                    padded_label_data = numpy.zeros(
                        padded_label_roi[1] - padded_label_roi[0],
                        label_slot.meta.dtype)
                    padded_label_data[roiToSlice(
                        *bb_roi_within_padded)] = block_label_data[roiToSlice(
                            *bb_roi_within_block)]

                    padded_image_roi = numpy.array(padded_label_roi)
                    assert (padded_image_roi[:, -1] == [0, 1]).all()
                    num_channels = image_slot.meta.shape[-1]
                    padded_image_roi[:, -1] = [0, num_channels]

                    # Ensure the results are plain ndarray, not VigraArray,
                    #  which some classifiers might have trouble with.
                    padded_image_data = numpy.asarray(
                        image_slot(*padded_image_roi).wait())

                    label_data_blocks.append(padded_label_data)
                    image_data_blocks.append(padded_image_data)

        if len(image_data_blocks) == 0:
            result[0] = None
        else:
            channel_names = self.Images[0].meta.channel_names
            axistags = self.Images[0].meta.axistags
            logger.debug("Training new pixelwise classifier: {}".format(
                classifier_factory.description))
            classifier = classifier_factory.create_and_train_pixelwise(
                image_data_blocks, label_data_blocks, axistags, channel_names)
            result[0] = classifier
            if classifier is not None:
                assert issubclass(type(classifier), LazyflowPixelwiseClassifierABC), \
                    "Classifier is of type {}, which does not satisfy the LazyflowPixelwiseClassifierABC interface."\
                    "".format( type(classifier) )

    def propagateDirty(self, slot, subindex, roi):
        self.Classifier.setDirty()
Exemple #29
0
class OpTrackingFeatureExtraction(Operator):
    name = "Tracking Feature Extraction"

    TranslationVectors = InputSlot(optional=True)
    RawImage = InputSlot()
    BinaryImage = InputSlot()

    # which features to compute.
    # nested dictionary with format:
    # dict[plugin_name][feature_name][parameter_name] = parameter_value
    # for example {"Standard Object Features": {"Mean in neighborhood":{"margin": (5, 5, 2)}}}
    FeatureNamesVigra = InputSlot(rtype=List, stype=Opaque, value={})
    
    FeatureNamesDivision = InputSlot(rtype=List, stype=Opaque, value={})
        

    LabelImage = OutputSlot()
    ObjectCenterImage = OutputSlot()
 
    # the computed features.
    # nested dictionary with format:
    # dict[plugin_name][feature_name] = feature_value
    RegionFeaturesVigra = OutputSlot(stype=Opaque, rtype=List)    
    RegionFeaturesDivision = OutputSlot(stype=Opaque, rtype=List)
    RegionFeaturesAll = OutputSlot(stype=Opaque, rtype=List)
    
    
    ComputedFeatureNamesVigra = OutputSlot(rtype=List, stype=Opaque)
    ComputedFeatureNamesAll = OutputSlot(rtype=List, stype=Opaque)

    BlockwiseRegionFeaturesVigra = OutputSlot() # For compatibility with tracking workflow, the RegionFeatures output
                                                # has rtype=List, indexed by t.
                                                # For other workflows, output has rtype=ArrayLike, indexed by (t)
    BlockwiseRegionFeaturesDivision = OutputSlot() 
    
    LabelInputHdf5 = InputSlot(optional=True)
    LabelOutputHdf5 = OutputSlot()
    CleanLabelBlocks = OutputSlot()

    RegionFeaturesCacheInputVigra = InputSlot(optional=True)
    RegionFeaturesCleanBlocksVigra = OutputSlot()
    
    RegionFeaturesCacheInputDivision = InputSlot(optional=True)
    RegionFeaturesCleanBlocksDivision = OutputSlot()
    
        
    def __init__(self, parent):
        super(OpTrackingFeatureExtraction, self).__init__(parent)
        
        # internal operators
        self._objectExtraction = OpObjectExtraction(parent=self)
                
        self._opDivFeats = OpCachedDivisionFeatures(parent=self)
        self._opDivFeatsAdaptOutput = OpAdaptTimeListRoi(parent=self)        

        # connect internal operators
        self._objectExtraction.RawImage.connect(self.RawImage)
        self._objectExtraction.BinaryImage.connect(self.BinaryImage)
        
        self._objectExtraction.Features.connect(self.FeatureNamesVigra)
        self._objectExtraction.LabelInputHdf5.connect(self.LabelInputHdf5)
        self._objectExtraction.RegionFeaturesCacheInput.connect(self.RegionFeaturesCacheInputVigra)
        self.LabelOutputHdf5.connect(self._objectExtraction.LabelOutputHdf5)
        self.CleanLabelBlocks.connect(self._objectExtraction.CleanLabelBlocks)
        self.RegionFeaturesCleanBlocksVigra.connect(self._objectExtraction.RegionFeaturesCleanBlocks)
        self.ObjectCenterImage.connect(self._objectExtraction.ObjectCenterImage)
        self.LabelImage.connect(self._objectExtraction.LabelImage)
        self.BlockwiseRegionFeaturesVigra.connect(self._objectExtraction.BlockwiseRegionFeatures)     
        self.ComputedFeatureNamesVigra.connect(self._objectExtraction.Features)
        self.RegionFeaturesVigra.connect(self._objectExtraction.RegionFeatures)    
                
        self._opDivFeats.LabelImage.connect(self.LabelImage)
        self._opDivFeats.DivisionFeatureNames.connect(self.FeatureNamesDivision)
        self._opDivFeats.CacheInput.connect(self.RegionFeaturesCacheInputDivision)
        self._opDivFeats.RegionFeaturesVigra.connect(self._objectExtraction.BlockwiseRegionFeatures)
        self.RegionFeaturesCleanBlocksDivision.connect(self._opDivFeats.CleanBlocks)        
        self.BlockwiseRegionFeaturesDivision.connect(self._opDivFeats.Output)
        
        self._opDivFeatsAdaptOutput.Input.connect(self._opDivFeats.Output)
        self.RegionFeaturesDivision.connect(self._opDivFeatsAdaptOutput.Output)
        
        # As soon as input data is available, check its constraints
        self.RawImage.notifyReady( self._checkConstraints )
        self.BinaryImage.notifyReady( self._checkConstraints )
    
               
    def setupOutputs(self, *args, **kwargs):
        self.ComputedFeatureNamesAll.meta.assignFrom(self.FeatureNamesVigra.meta)
        self.RegionFeaturesAll.meta.assignFrom(self.RegionFeaturesVigra.meta)
        
    def execute(self, slot, subindex, roi, result):
        if slot == self.ComputedFeatureNamesAll:
            feat_names_vigra = self.ComputedFeatureNamesVigra([]).wait()
            feat_names_div = self.FeatureNamesDivision([]).wait()        
            for plugin_name in feat_names_vigra.keys():
                assert plugin_name not in feat_names_div, "feature name dictionaries must be mutually exclusive"
            for plugin_name in feat_names_div.keys():
                assert plugin_name not in feat_names_vigra, "feature name dictionaries must be mutually exclusive"
            result = dict(feat_names_vigra.items() + feat_names_div.items())
            # FIXME do not hard-code this
            for name in [ 'SquaredDistances_' + str(i) for i in range(config.n_best_successors) ]:
                result[config.features_division_name][name] = {}
            
            return result
        elif slot == self.RegionFeaturesAll:
            feat_vigra = self.RegionFeaturesVigra(roi).wait()
            feat_div = self.RegionFeaturesDivision(roi).wait()
            assert np.all(feat_vigra.keys() == feat_div.keys())
            result = {}        
            for t in feat_vigra.keys():
                for plugin_name in feat_vigra[t].keys():
                    assert plugin_name not in feat_div[t], "feature dictionaries must be mutually exclusive"
                for plugin_name in feat_div[t].keys():
                    assert plugin_name not in feat_vigra[t], "feature dictionaries must be mutually exclusive"                    
                result[t] = dict(feat_div[t].items() + feat_vigra[t].items())            
            return result
        else:
            assert False, "Shouldn't get here."

    def propagateDirty(self, slot, subindex, roi):
        if slot == self.ComputedFeatureNamesVigra or slot == self.FeatureNamesDivision:
            self.ComputedFeatureNamesAll.setDirty(roi)
    
    def setInSlot(self, slot, subindex, roi, value):
        assert slot == self.LabelInputHdf5 or slot == self.RegionFeaturesCacheInputVigra or \
            slot == self.RegionFeaturesCacheInputDivision, "Invalid slot for setInSlot(): {}".format(slot.name)
           
    def _checkConstraints(self, *args):
        if self.RawImage.ready():
            rawTaggedShape = self.RawImage.meta.getTaggedShape()
            if 't' not in rawTaggedShape or rawTaggedShape['t'] < 2:
                msg = "Raw image must have a time dimension with at least 2 images.\n"\
                    "Your dataset has shape: {}".format(self.RawImage.meta.shape)
                    
        if self.BinaryImage.ready():
            rawTaggedShape = self.BinaryImage.meta.getTaggedShape()
            if 't' not in rawTaggedShape or rawTaggedShape['t'] < 2:
                msg = "Binary image must have a time dimension with at least 2 images.\n"\
                    "Your dataset has shape: {}".format(self.BinaryImage.meta.shape)
                    
        if self.RawImage.ready() and self.BinaryImage.ready():
            rawTaggedShape = self.RawImage.meta.getTaggedShape()
            binTaggedShape = self.BinaryImage.meta.getTaggedShape()
            rawTaggedShape['c'] = None
            binTaggedShape['c'] = None
            if dict(rawTaggedShape) != dict(binTaggedShape):
                logger.info("Raw data and other data must have equal dimensions (different channels are okay).\n"\
                      "Your datasets have shapes: {} and {}".format( self.RawImage.meta.shape, self.BinaryImage.meta.shape ))
                
                msg = "Raw data and other data must have equal dimensions (different channels are okay).\n"\
                      "Your datasets have shapes: {} and {}".format( self.RawImage.meta.shape, self.BinaryImage.meta.shape )
                raise DatasetConstraintError( "Object Extraction", msg ) 
Exemple #30
0
class OpWrapperFeatureSelection(Operator):
    FeatureLabelMatrix = InputSlot(level=1)
    WrapperMethod = InputSlot(optional=True)  # "SFS", "BFS", or "SBE"
    Classifier = InputSlot(
        optional=True
    )  # not used. In the future it should be possible to plug in a classifier here.
    # Default classifier it sklearn random forest
    EvaluationFunction = InputSlot(
        optional=True)  # if this is not connected then we use a default
    ComplexityPenalty = InputSlot(optional=True)

    SelectedFeatureIDs = OutputSlot()

    def setupOutputs(self):
        if self.WrapperMethod.connected():
            self._wrapper_method = self.WrapperMethod.value
        else:
            self._wrapper_method = "SFS"

        if self.Classifier.connected():
            self._classifier = self.Classifier.value
        else:
            from sklearn import ensemble

            self._classifier = ensemble.RandomForestClassifier(
                n_estimators=100, n_jobs=-1)

        if self.EvaluationFunction.connected():
            self._evaluation_fct = self.EvaluationFunction.value
        else:
            if self.ComplexityPenalty.connected():
                complexity_penalty = self.ComplexityPenalty.value
            else:
                complexity_penalty = 0.07  # default
            self._evaluator = ilastik_feature_selection.wrapper_feature_selection.EvaluationFunction(
                self._classifier, complexity_penalty=complexity_penalty)
            self._evaluation_fct = self._evaluator.evaluate_feature_set_size_penalty

        # the output slot should maybe contain the internal feature IDs or a bool list of len(internal_feature_ids)
        self.SelectedFeatureIDs.meta.shape = (1, )
        self.SelectedFeatureIDs.meta.dtype = list

    def execute(self, slot, subindex, roi, result):

        feature_label_matrix = self.FeatureLabelMatrix[0].value

        labels = feature_label_matrix[:, 0]  # first row is labels
        data = feature_label_matrix[:, 1:]  # the rest is data

        feature_selector = ilastik_feature_selection.wrapper_feature_selection.WrapperFeatureSelection(
            data, labels.astype("int"), self._evaluation_fct,
            self._wrapper_method)

        selected_features = feature_selector.run(overshoot=3)[0]

        # selected_features_names = [self.FeatureImages[0].meta['channel_names'][i] for i in selected_features]

        result = [selected_features]
        return result

    def propagateDirty(self, slot, subindex, roi):
        self.SelectedFeatureIDs.setDirty()