Пример #1
0
class OpGraphCutSegmentation(Operator):
    RawInput = InputSlot()  # FIXME is this neccessary?
    InputImage = InputSlot()
    LabelImage = InputSlot()
    Beta = InputSlot(value=.2)
    Channel = InputSlot(value=0)

    CachedOutput = OutputSlot()

    def __init__(self, *args, **kwargs):
        super(OpGraphCutSegmentation, self).__init__(*args, **kwargs)

        op = OpObjectsSegment(parent=self)

        op.Prediction.connect(self.InputImage)
        op.LabelImage.connect(self.LabelImage)
        op.Beta.connect(self.Beta)
        op.Channel.connect(self.Channel)

        self.CachedOutput.connect(op.CachedOutput)

        self._op = op

        self._filled = False

    def setupOutputs(self):
        pass

    def propagateDirty(self, slot, subindex, roi):
        pass  # Nothing to do

    def execute(self, slot, subindex, roi, result):
        assert False, "Shuld not get here"
Пример #2
0
class DirtyAssert(Operator):
    Input = InputSlot()

    def propagateDirty(self, slot, subindex, roi):
        assert np.all(roi.start == 0)
        assert np.all(roi.stop == self.Input.meta.shape)
        raise PropagateDirtyCalled()
Пример #3
0
class OpFillMaskArray(Operator):
    name = "OpFillMaskArray"
    category = "Pointwise"

    InputArray = InputSlot(allow_mask=True)
    InputFillValue = InputSlot(optional=True)

    Output = OutputSlot()

    def __init__(self, *args, **kwargs):
        super(OpFillMaskArray, self).__init__(*args, **kwargs)

    def setupOutputs(self):
        # Copy the input metadata to both outputs
        self.Output.meta.assignFrom(self.InputArray.meta)
        self.Output.meta.has_mask = False

    def execute(self, slot, subindex, roi, result):
        key = roi.toSlice()

        # Get data
        data = self.InputArray[key].wait()

        # Copy results
        if slot.name == 'Output':
            if not isinstance(data, numpy.ma.masked_array):
                result[...] = data
            elif self.InputFillValue.ready():
                result[...] = data.filled(self.InputFillValue.value)
            else:
                result[...] = data.filled()

    def propagateDirty(self, slot, subindex, roi):
        if (slot.name == "InputArray"):
            slicing = roi.toSlice()
            self.Output.setDirty(slicing)
        elif (slot.name == "InputFillValue"):
            self.Output.setDirty(slice(None))
        else:
            assert False, "Unknown dirty input slot"
Пример #4
0
class DirtyAssert(Operator):
    Input = InputSlot()

    def willBeDirty(self, t, c):
        self._t = t
        self._c = c

    def propagateDirty(self, slot, subindex, roi):
        t_ind = self.Input.meta.axistags.index("t")
        c_ind = self.Input.meta.axistags.index("c")
        assert roi.start[t_ind] == self._t
        assert roi.start[c_ind] == self._c
        assert roi.stop[t_ind] == self._t + 1
        assert roi.stop[c_ind] == self._c + 1
        raise PropagateDirtyCalled()
Пример #5
0
class OpMaskArray(Operator):
    name = "OpMaskArray"
    category = "Pointwise"

    InputArray = InputSlot(allow_mask=True)
    InputMask = InputSlot()

    Output = OutputSlot(allow_mask=True)

    def __init__(self, *args, **kwargs):
        super(OpMaskArray, self).__init__(*args, **kwargs)

    def setupOutputs(self):
        # Copy the input metadata to both outputs
        self.Output.meta.assignFrom(self.InputArray.meta)
        self.Output.meta.has_mask = True

    def execute(self, slot, subindex, roi, result):
        key = roi.toSlice()

        if slot.name == 'Output':
            # Write data into result (including a mask if provided)
            self.InputArray[key].writeInto(result).wait()

            # Get the added mask
            mask = self.InputMask[key].wait()

            # Apply the combination of the masks to result.
            result.mask[...] |= mask

    def propagateDirty(self, slot, subindex, roi):
        if (slot.name == "InputArray") or (slot.name == "InputMask"):
            slicing = roi.toSlice()
            self.Output.setDirty(slicing)
        else:
            assert False, "Unknown dirty input slot"
Пример #6
0
class CountExecutes(Operator):
    Input = InputSlot()
    Output = OutputSlot()

    numExecutes = 0

    def setupOutputs(self):
        self.Output.meta.assignFrom(self.Input.meta)

    def propagateDirty(self, slot, subindex, roi):
        self.Output.setDirty(roi)

    def execute(self, slot, sunbindex, roi, result):
        self.numExecutes += 1
        req = self.Input.get(roi)
        req.writeInto(result)
        req.block()
Пример #7
0
class OpSplitMaskArray(Operator):
    name = "OpSplitMaskArray"
    category = "Pointwise"

    Input = InputSlot(allow_mask=True)

    OutputArray = OutputSlot()
    OutputMask = OutputSlot()
    OutputFillValue = OutputSlot()

    def __init__(self, *args, **kwargs):
        super( OpSplitMaskArray, self ).__init__( *args, **kwargs )

    def setupOutputs(self):
        # Copy the input metadata to both outputs
        self.OutputArray.meta.assignFrom( self.Input.meta )
        self.OutputArray.meta.has_mask = False

        self.OutputMask.meta.assignFrom( self.Input.meta )
        self.OutputMask.meta.has_mask = False
        self.OutputMask.meta.dtype = numpy.bool8

        self.OutputFillValue.meta.assignFrom( self.Input.meta )
        self.OutputFillValue.meta.has_mask = False
        self.OutputFillValue.meta.shape = tuple()

    def execute(self, slot, subindex, roi, result):
        key = roi.toSlice()

        input_subview = self.Input[key].wait()

        if slot.name == 'OutputArray':
            result[...] = input_subview.data
        elif slot.name == 'OutputMask':
            result[...] = input_subview.mask
        elif slot.name == 'OutputFillValue':
            result[...] = input_subview.fill_value

    def propagateDirty(self, slot, subindex, roi):
        if (slot.name == "Input"):
            slicing = roi.toSlice()
            self.OutputArray.setDirty(slicing)
            self.OutputMask.setDirty(slicing)
            self.OutputFillValue.setDirty(Ellipsis)
        else:
            assert False, "Unknown dirty input slot"
class OpA(Operator):
    input = InputSlot()
    output = OutputSlot()

    def __init__(self, parent=None):
        Operator.__init__(self, parent)
        self.countSetupOutputs = 0

    def setupOutputs(self):
        self.countSetupOutputs += 1
        self.output.meta.shape = self.input.meta.shape
        self.output.meta.dtype = self.input.meta.dtype

    def execute(self, slot, subindex, roi, result):
        pass

    def propagateDirty(self, slot, subindex, roi):
        pass
Пример #9
0
class OpLabelVolume(Operator):

    name = "OpLabelVolume"

    ## provide the volume to label here
    # (arbitrary shape, dtype could be restricted, see the implementations
    # property supportedDtypes below)
    Input = InputSlot()

    ## provide labels that are treated as background
    # the shape of the background labels must match the shape of the volume in
    # channel and in time axis, and must have no spatial axes.
    # E.g.: volume.taggedShape = {'x': 10, 'y': 12, 'z': 5, 'c': 3, 't': 100}
    # ==>
    # background.taggedShape = {'c': 3, 't': 100}
    # TODO relax requirements (single value is already working)
    Background = InputSlot(optional=True)

    # Bypass cache (for headless mode)
    BypassModeEnabled = InputSlot(value=False)

    ## decide which CCL method to use
    #
    # currently available:
    # * 'vigra': use the fast algorithm from ukoethe/vigra
    # * 'blocked': use the memory saving algorithm from thorbenk/blockedarray
    #
    # A change here deletes all previously cached results.
    Method = InputSlot(value="vigra")

    ## Labeled volume
    # Axistags and shape are the same as on the Input, dtype is an integer
    # datatype.
    # This slot operates on a what-you-request-is-what-you-get basis, if you
    # request a subregion only that subregion will be considered for labeling
    # and no internal caches are used. If you want consistent labels for
    # subsequent requests, use CachedOutput instead.
    # This slot will be set dirty by time and channel if the background or the
    # input changes for the respective time-channel-slice.
    Output = OutputSlot()

    ## Cached label image
    # Axistags and shape are the same as on the Input, dtype is an integer
    # datatype.
    # This slot extends the ROI to the full xyz volume (c and t are unaffected)
    # and computes the labeling for the whole volume. As long as the input does
    # not get dirty, subsequent requests to this slot guarantee consistent
    # labelings. The internal cache in use is an OpCompressedCache.
    # This slot will be set dirty by time and channel if the background or the
    # input changes for the respective time-channel-slice.
    CachedOutput = OutputSlot()

    # cache access, see OpCompressedCache
    CleanBlocks = OutputSlot()

    def __init__(self, *args, **kwargs):
        super(OpLabelVolume, self).__init__(*args, **kwargs)

        # we just want to have 5d data internally
        op5 = OpReorderAxes(parent=self)
        op5.Input.connect(self.Input)
        op5.AxisOrder.setValue("txyzc")
        self._op5 = op5

        self._opLabel = None

        self._op5_2 = OpReorderAxes(parent=self)
        self._op5_2_cached = OpReorderAxes(parent=self)

        self.Output.connect(self._op5_2.Output)
        self.CachedOutput.connect(self._op5_2_cached.Output)

        # available OpLabelingABCs:
        # TODO: OpLazyConnectedComponents and _OpLabelBlocked does not conform to OpLabelingABC
        self._labelOps = {
            "vigra": _OpLabelVigra,
            "blocked": _OpLabelBlocked,
            "lazy": OpLazyConnectedComponents
        }

    def setupOutputs(self):
        method = self.Method.value
        if not isinstance(method, (str, unicode)):
            method = method[0]

        if self._opLabel is not None and type(
                self._opLabel) != self._labelOps[method]:
            # fully remove old labeling operator
            self._op5_2.Input.disconnect()
            self._op5_2_cached.Input.disconnect()
            self._opLabel.Input.disconnect()
            self._opLabel = None

        if self._opLabel is None:
            self._opLabel = self._labelOps[method](parent=self)
            self._opLabel.Input.connect(self._op5.Output)
            if method is "vigra":
                self._opLabel.BypassModeEnabled.connect(self.BypassModeEnabled)

        # connect reordering operators
        self._op5_2.Input.connect(self._opLabel.Output)
        self._op5_2_cached.Input.connect(self._opLabel.CachedOutput)

        # set the final reordering operator's AxisOrder to that of the input
        origOrder = self.Input.meta.getAxisKeys()
        self._op5_2.AxisOrder.setValue(origOrder)
        self._op5_2_cached.AxisOrder.setValue(origOrder)

        # connect cache access slots
        self.CleanBlocks.connect(self._opLabel.CleanBlocks)

        # set background values
        self._setBG()

    def propagateDirty(self, slot, subindex, roi):
        if slot == self.BypassModeEnabled:
            pass
        elif slot == self.Method:
            # We are changing the labeling method. In principle, the labelings
            # are equivalent, but not necessarily the same!
            self.Output.setDirty(slice(None))
        elif slot == self.Input:
            # handled by internal operator
            pass
        elif slot == self.Background:
            # propagate the background values, output will be set dirty in
            # internal operator
            self._setBG()

    def setInSlot(self, slot, subindex, roi, value):
        #    "Invalid slot for setInSlot(): {}".format( slot.name )
        # Nothing to do here.
        # Our Input slots are directly fed into the cache,
        #  so all calls to __setitem__ are forwarded automatically
        pass

    ## set the background values of inner operator
    def _setBG(self):
        if self.Background.ready():
            val = self.Background.value
        else:
            val = 0
        bg = np.asarray(val)
        t = self._op5.Output.meta.shape[0]
        c = self._op5.Output.meta.shape[4]
        if bg.size == 1:
            bg = np.zeros((c, t))
            bg[:] = val
            bg = vigra.taggedView(bg, axistags="ct")
        else:
            bg = vigra.taggedView(val, axistags=self.Background.meta.axistags)
            bg = bg.withAxes(*"ct")
        bg = bg.withAxes(*"txyzc")
        self._opLabel.Background.setValue(bg)
Пример #10
0
class OpLabelingABC(with_metaclass(ABCMeta, Operator)):
    Input = InputSlot()

    ## background with axes 'txyzc', spatial axes must be singletons
    Background = InputSlot()

    # Bypass cache (for headless mode)
    BypassModeEnabled = InputSlot(value=False)

    Output = OutputSlot()
    CachedOutput = OutputSlot()

    # cache access, see OpCompressedCache
    CleanBlocks = OutputSlot()

    # the numeric type that is used for labeling
    labelType = np.uint32

    ## list of supported dtypes
    @abstractproperty
    def supportedDtypes(self):
        pass

    def __init__(self, *args, **kwargs):
        super(OpLabelingABC, self).__init__(*args, **kwargs)
        self._cache = OpBlockedArrayCache(parent=self)
        self._cache.name = "OpLabelVolume.OutputCache"
        self._cache.BypassModeEnabled.connect(self.BypassModeEnabled)
        self._cache.Input.connect(self.Output)
        self.CachedOutput.connect(self._cache.Output)
        self.CleanBlocks.connect(self._cache.CleanBlocks)

    def setupOutputs(self):

        # check if the input dtype is valid
        if self.Input.ready():
            dtype = self.Input.meta.dtype
            if dtype not in self.supportedDtypes:
                msg = "{}: dtype '{}' not supported " "with method 'vigra'. Supported types: {}"
                msg = msg.format(self.name, dtype, self.supportedDtypes)
                raise ValueError(msg)

        # set cache chunk shape to the whole spatial volume
        shape = np.asarray(self.Input.meta.shape, dtype=np.int)
        shape[0] = 1
        shape[4] = 1
        self._cache.BlockShape.setValue(tuple(shape))

        # setup meta for Output
        self.Output.meta.assignFrom(self.Input.meta)
        self.Output.meta.dtype = self.labelType

    def propagateDirty(self, slot, subindex, roi):
        if slot == self.BypassModeEnabled:
            pass
        else:
            # a change in either input or background makes the whole
            # time-channel-slice dirty (CCL is a global operation)
            outroi = roi.copy()
            outroi.start[1:4] = (0, 0, 0)
            outroi.stop[1:4] = self.Input.meta.shape[1:4]
            self.Output.setDirty(outroi)
            self.CachedOutput.setDirty(outroi)

    def setInSlot(self, slot, subindex, roi, value):
        #    "Invalid slot for setInSlot(): {}".format( slot.name )
        # Nothing to do here.
        # Our Input slots are directly fed into the cache,
        #  so all calls to __setitem__ are forwarded automatically
        pass

    def execute(self, slot, subindex, roi, result):
        if slot == self.Output:
            # just label the ROI and write it to result
            self._label(roi, result)
        else:
            raise ValueError("Request to unknown slot {}".format(slot))

    def _label(self, roi, result):
        result = vigra.taggedView(result, axistags=self.Output.meta.axistags)
        # get the background values
        bg = self.Background[...].wait()
        bg = vigra.taggedView(bg, axistags=self.Background.meta.axistags)
        bg = bg.withAxes(*"ct")
        assert np.all(
            self.Background.meta.shape[0] == self.Input.meta.shape[0]
        ), "Shape of background values incompatible to shape of Input"
        assert np.all(
            self.Background.meta.shape[4] == self.Input.meta.shape[4]
        ), "Shape of background values incompatible to shape of Input"

        # do labeling in parallel over channels and time slices
        pool = RequestPool()

        start = np.asarray(roi.start, dtype=np.int)
        stop = np.asarray(roi.stop, dtype=np.int)
        for ti, t in enumerate(range(roi.start[0], roi.stop[0])):
            start[0], stop[0] = t, t + 1
            for ci, c in enumerate(range(roi.start[4], roi.stop[4])):
                start[4], stop[4] = c, c + 1
                newRoi = SubRegion(self.Output,
                                   start=tuple(start),
                                   stop=tuple(stop))
                resView = result[ti, ..., ci].withAxes(*"xyz")
                req = Request(partial(self._label3d, newRoi, bg[c, t],
                                      resView))
                pool.add(req)

        logger.debug(
            "{}: Computing connected components for ROI {} ...".format(
                self.name, roi))
        pool.wait()
        pool.clean()
        logger.debug("{}: Connected components computed.".format(self.name))

    ## compute the requested roi and put the results into result
    #
    # @param result the array to write into, 3d xyz
    @abstractmethod
    def _label3d(self, roi, bg, result):
        pass
Пример #11
0
class OpObjectsSegment(OpGraphCut):
    name = "OpObjectsSegment"

    # thresholded predictions, or otherwise obtained ROI indicators
    # (a value of 0 is assumed to be background and ignored)
    LabelImage = InputSlot()

    # margin around each object (always xyz!)
    Margin = InputSlot(value=np.asarray((20, 20, 20)))

    # bounding boxes of the labeled objects
    # this slot returns an array of dicts with shape (t, c)
    BoundingBoxes = OutputSlot(stype=Opaque)

    ### slots from OpGraphCut ###

    ## prediction maps
    #Prediction = InputSlot()

    ## graph cut parameter
    #Beta = InputSlot(value=.2)

    ## labeled segmentation image
    #Output = OutputSlot()
    #CachedOutput = OutputSlot()

    def __init__(self, *args, **kwargs):
        super(OpObjectsSegment, self).__init__(*args, **kwargs)

    def setupOutputs(self):
        super(OpObjectsSegment, self).setupOutputs()
        # sanity checks
        shape = self.LabelImage.meta.shape
        assert len(shape) == 5,\
            "Prediction maps must be a full 5d volume (txyzc)"
        tags = self.LabelImage.meta.getAxisKeys()
        tags = "".join(tags)
        assert tags == 'txyzc',\
            "Label image has wrong axes order"\
            "(expected: txyzc, got: {})".format(tags)

        # bounding boxes are just one element arrays of type object, but we
        # want to request boxes from a specific region, therefore BoundingBoxes
        # needs a shape
        shape = self.Prediction.meta.shape
        self.BoundingBoxes.meta.shape = shape
        self.BoundingBoxes.meta.dtype = np.object
        self.BoundingBoxes.meta.axistags = vigra.defaultAxistags('txyzc')

    def execute(self, slot, subindex, roi, result):
        # check the axes - cannot do this in setupOutputs because we could be
        # in some invalid intermediate state where the dimensions do not agree
        shape = self.LabelImage.meta.shape
        agree = [i == j for i, j in zip(self.Prediction.meta.shape, shape)]
        assert all(agree),\
            "shape mismatch: {} vs. {}".format(self.Prediction.meta.shape,
                                               shape)
        if slot == self.BoundingBoxes:
            return self._execute_bbox(roi, result)
        elif slot == self.Output:
            self._execute_graphcut(roi, result)
        else:
            raise NotImplementedError(
                "execute() is not implemented for slot {}".format(str(slot)))

    def _execute_bbox(self, roi, result):
        cc = self.LabelImage.get(roi).wait()
        cc = vigra.taggedView(cc, axistags=self.LabelImage.meta.axistags)
        cc = cc.withAxes(*'xyz')

        logger.debug("computing bboxes...")
        feats = vigra.analysis.extractRegionFeatures(
            cc.astype(np.float32),
            cc.astype(np.uint32),
            features=["Count", "Coord<Minimum>", "Coord<Maximum>"])
        feats_dict = {}
        feats_dict["Coord<Minimum>"] = feats["Coord<Minimum>"]
        feats_dict["Coord<Maximum>"] = feats["Coord<Maximum>"]
        feats_dict["Count"] = feats["Count"]
        return feats_dict

    def _execute_graphcut(self, roi, result):
        for i in (0, 4):
            assert roi.stop[i] - roi.start[i] == 1,\
                "Invalid roi for graph-cut: {}".format(str(roi))
        t = roi.start[0]
        c = roi.start[4]

        margin = self.Margin.value
        beta = self.Beta.value
        MAXBOXSIZE = 10000000  # FIXME justification??

        ## request the bounding box coordinates ##
        # the trailing index brackets give us the dictionary (instead of an
        # array of size 1)
        feats = self.BoundingBoxes.get(roi).wait()
        mins = feats["Coord<Minimum>"]
        maxs = feats["Coord<Maximum>"]
        nobj = mins.shape[0]
        # these are indices, so they should have an index datatype
        mins = mins.astype(np.uint32)
        maxs = maxs.astype(np.uint32)

        ## request the prediction image ##
        pred = self.Prediction.get(roi).wait()
        pred = vigra.taggedView(pred, axistags=self.Prediction.meta.axistags)
        pred = pred.withAxes(*'xyz')

        ## request the connected components image ##
        cc = self.LabelImage.get(roi).wait()
        cc = vigra.taggedView(cc, axistags=self.LabelImage.meta.axistags)
        cc = cc.withAxes(*'xyz')

        # provide xyz view for the output (just need 8bit for segmentation
        resultXYZ = vigra.taggedView(np.zeros(cc.shape, dtype=np.uint8),
                                     axistags='xyz')

        def processSingleObject(i):
            logger.debug("processing object {}".format(i))
            # maxs are inclusive, so we need to add 1
            xmin = max(mins[i][0] - margin[0], 0)
            ymin = max(mins[i][1] - margin[1], 0)
            zmin = max(mins[i][2] - margin[2], 0)
            xmax = min(maxs[i][0] + margin[0] + 1, cc.shape[0])
            ymax = min(maxs[i][1] + margin[1] + 1, cc.shape[1])
            zmax = min(maxs[i][2] + margin[2] + 1, cc.shape[2])
            ccbox = cc[xmin:xmax, ymin:ymax, zmin:zmax]
            resbox = resultXYZ[xmin:xmax, ymin:ymax, zmin:zmax]

            nVoxels = ccbox.size
            if nVoxels > MAXBOXSIZE:
                #problem too large to run graph cut, assign to seed
                logger.warn("Object {} too large for graph cut.".format(i))
                resbox[ccbox == i] = 1
                return

            probbox = pred[xmin:xmax, ymin:ymax, zmin:zmax]
            gcsegm = segmentGC(probbox, beta)
            gcsegm = vigra.taggedView(gcsegm, axistags='xyz')
            ccsegm = vigra.analysis.labelVolumeWithBackground(
                gcsegm.astype(np.uint8))

            # Extended bboxes of different objects might overlap.
            # To avoid conflicting segmentations, we find all connected
            # components in the results and only take the one, which
            # overlaps with the object "core" or "seed", defined by the
            # pre-thresholding
            seed = ccbox == i
            filtered = seed * ccsegm
            passed = vigra.analysis.unique(filtered.astype(np.uint32))
            assert len(passed.shape) == 1
            if passed.size > 2:
                logger.warn("ambiguous label assignment for region {}".format(
                    (xmin, xmax, ymin, ymax, zmin, zmax)))
                resbox[ccbox == i] = 1
            elif passed.size <= 1:
                logger.warn("box {} segmented out with beta {}".format(
                    i, beta))
            else:
                # assign to the overlap region
                label = passed[1]  # 0 is background
                resbox[ccsegm == label] = 1

        pool = RequestPool()
        #FIXME make sure that the parallel computations fit into memory
        for i in range(1, nobj):
            req = Request(functools.partial(processSingleObject, i))
            pool.add(req)

        logger.info("Processing {} objects ...".format(nobj - 1))

        pool.wait()
        pool.clean()

        logger.info("object loop done")

        # prepare result
        resView = vigra.taggedView(result, axistags=self.Output.meta.axistags)
        resView = resView.withAxes(*'xyz')

        # some labels could have been removed => relabel
        vigra.analysis.labelVolumeWithBackground(resultXYZ, out=resView)

    def propagateDirty(self, slot, subindex, roi):
        super(OpObjectsSegment, self).propagateDirty(slot, subindex, roi)

        if slot == self.LabelImage:
            # time-channel slices are pairwise independent

            # determine t, c from input volume
            t_ind = 0
            c_ind = 4
            t = (roi.start[t_ind], roi.stop[t_ind])
            c = (roi.start[c_ind], roi.stop[c_ind])

            # set output dirty
            start = t[0:1] + (0, ) * 3 + c[0:1]
            stop = t[1:2] + self.Output.meta.shape[1:4] + c[1:2]
            roi = SubRegion(self.Output, start=start, stop=stop)
            self.Output.setDirty(roi)
        elif slot == self.Margin:
            # margin affects the whole volume
            self.Output.setDirty(slice(None))
Пример #12
0
class OpGraphCut(Operator):
    name = "OpGraphCut"

    # prediction maps
    Prediction = InputSlot()

    # graph cut parameter, usually called lambda
    Beta = InputSlot(value=.2)

    # labeled segmentation image
    #     i=0: background
    #     i>0: connected foreground object i
    Output = OutputSlot()
    CachedOutput = OutputSlot()

    def __init__(self, *args, **kwargs):
        super(OpGraphCut, self).__init__(*args, **kwargs)
        self._cache = None

    def setupOutputs(self):
        # sanity checks
        shape = self.Prediction.meta.shape
        assert len(shape) == 5,\
            "Prediction maps must be a full 5d volume (tzyxc)"
        tags = self.Prediction.meta.getAxisKeys()
        tags = "".join(tags)
        assert tags == 'tzyxc',\
            "Prediction maps have wrong axes order"\
            "(expected: tzyxc, got: {})".format(tags)

        if self._cache is not None:
            self.CachedOutput.disconnect()
            self._cache.cleanUp()
            self._cache = None

        cache = OpCompressedCache(parent=self)
        cache.name = "{}._cache".format(self.name)
        cache.Input.connect(self.Output)
        self._cache = cache
        self.CachedOutput.connect(self._cache.Output)

        self.Output.meta.assignFrom(self.Prediction.meta)
        # output is a label image
        self.Output.meta.dtype = np.uint32

        # cache should hold entire c-t-slices in memory
        shape = list(self.Prediction.meta.shape)
        shape[0] = 1
        shape[4] = 1
        self._cache.BlockShape.setValue(tuple(shape))

    def execute(self, slot, subindex, roi, result):
        assert slot == self.Output, "Unknown slot requested: {}".format(slot)
        for i in (0, 4):
            assert roi.stop[i] - roi.start[i] == 1,\
                "Invalid roi for graph-cut: {}".format(str(roi))

        ## request the prediction image ##
        pred = self.Prediction.get(roi).wait()
        pred = vigra.taggedView(pred, axistags=self.Prediction.meta.axistags)
        pred = pred.withAxes(*'zyx')

        # prepare result
        resView = vigra.taggedView(result, axistags=self.Output.meta.axistags)
        resView = resView.withAxes(*'zyx')

        logger.info("Executing graph cut ... (this might take a while)")
        threshold_binary = segmentGC(pred, self.Beta.value)
        threshold_binary = vigra.taggedView( threshold_binary, 'zyx' )
        logger.info("Graph-cut done")

        # label the segmentation so that this operator is consistent with
        # the other thresholding operators
        vigra.analysis.labelVolumeWithBackground(threshold_binary.astype(np.uint8), out=resView)

    def propagateDirty(self, slot, subindex, roi):
        # all input slots affect the (global) graph cut computation

        if slot == self.Beta:
            # beta value affects the whole volume
            self.Output.setDirty(slice(None))
        elif slot == self.Prediction:
            # time-channel slices are pairwise independent

            # determine t, c from input volume
            t_ind = 0
            c_ind = 4
            t = (roi.start[t_ind], roi.stop[t_ind])
            c = (roi.start[c_ind], roi.stop[c_ind])

            # set output dirty
            start = t[0:1] + (0,)*3 + c[0:1]
            stop = t[1:2] + self.Output.meta.shape[1:4] + c[1:2]
            roi = SubRegion(self.Output, start=start, stop=stop)
            self.Output.setDirty(roi)
Пример #13
0
class OpLabelingABC(Operator):
    __metaclass__ = ABCMeta

    ## input with axes 'xyzct'
    Input = InputSlot()

    ## background with axes 'xyzct', spatial axes must be singletons
    Background = InputSlot()

    Output = OutputSlot()
    CachedOutput = OutputSlot()

    # the numeric type that is used for labeling
    labelType = np.uint32

    ## list of supported dtypes
    @abstractproperty
    def supportedDtypes(self):
        pass

    def __init__(self, *args, **kwargs):
        super(OpLabelingABC, self).__init__(*args, **kwargs)
        self._cache = OpCompressedCache(parent=self)
        self._cache.name = "OpLabelVolume.OutputCache"
        self._cache.Input.connect(self.Output)
        self.CachedOutput.connect(self._cache.Output)

    def setupOutputs(self):

        # check if the input dtype is valid
        if self.Input.ready():
            dtype = self.Input.meta.dtype
            if dtype not in self.supportedDtypes:
                msg = "{}: dtype '{}' not supported "\
                    "with method 'vigra'. Supported types: {}"
                msg = msg.format(self.name, dtype, self.supportedDtypes)
                raise ValueError(msg)

        # set cache chunk shape to the whole spatial volume
        shape = np.asarray(self.Input.meta.shape, dtype=np.int)
        shape[3:5] = 1
        self._cache.BlockShape.setValue(tuple(shape))

        # setup meta for Output
        self.Output.meta.assignFrom(self.Input.meta)
        self.Output.meta.dtype = self.labelType

    def propagateDirty(self, slot, subindex, roi):
        # a change in either input or background makes the whole
        # time-channel-slice dirty (CCL is a global operation)
        outroi = roi.copy()
        outroi.start[:3] = (0, 0, 0)
        outroi.stop[:3] = self.Input.meta.shape[:3]
        self.Output.setDirty(outroi)
        self.CachedOutput.setDirty(outroi)

    def execute(self, slot, subindex, roi, result):
        if slot == self.Output:
            # just label the ROI and write it to result
            self._label(roi, result)
        else:
            raise ValueError("Request to unknown slot {}".format(slot))

    def _label(self, roi, result):
        result = vigra.taggedView(result, axistags=self.Output.meta.axistags)
        # get the background values
        bg = self.Background[...].wait()
        bg = vigra.taggedView(bg, axistags=self.Background.meta.axistags)
        bg = bg.withAxes(*'ct')
        assert np.all(self.Background.meta.shape[3:] ==
                      self.Input.meta.shape[3:]),\
            "Shape of background values incompatible to shape of Input"

        # do labeling in parallel over channels and time slices
        pool = RequestPool()

        start = np.asarray(roi.start, dtype=np.int)
        stop = np.asarray(roi.stop, dtype=np.int)
        for ti, t in enumerate(range(roi.start[4], roi.stop[4])):
            start[4], stop[4] = t, t + 1
            for ci, c in enumerate(range(roi.start[3], roi.stop[3])):
                start[3], stop[3] = c, c + 1
                newRoi = SubRegion(self.Output,
                                   start=tuple(start),
                                   stop=tuple(stop))
                resView = result[..., ci, ti].withAxes(*'xyz')
                req = Request(partial(self._label3d, newRoi, bg[c, t],
                                      resView))
                pool.add(req)

        logger.debug(
            "{}: Computing connected components for ROI {} ...".format(
                self.name, roi))
        pool.wait()
        pool.clean()
        logger.debug("{}: Connected components computed.".format(self.name))

    ## compute the requested roi and put the results into result
    #
    # @param result the array to write into, 3d xyz
    @abstractmethod
    def _label3d(self, roi, bg, result):
        pass