Example #1
0
class OpArrayCacheCpp(OpCache):
    """ Allocates a block of memory as large as Input.meta.shape (==Output.meta.shape)
        with the same dtype in order to be able to cache results.
        
        blockShape: dirty regions are tracked with a granularity of blockShape
    """

    name = "ArrayCache"
    description = "numpy.ndarray caching class"
    category = "misc"

    DefaultBlockSize = 64

    #Input
    Input = InputSlot()
    blockShape = InputSlot(value=DefaultBlockSize)
    fixAtCurrent = InputSlot(value=False)

    #Output
    CleanBlocks = OutputSlot()
    Output = OutputSlot()

    loggingName = __name__ + ".OpArrayCache"
    logger = logging.getLogger(loggingName)
    traceLogger = logging.getLogger("TRACE." + loggingName)

    # Block states
    IN_PROCESS = 0
    DIRTY = 1
    CLEAN = 2
    FIXED_DIRTY = 3

    def __init__(self, *args, **kwargs):
        super(OpArrayCacheCpp, self).__init__(*args, **kwargs)
        self._origBlockShape = self.DefaultBlockSize

        self._last_access = None
        self._blockShape = None
        self._fixed = False
        self._lock = Lock()
        self._cacheHits = 0
        self._has_fixed_dirty_blocks = False
        self._memory_manager = ArrayCacheMemoryMgr.instance
        self._running = 0

    def usedMemory(self):
        pass

    #def usedMemory(self):
    #    if self._cache is not None:
    #        return self._cache.nbytes
    #    else:
    #        return 0

    #def _blockShapeForIndex(self, index):
    #    if self._cache is None:
    #        return None
    #    cacheShape = numpy.array(self._cache.shape)
    #    blockStart = index * self._blockShape
    #    blockStop = numpy.minimum(blockStart + self._blockShape, cacheShape)

    def fractionOfUsedMemoryDirty(self):
        pass

    #def fractionOfUsedMemoryDirty(self):
    #    totAll   = numpy.prod(self.Output.meta.shape)
    #    totDirty = 0
    #    for i, v in enumerate(self._blockState.ravel()):
    #        sh = self._blockShapeForIndex(i)
    #        if sh is None:
    #            continue
    #        if v == self.DIRTY or v == self.FIXED_DIRTY:
    #            totDirty += numpy.prod(sh)
    #    return totDirty/float(totAll)

    def lastAccessTime(self):
        return self._last_access

    def generateReport(self, report):
        pass

    #def generateReport(self, report):
    #    report.name = self.name
    #    report.fractionOfUsedMemoryDirty = self.fractionOfUsedMemoryDirty()
    #    report.usedMemory = self.usedMemory()
    #    report.lastAccessTime = self.lastAccessTime()
    #    report.dtype = self.Output.meta.dtype
    #    report.type = type(self)
    #    report.id = id(self)

    def _freeMemory(self, refcheck=True):
        pass

    #def _freeMemory(self, refcheck = True):
    #    with self._cacheLock:
    #        freed  = self.usedMemory()
    #        if self._cache is not None:
    #            fshape = self._cache.shape
    #            try:
    #                self._cache.resize((1,), refcheck = refcheck)
    #            except ValueError:
    #                freed = 0
    #                self.logger.warn("OpArrayCache: freeing failed due to view references")
    #            if freed > 0:
    #                self.logger.debug("OpArrayCache: freed cache of shape:{}".format(fshape))
    #
    #                self._lock.acquire()
    #                self._blockState[:] = OpArrayCache.DIRTY
    #                del self._cache
    #                self._cache = None
    #                self._lock.release()
    #        return freed

    def _allocateManagementStructures(self):
        pass

    #def _allocateManagementStructures(self):
    #    with Tracer(self.traceLogger):
    #        shape = self.Output.meta.shape
    #        if type(self._origBlockShape) != tuple:
    #            self._blockShape = (self._origBlockShape,)*len(shape)
    #        else:
    #            self._blockShape = self._origBlockShape
    #
    #        self._blockShape = numpy.minimum(self._blockShape, shape)
    #
    #        self._dirtyShape = numpy.ceil(1.0 * numpy.array(shape) / numpy.array(self._blockShape))
    #
    #        self.logger.debug("Configured OpArrayCache with shape={}, blockShape={}, dirtyShape={}, origBlockShape={}".format(shape, self._blockShape, self._dirtyShape, self._origBlockShape))
    #
    #        #if a request has been submitted to get a block, the request object
    #        #is stored within this array
    #        self._blockQuery = numpy.ndarray(self._dirtyShape, dtype=object)
    #
    #        #keep track of the dirty state of each block
    #        self._blockState = OpArrayCache.DIRTY * numpy.ones(self._dirtyShape, numpy.uint8)
    #
    #        self._blockState[:]= OpArrayCache.DIRTY
    #        self._dirtyState = OpArrayCache.CLEAN

    #def _allocateCache(self):
    #    with self._cacheLock:
    #        self._last_access = None
    #        self._cache_priority = 0
    #        self._running = 0
    #
    #        if self._cache is None or (self._cache.shape != self.Output.meta.shape):
    #            mem = numpy.zeros(self.Output.meta.shape, dtype = self.Output.meta.dtype)
    #            self.logger.debug("OpArrayCache: Allocating cache (size: %dbytes)" % mem.nbytes)
    #            if self._blockState is None:
    #                self._allocateManagementStructures()
    #            self._cache = mem
    #    self._memory_manager.add(self)

    def setupOutputs(self):
        self.CleanBlocks.meta.shape = (1, )
        self.CleanBlocks.meta.dtype = object

        reconfigure = False
        if self.inputs["fixAtCurrent"].ready():
            self._fixed = self.inputs["fixAtCurrent"].value

        if self.inputs["blockShape"].ready() and self.inputs["Input"].ready():
            newBShape = self.inputs["blockShape"].value
            if not isinstance(newBShape, tuple):
                newBShape = tuple([int(newBShape)] *
                                  len(self.Input.meta.shape))
            if self._origBlockShape != newBShape and self.inputs[
                    "Input"].ready():
                reconfigure = True
            self._origBlockShape = newBShape
            self._blockShape = newBShape

            inputSlot = self.inputs["Input"]
            self.outputs["Output"].meta.assignFrom(inputSlot.meta)

        shape = self.outputs["Output"].meta.shape

        if reconfigure and shape is not None:
            self._lock.acquire()
            if self.Input.meta.dtype == numpy.uint32:
                t = "uint32"
            elif self.Input.meta.dtype == numpy.int32:
                t = "int32"
            elif self.Input.meta.dtype == numpy.float32:
                t = "float32"
            elif self.Input.meta.dtype == numpy.int64:
                t = "int64"
            elif self.Input.meta.dtype == numpy.uint8:
                t = "uint8"
            else:
                raise RuntimeError("dtype %r not supported" %
                                   self.Input.meta.dtype)
            cls = "BlockedArray%d%s" % (len(self._blockShape), t)
            self.b = eval(cls)(self._blockShape)
            self.b.setDirty(tuple([0] * len(self._blockShape)),
                            self.Input.meta.shape, True)

            self._lock.release()

    def propagateDirty(self, slot, subindex, roi):
        shape = self.Output.meta.shape

        key = roi.toSlice()

        if slot == self.inputs["Input"]:
            with self._lock:
                self.b.setDirty(roi.start, roi.stop, True)
            if not self._fixed:
                self.outputs["Output"].setDirty(key)

        if slot == self.inputs["fixAtCurrent"]:
            if self.inputs["fixAtCurrent"].ready():
                self._fixed = self.inputs["fixAtCurrent"].value
                if not self._fixed and self.Output.meta.shape is not None and self._has_fixed_dirty_blocks:
                    pass
                    '''
                    # We've become unfixed, so we need to notify downstream 
                    #  operators of every block that became dirty while we were fixed.
                    # Convert all FIXED_DIRTY states into DIRTY states
                    with self._lock:
                        dirtyBlocks = self.b.dirtyBlocks(000, self.Output.meta.shape)
                        newDirtyBlocks = diff(dirtyBlocks, oldDityBlocks)
                        self._has_fixed_dirty_blocks = False
                    newDirtyBlocks = numpy.transpose(numpy.nonzero(cond))
                    
                    # To avoid lots of setDirty notifications, we simply merge all the dirtyblocks into one single superblock.
                    # This should be the best option in most cases, but could be bad in some cases.
                    # TODO: Optimize this by merging the dirty blocks via connected components or something.
                    cacheShape = numpy.array(self.Output.meta.shape)
                    dirtyStart = cacheShape
                    dirtyStop = [0] * len(cacheShape)
                    for index in newDirtyBlocks:
                        blockStart = index * self._blockShape
                        blockStop = numpy.minimum(blockStart + self._blockShape, cacheShape)
                        
                        dirtyStart = numpy.minimum(dirtyStart, blockStart)
                        dirtyStop = numpy.maximum(dirtyStop, blockStop)

                    if len(newDirtyBlocks > 0):
                        self.Output.setDirty( dirtyStart, dirtyStop )
                    '''

    def _updatePriority(self, new_access=None):
        pass

    #def _updatePriority(self, new_access = None):
    #    if self._last_access is None:
    #        self._last_access = new_access or time.time()
    #    cur_time = time.time()
    #    delta = cur_time - self._last_access + 1e-9
    #
    #    self._last_access = cur_time
    #    new_prio = 0.5 * self._cache_priority + delta
    #    self._cache_priority = new_prio

    def execute(self, slot, subindex, roi, result):
        if slot == self.Output:
            return self._executeOutput(slot, subindex, roi, result)
        elif slot == self.CleanBlocks:
            pass  #FIXME
            #return self._executeCleanBlocks(slot, subindex, roi, result)

    def _executeOutput(self, slot, subindex, roi, result):
        key = roi.toSlice()

        shape = self.Output.meta.shape
        start, stop = sliceToRoi(key, shape)

        self.traceLogger.debug("Acquiring ArrayCache lock...")
        self._lock.acquire()
        self.traceLogger.debug("ArrayCache lock acquired.")

        ch = self._cacheHits
        ch += 1
        self._cacheHits = ch

        self._running += 1

        bp, bq = self.b.dirtyBlocks(start, stop)

        #print "there are %d dirty blocks" % bp.shape[0]

        if not self._fixed:
            reqs = []
            sh = self.outputs["Output"].meta.shape
            for i in range(bp.shape[0]):
                bStart = tuple([int(t) for t in bp[i, :]])
                bStop = tuple([int(t) for t in numpy.minimum(bq[i, :], sh)])
                key = roiToSlice(bStart, bStop)
                req = self.Input[key]
                reqs.append((req, bStart, bStop))
            for r, bStart, bStop in reqs:
                r.wait()
            for r, bStart, bStop in reqs:
                x = r.wait()
                self.b.writeSubarray(bStart, bStop, r.wait())

        t1 = time.time()
        self.b.readSubarray(start, stop, result)

        #print "read subarray took %f" % (time.time()-t1)

        self._lock.release()

        return result

    def setInSlot(self, slot, subindex, roi, value):
        print "SET IN SLOT &&&&&"
        assert slot == self.inputs["Input"]
        ch = self._cacheHits
        ch += 1
        self._cacheHits = ch
        start, stop = roi.start, roi.stop

        self._lock.acquire()
        print "***", start, stop, value.shape
        self.b.writeSubarray(start, stop, value)
        self._lock.release()
class OpStreamingH5N5SequenceReaderS(Operator):
    """
    Imports a sequence of (ND) volumes inside one hdf5/N5 file into a single volume (ND+1)

    The 'S' at the end of the file name implies that this class handles multiple
    volumes in a single file.

    :param globstring: A glob string as defined by the glob module. We
        also support the following special extension to globstring
        syntax: A single string can hold a *list* of globstrings.
        The delimiter that separates the globstrings in the list is
        OS-specific via os.path.pathsep.

        For example, on Linux the pathsep is':', so

            '/a/b/c.txt:/d/e/f.txt:../g/i/h.txt'

        is parsed as

            ['/a/b/c.txt', '/d/e/f.txt', '../g/i/h.txt']

    """

    GlobString = InputSlot()
    # The project hdf5 File object (already opened)
    SequenceAxis = InputSlot(optional=True)  # The axis to stack across.
    OutputImage = OutputSlot()

    class WrongFileTypeError(Exception):
        def __init__(self, globString):
            self.filename = globString
            self.msg = f"File is not a HDF5 or N5: {globString}"
            super().__init__(self.msg)

    class InconsistentShape(Exception):
        def __init__(self, fileName, datasetName):
            self.fileName = fileName
            self.msg = (
                f"Cannot stack dataset: {fileName}/{datasetName} because its shape differs from the shape of "
                f"the previous datasets")
            super().__init__(self.msg)

    class InconsistentDType(Exception):
        def __init__(self, fileName, datasetName):
            self.fileName = fileName
            self.msg = (
                f"Cannot stack dataset: {fileName}/{datasetName} because its data type differs from the "
                f"type of the previous datasets")
            super().__init__(self.msg)

    class NotTheSameFileError(Exception):
        def __init__(self, globString):
            self.globString = globString
            self.msg = f"Glob string encompasses more than one HDF5/N5 file: {globString}"
            super().__init__(self.msg)

    class NoInternalPlaceholderError(Exception):
        def __init__(self, globString):
            self.globString = globString
            self.msg = f"Glob string does not contain a placeholder: {globString}"
            super().__init__(self.msg)

    class ExternalPlaceholderError(Exception):
        def __init__(self, globString):
            self.globString = globString
            self.msg = f"Glob string does contains an external placeholder (not supported!): {globString}"
            super().__init__(self.msg)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._h5N5File = None
        self._readers = []
        self._opStacker = OpMultiArrayStacker(parent=self)
        self._opStacker.AxisIndex.setValue(0)

    def cleanUp(self):
        self._opStacker.Images.resize(0)
        for opReader in self._readers:
            opReader.cleanUp()
        if self._h5N5File is not None:
            assert isinstance(
                self._h5N5File,
                (h5py.File,
                 z5py.N5File)), "_h5N5File should not be of any other type"
            self._h5N5File.close()

        super().cleanUp()

    def setupOutputs(self):
        pcs = PathComponents(self.GlobString.value.split(os.path.pathsep)[0])
        self._h5N5File = OpStreamingH5N5Reader.get_h5_n5_file(pcs.externalPath,
                                                              mode="r")
        self.checkGlobString(self.GlobString.value)
        file_paths = self.expandGlobStrings(self._h5N5File,
                                            self.GlobString.value)

        num_files = len(file_paths)
        if num_files == 0:
            self.OutputImage.disconnect()
            self.OutputImage.meta.NOTREADY = True
            return

        self.OutputImage.connect(self._opStacker.Output)
        # Get slice axes from first image
        try:
            opFirstImg = OpStreamingH5N5Reader(parent=self)
            opFirstImg.InternalPath.setValue(file_paths[0])
            opFirstImg.H5N5File.setValue(self._h5N5File)
            slice_axes = opFirstImg.OutputImage.meta.getAxisKeys()
            opFirstImg.cleanUp()
        except RuntimeError as e:
            logger.error(str(e))
            raise OpStreamingH5N5SequenceReaderS.FileOpenError(
                file_paths[0]) from e

        # Use given new axis or try to do something sensible
        if self.SequenceAxis.ready():
            new_axis = self.SequenceAxis.value
            assert len(new_axis) == 1
            assert new_axis in "tzyxc"
        else:
            # Try to pick an axis that doesn't already exist in each volume
            for new_axis in "tzc0":
                if new_axis not in slice_axes:
                    break

            if new_axis == "0":
                # All axes used already.
                # Stack across first existing axis
                new_axis = slice_axes[0]

        self._opStacker.Images.resize(0)
        self._opStacker.Images.resize(num_files)
        self._opStacker.AxisFlag.setValue(new_axis)

        for opReader in self._readers:
            opReader.cleanUp()

        self._readers = []
        dtype = None
        shape = None

        for filename, stacker_slot in zip(file_paths, self._opStacker.Images):
            opReader = OpStreamingH5N5Reader(parent=self)
            try:
                # Abort if the image-stack has no consistent dtype or shape
                if dtype is None:
                    dtype = self._h5N5File[filename].dtype
                    shape = self._h5N5File[filename].shape
                else:
                    if dtype != self._h5N5File[filename].dtype:
                        raise OpStreamingH5N5SequenceReaderS.InconsistentDType(
                            pcs.externalPath, filename)
                    if shape != self._h5N5File[filename].shape:
                        raise OpStreamingH5N5SequenceReaderS.InconsistentShape(
                            pcs.externalPath, filename)

                opReader.InternalPath.setValue(filename)
                opReader.H5N5File.setValue(self._h5N5File)
            except RuntimeError as e:
                logger.error(str(e))
                raise OpStreamingH5N5SequenceReaderS.FileOpenError(
                    file_paths[0]) from e
            else:
                stacker_slot.connect(opReader.OutputImage)
                self._readers.append(opReader)

    def propagateDirty(self, slot, subindex, roi):
        if slot == self.GlobString or slot == self.SequenceAxis:
            self.OutputImage.setDirty(slice(None))

    @staticmethod
    def expandGlobStrings(h5N5File, globStrings):
        """Matches a list of globStrings to internal paths of files

        Args:
            h5N5File: h5py or z5py File object, or path(string). If a string is given,
              the file is opened and closed in this method.
            globStrings: string. glob or path strings delimited by os.pathsep

        Returns:
            List of internal paths matching the globstrings that were found in
            the provided h5py.File object
        """
        if not isinstance(h5N5File, (h5py.File, z5py.N5File)):
            with OpStreamingH5N5Reader.get_h5_n5_file(h5N5File, mode="r") as f:
                ret = OpStreamingH5N5SequenceReaderS.expandGlobStrings(
                    f, globStrings)
            return ret

        ret = []
        # Parse list into separate globstrings and combine them
        for globString in globStrings.split(os.path.pathsep):
            s = globString.strip()
            components = PathComponents(s)
            ret += sorted(
                globH5N5(h5N5File, components.internalPath.lstrip("/")))
        return ret

    @staticmethod
    def checkGlobString(globString):
        """Checks whether globString is valid for this class

        Rules for globString:
            * must only contain one distinct external path
            * multiple internal paths, or placeholder '*' must be contained

        Args:
            globString (string): String, one or multiple paths separated with
              os.path.pathsep and possibly containing '*' as a placeholder.

        Raises:
            OpStreamingH5N5SequenceReaderS.ExternalPlaceholderError: External
              placeholders are not supported.
            OpStreamingH5N5SequenceReaderS.NoInternalPlaceholderError: This
              exception is raised if only a single path is provided ->
              OpStreamingH5N5Reader should be used in this case.
            OpStreamingH5N5SequenceReaderS.NotTheSameFileError: if multiple
              hdf5 files are (possibly) referenced in the globstring, this
              Exception is raised -> OpStreamingH5N5SequenceReaderM should
              be used in this case.
            OpStreamingH5N5SequenceReaderS.WrongFileTypeError:If file-
                extensions are not among the known H5 extensions, this error is
                raised (see OpStreamingH5N5Reader.H5EXTS and OpStreamingH5N5Reader.N5EXTS)
        """
        pathStrings = globString.split(os.path.pathsep)

        pathComponents = [PathComponents(p.strip()) for p in pathStrings]
        assert len(pathComponents) > 0

        if not all(p.extension in OpStreamingH5N5Reader.H5EXTS +
                   OpStreamingH5N5Reader.N5EXTS for p in pathComponents):
            raise OpStreamingH5N5SequenceReaderS.WrongFileTypeError(globString)

        if len(pathComponents) == 1:
            if pathComponents[0].internalPath is None:
                raise OpStreamingH5N5SequenceReaderS.NoInternalPlaceholderError(
                    globString)
            if "*" not in pathComponents[0].internalPath:
                raise OpStreamingH5N5SequenceReaderS.NoInternalPlaceholderError(
                    globString)
            if "*" in pathComponents[0].externalPath:
                raise OpStreamingH5N5SequenceReaderS.ExternalPlaceholderError(
                    globString)
        else:
            sameExternal = all(pathComponents[0].externalPath == x.externalPath
                               for x in pathComponents[1::])
            if sameExternal is not True:
                raise OpStreamingH5N5SequenceReaderS.NotTheSameFileError(
                    globString)
            if "*" in pathComponents[0].externalPath:
                raise OpStreamingH5N5SequenceReaderS.ExternalPlaceholderError(
                    globString)
Example #3
0
class OpValueCache(Operator, ObservableCache):
    """
    This operator caches a value in its entirety, 
    and allows for the value to be "forced in" from an external user.
    No memory management, no blockwise access.
    """
    
    name = "OpValueCache"
    category = "Cache"
    
    Input = InputSlot()
    fixAtCurrent = InputSlot(value=False)
    Output = OutputSlot()
    
    loggerName = __name__ + ".OpValueCache"
    logger = logging.getLogger(loggerName)
    traceLogger = logging.getLogger("TRACE." + loggerName)
    
    def __init__(self, *args, **kwargs):
        super(OpValueCache, self).__init__(*args, **kwargs)
        self._dirty = True
        self._value = None
        self._lock = threading.Lock()
        self._request = None

        # Now that we're initialized, it's safe to register with the memory manager
        self.registerWithMemoryManager()
        
        def handle_unready(slot):
            self._dirty = True
        self.Input.notifyUnready(handle_unready)
        
    def usedMemory(self):
        if isinstance(self._value, numpy.ndarray):
            return self._value.nbytes
        return 0 #FIXME

    def fractionOfUsedMemoryDirty(self):
        if self._dirty:
            return 1.0
        else:
            return 0.0

    def generateReport(self, report):
        super(OpValueCache, self).generateReport(report)
        if self._value is None:
            s = "no value"
        else:
            t = str(type(self._value))
            t = t[len("<type '"):-len("'>")]
            s = "value of type '{}'".format(t)
        report.info = s
    
    def setupOutputs(self):
        self.Output.meta.assignFrom(self.Input.meta)
    
    def execute(self, slot, subindex, roi, result):
        if self.fixAtCurrent.value is True or self._dirty is False:
            if result.shape == (1,):
                result[0] = self._value
            else:
                result[:] = self._value
            return result
        
        # Optimization: We don't let more than one caller trigger the value to be computed at the same time
        # If some other caller has already requested the value, we'll just wait for the request he already made.
        class State():
            Dirty = 0
            Waiting = 1
            Clean = 2

        request = None
        value = None
        with self._lock:
            # What state are we in?
            if not self._dirty:
                state = State.Clean
            elif self._request is not None:
                state = State.Waiting
            else:
                state = State.Dirty

            self.traceLogger.debug("State is: {}".format( {State.Dirty : 'Dirty',
                                                           State.Waiting : 'Waiting',
                                                           State.Clean : 'Clean'}[state]) )
            
            # Obtain the request to wait for (create it if necessary)
            if state == State.Dirty:
                request = self.Input[...]
                self._request = request
            elif state == State.Waiting:
                request = self._request
            else:
                value = self._value

        # Now release the lock and block for the request
        if state != State.Clean:
            success = False
            while not success:
                try:
                    if result.shape == (1,):
                        value = request.wait()[0]
                    else:
                        value = request.wait()
                    success = True
                except Request.InvalidRequestException:
                    # Oops, we're sharing the request with another thread 
                    #  and that other thread cancelled it before we got a chance to call wait().
                    # Just regenerate the request and try again...
                    with self._lock:
                        if request == self._request or self._request is None:
                            request = self.Input[...]
                            self._request = request
                        else:
                            request = self._request
                    state = State.Dirty
                except Request.CancellationException:
                    if state == State.Dirty:
                        with self._lock:
                            # If no other request has 'taken responsibility' since we were cancelled
                            # (i.e. self._request is still the request that raised this exception.)
                            if request == self._request:
                                self._request = None # This is mostly to aid testing.
                    raise
        
        if result.shape == (1,):
            result[0] = value
        else:
            result[:] = value
        
        # If we made the request, set the members
        if state == State.Dirty:
            with self._lock:
                self.Output._sig_value_changed()
                self._value = value
                self._request = None
                self._dirty = False

        return result

    def propagateDirty(self, slot, subindex, roi):
        if slot is self.Input:
            self._dirty = True
            if not self.fixAtCurrent.value:
                self.Output.setDirty(roi)
        elif slot is self.fixAtCurrent:
            if self.fixAtCurrent.value is False and self._dirty:
                self.Output.setDirty()


    def forceValue(self, value):
        """
        Allows a 'back door' to force data into this cache.
        Note: Use this function carefully.
        """
        with self._lock:
            self._value = value
            self._dirty = False
        self.Output.setDirty()
        
    def resetValue(self):
        """
        Remove the value from the cache.
        """
        with self._lock:
            self._value = None
            self._dirty = True
        self.Output.setDirty()
Example #4
0
class OpSplitRequestsBlockwise(Operator):
    """
    Large requests serviced on the downstream Output will be broken up into smaller requests,
    and requested in parallel from the upstream Input.
    The size of the smaller requests is determined by the BlockShape slot.
    A constructor argument offers an additional feature for exactly how requests are translated into blocks.
    """

    Input = InputSlot(allow_mask=True)
    BlockShape = InputSlot()
    Output = OutputSlot(allow_mask=True)

    def __init__(self, always_request_full_blocks, *args, **kwargs):
        """
        always_request_full_blocks: If True, requests for upstream data will always be the "full" block as specified
                                    by the BlockShape.  The requests will not be truncated to match the user's
                                    requested ROI.  (But the user's requested ROI will be used to extract the data
                                    from the block results.)

                                    This feature allows us to turn an "unblocked" cache into a "blocked" cache.
                                    (If we didn't expand requests to the full blocks they intersect, the upstream
                                    cache blocks would not have uniform size.)
        """
        super(OpSplitRequestsBlockwise, self).__init__(*args, **kwargs)
        self._always_request_full_blocks = always_request_full_blocks

    def setupOutputs(self):
        self.Output.meta.assignFrom(self.Input.meta)
        if len(self.BlockShape.value) != len(self.Input.meta.shape):
            self.Output.meta.NOTREADY = True
            return
        self.Output.meta.ideal_blockshape = tuple(numpy.minimum(self.BlockShape.value, self.Input.meta.shape))

        # Estimate ram usage per requested pixel
        ram_per_pixel = get_ram_per_element(self.Input.meta.dtype)

        # One 'pixel' includes all channels
        tagged_shape = self.Input.meta.getTaggedShape()
        if "c" in tagged_shape:
            ram_per_pixel *= float(tagged_shape["c"])

        if self.Input.meta.ram_usage_per_requested_pixel is not None:
            ram_per_pixel = max(ram_per_pixel, self.Input.meta.ram_usage_per_requested_pixel)

        self.Output.meta.ram_usage_per_requested_pixel = ram_per_pixel

    def execute(self, slot, subindex, roi, result):
        clipped_block_rois = getIntersectingRois(
            self.Input.meta.shape, self.BlockShape.value, (roi.start, roi.stop), True
        )
        if self._always_request_full_blocks:
            full_block_rois = getIntersectingRois(
                self.Input.meta.shape, self.BlockShape.value, (roi.start, roi.stop), False
            )
        else:
            full_block_rois = clipped_block_rois

        pool = RequestPool()
        for full_block_roi, clipped_block_roi in zip(full_block_rois, clipped_block_rois):
            full_block_roi = numpy.asarray(full_block_roi)
            clipped_block_roi = numpy.asarray(clipped_block_roi)

            req = self.Input(*full_block_roi)
            output_roi = numpy.asarray(clipped_block_roi) - roi.start
            if (full_block_roi == clipped_block_roi).all():
                req.writeInto(result[roiToSlice(*output_roi)])
            else:
                roi_within_block = clipped_block_roi - full_block_roi[0]

                def copy_request_result(output_roi, roi_within_block, request_result):
                    self.Output.stype.copy_data(
                        result[roiToSlice(*output_roi)], request_result[roiToSlice(*roi_within_block)]
                    )

                req.notify_finished(partial(copy_request_result, output_roi, roi_within_block))
            pool.add(req)
            del req
        pool.wait()

    def propagateDirty(self, slot, subindex, roi):
        if slot is self.Input:
            self.Output.setDirty(roi.start, roi.stop)
class OpUnmanagedCompressedCache(Operator):
    """
    A blockwise cache that stores each block as a separate in-memory hdf5 file with a compressed dataset.

    The files for each block have an internal chunk-shape, which corresponds to
    the amount of data that has to be decompressed for a single pixel lookup.
    The chunk shape is prioritized as follows:
        1. Input.meta.ideal_blockshape
           (make sure to set BlockShape to a multiple of ideal_blockshape!)
        2. BlockShape, if available and smaller than 1MiB (raw)
        3. Automatically determined shape with t=1, c=1 and xyz such that the
           blocks are smaller than 1MiB (raw)

    Note: This class is not managed by the memory manager, so there can be non-managed subclasses.
          The "managed" version is OpCompressedCache, defined below.

    Note:
      * It is not safe to call execute() and change the blockshape
        simultaneously.
      * it is not safe to reuse this cache #FIXME
    """

    # Also used to asynchronously force data into the cache via __setitem__ (see setInSlot(), below()
    Input = InputSlot(allow_mask=True)

    # shape of internal in-memory hdf5 files (defaults to the whole volume)
    BlockShape = InputSlot(optional=True)

    # Output as numpy arrays
    Output = OutputSlot(allow_mask=True)

    InputHdf5 = InputSlot(optional=True, allow_mask=True)
    # A list of rois (tuples) of the blocks that are currently stored in the cache
    CleanBlocks = OutputSlot()
    # Provides data as hdf5 datasets.  Only allowed for rois that exactly match a block.
    OutputHdf5 = OutputSlot(allow_mask=True)

    def __init__(self, *args, **kwargs):
        super(OpUnmanagedCompressedCache, self).__init__(*args, **kwargs)
        self._lock = RequestLock()
        self._init_cache(None)
        self._block_id_counter = itertools.count(
        )  # Used to ensure unique in-memory file names
        self._ignore_ideal_blockshape = False

    def _init_cache(self, new_blockshape):
        with self._lock:
            self._blockshape = new_blockshape
            self._cacheFiles = {}
            self._dirtyBlocks = set()
            self._blockLocks = {}
            self._chunkshape = self._chooseChunkshape(self._blockshape)
            self._last_access_times = collections.defaultdict(float)

    def cleanUp(self):
        logger.debug("Cleaning up")
        self._closeAllCacheFiles()
        super(OpUnmanagedCompressedCache, self).cleanUp()

    def setupOutputs(self):
        self.Output.meta.assignFrom(self.Input.meta)
        self.OutputHdf5.meta.assignFrom(self.Input.meta)
        self.CleanBlocks.meta.shape = (1, )
        self.CleanBlocks.meta.dtype = object

        # no block shape given -> use the whole volume as one block
        new_blockshape = self.Input.meta.shape
        if self.BlockShape.ready():
            new_blockshape = self.BlockShape.value

        if len(new_blockshape) != len(self.Input.meta.shape):
            self.Output.meta.NOTREADY = True
            self.CleanBlocks.meta.NOTREADY = True
            self.OutputHdf5.meta.NOTREADY = True
            self._init_cache(None)
            return

        # Clip blockshape to image bounds
        new_blockshape = tuple(
            numpy.minimum(new_blockshape, self.Input.meta.shape))

        if new_blockshape != self._blockshape:
            # If the blockshape changes, we have to reset the entire cache.
            self._init_cache(new_blockshape)

        self.Output.meta.ideal_blockshape = new_blockshape

    def execute(self, slot, subindex, roi, destination):
        if slot == self.Output:
            return self._executeOutput(roi, destination)
        elif slot == self.CleanBlocks:
            return self._executeCleanBlocks(destination)
        elif slot == self.OutputHdf5:
            return self._executeOutputHdf5(roi, destination)
        else:
            assert False, "Unknown output slot: {}".format(slot.name)

    def _executeOutput(self, roi, destination):
        assert len(roi.stop) == len(
            self.Input.meta.shape
        ), "roi: {} has the wrong number of dimensions for Input shape: {}".format(
            roi, self.Input.meta.shape)
        assert numpy.less_equal(roi.stop, self.Input.meta.shape).all(
        ), "roi: {} is out-of-bounds for Input shape: {}".format(
            roi, self.Input.meta.shape)

        block_starts = getIntersectingBlocks(self._blockshape,
                                             (roi.start, roi.stop))
        block_starts = list(map(tuple, block_starts))

        # Ensure all block cache files are up-to-date
        self._waitForBlocks(block_starts)
        self._copyData(roi, destination, block_starts)
        return destination

    def _waitForBlocks(self, block_starts):
        """
        Make sure that all blocks in the given list of blocks are present in the cache before returning.
        (Blocks that are not yet present will be requested from our Input slot.)
        """
        reqPool = RequestPool()  # (Do the work in parallel.)
        for block_start in block_starts:
            entire_block_roi = getBlockBounds(self.Output.meta.shape,
                                              self._blockshape, block_start)
            f = partial(self._ensureCached, entire_block_roi)
            reqPool.add(Request(f))
        logger.debug("Waiting for {} blocks...".format(len(block_starts)))
        reqPool.wait()

    def _copyData(self, roi, destination, block_starts):
        # Copy data from each block
        # (Parallelism not needed here: h5py will serialize these requests anyway)
        logger.debug("Copying data from {} blocks...".format(
            len(block_starts)))
        for block_start in block_starts:
            entire_block_roi = getBlockBounds(self.Output.meta.shape,
                                              self._blockshape, block_start)

            # This block's portion of the roi
            intersecting_roi = getIntersection((roi.start, roi.stop),
                                               entire_block_roi)

            # Compute slicing within destination array and slicing within this block
            destination_relative_intersection = numpy.subtract(
                intersecting_roi, roi.start)
            block_relative_intersection = numpy.subtract(
                intersecting_roi, block_start)
            destination_relative_intersection_slicing = roiToSlice(
                *destination_relative_intersection)
            block_relative_intersection_slicing = roiToSlice(
                *block_relative_intersection)

            # Copy from block to destination
            dataset = self._getBlockDataset(entire_block_roi)
            if self.Output.meta.has_mask:
                destination.data[
                    destination_relative_intersection_slicing] = dataset[
                        "data"][block_relative_intersection_slicing]
                destination.mask[
                    destination_relative_intersection_slicing] = dataset[
                        "mask"][block_relative_intersection_slicing]
                destination.fill_value = dataset["fill_value"][()]
            else:
                destination[
                    destination_relative_intersection_slicing] = dataset[
                        block_relative_intersection_slicing]
            self._last_access_times[block_start] = time.time()

    def _executeCleanBlocks(self, destination):
        """
        Execute function for the CleanBlocks output slot, which produces
        an *unsorted* list of block rois that the cache currently holds.
        """
        # Set difference: clean = existing - dirty
        clean_block_starts = set(self._cacheFiles.keys()) - self._dirtyBlocks

        output_shape = self.Output.meta.shape
        clean_block_rois = list(
            map(partial(getBlockBounds, output_shape, self._blockshape),
                clean_block_starts))
        results = []
        for cbr in clean_block_rois:
            results.append([TinyVector(cbr[0]), TinyVector(cbr[1])])
        destination[0] = results
        return destination

    def _executeOutputHdf5(self, roi, destination):
        logger.debug("Servicing request for hdf5 block {}".format(roi))

        assert isinstance(
            destination, h5py.Group
        ), "OutputHdf5 slot requires an hdf5 GROUP to copy into (not a numpy array)."
        assert ((roi.start % self._blockshape) == 0).all(
        ), "OutputHdf5 slot requires roi to be exactly one block."
        block_roi = getBlockBounds(self.Output.meta.shape, self._blockshape,
                                   roi.start)
        assert (block_roi == numpy.array(
            (roi.start, roi.stop
             ))).all(), "OutputHdf5 slot requires roi to be exactly one block."

        block_roi = [roi.start, roi.stop]
        self._ensureCached(block_roi)
        dataset = self._getBlockDataset(block_roi)
        assert str(
            block_roi
        ) not in destination, "destination hdf5 group already has a dataset with this block's name"
        destination.copy(dataset, str(block_roi))
        return destination

    def propagateDirty(self, slot, subindex, roi):
        if slot == self.Input:
            # Keep track of dirty blocks
            if self._blockshape is not None:
                with self._lock:
                    block_starts = getIntersectingBlocks(
                        self._blockshape, (roi.start, roi.stop))
                    block_starts = list(map(tuple, block_starts))

                    for block_start in block_starts:
                        self._dirtyBlocks.add(block_start)
            # Forward to downstream connections
            self.Output.setDirty(roi)
        elif slot == self.BlockShape:
            # Everything is dirty
            self.Output.setDirty(slice(None))
        else:
            assert False, "Unknown output slot"

    def _chooseChunkshape(self, blockshape):
        """
        Choose an optimal chunkshape for our blockshape and Input shape.
        We assume access patterns to vary more in space than in time or channel
        and choose the inner chunk shape to be about 1MiB slices of t and c.
        Furthermore, we use the function
          lazyflow.utility.chunkHelpers.chooseChunkShape()
        to preserve the aspect ratio of the input (at least approximately).
        """
        if blockshape is None:
            return None

        def isConsistent(idealshape):
            """
            check if ideal block shape and given block shape are consistent

            shapes are consistent if, for each dimension,
                * input is unready, or
                * blockshape equals fullshape, or
                * idealshape divides blockshape evenly
            """
            if not self.Input.ready():
                return True

            fullshape = self.Input.meta.shape
            z = list(zip(idealshape, blockshape, fullshape))
            m = [
                i_b_f[1] == i_b_f[2] or i_b_f[1] % i_b_f[0] == 0 for i_b_f in z
            ]
            return all(m)

        if not self._ignore_ideal_blockshape and self.Input.ready():
            # take the ideal chunk shape, but check if sane
            ideal = self.Input.meta.ideal_blockshape
            if ideal is not None:
                if len(ideal) == len(blockshape):
                    ideal = numpy.asarray(ideal, dtype=numpy.int)
                    for i, d in enumerate(ideal):
                        if d == 0:
                            ideal[i] = blockshape[i]
                    if not isConsistent(ideal):
                        logger.warning(
                            "{}: BlockShape and ideal_blockshape are "
                            "inconsistent {} vs {}".format(
                                self.name, blockshape, ideal))
                    else:
                        return tuple(ideal)
                else:
                    logger.warning(
                        "{}: Encountered meta.ideal_blockshape that does not fit the data"
                        .format(self.name))

        # we need to figure out an ideal chunk shape on our own

        # Start with a copy of blockshape
        axes = list(self.Output.meta.getTaggedShape().keys())
        taggedBlockShape = collections.OrderedDict(
            list(zip(axes, self._blockshape)))

        dtypeBytes = self._getDtypeBytes(self.Output.meta.dtype)

        desiredSpace = 1024**2 / float(dtypeBytes)

        if bigintprod(blockshape) <= desiredSpace:
            return blockshape

        # set t and c to 1
        for key in "tc":
            if key in taggedBlockShape:
                taggedBlockShape[key] = 1
        logger.debug("desired space: {}".format(desiredSpace))

        # extract only the spatial shape
        spatialKeys = [k for k in list(taggedBlockShape.keys()) if k in "xyz"]
        spatialShape = [taggedBlockShape[k] for k in spatialKeys]
        newSpatialShape = chooseChunkShape(spatialShape, desiredSpace)
        for k, v in zip(spatialKeys, newSpatialShape):
            taggedBlockShape[k] = v
        chunkShape = tuple(taggedBlockShape.values())
        logger.debug("Using chunk shape: {}".format(chunkShape))
        return chunkShape

    def _getDtypeBytes(self, dtype):
        if isinstance(dtype, numpy.dtype):
            # Make sure we're dealing with a type (e.g. numpy.float64),
            #  not a numpy.dtype
            dtype = dtype.type
        return dtype().nbytes

    def usedMemory(self):
        tot, unc = self._usedMemory()
        self._compression_factor = 1.0
        if tot > 0:
            self._compression_factor = unc / float(tot)
        return tot

    def _usedMemory(self):
        tot = 0.0
        unc = 0.0
        for key in list(self._cacheFiles.keys()):
            real, virt = self._memoryForBlock(key)
            tot += real
            unc += virt
        return tot, unc

    def _memoryForBlock(self, key):
        try:
            group = self._cacheFiles[key]
        except KeyError:
            # entry was removed, ignore it
            return 0
        tot = 0
        unc = 0
        if "data" in group:
            ds = group["data"]
            # actual size
            tot += get_storage_size(ds)
            # uncompressed size
            unc += ds.size * self._getDtypeBytes(ds.dtype)
        if "mask" in group:
            tot += group["mask"].size * self._getDtypeBytes(
                group["mask"].dtype)
        if "fill_value" in group:
            tot += group["fill_value"].size * self._getDtypeBytes(
                group["fill_value"].dtype)
        return tot, unc

    def _getCacheFile(self, entire_block_roi):
        """
        Get the cache file for the block that starts at block_start.
        If it doesn't exist yet, create it first.
        """
        block_start = tuple(entire_block_roi[0])
        if block_start in self._cacheFiles:
            return self._cacheFiles[block_start]
        with self._lock:
            if block_start not in self._cacheFiles:
                # Create an in-memory hdf5 file with a unique name
                # (the counter ensures that even blocks that have been deleted previously get a unique name when they are re-created).
                logger.debug("Creating a cache file for block: {}".format(
                    list(block_start)))
                filename = (str(id(self)) + str(id(self._cacheFiles)) +
                            str(block_start) +
                            str(next(self._block_id_counter)))
                mem_file = h5py.File(filename,
                                     driver="core",
                                     backing_store=False,
                                     mode="w")

                # h5py will crash if the chunkshape is larger than the dataset shape.
                datashape = tuple(entire_block_roi[1] - entire_block_roi[0])
                chunkshape = numpy.minimum(numpy.array(datashape),
                                           self._chunkshape)
                chunkshape = tuple(chunkshape)

                # Make a compressed dataset
                mem_file.create_dataset(
                    "data",
                    shape=datashape,
                    dtype=self.Output.meta.dtype,
                    chunks=chunkshape,
                    compression="lzf")  # lzf should be faster than gzip,
                # with a slightly worse compression ratio
                # Add mask information if needed.
                if self.Output.meta.has_mask:
                    mem_file.create_dataset(
                        "mask",
                        shape=datashape,
                        dtype=bool,
                        chunks=chunkshape,
                        compression="lzf")  # lzf should be faster than gzip,
                    # with a slightly worse compression ratio
                    mem_file.create_dataset("fill_value",
                                            shape=tuple(),
                                            dtype=self.Output.meta.dtype)

                self._blockLocks[block_start] = RequestLock()
                self._cacheFiles[block_start] = mem_file
                self._dirtyBlocks.add(block_start)
            return self._cacheFiles[block_start]

    def _ensureCached(self, entire_block_roi):
        """
        Ensure that the cache file for the given block is up-to-date.
        (Refresh it if it's dirty.)
        """
        block_start = tuple(entire_block_roi[0])
        block_file = self._getCacheFile(entire_block_roi)
        if block_start in self._dirtyBlocks:
            updated_cache = False
            with self._blockLocks[block_start]:
                # Check AGAIN now that we have the lock.
                # (Avoid doing this twice in parallel requests.)
                if block_start in self._dirtyBlocks:
                    # Can't write directly into the hdf5 dataset because
                    #  h5py.dataset.__getitem__ creates a copy, not a view.
                    # We must use a temporary numpy array to hold the data.
                    data = self.Input(*entire_block_roi).wait()
                    block_file["data"][...] = data
                    if self.Output.meta.has_mask:
                        block_file["mask"][...] = data.mask
                        block_file["fill_value"][...] = data.fill_value

                    if logger.isEnabledFor(logging.DEBUG):
                        uncompressed_size = bigintprod(
                            data.shape) * self._getDtypeBytes(data.dtype)
                        storage_size = block_file["data"].id.get_storage_size()
                        if "mask" in block_file:
                            storage_size += block_file[
                                "mask"].id.get_storage_size()
                        if "fill_value" in block_file:
                            storage_size += block_file[
                                "fill_value"].id.get_storage_size()
                        logger.debug(
                            "Storage for block: {} is {}. ({}% of original)".
                            format(block_start, storage_size,
                                   100 * storage_size / uncompressed_size))
                    with self._lock:
                        self._dirtyBlocks.remove(block_start)
                    updated_cache = True

            if updated_cache:
                # Now that the lock is released, signal that the cache was updated.
                self.Output._sig_value_changed()
                self.OutputHdf5._sig_value_changed()
                self.CleanBlocks._sig_value_changed()

    def setInSlot(self, slot, subindex, roi, value):
        """
        Overridden from Operator
        """
        if slot == self.Input:
            self._setInSlotInput(slot, subindex, roi, value)
        elif slot == self.InputHdf5:
            self._setInSlotInputHdf5(slot, subindex, roi, value)
        else:
            assert False, "Invalid input slot for setInSlot(): {}".format(
                slot.name)

    def _setInSlotInput(self,
                        slot,
                        subindex,
                        roi,
                        value,
                        store_zero_blocks=True):
        """
        Write the data in the array 'value' into the cache.
        If the optional store_zero_blocks param is False, then don't bother
        creating cache blocks for blocks that are totally zero.
        """
        assert len(roi.stop) == len(
            self.Input.meta.shape
        ), "roi: {} has the wrong number of dimensions for Input shape: {}".format(
            roi, self.Input.meta.shape)
        assert numpy.less_equal(roi.stop, self.Input.meta.shape).all(
        ), "roi: {} is out-of-bounds for Input shape: {}".format(
            roi, self.Input.meta.shape)

        block_starts = getIntersectingBlocks(self._blockshape,
                                             (roi.start, roi.stop))
        block_starts = list(map(tuple, block_starts))

        # Copy data to each block
        logger.debug("Copying data INTO {} blocks...".format(
            len(block_starts)))
        for block_start in block_starts:
            entire_block_roi = getBlockBounds(self.Output.meta.shape,
                                              self._blockshape, block_start)

            # This block's portion of the roi
            intersecting_roi = getIntersection((roi.start, roi.stop),
                                               entire_block_roi)

            # Compute slicing within source array and slicing within this block
            source_relative_intersection = numpy.subtract(
                intersecting_roi, roi.start)
            block_relative_intersection = numpy.subtract(
                intersecting_roi, block_start)
            source_relative_intersection_slicing = roiToSlice(
                *source_relative_intersection)
            block_relative_intersection_slicing = roiToSlice(
                *block_relative_intersection)

            new_block_data = value[source_relative_intersection_slicing]
            new_block_sum = new_block_data.sum()
            if not store_zero_blocks and new_block_sum == 0 and block_start not in self._cacheFiles:
                # Special fast-path: If this block doesn't exist yet,
                #  don't bother creating if we're just going to fill it with zeros.
                # (This feature is used by the OpCompressedUserLabelArray)
                pass
            else:
                # Copy from source to block
                dataset = self._getBlockDataset(entire_block_roi)
                if self.Output.meta.has_mask:
                    dataset["data"][
                        block_relative_intersection_slicing] = new_block_data.data
                    dataset["mask"][
                        block_relative_intersection_slicing] = new_block_data.mask
                    dataset["fill_value"][()] = new_block_data.fill_value

                    # Untested. Write a test to use this.
                    # # If we can, remove this block entirely.
                    # if not store_zero_blocks and new_block_sum == 0 and (dataset["data"][:] == 0).all() and (dataset["mask"]).any() and (dataset["fill_value"] == 0).all():
                    #     with self._lock:
                    #         with self._blockLocks[block_start]:
                    #            self._cacheFiles[block_start].close()
                    #            del self._cacheFiles[block_start]
                    #         del self._blockLocks[block_start]
                else:
                    dataset[
                        block_relative_intersection_slicing] = new_block_data

                    # If we can, remove this block entirely.
                    if not store_zero_blocks and new_block_sum == 0 and (
                            dataset[:] == 0).all():
                        with self._lock:
                            with self._blockLocks[block_start]:
                                self._cacheFiles[block_start].close()
                                del self._cacheFiles[block_start]
                            del self._blockLocks[block_start]

            # Here, we assume that if this function is used to update ANY PART of a
            #  block, he is responsible for updating the ENTIRE block.
            # Therefore, this block is no longer 'dirty'
            self._dirtyBlocks.discard(block_start)

    #            self.Output._sig_value_changed()
    #            self.OutputHdf5._sig_value_changed()
    #            self.CleanBlocks._sig_value_changed()

    def _setInSlotInputHdf5(self, slot, subindex, roi, value):
        logger.debug("Setting block {} from hdf5".format(roi))
        if self.Output.meta.has_mask:
            assert isinstance(
                value, h5py.Group
            ), "InputHdf5 slot requires an hdf5 Group to copy from (not a numpy masked array)."
        else:
            assert isinstance(
                value, h5py.Dataset
            ), "InputHdf5 slot requires an hdf5 Dataset to copy from (not a numpy array)."

        block_roi = getBlockBounds(self.Output.meta.shape, self._blockshape,
                                   roi.start)

        roi_is_exactly_one_block = True
        roi_is_exactly_one_block &= ((roi.start % self._blockshape) == 0).all()
        roi_is_exactly_one_block &= (block_roi == numpy.array(
            (roi.start, roi.stop))).all()
        if roi_is_exactly_one_block:
            cachefile = self._getCacheFile(block_roi)
            logger.debug(
                "Copying HDF5 data directly into block {}".format(block_roi))

            if self.Output.meta.has_mask:
                assert len(value) == 3

                for each in ["data", "mask", "fill_value"]:
                    assert each in value
                    assert cachefile[each].dtype == value[each].dtype
                    assert cachefile[each].shape == value[each].shape

                for each in ["data", "mask", "fill_value"]:
                    del cachefile[each]
                    cachefile.copy(value[each], each)
            else:
                assert cachefile["data"].dtype == value.dtype
                assert cachefile["data"].shape == value.shape
                del cachefile["data"]
                cachefile.copy(value, "data")

            block_start = tuple(roi.start)
            self._dirtyBlocks.discard(block_start)
        else:
            # This hdf5 data does not correspond to exactly one block.
            # We must uncompress it and write it the "normal" way (the slow way)
            # FIXME: This would use less memory if we uncompressed the data block-by-block
            data = None

            if self.Output.meta.has_mask:
                data = numpy.ma.masked_array(
                    value["data"][()],
                    mask=value["mask"][()],
                    fill_value=value["fill_value"][()],
                    shrink=False)
            else:
                data = value[()]

            self.Input[roiToSlice(roi.start, roi.stop)] = data

    #        self.Output._sig_value_changed()
    #        self.OutputHdf5._sig_value_changed()
    #        self.CleanBlocks._sig_value_changed()

    def _getBlockDataset(self, entire_block_roi):
        """
        Get the correct cache file and return the *dataset* handle,
        not a numpy array of its contents.
        """
        block_file = self._getCacheFile(entire_block_roi)
        if self.Output.meta.has_mask:
            return block_file["/"]
        else:
            return block_file["data"]

    def _closeAllCacheFiles(self):
        logger.debug("Closing all caches")
        cacheFiles = self._cacheFiles
        for k, v in list(cacheFiles.items()):
            with self._blockLocks[k]:
                v.close()
        with self._lock:
            self._blockLocks = {}
            self._cacheFiles = {}
Example #6
0
class OpDataSelection(Operator):
    """
    The top-level operator for the data selection applet, implemented as a single-image operator.
    The applet uses an OperatorWrapper to make it suitable for use in a workflow.
    """
    name = "OpDataSelection"
    category = "Top-level"
    
    SupportedExtensions = OpInputDataReader.SupportedExtensions

    # Inputs
    RoleName = InputSlot(stype='string', value='')
    ProjectFile = InputSlot(stype='object', optional=True) #: The project hdf5 File object (already opened)
    ProjectDataGroup = InputSlot(stype='string', optional=True) #: The internal path to the hdf5 group where project-local datasets are stored within the project file
    WorkingDirectory = InputSlot(stype='filestring') #: The filesystem directory where the project file is located
    Dataset = InputSlot(stype='object') #: A DatasetInfo object

    # Outputs
    Image = OutputSlot() #: The output image
    AllowLabels = OutputSlot(stype='bool') #: A bool indicating whether or not this image can be used for training

    _NonTransposedImage = OutputSlot() #: The output slot, in the data's original axis ordering (regardless of forceAxisOrder)

    ImageName = OutputSlot(stype='string') #: The name of the output image
    
    class InvalidDimensionalityError(Exception):
        """Raised if the user tries to replace the dataset with a new one of differing dimensionality."""
        def __init__(self, message ):
            super( OpDataSelection.InvalidDimensionalityError, self ).__init__()
            self.message = message
        
        def __str__(self):
            return self.message

    def __init__(self, forceAxisOrder=False, *args, **kwargs):
        super(OpDataSelection, self).__init__(*args, **kwargs)
        self.forceAxisOrder = forceAxisOrder
        self._opReaders = []

        # If the gui calls disconnect() on an input slot without replacing it with something else,
        #  we still need to clean up the internal operator that was providing our data.
        self.ProjectFile.notifyUnready(self.internalCleanup)
        self.ProjectDataGroup.notifyUnready(self.internalCleanup)
        self.WorkingDirectory.notifyUnready(self.internalCleanup)
        self.Dataset.notifyUnready(self.internalCleanup)

    def internalCleanup(self, *args):
        if len(self._opReaders) > 0:
            self.Image.disconnect()
            self._NonTransposedImage.disconnect()
            for reader in reversed(self._opReaders):
                reader.cleanUp()
            self._opReaders = []
    
    def setupOutputs(self):
        self.internalCleanup()
        datasetInfo = self.Dataset.value

        try:
            # Data only comes from the project file if the user said so AND it exists in the project
            datasetInProject = (datasetInfo.location == DatasetInfo.Location.ProjectInternal)
            datasetInProject &= self.ProjectFile.ready()
            if datasetInProject:
                internalPath = self.ProjectDataGroup.value + '/' + datasetInfo.datasetId
                datasetInProject &= internalPath in self.ProjectFile.value
    
            # If we should find the data in the project file, use a dataset reader
            if datasetInProject:
                opReader = OpStreamingHdf5Reader(parent=self)
                opReader.Hdf5File.setValue(self.ProjectFile.value)
                opReader.InternalPath.setValue(internalPath)
                providerSlot = opReader.OutputImage
            elif datasetInfo.location == DatasetInfo.Location.PreloadedArray:
                preloaded_array = datasetInfo.preloaded_array
                assert preloaded_array is not None
                if not hasattr(preloaded_array, 'axistags'):
                    # Guess the axis order, since one was not provided.
                    axisorders = { 2 : 'yx',
                                   3 : 'zyx',
                                   4 : 'zyxc',
                                   5 : 'tzyxc' }

                    shape = preloaded_array.shape
                    ndim = preloaded_array.ndim            
                    assert ndim != 0, "Support for 0-D data not yet supported"
                    assert ndim != 1, "Support for 1-D data not yet supported"
                    assert ndim <= 5, "No support for data with more than 5 dimensions."
        
                    axisorder = axisorders[ndim]
                    if ndim == 3 and shape[2] <= 4:
                        # Special case: If the 3rd dim is small, assume it's 'c', not 'z'
                        axisorder = 'yxc'
                    preloaded_array = vigra.taggedView(preloaded_array, axisorder)
                opReader = OpArrayPiper(parent=self)
                opReader.Input.setValue( preloaded_array )
                providerSlot = opReader.Output
            else:
                # Use a normal (filesystem) reader
                opReader = OpInputDataReader(parent=self)
                if datasetInfo.subvolume_roi is not None:
                    opReader.SubVolumeRoi.setValue( datasetInfo.subvolume_roi )
                opReader.WorkingDirectory.setValue( self.WorkingDirectory.value )
                opReader.FilePath.setValue(datasetInfo.filePath)
                providerSlot = opReader.Output
            self._opReaders.append(opReader)
            
            # Inject metadata if the dataset info specified any.
            # Also, inject if if dtype is uint8, which we can reasonably assume has drange (0,255)
            metadata = {}
            metadata['display_mode'] = datasetInfo.display_mode
            role_name = self.RoleName.value
            if 'c' not in providerSlot.meta.getTaggedShape():
                num_channels = 0
            else:
                num_channels = providerSlot.meta.getTaggedShape()['c']
            if num_channels > 1:
                metadata['channel_names'] = ["{}-{}".format(role_name, i) for i in range(num_channels)]
            else:
                metadata['channel_names'] = [role_name]
                 
            if datasetInfo.drange is not None:
                metadata['drange'] = datasetInfo.drange
            elif providerSlot.meta.dtype == numpy.uint8:
                # SPECIAL case for uint8 data: Provide a default drange.
                # The user can always override this herself if she wants.
                metadata['drange'] = (0,255)
            if datasetInfo.normalizeDisplay is not None:
                metadata['normalizeDisplay'] = datasetInfo.normalizeDisplay
            if datasetInfo.axistags is not None:
                if len(datasetInfo.axistags) != len(providerSlot.meta.shape):
                    # This usually only happens when we copied a DatasetInfo from another lane,
                    # and used it as a 'template' to initialize this lane.
                    # This happens in the BatchProcessingApplet when it attempts to guess the axistags of 
                    # batch images based on the axistags chosen by the user in the interactive images.
                    # If the interactive image tags don't make sense for the batch image, you get this error.
                    raise Exception( "Your dataset's provided axistags ({}) do not have the "
                                     "correct dimensionality for your dataset, which has {} dimensions."
                                     .format( "".join(tag.key for tag in datasetInfo.axistags), len(providerSlot.meta.shape) ) )
                metadata['axistags'] = datasetInfo.axistags
            if datasetInfo.subvolume_roi is not None:
                metadata['subvolume_roi'] = datasetInfo.subvolume_roi
                
                # FIXME: We are overwriting the axistags metadata to intentionally allow 
                #        the user to change our interpretation of which axis is which.
                #        That's okay, but technically there's a special corner case if 
                #        the user redefines the channel axis index.  
                #        Technically, it invalidates the meaning of meta.ram_usage_per_requested_pixel.
                #        For most use-cases, that won't really matter, which is why I'm not worrying about it right now.
            
            opMetadataInjector = OpMetadataInjector( parent=self )
            opMetadataInjector.Input.connect( providerSlot )
            opMetadataInjector.Metadata.setValue( metadata )
            providerSlot = opMetadataInjector.Output
            self._opReaders.append( opMetadataInjector )

            self._NonTransposedImage.connect(providerSlot)
            
            if self.forceAxisOrder:
                # Before we re-order, make sure no non-singleton 
                #  axes would be dropped by the forced order.
                output_order = "".join(self.forceAxisOrder)
                provider_order = "".join(providerSlot.meta.getAxisKeys())
                tagged_provider_shape = providerSlot.meta.getTaggedShape()
                dropped_axes = set(provider_order) - set(output_order)
                if any(tagged_provider_shape[a] > 1 for a in dropped_axes):
                    msg = "The axes of your dataset ({}) are not compatible with the axes used by this workflow ({}). Please fix them."\
                          .format(provider_order, output_order)
                    raise DatasetConstraintError("DataSelection", msg)

                op5 = OpReorderAxes(parent=self)
                op5.AxisOrder.setValue(self.forceAxisOrder)
                op5.Input.connect(providerSlot)
                providerSlot = op5.Output
                self._opReaders.append(op5)
            
            # If the channel axis is not last (or is missing),
            #  make sure the axes are re-ordered so that channel is last.
            if providerSlot.meta.axistags.index('c') != len( providerSlot.meta.axistags )-1:
                op5 = OpReorderAxes( parent=self )
                keys = providerSlot.meta.getTaggedShape().keys()
                try:
                    # Remove if present.
                    keys.remove('c')
                except ValueError:
                    pass
                # Append
                keys.append('c')
                op5.AxisOrder.setValue( "".join( keys ) )
                op5.Input.connect( providerSlot )
                providerSlot = op5.Output
                self._opReaders.append( op5 )
            
            # Connect our external outputs to the internal operators we chose
            self.Image.connect(providerSlot)
            
            # Set the image name and usage flag
            self.AllowLabels.setValue( datasetInfo.allowLabels )
            
            # If the reading operator provides a nickname, use it.
            if self.Image.meta.nickname is not None:
                datasetInfo.nickname = self.Image.meta.nickname
            
            imageName = datasetInfo.nickname
            if imageName == "":
                imageName = datasetInfo.filePath
            self.ImageName.setValue(imageName)
        
        except:
            self.internalCleanup()
            raise

    def propagateDirty(self, slot, subindex, roi):
        # Output slots are directly connected to internal operators
        pass

    @classmethod
    def getInternalDatasets(cls, filePath):
        return OpInputDataReader.getInternalDatasets( filePath )
Example #7
0
class OpDataSelection(Operator):
    """
    The top-level operator for the data selection applet, implemented as a single-image operator.
    The applet uses an OperatorWrapper to make it suitable for use in a workflow.
    """
    name = "OpDataSelection"
    category = "Top-level"

    SupportedExtensions = OpInputDataReader.SupportedExtensions

    # Inputs
    RoleName = InputSlot(stype='string', value='')
    ProjectFile = InputSlot(
        stype='object',
        optional=True)  # : The project hdf5 File object (already opened)
    # : The internal path to the hdf5 group where project-local datasets are stored within the project file
    ProjectDataGroup = InputSlot(stype='string', optional=True)
    WorkingDirectory = InputSlot(
        stype='filestring'
    )  # : The filesystem directory where the project file is located
    Dataset = InputSlot(stype='object')  # : A DatasetInfo object

    # Outputs
    Image = OutputSlot()  # : The output image
    AllowLabels = OutputSlot(
        stype='bool'
    )  # : A bool indicating whether or not this image can be used for training

    # : The output slot, in the data's original axis ordering (regardless of forceAxisOrder)
    _NonTransposedImage = OutputSlot()

    ImageName = OutputSlot(stype='string')  # : The name of the output image

    class InvalidDimensionalityError(Exception):
        """Raised if the user tries to replace the dataset with a new one of differing dimensionality."""
        def __init__(self, message):
            super(OpDataSelection.InvalidDimensionalityError, self).__init__()
            self.message = message

        def __str__(self):
            return self.message

    def __init__(self, forceAxisOrder=['tczyx'], *args, **kwargs):
        """
        forceAxisOrder: How to auto-reorder the input data before connecting it to the rest of the workflow.
                        Should be a list of input orders that are allowed by the workflow
                        For example, if the workflow can handle 2D and 3D, you might pass ['yxc', 'zyxc'].
                        If it only handles exactly 5D, you might pass 'tzyxc', assuming that's how you wrote the
                        workflow.
                        todo: move toward 'tczyx' standard.
        """
        super(OpDataSelection, self).__init__(*args, **kwargs)
        self.forceAxisOrder = forceAxisOrder
        self._opReaders = []

        # If the gui calls disconnect() on an input slot without replacing it with something else,
        #  we still need to clean up the internal operator that was providing our data.
        self.ProjectFile.notifyUnready(self.internalCleanup)
        self.ProjectDataGroup.notifyUnready(self.internalCleanup)
        self.WorkingDirectory.notifyUnready(self.internalCleanup)
        self.Dataset.notifyUnready(self.internalCleanup)

    def internalCleanup(self, *args):
        if len(self._opReaders) > 0:
            self.Image.disconnect()
            self._NonTransposedImage.disconnect()
            for reader in reversed(self._opReaders):
                reader.cleanUp()
            self._opReaders = []

    def setupOutputs(self):
        self.internalCleanup()
        datasetInfo = self.Dataset.value

        try:
            # Data only comes from the project file if the user said so AND it exists in the project
            datasetInProject = (
                datasetInfo.location == DatasetInfo.Location.ProjectInternal)
            datasetInProject &= self.ProjectFile.ready()
            if datasetInProject:
                internalPath = self.ProjectDataGroup.value + '/' + datasetInfo.datasetId
                datasetInProject &= internalPath in self.ProjectFile.value

            # If we should find the data in the project file, use a dataset reader
            if datasetInProject:
                opReader = OpStreamingH5N5Reader(parent=self)
                opReader.H5N5File.setValue(self.ProjectFile.value)
                opReader.InternalPath.setValue(internalPath)
                providerSlot = opReader.OutputImage
            elif datasetInfo.location == DatasetInfo.Location.PreloadedArray:
                preloaded_array = datasetInfo.preloaded_array
                assert preloaded_array is not None
                if not hasattr(preloaded_array, 'axistags'):
                    axisorder = get_default_axisordering(preloaded_array.shape)
                    preloaded_array = vigra.taggedView(preloaded_array,
                                                       axisorder)

                opReader = OpArrayPiper(parent=self)
                opReader.Input.setValue(preloaded_array)
                providerSlot = opReader.Output
            else:
                if datasetInfo.realDataSource:
                    # Use a normal (filesystem) reader
                    opReader = OpInputDataReader(parent=self)
                    if datasetInfo.subvolume_roi is not None:
                        opReader.SubVolumeRoi.setValue(
                            datasetInfo.subvolume_roi)
                    opReader.WorkingDirectory.setValue(
                        self.WorkingDirectory.value)
                    opReader.SequenceAxis.setValue(datasetInfo.sequenceAxis)
                    opReader.FilePath.setValue(datasetInfo.filePath)
                else:
                    # Use fake reader: allows to run the project in a headless
                    # mode without the raw data
                    opReader = OpZeroDefault(parent=self)
                    opReader.MetaInput.meta = MetaDict(
                        shape=datasetInfo.laneShape,
                        dtype=datasetInfo.laneDtype,
                        drange=datasetInfo.drange,
                        axistags=datasetInfo.axistags)
                    opReader.MetaInput.setValue(
                        numpy.zeros(datasetInfo.laneShape,
                                    dtype=datasetInfo.laneDtype))
                providerSlot = opReader.Output
            self._opReaders.append(opReader)

            # Inject metadata if the dataset info specified any.
            # Also, inject if if dtype is uint8, which we can reasonably assume has drange (0,255)
            metadata = {}
            metadata['display_mode'] = datasetInfo.display_mode
            role_name = self.RoleName.value
            if 'c' not in providerSlot.meta.getTaggedShape():
                num_channels = 0
            else:
                num_channels = providerSlot.meta.getTaggedShape()['c']
            if num_channels > 1:
                metadata['channel_names'] = [
                    "{}-{}".format(role_name, i) for i in range(num_channels)
                ]
            else:
                metadata['channel_names'] = [role_name]

            if datasetInfo.drange is not None:
                metadata['drange'] = datasetInfo.drange
            elif providerSlot.meta.dtype == numpy.uint8:
                # SPECIAL case for uint8 data: Provide a default drange.
                # The user can always override this herself if she wants.
                metadata['drange'] = (0, 255)
            if datasetInfo.normalizeDisplay is not None:
                metadata['normalizeDisplay'] = datasetInfo.normalizeDisplay
            if datasetInfo.axistags is not None:
                if len(datasetInfo.axistags) != len(providerSlot.meta.shape):
                    ts = providerSlot.meta.getTaggedShape()
                    if 'c' in ts and 'c' not in datasetInfo.axistags and len(
                            datasetInfo.axistags) + 1 == len(ts):
                        # provider has no channel axis, but template has => add channel axis to provider
                        # fixme: Optimize the axistag guess in BatchProcessingApplet instead of hoping for the best here
                        metadata['axistags'] = vigra.defaultAxistags(
                            ''.join(datasetInfo.axistags.keys()) + 'c')
                    else:
                        # This usually only happens when we copied a DatasetInfo from another lane,
                        # and used it as a 'template' to initialize this lane.
                        # This happens in the BatchProcessingApplet when it attempts to guess the axistags of
                        # batch images based on the axistags chosen by the user in the interactive images.
                        # If the interactive image tags don't make sense for the batch image, you get this error.
                        raise Exception(
                            "Your dataset's provided axistags ({}) do not have the "
                            "correct dimensionality for your dataset, which has {} dimensions."
                            .format(
                                "".join(tag.key
                                        for tag in datasetInfo.axistags),
                                len(providerSlot.meta.shape)))
                else:
                    metadata['axistags'] = datasetInfo.axistags
            if datasetInfo.original_axistags is not None:
                metadata['original_axistags'] = datasetInfo.original_axistags

            if datasetInfo.subvolume_roi is not None:
                metadata['subvolume_roi'] = datasetInfo.subvolume_roi

                # FIXME: We are overwriting the axistags metadata to intentionally allow
                #        the user to change our interpretation of which axis is which.
                #        That's okay, but technically there's a special corner case if
                #        the user redefines the channel axis index.
                #        Technically, it invalidates the meaning of meta.ram_usage_per_requested_pixel.
                #        For most use-cases, that won't really matter, which is why I'm not worrying about it right now.

            opMetadataInjector = OpMetadataInjector(parent=self)
            opMetadataInjector.Input.connect(providerSlot)
            opMetadataInjector.Metadata.setValue(metadata)
            providerSlot = opMetadataInjector.Output
            self._opReaders.append(opMetadataInjector)

            self._NonTransposedImage.connect(providerSlot)

            # make sure that x and y axes are present in the selected axis order
            if 'x' not in providerSlot.meta.axistags or 'y' not in providerSlot.meta.axistags:
                raise DatasetConstraintError(
                    "DataSelection",
                    "Data must always have at leaset the axes x and y for ilastik to work."
                )

            if self.forceAxisOrder:
                assert isinstance(self.forceAxisOrder, list), \
                    "forceAxisOrder should be a *list* of preferred axis orders"

                # Before we re-order, make sure no non-singleton
                #  axes would be dropped by the forced order.
                tagged_provider_shape = providerSlot.meta.getTaggedShape()
                minimal_axes = [
                    k_v for k_v in list(tagged_provider_shape.items())
                    if k_v[1] > 1
                ]
                minimal_axes = set(k for k, v in minimal_axes)

                # Pick the shortest of the possible 'forced' orders that
                # still contains all the axes of the original dataset.
                candidate_orders = list(self.forceAxisOrder)
                candidate_orders = [
                    order for order in candidate_orders
                    if minimal_axes.issubset(set(order))
                ]

                if len(candidate_orders) == 0:
                    msg = "The axes of your dataset ({}) are not compatible with any of the allowed"\
                          " axis configurations used by this workflow ({}). Please fix them."\
                          .format(providerSlot.meta.getAxisKeys(), self.forceAxisOrder)
                    raise DatasetConstraintError("DataSelection", msg)

                output_order = sorted(candidate_orders,
                                      key=len)[0]  # the shortest one
                output_order = "".join(output_order)
            else:
                # No forced axisorder is supplied. Use original axisorder as
                # output order: it is assumed by the export-applet, that the
                # an OpReorderAxes operator is added in the beginning
                output_order = "".join(
                    [x for x in providerSlot.meta.axistags.keys()])

            op5 = OpReorderAxes(parent=self)
            op5.AxisOrder.setValue(output_order)
            op5.Input.connect(providerSlot)
            providerSlot = op5.Output
            self._opReaders.append(op5)

            # If the channel axis is missing, add it as last axis
            if 'c' not in providerSlot.meta.axistags:
                op5 = OpReorderAxes(parent=self)
                keys = providerSlot.meta.getAxisKeys()

                # Append
                keys.append('c')
                op5.AxisOrder.setValue("".join(keys))
                op5.Input.connect(providerSlot)
                providerSlot = op5.Output
                self._opReaders.append(op5)

            # Connect our external outputs to the internal operators we chose
            self.Image.connect(providerSlot)

            self.AllowLabels.setValue(datasetInfo.allowLabels)

            # If the reading operator provides a nickname, use it.
            if self.Image.meta.nickname is not None:
                datasetInfo.nickname = self.Image.meta.nickname

            imageName = datasetInfo.nickname
            if imageName == "":
                imageName = datasetInfo.filePath
            self.ImageName.setValue(imageName)

        except:
            self.internalCleanup()
            raise

    def propagateDirty(self, slot, subindex, roi):
        # Output slots are directly connected to internal operators
        pass

    @classmethod
    def getInternalDatasets(cls, filePath):
        return OpInputDataReader.getInternalDatasets(filePath)
class OpCountingBatchResults(OpDataExport):
    # Add these additional input slots, to be used by the GUI.
    PmapColors = InputSlot()
    LabelNames = InputSlot()
Example #9
0
class _OpThresholdOneLevel(Operator):
    name = "_OpThresholdOneLevel"

    InputImage = InputSlot()
    MinSize = InputSlot(stype='int', value=0)
    MaxSize = InputSlot(stype='int', value=1000000)
    Threshold = InputSlot(stype='float', value=0.5)

    Output = OutputSlot()

    #debug output
    BeforeSizeFilter = OutputSlot()

    def __init__(self, *args, **kwargs):
        super(_OpThresholdOneLevel, self).__init__(*args, **kwargs)

        self._opThresholder = OpPixelOperator(parent=self)
        self._opThresholder.Input.connect(self.InputImage)

        self._opLabeler = OpLabelVolume(parent=self)
        self._opLabeler.Method.setValue(_labeling_impl)
        self._opLabeler.Input.connect(self._opThresholder.Output)

        self.BeforeSizeFilter.connect(self._opLabeler.Output)

        self._opFilter = OpFilterLabels(parent=self)
        self._opFilter.Input.connect(self._opLabeler.Output)
        self._opFilter.MinLabelSize.connect(self.MinSize)
        self._opFilter.MaxLabelSize.connect(self.MaxSize)
        self._opFilter.BinaryOut.setValue(False)

        self.Output.connect(self._opFilter.Output)

    def setupOutputs(self):
        def thresholdToUint8(thresholdValue, a):
            drange = self.InputImage.meta.drange
            if drange is not None:
                assert drange[0] == 0,\
                    "Don't know how to threshold data with this drange."
                thresholdValue *= drange[1]
            if a.dtype == numpy.uint8:
                # In-place (numpy optimizes this!)
                a[:] = (a > thresholdValue)
                return a
            else:
                return (a > thresholdValue).astype(numpy.uint8)

        self._opThresholder.Function.setValue(
            partial(thresholdToUint8, self.Threshold.value))

        # self.Output already has metadata: it is directly connected to self._opFilter.Output

    def execute(self, slot, subindex, roi, result):
        assert False, "Shouldn't get here..."

    def propagateDirty(self, slot, subindex, roi):
        pass  # nothing to do here

    def setInSlot(self, slot, subindex, roi, value):
        # Nothing to do here.
        # Our Input slots are directly fed into the cache,
        #  so all calls to __setitem__ are forwarded automatically
        pass
Example #10
0
class OpResize5D(Operator):
    """
    Resize a 5D image.
    Notes:
        - Input must be 5D, tzyxc
        - Resizing is performed across zyx dimensions only. time dimension may not be resized.
    """
    Input = InputSlot()
    ResizedShape = InputSlot()

    Output = OutputSlot()

    def __init__(self, *args, **kwargs):
        super(OpResize5D, self).__init__(*args, **kwargs)
        self._input_to_output_scales = None
        self.progressSignal = OrderedSignal()

    def setupOutputs(self):
        assert self.Input.meta.getAxisKeys() == list('tzyxc')
        input_shape = self.Input.meta.shape
        output_shape = self.ResizedShape.value
        assert isinstance(output_shape, tuple)
        assert len(output_shape) == len(input_shape)
        self.Output.meta.assignFrom(self.Input.meta)
        self.Output.meta.shape = output_shape

        self._input_to_output_scales = numpy.array(
            output_shape, dtype=numpy.float32) / input_shape

        axes = self.Input.meta.getAxisKeys()
        if 'c' in axes:
            assert self._input_to_output_scales[ axes.index('c') ] == 1.0, \
                "Resizing the channel dimension is not supported."
        if 't' in axes:
            assert self._input_to_output_scales[ axes.index('t') ] == 1.0, \
                "Resizing the time dimension is not supported (yet)."

    def execute(self, slot, subindex, output_roi, result):
        # Special fast path if no resampling needed
        if self.Input.meta.shape == self.Output.meta.shape:
            self.Input(output_roi.start,
                       output_roi.stop).writeInto(result).wait()
            return result

        # Map output_roi to input_roi
        output_roi = numpy.array((output_roi.start, output_roi.stop))
        input_roi = output_roi / self._input_to_output_scales

        # Convert to int (round start down, round stop up)
        input_roi[1] += 0.5
        input_roi = input_roi.astype(int)

        t_start = output_roi[0][0]
        t_stop = output_roi[1][0]

        def process_timestep(t):
            # Request input and resize it.
            # FIXME: This is not quite correct.  We should request a halo that is wide enough
            #        for the BSpline used by resize(). See vigra docs for BSlineBase.radius()
            step_input_roi = copy.copy(input_roi)
            step_input_roi[0][0] = t
            step_input_roi[1][0] = t + 1

            step_input_data = self.Input(*step_input_roi).wait()
            step_input_data = vigra.taggedView(step_input_data, 'tzyxc')

            step_shape_4d = numpy.array(step_input_data[0].shape)
            step_shape_4d_nochannel = step_shape_4d[:-1]
            squeezed_slicing = numpy.where(step_shape_4d_nochannel == 1, 0,
                                           slice(None))
            squeezed_slicing = tuple(squeezed_slicing) + (slice(None), )

            step_input_squeezed = step_input_data[0][squeezed_slicing]
            result_step = result[t][squeezed_slicing]
            # vigra assumes wrong axis order if we don't specify one explicitly here...
            result_step = vigra.taggedView(result_step,
                                           step_input_squeezed.axistags)

            if self.Input.meta.dtype == numpy.float32:
                vigra.sampling.resize(step_input_squeezed, out=result_step)
            else:
                step_input_squeezed = step_input_squeezed.astype(numpy.float32)
                result_float = vigra.sampling.resize(
                    step_input_squeezed, shape=result_step.shape[:-1])
                result_step[:] = result_float.round()

        # FIXME: Progress here will not be correct for multiple threads.
        self.progressSignal(0)

        # FIXME: request pool...
        for t in range(t_start, t_stop):
            process_timestep(t)
            progress = 100 * (t - t_start) / (t_stop - t_start)
            self.progressSignal(int(progress))
        self.progressSignal(100)

        return result

    def propagateDirty(self, slot, subindex, input_roi):
        # FIXME: When execute() is fixed to use a halo, we should also
        #        incorporate the halo into this dirty propagation logic.

        # Map input_roi to output_roi
        input_roi = numpy.array((input_roi.start, input_roi.stop))
        output_roi = input_roi * self._input_to_output_scales

        # Convert to int (round start down, round stop up)
        output_roi[1] += 0.5
        output_roi = output_roi.astype(int)

        self.Output.setDirty(*output_roi)
class OpExportMultipageTiffSequence(Operator):
    Input = InputSlot(
    )  # The last two non-singleton axes (except 'c') are the axes of the slices.
    # Re-order the axes yourself if you want an alternative slicing direction

    FilepathPattern = InputSlot(
    )  # A complete filepath including a {slice_index} member and a valid file extension.
    SliceIndexOffset = InputSlot(
        value=0)  # Added to the {slice_index} in the export filename.

    def __init__(self, *args, **kwargs):
        super(OpExportMultipageTiffSequence, self).__init__(*args, **kwargs)
        self.progressSignal = OrderedSignal()

    def run_export(self):
        """
        Request the volume in slices (running in parallel), and write each slice to the correct page.
        Note: We can't use BigRequestStreamer here, because the data for each slice wouldn't be 
              guaranteed to arrive in the correct order.
        """
        # Make the directory first if necessary
        export_dir = os.path.split(self.FilepathPattern.value)[0]
        if not os.path.exists(export_dir):
            os.makedirs(export_dir)

        # Blockshape is the same as the input shape, except for the sliced dimension
        step_axis = self._volume_axes[0]
        tagged_blockshape = self.Input.meta.getTaggedShape()
        tagged_blockshape[step_axis] = 1
        block_shape = (list(tagged_blockshape.values()))
        logger.debug(
            "Starting Multipage Sequence Export with block shape: {}".format(
                block_shape))

        # Block step is all zeros except step axis, e.g. (0, 1, 0, 0, 0)
        block_step = numpy.array(self.Input.meta.getAxisKeys()) == step_axis
        block_step = block_step.astype(int)

        filepattern = self.FilepathPattern.value

        # If the user didn't provide custom formatting for the slice field,
        #  auto-format to include zero-padding
        if '{slice_index}' in filepattern:
            filepattern = filepattern.format(
                slice_index='{' +
                'slice_index:0{}'.format(self._max_slice_digits) + '}')

        self.progressSignal(0)
        # Nothing fancy here: Just loop over the blocks in order.
        tagged_shape = self.Input.meta.getTaggedShape()
        for block_index in range(tagged_shape[step_axis]):
            roi = numpy.array(roiFromShape(block_shape))
            roi += block_index * block_step
            roi = list(map(tuple, roi))

            try:
                opSubregion = OpSubRegion(parent=self)
                opSubregion.Roi.setValue(roi)
                opSubregion.Input.connect(self.Input)

                formatted_path = filepattern.format(
                    slice_index=(block_index + self.SliceIndexOffset.value))
                opExportBlock = OpExportMultipageTiff(parent=self)
                opExportBlock.Input.connect(opSubregion.Output)
                opExportBlock.Filepath.setValue(formatted_path)

                block_start_progress = 100 * block_index // tagged_shape[
                    step_axis]

                def _handleBlockProgress(block_progress):
                    self.progressSignal(block_start_progress +
                                        block_progress //
                                        tagged_shape[step_axis])

                opExportBlock.progressSignal.subscribe(_handleBlockProgress)

                # Run the export for this block
                opExportBlock.run_export()
            finally:
                opExportBlock.cleanUp()
                opSubregion.cleanUp()

        self.progressSignal(100)

    def setupOutputs(self):
        # If stacking XY images in Z-steps,
        #  then self._volume_axes = 'zxy'
        self._volume_axes = self.get_nonsingleton_axes()
        step_axis = self._volume_axes[0]
        max_slice = self.SliceIndexOffset.value + self.Input.meta.getTaggedShape(
        )[step_axis]
        self._max_slice_digits = int(math.ceil(math.log10(max_slice + 1)))

        # Check for errors
        assert len(self._volume_axes) == 4 or len(self._volume_axes) == 5 and 'c' in self._volume_axes[1:], \
            "Exported stacks must have exactly 4 non-singleton dimensions (other than the channel dimension).  "\
            "You stack dimensions are: {}".format( self.Input.meta.getTaggedShape() )

        # Test to make sure the filepath pattern includes slice index field
        filepath_pattern = self.FilepathPattern.value
        assert '123456789' in filepath_pattern.format( slice_index=123456789 ), \
            "Output filepath pattern must contain the '{slice_index}' field for formatting.\n"\
            "Your format was: {}".format( filepath_pattern )

    # No output slots...
    def execute(self, slot, subindex, roi, result):
        pass

    def propagateDirty(self, slot, subindex, roi):
        pass

    def get_nonsingleton_axes(self):
        return self.get_nonsingleton_axes_for_tagged_shape(
            self.Input.meta.getTaggedShape())

    @classmethod
    def get_nonsingleton_axes_for_tagged_shape(self, tagged_shape):
        # Find the non-singleton axes.
        # The first non-singleton axis is the step axis.
        # The last 2 non-channel non-singleton axes will be the axes of the slices.
        tagged_items = list(tagged_shape.items())
        filtered_items = [k_v for k_v in tagged_items if k_v[1] > 1]
        filtered_axes = list(zip(*filtered_items))[0]
        return filtered_axes
class OpStreamingUfmfReader(Operator):
    """
    Imports videos in uFMF format. For more information refer to the Ctrax file tracker: http://ctrax.sourceforge.net/
    """

    name = "OpStreamingUfmfReader"
    category = "Input"

    position = None

    FileName = InputSlot(stype="filestring")
    Output = OutputSlot()

    class DatasetReadError(Exception):
        pass

    def __init__(self, *args, **kwargs):
        super(OpStreamingUfmfReader, self).__init__(*args, **kwargs)
        self._lock = threading.Lock()

    def setupOutputs(self):
        """
        Load the file specified via our input slot and present its data on the output slot.
        """
        fileName = self.FileName.value

        self.fmf = UfmfParser.FlyMovieEmulator(str(fileName))
        frameNum = self.fmf.get_n_frames()
        width = self.fmf.get_width()
        height = self.fmf.get_height()

        try:
            self.frame, timestamp = self.fmf.get_next_frame()
        except FMF.NoMoreFramesException as err:
            logger.info("Error reading uFMF frame.")

        self.Output.meta.dtype = self.frame.dtype.type
        self.Output.meta.axistags = vigra.defaultAxistags(AXIS_ORDER)
        self.Output.meta.shape = (frameNum, self.frame.shape[0],
                                  self.frame.shape[1], 1)
        self.Output.meta.ideal_blockshape = (1, ) + self.Output.meta.shape[1:]

    def execute(self, slot, subindex, roi, result):
        start, stop = roi.start, roi.stop

        tStart, tStop = start[0], stop[0]
        yStart, yStop = start[1], stop[1]
        xStart, xStop = start[2], stop[2]
        cStart, cStop = start[3], stop[3]

        for tResult, tFrame in enumerate(range(tStart, tStop)):
            with self._lock:
                if self.position != tFrame:
                    self.position = tFrame
                    self.fmf.seek(tFrame)
                    self.frame, timestamp = self.fmf.get_next_frame()
                result[tResult, ..., 0] = self.frame[yStart:yStop,
                                                     xStart:xStop]

    def propagateDirty(self, slot, subindex, roi):
        if slot == self.FileName:
            self.Output.setDirty(slice(None))

    def cleanUp(self):
        self.fmf.close()
        super(OpStreamingUfmfReader, self).cleanUp()
Example #13
0
class OpExportSlot(Operator):
    """
    Export a slot 'as-is', i.e. no subregion, no dtype conversion, no normalization, no axis re-ordering, etc.
    For sequence export formats, the sequence is indexed by the axistags' FIRST axis.
    For example, txyzc produces a sequence of xyzc volumes.
    """

    Input = InputSlot()

    OutputFormat = InputSlot(value="hdf5")  # string.  See formats, below
    OutputFilenameFormat = (
        InputSlot()
    )  # A format string allowing {roi}, {t_start}, {t_stop}, etc (but not {nickname} or {dataset_dir})
    OutputInternalPath = InputSlot(value="exported_data")

    CoordinateOffset = InputSlot(
        optional=True
    )  # Add an offset to the roi coordinates in the export path (useful if Input is a subregion of a larger dataset)

    ExportPath = OutputSlot()
    FormatSelectionErrorMsg = OutputSlot()

    _2d_exts = vigra.impex.listExtensions().split()

    # List all supported formats
    _2d_formats = [FormatInfo(ext, ext, 2, 2) for ext in _2d_exts]
    _3d_sequence_formats = [
        FormatInfo(ext + " sequence", ext, 3, 3) for ext in _2d_exts
    ]
    _3d_volume_formats = [FormatInfo("multipage tiff", "tiff", 3, 3)]
    _4d_sequence_formats = [
        FormatInfo("multipage tiff sequence", "tiff", 4, 4)
    ]
    nd_format_formats = [
        FormatInfo("hdf5", "h5", 0, 5),
        FormatInfo("compressed hdf5", "h5", 0, 5),
        FormatInfo("n5", "n5", 0, 5),
        FormatInfo("compressed n5", "n5", 0, 5),
        FormatInfo("numpy", "npy", 0, 5),
        FormatInfo("dvid", "", 2, 5),
        FormatInfo("blockwise hdf5", "json", 0, 5),
    ]

    ALL_FORMATS = _2d_formats + _3d_sequence_formats + _3d_volume_formats + _4d_sequence_formats + nd_format_formats

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.progressSignal = OrderedSignal()

        # Set up the impl function lookup dict
        export_impls = {}
        export_impls["hdf5"] = ("h5", self._export_h5n5)
        export_impls["compressed hdf5"] = ("h5",
                                           partial(self._export_h5n5, True))
        export_impls["n5"] = ("n5", self._export_h5n5)
        export_impls["compressed n5"] = ("n5", partial(self._export_h5n5,
                                                       True))
        export_impls["numpy"] = ("npy", self._export_npy)
        export_impls["dvid"] = ("", self._export_dvid)
        export_impls["blockwise hdf5"] = ("json", self._export_blockwise_hdf5)

        for fmt in self._2d_formats:
            export_impls[fmt.name] = (fmt.extension,
                                      partial(self._export_2d, fmt.extension))

        for fmt in self._3d_sequence_formats:
            export_impls[fmt.name] = (fmt.extension,
                                      partial(self._export_3d_sequence,
                                              fmt.extension))

        export_impls["multipage tiff"] = ("tiff", self._export_multipage_tiff)
        export_impls["multipage tiff sequence"] = (
            "tiff", self._export_multipage_tiff_sequence)
        self._export_impls = export_impls

        self.Input.notifyMetaChanged(self._updateFormatSelectionErrorMsg)

    def setupOutputs(self):
        self.ExportPath.meta.shape = (1, )
        self.ExportPath.meta.dtype = object
        self.FormatSelectionErrorMsg.meta.shape = (1, )
        self.FormatSelectionErrorMsg.meta.dtype = object

        if self.OutputFormat.value in (
                "hdf5",
                "compressed hdf5") and self.OutputInternalPath.value == "":
            self.ExportPath.meta.NOTREADY = True

    def execute(self, slot, subindex, roi, result):
        if slot == self.ExportPath:
            return self._executeExportPath(result)
        else:
            assert False, "Unknown output slot: {}".format(slot.name)

    def _executeExportPath(self, result):
        path_format = self.OutputFilenameFormat.value
        file_extension = self._export_impls[self.OutputFormat.value][0]

        # Remove existing extension (if present) and add the correct extension (if any)
        if file_extension:
            path_format = os.path.splitext(path_format)[0]
            path_format += "." + file_extension

        # Provide the TOTAL path (including dataset name)
        if self.OutputFormat.value in ("hdf5", "compressed hdf5", "n5",
                                       "compressed n5"):
            path_format += "/" + self.OutputInternalPath.value

        roi = numpy.array(roiFromShape(self.Input.meta.shape))

        # Intermediate state can cause coordinate offset and input shape to be mismatched.
        # Just don't use the offset if it looks wrong.
        # (The client will provide a valid offset later on.)
        if self.CoordinateOffset.ready() and len(
                self.CoordinateOffset.value) == len(roi[0]):
            offset = self.CoordinateOffset.value
            assert len(roi[0] == len(offset))
            roi += offset
        optional_replacements = {}
        optional_replacements["roi"] = list(map(tuple, roi))
        for key, (start, stop) in zip(self.Input.meta.getAxisKeys(),
                                      roi.transpose()):
            optional_replacements[key + "_start"] = start
            optional_replacements[key + "_stop"] = stop
        formatted_path = format_known_keys(path_format,
                                           optional_replacements,
                                           strict=False)
        result[0] = formatted_path
        return result

    def _updateFormatSelectionErrorMsg(self, *args):
        error_msg = self._get_format_selection_error_msg()
        self.FormatSelectionErrorMsg.setValue(error_msg)

    def _get_format_selection_error_msg(self, *args):
        """
        If the currently selected format does not support the input image format,
        return an error message stating why. Otherwise, return an empty string.
        """
        if not self.Input.ready():
            return "Input not ready"
        output_format = self.OutputFormat.value

        # These cases support all combinations
        if output_format in ("hdf5", "compressed hdf5", "n5", "compressed n5",
                             "npy", "blockwise hdf5"):
            return ""

        tagged_shape = self.Input.meta.getTaggedShape()
        axes = OpStackWriter.get_nonsingleton_axes_for_tagged_shape(
            tagged_shape)
        output_dtype = self.Input.meta.dtype

        if output_format == "dvid":
            # dvid requires a channel axis, which must come last.
            # Internally, we transpose it before sending it over the wire
            if list(tagged_shape.keys())[-1] != "c":
                return "DVID requires the last axis to be channel."

            # Make sure DVID supports this dtype/channel combo.
            from libdvid.voxels import VoxelsMetadata

            axiskeys = self.Input.meta.getAxisKeys()
            # We reverse the axiskeys because the export operator (see below) uses transpose_axes=True
            reverse_axiskeys = "".join(reversed(axiskeys))
            reverse_shape = tuple(reversed(self.Input.meta.shape))
            metainfo = VoxelsMetadata.create_default_metadata(
                reverse_shape, output_dtype, reverse_axiskeys, 0.0,
                "nanometers")
            try:
                metainfo.determine_dvid_typename()
            except Exception as ex:
                return str(ex)
            else:
                return ""

        return FormatValidity.check(self.Input.meta.getTaggedShape(),
                                    self.Input.meta.dtype, output_format)

    def propagateDirty(self, slot, subindex, roi):
        if slot == self.OutputFormat or slot == self.OutputFilenameFormat:
            self.ExportPath.setDirty()
        if slot == self.OutputFormat:
            self._updateFormatSelectionErrorMsg()

    def run_export_to_array(self):
        """
        Export the slot data to an array, instead of to disk.
        The data is computed blockwise, as necessary.
        The result is returned.
        """
        self.progressSignal(0)
        opExport = OpExportToArray(parent=self)
        try:
            opExport.progressSignal.subscribe(self.progressSignal)
            opExport.Input.connect(self.Input)
            return opExport.run_export_to_array()
        finally:
            opExport.cleanUp()
            self.progressSignal(100)

    def run_export(self):
        """
        Perform the export and WAIT for it to complete.
        If you want asynchronous execution, run this function in a request:

            req = Request( opExport.run_export )
            req.submit()
        """
        output_format = self.OutputFormat.value
        try:
            export_func = self._export_impls[output_format][1]
        except KeyError as e:
            raise Exception(f"Unknown export format: {output_format}") from e
        else:
            mkdir_p(PathComponents(self.ExportPath.value).externalDirectory)
            export_func()

    def _export_h5n5(self, compress=False):
        self.progressSignal(0)

        # Create and open the hdf5/n5 file
        export_components = PathComponents(self.ExportPath.value)
        try:
            if os.path.isdir(export_components.externalPath
                             ):  # externalPath leads to a n5 file
                shutil.rmtree(export_components.externalPath
                              )  # n5 is stored as a directory structure
            else:
                os.remove(export_components.externalPath)
        except OSError as ex:
            # It's okay if the file isn't there.
            if ex.errno != 2:
                raise
        try:
            with OpStreamingH5N5Reader.get_h5_n5_file(
                    export_components.externalPath, "w") as h5N5File:
                # Create a temporary operator to do the work for us
                opH5N5Writer = OpH5N5WriterBigDataset(parent=self)
                try:
                    opH5N5Writer.CompressionEnabled.setValue(compress)
                    opH5N5Writer.h5N5File.setValue(h5N5File)
                    opH5N5Writer.h5N5Path.setValue(
                        export_components.internalPath)
                    opH5N5Writer.Image.connect(self.Input)

                    # The H5 Writer provides it's own progress signal, so just connect ours to it.
                    opH5N5Writer.progressSignal.subscribe(self.progressSignal)

                    # Perform the export and block for it in THIS THREAD.
                    opH5N5Writer.WriteImage[:].wait()
                finally:
                    opH5N5Writer.cleanUp()
                    self.progressSignal(100)
        except IOError as ex:
            import sys

            msg = "\nException raised when attempting to export to {}: {}\n".format(
                export_components.externalPath, str(ex))
            sys.stderr.write(msg)
            raise

    def _export_npy(self):
        self.progressSignal(0)
        export_path = self.ExportPath.value
        try:
            opWriter = OpNpyWriter(parent=self)
            opWriter.Filepath.setValue(export_path)
            opWriter.Input.connect(self.Input)

            # Run the export in this thread
            opWriter.write()
        finally:
            opWriter.cleanUp()
            self.progressSignal(100)

    def _export_dvid(self):
        self.progressSignal(0)
        export_path = self.ExportPath.value

        opExport = OpExportDvidVolume(transpose_axes=True, parent=self)
        try:
            opExport.Input.connect(self.Input)
            opExport.NodeDataUrl.setValue(export_path)

            # Run the export in this thread
            opExport.run_export()
        finally:
            opExport.cleanUp()
            self.progressSignal(100)

    def _export_blockwise_hdf5(self):
        raise NotImplementedError

    def _export_2d(self, fmt):
        self.progressSignal(0)
        export_path = self.ExportPath.value
        opExport = OpExport2DImage(parent=self)
        try:
            opExport.progressSignal.subscribe(self.progressSignal)
            opExport.Filepath.setValue(export_path)
            opExport.Input.connect(self.Input)

            # Run the export
            opExport.run_export()
        finally:
            opExport.cleanUp()
            self.progressSignal(100)

    def _export_3d_sequence(self, extension):
        self.progressSignal(0)
        export_path_base, export_path_ext = os.path.splitext(
            self.ExportPath.value)
        export_path_pattern = export_path_base + "." + extension

        try:
            opWriter = OpStackWriter(parent=self)
            opWriter.FilepathPattern.setValue(export_path_pattern)
            opWriter.Input.connect(self.Input)
            opWriter.progressSignal.subscribe(self.progressSignal)

            if self.CoordinateOffset.ready():
                step_axis = opWriter.get_nonsingleton_axes()[0]
                step_axis_index = self.Input.meta.getAxisKeys().index(
                    step_axis)
                step_axis_offset = self.CoordinateOffset.value[step_axis_index]
                opWriter.SliceIndexOffset.setValue(step_axis_offset)

            # Run the export
            opWriter.run_export()
        finally:
            opWriter.cleanUp()
            self.progressSignal(100)

    def _export_multipage_tiff(self):
        self.progressSignal(0)
        export_path = self.ExportPath.value
        try:
            opExport = OpExportMultipageTiff(parent=self)
            opExport.Filepath.setValue(export_path)
            opExport.Input.connect(self.Input)
            opExport.progressSignal.subscribe(self.progressSignal)

            # Run the export
            opExport.run_export()
        finally:
            opExport.cleanUp()
            self.progressSignal(100)

    def _export_multipage_tiff_sequence(self):
        self.progressSignal(0)
        export_path_base, export_path_ext = os.path.splitext(
            self.ExportPath.value)
        export_path_pattern = export_path_base + ".tiff"

        try:
            opExport = OpExportMultipageTiffSequence(parent=self)
            opExport.FilepathPattern.setValue(export_path_pattern)
            opExport.Input.connect(self.Input)
            opExport.progressSignal.subscribe(self.progressSignal)

            if self.CoordinateOffset.ready():
                step_axis = opExport.get_nonsingleton_axes()[0]
                step_axis_index = self.Input.meta.getAxisKeys().index(
                    step_axis)
                step_axis_offset = self.CoordinateOffset.value[step_axis_index]
                opExport.SliceIndexOffset.setValue(step_axis_offset)

            # Run the export
            opExport.run_export()
        finally:
            opExport.cleanUp()
            self.progressSignal(100)
Example #14
0
class _OpVigraLabelVolume(Operator):
    """
    Operator that simply wraps vigra's labelVolume function.
    """

    name = "OpVigraLabelVolume"
    category = "Vigra"

    Input = InputSlot()
    BackgroundValue = InputSlot(optional=True)

    Output = OutputSlot()

    def setupOutputs(self):
        inputShape = self.Input.meta.shape

        # Must have at most 1 time slice
        timeIndex = self.Input.meta.axistags.index("t")
        assert timeIndex == len(inputShape) or inputShape[timeIndex] == 1

        # Must have at most 1 channel
        channelIndex = self.Input.meta.axistags.channelIndex
        assert channelIndex == len(inputShape) or inputShape[channelIndex] == 1

        self.Output.meta.assignFrom(self.Input.meta)
        self.Output.meta.dtype = numpy.uint32

    def execute(self, slot, subindex, roi, destination):
        assert slot == self.Output

        resultView = destination.view(vigra.VigraArray)
        resultView.axistags = self.Input.meta.axistags

        inputData = self.Input(roi.start, roi.stop).wait()
        inputData = inputData.view(vigra.VigraArray)
        inputData.axistags = self.Input.meta.axistags

        # Drop the time axis, which vigra.labelVolume doesn't remove automatically
        axiskeys = [tag.key for tag in inputData.axistags]
        if "t" in axiskeys:
            inputData = inputData.bindAxis("t", 0)
            resultView = resultView.bindAxis("t", 0)

        # Drop the channel axis, too.
        if "c" in axiskeys:
            inputData = inputData.bindAxis("c", 0)
            resultView = resultView.bindAxis("c", 0)

        # I have no idea why, but vigra sometimes throws a precondition error if this line is present.
        # ...on the other hand, I can't remember why I added this line in the first place...
        # inputData = inputData.view(numpy.ndarray)

        if self.BackgroundValue.ready():
            bg = self.BackgroundValue.value
            if isinstance(bg, numpy.ndarray):
                # If background value was given as a 1-element array, extract it.
                assert bg.size == 1
                bg = bg.squeeze()[()]
            if isinstance(bg, numpy.float):
                bg = float(bg)
            else:
                bg = int(bg)
            if len(inputData.shape) == 2:
                vigra.analysis.labelImageWithBackground(inputData,
                                                        background_value=bg,
                                                        out=resultView)
            else:
                vigra.analysis.labelVolumeWithBackground(inputData,
                                                         background_value=bg,
                                                         out=resultView)
        else:
            if len(inputData.shape) == 2:
                vigra.analysis.labelImageWithBackground(inputData,
                                                        out=resultView)
            else:
                vigra.analysis.labelVolumeWithBackground(inputData,
                                                         out=resultView)

        return destination

    def propagateDirty(self, inputSlot, subindex, roi):
        if inputSlot == self.Input:
            # If anything changed, the whole image is now dirty
            #  because a single pixel change can trigger a cascade of relabeling.
            self.Output.setDirty(slice(None))
        elif inputSlot == self.BackgroundValue:
            self.Output.setDirty(slice(None))
Example #15
0
class OpAnisotropicGaussianSmoothing5d(Operator):
    # raw volume, in 5d 'txyzc' order
    Input = InputSlot()
    Sigmas = InputSlot(value={'x': 1.0, 'y': 1.0, 'z': 1.0})

    Output = OutputSlot()

    def setupOutputs(self):
        self.Output.meta.assignFrom(self.Input.meta)
        self.Output.meta.dtype = numpy.float32  # vigra gaussian only supports float32
        self._sigmas = self.Sigmas.value
        assert isinstance(self.Sigmas.value,
                          dict), "Sigmas slot expects a dict"
        assert set(self._sigmas.keys()) == set(
            'xyz'), "Sigmas slot expects three key-value pairs for x,y,z"

    def execute(self, slot, subindex, roi, result):
        assert all(roi.stop <= self.Input.meta.shape),\
            "Requested roi {} is too large for this input image of shape {}.".format(roi, self.Input.meta.shape)

        # Determine how much input data we'll need, and where the result will be
        # relative to that input roi
        # inputRoi is a 5d roi, computeRoi depends on the number of singletons
        # in shape, but is at most 3d
        inputRoi, computeRoi = self._getInputComputeRois(roi)

        # Obtain the input data
        with Timer() as resultTimer:
            data = self.Input(*inputRoi).wait()
        logger.debug("Obtaining input data took {} seconds for roi {}".format(
            resultTimer.seconds(), inputRoi))
        data = vigra.taggedView(data, axistags='txyzc')

        # input is in txyzc order
        tIndex = 0
        cIndex = 4

        # Must be float32
        if data.dtype != numpy.float32:
            data = data.astype(numpy.float32)

        # we need to remove a singleton z axis, otherwise we get
        # 'kernel longer than line' errors
        ts = self.Input.meta.getTaggedShape()
        tags = [k for k in 'xyz' if ts[k] > 1]
        sigma = [self._sigmas[k] for k in tags]

        # Check if we need to smooth
        if any([x < 0.1 for x in sigma]):
            # just pipe the input through
            result[...] = data
            return

        for i, t in enumerate(xrange(roi.start[tIndex], roi.stop[tIndex])):
            for j, c in enumerate(xrange(roi.start[cIndex], roi.stop[cIndex])):
                # prepare the result as an argument
                resview = vigra.taggedView(result[i, ..., j], axistags='xyz')
                dataview = data[i, ..., j]
                # TODO make this general, not just for z axis
                resview = resview.withAxes(*tags)
                dataview = dataview.withAxes(*tags)

                # Smooth the input data
                vigra.filters.gaussianSmoothing(dataview,
                                                sigma,
                                                window_size=2.0,
                                                roi=computeRoi,
                                                out=resview)

    def _getInputComputeRois(self, roi):
        shape = self.Input.meta.shape
        start = numpy.asarray(roi.start)
        stop = numpy.asarray(roi.stop)
        n = len(stop)
        spatStart = [roi.start[i] for i in range(n) if shape[i] > 1]
        spatStop = [roi.stop[i] for i in range(n) if shape[i] > 1]
        sigma = [0] + map(self._sigmas.get, 'xyz') + [0]
        spatialRoi = (spatStart, spatStop)

        inputSpatialRoi = enlargeRoiForHalo(roi.start,
                                            roi.stop,
                                            shape,
                                            sigma,
                                            window=2.0)

        # Determine the roi within the input data we're going to request
        inputRoiOffset = roi.start - inputSpatialRoi[0]
        computeRoi = [inputRoiOffset, inputRoiOffset + stop - start]
        for i in (0, 1):
            computeRoi[i] = [
                computeRoi[i][j] for j in range(n)
                if shape[j] > 1 and j not in (0, 4)
            ]

        # make sure that vigra understands our integer types
        computeRoi = (tuple(map(int,
                                computeRoi[0])), tuple(map(int,
                                                           computeRoi[1])))

        inputRoi = (list(inputSpatialRoi[0]), list(inputSpatialRoi[1]))

        return inputRoi, computeRoi

    def propagateDirty(self, slot, subindex, roi):
        if slot == self.Input:
            # Halo calculation is bidirectional, so we can re-use the function
            # that computes the halo during execute()
            inputRoi, _ = self._getInputComputeRois(roi)
            self.Output.setDirty(inputRoi[0], inputRoi[1])
        elif slot == self.Sigmas:
            self.Output.setDirty(slice(None))
        else:
            assert False, "Unknown input slot: {}".format(slot.name)
Example #16
0
class _OpThresholdTwoLevels(Operator):
    name = "_OpThresholdTwoLevels"

    InputImage = InputSlot()
    MinSize = InputSlot(stype='int', value=0)
    MaxSize = InputSlot(stype='int', value=1000000)
    HighThreshold = InputSlot(stype='float', value=0.5)
    LowThreshold = InputSlot(stype='float', value=0.2)

    Output = OutputSlot()
    CachedOutput = OutputSlot()  # For the GUI (blockwise-access)

    # For serialization
    InputHdf5 = InputSlot(optional=True)
    OutputHdf5 = OutputSlot()
    CleanBlocks = OutputSlot()

    # Debug outputs
    BigRegions = OutputSlot()
    SmallRegions = OutputSlot()
    FilteredSmallLabels = OutputSlot()

    # Schematic:
    #
    #           HighThreshold                         MinSize,MaxSize                       --(cache)--> opColorize -> FilteredSmallLabels
    #                   \                                       \                     /
    #           opHighThresholder --> opHighLabeler --> opHighLabelSizeFilter                           Output
    #          /                   \          /                 \                                            \                         /
    # InputImage        --(cache)--> SmallRegions                    opSelectLabels -->opFinalLabelSizeFilter--> opCache --> CachedOutput
    #          \                                                              /                                           /       \
    #           opLowThresholder ----> opLowLabeler --------------------------                                       InputHdf5     --> OutputHdf5
    #                   /                \                                                                                        -> CleanBlocks
    #           LowThreshold            --(cache)--> BigRegions

    def __init__(self, *args, **kwargs):
        super(_OpThresholdTwoLevels, self).__init__(*args, **kwargs)

        self._opLowThresholder = OpPixelOperator(parent=self)
        self._opLowThresholder.Input.connect(self.InputImage)

        self._opHighThresholder = OpPixelOperator(parent=self)
        self._opHighThresholder.Input.connect(self.InputImage)

        self._opLowLabeler = OpLabelVolume(parent=self)
        self._opLowLabeler.Method.setValue(_labeling_impl)
        self._opLowLabeler.Input.connect(self._opLowThresholder.Output)

        self._opHighLabeler = OpLabelVolume(parent=self)
        self._opHighLabeler.Method.setValue(_labeling_impl)
        self._opHighLabeler.Input.connect(self._opHighThresholder.Output)

        self._opHighLabelSizeFilter = OpFilterLabels(parent=self)
        self._opHighLabelSizeFilter.Input.connect(self._opHighLabeler.Output)
        self._opHighLabelSizeFilter.MinLabelSize.connect(self.MinSize)
        self._opHighLabelSizeFilter.MaxLabelSize.connect(self.MaxSize)
        self._opHighLabelSizeFilter.BinaryOut.setValue(
            False)  # we do the binarization in opSelectLabels
        # this way, we get to display pretty colors

        self._opSelectLabels = OpSelectLabels(parent=self)
        self._opSelectLabels.BigLabels.connect(self._opLowLabeler.Output)
        self._opSelectLabels.SmallLabels.connect(
            self._opHighLabelSizeFilter.Output)

        # remove the remaining very large objects -
        # they might still be present in case a big object
        # was split into many small ones for the higher threshold
        # and they got reconnected again at lower threshold
        self._opFinalLabelSizeFilter = OpFilterLabels(parent=self)
        self._opFinalLabelSizeFilter.Input.connect(self._opSelectLabels.Output)
        self._opFinalLabelSizeFilter.MinLabelSize.connect(self.MinSize)
        self._opFinalLabelSizeFilter.MaxLabelSize.connect(self.MaxSize)
        self._opFinalLabelSizeFilter.BinaryOut.setValue(False)

        self._opCache = OpCompressedCache(parent=self)
        self._opCache.name = "_OpThresholdTwoLevels._opCache"
        self._opCache.InputHdf5.connect(self.InputHdf5)
        self._opCache.Input.connect(self._opFinalLabelSizeFilter.Output)

        # Connect our own outputs
        self.Output.connect(self._opFinalLabelSizeFilter.Output)
        self.CachedOutput.connect(self._opCache.Output)

        # Serialization outputs
        self.CleanBlocks.connect(self._opCache.CleanBlocks)
        self.OutputHdf5.connect(self._opCache.OutputHdf5)

        #self.InputChannel.connect( self._opChannelSelector.Output )

        # More debug outputs.  These all go through their own caches
        self._opBigRegionCache = OpCompressedCache(parent=self)
        self._opBigRegionCache.name = "_OpThresholdTwoLevels._opBigRegionCache"
        self._opBigRegionCache.Input.connect(self._opLowThresholder.Output)
        self.BigRegions.connect(self._opBigRegionCache.Output)

        self._opSmallRegionCache = OpCompressedCache(parent=self)
        self._opSmallRegionCache.name = "_OpThresholdTwoLevels._opSmallRegionCache"
        self._opSmallRegionCache.Input.connect(self._opHighThresholder.Output)
        self.SmallRegions.connect(self._opSmallRegionCache.Output)

        self._opFilteredSmallLabelsCache = OpCompressedCache(parent=self)
        self._opFilteredSmallLabelsCache.name = "_OpThresholdTwoLevels._opFilteredSmallLabelsCache"
        self._opFilteredSmallLabelsCache.Input.connect(
            self._opHighLabelSizeFilter.Output)
        self._opColorizeSmallLabels = OpColorizeLabels(parent=self)
        self._opColorizeSmallLabels.Input.connect(
            self._opFilteredSmallLabelsCache.Output)
        self.FilteredSmallLabels.connect(self._opColorizeSmallLabels.Output)

    def setupOutputs(self):
        def thresholdToUint8(thresholdValue, a):
            drange = self.InputImage.meta.drange
            if drange is not None:
                assert drange[0] == 0,\
                    "Don't know how to threshold data with this drange."
                thresholdValue *= drange[1]
            if a.dtype == numpy.uint8:
                # In-place (numpy optimizes this!)
                a[:] = (a > thresholdValue)
                return a
            else:
                return (a > thresholdValue).astype(numpy.uint8)

        self._opLowThresholder.Function.setValue(
            partial(thresholdToUint8, self.LowThreshold.value))
        self._opHighThresholder.Function.setValue(
            partial(thresholdToUint8, self.HighThreshold.value))

        # Output is already connected internally -- don't reassign new metadata
        # self.Output.meta.assignFrom(self.InputImage.meta)

        # Blockshape is the entire spatial volume (hysteresis thresholding is
        # a global operation)
        tagged_shape = self.Output.meta.getTaggedShape()
        tagged_shape['c'] = 1
        tagged_shape['t'] = 1
        self._opCache.BlockShape.setValue(tuple(tagged_shape.values()))
        self._opBigRegionCache.BlockShape.setValue(tuple(
            tagged_shape.values()))
        self._opSmallRegionCache.BlockShape.setValue(
            tuple(tagged_shape.values()))
        self._opFilteredSmallLabelsCache.BlockShape.setValue(
            tuple(tagged_shape.values()))

    def execute(self, slot, subindex, roi, result):
        assert False, "Shouldn't get here..."

    def propagateDirty(self, slot, subindex, roi):
        pass  # Nothing to do here

    def setInSlot(self, slot, subindex, roi, value):
        assert slot == self.InputHdf5,\
            "Invalid slot for setInSlot(): {}".format(slot.name)
Example #17
0
class OpVigraWatershed(Operator):
    """
    Operator wrapper for vigra's default watershed function.
    """
    name = "OpVigraWatershed"
    category = "Vigra"

    InputImage = InputSlot()
    PaddingWidth = InputSlot(
    )  # Specifies the extra pixels around the border of the image to use when computing the watershed.
    # (Region is clipped to the size of the input image.)

    SeedImage = InputSlot(optional=True)

    Output = OutputSlot()

    def __init__(self, *args, **kwargs):
        super(OpVigraWatershed, self).__init__(*args, **kwargs)

        # Keep a dict of roi : max label
        self._maxLabels = {}
        self._lock = threading.Lock()

    @property
    def maxLabels(self):
        return self._maxLabels

    def clearMaxLabels(self):
        with self._lock:
            self._maxLabels = {}

    def setupOutputs(self):
        self.Output.meta.assignFrom(self.InputImage.meta)
        self.Output.meta.dtype = numpy.uint32

        #warnings.warn("FIXME: How can this drange be right?")
        #self.Output.meta.drange = (0,255)

        if self.SeedImage.ready():
            assert numpy.issubdtype(self.SeedImage.meta.dtype, numpy.uint32)
            assert self.SeedImage.meta.shape == self.InputImage.meta.shape, \
                "{} != {}".format(self.SeedImage.meta.shape, self.InputImage.meta.shape)

    def getSlicings(self, roi):
        """
        Pad the given roi to obtain a new slicing to use for obtaining input data.
        Return the padded slicing and the slicing that returns the original roi within the padded data.
        """
        tags = self.InputImage.meta.axistags
        pairs = list(
            zip([tag.key for tag in tags], list(zip(roi.start, roi.stop))))
        slices = [(k, slice(start, stop)) for (k, (start, stop)) in pairs]

        # Compute the watershed over a larger area than requested (padded area)
        padding = self.PaddingWidth.value
        paddedSlices = []  # The requested slicing + padding
        outputSlices = [
        ]  # The slicing to get the requested slicing from the padded data
        for i, (key, s) in enumerate(slices):
            p = s
            if key in 'xyz':
                p_start = max(s.start - padding, 0)
                p_stop = min(s.stop + padding, self.InputImage.meta.shape[i])
                p = slice(p_start, p_stop)

            paddedSlices += [p]
            o = slice(s.start - p.start, s.stop - p.start)
            outputSlices += [o]

        return paddedSlices, outputSlices

    def execute(self, slot, subindex, roi, result):
        assert slot == self.Output

        # Every request is computed on-the-fly.
        # (No caching)
        paddedSlices, outputSlices = self.getSlicings(roi)

        # Get input data
        inputRegion = self.InputImage[paddedSlices].wait()

        # Makes sure vigra will understand this type
        if inputRegion.dtype != numpy.uint8 and inputRegion.dtype != numpy.float32:
            inputRegion = inputRegion.astype('float32')

        # Convert to vigra array
        inputRegion = inputRegion.view(vigra.VigraArray)
        inputRegion.axistags = self.InputImage.meta.axistags

        # Reduce to 3-D (keep order of xyz axes)
        tags = self.InputImage.meta.axistags
        axes3d = "".join([tag.key for tag in tags if tag.key in 'xyz'])
        inputRegion = inputRegion.withAxes(*axes3d)
        logger.debug('inputRegion 3D shape:{}'.format(inputRegion.shape))

        logger.debug("roi={}".format(roi))
        logger.debug("paddedSlices={}".format(paddedSlices))
        logger.debug("outputSlices={}".format(outputSlices))

        # If we know the range of the data, then convert to uint8
        # so we can automatically benefit from vigra's "turbo" mode
        if self.InputImage.meta.drange is not None:
            drange = self.InputImage.meta.drange
            inputRegion = numpy.asarray(inputRegion, dtype=numpy.float32)
            inputRegion = vigra.taggedView(inputRegion, axes3d)
            inputRegion -= drange[0]
            inputRegion /= (drange[1] - drange[0])
            inputRegion *= 255.0
            inputRegion = inputRegion.astype(numpy.uint8)

        # This is where the magic happens
        if self.SeedImage.ready():
            seedImage = self.SeedImage[paddedSlices].wait()
            seedImage = seedImage.view(vigra.VigraArray)
            seedImage.axistags = tags
            seedImage = seedImage.withAxes(*axes3d)
            logger.debug("Input shape = {}, seed shape = {}".format(
                inputRegion.shape, seedImage.shape))
            logger.debug("Input axes = {}, seed axes = {}".format(
                inputRegion.axistags, seedImage.axistags))
            watershed, maxLabel = vigra.analysis.watersheds(inputRegion,
                                                            seeds=seedImage)
        else:
            watershed, maxLabel = vigra.analysis.watersheds(inputRegion)
        logger.debug("Finished Watershed")

        logger.debug("watershed 3D output shape={}".format(watershed.shape))
        logger.debug("maxLabel={}".format(maxLabel))

        # Promote back to 5-D
        watershed = vigra.taggedView(watershed, axes3d)
        watershed = watershed.withAxes(*[tag.key for tag in tags])
        logger.debug("watershed 5D shape: {}".format(watershed.shape))
        logger.debug("watershed axistags: {}".format(watershed.axistags))

        with self._lock:
            start = tuple(s.start for s in paddedSlices)
            stop = tuple(s.stop for s in paddedSlices)
            self._maxLabels[(start, stop)] = maxLabel

        #print numpy.sort(vigra.analysis.unique(watershed[outputSlices])).shape
        # Return only the region the user requested
        result[:] = watershed[outputSlices].view(numpy.ndarray).reshape(
            result.shape)
        return result

    def propagateDirty(self, inputSlot, subindex, roi):
        if not self.configured():
            self.Output.setDirty(slice(None))
        elif inputSlot.name == "InputImage" or inputSlot.name == "SeedImage":
            paddedSlicing, outputSlicing = self.getSlicings(roi)
            self.Output.setDirty(paddedSlicing)
        elif inputSlot.name == "PaddingWidth":
            self.Output.setDirty(slice(None))
        else:
            assert False, "Unknown input slot."
Example #18
0
class _OpCacheWrapper(Operator):
    name = "OpCacheWrapper"
    Input = InputSlot()

    Output = OutputSlot()

    InputHdf5 = InputSlot(optional=True)
    CleanBlocks = OutputSlot()
    OutputHdf5 = OutputSlot()

    def __init__(self, *args, **kwargs):
        super(_OpCacheWrapper, self).__init__(*args, **kwargs)
        op1 = OpReorderAxes(parent=self)
        op1.name = "op1"
        op2 = OpReorderAxes(parent=self)
        op2.name = "op2"

        op1.AxisOrder.setValue('xyzct')
        op2.AxisOrder.setValue('txyzc')

        op1.Input.connect(self.Input)
        self.Output.connect(op2.Output)

        self._op1 = op1
        self._op2 = op2
        self._cache = None

    def setupOutputs(self):
        self._disconnectInternals()

        # we need a new cache
        cache = OpCompressedCache(parent=self)
        cache.name = self.name + "WrappedCache"

        # connect cache outputs
        self.CleanBlocks.connect(cache.CleanBlocks)
        self.OutputHdf5.connect(cache.OutputHdf5)
        self._op2.Input.connect(cache.Output)

        # connect cache inputs
        cache.InputHdf5.connect(self.InputHdf5)
        cache.Input.connect(self._op1.Output)

        # set the cache block shape
        tagged_shape = self._op1.Output.meta.getTaggedShape()
        tagged_shape['t'] = 1
        tagged_shape['c'] = 1
        cacheshape = map(lambda k: tagged_shape[k], 'xyzct')
        if _labeling_impl == "lazy":
            #HACK hardcoded block shape
            blockshape = numpy.minimum(cacheshape, 256)
        else:
            # use full spatial volume if not lazy
            blockshape = cacheshape
        cache.BlockShape.setValue(tuple(blockshape))

        self._cache = cache

    def execute(self, slot, subindex, roi, result):
        assert False

    def propagateDirty(self, slot, subindex, roi):
        pass

    def setInSlot(self, slot, subindex, key, value):
        assert slot == self.InputHdf5,\
            "setInSlot not implemented for slot {}".format(slot.name)
        assert self._cache is not None,\
            "setInSlot called before input was configured"
        self._cache.setInSlot(self._cache.InputHdf5, subindex, key, value)

    def _disconnectInternals(self):
        self.CleanBlocks.disconnect()
        self.OutputHdf5.disconnect()
        self._op2.Input.disconnect()

        if self._cache is not None:
            self._cache.InputHdf5.disconnect()
            self._cache.Input.disconnect()
            del self._cache
Example #19
0
class OpSparseLabelArray(OpCache):
    name = "Sparse Label Array"
    description = "simple cache for sparse label arrays"

    inputSlots = [
        InputSlot("Input", optional=True),
        InputSlot("shape"),
        InputSlot("eraser"),
        InputSlot("deleteLabel", optional=True)
    ]

    outputSlots = [
        OutputSlot("Output"),
        OutputSlot("nonzeroValues"),
        OutputSlot("nonzeroCoordinates"),
        OutputSlot("maxLabel")
    ]

    def __init__(self, *args, **kwargs):
        super(OpSparseLabelArray, self).__init__(*args, **kwargs)
        self.lock = threading.Lock()
        self._denseArray = None
        self._sparseNZ = None
        self._oldShape = (0, )
        self._maxLabel = 0

    def usedMemory(self):
        if self._denseArray is not None:
            return self._denseArray.nbytes
        return 0

    def lastAccessTime(self):
        return 0
        #return self._last_access

    def generateReport(self, report):
        report.name = self.name
        #report.fractionOfUsedMemoryDirty = self.fractionOfUsedMemoryDirty()
        report.usedMemory = self.usedMemory()
        #report.lastAccessTime = self.lastAccessTime()
        report.dtype = self.Output.meta.dtype
        report.type = type(self)
        report.id = id(self)

    def setupOutputs(self):
        if (numpy.array(self._oldShape) != self.inputs["shape"].value).any():
            shape = self.inputs["shape"].value
            self._oldShape = shape
            self.outputs["Output"].meta.dtype = numpy.uint8
            self.outputs["Output"].meta.shape = shape

            # FIXME: Don't give arbitrary axistags.  Specify them correctly if you need them.
            #self.outputs["Output"].meta.axistags = vigra.defaultAxistags(len(shape))

            self.inputs["Input"].meta.shape = shape

            self.outputs["nonzeroValues"].meta.dtype = object
            self.outputs["nonzeroValues"].meta.shape = (1, )

            self.outputs["nonzeroCoordinates"].meta.dtype = object
            self.outputs["nonzeroCoordinates"].meta.shape = (1, )

            self._denseArray = numpy.zeros(shape, numpy.uint8)
            self._sparseNZ = blist.sorteddict()

        if self.inputs["deleteLabel"].ready(
        ) and self.inputs["deleteLabel"].value != -1:
            labelNr = self.inputs["deleteLabel"].value

            neutralElement = 0
            self.inputs["deleteLabel"].setValue(-1)  #reset state of inputslot
            self.lock.acquire()

            # Find the entries to remove
            updateNZ = numpy.nonzero(
                numpy.where(self._denseArray == labelNr, 1, 0))
            if len(updateNZ) > 0:
                # Convert to 1-D indexes for the raveled version
                updateNZRavel = numpy.ravel_multi_index(
                    updateNZ, self._denseArray.shape)
                # Zero out the entries we don't want any more
                self._denseArray.ravel()[updateNZRavel] = neutralElement
                # Remove the zeros from the sparse list
                for index in updateNZRavel:
                    self._sparseNZ.pop(index)
            # Labels are continuous values: Shift all higher label values down by 1.
            self._denseArray[:] = numpy.where(self._denseArray > labelNr,
                                              self._denseArray - 1,
                                              self._denseArray)
            self._maxLabel = self._denseArray.max()
            self.lock.release()
            self.outputs["nonzeroValues"].setDirty(slice(None))
            self.outputs["nonzeroCoordinates"].setDirty(slice(None))
            self.outputs["Output"].setDirty(slice(None))
            self.outputs["maxLabel"].setValue(self._maxLabel)

    def execute(self, slot, subindex, roi, result):
        key = roiToSlice(roi.start, roi.stop)

        self.lock.acquire()
        assert (
            self.inputs["eraser"].ready() == True
            and self.inputs["shape"].ready() == True
        ), "OpDenseSparseArray:  One of the neccessary input slots is not ready: shape: %r, eraser: %r" % (
            self.inputs["eraser"].ready(), self.inputs["shape"].ready())
        if slot.name == "Output":
            result[:] = self._denseArray[key]
        elif slot.name == "nonzeroValues":
            result[0] = numpy.array(self._sparseNZ.values())
        elif slot.name == "nonzeroCoordinates":
            result[0] = numpy.array(self._sparseNZ.keys())
        elif slot.name == "maxLabel":
            result[0] = self._maxLabel
        self.lock.release()
        return result

    def setInSlot(self, slot, subindex, roi, value):
        key = roi.toSlice()
        assert value.dtype == self._denseArray.dtype, "Labels must be {}".format(
            self._denseArray.dtype)
        assert isinstance(value, numpy.ndarray)
        if type(value) != numpy.ndarray:
            # vigra.VigraArray doesn't handle advanced indexing correctly,
            #   so convert to numpy.ndarray first
            value = value.view(numpy.ndarray)

        shape = self.inputs["shape"].value
        eraseLabel = self.inputs["eraser"].value
        neutralElement = 0

        self.lock.acquire()
        #fix slicing of single dimensions:
        start, stop = sliceToRoi(key, shape, extendSingleton=False)
        start = start.floor()._asint()
        stop = stop.floor()._asint()

        tempKey = roiToSlice(start - start, stop - start)  #, hardBind = True)

        stop += numpy.where(stop - start == 0, 1, 0)

        key = roiToSlice(start, stop)

        updateShape = tuple(stop - start)

        update = self._denseArray[key].copy()

        update[tempKey] = value

        startRavel = numpy.ravel_multi_index(numpy.array(start, numpy.int32),
                                             shape)

        #insert values into dict
        updateNZ = numpy.nonzero(numpy.where(update != neutralElement, 1, 0))
        updateNZRavelSmall = numpy.ravel_multi_index(updateNZ, updateShape)

        if isinstance(value, numpy.ndarray):
            valuesNZ = value.ravel()[updateNZRavelSmall]
        else:
            valuesNZ = value

        updateNZRavel = numpy.ravel_multi_index(updateNZ, shape)
        updateNZRavel += startRavel

        self._denseArray.ravel()[updateNZRavel] = valuesNZ

        valuesNZ = self._denseArray.ravel()[updateNZRavel]

        self._denseArray.ravel()[updateNZRavel] = valuesNZ

        td = blist.sorteddict(zip(updateNZRavel.tolist(), valuesNZ.tolist()))

        self._sparseNZ.update(td)

        #remove values to be deleted
        updateNZ = numpy.nonzero(numpy.where(update == eraseLabel, 1, 0))
        if len(updateNZ) > 0:
            updateNZRavel = numpy.ravel_multi_index(updateNZ, shape)
            updateNZRavel += startRavel
            self._denseArray.ravel()[updateNZRavel] = neutralElement
            for index in updateNZRavel:
                self._sparseNZ.pop(index)

        # Update our maxlabel
        self._maxLabel = self._denseArray.max()

        self.lock.release()

        # Set our max label dirty if necessary
        self.outputs["maxLabel"].setValue(self._maxLabel)
        self.outputs["Output"].setDirty(key)

    def propagateDirty(self, dirtySlot, subindex, roi):
        if dirtySlot == self.Input:
            self.Output.setDirty(roi)
        else:
            # All other inputs are single-value inputs that will trigger
            #  a new call to setupOutputs, which already sets the outputs dirty.
            # (See above.)
            pass
Example #20
0
class OpThresholdTwoLevels(Operator):
    name = "OpThresholdTwoLevels"

    RawInput = InputSlot(optional=True)  # Display only
    InputChannelColors = InputSlot(optional=True)  # Display only

    InputImage = InputSlot()
    MinSize = InputSlot(stype='int', value=10)
    MaxSize = InputSlot(stype='int', value=1000000)
    HighThreshold = InputSlot(stype='float', value=0.5)
    LowThreshold = InputSlot(stype='float', value=0.2)
    SingleThreshold = InputSlot(stype='float', value=0.5)
    SmootherSigma = InputSlot(value={'x': 1.0, 'y': 1.0, 'z': 1.0})
    Channel = InputSlot(value=0)
    CurOperator = InputSlot(stype='int', value=0)

    ## Graph-Cut options ##

    SingleThresholdGC = InputSlot(stype='float', value=0.5)

    Beta = InputSlot(value=.2)

    # apply thresholding before graph-cut
    UsePreThreshold = InputSlot(stype='bool', value=True)

    # margin around single object (only graph-cut)
    Margin = InputSlot(value=numpy.asarray((20, 20, 20)))

    ## Output slots ##

    Output = OutputSlot()

    CachedOutput = OutputSlot()  # For the GUI (blockwise-access)

    # For serialization
    InputHdf5 = InputSlot(optional=True)

    CleanBlocks = OutputSlot()

    OutputHdf5 = OutputSlot()

    ## Debug outputs

    InputChannel = OutputSlot()
    Smoothed = OutputSlot()
    BigRegions = OutputSlot()
    SmallRegions = OutputSlot()
    FilteredSmallLabels = OutputSlot()
    BeforeSizeFilter = OutputSlot()

    def __init__(self, *args, **kwargs):
        super(OpThresholdTwoLevels, self).__init__(*args, **kwargs)

        self.InputImage.notifyReady(self.checkConstraints)

        self._opReorder1 = OpReorderAxes(parent=self)
        self._opReorder1.AxisOrder.setValue('txyzc')
        self._opReorder1.Input.connect(self.InputImage)

        self._opChannelSelector = OpSingleChannelSelector(parent=self)
        self._opChannelSelector.Input.connect(self._opReorder1.Output)
        self._opChannelSelector.Index.connect(self.Channel)

        # anisotropic gauss
        self._opSmoother = OpAnisotropicGaussianSmoothing5d(parent=self)
        self._opSmoother.Sigmas.connect(self.SmootherSigma)
        self._opSmoother.Input.connect(self._opChannelSelector.Output)

        # debug output
        self.Smoothed.connect(self._opSmoother.Output)

        # single threshold operator
        self.opThreshold1 = _OpThresholdOneLevel(parent=self)
        self.opThreshold1.Threshold.connect(self.SingleThreshold)
        self.opThreshold1.MinSize.connect(self.MinSize)
        self.opThreshold1.MaxSize.connect(self.MaxSize)

        # double threshold operator
        self.opThreshold2 = _OpThresholdTwoLevels(parent=self)
        self.opThreshold2.MinSize.connect(self.MinSize)
        self.opThreshold2.MaxSize.connect(self.MaxSize)
        self.opThreshold2.LowThreshold.connect(self.LowThreshold)
        self.opThreshold2.HighThreshold.connect(self.HighThreshold)

        # Identity-preserving hysteresis thresholding
        self.opIpht = OpIpht(parent=self)
        self.opIpht.MinSize.connect(self.MinSize)
        self.opIpht.MaxSize.connect(self.MaxSize)
        self.opIpht.LowThreshold.connect(self.LowThreshold)
        self.opIpht.HighThreshold.connect(self.HighThreshold)
        self.opIpht.InputImage.connect(self._opSmoother.Output)

        if haveGraphCut():
            self.opThreshold1GC = _OpThresholdOneLevel(parent=self)
            self.opThreshold1GC.Threshold.connect(self.SingleThresholdGC)
            self.opThreshold1GC.MinSize.connect(self.MinSize)
            self.opThreshold1GC.MaxSize.connect(self.MaxSize)

            self.opObjectsGraphCut = OpObjectsSegment(parent=self)
            self.opObjectsGraphCut.Prediction.connect(self.Smoothed)
            self.opObjectsGraphCut.LabelImage.connect(
                self.opThreshold1GC.Output)
            self.opObjectsGraphCut.Beta.connect(self.Beta)
            self.opObjectsGraphCut.Margin.connect(self.Margin)

            self.opGraphCut = OpGraphCut(parent=self)
            self.opGraphCut.Prediction.connect(self.Smoothed)
            self.opGraphCut.Beta.connect(self.Beta)

        self._op5CacheOutput = OpReorderAxes(parent=self)

        self._opReorder2 = OpReorderAxes(parent=self)
        self.Output.connect(self._opReorder2.Output)

        #cache our own output, don't propagate from internal operator
        self._cache = _OpCacheWrapper(parent=self)
        self._cache.name = "OpThresholdTwoLevels.OpCacheWrapper"
        self.CachedOutput.connect(self._cache.Output)

        # Serialization slots
        self._cache.InputHdf5.connect(self.InputHdf5)
        self.CleanBlocks.connect(self._cache.CleanBlocks)
        self.OutputHdf5.connect(self._cache.OutputHdf5)

        #Debug outputs
        self.InputChannel.connect(self._opChannelSelector.Output)

    def setupOutputs(self):

        self._opReorder2.AxisOrder.setValue(self.InputImage.meta.getAxisKeys())

        # propagate drange
        self.opThreshold1.InputImage.meta.drange = self.InputImage.meta.drange
        if haveGraphCut():
            self.opThreshold1GC.InputImage.meta.drange = self.InputImage.meta.drange
        self.opThreshold2.InputImage.meta.drange = self.InputImage.meta.drange

        self._disconnectAll()

        curIndex = self.CurOperator.value

        if curIndex == 0:
            outputSlot = self._connectForSingleThreshold(self.opThreshold1)
        elif curIndex == 1:
            outputSlot = self._connectForTwoLevelThreshold()
        elif curIndex == 2:
            outputSlot = self._connectForGraphCut()
        elif curIndex == 3:
            outputSlot = self.opIpht.Output
        else:
            raise ValueError(
                "Unknown index {} for current tab.".format(curIndex))

        self._opReorder2.Input.connect(outputSlot)
        # force the cache to emit a dirty signal
        self._cache.Input.connect(outputSlot)
        self._cache.Input.setDirty(slice(None))

    def checkConstraints(self, *args):
        if self._opReorder1.Output.ready():
            numChannels = self._opReorder1.Output.meta.getTaggedShape()['c']
            if self.Channel.value >= numChannels:
                raise DatasetConstraintError(
                    "Two-Level Thresholding",
                    "Your project is configured to select data from channel"
                    " #{}, but your input data only has {} channels.".format(
                        self.Channel.value, numChannels))

    def _disconnectAll(self):
        # start from back
        for slot in [
                self.BigRegions, self.SmallRegions, self.FilteredSmallLabels,
                self.BeforeSizeFilter
        ]:
            slot.disconnect()
            slot.meta.NOTREADY = True
        self._opReorder2.Input.disconnect()
        if haveGraphCut():
            self.opThreshold1GC.InputImage.disconnect()
        self.opThreshold1.InputImage.disconnect()
        self.opThreshold2.InputImage.disconnect()

    def _connectForSingleThreshold(self, threshOp):
        # connect the operators for SingleThreshold
        self.BeforeSizeFilter.connect(threshOp.BeforeSizeFilter)
        self.BeforeSizeFilter.meta.NOTREADY = None
        threshOp.InputImage.connect(self._opSmoother.Output)
        return threshOp.Output

    def _connectForTwoLevelThreshold(self):
        # connect the operators for TwoLevelThreshold
        self.BigRegions.connect(self.opThreshold2.BigRegions)
        self.SmallRegions.connect(self.opThreshold2.SmallRegions)
        self.FilteredSmallLabels.connect(self.opThreshold2.FilteredSmallLabels)
        for slot in [
                self.BigRegions, self.SmallRegions, self.FilteredSmallLabels
        ]:
            slot.meta.NOTREADY = None
        self.opThreshold2.InputImage.connect(self._opSmoother.Output)

        return self.opThreshold2.Output

    def _connectForGraphCut(self):
        assert haveGraphCut(), "Module for graph cut is not available"
        if self.UsePreThreshold.value:
            self._connectForSingleThreshold(self.opThreshold1GC)
            return self.opObjectsGraphCut.Output
        else:
            return self.opGraphCut.Output

    # raise an error if setInSlot is called, we do not pre-cache input
    #def setInSlot(self, slot, subindex, roi, value):
    #pass

    def execute(self, slot, subindex, roi, destination):
        assert False, "Shouldn't get here."

    def propagateDirty(self, slot, subindex, roi):
        # dirtiness propagation is handled in the sub-operators
        pass

    def setInSlot(self, slot, subindex, roi, value):
        assert slot == self.InputHdf5,\
            "[{}] Wrong slot for setInSlot(): {}".format(self.name,
                                                         slot)
        pass
Example #21
0
class OpDataSelectionGroup(Operator):
    # Inputs
    ProjectFile = InputSlot(stype='object', optional=True)
    ProjectDataGroup = InputSlot(stype='string', optional=True)
    WorkingDirectory = InputSlot(stype='filestring')
    DatasetRoles = InputSlot(stype='object')

    # Must mark as optional because not all subslots are required.
    DatasetGroup = InputSlot(stype='object', level=1, optional=True)

    # Outputs
    ImageGroup = OutputSlot(level=1)

    # These output slots are provided as a convenience, since otherwise it is tricky to create a lane-wise multislot of
    # level-1 for only a single role.
    # (It can be done, but requires OpTransposeSlots to invert the level-2 multislot indexes...)
    Image = OutputSlot()  # The first dataset. Equivalent to ImageGroup[0]
    Image1 = OutputSlot()  # The second dataset. Equivalent to ImageGroup[1]
    Image2 = OutputSlot()  # The third dataset. Equivalent to ImageGroup[2]
    AllowLabels = OutputSlot(
        stype='bool')  # Pulled from the first dataset only.

    _NonTransposedImageGroup = OutputSlot(level=1)

    # Must be the LAST slot declared in this class.
    # When the shell detects that this slot has been resized,
    #  it assumes all the others have already been resized.
    ImageName = OutputSlot(
    )  # Name of the first dataset is used.  Other names are ignored.

    def __init__(self, forceAxisOrder=None, *args, **kwargs):
        super(OpDataSelectionGroup, self).__init__(*args, **kwargs)
        self._opDatasets = None
        self._roles = []
        self._forceAxisOrder = forceAxisOrder

        def handleNewRoles(*args):
            self.DatasetGroup.resize(len(self.DatasetRoles.value))

        self.DatasetRoles.notifyReady(handleNewRoles)

    def setupOutputs(self):
        # Create internal operators
        if self.DatasetRoles.value != self._roles:
            self._roles = self.DatasetRoles.value
            # Clean up the old operators
            self.ImageGroup.disconnect()
            self.Image.disconnect()
            self.Image1.disconnect()
            self.Image2.disconnect()
            self._NonTransposedImageGroup.disconnect()
            if self._opDatasets is not None:
                self._opDatasets.cleanUp()

            self._opDatasets = OperatorWrapper(
                OpDataSelection,
                parent=self,
                operator_kwargs={'forceAxisOrder': self._forceAxisOrder},
                broadcastingSlotNames=[
                    'ProjectFile', 'ProjectDataGroup', 'WorkingDirectory'
                ])
            self.ImageGroup.connect(self._opDatasets.Image)
            self._NonTransposedImageGroup.connect(
                self._opDatasets._NonTransposedImage)
            self._opDatasets.Dataset.connect(self.DatasetGroup)
            self._opDatasets.ProjectFile.connect(self.ProjectFile)
            self._opDatasets.ProjectDataGroup.connect(self.ProjectDataGroup)
            self._opDatasets.WorkingDirectory.connect(self.WorkingDirectory)

        for role_index, opDataSelection in enumerate(self._opDatasets):
            opDataSelection.RoleName.setValue(self._roles[role_index])

        if len(self._opDatasets.Image) > 0:
            self.Image.connect(self._opDatasets.Image[0])

            if len(self._opDatasets.Image) >= 2:
                self.Image1.connect(self._opDatasets.Image[1])
            else:
                self.Image1.disconnect()
                self.Image1.meta.NOTREADY = True

            if len(self._opDatasets.Image) >= 3:
                self.Image2.connect(self._opDatasets.Image[2])
            else:
                self.Image2.disconnect()
                self.Image2.meta.NOTREADY = True

            self.ImageName.connect(self._opDatasets.ImageName[0])
            self.AllowLabels.connect(self._opDatasets.AllowLabels[0])
        else:
            self.Image.disconnect()
            self.Image1.disconnect()
            self.Image2.disconnect()
            self.ImageName.disconnect()
            self.AllowLabels.disconnect()
            self.Image.meta.NOTREADY = True
            self.Image1.meta.NOTREADY = True
            self.Image2.meta.NOTREADY = True
            self.ImageName.meta.NOTREADY = True
            self.AllowLabels.meta.NOTREADY = True

    def execute(self, slot, subindex, rroi, result):
        assert False, "Unknown or unconnected output slot: {}".format(
            slot.name)

    def propagateDirty(self, slot, subindex, roi):
        # Output slots are directly connected to internal operators
        pass
Example #22
0
class OpLabeledThreshold(Operator):
    Input = InputSlot()  # Must have exactly 1 channel
    CoreLabels = InputSlot(optional=True)  # Not used for 'Simple' method.
    Method = InputSlot(value=ThresholdMethod.SIMPLE)
    FinalThreshold = InputSlot(value=0.2)
    GraphcutBeta = InputSlot(value=0.2)  # Graphcut only

    Output = OutputSlot()

    def __init__(self, *args, **kwargs):
        super(OpLabeledThreshold, self).__init__(*args, **kwargs)

        execute_funcs = {}
        execute_funcs[ThresholdMethod.SIMPLE] = self._execute_SIMPLE
        execute_funcs[ThresholdMethod.HYSTERESIS] = self._execute_HYSTERESIS
        execute_funcs[ThresholdMethod.GRAPHCUT] = self._execute_GRAPHCUT
        execute_funcs[ThresholdMethod.IPHT] = self._execute_IPHT
        self.execute_funcs = execute_funcs

    def setupOutputs(self):
        assert self.Input.meta.getAxisKeys() == list("tzyxc")
        assert self.Input.meta.shape[-1] == 1
        if self.CoreLabels.ready():
            assert self.CoreLabels.meta.getAxisKeys() == list("tzyxc")

        self.Output.meta.assignFrom(self.Input.meta)
        self.Output.meta.dtype = np.uint32

    def propagateDirty(self, slot, subindex, roi):
        self.Output.setDirty()

    def execute(self, slot, subindex, roi, result):
        result = vigra.taggedView(result, self.Output.meta.axistags)

        # Iterate over time slices to avoid connected component problems.
        for t_index, t in enumerate(range(roi.start[0], roi.stop[0])):
            t_slice_roi = roi.copy()
            t_slice_roi.start[0] = t
            t_slice_roi.stop[0] = t + 1

            result_slice = result[t_index:t_index + 1]
            self.execute_funcs[self.Method.value](t_slice_roi, result_slice)

    def _execute_SIMPLE(self, roi, result):
        assert result.shape[0] == 1
        assert tuple(roi.stop - roi.start) == result.shape

        final_threshold = self.FinalThreshold.value

        data = self.Input(roi.start, roi.stop).wait()
        data = vigra.taggedView(data, self.Input.meta.axistags)

        result = vigra.taggedView(result, self.Output.meta.axistags)

        binary = (data >= final_threshold).view(np.uint8)
        vigra.analysis.labelMultiArrayWithBackground(binary[0, ..., 0],
                                                     out=result[0, ..., 0])

    def _execute_HYSTERESIS(self, roi, result):
        self._execute_SIMPLE(roi, result)
        final_labels = vigra.taggedView(result, self.Output.meta.axistags)

        core_labels = self.CoreLabels(roi.start, roi.stop).wait()
        core_labels = vigra.taggedView(core_labels,
                                       self.CoreLabels.meta.axistags)

        select_labels(core_labels, final_labels)  # Edits final_labels in-place

    def _execute_IPHT(self, roi, result):
        core_labels = self.CoreLabels(roi.start, roi.stop).wait()
        core_labels = vigra.taggedView(core_labels,
                                       self.CoreLabels.meta.axistags)

        data = self.Input(roi.start, roi.stop).wait()
        data = vigra.taggedView(data, self.Input.meta.axistags)

        final_threshold = self.FinalThreshold.value
        result = vigra.taggedView(result, self.Output.meta.axistags)
        threshold_from_cores(data[0, ..., 0],
                             core_labels[0, ..., 0],
                             final_threshold,
                             out=result[0, ..., 0])

    def _execute_GRAPHCUT(self, roi, result):
        data = self.Input(roi.start, roi.stop).wait()
        data = vigra.taggedView(data, self.Input.meta.axistags)
        data_zyx = data[0, ..., 0]

        beta = self.GraphcutBeta.value
        ft = self.FinalThreshold.value

        # The segmentGC() function will implicitly threshold at 0.5,
        # but we want to respect the user's FinalThreshold setting.
        # Here, we scale from input pixels --> gc potentials in the following way:
        #
        # 0.0..FT --> 0.0..0.5
        # 0.5..FT --> 0.5..1.0
        #
        # For instance, input pixels that match the user's FT exactly will map to 0.5,
        # and graphcut will place them on the threshold border.

        above_threshold_mask = data_zyx >= ft
        below_threshold_mask = ~above_threshold_mask

        data_zyx[below_threshold_mask] *= old_div(0.5, ft)
        data_zyx[above_threshold_mask] = 0.5 + old_div(
            (data_zyx[above_threshold_mask] - ft), (1 - ft))

        binary_seg_zyx = segmentGC(data_zyx, beta).astype(np.uint8)
        del data_zyx
        vigra.analysis.labelMultiArrayWithBackground(binary_seg_zyx,
                                                     out=result[0, ..., 0])
Example #23
0
class OpCompressedUserLabelArray(OpUnmanagedCompressedCache):
    """
    A subclass of OpUnmanagedCompressedCache that is suitable for storing user-drawn label pixels.
    (This is not a 'managed' cache because its data must never be deleted by the memory manager.)
    Note that setInSlot has special functionality (only non-zero pixels are written, and there is also an "eraser" pixel value).

    See note below about blockshape changes.
    """
    #Input = InputSlot()
    shape = InputSlot(optional=True) # Should not be used.
    eraser = InputSlot()
    deleteLabel = InputSlot(optional = True)
    blockShape = InputSlot() # If the blockshape is changed after labels have been stored, all cache data is lost.

    #Output = OutputSlot()
    #nonzeroValues = OutputSlot()
    #nonzeroCoordinates = OutputSlot()
    nonzeroBlocks = OutputSlot()
    #maxLabel = OutputSlot()
    
    Projection2D = OutputSlot(allow_mask=True) # A somewhat magic output that returns a projection of all
                                               # label data underneath a given roi, from all slices.
                                               # If, for example, a 256x1x256 tile is requested from this slot,
                                               # It will return a projection of ALL labels that fall within the 256 x ... x 256 tile.
                                               # (The projection axis is *inferred* from the shape of the requested data).
                                               # The projection data is float32 between 0.0 and 1.0, where:
                                               # - Exactly 0.0 means "no labels under this pixel"
                                               # - 1/256.0 means "labels in the first slice"
                                               # - ...
                                               # - 1.0 means "last slice"
                                               # The output is suitable for display in a colortable.
    
    def __init__(self, *args, **kwargs):
        self._blockshape = None
        self._label_to_purge = 0
        super(OpCompressedUserLabelArray, self).__init__( *args, **kwargs )

        # ignoring the ideal chunk shape is ok because we use the input only
        # to get the volume shape
        self._ignore_ideal_blockshape = True
    
    def clearLabel(self, label_value):
        """
        Clear (reset to 0) all pixels of the given label value.
        Unlike using the deleteLabel slot, this function does not "shift down" all labels above this label value.
        """
        self._purge_label( label_value, False )

    def mergeLabels(self, from_label, into_label):
        self._purge_label(from_label, True, into_label)
    
    def setupOutputs(self):
        # Due to a temporary naming clash, pass our subclass blockshape to the superclass
        # TODO: Fix this by renaming the BlockShape slots to be consistent.
        self.BlockShape.setValue( self.blockShape.value )
        
        super( OpCompressedUserLabelArray, self ).setupOutputs()
        if self.Output.meta.NOTREADY:
            self.nonzeroBlocks.meta.NOTREADY = True
            self.Projection2D.meta.NOTREADY = True
            return
        self.nonzeroBlocks.meta.dtype = object
        self.nonzeroBlocks.meta.shape = (1,)
        
        # Overwrite the Output metadata (should be uint8 no matter what the input data is...)
        self.Output.meta.assignFrom(self.Input.meta)
        self.Output.meta.dtype = numpy.uint8
        self.Output.meta.shape = self.Input.meta.shape[:-1] + (1,)
        self.Output.meta.drange = (0,255)
        self.OutputHdf5.meta.assignFrom(self.Output.meta)
        
        # The Projection2D slot is a strange beast:
        # It appears to have the same output shape as any other output slot,
        #  but it can only be accessed in 2D slices.
        self.Projection2D.meta.assignFrom(self.Output.meta)
        self.Projection2D.meta.dtype = numpy.float32
        self.Projection2D.meta.drange = (0.0, 1.0)
        

        # Overwrite the blockshape
        if self._blockshape is None:
            self._blockshape = numpy.minimum( self.BlockShape.value, self.Output.meta.shape )
        elif self.blockShape.value != self._blockshape:
            nonzero_blocks_destination = [None]
            self._execute_nonzeroBlocks(nonzero_blocks_destination)
            nonzero_blocks = nonzero_blocks_destination[0]
            if len(nonzero_blocks) > 0:
                raise RuntimeError( "You are not permitted to reconfigure the labeling operator after you've already stored labels in it." )

        # Overwrite chunkshape now that blockshape has been overwritten
        self._chunkshape = self._chooseChunkshape(self._blockshape)

        self._eraser_magic_value = self.eraser.value
        
        # Are we being told to delete a label?
        if self.deleteLabel.ready():
            new_purge_label = self.deleteLabel.value
            if self._label_to_purge != new_purge_label:
                self._label_to_purge = new_purge_label
                if self._label_to_purge > 0:
                    self._purge_label( self._label_to_purge, True )

    def _purge_label(self, label_to_purge, decrement_remaining, replacement_value=0):
        """
        Scan through all labeled pixels.
        (1) Reassign all pixels of the given value (set to replacement_value)
        (2) If decrement_remaining=True, decrement all labels above that 
            value so the set of stored labels remains consecutive.
            Note that the decrement is performed AFTER replacement.
        """
        changed_block_rois = []
        #stored_block_rois = self.CleanBlocks.value
        stored_block_roi_destination = [None]
        self.execute(self.CleanBlocks, (), SubRegion( self.Output, (0,),(1,) ), stored_block_roi_destination)
        stored_block_rois = stored_block_roi_destination[0]

        for block_roi in stored_block_rois:
            # Get data
            block_shape = numpy.subtract( block_roi[1], block_roi[0] )
            block = self.Output.stype.allocateDestination(SubRegion(self.Output, *roiFromShape(block_shape)))

            self.execute(self.Output, (), SubRegion( self.Output, *block_roi ), block)

            # Locate pixels to change
            matching_label_coords = numpy.nonzero( block == label_to_purge )

            # Change the data
            block[matching_label_coords] = replacement_value
            coords_to_decrement = block > label_to_purge
            if decrement_remaining:
                block[coords_to_decrement] -= 1
            
            # Update cache with the new data (only if something really changed)
            if len(matching_label_coords[0]) > 0 or (decrement_remaining and coords_to_decrement.sum() > 0):
                super( OpCompressedUserLabelArray, self )._setInSlotInput( self.Input, (), SubRegion( self.Output, *block_roi ), block, store_zero_blocks=False )
                changed_block_rois.append( block_roi )

        for block_roi in changed_block_rois:
            # FIXME: Shouldn't this dirty notification be handled in OpUnmanagedCompressedCache?
            self.Output.setDirty( *block_roi )
    
    def execute(self, slot, subindex, roi, destination):
        if slot == self.Output:
            self._executeOutput(roi, destination)
        elif slot == self.nonzeroBlocks:
            self._execute_nonzeroBlocks(destination)
        elif slot == self.Projection2D:
            self._executeProjection2D(roi, destination)
        else:
            return super( OpCompressedUserLabelArray, self ).execute( slot, subindex, roi, destination )

    def _executeOutput(self, roi, destination):
        assert len(roi.stop) == len(self.Output.meta.shape), \
            "roi: {} has the wrong number of dimensions for Output shape: {}"\
            "".format( roi, self.Output.meta.shape )
        assert numpy.less_equal(roi.stop, self.Output.meta.shape).all(), \
            "roi: {} is out-of-bounds for Output shape: {}"\
            "".format( roi, self.Output.meta.shape )
        
        block_starts = getIntersectingBlocks( self._blockshape, (roi.start, roi.stop) )
        self._copyData(roi, destination, block_starts)
        return destination

    def _execute_nonzeroBlocks(self, destination):
        stored_block_rois_destination = [None]
        self._executeCleanBlocks( stored_block_rois_destination )
        stored_block_rois = stored_block_rois_destination[0]
        block_slicings = map( lambda block_roi: roiToSlice(*block_roi), stored_block_rois )
        destination[0] = block_slicings

    def _executeProjection2D(self, roi, destination):
        assert sum(TinyVector(destination.shape) > 1) <= 2, "Projection result must be exactly 2D"
        
        # First, we have to determine which axis we are projecting along.
        # We infer this from the shape of the roi.
        # For example, if the roi is of shape 
        #  zyx = (1,256,256), then we know we're projecting along Z
        # If more than one axis has a width of 1, then we choose an 
        #  axis according to the following priority order: zyxt
        tagged_input_shape = self.Output.meta.getTaggedShape()
        tagged_result_shape = collections.OrderedDict( zip( tagged_input_shape.keys(),
                                                            destination.shape ) )
        nonprojection_axes = []
        for key in tagged_input_shape.keys():
            if (key == 'c' or tagged_input_shape[key] == 1 or tagged_result_shape[key] > 1):
                nonprojection_axes.append( key )
            
        possible_projection_axes = set(tagged_input_shape) - set(nonprojection_axes)
        if len(possible_projection_axes) == 0:
            # If the image is 2D to begin with, 
            #   then the projection is simply the same as the normal output,
            #   EXCEPT it is made binary
            self.Output(roi.start, roi.stop).writeInto(destination).wait()
            
            # make binary
            numpy.greater(destination, 0, out=destination)
            return
        
        for k in 'zyxt':
            if k in possible_projection_axes:
                projection_axis_key = k
                break

        # Now we know which axis we're projecting along.
        # Proceed with the projection, working blockwise to avoid unecessary work in unlabeled blocks
        
        projection_axis_index = self.Output.meta.getAxisKeys().index(projection_axis_key)
        projection_length = tagged_input_shape[projection_axis_key]
        input_roi = roi.copy()
        input_roi.start[projection_axis_index] = 0
        input_roi.stop[projection_axis_index] = projection_length

        destination[:] = 0.0

        # Get the logical blocking.
        block_starts = getIntersectingBlocks( self._blockshape, (input_roi.start, input_roi.stop) )

        # (Parallelism wouldn't help here: h5py will serialize these requests anyway)
        block_starts = map( tuple, block_starts )
        for block_start in block_starts:
            if block_start not in self._cacheFiles:
                # No label data in this block.  Move on.
                continue

            entire_block_roi = getBlockBounds( self.Output.meta.shape, self._blockshape, block_start )

            # This block's portion of the roi
            intersecting_roi = getIntersection( (input_roi.start, input_roi.stop), entire_block_roi )
            
            # Compute slicing within the deep array and slicing within this block
            deep_relative_intersection = numpy.subtract(intersecting_roi, input_roi.start)
            block_relative_intersection = numpy.subtract(intersecting_roi, block_start)
            block_relative_intersection_slicing = roiToSlice(*block_relative_intersection)

            block = self._getBlockDataset( entire_block_roi )
            deep_data = None
            if self.Output.meta.has_mask:
                deep_data = numpy.ma.masked_array(
                    block["data"][block_relative_intersection_slicing],
                    mask=block["mask"][block_relative_intersection_slicing],
                    fill_value=block["fill_value"][()],
                    shrink=False
                )
            else:
                deep_data = block[block_relative_intersection_slicing]

            # make binary and convert to float (must copy)
            deep_data_float = deep_data.astype(numpy.float32)
            deep_data_float[deep_data_float.nonzero()] = 1

            # multiply by slice-index
            deep_data_view = numpy.rollaxis(deep_data_float, projection_axis_index, 0)

            min_deep_slice_index = deep_relative_intersection[0][projection_axis_index]
            max_deep_slice_index = deep_relative_intersection[1][projection_axis_index]
            
            def calc_color_value(slice_index):
                # Note 1: We assume that the colortable has at least 256 entries in it,
                #           so, we try to ensure that all colors are above 1/256 
                #           (we don't want colors in low slices to be rounded to 0)
                # Note 2: Ideally, we'd use a min projection in the code below, so that 
                #           labels in the "back" slices would appear occluded.  But the 
                #           min projection would favor 0.0.  Instead, we invert the 
                #           relationship between color and slice index, do a max projection, 
                #           and then re-invert the colors after everything is done.
                #           Hence, this function starts with (1.0 - ...)
                return (1.0 - (float(slice_index) / projection_length)) * (1.0 - 1.0/255) + 1.0/255.0
            min_color_value = calc_color_value(min_deep_slice_index)
            max_color_value = calc_color_value(max_deep_slice_index)
            
            num_slices = max_deep_slice_index - min_deep_slice_index
            deep_data_view *= numpy.linspace( min_color_value, max_color_value, num=num_slices )\
                              [ (slice(None),) + (None,)*(deep_data_view.ndim-1) ]

            # Take the max projection of this block's data.
            block_max_projection = deep_data_float.max(axis=projection_axis_index)
            block_max_projection = numpy.ma.expand_dims(block_max_projection, axis=projection_axis_index)

            # Merge this block's projection into the overall projection.
            destination_relative_intersection = numpy.array(deep_relative_intersection)
            destination_relative_intersection[:, projection_axis_index] = (0,1)
            destination_relative_intersection_slicing = roiToSlice(*destination_relative_intersection)

            destination_subview = destination[destination_relative_intersection_slicing]
            numpy.maximum(block_max_projection, destination_subview, out=destination_subview)
            
            # Invert the nonzero pixels so increasing colors correspond to increasing slices.
            # See comment in calc_color_value(), above.
            destination_subview[destination_subview.nonzero()] -= 1
            destination_subview[()] = -destination_subview

        return

    def _copyData(self, roi, destination, block_starts):
        """
        Copy data from each block into the destination array.
        For blocks that aren't currently stored, just write zeros.
        """
        # (Parallelism not needed here: h5py will serialize these requests anyway)
        block_starts = map( tuple, block_starts )
        for block_start in block_starts:
            entire_block_roi = getBlockBounds( self.Output.meta.shape, self._blockshape, block_start )

            # This block's portion of the roi
            intersecting_roi = getIntersection( (roi.start, roi.stop), entire_block_roi )
            
            # Compute slicing within destination array and slicing within this block
            destination_relative_intersection = numpy.subtract(intersecting_roi, roi.start)
            block_relative_intersection = numpy.subtract(intersecting_roi, block_start)
            destination_relative_intersection_slicing = roiToSlice(*destination_relative_intersection)
            block_relative_intersection_slicing = roiToSlice(*block_relative_intersection)

            if block_start in self._cacheFiles:
                # Copy from block to destination
                dataset = self._getBlockDataset( entire_block_roi )

                if self.Output.meta.has_mask:
                    destination[ destination_relative_intersection_slicing ] = dataset["data"][ block_relative_intersection_slicing ]
                    destination.mask[ destination_relative_intersection_slicing ] = dataset["mask"][ block_relative_intersection_slicing ]
                    destination.fill_value = dataset["fill_value"][()]
                else:
                    destination[ destination_relative_intersection_slicing ] = dataset[ block_relative_intersection_slicing ]
            else:
                # Not stored yet.  Overwrite with zeros.
                destination[ destination_relative_intersection_slicing ] = 0

    def propagateDirty(self, slot, subindex, roi):
        # There should be no way to make the output dirty except via setInSlot()
        pass

    def setInSlot(self, slot, subindex, roi, new_pixels):
        if slot is self.Input:
            self._setInSlotInput(slot, subindex, roi, new_pixels)
        else:
            # We don't yet support the InputHdf5 slot in this function.
            assert False, "Unsupported slot for setInSlot: {}".format( slot.name )
            
    def _setInSlotInput(self, slot, subindex, roi, new_pixels):
        """
        Since this is a label array, inserting pixels has a special meaning:
        We only overwrite the new non-zero pixels. In the new data, zeros mean "don't change".
        
        So, here's what each pixel we're adding means:
        0: don't change
        1: change to 1
        2: change to 2
        ...
        N: change to N
        magic_eraser_value: change to 0  
        """

        # Extract the data to modify
        original_data = self.Output.stype.allocateDestination(SubRegion(self.Output, *roiFromShape(new_pixels.shape)))

        self.execute(self.Output, (), roi, original_data)
        
        # Reset the pixels we need to change (so we can use |= below)
        original_data[new_pixels.nonzero()] = 0
        
        # Update
        original_data |= new_pixels

        # Replace 'eraser' values with zeros.
        cleaned_data = original_data.copy()
        cleaned_data[original_data == self._eraser_magic_value] = 0

        # Set in the cache (our superclass).
        super( OpCompressedUserLabelArray, self )._setInSlotInput( slot, subindex, roi, cleaned_data, store_zero_blocks=False )
        
        # FIXME: Shouldn't this notification be triggered from within OpUnmanagedCompressedCache?
        self.Output.setDirty( roi.start, roi.stop )
        
        return cleaned_data # Internal use: Return the cleaned_data        

    def ingestData(self, slot):
        """
        Read the data from the given slot and copy it into this cache.
        The rules about special pixel meanings apply here, just like setInSlot
        
        Returns: the max label found in the slot.
        """
        assert self._blockshape is not None
        assert self.Output.meta.shape[:-1] == slot.meta.shape[:-1], \
            "{} != {}".format( self.Output.meta.shape, slot.meta.shape )
        max_label = 0

        # Get logical blocking.
        block_starts = getIntersectingBlocks( self._blockshape, roiFromShape(self.Output.meta.shape) )
        block_starts = map( tuple, block_starts )

        # Write each block
        for block_start in block_starts:
            block_roi = getBlockBounds( self.Output.meta.shape, self._blockshape, block_start )
            
            # Request the block data
            block_data = slot(*block_roi).wait()
            
            # Write into the array
            subregion_roi = SubRegion(self.Output, *block_roi)
            cleaned_block_data = self._setInSlotInput( self.Input, (), subregion_roi, block_data )
            
            max_label = max( max_label, cleaned_block_data.max() )
        
        return max_label
Example #24
0
class OpThresholdTwoLevels(Operator):
    RawInput = InputSlot(optional=True)  # Display only
    InputChannelColors = InputSlot(optional=True)  # Display only

    InputImage = InputSlot()
    MinSize = InputSlot(stype="int", value=10)
    MaxSize = InputSlot(stype="int", value=1000000)
    HighThreshold = InputSlot(stype="float", value=0.8)
    LowThreshold = InputSlot(stype="float", value=0.5)
    SmootherSigma = InputSlot(value={"z": 1.0, "y": 1.0, "x": 1.0})
    Channel = InputSlot(value=0)
    CoreChannel = InputSlot(value=0)
    CurOperator = InputSlot(stype="int", value=ThresholdMethod.SIMPLE
                            )  # This slot would be better named 'method',
    # but we're keeping this slot name for backwards
    # compatibility with old project files
    Beta = InputSlot(value=0.2)  # For GraphCut

    ## Output slots ##
    Output = OutputSlot()
    CachedOutput = OutputSlot()  # For the GUI (blockwise-access)

    # For serialization
    CacheInput = InputSlot(optional=True)
    CleanBlocks = OutputSlot()

    ## Debug outputs

    InputChannel = OutputSlot()
    Smoothed = OutputSlot()
    BigRegions = OutputSlot()
    SmallRegions = OutputSlot()
    FilteredSmallLabels = OutputSlot()
    BeforeSizeFilter = OutputSlot()

    ## Basic schematic (debug outputs not shown)
    ##
    ## InputImage -> opReorder -> opSmoother -> opSmootherCache --> opFinalChannelSelector --> opSumInputs --------------------> opFinalThreshold -> opFinalFilter -> opReorder -> Output
    ##                                                         \                             /                                  /                                              \
    ##                                                          --> opCoreChannelSelector --> opCoreThreshold -> opCoreFilter --                                                opCache -> CachedOutput
    ##                                                                                                                                                                                 `-> CleanBlocks
    def __init__(self, *args, **kwargs):
        super(OpThresholdTwoLevels, self).__init__(*args, **kwargs)

        self.opReorderInput = OpReorderAxes(parent=self)
        self.opReorderInput.AxisOrder.setValue("tzyxc")
        self.opReorderInput.Input.connect(self.InputImage)

        # PROBABILITIES: Convert to float32
        self.opConvertProbabilities = OpConvertDtype(parent=self)
        self.opConvertProbabilities.ConversionDtype.setValue(np.float32)
        self.opConvertProbabilities.Input.connect(self.opReorderInput.Output)

        # PROBABILITIES: Normalize drange to [0.0, 1.0]
        self.opNormalizeProbabilities = OpPixelOperator(parent=self)

        def normalize_inplace(a):
            drange = self.opNormalizeProbabilities.Input.meta.drange
            if drange is None or (drange[0] == 0.0 and drange[1] == 1.0):
                return a
            a[:] -= drange[0]
            a[:] = a[:] / float((drange[1] - drange[0]))
            return a

        self.opNormalizeProbabilities.Input.connect(
            self.opConvertProbabilities.Output)
        self.opNormalizeProbabilities.Function.setValue(normalize_inplace)

        self.opSmoother = OpAnisotropicGaussianSmoothing5d(parent=self)
        self.opSmoother.Sigmas.connect(self.SmootherSigma)
        self.opSmoother.Input.connect(self.opNormalizeProbabilities.Output)

        self.opSmootherCache = OpBlockedArrayCache(parent=self)
        self.opSmootherCache.BlockShape.setValue((1, None, None, None, 1))
        self.opSmootherCache.Input.connect(self.opSmoother.Output)

        self.opCoreChannelSelector = OpSingleChannelSelector(parent=self)
        self.opCoreChannelSelector.Index.connect(self.CoreChannel)
        self.opCoreChannelSelector.Input.connect(self.opSmootherCache.Output)

        self.opCoreThreshold = OpLabeledThreshold(parent=self)
        self.opCoreThreshold.Method.setValue(ThresholdMethod.SIMPLE)
        self.opCoreThreshold.FinalThreshold.connect(self.HighThreshold)
        self.opCoreThreshold.Input.connect(self.opCoreChannelSelector.Output)

        self.opCoreFilter = OpFilterLabels(parent=self)
        self.opCoreFilter.BinaryOut.setValue(False)
        self.opCoreFilter.MinLabelSize.connect(self.MinSize)
        self.opCoreFilter.MaxLabelSize.connect(self.MaxSize)
        self.opCoreFilter.Input.connect(self.opCoreThreshold.Output)

        self.opFinalChannelSelector = OpSingleChannelSelector(parent=self)
        self.opFinalChannelSelector.Index.connect(self.Channel)
        self.opFinalChannelSelector.Input.connect(self.opSmootherCache.Output)

        self.opSumInputs = OpMultiArrayMerger(
            parent=self)  # see setupOutputs (below) for input connections
        self.opSumInputs.MergingFunction.setValue(sum)

        self.opFinalThreshold = OpLabeledThreshold(parent=self)
        self.opFinalThreshold.Method.connect(self.CurOperator)
        self.opFinalThreshold.FinalThreshold.connect(self.LowThreshold)
        self.opFinalThreshold.GraphcutBeta.connect(self.Beta)
        self.opFinalThreshold.CoreLabels.connect(self.opCoreFilter.Output)
        self.opFinalThreshold.Input.connect(self.opSumInputs.Output)

        self.opFinalFilter = OpFilterLabels(parent=self)
        self.opFinalFilter.BinaryOut.setValue(False)
        self.opFinalFilter.MinLabelSize.connect(self.MinSize)
        self.opFinalFilter.MaxLabelSize.connect(self.MaxSize)
        self.opFinalFilter.Input.connect(self.opFinalThreshold.Output)

        self.opReorderOutput = OpReorderAxes(parent=self)
        # self.opReorderOutput.AxisOrder.setValue('tzyxc') # See setupOutputs()
        self.opReorderOutput.Input.connect(self.opFinalFilter.Output)

        self.Output.connect(self.opReorderOutput.Output)

        self.opCache = OpBlockedArrayCache(parent=self)
        self.opCache.CompressionEnabled.setValue(True)
        self.opCache.Input.connect(self.opReorderOutput.Output)

        self.CachedOutput.connect(self.opCache.Output)
        self.CleanBlocks.connect(self.opCache.CleanBlocks)

        ## Debug outputs
        self.Smoothed.connect(self.opSmootherCache.Output)
        self.InputChannel.connect(self.opFinalChannelSelector.Output)
        self.SmallRegions.connect(self.opCoreThreshold.Output)
        self.BeforeSizeFilter.connect(self.opFinalThreshold.Output)

        self.opFilteredSmallLabelsCache = OpBlockedArrayCache(parent=self)
        self.opFilteredSmallLabelsCache.CompressionEnabled.setValue(True)
        self.opFilteredSmallLabelsCache.Input.connect(self.opCoreFilter.Output)
        self.FilteredSmallLabels.connect(
            self.opFilteredSmallLabelsCache.Output)

        # Since hysteresis thresholding creates the big regions and immediately discards the bad ones,
        # we have to recreate it here if the user wants to view it as a debug layer
        self.opBigRegionsThreshold = OpLabeledThreshold(parent=self)
        self.opBigRegionsThreshold.Method.setValue(ThresholdMethod.SIMPLE)
        self.opBigRegionsThreshold.FinalThreshold.connect(self.LowThreshold)
        self.opBigRegionsThreshold.Input.connect(
            self.opFinalChannelSelector.Output)
        self.BigRegions.connect(self.opBigRegionsThreshold.Output)

    def setupOutputs(self):
        axes = self.InputImage.meta.getAxisKeys()
        self.opReorderOutput.AxisOrder.setValue(axes)

        # Cache individual t,c slices
        blockshape = tuple(1 if k in "tc" else None for k in axes)
        self.opCache.BlockShape.setValue(blockshape)
        # assuming (t, c, z, y, x) here.
        self.opFilteredSmallLabelsCache.BlockShape.setValue(
            (1, 1, None, None, None))

        if (self.CurOperator.value
                in (ThresholdMethod.HYSTERESIS, ThresholdMethod.IPHT)
                and self.Channel.value != self.CoreChannel.value):
            self.opSumInputs.Inputs.resize(2)
            self.opSumInputs.Inputs[0].connect(
                self.opFinalChannelSelector.Output)
            self.opSumInputs.Inputs[1].connect(
                self.opCoreChannelSelector.Output)
        else:
            self.opSumInputs.Inputs.resize(1)
            self.opSumInputs.Inputs[0].connect(
                self.opFinalChannelSelector.Output)

    def setInSlot(self, slot, subindex, roi, value):
        self.opCache.setInSlot(self.opCache.Input, subindex, roi, value)

    def execute(self, slot, subindex, roi, destination):
        assert False, "Shouldn't get here."

    def propagateDirty(self, slot, subindex, roi):
        pass  # dirtiness propagation is handled in the sub-operators
Example #25
0
class OpExportDvidVolume(Operator):
    Input = InputSlot()
    NodeDataUrl = InputSlot(
    )  # Should be a url of the form http://<hostname>[:<port>]/api/node/<uuid>/<dataname>
    OffsetCoord = InputSlot(optional=True)

    def __init__(self, transpose_axes, *args, **kwargs):
        super(OpExportDvidVolume, self).__init__(*args, **kwargs)
        self.progressSignal = OrderedSignal()
        self._transpose_axes = transpose_axes

    # No output slots...
    def setupOutputs(self):
        if self._transpose_axes:
            assert self.Input.meta.axistags.channelIndex == len(
                self.Input.meta.axistags) - 1
        else:
            assert self.Input.meta.axistags.channelIndex == 0

    def execute(self, slot, subindex, roi, result):
        pass

    def propagateDirty(self, slot, subindex, roi):
        pass

    def run_export(self):
        self.progressSignal(0)

        url = self.NodeDataUrl.value
        url_path = url.split('://')[1]
        hostname, api, node, uuid, dataname = url_path.split('/')
        assert api == 'api'
        assert node == 'node'

        axiskeys = self.Input.meta.getAxisKeys()
        shape = self.Input.meta.shape

        if self._transpose_axes:
            axiskeys = reversed(axiskeys)
            shape = tuple(reversed(shape))

        axiskeys = "".join(axiskeys)

        if self.OffsetCoord.ready():
            offset_start = self.OffsetCoord.value
        else:
            offset_start = (0, ) * len(self.Input.meta.shape)

        connection = pydvid.dvid_connection.DvidConnection(hostname)
        with contextlib.closing(connection):
            self.progressSignal(5)

            # Get the dataset details
            try:
                metadata = pydvid.voxels.get_metadata(connection, uuid,
                                                      dataname)
            except pydvid.errors.DvidHttpError as ex:
                if ex.status_code != 404:
                    raise
                # Dataset doesn't exist yet.  Let's create it.
                metadata = pydvid.voxels.VoxelsMetadata.create_default_metadata(
                    shape, self.Input.meta.dtype, axiskeys, 0.0, "")
                pydvid.voxels.create_new(connection, uuid, dataname, metadata)

            # Since this class is generall used to push large blocks of data,
            #  we'll be nice and set throttle=True
            client = pydvid.voxels.VoxelsAccessor(connection,
                                                  uuid,
                                                  dataname,
                                                  throttle=True)

            def handle_block_result(roi, data):
                # Send it to dvid
                roi = numpy.asarray(roi)
                roi += offset_start
                start, stop = roi
                if self._transpose_axes:
                    data = data.transpose()
                    start = tuple(reversed(start))
                    stop = tuple(reversed(stop))
                    client.post_ndarray(start, stop, data)

            requester = BigRequestStreamer(self.Input,
                                           roiFromShape(self.Input.meta.shape))
            requester.resultSignal.subscribe(handle_block_result)
            requester.progressSignal.subscribe(self.progressSignal)
            requester.execute()

        self.progressSignal(100)
Example #26
0
class OpAnisotropicGaussianSmoothing(Operator):
    Input = InputSlot()
    Sigmas = InputSlot(value={'x': 1.0, 'y': 1.0, 'z': 1.0})

    Output = OutputSlot()

    def setupOutputs(self):

        self.Output.meta.assignFrom(self.Input.meta)
        #if there is a time of dim 1, output won't have that
        timeIndex = self.Output.meta.axistags.index('t')
        if timeIndex < len(self.Output.meta.shape):
            newshape = list(self.Output.meta.shape)
            newshape.pop(timeIndex)
            self.Output.meta.shape = tuple(newshape)
            del self.Output.meta.axistags[timeIndex]
        self.Output.meta.dtype = numpy.float32  # vigra gaussian only supports float32
        self._sigmas = self.Sigmas.value
        assert isinstance(self.Sigmas.value,
                          dict), "Sigmas slot expects a dict"
        assert set(self._sigmas.keys()) == set(
            'xyz'), "Sigmas slot expects three key-value pairs for x,y,z"
        print("Assigning output: {} ====> {}".format(
            self.Input.meta.getTaggedShape(),
            self.Output.meta.getTaggedShape()))
        #self.Output.setDirty( slice(None) )

    def execute(self, slot, subindex, roi, result):
        assert all(
            roi.stop <= self.Input.meta.shape
        ), "Requested roi {} is too large for this input image of shape {}.".format(
            roi, self.Input.meta.shape)
        # Determine how much input data we'll need, and where the result will be relative to that input roi
        inputRoi, computeRoi = self._getInputComputeRois(roi)
        # Obtain the input data
        with Timer() as resultTimer:
            data = self.Input(*inputRoi).wait()
        logger.debug("Obtaining input data took {} seconds for roi {}".format(
            resultTimer.seconds(), inputRoi))

        xIndex = self.Input.meta.axistags.index('x')
        yIndex = self.Input.meta.axistags.index('y')
        zIndex = self.Input.meta.axistags.index(
            'z') if self.Input.meta.axistags.index('z') < len(
                self.Input.meta.shape) else None
        cIndex = self.Input.meta.axistags.index(
            'c') if self.Input.meta.axistags.index('c') < len(
                self.Input.meta.shape) else None

        # Must be float32
        if data.dtype != numpy.float32:
            data = data.astype(numpy.float32)

        axiskeys = self.Input.meta.getAxisKeys()
        spatialkeys = filter(lambda k: k in 'xyz', axiskeys)

        # we need to remove a singleton z axis, otherwise we get
        # 'kernel longer than line' errors
        reskey = [slice(None, None, None)] * len(self.Input.meta.shape)
        reskey[cIndex] = 0
        if zIndex and self.Input.meta.shape[zIndex] == 1:
            removedZ = True
            data = data.reshape((data.shape[xIndex], data.shape[yIndex]))
            reskey[zIndex] = 0
            spatialkeys = filter(lambda k: k in 'xy', axiskeys)
        else:
            removedZ = False

        sigma = map(self._sigmas.get, spatialkeys)
        #Check if we need to smooth
        if any([x < 0.1 for x in sigma]):
            if removedZ:
                resultXY = vigra.taggedView(result, axistags="".join(axiskeys))
                resultXY = resultXY.withAxes(*'xy')
                resultXY[:] = data
            else:
                result[:] = data
            return result

        # Smooth the input data
        smoothed = vigra.filters.gaussianSmoothing(
            data,
            sigma,
            window_size=2.0,
            roi=computeRoi,
            out=result[tuple(reskey)])  # FIXME: Assumes channel is last axis
        expectedShape = tuple(
            TinyVector(computeRoi[1]) - TinyVector(computeRoi[0]))
        assert tuple(
            smoothed.shape
        ) == expectedShape, "Smoothed data shape {} didn't match expected shape {}".format(
            smoothed.shape, roi.stop - roi.start)

        return result

    def _getInputComputeRois(self, roi):
        axiskeys = self.Input.meta.getAxisKeys()
        spatialkeys = filter(lambda k: k in 'xyz', axiskeys)
        sigma = map(self._sigmas.get, spatialkeys)
        inputSpatialShape = self.Input.meta.getTaggedShape()
        spatialRoi = (TinyVector(roi.start), TinyVector(roi.stop))
        tIndex = None
        cIndex = None
        zIndex = None
        if 'c' in inputSpatialShape:
            del inputSpatialShape['c']
            cIndex = axiskeys.index('c')
        if 't' in inputSpatialShape.keys():
            assert inputSpatialShape['t'] == 1
            tIndex = axiskeys.index('t')

        if 'z' in inputSpatialShape.keys() and inputSpatialShape['z'] == 1:
            #2D image, avoid kernel longer than line exception
            del inputSpatialShape['z']
            zIndex = axiskeys.index('z')

        indices = [tIndex, cIndex, zIndex]
        indices = sorted(indices, reverse=True)
        for ind in indices:
            if ind:
                spatialRoi[0].pop(ind)
                spatialRoi[1].pop(ind)

        inputSpatialRoi = enlargeRoiForHalo(spatialRoi[0],
                                            spatialRoi[1],
                                            inputSpatialShape.values(),
                                            sigma,
                                            window=2.0)

        # Determine the roi within the input data we're going to request
        inputRoiOffset = spatialRoi[0] - inputSpatialRoi[0]
        computeRoi = (inputRoiOffset,
                      inputRoiOffset + spatialRoi[1] - spatialRoi[0])

        # For some reason, vigra.filters.gaussianSmoothing will raise an exception if this parameter doesn't have the correct integer type.
        # (for example, if we give it as a numpy.ndarray with dtype=int64, we get an error)
        computeRoi = (tuple(map(int,
                                computeRoi[0])), tuple(map(int,
                                                           computeRoi[1])))

        inputRoi = (list(inputSpatialRoi[0]), list(inputSpatialRoi[1]))
        for ind in reversed(indices):
            if ind:
                inputRoi[0].insert(ind, 0)
                inputRoi[1].insert(ind, 1)

        return inputRoi, computeRoi

    def propagateDirty(self, slot, subindex, roi):
        if slot == self.Input:
            # Halo calculation is bidirectional, so we can re-use the function that computes the halo during execute()
            inputRoi, _ = self._getInputComputeRois(roi)
            self.Output.setDirty(inputRoi[0], inputRoi[1])
        elif slot == self.Sigmas:
            self.Output.setDirty(slice(None))
        else:
            assert False, "Unknown input slot: {}".format(slot.name)
Example #27
0
class OpCacheFixer(Operator):
    """
    Can be inserted in front of a cache operator to implement the "fixAtCurrent" 
    behavior currently implemented by multiple lazyflow caches.
    
    While fixAtCurrent=False, this operator is merely a pass-through.
    
    While fixAtCurrent=True, this operator does not forward dirty notifications 
    to downstream operators. Instead, it remembers the total ROI of the dirty area 
    (as a bounding box), and emits the entire dirty ROI at once as soon as it becomes "unfixed".
    Also, this operator returns only zeros while fixAtCurrent=True.
    """
    fixAtCurrent = InputSlot(value=False)
    Input = InputSlot(allow_mask=True)
    Output = OutputSlot(allow_mask=True)

    def __init__(self, *args, **kwargs):
        super(OpCacheFixer, self).__init__(*args, **kwargs)
        self._fixed = False
        self._fixed_dirty_roi = None

    def setupOutputs(self):
        self.Output.meta.assignFrom( self.Input.meta )
        self.Output.meta.dontcache = self.fixAtCurrent.value

        # During initialization, if fixAtCurrent is configured before Input, then propagateDirty was never called.
        # We need to make sure that the dirty logic for fixAtCurrent has definitely been called here.
        self.propagateDirty(self.fixAtCurrent, (), slice(None))

    def execute(self, slot, subindex, roi, result):
        if self._fixed:
            # The downstream user doesn't know he's getting fake data.
            # When we become "unfixed", we need to tell him.
            self._expand_fixed_dirty_roi( (roi.start, roi.stop) )
            result[:] = 0
        else:
            self.Input(roi.start, roi.stop).writeInto(result).wait()
        
    def setInSlot(self, slot, subindex, roi, value):
        # Forward to the output
        self.Output[roiToSlice(roi.start, roi.stop)] = value
        
        entire_roi = roiFromShape(self.Input.meta.shape)
        if (numpy.array((roi.start, roi.stop)) == entire_roi).all():
            # Nothing is dirty any more.
            self._init_fixed_dirty_roi()
    
    def propagateDirty(self, slot, subindex, roi):
        if slot is self.fixAtCurrent:
            # If we're becoming UN-fixed, send out a big dirty notification
            if ( self._fixed and not self.fixAtCurrent.value and
                 self._fixed_dirty_roi and (self._fixed_dirty_roi[1] - self._fixed_dirty_roi[0] > 0).all() ):
                self.Output.setDirty( *self._fixed_dirty_roi )
                self._fixed_dirty_roi = None
            self._fixed = self.fixAtCurrent.value
        elif slot is self.Input:
            if self._fixed:
                # We can't propagate this downstream,
                #  but we need to remember that it was marked dirty.
                # Expand our dirty bounding box.
                self._expand_fixed_dirty_roi( (roi.start, roi.stop) )
            else:
                self.Output.setDirty(roi.start, roi.stop)

    def _init_fixed_dirty_roi(self):
        # Intentionally flipped: nothing is dirty at first.
        entire_roi = roiFromShape(self.Input.meta.shape)
        self._fixed_dirty_roi = (entire_roi[1], entire_roi[0])

    def _expand_fixed_dirty_roi(self, roi):
        if self._fixed_dirty_roi is None:
            self._init_fixed_dirty_roi()
        start, stop = self._fixed_dirty_roi
        start = numpy.minimum(start, roi[0])
        stop = numpy.maximum(stop, roi[1])
        self._fixed_dirty_roi = (start, stop)
Example #28
0
class OpSelectLabels(Operator):

    ## The smaller clusters
    # i.e. results of high thresholding
    SmallLabels = InputSlot()

    ## The larger clusters
    # i.e. results of low thresholding
    BigLabels = InputSlot()

    Output = OutputSlot()

    def setupOutputs(self):
        self.Output.meta.assignFrom(self.BigLabels.meta)
        self.Output.meta.dtype = numpy.uint32
        self.Output.meta.drange = (0, 1)

    def execute(self, slot, subindex, roi, result):
        assert slot == self.Output

        # This operator is typically used with very big rois, so be extremely memory-conscious:
        # - Don't request the small and big inputs in parallel.
        # - Clean finished requests immediately (don't wait for this function to exit)
        # - Delete intermediate results as soon as possible.

        if logger.isEnabledFor(logging.DEBUG):
            dtypeBytes = self.SmallLabels.meta.getDtypeBytes()
            roiShape = roi.stop - roi.start
            logger.debug("Roi shape is {} = {} MB".format(
                roiShape,
                numpy.prod(roiShape) * dtypeBytes / 1e6))
            starting_memory_usage_mb = getMemoryUsageMb()
            logger.debug("Starting with memory usage: {} MB".format(
                starting_memory_usage_mb))

        def logMemoryIncrease(msg):
            """Log a debug message about the RAM usage compared to when this function started execution."""
            if logger.isEnabledFor(logging.DEBUG):
                memory_increase_mb = getMemoryUsageMb(
                ) - starting_memory_usage_mb
                logger.debug("{}, memory increase is: {} MB".format(
                    msg, memory_increase_mb))

        smallLabelsReq = self.SmallLabels(roi.start, roi.stop)
        smallLabels = smallLabelsReq.wait()
        smallLabelsReq.clean()
        logMemoryIncrease("After obtaining small labels")

        smallNonZero = numpy.ndarray(shape=smallLabels.shape, dtype=bool)
        smallNonZero[...] = (smallLabels != 0)
        del smallLabels

        logMemoryIncrease("Before obtaining big labels")
        bigLabels = self.BigLabels(roi.start, roi.stop).wait()
        logMemoryIncrease("After obtaining big labels")

        prod = smallNonZero * bigLabels

        del smallNonZero

        # get labels that passed the masking
        #passed = numpy.unique(prod)
        passed = numpy.bincount(prod.flat).nonzero()[
            0]  # Much faster than unique(), which copies and sorts

        # 0 is not a valid label
        if passed[0] == 0:
            passed = passed[1:]

        logMemoryIncrease("After taking product")
        del prod

        all_label_values = numpy.zeros((bigLabels.max() + 1, ),
                                       dtype=numpy.uint32)

        for i, l in enumerate(passed):
            all_label_values[l] = i + 1
        all_label_values[0] = 0

        # tricky: map the old labels to the new ones, labels that didnt pass
        # are mapped to zero
        result[:] = all_label_values[bigLabels]

        logMemoryIncrease("Just before return")
        return result

    def propagateDirty(self, slot, subindex, roi):
        if slot == self.SmallLabels or slot == self.BigLabels:
            self.Output.setDirty(slice(None))
        else:
            assert False, "Unknown input slot: {}".format(slot.name)
Example #29
0
class OpMockPixelClassifier(Operator):
    """
    This class is a simple stand-in for the real pixel classification operator.
    Uses hard-coded data shape and block shape.
    Provides hard-coded outputs.
    """
    name = "OpMockPixelClassifier"

    LabelInputs = InputSlot(
        optional=True,
        level=1)  # Input for providing label data from an external source

    PredictionsFromDisk = InputSlot(
        optional=True, level=1)  # TODO: Actually use this input for something

    ClassifierFactory = InputSlot(
        value=ParallelVigraRfLazyflowClassifierFactory(10, 10))

    NonzeroLabelBlocks = OutputSlot(
        level=1,
        stype='object')  # A list if slices that contain non-zero label values
    LabelImages = OutputSlot(level=1)  # Labels from the user

    Classifier = OutputSlot(stype='object')

    PredictionProbabilities = OutputSlot(level=1)

    FreezePredictions = InputSlot()

    LabelNames = OutputSlot()
    LabelColors = OutputSlot()
    PmapColors = OutputSlot()
    Bookmarks = OutputSlot(level=1)

    def __init__(self, *args, **kwargs):
        super(OpMockPixelClassifier, self).__init__(*args, **kwargs)

        self.LabelNames.setValue(["Membrane", "Cytoplasm"])
        self.LabelColors.setValue([(255, 0, 0), (0, 255, 0)])  # Red, Green
        self.PmapColors.setValue([(255, 0, 0), (0, 255, 0)])  # Red, Green

        self._data = []
        self.dataShape = (1, 10, 100, 100, 1)
        self.prediction_shape = self.dataShape[:-1] + (
            2, )  # Hard-coded to provide 2 classes

        self.FreezePredictions.setValue(False)

        self.opClassifier = OpTrainClassifierBlocked(graph=self.graph,
                                                     parent=self)
        self.opClassifier.ClassifierFactory.connect(self.ClassifierFactory)
        self.opClassifier.Labels.connect(self.LabelImages)
        self.opClassifier.nonzeroLabelBlocks.connect(self.NonzeroLabelBlocks)
        self.opClassifier.MaxLabel.setValue(2)

        self.classifier_cache = OpValueCache(graph=self.graph, parent=self)
        self.classifier_cache.Input.connect(self.opClassifier.Classifier)

        p1 = old_div(numpy.indices(self.dataShape).sum(0), 207.0)
        p2 = 1 - p1

        self.predictionData = numpy.concatenate((p1, p2), axis=4)

    def setupOutputs(self):
        numImages = len(self.LabelInputs)

        self.PredictionsFromDisk.resize(numImages)
        self.NonzeroLabelBlocks.resize(numImages)
        self.LabelImages.resize(numImages)
        self.PredictionProbabilities.resize(numImages)
        self.opClassifier.Images.resize(numImages)

        for i in range(numImages):
            self._data.append(numpy.zeros(self.dataShape))
            self.NonzeroLabelBlocks[i].meta.shape = (1, )
            self.NonzeroLabelBlocks[i].meta.dtype = object

            # Hard-coded: Two prediction classes
            self.PredictionProbabilities[i].meta.shape = self.prediction_shape
            self.PredictionProbabilities[i].meta.dtype = numpy.float64
            self.PredictionProbabilities[
                i].meta.axistags = vigra.defaultAxistags('txyzc')

            # Classify with random data
            self.opClassifier.Images[i].setValue(
                vigra.taggedView(numpy.random.random(self.dataShape), 'txyzc'))

            self.LabelImages[i].meta.shape = self.dataShape
            self.LabelImages[i].meta.dtype = numpy.float64
            self.LabelImages[i].meta.axistags = self.opClassifier.Images[
                i].meta.axistags

        self.Classifier.connect(self.opClassifier.Classifier)

    def setInSlot(self, slot, subindex, roi, value):
        key = roi.toSlice()
        assert slot.name == "LabelInputs"
        self._data[subindex[0]][key] = value
        self.LabelImages[subindex[0]].setDirty(key)

    def execute(self, slot, subindex, roi, result):
        key = roiToSlice(roi.start, roi.stop)
        index = subindex[0]
        if slot.name == "NonzeroLabelBlocks":
            # Split into 10 chunks
            blocks = []
            slicing = [slice(0, maximum) for maximum in self.dataShape]
            for i in range(10):
                slicing[2] = slice(i * 10, (i + 1) * 10)
                if not (self._data[index][slicing] == 0).all():
                    blocks.append(list(slicing))

            result[0] = blocks
        if slot.name == "LabelImages":
            result[...] = self._data[index][key]
        if slot.name == "PredictionProbabilities":
            result[...] = self.predictionData[key]

    def propagateDirty(self, slot, subindex, roi):
        pass
Example #30
0
class OpPredictionPipeline(OpPredictionPipelineNoCache):
    """
    This operator extends the cacheless prediction pipeline above with additional outputs for the GUI.
    (It uses caches for these outputs, and has an extra input for cached features.)
    """        
    FreezePredictions = InputSlot()
    CachedFeatureImages = InputSlot()

    PredictionProbabilities = OutputSlot()
    CachedPredictionProbabilities = OutputSlot()

    PredictionProbabilitiesUint8 = OutputSlot()
    
    PredictionProbabilityChannels = OutputSlot( level=1 )
    SegmentationChannels = OutputSlot( level=1 )
    UncertaintyEstimate = OutputSlot()

    def __init__(self, *args, **kwargs):
        super(OpPredictionPipeline, self).__init__( *args, **kwargs )

        # Random forest prediction using CACHED features.
        self.predict = OpClassifierPredict( parent=self )
        self.predict.name = "OpClassifierPredict"
        self.predict.Classifier.connect(self.Classifier) 
        self.predict.Image.connect(self.CachedFeatureImages)
        self.predict.PredictionMask.connect(self.PredictionMask)
        self.predict.LabelsCount.connect( self.NumClasses )
        self.PredictionProbabilities.connect( self.predict.PMaps )

        # Alternate headless output: uint8 instead of float.
        # Note that drange is automatically updated.        
        self.opConvertToUint8 = OpPixelOperator( parent=self )
        self.opConvertToUint8.Input.connect( self.predict.PMaps )
        self.opConvertToUint8.Function.setValue( lambda a: (255*a).astype(numpy.uint8) )
        self.PredictionProbabilitiesUint8.connect( self.opConvertToUint8.Output )

        # Prediction cache for the GUI
        self.prediction_cache_gui = OpSlicedBlockedArrayCache( parent=self )
        self.prediction_cache_gui.name = "prediction_cache_gui"
        self.prediction_cache_gui.inputs["fixAtCurrent"].connect( self.FreezePredictions )
        self.prediction_cache_gui.inputs["Input"].connect( self.predict.PMaps )
        self.CachedPredictionProbabilities.connect(self.prediction_cache_gui.Output )

        # Also provide each prediction channel as a separate layer (for the GUI)
        self.opPredictionSlicer = OpMultiArraySlicer2( parent=self )
        self.opPredictionSlicer.name = "opPredictionSlicer"
        self.opPredictionSlicer.Input.connect( self.prediction_cache_gui.Output )
        self.opPredictionSlicer.AxisFlag.setValue('c')
        self.PredictionProbabilityChannels.connect( self.opPredictionSlicer.Slices )
        
        self.opSegmentor = OpMaxChannelIndicatorOperator( parent=self )
        self.opSegmentor.Input.connect( self.prediction_cache_gui.Output )

        self.opSegmentationSlicer = OpMultiArraySlicer2( parent=self )
        self.opSegmentationSlicer.name = "opSegmentationSlicer"
        self.opSegmentationSlicer.Input.connect( self.opSegmentor.Output )
        self.opSegmentationSlicer.AxisFlag.setValue('c')
        self.SegmentationChannels.connect( self.opSegmentationSlicer.Slices )

        # Create a layer for uncertainty estimate
        self.opUncertaintyEstimator = OpEnsembleMargin( parent=self )
        self.opUncertaintyEstimator.Input.connect( self.prediction_cache_gui.Output )

        # Cache the uncertainty so we get zeros for uncomputed points
        self.opUncertaintyCache = OpSlicedBlockedArrayCache( parent=self )
        self.opUncertaintyCache.name = "opUncertaintyCache"
        self.opUncertaintyCache.Input.connect( self.opUncertaintyEstimator.Output )
        self.opUncertaintyCache.fixAtCurrent.connect( self.FreezePredictions )
        self.UncertaintyEstimate.connect( self.opUncertaintyCache.Output )

    def setupOutputs(self):
        # Set the blockshapes for each input image separately, depending on which axistags it has.
        axisOrder = [ tag.key for tag in self.FeatureImages.meta.axistags ]

        blockDimsX = { 't' : (1,1),
                       'z' : (256,256),
                       'y' : (256,256),
                       'x' : (1,1),
                       'c' : (100, 100) }

        blockDimsY = { 't' : (1,1),
                       'z' : (256,256),
                       'y' : (1,1),
                       'x' : (256,256),
                       'c' : (100,100) }

        blockDimsZ = { 't' : (1,1),
                       'z' : (1,1),
                       'y' : (256,256),
                       'x' : (256,256),
                       'c' : (100,100) }

        innerBlockShapeX = tuple( blockDimsX[k][0] for k in axisOrder )
        outerBlockShapeX = tuple( blockDimsX[k][1] for k in axisOrder )

        innerBlockShapeY = tuple( blockDimsY[k][0] for k in axisOrder )
        outerBlockShapeY = tuple( blockDimsY[k][1] for k in axisOrder )

        innerBlockShapeZ = tuple( blockDimsZ[k][0] for k in axisOrder )
        outerBlockShapeZ = tuple( blockDimsZ[k][1] for k in axisOrder )

        self.prediction_cache_gui.inputs["innerBlockShape"].setValue( (innerBlockShapeX, innerBlockShapeY, innerBlockShapeZ) )
        self.prediction_cache_gui.inputs["outerBlockShape"].setValue( (outerBlockShapeX, outerBlockShapeY, outerBlockShapeZ) )

        self.opUncertaintyCache.inputs["innerBlockShape"].setValue( (innerBlockShapeX, innerBlockShapeY, innerBlockShapeZ) )
        self.opUncertaintyCache.inputs["outerBlockShape"].setValue( (outerBlockShapeX, outerBlockShapeY, outerBlockShapeZ) )

        assert self.opConvertToUint8.Output.meta.drange == (0,255)