Example #1
0
class OpDataSelectionGroup(Operator):
    # Inputs
    ProjectFile = InputSlot(stype='object', optional=True)
    ProjectDataGroup = InputSlot(stype='string', optional=True)
    WorkingDirectory = InputSlot(stype='filestring')
    DatasetRoles = InputSlot(stype='object')

    DatasetGroup = InputSlot(
        stype='object', level=1, optional=True
    )  # Must mark as optional because not all subslots are required.

    # Outputs
    ImageGroup = OutputSlot(level=1)

    # These output slots are provided as a convenience, since otherwise it is tricky to create a lane-wise multislot of level-1 for only a single role.
    # (It can be done, but requires OpTransposeSlots to invert the level-2 multislot indexes...)
    Image = OutputSlot()  # The first dataset. Equivalent to ImageGroup[0]
    Image1 = OutputSlot()  # The second dataset. Equivalent to ImageGroup[1]
    Image2 = OutputSlot()  # The third dataset. Equivalent to ImageGroup[2]
    AllowLabels = OutputSlot(
        stype='bool')  # Pulled from the first dataset only.

    _NonTransposedImageGroup = OutputSlot(level=1)

    # Must be the LAST slot declared in this class.
    # When the shell detects that this slot has been resized,
    #  it assumes all the others have already been resized.
    ImageName = OutputSlot(
    )  # Name of the first dataset is used.  Other names are ignored.

    def __init__(self, forceAxisOrder=None, *args, **kwargs):
        super(OpDataSelectionGroup, self).__init__(*args, **kwargs)
        self._opDatasets = None
        self._roles = []
        self._forceAxisOrder = forceAxisOrder

        def handleNewRoles(*args):
            self.DatasetGroup.resize(len(self.DatasetRoles.value))

        self.DatasetRoles.notifyReady(handleNewRoles)

    def setupOutputs(self):
        # Create internal operators
        if self.DatasetRoles.value != self._roles:
            self._roles = self.DatasetRoles.value
            # Clean up the old operators
            self.ImageGroup.disconnect()
            self.Image.disconnect()
            self.Image1.disconnect()
            self.Image2.disconnect()
            self._NonTransposedImageGroup.disconnect()
            if self._opDatasets is not None:
                self._opDatasets.cleanUp()

            self._opDatasets = OperatorWrapper(
                OpDataSelection,
                parent=self,
                operator_kwargs={'forceAxisOrder': self._forceAxisOrder},
                broadcastingSlotNames=[
                    'ProjectFile', 'ProjectDataGroup', 'WorkingDirectory'
                ])
            self.ImageGroup.connect(self._opDatasets.Image)
            self._NonTransposedImageGroup.connect(
                self._opDatasets._NonTransposedImage)
            self._opDatasets.Dataset.connect(self.DatasetGroup)
            self._opDatasets.ProjectFile.connect(self.ProjectFile)
            self._opDatasets.ProjectDataGroup.connect(self.ProjectDataGroup)
            self._opDatasets.WorkingDirectory.connect(self.WorkingDirectory)

        if len(self._opDatasets.Image) > 0:
            self.Image.connect(self._opDatasets.Image[0])

            if len(self._opDatasets.Image) >= 2:
                self.Image1.connect(self._opDatasets.Image[1])
            else:
                self.Image1.disconnect()
                self.Image1.meta.NOTREADY = True

            if len(self._opDatasets.Image) >= 3:
                self.Image2.connect(self._opDatasets.Image[2])
            else:
                self.Image2.disconnect()
                self.Image2.meta.NOTREADY = True

            self.ImageName.connect(self._opDatasets.ImageName[0])
            self.AllowLabels.connect(self._opDatasets.AllowLabels[0])
        else:
            self.Image.disconnect()
            self.Image1.disconnect()
            self.Image2.disconnect()
            self.ImageName.disconnect()
            self.AllowLabels.disconnect()
            self.Image.meta.NOTREADY = True
            self.Image1.meta.NOTREADY = True
            self.Image2.meta.NOTREADY = True
            self.ImageName.meta.NOTREADY = True
            self.AllowLabels.meta.NOTREADY = True

    def execute(self, slot, subindex, rroi, result):
        assert False, "Unknown or unconnected output slot: {}".format(
            slot.name)

    def propagateDirty(self, slot, subindex, roi):
        # Output slots are directly connected to internal operators
        pass
Example #2
0
class OpConservationTracking(OpTrackingBase):
    DivisionProbabilities = InputSlot(optional=True, stype=Opaque, rtype=List)
    DetectionProbabilities = InputSlot(stype=Opaque, rtype=List)
    NumLabels = InputSlot()

    # compressed cache for merger output
    MergerInputHdf5 = InputSlot(optional=True)
    MergerCleanBlocks = OutputSlot()
    MergerOutputHdf5 = OutputSlot()
    MergerCachedOutput = OutputSlot()  # For the GUI (blockwise access)
    MergerOutput = OutputSlot()

    CoordinateMap = OutputSlot()

    RelabeledInputHdf5 = InputSlot(optional=True)
    RelabeledCleanBlocks = OutputSlot()
    RelabeledOutputHdf5 = OutputSlot()
    RelabeledCachedOutput = OutputSlot()  # For the GUI (blockwise access)
    RelabeledImage = OutputSlot()

    def __init__(self, parent=None, graph=None):
        super(OpConservationTracking, self).__init__(parent=parent,
                                                     graph=graph)

        self._mergerOpCache = OpCompressedCache(parent=self)
        self._mergerOpCache.InputHdf5.connect(self.MergerInputHdf5)
        self._mergerOpCache.Input.connect(self.MergerOutput)
        self.MergerCleanBlocks.connect(self._mergerOpCache.CleanBlocks)
        self.MergerOutputHdf5.connect(self._mergerOpCache.OutputHdf5)
        self.MergerCachedOutput.connect(self._mergerOpCache.Output)

        self._relabeledOpCache = OpCompressedCache(parent=self)
        self._relabeledOpCache.InputHdf5.connect(self.RelabeledInputHdf5)
        self._relabeledOpCache.Input.connect(self.RelabeledImage)
        self.RelabeledCleanBlocks.connect(self._relabeledOpCache.CleanBlocks)
        self.RelabeledOutputHdf5.connect(self._relabeledOpCache.OutputHdf5)
        self.RelabeledCachedOutput.connect(self._relabeledOpCache.Output)
        self.tracker = None
        self._ndim = 3

    def setupOutputs(self):
        super(OpConservationTracking, self).setupOutputs()
        self.MergerOutput.meta.assignFrom(self.LabelImage.meta)
        self.RelabeledImage.meta.assignFrom(self.LabelImage.meta)
        self._ndim = 2 if self.LabelImage.meta.shape[3] == 1 else 3

        self._mergerOpCache.BlockShape.setValue(self._blockshape)
        self._relabeledOpCache.BlockShape.setValue(self._blockshape)

        frame_shape = (
            1, ) + self.LabelImage.meta.shape[1:]  # assumes t,x,y,z,c order
        assert frame_shape[-1] == 1
        self.MergerOutput.meta.ideal_blockshape = frame_shape
        self.RelabeledImage.meta.ideal_blockshape = frame_shape

    def execute(self, slot, subindex, roi, result):
        if slot is self.Output:
            parameters = self.Parameters.value
            trange = range(roi.start[0], roi.stop[0])
            original = np.zeros(result.shape, dtype=slot.meta.dtype)
            super(OpConservationTracking,
                  self).execute(slot, subindex, roi, original)
            result[:] = self.LabelImage.get(roi).wait()
            pixel_offsets = roi.start[
                1:-1]  # offset only in pixels, not time and channel
            for t in trange:
                if ('time_range' in parameters
                        and t <= parameters['time_range'][-1]
                        and t >= parameters['time_range'][0]
                        and len(self.resolvedto) > t
                        and len(self.resolvedto[t])):
                    result[t - roi.start[0], ..., 0] = self._relabelMergers(
                        result[t - roi.start[0], ..., 0], t, pixel_offsets)
                else:
                    result[t - roi.start[0], ...][:] = 0

            original[result != 0] = result[result != 0]
            result[:] = original
        elif slot is self.MergerOutput:
            parameters = self.Parameters.value
            trange = range(roi.start[0], roi.stop[0])
            result[:] = self.LabelImage.get(roi).wait()
            pixel_offsets = roi.start[
                1:-1]  # offset only in pixels, not time and channel
            for t in trange:
                if ('time_range' in parameters
                        and t <= parameters['time_range'][-1]
                        and t >= parameters['time_range'][0]
                        and len(self.mergers) > t and len(self.mergers[t])):
                    if 'withMergerResolution' in parameters.keys(
                    ) and parameters['withMergerResolution']:
                        result[t - roi.start[0], ...,
                               0] = self._relabelMergers(
                                   result[t - roi.start[0], ..., 0], t,
                                   pixel_offsets, True)
                    else:
                        result[t - roi.start[0], ..., 0] = highlightMergers(
                            result[t - roi.start[0], ..., 0], self.mergers[t])
                else:
                    result[t - roi.start[0], ...][:] = 0
        elif slot is self.RelabeledImage:
            parameters = self.Parameters.value
            trange = range(roi.start[0], roi.stop[0])
            result[:] = self.LabelImage.get(roi).wait()
            pixel_offsets = roi.start[
                1:-1]  # offset only in pixels, not time and channel
            for t in trange:
                if ('time_range' in parameters
                        and t <= parameters['time_range'][-1]
                        and t >= parameters['time_range'][0]
                        and len(self.resolvedto) > t
                        and len(self.resolvedto[t])
                        and 'withMergerResolution' in parameters.keys()
                        and parameters['withMergerResolution']):
                    result[t - roi.start[0], ..., 0] = self._relabelMergers(
                        result[t - roi.start[0], ..., 0], t, pixel_offsets,
                        False, True)
        else:  # default bahaviour
            super(OpConservationTracking,
                  self).execute(slot, subindex, roi, result)
        return result

    def setInSlot(self, slot, subindex, roi, value):
        assert slot == self.InputHdf5 or slot == self.MergerInputHdf5 or slot == self.RelabeledInputHdf5, "Invalid slot for setInSlot(): {}".format(
            slot.name)

    def track(self,
              time_range,
              x_range,
              y_range,
              z_range,
              size_range=(0, 100000),
              x_scale=1.0,
              y_scale=1.0,
              z_scale=1.0,
              maxDist=30,
              maxObj=2,
              divThreshold=0.5,
              avgSize=[0],
              withTracklets=False,
              sizeDependent=True,
              divWeight=10.0,
              transWeight=10.0,
              withDivisions=True,
              withOpticalCorrection=True,
              withClassifierPrior=False,
              ndim=3,
              cplex_timeout=None,
              withMergerResolution=True,
              borderAwareWidth=0.0,
              withArmaCoordinates=True,
              appearance_cost=500,
              disappearance_cost=500,
              motionModelWeight=10.0,
              force_build_hypotheses_graph=False,
              max_nearest_neighbors=1,
              withBatchProcessing=False,
              solverName="ILP"):

        if not self.Parameters.ready():
            raise Exception("Parameter slot is not ready")

        # it is assumed that the self.Parameters object is changed only at this
        # place (ugly assumption). Therefore we can track any changes in the
        # parameters as done in the following lines: If the same value for the
        # key is already written in the parameters dictionary, the
        # paramters_changed dictionary will get a "False" entry for this key,
        # otherwise it is set to "True"
        parameters = self.Parameters.value

        parameters['maxDist'] = maxDist
        parameters['maxObj'] = maxObj
        parameters['divThreshold'] = divThreshold
        parameters['avgSize'] = avgSize
        parameters['withTracklets'] = withTracklets
        parameters['sizeDependent'] = sizeDependent
        parameters['divWeight'] = divWeight
        parameters['transWeight'] = transWeight
        parameters['withDivisions'] = withDivisions
        parameters['withOpticalCorrection'] = withOpticalCorrection
        parameters['withClassifierPrior'] = withClassifierPrior
        parameters['withMergerResolution'] = withMergerResolution
        parameters['borderAwareWidth'] = borderAwareWidth
        parameters['withArmaCoordinates'] = withArmaCoordinates
        parameters['appearanceCost'] = appearance_cost
        parameters['disappearanceCost'] = disappearance_cost

        do_build_hypotheses_graph = True

        if cplex_timeout:
            parameters['cplex_timeout'] = cplex_timeout
        else:
            parameters['cplex_timeout'] = ''
            cplex_timeout = float(1e75)

        if withClassifierPrior:
            if not self.DetectionProbabilities.ready() or len(
                    self.DetectionProbabilities([0]).wait()[0]) == 0:
                raise DatasetConstraintError(
                    'Tracking',
                    'Classifier not ready yet. Did you forget to train the Object Count Classifier?'
                )
            if not self.NumLabels.ready() or self.NumLabels.value < (maxObj +
                                                                     1):
                raise DatasetConstraintError('Tracking', 'The max. number of objects must be consistent with the number of labels given in Object Count Classification.\n' +\
                    'Check whether you have (i) the correct number of label names specified in Object Count Classification, and (ii) provided at least ' +\
                    'one training example for each class.')
            if len(self.DetectionProbabilities(
                [0]).wait()[0][0]) < (maxObj + 1):
                raise DatasetConstraintError('Tracking', 'The max. number of objects must be consistent with the number of labels given in Object Count Classification.\n' +\
                    'Check whether you have (i) the correct number of label names specified in Object Count Classification, and (ii) provided at least ' +\
                    'one training example for each class.')

        median_obj_size = [0]

        fs, ts, empty_frame, max_traxel_id_at = self._generate_traxelstore(
            time_range,
            x_range,
            y_range,
            z_range,
            size_range,
            x_scale,
            y_scale,
            z_scale,
            median_object_size=median_obj_size,
            with_div=withDivisions,
            with_opt_correction=withOpticalCorrection,
            with_classifier_prior=withClassifierPrior)

        if empty_frame:
            raise DatasetConstraintError(
                'Tracking', 'Can not track frames with 0 objects, abort.')

        if avgSize[0] > 0:
            median_obj_size = avgSize

        logger.info('median_obj_size = {}'.format(median_obj_size))

        ep_gap = 0.05
        transition_parameter = 5

        fov = pgmlink.FieldOfView(
            time_range[0] * 1.0,
            x_range[0] * x_scale,
            y_range[0] * y_scale,
            z_range[0] * z_scale,
            time_range[-1] * 1.0,
            (x_range[1] - 1) * x_scale,
            (y_range[1] - 1) * y_scale,
            (z_range[1] - 1) * z_scale,
        )

        logger.info('fov = {},{},{},{},{},{},{},{}'.format(
            time_range[0] * 1.0,
            x_range[0] * x_scale,
            y_range[0] * y_scale,
            z_range[0] * z_scale,
            time_range[-1] * 1.0,
            (x_range[1] - 1) * x_scale,
            (y_range[1] - 1) * y_scale,
            (z_range[1] - 1) * z_scale,
        ))

        if ndim == 2:
            assert z_range[0] * z_scale == 0 and (
                z_range[1] -
                1) * z_scale == 0, "fov of z must be (0,0) if ndim==2"

        if self.tracker is None:
            do_build_hypotheses_graph = True

        solverType = self.getPgmlinkSolverType(solverName)

        if do_build_hypotheses_graph:
            print '\033[94m' + "make new graph" + '\033[0m'
            self.tracker = pgmlink.ConsTracking(
                int(maxObj),
                bool(sizeDependent),  # size_dependent_detection_prob
                float(median_obj_size[0]),  # median_object_size
                float(maxDist),
                bool(withDivisions),
                float(divThreshold),
                "none",  # detection_rf_filename
                fov,
                "none",  # dump traxelstore,
                solverType,
                ndim)
            g = self.tracker.buildGraph(ts, max_nearest_neighbors)

        # create dummy uncertainty parameter object with just one iteration, so no perturbations at all (iter=0 -> MAP)
        sigmas = pgmlink.VectorOfDouble()
        for i in range(5):
            sigmas.append(0.0)
        uncertaintyParams = pgmlink.UncertaintyParameter(
            1, pgmlink.DistrId.PerturbAndMAP, sigmas)

        params = self.tracker.get_conservation_tracking_parameters(
            0,  # forbidden_cost
            float(ep_gap),  # ep_gap
            bool(withTracklets),  # with tracklets
            float(10.0),  # detection weight
            float(divWeight),  # division weight
            float(transWeight),  # transition weight
            float(disappearance_cost),  # disappearance cost
            float(appearance_cost),  # appearance cost
            bool(withMergerResolution),  # with merger resolution
            int(ndim),  # ndim
            float(transition_parameter),  # transition param
            float(borderAwareWidth),  # border width
            True,  #with_constraints
            uncertaintyParams,  # uncertainty parameters
            float(cplex_timeout),  # cplex timeout
            None,  # transition classifier
            solverType,
            False,  # training to hard constraints
            1  # num threads
        )

        # if motionModelWeight > 0:
        #     logger.info("Registering motion model with weight {}".format(motionModelWeight))
        #     params.register_motion_model4_func(swirl_motion_func_creator(motionModelWeight), motionModelWeight * 25.0)

        try:
            eventsVector = self.tracker.track(params, False)

            eventsVector = eventsVector[
                0]  # we have a vector such that we could get a vector per perturbation

            # extract the coordinates with the given event vector
            if withMergerResolution:
                coordinate_map = pgmlink.TimestepIdCoordinateMap()

                self._get_merger_coordinates(coordinate_map, time_range,
                                             eventsVector)
                self.CoordinateMap.setValue(coordinate_map)

                eventsVector = self.tracker.resolve_mergers(
                    eventsVector,
                    params,
                    coordinate_map.get(),
                    float(ep_gap),
                    float(transWeight),
                    bool(withTracklets),
                    ndim,
                    transition_parameter,
                    max_traxel_id_at,
                    True,  # with_constraints
                    None)  # TransitionClassifier

        except Exception as e:
            raise Exception, 'Tracking terminated unsuccessfully: ' + str(e)

        if len(eventsVector) == 0:
            raise Exception, 'Tracking terminated unsuccessfully: Events vector has zero length.'

        events = get_events(eventsVector)
        self.Parameters.setValue(parameters, check_changed=False)
        self.EventsVector.setValue(events, check_changed=False)
        self.RelabeledImage.setDirty()

        if not withBatchProcessing:
            merger_layer_idx = self.parent.parent.trackingApplet._gui.currentGui(
            ).layerstack.findMatchingIndex(lambda x: x.name == "Merger")
            tracking_layer_idx = self.parent.parent.trackingApplet._gui.currentGui(
            ).layerstack.findMatchingIndex(lambda x: x.name == "Tracking")
            if 'withMergerResolution' in parameters.keys(
            ) and not parameters['withMergerResolution']:
                self.parent.parent.trackingApplet._gui.currentGui().layerstack[merger_layer_idx].colorTable = \
                    self.parent.parent.trackingApplet._gui.currentGui().merger_colortable
            else:
                self.parent.parent.trackingApplet._gui.currentGui().layerstack[merger_layer_idx].colorTable = \
                    self.parent.parent.trackingApplet._gui.currentGui().tracking_colortable

    @staticmethod
    def getPgmlinkSolverType(solverName):
        if solverName == "ILP":
            # select solver type
            if hasattr(pgmlink.ConsTrackingSolverType, "CplexSolver"):
                solver = pgmlink.ConsTrackingSolverType.CplexSolver
            else:
                raise AssertionError(
                    "Cannot select ILP solver if pgmlink was compiled without ILP support"
                )
        elif solverName == "Magnusson":
            if hasattr(pgmlink.ConsTrackingSolverType, "DynProgSolver"):
                solver = pgmlink.ConsTrackingSolverType.DynProgSolver
            else:
                raise AssertionError(
                    "Cannot select Magnusson solver if pgmlink was compiled without Magnusson support"
                )
        elif solverName == "Flow":
            if hasattr(pgmlink.ConsTrackingSolverType, "FlowSolver"):
                solver = pgmlink.ConsTrackingSolverType.FlowSolver
            else:
                raise AssertionError(
                    "Cannot select Flow solver if pgmlink was compiled without Flow support"
                )
        else:
            raise ValueError("Unknown solver {} selected".format(solverName))
        return solver

    def propagateDirty(self, inputSlot, subindex, roi):
        super(OpConservationTracking,
              self).propagateDirty(inputSlot, subindex, roi)

        if inputSlot == self.NumLabels:
            if self.parent.parent.trackingApplet._gui \
                    and self.parent.parent.trackingApplet._gui.currentGui() \
                    and self.NumLabels.ready() \
                    and self.NumLabels.value > 1:
                self.parent.parent.trackingApplet._gui.currentGui(
                )._drawer.maxObjectsBox.setValue(self.NumLabels.value - 1)

    def _get_merger_coordinates(self, coordinate_map, time_range,
                                eventsVector):
        # fetch features
        feats = self.ObjectFeatures(time_range).wait()
        # iterate over all timesteps
        for t in feats.keys():
            rc = feats[t][default_features_key]['RegionCenter']
            lower = feats[t][default_features_key]['Coord<Minimum>']
            upper = feats[t][default_features_key]['Coord<Maximum>']
            size = feats[t][default_features_key]['Count']
            for event in eventsVector[t]:
                # check for merger events
                if event.type == pgmlink.EventType.Merger:
                    idx = event.traxel_ids[0]
                    # generate roi: assume the following order: txyzc
                    n_dim = len(rc[idx])
                    roi = [0] * 5
                    roi[0] = slice(int(t), int(t + 1))
                    roi[1] = slice(int(lower[idx][0]), int(upper[idx][0] + 1))
                    roi[2] = slice(int(lower[idx][1]), int(upper[idx][1] + 1))
                    if n_dim == 3:
                        roi[3] = slice(int(lower[idx][2]),
                                       int(upper[idx][2] + 1))
                    else:
                        assert n_dim == 2
                    image_excerpt = self.LabelImage[roi].wait()
                    if n_dim == 2:
                        image_excerpt = image_excerpt[0, ..., 0, 0]
                    elif n_dim == 3:
                        image_excerpt = image_excerpt[0, ..., 0]
                    else:
                        raise Exception, "n_dim = %s instead of 2 or 3"

                    pgmlink.extract_coord_by_timestep_id(
                        coordinate_map, image_excerpt,
                        lower[idx].astype(np.int64), t, idx, int(size[idx, 0]))

    def _relabelMergers(self,
                        volume,
                        time,
                        pixel_offsets=[0, 0, 0],
                        onlyMergers=False,
                        noRelabeling=False):
        if self.CoordinateMap.value.size() == 0:
            logger.info(
                "Skipping merger relabeling because coordinate map is empty")
            if onlyMergers:
                return np.zeros_like(volume)
            else:
                return volume
        if time >= len(self.resolvedto):
            if onlyMergers:
                return np.zeros_like(volume)
            else:
                return volume

        coordinate_map = self.CoordinateMap.value
        valid_ids = []
        for old_id, new_ids in self.resolvedto[time].iteritems():
            for new_id in new_ids:
                # TODO Reliable distinction between 2d and 3d?
                if self._ndim == 2:
                    # Assume we have 2d data: bind z to zero
                    relabel_volume = volume[..., 0]
                else:
                    # For 3d data use the whole volume
                    relabel_volume = volume
                # relabel
                pgmlink.update_labelimage(
                    coordinate_map, relabel_volume,
                    np.array(pixel_offsets, dtype=np.int64), int(time),
                    int(new_id))
                valid_ids.append(new_id)

        if onlyMergers:
            # find indices of merger ids, set everything else to zero
            idx = np.in1d(volume.ravel(), valid_ids).reshape(volume.shape)
            volume[-idx] = 0

        if noRelabeling:
            return volume
        else:
            return relabel(volume, self.label2color[time])

    def do_export(self,
                  settings,
                  selected_features,
                  progress_slot,
                  lane_index,
                  filename_suffix=""):
        """
        Implements ExportOperator.do_export(settings, selected_features, progress_slot
        Most likely called from ExportOperator.export_object_data
        :param settings: the settings for the exporter, see
        :param selected_features:
        :param progress_slot:
        :param lane_index: Ignored. (This is a single-lane operator. It is the caller's responsibility to make sure he's calling the right lane.)
        :param filename_suffix: If provided, appended to the filename (before the extension).
        :return:
        """

        #assert lane_index == 0, "This has only been tested in tracking workflows with a single image."

        with_divisions = self.Parameters.value[
            "withDivisions"] if self.Parameters.ready() else False
        with_merger_resolution = self.Parameters.value[
            "withMergerResolution"] if self.Parameters.ready() else False

        if with_divisions:
            object_feature_slot = self.ObjectFeaturesWithDivFeatures
        else:
            object_feature_slot = self.ObjectFeatures

        if with_merger_resolution:
            label_image = self.RelabeledImage

            opRelabeledRegionFeatures = self._setupRelabeledFeatureSlot(
                object_feature_slot)
            object_feature_slot = opRelabeledRegionFeatures.RegionFeatures
        else:
            label_image = self.LabelImage

        self._do_export_impl(settings, selected_features, progress_slot,
                             object_feature_slot, label_image, lane_index,
                             filename_suffix)

        if with_merger_resolution:
            opRelabeledRegionFeatures.cleanUp()

    def _setupRelabeledFeatureSlot(self, original_feature_slot):
        from ilastik.applets.trackingFeatureExtraction import config
        # when exporting after merger resolving, the stored object features are not up to date for the relabeled objects
        opRelabeledRegionFeatures = OpRelabeledMergerFeatureExtraction(
            parent=self)
        opRelabeledRegionFeatures.RawImage.connect(self.RawImage)
        opRelabeledRegionFeatures.LabelImage.connect(self.LabelImage)
        opRelabeledRegionFeatures.RelabeledImage.connect(self.RelabeledImage)
        opRelabeledRegionFeatures.OriginalRegionFeatures.connect(
            original_feature_slot)
        opRelabeledRegionFeatures.ResolvedTo.setValue(self.resolvedto)

        vigra_features = list((set(config.vigra_features)).union(
            config.selected_features_objectcount[config.features_vigra_name]))
        feature_names_vigra = {}
        feature_names_vigra[config.features_vigra_name] = {
            name: {}
            for name in vigra_features
        }
        opRelabeledRegionFeatures.FeatureNames.setValue(feature_names_vigra)

        return opRelabeledRegionFeatures
class OpSlicedBlockedArrayCache(OpCache):
    name = "OpSlicedBlockedArrayCache"
    description = ""

    #Inputs
    Input = InputSlot()
    innerBlockShape = InputSlot()
    outerBlockShape = InputSlot()
    fixAtCurrent = InputSlot(value=False)

    #Outputs
    Output = OutputSlot()
    InnerOutputs = OutputSlot(level=1)

    loggerName = __name__ + ".OpSlicedBlockedArrayCache"
    logger = logging.getLogger(loggerName)
    traceLogger = logging.getLogger("TRACE." + loggerName)

    def __init__(self, *args, **kwargs):
        super(OpSlicedBlockedArrayCache, self).__init__(*args, **kwargs)
        self._innerOps = []

    def generateReport(self, report):
        report.name = self.name
        report.fractionOfUsedMemoryDirty = self.fractionOfUsedMemoryDirty()
        report.usedMemory = self.usedMemory()
        report.lastAccessTime = self.lastAccessTime()
        report.dtype = self.Output.meta.dtype
        report.type = type(self)
        report.id = id(self)
        sh = self.Output.meta.shape
        if sh is not None:
            report.roi = ([0] * len(sh), sh)

        for i, iOp in enumerate(self._innerOps):
            n = MemInfoNode()
            report.children.append(n)
            iOp.generateReport(n)

    def usedMemory(self):
        tot = 0.0
        for iOp in self._innerOps:
            tot += iOp.usedMemory()
        return tot

    def setupOutputs(self):
        self.shape = self.inputs["Input"].meta.shape
        self._outerShapes = self.inputs["outerBlockShape"].value
        self._innerShapes = self.inputs["innerBlockShape"].value

        for blockshape in self._innerShapes + self._outerShapes:
            if len(blockshape) != len(self.Input.meta.shape):
                self.Output.meta.NOTREADY = True
                return

        # FIXME: This is wrong: Shouldn't it actually compare the new inner block shape with the old one?
        if len(self._innerShapes) != len(self._innerOps):
            # Clean up previous inner operators
            for slot in self.InnerOutputs:
                slot.disconnect()
            for o in self._innerOps:
                o.cleanUp()

            self._innerOps = []

            for i, innershape in enumerate(self._innerShapes):
                op = OpBlockedArrayCache(parent=self)
                op.inputs["fixAtCurrent"].connect(self.inputs["fixAtCurrent"])
                self._innerOps.append(op)

                op.inputs["Input"].connect(self.inputs["Input"])

                # Forward "value changed" notifications to our own output
                op.Output.notifyValueChanged(self.Output._sig_value_changed)

        for i, innershape in enumerate(self._innerShapes):
            op = self._innerOps[i]
            op.inputs["innerBlockShape"].setValue(innershape)
            op.inputs["outerBlockShape"].setValue(self._outerShapes[i])

        self.Output.meta.assignFrom(self.Input.meta)

        # Estimate ram usage
        ram_per_pixel = 0
        if self.Output.meta.dtype == object or self.Output.meta.dtype == numpy.object_:
            ram_per_pixel = sys.getsizeof(None)
        elif numpy.issubdtype(self.Output.meta.dtype, numpy.dtype):
            ram_per_pixel = self.Output.meta.dtype().nbytes

        tagged_shape = self.Output.meta.getTaggedShape()
        if 'c' in tagged_shape:
            ram_per_pixel *= float(tagged_shape['c'])

        if self.Output.meta.ram_usage_per_requested_pixel is not None:
            ram_per_pixel = max(ram_per_pixel,
                                self.Output.meta.ram_usage_per_requested_pixel)

        self.Output.meta.ram_usage_per_requested_pixel = ram_per_pixel

        # We also provide direct access to each of our inner cache outputs.
        self.InnerOutputs.resize(len(self._innerOps))
        for i, slot in enumerate(self.InnerOutputs):
            slot.connect(self._innerOps[i].Output)

    def execute(self, slot, subindex, roi, result):
        t = time.time()
        assert slot == self.Output

        key = roi.toSlice()
        start, stop = sliceToRoi(key, self.shape)
        roishape = numpy.array(stop) - numpy.array(start)

        max_dist_squared = sys.maxint
        index = 0

        for i, blockshape in enumerate(self._innerShapes):
            blockshape = numpy.array(blockshape)

            diff = roishape - blockshape
            diffsquared = diff * diff
            distance_squared = numpy.sum(diffsquared)
            if distance_squared < max_dist_squared:
                index = i
                max_dist_squared = distance_squared

        op = self._innerOps[index]
        op.outputs["Output"][key].writeInto(result).wait()
        self.logger.debug("read %r took %f msec." % (roi.pprint(), 1000.0 *
                                                     (time.time() - t)))

    def propagateDirty(self, slot, subindex, roi):
        key = roi.toSlice()
        # We *could* simply forward dirty notifications from our inner operators
        # to our output (by subscribing to their notifyDirty signals),
        # but that would result in duplicates of many (not all!) dirty notifications
        # (since we have more than one inner cache, and each is receiving dirty notifications)
        # Instead, we simply mark *everything* dirty when we beome unfixed or if the block shape changes.
        fixed = self.fixAtCurrent.value
        if not fixed:
            if slot == self.Input:
                self.Output.setDirty(key)
            elif slot == self.outerBlockShape or slot == self.innerBlockShape:
                #self.Output.setDirty( slice(None) )
                pass  # Blockshape changes don't trigger dirty notifications
                # It is considered an error to change the blockshape after the initial configuration.
            elif slot == self.fixAtCurrent:
                self.Output.setDirty(slice(None))
            else:
                assert False, "Unknown dirty input slot"
Example #4
0
class OpPixelClassification(Operator):
    """
    Top-level operator for pixel classification
    """
    name = "OpPixelClassification"
    category = "Top-level"

    # Graph inputs

    InputImages = InputSlot(
        level=1)  # Original input data.  Used for display only.
    PredictionMasks = InputSlot(
        level=1, optional=True
    )  # Routed to OpClassifierPredict.PredictionMask.  See there for details.

    LabelInputs = InputSlot(
        optional=True,
        level=1)  # Input for providing label data from an external source
    LabelsAllowedFlags = InputSlot(
        stype='bool',
        level=1)  # Specifies which images are permitted to be labeled

    FeatureImages = InputSlot(
        level=1
    )  # Computed feature images (each channel is a different feature)
    CachedFeatureImages = InputSlot(level=1)  # Cached feature data.

    FreezePredictions = InputSlot(stype='bool')
    ClassifierFactory = InputSlot(
        value=ParallelVigraRfLazyflowClassifierFactory(10, 10))

    PredictionsFromDisk = InputSlot(optional=True, level=1)

    PredictionProbabilities = OutputSlot(
        level=1
    )  # Classification predictions (via feature cache for interactive speed)

    PredictionProbabilityChannels = OutputSlot(
        level=2)  # Classification predictions, enumerated by channel
    SegmentationChannels = OutputSlot(
        level=2)  # Binary image of the final selections.

    LabelImages = OutputSlot(level=1)  # Labels from the user
    NonzeroLabelBlocks = OutputSlot(
        level=1)  # A list if slices that contain non-zero label values
    Classifier = OutputSlot(
    )  # We provide the classifier as an external output for other applets to use

    CachedPredictionProbabilities = OutputSlot(
        level=1
    )  # Classification predictions (via feature cache AND prediction cache)

    HeadlessPredictionProbabilities = OutputSlot(
        level=1
    )  # Classification predictions ( via no image caches (except for the classifier itself )
    HeadlessUint8PredictionProbabilities = OutputSlot(
        level=1)  # Same as above, but 0-255 uint8 instead of 0.0-1.0 float32
    HeadlessUncertaintyEstimate = OutputSlot(
        level=1
    )  # Same as uncertaintly estimate, but does not rely on cached data.

    UncertaintyEstimate = OutputSlot(level=1)

    SimpleSegmentation = OutputSlot(level=1)  # For debug, for now

    # GUI-only (not part of the pipeline, but saved to the project)
    LabelNames = OutputSlot()
    LabelColors = OutputSlot()
    PmapColors = OutputSlot()

    NumClasses = OutputSlot()

    def setupOutputs(self):
        self.LabelNames.meta.dtype = object
        self.LabelNames.meta.shape = (1, )
        self.LabelColors.meta.dtype = object
        self.LabelColors.meta.shape = (1, )
        self.PmapColors.meta.dtype = object
        self.PmapColors.meta.shape = (1, )

    def __init__(self, *args, **kwargs):
        """
        Instantiate all internal operators and connect them together.
        """
        super(OpPixelClassification, self).__init__(*args, **kwargs)

        # Default values for some input slots
        self.FreezePredictions.setValue(True)
        self.LabelNames.setValue([])
        self.LabelColors.setValue([])
        self.PmapColors.setValue([])

        # SPECIAL connection: The LabelInputs slot doesn't get it's data
        #  from the InputImages slot, but it's shape must match.
        self.LabelInputs.connect(self.InputImages)

        # Hook up Labeling Pipeline
        self.opLabelPipeline = OpMultiLaneWrapper(
            OpLabelPipeline,
            parent=self,
            broadcastingSlotNames=['DeleteLabel'])
        self.opLabelPipeline.RawImage.connect(self.InputImages)
        self.opLabelPipeline.LabelInput.connect(self.LabelInputs)
        self.opLabelPipeline.DeleteLabel.setValue(-1)
        self.LabelImages.connect(self.opLabelPipeline.Output)
        self.NonzeroLabelBlocks.connect(self.opLabelPipeline.nonzeroBlocks)

        # Hook up the Training operator
        self.opTrain = OpTrainClassifierBlocked(parent=self)
        self.opTrain.ClassifierFactory.connect(self.ClassifierFactory)
        self.opTrain.Labels.connect(self.opLabelPipeline.Output)
        self.opTrain.Images.connect(self.CachedFeatureImages)
        self.opTrain.nonzeroLabelBlocks.connect(
            self.opLabelPipeline.nonzeroBlocks)

        # Hook up the Classifier Cache
        # The classifier is cached here to allow serializers to force in
        #   a pre-calculated classifier (loaded from disk)
        self.classifier_cache = OpValueCache(parent=self)
        self.classifier_cache.name = "OpPixelClassification.classifier_cache"
        self.classifier_cache.inputs["Input"].connect(
            self.opTrain.outputs['Classifier'])
        self.classifier_cache.inputs["fixAtCurrent"].connect(
            self.FreezePredictions)
        self.Classifier.connect(self.classifier_cache.Output)

        # Hook up the prediction pipeline inputs
        self.opPredictionPipeline = OpMultiLaneWrapper(OpPredictionPipeline,
                                                       parent=self)
        self.opPredictionPipeline.FeatureImages.connect(self.FeatureImages)
        self.opPredictionPipeline.CachedFeatureImages.connect(
            self.CachedFeatureImages)
        self.opPredictionPipeline.Classifier.connect(
            self.classifier_cache.Output)
        self.opPredictionPipeline.FreezePredictions.connect(
            self.FreezePredictions)
        self.opPredictionPipeline.PredictionsFromDisk.connect(
            self.PredictionsFromDisk)
        self.opPredictionPipeline.PredictionMask.connect(self.PredictionMasks)

        def _updateNumClasses(*args):
            """
            When the number of labels changes, we MUST make sure that the prediction image changes its shape (the number of channels).
            Since setupOutputs is not called for mere dirty notifications, but is called in response to setValue(),
            we use this function to call setValue().
            """
            numClasses = len(self.LabelNames.value)
            self.opTrain.MaxLabel.setValue(numClasses)
            self.opPredictionPipeline.NumClasses.setValue(numClasses)
            self.NumClasses.setValue(numClasses)

        self.LabelNames.notifyDirty(_updateNumClasses)

        # Prediction pipeline outputs -> Top-level outputs
        self.PredictionProbabilities.connect(
            self.opPredictionPipeline.PredictionProbabilities)
        self.CachedPredictionProbabilities.connect(
            self.opPredictionPipeline.CachedPredictionProbabilities)
        self.HeadlessPredictionProbabilities.connect(
            self.opPredictionPipeline.HeadlessPredictionProbabilities)
        self.HeadlessUint8PredictionProbabilities.connect(
            self.opPredictionPipeline.HeadlessUint8PredictionProbabilities)
        self.PredictionProbabilityChannels.connect(
            self.opPredictionPipeline.PredictionProbabilityChannels)
        self.SegmentationChannels.connect(
            self.opPredictionPipeline.SegmentationChannels)
        self.UncertaintyEstimate.connect(
            self.opPredictionPipeline.UncertaintyEstimate)
        self.SimpleSegmentation.connect(
            self.opPredictionPipeline.SimpleSegmentation)
        self.HeadlessUncertaintyEstimate.connect(
            self.opPredictionPipeline.HeadlessUncertaintyEstimate)

        def inputResizeHandler(slot, oldsize, newsize):
            if (newsize == 0):
                self.LabelImages.resize(0)
                self.NonzeroLabelBlocks.resize(0)
                self.PredictionProbabilities.resize(0)
                self.CachedPredictionProbabilities.resize(0)

        self.InputImages.notifyResized(inputResizeHandler)

        # Debug assertions: Check to make sure the non-wrapped operators stayed that way.
        assert self.opTrain.Images.operator == self.opTrain

        def handleNewInputImage(multislot, index, *args):
            def handleInputReady(slot):
                self._checkConstraints(index)
                self.setupCaches(multislot.index(slot))

            multislot[index].notifyReady(handleInputReady)

        self.InputImages.notifyInserted(handleNewInputImage)

        def handleNewMaskImage(multislot, index, *args):
            def handleInputReady(slot):
                self._checkConstraints(index)

            multislot[index].notifyReady(handleInputReady)

        self.PredictionMasks.notifyInserted(handleNewMaskImage)

        # All input multi-slots should be kept in sync
        # Output multi-slots will auto-sync via the graph
        multiInputs = filter(lambda s: s.level >= 1, self.inputs.values())
        for s1 in multiInputs:
            for s2 in multiInputs:
                if s1 != s2:

                    def insertSlot(a, b, position, finalsize):
                        a.insertSlot(position, finalsize)

                    s1.notifyInserted(partial(insertSlot, s2))

                    def removeSlot(a, b, position, finalsize):
                        a.removeSlot(position, finalsize)

                    s1.notifyRemoved(partial(removeSlot, s2))

    def setupCaches(self, imageIndex):
        numImages = len(self.InputImages)
        inputSlot = self.InputImages[imageIndex]
        #        # Can't setup if all inputs haven't been set yet.
        #        if numImages != len(self.FeatureImages) or \
        #           numImages != len(self.CachedFeatureImages):
        #            return
        #
        #        self.LabelImages.resize(numImages)
        self.LabelInputs.resize(numImages)

        # Special case: We have to set up the shape of our label *input* according to our image input shape
        shapeList = list(self.InputImages[imageIndex].meta.shape)
        try:
            channelIndex = self.InputImages[imageIndex].meta.axistags.index(
                'c')
            shapeList[channelIndex] = 1
        except:
            pass
        self.LabelInputs[imageIndex].meta.shape = tuple(shapeList)
        self.LabelInputs[imageIndex].meta.axistags = inputSlot.meta.axistags

    def _checkConstraints(self, laneIndex):
        """
        Ensure that all input images have the same number of channels.
        """
        if not self.InputImages[laneIndex].ready():
            return

        thisLaneTaggedShape = self.InputImages[laneIndex].meta.getTaggedShape()

        # Find a different lane and use it for comparison
        validShape = thisLaneTaggedShape
        for i, slot in enumerate(self.InputImages):
            if slot.ready() and i != laneIndex:
                validShape = slot.meta.getTaggedShape()
                break

        if validShape['c'] != thisLaneTaggedShape['c']:
            raise DatasetConstraintError(
                 "Pixel Classification",
                 "All input images must have the same number of channels.  "\
                 "Your new image has {} channel(s), but your other images have {} channel(s)."\
                 .format( thisLaneTaggedShape['c'], validShape['c'] ) )

        if len(validShape) != len(thisLaneTaggedShape):
            raise DatasetConstraintError(
                 "Pixel Classification",
                 "All input images must have the same dimensionality.  "\
                 "Your new image has {} dimensions (including channel), but your other images have {} dimensions."\
                 .format( len(thisLaneTaggedShape), len(validShape) ) )

        mask_slot = self.PredictionMasks[laneIndex]
        input_shape = tuple(thisLaneTaggedShape.values())
        if mask_slot.ready() and mask_slot.meta.shape[:-1] != input_shape[:-1]:
            raise DatasetConstraintError(
                 "Pixel Classification",
                 "If you supply a prediction mask, it must have the same shape as the input image."\
                 "Your input image has shape {}, but your mask has shape {}."\
                 .format( input_shape, mask_slot.meta.shape ) )

    def setInSlot(self, slot, subindex, roi, value):
        # Nothing to do here: All inputs that support __setitem__
        #   are directly connected to internal operators.
        pass

    def propagateDirty(self, slot, subindex, roi):
        # Nothing to do here: All outputs are directly connected to
        #  internal operators that handle their own dirty propagation.
        pass

    def addLane(self, laneIndex):
        numLanes = len(self.InputImages)
        assert numLanes == laneIndex, "Image lanes must be appended."
        self.InputImages.resize(numLanes + 1)

    def removeLane(self, laneIndex, finalLength):
        self.InputImages.removeSlot(laneIndex, finalLength)

    def getLane(self, laneIndex):
        return OperatorSubView(self, laneIndex)
Example #5
0
class OpTrackingFeatureExtraction(Operator):
    name = "Tracking Feature Extraction"

    TranslationVectors = InputSlot(optional=True)
    RawImage = InputSlot()
    BinaryImage = InputSlot()

    # which features to compute.
    # nested dictionary with format:
    # dict[plugin_name][feature_name][parameter_name] = parameter_value
    # for example {"Standard Object Features": {"Mean in neighborhood":{"margin": (5, 5, 2)}}}
    FeatureNamesVigra = InputSlot(rtype=List, stype=Opaque, value={})

    FeatureNamesDivision = InputSlot(rtype=List, stype=Opaque, value={})

    # Bypass cache (for headless mode)
    BypassModeEnabled = InputSlot(value=False)

    LabelImage = OutputSlot()
    ObjectCenterImage = OutputSlot()

    # the computed features.
    # nested dictionary with format:
    # dict[plugin_name][feature_name] = feature_value
    RegionFeaturesVigra = OutputSlot(stype=Opaque, rtype=List)
    RegionFeaturesDivision = OutputSlot(stype=Opaque, rtype=List)
    RegionFeaturesAll = OutputSlot(stype=Opaque, rtype=List)

    ComputedFeatureNamesAll = OutputSlot(rtype=List, stype=Opaque)
    ComputedFeatureNamesNoDivisions = OutputSlot(rtype=List, stype=Opaque)

    BlockwiseRegionFeaturesVigra = OutputSlot(
    )  # For compatibility with tracking workflow, the RegionFeatures output
    # has rtype=List, indexed by t.
    # For other workflows, output has rtype=ArrayLike, indexed by (t)
    BlockwiseRegionFeaturesDivision = OutputSlot()

    CleanLabelBlocks = OutputSlot()
    LabelImageCacheInput = InputSlot(optional=True)

    RegionFeaturesCacheInputVigra = InputSlot(optional=True)
    RegionFeaturesCleanBlocksVigra = OutputSlot()

    RegionFeaturesCacheInputDivision = InputSlot(optional=True)
    RegionFeaturesCleanBlocksDivision = OutputSlot()

    def __init__(self, parent):
        super(OpTrackingFeatureExtraction, self).__init__(parent)
        self._default_features = None
        # internal operators
        self._objectExtraction = OpObjectExtraction(parent=self)

        self._opDivFeats = OpCachedDivisionFeatures(parent=self)
        self._opDivFeatsAdaptOutput = OpAdaptTimeListRoi(parent=self)

        # connect internal operators
        self._objectExtraction.RawImage.connect(self.RawImage)
        self._objectExtraction.BinaryImage.connect(self.BinaryImage)
        self._objectExtraction.BypassModeEnabled.connect(
            self.BypassModeEnabled)
        self._objectExtraction.Features.connect(self.FeatureNamesVigra)
        self._objectExtraction.RegionFeaturesCacheInput.connect(
            self.RegionFeaturesCacheInputVigra)
        self._objectExtraction.LabelImageCacheInput.connect(
            self.LabelImageCacheInput)
        self.CleanLabelBlocks.connect(self._objectExtraction.CleanLabelBlocks)
        self.RegionFeaturesCleanBlocksVigra.connect(
            self._objectExtraction.RegionFeaturesCleanBlocks)
        self.ObjectCenterImage.connect(
            self._objectExtraction.ObjectCenterImage)
        self.LabelImage.connect(self._objectExtraction.LabelImage)
        self.BlockwiseRegionFeaturesVigra.connect(
            self._objectExtraction.BlockwiseRegionFeatures)
        self.RegionFeaturesVigra.connect(self._objectExtraction.RegionFeatures)

        self._opDivFeats.LabelImage.connect(self.LabelImage)
        self._opDivFeats.DivisionFeatureNames.connect(
            self.FeatureNamesDivision)
        self._opDivFeats.CacheInput.connect(
            self.RegionFeaturesCacheInputDivision)
        self._opDivFeats.RegionFeaturesVigra.connect(
            self._objectExtraction.BlockwiseRegionFeatures)
        self.RegionFeaturesCleanBlocksDivision.connect(
            self._opDivFeats.CleanBlocks)
        self.BlockwiseRegionFeaturesDivision.connect(self._opDivFeats.Output)

        self._opDivFeatsAdaptOutput.Input.connect(self._opDivFeats.Output)
        self.RegionFeaturesDivision.connect(self._opDivFeatsAdaptOutput.Output)

        # As soon as input data is available, check its constraints
        self.RawImage.notifyReady(self._checkConstraints)
        self.BinaryImage.notifyReady(self._checkConstraints)

        # FIXME this shouldn't be done in post-filtering, but in reading the config or around that time
        self.RawImage.notifyReady(self._filterFeaturesByDim)

    def setDefaultFeatures(self, feats):
        self._default_features = feats

    def setupOutputs(self, *args, **kwargs):
        self.ComputedFeatureNamesAll.meta.assignFrom(
            self.FeatureNamesVigra.meta)
        self.ComputedFeatureNamesNoDivisions.meta.assignFrom(
            self.FeatureNamesVigra.meta)
        self.RegionFeaturesAll.meta.assignFrom(self.RegionFeaturesVigra.meta)

    def execute(self, slot, subindex, roi, result):
        if slot == self.ComputedFeatureNamesAll:
            feat_names_vigra = self.FeatureNamesVigra([]).wait()
            feat_names_div = self.FeatureNamesDivision([]).wait()
            for plugin_name in list(feat_names_vigra.keys()):
                assert plugin_name not in feat_names_div, "feature name dictionaries must be mutually exclusive"
            for plugin_name in list(feat_names_div.keys()):
                assert plugin_name not in feat_names_vigra, "feature name dictionaries must be mutually exclusive"
            result = dict(
                list(feat_names_vigra.items()) + list(feat_names_div.items()))

            return result
        elif slot == self.ComputedFeatureNamesNoDivisions:
            feat_names_vigra = self.FeatureNamesVigra([]).wait()
            result = dict(list(feat_names_vigra.items()))

            return result
        elif slot == self.RegionFeaturesAll:
            feat_vigra = self.RegionFeaturesVigra(roi).wait()
            feat_div = self.RegionFeaturesDivision(roi).wait()
            assert np.all(list(feat_vigra.keys()) == list(feat_div.keys()))
            result = {}
            for t in list(feat_vigra.keys()):
                for plugin_name in list(feat_vigra[t].keys()):
                    assert plugin_name not in feat_div[
                        t], "feature dictionaries must be mutually exclusive"
                for plugin_name in list(feat_div[t].keys()):
                    assert plugin_name not in feat_vigra[
                        t], "feature dictionaries must be mutually exclusive"
                result[t] = dict(
                    list(feat_div[t].items()) + list(feat_vigra[t].items()))
            return result
        else:
            assert False, "Shouldn't get here."

    def propagateDirty(self, slot, subindex, roi):
        if slot == self.BypassModeEnabled:
            pass
        elif slot == self.FeatureNamesVigra or slot == self.FeatureNamesDivision:
            self.ComputedFeatureNamesAll.setDirty(roi)
            self.ComputedFeatureNamesNoDivisions.setDirty(roi)

    def setInSlot(self, slot, subindex, roi, value):
        assert slot == self.RegionFeaturesCacheInputVigra or \
            slot == self.RegionFeaturesCacheInputDivision or \
            slot == self.LabelImageCacheInput, "Invalid slot for setInSlot(): {}".format(slot.name)

    def _checkConstraints(self, *args):
        if self.RawImage.ready():
            rawTaggedShape = self.RawImage.meta.getTaggedShape()
            if 't' not in rawTaggedShape or rawTaggedShape['t'] < 2:
                msg = "Raw image must have a time dimension with at least 2 images.\n"\
                    "Your dataset has shape: {}".format(self.RawImage.meta.shape)

        if self.BinaryImage.ready():
            rawTaggedShape = self.BinaryImage.meta.getTaggedShape()
            if 't' not in rawTaggedShape or rawTaggedShape['t'] < 2:
                msg = "Binary image must have a time dimension with at least 2 images.\n"\
                    "Your dataset has shape: {}".format(self.BinaryImage.meta.shape)

        if self.RawImage.ready() and self.BinaryImage.ready():
            rawTaggedShape = self.RawImage.meta.getTaggedShape()
            binTaggedShape = self.BinaryImage.meta.getTaggedShape()
            rawTaggedShape['c'] = None
            binTaggedShape['c'] = None
            if dict(rawTaggedShape) != dict(binTaggedShape):
                logger.info("Raw data and other data must have equal dimensions (different channels are okay).\n"\
                      "Your datasets have shapes: {} and {}".format( self.RawImage.meta.shape, self.BinaryImage.meta.shape ))

                msg = "Raw data and other data must have equal dimensions (different channels are okay).\n"\
                      "Your datasets have shapes: {} and {}".format( self.RawImage.meta.shape, self.BinaryImage.meta.shape )
                raise DatasetConstraintError("Object Extraction", msg)

    def _filterFeaturesByDim(self, *args):
        # Remove 2D-only features from 3D datasets
        # Features look as follows:
        # dict[plugin_name][feature_name][parameter_name] = parameter_value
        # for example {"Standard Object Features": {"Mean in neighborhood":{"margin": (5, 5, 2)}}}

        # FIXME: this is a hacky solution because we overwrite the INPUT slot FeatureNamesVigra depending on the data shape.
        # We store the _default_features separately because if the user switches from a 3D to a 2D dataset the value of
        # FeatureNamesVigra would not be populated with all the features again, but only those that work for 3D.

        if self.RawImage.ready() and self.FeatureNamesVigra.ready():

            rawTaggedShape = self.RawImage.meta.getTaggedShape()
            filtered_features_dict = {}
            if rawTaggedShape['z'] > 1:
                # Filter out the 2D-only features, which helpfully have "2D" in their plugin name
                current_dict = self._default_features
                for plugin in list(current_dict.keys()):
                    if not "2D" in plugin:
                        filtered_features_dict[plugin] = current_dict[plugin]

                self.FeatureNamesVigra.setValue(filtered_features_dict)
            else:
                # Filter out the 2D-only features, which helpfully have "2D" in their plugin name
                current_dict = self._default_features
                for plugin in list(current_dict.keys()):
                    if not "3D" in plugin:
                        filtered_features_dict[plugin] = current_dict[plugin]

                self.FeatureNamesVigra.setValue(filtered_features_dict)
class OpInterpMissingData(Operator):
    name = "OpInterpMissingData"

    InputVolume = InputSlot()
    InputSearchDepth = InputSlot(value=3)
    PatchSize = InputSlot(value=128)
    HaloSize = InputSlot(value=30)
    DetectionMethod = InputSlot(value="svm")
    InterpolationMethod = InputSlot(value="cubic")

    # be careful when using the following: setting the same thing twice will not trigger
    # the action you desire, even if something else has changed
    OverloadDetector = InputSlot(value="")

    Output = OutputSlot()
    Missing = OutputSlot()
    Detector = OutputSlot(stype=Opaque)

    _requiredMargin = {"cubic": 2, "linear": 1, "constant": 0}
    _dirty = False

    def __init__(self, *args, **kwargs):
        super(OpInterpMissingData, self).__init__(*args, **kwargs)

        self.detector = OpDetectMissing(parent=self)
        self.interpolator = OpInterpolate(parent=self)

        self.detector.InputVolume.connect(self.InputVolume)
        self.detector.PatchSize.connect(self.PatchSize)
        self.detector.HaloSize.connect(self.HaloSize)
        self.detector.DetectionMethod.connect(self.DetectionMethod)
        self.detector.OverloadDetector.connect(self.OverloadDetector)

        self.interpolator.InputVolume.connect(self.InputVolume)
        self.interpolator.Missing.connect(self.detector.Output)
        self.interpolator.InterpolationMethod.connect(self.InterpolationMethod)

        self.Missing.connect(self.detector.Output)
        self.Detector.connect(self.detector.Detector)

    def isDirty(self):
        return self._dirty

    def resetDirty(self):
        self._dirty = False

    def setupOutputs(self):
        # Output has the same shape/axes/dtype/drange as input
        self.Output.meta.assignFrom(self.InputVolume.meta)

        self.Detector.meta.shape = (1, )

    def execute(self, slot, subindex, roi, result):
        """
        execute
        """

        method = self.InterpolationMethod.value

        assert method in list(self._requiredMargin.keys(
        )), "Unknown interpolation method {}".format(method)

        z_index = self.InputVolume.meta.axistags.index("z")
        c_index = self.InputVolume.meta.axistags.index("c")
        t_index = self.InputVolume.meta.axistags.index("t")
        # nz = self.InputVolume.meta.getTaggedShape()['z']

        resultZYXCT = vigra.taggedView(
            result, self.InputVolume.meta.axistags).withAxes(*"zyxct")

        # backup ROI
        oldStart = np.copy(roi.start)
        oldStop = np.copy(roi.stop)

        if c_index < len(roi.start):
            cRange = np.arange(roi.start[c_index], roi.stop[c_index])
        else:
            cRange = np.array([0])

        if t_index < len(roi.start):
            tRange = np.arange(roi.start[t_index], roi.stop[t_index])
        else:
            tRange = np.array([0])

        for c in cRange:
            for t in tRange:

                # change roi to single block
                if c_index < len(roi.start):
                    roi.start[c_index] = c
                    roi.stop[c_index] = c + 1

                if t_index < len(roi.start):
                    roi.start[t_index] = t
                    roi.stop[t_index] = t + 1

                # check if more input is needed, and how many
                z_offsets = self._extendRoi(roi)

                # get extended interpolation
                roi.start[z_index] -= z_offsets[0]
                roi.stop[z_index] += z_offsets[1]

                a = self.interpolator.Output.get(roi).wait()

                # reduce to original roi
                roi.stop = roi.stop - roi.start
                roi.start *= 0
                roi.start[z_index] += z_offsets[0]
                roi.stop[z_index] -= z_offsets[1]
                key = roiToSlice(roi.start, roi.stop)

                resultZYXCT[..., c, t] = vigra.taggedView(
                    a[key], self.InputVolume.meta.axistags).withAxes(*"zyx")

                # restore ROI, will be used in other methods!!!
                roi.start = np.copy(oldStart)
                roi.stop = np.copy(oldStop)

        return result

    def propagateDirty(self, slot, subindex, roi):

        if slot == self.InputVolume:
            self.Output.setDirty(roi)

        if slot == self.OverloadDetector:
            self._dirty = True

        if slot == self.PatchSize or slot == self.HaloSize:
            self._dirty = True

    def train(self, force=False):
        return self.detector.train(force=force)

    def _extendRoi(self, roi):
        origStart = np.copy(roi.start)
        origStop = np.copy(roi.stop)

        offset_top = 0
        offset_bot = 0

        z_index = self.InputVolume.meta.axistags.index("z")

        depth = self.InputSearchDepth.value
        nRequestedSlices = roi.stop[z_index] - roi.start[z_index]
        nNeededSlices = self._requiredMargin[self.InterpolationMethod.value]

        missing = vigra.taggedView(
            self.detector.Output.get(roi).wait(),
            axistags=self.InputVolume.meta.axistags).withAxes(*"zyx")

        nGoodSlicesTop = 0
        # go inside the roi
        for k in range(nRequestedSlices):
            if np.all(missing[k, ...] == 0):  # clean slice
                nGoodSlicesTop += 1
            else:
                break

        # are we finished yet?
        if nGoodSlicesTop >= nRequestedSlices:
            return (0, 0)

        # looks like we need more slices on top
        while nGoodSlicesTop < nNeededSlices and offset_top < depth and roi.start[
                z_index] > 0:
            roi.stop[z_index] = roi.start[z_index]
            roi.start[z_index] -= 1
            offset_top += 1
            topmissing = self.detector.Output.get(roi).wait()
            if np.all(topmissing == 0):  # clean slice
                nGoodSlicesTop += 1
            else:  # need to start again
                nGoodSlicesTop = 0

        nGoodSlicesBot = 0
        # go inside the roi
        for k in range(1, nRequestedSlices + 1):
            if np.all(missing[-k, ...] == 0):  # clean slice
                nGoodSlicesBot += 1
            else:
                break

        roi.start = np.copy(origStart)
        roi.stop = np.copy(origStop)

        # looks like we need more slices on bottom
        while (roi.stop[z_index] < self.InputVolume.meta.getTaggedShape()["z"]
               and nGoodSlicesBot < nNeededSlices and offset_bot < depth):
            roi.start[z_index] = roi.stop[z_index]
            roi.stop[z_index] += 1
            offset_bot += 1
            botmissing = self.detector.Output.get(roi).wait()
            if np.all(botmissing == 0):  # clean slice
                nGoodSlicesBot += 1
            else:  # need to start again
                nGoodSlicesBot = 0

        roi.start = np.copy(origStart)
        roi.stop = np.copy(origStop)

        return (offset_top, offset_bot)
Example #7
0
class OpArrayCache(Operator, ManagedCache):
    """ Allocates a block of memory as large as Input.meta.shape (==Output.meta.shape)
        with the same dtype in order to be able to cache results.
        
        blockShape: dirty regions are tracked with a granularity of blockShape
    """
    
    name = "ArrayCache"
    description = "numpy.ndarray caching class"
    category = "misc"

    DefaultBlockSize = 64

    #Input
    Input = InputSlot(allow_mask=True)
    blockShape = InputSlot(value = DefaultBlockSize)
    fixAtCurrent = InputSlot(value = False)
   
    #Output
    CleanBlocks = OutputSlot()
    Output = OutputSlot(allow_mask=True)

    loggingName = __name__ + ".OpArrayCache"
    logger = logging.getLogger(loggingName)
    traceLogger = logging.getLogger("TRACE." + loggingName)
    
    # Block states
    IN_PROCESS  = 0
    DIRTY       = 1
    CLEAN       = 2
    FIXED_DIRTY = 3

    def __init__(self, *args, **kwargs):
        super( OpArrayCache, self ).__init__(*args, **kwargs)
        self._origBlockShape = self.DefaultBlockSize
        self._blockShape = None
        self._dirtyShape = None
        self._blockState = None
        self._dirtyState = None
        self._fixed = False
        self._cache = None
        self._lock = Lock()
        self._lazyAlloc = True
        self._cacheHits = 0
        self._has_fixed_dirty_blocks = False
        self._running = 0

        # Now that we're initialized, it's safe to register with the memory manager
        self.registerWithMemoryManager()

    # ========== CACHE API ==========

    def usedMemory(self):
        return self._usedMemory(self._cache)

    def fractionOfUsedMemoryDirty(self):
        if self.Output.meta.shape is None:
            return 0

        totAll   = numpy.prod(self.Output.meta.shape)
        totDirty = 0
        if self._blockState is None:
            return 0
        it = numpy.nditer(self._blockState, flags=['multi_index'])
        while not it.finished:
            v = it[0]
            sh = self._blockShapeForIndex(it.multi_index)
            it.iternext()
            if sh is None:
                continue
            if v == self.DIRTY or v == self.FIXED_DIRTY:
                totDirty += numpy.prod(sh)
        return totDirty/float(totAll)
    
    def lastAccessTime(self):
        return super(OpArrayCache, self).lastAccessTime()

    def freeMemory(self):
        return self._freeMemory()

    def freeDirtyMemory(self):
        if self.fractionOfUsedMemoryDirty() >= 1.0:
            # we can only free if all is dirty without touching non-dirty areas
            return self.freeMemory()
        else:
            return 0

    def generateReport(self, memInfoNode):
        super(OpArrayCache, self).generateReport(memInfoNode)
        if self._cache is not None:
            if hasattr(self._cache, "dtype"):
                memInfoNode.dtype = self._cache.dtype
            else:
                # cache is no array, so we cannot determine the dtype
                pass

    # ========== END CACHE API ==========

    @staticmethod
    def _usedMemory(item):
        s = 0
        if isinstance(item, numpy.ndarray):
            if item.dtype == numpy.object:
                for x in item.ravel():
                    s += OpArrayCache._usedMemory(x)
            else:
                s = item.nbytes
        elif isinstance(item, dict):
            for key in item.keys():
                try:
                    obj = item[key]
                except KeyError:
                    # cleaned up
                    pass
                else:
                    s += OpArrayCache._usedMemory(obj)
        return s
                    

    def _blockShapeForIndex(self, index):
        if self._cache is None:
            return None
        index = numpy.asarray(index)

        cacheShape = numpy.array(self._cache.shape)
        blockShape = numpy.array(self._blockShape)
        start = blockShape * index
        stop = blockShape * (index + 1)
        start = numpy.maximum(start, (0,)*len(start))
        stop = numpy.minimum(stop, cacheShape)
        ret = tuple(map(int, stop - start))
        return ret

    def _freeMemory(self, refcheck = True):
        with self._lock:
            freed  = self.usedMemory()
            if self._cache is not None and (self._blockState != OpArrayCache.IN_PROCESS).all():
                if self._cache.shape == ():
                    return
                fshape = self._cache.shape
                try:
                    self._cache.resize((), refcheck = refcheck)
                except ValueError:
                    freed = 0
                    self.logger.debug("OpArrayCache (name={}): freeing failed due to view references".format(self.name))
                if freed > 0:
                    self.logger.debug("OpArrayCache: freed cache of shape:{}".format(fshape))
    
                    self._blockState[:] = OpArrayCache.DIRTY
                    del self._cache
                    self._cache = None
            return freed

    def _get_full_blockshape(self, input_blockshape):
        max_shape = self.Input.meta.shape
        if not isinstance(input_blockshape, collections.Iterable):
            # Broadcast as a tuple
            blockshape = (input_blockshape,)*len(max_shape)
        else:
            blockshape = tuple( input_blockshape )
        blockshape = numpy.minimum(blockshape, max_shape)
        return tuple(blockshape)
    
    def _allocateManagementStructures(self):
        shape = self.Output.meta.shape
        self._blockShape = self._get_full_blockshape(self._origBlockShape)    
        self._dirtyShape = numpy.ceil(1.0 * numpy.array(shape) / numpy.array(self._blockShape)).astype(numpy.int)

        self.logger.debug("Configured OpArrayCache with shape={}, blockShape={}, dirtyShape={}, origBlockShape={}".format(shape, self._blockShape, self._dirtyShape, self._origBlockShape))
    
        #if a request has been submitted to get a block, the request object
        #is stored within this array
        self._blockQuery = numpy.ndarray(self._dirtyShape, dtype=object)
        
        #keep track of the dirty state of each block
        self._blockState = OpArrayCache.DIRTY * numpy.ones(self._dirtyShape, numpy.uint8)
    
        self._blockState[:]= OpArrayCache.DIRTY
        self._dirtyState = OpArrayCache.CLEAN
    
    def _allocateCache(self):
        self._last_access_time = 0
        self._cache_priority = 0
        self._running = 0

        if self._cache is None or (self._cache.shape != self.Output.meta.shape):
            mem = self.Output.stype.allocateDestination(None)
            mem[:] = 0
            self.logger.debug("OpArrayCache: Allocating cache (size: %dbytes)" % mem.nbytes)
            if self._blockState is None:
                self._allocateManagementStructures()
            self._cache = mem

    def setupOutputs(self):
        self.CleanBlocks.meta.shape = (1,)
        self.CleanBlocks.meta.dtype = object
        reconfigure = False
        if  self.inputs["fixAtCurrent"].ready():
            self._fixed =  self.inputs["fixAtCurrent"].value

        if self.inputs["blockShape"].ready() and self.inputs["Input"].ready():
            newBShape = self.inputs["blockShape"].value
            assert numpy.issubdtype(type(newBShape), numpy.integer) or all( map(lambda x: numpy.issubdtype(type(x), numpy.integer), newBShape) )
            if self._origBlockShape != newBShape and self.inputs["Input"].ready():
                reconfigure = True
            self._origBlockShape = newBShape
            self._blockShape = newBShape

            inputSlot = self.inputs["Input"]
            self.Output.meta.assignFrom(inputSlot.meta)

            if isinstance(self._blockShape, collections.Iterable) and \
               len(self._blockShape) != len(self.Input.meta.shape):
                self.Output.meta.NOTREADY = True
                self.CleanBlocks.meta.NOTREADY = True
                return

            # Estimate ram usage            
            ram_per_pixel = 0
            if self.Output.meta.dtype == object or self.Output.meta.dtype == numpy.object_:
                ram_per_pixel = sys.getsizeof(None)
            elif numpy.issubdtype(self.Output.meta.dtype, numpy.dtype):
                ram_per_pixel = self.Output.meta.dtype().nbytes
            
            tagged_shape = self.Output.meta.getTaggedShape()
            if 'c' in tagged_shape:
                ram_per_pixel *= float(tagged_shape['c'])

            if self.Output.meta.ram_usage_per_requested_pixel is not None:
                ram_per_pixel = max( ram_per_pixel, self.Output.meta.ram_usage_per_requested_pixel )

            self.Output.meta.ram_usage_per_requested_pixel = ram_per_pixel

        shape = self.Output.meta.shape
        if (self._dirtyShape is None or reconfigure) and shape is not None:
            with self._lock:
                self._allocateManagementStructures()
                if not self._lazyAlloc:
                    self._allocateCache()

        self.Output.meta.ideal_blockshape = self._get_full_blockshape(self._origBlockShape)

    def propagateDirty(self, slot, subindex, roi):
        shape = self.Input.meta.shape
        key = roi.toSlice()

        if slot == self.inputs["Input"]:
            start, stop = sliceToRoi(key, shape)

            with self._lock:
                if self._blockState is not None:
                    blockStart = numpy.floor(1.0 * start / self._blockShape)
                    blockStop = numpy.ceil(1.0 * stop / self._blockShape)
                    blockKey = roiToSlice(blockStart,blockStop)
                    if self._fixed:
                        # Remember that this block became dirty while we were fixed 
                        #  so we can notify downstream operators when we become unfixed.
                        self._blockState[blockKey] = OpArrayCache.FIXED_DIRTY
                        self._has_fixed_dirty_blocks = True
                    else:
                        self._blockState[blockKey] = OpArrayCache.DIRTY

            if not self._fixed:
                self.outputs["Output"].setDirty(key)
        if slot == self.inputs["fixAtCurrent"]:
            if self.inputs["fixAtCurrent"].ready():
                self._fixed = self.inputs["fixAtCurrent"].value
                if not self._fixed and self.Output.meta.shape is not None and self._has_fixed_dirty_blocks:
                    # We've become unfixed, so we need to notify downstream 
                    #  operators of every block that became dirty while we were fixed.
                    # Convert all FIXED_DIRTY states into DIRTY states
                    with self._lock:
                        cond = (self._blockState[...] == OpArrayCache.FIXED_DIRTY)
                        self._blockState[...]  = fastWhere(cond, OpArrayCache.DIRTY, self._blockState, numpy.uint8)
                        self._has_fixed_dirty_blocks = False
                    newDirtyBlocks = numpy.transpose(numpy.nonzero(cond))
                    
                    # To avoid lots of setDirty notifications, we simply merge all the dirtyblocks into one single superblock.
                    # This should be the best option in most cases, but could be bad in some cases.
                    # TODO: Optimize this by merging the dirty blocks via connected components or something.
                    cacheShape = numpy.array(self.Output.meta.shape)
                    dirtyStart = cacheShape
                    dirtyStop = [0] * len(cacheShape)
                    for index in newDirtyBlocks:
                        blockStart = index * self._blockShape
                        blockStop = numpy.minimum(blockStart + self._blockShape, cacheShape)
                        
                        dirtyStart = numpy.minimum(dirtyStart, blockStart)
                        dirtyStop = numpy.maximum(dirtyStop, blockStop)

                    if len(newDirtyBlocks > 0):
                        self.Output.setDirty( dirtyStart, dirtyStop )

    def _updatePriority(self, new_access = None):
        if self._last_access_time is None:
            self._last_access_time = new_access or time.time()
        cur_time = time.time()
        delta = cur_time - self._last_access_time + 1e-9

        self._last_access_time = cur_time
        new_prio = 0.5 * self._cache_priority + delta
        self._cache_priority = new_prio

    def execute(self, slot, subindex, roi, result):
        if slot == self.Output:
            return self._executeOutput(slot, subindex, roi, result)
        elif slot == self.CleanBlocks:
            return self._executeCleanBlocks(slot, subindex, roi, result)
        
    def _executeOutput(self, slot, subindex, roi, result):
        t = time.time()
        key = roi.toSlice()

        shape = self.Output.meta.shape
        start, stop = sliceToRoi(key, shape)

        with self._lock:
            ch = self._cacheHits
            ch += 1
            self._cacheHits = ch
    
            self._running += 1
    
            if (self._cache is None or 
                    self._cache.shape != self.Output.meta.shape):
                self._allocateCache()
    
            cacheView = self._cache[:] #prevent freeing of cache during running this function
    
    
            blockStart = (1.0 * start / self._blockShape).floor()
            blockStop = (1.0 * stop / self._blockShape).ceil()
            blockKey = roiToSlice(blockStart,blockStop)
    
            blockSet = self._blockState[blockKey]
    
            # this is a little optimization to shortcut
            # many lines of python code when all data is
            # is already in the cache:
            if numpy.logical_or(blockSet == OpArrayCache.CLEAN, blockSet == OpArrayCache.FIXED_DIRTY).all():
                cache_result = self._cache[roiToSlice(start, stop)]
                self.Output.stype.copy_data(result, cache_result)

                self._running -= 1
                self._updatePriority()
                cacheView = None
                return
    
            extracted = numpy.extract( blockSet == OpArrayCache.IN_PROCESS, self._blockQuery[blockKey])
            inProcessQueries = numpy.unique(extracted)
    
            cond = (blockSet == OpArrayCache.DIRTY)
            tileWeights = fastWhere(cond, 1, 128**3, numpy.uint32)
            trueDirtyIndices = numpy.nonzero(cond)
    
            if has_drtile:
                tileArray = drtile.test_DRTILE(tileWeights, 128**3).swapaxes(0,1)
            else:
                tileStartArray = numpy.array(trueDirtyIndices)
                tileStopArray = 1 + tileStartArray
                tileArray = numpy.concatenate((tileStartArray, tileStopArray), axis=0)
            
            dirtyRois = []
            half = tileArray.shape[0]//2
            dirtyPool = RequestPool()
    
            for i in range(tileArray.shape[1]):
    
                drStart3 = tileArray[:half,i]
                drStop3 = tileArray[half:,i]
                drStart2 = drStart3 + blockStart
                drStop2 = drStop3 + blockStart
                drStart = drStart2*self._blockShape
                drStop = drStop2*self._blockShape
    
                shape = self.Output.meta.shape
                drStop = numpy.minimum(drStop, shape)
                drStart = numpy.minimum(drStart, shape)
    
                key2 = roiToSlice(drStart2,drStop2)
    
                key = roiToSlice(drStart,drStop)
    
                if not self._fixed:
                    dirtyRois.append([drStart,drStop])
    
                    req = self.inputs["Input"][key].writeInto(self._cache[key])    
                    req.uncancellable = True #FIXME
                    
                    dirtyPool.add(req)
    
                    self._blockQuery[key2] = weakref.ref(req)
    
                    #sanity check:
                    if (self._blockState[key2] != OpArrayCache.DIRTY).any():
                        logger.warning( "original condition" + str(cond) )
                        logger.warning( "original tilearray {} {}".format( tileArray, tileArray.shape ) )
                        logger.warning( "original tileWeights {} {}".format( tileWeights, tileWeights.shape ) )
                        logger.warning( "sub condition {}".format( self._blockState[key2] == OpArrayCache.DIRTY ) )
                        logger.warning( "START={}, STOP={}".format( drStart2, drStop2 ) )
                        import h5py
                        with h5py.File("test.h5", "w") as f:
                            f.create_dataset("data",data = tileWeights)
                            logger.warning( "%r \n %r \n %r\n %r\n %r \n%r" % (key2, blockKey,self._blockState[key2], self._blockState[blockKey][trueDirtyIndices],self._blockState[blockKey],tileWeights) )
                        assert False
                    self._blockState[key2] = OpArrayCache.IN_PROCESS
    
            # indicate the inprocessing state, by setting array to 0 (i.e. IN_PROCESS)
            if not self._fixed:
                blockSet[:]  = fastWhere(cond, OpArrayCache.IN_PROCESS, blockSet, numpy.uint8)
            else:
                # Someone asked for some dirty blocks while we were fixed.
                # Mark these blocks to be signaled as dirty when we become unfixed
                blockSet[:]  = fastWhere(cond, OpArrayCache.FIXED_DIRTY, blockSet, numpy.uint8)
                self._has_fixed_dirty_blocks = True

        temp = itertools.count(0)

        #wait for all requests to finish
        something_updated = len( dirtyPool ) > 0
        dirtyPool.wait()
        if something_updated:
            # Signal that something was updated.
            # Note that we don't need to do this for the 'in process' queries (below)  
            #  because they are already in the dirtyPool in some other thread
            self.Output._sig_value_changed()

        # indicate the finished inprocess state (i.e. CLEAN)
        if not self._fixed and temp.next() == 0:
            with self._lock:
                blockSet[:] = fastWhere(cond, OpArrayCache.CLEAN, blockSet, numpy.uint8)
                self._blockQuery[blockKey] = fastWhere(cond, None, self._blockQuery[blockKey], object)

        # Wait for all in-process queries.
        # Can't use RequestPool here because these requests have already started.
        for req in inProcessQueries:
            req = req() # get original req object from weakref
            if req is not None:
                req.wait()

        # finally, store results in result area
        with self._lock:
            if self._cache is not None:
                cache_result = self._cache[roiToSlice(start, stop)]
                self.Output.stype.copy_data(result, cache_result)
            else:
                self.inputs["Input"][roiToSlice(start, stop)].writeInto(result).wait()
            self._running -= 1
            self._updatePriority()
            cacheView = None
        self.logger.debug("read %s took %f sec." % (roi.pprint(), time.time()-t))

    def setInSlot(self, slot, subindex, roi, value):
        assert slot == self.inputs["Input"]
        ch = self._cacheHits
        ch += 1
        self._cacheHits = ch
        start, stop = roi.start, roi.stop
        blockStart = numpy.ceil(1.0 * start / self._blockShape)
        blockStop = numpy.floor(1.0 * stop / self._blockShape)
        blockStop = numpy.where(stop == self.Output.meta.shape, self._dirtyShape, blockStop)
        blockKey = roiToSlice(blockStart,blockStop)

        if (self._blockState[blockKey] != OpArrayCache.CLEAN).any():
            start2 = blockStart * self._blockShape
            stop2 = blockStop * self._blockShape
            stop2 = numpy.minimum(stop2, self.Output.meta.shape)
            key2 = roiToSlice(start2,stop2)
            with self._lock:
                if self._cache is None:
                    self._allocateCache()
                self.Output.stype.copy_data(
                    self._cache[key2],
                    value[roiToSlice(start2-start,stop2-start)]
                )
                self._blockState[blockKey] = self._dirtyState
                self._blockQuery[blockKey] = None

    def _executeCleanBlocks(self, slot, subindex, roi, destination):
        indexCols = numpy.where(self._blockState == OpArrayCache.CLEAN)
        clean_block_starts = numpy.array(indexCols).transpose()
        clean_block_starts *= self._blockShape
            
        inputShape = self.Input.meta.shape
        clean_block_rois = map( partial( getBlockBounds, inputShape, self._blockShape ),
                                clean_block_starts )
        destination[0] = map( partial(map, TinyVector), clean_block_rois )
        return destination
Example #8
0
class OpVigraWatershedViewer(Operator):
    name = "OpWatershedViewer"
    category = "top-level"

    RawImage = InputSlot(
        optional=True)  # Displayed in the GUI (not used in pipeline)

    InputImage = InputSlot()  # The image to be sliced and watershedded

    FreezeCache = InputSlot(value=True)  # opWatershedCache

    InputChannelIndexes = InputSlot(value=[0])
    WatershedPadding = InputSlot(value=10)
    OverrideLabels = InputSlot(value={0: (0, 0, 0, 0)})
    SeedThresholdValue = InputSlot(value=0.0)
    MinSeedSize = InputSlot(value=0)
    CacheBlockShape = InputSlot(
        value=(256, 10)
    )  # opWatershedCache block shapes. Expected: tuple of (width, depth) viewing shape

    Seeds = OutputSlot()  # For batch export
    WatershedLabels = OutputSlot()  # Watershed labeled output
    SummedInput = OutputSlot()  # Watershed input (for gui display)
    ColoredPixels = OutputSlot()  # Colored watershed labels (for gui display)
    ColoredSeeds = OutputSlot()  # Seeds to the watershed (for gui display)

    SelectedInputChannels = OutputSlot(level=1)

    def __init__(self, *args, **kwargs):
        super(OpVigraWatershedViewer, self).__init__(*args, **kwargs)
        self._seedThreshold = None

        # Overview Schematic
        # Example here uses input channels 0,2,5

        # InputChannelIndexes=[0,2,5] ----
        #                                 \
        # InputImage --> opChannelSlicer .Slices[0] ---\
        #                                .Slices[1] ----> opAverage -------------------------------------------------> opWatershed --> opWatershedCache --> opColorizer --> GUI
        #                                .Slices[2] ---/           \                                                  /
        #                                                           \     MinSeedSize                                /
        #                                                            \               \                              /
        #                              SeedThresholdValue ----------> opThreshold --> opSeedLabeler --> opSeedFilter --> opSeedCache --> opSeedColorizer --> GUI

        # Create operators
        self.opChannelSlicer = OpMultiArraySlicer2(parent=self)
        self.opAverage = OpMultiArrayMerger(parent=self)
        self.opWatershed = OpVigraWatershed(parent=self)
        self.opWatershedCache = OpSlicedBlockedArrayCache(parent=self)
        self.opColorizer = OpColorizeLabels(parent=self)

        self.opThreshold = OpPixelOperator(parent=self)
        self.opSeedLabeler = OpVigraLabelVolume(parent=self)
        self.opSeedFilter = OpFilterLabels(parent=self)
        self.opSeedCache = OpSlicedBlockedArrayCache(parent=self)
        self.opSeedColorizer = OpColorizeLabels(parent=self)

        # Select specific input channels
        self.opChannelSlicer.Input.connect(self.InputImage)
        self.opChannelSlicer.SliceIndexes.connect(self.InputChannelIndexes)
        self.opChannelSlicer.AxisFlag.setValue('c')

        # Average selected channels
        def average(arrays):
            if len(arrays) == 0:
                return 0
            else:
                return sum(arrays) / float(len(arrays))

        self.opAverage.MergingFunction.setValue(average)
        self.opAverage.Inputs.connect(self.opChannelSlicer.Slices)

        # Threshold for seeds
        self.opThreshold.Input.connect(self.opAverage.Output)

        # Label seeds
        self.opSeedLabeler.Input.connect(self.opThreshold.Output)

        # Filter seeds
        self.opSeedFilter.MinLabelSize.connect(self.MinSeedSize)
        self.opSeedFilter.Input.connect(self.opSeedLabeler.Output)

        # Cache seeds
        self.opSeedCache.fixAtCurrent.connect(self.FreezeCache)
        self.opSeedCache.Input.connect(self.opSeedFilter.Output)

        # Color seeds for RBG display
        self.opSeedColorizer.Input.connect(self.opSeedCache.Output)
        self.opSeedColorizer.OverrideColors.setValue({0: (0, 0, 0, 0)})

        # Compute watershed labels (possibly with seeds, see setupOutputs)
        self.opWatershed.InputImage.connect(self.opAverage.Output)
        self.opWatershed.PaddingWidth.connect(self.WatershedPadding)

        # Cache the watershed output
        self.opWatershedCache.fixAtCurrent.connect(self.FreezeCache)
        self.opWatershedCache.Input.connect(self.opWatershed.Output)

        # Colorize the watershed labels for RGB display
        self.opColorizer.Input.connect(self.opWatershedCache.Output)
        self.opColorizer.OverrideColors.connect(self.OverrideLabels)

        # Connnect external outputs the operators that provide them
        self.Seeds.connect(self.opSeedCache.Output)
        self.ColoredPixels.connect(self.opColorizer.Output)
        self.SelectedInputChannels.connect(self.opChannelSlicer.Slices)
        self.SummedInput.connect(self.opAverage.Output)
        self.ColoredSeeds.connect(self.opSeedColorizer.Output)

    def setupOutputs(self):
        # User has control over cache block shape
        # Width and depth are applied to x,y, or z depending on which slicing view is being used.
        width, depth = self.CacheBlockShape.value

        ## Cache blocks
        # Inner and outer block shapes are the same.
        # We're using this cache for the "sliced" property, not the "blocked" property.
        blockDimsX = {
            't': (1, 1),
            'z': (width, width),
            'y': (width, width),
            'x': (depth, depth),
            'c': (1, 1)
        }

        blockDimsY = {
            't': (1, 1),
            'z': (width, width),
            'y': (depth, depth),
            'x': (width, width),
            'c': (1, 1)
        }

        blockDimsZ = {
            't': (1, 1),
            'z': (depth, depth),
            'y': (width, width),
            'x': (width, width),
            'c': (1, 1)
        }

        # Set the blockshapes for each input image separately, depending on which axistags it has.
        axisOrder = [tag.key for tag in self.InputImage.meta.axistags]

        innerBlockShapeX = tuple(blockDimsX[k][0] for k in axisOrder)
        outerBlockShapeX = tuple(blockDimsX[k][1] for k in axisOrder)

        innerBlockShapeY = tuple(blockDimsY[k][0] for k in axisOrder)
        outerBlockShapeY = tuple(blockDimsY[k][1] for k in axisOrder)

        innerBlockShapeZ = tuple(blockDimsZ[k][0] for k in axisOrder)
        outerBlockShapeZ = tuple(blockDimsZ[k][1] for k in axisOrder)

        self.opWatershedCache.innerBlockShape.setValue(
            (innerBlockShapeX, innerBlockShapeY, innerBlockShapeZ))
        self.opWatershedCache.outerBlockShape.setValue(
            (outerBlockShapeX, outerBlockShapeY, outerBlockShapeZ))

        # Seed cache has same shape as watershed cache
        self.opSeedCache.innerBlockShape.setValue(
            (innerBlockShapeX, innerBlockShapeY, innerBlockShapeZ))
        self.opSeedCache.outerBlockShape.setValue(
            (outerBlockShapeX, outerBlockShapeY, outerBlockShapeZ))

        # For now watershed labels always come from the X-Y slicing view
        if len(self.opWatershedCache.InnerOutputs) > 0:
            self.WatershedLabels.connect(self.opWatershedCache.InnerOutputs[2])

        if self.SeedThresholdValue.ready():
            seedThreshold = self.SeedThresholdValue.value
            if not self.opWatershed.SeedImage.connected(
            ) or seedThreshold != self._seedThreshold:
                self._seedThreshold = seedThreshold

                self.opThreshold.Function.setValue(
                    lambda a: (a <= seedThreshold).astype(numpy.uint8))
                self.opWatershed.SeedImage.connect(self.opSeedFilter.Output)
        else:
            self.opWatershed.SeedImage.disconnect()
            self.opThreshold.Function.disconnect()

    def propagateDirty(self, slot, subindex, roi):
        # All outputs are directly connected to internal operators
        pass
Example #9
0
class OpInputDataReader(Operator):
    """
    This operator can read input data of any supported type.
    The data format is determined from the file extension.
    """

    name = "OpInputDataReader"
    category = "Input"

    videoExts = ["ufmf", "mmf"]
    h5_n5_Exts = ["h5", "hdf5", "ilp", "n5"]
    n5Selection = [
        "json"
    ]  # n5 stores data in a directory, containing a json-file which we use to select the n5-file
    klbExts = ["klb"]
    npyExts = ["npy"]
    npzExts = ["npz"]
    rawExts = ["dat", "bin", "raw"]
    blockwiseExts = ["json"]
    tiledExts = ["json"]
    tiffExts = ["tif", "tiff"]
    vigraImpexExts = vigra.impex.listExtensions().split()

    SupportedExtensions = (h5_n5_Exts + n5Selection + npyExts + npzExts +
                           rawExts + vigraImpexExts + blockwiseExts +
                           videoExts + klbExts)

    if _supports_dvid:
        dvidExts = ["dvidvol"]
        SupportedExtensions += dvidExts

    if _supports_h5blockreader:
        h5blockstoreExts = ["json"]
        SupportedExtensions += h5blockstoreExts

    # FilePath is inspected to determine data type.
    # For hdf5 files, append the internal path to the filepath,
    #  e.g. /mydir/myfile.h5/internal/path/to/dataset
    # For stacks, provide a globstring, e.g. /mydir/input*.png
    # Other types are determined via file extension
    WorkingDirectory = InputSlot(stype="filestring", optional=True)
    FilePath = InputSlot(stype="filestring")
    SequenceAxis = InputSlot(optional=True)

    # FIXME: Document this.
    SubVolumeRoi = InputSlot(optional=True)  # (start, stop)

    Output = OutputSlot()

    loggingName = __name__ + ".OpInputDataReader"
    logger = logging.getLogger(loggingName)

    class DatasetReadError(Exception):
        pass

    def __init__(self, *args, **kwargs):
        super(OpInputDataReader, self).__init__(*args, **kwargs)
        self.internalOperators = []
        self.internalOutput = None
        self._file = None

    def cleanUp(self):
        super(OpInputDataReader, self).cleanUp()
        if self._file is not None:
            self._file.close()
            self._file = None

    def setupOutputs(self):
        """
        Inspect the file name and instantiate and connect an internal operator of the appropriate type.
        TODO: Handle datasets of non-standard (non-5d) dimensions.
        """
        filePath = self.FilePath.value
        assert isinstance(
            filePath,
            (str, unicode
             )), "Error: filePath is not of type str.  It's of type {}".format(
                 type(filePath))

        # Does this look like a relative path?
        useRelativePath = not isUrl(filePath) and not os.path.isabs(filePath)

        if useRelativePath:
            # If using a relative path, we need both inputs before proceeding
            if not self.WorkingDirectory.ready():
                return
            else:
                # Convert this relative path into an absolute path
                filePath = os.path.normpath(
                    os.path.join(self.WorkingDirectory.value,
                                 filePath)).replace("\\", "/")

        # Clean up before reconfiguring
        if self.internalOperators:
            self.Output.disconnect()
            self.opInjector.cleanUp()
            for op in self.internalOperators[::-1]:
                op.cleanUp()
            self.internalOperators = []
            self.internalOutput = None
        if self._file is not None:
            self._file.close()

        openFuncs = [
            self._attemptOpenAsKlb,
            self._attemptOpenAsUfmf,
            self._attemptOpenAsMmf,
            self._attemptOpenAsRESTfulPrecomputedChunkedVolume,
            self._attemptOpenAsDvidVolume,
            self._attemptOpenAsH5N5Stack,
            self._attemptOpenAsTiffStack,
            self._attemptOpenAsStack,
            self._attemptOpenAsH5N5,
            self._attemptOpenAsNpy,
            self._attemptOpenAsRawBinary,
            self._attemptOpenAsTiledVolume,
            self._attemptOpenAsH5BlockStore,
            self._attemptOpenAsBlockwiseFileset,
            self._attemptOpenAsRESTfulBlockwiseFileset,
            self._attemptOpenAsBigTiff,
            self._attemptOpenAsTiff,
            self._attemptOpenWithVigraImpex,
        ]

        # Try every method of opening the file until one works.
        iterFunc = openFuncs.__iter__()
        while not self.internalOperators:
            try:
                openFunc = next(iterFunc)
            except StopIteration:
                break
            self.internalOperators, self.internalOutput = openFunc(filePath)

        if self.internalOutput is None:
            raise RuntimeError("Can't read " + filePath +
                               " because it has an unrecognized format.")

        # If we've got a ROI, append a subregion operator.
        if self.SubVolumeRoi.ready():
            self._opSubRegion = OpSubRegion(parent=self)
            self._opSubRegion.Roi.setValue(self.SubVolumeRoi.value)
            self._opSubRegion.Input.connect(self.internalOutput)
            self.internalOutput = self._opSubRegion.Output

        self.opInjector = OpMetadataInjector(parent=self)
        self.opInjector.Input.connect(self.internalOutput)

        # Add metadata for estimated RAM usage if the internal operator didn't already provide it.
        if self.internalOutput.meta.ram_per_pixelram_usage_per_requested_pixel is None:
            ram_per_pixel = self.internalOutput.meta.dtype().nbytes
            if "c" in self.internalOutput.meta.getTaggedShape():
                ram_per_pixel *= self.internalOutput.meta.getTaggedShape()["c"]
            self.opInjector.Metadata.setValue(
                {"ram_per_pixelram_usage_per_requested_pixel": ram_per_pixel})
        else:
            # Nothing to add
            self.opInjector.Metadata.setValue({})

        # Directly connect our own output to the internal output
        self.Output.connect(self.opInjector.Output)

    def _attemptOpenAsKlb(self, filePath):
        if not os.path.splitext(filePath)[1].lower() == ".klb":
            return ([], None)

        opReader = OpKlbReader(parent=self)
        opReader.FilePath.setValue(filePath)
        return [opReader, opReader.Output]

    def _attemptOpenAsMmf(self, filePath):
        if ".mmf" in filePath:
            mmfReader = OpStreamingMmfReader(parent=self)
            mmfReader.FileName.setValue(filePath)

            return ([mmfReader], mmfReader.Output)
            """
            # Cache the frames we read
            frameShape = mmfReader.Output.meta.ideal_blockshape

            mmfCache = OpBlockedArrayCache( parent=self )
            mmfCache.fixAtCurrent.setValue( False )
            mmfCache.BlockShape.setValue( frameShape )
            mmfCache.Input.connect( mmfReader.Output )

            return ([mmfReader, mmfCache], mmfCache.Output)
            """
        else:
            return ([], None)

    def _attemptOpenAsUfmf(self, filePath):
        if ".ufmf" in filePath:
            ufmfReader = OpStreamingUfmfReader(parent=self)
            ufmfReader.FileName.setValue(filePath)

            return ([ufmfReader], ufmfReader.Output)

            # Cache the frames we read
            """
            frameShape = ufmfReader.Output.meta.ideal_blockshape

            ufmfCache = OpBlockedArrayCache( parent=self )
            ufmfCache.fixAtCurrent.setValue( False )
            ufmfCache.BlockShape.setValue( frameShape )
            ufmfCache.Input.connect( ufmfReader.Output )

            return ([ufmfReader, ufmfCache], ufmfCache.Output)
            """
        else:
            return ([], None)

    def _attemptOpenAsRESTfulPrecomputedChunkedVolume(self, filePath):
        if not filePath.lower().startswith("precomputed://"):
            return ([], None)
        else:
            url = filePath.lstrip("precomputed://")
            reader = OpRESTfulPrecomputedChunkedVolumeReader(parent=self)
            reader.BaseUrl.setValue(url)
            return [reader], reader.Output

    def _attemptOpenAsH5N5Stack(self, filePath):
        if not ("*" in filePath or os.path.pathsep in filePath):
            return ([], None)

        # Now use the .checkGlobString method of the stack readers
        isSingleFile = True
        try:
            OpStreamingH5N5SequenceReaderS.checkGlobString(filePath)
        except OpStreamingH5N5SequenceReaderS.WrongFileTypeError:
            return ([], None)
        except (
                OpStreamingH5N5SequenceReaderS.NoInternalPlaceholderError,
                OpStreamingH5N5SequenceReaderS.NotTheSameFileError,
                OpStreamingH5N5SequenceReaderS.ExternalPlaceholderError,
        ):
            isSingleFile = False

        isMultiFile = True
        try:
            OpStreamingH5N5SequenceReaderM.checkGlobString(filePath)
        except (
                OpStreamingH5N5SequenceReaderM.NoExternalPlaceholderError,
                OpStreamingH5N5SequenceReaderM.SameFileError,
                OpStreamingH5N5SequenceReaderM.InternalPlaceholderError,
        ):
            isMultiFile = False

        assert not (isMultiFile and isSingleFile)

        if isSingleFile is True:
            opReader = OpStreamingH5N5SequenceReaderS(parent=self)
        elif isMultiFile is True:
            opReader = OpStreamingH5N5SequenceReaderM(parent=self)

        try:
            opReader.SequenceAxis.connect(self.SequenceAxis)
            opReader.GlobString.setValue(filePath)
            return ([opReader], opReader.OutputImage)
        except (OpStreamingH5N5SequenceReaderM.WrongFileTypeError,
                OpStreamingH5N5SequenceReaderS.WrongFileTypeError):
            return ([], None)

    def _attemptOpenAsTiffStack(self, filePath):
        if not ("*" in filePath or os.path.pathsep in filePath):
            return ([], None)

        try:
            opReader = OpTiffSequenceReader(parent=self)
            opReader.SequenceAxis.connect(self.SequenceAxis)
            opReader.GlobString.setValue(filePath)
            return ([opReader], opReader.Output)
        except OpTiffSequenceReader.WrongFileTypeError as ex:
            return ([], None)

    def _attemptOpenAsStack(self, filePath):
        if "*" in filePath or os.path.pathsep in filePath:
            stackReader = OpStackLoader(parent=self)
            stackReader.SequenceAxis.connect(self.SequenceAxis)
            stackReader.globstring.setValue(filePath)
            return ([stackReader], stackReader.stack)
        else:
            return ([], None)

    def _attemptOpenAsH5N5(self, filePath):
        # Check for an hdf5 or n5 extension
        pathComponents = PathComponents(filePath)
        ext = pathComponents.extension
        if ext[1:] not in OpInputDataReader.h5_n5_Exts:
            return [], None

        externalPath = pathComponents.externalPath
        internalPath = pathComponents.internalPath

        if not os.path.exists(externalPath):
            raise OpInputDataReader.DatasetReadError(
                "Input file does not exist: " + externalPath)

        # Open the h5/n5 file in read-only mode
        try:
            h5N5File = OpStreamingH5N5Reader.get_h5_n5_file(externalPath, "r")
        except OpInputDataReader.DatasetReadError:
            raise
        except Exception as e:
            msg = "Unable to open H5/N5 File: {}\n{}".format(
                externalPath, str(e))
            raise OpInputDataReader.DatasetReadError(msg) from e
        else:
            if not internalPath:
                possible_internal_paths = lsH5N5(h5N5File)
                if len(possible_internal_paths) == 1:
                    internalPath = possible_internal_paths[0]["name"]
                elif len(possible_internal_paths) == 0:
                    h5N5File.close()
                    msg = "H5/N5 file contains no datasets: {}".format(
                        externalPath)
                    raise OpInputDataReader.DatasetReadError(msg)
                else:
                    h5N5File.close()
                    msg = (
                        "When using hdf5/n5, you must append the hdf5 internal path to the "
                        "data set to your filename, e.g. myfile.h5/volume/data  "
                        "No internal path provided for dataset in file: {}".
                        format(externalPath))
                    raise OpInputDataReader.DatasetReadError(msg)
            try:
                compression_setting = h5N5File[internalPath].compression
            except Exception as e:
                h5N5File.close()
                msg = "Error reading H5/N5 File: {}\n{}".format(
                    externalPath, e)
                raise OpInputDataReader.DatasetReadError(msg) from e

            # If the h5 dataset is compressed, we'll have better performance
            #  with a multi-process hdf5 access object.
            # (Otherwise, single-process is faster.)
            allow_multiprocess_hdf5 = (
                "LAZYFLOW_MULTIPROCESS_HDF5" in os.environ
                and os.environ["LAZYFLOW_MULTIPROCESS_HDF5"] != "")
            if compression_setting is not None and allow_multiprocess_hdf5 and isinstance(
                    h5N5File, h5py.File):
                h5N5File.close()
                h5N5File = MultiProcessHdf5File(externalPath, "r")

        self._file = h5N5File

        h5N5Reader = OpStreamingH5N5Reader(parent=self)
        h5N5Reader.H5N5File.setValue(h5N5File)

        try:
            h5N5Reader.InternalPath.setValue(internalPath)
        except OpStreamingH5N5Reader.DatasetReadError as e:
            msg = "Error reading H5/N5 File: {}\n{}".format(
                externalPath, e.msg)
            raise OpInputDataReader.DatasetReadError(msg) from e

        return ([h5N5Reader], h5N5Reader.OutputImage)

    def _attemptOpenAsNpy(self, filePath):
        pathComponents = PathComponents(filePath)
        ext = pathComponents.extension
        npyzExts = OpInputDataReader.npyExts + OpInputDataReader.npzExts
        if ext not in (".%s" % x for x in npyzExts):
            return ([], None)

        externalPath = pathComponents.externalPath
        internalPath = pathComponents.internalPath
        # FIXME: check whether path is valid?!

        if not os.path.exists(externalPath):
            raise OpInputDataReader.DatasetReadError(
                "Input file does not exist: " + externalPath)

        try:
            # Create an internal operator
            npyReader = OpNpyFileReader(parent=self)
            if internalPath is not None:
                internalPath = internalPath.replace("/", "")
            npyReader.InternalPath.setValue(internalPath)
            npyReader.FileName.setValue(externalPath)
            return ([npyReader], npyReader.Output)
        except OpNpyFileReader.DatasetReadError as e:
            raise OpInputDataReader.DatasetReadError(*e.args) from e

    def _attemptOpenAsRawBinary(self, filePath):
        fileExtension = os.path.splitext(filePath)[1].lower()
        fileExtension = fileExtension.lstrip(".")  # Remove leading dot

        # Check for numpy extension
        if fileExtension not in OpInputDataReader.rawExts:
            return ([], None)
        else:
            try:
                # Create an internal operator
                opReader = OpRawBinaryFileReader(parent=self)
                opReader.FilePath.setValue(filePath)
                return ([opReader], opReader.Output)
            except OpRawBinaryFileReader.DatasetReadError as e:
                raise OpInputDataReader.DatasetReadError(*e.args) from e

    def _attemptOpenAsH5BlockStore(self, filePath):
        if not os.path.splitext(filePath)[1] == ".json":
            return ([], None)

        op = OpH5BlockStoreReader(parent=self)
        try:
            # For now, there is no explicit schema validation for the json file,
            # but H5BlockStore constructor will fail to load the json.
            op.IndexFilepath.setValue(filePath)
            return [op], op.Output
        except:
            raise  # DELME
            op.cleanUp()
            return ([], None)

    def _attemptOpenAsDvidVolume(self, filePath):
        """
        Two ways to specify a dvid volume.
        1) via a file that contains the hostname, uuid, and dataset name (1 per line)
        2) as a url, e.g. http://localhost:8000/api/node/uuid/dataname
        """
        if os.path.splitext(filePath)[1] == ".dvidvol":
            with open(filePath) as f:
                filetext = f.read()
                hostname, uuid, dataname = filetext.splitlines()
            opDvidVolume = OpDvidVolume(hostname, uuid, dataname, parent=self)
            return [opDvidVolume], opDvidVolume.Output

        if "://" not in filePath:
            return ([], None)  # not a url

        url_format = "^protocol://hostname/api/node/uuid/dataname(\\?query_string)?"
        for field in [
                "protocol", "hostname", "uuid", "dataname", "query_string"
        ]:
            url_format = url_format.replace(field, "(?P<" + field + ">[^?]+)")
        match = re.match(url_format, filePath)
        if not match:
            # DVID is the only url-based format we support right now.
            # So if it looks like the user gave a URL that isn't a valid DVID node, then error.
            raise OpInputDataReader.DatasetReadError(
                "Invalid URL format for DVID: {}".format(filePath))

        fields = match.groupdict()
        try:
            query_string = fields["query_string"]
            query_args = {}
            if query_string:
                query_args = dict(
                    [s.split("=") for s in query_string.split("&")])
            try:
                opDvidVolume = OpDvidVolume(fields["hostname"],
                                            fields["uuid"],
                                            fields["dataname"],
                                            query_args,
                                            parent=self)
                return [opDvidVolume], opDvidVolume.Output
            except:
                # Maybe this is actually a roi
                opDvidRoi = OpDvidRoi(fields["hostname"],
                                      fields["uuid"],
                                      fields["dataname"],
                                      parent=self)
                return [opDvidRoi], opDvidRoi.Output
        except OpDvidVolume.DatasetReadError as e:
            raise OpInputDataReader.DatasetReadError(*e.args) from e

    def _attemptOpenAsBlockwiseFileset(self, filePath):
        fileExtension = os.path.splitext(filePath)[1].lower()
        fileExtension = fileExtension.lstrip(".")  # Remove leading dot

        if fileExtension in OpInputDataReader.blockwiseExts:
            opReader = OpBlockwiseFilesetReader(parent=self)
            try:
                # This will raise a SchemaError if this is the wrong type of json config.
                opReader.DescriptionFilePath.setValue(filePath)
                return ([opReader], opReader.Output)
            except JsonConfigParser.SchemaError:
                opReader.cleanUp()
            except OpBlockwiseFilesetReader.MissingDatasetError as e:
                raise OpInputDataReader.DatasetReadError(*e.args) from e
        return ([], None)

    def _attemptOpenAsRESTfulBlockwiseFileset(self, filePath):
        fileExtension = os.path.splitext(filePath)[1].lower()
        fileExtension = fileExtension.lstrip(".")  # Remove leading dot

        if fileExtension in OpInputDataReader.blockwiseExts:
            opReader = OpRESTfulBlockwiseFilesetReader(parent=self)
            try:
                # This will raise a SchemaError if this is the wrong type of json config.
                opReader.DescriptionFilePath.setValue(filePath)
                return ([opReader], opReader.Output)
            except JsonConfigParser.SchemaError:
                opReader.cleanUp()
            except OpRESTfulBlockwiseFilesetReader.MissingDatasetError as e:
                raise OpInputDataReader.DatasetReadError(*e.args) from e
        return ([], None)

    def _attemptOpenAsTiledVolume(self, filePath):
        fileExtension = os.path.splitext(filePath)[1].lower()
        fileExtension = fileExtension.lstrip(".")  # Remove leading dot

        if fileExtension in OpInputDataReader.tiledExts:
            opReader = OpCachedTiledVolumeReader(parent=self)
            try:
                # This will raise a SchemaError if this is the wrong type of json config.
                opReader.DescriptionFilePath.setValue(filePath)
                return ([opReader], opReader.SpecifiedOutput)
            except JsonConfigParser.SchemaError:
                opReader.cleanUp()
        return ([], None)

    def _attemptOpenAsBigTiff(self, filePath):
        if not _supports_bigtiff:
            return ([], None)

        fileExtension = os.path.splitext(filePath)[1].lower()
        fileExtension = fileExtension.lstrip(".")  # Remove leading dot

        if fileExtension not in OpInputDataReader.tiffExts:
            return ([], None)

        if not os.path.exists(filePath):
            raise OpInputDataReader.DatasetReadError(
                "Input file does not exist: " + filePath)

        opReader = OpBigTiffReader(parent=self)
        try:
            opReader.Filepath.setValue(filePath)
            return ([opReader], opReader.Output)
        except OpBigTiffReader.NotBigTiffError as ex:
            opReader.cleanUp()
        return ([], None)

    def _attemptOpenAsTiff(self, filePath):
        fileExtension = os.path.splitext(filePath)[1].lower()
        fileExtension = fileExtension.lstrip(".")  # Remove leading dot

        if fileExtension not in OpInputDataReader.tiffExts:
            return ([], None)

        if not os.path.exists(filePath):
            raise OpInputDataReader.DatasetReadError(
                "Input file does not exist: " + filePath)

        opReader = OpTiffReader(parent=self)
        opReader.Filepath.setValue(filePath)

        page_shape = opReader.Output.meta.ideal_blockshape

        # Cache the pages we read
        opCache = OpBlockedArrayCache(parent=self)
        opCache.fixAtCurrent.setValue(False)
        opCache.BlockShape.setValue(page_shape)
        opCache.Input.connect(opReader.Output)

        return ([opReader, opCache], opCache.Output)

    def _attemptOpenWithVigraImpex(self, filePath):
        fileExtension = os.path.splitext(filePath)[1].lower()
        fileExtension = fileExtension.lstrip(".")  # Remove leading dot

        if fileExtension not in OpInputDataReader.vigraImpexExts:
            return ([], None)

        if not os.path.exists(filePath):
            raise OpInputDataReader.DatasetReadError(
                "Input file does not exist: " + filePath)

        vigraReader = OpImageReader(parent=self)
        vigraReader.Filename.setValue(filePath)

        # Cache the image instead of reading the hard disk for every access.
        imageCache = OpBlockedArrayCache(parent=self)
        imageCache.Input.connect(vigraReader.Image)

        # 2D: Just one block for the whole image
        cacheBlockShape = vigraReader.Image.meta.shape

        taggedShape = vigraReader.Image.meta.getTaggedShape()
        if "z" in list(taggedShape.keys()):
            # 3D: blocksize is one slice.
            taggedShape["z"] = 1
            cacheBlockShape = tuple(taggedShape.values())

        imageCache.fixAtCurrent.setValue(False)
        imageCache.BlockShape.setValue(cacheBlockShape)
        assert imageCache.Output.ready()

        return ([vigraReader, imageCache], imageCache.Output)

    def execute(self, slot, subindex, roi, result):
        assert False, "Shouldn't get here because our output is directly connected..."

    def propagateDirty(self, slot, subindex, roi):
        # Output slots are directly conncted to internal operators
        pass
Example #10
0
class OpEdgeTraining(Operator):
    # Shared across lanes
    DEFAULT_FEATURES = {"Grayscale": ['standard_edge_mean']}
    FeatureNames = InputSlot(value=DEFAULT_FEATURES)
    FreezeClassifier = InputSlot(value=True)

    # Lane-wise
    EdgeLabelsDict = InputSlot(level=1, value={})
    VoxelData = InputSlot(level=1)
    Superpixels = InputSlot(level=1)
    GroundtruthSegmentation = InputSlot(level=1, optional=True)
    RawData = InputSlot(level=1,
                        optional=True)  # Used by the GUI for display only

    Rag = OutputSlot(level=1)
    EdgeProbabilities = OutputSlot(level=1)
    EdgeProbabilitiesDict = OutputSlot(
        level=1)  # A dict of id_pair -> probabilities
    NaiveSegmentation = OutputSlot(level=1)

    def __init__(self, *args, **kwargs):
        super(OpEdgeTraining, self).__init__(*args, **kwargs)

        self.opCreateRag = OpMultiLaneWrapper(OpCreateRag, parent=self)
        self.opCreateRag.Superpixels.connect(self.Superpixels)

        self.opRagCache = OpMultiLaneWrapper(
            OpValueCache, parent=self, broadcastingSlotNames=['fixAtCurrent'])
        self.opRagCache.Input.connect(self.opCreateRag.Rag)
        self.opRagCache.name = 'opRagCache'

        self.opComputeEdgeFeatures = OpMultiLaneWrapper(
            OpComputeEdgeFeatures,
            parent=self,
            broadcastingSlotNames=['FeatureNames'])
        self.opComputeEdgeFeatures.FeatureNames.connect(self.FeatureNames)
        self.opComputeEdgeFeatures.VoxelData.connect(self.VoxelData)
        self.opComputeEdgeFeatures.Rag.connect(self.opRagCache.Output)

        self.opEdgeFeaturesCache = OpMultiLaneWrapper(
            OpValueCache, parent=self, broadcastingSlotNames=['fixAtCurrent'])
        self.opEdgeFeaturesCache.Input.connect(
            self.opComputeEdgeFeatures.EdgeFeaturesDataFrame)
        self.opEdgeFeaturesCache.name = 'opEdgeFeaturesCache'

        self.opTrainEdgeClassifier = OpTrainEdgeClassifier(parent=self)
        self.opTrainEdgeClassifier.EdgeLabelsDict.connect(self.EdgeLabelsDict)
        self.opTrainEdgeClassifier.EdgeFeaturesDataFrame.connect(
            self.opEdgeFeaturesCache.Output)

        # classifier cache input is set after training.
        self.opClassifierCache = OpValueCache(parent=self)
        self.opClassifierCache.Input.connect(
            self.opTrainEdgeClassifier.EdgeClassifier)
        self.opClassifierCache.fixAtCurrent.connect(self.FreezeClassifier)
        self.opClassifierCache.name = 'opClassifierCache'

        self.opPredictEdgeProbabilities = OpMultiLaneWrapper(
            OpPredictEdgeProbabilities,
            parent=self,
            broadcastingSlotNames=['EdgeClassifier'])
        self.opPredictEdgeProbabilities.EdgeClassifier.connect(
            self.opClassifierCache.Output)
        self.opPredictEdgeProbabilities.EdgeFeaturesDataFrame.connect(
            self.opEdgeFeaturesCache.Output)

        self.opEdgeProbabilitiesCache = OpMultiLaneWrapper(
            OpValueCache, parent=self, broadcastingSlotNames=['fixAtCurrent'])
        self.opEdgeProbabilitiesCache.Input.connect(
            self.opPredictEdgeProbabilities.EdgeProbabilities)
        self.opEdgeProbabilitiesCache.name = 'opEdgeProbabilitiesCache'
        self.opEdgeProbabilitiesCache.fixAtCurrent.connect(
            self.FreezeClassifier)

        self.opEdgeProbabilitiesDict = OpMultiLaneWrapper(
            OpEdgeProbabilitiesDict, parent=self)
        self.opEdgeProbabilitiesDict.Rag.connect(self.opRagCache.Output)
        self.opEdgeProbabilitiesDict.EdgeProbabilities.connect(
            self.opEdgeProbabilitiesCache.Output)

        self.opEdgeProbabilitiesDictCache = OpMultiLaneWrapper(
            OpValueCache, parent=self, broadcastingSlotNames=['fixAtCurrent'])
        self.opEdgeProbabilitiesDictCache.Input.connect(
            self.opEdgeProbabilitiesDict.EdgeProbabilitiesDict)
        self.opEdgeProbabilitiesDictCache.name = 'opEdgeProbabilitiesDictCache'

        self.opNaiveSegmentation = OpMultiLaneWrapper(OpNaiveSegmentation,
                                                      parent=self)
        self.opNaiveSegmentation.Superpixels.connect(self.Superpixels)
        self.opNaiveSegmentation.Rag.connect(self.opRagCache.Output)
        self.opNaiveSegmentation.EdgeProbabilities.connect(
            self.opEdgeProbabilitiesCache.Output)

        self.opNaiveSegmentationCache = OpMultiLaneWrapper(
            OpBlockedArrayCache,
            parent=self,
            broadcastingSlotNames=[
                'CompressionEnabled', 'fixAtCurrent', 'BypassModeEnabled'
            ])
        self.opNaiveSegmentationCache.CompressionEnabled.setValue(True)
        self.opNaiveSegmentationCache.Input.connect(
            self.opNaiveSegmentation.Output)
        self.opNaiveSegmentationCache.name = 'opNaiveSegmentationCache'

        self.Rag.connect(self.opRagCache.Output)
        self.EdgeProbabilities.connect(self.opEdgeProbabilitiesCache.Output)
        self.EdgeProbabilitiesDict.connect(
            self.opEdgeProbabilitiesDictCache.Output)
        self.NaiveSegmentation.connect(self.opNaiveSegmentationCache.Output)

        # All input multi-slots should be kept in sync
        # Output multi-slots will auto-sync via the graph
        multiInputs = filter(lambda s: s.level >= 1, self.inputs.values())
        for s1 in multiInputs:
            for s2 in multiInputs:
                if s1 != s2:

                    def insertSlot(a, b, position, finalsize):
                        a.insertSlot(position, finalsize)

                    s1.notifyInserted(partial(insertSlot, s2))

                    def removeSlot(a, b, position, finalsize):
                        a.removeSlot(position, finalsize)

                    s1.notifyRemoved(partial(removeSlot, s2))

        # If superpixels change, we have to delete our edge labels.
        # Since we're dealing with multi-lane slot, setting up dirty handlers is a two-stage process.
        # (1) React to lane insertion by subscribing to dirty signals for the new lane.
        # (2) React to each lane's dirty signal by deleting the labels for that lane.

        def subscribe_to_dirty_sp(slot, position, finalsize):
            # A new lane was added.  Subscribe to it's dirty signal.
            assert slot is self.Superpixels
            self.Superpixels[position].notifyDirty(
                self.handle_dirty_superpixels)
            self.Superpixels[position].notifyReady(
                self.handle_dirty_superpixels)
            self.Superpixels[position].notifyUnready(
                self.handle_dirty_superpixels)

        # When a new lane is added, set up the listener for dirtyness.
        self.Superpixels.notifyInserted(subscribe_to_dirty_sp)

    def handle_dirty_superpixels(self, subslot, *args):
        """
        Discards the labels for a given lane.
        NOTE: In addition to callers in this file, this function is also called from multicutWorkflow.py
        """
        # Determine which lane triggered this and delete it's labels
        lane_index = self.Superpixels.index(subslot)
        old_labels = self.EdgeLabelsDict[lane_index].value
        if old_labels:
            logger.warn(
                "Superpixels changed.  Deleting all labels in lane {}.".format(
                    lane_index))
            logger.info("Old labels were: {}".format(old_labels))
            self.EdgeLabelsDict[lane_index].setValue({})

    def setupOutputs(self):
        for sp_slot, seg_cache_blockshape_slot in zip(
                self.Superpixels,
                self.opNaiveSegmentationCache.outerBlockShape):
            assert sp_slot.meta.dtype == np.uint32
            assert sp_slot.meta.getAxisKeys()[-1] == 'c'
            seg_cache_blockshape_slot.setValue(sp_slot.meta.shape)

    def execute(self, slot, subindex, roi, result):
        assert False, "Shouldn't get here, but requesting slot: {}".format(
            slot)

    def propagateDirty(self, slot, subindex, roi):
        pass

    def setEdgeLabelsFromGroundtruth(self, lane_index):
        """
        For the given lane, read the ground truth volume and
        automatically determine edge label values.
        """
        op_view = self.getLane(lane_index)

        if not op_view.GroundtruthSegmentation.ready():
            raise RuntimeError(
                "There is no Ground Truth data available for lane: {}".format(
                    lane_index))

        logger.info("Loading groundtruth for lane {}...".format(lane_index))
        gt_vol = op_view.GroundtruthSegmentation[:].wait()
        gt_vol = vigra.taggedView(
            gt_vol, op_view.GroundtruthSegmentation.meta.axistags)
        gt_vol = gt_vol.withAxes(''.join(
            tag.key for tag in op_view.Superpixels.meta.axistags))
        gt_vol = gt_vol.dropChannelAxis()

        rag = op_view.opRagCache.Output.value

        logger.info("Computing edge decisions from groundtruth...")
        decisions = rag.edge_decisions_from_groundtruth(gt_vol, asdict=False)
        edge_labels = decisions.view(np.uint8) + 1
        edge_ids = map(tuple, rag.edge_ids)
        edge_labels_dict = dict(zip(edge_ids, edge_labels))
        op_view.EdgeLabelsDict.setValue(edge_labels_dict)

    def addLane(self, laneIndex):
        numLanes = len(self.VoxelData)
        assert numLanes == laneIndex, "Image lanes must be appended."
        self.VoxelData.resize(numLanes + 1)

    def removeLane(self, laneIndex, finalLength):
        self.VoxelData.removeSlot(laneIndex, finalLength)

    def getLane(self, laneIndex):
        return OperatorSubView(self, laneIndex)
Example #11
0
class OpTrainEdgeClassifier(Operator):
    EdgeLabelsDict = InputSlot(level=1)
    EdgeFeaturesDataFrame = InputSlot(level=1)

    EdgeClassifier = OutputSlot()

    def setupOutputs(self):
        self.EdgeClassifier.meta.shape = (1, )
        self.EdgeClassifier.meta.dtype = object

    def execute(self, slot, subindex, roi, result):
        all_features_and_labels_df = None

        for lane_index, (labels_dict_slot, features_slot) in \
                enumerate( zip(self.EdgeLabelsDict, self.EdgeFeaturesDataFrame) ):
            logger.info(
                "Retrieving features for lane {}...".format(lane_index))

            labels_dict = labels_dict_slot.value.copy(
            )  # Copy now to avoid threading issues.
            if not labels_dict:
                continue

            sp_columns = np.array(labels_dict.keys())
            edge_features_df = features_slot.value
            assert list(edge_features_df.columns[0:2]) == ['sp1', 'sp2']

            labels_df = pd.DataFrame(sp_columns, columns=['sp1', 'sp2'])
            labels_df['label'] = labels_dict.values()

            # Drop zero labels
            labels_df = labels_df[labels_df['label'] != 0]

            # Merge in features
            features_and_labels_df = pd.merge(edge_features_df,
                                              labels_df,
                                              how='right',
                                              on=['sp1', 'sp2'])
            if all_features_and_labels_df is not None:
                all_features_and_labels_df = all_features_and_labels_df.append(
                    features_and_labels_df)
            else:
                all_features_and_labels_df = features_and_labels_df

        if all_features_and_labels_df is None:
            # No labels yet.
            result[0] = None
            return

        assert list(all_features_and_labels_df.columns[0:2]) == ['sp1', 'sp2']
        assert all_features_and_labels_df.columns[-1] == 'label'

        feature_matrix = all_features_and_labels_df.iloc[:, 2:
                                                         -1].values  # Omit 'sp1', 'sp2', and 'label'
        labels = all_features_and_labels_df.iloc[:, -1].values

        logger.info("Training classifier with {} labels...".format(
            len(labels)))
        # TODO: Allow factory to be configured via an input slot
        classifier_factory = ParallelVigraRfLazyflowClassifierFactory()
        classifier = classifier_factory.create_and_train(
            feature_matrix,
            labels,
            feature_names=all_features_and_labels_df.columns[2:-1].values)
        assert set(classifier.known_classes).issubset(set([1, 2]))
        result[0] = classifier

    def propagateDirty(self, slot, subindex, roi):
        self.EdgeClassifier.setDirty()
Example #12
0
class OpCarving(Operator):
    name = "Carving"
    category = "interactive segmentation"

    # I n p u t s #

    #MST of preprocessed Graph
    MST = InputSlot()

    # These three slots are for display only.
    # All computation is done with the MST.
    OverlayData = InputSlot(
        optional=True
    )  # Display-only: Available to the GUI in case the input data was preprocessed in some way but you still want to see the 'raw' data.
    InputData = InputSlot()  # The data used by preprocessing (display only)
    FilteredInputData = InputSlot()  # The output of the preprocessing filter

    #write the seeds that the users draw into this slot
    WriteSeeds = InputSlot()

    #trigger an update by writing into this slot
    Trigger = InputSlot(value=numpy.zeros((1, ), dtype=numpy.uint8))

    #number between 0.0 and 1.0
    #bias of the background
    #FIXME: correct name?
    BackgroundPriority = InputSlot(value=0.95)

    LabelNames = OutputSlot(stype='list')

    #a number between 0 and 256
    #below the number, no background bias will be applied to the edge weights
    NoBiasBelow = InputSlot(value=64)

    # uncertainty type
    UncertaintyType = InputSlot()

    # O u t p u t s #

    #current object + background
    Segmentation = OutputSlot()

    Supervoxels = OutputSlot()

    Uncertainty = OutputSlot()

    #contains an array with where all objects done so far are labeled the same
    DoneObjects = OutputSlot()

    #contains an array with the object labels done so far, one label for each
    #object
    DoneSegmentation = OutputSlot()

    CurrentObjectName = OutputSlot(stype='string')

    AllObjectNames = OutputSlot(rtype=List, stype=Opaque)

    #current object has an actual segmentation
    HasSegmentation = OutputSlot(stype='bool')

    #Hint Overlay
    HintOverlay = OutputSlot()

    #Pmap Overlay
    PmapOverlay = OutputSlot()

    MstOut = OutputSlot()

    def __init__(self,
                 graph=None,
                 hintOverlayFile=None,
                 pmapOverlayFile=None,
                 parent=None):
        super(OpCarving, self).__init__(graph=graph, parent=parent)
        self.opLabelArray = OpDenseLabelArray(parent=self)
        #self.opLabelArray.EraserLabelValue.setValue( 100 )
        self.opLabelArray.MetaInput.connect(self.InputData)

        self._hintOverlayFile = hintOverlayFile
        self._mst = None
        self.has_seeds = False  # keeps track of whether or not there are seeds currently loaded, either drawn by the user or loaded from a saved object

        self.LabelNames.setValue(["Background", "Object"])

        #supervoxels of finished and saved objects
        self._done_lut = None
        self._done_seg_lut = None
        self._hints = None
        self._pmap = None
        if hintOverlayFile is not None:
            try:
                f = h5py.File(hintOverlayFile, "r")
            except Exception as e:
                logger.info("Could not open hint overlay '%s'" %
                            hintOverlayFile)
                raise e
            self._hints = f["/hints"].value[numpy.newaxis, :, :, :,
                                            numpy.newaxis]

        if pmapOverlayFile is not None:
            try:
                f = h5py.File(pmapOverlayFile, "r")
            except Exception as e:
                raise RuntimeError("Could not open pmap overlay '%s'" %
                                   pmapOverlayFile)
            self._pmap = f["/data"].value[numpy.newaxis, :, :, :,
                                          numpy.newaxis]

        self._setCurrObjectName("<not saved yet>")
        self.HasSegmentation.setValue(False)

        # keep track of a set of object names that have changed since
        # the last serialization of this object to disk
        self._dirtyObjects = set()
        self.preprocessingApplet = None

        self._opMstCache = OpValueCache(parent=self)
        self.MstOut.connect(self._opMstCache.Output)

        self.InputData.notifyReady(self._checkConstraints)

    def _checkConstraints(self, *args):
        slot = self.InputData
        numChannels = slot.meta.getTaggedShape()['c']
        if numChannels != 1:
            raise DatasetConstraintError(
                "Carving", "Input image must have exactly one channel.  " +
                "You attempted to add a dataset with {} channels".format(
                    numChannels))

        sh = slot.meta.shape
        ax = slot.meta.axistags
        if len(slot.meta.shape) != 5:
            # Raise a regular exception.  This error is for developers, not users.
            raise RuntimeError("was expecting a 5D dataset, got shape=%r" %
                               (sh, ))
        if slot.meta.getTaggedShape()['t'] != 1:
            raise DatasetConstraintError(
                "Carving",
                "Input image must not have more than one time slice.  " +
                "You attempted to add a dataset with {} time slices".format(
                    slot.meta.getTaggedShape()['t']))

        for i in range(1, 4):
            if not ax[i].isSpatial():
                # This is for developers.  Don't need a user-friendly error.
                raise RuntimeError("%d-th axis %r is not spatial" % (i, ax[i]))

    def _clearLabels(self):
        #clear the labels
        self.opLabelArray.DeleteLabel.setValue(2)
        self.opLabelArray.DeleteLabel.setValue(1)
        self.opLabelArray.DeleteLabel.setValue(-1)
        self.has_seeds = False

    def _setCurrObjectName(self, n):
        """
        Sets the current object name to n.
        """
        self._currObjectName = n
        self.CurrentObjectName.setValue(n)

    def _buildDone(self):
        """
        Builds the done segmentation anew, for example after saving an object or
        deleting an object.
        """
        if self._mst is None:
            return
        with Timer() as timer:
            self._done_lut = numpy.zeros(self._mst.numNodes + 1,
                                         dtype=numpy.int32)
            self._done_seg_lut = numpy.zeros(self._mst.numNodes + 1,
                                             dtype=numpy.int32)
            logger.info("building 'done' luts")
            for name, objectSupervoxels in self._mst.object_lut.iteritems():
                if name == self._currObjectName:
                    continue
                self._done_lut[objectSupervoxels] += 1
                assert name in self._mst.object_names, "%s not in self._mst.object_names, keys are %r" % (
                    name, self._mst.object_names.keys())
                self._done_seg_lut[objectSupervoxels] = self._mst.object_names[
                    name]
        logger.info("building the 'done' luts took {} seconds".format(
            timer.seconds()))

    def dataIsStorable(self):
        if self._mst is None:
            return False
        nodeSeeds = self._mst.gridSegmentor.getNodeSeeds()
        fg_seedNum = len(numpy.where(nodeSeeds == 2)[0])
        bg_seedNum = len(numpy.where(nodeSeeds == 1)[0])
        if not (fg_seedNum > 0 and bg_seedNum > 0):
            return False
        else:
            return True

    def setupOutputs(self):
        self.Segmentation.meta.assignFrom(self.InputData.meta)
        self.Segmentation.meta.dtype = numpy.int32

        self.Supervoxels.meta.assignFrom(self.Segmentation.meta)
        self.DoneObjects.meta.assignFrom(self.Segmentation.meta)
        self.DoneSegmentation.meta.assignFrom(self.Segmentation.meta)

        self.HintOverlay.meta.assignFrom(self.InputData.meta)
        self.PmapOverlay.meta.assignFrom(self.InputData.meta)

        self.Uncertainty.meta.assignFrom(self.InputData.meta)
        self.Uncertainty.meta.dtype = numpy.uint8

        self.Trigger.meta.shape = (1, )
        self.Trigger.meta.dtype = numpy.uint8

        if self._mst is not None:
            objects = self._mst.object_names.keys()
            self.AllObjectNames.meta.shape = (len(objects), )
        else:
            self.AllObjectNames.meta.shape = (0, )

        self.AllObjectNames.meta.dtype = object

    def connectToPreprocessingApplet(self, applet):
        self.PreprocessingApplet = applet

#     def updatePreprocessing(self):
#         if self.PreprocessingApplet is None or self._mst is None:
#             return
#FIXME: why were the following lines needed ?
# if len(self._mst.object_names)==0:
#     self.PreprocessingApplet.enableWriteprotect(True)
# else:
#     self.PreprocessingApplet.enableWriteprotect(False)

    def hasCurrentObject(self):
        """
        Returns current object name. None if it is not set.
        """
        #FIXME: This is misleading. Having a current object and that object having
        #a name is not the same thing.
        return self._currObjectName

    def currentObjectName(self):
        """
        Returns current object name. Return "" if no current object
        """
        assert self._currObjectName is not None, "FIXME: This function should either return '' or None.  Why does it sometimes return one and then the other?"
        return self._currObjectName

    def hasObjectWithName(self, name):
        """
        Returns True if object with name is existent. False otherwise.
        """
        return name in self._mst.object_lut

    def doneObjectNamesForPosition(self, position3d):
        """
        Returns a list of names of objects which occupy a specific 3D position.
        List is empty if there are no objects present.
        """
        assert len(position3d) == 3

        #find the supervoxel that was clicked
        sv = self._mst.supervoxelUint32[position3d]
        names = []
        for name, objectSupervoxels in self._mst.object_lut.iteritems():
            if numpy.sum(sv == objectSupervoxels) > 0:
                names.append(name)
        logger.info("click on %r, supervoxel=%d: %r" % (position3d, sv, names))
        return names

    @Operator.forbidParallelExecute
    def attachVoxelLabelsToObject(self, name, fgVoxels, bgVoxels):
        """
        Attaches Voxellabes to an object called name.
        """
        self._mst.object_seeds_fg_voxels[name] = fgVoxels
        self._mst.object_seeds_bg_voxels[name] = bgVoxels

    @Operator.forbidParallelExecute
    def clearCurrentLabeling(self, trigger_recompute=True):
        """
        Clears the current labeling.
        """
        self._clearLabels()
        self._mst.gridSegmentor.clearSeeds()
        #lut_segmentation = self._mst.segmentation.lut[:]
        #lut_segmentation[:] = 0
        #lut_seeds = self._mst.seeds.lut[:]
        #lut_seeds[:] = 0
        #self.HasSegmentation.setValue(False)

        self.Trigger.setDirty(slice(None))

    def loadObject_impl(self, name):
        """
        Loads a single object called name to be the currently edited object. Its
        not part of the done segmentation anymore.
        """
        assert self._mst is not None
        logger.info("[OpCarving] load object %s (opCarving=%d, mst=%d)" %
                    (name, id(self), id(self._mst)))

        assert name in self._mst.object_lut
        assert name in self._mst.object_seeds_fg_voxels
        assert name in self._mst.object_seeds_bg_voxels
        assert name in self._mst.bg_priority
        assert name in self._mst.no_bias_below

        #lut_segmentation = self._mst.segmentation.lut[:]
        #lut_objects = self._mst.objects.lut[:]
        #lut_seeds = self._mst.seeds.lut[:]
        ## clean seeds
        #lut_seeds[:] = 0

        # set foreground and background seeds
        fgVoxelsSeedPos = self._mst.object_seeds_fg_voxels[name]
        bgVoxelsSeedPos = self._mst.object_seeds_bg_voxels[name]
        fgArraySeedPos = numpy.array(fgVoxelsSeedPos)
        bgArraySeedPos = numpy.array(bgVoxelsSeedPos)

        self._mst.setSeeds(fgArraySeedPos, bgArraySeedPos)

        # load the actual segmentation
        fgNodes = self._mst.object_lut[name]

        self._mst.setResulFgObj(fgNodes[0])

        #newSegmentation = numpy.ones(len(lut_objects), dtype=numpy.int32)
        #newSegmentation[ self._mst.object_lut[name] ] = 2
        #lut_segmentation[:] = newSegmentation

        self._setCurrObjectName(name)
        self.HasSegmentation.setValue(True)

        #now that 'name' is no longer part of the set of finished objects, rebuild the done overlay
        self._buildDone()
        return (fgVoxelsSeedPos, bgVoxelsSeedPos)

    def loadObject(self, name):
        logger.info("want to load object with name = %s" % name)
        if not self.hasObjectWithName(name):
            logger.info("  --> no such object '%s'" % name)
            return False

        if self.hasCurrentObject():
            self.saveCurrentObject()
        self._clearLabels()

        fgVoxels, bgVoxels = self.loadObject_impl(name)

        fg_bounding_box_start = numpy.array(map(numpy.min, fgVoxels))
        fg_bounding_box_stop = 1 + numpy.array(map(numpy.max, fgVoxels))

        bg_bounding_box_start = numpy.array(map(numpy.min, bgVoxels))
        bg_bounding_box_stop = 1 + numpy.array(map(numpy.max, bgVoxels))

        bounding_box_start = numpy.minimum(fg_bounding_box_start,
                                           bg_bounding_box_start)
        bounding_box_stop = numpy.maximum(fg_bounding_box_stop,
                                          bg_bounding_box_stop)

        bounding_box_slicing = roiToSlice(bounding_box_start,
                                          bounding_box_stop)

        bounding_box_shape = tuple(bounding_box_stop - bounding_box_start)
        dtype = self.opLabelArray.Output.meta.dtype

        # Convert coordinates to be relative to bounding box
        fgVoxels = numpy.array(fgVoxels)
        fgVoxels = fgVoxels - numpy.array([bounding_box_start]).transpose()
        fgVoxels = list(fgVoxels)

        bgVoxels = numpy.array(bgVoxels)
        bgVoxels = bgVoxels - numpy.array([bounding_box_start]).transpose()
        bgVoxels = list(bgVoxels)

        with Timer() as timer:
            logger.info("Loading seeds....")
            z = numpy.zeros(bounding_box_shape, dtype=dtype)
            logger.info("Allocating seed array took {} seconds".format(
                timer.seconds()))
            z[fgVoxels] = 2
            z[bgVoxels] = 1
            self.WriteSeeds[(slice(0, 1), ) + bounding_box_slicing +
                            (slice(0, 1), )] = z[numpy.newaxis, :, :, :,
                                                 numpy.newaxis]
        logger.info("Loading seeds took a total of {} seconds".format(
            timer.seconds()))

        #restore the correct parameter values
        mst = self._mst

        assert name in mst.object_lut
        assert name in mst.object_seeds_fg_voxels
        assert name in mst.object_seeds_bg_voxels
        assert name in mst.bg_priority
        assert name in mst.no_bias_below

        assert name in mst.bg_priority
        assert name in mst.no_bias_below

        self.BackgroundPriority.setValue(mst.bg_priority[name])
        self.NoBiasBelow.setValue(mst.no_bias_below[name])

        #self.updatePreprocessing()
        # The entire segmentation layer needs to be refreshed now.
        self.Segmentation.setDirty()

        return True

    @Operator.forbidParallelExecute
    def deleteObject_impl(self, name):
        """
        Deletes an object called name.
        """
        #lut_seeds = self._mst.seeds.lut[:]
        # clean seeds
        #lut_seeds[:] = 0

        del self._mst.object_lut[name]
        del self._mst.object_seeds_fg_voxels[name]
        del self._mst.object_seeds_bg_voxels[name]
        del self._mst.bg_priority[name]
        del self._mst.no_bias_below[name]

        #delete it from object_names, as it indicates
        #whether the object exists
        if name in self._mst.object_names:
            del self._mst.object_names[name]

        self._setCurrObjectName("<not saved yet>")

        #now that 'name' has been deleted, rebuild the done overlay
        self._buildDone()
        #self.updatePreprocessing()

    def deleteObject(self, name):
        logger.info("want to delete object with name = %s" % name)
        if not self.hasObjectWithName(name):
            logger.info("  --> no such object '%s'" % name)
            return False

        self.deleteObject_impl(name)
        #clear the user labels
        self._clearLabels()
        # trigger a re-computation
        self.Trigger.setDirty(slice(None))
        self._dirtyObjects.add(name)

        objects = self._mst.object_names.keys()
        logger.info("save: len = {}".format(len(objects)))
        self.AllObjectNames.meta.shape = (len(objects), )

        self.HasSegmentation.setValue(False)

        return True

    @Operator.forbidParallelExecute
    def saveCurrentObject(self):
        """
        Saves the objects which is currently edited.
        """
        if self._currObjectName:
            name = copy.copy(self._currObjectName)
            logger.info("saving object %s" % self._currObjectName)
            self.saveCurrentObjectAs(self._currObjectName)
            self.HasSegmentation.setValue(False)
            return name
        return ""

    @Operator.forbidParallelExecute
    def saveCurrentObjectAs(self, name):
        """
        Saves current object as name.
        """
        seed = 2
        logger.info("   --> Saving object %r from seed %r" % (name, seed))
        if self._mst.object_names.has_key(name):
            objNr = self._mst.object_names[name]
        else:
            # find free objNr
            if len(self._mst.object_names.values()) > 0:
                objNr = numpy.max(numpy.array(
                    self._mst.object_names.values())) + 1
            else:
                objNr = 1

        sVseg = self._mst.getSuperVoxelSeg()
        sVseed = self._mst.getSuperVoxelSeeds()

        self._mst.object_names[name] = objNr

        self._mst.bg_priority[name] = self.BackgroundPriority.value
        self._mst.no_bias_below[name] = self.NoBiasBelow.value

        self._mst.objects[name] = numpy.where(sVseg == 2)
        self._mst.object_lut[name] = numpy.where(sVseg == 2)

        self._setCurrObjectName("<not saved yet>")
        self.HasSegmentation.setValue(False)

        objects = self._mst.object_names.keys()
        self.AllObjectNames.meta.shape = (len(objects), )

        #now that 'name' is no longer part of the set of finished objects, rebuild the done overlay

        self._buildDone()
        #self._clearLabels()
        #self._mst.clearSegmentation()
        #self.clearCurrentLabeling()
        #self._mst.gridSegmentor.clearSeeds()
        #self.Trigger.setDirty(slice(None))
        #self.updatePreprocessing()

    def get_label_voxels(self):
        #the voxel coordinates of fg and bg labels
        if not self.opLabelArray.NonzeroBlocks.ready():
            return (None, None)

        nonzeroSlicings = self.opLabelArray.NonzeroBlocks[:].wait()[0]

        coors1 = [[], [], []]
        coors2 = [[], [], []]
        for sl in nonzeroSlicings:
            a = self.opLabelArray.Output[sl].wait()
            w1 = numpy.where(a == 1)
            w2 = numpy.where(a == 2)
            w1 = [w1[i] + sl[i].start for i in range(1, 4)]
            w2 = [w2[i] + sl[i].start for i in range(1, 4)]
            for i in range(3):
                coors1[i].append(w1[i])
                coors2[i].append(w2[i])

        for i in range(3):
            if len(coors1[i]) > 0:
                coors1[i] = numpy.concatenate(coors1[i], 0)
            else:
                coors1[i] = numpy.ndarray((0, ), numpy.int32)
            if len(coors2[i]) > 0:
                coors2[i] = numpy.concatenate(coors2[i], 0)
            else:
                coors2[i] = numpy.ndarray((0, ), numpy.int32)
        return (coors2, coors1)

    def saveObjectAs(self, name):
        # first, save the object under "name"
        self.saveCurrentObjectAs(name)
        # Sparse label array automatically shifts label values down 1

        sVseed = self._mst.getSuperVoxelSeeds()
        #fgVoxels = numpy.where(sVseed==2)
        #bgVoxels = numpy.where(sVseed==1)

        fgVoxels, bgVoxels = self.get_label_voxels()

        self.attachVoxelLabelsToObject(name,
                                       fgVoxels=fgVoxels,
                                       bgVoxels=bgVoxels)

        self._clearLabels()

        # trigger a re-computation
        self.Trigger.setDirty(slice(None))

        self._dirtyObjects.add(name)

        self._mst.gridSegmentor.clearSeeds()

        self._mst.clearSegmentation()
        self.clearCurrentLabeling()

    def getMaxUncertaintyPos(self, label):
        # FIXME: currently working on
        uncertainties = self._mst.uncertainty.lut
        segmentation = self._mst.segmentation.lut
        uncertainty_fg = numpy.where(segmentation == label, uncertainties, 0)
        index_max_uncert = numpy.argmax(uncertainty_fg, axis=0)
        pos = self._mst.regionCenter[index_max_uncert, :]

        return pos

    def execute(self, slot, subindex, roi, result):
        self._mst = self.MST.value

        if slot == self.AllObjectNames:
            ret = self._mst.object_names.keys()
            return ret

        sl = roi.toSlice()
        if slot == self.Segmentation:
            #avoid data being copied
            temp = self._mst.getVoxelSegmentation(roi=roi)
            temp.shape = (1, ) + temp.shape + (1, )

        elif slot == self.Supervoxels:
            #avoid data being copied
            temp = self._mst.supervoxelUint32[sl[1:4]]
            temp.shape = (1, ) + temp.shape + (1, )
        elif slot == self.DoneObjects:
            #avoid data being copied
            if self._done_lut is None:
                result[0, :, :, :, 0] = 0
                return result
            else:
                temp = self._done_lut[self._mst.supervoxelUint32[sl[1:4]]]
                temp.shape = (1, ) + temp.shape + (1, )
        elif slot == self.DoneSegmentation:
            #avoid data being copied
            if self._done_seg_lut is None:
                result[0, :, :, :, 0] = 0
                return result
            else:
                temp = self._done_seg_lut[self._mst.supervoxelUint32[sl[1:4]]]
                temp.shape = (1, ) + temp.shape + (1, )
        elif slot == self.HintOverlay:
            if self._hints is None:
                result[:] = 0
                return result
            else:
                result[:] = self._hints[roi.toSlice()]
                return result
        elif slot == self.PmapOverlay:
            if self._pmap is None:
                result[:] = 0
                return result
            else:
                result[:] = self._pmap[roi.toSlice()]
                return result
        elif slot == self.Uncertainty:
            temp = self._mst.uncertainty[sl[1:4]]
            temp.shape = (1, ) + temp.shape + (1, )
        else:
            raise RuntimeError("unknown slot")
        return temp  #avoid copying data

    def setInSlot(self, slot, subindex, roi, value):
        key = roi.toSlice()
        if slot == self.WriteSeeds:
            with Timer() as timer:
                logger.info("Writing seeds to label array")
                self.opLabelArray.LabelSinkInput[roi.toSlice()] = value
                logger.info(
                    "Writing seeds to label array took {} seconds".format(
                        timer.seconds()))

            assert self._mst is not None

            # Important: mst.seeds will requires erased values to be 255 (a.k.a -1)
            #value[:] = numpy.where(value == 100, 255, value)
            seedVal = value.max()
            with Timer() as timer:
                logger.info("Writing seeds to MST")
                if hasattr(key, '__len__'):
                    self._mst.addSeeds(roi=roi, brushStroke=value.squeeze())
                else:
                    raise RuntimeError("when is this part of the code called")
                    self._mst.seeds[key] = value
            logger.info("Writing seeds to MST took {} seconds".format(
                timer.seconds()))

            self.has_seeds = True
        else:
            raise RuntimeError("unknown slots")

    def propagateDirty(self, slot, subindex, roi):
        if slot == self.Trigger or \
           slot == self.BackgroundPriority or \
           slot == self.NoBiasBelow or \
           slot == self.UncertaintyType:
            if self._mst is None:
                return
            if not self.BackgroundPriority.ready():
                return
            if not self.NoBiasBelow.ready():
                return

            bgPrio = self.BackgroundPriority.value
            noBiasBelow = self.NoBiasBelow.value

            logger.info(
                "compute new carving results with bg priority = %f, no bias below %d"
                % (bgPrio, noBiasBelow))
            t1 = time.time()
            labelCount = 2
            params = dict()
            params["prios"] = [1.0, bgPrio, 1.0]
            params["uncertainty"] = self.UncertaintyType.value
            params["noBiasBelow"] = noBiasBelow

            unaries = numpy.zeros((self._mst.numNodes + 1, labelCount + 1),
                                  dtype=numpy.float32)
            self._mst.run(unaries, **params)
            logger.info(" ... carving took %f sec." % (time.time() - t1))

            self.Segmentation.setDirty(slice(None))
            hasSeg = numpy.any(self._mst.hasSeg)
            #hasSeg = numpy.any(self._mst.segmentation.lut > 0 )
            self.HasSegmentation.setValue(hasSeg)

        elif slot == self.MST:
            self._opMstCache.Input.disconnect()
            self._mst = self.MST.value
            self._opMstCache.Input.setValue(self._mst)
        elif slot == self.OverlayData or \
             slot == self.InputData or \
             slot == self.FilteredInputData or \
             slot == self.WriteSeeds:
            pass
        else:
            assert False, "Unknown input slot: {}".format(slot.name)
Example #13
0
class OpTrackingBase(Operator, ExportingOperator):
    name = "Tracking"
    category = "other"

    LabelImage = InputSlot()
    ObjectFeatures = InputSlot(stype=Opaque, rtype=List)
    ObjectFeaturesWithDivFeatures = InputSlot(optional=True, stype=Opaque, rtype=List)
    ComputedFeatureNames = InputSlot(rtype=List, stype=Opaque)
    ComputedFeatureNamesWithDivFeatures = InputSlot(optional=True, rtype=List, stype=Opaque)
    EventsVector = InputSlot(value={})
    FilteredLabels = InputSlot(value={})
    RawImage = InputSlot()
    Parameters = InputSlot(value={})

    # for serialization
    InputHdf5 = InputSlot(optional=True)
    CleanBlocks = OutputSlot()
    AllBlocks = OutputSlot()
    OutputHdf5 = OutputSlot()
    CachedOutput = OutputSlot()  # For the GUI (blockwise-access)

    Output = OutputSlot()

    # Use a slot for storing the export settings in the project file.
    ExportSettings = InputSlot()

    # Override functions ExportingOperator mixin
    def configure_table_export_settings(self, settings, selected_features):
        self.ExportSettings.setValue( (settings, selected_features) )

    def get_table_export_settings(self):
        if self.ExportSettings.ready():
            (settings, selected_features) = self.ExportSettings.value
            return (settings, selected_features)
        else:
            return None, None

    def __init__(self, parent=None, graph=None):
        super(OpTrackingBase, self).__init__(parent=parent, graph=graph)
        self.label2color = []
        self.mergers = []
        self.resolvedto = []

        self.track_id = None
        self.extra_track_ids = None
        self.divisions = None

        self._opCache = OpCompressedCache(parent=self)
        self._opCache.InputHdf5.connect(self.InputHdf5)
        self._opCache.Input.connect(self.Output)
        self.CleanBlocks.connect(self._opCache.CleanBlocks)
        self.OutputHdf5.connect(self._opCache.OutputHdf5)
        self.CachedOutput.connect(self._opCache.Output)

        self.zeroProvider = OpZeroDefault(parent=self)
        self.zeroProvider.MetaInput.connect(self.LabelImage)

        # As soon as input data is available, check its constraints
        self.RawImage.notifyReady(self._checkConstraints)
        self.LabelImage.notifyReady(self._checkConstraints)

        self.export_progress_dialog = None
        self.ExportSettings.setValue( (None, None) )

    def setupOutputs(self):
        self.Output.meta.assignFrom(self.LabelImage.meta)

        # cache our own output, don't propagate from internal operator
        chunks = list(self.LabelImage.meta.shape)
        # FIXME: assumes t,x,y,z,c
        chunks[0] = 1  # 't'        
        self._blockshape = tuple(chunks)
        self._opCache.BlockShape.setValue(self._blockshape)

        self.AllBlocks.meta.shape = (1,)
        self.AllBlocks.meta.dtype = object


    def _checkConstraints(self, *args):
        if self.RawImage.ready():
            rawTaggedShape = self.RawImage.meta.getTaggedShape()
            if rawTaggedShape['t'] < 2:
                raise DatasetConstraintError(
                    "Tracking",
                    "For tracking, the dataset must have a time axis with at least 2 images.   " \
                    "Please load time-series data instead. See user documentation for details.")

        if self.LabelImage.ready():
            segmentationTaggedShape = self.LabelImage.meta.getTaggedShape()
            if segmentationTaggedShape['t'] < 2:
                raise DatasetConstraintError(
                    "Tracking",
                    "For tracking, the dataset must have a time axis with at least 2 images.   " \
                    "Please load time-series data instead. See user documentation for details.")

        if self.RawImage.ready() and self.LabelImage.ready():
            rawTaggedShape['c'] = None
            segmentationTaggedShape['c'] = None
            if dict(rawTaggedShape) != dict(segmentationTaggedShape):
                raise DatasetConstraintError("Tracking",
                                             "For tracking, the raw data and the prediction maps must contain the same " \
                                             "number of timesteps and the same shape.   " \
                                             "Your raw image has a shape of (t, x, y, z, c) = {}, whereas your prediction image has a " \
                                             "shape of (t, x, y, z, c) = {}" \
                                             .format(self.RawImage.meta.shape, self.BinaryImage.meta.shape))

    def execute(self, slot, subindex, roi, result):
        if slot is self.Output:
            result[:] = self.LabelImage.get(roi).wait()
            if not self.Parameters.ready():
                raise Exception("Parameter slot is not ready")
            parameters = self.Parameters.value

            t_start = roi.start[0]
            t_end = roi.stop[0]
            for t in range(t_start, t_end):
                if ('time_range' in parameters and t <= parameters['time_range'][-1] and t >= parameters['time_range'][
                    0]) and len(self.label2color) > t:
                    result[t - t_start, ..., 0] = relabel(result[t - t_start, ..., 0], self.label2color[t])
                else:
                    result[t - t_start, ...] = 0
            return result
        elif slot == self.AllBlocks:
            # if nothing was computed, return empty list
            if len(self.label2color) == 0:
                result[0] = []
                return result

            all_block_rois = []
            shape = self.Output.meta.shape
            # assumes t,x,y,z,c
            slicing = [slice(None), ] * 5
            for t in range(shape[0]):
                slicing[0] = slice(t, t + 1)
                all_block_rois.append(sliceToRoi(slicing, shape))

            result[0] = all_block_rois
            return result


    def propagateDirty(self, inputSlot, subindex, roi):
        if inputSlot is self.LabelImage:
            self.Output.setDirty(roi)
        elif inputSlot is self.EventsVector:
            self._setLabel2Color()
            try:
                self._setLabel2Color(export_mode=True)
            except:
                logger.debug("Warning: some label information might be wrong...")


    def setInSlot(self, slot, subindex, roi, value):
        assert slot == self.InputHdf5, "Invalid slot for setInSlot(): {}".format(slot.name)

    def _setLabel2Color(self, successive_ids=True, export_mode=False):
        if not self.EventsVector.ready() or not self.Parameters.ready() \
                or not self.FilteredLabels.ready():
            return

        if export_mode:
            assert successive_ids, "Export mode only works for successive ids"

        events = self.EventsVector.value
        parameters = self.Parameters.value
        
        time_min = 0
        time_max = self.RawImage.meta.shape[0] - 1 # Assumes t,x,y,z,c
        if 'time_range' in parameters:
            time_min, time_max = parameters['time_range']
        time_range = range(time_min, time_max)

        filtered_labels = self.FilteredLabels.value

        label2color = []
        label2color.append({})
        mergers = []
        resolvedto = []

        maxId = 2  # misdetections have id 1

        # handle start time offsets
        for i in range(time_range[0]):
            label2color.append({})
            mergers.append({})
            resolvedto.append({})

        extra_track_ids = {}
        if export_mode:
            multi_move = {}
            multi_move_next = {}
            divisions = []

        for i in time_range:
            app = get_dict_value(events[str(i - time_range[0] + 1)], "app", [])
            div = get_dict_value(events[str(i - time_range[0] + 1)], "div", [])
            mov = get_dict_value(events[str(i - time_range[0] + 1)], "mov", [])
            merger = get_dict_value(events[str(i - time_range[0])], "merger", [])
            res = get_dict_value(events[str(i - time_range[0])], "res", {})

            logger.debug(" {} app at {}".format(len(app), i))
            logger.debug(" {} div at {}".format(len(div), i))
            logger.debug(" {} mov at {}".format(len(mov), i))
            logger.debug(" {} merger at {}".format(len(merger), i))

            label2color.append({})
            mergers.append({})
            moves_at = []
            resolvedto.append({})

            if export_mode:
                moves_to = {}

            for e in app:
                if successive_ids:
                    label2color[-1][int(e[0])] = maxId  # in export mode, the label color is used as track ID
                    maxId += 1
                else:
                    label2color[-1][int(e[0])] = np.random.randint(1, 255)

            for e in mov:
                if export_mode:
                    if e[1] in moves_to:
                        multi_move.setdefault(i, {})
                        multi_move[i][e[0]] = e[1]
                        if len(moves_to[e[1]]) == 1:  # if we are just setting up this multi move
                            multi_move[i][moves_to[e[1]][0]] = e[1]
                        multi_move_next[(i, e[1])] = 0
                    moves_to.setdefault(e[1], [])
                    moves_to[e[1]].append(e[0])  # moves_to[target] contains list of incoming object ids

                # alternative way of appearance
                if not label2color[-2].has_key(int(e[0])):
                    if successive_ids:
                        label2color[-2][int(e[0])] = maxId
                        maxId += 1
                    else:
                        label2color[-2][int(e[0])] = np.random.randint(1, 255)

                # assign color of parent
                label2color[-1][int(e[1])] = label2color[-2][int(e[0])]
                moves_at.append(int(e[0]))

                if export_mode:
                    key = i - 1, e[0]
                    if key in multi_move_next:  # captures mergers staying connected over longer time spans
                        multi_move_next[key] = e[1]  # redirects output of last merger to target in this frame
                        multi_move_next[(i, e[1])] = 0  # sets current end to zero (might be changed by above line in the future)

            for e in div:  # event(parent, child, child)
                # if not label2color[-2].has_key(int(e[0])):
                if not int(e[0]) in label2color[-2]:
                    if successive_ids:
                        label2color[-2][int(e[0])] = maxId
                        maxId += 1
                    else:
                        label2color[-2][int(e[0])] = np.random.randint(1, 255)
                ancestor_color = label2color[-2][int(e[0])]
                if export_mode:
                    label2color[-1][int(e[1])] = maxId
                    label2color[-1][int(e[2])] = maxId + 1
                    divisions.append((i, int(e[0]), ancestor_color,
                                      int(e[1]), maxId,
                                      int(e[2]), maxId + 1
                    ))
                    maxId += 2
                else:
                    label2color[-1][int(e[1])] = ancestor_color
                    label2color[-1][int(e[2])] = ancestor_color

            for e in merger:
                mergers[-1][int(e[0])] = int(e[1])

            for o, r in res.iteritems():
                resolvedto[-1][int(o)] = [int(c) for c in r[:-1]]
                # label the original object with the false detection label
                mergers[-1][int(o)] = len(r[:-1])

                if export_mode:
                    extra_track_ids.setdefault(i, {})
                    extra_track_ids[i][int(o)] = [int(c) for c in r[:-1]]

        # last timestep
        merger = get_dict_value(events[str(time_range[-1] - time_range[0] + 1)], "merger", [])
        mergers.append({})
        for e in merger:
            mergers[-1][int(e[0])] = int(e[1])

        res = get_dict_value(events[str(time_range[-1] - time_range[0] + 1)], "res", {})
        resolvedto.append({})
        if export_mode:
            extra_track_ids[time_range[-1] + 1] = {}
        for o, r in res.iteritems():
            resolvedto[-1][int(o)] = [int(c) for c in r[:-1]]
            mergers[-1][int(o)] = len(r[:-1])

            if export_mode:
                    extra_track_ids[time_range[-1] + 1][int(o)] = [int(c) for c in r[:-1]]

        # mark the filtered objects
        for i in filtered_labels.keys():
            if int(i) + time_range[0] >= len(label2color):
                continue
            fl_at = filtered_labels[i]
            for l in fl_at:
                assert l not in label2color[int(i) + time_range[0]]
                label2color[int(i) + time_range[0]][l] = 0

        if export_mode:  # don't set fields when in export_mode
            self.track_id = label2color
            self.divisions = divisions
            self.extra_track_ids = extra_track_ids
            return label2color, extra_track_ids, divisions

        self.track_id = label2color
        self.extra_track_ids = extra_track_ids
        self.label2color = label2color
        self.resolvedto = resolvedto
        self.mergers = mergers

        self.Output._value = None
        self.Output.setDirty(slice(None))

        if 'MergerOutput' in self.outputs:
            self.MergerOutput._value = None
            self.MergerOutput.setDirty(slice(None))

    def export_track_ids(self):
        return self._setLabel2Color(export_mode=True)

    def track_children(self, track_id, start=0):
        if start in self.divisions:
            for t, _, track, _, child_track1, _, child_track2 in self.divisions[start:]:
                if track == track_id:
                    children_of = partial(self.track_children, start=t)
                    return [child_track1, child_track2] + \
                           children_of(child_track1) + children_of(child_track2)
        return []

    def track_parent(self, track_id):
        if not self.divisions == {}:
            for t, oid, track, _, child_track1, _, child_track2 in self.divisions[:-1]:
                if track_id in (child_track1, child_track2):
                    return [track] + self.track_parent(track)
        return []

    def track_family(self, track_id):
        return self.track_children(track_id), self.track_parent(track_id)


    def _generate_traxelstore(self,
                              time_range,
                              x_range,
                              y_range,
                              z_range,
                              size_range,
                              x_scale=1.0,
                              y_scale=1.0,
                              z_scale=1.0,
                              with_div=False,
                              with_local_centers=False,
                              median_object_size=None,
                              max_traxel_id_at=None,
                              with_opt_correction=False,
                              with_coordinate_list=False,
                              with_classifier_prior=False):

        if not self.Parameters.ready():
            raise Exception("Parameter slot is not ready")

        parameters = self.Parameters.value
        parameters['scales'] = [x_scale, y_scale, z_scale]
        parameters['time_range'] = [min(time_range), max(time_range)]
        parameters['x_range'] = x_range
        parameters['y_range'] = y_range
        parameters['z_range'] = z_range
        parameters['size_range'] = size_range

        logger.info("generating traxels")
        logger.info("fetching region features and division probabilities")
        feats = self.ObjectFeatures(time_range).wait()

        if with_div:
            if not self.DivisionProbabilities.ready() or len(self.DivisionProbabilities([0]).wait()[0]) == 0:
                msgStr = "\nDivision classifier has not been trained! " + \
                         "Uncheck divisible objects if your objects don't divide or " + \
                         "go back to the Division Detection applet and train it."
                raise DatasetConstraintError ("Tracking",msgStr)
            divProbs = self.DivisionProbabilities(time_range).wait()

        if with_local_centers:
            localCenters = self.RegionLocalCenters(time_range).wait()

        if with_classifier_prior:
            if not self.DetectionProbabilities.ready() or len(self.DetectionProbabilities([0]).wait()[0]) == 0:
                msgStr = "\nObject count classifier has not been trained! " + \
                         "Go back to the Object Count Classification applet and train it."
                raise DatasetConstraintError ("Tracking",msgStr)
            detProbs = self.DetectionProbabilities(time_range).wait()

        logger.info("filling traxelstore")
        ts = pgmlink.TraxelStore()
        fs = pgmlink.FeatureStore()

        max_traxel_id_at = pgmlink.VectorOfInt()
        filtered_labels = {}
        obj_sizes = []
        total_count = 0
        empty_frame = False

        for t in feats.keys():
            rc = feats[t][default_features_key]['RegionCenter']
            lower = feats[t][default_features_key]['Coord<Minimum>']
            upper = feats[t][default_features_key]['Coord<Maximum>']
            if rc.size:
                rc = rc[1:, ...]
                lower = lower[1:, ...]
                upper = upper[1:, ...]

            if with_opt_correction:
                try:
                    rc_corr = feats[t][config.features_vigra_name]['RegionCenter_corr']
                except:
                    raise Exception, 'Can not consider optical correction since it has not been computed before'
                if rc_corr.size:
                    rc_corr = rc_corr[1:, ...]

            ct = feats[t][default_features_key]['Count']
            if ct.size:
                ct = ct[1:, ...]

            logger.debug("at timestep {}, {} traxels found".format(t, rc.shape[0]))
            count = 0
            filtered_labels_at = []
            for idx in range(rc.shape[0]):
                # for 2d data, set z-coordinate to 0:
                if len(rc[idx]) == 2:
                    x, y = rc[idx]
                    z = 0
                elif len(rc[idx]) == 3:
                    x, y, z = rc[idx]
                else:
                    raise DatasetConstraintError ("Tracking", "The RegionCenter feature must have dimensionality 2 or 3.")
                size = ct[idx]
                if (x < x_range[0] or x >= x_range[1] or
                            y < y_range[0] or y >= y_range[1] or
                            z < z_range[0] or z >= z_range[1] or
                            size < size_range[0] or size >= size_range[1]):
                    filtered_labels_at.append(int(idx + 1))
                    continue
                else:
                    count += 1
                tr = pgmlink.Traxel()
                tr.set_feature_store(fs)
                tr.set_x_scale(x_scale)
                tr.set_y_scale(y_scale)
                tr.set_z_scale(z_scale)
                tr.Id = int(idx + 1)
                tr.Timestep = int(t)

                # pgmlink expects always 3 coordinates, z=0 for 2d data
                tr.add_feature_array("com", 3)
                for i, v in enumerate([x, y, z]):
                    tr.set_feature_value('com', i, float(v))

                tr.add_feature_array("CoordMinimum", 3)
                for i, v in enumerate(lower[idx]):
                    tr.set_feature_value("CoordMinimum", i, float(v))
                tr.add_feature_array("CoordMaximum", 3)
                for i, v in enumerate(upper[idx]):
                    tr.set_feature_value("CoordMaximum", i, float(v))

                if with_opt_correction:
                    tr.add_feature_array("com_corrected", 3)
                    for i, v in enumerate(rc_corr[idx]):
                        tr.set_feature_value("com_corrected", i, float(v))
                    if len(rc_corr[idx]) == 2:
                        tr.set_feature_value("com_corrected", 2, 0.)

                if with_div:
                    tr.add_feature_array("divProb", 1)
                    # idx+1 because rc and ct start from 1, divProbs starts from 0
                    tr.set_feature_value("divProb", 0, float(divProbs[t][idx + 1][1]))

                if with_classifier_prior:
                    tr.add_feature_array("detProb", len(detProbs[t][idx + 1]))
                    for i, v in enumerate(detProbs[t][idx + 1]):
                        val = float(v)
                        if val < 0.0000001:
                            val = 0.0000001
                        if val > 0.99999999:
                            val = 0.99999999
                        tr.set_feature_value("detProb", i, float(val))


                # FIXME: check whether it is 2d or 3d data!
                if with_local_centers:
                    tr.add_feature_array("localCentersX", len(localCenters[t][idx + 1]))
                    tr.add_feature_array("localCentersY", len(localCenters[t][idx + 1]))
                    tr.add_feature_array("localCentersZ", len(localCenters[t][idx + 1]))
                    for i, v in enumerate(localCenters[t][idx + 1]):
                        tr.set_feature_value("localCentersX", i, float(v[0]))
                        tr.set_feature_value("localCentersY", i, float(v[1]))
                        tr.set_feature_value("localCentersZ", i, float(v[2]))

                tr.add_feature_array("count", 1)
                tr.set_feature_value("count", 0, float(size))
                if median_object_size is not None:
                    obj_sizes.append(float(size))

                ts.add(fs, tr)

            if len(filtered_labels_at) > 0:
                filtered_labels[str(int(t) - time_range[0])] = filtered_labels_at
            logger.debug("at timestep {}, {} traxels passed filter".format(t, count))
            max_traxel_id_at.append(int(rc.shape[0]))
            if count == 0:
                empty_frame = True

            total_count += count

        if median_object_size is not None:
            median_object_size[0] = np.median(np.array(obj_sizes), overwrite_input=True)
            logger.info('median object size = ' + str(median_object_size[0]))

        self.FilteredLabels.setValue(filtered_labels, check_changed=True)

        return fs, ts, empty_frame, max_traxel_id_at

    def save_export_progress_dialog(self, dialog):
        """
        Implements ExportOperator.save_export_progress_dialog
        Without this the progress dialog would be hidden after the export
        :param dialog: the ProgressDialog to save
        """
        self.export_progress_dialog = dialog

    def do_export(self, settings, selected_features, progress_slot, lane_index, filename_suffix=""):
        """
        Implements ExportOperator.do_export(settings, selected_features, progress_slot
        Most likely called from ExportOperator.export_object_data
        :param settings: the settings for the exporter, see
        :param selected_features:
        :param progress_slot:
        :param lane_index: Ignored. (This is a single-lane operator. It is the caller's responsibility to make sure he's calling the right lane.)
        :param filename_suffix: If provided, appended to the filename (before the extension).
        :return:
        """

        with_divisions = self.Parameters.value["withDivisions"] if self.Parameters.ready() else False
        if with_divisions:
            object_feature_slot = self.ObjectFeaturesWithDivFeatures
        else:
            object_feature_slot = self.ObjectFeatures

        self._do_export_impl(settings, selected_features, progress_slot, object_feature_slot, self.LabelImage, lane_index, filename_suffix)


    def _do_export_impl(self, settings, selected_features, progress_slot, object_feature_slot, label_image_slot, lane_index, filename_suffix=""):
        from ilastik.utility.exportFile import objects_per_frame, ExportFile, ilastik_ids, Mode, Default, \
            flatten_dict, division_flatten_dict

        selected_features = list(selected_features)
        with_divisions = self.Parameters.value["withDivisions"] if self.Parameters.ready() else False
        obj_count = list(objects_per_frame(label_image_slot))
        track_ids, extra_track_ids, divisions = self.export_track_ids()
        self._setLabel2Color()
        lineage = flatten_dict(self.label2color, obj_count)
        multi_move_max = self.Parameters.value["maxObj"] if self.Parameters.ready() else 2
        t_range = self.Parameters.value["time_range"] if self.Parameters.ready() else (0, 0)
        ids = ilastik_ids(obj_count)

        file_path = settings["file path"]
        if filename_suffix:
            path, ext = os.path.splitext(file_path)
            file_path = path + "-" + filename_suffix + ext

        export_file = ExportFile(file_path)
        export_file.ExportProgress.subscribe(progress_slot)
        export_file.InsertionProgress.subscribe(progress_slot)

        export_file.add_columns("table", range(sum(obj_count)), Mode.List, Default.KnimeId)
        export_file.add_columns("table", list(ids), Mode.List, Default.IlastikId)
        export_file.add_columns("table", lineage, Mode.List, Default.Lineage)
        export_file.add_columns("table", track_ids, Mode.IlastikTrackingTable,
                                {"max": multi_move_max, "counts": obj_count, "extra ids": extra_track_ids,
                                 "range": t_range})

        export_file.add_columns("table", object_feature_slot, Mode.IlastikFeatureTable,
                                {"selection": selected_features})

        if with_divisions:
            if divisions:
                div_lineage = division_flatten_dict(divisions, self.label2color)
                zips = zip(*divisions)
                divisions = zip(zips[0], div_lineage, *zips[1:])
                export_file.add_columns("divisions", divisions, Mode.List, Default.DivisionNames)
            else:
                logger.debug("No divisions occurred. Division Table will not be exported!")

        if settings["file type"] == "h5":
            export_file.add_rois(Default.LabelRoiPath, label_image_slot, "table", settings["margin"], "labeling")
            if settings["include raw"]:
                export_file.add_image(Default.RawPath, self.RawImage)
            else:
                export_file.add_rois(Default.RawRoiPath, self.RawImage, "table", settings["margin"])
        export_file.write_all(settings["file type"], settings["compression"])

        export_file.ExportProgress.unsubscribe(progress_slot)
        export_file.InsertionProgress.unsubscribe(progress_slot)
Example #14
0
class _OpCachedLabelImage(Operator):
    """
    Combines OpLabelImage with OpCompressedCache, and provides a default block shape.
    """

    Input = InputSlot()

    BackgroundLabels = InputSlot(
        optional=True)  # Optional. See OpLabelImage for details.
    BlockShape = InputSlot(
        optional=True
    )  # If not provided, blockshape is 1 time slice, 1 channel slice,
    #  and the entire volume in xyz.
    Output = OutputSlot()

    # Serialization support
    InputHdf5 = InputSlot(optional=True)
    CleanBlocks = OutputSlot()
    OutputHdf5 = OutputSlot()  # See OpCachedLabelImage for details

    # Schematic:
    #
    # BackgroundLabels --     BlockShape --
    #                    \                 \
    # Input ------------> OpLabelImage ---> OpCompressedCache --> Output
    #                                                        \
    #                                                         --> CleanBlocks

    def __init__(self, *args, **kwargs):
        super(_OpCachedLabelImage, self).__init__(*args, **kwargs)

        # Hook up the labeler
        self._opLabelImage = OpLabelImage(parent=self)
        self._opLabelImage.Input.connect(self.Input)
        self._opLabelImage.BackgroundLabels.connect(self.BackgroundLabels)

        # Hook up the cache
        self._opCache = OpCompressedCache(parent=self)
        self._opCache.Input.connect(self._opLabelImage.Output)
        self._opCache.InputHdf5.connect(self.InputHdf5)

        # Hook up our output slots
        self.Output.connect(self._opCache.Output)
        self.CleanBlocks.connect(self._opCache.CleanBlocks)
        self.OutputHdf5.connect(self._opCache.OutputHdf5)

    def generateReport(self, report):
        return self._opCache.generateReport(report)

    def usedMemory(self):
        return self._opCache.usedMemory()

    def fractionOfUsedMemoryDirty(self):
        return self._opCache.fractionOfUsedMemoryDirty()

    def lastAccessTime(self):
        return self._opCache.lastAccessTime()

    def setupOutputs(self):
        if self.BlockShape.ready():
            self._opCache.BlockShape.setValue(self.BlockShape.value)
        else:
            # By default, block shape is the same as the entire image shape,
            #  but only 1 time slice and 1 channel slice
            taggedBlockShape = self.Input.meta.getTaggedShape()
            taggedBlockShape["t"] = 1
            taggedBlockShape["c"] = 1
            self._opCache.BlockShape.setValue(tuple(taggedBlockShape.values()))

    def execute(self, slot, subindex, roi, destination):
        assert False, "Shouldn't get here."

    def propagateDirty(self, slot, subindex, roi):
        pass  # Nothing to do...

    def setInSlot(self, slot, subindex, roi, value):
        assert slot == self.Input or slot == self.InputHdf5, "Invalid slot for setInSlot(): {}".format(
            slot.name)
class OpSimpleBlockedArrayCache(OpUnblockedArrayCache):
    BlockShape = InputSlot(
        optional=True
    )  # Must be a tuple.  Any 'None' elements will be interpreted as 'max' for that dimension.
    BypassModeEnabled = InputSlot(value=False)

    def __init__(self, *args, **kwargs):
        super(OpSimpleBlockedArrayCache, self).__init__(*args, **kwargs)
        self._blockshape = None

    def setupOutputs(self):
        super(OpSimpleBlockedArrayCache, self).setupOutputs()
        if self.BlockShape.ready():
            self._blockshape = self.BlockShape.value
        else:
            self._blockshape = self.Input.meta.shape

        if len(self._blockshape) != len(self.Input.meta.shape):
            self.Output.meta.NOTREADY = True
            return

        # Replace 'None' (or zero) with default (from Input shape)
        self._blockshape = tuple(
            numpy.where(self._blockshape, self._blockshape,
                        self.Input.meta.shape))

        self.Output.meta.ideal_blockshape = tuple(
            numpy.minimum(self._blockshape, self.Input.meta.shape))

        # Estimate ram usage per requested pixel
        ram_per_pixel = get_ram_per_element(self.Input.meta.dtype)

        # One 'pixel' includes all channels
        tagged_shape = self.Input.meta.getTaggedShape()
        if "c" in tagged_shape:
            ram_per_pixel *= float(tagged_shape["c"])

        if self.Input.meta.ram_usage_per_requested_pixel is not None:
            ram_per_pixel = max(ram_per_pixel,
                                self.Input.meta.ram_usage_per_requested_pixel)

        self.Output.meta.ram_usage_per_requested_pixel = ram_per_pixel

    def _execute_Output(self, slot, subindex, roi, result):
        """
        Overridden from OpUnblockedArrayCache
        """
        def copy_block(full_block_roi, clipped_block_roi):
            full_block_roi = numpy.asarray(full_block_roi)
            clipped_block_roi = numpy.asarray(clipped_block_roi)
            output_roi = numpy.asarray(clipped_block_roi) - roi.start

            block_roi = self._get_containing_block_roi(clipped_block_roi)

            # Skip cache and copy full block directly
            if self.BypassModeEnabled.value:
                full_block_data = self.Output.stype.allocateDestination(
                    SubRegion(self.Output, *full_block_roi))

                self.Input(*full_block_roi).writeInto(full_block_data).block()

                roi_within_block = clipped_block_roi - full_block_roi[0]
                self.Output.stype.copy_data(
                    result[roiToSlice(*output_roi)],
                    full_block_data[roiToSlice(*roi_within_block)])
            # If data data exists already or we can just fetch it without needing extra scratch space,
            # just call the base class
            elif block_roi is not None or (full_block_roi
                                           == clipped_block_roi).all():
                self._execute_Output_impl(clipped_block_roi,
                                          result[roiToSlice(*output_roi)])
            elif self.Input.meta.dontcache:
                # Data isn't in the cache, but we don't need it in the cache anyway.
                self.Input(*clipped_block_roi).writeInto(
                    result[roiToSlice(*output_roi)]).block()
            else:
                # Data doesn't exist yet in the cache.
                # Request the full block, but then discard the parts we don't need.

                # (We use allocateDestination() here to support MaskedArray types.)
                # TODO: We should probably just get rid of MaskedArray support altogether...
                full_block_data = self.Output.stype.allocateDestination(
                    SubRegion(self.Output, *full_block_roi))
                self._execute_Output_impl(full_block_roi, full_block_data)

                roi_within_block = clipped_block_roi - full_block_roi[0]
                self.Output.stype.copy_data(
                    result[roiToSlice(*output_roi)],
                    full_block_data[roiToSlice(*roi_within_block)])

        clipped_block_rois = getIntersectingRois(self.Input.meta.shape,
                                                 self._blockshape,
                                                 (roi.start, roi.stop), True)
        full_block_rois = getIntersectingRois(self.Input.meta.shape,
                                              self._blockshape,
                                              (roi.start, roi.stop), False)

        pool = RequestPool()
        for full_block_roi, clipped_block_roi in zip(full_block_rois,
                                                     clipped_block_rois):
            req = Request(
                partial(copy_block, full_block_roi, clipped_block_roi))
            pool.add(req)
        pool.wait()

    def propagateDirty(self, slot, subindex, roi):
        if slot in (self.BypassModeEnabled, self.BlockShape):
            return
        super(OpSimpleBlockedArrayCache,
              self).propagateDirty(slot, subindex, roi)
class OpConservationTracking(Operator):
    LabelImage = InputSlot()
    ObjectFeatures = InputSlot(stype=Opaque, rtype=List)
    ObjectFeaturesWithDivFeatures = InputSlot(optional=True, stype=Opaque, rtype=List)
    ComputedFeatureNames = InputSlot(rtype=List, stype=Opaque)
    ComputedFeatureNamesWithDivFeatures = InputSlot(optional=True, rtype=List, stype=Opaque)
    FilteredLabels = InputSlot(value={})
    RawImage = InputSlot()
    Parameters = InputSlot(value={})
    HypothesesGraph = InputSlot(value={})
    ResolvedMergers = InputSlot(value={})
 
    # for serialization
    CleanBlocks = OutputSlot()
    AllBlocks = OutputSlot()
    CachedOutput = OutputSlot()  # For the GUI (blockwise-access)
 
    Output = OutputSlot() # Volume relabelled with lineage IDs
 
    # Use a slot for storing the export settings in the project file.
    # just here so that old projects still load!
    ExportSettings = InputSlot()

    DivisionProbabilities = InputSlot(optional=True, stype=Opaque, rtype=List)
    DetectionProbabilities = InputSlot(stype=Opaque, rtype=List)
    NumLabels = InputSlot()

    # compressed cache for merger output
    MergerCleanBlocks = OutputSlot()
    MergerCachedOutput = OutputSlot() # For the GUI (blockwise access)
    MergerOutput = OutputSlot() # Volume showing only merger IDs

    RelabeledCleanBlocks = OutputSlot()
    RelabeledCachedOutput = OutputSlot() # For the GUI (blockwise access)
    RelabeledImage = OutputSlot() # Volume showing object IDs

    def __init__(self, parent=None, graph=None):
        super(OpConservationTracking, self).__init__(parent=parent, graph=graph)

        self._opCache = OpBlockedArrayCache(parent=self)
        self._opCache.name = "OpConservationTracking._opCache"
        self._opCache.Input.connect(self.Output)
        self.CleanBlocks.connect(self._opCache.CleanBlocks)
        self.CachedOutput.connect(self._opCache.Output)

        self.zeroProvider = OpZeroDefault(parent=self)
        self.zeroProvider.MetaInput.connect(self.LabelImage)

        # As soon as input data is available, check its constraints
        self.RawImage.notifyReady(self._checkConstraints)
        self.LabelImage.notifyReady(self._checkConstraints)

        self.ExportSettings.setValue( (None, None) )

        self._mergerOpCache = OpBlockedArrayCache(parent=self)
        self._mergerOpCache.name = "OpConservationTracking._mergerOpCache"
        self._mergerOpCache.Input.connect(self.MergerOutput)
        self.MergerCleanBlocks.connect(self._mergerOpCache.CleanBlocks)
        self.MergerCachedOutput.connect(self._mergerOpCache.Output)

        self._relabeledOpCache = OpBlockedArrayCache(parent=self)
        self._relabeledOpCache.name = "OpConservationTracking._mergerOpCache"
        self._relabeledOpCache.Input.connect(self.RelabeledImage)
        self.RelabeledCleanBlocks.connect(self._relabeledOpCache.CleanBlocks)
        self.RelabeledCachedOutput.connect(self._relabeledOpCache.Output)
        
        # Merger resolver plugin manager (contains GMM fit routine)
        self.pluginPaths = [os.path.join(os.path.dirname(os.path.abspath(hytra.__file__)), 'plugins')]
        pluginManager = TrackingPluginManager(verbose=False, pluginPaths=self.pluginPaths)
        self.mergerResolverPlugin = pluginManager.getMergerResolver()

        self.result = None

        # gui progress
        self.progressWindow = None
        self.progressVisitor=DefaultProgressVisitor()

    def setupOutputs(self):
        self.Output.meta.assignFrom(self.LabelImage.meta)

        # cache our own output, don't propagate from internal operator
        chunks = list(self.LabelImage.meta.shape)
        # FIXME: assumes t,x,y,z,c
        chunks[0] = 1  # 't'        
        self._blockshape = tuple(chunks)
        self._opCache.BlockShape.setValue(self._blockshape)

        self.AllBlocks.meta.shape = (1,)
        self.AllBlocks.meta.dtype = object
        
        self.MergerOutput.meta.assignFrom(self.LabelImage.meta)
        self.RelabeledImage.meta.assignFrom(self.LabelImage.meta)

        self._mergerOpCache.BlockShape.setValue( self._blockshape )
        self._relabeledOpCache.BlockShape.setValue( self._blockshape )
        
        frame_shape = (1,) + self.LabelImage.meta.shape[1:] # assumes t,x,y,z,c order
        assert frame_shape[-1] == 1
        self.MergerOutput.meta.ideal_blockshape = frame_shape
        self.RelabeledImage.meta.ideal_blockshape = frame_shape
          
    def execute(self, slot, subindex, roi, result):
        # Output showing lineage IDs
        if slot is self.Output:
            if not self.Parameters.ready():
                raise Exception("Parameter slot is not ready")
            parameters = self.Parameters.value
            resolvedMergers = self.ResolvedMergers.value
            
            # Assume [t,x,y,z,c] order           
            trange = range(roi.start[0], roi.stop[0])
            offset = roi.start[1:-1]
       
            result[:] =  self.LabelImage.get(roi).wait()

            for t in trange:
                if 'time_range' in parameters and t <= parameters['time_range'][-1] and t >= parameters['time_range'][0]:
                    if resolvedMergers:
                        self._labelMergers(result[t-roi.start[0],...,0], t, offset)
                    result[t-roi.start[0],...,0] = self._labelLineageIds(result[t-roi.start[0],...,0], t)
                else:
                    result[t-roi.start[0],...][:] = 0
        
        # Output showing mergers only    
        elif slot is self.MergerOutput:
            parameters = self.Parameters.value
            resolvedMergers = self.ResolvedMergers.value
            
            # Assume [t,x,y,z,c] order
            trange = range(roi.start[0], roi.stop[0])
            offset = roi.start[1:-1]

            result[:] =  self.LabelImage.get(roi).wait()
   
            for t in trange:
                if 'time_range' in parameters and t <= parameters['time_range'][-1] and t >= parameters['time_range'][0]:
                    if resolvedMergers:
                        self._labelMergers(result[t-roi.start[0],...,0], t, offset)   
                    result[t-roi.start[0],...,0] = self._labelLineageIds(result[t-roi.start[0],...,0], t, onlyMergers=True)
                else:
                    result[t-roi.start[0],...][:] = 0

        # Output showing object Ids (before lineage IDs are assigned)   
        elif slot is self.RelabeledImage:
            parameters = self.Parameters.value
            resolvedMergers = self.ResolvedMergers.value
            
            # Assume [t,x,y,z,c] order
            trange = range(roi.start[0], roi.stop[0])
            offset = roi.start[1:-1] 

            result[:] =  self.LabelImage.get(roi).wait()
            
            for t in trange:
                if resolvedMergers and 'time_range' in parameters and t <= parameters['time_range'][-1] and t >= parameters['time_range'][0]:
                    self._labelMergers(result[t-roi.start[0],...,0], t, offset)
        
        # Cache blocks            
        elif slot == self.AllBlocks:
            # if nothing was computed, return empty list
            if not self.HypothesesGraph.value:
                result[0] = []
                return result

            all_block_rois = []
            shape = self.Output.meta.shape
            # assumes t,x,y,z,c
            slicing = [slice(None), ] * 5
            for t in range(shape[0]):
                slicing[0] = slice(t, t + 1)
                all_block_rois.append(sliceToRoi(slicing, shape))

            result[0] = all_block_rois
            return result

    def setInSlot(self, slot, subindex, roi, value):
        assert slot == self.InputHdf5 or slot == self.MergerInputHdf5 or slot == self.RelabeledInputHdf5, "Invalid slot for setInSlot(): {}".format( slot.name )
    
    def _createHypothesesGraph(self):
        '''
        Construct a hypotheses graph given the current settings in the parameters slot
        '''
        parameters = self.Parameters.value
        time_range = range(parameters['time_range'][0],parameters['time_range'][1] + 1)
        x_range = parameters['x_range']
        y_range = parameters['y_range']
        z_range = parameters['z_range']
        size_range = parameters['size_range']
        scales = parameters['scales']
        withDivisions = parameters['withDivisions']
        withClassifierPrior = parameters['withClassifierPrior']
        maxDist = parameters['maxDist']
        maxObj = parameters['maxObj']
        divThreshold = parameters['divThreshold']
        max_nearest_neighbors = parameters['max_nearest_neighbors']
        borderAwareWidth = parameters['borderAwareWidth']

        traxelstore = self._generate_traxelstore(time_range, x_range, y_range, z_range,
                                                       size_range, scales[0], scales[1], scales[2], 
                                                       with_div=withDivisions,
                                                       with_classifier_prior=withClassifierPrior)

        def constructFov(shape, t0, t1, scale=[1, 1, 1]):
            [xshape, yshape, zshape] = shape
            [xscale, yscale, zscale] = scale
        
            fov = FieldOfView(t0, 0, 0, 0, t1, xscale * (xshape - 1), yscale * (yshape - 1),
                              zscale * (zshape - 1))
            return fov

        fieldOfView = constructFov((x_range[1], y_range[1], z_range[1]),
                                   time_range[0],
                                   time_range[-1]+1,
                                   scales)

        if WITH_HYTRA:
            hypothesesGraph = IlastikHypothesesGraph(
                probabilityGenerator=traxelstore,
                timeRange=(time_range[0],time_range[-1]+1),
                maxNumObjects=maxObj,
                numNearestNeighbors=max_nearest_neighbors,
                fieldOfView=fieldOfView,
                withDivisions=withDivisions,
                maxNeighborDistance=maxDist,
                divisionThreshold=divThreshold,
                borderAwareWidth=borderAwareWidth,
                progressVisitor=self.progressVisitor
            )
        else:
            hypothesesGraph = IlastikHypothesesGraph(
                probabilityGenerator=traxelstore,
                timeRange=(time_range[0],time_range[-1]+1),
                maxNumObjects=maxObj,
                numNearestNeighbors=max_nearest_neighbors,
                fieldOfView=fieldOfView,
                withDivisions=withDivisions,
                maxNeighborDistance=maxDist,
                divisionThreshold=divThreshold,
                borderAwareWidth=borderAwareWidth
            )
        return hypothesesGraph
    
    def _resolveMergers(self, hypothesesGraph, model):
        '''
        run merger resolution on the hypotheses graph which contains the current solution
        '''
        logger.info("Resolving mergers.")
                
        parameters = self.Parameters.value
        withTracklets = parameters['withTracklets']
        originalGraph = hypothesesGraph.referenceTraxelGraph if withTracklets else hypothesesGraph
        resolvedMergersDict = {}
        
        # Enable full graph computation for animal tracking workflow
        withFullGraph = False
        if 'withAnimalTracking' in parameters and parameters['withAnimalTracking']: # TODO: Setting this parameter outside of the track() function (on AnimalConservationTrackingWorkflow) is not desirable 
            withFullGraph = True
            logger.info("Computing full graph on merger resolver (Only enabled on animal tracking workflow)")
        
        mergerResolver = IlastikMergerResolver(originalGraph, pluginPaths=self.pluginPaths, withFullGraph=withFullGraph)
        
        # Check if graph contains mergers, otherwise skip merger resolving
        if not mergerResolver.mergerNum:
            logger.info("Graph contains no mergers. Skipping merger resolving.")
        else:        
            # Fit and refine merger nodes using a GMM 
            # It has to be done per time-step in order to aviod loading the whole video on RAM
            traxelIdPerTimestepToUniqueIdMap, uuidToTraxelMap = getMappingsBetweenUUIDsAndTraxels(model)
            timesteps = [int(t) for t in traxelIdPerTimestepToUniqueIdMap.keys()]
            timesteps.sort()
            
            timeIndex = self.LabelImage.meta.axistags.index('t')
            numTimeStep = len(timesteps)
            count=0
            for timestep in timesteps:
                count +=1
                self.progressVisitor.showProgress(count/float(numTimeStep))

                roi = [slice(None) for i in range(len(self.LabelImage.meta.shape))]
                roi[timeIndex] = slice(timestep, timestep+1)
                roi = tuple(roi)
                
                labelImage = self.LabelImage[roi].wait()
                
                # Get coordinates for object IDs in label image. Used by GMM merger fit.
                objectIds = vigra.analysis.unique(labelImage[0,...,0])
                maxObjectId = max(objectIds)
                
                coordinatesForIds = {}
                
                pool = RequestPool()
                for objectId in objectIds:
                    pool.add(Request(partial(mergerResolver.getCoordinatesForObjectId, coordinatesForIds, labelImage[0, ..., 0], timestep, objectId)))                 

                # Run requests to get object ID coordinates
                pool.wait()              
                
                # Fit mergers and store fit info in nodes  
                if coordinatesForIds:
                    mergerResolver.fitAndRefineNodesForTimestep(coordinatesForIds, maxObjectId, timestep)   
                
            self.parent.parent.trackingApplet.progressSignal.emit(100)

            # Compute object features, re-run flow solver, update model and result, and get merger dictionary
            resolvedMergersDict = mergerResolver.run()
        return resolvedMergersDict

    def raiseException(self, progressWindow, str):
        if progressWindow is not None:
            progressWindow.onTrackDone()
        raise Exception (str)

    def raiseDatasetConstraintError(self, progressWindow, titleStr, str):
        if progressWindow is not None:
            progressWindow.onTrackDone()
        raise DatasetConstraintError(titleStr, str)

    def track(self,
            time_range,
            x_range,
            y_range,
            z_range,
            size_range=(0, 100000),
            x_scale=1.0,
            y_scale=1.0,
            z_scale=1.0,
            maxDist=30,     
            maxObj=2,       
            divThreshold=0.5,
            avgSize=[0],                        
            withTracklets=False,
            sizeDependent=True,
            detWeight=10.0,
            divWeight=10.0,
            transWeight=10.0,
            withDivisions=True,
            withOpticalCorrection=True,
            withClassifierPrior=False,
            ndim=3,
            cplex_timeout=None,
            withMergerResolution=True,
            borderAwareWidth = 0.0,
            withArmaCoordinates = True,
            appearance_cost = 500,
            disappearance_cost = 500,
            motionModelWeight=10.0,
            force_build_hypotheses_graph = False,
            max_nearest_neighbors = 1,
            numFramesPerSplit=0,
            withBatchProcessing = False,
            solverName="Flow-based",
            progressWindow=None,
            progressVisitor=CommandLineProgressVisitor()
            ):
        """
        Main conservation tracking function. Runs tracking solver, generates hypotheses graph, and resolves mergers.
        """

        if WITH_HYTRA:
            self.progressWindow = progressWindow
            self.progressVisitor=progressVisitor
        else:
            self.progressWindow = None
            self.progressVisitor = DefaultProgressVisitor()

        if not self.Parameters.ready():
            self.raiseException(self.progressWindow, "Parameter slot is not ready")
        
        # it is assumed that the self.Parameters object is changed only at this
        # place (ugly assumption). Therefore we can track any changes in the
        # parameters as done in the following lines: If the same value for the
        # key is already written in the parameters dictionary, the
        # paramters_changed dictionary will get a "False" entry for this key,
        # otherwise it is set to "True"
        parameters = self.Parameters.value

        parameters['maxDist'] = maxDist
        parameters['maxObj'] = maxObj
        parameters['divThreshold'] = divThreshold
        parameters['avgSize'] = avgSize
        parameters['withTracklets'] = withTracklets
        parameters['sizeDependent'] = sizeDependent
        parameters['detWeight'] = detWeight
        parameters['divWeight'] = divWeight
        parameters['transWeight'] = transWeight
        parameters['withDivisions'] = withDivisions
        parameters['withOpticalCorrection'] = withOpticalCorrection
        parameters['withClassifierPrior'] = withClassifierPrior
        parameters['withMergerResolution'] = withMergerResolution
        parameters['borderAwareWidth'] = borderAwareWidth
        parameters['withArmaCoordinates'] = withArmaCoordinates
        parameters['appearanceCost'] = appearance_cost
        parameters['disappearanceCost'] = disappearance_cost       
        parameters['scales'] = [x_scale, y_scale, z_scale]
        parameters['time_range'] = [min(time_range), max(time_range)]
        parameters['x_range'] = x_range
        parameters['y_range'] = y_range
        parameters['z_range'] = z_range
        parameters['max_nearest_neighbors'] = max_nearest_neighbors
        parameters['numFramesPerSplit'] = numFramesPerSplit
        parameters['solver'] = str(solverName)

        # Set a size range with a minimum area equal to the max number of objects (since the GMM throws an error if we try to fit more gaussians than the number of pixels in the object)
        size_range = (max(maxObj, size_range[0]), size_range[1])
        parameters['size_range'] = size_range

        if cplex_timeout:
            parameters['cplex_timeout'] = cplex_timeout
        else:
            parameters['cplex_timeout'] = ''
            cplex_timeout = float(1e75)
        
        self.Parameters.setValue(parameters, check_changed=False)
        
        if withClassifierPrior:
            if not self.DetectionProbabilities.ready() or len(self.DetectionProbabilities([0]).wait()[0]) == 0:
                self.raiseDatasetConstraintError(self.progressWindow, 'Tracking', 'Classifier not ready yet. Did you forget to train the Object Count Classifier?')
            if not self.NumLabels.ready() or self.NumLabels.value < (maxObj + 1):
                self.raiseDatasetConstraintError(self.progressWindow, 'Tracking', 'The max. number of objects must be consistent with the number of labels given in Object Count Classification.\n' +\
                    'Check whether you have (i) the correct number of label names specified in Object Count Classification, and (ii) provided at least ' +\
                    'one training example for each class.')
            if len(self.DetectionProbabilities([0]).wait()[0][0]) < (maxObj + 1):
                self.raiseDatasetConstraintError(self.progressWindow, 'Tracking', 'The max. number of objects must be consistent with the number of labels given in Object Count Classification.\n' +\
                    'Check whether you have (i) the correct number of label names specified in Object Count Classification, and (ii) provided at least ' +\
                    'one training example for each class.')

        hypothesesGraph = self._createHypothesesGraph()
        hypothesesGraph.allowLengthOneTracks = True

        if withTracklets:
            hypothesesGraph = hypothesesGraph.generateTrackletGraph()

        hypothesesGraph.insertEnergies()
        trackingGraph = hypothesesGraph.toTrackingGraph()
        trackingGraph.convexifyCosts()
        model = trackingGraph.model
        model['settings']['allowLengthOneTracks'] = True

        detWeight = 10.0 # FIXME: Should we store this weight in the parameters slot?
        weights = trackingGraph.weightsListToDict([transWeight, detWeight, divWeight, appearance_cost, disappearance_cost])

        stepStr = "Tracking solver"
        self.progressVisitor.showState(stepStr)
        self.progressVisitor.showProgress(0)

        if solverName == 'Flow-based' and dpct:
            if numFramesPerSplit:
                # Run solver with frame splits (split, solve, and stitch video to improve running-time)
                from hytra.core.splittracking import SplitTracking 
                result = SplitTracking.trackFlowBasedWithSplits(model, weights, numFramesPerSplit=numFramesPerSplit)
            else:
                result = dpct.trackFlowBased(model, weights)

        elif solverName == 'ILP' and mht:
            result = mht.track(model, weights)
        else:
            raise ValueError("Invalid tracking solver selected")

        self.progressVisitor.showProgress(1.0)
        # Insert the solution into the hypotheses graph and from that deduce the lineages
        if hypothesesGraph:
            hypothesesGraph.insertSolution(result)
            
        # Merger resolution
        resolvedMergersDict = {}
        if withMergerResolution:
            stepStr = "Merger resolution"
            self.progressVisitor.showState(stepStr)
            resolvedMergersDict = self._resolveMergers(hypothesesGraph, model)

        # Set value of resolved mergers slot (Should be empty if mergers are disabled)
        self.ResolvedMergers.setValue(resolvedMergersDict, check_changed=False)
                
        # Computing tracking lineage IDs from within Hytra
        hypothesesGraph.computeLineage()

        if self.progressWindow is not None:
            self.progressWindow.onTrackDone()

        # Uncomment to export a hypothese graph diagram
        #logger.info("Exporting hypotheses graph diagram")
        #from hytra.util.hypothesesgraphdiagram import HypothesesGraphDiagram
        #hgv = HypothesesGraphDiagram(hypothesesGraph._graph, timeRange=(0, 10), fileName='HypothesesGraph.png' )
                
        # Set value of hypotheses grap slot (use referenceTraxelGraph if using tracklets)
        hypothesesGraph = hypothesesGraph.referenceTraxelGraph if withTracklets else hypothesesGraph
        self.HypothesesGraph.setValue(hypothesesGraph, check_changed=False)

        # Set all the output slots dirty (See execute() function)
        self.Output.setDirty()
        self.MergerOutput.setDirty()
        self.RelabeledImage.setDirty()

        return result

    def propagateDirty(self, inputSlot, subindex, roi):
        if inputSlot is self.LabelImage:
            self.Output.setDirty(roi)
        elif inputSlot is self.HypothesesGraph:
            pass
        elif inputSlot is self.ResolvedMergers:
            pass
        elif inputSlot == self.NumLabels:
            if self.parent.parent.trackingApplet._gui \
                    and self.parent.parent.trackingApplet._gui.currentGui() \
                    and self.NumLabels.ready() \
                    and self.NumLabels.value > 1:
                self.parent.parent.trackingApplet._gui.currentGui()._drawer.maxObjectsBox.setValue(self.NumLabels.value-1)

    def _labelMergers(self, volume, time, offset):
        """
        Label volume mergers with correspoding IDs, using the plugin GMM fit
        """
        resolvedMergersDict = self.ResolvedMergers.value
        
        if time not in resolvedMergersDict:
            return volume
        
        idxs = vigra.analysis.unique(volume)
        
        for idx in idxs: 
            if idx in resolvedMergersDict[time]:
                fits = resolvedMergersDict[time][idx]['fits']
                newIds = resolvedMergersDict[time][idx]['newIds']
                self.mergerResolverPlugin.updateLabelImage(volume, idx, fits, newIds, offset=offset)
        
        return volume               

    def _labelLineageIds(self, volume, time, onlyMergers=False):
        """
        Label the every object in the volume for the given time frame by the lineage ID it belongs to.
        If onlyMergers is True, then only those segments that were resolved from a merger are shown, everything else set to zero.

        :return: the relabeled volume, where 0 means background, 1 means false detection, and all higher numbers indicate lineages
        """
        hypothesesGraph = self.HypothesesGraph.value
        
        if not hypothesesGraph:
            return np.zeros_like(volume) 
        
        resolvedMergersDict = self.ResolvedMergers.value

        indexMapping = np.zeros(np.amax(volume) + 1, dtype=volume.dtype)
        
        idxs = vigra.analysis.unique(volume)
        
        # Reduce labels to the ones that contain mergers
        if onlyMergers:
            if resolvedMergersDict:
                if time not in resolvedMergersDict:
                    idxs = []
                else:
                    newIds = [newId for _, nodeDict in resolvedMergersDict[time].items() for newId in nodeDict['newIds']]
                    idxs = [id for id in idxs if id in newIds]
            else:
                idxs = [idx for idx in idxs if idx > 0 and hypothesesGraph.hasNode((time,idx)) and hypothesesGraph._graph.node[(time,idx)]['value'] > 1]

        # Map labels to corresponding lineage IDs
        for idx in idxs:
            if idx > 0 and hypothesesGraph.hasNode((time,idx)):
                lineage_id = hypothesesGraph.getLineageId(time, idx)
                if lineage_id is None:
                    lineage_id = 1
                indexMapping[idx] = lineage_id
            
        return indexMapping[volume]
 
 
    def _setupRelabeledFeatureSlot(self, original_feature_slot):
        from ilastik.applets.trackingFeatureExtraction import config
        # when exporting after merger resolving, the stored object features are not up to date for the relabeled objects
        opRelabeledRegionFeatures = OpRelabeledMergerFeatureExtraction(parent=self)
        opRelabeledRegionFeatures.RawImage.connect(self.RawImage)
        opRelabeledRegionFeatures.LabelImage.connect(self.LabelImage)
        opRelabeledRegionFeatures.RelabeledImage.connect(self.RelabeledImage)
        opRelabeledRegionFeatures.OriginalRegionFeatures.connect(original_feature_slot)

        vigra_features = list((set(config.vigra_features)).union(config.selected_features_objectcount[config.features_vigra_name]))
        feature_names_vigra = {}
        feature_names_vigra[config.features_vigra_name] = { name: {} for name in vigra_features }
        opRelabeledRegionFeatures.FeatureNames.setValue(feature_names_vigra)

        return opRelabeledRegionFeatures
                     

    def exportPlugin(self, filename, plugin, checkOverwriteFiles=False):
        with_divisions = self.Parameters.value["withDivisions"] if self.Parameters.ready() else False
        with_merger_resolution = self.Parameters.value["withMergerResolution"] if self.Parameters.ready() else False

        # Create opRegionFeatures to extract features of relabeled volume
        if with_merger_resolution:
            parameters = self.Parameters.value
            
            # Use simple relabeled merger feature slot configuration instead of opRelabeledMergerFeatureExtraction
            # This is faster for videos with few mergers and few number of objects per frame
            if False:#'withAnimalTracking' in parameters and parameters['withAnimalTracking']:  
                logger.info('Setting relabeled merger feature slots for animal tracking')
                from ilastik.applets.trackingFeatureExtraction import config
                
                self._opRegionFeatures = OpRegionFeatures(parent=self)
                self._opRegionFeatures.RawVolume.connect(self.RawImage)
                self._opRegionFeatures.LabelVolume.connect(self.RelabeledImage)
                
                vigra_features = list((set(config.vigra_features)).union(config.selected_features_objectcount[config.features_vigra_name]))
                feature_names_vigra = {}
                feature_names_vigra[config.features_vigra_name] = { name: {} for name in vigra_features }
                self._opRegionFeatures.Features.setValue(feature_names_vigra)
        
                self._opAdaptTimeListRoi = OpAdaptTimeListRoi(parent=self)
                self._opAdaptTimeListRoi.Input.connect(self._opRegionFeatures.Output)
                
                object_feature_slot = self._opAdaptTimeListRoi.Output
            # Use opRelabeledMergerFeatureExtraction for cell tracking
            else:
                opRelabeledRegionFeatures = self._setupRelabeledFeatureSlot(self.ObjectFeatures)
                object_feature_slot = opRelabeledRegionFeatures.RegionFeatures                
            
            label_image = self.RelabeledImage

        # Use ObjectFeaturesWithDivFeatures slot
        elif with_divisions:
            object_feature_slot = self.ObjectFeaturesWithDivFeatures
            label_image = self.LabelImage
        # Use ObjectFeatures slot only
        else:
            object_feature_slot = self.ObjectFeatures
            label_image = self.LabelImage
        
        hypothesesGraph = self.HypothesesGraph.value

        if checkOverwriteFiles and plugin.checkFilesExist(filename):
            # do not export if we would otherwise overwrite files
            return False

        if not plugin.export(filename, hypothesesGraph, object_feature_slot, label_image, self.RawImage):
            raise RuntimeError('Exporting tracking solution with plugin failed')
        else:
            return True

    def get_table_export_settings(self):
        # TODO: remove once tracking is hytra-only
        return None, None

    def _checkConstraints(self, *args):
        if self.RawImage.ready():
            rawTaggedShape = self.RawImage.meta.getTaggedShape()
            if rawTaggedShape['t'] < 2:
                raise DatasetConstraintError(
                    "Tracking",
                    "For tracking, the dataset must have a time axis with at least 2 images.   " \
                    "Please load time-series data instead. See user documentation for details.")

        if self.LabelImage.ready():
            segmentationTaggedShape = self.LabelImage.meta.getTaggedShape()
            if segmentationTaggedShape['t'] < 2:
                raise DatasetConstraintError(
                    "Tracking",
                    "For tracking, the dataset must have a time axis with at least 2 images.   " \
                    "Please load time-series data instead. See user documentation for details.")

        if self.RawImage.ready() and self.LabelImage.ready():
            rawTaggedShape['c'] = None
            segmentationTaggedShape['c'] = None
            if dict(rawTaggedShape) != dict(segmentationTaggedShape):
                raise DatasetConstraintError("Tracking",
                                             "For tracking, the raw data and the prediction maps must contain the same " \
                                             "number of timesteps and the same shape.   " \
                                             "Your raw image has a shape of (t, x, y, z, c) = {}, whereas your prediction image has a " \
                                             "shape of (t, x, y, z, c) = {}" \
                                             .format(self.RawImage.meta.shape, self.BinaryImage.meta.shape))

    def _generate_traxelstore(self,
                              time_range,
                              x_range,
                              y_range,
                              z_range,
                              size_range,
                              x_scale=1.0,
                              y_scale=1.0,
                              z_scale=1.0,
                              with_div=False,
                              with_local_centers=False,
                              with_classifier_prior=False):

        logger.info("generating traxels")
        traxelstore = ProbabilityGenerator()
        
        logger.info("fetching region features and division probabilities")
        feats = self.ObjectFeatures(time_range).wait()

        if with_div:
            if not self.DivisionProbabilities.ready() or len(self.DivisionProbabilities([0]).wait()[0]) == 0:
                msgStr = "\nDivision classifier has not been trained! " + \
                         "Uncheck divisible objects if your objects don't divide or " + \
                         "go back to the Division Detection applet and train it."
                raise DatasetConstraintError ("Tracking",msgStr)
            divProbs = self.DivisionProbabilities(time_range).wait()

        if with_local_centers:
            localCenters = self.RegionLocalCenters(time_range).wait()

        if with_classifier_prior:
            if not self.DetectionProbabilities.ready() or len(self.DetectionProbabilities([0]).wait()[0]) == 0:
                msgStr = "\nObject count classifier has not been trained! " + \
                         "Go back to the Object Count Classification applet and train it."
                raise DatasetConstraintError ("Tracking",msgStr)
            detProbs = self.DetectionProbabilities(time_range).wait()

        logger.info("filling traxelstore")

        filtered_labels = {}
        total_count = 0
        empty_frame = False
        numTimeStep = len(feats.keys())
        countT = 0

        stepStr = "Creating traxel store"
        self.progressVisitor.showState(stepStr+"                              ")

        for t in feats.keys():
            countT +=1
            self.progressVisitor.showProgress(countT/float(numTimeStep))

            rc = feats[t][default_features_key]['RegionCenter']
            lower = feats[t][default_features_key]['Coord<Minimum>']
            upper = feats[t][default_features_key]['Coord<Maximum>']
            if rc.size:
                rc = rc[1:, ...]
                lower = lower[1:, ...]
                upper = upper[1:, ...]

            ct = feats[t][default_features_key]['Count']
            if ct.size:
                ct = ct[1:, ...]

            logger.debug("at timestep {}, {} traxels found".format(t, rc.shape[0]))
            count = 0
            filtered_labels_at = []
            for idx in range(rc.shape[0]):
                traxel = Traxel()
                
                # for 2d data, set z-coordinate to 0:
                if len(rc[idx]) == 2:
                    x, y = rc[idx]
                    z = 0
                    x_lower, y_lower = lower[idx]
                    x_upper, y_upper = upper[idx]
                    z_lower = 0
                    z_upper = 0
                elif len(rc[idx]) == 3:
                    x, y, z = rc[idx]
                    x_lower, y_lower, z_lower = lower[idx]
                    x_upper, y_upper, z_upper = upper[idx]
                else:
                    raise DatasetConstraintError ("Tracking", "The RegionCenter feature must have dimensionality 2 or 3.")

                size = ct[idx]

                if (x_upper < x_range[0]  or x_lower >= x_range[1] or
                            y_upper < y_range[0] or y_lower >= y_range[1] or
                            z_upper < z_range[0] or z_lower >= z_range[1] or
                            size < size_range[0] or size >= size_range[1]):
                    filtered_labels_at.append(int(idx + 1))
                    continue
                else:
                    count += 1
                
                traxel.Id = int(idx + 1)
                traxel.Timestep = int(t) 
                traxel.set_x_scale(x_scale)
                traxel.set_y_scale(y_scale)
                traxel.set_z_scale(z_scale)

                # Expects always 3 coordinates, z=0 for 2d data
                traxel.add_feature_array("com", 3)
                for i, v in enumerate([x, y, z]):
                    traxel.set_feature_value('com', i, float(v))

                traxel.add_feature_array("CoordMinimum", 3)
                for i, v in enumerate(lower[idx]):
                    traxel.set_feature_value("CoordMinimum", i, float(v))
                traxel.add_feature_array("CoordMaximum", 3)
                for i, v in enumerate(upper[idx]):
                    traxel.set_feature_value("CoordMaximum", i, float(v))

                if with_div:
                    traxel.add_feature_array("divProb", 2)
                    # idx+1 because rc and ct start from 1, divProbs starts from 0
                    prob = float(divProbs[t][idx + 1][1])
                    prob = float(prob)
                    if prob < 0.0000001:
                        prob = 0.0000001
                    if prob > 0.99999999:
                        prob = 0.99999999
                    traxel.set_feature_value("divProb", 0, 1.0 - prob)
                    traxel.set_feature_value("divProb", 1, prob)

                if with_classifier_prior:
                    traxel.add_feature_array("detProb", len(detProbs[t][idx + 1]))
                    for i, v in enumerate(detProbs[t][idx + 1]):
                        val = float(v)
                        if val < 0.0000001:
                            val = 0.0000001
                        if val > 0.99999999:
                            val = 0.99999999
                        traxel.set_feature_value("detProb", i, float(val))

                # FIXME: check whether it is 2d or 3d data!
                if with_local_centers:                   
                    traxel.add_feature_array("localCentersX", len(localCenters[t][idx + 1]))
                    traxel.add_feature_array("localCentersY", len(localCenters[t][idx + 1]))
                    traxel.add_feature_array("localCentersZ", len(localCenters[t][idx + 1]))
                    
                    for i, v in enumerate(localCenters[t][idx + 1]):                        
                        traxel.set_feature_value("localCentersX", i, float(v[0]))
                        traxel.set_feature_value("localCentersY", i, float(v[1]))
                        traxel.set_feature_value("localCentersZ", i, float(v[2]))
                
                traxel.add_feature_array("count", 1)
                traxel.set_feature_value("count", 0, float(size))

                if (x_upper < x_range[0]  or x_lower >= x_range[1] or
                            y_upper < y_range[0] or y_lower >= y_range[1] or
                            z_upper < z_range[0] or z_lower >= z_range[1] or
                            size < size_range[0] or size >= size_range[1]):
                    logger.info("Omitting traxel with ID: {} {}".format(traxel.Id,t))
                    print "Omitting traxel with ID: {} {}".format(traxel.Id,t)
                else:
                    logger.info("Adding traxel with ID: {}  {}".format(traxel.Id,t))
                    traxelstore.TraxelsPerFrame.setdefault(int(t), {})[int(idx + 1)] = traxel

            if len(filtered_labels_at) > 0:
                filtered_labels[str(int(t) - time_range[0])] = filtered_labels_at
                
            logger.debug("at timestep {}, {} traxels passed filter".format(t, count))

            if count == 0:
                empty_frame = True
                logger.info('Found empty frames for time {}'.format(t))

            total_count += count

        self.parent.parent.trackingApplet.progressSignal.emit(100)
        self.FilteredLabels.setValue(filtered_labels, check_changed=True)

        return traxelstore
    
    def isTrackingSolutionAvailable(self):
        """
        check whether the hypotheses graph is filled and contains a tracking solution
        
        :return: True if there is a tracking solution available, False otherwise
        """
        hypothesesGraph = self.HypothesesGraph.value

        from hytra.core.hypothesesgraph import HypothesesGraph
        if isinstance(hypothesesGraph, HypothesesGraph):
            hypothesesGraph = hypothesesGraph.referenceTraxelGraph if hypothesesGraph.withTracklets else hypothesesGraph
            if 'value' in hypothesesGraph._graph.nodes(data='True')[0][1]:
                return True
        return False
class OpInterpolate(Operator):
    InputVolume = InputSlot()
    Missing = InputSlot()
    InterpolationMethod = InputSlot(value="cubic")

    Output = OutputSlot()

    _requiredMargin = {"cubic": 2, "linear": 1, "constant": 0}
    _maxInterpolationDistance = {
        "cubic": 1,
        "linear": np.inf,
        "constant": np.inf
    }
    _fallbacks = {"cubic": "linear", "linear": "constant", "constant": None}

    def propagateDirty(self, slot, subindex, roi):
        # TODO
        self.Output.setDirty(roi)

    def setupOutputs(self):
        # Output has the same shape/axes/dtype/drange as input
        self.Output.meta.assignFrom(self.InputVolume.meta)

        try:
            self._iinfo = np.iinfo(self.InputVolume.meta.dtype)
        except ValueError:
            # not integer type, no casting needed
            self._iinfo = None

        assert (
            self.InputVolume.meta.getTaggedShape() ==
            self.Missing.meta.getTaggedShape()
        ), "InputVolume and Missing must have the same shape " + "({} vs {})".format(
            self.InputVolume.meta.getTaggedShape(),
            self.Missing.meta.getTaggedShape())

    def execute(self, slot, subindex, roi, result):

        # prefill result
        result[:] = self.InputVolume.get(roi).wait()

        resultZYXCT = vigra.taggedView(
            result, self.InputVolume.meta.axistags).withAxes(*"zyxct")
        missingZYXCT = vigra.taggedView(
            self.Missing.get(roi).wait(),
            self.Missing.meta.axistags).withAxes(*"zyxct")

        for t in range(resultZYXCT.shape[4]):
            for c in range(resultZYXCT.shape[3]):
                missingLabeled = vigra.analysis.labelVolumeWithBackground(
                    missingZYXCT[..., c, t])
                maxLabel = missingLabeled.max()
                for i in range(1, maxLabel + 1):
                    self._interpolate(resultZYXCT[..., c, t],
                                      missingLabeled == i)

        return result

    def _cast(self, x):
        """
        casts the array to expected range (i.e. 0..255 for uint8 types, ...)
        """
        if not self._iinfo is None:
            x = np.where(x > self._iinfo.max, self._iinfo.max, x)
            x = np.where(x < self._iinfo.min, self._iinfo.min, x)
        return x

    def _interpolate(self, volume, missing, method=None):
        """
        interpolates in z direction
        :param volume: 3d block with axistags 'zyx'
        :type volume: array-like
        :param missing: True where data is missing
        :type missing: bool, 3d block with axistags 'zyx'
        :param method: 'cubic' or 'linear' or 'constant'
        :type method: str
        """

        method = self.InterpolationMethod.value if method is None else method
        # sanity checks
        assert method in list(
            self._requiredMargin.keys()), "Unknown method '{}'".format(method)

        assert (volume.axistags.index("z") == 0
                and volume.axistags.index("y") == 1
                and volume.axistags.index("x") == 2 and len(volume.shape)
                == 3), "Data must be 3d with z as first axis."

        # number and z-location of missing slices (z-axis is at zero)
        black_z_ind, black_y_ind, black_x_ind = np.where(missing)

        if len(black_z_ind) == 0:  # no need for interpolation
            return

        if black_z_ind.max() - black_z_ind.min(
        ) + 1 > self._maxInterpolationDistance[method]:
            self._interpolate(volume, missing, self._fallbacks[method])
            return

        # indices with respect to the required margin around the missing values
        minZ = black_z_ind.min() - self._requiredMargin[method]
        maxZ = black_z_ind.max() + self._requiredMargin[method]

        n = maxZ - minZ - 2 * self._requiredMargin[method] + 1

        if not (minZ > -1 and maxZ < volume.shape[0]):
            # this method is not applicable, try another one
            logger.warning(" ".join((
                "Margin not big enough for interpolation ",
                "(need at least {} pixels for '{}')".format(
                    self._requiredMargin[method], method),
            )))

            if self._fallbacks[method] is not None:
                logger.warning("Falling back to method '{}'".format(
                    self._fallbacks[method]))
                self._interpolate(volume, missing, self._fallbacks[method])
                return
            else:
                assert False, " ".join((
                    "Margin not big enough for interpolation",
                    "(need at least {} pixels for '{}')".format(
                        self._requiredMargin[method], method),
                    "and no fallback available",
                ))

        minY, maxY = (black_y_ind.min(), black_y_ind.max())
        minX, maxX = (black_x_ind.min(), black_x_ind.max())

        if method == "linear" or method == "cubic" and n > 1:
            # do a convex combination of the boundary slices
            xs = np.linspace(0, 1, n + 2)
            left = volume[minZ, minY:maxY + 1, minX:maxX + 1]
            right = volume[maxZ, minY:maxY + 1, minX:maxX + 1]

            for i in range(n):
                # interpolate every slice
                volume[minZ + i + 1, minY:maxY + 1,
                       minX:maxX + 1] = self._cast((1 - xs[i + 1]) * left +
                                                   xs[i + 1] * right)

        elif method == "cubic":
            # interpolation coefficients

            D = np.rollaxis(
                volume[[minZ, minZ + 1, maxZ - 1, maxZ], minY:maxY + 1,
                       minX:maxX + 1], 0, 3)
            F = np.tensordot(D, _cubic_mat(n), ([2], [1]))

            xs = np.linspace(0, 1, n + 2)
            for i in range(n):
                # interpolate every slice
                x = xs[i + 1]
                volume[minZ + i + 2, minY:maxY + 1, minX:maxX +
                       1] = self._cast(F[..., 0] + F[..., 1] * x +
                                       F[..., 2] * x**2 + F[..., 3] * x**3)

        else:  # constant
            if minZ > 0:
                # fill right hand side with last good slice
                for i in range(maxZ - minZ + 1):
                    volume[minZ + i, minY:maxY + 1,
                           minX:maxX + 1] = volume[minZ - 1, minY:maxY + 1,
                                                   minX:maxX + 1]
            elif maxZ < volume.shape[0] - 1:
                # fill left hand side with last good slice
                for i in range(maxZ - minZ + 1):
                    volume[minZ + i, minY:maxY + 1,
                           minX:maxX + 1] = volume[maxZ + 1, minY:maxY + 1,
                                                   minX:maxX + 1]
            else:
                # nothing to do for empty block
                logger.warning(
                    "Not enough data for interpolation leaving slice as is ..."
                )
Example #18
0
class OpStackLoader(Operator):
    """Imports an image stack.

    Note: This operator does NOT cache the images, so direct access
          via the execute() function is very inefficient, especially
          through the Z-axis. Typically, you'll want to connect this
          operator to a cache whose block size is large in the X-Y
          plane.

    :param globstring: A glob string as defined by the glob module. We
        also support the following special extension to globstring
        syntax: A single string can hold a *list* of globstrings.
        The delimiter that separates the globstrings in the list is
        OS-specific via os.path.pathsep.

        For example, on Linux the pathsep is':', so

            '/a/b/c.txt:/d/e/f.txt:../g/i/h.txt'

        is parsed as

            ['/a/b/c.txt', '/d/e/f.txt', '../g/i/h.txt']

    """

    name = "Image Stack Reader"
    category = "Input"

    globstring = InputSlot()
    SequenceAxis = InputSlot(optional=True)
    stack = OutputSlot()

    class FileOpenError(Exception):
        def __init__(self, filename):
            self.filename = filename
            self.msg = f"Unable to open file: {filename}"
            super().__init__(self.msg)

    def setupOutputs(self):
        self.fileNameList = self.expandGlobStrings(self.globstring.value)

        num_files = len(self.fileNameList)
        if len(self.fileNameList) == 0:
            self.stack.meta.NOTREADY = True
            return
        try:
            self.info = vigra.impex.ImageInfo(self.fileNameList[0])
            self.slices_per_file = vigra.impex.numberImages(
                self.fileNameList[0])
        except RuntimeError as e:
            logger.error(str(e))
            raise OpStackLoader.FileOpenError(self.fileNameList[0]) from e

        slice_shape = self.info.getShape()
        X, Y, C = slice_shape
        if self.slices_per_file == 1:
            if self.SequenceAxis.ready():
                sequence_axis = str(self.SequenceAxis.value)
                assert sequence_axis in "tzc"
            else:
                sequence_axis = "z"
            # For stacks of 2D images, we assume xy slices
            if sequence_axis == "c":
                shape = (X, Y, C * num_files)
                axistags = vigra.defaultAxistags("xyc")
            else:
                shape = (num_files, Y, X, C)
                axistags = vigra.defaultAxistags(sequence_axis + "yxc")
        else:
            if self.SequenceAxis.ready():
                sequence_axis = self.SequenceAxis.value
                assert sequence_axis in "tzc"
            else:
                sequence_axis = "t"

            if sequence_axis == "z":
                axistags = vigra.defaultAxistags("ztyxc")
            elif sequence_axis == "t":
                axistags = vigra.defaultAxistags("tzyxc")
            else:
                axistags = vigra.defaultAxistags("czyx")

            # For stacks of 3D volumes, we assume xyz blocks stacked along
            # sequence_axis
            if sequence_axis == "c":
                shape = (num_files * C, self.slices_per_file, Y, X)
            else:
                shape = (num_files, self.slices_per_file, Y, X, C)

        self.stack.meta.shape = shape
        self.stack.meta.axistags = axistags
        self.stack.meta.dtype = self.info.getDtype()

    def propagateDirty(self, slot, subindex, roi):
        assert slot == self.globstring
        # Any change to the globstring means our entire output is dirty.
        self.stack.setDirty()

    def execute(self, slot, subindex, roi, result):
        if len(self.stack.meta.shape) == 3:
            return self._execute_3d(roi, result)
        elif len(self.stack.meta.shape) == 4:
            return self._execute_4d(roi, result)
        elif len(self.stack.meta.shape) == 5:
            return self._execute_5d(roi, result)
        else:
            assert False, f"Unexpected output shape: {self.stack.meta.shape}"

    def _execute_3d(self, roi, result):
        traceLogger.debug("OpStackLoader: Execute for: " + str(roi))
        # roi is in xyc order; stacking over c
        x_start, y_start, c_start = roi.start
        x_stop, y_stop, c_stop = roi.stop

        # get C of slice
        C = self.info.getShape()[2]

        # Copy each c-slice one at a time.
        for i, fileName in enumerate(self.fileNameList[c_start // C:c_stop //
                                                       C]):
            traceLogger.debug(f"Reading image: {fileName}")
            file_shape = vigra.impex.ImageInfo(fileName).getShape()
            if self.info.getShape() != file_shape:
                raise RuntimeError("not all files have the same shape")
            images_per_file = vigra.impex.numberImages(self.fileNameList[0])
            if self.slices_per_file != images_per_file:
                raise RuntimeError("Not all files have the same number of "
                                   "slices")

            result[:, :, i * C:(i + 1) * C] = vigra.impex.readImage(fileName)[
                x_start:x_stop, y_start:y_stop, :].withAxes(*"xyc")
        return result

    def _execute_4d(self, roi, result):
        traceLogger.debug("OpStackLoader: Execute for: " + str(roi))
        # roi is in zyxc, tyxc or czyx order, depending on SequenceAxis
        z_start, y_start, x_start, c_start = roi.start
        z_stop, y_stop, x_stop, c_stop = roi.stop

        # get C of slice
        C = self.info.getShape()[2]

        # Copy each z-slice one at a time.
        for result_z, fileName in enumerate(self.fileNameList[z_start:z_stop]):
            traceLogger.debug(f"Reading image: {fileName}")
            file_shape = vigra.impex.ImageInfo(fileName).getShape()
            if self.info.getShape() != file_shape:
                raise RuntimeError("not all files have the same shape")
            images_per_file = vigra.impex.numberImages(self.fileNameList[0])
            if self.slices_per_file != images_per_file:
                raise RuntimeError("Not all files have the same number of "
                                   "slices")

            if self.stack.meta.axistags.channelIndex == 0:
                # czyx order -> read slice along z (here y)
                for result_y, y in enumerate(range(y_start, y_stop)):
                    result[result_z * C:(result_z + 1) * C, result_y,
                           ...] = vigra.impex.readImage(
                               fileName,
                               index=y)[c_start:c_stop,
                                        x_start:x_stop].withAxes(*"cyx")
            else:
                result[result_z, ...] = vigra.impex.readImage(
                    fileName)[x_start:x_stop, y_start:y_stop,
                              c_start:c_stop].withAxes(*"yxc")
        return result

    def _execute_5d(self, roi, result):
        # Technically, t and z might be switched depending on SequenceAxis.
        # Beware these variable names for t/z might be misleading.
        t_start, z_start, y_start, x_start, c_start = roi.start
        t_stop, z_stop, y_stop, x_stop, c_stop = roi.stop

        # Use *enumerated* range to get global t coords and result t coords
        for result_t, t in enumerate(range(t_start, t_stop)):
            file_name = self.fileNameList[t]
            for result_z, z in enumerate(range(z_start, z_stop)):
                img = vigra.readImage(file_name, index=z)
                result[result_t,
                       result_z, :, :, :] = img[x_start:x_stop, y_start:y_stop,
                                                c_start:c_stop].withAxes(
                                                    *"yxc")
        return result

    @staticmethod
    def expandGlobStrings(globStrings):
        ret = []
        # Parse list into separate globstrings and combine them
        for globString in globStrings.split(os.path.pathsep):
            s = globString.strip()
            ret += sorted(glob.glob(s))
        return ret
Example #19
0
class OpDetectMissing(Operator):
    """
    Sub-Operator for detection of missing image content
    """

    InputVolume = InputSlot()
    PatchSize = InputSlot(value=128)
    HaloSize = InputSlot(value=30)
    DetectionMethod = InputSlot(value="classic")
    NHistogramBins = InputSlot(value=_defaultBinSize)
    OverloadDetector = InputSlot(value="")

    # histograms: ndarray, shape: nHistograms x (NHistogramBins.value + 1)
    # the last column holds the label, i.e. {0: negative, 1: positive}
    TrainingHistograms = InputSlot()

    Output = OutputSlot()
    Detector = OutputSlot(stype=Opaque)

    ### PRIVATE class attributes ###
    _manager = None

    ### PRIVATE attributes ###
    _inputRange = (0, 255)
    _needsTraining = True
    _felzenOpts = {
        "firstSamples": 250,
        "maxRemovePerStep": 0,
        "maxAddPerStep": 250,
        "maxSamples": 1000,
        "nTrainingSteps": 4,
    }

    def __init__(self, *args, **kwargs):
        super(OpDetectMissing, self).__init__(*args, **kwargs)
        self.TrainingHistograms.setValue(_defaultTrainingHistograms())

    def propagateDirty(self, slot, subindex, roi):
        if slot == self.InputVolume:
            self.Output.setDirty(roi)

        if slot == self.TrainingHistograms:
            OpDetectMissing._needsTraining = True

        if slot == self.NHistogramBins:
            OpDetectMissing._needsTraining = OpDetectMissing._manager.has(
                self.NHistogramBins.value)

        if slot == self.PatchSize or slot == self.HaloSize:
            self.Output.setDirty()

        if slot == self.OverloadDetector:
            s = self.OverloadDetector.value
            self.loads(s)
            self.Output.setDirty()

    def setupOutputs(self):
        self.Output.meta.assignFrom(self.InputVolume.meta)
        self.Output.meta.dtype = np.uint8

        # determine range of input
        if self.InputVolume.meta.dtype == np.uint8:
            r = (0, 255)
        elif self.InputVolume.meta.dtype == np.uint16:
            r = (0, 65535)
        else:
            # FIXME hardcoded range, use np.iinfo
            r = (0, 255)
        self._inputRange = r

        self.Detector.meta.shape = (1, )

    def execute(self, slot, subindex, roi, result):

        if slot == self.Detector:
            result = self.dumps()
            return result

        # sanity check
        assert self.DetectionMethod.value in [
            "svm", "classic"
        ], "Unknown detection method '{}'".format(self.DetectionMethod.value)

        # prefill result
        resultZYXCT = vigra.taggedView(
            result, self.InputVolume.meta.axistags).withAxes(*"zyxct")

        # acquire data
        data = self.InputVolume.get(roi).wait()
        dataZYXCT = vigra.taggedView(
            data, self.InputVolume.meta.axistags).withAxes(*"zyxct")

        # walk over time and channel axes
        for t in range(dataZYXCT.shape[4]):
            for c in range(dataZYXCT.shape[3]):
                resultZYXCT[..., c, t] = self._detectMissing(dataZYXCT[..., c,
                                                                       t])

        return result

    def _detectMissing(self, data):
        """
        detects missing regions and labels each missing region with 1
        :param data: 3d data with axistags 'zyx'
        :type data: array-like
        """

        assert (data.axistags.index("z") == 0 and data.axistags.index("y") == 1
                and data.axistags.index("x") == 2
                and len(data.shape) == 3), "Data must be 3d with axis 'zyx'."

        result = np.zeros(data.shape, dtype=np.uint8)

        patchSize = self.PatchSize.value
        haloSize = self.HaloSize.value

        if patchSize is None or not patchSize > 0:
            raise ValueError("PatchSize must be a positive integer")
        if haloSize is None or haloSize < 0:
            raise ValueError("HaloSize must be a non-negative integer")

        maxZ = data.shape[0]

        # walk over slices
        for z in range(maxZ):
            patches, slices = _patchify(data[z, :, :], patchSize, haloSize)
            hists = []
            # walk over patches
            for patch in patches:
                (hist, _) = np.histogram(patch,
                                         bins=self.NHistogramBins.value,
                                         range=self._inputRange,
                                         density=True)
                hists.append(hist)
            hists = np.vstack(hists)

            pred = self.predict(hists, method=self.DetectionMethod.value)
            for i, p in enumerate(pred):
                if p > 0:
                    # patch is classified as missing
                    result[z, slices[i][0], slices[i][1]] |= 1

        return result

    def train(self, force=False):
        """
        trains with samples drawn from slot TrainingHistograms
        (retrains only if bin size is currently untrained or force is True)
        """

        # return early if unneccessary
        if not force and not OpDetectMissing._needsTraining and OpDetectMissing._manager.has(
                self.NHistogramBins.value):
            return

        # return if we don't have svms
        if not havesklearn:
            return

        logger.debug("Training for {} histogram bins ...".format(
            self.NHistogramBins.value))

        if self.DetectionMethod.value == "classic" or not havesklearn:
            # no need to train this
            return

        histograms = self.TrainingHistograms[:].wait()

        logger.debug("Finished loading histogram data of shape {}.".format(
            histograms.shape))

        assert (
            histograms.shape[1] >= self.NHistogramBins.value + 1
            and len(histograms.shape) == 2
        ), "Training data has wrong shape (expected: (n,{}), got: {}.".format(
            self.NHistogramBins.value + 1, histograms.shape)

        labels = histograms[:, self.NHistogramBins.value]
        histograms = histograms[:, :self.NHistogramBins.value]

        neg_inds = np.where(labels == 0)[0]
        pos_inds = np.setdiff1d(np.arange(len(labels)), neg_inds)

        pos = histograms[pos_inds]
        neg = histograms[neg_inds]
        npos = len(pos)
        nneg = len(neg)

        # prepare for 10-fold cross-validation
        nfolds = 10
        cfp = np.zeros((nfolds, ))
        cfn = np.zeros((nfolds, ))
        cprec = np.zeros((nfolds, ))
        crec = np.zeros((nfolds, ))
        pos_random = np.random.permutation(len(pos))
        neg_random = np.random.permutation(len(neg))

        logger.debug("Starting training with " +
                     "{} negative patches and {} positive patches...".format(
                         len(neg), len(pos)))
        self._felzenszwalbTraining(neg, pos)
        logger.debug("Finished training.")

        OpDetectMissing._needsTraining = False

    def _felzenszwalbTraining(self, negative, positive):
        """
        we want to train on a 'hard' subset of the training data, see
        FELZENSZWALB ET AL.: OBJECT DETECTION WITH DISCRIMINATIVELY TRAINED PART-BASED MODELS (4.4), PAMI 32/9
        """

        # TODO sanity checks

        n = (self.PatchSize.value + self.HaloSize.value)**2
        method = self.DetectionMethod.value

        # set options for Felzenszwalb training
        firstSamples = self._felzenOpts["firstSamples"]
        maxRemovePerStep = self._felzenOpts["maxRemovePerStep"]
        maxAddPerStep = self._felzenOpts["maxAddPerStep"]
        maxSamples = self._felzenOpts["maxSamples"]
        nTrainingSteps = self._felzenOpts["nTrainingSteps"]

        # initial choice of training samples
        (initNegative, choiceNegative, _,
         _) = _chooseRandomSubset(negative, min(firstSamples, len(negative)))
        (initPositive, choicePositive, _,
         _) = _chooseRandomSubset(positive, min(firstSamples, len(positive)))

        # setup for parallel training
        samples = [negative, positive]
        choice = [choiceNegative, choicePositive]
        S_t = [initNegative, initPositive]

        finished = [False, False]

        ### BEGIN SUBROUTINE ###
        def felzenstep(x, cache, ind):

            case = ("positive" if ind > 0 else "negative") + " set"
            pred = self.predict(x, method=method)

            hard = np.where(pred != ind)[0]
            easy = np.setdiff1d(list(range(len(x))), hard)
            logger.debug(" {}: currently {} hard and {} easy samples".format(
                case, len(hard), len(easy)))

            # shrink the cache
            easyInCache = np.intersect1d(easy, cache) if len(easy) > 0 else []
            if len(easyInCache) > 0:
                (removeFromCache, _, _, _) = _chooseRandomSubset(
                    easyInCache, min(len(easyInCache), maxRemovePerStep))
                cache = np.setdiff1d(cache, removeFromCache)
                logger.debug(" {}: shrunk the cache by {} elements".format(
                    case, len(removeFromCache)))

            # grow the cache
            temp = len(cache)
            addToCache = _chooseRandomSubset(hard, min(len(hard),
                                                       maxAddPerStep))[0]
            cache = np.union1d(cache, addToCache)
            addedHard = len(cache) - temp
            logger.debug(" {}: grown the cache by {} elements".format(
                case, addedHard))

            if len(cache) > maxSamples:
                logger.debug(
                    " {}: Cache to big, removing elements.".format(case))
                cache = _chooseRandomSubset(cache, maxSamples)[0]

            # apply the cache
            C = x[cache]

            return (C, cache, addedHard == 0)

        ### END SUBROUTINE ###

        ### BEGIN PARALLELIZATION FUNCTION ###
        def partFun(i):
            (C, newChoice, newFinished) = felzenstep(samples[i], choice[i], i)
            S_t[i] = C
            choice[i] = newChoice
            finished[i] = newFinished

        ### END PARALLELIZATION FUNCTION ###

        for k in range(nTrainingSteps):

            logger.debug("Felzenszwalb Training " +
                         "(step {}/{}): {} hard negative samples, {}".format(
                             k + 1, nTrainingSteps, len(S_t[0]), len(S_t[1])) +
                         "hard positive samples.")
            self.fit(S_t[0], S_t[1], method=method)

            pool = RequestPool()

            for i in range(len(S_t)):
                req = Request(partial(partFun, i))
                pool.add(req)

            pool.wait()
            pool.clean()

            if np.all(finished):
                # already have all hard examples in training set
                break

        self.fit(S_t[0], S_t[1], method=method)

        logger.debug(" Finished Felzenszwalb Training.")
Example #20
0
class OpStackWriter(Operator):
    name = "Stack File Writer"
    category = "Output"

    Input = InputSlot(
    )  # The last two non-singleton axes (except 'c') are the axes of the slices.
    # Re-order the axes yourself if you want an alternative slicing direction

    FilepathPattern = InputSlot(
    )  # A complete filepath including a {slice_index} member and a valid file extension.
    SliceIndexOffset = InputSlot(
        value=0)  # Added to the {slice_index} in the export filename.

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.progressSignal = OrderedSignal()

    def run_export(self):
        """
        Request the volume in slices (running in parallel), and write each slice to a separate image.
        """
        # Make the directory first if necessary
        export_dir = os.path.split(self.FilepathPattern.value)[0]
        if not os.path.exists(export_dir):
            os.makedirs(export_dir)

        # Sliceshape is the same as the input shape, except for the sliced dimension
        tagged_sliceshape = self.Input.meta.getTaggedShape()
        tagged_sliceshape[self._volume_axes[0]] = 1
        slice_shape = list(tagged_sliceshape.values())

        parallel_requests = 4

        # If ram usage info is available, make a better guess about how many requests we can launch in parallel
        ram_usage_per_requested_pixel = self.Input.meta.ram_usage_per_requested_pixel
        if ram_usage_per_requested_pixel is not None:
            pixels_per_slice = numpy.prod(slice_shape)
            if "c" in tagged_sliceshape:
                pixels_per_slice //= tagged_sliceshape["c"]

            ram_usage_per_slice = pixels_per_slice * ram_usage_per_requested_pixel

            # Fudge factor: Reduce RAM usage by a bit
            available_ram = psutil.virtual_memory().available
            available_ram *= 0.5

            parallel_requests = int(available_ram // ram_usage_per_slice)

            if parallel_requests < 1:
                raise MemoryError(
                    "Not enough RAM to export to the selected format. "
                    "Consider exporting to hdf5 (h5).")

        streamer = BigRequestStreamer(self.Input,
                                      roiFromShape(self.Input.meta.shape),
                                      slice_shape, parallel_requests)

        # Write the slices as they come in (possibly out-of-order, but probably not)
        streamer.resultSignal.subscribe(self._write_slice)
        streamer.progressSignal.subscribe(self.progressSignal)

        logger.debug(
            f"Starting Stack Export with slicing shape: {slice_shape}")
        streamer.execute()

    def setupOutputs(self):
        # If stacking XY images in Z-steps,
        #  then self._volume_axes = 'zxy'
        self._volume_axes = self.get_nonsingleton_axes()
        step_axis = self._volume_axes[0]
        max_slice = self.SliceIndexOffset.value + self.Input.meta.getTaggedShape(
        )[step_axis]
        self._max_slice_digits = int(math.ceil(math.log10(max_slice + 1)))

        # Check for errors
        assert len(self._volume_axes) == 3 or len(
            self._volume_axes
        ) == 4 and "c" in self._volume_axes[1:], (
            "Exported stacks must have exactly 3 non-singleton dimensions (other than the channel dimension).  "
            "Your stack dimensions are: {}".format(
                self.Input.meta.getTaggedShape()))

        # Test to make sure the filepath pattern includes slice index field
        filepath_pattern = self.FilepathPattern.value
        assert "123456789" in filepath_pattern.format(
            slice_index=123_456_789
        ), ("Output filepath pattern must contain the '{{slice_index}}' field for formatting.\n"
            "Your format was: {}".format(filepath_pattern))

    # No output slots...
    def execute(self, slot, subindex, roi, result):
        pass

    def propagateDirty(self, slot, subindex, roi):
        pass

    def get_nonsingleton_axes(self):
        return self.get_nonsingleton_axes_for_tagged_shape(
            self.Input.meta.getTaggedShape())

    @classmethod
    def get_nonsingleton_axes_for_tagged_shape(self, tagged_shape):
        # Find the non-singleton axes.
        # The first non-singleton axis is the step axis.
        # The last 2 non-channel non-singleton axes will be the axes of the slices.
        tagged_items = list(tagged_shape.items())
        filtered_items = [k_v for k_v in tagged_items if k_v[1] > 1]
        filtered_axes = list(zip(*filtered_items))[0]
        return filtered_axes

    def _write_slice(self, roi, slice_data):
        """
        Write the data from the given roi into a slice image.
        """
        step_axis = self._volume_axes[0]
        input_axes = self.Input.meta.getAxisKeys()
        tagged_roi = OrderedDict(list(zip(input_axes, list(zip(*roi)))))
        # e.g. tagged_roi={ 'x':(0,1), 'y':(3,4), 'z':(10,20) }
        assert tagged_roi[step_axis][1] - tagged_roi[step_axis][
            0] == 1, "Expected roi to be a single slice."
        slice_index = tagged_roi[step_axis][0] + self.SliceIndexOffset.value
        filepattern = self.FilepathPattern.value

        # If the user didn't provide custom formatting for the slice field,
        #  auto-format to include zero-padding
        if "{slice_index}" in filepattern:
            filepattern = filepattern.format(
                slice_index="{" +
                "slice_index:0{}".format(self._max_slice_digits) + "}")
        formatted_path = filepattern.format(slice_index=slice_index)

        squeezed_data = slice_data.squeeze()
        squeezed_data = vigra.taggedView(
            squeezed_data,
            vigra.defaultAxistags("".join(self._volume_axes[1:])))
        assert len(squeezed_data.shape) == len(self._volume_axes) - 1

        # logger.debug( "Writing slice image for roi: {}".format( roi ) )
        logger.debug("Writing slice: {}".format(formatted_path))
        vigra.impex.writeImage(squeezed_data, formatted_path)
Example #21
0
class OpLayerViewer(Operator):
    name = "OpLayerViewer"
    category = "top-level"

    RawInput = InputSlot()
Example #22
0
class OpStackToH5Writer(Operator):
    name = "OpStackToH5Writer"
    category = "IO"

    GlobString = InputSlot(stype="globstring")
    hdf5Group = InputSlot(stype="object")
    hdf5Path = InputSlot(stype="string")

    # Requesting the output induces the copy from stack to h5 file.
    WriteImage = OutputSlot(stype="bool")

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.progressSignal = OrderedSignal()
        self.opStackLoader = OpStackLoader(parent=self)
        self.opStackLoader.globstring.connect(self.GlobString)

    def setupOutputs(self):
        self.WriteImage.meta.shape = (1, )
        self.WriteImage.meta.dtype = object

    def propagateDirty(self, slot, subindex, roi):
        # Any change to our inputs means we're dirty
        assert slot == self.GlobString or slot == self.hdf5Group or slot == self.hdf5Path
        self.WriteImage.setDirty(slice(None))

    def execute(self, slot, subindex, roi, result):
        if not self.opStackLoader.fileNameList:
            raise Exception(
                f"Didn't find any files to combine.  Is the glob string valid?  "
                f"globstring = {self.GlobString.value}")

        # Copy the data image-by-image
        stackTags = self.opStackLoader.stack.meta.axistags
        zAxis = stackTags.index("z")
        dataShape = self.opStackLoader.stack.meta.shape
        numImages = self.opStackLoader.stack.meta.shape[zAxis]
        axistags = self.opStackLoader.stack.meta.axistags
        dtype = self.opStackLoader.stack.meta.dtype
        if isinstance(dtype, numpy.dtype):
            # Make sure we're dealing with a type (e.g. numpy.float64),
            #  not a numpy.dtype
            dtype = dtype.type

        index_ = axistags.index("c")
        if index_ >= len(dataShape):
            numChannels = 1
        else:
            numChannels = dataShape[index_]

        # Set up our chunk shape: Aim for a cube that's roughly 300k in size
        dtypeBytes = dtype().nbytes
        cubeDim = math.pow(300_000 // (numChannels * dtypeBytes), (1 / 3.0))
        cubeDim = int(cubeDim)

        chunkDims = {}
        chunkDims["t"] = 1
        chunkDims["x"] = cubeDim
        chunkDims["y"] = cubeDim
        chunkDims["z"] = cubeDim
        chunkDims["c"] = numChannels

        # h5py guide to chunking says chunks of 300k or less "work best"
        assert chunkDims["x"] * chunkDims["y"] * chunkDims[
            "z"] * numChannels * dtypeBytes <= 300_000

        chunkShape = ()
        for i in range(len(dataShape)):
            axisKey = axistags[i].key
            # Chunk shape can't be larger than the data shape
            chunkShape += (min(chunkDims[axisKey], dataShape[i]), )

        # Create the dataset
        internalPath = self.hdf5Path.value
        internalPath = internalPath.replace("\\", "/")  # Windows fix
        group = self.hdf5Group.value
        if internalPath in group:
            del group[internalPath]

        data = group.create_dataset(
            internalPath,
            # compression='gzip',
            # compression_opts=4,
            shape=dataShape,
            dtype=dtype,
            chunks=chunkShape,
        )
        # Now copy each image
        self.progressSignal(0)

        for z in range(numImages):
            # Ask for an entire z-slice (exactly one whole image from the stack)
            slicing = [slice(None)] * len(stackTags)
            slicing[zAxis] = slice(z, z + 1)
            data[tuple(slicing)] = self.opStackLoader.stack[slicing].wait()
            self.progressSignal(z * 100 // numImages)

        data.attrs["axistags"] = axistags.toJSON()

        # We're done
        result[...] = True

        self.progressSignal(100)

        return result
Example #23
0
class OpPredictionPipeline(OpPredictionPipelineNoCache):
    """
    This operator extends the cacheless prediction pipeline above with additional outputs for the GUI.
    (It uses caches for these outputs, and has an extra input for cached features.)
    """
    FreezePredictions = InputSlot()
    CachedFeatureImages = InputSlot()

    PredictionProbabilities = OutputSlot()
    CachedPredictionProbabilities = OutputSlot()
    PredictionProbabilityChannels = OutputSlot(level=1)
    SegmentationChannels = OutputSlot(level=1)
    UncertaintyEstimate = OutputSlot()

    def __init__(self, *args, **kwargs):
        super(OpPredictionPipeline, self).__init__(*args, **kwargs)

        # Random forest prediction using CACHED features.
        self.predict = OpClassifierPredict(parent=self)
        self.predict.name = "OpClassifierPredict"
        self.predict.Classifier.connect(self.Classifier)
        self.predict.Image.connect(self.CachedFeatureImages)
        self.predict.PredictionMask.connect(self.PredictionMask)
        self.predict.LabelsCount.connect(self.NumClasses)
        self.PredictionProbabilities.connect(self.predict.PMaps)

        # Prediction cache for the GUI
        self.prediction_cache_gui = OpSlicedBlockedArrayCache(parent=self)
        self.prediction_cache_gui.name = "prediction_cache_gui"
        self.prediction_cache_gui.inputs["fixAtCurrent"].connect(
            self.FreezePredictions)
        self.prediction_cache_gui.inputs["Input"].connect(self.predict.PMaps)
        self.CachedPredictionProbabilities.connect(
            self.prediction_cache_gui.Output)

        # Also provide each prediction channel as a separate layer (for the GUI)
        self.opPredictionSlicer = OpMultiArraySlicer2(parent=self)
        self.opPredictionSlicer.name = "opPredictionSlicer"
        self.opPredictionSlicer.Input.connect(self.prediction_cache_gui.Output)
        self.opPredictionSlicer.AxisFlag.setValue('c')
        self.PredictionProbabilityChannels.connect(
            self.opPredictionSlicer.Slices)

        self.opSegmentor = OpMaxChannelIndicatorOperator(parent=self)
        self.opSegmentor.Input.connect(self.prediction_cache_gui.Output)

        self.opSegmentationSlicer = OpMultiArraySlicer2(parent=self)
        self.opSegmentationSlicer.name = "opSegmentationSlicer"
        self.opSegmentationSlicer.Input.connect(self.opSegmentor.Output)
        self.opSegmentationSlicer.AxisFlag.setValue('c')
        self.SegmentationChannels.connect(self.opSegmentationSlicer.Slices)

        # Create a layer for uncertainty estimate
        self.opUncertaintyEstimator = OpEnsembleMargin(parent=self)
        self.opUncertaintyEstimator.Input.connect(
            self.prediction_cache_gui.Output)

        # Cache the uncertainty so we get zeros for uncomputed points
        self.opUncertaintyCache = OpSlicedBlockedArrayCache(parent=self)
        self.opUncertaintyCache.name = "opUncertaintyCache"
        self.opUncertaintyCache.Input.connect(
            self.opUncertaintyEstimator.Output)
        self.opUncertaintyCache.fixAtCurrent.connect(self.FreezePredictions)
        self.UncertaintyEstimate.connect(self.opUncertaintyCache.Output)

    def setupOutputs(self):
        # Set the blockshapes for each input image separately, depending on which axistags it has.
        axisOrder = [tag.key for tag in self.FeatureImages.meta.axistags]

        blockDimsX = {
            't': (1, 1),
            'z': (128, 256),
            'y': (128, 256),
            'x': (1, 1),
            'c': (100, 100)
        }

        blockDimsY = {
            't': (1, 1),
            'z': (128, 256),
            'y': (1, 1),
            'x': (128, 256),
            'c': (100, 100)
        }

        blockDimsZ = {
            't': (1, 1),
            'z': (1, 1),
            'y': (128, 256),
            'x': (128, 256),
            'c': (100, 100)
        }

        innerBlockShapeX = tuple(blockDimsX[k][0] for k in axisOrder)
        outerBlockShapeX = tuple(blockDimsX[k][1] for k in axisOrder)

        innerBlockShapeY = tuple(blockDimsY[k][0] for k in axisOrder)
        outerBlockShapeY = tuple(blockDimsY[k][1] for k in axisOrder)

        innerBlockShapeZ = tuple(blockDimsZ[k][0] for k in axisOrder)
        outerBlockShapeZ = tuple(blockDimsZ[k][1] for k in axisOrder)

        self.prediction_cache_gui.inputs["innerBlockShape"].setValue(
            (innerBlockShapeX, innerBlockShapeY, innerBlockShapeZ))
        self.prediction_cache_gui.inputs["outerBlockShape"].setValue(
            (outerBlockShapeX, outerBlockShapeY, outerBlockShapeZ))

        self.opUncertaintyCache.inputs["innerBlockShape"].setValue(
            (innerBlockShapeX, innerBlockShapeY, innerBlockShapeZ))
        self.opUncertaintyCache.inputs["outerBlockShape"].setValue(
            (outerBlockShapeX, outerBlockShapeY, outerBlockShapeZ))
Example #24
0
class OpImageReader(Operator):
    """
    Read an image using vigra.impex.readImage().
    Supports 2D images (output as xyc) and also multi-page tiffs (output as zyxc).
    """

    Filename = InputSlot(stype="filestring")
    Image = OutputSlot()

    def setupOutputs(self):
        filename = self.Filename.value

        info = vigra.impex.ImageInfo(filename)
        assert [tag.key for tag in info.getAxisTags()] == ["x", "y", "c"]

        shape_xyc = info.getShape()
        shape_yxc = (shape_xyc[1], shape_xyc[0], shape_xyc[2])

        self.Image.meta.dtype = info.getDtype()
        self.Image.meta.prefer_2d = True

        numImages = vigra.impex.numberImages(filename)
        if numImages == 1:
            # For 2D, we use order yxc.
            self.Image.meta.shape = shape_yxc
            v_tags = info.getAxisTags()
            self.Image.meta.axistags = vigra.AxisTags(
                [v_tags[k] for k in "yxc"])
        else:
            # For 3D, we use zyxc
            # Insert z-axis shape
            shape_zyxc = (numImages, ) + shape_yxc
            self.Image.meta.shape = shape_zyxc

            # Insert z tag
            z_tag = vigra.defaultAxistags("z")[0]
            tags_xyc = [tag for tag in info.getAxisTags()]
            tags_zyxc = [z_tag] + list(reversed(tags_xyc[:-1])) + tags_xyc[-1:]
            self.Image.meta.axistags = vigra.AxisTags(tags_zyxc)

    def execute(self, slot, subindex, rroi, result):
        filename = self.Filename.value

        if "z" in self.Image.meta.getAxisKeys():
            # Copy from each image slice into the corresponding slice of the result.
            roi_zyxc = numpy.array([rroi.start, rroi.stop])
            for z_global, z_result in zip(list(range(*roi_zyxc[:, 0])),
                                          list(range(result.shape[0]))):
                full_slice = vigra.impex.readImage(filename, index=z_global)
                full_slice = full_slice.transpose(1, 0, 2)  # xyc -> yxc
                assert full_slice.shape == self.Image.meta.shape[1:]
                result[z_result] = full_slice[roiToSlice(*roi_zyxc[:, 1:])]
        else:
            full_slice = vigra.impex.readImage(filename).transpose(
                1, 0, 2)  # xyc -> yxc
            assert full_slice.shape == self.Image.meta.shape
            roi_yxc = numpy.array([rroi.start, rroi.stop])
            result[:] = full_slice[roiToSlice(*roi_yxc)]
        return result

    def propagateDirty(self, slot, subindex, roi):
        if slot == self.Filename:
            self.Image.setDirty()
        else:
            assert False, "Unknown dirty input slot."
Example #25
0
class OpDivisionFeatures(Operator):
    """Computes division features on a 5D volume."""
    LabelVolume = InputSlot()
    DivisionFeatureNames = InputSlot(rtype=List, stype=Opaque)
    RegionFeaturesVigra = InputSlot()

    BlockwiseDivisionFeatures = OutputSlot()

    def __init__(self, *args, **kwargs):
        super(OpDivisionFeatures, self).__init__(*args, **kwargs)

    def setupOutputs(self):
        taggedShape = self.LabelVolume.meta.getTaggedShape()

        if set(taggedShape.keys()) != set('txyzc'):
            raise Exception("Input volumes must have txyzc axes.")

        self.BlockwiseDivisionFeatures.meta.shape = tuple([taggedShape['t']])
        self.BlockwiseDivisionFeatures.meta.axistags = vigra.defaultAxistags(
            "t")
        self.BlockwiseDivisionFeatures.meta.dtype = object

        ndim = 3
        if np.any(list(taggedShape.get(k, 0) == 1 for k in "xyz")):
            ndim = 2

        self.featureManager = FeatureManager(
            scales=config.image_scale,
            n_best=config.n_best_successors,
            com_name_cur=config.com_name_cur,
            com_name_next=config.com_name_next,
            size_name=config.size_name,
            delim=config.delim,
            template_size=config.template_size,
            ndim=ndim,
            size_filter=config.size_filter,
            squared_distance_default=config.squared_distance_default)

    def execute(self, slot, subindex, roi, result):
        assert len(roi.start) == len(roi.stop) == len(
            self.BlockwiseDivisionFeatures.meta.shape)
        assert slot == self.BlockwiseDivisionFeatures
        taggedShape = self.LabelVolume.meta.getTaggedShape()
        timeIndex = list(taggedShape.keys()).index('t')

        import time
        start = time.perf_counter()

        vroi_start = len(self.LabelVolume.meta.shape) * [
            0,
        ]
        vroi_stop = list(self.LabelVolume.meta.shape)

        assert len(roi.start) == 1
        froi_start = roi.start[0]
        froi_stop = roi.stop[0]
        vroi_stop[timeIndex] = roi.stop[0]

        assert timeIndex == 0
        vroi_start[timeIndex] = roi.start[0]
        if roi.stop[0] + 1 < self.LabelVolume.meta.shape[timeIndex]:
            vroi_stop[timeIndex] = roi.stop[0] + 1
            froi_stop = roi.stop[0] + 1
        vroi = [
            slice(vroi_start[i], vroi_stop[i]) for i in range(len(vroi_start))
        ]

        feats = self.RegionFeaturesVigra[slice(froi_start, froi_stop)].wait()
        labelVolume = self.LabelVolume[vroi].wait()
        divisionFeatNames = self.DivisionFeatureNames[(
        )].wait()[config.features_division_name]

        for t in range(roi.stop[0] - roi.start[0]):
            result[t] = {}
            feats_cur = feats[t][config.features_vigra_name]
            if t + 1 < froi_stop - froi_start:
                feats_next = feats[t + 1][config.features_vigra_name]

                img_next = labelVolume[t + 1, ...]
            else:
                feats_next = None
                img_next = None
            res = self.featureManager.computeFeatures_at(
                feats_cur, feats_next, img_next, divisionFeatNames)
            result[t][config.features_division_name] = res

        stop = time.perf_counter()
        logger.debug(
            "TIMING: computing division features took {:.3f}s".format(stop -
                                                                      start))
        return result

    def propagateDirty(self, slot, subindex, roi):
        if slot is self.DivisionFeatureNames:
            self.BlockwiseDivisionFeatures.setDirty(slice(None))
        elif slot is self.RegionFeaturesVigra:
            self.BlockwiseDivisionFeatures.setDirty(roi)
        else:
            axes = list(self.LabelVolume.meta.getTaggedShape().keys())
            dirtyStart = collections.OrderedDict(list(zip(axes, roi.start)))
            dirtyStop = collections.OrderedDict(list(zip(axes, roi.stop)))

            # Remove the spatial and channel dims (keep t, if present)
            del dirtyStart['x']
            del dirtyStart['y']
            del dirtyStart['z']
            del dirtyStart['c']

            del dirtyStop['x']
            del dirtyStop['y']
            del dirtyStop['z']
            del dirtyStop['c']

            self.BlockwiseDivisionFeatures.setDirty(list(dirtyStart.values()),
                                                    list(dirtyStop.values()))
Example #26
0
class OpH5N5WriterBigDataset(Operator):
    name = "H5 and N5 File Writer BigDataset"
    category = "Output"

    h5N5File = InputSlot(
    )  # Must be an already-open hdf5File/n5File (or group) for writing to
    h5N5Path = InputSlot()
    Image = InputSlot()
    # h5py uses single-threaded gzip comression, which really slows down export.
    CompressionEnabled = InputSlot(value=False)
    BatchSize = InputSlot(optional=True)

    WriteImage = OutputSlot()

    loggingName = __name__ + ".OpH5N5WriterBigDataset"
    logger = logging.getLogger(loggingName)
    traceLogger = logging.getLogger("TRACE." + loggingName)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.progressSignal = OrderedSignal()
        self.d = None
        self.f = None

    def cleanUp(self):
        super().cleanUp()
        # Discard the reference to the dataset, to ensure that the file can be closed.
        self.d = None
        self.f = None
        self.progressSignal.clean()

    def setupOutputs(self):
        self.outputs["WriteImage"].meta.shape = (1, )
        self.outputs["WriteImage"].meta.dtype = object

        self.f = self.inputs["h5N5File"].value
        h5N5Path = self.inputs["h5N5Path"].value

        # On windows, there may be backslashes.
        h5N5Path = h5N5Path.replace("\\", "/")

        h5N5GroupName, datasetName = os.path.split(h5N5Path)
        if h5N5GroupName == "":
            g = self.f
        else:
            if h5N5GroupName in self.f:
                g = self.f[h5N5GroupName]
            else:
                g = self.f.create_group(h5N5GroupName)

        dataShape = self.Image.meta.shape
        self.logger.info(f"Data shape: {dataShape}")

        dtype = self.Image.meta.dtype
        if isinstance(dtype, numpy.dtype):
            # Make sure we're dealing with a type (e.g. numpy.float64),
            # not a numpy.dtype
            dtype = dtype.type
        # Set up our chunk shape: Aim for a cube that's roughly 512k in size
        dtypeBytes = dtype().nbytes

        tagged_maxshape = self.Image.meta.getTaggedShape()
        if "t" in tagged_maxshape:
            # Assume that chunks should not span multiple t-slices,
            # and channels are often handled separately, too.
            tagged_maxshape["t"] = 1

        if "c" in tagged_maxshape:
            tagged_maxshape["c"] = 1

        self.chunkShape = determineBlockShape(list(tagged_maxshape.values()),
                                              512_000.0 / dtypeBytes)

        if datasetName in list(g.keys()):
            del g[datasetName]
        kwargs = {
            "shape": dataShape,
            "dtype": dtype,
            "chunks": self.chunkShape
        }
        if self.CompressionEnabled.value:
            kwargs[
                "compression"] = "gzip"  # <-- Would be nice to use lzf compression here, but that is h5py-specific.
            if isinstance(self.f, h5py.File):
                kwargs[
                    "compression_opts"] = 1  # <-- Optimize for speed, not disk space.
            else:  # z5py has uses different names here
                kwargs["level"] = 1  # <-- Optimize for speed, not disk space.
        else:
            if isinstance(self.f, z5py.N5File
                          ):  # n5 uses gzip level 5 as default compression.
                kwargs["compression"] = "raw"

        self.d = g.create_dataset(datasetName, **kwargs)

        if self.Image.meta.drange is not None:
            self.d.attrs["drange"] = self.Image.meta.drange
        if self.Image.meta.display_mode is not None:
            self.d.attrs["display_mode"] = self.Image.meta.display_mode

    def execute(self, slot, subindex, rroi, result):
        self.progressSignal(0)

        # Save the axistags as a dataset attribute
        self.d.attrs["axistags"] = self.Image.meta.axistags.toJSON()

        def handle_block_result(roi, data):
            slicing = roiToSlice(*roi)
            if data.flags.c_contiguous:
                self.d.write_direct(data.view(numpy.ndarray), dest_sel=slicing)
            else:
                self.d[slicing] = data

        batch_size = None
        if self.BatchSize.ready():
            batch_size = self.BatchSize.value
        requester = BigRequestStreamer(self.Image,
                                       roiFromShape(self.Image.meta.shape),
                                       batchSize=batch_size)
        requester.resultSignal.subscribe(handle_block_result)
        requester.progressSignal.subscribe(self.progressSignal)
        requester.execute()

        # Be paranoid: Flush right now.
        if isinstance(self.f, h5py.File):
            self.f.file.flush()  # not available in z5py

        # We're finished.
        result[0] = True

        self.progressSignal(100)

    def propagateDirty(self, slot, subindex, roi):
        # The output from this operator isn't generally connected to other operators.
        # If someone is using it that way, we'll assume that the user wants to know that
        # the input image has become dirty and may need to be written to disk again.
        self.WriteImage.setDirty(slice(None))
Example #27
0
class OpStreamingHdf5Reader(Operator):
    """
    The top-level operator for the data selection applet.
    """
    name = "OpStreamingHdf5Reader"
    category = "Reader"

    # The project hdf5 File object (already opened)
    Hdf5File = InputSlot(stype='hdf5File')

    # The internal path for project-local datasets
    InternalPath = InputSlot(stype='string')

    # Output data
    OutputImage = OutputSlot()
    
    class DatasetReadError(Exception):
        def __init__(self, internalPath):
            self.internalPath = internalPath
            self.msg = "Unable to open Hdf5 dataset: {}".format( internalPath )
            super(OpStreamingHdf5Reader.DatasetReadError, self).__init__( self.msg )

    def __init__(self, *args, **kwargs):
        super(OpStreamingHdf5Reader, self).__init__(*args, **kwargs)
        self._hdf5File = None

    def setupOutputs(self):
        # Read the dataset meta-info from the HDF5 dataset
        self._hdf5File = self.Hdf5File.value
        internalPath = self.InternalPath.value

        if internalPath not in self._hdf5File:
            raise OpStreamingHdf5Reader.DatasetReadError(internalPath)

        dataset = self._hdf5File[internalPath]

        try:
            # Read the axistags property without actually importing the data
            axistagsJson = self._hdf5File[internalPath].attrs['axistags'] # Throws KeyError if 'axistags' can't be found
            axistags = vigra.AxisTags.fromJSON(axistagsJson)
        except KeyError:
            # No axistags found.
            ndims = len(dataset.shape)
            assert ndims != 0, "OpStreamingHdf5Reader: Zero-dimensional datasets not supported."
            assert ndims != 1, "OpStreamingHdf5Reader: Support for 1-D data not yet supported"
            assert ndims <= 5, "OpStreamingHdf5Reader: No support for data with more than 5 dimensions."

            axisorders = { 2 : 'xy',
                           3 : 'xyz',
                           4 : 'xyzc',
                           5 : 'txyzc' }
    
            axisorder = axisorders[ndims]
            if ndims == 3 and dataset.shape[2] <= 4:
                # Special case: If the 3rd dim is small, assume it's 'c', not 'z'
                axisorder = 'xyc'

            axistags = vigra.defaultAxistags(axisorder)

        assert len(axistags) == len( dataset.shape ),\
            "Mismatch between shape {} and axisorder {}".format( dataset.shape, axisorder )

        # Configure our slot meta-info
        self.OutputImage.meta.dtype = dataset.dtype.type
        self.OutputImage.meta.shape = dataset.shape
        self.OutputImage.meta.axistags = axistags

        # If the dataset specifies a datarange, add it to the slot metadata
        if 'drange' in self._hdf5File[internalPath].attrs:
            self.OutputImage.meta.drange = tuple( self._hdf5File[internalPath].attrs['drange'] )
        
        total_volume = numpy.prod(numpy.array(self._hdf5File[internalPath].shape))
        if not self._hdf5File[internalPath].chunks and total_volume > 1e8:
            self.OutputImage.meta.inefficient_format = True
            logger.warn("This dataset ({}{}) is NOT chunked.  "
                        "Performance for 3D access patterns will be bad!"
                        .format( self._hdf5File.filename, internalPath ))

    def execute(self, slot, subindex, roi, result):
        t = time.time()
        assert self._hdf5File is not None
        # Read the desired data directly from the hdf5File
        key = roi.toSlice()
        hdf5File = self._hdf5File
        internalPath = self.InternalPath.value

        timer = None
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug("Reading HDF5 block: [{}, {}]".format( roi.start, roi.stop ))
            timer = Timer()
            timer.unpause()        

        if result.flags.c_contiguous:
            hdf5File[internalPath].read_direct( result[...], key )
        else:
            result[...] = hdf5File[internalPath][key]
        if logger.getEffectiveLevel() >= logging.DEBUG:
            t = 1000.0*(time.time()-t)
            logger.debug("took %f msec." % t)

        if timer:
            timer.pause()
            logger.debug("Completed HDF5 read in {} seconds: [{}, {}]".format( timer.seconds(), roi.start, roi.stop ))            

    def propagateDirty(self, slot, subindex, roi):
        if slot == self.Hdf5File or slot == self.InternalPath:
            self.OutputImage.setDirty( slice(None) )
Example #28
0
class OpBlockedArrayCache(Operator, ManagedBlockedCache):
    """
    A blockwise array cache designed to replace the old OpBlockedArrayCache.
    Instead of a monolithic implementation, this operator is a small pipeline of three simple operators.

    The actual caching of data is handled by an unblocked cache, so the "blocked" functionality is
    implemented via separate "splitting" operator that comes after the cache.
    Also, the "fixAtCurrent" feature is implemented in a special operator, which comes before the cache.
    """

    fixAtCurrent = InputSlot(value=False)
    Input = InputSlot(allow_mask=True)
    # BlockShape = InputSlot()
    BlockShape = InputSlot(
        optional=True
    )  # If 'None' is present, those items will be treated as max for the dimension.
    # If not provided, will be set to Input.meta.shape
    BypassModeEnabled = InputSlot(value=False)
    CompressionEnabled = InputSlot(value=False)

    Output = OutputSlot(allow_mask=True)
    CleanBlocks = OutputSlot(
    )  # A list of slicings indicating which blocks are stored in the cache and clean.

    def __init__(self, *args, **kwargs):
        super(OpBlockedArrayCache, self).__init__(*args, **kwargs)

        # SCHEMATIC WHEN BypassModeEnabled == False:
        #
        # Input ---------> opCacheFixer -> opSimpleBlockedArrayCache -> (indirectly via execute) -> Output
        #                 /               /
        # fixAtCurrent --                /
        #                               /
        # BlockShape -------------------

        # SCHEMATIC WHEN BypassModeEnabled == True:
        #
        # Input --> (indirectly via execute) -> Output

        self._opCacheFixer = OpCacheFixer(parent=self)
        self._opCacheFixer.Input.connect(self.Input)
        self._opCacheFixer.fixAtCurrent.connect(self.fixAtCurrent)

        self._opSimpleBlockedArrayCache = OpSimpleBlockedArrayCache(
            parent=self)
        self._opSimpleBlockedArrayCache.Input.connect(
            self._opCacheFixer.Output)
        self._opSimpleBlockedArrayCache.CompressionEnabled.connect(
            self.CompressionEnabled)
        self._opSimpleBlockedArrayCache.Input.connect(
            self._opCacheFixer.Output)
        self._opSimpleBlockedArrayCache.BlockShape.connect(self.BlockShape)
        self._opSimpleBlockedArrayCache.BypassModeEnabled.connect(
            self.BypassModeEnabled)
        self.CleanBlocks.connect(self._opSimpleBlockedArrayCache.CleanBlocks)
        self.Output.connect(self._opSimpleBlockedArrayCache.Output)

        # Instead of connecting our Output directly to our internal pipeline,
        # We manually forward the data via the execute() function,
        #  which allows us to implement a bypass for the internal pipeline if Enabled
        # self.Output.connect( self._opSimpleBlockedArrayCache.Output )

        # Since we didn't directly connect the pipeline to our output, explicitly forward dirty notifications
        self._opSimpleBlockedArrayCache.Output.notifyDirty(
            lambda slot, roi: self.Output.setDirty(roi.start, roi.stop))

        # This member is used by tests that check RAM usage.
        self.setup_ram_context = RamMeasurementContext()
        self.registerWithMemoryManager()

    def setupOutputs(self):
        if not self.BlockShape.ready():
            self.BlockShape.setValue(self.Input.meta.shape)
        # Copy metadata from the internal pipeline to the output
        self.Output.meta.assignFrom(
            self._opSimpleBlockedArrayCache.Output.meta)

    def execute(self, slot, subindex, roi, result):
        assert False, "Shouldn't get here"

    def propagateDirty(self, slot, subindex, roi):
        pass

    def setInSlot(self, slot, subindex, key, value):
        pass  # Nothing to do here: Input is connected to an internal operator

    # ======= mimic cache interface for wrapping operators =======

    def usedMemory(self):
        return self._opSimpleBlockedArrayCache.usedMemory()

    def fractionOfUsedMemoryDirty(self):
        # dirty memory is discarded immediately
        return self._opSimpleBlockedArrayCache.fractionOfUsedMemoryDirty()

    def lastAccessTime(self):
        return self._opSimpleBlockedArrayCache.lastAccessTime()

    def getBlockAccessTimes(self):
        return self._opSimpleBlockedArrayCache.getBlockAccessTimes()

    def freeMemory(self):
        return self._opSimpleBlockedArrayCache.freeMemory()

    def freeBlock(self, key):
        return self._opSimpleBlockedArrayCache.freeBlock(key)

    def freeDirtyMemory(self):
        return self._opSimpleBlockedArrayCache.freeDirtyMemory()

    def generateReport(self, report):
        self._opSimpleBlockedArrayCache.generateReport(report)
        child = copy.copy(report)
        super(OpBlockedArrayCache, self).generateReport(report)
        report.children.append(child)
Example #29
0
class OpAnisotropicGaussianSmoothing(Operator):
    Input = InputSlot()
    Sigmas = InputSlot( value={'z':1.0, 'y':1.0, 'x':1.0} )
    
    Output = OutputSlot()

    def setupOutputs(self):
        
        self.Output.meta.assignFrom(self.Input.meta)
        #if there is a time of dim 1, output won't have that
        timeIndex = self.Output.meta.axistags.index('t')
        if timeIndex<len(self.Output.meta.shape):
            newshape = list(self.Output.meta.shape)
            newshape.pop(timeIndex)
            self.Output.meta.shape = tuple(newshape)
            del self.Output.meta.axistags[timeIndex]
        self.Output.meta.dtype = numpy.float32 # vigra gaussian only supports float32
        self._sigmas = self.Sigmas.value
        assert isinstance(self.Sigmas.value, dict), "Sigmas slot expects a dict"
        assert set(self._sigmas.keys()) == set('zyx'), "Sigmas slot expects three key-value pairs for z,y,x"
        print("Assigning output: {} ====> {}".format(self.Input.meta.getTaggedShape(), self.Output.meta.getTaggedShape()))
        #self.Output.setDirty( slice(None) )
    
    def execute(self, slot, subindex, roi, result):
        assert all(roi.stop <= self.Input.meta.shape), "Requested roi {} is too large for this input image of shape {}.".format( roi, self.Input.meta.shape )
        # Determine how much input data we'll need, and where the result will be relative to that input roi
        inputRoi, computeRoi = self._getInputComputeRois(roi)        
        # Obtain the input data 
        with Timer() as resultTimer:
            data = self.Input( *inputRoi ).wait()
        logger.debug("Obtaining input data took {} seconds for roi {}".format( resultTimer.seconds(), inputRoi ))
        
        zIndex = self.Input.meta.axistags.index('z') if self.Input.meta.axistags.index('z')<len(self.Input.meta.shape) else None
        xIndex = self.Input.meta.axistags.index('x')
        yIndex = self.Input.meta.axistags.index('y')
        cIndex = self.Input.meta.axistags.index('c') if self.Input.meta.axistags.index('c')<len(self.Input.meta.shape) else None
        
        # Must be float32
        if data.dtype != numpy.float32:
            data = data.astype(numpy.float32)
        
        axiskeys = self.Input.meta.getAxisKeys()
        spatialkeys = filter( lambda k: k in 'zyx', axiskeys )

        # we need to remove a singleton z axis, otherwise we get 
        # 'kernel longer than line' errors
        reskey = [slice(None, None, None)]*len(self.Input.meta.shape)
        reskey[cIndex]=0
        if zIndex and self.Input.meta.shape[zIndex]==1:
            removedZ = True
            data = data.reshape((data.shape[xIndex], data.shape[yIndex]))
            reskey[zIndex]=0
            spatialkeys = filter( lambda k: k in 'yx', axiskeys )
        else:
            removedZ = False

        sigma = map(self._sigmas.get, spatialkeys)
        #Check if we need to smooth
        if any([x < 0.1 for x in sigma]):
            if removedZ:
                resultYX = vigra.taggedView(result, axistags="".join(axiskeys))
                resultYX = resultYX.withAxes(*'yx')
                resultYX[:] = data
            else:
                result[:] = data
            return result

        # Smooth the input data
        smoothed = vigra.filters.gaussianSmoothing(data, sigma, window_size=2.0, roi=computeRoi, out=result[tuple(reskey)]) # FIXME: Assumes channel is last axis
        expectedShape = tuple(TinyVector(computeRoi[1]) - TinyVector(computeRoi[0]))
        assert tuple(smoothed.shape) == expectedShape, "Smoothed data shape {} didn't match expected shape {}".format( smoothed.shape, roi.stop - roi.start )
        
        return result
    
    def _getInputComputeRois(self, roi):
        axiskeys = self.Input.meta.getAxisKeys()
        spatialkeys = filter( lambda k: k in 'zyx', axiskeys )
        sigma = map( self._sigmas.get, spatialkeys )
        inputSpatialShape = self.Input.meta.getTaggedShape()
        spatialRoi = ( TinyVector(roi.start), TinyVector(roi.stop) )
        tIndex = None
        cIndex = None
        zIndex = None
        if 'c' in inputSpatialShape:
            del inputSpatialShape['c']
            cIndex = axiskeys.index('c')
        if 't' in inputSpatialShape.keys():
            assert inputSpatialShape['t'] == 1
            tIndex = axiskeys.index('t')

        if 'z' in inputSpatialShape.keys() and inputSpatialShape['z']==1:
            #2D image, avoid kernel longer than line exception
            del inputSpatialShape['z']
            zIndex = axiskeys.index('z')
            
        indices = [tIndex, cIndex, zIndex]
        indices = sorted(indices, reverse=True)
        for ind in indices:
            if ind:
                spatialRoi[0].pop(ind)
                spatialRoi[1].pop(ind)
        
        inputSpatialRoi = enlargeRoiForHalo(spatialRoi[0], spatialRoi[1], inputSpatialShape.values(), sigma, window=2.0)
        
        # Determine the roi within the input data we're going to request
        inputRoiOffset = spatialRoi[0] - inputSpatialRoi[0]
        computeRoi = (inputRoiOffset, inputRoiOffset + spatialRoi[1] - spatialRoi[0])
        
        # For some reason, vigra.filters.gaussianSmoothing will raise an exception if this parameter doesn't have the correct integer type.
        # (for example, if we give it as a numpy.ndarray with dtype=int64, we get an error)
        computeRoi = ( tuple(map(int, computeRoi[0])),
                       tuple(map(int, computeRoi[1])) )
        
        inputRoi = (list(inputSpatialRoi[0]), list(inputSpatialRoi[1]))
        for ind in reversed(indices):
            if ind:
                inputRoi[0].insert( ind, 0 )
                inputRoi[1].insert( ind, 1 )

        return inputRoi, computeRoi
        
    def propagateDirty(self, slot, subindex, roi):
        if slot == self.Input:
            # Halo calculation is bidirectional, so we can re-use the function that computes the halo during execute()
            inputRoi, _ = self._getInputComputeRois(roi)
            self.Output.setDirty( inputRoi[0], inputRoi[1] )
        elif slot == self.Sigmas:
            self.Output.setDirty( slice(None) )
        else:
            assert False, "Unknown input slot: {}".format( slot.name )
Example #30
0
class OpDataSelection(Operator):
    """
    The top-level operator for the data selection applet, implemented as a single-image operator.
    The applet uses an OperatorWrapper to make it suitable for use in a workflow.
    """
    name = "OpDataSelection"
    category = "Top-level"

    SupportedExtensions = OpInputDataReader.SupportedExtensions

    # Inputs
    ProjectFile = InputSlot(
        stype='object',
        optional=True)  #: The project hdf5 File object (already opened)
    ProjectDataGroup = InputSlot(
        stype='string', optional=True
    )  #: The internal path to the hdf5 group where project-local datasets are stored within the project file
    WorkingDirectory = InputSlot(
        stype='filestring'
    )  #: The filesystem directory where the project file is located
    Dataset = InputSlot(stype='object')  #: A DatasetInfo object

    # Outputs
    Image = OutputSlot()  #: The output image
    AllowLabels = OutputSlot(
        stype='bool'
    )  #: A bool indicating whether or not this image can be used for training

    _NonTransposedImage = OutputSlot(
    )  #: The output slot, in the data's original axis ordering (regardless of forceAxisOrder)

    ImageName = OutputSlot(stype='string')  #: The name of the output image

    class InvalidDimensionalityError(Exception):
        """Raised if the user tries to replace the dataset with a new one of differing dimensionality."""
        def __init__(self, message):
            super(OpDataSelection.InvalidDimensionalityError, self).__init__()
            self.message = message

        def __str__(self):
            return self.message

    def __init__(self, forceAxisOrder=False, *args, **kwargs):
        super(OpDataSelection, self).__init__(*args, **kwargs)
        self.forceAxisOrder = forceAxisOrder
        self._opReaders = []

        # If the gui calls disconnect() on an input slot without replacing it with something else,
        #  we still need to clean up the internal operator that was providing our data.
        self.ProjectFile.notifyUnready(self.internalCleanup)
        self.ProjectDataGroup.notifyUnready(self.internalCleanup)
        self.WorkingDirectory.notifyUnready(self.internalCleanup)
        self.Dataset.notifyUnready(self.internalCleanup)

    def internalCleanup(self, *args):
        if len(self._opReaders) > 0:
            self.Image.disconnect()
            self._NonTransposedImage.disconnect()
            for reader in reversed(self._opReaders):
                reader.cleanUp()
            self._opReaders = []

    def setupOutputs(self):
        self.internalCleanup()
        datasetInfo = self.Dataset.value

        try:
            # Data only comes from the project file if the user said so AND it exists in the project
            datasetInProject = (
                datasetInfo.location == DatasetInfo.Location.ProjectInternal)
            datasetInProject &= self.ProjectFile.ready()
            if datasetInProject:
                internalPath = self.ProjectDataGroup.value + '/' + datasetInfo.datasetId
                datasetInProject &= internalPath in self.ProjectFile.value

            # If we should find the data in the project file, use a dataset reader
            if datasetInProject:
                opReader = OpStreamingHdf5Reader(parent=self)
                opReader.Hdf5File.setValue(self.ProjectFile.value)
                opReader.InternalPath.setValue(internalPath)
                providerSlot = opReader.OutputImage
            elif datasetInfo.location == DatasetInfo.Location.PreloadedArray:
                preloaded_array = datasetInfo.preloaded_array
                assert preloaded_array is not None
                if not hasattr(preloaded_array, 'axistags'):
                    # Guess the axis order, since one was not provided.
                    axisorders = {2: 'yx', 3: 'zyx', 4: 'zyxc', 5: 'tzyxc'}

                    shape = preloaded_array.shape
                    ndim = preloaded_array.ndim
                    assert ndim != 0, "Support for 0-D data not yet supported"
                    assert ndim != 1, "Support for 1-D data not yet supported"
                    assert ndim <= 5, "No support for data with more than 5 dimensions."

                    axisorder = axisorders[ndim]
                    if ndim == 3 and shape[2] <= 4:
                        # Special case: If the 3rd dim is small, assume it's 'c', not 'z'
                        axisorder = 'yxc'
                    preloaded_array = vigra.taggedView(preloaded_array,
                                                       axisorder)
                opReader = OpArrayPiper(parent=self)
                opReader.Input.setValue(preloaded_array)
                providerSlot = opReader.Output
            else:
                # Use a normal (filesystem) reader
                opReader = OpInputDataReader(parent=self)
                if datasetInfo.subvolume_roi is not None:
                    opReader.SubVolumeRoi.setValue(datasetInfo.subvolume_roi)
                opReader.WorkingDirectory.setValue(self.WorkingDirectory.value)
                opReader.FilePath.setValue(datasetInfo.filePath)
                providerSlot = opReader.Output
            self._opReaders.append(opReader)

            # Inject metadata if the dataset info specified any.
            # Also, inject if if dtype is uint8, which we can reasonably assume has drange (0,255)
            if datasetInfo.normalizeDisplay is not None or \
               datasetInfo.drange is not None or \
               datasetInfo.axistags is not None or \
               (providerSlot.meta.drange is None and providerSlot.meta.dtype == numpy.uint8):
                metadata = {}
                if datasetInfo.drange is not None:
                    metadata['drange'] = datasetInfo.drange
                elif providerSlot.meta.dtype == numpy.uint8:
                    # SPECIAL case for uint8 data: Provide a default drange.
                    # The user can always override this herself if she wants.
                    metadata['drange'] = (0, 255)
                if datasetInfo.normalizeDisplay is not None:
                    metadata['normalizeDisplay'] = datasetInfo.normalizeDisplay
                if datasetInfo.axistags is not None:
                    if len(datasetInfo.axistags) != len(
                            providerSlot.meta.shape):
                        raise Exception(
                            "Your dataset's provided axistags ({}) do not have the "
                            "correct dimensionality for your dataset, which has {} dimensions."
                            .format(
                                "".join(tag.key
                                        for tag in datasetInfo.axistags),
                                len(providerSlot.meta.shape)))
                    metadata['axistags'] = datasetInfo.axistags
                if datasetInfo.subvolume_roi is not None:
                    metadata['subvolume_roi'] = datasetInfo.subvolume_roi

                    # FIXME: We are overwriting the axistags metadata to intentionally allow
                    #        the user to change our interpretation of which axis is which.
                    #        That's okay, but technically there's a special corner case if
                    #        the user redefines the channel axis index.
                    #        Technically, it invalidates the meaning of meta.ram_usage_per_requested_pixel.
                    #        For most use-cases, that won't really matter, which is why I'm not worrying about it right now.

                opMetadataInjector = OpMetadataInjector(parent=self)
                opMetadataInjector.Input.connect(providerSlot)
                opMetadataInjector.Metadata.setValue(metadata)
                providerSlot = opMetadataInjector.Output
                self._opReaders.append(opMetadataInjector)

            self._NonTransposedImage.connect(providerSlot)

            if self.forceAxisOrder:
                # Before we re-order, make sure no non-singleton
                #  axes would be dropped by the forced order.
                output_order = "".join(self.forceAxisOrder)
                provider_order = "".join(providerSlot.meta.getAxisKeys())
                tagged_provider_shape = providerSlot.meta.getTaggedShape()
                dropped_axes = set(provider_order) - set(output_order)
                if any(tagged_provider_shape[a] > 1 for a in dropped_axes):
                    msg = "The axes of your dataset ({}) are not compatible with the axes used by this workflow ({}). Please fix them."\
                          .format(provider_order, output_order)
                    raise DatasetConstraintError("DataSelection", msg)

                op5 = OpReorderAxes(parent=self)
                op5.AxisOrder.setValue(self.forceAxisOrder)
                op5.Input.connect(providerSlot)
                providerSlot = op5.Output
                self._opReaders.append(op5)

            # If the channel axis is not last (or is missing),
            #  make sure the axes are re-ordered so that channel is last.
            if providerSlot.meta.axistags.index('c') != len(
                    providerSlot.meta.axistags) - 1:
                op5 = OpReorderAxes(parent=self)
                keys = providerSlot.meta.getTaggedShape().keys()
                try:
                    # Remove if present.
                    keys.remove('c')
                except ValueError:
                    pass
                # Append
                keys.append('c')
                op5.AxisOrder.setValue("".join(keys))
                op5.Input.connect(providerSlot)
                providerSlot = op5.Output
                self._opReaders.append(op5)

            # Connect our external outputs to the internal operators we chose
            self.Image.connect(providerSlot)

            # Set the image name and usage flag
            self.AllowLabels.setValue(datasetInfo.allowLabels)

            # If the reading operator provides a nickname, use it.
            if self.Image.meta.nickname is not None:
                datasetInfo.nickname = self.Image.meta.nickname

            imageName = datasetInfo.nickname
            if imageName == "":
                imageName = datasetInfo.filePath
            self.ImageName.setValue(imageName)

        except:
            self.internalCleanup()
            raise

    def propagateDirty(self, slot, subindex, roi):
        # Output slots are directly connected to internal operators
        pass

    @classmethod
    def getInternalDatasets(cls, filePath):
        return OpInputDataReader.getInternalDatasets(filePath)