class OpArrayCacheCpp(OpCache): """ Allocates a block of memory as large as Input.meta.shape (==Output.meta.shape) with the same dtype in order to be able to cache results. blockShape: dirty regions are tracked with a granularity of blockShape """ name = "ArrayCache" description = "numpy.ndarray caching class" category = "misc" DefaultBlockSize = 64 #Input Input = InputSlot() blockShape = InputSlot(value=DefaultBlockSize) fixAtCurrent = InputSlot(value=False) #Output CleanBlocks = OutputSlot() Output = OutputSlot() loggingName = __name__ + ".OpArrayCache" logger = logging.getLogger(loggingName) traceLogger = logging.getLogger("TRACE." + loggingName) # Block states IN_PROCESS = 0 DIRTY = 1 CLEAN = 2 FIXED_DIRTY = 3 def __init__(self, *args, **kwargs): super(OpArrayCacheCpp, self).__init__(*args, **kwargs) self._origBlockShape = self.DefaultBlockSize self._last_access = None self._blockShape = None self._fixed = False self._lock = Lock() self._cacheHits = 0 self._has_fixed_dirty_blocks = False self._memory_manager = ArrayCacheMemoryMgr.instance self._running = 0 def usedMemory(self): pass #def usedMemory(self): # if self._cache is not None: # return self._cache.nbytes # else: # return 0 #def _blockShapeForIndex(self, index): # if self._cache is None: # return None # cacheShape = numpy.array(self._cache.shape) # blockStart = index * self._blockShape # blockStop = numpy.minimum(blockStart + self._blockShape, cacheShape) def fractionOfUsedMemoryDirty(self): pass #def fractionOfUsedMemoryDirty(self): # totAll = numpy.prod(self.Output.meta.shape) # totDirty = 0 # for i, v in enumerate(self._blockState.ravel()): # sh = self._blockShapeForIndex(i) # if sh is None: # continue # if v == self.DIRTY or v == self.FIXED_DIRTY: # totDirty += numpy.prod(sh) # return totDirty/float(totAll) def lastAccessTime(self): return self._last_access def generateReport(self, report): pass #def generateReport(self, report): # report.name = self.name # report.fractionOfUsedMemoryDirty = self.fractionOfUsedMemoryDirty() # report.usedMemory = self.usedMemory() # report.lastAccessTime = self.lastAccessTime() # report.dtype = self.Output.meta.dtype # report.type = type(self) # report.id = id(self) def _freeMemory(self, refcheck=True): pass #def _freeMemory(self, refcheck = True): # with self._cacheLock: # freed = self.usedMemory() # if self._cache is not None: # fshape = self._cache.shape # try: # self._cache.resize((1,), refcheck = refcheck) # except ValueError: # freed = 0 # self.logger.warn("OpArrayCache: freeing failed due to view references") # if freed > 0: # self.logger.debug("OpArrayCache: freed cache of shape:{}".format(fshape)) # # self._lock.acquire() # self._blockState[:] = OpArrayCache.DIRTY # del self._cache # self._cache = None # self._lock.release() # return freed def _allocateManagementStructures(self): pass #def _allocateManagementStructures(self): # with Tracer(self.traceLogger): # shape = self.Output.meta.shape # if type(self._origBlockShape) != tuple: # self._blockShape = (self._origBlockShape,)*len(shape) # else: # self._blockShape = self._origBlockShape # # self._blockShape = numpy.minimum(self._blockShape, shape) # # self._dirtyShape = numpy.ceil(1.0 * numpy.array(shape) / numpy.array(self._blockShape)) # # self.logger.debug("Configured OpArrayCache with shape={}, blockShape={}, dirtyShape={}, origBlockShape={}".format(shape, self._blockShape, self._dirtyShape, self._origBlockShape)) # # #if a request has been submitted to get a block, the request object # #is stored within this array # self._blockQuery = numpy.ndarray(self._dirtyShape, dtype=object) # # #keep track of the dirty state of each block # self._blockState = OpArrayCache.DIRTY * numpy.ones(self._dirtyShape, numpy.uint8) # # self._blockState[:]= OpArrayCache.DIRTY # self._dirtyState = OpArrayCache.CLEAN #def _allocateCache(self): # with self._cacheLock: # self._last_access = None # self._cache_priority = 0 # self._running = 0 # # if self._cache is None or (self._cache.shape != self.Output.meta.shape): # mem = numpy.zeros(self.Output.meta.shape, dtype = self.Output.meta.dtype) # self.logger.debug("OpArrayCache: Allocating cache (size: %dbytes)" % mem.nbytes) # if self._blockState is None: # self._allocateManagementStructures() # self._cache = mem # self._memory_manager.add(self) def setupOutputs(self): self.CleanBlocks.meta.shape = (1, ) self.CleanBlocks.meta.dtype = object reconfigure = False if self.inputs["fixAtCurrent"].ready(): self._fixed = self.inputs["fixAtCurrent"].value if self.inputs["blockShape"].ready() and self.inputs["Input"].ready(): newBShape = self.inputs["blockShape"].value if not isinstance(newBShape, tuple): newBShape = tuple([int(newBShape)] * len(self.Input.meta.shape)) if self._origBlockShape != newBShape and self.inputs[ "Input"].ready(): reconfigure = True self._origBlockShape = newBShape self._blockShape = newBShape inputSlot = self.inputs["Input"] self.outputs["Output"].meta.assignFrom(inputSlot.meta) shape = self.outputs["Output"].meta.shape if reconfigure and shape is not None: self._lock.acquire() if self.Input.meta.dtype == numpy.uint32: t = "uint32" elif self.Input.meta.dtype == numpy.int32: t = "int32" elif self.Input.meta.dtype == numpy.float32: t = "float32" elif self.Input.meta.dtype == numpy.int64: t = "int64" elif self.Input.meta.dtype == numpy.uint8: t = "uint8" else: raise RuntimeError("dtype %r not supported" % self.Input.meta.dtype) cls = "BlockedArray%d%s" % (len(self._blockShape), t) self.b = eval(cls)(self._blockShape) self.b.setDirty(tuple([0] * len(self._blockShape)), self.Input.meta.shape, True) self._lock.release() def propagateDirty(self, slot, subindex, roi): shape = self.Output.meta.shape key = roi.toSlice() if slot == self.inputs["Input"]: with self._lock: self.b.setDirty(roi.start, roi.stop, True) if not self._fixed: self.outputs["Output"].setDirty(key) if slot == self.inputs["fixAtCurrent"]: if self.inputs["fixAtCurrent"].ready(): self._fixed = self.inputs["fixAtCurrent"].value if not self._fixed and self.Output.meta.shape is not None and self._has_fixed_dirty_blocks: pass ''' # We've become unfixed, so we need to notify downstream # operators of every block that became dirty while we were fixed. # Convert all FIXED_DIRTY states into DIRTY states with self._lock: dirtyBlocks = self.b.dirtyBlocks(000, self.Output.meta.shape) newDirtyBlocks = diff(dirtyBlocks, oldDityBlocks) self._has_fixed_dirty_blocks = False newDirtyBlocks = numpy.transpose(numpy.nonzero(cond)) # To avoid lots of setDirty notifications, we simply merge all the dirtyblocks into one single superblock. # This should be the best option in most cases, but could be bad in some cases. # TODO: Optimize this by merging the dirty blocks via connected components or something. cacheShape = numpy.array(self.Output.meta.shape) dirtyStart = cacheShape dirtyStop = [0] * len(cacheShape) for index in newDirtyBlocks: blockStart = index * self._blockShape blockStop = numpy.minimum(blockStart + self._blockShape, cacheShape) dirtyStart = numpy.minimum(dirtyStart, blockStart) dirtyStop = numpy.maximum(dirtyStop, blockStop) if len(newDirtyBlocks > 0): self.Output.setDirty( dirtyStart, dirtyStop ) ''' def _updatePriority(self, new_access=None): pass #def _updatePriority(self, new_access = None): # if self._last_access is None: # self._last_access = new_access or time.time() # cur_time = time.time() # delta = cur_time - self._last_access + 1e-9 # # self._last_access = cur_time # new_prio = 0.5 * self._cache_priority + delta # self._cache_priority = new_prio def execute(self, slot, subindex, roi, result): if slot == self.Output: return self._executeOutput(slot, subindex, roi, result) elif slot == self.CleanBlocks: pass #FIXME #return self._executeCleanBlocks(slot, subindex, roi, result) def _executeOutput(self, slot, subindex, roi, result): key = roi.toSlice() shape = self.Output.meta.shape start, stop = sliceToRoi(key, shape) self.traceLogger.debug("Acquiring ArrayCache lock...") self._lock.acquire() self.traceLogger.debug("ArrayCache lock acquired.") ch = self._cacheHits ch += 1 self._cacheHits = ch self._running += 1 bp, bq = self.b.dirtyBlocks(start, stop) #print "there are %d dirty blocks" % bp.shape[0] if not self._fixed: reqs = [] sh = self.outputs["Output"].meta.shape for i in range(bp.shape[0]): bStart = tuple([int(t) for t in bp[i, :]]) bStop = tuple([int(t) for t in numpy.minimum(bq[i, :], sh)]) key = roiToSlice(bStart, bStop) req = self.Input[key] reqs.append((req, bStart, bStop)) for r, bStart, bStop in reqs: r.wait() for r, bStart, bStop in reqs: x = r.wait() self.b.writeSubarray(bStart, bStop, r.wait()) t1 = time.time() self.b.readSubarray(start, stop, result) #print "read subarray took %f" % (time.time()-t1) self._lock.release() return result def setInSlot(self, slot, subindex, roi, value): print "SET IN SLOT &&&&&" assert slot == self.inputs["Input"] ch = self._cacheHits ch += 1 self._cacheHits = ch start, stop = roi.start, roi.stop self._lock.acquire() print "***", start, stop, value.shape self.b.writeSubarray(start, stop, value) self._lock.release()
class OpStreamingH5N5SequenceReaderS(Operator): """ Imports a sequence of (ND) volumes inside one hdf5/N5 file into a single volume (ND+1) The 'S' at the end of the file name implies that this class handles multiple volumes in a single file. :param globstring: A glob string as defined by the glob module. We also support the following special extension to globstring syntax: A single string can hold a *list* of globstrings. The delimiter that separates the globstrings in the list is OS-specific via os.path.pathsep. For example, on Linux the pathsep is':', so '/a/b/c.txt:/d/e/f.txt:../g/i/h.txt' is parsed as ['/a/b/c.txt', '/d/e/f.txt', '../g/i/h.txt'] """ GlobString = InputSlot() # The project hdf5 File object (already opened) SequenceAxis = InputSlot(optional=True) # The axis to stack across. OutputImage = OutputSlot() class WrongFileTypeError(Exception): def __init__(self, globString): self.filename = globString self.msg = f"File is not a HDF5 or N5: {globString}" super().__init__(self.msg) class InconsistentShape(Exception): def __init__(self, fileName, datasetName): self.fileName = fileName self.msg = ( f"Cannot stack dataset: {fileName}/{datasetName} because its shape differs from the shape of " f"the previous datasets") super().__init__(self.msg) class InconsistentDType(Exception): def __init__(self, fileName, datasetName): self.fileName = fileName self.msg = ( f"Cannot stack dataset: {fileName}/{datasetName} because its data type differs from the " f"type of the previous datasets") super().__init__(self.msg) class NotTheSameFileError(Exception): def __init__(self, globString): self.globString = globString self.msg = f"Glob string encompasses more than one HDF5/N5 file: {globString}" super().__init__(self.msg) class NoInternalPlaceholderError(Exception): def __init__(self, globString): self.globString = globString self.msg = f"Glob string does not contain a placeholder: {globString}" super().__init__(self.msg) class ExternalPlaceholderError(Exception): def __init__(self, globString): self.globString = globString self.msg = f"Glob string does contains an external placeholder (not supported!): {globString}" super().__init__(self.msg) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._h5N5File = None self._readers = [] self._opStacker = OpMultiArrayStacker(parent=self) self._opStacker.AxisIndex.setValue(0) def cleanUp(self): self._opStacker.Images.resize(0) for opReader in self._readers: opReader.cleanUp() if self._h5N5File is not None: assert isinstance( self._h5N5File, (h5py.File, z5py.N5File)), "_h5N5File should not be of any other type" self._h5N5File.close() super().cleanUp() def setupOutputs(self): pcs = PathComponents(self.GlobString.value.split(os.path.pathsep)[0]) self._h5N5File = OpStreamingH5N5Reader.get_h5_n5_file(pcs.externalPath, mode="r") self.checkGlobString(self.GlobString.value) file_paths = self.expandGlobStrings(self._h5N5File, self.GlobString.value) num_files = len(file_paths) if num_files == 0: self.OutputImage.disconnect() self.OutputImage.meta.NOTREADY = True return self.OutputImage.connect(self._opStacker.Output) # Get slice axes from first image try: opFirstImg = OpStreamingH5N5Reader(parent=self) opFirstImg.InternalPath.setValue(file_paths[0]) opFirstImg.H5N5File.setValue(self._h5N5File) slice_axes = opFirstImg.OutputImage.meta.getAxisKeys() opFirstImg.cleanUp() except RuntimeError as e: logger.error(str(e)) raise OpStreamingH5N5SequenceReaderS.FileOpenError( file_paths[0]) from e # Use given new axis or try to do something sensible if self.SequenceAxis.ready(): new_axis = self.SequenceAxis.value assert len(new_axis) == 1 assert new_axis in "tzyxc" else: # Try to pick an axis that doesn't already exist in each volume for new_axis in "tzc0": if new_axis not in slice_axes: break if new_axis == "0": # All axes used already. # Stack across first existing axis new_axis = slice_axes[0] self._opStacker.Images.resize(0) self._opStacker.Images.resize(num_files) self._opStacker.AxisFlag.setValue(new_axis) for opReader in self._readers: opReader.cleanUp() self._readers = [] dtype = None shape = None for filename, stacker_slot in zip(file_paths, self._opStacker.Images): opReader = OpStreamingH5N5Reader(parent=self) try: # Abort if the image-stack has no consistent dtype or shape if dtype is None: dtype = self._h5N5File[filename].dtype shape = self._h5N5File[filename].shape else: if dtype != self._h5N5File[filename].dtype: raise OpStreamingH5N5SequenceReaderS.InconsistentDType( pcs.externalPath, filename) if shape != self._h5N5File[filename].shape: raise OpStreamingH5N5SequenceReaderS.InconsistentShape( pcs.externalPath, filename) opReader.InternalPath.setValue(filename) opReader.H5N5File.setValue(self._h5N5File) except RuntimeError as e: logger.error(str(e)) raise OpStreamingH5N5SequenceReaderS.FileOpenError( file_paths[0]) from e else: stacker_slot.connect(opReader.OutputImage) self._readers.append(opReader) def propagateDirty(self, slot, subindex, roi): if slot == self.GlobString or slot == self.SequenceAxis: self.OutputImage.setDirty(slice(None)) @staticmethod def expandGlobStrings(h5N5File, globStrings): """Matches a list of globStrings to internal paths of files Args: h5N5File: h5py or z5py File object, or path(string). If a string is given, the file is opened and closed in this method. globStrings: string. glob or path strings delimited by os.pathsep Returns: List of internal paths matching the globstrings that were found in the provided h5py.File object """ if not isinstance(h5N5File, (h5py.File, z5py.N5File)): with OpStreamingH5N5Reader.get_h5_n5_file(h5N5File, mode="r") as f: ret = OpStreamingH5N5SequenceReaderS.expandGlobStrings( f, globStrings) return ret ret = [] # Parse list into separate globstrings and combine them for globString in globStrings.split(os.path.pathsep): s = globString.strip() components = PathComponents(s) ret += sorted( globH5N5(h5N5File, components.internalPath.lstrip("/"))) return ret @staticmethod def checkGlobString(globString): """Checks whether globString is valid for this class Rules for globString: * must only contain one distinct external path * multiple internal paths, or placeholder '*' must be contained Args: globString (string): String, one or multiple paths separated with os.path.pathsep and possibly containing '*' as a placeholder. Raises: OpStreamingH5N5SequenceReaderS.ExternalPlaceholderError: External placeholders are not supported. OpStreamingH5N5SequenceReaderS.NoInternalPlaceholderError: This exception is raised if only a single path is provided -> OpStreamingH5N5Reader should be used in this case. OpStreamingH5N5SequenceReaderS.NotTheSameFileError: if multiple hdf5 files are (possibly) referenced in the globstring, this Exception is raised -> OpStreamingH5N5SequenceReaderM should be used in this case. OpStreamingH5N5SequenceReaderS.WrongFileTypeError:If file- extensions are not among the known H5 extensions, this error is raised (see OpStreamingH5N5Reader.H5EXTS and OpStreamingH5N5Reader.N5EXTS) """ pathStrings = globString.split(os.path.pathsep) pathComponents = [PathComponents(p.strip()) for p in pathStrings] assert len(pathComponents) > 0 if not all(p.extension in OpStreamingH5N5Reader.H5EXTS + OpStreamingH5N5Reader.N5EXTS for p in pathComponents): raise OpStreamingH5N5SequenceReaderS.WrongFileTypeError(globString) if len(pathComponents) == 1: if pathComponents[0].internalPath is None: raise OpStreamingH5N5SequenceReaderS.NoInternalPlaceholderError( globString) if "*" not in pathComponents[0].internalPath: raise OpStreamingH5N5SequenceReaderS.NoInternalPlaceholderError( globString) if "*" in pathComponents[0].externalPath: raise OpStreamingH5N5SequenceReaderS.ExternalPlaceholderError( globString) else: sameExternal = all(pathComponents[0].externalPath == x.externalPath for x in pathComponents[1::]) if sameExternal is not True: raise OpStreamingH5N5SequenceReaderS.NotTheSameFileError( globString) if "*" in pathComponents[0].externalPath: raise OpStreamingH5N5SequenceReaderS.ExternalPlaceholderError( globString)
class OpValueCache(Operator, ObservableCache): """ This operator caches a value in its entirety, and allows for the value to be "forced in" from an external user. No memory management, no blockwise access. """ name = "OpValueCache" category = "Cache" Input = InputSlot() fixAtCurrent = InputSlot(value=False) Output = OutputSlot() loggerName = __name__ + ".OpValueCache" logger = logging.getLogger(loggerName) traceLogger = logging.getLogger("TRACE." + loggerName) def __init__(self, *args, **kwargs): super(OpValueCache, self).__init__(*args, **kwargs) self._dirty = True self._value = None self._lock = threading.Lock() self._request = None # Now that we're initialized, it's safe to register with the memory manager self.registerWithMemoryManager() def handle_unready(slot): self._dirty = True self.Input.notifyUnready(handle_unready) def usedMemory(self): if isinstance(self._value, numpy.ndarray): return self._value.nbytes return 0 #FIXME def fractionOfUsedMemoryDirty(self): if self._dirty: return 1.0 else: return 0.0 def generateReport(self, report): super(OpValueCache, self).generateReport(report) if self._value is None: s = "no value" else: t = str(type(self._value)) t = t[len("<type '"):-len("'>")] s = "value of type '{}'".format(t) report.info = s def setupOutputs(self): self.Output.meta.assignFrom(self.Input.meta) def execute(self, slot, subindex, roi, result): if self.fixAtCurrent.value is True or self._dirty is False: if result.shape == (1,): result[0] = self._value else: result[:] = self._value return result # Optimization: We don't let more than one caller trigger the value to be computed at the same time # If some other caller has already requested the value, we'll just wait for the request he already made. class State(): Dirty = 0 Waiting = 1 Clean = 2 request = None value = None with self._lock: # What state are we in? if not self._dirty: state = State.Clean elif self._request is not None: state = State.Waiting else: state = State.Dirty self.traceLogger.debug("State is: {}".format( {State.Dirty : 'Dirty', State.Waiting : 'Waiting', State.Clean : 'Clean'}[state]) ) # Obtain the request to wait for (create it if necessary) if state == State.Dirty: request = self.Input[...] self._request = request elif state == State.Waiting: request = self._request else: value = self._value # Now release the lock and block for the request if state != State.Clean: success = False while not success: try: if result.shape == (1,): value = request.wait()[0] else: value = request.wait() success = True except Request.InvalidRequestException: # Oops, we're sharing the request with another thread # and that other thread cancelled it before we got a chance to call wait(). # Just regenerate the request and try again... with self._lock: if request == self._request or self._request is None: request = self.Input[...] self._request = request else: request = self._request state = State.Dirty except Request.CancellationException: if state == State.Dirty: with self._lock: # If no other request has 'taken responsibility' since we were cancelled # (i.e. self._request is still the request that raised this exception.) if request == self._request: self._request = None # This is mostly to aid testing. raise if result.shape == (1,): result[0] = value else: result[:] = value # If we made the request, set the members if state == State.Dirty: with self._lock: self.Output._sig_value_changed() self._value = value self._request = None self._dirty = False return result def propagateDirty(self, slot, subindex, roi): if slot is self.Input: self._dirty = True if not self.fixAtCurrent.value: self.Output.setDirty(roi) elif slot is self.fixAtCurrent: if self.fixAtCurrent.value is False and self._dirty: self.Output.setDirty() def forceValue(self, value): """ Allows a 'back door' to force data into this cache. Note: Use this function carefully. """ with self._lock: self._value = value self._dirty = False self.Output.setDirty() def resetValue(self): """ Remove the value from the cache. """ with self._lock: self._value = None self._dirty = True self.Output.setDirty()
class OpSplitRequestsBlockwise(Operator): """ Large requests serviced on the downstream Output will be broken up into smaller requests, and requested in parallel from the upstream Input. The size of the smaller requests is determined by the BlockShape slot. A constructor argument offers an additional feature for exactly how requests are translated into blocks. """ Input = InputSlot(allow_mask=True) BlockShape = InputSlot() Output = OutputSlot(allow_mask=True) def __init__(self, always_request_full_blocks, *args, **kwargs): """ always_request_full_blocks: If True, requests for upstream data will always be the "full" block as specified by the BlockShape. The requests will not be truncated to match the user's requested ROI. (But the user's requested ROI will be used to extract the data from the block results.) This feature allows us to turn an "unblocked" cache into a "blocked" cache. (If we didn't expand requests to the full blocks they intersect, the upstream cache blocks would not have uniform size.) """ super(OpSplitRequestsBlockwise, self).__init__(*args, **kwargs) self._always_request_full_blocks = always_request_full_blocks def setupOutputs(self): self.Output.meta.assignFrom(self.Input.meta) if len(self.BlockShape.value) != len(self.Input.meta.shape): self.Output.meta.NOTREADY = True return self.Output.meta.ideal_blockshape = tuple(numpy.minimum(self.BlockShape.value, self.Input.meta.shape)) # Estimate ram usage per requested pixel ram_per_pixel = get_ram_per_element(self.Input.meta.dtype) # One 'pixel' includes all channels tagged_shape = self.Input.meta.getTaggedShape() if "c" in tagged_shape: ram_per_pixel *= float(tagged_shape["c"]) if self.Input.meta.ram_usage_per_requested_pixel is not None: ram_per_pixel = max(ram_per_pixel, self.Input.meta.ram_usage_per_requested_pixel) self.Output.meta.ram_usage_per_requested_pixel = ram_per_pixel def execute(self, slot, subindex, roi, result): clipped_block_rois = getIntersectingRois( self.Input.meta.shape, self.BlockShape.value, (roi.start, roi.stop), True ) if self._always_request_full_blocks: full_block_rois = getIntersectingRois( self.Input.meta.shape, self.BlockShape.value, (roi.start, roi.stop), False ) else: full_block_rois = clipped_block_rois pool = RequestPool() for full_block_roi, clipped_block_roi in zip(full_block_rois, clipped_block_rois): full_block_roi = numpy.asarray(full_block_roi) clipped_block_roi = numpy.asarray(clipped_block_roi) req = self.Input(*full_block_roi) output_roi = numpy.asarray(clipped_block_roi) - roi.start if (full_block_roi == clipped_block_roi).all(): req.writeInto(result[roiToSlice(*output_roi)]) else: roi_within_block = clipped_block_roi - full_block_roi[0] def copy_request_result(output_roi, roi_within_block, request_result): self.Output.stype.copy_data( result[roiToSlice(*output_roi)], request_result[roiToSlice(*roi_within_block)] ) req.notify_finished(partial(copy_request_result, output_roi, roi_within_block)) pool.add(req) del req pool.wait() def propagateDirty(self, slot, subindex, roi): if slot is self.Input: self.Output.setDirty(roi.start, roi.stop)
class OpUnmanagedCompressedCache(Operator): """ A blockwise cache that stores each block as a separate in-memory hdf5 file with a compressed dataset. The files for each block have an internal chunk-shape, which corresponds to the amount of data that has to be decompressed for a single pixel lookup. The chunk shape is prioritized as follows: 1. Input.meta.ideal_blockshape (make sure to set BlockShape to a multiple of ideal_blockshape!) 2. BlockShape, if available and smaller than 1MiB (raw) 3. Automatically determined shape with t=1, c=1 and xyz such that the blocks are smaller than 1MiB (raw) Note: This class is not managed by the memory manager, so there can be non-managed subclasses. The "managed" version is OpCompressedCache, defined below. Note: * It is not safe to call execute() and change the blockshape simultaneously. * it is not safe to reuse this cache #FIXME """ # Also used to asynchronously force data into the cache via __setitem__ (see setInSlot(), below() Input = InputSlot(allow_mask=True) # shape of internal in-memory hdf5 files (defaults to the whole volume) BlockShape = InputSlot(optional=True) # Output as numpy arrays Output = OutputSlot(allow_mask=True) InputHdf5 = InputSlot(optional=True, allow_mask=True) # A list of rois (tuples) of the blocks that are currently stored in the cache CleanBlocks = OutputSlot() # Provides data as hdf5 datasets. Only allowed for rois that exactly match a block. OutputHdf5 = OutputSlot(allow_mask=True) def __init__(self, *args, **kwargs): super(OpUnmanagedCompressedCache, self).__init__(*args, **kwargs) self._lock = RequestLock() self._init_cache(None) self._block_id_counter = itertools.count( ) # Used to ensure unique in-memory file names self._ignore_ideal_blockshape = False def _init_cache(self, new_blockshape): with self._lock: self._blockshape = new_blockshape self._cacheFiles = {} self._dirtyBlocks = set() self._blockLocks = {} self._chunkshape = self._chooseChunkshape(self._blockshape) self._last_access_times = collections.defaultdict(float) def cleanUp(self): logger.debug("Cleaning up") self._closeAllCacheFiles() super(OpUnmanagedCompressedCache, self).cleanUp() def setupOutputs(self): self.Output.meta.assignFrom(self.Input.meta) self.OutputHdf5.meta.assignFrom(self.Input.meta) self.CleanBlocks.meta.shape = (1, ) self.CleanBlocks.meta.dtype = object # no block shape given -> use the whole volume as one block new_blockshape = self.Input.meta.shape if self.BlockShape.ready(): new_blockshape = self.BlockShape.value if len(new_blockshape) != len(self.Input.meta.shape): self.Output.meta.NOTREADY = True self.CleanBlocks.meta.NOTREADY = True self.OutputHdf5.meta.NOTREADY = True self._init_cache(None) return # Clip blockshape to image bounds new_blockshape = tuple( numpy.minimum(new_blockshape, self.Input.meta.shape)) if new_blockshape != self._blockshape: # If the blockshape changes, we have to reset the entire cache. self._init_cache(new_blockshape) self.Output.meta.ideal_blockshape = new_blockshape def execute(self, slot, subindex, roi, destination): if slot == self.Output: return self._executeOutput(roi, destination) elif slot == self.CleanBlocks: return self._executeCleanBlocks(destination) elif slot == self.OutputHdf5: return self._executeOutputHdf5(roi, destination) else: assert False, "Unknown output slot: {}".format(slot.name) def _executeOutput(self, roi, destination): assert len(roi.stop) == len( self.Input.meta.shape ), "roi: {} has the wrong number of dimensions for Input shape: {}".format( roi, self.Input.meta.shape) assert numpy.less_equal(roi.stop, self.Input.meta.shape).all( ), "roi: {} is out-of-bounds for Input shape: {}".format( roi, self.Input.meta.shape) block_starts = getIntersectingBlocks(self._blockshape, (roi.start, roi.stop)) block_starts = list(map(tuple, block_starts)) # Ensure all block cache files are up-to-date self._waitForBlocks(block_starts) self._copyData(roi, destination, block_starts) return destination def _waitForBlocks(self, block_starts): """ Make sure that all blocks in the given list of blocks are present in the cache before returning. (Blocks that are not yet present will be requested from our Input slot.) """ reqPool = RequestPool() # (Do the work in parallel.) for block_start in block_starts: entire_block_roi = getBlockBounds(self.Output.meta.shape, self._blockshape, block_start) f = partial(self._ensureCached, entire_block_roi) reqPool.add(Request(f)) logger.debug("Waiting for {} blocks...".format(len(block_starts))) reqPool.wait() def _copyData(self, roi, destination, block_starts): # Copy data from each block # (Parallelism not needed here: h5py will serialize these requests anyway) logger.debug("Copying data from {} blocks...".format( len(block_starts))) for block_start in block_starts: entire_block_roi = getBlockBounds(self.Output.meta.shape, self._blockshape, block_start) # This block's portion of the roi intersecting_roi = getIntersection((roi.start, roi.stop), entire_block_roi) # Compute slicing within destination array and slicing within this block destination_relative_intersection = numpy.subtract( intersecting_roi, roi.start) block_relative_intersection = numpy.subtract( intersecting_roi, block_start) destination_relative_intersection_slicing = roiToSlice( *destination_relative_intersection) block_relative_intersection_slicing = roiToSlice( *block_relative_intersection) # Copy from block to destination dataset = self._getBlockDataset(entire_block_roi) if self.Output.meta.has_mask: destination.data[ destination_relative_intersection_slicing] = dataset[ "data"][block_relative_intersection_slicing] destination.mask[ destination_relative_intersection_slicing] = dataset[ "mask"][block_relative_intersection_slicing] destination.fill_value = dataset["fill_value"][()] else: destination[ destination_relative_intersection_slicing] = dataset[ block_relative_intersection_slicing] self._last_access_times[block_start] = time.time() def _executeCleanBlocks(self, destination): """ Execute function for the CleanBlocks output slot, which produces an *unsorted* list of block rois that the cache currently holds. """ # Set difference: clean = existing - dirty clean_block_starts = set(self._cacheFiles.keys()) - self._dirtyBlocks output_shape = self.Output.meta.shape clean_block_rois = list( map(partial(getBlockBounds, output_shape, self._blockshape), clean_block_starts)) results = [] for cbr in clean_block_rois: results.append([TinyVector(cbr[0]), TinyVector(cbr[1])]) destination[0] = results return destination def _executeOutputHdf5(self, roi, destination): logger.debug("Servicing request for hdf5 block {}".format(roi)) assert isinstance( destination, h5py.Group ), "OutputHdf5 slot requires an hdf5 GROUP to copy into (not a numpy array)." assert ((roi.start % self._blockshape) == 0).all( ), "OutputHdf5 slot requires roi to be exactly one block." block_roi = getBlockBounds(self.Output.meta.shape, self._blockshape, roi.start) assert (block_roi == numpy.array( (roi.start, roi.stop ))).all(), "OutputHdf5 slot requires roi to be exactly one block." block_roi = [roi.start, roi.stop] self._ensureCached(block_roi) dataset = self._getBlockDataset(block_roi) assert str( block_roi ) not in destination, "destination hdf5 group already has a dataset with this block's name" destination.copy(dataset, str(block_roi)) return destination def propagateDirty(self, slot, subindex, roi): if slot == self.Input: # Keep track of dirty blocks if self._blockshape is not None: with self._lock: block_starts = getIntersectingBlocks( self._blockshape, (roi.start, roi.stop)) block_starts = list(map(tuple, block_starts)) for block_start in block_starts: self._dirtyBlocks.add(block_start) # Forward to downstream connections self.Output.setDirty(roi) elif slot == self.BlockShape: # Everything is dirty self.Output.setDirty(slice(None)) else: assert False, "Unknown output slot" def _chooseChunkshape(self, blockshape): """ Choose an optimal chunkshape for our blockshape and Input shape. We assume access patterns to vary more in space than in time or channel and choose the inner chunk shape to be about 1MiB slices of t and c. Furthermore, we use the function lazyflow.utility.chunkHelpers.chooseChunkShape() to preserve the aspect ratio of the input (at least approximately). """ if blockshape is None: return None def isConsistent(idealshape): """ check if ideal block shape and given block shape are consistent shapes are consistent if, for each dimension, * input is unready, or * blockshape equals fullshape, or * idealshape divides blockshape evenly """ if not self.Input.ready(): return True fullshape = self.Input.meta.shape z = list(zip(idealshape, blockshape, fullshape)) m = [ i_b_f[1] == i_b_f[2] or i_b_f[1] % i_b_f[0] == 0 for i_b_f in z ] return all(m) if not self._ignore_ideal_blockshape and self.Input.ready(): # take the ideal chunk shape, but check if sane ideal = self.Input.meta.ideal_blockshape if ideal is not None: if len(ideal) == len(blockshape): ideal = numpy.asarray(ideal, dtype=numpy.int) for i, d in enumerate(ideal): if d == 0: ideal[i] = blockshape[i] if not isConsistent(ideal): logger.warning( "{}: BlockShape and ideal_blockshape are " "inconsistent {} vs {}".format( self.name, blockshape, ideal)) else: return tuple(ideal) else: logger.warning( "{}: Encountered meta.ideal_blockshape that does not fit the data" .format(self.name)) # we need to figure out an ideal chunk shape on our own # Start with a copy of blockshape axes = list(self.Output.meta.getTaggedShape().keys()) taggedBlockShape = collections.OrderedDict( list(zip(axes, self._blockshape))) dtypeBytes = self._getDtypeBytes(self.Output.meta.dtype) desiredSpace = 1024**2 / float(dtypeBytes) if bigintprod(blockshape) <= desiredSpace: return blockshape # set t and c to 1 for key in "tc": if key in taggedBlockShape: taggedBlockShape[key] = 1 logger.debug("desired space: {}".format(desiredSpace)) # extract only the spatial shape spatialKeys = [k for k in list(taggedBlockShape.keys()) if k in "xyz"] spatialShape = [taggedBlockShape[k] for k in spatialKeys] newSpatialShape = chooseChunkShape(spatialShape, desiredSpace) for k, v in zip(spatialKeys, newSpatialShape): taggedBlockShape[k] = v chunkShape = tuple(taggedBlockShape.values()) logger.debug("Using chunk shape: {}".format(chunkShape)) return chunkShape def _getDtypeBytes(self, dtype): if isinstance(dtype, numpy.dtype): # Make sure we're dealing with a type (e.g. numpy.float64), # not a numpy.dtype dtype = dtype.type return dtype().nbytes def usedMemory(self): tot, unc = self._usedMemory() self._compression_factor = 1.0 if tot > 0: self._compression_factor = unc / float(tot) return tot def _usedMemory(self): tot = 0.0 unc = 0.0 for key in list(self._cacheFiles.keys()): real, virt = self._memoryForBlock(key) tot += real unc += virt return tot, unc def _memoryForBlock(self, key): try: group = self._cacheFiles[key] except KeyError: # entry was removed, ignore it return 0 tot = 0 unc = 0 if "data" in group: ds = group["data"] # actual size tot += get_storage_size(ds) # uncompressed size unc += ds.size * self._getDtypeBytes(ds.dtype) if "mask" in group: tot += group["mask"].size * self._getDtypeBytes( group["mask"].dtype) if "fill_value" in group: tot += group["fill_value"].size * self._getDtypeBytes( group["fill_value"].dtype) return tot, unc def _getCacheFile(self, entire_block_roi): """ Get the cache file for the block that starts at block_start. If it doesn't exist yet, create it first. """ block_start = tuple(entire_block_roi[0]) if block_start in self._cacheFiles: return self._cacheFiles[block_start] with self._lock: if block_start not in self._cacheFiles: # Create an in-memory hdf5 file with a unique name # (the counter ensures that even blocks that have been deleted previously get a unique name when they are re-created). logger.debug("Creating a cache file for block: {}".format( list(block_start))) filename = (str(id(self)) + str(id(self._cacheFiles)) + str(block_start) + str(next(self._block_id_counter))) mem_file = h5py.File(filename, driver="core", backing_store=False, mode="w") # h5py will crash if the chunkshape is larger than the dataset shape. datashape = tuple(entire_block_roi[1] - entire_block_roi[0]) chunkshape = numpy.minimum(numpy.array(datashape), self._chunkshape) chunkshape = tuple(chunkshape) # Make a compressed dataset mem_file.create_dataset( "data", shape=datashape, dtype=self.Output.meta.dtype, chunks=chunkshape, compression="lzf") # lzf should be faster than gzip, # with a slightly worse compression ratio # Add mask information if needed. if self.Output.meta.has_mask: mem_file.create_dataset( "mask", shape=datashape, dtype=bool, chunks=chunkshape, compression="lzf") # lzf should be faster than gzip, # with a slightly worse compression ratio mem_file.create_dataset("fill_value", shape=tuple(), dtype=self.Output.meta.dtype) self._blockLocks[block_start] = RequestLock() self._cacheFiles[block_start] = mem_file self._dirtyBlocks.add(block_start) return self._cacheFiles[block_start] def _ensureCached(self, entire_block_roi): """ Ensure that the cache file for the given block is up-to-date. (Refresh it if it's dirty.) """ block_start = tuple(entire_block_roi[0]) block_file = self._getCacheFile(entire_block_roi) if block_start in self._dirtyBlocks: updated_cache = False with self._blockLocks[block_start]: # Check AGAIN now that we have the lock. # (Avoid doing this twice in parallel requests.) if block_start in self._dirtyBlocks: # Can't write directly into the hdf5 dataset because # h5py.dataset.__getitem__ creates a copy, not a view. # We must use a temporary numpy array to hold the data. data = self.Input(*entire_block_roi).wait() block_file["data"][...] = data if self.Output.meta.has_mask: block_file["mask"][...] = data.mask block_file["fill_value"][...] = data.fill_value if logger.isEnabledFor(logging.DEBUG): uncompressed_size = bigintprod( data.shape) * self._getDtypeBytes(data.dtype) storage_size = block_file["data"].id.get_storage_size() if "mask" in block_file: storage_size += block_file[ "mask"].id.get_storage_size() if "fill_value" in block_file: storage_size += block_file[ "fill_value"].id.get_storage_size() logger.debug( "Storage for block: {} is {}. ({}% of original)". format(block_start, storage_size, 100 * storage_size / uncompressed_size)) with self._lock: self._dirtyBlocks.remove(block_start) updated_cache = True if updated_cache: # Now that the lock is released, signal that the cache was updated. self.Output._sig_value_changed() self.OutputHdf5._sig_value_changed() self.CleanBlocks._sig_value_changed() def setInSlot(self, slot, subindex, roi, value): """ Overridden from Operator """ if slot == self.Input: self._setInSlotInput(slot, subindex, roi, value) elif slot == self.InputHdf5: self._setInSlotInputHdf5(slot, subindex, roi, value) else: assert False, "Invalid input slot for setInSlot(): {}".format( slot.name) def _setInSlotInput(self, slot, subindex, roi, value, store_zero_blocks=True): """ Write the data in the array 'value' into the cache. If the optional store_zero_blocks param is False, then don't bother creating cache blocks for blocks that are totally zero. """ assert len(roi.stop) == len( self.Input.meta.shape ), "roi: {} has the wrong number of dimensions for Input shape: {}".format( roi, self.Input.meta.shape) assert numpy.less_equal(roi.stop, self.Input.meta.shape).all( ), "roi: {} is out-of-bounds for Input shape: {}".format( roi, self.Input.meta.shape) block_starts = getIntersectingBlocks(self._blockshape, (roi.start, roi.stop)) block_starts = list(map(tuple, block_starts)) # Copy data to each block logger.debug("Copying data INTO {} blocks...".format( len(block_starts))) for block_start in block_starts: entire_block_roi = getBlockBounds(self.Output.meta.shape, self._blockshape, block_start) # This block's portion of the roi intersecting_roi = getIntersection((roi.start, roi.stop), entire_block_roi) # Compute slicing within source array and slicing within this block source_relative_intersection = numpy.subtract( intersecting_roi, roi.start) block_relative_intersection = numpy.subtract( intersecting_roi, block_start) source_relative_intersection_slicing = roiToSlice( *source_relative_intersection) block_relative_intersection_slicing = roiToSlice( *block_relative_intersection) new_block_data = value[source_relative_intersection_slicing] new_block_sum = new_block_data.sum() if not store_zero_blocks and new_block_sum == 0 and block_start not in self._cacheFiles: # Special fast-path: If this block doesn't exist yet, # don't bother creating if we're just going to fill it with zeros. # (This feature is used by the OpCompressedUserLabelArray) pass else: # Copy from source to block dataset = self._getBlockDataset(entire_block_roi) if self.Output.meta.has_mask: dataset["data"][ block_relative_intersection_slicing] = new_block_data.data dataset["mask"][ block_relative_intersection_slicing] = new_block_data.mask dataset["fill_value"][()] = new_block_data.fill_value # Untested. Write a test to use this. # # If we can, remove this block entirely. # if not store_zero_blocks and new_block_sum == 0 and (dataset["data"][:] == 0).all() and (dataset["mask"]).any() and (dataset["fill_value"] == 0).all(): # with self._lock: # with self._blockLocks[block_start]: # self._cacheFiles[block_start].close() # del self._cacheFiles[block_start] # del self._blockLocks[block_start] else: dataset[ block_relative_intersection_slicing] = new_block_data # If we can, remove this block entirely. if not store_zero_blocks and new_block_sum == 0 and ( dataset[:] == 0).all(): with self._lock: with self._blockLocks[block_start]: self._cacheFiles[block_start].close() del self._cacheFiles[block_start] del self._blockLocks[block_start] # Here, we assume that if this function is used to update ANY PART of a # block, he is responsible for updating the ENTIRE block. # Therefore, this block is no longer 'dirty' self._dirtyBlocks.discard(block_start) # self.Output._sig_value_changed() # self.OutputHdf5._sig_value_changed() # self.CleanBlocks._sig_value_changed() def _setInSlotInputHdf5(self, slot, subindex, roi, value): logger.debug("Setting block {} from hdf5".format(roi)) if self.Output.meta.has_mask: assert isinstance( value, h5py.Group ), "InputHdf5 slot requires an hdf5 Group to copy from (not a numpy masked array)." else: assert isinstance( value, h5py.Dataset ), "InputHdf5 slot requires an hdf5 Dataset to copy from (not a numpy array)." block_roi = getBlockBounds(self.Output.meta.shape, self._blockshape, roi.start) roi_is_exactly_one_block = True roi_is_exactly_one_block &= ((roi.start % self._blockshape) == 0).all() roi_is_exactly_one_block &= (block_roi == numpy.array( (roi.start, roi.stop))).all() if roi_is_exactly_one_block: cachefile = self._getCacheFile(block_roi) logger.debug( "Copying HDF5 data directly into block {}".format(block_roi)) if self.Output.meta.has_mask: assert len(value) == 3 for each in ["data", "mask", "fill_value"]: assert each in value assert cachefile[each].dtype == value[each].dtype assert cachefile[each].shape == value[each].shape for each in ["data", "mask", "fill_value"]: del cachefile[each] cachefile.copy(value[each], each) else: assert cachefile["data"].dtype == value.dtype assert cachefile["data"].shape == value.shape del cachefile["data"] cachefile.copy(value, "data") block_start = tuple(roi.start) self._dirtyBlocks.discard(block_start) else: # This hdf5 data does not correspond to exactly one block. # We must uncompress it and write it the "normal" way (the slow way) # FIXME: This would use less memory if we uncompressed the data block-by-block data = None if self.Output.meta.has_mask: data = numpy.ma.masked_array( value["data"][()], mask=value["mask"][()], fill_value=value["fill_value"][()], shrink=False) else: data = value[()] self.Input[roiToSlice(roi.start, roi.stop)] = data # self.Output._sig_value_changed() # self.OutputHdf5._sig_value_changed() # self.CleanBlocks._sig_value_changed() def _getBlockDataset(self, entire_block_roi): """ Get the correct cache file and return the *dataset* handle, not a numpy array of its contents. """ block_file = self._getCacheFile(entire_block_roi) if self.Output.meta.has_mask: return block_file["/"] else: return block_file["data"] def _closeAllCacheFiles(self): logger.debug("Closing all caches") cacheFiles = self._cacheFiles for k, v in list(cacheFiles.items()): with self._blockLocks[k]: v.close() with self._lock: self._blockLocks = {} self._cacheFiles = {}
class OpDataSelection(Operator): """ The top-level operator for the data selection applet, implemented as a single-image operator. The applet uses an OperatorWrapper to make it suitable for use in a workflow. """ name = "OpDataSelection" category = "Top-level" SupportedExtensions = OpInputDataReader.SupportedExtensions # Inputs RoleName = InputSlot(stype='string', value='') ProjectFile = InputSlot(stype='object', optional=True) #: The project hdf5 File object (already opened) ProjectDataGroup = InputSlot(stype='string', optional=True) #: The internal path to the hdf5 group where project-local datasets are stored within the project file WorkingDirectory = InputSlot(stype='filestring') #: The filesystem directory where the project file is located Dataset = InputSlot(stype='object') #: A DatasetInfo object # Outputs Image = OutputSlot() #: The output image AllowLabels = OutputSlot(stype='bool') #: A bool indicating whether or not this image can be used for training _NonTransposedImage = OutputSlot() #: The output slot, in the data's original axis ordering (regardless of forceAxisOrder) ImageName = OutputSlot(stype='string') #: The name of the output image class InvalidDimensionalityError(Exception): """Raised if the user tries to replace the dataset with a new one of differing dimensionality.""" def __init__(self, message ): super( OpDataSelection.InvalidDimensionalityError, self ).__init__() self.message = message def __str__(self): return self.message def __init__(self, forceAxisOrder=False, *args, **kwargs): super(OpDataSelection, self).__init__(*args, **kwargs) self.forceAxisOrder = forceAxisOrder self._opReaders = [] # If the gui calls disconnect() on an input slot without replacing it with something else, # we still need to clean up the internal operator that was providing our data. self.ProjectFile.notifyUnready(self.internalCleanup) self.ProjectDataGroup.notifyUnready(self.internalCleanup) self.WorkingDirectory.notifyUnready(self.internalCleanup) self.Dataset.notifyUnready(self.internalCleanup) def internalCleanup(self, *args): if len(self._opReaders) > 0: self.Image.disconnect() self._NonTransposedImage.disconnect() for reader in reversed(self._opReaders): reader.cleanUp() self._opReaders = [] def setupOutputs(self): self.internalCleanup() datasetInfo = self.Dataset.value try: # Data only comes from the project file if the user said so AND it exists in the project datasetInProject = (datasetInfo.location == DatasetInfo.Location.ProjectInternal) datasetInProject &= self.ProjectFile.ready() if datasetInProject: internalPath = self.ProjectDataGroup.value + '/' + datasetInfo.datasetId datasetInProject &= internalPath in self.ProjectFile.value # If we should find the data in the project file, use a dataset reader if datasetInProject: opReader = OpStreamingHdf5Reader(parent=self) opReader.Hdf5File.setValue(self.ProjectFile.value) opReader.InternalPath.setValue(internalPath) providerSlot = opReader.OutputImage elif datasetInfo.location == DatasetInfo.Location.PreloadedArray: preloaded_array = datasetInfo.preloaded_array assert preloaded_array is not None if not hasattr(preloaded_array, 'axistags'): # Guess the axis order, since one was not provided. axisorders = { 2 : 'yx', 3 : 'zyx', 4 : 'zyxc', 5 : 'tzyxc' } shape = preloaded_array.shape ndim = preloaded_array.ndim assert ndim != 0, "Support for 0-D data not yet supported" assert ndim != 1, "Support for 1-D data not yet supported" assert ndim <= 5, "No support for data with more than 5 dimensions." axisorder = axisorders[ndim] if ndim == 3 and shape[2] <= 4: # Special case: If the 3rd dim is small, assume it's 'c', not 'z' axisorder = 'yxc' preloaded_array = vigra.taggedView(preloaded_array, axisorder) opReader = OpArrayPiper(parent=self) opReader.Input.setValue( preloaded_array ) providerSlot = opReader.Output else: # Use a normal (filesystem) reader opReader = OpInputDataReader(parent=self) if datasetInfo.subvolume_roi is not None: opReader.SubVolumeRoi.setValue( datasetInfo.subvolume_roi ) opReader.WorkingDirectory.setValue( self.WorkingDirectory.value ) opReader.FilePath.setValue(datasetInfo.filePath) providerSlot = opReader.Output self._opReaders.append(opReader) # Inject metadata if the dataset info specified any. # Also, inject if if dtype is uint8, which we can reasonably assume has drange (0,255) metadata = {} metadata['display_mode'] = datasetInfo.display_mode role_name = self.RoleName.value if 'c' not in providerSlot.meta.getTaggedShape(): num_channels = 0 else: num_channels = providerSlot.meta.getTaggedShape()['c'] if num_channels > 1: metadata['channel_names'] = ["{}-{}".format(role_name, i) for i in range(num_channels)] else: metadata['channel_names'] = [role_name] if datasetInfo.drange is not None: metadata['drange'] = datasetInfo.drange elif providerSlot.meta.dtype == numpy.uint8: # SPECIAL case for uint8 data: Provide a default drange. # The user can always override this herself if she wants. metadata['drange'] = (0,255) if datasetInfo.normalizeDisplay is not None: metadata['normalizeDisplay'] = datasetInfo.normalizeDisplay if datasetInfo.axistags is not None: if len(datasetInfo.axistags) != len(providerSlot.meta.shape): # This usually only happens when we copied a DatasetInfo from another lane, # and used it as a 'template' to initialize this lane. # This happens in the BatchProcessingApplet when it attempts to guess the axistags of # batch images based on the axistags chosen by the user in the interactive images. # If the interactive image tags don't make sense for the batch image, you get this error. raise Exception( "Your dataset's provided axistags ({}) do not have the " "correct dimensionality for your dataset, which has {} dimensions." .format( "".join(tag.key for tag in datasetInfo.axistags), len(providerSlot.meta.shape) ) ) metadata['axistags'] = datasetInfo.axistags if datasetInfo.subvolume_roi is not None: metadata['subvolume_roi'] = datasetInfo.subvolume_roi # FIXME: We are overwriting the axistags metadata to intentionally allow # the user to change our interpretation of which axis is which. # That's okay, but technically there's a special corner case if # the user redefines the channel axis index. # Technically, it invalidates the meaning of meta.ram_usage_per_requested_pixel. # For most use-cases, that won't really matter, which is why I'm not worrying about it right now. opMetadataInjector = OpMetadataInjector( parent=self ) opMetadataInjector.Input.connect( providerSlot ) opMetadataInjector.Metadata.setValue( metadata ) providerSlot = opMetadataInjector.Output self._opReaders.append( opMetadataInjector ) self._NonTransposedImage.connect(providerSlot) if self.forceAxisOrder: # Before we re-order, make sure no non-singleton # axes would be dropped by the forced order. output_order = "".join(self.forceAxisOrder) provider_order = "".join(providerSlot.meta.getAxisKeys()) tagged_provider_shape = providerSlot.meta.getTaggedShape() dropped_axes = set(provider_order) - set(output_order) if any(tagged_provider_shape[a] > 1 for a in dropped_axes): msg = "The axes of your dataset ({}) are not compatible with the axes used by this workflow ({}). Please fix them."\ .format(provider_order, output_order) raise DatasetConstraintError("DataSelection", msg) op5 = OpReorderAxes(parent=self) op5.AxisOrder.setValue(self.forceAxisOrder) op5.Input.connect(providerSlot) providerSlot = op5.Output self._opReaders.append(op5) # If the channel axis is not last (or is missing), # make sure the axes are re-ordered so that channel is last. if providerSlot.meta.axistags.index('c') != len( providerSlot.meta.axistags )-1: op5 = OpReorderAxes( parent=self ) keys = providerSlot.meta.getTaggedShape().keys() try: # Remove if present. keys.remove('c') except ValueError: pass # Append keys.append('c') op5.AxisOrder.setValue( "".join( keys ) ) op5.Input.connect( providerSlot ) providerSlot = op5.Output self._opReaders.append( op5 ) # Connect our external outputs to the internal operators we chose self.Image.connect(providerSlot) # Set the image name and usage flag self.AllowLabels.setValue( datasetInfo.allowLabels ) # If the reading operator provides a nickname, use it. if self.Image.meta.nickname is not None: datasetInfo.nickname = self.Image.meta.nickname imageName = datasetInfo.nickname if imageName == "": imageName = datasetInfo.filePath self.ImageName.setValue(imageName) except: self.internalCleanup() raise def propagateDirty(self, slot, subindex, roi): # Output slots are directly connected to internal operators pass @classmethod def getInternalDatasets(cls, filePath): return OpInputDataReader.getInternalDatasets( filePath )
class OpDataSelection(Operator): """ The top-level operator for the data selection applet, implemented as a single-image operator. The applet uses an OperatorWrapper to make it suitable for use in a workflow. """ name = "OpDataSelection" category = "Top-level" SupportedExtensions = OpInputDataReader.SupportedExtensions # Inputs RoleName = InputSlot(stype='string', value='') ProjectFile = InputSlot( stype='object', optional=True) # : The project hdf5 File object (already opened) # : The internal path to the hdf5 group where project-local datasets are stored within the project file ProjectDataGroup = InputSlot(stype='string', optional=True) WorkingDirectory = InputSlot( stype='filestring' ) # : The filesystem directory where the project file is located Dataset = InputSlot(stype='object') # : A DatasetInfo object # Outputs Image = OutputSlot() # : The output image AllowLabels = OutputSlot( stype='bool' ) # : A bool indicating whether or not this image can be used for training # : The output slot, in the data's original axis ordering (regardless of forceAxisOrder) _NonTransposedImage = OutputSlot() ImageName = OutputSlot(stype='string') # : The name of the output image class InvalidDimensionalityError(Exception): """Raised if the user tries to replace the dataset with a new one of differing dimensionality.""" def __init__(self, message): super(OpDataSelection.InvalidDimensionalityError, self).__init__() self.message = message def __str__(self): return self.message def __init__(self, forceAxisOrder=['tczyx'], *args, **kwargs): """ forceAxisOrder: How to auto-reorder the input data before connecting it to the rest of the workflow. Should be a list of input orders that are allowed by the workflow For example, if the workflow can handle 2D and 3D, you might pass ['yxc', 'zyxc']. If it only handles exactly 5D, you might pass 'tzyxc', assuming that's how you wrote the workflow. todo: move toward 'tczyx' standard. """ super(OpDataSelection, self).__init__(*args, **kwargs) self.forceAxisOrder = forceAxisOrder self._opReaders = [] # If the gui calls disconnect() on an input slot without replacing it with something else, # we still need to clean up the internal operator that was providing our data. self.ProjectFile.notifyUnready(self.internalCleanup) self.ProjectDataGroup.notifyUnready(self.internalCleanup) self.WorkingDirectory.notifyUnready(self.internalCleanup) self.Dataset.notifyUnready(self.internalCleanup) def internalCleanup(self, *args): if len(self._opReaders) > 0: self.Image.disconnect() self._NonTransposedImage.disconnect() for reader in reversed(self._opReaders): reader.cleanUp() self._opReaders = [] def setupOutputs(self): self.internalCleanup() datasetInfo = self.Dataset.value try: # Data only comes from the project file if the user said so AND it exists in the project datasetInProject = ( datasetInfo.location == DatasetInfo.Location.ProjectInternal) datasetInProject &= self.ProjectFile.ready() if datasetInProject: internalPath = self.ProjectDataGroup.value + '/' + datasetInfo.datasetId datasetInProject &= internalPath in self.ProjectFile.value # If we should find the data in the project file, use a dataset reader if datasetInProject: opReader = OpStreamingH5N5Reader(parent=self) opReader.H5N5File.setValue(self.ProjectFile.value) opReader.InternalPath.setValue(internalPath) providerSlot = opReader.OutputImage elif datasetInfo.location == DatasetInfo.Location.PreloadedArray: preloaded_array = datasetInfo.preloaded_array assert preloaded_array is not None if not hasattr(preloaded_array, 'axistags'): axisorder = get_default_axisordering(preloaded_array.shape) preloaded_array = vigra.taggedView(preloaded_array, axisorder) opReader = OpArrayPiper(parent=self) opReader.Input.setValue(preloaded_array) providerSlot = opReader.Output else: if datasetInfo.realDataSource: # Use a normal (filesystem) reader opReader = OpInputDataReader(parent=self) if datasetInfo.subvolume_roi is not None: opReader.SubVolumeRoi.setValue( datasetInfo.subvolume_roi) opReader.WorkingDirectory.setValue( self.WorkingDirectory.value) opReader.SequenceAxis.setValue(datasetInfo.sequenceAxis) opReader.FilePath.setValue(datasetInfo.filePath) else: # Use fake reader: allows to run the project in a headless # mode without the raw data opReader = OpZeroDefault(parent=self) opReader.MetaInput.meta = MetaDict( shape=datasetInfo.laneShape, dtype=datasetInfo.laneDtype, drange=datasetInfo.drange, axistags=datasetInfo.axistags) opReader.MetaInput.setValue( numpy.zeros(datasetInfo.laneShape, dtype=datasetInfo.laneDtype)) providerSlot = opReader.Output self._opReaders.append(opReader) # Inject metadata if the dataset info specified any. # Also, inject if if dtype is uint8, which we can reasonably assume has drange (0,255) metadata = {} metadata['display_mode'] = datasetInfo.display_mode role_name = self.RoleName.value if 'c' not in providerSlot.meta.getTaggedShape(): num_channels = 0 else: num_channels = providerSlot.meta.getTaggedShape()['c'] if num_channels > 1: metadata['channel_names'] = [ "{}-{}".format(role_name, i) for i in range(num_channels) ] else: metadata['channel_names'] = [role_name] if datasetInfo.drange is not None: metadata['drange'] = datasetInfo.drange elif providerSlot.meta.dtype == numpy.uint8: # SPECIAL case for uint8 data: Provide a default drange. # The user can always override this herself if she wants. metadata['drange'] = (0, 255) if datasetInfo.normalizeDisplay is not None: metadata['normalizeDisplay'] = datasetInfo.normalizeDisplay if datasetInfo.axistags is not None: if len(datasetInfo.axistags) != len(providerSlot.meta.shape): ts = providerSlot.meta.getTaggedShape() if 'c' in ts and 'c' not in datasetInfo.axistags and len( datasetInfo.axistags) + 1 == len(ts): # provider has no channel axis, but template has => add channel axis to provider # fixme: Optimize the axistag guess in BatchProcessingApplet instead of hoping for the best here metadata['axistags'] = vigra.defaultAxistags( ''.join(datasetInfo.axistags.keys()) + 'c') else: # This usually only happens when we copied a DatasetInfo from another lane, # and used it as a 'template' to initialize this lane. # This happens in the BatchProcessingApplet when it attempts to guess the axistags of # batch images based on the axistags chosen by the user in the interactive images. # If the interactive image tags don't make sense for the batch image, you get this error. raise Exception( "Your dataset's provided axistags ({}) do not have the " "correct dimensionality for your dataset, which has {} dimensions." .format( "".join(tag.key for tag in datasetInfo.axistags), len(providerSlot.meta.shape))) else: metadata['axistags'] = datasetInfo.axistags if datasetInfo.original_axistags is not None: metadata['original_axistags'] = datasetInfo.original_axistags if datasetInfo.subvolume_roi is not None: metadata['subvolume_roi'] = datasetInfo.subvolume_roi # FIXME: We are overwriting the axistags metadata to intentionally allow # the user to change our interpretation of which axis is which. # That's okay, but technically there's a special corner case if # the user redefines the channel axis index. # Technically, it invalidates the meaning of meta.ram_usage_per_requested_pixel. # For most use-cases, that won't really matter, which is why I'm not worrying about it right now. opMetadataInjector = OpMetadataInjector(parent=self) opMetadataInjector.Input.connect(providerSlot) opMetadataInjector.Metadata.setValue(metadata) providerSlot = opMetadataInjector.Output self._opReaders.append(opMetadataInjector) self._NonTransposedImage.connect(providerSlot) # make sure that x and y axes are present in the selected axis order if 'x' not in providerSlot.meta.axistags or 'y' not in providerSlot.meta.axistags: raise DatasetConstraintError( "DataSelection", "Data must always have at leaset the axes x and y for ilastik to work." ) if self.forceAxisOrder: assert isinstance(self.forceAxisOrder, list), \ "forceAxisOrder should be a *list* of preferred axis orders" # Before we re-order, make sure no non-singleton # axes would be dropped by the forced order. tagged_provider_shape = providerSlot.meta.getTaggedShape() minimal_axes = [ k_v for k_v in list(tagged_provider_shape.items()) if k_v[1] > 1 ] minimal_axes = set(k for k, v in minimal_axes) # Pick the shortest of the possible 'forced' orders that # still contains all the axes of the original dataset. candidate_orders = list(self.forceAxisOrder) candidate_orders = [ order for order in candidate_orders if minimal_axes.issubset(set(order)) ] if len(candidate_orders) == 0: msg = "The axes of your dataset ({}) are not compatible with any of the allowed"\ " axis configurations used by this workflow ({}). Please fix them."\ .format(providerSlot.meta.getAxisKeys(), self.forceAxisOrder) raise DatasetConstraintError("DataSelection", msg) output_order = sorted(candidate_orders, key=len)[0] # the shortest one output_order = "".join(output_order) else: # No forced axisorder is supplied. Use original axisorder as # output order: it is assumed by the export-applet, that the # an OpReorderAxes operator is added in the beginning output_order = "".join( [x for x in providerSlot.meta.axistags.keys()]) op5 = OpReorderAxes(parent=self) op5.AxisOrder.setValue(output_order) op5.Input.connect(providerSlot) providerSlot = op5.Output self._opReaders.append(op5) # If the channel axis is missing, add it as last axis if 'c' not in providerSlot.meta.axistags: op5 = OpReorderAxes(parent=self) keys = providerSlot.meta.getAxisKeys() # Append keys.append('c') op5.AxisOrder.setValue("".join(keys)) op5.Input.connect(providerSlot) providerSlot = op5.Output self._opReaders.append(op5) # Connect our external outputs to the internal operators we chose self.Image.connect(providerSlot) self.AllowLabels.setValue(datasetInfo.allowLabels) # If the reading operator provides a nickname, use it. if self.Image.meta.nickname is not None: datasetInfo.nickname = self.Image.meta.nickname imageName = datasetInfo.nickname if imageName == "": imageName = datasetInfo.filePath self.ImageName.setValue(imageName) except: self.internalCleanup() raise def propagateDirty(self, slot, subindex, roi): # Output slots are directly connected to internal operators pass @classmethod def getInternalDatasets(cls, filePath): return OpInputDataReader.getInternalDatasets(filePath)
class OpCountingBatchResults(OpDataExport): # Add these additional input slots, to be used by the GUI. PmapColors = InputSlot() LabelNames = InputSlot()
class _OpThresholdOneLevel(Operator): name = "_OpThresholdOneLevel" InputImage = InputSlot() MinSize = InputSlot(stype='int', value=0) MaxSize = InputSlot(stype='int', value=1000000) Threshold = InputSlot(stype='float', value=0.5) Output = OutputSlot() #debug output BeforeSizeFilter = OutputSlot() def __init__(self, *args, **kwargs): super(_OpThresholdOneLevel, self).__init__(*args, **kwargs) self._opThresholder = OpPixelOperator(parent=self) self._opThresholder.Input.connect(self.InputImage) self._opLabeler = OpLabelVolume(parent=self) self._opLabeler.Method.setValue(_labeling_impl) self._opLabeler.Input.connect(self._opThresholder.Output) self.BeforeSizeFilter.connect(self._opLabeler.Output) self._opFilter = OpFilterLabels(parent=self) self._opFilter.Input.connect(self._opLabeler.Output) self._opFilter.MinLabelSize.connect(self.MinSize) self._opFilter.MaxLabelSize.connect(self.MaxSize) self._opFilter.BinaryOut.setValue(False) self.Output.connect(self._opFilter.Output) def setupOutputs(self): def thresholdToUint8(thresholdValue, a): drange = self.InputImage.meta.drange if drange is not None: assert drange[0] == 0,\ "Don't know how to threshold data with this drange." thresholdValue *= drange[1] if a.dtype == numpy.uint8: # In-place (numpy optimizes this!) a[:] = (a > thresholdValue) return a else: return (a > thresholdValue).astype(numpy.uint8) self._opThresholder.Function.setValue( partial(thresholdToUint8, self.Threshold.value)) # self.Output already has metadata: it is directly connected to self._opFilter.Output def execute(self, slot, subindex, roi, result): assert False, "Shouldn't get here..." def propagateDirty(self, slot, subindex, roi): pass # nothing to do here def setInSlot(self, slot, subindex, roi, value): # Nothing to do here. # Our Input slots are directly fed into the cache, # so all calls to __setitem__ are forwarded automatically pass
class OpResize5D(Operator): """ Resize a 5D image. Notes: - Input must be 5D, tzyxc - Resizing is performed across zyx dimensions only. time dimension may not be resized. """ Input = InputSlot() ResizedShape = InputSlot() Output = OutputSlot() def __init__(self, *args, **kwargs): super(OpResize5D, self).__init__(*args, **kwargs) self._input_to_output_scales = None self.progressSignal = OrderedSignal() def setupOutputs(self): assert self.Input.meta.getAxisKeys() == list('tzyxc') input_shape = self.Input.meta.shape output_shape = self.ResizedShape.value assert isinstance(output_shape, tuple) assert len(output_shape) == len(input_shape) self.Output.meta.assignFrom(self.Input.meta) self.Output.meta.shape = output_shape self._input_to_output_scales = numpy.array( output_shape, dtype=numpy.float32) / input_shape axes = self.Input.meta.getAxisKeys() if 'c' in axes: assert self._input_to_output_scales[ axes.index('c') ] == 1.0, \ "Resizing the channel dimension is not supported." if 't' in axes: assert self._input_to_output_scales[ axes.index('t') ] == 1.0, \ "Resizing the time dimension is not supported (yet)." def execute(self, slot, subindex, output_roi, result): # Special fast path if no resampling needed if self.Input.meta.shape == self.Output.meta.shape: self.Input(output_roi.start, output_roi.stop).writeInto(result).wait() return result # Map output_roi to input_roi output_roi = numpy.array((output_roi.start, output_roi.stop)) input_roi = output_roi / self._input_to_output_scales # Convert to int (round start down, round stop up) input_roi[1] += 0.5 input_roi = input_roi.astype(int) t_start = output_roi[0][0] t_stop = output_roi[1][0] def process_timestep(t): # Request input and resize it. # FIXME: This is not quite correct. We should request a halo that is wide enough # for the BSpline used by resize(). See vigra docs for BSlineBase.radius() step_input_roi = copy.copy(input_roi) step_input_roi[0][0] = t step_input_roi[1][0] = t + 1 step_input_data = self.Input(*step_input_roi).wait() step_input_data = vigra.taggedView(step_input_data, 'tzyxc') step_shape_4d = numpy.array(step_input_data[0].shape) step_shape_4d_nochannel = step_shape_4d[:-1] squeezed_slicing = numpy.where(step_shape_4d_nochannel == 1, 0, slice(None)) squeezed_slicing = tuple(squeezed_slicing) + (slice(None), ) step_input_squeezed = step_input_data[0][squeezed_slicing] result_step = result[t][squeezed_slicing] # vigra assumes wrong axis order if we don't specify one explicitly here... result_step = vigra.taggedView(result_step, step_input_squeezed.axistags) if self.Input.meta.dtype == numpy.float32: vigra.sampling.resize(step_input_squeezed, out=result_step) else: step_input_squeezed = step_input_squeezed.astype(numpy.float32) result_float = vigra.sampling.resize( step_input_squeezed, shape=result_step.shape[:-1]) result_step[:] = result_float.round() # FIXME: Progress here will not be correct for multiple threads. self.progressSignal(0) # FIXME: request pool... for t in range(t_start, t_stop): process_timestep(t) progress = 100 * (t - t_start) / (t_stop - t_start) self.progressSignal(int(progress)) self.progressSignal(100) return result def propagateDirty(self, slot, subindex, input_roi): # FIXME: When execute() is fixed to use a halo, we should also # incorporate the halo into this dirty propagation logic. # Map input_roi to output_roi input_roi = numpy.array((input_roi.start, input_roi.stop)) output_roi = input_roi * self._input_to_output_scales # Convert to int (round start down, round stop up) output_roi[1] += 0.5 output_roi = output_roi.astype(int) self.Output.setDirty(*output_roi)
class OpExportMultipageTiffSequence(Operator): Input = InputSlot( ) # The last two non-singleton axes (except 'c') are the axes of the slices. # Re-order the axes yourself if you want an alternative slicing direction FilepathPattern = InputSlot( ) # A complete filepath including a {slice_index} member and a valid file extension. SliceIndexOffset = InputSlot( value=0) # Added to the {slice_index} in the export filename. def __init__(self, *args, **kwargs): super(OpExportMultipageTiffSequence, self).__init__(*args, **kwargs) self.progressSignal = OrderedSignal() def run_export(self): """ Request the volume in slices (running in parallel), and write each slice to the correct page. Note: We can't use BigRequestStreamer here, because the data for each slice wouldn't be guaranteed to arrive in the correct order. """ # Make the directory first if necessary export_dir = os.path.split(self.FilepathPattern.value)[0] if not os.path.exists(export_dir): os.makedirs(export_dir) # Blockshape is the same as the input shape, except for the sliced dimension step_axis = self._volume_axes[0] tagged_blockshape = self.Input.meta.getTaggedShape() tagged_blockshape[step_axis] = 1 block_shape = (list(tagged_blockshape.values())) logger.debug( "Starting Multipage Sequence Export with block shape: {}".format( block_shape)) # Block step is all zeros except step axis, e.g. (0, 1, 0, 0, 0) block_step = numpy.array(self.Input.meta.getAxisKeys()) == step_axis block_step = block_step.astype(int) filepattern = self.FilepathPattern.value # If the user didn't provide custom formatting for the slice field, # auto-format to include zero-padding if '{slice_index}' in filepattern: filepattern = filepattern.format( slice_index='{' + 'slice_index:0{}'.format(self._max_slice_digits) + '}') self.progressSignal(0) # Nothing fancy here: Just loop over the blocks in order. tagged_shape = self.Input.meta.getTaggedShape() for block_index in range(tagged_shape[step_axis]): roi = numpy.array(roiFromShape(block_shape)) roi += block_index * block_step roi = list(map(tuple, roi)) try: opSubregion = OpSubRegion(parent=self) opSubregion.Roi.setValue(roi) opSubregion.Input.connect(self.Input) formatted_path = filepattern.format( slice_index=(block_index + self.SliceIndexOffset.value)) opExportBlock = OpExportMultipageTiff(parent=self) opExportBlock.Input.connect(opSubregion.Output) opExportBlock.Filepath.setValue(formatted_path) block_start_progress = 100 * block_index // tagged_shape[ step_axis] def _handleBlockProgress(block_progress): self.progressSignal(block_start_progress + block_progress // tagged_shape[step_axis]) opExportBlock.progressSignal.subscribe(_handleBlockProgress) # Run the export for this block opExportBlock.run_export() finally: opExportBlock.cleanUp() opSubregion.cleanUp() self.progressSignal(100) def setupOutputs(self): # If stacking XY images in Z-steps, # then self._volume_axes = 'zxy' self._volume_axes = self.get_nonsingleton_axes() step_axis = self._volume_axes[0] max_slice = self.SliceIndexOffset.value + self.Input.meta.getTaggedShape( )[step_axis] self._max_slice_digits = int(math.ceil(math.log10(max_slice + 1))) # Check for errors assert len(self._volume_axes) == 4 or len(self._volume_axes) == 5 and 'c' in self._volume_axes[1:], \ "Exported stacks must have exactly 4 non-singleton dimensions (other than the channel dimension). "\ "You stack dimensions are: {}".format( self.Input.meta.getTaggedShape() ) # Test to make sure the filepath pattern includes slice index field filepath_pattern = self.FilepathPattern.value assert '123456789' in filepath_pattern.format( slice_index=123456789 ), \ "Output filepath pattern must contain the '{slice_index}' field for formatting.\n"\ "Your format was: {}".format( filepath_pattern ) # No output slots... def execute(self, slot, subindex, roi, result): pass def propagateDirty(self, slot, subindex, roi): pass def get_nonsingleton_axes(self): return self.get_nonsingleton_axes_for_tagged_shape( self.Input.meta.getTaggedShape()) @classmethod def get_nonsingleton_axes_for_tagged_shape(self, tagged_shape): # Find the non-singleton axes. # The first non-singleton axis is the step axis. # The last 2 non-channel non-singleton axes will be the axes of the slices. tagged_items = list(tagged_shape.items()) filtered_items = [k_v for k_v in tagged_items if k_v[1] > 1] filtered_axes = list(zip(*filtered_items))[0] return filtered_axes
class OpStreamingUfmfReader(Operator): """ Imports videos in uFMF format. For more information refer to the Ctrax file tracker: http://ctrax.sourceforge.net/ """ name = "OpStreamingUfmfReader" category = "Input" position = None FileName = InputSlot(stype="filestring") Output = OutputSlot() class DatasetReadError(Exception): pass def __init__(self, *args, **kwargs): super(OpStreamingUfmfReader, self).__init__(*args, **kwargs) self._lock = threading.Lock() def setupOutputs(self): """ Load the file specified via our input slot and present its data on the output slot. """ fileName = self.FileName.value self.fmf = UfmfParser.FlyMovieEmulator(str(fileName)) frameNum = self.fmf.get_n_frames() width = self.fmf.get_width() height = self.fmf.get_height() try: self.frame, timestamp = self.fmf.get_next_frame() except FMF.NoMoreFramesException as err: logger.info("Error reading uFMF frame.") self.Output.meta.dtype = self.frame.dtype.type self.Output.meta.axistags = vigra.defaultAxistags(AXIS_ORDER) self.Output.meta.shape = (frameNum, self.frame.shape[0], self.frame.shape[1], 1) self.Output.meta.ideal_blockshape = (1, ) + self.Output.meta.shape[1:] def execute(self, slot, subindex, roi, result): start, stop = roi.start, roi.stop tStart, tStop = start[0], stop[0] yStart, yStop = start[1], stop[1] xStart, xStop = start[2], stop[2] cStart, cStop = start[3], stop[3] for tResult, tFrame in enumerate(range(tStart, tStop)): with self._lock: if self.position != tFrame: self.position = tFrame self.fmf.seek(tFrame) self.frame, timestamp = self.fmf.get_next_frame() result[tResult, ..., 0] = self.frame[yStart:yStop, xStart:xStop] def propagateDirty(self, slot, subindex, roi): if slot == self.FileName: self.Output.setDirty(slice(None)) def cleanUp(self): self.fmf.close() super(OpStreamingUfmfReader, self).cleanUp()
class OpExportSlot(Operator): """ Export a slot 'as-is', i.e. no subregion, no dtype conversion, no normalization, no axis re-ordering, etc. For sequence export formats, the sequence is indexed by the axistags' FIRST axis. For example, txyzc produces a sequence of xyzc volumes. """ Input = InputSlot() OutputFormat = InputSlot(value="hdf5") # string. See formats, below OutputFilenameFormat = ( InputSlot() ) # A format string allowing {roi}, {t_start}, {t_stop}, etc (but not {nickname} or {dataset_dir}) OutputInternalPath = InputSlot(value="exported_data") CoordinateOffset = InputSlot( optional=True ) # Add an offset to the roi coordinates in the export path (useful if Input is a subregion of a larger dataset) ExportPath = OutputSlot() FormatSelectionErrorMsg = OutputSlot() _2d_exts = vigra.impex.listExtensions().split() # List all supported formats _2d_formats = [FormatInfo(ext, ext, 2, 2) for ext in _2d_exts] _3d_sequence_formats = [ FormatInfo(ext + " sequence", ext, 3, 3) for ext in _2d_exts ] _3d_volume_formats = [FormatInfo("multipage tiff", "tiff", 3, 3)] _4d_sequence_formats = [ FormatInfo("multipage tiff sequence", "tiff", 4, 4) ] nd_format_formats = [ FormatInfo("hdf5", "h5", 0, 5), FormatInfo("compressed hdf5", "h5", 0, 5), FormatInfo("n5", "n5", 0, 5), FormatInfo("compressed n5", "n5", 0, 5), FormatInfo("numpy", "npy", 0, 5), FormatInfo("dvid", "", 2, 5), FormatInfo("blockwise hdf5", "json", 0, 5), ] ALL_FORMATS = _2d_formats + _3d_sequence_formats + _3d_volume_formats + _4d_sequence_formats + nd_format_formats def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.progressSignal = OrderedSignal() # Set up the impl function lookup dict export_impls = {} export_impls["hdf5"] = ("h5", self._export_h5n5) export_impls["compressed hdf5"] = ("h5", partial(self._export_h5n5, True)) export_impls["n5"] = ("n5", self._export_h5n5) export_impls["compressed n5"] = ("n5", partial(self._export_h5n5, True)) export_impls["numpy"] = ("npy", self._export_npy) export_impls["dvid"] = ("", self._export_dvid) export_impls["blockwise hdf5"] = ("json", self._export_blockwise_hdf5) for fmt in self._2d_formats: export_impls[fmt.name] = (fmt.extension, partial(self._export_2d, fmt.extension)) for fmt in self._3d_sequence_formats: export_impls[fmt.name] = (fmt.extension, partial(self._export_3d_sequence, fmt.extension)) export_impls["multipage tiff"] = ("tiff", self._export_multipage_tiff) export_impls["multipage tiff sequence"] = ( "tiff", self._export_multipage_tiff_sequence) self._export_impls = export_impls self.Input.notifyMetaChanged(self._updateFormatSelectionErrorMsg) def setupOutputs(self): self.ExportPath.meta.shape = (1, ) self.ExportPath.meta.dtype = object self.FormatSelectionErrorMsg.meta.shape = (1, ) self.FormatSelectionErrorMsg.meta.dtype = object if self.OutputFormat.value in ( "hdf5", "compressed hdf5") and self.OutputInternalPath.value == "": self.ExportPath.meta.NOTREADY = True def execute(self, slot, subindex, roi, result): if slot == self.ExportPath: return self._executeExportPath(result) else: assert False, "Unknown output slot: {}".format(slot.name) def _executeExportPath(self, result): path_format = self.OutputFilenameFormat.value file_extension = self._export_impls[self.OutputFormat.value][0] # Remove existing extension (if present) and add the correct extension (if any) if file_extension: path_format = os.path.splitext(path_format)[0] path_format += "." + file_extension # Provide the TOTAL path (including dataset name) if self.OutputFormat.value in ("hdf5", "compressed hdf5", "n5", "compressed n5"): path_format += "/" + self.OutputInternalPath.value roi = numpy.array(roiFromShape(self.Input.meta.shape)) # Intermediate state can cause coordinate offset and input shape to be mismatched. # Just don't use the offset if it looks wrong. # (The client will provide a valid offset later on.) if self.CoordinateOffset.ready() and len( self.CoordinateOffset.value) == len(roi[0]): offset = self.CoordinateOffset.value assert len(roi[0] == len(offset)) roi += offset optional_replacements = {} optional_replacements["roi"] = list(map(tuple, roi)) for key, (start, stop) in zip(self.Input.meta.getAxisKeys(), roi.transpose()): optional_replacements[key + "_start"] = start optional_replacements[key + "_stop"] = stop formatted_path = format_known_keys(path_format, optional_replacements, strict=False) result[0] = formatted_path return result def _updateFormatSelectionErrorMsg(self, *args): error_msg = self._get_format_selection_error_msg() self.FormatSelectionErrorMsg.setValue(error_msg) def _get_format_selection_error_msg(self, *args): """ If the currently selected format does not support the input image format, return an error message stating why. Otherwise, return an empty string. """ if not self.Input.ready(): return "Input not ready" output_format = self.OutputFormat.value # These cases support all combinations if output_format in ("hdf5", "compressed hdf5", "n5", "compressed n5", "npy", "blockwise hdf5"): return "" tagged_shape = self.Input.meta.getTaggedShape() axes = OpStackWriter.get_nonsingleton_axes_for_tagged_shape( tagged_shape) output_dtype = self.Input.meta.dtype if output_format == "dvid": # dvid requires a channel axis, which must come last. # Internally, we transpose it before sending it over the wire if list(tagged_shape.keys())[-1] != "c": return "DVID requires the last axis to be channel." # Make sure DVID supports this dtype/channel combo. from libdvid.voxels import VoxelsMetadata axiskeys = self.Input.meta.getAxisKeys() # We reverse the axiskeys because the export operator (see below) uses transpose_axes=True reverse_axiskeys = "".join(reversed(axiskeys)) reverse_shape = tuple(reversed(self.Input.meta.shape)) metainfo = VoxelsMetadata.create_default_metadata( reverse_shape, output_dtype, reverse_axiskeys, 0.0, "nanometers") try: metainfo.determine_dvid_typename() except Exception as ex: return str(ex) else: return "" return FormatValidity.check(self.Input.meta.getTaggedShape(), self.Input.meta.dtype, output_format) def propagateDirty(self, slot, subindex, roi): if slot == self.OutputFormat or slot == self.OutputFilenameFormat: self.ExportPath.setDirty() if slot == self.OutputFormat: self._updateFormatSelectionErrorMsg() def run_export_to_array(self): """ Export the slot data to an array, instead of to disk. The data is computed blockwise, as necessary. The result is returned. """ self.progressSignal(0) opExport = OpExportToArray(parent=self) try: opExport.progressSignal.subscribe(self.progressSignal) opExport.Input.connect(self.Input) return opExport.run_export_to_array() finally: opExport.cleanUp() self.progressSignal(100) def run_export(self): """ Perform the export and WAIT for it to complete. If you want asynchronous execution, run this function in a request: req = Request( opExport.run_export ) req.submit() """ output_format = self.OutputFormat.value try: export_func = self._export_impls[output_format][1] except KeyError as e: raise Exception(f"Unknown export format: {output_format}") from e else: mkdir_p(PathComponents(self.ExportPath.value).externalDirectory) export_func() def _export_h5n5(self, compress=False): self.progressSignal(0) # Create and open the hdf5/n5 file export_components = PathComponents(self.ExportPath.value) try: if os.path.isdir(export_components.externalPath ): # externalPath leads to a n5 file shutil.rmtree(export_components.externalPath ) # n5 is stored as a directory structure else: os.remove(export_components.externalPath) except OSError as ex: # It's okay if the file isn't there. if ex.errno != 2: raise try: with OpStreamingH5N5Reader.get_h5_n5_file( export_components.externalPath, "w") as h5N5File: # Create a temporary operator to do the work for us opH5N5Writer = OpH5N5WriterBigDataset(parent=self) try: opH5N5Writer.CompressionEnabled.setValue(compress) opH5N5Writer.h5N5File.setValue(h5N5File) opH5N5Writer.h5N5Path.setValue( export_components.internalPath) opH5N5Writer.Image.connect(self.Input) # The H5 Writer provides it's own progress signal, so just connect ours to it. opH5N5Writer.progressSignal.subscribe(self.progressSignal) # Perform the export and block for it in THIS THREAD. opH5N5Writer.WriteImage[:].wait() finally: opH5N5Writer.cleanUp() self.progressSignal(100) except IOError as ex: import sys msg = "\nException raised when attempting to export to {}: {}\n".format( export_components.externalPath, str(ex)) sys.stderr.write(msg) raise def _export_npy(self): self.progressSignal(0) export_path = self.ExportPath.value try: opWriter = OpNpyWriter(parent=self) opWriter.Filepath.setValue(export_path) opWriter.Input.connect(self.Input) # Run the export in this thread opWriter.write() finally: opWriter.cleanUp() self.progressSignal(100) def _export_dvid(self): self.progressSignal(0) export_path = self.ExportPath.value opExport = OpExportDvidVolume(transpose_axes=True, parent=self) try: opExport.Input.connect(self.Input) opExport.NodeDataUrl.setValue(export_path) # Run the export in this thread opExport.run_export() finally: opExport.cleanUp() self.progressSignal(100) def _export_blockwise_hdf5(self): raise NotImplementedError def _export_2d(self, fmt): self.progressSignal(0) export_path = self.ExportPath.value opExport = OpExport2DImage(parent=self) try: opExport.progressSignal.subscribe(self.progressSignal) opExport.Filepath.setValue(export_path) opExport.Input.connect(self.Input) # Run the export opExport.run_export() finally: opExport.cleanUp() self.progressSignal(100) def _export_3d_sequence(self, extension): self.progressSignal(0) export_path_base, export_path_ext = os.path.splitext( self.ExportPath.value) export_path_pattern = export_path_base + "." + extension try: opWriter = OpStackWriter(parent=self) opWriter.FilepathPattern.setValue(export_path_pattern) opWriter.Input.connect(self.Input) opWriter.progressSignal.subscribe(self.progressSignal) if self.CoordinateOffset.ready(): step_axis = opWriter.get_nonsingleton_axes()[0] step_axis_index = self.Input.meta.getAxisKeys().index( step_axis) step_axis_offset = self.CoordinateOffset.value[step_axis_index] opWriter.SliceIndexOffset.setValue(step_axis_offset) # Run the export opWriter.run_export() finally: opWriter.cleanUp() self.progressSignal(100) def _export_multipage_tiff(self): self.progressSignal(0) export_path = self.ExportPath.value try: opExport = OpExportMultipageTiff(parent=self) opExport.Filepath.setValue(export_path) opExport.Input.connect(self.Input) opExport.progressSignal.subscribe(self.progressSignal) # Run the export opExport.run_export() finally: opExport.cleanUp() self.progressSignal(100) def _export_multipage_tiff_sequence(self): self.progressSignal(0) export_path_base, export_path_ext = os.path.splitext( self.ExportPath.value) export_path_pattern = export_path_base + ".tiff" try: opExport = OpExportMultipageTiffSequence(parent=self) opExport.FilepathPattern.setValue(export_path_pattern) opExport.Input.connect(self.Input) opExport.progressSignal.subscribe(self.progressSignal) if self.CoordinateOffset.ready(): step_axis = opExport.get_nonsingleton_axes()[0] step_axis_index = self.Input.meta.getAxisKeys().index( step_axis) step_axis_offset = self.CoordinateOffset.value[step_axis_index] opExport.SliceIndexOffset.setValue(step_axis_offset) # Run the export opExport.run_export() finally: opExport.cleanUp() self.progressSignal(100)
class _OpVigraLabelVolume(Operator): """ Operator that simply wraps vigra's labelVolume function. """ name = "OpVigraLabelVolume" category = "Vigra" Input = InputSlot() BackgroundValue = InputSlot(optional=True) Output = OutputSlot() def setupOutputs(self): inputShape = self.Input.meta.shape # Must have at most 1 time slice timeIndex = self.Input.meta.axistags.index("t") assert timeIndex == len(inputShape) or inputShape[timeIndex] == 1 # Must have at most 1 channel channelIndex = self.Input.meta.axistags.channelIndex assert channelIndex == len(inputShape) or inputShape[channelIndex] == 1 self.Output.meta.assignFrom(self.Input.meta) self.Output.meta.dtype = numpy.uint32 def execute(self, slot, subindex, roi, destination): assert slot == self.Output resultView = destination.view(vigra.VigraArray) resultView.axistags = self.Input.meta.axistags inputData = self.Input(roi.start, roi.stop).wait() inputData = inputData.view(vigra.VigraArray) inputData.axistags = self.Input.meta.axistags # Drop the time axis, which vigra.labelVolume doesn't remove automatically axiskeys = [tag.key for tag in inputData.axistags] if "t" in axiskeys: inputData = inputData.bindAxis("t", 0) resultView = resultView.bindAxis("t", 0) # Drop the channel axis, too. if "c" in axiskeys: inputData = inputData.bindAxis("c", 0) resultView = resultView.bindAxis("c", 0) # I have no idea why, but vigra sometimes throws a precondition error if this line is present. # ...on the other hand, I can't remember why I added this line in the first place... # inputData = inputData.view(numpy.ndarray) if self.BackgroundValue.ready(): bg = self.BackgroundValue.value if isinstance(bg, numpy.ndarray): # If background value was given as a 1-element array, extract it. assert bg.size == 1 bg = bg.squeeze()[()] if isinstance(bg, numpy.float): bg = float(bg) else: bg = int(bg) if len(inputData.shape) == 2: vigra.analysis.labelImageWithBackground(inputData, background_value=bg, out=resultView) else: vigra.analysis.labelVolumeWithBackground(inputData, background_value=bg, out=resultView) else: if len(inputData.shape) == 2: vigra.analysis.labelImageWithBackground(inputData, out=resultView) else: vigra.analysis.labelVolumeWithBackground(inputData, out=resultView) return destination def propagateDirty(self, inputSlot, subindex, roi): if inputSlot == self.Input: # If anything changed, the whole image is now dirty # because a single pixel change can trigger a cascade of relabeling. self.Output.setDirty(slice(None)) elif inputSlot == self.BackgroundValue: self.Output.setDirty(slice(None))
class OpAnisotropicGaussianSmoothing5d(Operator): # raw volume, in 5d 'txyzc' order Input = InputSlot() Sigmas = InputSlot(value={'x': 1.0, 'y': 1.0, 'z': 1.0}) Output = OutputSlot() def setupOutputs(self): self.Output.meta.assignFrom(self.Input.meta) self.Output.meta.dtype = numpy.float32 # vigra gaussian only supports float32 self._sigmas = self.Sigmas.value assert isinstance(self.Sigmas.value, dict), "Sigmas slot expects a dict" assert set(self._sigmas.keys()) == set( 'xyz'), "Sigmas slot expects three key-value pairs for x,y,z" def execute(self, slot, subindex, roi, result): assert all(roi.stop <= self.Input.meta.shape),\ "Requested roi {} is too large for this input image of shape {}.".format(roi, self.Input.meta.shape) # Determine how much input data we'll need, and where the result will be # relative to that input roi # inputRoi is a 5d roi, computeRoi depends on the number of singletons # in shape, but is at most 3d inputRoi, computeRoi = self._getInputComputeRois(roi) # Obtain the input data with Timer() as resultTimer: data = self.Input(*inputRoi).wait() logger.debug("Obtaining input data took {} seconds for roi {}".format( resultTimer.seconds(), inputRoi)) data = vigra.taggedView(data, axistags='txyzc') # input is in txyzc order tIndex = 0 cIndex = 4 # Must be float32 if data.dtype != numpy.float32: data = data.astype(numpy.float32) # we need to remove a singleton z axis, otherwise we get # 'kernel longer than line' errors ts = self.Input.meta.getTaggedShape() tags = [k for k in 'xyz' if ts[k] > 1] sigma = [self._sigmas[k] for k in tags] # Check if we need to smooth if any([x < 0.1 for x in sigma]): # just pipe the input through result[...] = data return for i, t in enumerate(xrange(roi.start[tIndex], roi.stop[tIndex])): for j, c in enumerate(xrange(roi.start[cIndex], roi.stop[cIndex])): # prepare the result as an argument resview = vigra.taggedView(result[i, ..., j], axistags='xyz') dataview = data[i, ..., j] # TODO make this general, not just for z axis resview = resview.withAxes(*tags) dataview = dataview.withAxes(*tags) # Smooth the input data vigra.filters.gaussianSmoothing(dataview, sigma, window_size=2.0, roi=computeRoi, out=resview) def _getInputComputeRois(self, roi): shape = self.Input.meta.shape start = numpy.asarray(roi.start) stop = numpy.asarray(roi.stop) n = len(stop) spatStart = [roi.start[i] for i in range(n) if shape[i] > 1] spatStop = [roi.stop[i] for i in range(n) if shape[i] > 1] sigma = [0] + map(self._sigmas.get, 'xyz') + [0] spatialRoi = (spatStart, spatStop) inputSpatialRoi = enlargeRoiForHalo(roi.start, roi.stop, shape, sigma, window=2.0) # Determine the roi within the input data we're going to request inputRoiOffset = roi.start - inputSpatialRoi[0] computeRoi = [inputRoiOffset, inputRoiOffset + stop - start] for i in (0, 1): computeRoi[i] = [ computeRoi[i][j] for j in range(n) if shape[j] > 1 and j not in (0, 4) ] # make sure that vigra understands our integer types computeRoi = (tuple(map(int, computeRoi[0])), tuple(map(int, computeRoi[1]))) inputRoi = (list(inputSpatialRoi[0]), list(inputSpatialRoi[1])) return inputRoi, computeRoi def propagateDirty(self, slot, subindex, roi): if slot == self.Input: # Halo calculation is bidirectional, so we can re-use the function # that computes the halo during execute() inputRoi, _ = self._getInputComputeRois(roi) self.Output.setDirty(inputRoi[0], inputRoi[1]) elif slot == self.Sigmas: self.Output.setDirty(slice(None)) else: assert False, "Unknown input slot: {}".format(slot.name)
class _OpThresholdTwoLevels(Operator): name = "_OpThresholdTwoLevels" InputImage = InputSlot() MinSize = InputSlot(stype='int', value=0) MaxSize = InputSlot(stype='int', value=1000000) HighThreshold = InputSlot(stype='float', value=0.5) LowThreshold = InputSlot(stype='float', value=0.2) Output = OutputSlot() CachedOutput = OutputSlot() # For the GUI (blockwise-access) # For serialization InputHdf5 = InputSlot(optional=True) OutputHdf5 = OutputSlot() CleanBlocks = OutputSlot() # Debug outputs BigRegions = OutputSlot() SmallRegions = OutputSlot() FilteredSmallLabels = OutputSlot() # Schematic: # # HighThreshold MinSize,MaxSize --(cache)--> opColorize -> FilteredSmallLabels # \ \ / # opHighThresholder --> opHighLabeler --> opHighLabelSizeFilter Output # / \ / \ \ / # InputImage --(cache)--> SmallRegions opSelectLabels -->opFinalLabelSizeFilter--> opCache --> CachedOutput # \ / / \ # opLowThresholder ----> opLowLabeler -------------------------- InputHdf5 --> OutputHdf5 # / \ -> CleanBlocks # LowThreshold --(cache)--> BigRegions def __init__(self, *args, **kwargs): super(_OpThresholdTwoLevels, self).__init__(*args, **kwargs) self._opLowThresholder = OpPixelOperator(parent=self) self._opLowThresholder.Input.connect(self.InputImage) self._opHighThresholder = OpPixelOperator(parent=self) self._opHighThresholder.Input.connect(self.InputImage) self._opLowLabeler = OpLabelVolume(parent=self) self._opLowLabeler.Method.setValue(_labeling_impl) self._opLowLabeler.Input.connect(self._opLowThresholder.Output) self._opHighLabeler = OpLabelVolume(parent=self) self._opHighLabeler.Method.setValue(_labeling_impl) self._opHighLabeler.Input.connect(self._opHighThresholder.Output) self._opHighLabelSizeFilter = OpFilterLabels(parent=self) self._opHighLabelSizeFilter.Input.connect(self._opHighLabeler.Output) self._opHighLabelSizeFilter.MinLabelSize.connect(self.MinSize) self._opHighLabelSizeFilter.MaxLabelSize.connect(self.MaxSize) self._opHighLabelSizeFilter.BinaryOut.setValue( False) # we do the binarization in opSelectLabels # this way, we get to display pretty colors self._opSelectLabels = OpSelectLabels(parent=self) self._opSelectLabels.BigLabels.connect(self._opLowLabeler.Output) self._opSelectLabels.SmallLabels.connect( self._opHighLabelSizeFilter.Output) # remove the remaining very large objects - # they might still be present in case a big object # was split into many small ones for the higher threshold # and they got reconnected again at lower threshold self._opFinalLabelSizeFilter = OpFilterLabels(parent=self) self._opFinalLabelSizeFilter.Input.connect(self._opSelectLabels.Output) self._opFinalLabelSizeFilter.MinLabelSize.connect(self.MinSize) self._opFinalLabelSizeFilter.MaxLabelSize.connect(self.MaxSize) self._opFinalLabelSizeFilter.BinaryOut.setValue(False) self._opCache = OpCompressedCache(parent=self) self._opCache.name = "_OpThresholdTwoLevels._opCache" self._opCache.InputHdf5.connect(self.InputHdf5) self._opCache.Input.connect(self._opFinalLabelSizeFilter.Output) # Connect our own outputs self.Output.connect(self._opFinalLabelSizeFilter.Output) self.CachedOutput.connect(self._opCache.Output) # Serialization outputs self.CleanBlocks.connect(self._opCache.CleanBlocks) self.OutputHdf5.connect(self._opCache.OutputHdf5) #self.InputChannel.connect( self._opChannelSelector.Output ) # More debug outputs. These all go through their own caches self._opBigRegionCache = OpCompressedCache(parent=self) self._opBigRegionCache.name = "_OpThresholdTwoLevels._opBigRegionCache" self._opBigRegionCache.Input.connect(self._opLowThresholder.Output) self.BigRegions.connect(self._opBigRegionCache.Output) self._opSmallRegionCache = OpCompressedCache(parent=self) self._opSmallRegionCache.name = "_OpThresholdTwoLevels._opSmallRegionCache" self._opSmallRegionCache.Input.connect(self._opHighThresholder.Output) self.SmallRegions.connect(self._opSmallRegionCache.Output) self._opFilteredSmallLabelsCache = OpCompressedCache(parent=self) self._opFilteredSmallLabelsCache.name = "_OpThresholdTwoLevels._opFilteredSmallLabelsCache" self._opFilteredSmallLabelsCache.Input.connect( self._opHighLabelSizeFilter.Output) self._opColorizeSmallLabels = OpColorizeLabels(parent=self) self._opColorizeSmallLabels.Input.connect( self._opFilteredSmallLabelsCache.Output) self.FilteredSmallLabels.connect(self._opColorizeSmallLabels.Output) def setupOutputs(self): def thresholdToUint8(thresholdValue, a): drange = self.InputImage.meta.drange if drange is not None: assert drange[0] == 0,\ "Don't know how to threshold data with this drange." thresholdValue *= drange[1] if a.dtype == numpy.uint8: # In-place (numpy optimizes this!) a[:] = (a > thresholdValue) return a else: return (a > thresholdValue).astype(numpy.uint8) self._opLowThresholder.Function.setValue( partial(thresholdToUint8, self.LowThreshold.value)) self._opHighThresholder.Function.setValue( partial(thresholdToUint8, self.HighThreshold.value)) # Output is already connected internally -- don't reassign new metadata # self.Output.meta.assignFrom(self.InputImage.meta) # Blockshape is the entire spatial volume (hysteresis thresholding is # a global operation) tagged_shape = self.Output.meta.getTaggedShape() tagged_shape['c'] = 1 tagged_shape['t'] = 1 self._opCache.BlockShape.setValue(tuple(tagged_shape.values())) self._opBigRegionCache.BlockShape.setValue(tuple( tagged_shape.values())) self._opSmallRegionCache.BlockShape.setValue( tuple(tagged_shape.values())) self._opFilteredSmallLabelsCache.BlockShape.setValue( tuple(tagged_shape.values())) def execute(self, slot, subindex, roi, result): assert False, "Shouldn't get here..." def propagateDirty(self, slot, subindex, roi): pass # Nothing to do here def setInSlot(self, slot, subindex, roi, value): assert slot == self.InputHdf5,\ "Invalid slot for setInSlot(): {}".format(slot.name)
class OpVigraWatershed(Operator): """ Operator wrapper for vigra's default watershed function. """ name = "OpVigraWatershed" category = "Vigra" InputImage = InputSlot() PaddingWidth = InputSlot( ) # Specifies the extra pixels around the border of the image to use when computing the watershed. # (Region is clipped to the size of the input image.) SeedImage = InputSlot(optional=True) Output = OutputSlot() def __init__(self, *args, **kwargs): super(OpVigraWatershed, self).__init__(*args, **kwargs) # Keep a dict of roi : max label self._maxLabels = {} self._lock = threading.Lock() @property def maxLabels(self): return self._maxLabels def clearMaxLabels(self): with self._lock: self._maxLabels = {} def setupOutputs(self): self.Output.meta.assignFrom(self.InputImage.meta) self.Output.meta.dtype = numpy.uint32 #warnings.warn("FIXME: How can this drange be right?") #self.Output.meta.drange = (0,255) if self.SeedImage.ready(): assert numpy.issubdtype(self.SeedImage.meta.dtype, numpy.uint32) assert self.SeedImage.meta.shape == self.InputImage.meta.shape, \ "{} != {}".format(self.SeedImage.meta.shape, self.InputImage.meta.shape) def getSlicings(self, roi): """ Pad the given roi to obtain a new slicing to use for obtaining input data. Return the padded slicing and the slicing that returns the original roi within the padded data. """ tags = self.InputImage.meta.axistags pairs = list( zip([tag.key for tag in tags], list(zip(roi.start, roi.stop)))) slices = [(k, slice(start, stop)) for (k, (start, stop)) in pairs] # Compute the watershed over a larger area than requested (padded area) padding = self.PaddingWidth.value paddedSlices = [] # The requested slicing + padding outputSlices = [ ] # The slicing to get the requested slicing from the padded data for i, (key, s) in enumerate(slices): p = s if key in 'xyz': p_start = max(s.start - padding, 0) p_stop = min(s.stop + padding, self.InputImage.meta.shape[i]) p = slice(p_start, p_stop) paddedSlices += [p] o = slice(s.start - p.start, s.stop - p.start) outputSlices += [o] return paddedSlices, outputSlices def execute(self, slot, subindex, roi, result): assert slot == self.Output # Every request is computed on-the-fly. # (No caching) paddedSlices, outputSlices = self.getSlicings(roi) # Get input data inputRegion = self.InputImage[paddedSlices].wait() # Makes sure vigra will understand this type if inputRegion.dtype != numpy.uint8 and inputRegion.dtype != numpy.float32: inputRegion = inputRegion.astype('float32') # Convert to vigra array inputRegion = inputRegion.view(vigra.VigraArray) inputRegion.axistags = self.InputImage.meta.axistags # Reduce to 3-D (keep order of xyz axes) tags = self.InputImage.meta.axistags axes3d = "".join([tag.key for tag in tags if tag.key in 'xyz']) inputRegion = inputRegion.withAxes(*axes3d) logger.debug('inputRegion 3D shape:{}'.format(inputRegion.shape)) logger.debug("roi={}".format(roi)) logger.debug("paddedSlices={}".format(paddedSlices)) logger.debug("outputSlices={}".format(outputSlices)) # If we know the range of the data, then convert to uint8 # so we can automatically benefit from vigra's "turbo" mode if self.InputImage.meta.drange is not None: drange = self.InputImage.meta.drange inputRegion = numpy.asarray(inputRegion, dtype=numpy.float32) inputRegion = vigra.taggedView(inputRegion, axes3d) inputRegion -= drange[0] inputRegion /= (drange[1] - drange[0]) inputRegion *= 255.0 inputRegion = inputRegion.astype(numpy.uint8) # This is where the magic happens if self.SeedImage.ready(): seedImage = self.SeedImage[paddedSlices].wait() seedImage = seedImage.view(vigra.VigraArray) seedImage.axistags = tags seedImage = seedImage.withAxes(*axes3d) logger.debug("Input shape = {}, seed shape = {}".format( inputRegion.shape, seedImage.shape)) logger.debug("Input axes = {}, seed axes = {}".format( inputRegion.axistags, seedImage.axistags)) watershed, maxLabel = vigra.analysis.watersheds(inputRegion, seeds=seedImage) else: watershed, maxLabel = vigra.analysis.watersheds(inputRegion) logger.debug("Finished Watershed") logger.debug("watershed 3D output shape={}".format(watershed.shape)) logger.debug("maxLabel={}".format(maxLabel)) # Promote back to 5-D watershed = vigra.taggedView(watershed, axes3d) watershed = watershed.withAxes(*[tag.key for tag in tags]) logger.debug("watershed 5D shape: {}".format(watershed.shape)) logger.debug("watershed axistags: {}".format(watershed.axistags)) with self._lock: start = tuple(s.start for s in paddedSlices) stop = tuple(s.stop for s in paddedSlices) self._maxLabels[(start, stop)] = maxLabel #print numpy.sort(vigra.analysis.unique(watershed[outputSlices])).shape # Return only the region the user requested result[:] = watershed[outputSlices].view(numpy.ndarray).reshape( result.shape) return result def propagateDirty(self, inputSlot, subindex, roi): if not self.configured(): self.Output.setDirty(slice(None)) elif inputSlot.name == "InputImage" or inputSlot.name == "SeedImage": paddedSlicing, outputSlicing = self.getSlicings(roi) self.Output.setDirty(paddedSlicing) elif inputSlot.name == "PaddingWidth": self.Output.setDirty(slice(None)) else: assert False, "Unknown input slot."
class _OpCacheWrapper(Operator): name = "OpCacheWrapper" Input = InputSlot() Output = OutputSlot() InputHdf5 = InputSlot(optional=True) CleanBlocks = OutputSlot() OutputHdf5 = OutputSlot() def __init__(self, *args, **kwargs): super(_OpCacheWrapper, self).__init__(*args, **kwargs) op1 = OpReorderAxes(parent=self) op1.name = "op1" op2 = OpReorderAxes(parent=self) op2.name = "op2" op1.AxisOrder.setValue('xyzct') op2.AxisOrder.setValue('txyzc') op1.Input.connect(self.Input) self.Output.connect(op2.Output) self._op1 = op1 self._op2 = op2 self._cache = None def setupOutputs(self): self._disconnectInternals() # we need a new cache cache = OpCompressedCache(parent=self) cache.name = self.name + "WrappedCache" # connect cache outputs self.CleanBlocks.connect(cache.CleanBlocks) self.OutputHdf5.connect(cache.OutputHdf5) self._op2.Input.connect(cache.Output) # connect cache inputs cache.InputHdf5.connect(self.InputHdf5) cache.Input.connect(self._op1.Output) # set the cache block shape tagged_shape = self._op1.Output.meta.getTaggedShape() tagged_shape['t'] = 1 tagged_shape['c'] = 1 cacheshape = map(lambda k: tagged_shape[k], 'xyzct') if _labeling_impl == "lazy": #HACK hardcoded block shape blockshape = numpy.minimum(cacheshape, 256) else: # use full spatial volume if not lazy blockshape = cacheshape cache.BlockShape.setValue(tuple(blockshape)) self._cache = cache def execute(self, slot, subindex, roi, result): assert False def propagateDirty(self, slot, subindex, roi): pass def setInSlot(self, slot, subindex, key, value): assert slot == self.InputHdf5,\ "setInSlot not implemented for slot {}".format(slot.name) assert self._cache is not None,\ "setInSlot called before input was configured" self._cache.setInSlot(self._cache.InputHdf5, subindex, key, value) def _disconnectInternals(self): self.CleanBlocks.disconnect() self.OutputHdf5.disconnect() self._op2.Input.disconnect() if self._cache is not None: self._cache.InputHdf5.disconnect() self._cache.Input.disconnect() del self._cache
class OpSparseLabelArray(OpCache): name = "Sparse Label Array" description = "simple cache for sparse label arrays" inputSlots = [ InputSlot("Input", optional=True), InputSlot("shape"), InputSlot("eraser"), InputSlot("deleteLabel", optional=True) ] outputSlots = [ OutputSlot("Output"), OutputSlot("nonzeroValues"), OutputSlot("nonzeroCoordinates"), OutputSlot("maxLabel") ] def __init__(self, *args, **kwargs): super(OpSparseLabelArray, self).__init__(*args, **kwargs) self.lock = threading.Lock() self._denseArray = None self._sparseNZ = None self._oldShape = (0, ) self._maxLabel = 0 def usedMemory(self): if self._denseArray is not None: return self._denseArray.nbytes return 0 def lastAccessTime(self): return 0 #return self._last_access def generateReport(self, report): report.name = self.name #report.fractionOfUsedMemoryDirty = self.fractionOfUsedMemoryDirty() report.usedMemory = self.usedMemory() #report.lastAccessTime = self.lastAccessTime() report.dtype = self.Output.meta.dtype report.type = type(self) report.id = id(self) def setupOutputs(self): if (numpy.array(self._oldShape) != self.inputs["shape"].value).any(): shape = self.inputs["shape"].value self._oldShape = shape self.outputs["Output"].meta.dtype = numpy.uint8 self.outputs["Output"].meta.shape = shape # FIXME: Don't give arbitrary axistags. Specify them correctly if you need them. #self.outputs["Output"].meta.axistags = vigra.defaultAxistags(len(shape)) self.inputs["Input"].meta.shape = shape self.outputs["nonzeroValues"].meta.dtype = object self.outputs["nonzeroValues"].meta.shape = (1, ) self.outputs["nonzeroCoordinates"].meta.dtype = object self.outputs["nonzeroCoordinates"].meta.shape = (1, ) self._denseArray = numpy.zeros(shape, numpy.uint8) self._sparseNZ = blist.sorteddict() if self.inputs["deleteLabel"].ready( ) and self.inputs["deleteLabel"].value != -1: labelNr = self.inputs["deleteLabel"].value neutralElement = 0 self.inputs["deleteLabel"].setValue(-1) #reset state of inputslot self.lock.acquire() # Find the entries to remove updateNZ = numpy.nonzero( numpy.where(self._denseArray == labelNr, 1, 0)) if len(updateNZ) > 0: # Convert to 1-D indexes for the raveled version updateNZRavel = numpy.ravel_multi_index( updateNZ, self._denseArray.shape) # Zero out the entries we don't want any more self._denseArray.ravel()[updateNZRavel] = neutralElement # Remove the zeros from the sparse list for index in updateNZRavel: self._sparseNZ.pop(index) # Labels are continuous values: Shift all higher label values down by 1. self._denseArray[:] = numpy.where(self._denseArray > labelNr, self._denseArray - 1, self._denseArray) self._maxLabel = self._denseArray.max() self.lock.release() self.outputs["nonzeroValues"].setDirty(slice(None)) self.outputs["nonzeroCoordinates"].setDirty(slice(None)) self.outputs["Output"].setDirty(slice(None)) self.outputs["maxLabel"].setValue(self._maxLabel) def execute(self, slot, subindex, roi, result): key = roiToSlice(roi.start, roi.stop) self.lock.acquire() assert ( self.inputs["eraser"].ready() == True and self.inputs["shape"].ready() == True ), "OpDenseSparseArray: One of the neccessary input slots is not ready: shape: %r, eraser: %r" % ( self.inputs["eraser"].ready(), self.inputs["shape"].ready()) if slot.name == "Output": result[:] = self._denseArray[key] elif slot.name == "nonzeroValues": result[0] = numpy.array(self._sparseNZ.values()) elif slot.name == "nonzeroCoordinates": result[0] = numpy.array(self._sparseNZ.keys()) elif slot.name == "maxLabel": result[0] = self._maxLabel self.lock.release() return result def setInSlot(self, slot, subindex, roi, value): key = roi.toSlice() assert value.dtype == self._denseArray.dtype, "Labels must be {}".format( self._denseArray.dtype) assert isinstance(value, numpy.ndarray) if type(value) != numpy.ndarray: # vigra.VigraArray doesn't handle advanced indexing correctly, # so convert to numpy.ndarray first value = value.view(numpy.ndarray) shape = self.inputs["shape"].value eraseLabel = self.inputs["eraser"].value neutralElement = 0 self.lock.acquire() #fix slicing of single dimensions: start, stop = sliceToRoi(key, shape, extendSingleton=False) start = start.floor()._asint() stop = stop.floor()._asint() tempKey = roiToSlice(start - start, stop - start) #, hardBind = True) stop += numpy.where(stop - start == 0, 1, 0) key = roiToSlice(start, stop) updateShape = tuple(stop - start) update = self._denseArray[key].copy() update[tempKey] = value startRavel = numpy.ravel_multi_index(numpy.array(start, numpy.int32), shape) #insert values into dict updateNZ = numpy.nonzero(numpy.where(update != neutralElement, 1, 0)) updateNZRavelSmall = numpy.ravel_multi_index(updateNZ, updateShape) if isinstance(value, numpy.ndarray): valuesNZ = value.ravel()[updateNZRavelSmall] else: valuesNZ = value updateNZRavel = numpy.ravel_multi_index(updateNZ, shape) updateNZRavel += startRavel self._denseArray.ravel()[updateNZRavel] = valuesNZ valuesNZ = self._denseArray.ravel()[updateNZRavel] self._denseArray.ravel()[updateNZRavel] = valuesNZ td = blist.sorteddict(zip(updateNZRavel.tolist(), valuesNZ.tolist())) self._sparseNZ.update(td) #remove values to be deleted updateNZ = numpy.nonzero(numpy.where(update == eraseLabel, 1, 0)) if len(updateNZ) > 0: updateNZRavel = numpy.ravel_multi_index(updateNZ, shape) updateNZRavel += startRavel self._denseArray.ravel()[updateNZRavel] = neutralElement for index in updateNZRavel: self._sparseNZ.pop(index) # Update our maxlabel self._maxLabel = self._denseArray.max() self.lock.release() # Set our max label dirty if necessary self.outputs["maxLabel"].setValue(self._maxLabel) self.outputs["Output"].setDirty(key) def propagateDirty(self, dirtySlot, subindex, roi): if dirtySlot == self.Input: self.Output.setDirty(roi) else: # All other inputs are single-value inputs that will trigger # a new call to setupOutputs, which already sets the outputs dirty. # (See above.) pass
class OpThresholdTwoLevels(Operator): name = "OpThresholdTwoLevels" RawInput = InputSlot(optional=True) # Display only InputChannelColors = InputSlot(optional=True) # Display only InputImage = InputSlot() MinSize = InputSlot(stype='int', value=10) MaxSize = InputSlot(stype='int', value=1000000) HighThreshold = InputSlot(stype='float', value=0.5) LowThreshold = InputSlot(stype='float', value=0.2) SingleThreshold = InputSlot(stype='float', value=0.5) SmootherSigma = InputSlot(value={'x': 1.0, 'y': 1.0, 'z': 1.0}) Channel = InputSlot(value=0) CurOperator = InputSlot(stype='int', value=0) ## Graph-Cut options ## SingleThresholdGC = InputSlot(stype='float', value=0.5) Beta = InputSlot(value=.2) # apply thresholding before graph-cut UsePreThreshold = InputSlot(stype='bool', value=True) # margin around single object (only graph-cut) Margin = InputSlot(value=numpy.asarray((20, 20, 20))) ## Output slots ## Output = OutputSlot() CachedOutput = OutputSlot() # For the GUI (blockwise-access) # For serialization InputHdf5 = InputSlot(optional=True) CleanBlocks = OutputSlot() OutputHdf5 = OutputSlot() ## Debug outputs InputChannel = OutputSlot() Smoothed = OutputSlot() BigRegions = OutputSlot() SmallRegions = OutputSlot() FilteredSmallLabels = OutputSlot() BeforeSizeFilter = OutputSlot() def __init__(self, *args, **kwargs): super(OpThresholdTwoLevels, self).__init__(*args, **kwargs) self.InputImage.notifyReady(self.checkConstraints) self._opReorder1 = OpReorderAxes(parent=self) self._opReorder1.AxisOrder.setValue('txyzc') self._opReorder1.Input.connect(self.InputImage) self._opChannelSelector = OpSingleChannelSelector(parent=self) self._opChannelSelector.Input.connect(self._opReorder1.Output) self._opChannelSelector.Index.connect(self.Channel) # anisotropic gauss self._opSmoother = OpAnisotropicGaussianSmoothing5d(parent=self) self._opSmoother.Sigmas.connect(self.SmootherSigma) self._opSmoother.Input.connect(self._opChannelSelector.Output) # debug output self.Smoothed.connect(self._opSmoother.Output) # single threshold operator self.opThreshold1 = _OpThresholdOneLevel(parent=self) self.opThreshold1.Threshold.connect(self.SingleThreshold) self.opThreshold1.MinSize.connect(self.MinSize) self.opThreshold1.MaxSize.connect(self.MaxSize) # double threshold operator self.opThreshold2 = _OpThresholdTwoLevels(parent=self) self.opThreshold2.MinSize.connect(self.MinSize) self.opThreshold2.MaxSize.connect(self.MaxSize) self.opThreshold2.LowThreshold.connect(self.LowThreshold) self.opThreshold2.HighThreshold.connect(self.HighThreshold) # Identity-preserving hysteresis thresholding self.opIpht = OpIpht(parent=self) self.opIpht.MinSize.connect(self.MinSize) self.opIpht.MaxSize.connect(self.MaxSize) self.opIpht.LowThreshold.connect(self.LowThreshold) self.opIpht.HighThreshold.connect(self.HighThreshold) self.opIpht.InputImage.connect(self._opSmoother.Output) if haveGraphCut(): self.opThreshold1GC = _OpThresholdOneLevel(parent=self) self.opThreshold1GC.Threshold.connect(self.SingleThresholdGC) self.opThreshold1GC.MinSize.connect(self.MinSize) self.opThreshold1GC.MaxSize.connect(self.MaxSize) self.opObjectsGraphCut = OpObjectsSegment(parent=self) self.opObjectsGraphCut.Prediction.connect(self.Smoothed) self.opObjectsGraphCut.LabelImage.connect( self.opThreshold1GC.Output) self.opObjectsGraphCut.Beta.connect(self.Beta) self.opObjectsGraphCut.Margin.connect(self.Margin) self.opGraphCut = OpGraphCut(parent=self) self.opGraphCut.Prediction.connect(self.Smoothed) self.opGraphCut.Beta.connect(self.Beta) self._op5CacheOutput = OpReorderAxes(parent=self) self._opReorder2 = OpReorderAxes(parent=self) self.Output.connect(self._opReorder2.Output) #cache our own output, don't propagate from internal operator self._cache = _OpCacheWrapper(parent=self) self._cache.name = "OpThresholdTwoLevels.OpCacheWrapper" self.CachedOutput.connect(self._cache.Output) # Serialization slots self._cache.InputHdf5.connect(self.InputHdf5) self.CleanBlocks.connect(self._cache.CleanBlocks) self.OutputHdf5.connect(self._cache.OutputHdf5) #Debug outputs self.InputChannel.connect(self._opChannelSelector.Output) def setupOutputs(self): self._opReorder2.AxisOrder.setValue(self.InputImage.meta.getAxisKeys()) # propagate drange self.opThreshold1.InputImage.meta.drange = self.InputImage.meta.drange if haveGraphCut(): self.opThreshold1GC.InputImage.meta.drange = self.InputImage.meta.drange self.opThreshold2.InputImage.meta.drange = self.InputImage.meta.drange self._disconnectAll() curIndex = self.CurOperator.value if curIndex == 0: outputSlot = self._connectForSingleThreshold(self.opThreshold1) elif curIndex == 1: outputSlot = self._connectForTwoLevelThreshold() elif curIndex == 2: outputSlot = self._connectForGraphCut() elif curIndex == 3: outputSlot = self.opIpht.Output else: raise ValueError( "Unknown index {} for current tab.".format(curIndex)) self._opReorder2.Input.connect(outputSlot) # force the cache to emit a dirty signal self._cache.Input.connect(outputSlot) self._cache.Input.setDirty(slice(None)) def checkConstraints(self, *args): if self._opReorder1.Output.ready(): numChannels = self._opReorder1.Output.meta.getTaggedShape()['c'] if self.Channel.value >= numChannels: raise DatasetConstraintError( "Two-Level Thresholding", "Your project is configured to select data from channel" " #{}, but your input data only has {} channels.".format( self.Channel.value, numChannels)) def _disconnectAll(self): # start from back for slot in [ self.BigRegions, self.SmallRegions, self.FilteredSmallLabels, self.BeforeSizeFilter ]: slot.disconnect() slot.meta.NOTREADY = True self._opReorder2.Input.disconnect() if haveGraphCut(): self.opThreshold1GC.InputImage.disconnect() self.opThreshold1.InputImage.disconnect() self.opThreshold2.InputImage.disconnect() def _connectForSingleThreshold(self, threshOp): # connect the operators for SingleThreshold self.BeforeSizeFilter.connect(threshOp.BeforeSizeFilter) self.BeforeSizeFilter.meta.NOTREADY = None threshOp.InputImage.connect(self._opSmoother.Output) return threshOp.Output def _connectForTwoLevelThreshold(self): # connect the operators for TwoLevelThreshold self.BigRegions.connect(self.opThreshold2.BigRegions) self.SmallRegions.connect(self.opThreshold2.SmallRegions) self.FilteredSmallLabels.connect(self.opThreshold2.FilteredSmallLabels) for slot in [ self.BigRegions, self.SmallRegions, self.FilteredSmallLabels ]: slot.meta.NOTREADY = None self.opThreshold2.InputImage.connect(self._opSmoother.Output) return self.opThreshold2.Output def _connectForGraphCut(self): assert haveGraphCut(), "Module for graph cut is not available" if self.UsePreThreshold.value: self._connectForSingleThreshold(self.opThreshold1GC) return self.opObjectsGraphCut.Output else: return self.opGraphCut.Output # raise an error if setInSlot is called, we do not pre-cache input #def setInSlot(self, slot, subindex, roi, value): #pass def execute(self, slot, subindex, roi, destination): assert False, "Shouldn't get here." def propagateDirty(self, slot, subindex, roi): # dirtiness propagation is handled in the sub-operators pass def setInSlot(self, slot, subindex, roi, value): assert slot == self.InputHdf5,\ "[{}] Wrong slot for setInSlot(): {}".format(self.name, slot) pass
class OpDataSelectionGroup(Operator): # Inputs ProjectFile = InputSlot(stype='object', optional=True) ProjectDataGroup = InputSlot(stype='string', optional=True) WorkingDirectory = InputSlot(stype='filestring') DatasetRoles = InputSlot(stype='object') # Must mark as optional because not all subslots are required. DatasetGroup = InputSlot(stype='object', level=1, optional=True) # Outputs ImageGroup = OutputSlot(level=1) # These output slots are provided as a convenience, since otherwise it is tricky to create a lane-wise multislot of # level-1 for only a single role. # (It can be done, but requires OpTransposeSlots to invert the level-2 multislot indexes...) Image = OutputSlot() # The first dataset. Equivalent to ImageGroup[0] Image1 = OutputSlot() # The second dataset. Equivalent to ImageGroup[1] Image2 = OutputSlot() # The third dataset. Equivalent to ImageGroup[2] AllowLabels = OutputSlot( stype='bool') # Pulled from the first dataset only. _NonTransposedImageGroup = OutputSlot(level=1) # Must be the LAST slot declared in this class. # When the shell detects that this slot has been resized, # it assumes all the others have already been resized. ImageName = OutputSlot( ) # Name of the first dataset is used. Other names are ignored. def __init__(self, forceAxisOrder=None, *args, **kwargs): super(OpDataSelectionGroup, self).__init__(*args, **kwargs) self._opDatasets = None self._roles = [] self._forceAxisOrder = forceAxisOrder def handleNewRoles(*args): self.DatasetGroup.resize(len(self.DatasetRoles.value)) self.DatasetRoles.notifyReady(handleNewRoles) def setupOutputs(self): # Create internal operators if self.DatasetRoles.value != self._roles: self._roles = self.DatasetRoles.value # Clean up the old operators self.ImageGroup.disconnect() self.Image.disconnect() self.Image1.disconnect() self.Image2.disconnect() self._NonTransposedImageGroup.disconnect() if self._opDatasets is not None: self._opDatasets.cleanUp() self._opDatasets = OperatorWrapper( OpDataSelection, parent=self, operator_kwargs={'forceAxisOrder': self._forceAxisOrder}, broadcastingSlotNames=[ 'ProjectFile', 'ProjectDataGroup', 'WorkingDirectory' ]) self.ImageGroup.connect(self._opDatasets.Image) self._NonTransposedImageGroup.connect( self._opDatasets._NonTransposedImage) self._opDatasets.Dataset.connect(self.DatasetGroup) self._opDatasets.ProjectFile.connect(self.ProjectFile) self._opDatasets.ProjectDataGroup.connect(self.ProjectDataGroup) self._opDatasets.WorkingDirectory.connect(self.WorkingDirectory) for role_index, opDataSelection in enumerate(self._opDatasets): opDataSelection.RoleName.setValue(self._roles[role_index]) if len(self._opDatasets.Image) > 0: self.Image.connect(self._opDatasets.Image[0]) if len(self._opDatasets.Image) >= 2: self.Image1.connect(self._opDatasets.Image[1]) else: self.Image1.disconnect() self.Image1.meta.NOTREADY = True if len(self._opDatasets.Image) >= 3: self.Image2.connect(self._opDatasets.Image[2]) else: self.Image2.disconnect() self.Image2.meta.NOTREADY = True self.ImageName.connect(self._opDatasets.ImageName[0]) self.AllowLabels.connect(self._opDatasets.AllowLabels[0]) else: self.Image.disconnect() self.Image1.disconnect() self.Image2.disconnect() self.ImageName.disconnect() self.AllowLabels.disconnect() self.Image.meta.NOTREADY = True self.Image1.meta.NOTREADY = True self.Image2.meta.NOTREADY = True self.ImageName.meta.NOTREADY = True self.AllowLabels.meta.NOTREADY = True def execute(self, slot, subindex, rroi, result): assert False, "Unknown or unconnected output slot: {}".format( slot.name) def propagateDirty(self, slot, subindex, roi): # Output slots are directly connected to internal operators pass
class OpLabeledThreshold(Operator): Input = InputSlot() # Must have exactly 1 channel CoreLabels = InputSlot(optional=True) # Not used for 'Simple' method. Method = InputSlot(value=ThresholdMethod.SIMPLE) FinalThreshold = InputSlot(value=0.2) GraphcutBeta = InputSlot(value=0.2) # Graphcut only Output = OutputSlot() def __init__(self, *args, **kwargs): super(OpLabeledThreshold, self).__init__(*args, **kwargs) execute_funcs = {} execute_funcs[ThresholdMethod.SIMPLE] = self._execute_SIMPLE execute_funcs[ThresholdMethod.HYSTERESIS] = self._execute_HYSTERESIS execute_funcs[ThresholdMethod.GRAPHCUT] = self._execute_GRAPHCUT execute_funcs[ThresholdMethod.IPHT] = self._execute_IPHT self.execute_funcs = execute_funcs def setupOutputs(self): assert self.Input.meta.getAxisKeys() == list("tzyxc") assert self.Input.meta.shape[-1] == 1 if self.CoreLabels.ready(): assert self.CoreLabels.meta.getAxisKeys() == list("tzyxc") self.Output.meta.assignFrom(self.Input.meta) self.Output.meta.dtype = np.uint32 def propagateDirty(self, slot, subindex, roi): self.Output.setDirty() def execute(self, slot, subindex, roi, result): result = vigra.taggedView(result, self.Output.meta.axistags) # Iterate over time slices to avoid connected component problems. for t_index, t in enumerate(range(roi.start[0], roi.stop[0])): t_slice_roi = roi.copy() t_slice_roi.start[0] = t t_slice_roi.stop[0] = t + 1 result_slice = result[t_index:t_index + 1] self.execute_funcs[self.Method.value](t_slice_roi, result_slice) def _execute_SIMPLE(self, roi, result): assert result.shape[0] == 1 assert tuple(roi.stop - roi.start) == result.shape final_threshold = self.FinalThreshold.value data = self.Input(roi.start, roi.stop).wait() data = vigra.taggedView(data, self.Input.meta.axistags) result = vigra.taggedView(result, self.Output.meta.axistags) binary = (data >= final_threshold).view(np.uint8) vigra.analysis.labelMultiArrayWithBackground(binary[0, ..., 0], out=result[0, ..., 0]) def _execute_HYSTERESIS(self, roi, result): self._execute_SIMPLE(roi, result) final_labels = vigra.taggedView(result, self.Output.meta.axistags) core_labels = self.CoreLabels(roi.start, roi.stop).wait() core_labels = vigra.taggedView(core_labels, self.CoreLabels.meta.axistags) select_labels(core_labels, final_labels) # Edits final_labels in-place def _execute_IPHT(self, roi, result): core_labels = self.CoreLabels(roi.start, roi.stop).wait() core_labels = vigra.taggedView(core_labels, self.CoreLabels.meta.axistags) data = self.Input(roi.start, roi.stop).wait() data = vigra.taggedView(data, self.Input.meta.axistags) final_threshold = self.FinalThreshold.value result = vigra.taggedView(result, self.Output.meta.axistags) threshold_from_cores(data[0, ..., 0], core_labels[0, ..., 0], final_threshold, out=result[0, ..., 0]) def _execute_GRAPHCUT(self, roi, result): data = self.Input(roi.start, roi.stop).wait() data = vigra.taggedView(data, self.Input.meta.axistags) data_zyx = data[0, ..., 0] beta = self.GraphcutBeta.value ft = self.FinalThreshold.value # The segmentGC() function will implicitly threshold at 0.5, # but we want to respect the user's FinalThreshold setting. # Here, we scale from input pixels --> gc potentials in the following way: # # 0.0..FT --> 0.0..0.5 # 0.5..FT --> 0.5..1.0 # # For instance, input pixels that match the user's FT exactly will map to 0.5, # and graphcut will place them on the threshold border. above_threshold_mask = data_zyx >= ft below_threshold_mask = ~above_threshold_mask data_zyx[below_threshold_mask] *= old_div(0.5, ft) data_zyx[above_threshold_mask] = 0.5 + old_div( (data_zyx[above_threshold_mask] - ft), (1 - ft)) binary_seg_zyx = segmentGC(data_zyx, beta).astype(np.uint8) del data_zyx vigra.analysis.labelMultiArrayWithBackground(binary_seg_zyx, out=result[0, ..., 0])
class OpCompressedUserLabelArray(OpUnmanagedCompressedCache): """ A subclass of OpUnmanagedCompressedCache that is suitable for storing user-drawn label pixels. (This is not a 'managed' cache because its data must never be deleted by the memory manager.) Note that setInSlot has special functionality (only non-zero pixels are written, and there is also an "eraser" pixel value). See note below about blockshape changes. """ #Input = InputSlot() shape = InputSlot(optional=True) # Should not be used. eraser = InputSlot() deleteLabel = InputSlot(optional = True) blockShape = InputSlot() # If the blockshape is changed after labels have been stored, all cache data is lost. #Output = OutputSlot() #nonzeroValues = OutputSlot() #nonzeroCoordinates = OutputSlot() nonzeroBlocks = OutputSlot() #maxLabel = OutputSlot() Projection2D = OutputSlot(allow_mask=True) # A somewhat magic output that returns a projection of all # label data underneath a given roi, from all slices. # If, for example, a 256x1x256 tile is requested from this slot, # It will return a projection of ALL labels that fall within the 256 x ... x 256 tile. # (The projection axis is *inferred* from the shape of the requested data). # The projection data is float32 between 0.0 and 1.0, where: # - Exactly 0.0 means "no labels under this pixel" # - 1/256.0 means "labels in the first slice" # - ... # - 1.0 means "last slice" # The output is suitable for display in a colortable. def __init__(self, *args, **kwargs): self._blockshape = None self._label_to_purge = 0 super(OpCompressedUserLabelArray, self).__init__( *args, **kwargs ) # ignoring the ideal chunk shape is ok because we use the input only # to get the volume shape self._ignore_ideal_blockshape = True def clearLabel(self, label_value): """ Clear (reset to 0) all pixels of the given label value. Unlike using the deleteLabel slot, this function does not "shift down" all labels above this label value. """ self._purge_label( label_value, False ) def mergeLabels(self, from_label, into_label): self._purge_label(from_label, True, into_label) def setupOutputs(self): # Due to a temporary naming clash, pass our subclass blockshape to the superclass # TODO: Fix this by renaming the BlockShape slots to be consistent. self.BlockShape.setValue( self.blockShape.value ) super( OpCompressedUserLabelArray, self ).setupOutputs() if self.Output.meta.NOTREADY: self.nonzeroBlocks.meta.NOTREADY = True self.Projection2D.meta.NOTREADY = True return self.nonzeroBlocks.meta.dtype = object self.nonzeroBlocks.meta.shape = (1,) # Overwrite the Output metadata (should be uint8 no matter what the input data is...) self.Output.meta.assignFrom(self.Input.meta) self.Output.meta.dtype = numpy.uint8 self.Output.meta.shape = self.Input.meta.shape[:-1] + (1,) self.Output.meta.drange = (0,255) self.OutputHdf5.meta.assignFrom(self.Output.meta) # The Projection2D slot is a strange beast: # It appears to have the same output shape as any other output slot, # but it can only be accessed in 2D slices. self.Projection2D.meta.assignFrom(self.Output.meta) self.Projection2D.meta.dtype = numpy.float32 self.Projection2D.meta.drange = (0.0, 1.0) # Overwrite the blockshape if self._blockshape is None: self._blockshape = numpy.minimum( self.BlockShape.value, self.Output.meta.shape ) elif self.blockShape.value != self._blockshape: nonzero_blocks_destination = [None] self._execute_nonzeroBlocks(nonzero_blocks_destination) nonzero_blocks = nonzero_blocks_destination[0] if len(nonzero_blocks) > 0: raise RuntimeError( "You are not permitted to reconfigure the labeling operator after you've already stored labels in it." ) # Overwrite chunkshape now that blockshape has been overwritten self._chunkshape = self._chooseChunkshape(self._blockshape) self._eraser_magic_value = self.eraser.value # Are we being told to delete a label? if self.deleteLabel.ready(): new_purge_label = self.deleteLabel.value if self._label_to_purge != new_purge_label: self._label_to_purge = new_purge_label if self._label_to_purge > 0: self._purge_label( self._label_to_purge, True ) def _purge_label(self, label_to_purge, decrement_remaining, replacement_value=0): """ Scan through all labeled pixels. (1) Reassign all pixels of the given value (set to replacement_value) (2) If decrement_remaining=True, decrement all labels above that value so the set of stored labels remains consecutive. Note that the decrement is performed AFTER replacement. """ changed_block_rois = [] #stored_block_rois = self.CleanBlocks.value stored_block_roi_destination = [None] self.execute(self.CleanBlocks, (), SubRegion( self.Output, (0,),(1,) ), stored_block_roi_destination) stored_block_rois = stored_block_roi_destination[0] for block_roi in stored_block_rois: # Get data block_shape = numpy.subtract( block_roi[1], block_roi[0] ) block = self.Output.stype.allocateDestination(SubRegion(self.Output, *roiFromShape(block_shape))) self.execute(self.Output, (), SubRegion( self.Output, *block_roi ), block) # Locate pixels to change matching_label_coords = numpy.nonzero( block == label_to_purge ) # Change the data block[matching_label_coords] = replacement_value coords_to_decrement = block > label_to_purge if decrement_remaining: block[coords_to_decrement] -= 1 # Update cache with the new data (only if something really changed) if len(matching_label_coords[0]) > 0 or (decrement_remaining and coords_to_decrement.sum() > 0): super( OpCompressedUserLabelArray, self )._setInSlotInput( self.Input, (), SubRegion( self.Output, *block_roi ), block, store_zero_blocks=False ) changed_block_rois.append( block_roi ) for block_roi in changed_block_rois: # FIXME: Shouldn't this dirty notification be handled in OpUnmanagedCompressedCache? self.Output.setDirty( *block_roi ) def execute(self, slot, subindex, roi, destination): if slot == self.Output: self._executeOutput(roi, destination) elif slot == self.nonzeroBlocks: self._execute_nonzeroBlocks(destination) elif slot == self.Projection2D: self._executeProjection2D(roi, destination) else: return super( OpCompressedUserLabelArray, self ).execute( slot, subindex, roi, destination ) def _executeOutput(self, roi, destination): assert len(roi.stop) == len(self.Output.meta.shape), \ "roi: {} has the wrong number of dimensions for Output shape: {}"\ "".format( roi, self.Output.meta.shape ) assert numpy.less_equal(roi.stop, self.Output.meta.shape).all(), \ "roi: {} is out-of-bounds for Output shape: {}"\ "".format( roi, self.Output.meta.shape ) block_starts = getIntersectingBlocks( self._blockshape, (roi.start, roi.stop) ) self._copyData(roi, destination, block_starts) return destination def _execute_nonzeroBlocks(self, destination): stored_block_rois_destination = [None] self._executeCleanBlocks( stored_block_rois_destination ) stored_block_rois = stored_block_rois_destination[0] block_slicings = map( lambda block_roi: roiToSlice(*block_roi), stored_block_rois ) destination[0] = block_slicings def _executeProjection2D(self, roi, destination): assert sum(TinyVector(destination.shape) > 1) <= 2, "Projection result must be exactly 2D" # First, we have to determine which axis we are projecting along. # We infer this from the shape of the roi. # For example, if the roi is of shape # zyx = (1,256,256), then we know we're projecting along Z # If more than one axis has a width of 1, then we choose an # axis according to the following priority order: zyxt tagged_input_shape = self.Output.meta.getTaggedShape() tagged_result_shape = collections.OrderedDict( zip( tagged_input_shape.keys(), destination.shape ) ) nonprojection_axes = [] for key in tagged_input_shape.keys(): if (key == 'c' or tagged_input_shape[key] == 1 or tagged_result_shape[key] > 1): nonprojection_axes.append( key ) possible_projection_axes = set(tagged_input_shape) - set(nonprojection_axes) if len(possible_projection_axes) == 0: # If the image is 2D to begin with, # then the projection is simply the same as the normal output, # EXCEPT it is made binary self.Output(roi.start, roi.stop).writeInto(destination).wait() # make binary numpy.greater(destination, 0, out=destination) return for k in 'zyxt': if k in possible_projection_axes: projection_axis_key = k break # Now we know which axis we're projecting along. # Proceed with the projection, working blockwise to avoid unecessary work in unlabeled blocks projection_axis_index = self.Output.meta.getAxisKeys().index(projection_axis_key) projection_length = tagged_input_shape[projection_axis_key] input_roi = roi.copy() input_roi.start[projection_axis_index] = 0 input_roi.stop[projection_axis_index] = projection_length destination[:] = 0.0 # Get the logical blocking. block_starts = getIntersectingBlocks( self._blockshape, (input_roi.start, input_roi.stop) ) # (Parallelism wouldn't help here: h5py will serialize these requests anyway) block_starts = map( tuple, block_starts ) for block_start in block_starts: if block_start not in self._cacheFiles: # No label data in this block. Move on. continue entire_block_roi = getBlockBounds( self.Output.meta.shape, self._blockshape, block_start ) # This block's portion of the roi intersecting_roi = getIntersection( (input_roi.start, input_roi.stop), entire_block_roi ) # Compute slicing within the deep array and slicing within this block deep_relative_intersection = numpy.subtract(intersecting_roi, input_roi.start) block_relative_intersection = numpy.subtract(intersecting_roi, block_start) block_relative_intersection_slicing = roiToSlice(*block_relative_intersection) block = self._getBlockDataset( entire_block_roi ) deep_data = None if self.Output.meta.has_mask: deep_data = numpy.ma.masked_array( block["data"][block_relative_intersection_slicing], mask=block["mask"][block_relative_intersection_slicing], fill_value=block["fill_value"][()], shrink=False ) else: deep_data = block[block_relative_intersection_slicing] # make binary and convert to float (must copy) deep_data_float = deep_data.astype(numpy.float32) deep_data_float[deep_data_float.nonzero()] = 1 # multiply by slice-index deep_data_view = numpy.rollaxis(deep_data_float, projection_axis_index, 0) min_deep_slice_index = deep_relative_intersection[0][projection_axis_index] max_deep_slice_index = deep_relative_intersection[1][projection_axis_index] def calc_color_value(slice_index): # Note 1: We assume that the colortable has at least 256 entries in it, # so, we try to ensure that all colors are above 1/256 # (we don't want colors in low slices to be rounded to 0) # Note 2: Ideally, we'd use a min projection in the code below, so that # labels in the "back" slices would appear occluded. But the # min projection would favor 0.0. Instead, we invert the # relationship between color and slice index, do a max projection, # and then re-invert the colors after everything is done. # Hence, this function starts with (1.0 - ...) return (1.0 - (float(slice_index) / projection_length)) * (1.0 - 1.0/255) + 1.0/255.0 min_color_value = calc_color_value(min_deep_slice_index) max_color_value = calc_color_value(max_deep_slice_index) num_slices = max_deep_slice_index - min_deep_slice_index deep_data_view *= numpy.linspace( min_color_value, max_color_value, num=num_slices )\ [ (slice(None),) + (None,)*(deep_data_view.ndim-1) ] # Take the max projection of this block's data. block_max_projection = deep_data_float.max(axis=projection_axis_index) block_max_projection = numpy.ma.expand_dims(block_max_projection, axis=projection_axis_index) # Merge this block's projection into the overall projection. destination_relative_intersection = numpy.array(deep_relative_intersection) destination_relative_intersection[:, projection_axis_index] = (0,1) destination_relative_intersection_slicing = roiToSlice(*destination_relative_intersection) destination_subview = destination[destination_relative_intersection_slicing] numpy.maximum(block_max_projection, destination_subview, out=destination_subview) # Invert the nonzero pixels so increasing colors correspond to increasing slices. # See comment in calc_color_value(), above. destination_subview[destination_subview.nonzero()] -= 1 destination_subview[()] = -destination_subview return def _copyData(self, roi, destination, block_starts): """ Copy data from each block into the destination array. For blocks that aren't currently stored, just write zeros. """ # (Parallelism not needed here: h5py will serialize these requests anyway) block_starts = map( tuple, block_starts ) for block_start in block_starts: entire_block_roi = getBlockBounds( self.Output.meta.shape, self._blockshape, block_start ) # This block's portion of the roi intersecting_roi = getIntersection( (roi.start, roi.stop), entire_block_roi ) # Compute slicing within destination array and slicing within this block destination_relative_intersection = numpy.subtract(intersecting_roi, roi.start) block_relative_intersection = numpy.subtract(intersecting_roi, block_start) destination_relative_intersection_slicing = roiToSlice(*destination_relative_intersection) block_relative_intersection_slicing = roiToSlice(*block_relative_intersection) if block_start in self._cacheFiles: # Copy from block to destination dataset = self._getBlockDataset( entire_block_roi ) if self.Output.meta.has_mask: destination[ destination_relative_intersection_slicing ] = dataset["data"][ block_relative_intersection_slicing ] destination.mask[ destination_relative_intersection_slicing ] = dataset["mask"][ block_relative_intersection_slicing ] destination.fill_value = dataset["fill_value"][()] else: destination[ destination_relative_intersection_slicing ] = dataset[ block_relative_intersection_slicing ] else: # Not stored yet. Overwrite with zeros. destination[ destination_relative_intersection_slicing ] = 0 def propagateDirty(self, slot, subindex, roi): # There should be no way to make the output dirty except via setInSlot() pass def setInSlot(self, slot, subindex, roi, new_pixels): if slot is self.Input: self._setInSlotInput(slot, subindex, roi, new_pixels) else: # We don't yet support the InputHdf5 slot in this function. assert False, "Unsupported slot for setInSlot: {}".format( slot.name ) def _setInSlotInput(self, slot, subindex, roi, new_pixels): """ Since this is a label array, inserting pixels has a special meaning: We only overwrite the new non-zero pixels. In the new data, zeros mean "don't change". So, here's what each pixel we're adding means: 0: don't change 1: change to 1 2: change to 2 ... N: change to N magic_eraser_value: change to 0 """ # Extract the data to modify original_data = self.Output.stype.allocateDestination(SubRegion(self.Output, *roiFromShape(new_pixels.shape))) self.execute(self.Output, (), roi, original_data) # Reset the pixels we need to change (so we can use |= below) original_data[new_pixels.nonzero()] = 0 # Update original_data |= new_pixels # Replace 'eraser' values with zeros. cleaned_data = original_data.copy() cleaned_data[original_data == self._eraser_magic_value] = 0 # Set in the cache (our superclass). super( OpCompressedUserLabelArray, self )._setInSlotInput( slot, subindex, roi, cleaned_data, store_zero_blocks=False ) # FIXME: Shouldn't this notification be triggered from within OpUnmanagedCompressedCache? self.Output.setDirty( roi.start, roi.stop ) return cleaned_data # Internal use: Return the cleaned_data def ingestData(self, slot): """ Read the data from the given slot and copy it into this cache. The rules about special pixel meanings apply here, just like setInSlot Returns: the max label found in the slot. """ assert self._blockshape is not None assert self.Output.meta.shape[:-1] == slot.meta.shape[:-1], \ "{} != {}".format( self.Output.meta.shape, slot.meta.shape ) max_label = 0 # Get logical blocking. block_starts = getIntersectingBlocks( self._blockshape, roiFromShape(self.Output.meta.shape) ) block_starts = map( tuple, block_starts ) # Write each block for block_start in block_starts: block_roi = getBlockBounds( self.Output.meta.shape, self._blockshape, block_start ) # Request the block data block_data = slot(*block_roi).wait() # Write into the array subregion_roi = SubRegion(self.Output, *block_roi) cleaned_block_data = self._setInSlotInput( self.Input, (), subregion_roi, block_data ) max_label = max( max_label, cleaned_block_data.max() ) return max_label
class OpThresholdTwoLevels(Operator): RawInput = InputSlot(optional=True) # Display only InputChannelColors = InputSlot(optional=True) # Display only InputImage = InputSlot() MinSize = InputSlot(stype="int", value=10) MaxSize = InputSlot(stype="int", value=1000000) HighThreshold = InputSlot(stype="float", value=0.8) LowThreshold = InputSlot(stype="float", value=0.5) SmootherSigma = InputSlot(value={"z": 1.0, "y": 1.0, "x": 1.0}) Channel = InputSlot(value=0) CoreChannel = InputSlot(value=0) CurOperator = InputSlot(stype="int", value=ThresholdMethod.SIMPLE ) # This slot would be better named 'method', # but we're keeping this slot name for backwards # compatibility with old project files Beta = InputSlot(value=0.2) # For GraphCut ## Output slots ## Output = OutputSlot() CachedOutput = OutputSlot() # For the GUI (blockwise-access) # For serialization CacheInput = InputSlot(optional=True) CleanBlocks = OutputSlot() ## Debug outputs InputChannel = OutputSlot() Smoothed = OutputSlot() BigRegions = OutputSlot() SmallRegions = OutputSlot() FilteredSmallLabels = OutputSlot() BeforeSizeFilter = OutputSlot() ## Basic schematic (debug outputs not shown) ## ## InputImage -> opReorder -> opSmoother -> opSmootherCache --> opFinalChannelSelector --> opSumInputs --------------------> opFinalThreshold -> opFinalFilter -> opReorder -> Output ## \ / / \ ## --> opCoreChannelSelector --> opCoreThreshold -> opCoreFilter -- opCache -> CachedOutput ## `-> CleanBlocks def __init__(self, *args, **kwargs): super(OpThresholdTwoLevels, self).__init__(*args, **kwargs) self.opReorderInput = OpReorderAxes(parent=self) self.opReorderInput.AxisOrder.setValue("tzyxc") self.opReorderInput.Input.connect(self.InputImage) # PROBABILITIES: Convert to float32 self.opConvertProbabilities = OpConvertDtype(parent=self) self.opConvertProbabilities.ConversionDtype.setValue(np.float32) self.opConvertProbabilities.Input.connect(self.opReorderInput.Output) # PROBABILITIES: Normalize drange to [0.0, 1.0] self.opNormalizeProbabilities = OpPixelOperator(parent=self) def normalize_inplace(a): drange = self.opNormalizeProbabilities.Input.meta.drange if drange is None or (drange[0] == 0.0 and drange[1] == 1.0): return a a[:] -= drange[0] a[:] = a[:] / float((drange[1] - drange[0])) return a self.opNormalizeProbabilities.Input.connect( self.opConvertProbabilities.Output) self.opNormalizeProbabilities.Function.setValue(normalize_inplace) self.opSmoother = OpAnisotropicGaussianSmoothing5d(parent=self) self.opSmoother.Sigmas.connect(self.SmootherSigma) self.opSmoother.Input.connect(self.opNormalizeProbabilities.Output) self.opSmootherCache = OpBlockedArrayCache(parent=self) self.opSmootherCache.BlockShape.setValue((1, None, None, None, 1)) self.opSmootherCache.Input.connect(self.opSmoother.Output) self.opCoreChannelSelector = OpSingleChannelSelector(parent=self) self.opCoreChannelSelector.Index.connect(self.CoreChannel) self.opCoreChannelSelector.Input.connect(self.opSmootherCache.Output) self.opCoreThreshold = OpLabeledThreshold(parent=self) self.opCoreThreshold.Method.setValue(ThresholdMethod.SIMPLE) self.opCoreThreshold.FinalThreshold.connect(self.HighThreshold) self.opCoreThreshold.Input.connect(self.opCoreChannelSelector.Output) self.opCoreFilter = OpFilterLabels(parent=self) self.opCoreFilter.BinaryOut.setValue(False) self.opCoreFilter.MinLabelSize.connect(self.MinSize) self.opCoreFilter.MaxLabelSize.connect(self.MaxSize) self.opCoreFilter.Input.connect(self.opCoreThreshold.Output) self.opFinalChannelSelector = OpSingleChannelSelector(parent=self) self.opFinalChannelSelector.Index.connect(self.Channel) self.opFinalChannelSelector.Input.connect(self.opSmootherCache.Output) self.opSumInputs = OpMultiArrayMerger( parent=self) # see setupOutputs (below) for input connections self.opSumInputs.MergingFunction.setValue(sum) self.opFinalThreshold = OpLabeledThreshold(parent=self) self.opFinalThreshold.Method.connect(self.CurOperator) self.opFinalThreshold.FinalThreshold.connect(self.LowThreshold) self.opFinalThreshold.GraphcutBeta.connect(self.Beta) self.opFinalThreshold.CoreLabels.connect(self.opCoreFilter.Output) self.opFinalThreshold.Input.connect(self.opSumInputs.Output) self.opFinalFilter = OpFilterLabels(parent=self) self.opFinalFilter.BinaryOut.setValue(False) self.opFinalFilter.MinLabelSize.connect(self.MinSize) self.opFinalFilter.MaxLabelSize.connect(self.MaxSize) self.opFinalFilter.Input.connect(self.opFinalThreshold.Output) self.opReorderOutput = OpReorderAxes(parent=self) # self.opReorderOutput.AxisOrder.setValue('tzyxc') # See setupOutputs() self.opReorderOutput.Input.connect(self.opFinalFilter.Output) self.Output.connect(self.opReorderOutput.Output) self.opCache = OpBlockedArrayCache(parent=self) self.opCache.CompressionEnabled.setValue(True) self.opCache.Input.connect(self.opReorderOutput.Output) self.CachedOutput.connect(self.opCache.Output) self.CleanBlocks.connect(self.opCache.CleanBlocks) ## Debug outputs self.Smoothed.connect(self.opSmootherCache.Output) self.InputChannel.connect(self.opFinalChannelSelector.Output) self.SmallRegions.connect(self.opCoreThreshold.Output) self.BeforeSizeFilter.connect(self.opFinalThreshold.Output) self.opFilteredSmallLabelsCache = OpBlockedArrayCache(parent=self) self.opFilteredSmallLabelsCache.CompressionEnabled.setValue(True) self.opFilteredSmallLabelsCache.Input.connect(self.opCoreFilter.Output) self.FilteredSmallLabels.connect( self.opFilteredSmallLabelsCache.Output) # Since hysteresis thresholding creates the big regions and immediately discards the bad ones, # we have to recreate it here if the user wants to view it as a debug layer self.opBigRegionsThreshold = OpLabeledThreshold(parent=self) self.opBigRegionsThreshold.Method.setValue(ThresholdMethod.SIMPLE) self.opBigRegionsThreshold.FinalThreshold.connect(self.LowThreshold) self.opBigRegionsThreshold.Input.connect( self.opFinalChannelSelector.Output) self.BigRegions.connect(self.opBigRegionsThreshold.Output) def setupOutputs(self): axes = self.InputImage.meta.getAxisKeys() self.opReorderOutput.AxisOrder.setValue(axes) # Cache individual t,c slices blockshape = tuple(1 if k in "tc" else None for k in axes) self.opCache.BlockShape.setValue(blockshape) # assuming (t, c, z, y, x) here. self.opFilteredSmallLabelsCache.BlockShape.setValue( (1, 1, None, None, None)) if (self.CurOperator.value in (ThresholdMethod.HYSTERESIS, ThresholdMethod.IPHT) and self.Channel.value != self.CoreChannel.value): self.opSumInputs.Inputs.resize(2) self.opSumInputs.Inputs[0].connect( self.opFinalChannelSelector.Output) self.opSumInputs.Inputs[1].connect( self.opCoreChannelSelector.Output) else: self.opSumInputs.Inputs.resize(1) self.opSumInputs.Inputs[0].connect( self.opFinalChannelSelector.Output) def setInSlot(self, slot, subindex, roi, value): self.opCache.setInSlot(self.opCache.Input, subindex, roi, value) def execute(self, slot, subindex, roi, destination): assert False, "Shouldn't get here." def propagateDirty(self, slot, subindex, roi): pass # dirtiness propagation is handled in the sub-operators
class OpExportDvidVolume(Operator): Input = InputSlot() NodeDataUrl = InputSlot( ) # Should be a url of the form http://<hostname>[:<port>]/api/node/<uuid>/<dataname> OffsetCoord = InputSlot(optional=True) def __init__(self, transpose_axes, *args, **kwargs): super(OpExportDvidVolume, self).__init__(*args, **kwargs) self.progressSignal = OrderedSignal() self._transpose_axes = transpose_axes # No output slots... def setupOutputs(self): if self._transpose_axes: assert self.Input.meta.axistags.channelIndex == len( self.Input.meta.axistags) - 1 else: assert self.Input.meta.axistags.channelIndex == 0 def execute(self, slot, subindex, roi, result): pass def propagateDirty(self, slot, subindex, roi): pass def run_export(self): self.progressSignal(0) url = self.NodeDataUrl.value url_path = url.split('://')[1] hostname, api, node, uuid, dataname = url_path.split('/') assert api == 'api' assert node == 'node' axiskeys = self.Input.meta.getAxisKeys() shape = self.Input.meta.shape if self._transpose_axes: axiskeys = reversed(axiskeys) shape = tuple(reversed(shape)) axiskeys = "".join(axiskeys) if self.OffsetCoord.ready(): offset_start = self.OffsetCoord.value else: offset_start = (0, ) * len(self.Input.meta.shape) connection = pydvid.dvid_connection.DvidConnection(hostname) with contextlib.closing(connection): self.progressSignal(5) # Get the dataset details try: metadata = pydvid.voxels.get_metadata(connection, uuid, dataname) except pydvid.errors.DvidHttpError as ex: if ex.status_code != 404: raise # Dataset doesn't exist yet. Let's create it. metadata = pydvid.voxels.VoxelsMetadata.create_default_metadata( shape, self.Input.meta.dtype, axiskeys, 0.0, "") pydvid.voxels.create_new(connection, uuid, dataname, metadata) # Since this class is generall used to push large blocks of data, # we'll be nice and set throttle=True client = pydvid.voxels.VoxelsAccessor(connection, uuid, dataname, throttle=True) def handle_block_result(roi, data): # Send it to dvid roi = numpy.asarray(roi) roi += offset_start start, stop = roi if self._transpose_axes: data = data.transpose() start = tuple(reversed(start)) stop = tuple(reversed(stop)) client.post_ndarray(start, stop, data) requester = BigRequestStreamer(self.Input, roiFromShape(self.Input.meta.shape)) requester.resultSignal.subscribe(handle_block_result) requester.progressSignal.subscribe(self.progressSignal) requester.execute() self.progressSignal(100)
class OpAnisotropicGaussianSmoothing(Operator): Input = InputSlot() Sigmas = InputSlot(value={'x': 1.0, 'y': 1.0, 'z': 1.0}) Output = OutputSlot() def setupOutputs(self): self.Output.meta.assignFrom(self.Input.meta) #if there is a time of dim 1, output won't have that timeIndex = self.Output.meta.axistags.index('t') if timeIndex < len(self.Output.meta.shape): newshape = list(self.Output.meta.shape) newshape.pop(timeIndex) self.Output.meta.shape = tuple(newshape) del self.Output.meta.axistags[timeIndex] self.Output.meta.dtype = numpy.float32 # vigra gaussian only supports float32 self._sigmas = self.Sigmas.value assert isinstance(self.Sigmas.value, dict), "Sigmas slot expects a dict" assert set(self._sigmas.keys()) == set( 'xyz'), "Sigmas slot expects three key-value pairs for x,y,z" print("Assigning output: {} ====> {}".format( self.Input.meta.getTaggedShape(), self.Output.meta.getTaggedShape())) #self.Output.setDirty( slice(None) ) def execute(self, slot, subindex, roi, result): assert all( roi.stop <= self.Input.meta.shape ), "Requested roi {} is too large for this input image of shape {}.".format( roi, self.Input.meta.shape) # Determine how much input data we'll need, and where the result will be relative to that input roi inputRoi, computeRoi = self._getInputComputeRois(roi) # Obtain the input data with Timer() as resultTimer: data = self.Input(*inputRoi).wait() logger.debug("Obtaining input data took {} seconds for roi {}".format( resultTimer.seconds(), inputRoi)) xIndex = self.Input.meta.axistags.index('x') yIndex = self.Input.meta.axistags.index('y') zIndex = self.Input.meta.axistags.index( 'z') if self.Input.meta.axistags.index('z') < len( self.Input.meta.shape) else None cIndex = self.Input.meta.axistags.index( 'c') if self.Input.meta.axistags.index('c') < len( self.Input.meta.shape) else None # Must be float32 if data.dtype != numpy.float32: data = data.astype(numpy.float32) axiskeys = self.Input.meta.getAxisKeys() spatialkeys = filter(lambda k: k in 'xyz', axiskeys) # we need to remove a singleton z axis, otherwise we get # 'kernel longer than line' errors reskey = [slice(None, None, None)] * len(self.Input.meta.shape) reskey[cIndex] = 0 if zIndex and self.Input.meta.shape[zIndex] == 1: removedZ = True data = data.reshape((data.shape[xIndex], data.shape[yIndex])) reskey[zIndex] = 0 spatialkeys = filter(lambda k: k in 'xy', axiskeys) else: removedZ = False sigma = map(self._sigmas.get, spatialkeys) #Check if we need to smooth if any([x < 0.1 for x in sigma]): if removedZ: resultXY = vigra.taggedView(result, axistags="".join(axiskeys)) resultXY = resultXY.withAxes(*'xy') resultXY[:] = data else: result[:] = data return result # Smooth the input data smoothed = vigra.filters.gaussianSmoothing( data, sigma, window_size=2.0, roi=computeRoi, out=result[tuple(reskey)]) # FIXME: Assumes channel is last axis expectedShape = tuple( TinyVector(computeRoi[1]) - TinyVector(computeRoi[0])) assert tuple( smoothed.shape ) == expectedShape, "Smoothed data shape {} didn't match expected shape {}".format( smoothed.shape, roi.stop - roi.start) return result def _getInputComputeRois(self, roi): axiskeys = self.Input.meta.getAxisKeys() spatialkeys = filter(lambda k: k in 'xyz', axiskeys) sigma = map(self._sigmas.get, spatialkeys) inputSpatialShape = self.Input.meta.getTaggedShape() spatialRoi = (TinyVector(roi.start), TinyVector(roi.stop)) tIndex = None cIndex = None zIndex = None if 'c' in inputSpatialShape: del inputSpatialShape['c'] cIndex = axiskeys.index('c') if 't' in inputSpatialShape.keys(): assert inputSpatialShape['t'] == 1 tIndex = axiskeys.index('t') if 'z' in inputSpatialShape.keys() and inputSpatialShape['z'] == 1: #2D image, avoid kernel longer than line exception del inputSpatialShape['z'] zIndex = axiskeys.index('z') indices = [tIndex, cIndex, zIndex] indices = sorted(indices, reverse=True) for ind in indices: if ind: spatialRoi[0].pop(ind) spatialRoi[1].pop(ind) inputSpatialRoi = enlargeRoiForHalo(spatialRoi[0], spatialRoi[1], inputSpatialShape.values(), sigma, window=2.0) # Determine the roi within the input data we're going to request inputRoiOffset = spatialRoi[0] - inputSpatialRoi[0] computeRoi = (inputRoiOffset, inputRoiOffset + spatialRoi[1] - spatialRoi[0]) # For some reason, vigra.filters.gaussianSmoothing will raise an exception if this parameter doesn't have the correct integer type. # (for example, if we give it as a numpy.ndarray with dtype=int64, we get an error) computeRoi = (tuple(map(int, computeRoi[0])), tuple(map(int, computeRoi[1]))) inputRoi = (list(inputSpatialRoi[0]), list(inputSpatialRoi[1])) for ind in reversed(indices): if ind: inputRoi[0].insert(ind, 0) inputRoi[1].insert(ind, 1) return inputRoi, computeRoi def propagateDirty(self, slot, subindex, roi): if slot == self.Input: # Halo calculation is bidirectional, so we can re-use the function that computes the halo during execute() inputRoi, _ = self._getInputComputeRois(roi) self.Output.setDirty(inputRoi[0], inputRoi[1]) elif slot == self.Sigmas: self.Output.setDirty(slice(None)) else: assert False, "Unknown input slot: {}".format(slot.name)
class OpCacheFixer(Operator): """ Can be inserted in front of a cache operator to implement the "fixAtCurrent" behavior currently implemented by multiple lazyflow caches. While fixAtCurrent=False, this operator is merely a pass-through. While fixAtCurrent=True, this operator does not forward dirty notifications to downstream operators. Instead, it remembers the total ROI of the dirty area (as a bounding box), and emits the entire dirty ROI at once as soon as it becomes "unfixed". Also, this operator returns only zeros while fixAtCurrent=True. """ fixAtCurrent = InputSlot(value=False) Input = InputSlot(allow_mask=True) Output = OutputSlot(allow_mask=True) def __init__(self, *args, **kwargs): super(OpCacheFixer, self).__init__(*args, **kwargs) self._fixed = False self._fixed_dirty_roi = None def setupOutputs(self): self.Output.meta.assignFrom( self.Input.meta ) self.Output.meta.dontcache = self.fixAtCurrent.value # During initialization, if fixAtCurrent is configured before Input, then propagateDirty was never called. # We need to make sure that the dirty logic for fixAtCurrent has definitely been called here. self.propagateDirty(self.fixAtCurrent, (), slice(None)) def execute(self, slot, subindex, roi, result): if self._fixed: # The downstream user doesn't know he's getting fake data. # When we become "unfixed", we need to tell him. self._expand_fixed_dirty_roi( (roi.start, roi.stop) ) result[:] = 0 else: self.Input(roi.start, roi.stop).writeInto(result).wait() def setInSlot(self, slot, subindex, roi, value): # Forward to the output self.Output[roiToSlice(roi.start, roi.stop)] = value entire_roi = roiFromShape(self.Input.meta.shape) if (numpy.array((roi.start, roi.stop)) == entire_roi).all(): # Nothing is dirty any more. self._init_fixed_dirty_roi() def propagateDirty(self, slot, subindex, roi): if slot is self.fixAtCurrent: # If we're becoming UN-fixed, send out a big dirty notification if ( self._fixed and not self.fixAtCurrent.value and self._fixed_dirty_roi and (self._fixed_dirty_roi[1] - self._fixed_dirty_roi[0] > 0).all() ): self.Output.setDirty( *self._fixed_dirty_roi ) self._fixed_dirty_roi = None self._fixed = self.fixAtCurrent.value elif slot is self.Input: if self._fixed: # We can't propagate this downstream, # but we need to remember that it was marked dirty. # Expand our dirty bounding box. self._expand_fixed_dirty_roi( (roi.start, roi.stop) ) else: self.Output.setDirty(roi.start, roi.stop) def _init_fixed_dirty_roi(self): # Intentionally flipped: nothing is dirty at first. entire_roi = roiFromShape(self.Input.meta.shape) self._fixed_dirty_roi = (entire_roi[1], entire_roi[0]) def _expand_fixed_dirty_roi(self, roi): if self._fixed_dirty_roi is None: self._init_fixed_dirty_roi() start, stop = self._fixed_dirty_roi start = numpy.minimum(start, roi[0]) stop = numpy.maximum(stop, roi[1]) self._fixed_dirty_roi = (start, stop)
class OpSelectLabels(Operator): ## The smaller clusters # i.e. results of high thresholding SmallLabels = InputSlot() ## The larger clusters # i.e. results of low thresholding BigLabels = InputSlot() Output = OutputSlot() def setupOutputs(self): self.Output.meta.assignFrom(self.BigLabels.meta) self.Output.meta.dtype = numpy.uint32 self.Output.meta.drange = (0, 1) def execute(self, slot, subindex, roi, result): assert slot == self.Output # This operator is typically used with very big rois, so be extremely memory-conscious: # - Don't request the small and big inputs in parallel. # - Clean finished requests immediately (don't wait for this function to exit) # - Delete intermediate results as soon as possible. if logger.isEnabledFor(logging.DEBUG): dtypeBytes = self.SmallLabels.meta.getDtypeBytes() roiShape = roi.stop - roi.start logger.debug("Roi shape is {} = {} MB".format( roiShape, numpy.prod(roiShape) * dtypeBytes / 1e6)) starting_memory_usage_mb = getMemoryUsageMb() logger.debug("Starting with memory usage: {} MB".format( starting_memory_usage_mb)) def logMemoryIncrease(msg): """Log a debug message about the RAM usage compared to when this function started execution.""" if logger.isEnabledFor(logging.DEBUG): memory_increase_mb = getMemoryUsageMb( ) - starting_memory_usage_mb logger.debug("{}, memory increase is: {} MB".format( msg, memory_increase_mb)) smallLabelsReq = self.SmallLabels(roi.start, roi.stop) smallLabels = smallLabelsReq.wait() smallLabelsReq.clean() logMemoryIncrease("After obtaining small labels") smallNonZero = numpy.ndarray(shape=smallLabels.shape, dtype=bool) smallNonZero[...] = (smallLabels != 0) del smallLabels logMemoryIncrease("Before obtaining big labels") bigLabels = self.BigLabels(roi.start, roi.stop).wait() logMemoryIncrease("After obtaining big labels") prod = smallNonZero * bigLabels del smallNonZero # get labels that passed the masking #passed = numpy.unique(prod) passed = numpy.bincount(prod.flat).nonzero()[ 0] # Much faster than unique(), which copies and sorts # 0 is not a valid label if passed[0] == 0: passed = passed[1:] logMemoryIncrease("After taking product") del prod all_label_values = numpy.zeros((bigLabels.max() + 1, ), dtype=numpy.uint32) for i, l in enumerate(passed): all_label_values[l] = i + 1 all_label_values[0] = 0 # tricky: map the old labels to the new ones, labels that didnt pass # are mapped to zero result[:] = all_label_values[bigLabels] logMemoryIncrease("Just before return") return result def propagateDirty(self, slot, subindex, roi): if slot == self.SmallLabels or slot == self.BigLabels: self.Output.setDirty(slice(None)) else: assert False, "Unknown input slot: {}".format(slot.name)
class OpMockPixelClassifier(Operator): """ This class is a simple stand-in for the real pixel classification operator. Uses hard-coded data shape and block shape. Provides hard-coded outputs. """ name = "OpMockPixelClassifier" LabelInputs = InputSlot( optional=True, level=1) # Input for providing label data from an external source PredictionsFromDisk = InputSlot( optional=True, level=1) # TODO: Actually use this input for something ClassifierFactory = InputSlot( value=ParallelVigraRfLazyflowClassifierFactory(10, 10)) NonzeroLabelBlocks = OutputSlot( level=1, stype='object') # A list if slices that contain non-zero label values LabelImages = OutputSlot(level=1) # Labels from the user Classifier = OutputSlot(stype='object') PredictionProbabilities = OutputSlot(level=1) FreezePredictions = InputSlot() LabelNames = OutputSlot() LabelColors = OutputSlot() PmapColors = OutputSlot() Bookmarks = OutputSlot(level=1) def __init__(self, *args, **kwargs): super(OpMockPixelClassifier, self).__init__(*args, **kwargs) self.LabelNames.setValue(["Membrane", "Cytoplasm"]) self.LabelColors.setValue([(255, 0, 0), (0, 255, 0)]) # Red, Green self.PmapColors.setValue([(255, 0, 0), (0, 255, 0)]) # Red, Green self._data = [] self.dataShape = (1, 10, 100, 100, 1) self.prediction_shape = self.dataShape[:-1] + ( 2, ) # Hard-coded to provide 2 classes self.FreezePredictions.setValue(False) self.opClassifier = OpTrainClassifierBlocked(graph=self.graph, parent=self) self.opClassifier.ClassifierFactory.connect(self.ClassifierFactory) self.opClassifier.Labels.connect(self.LabelImages) self.opClassifier.nonzeroLabelBlocks.connect(self.NonzeroLabelBlocks) self.opClassifier.MaxLabel.setValue(2) self.classifier_cache = OpValueCache(graph=self.graph, parent=self) self.classifier_cache.Input.connect(self.opClassifier.Classifier) p1 = old_div(numpy.indices(self.dataShape).sum(0), 207.0) p2 = 1 - p1 self.predictionData = numpy.concatenate((p1, p2), axis=4) def setupOutputs(self): numImages = len(self.LabelInputs) self.PredictionsFromDisk.resize(numImages) self.NonzeroLabelBlocks.resize(numImages) self.LabelImages.resize(numImages) self.PredictionProbabilities.resize(numImages) self.opClassifier.Images.resize(numImages) for i in range(numImages): self._data.append(numpy.zeros(self.dataShape)) self.NonzeroLabelBlocks[i].meta.shape = (1, ) self.NonzeroLabelBlocks[i].meta.dtype = object # Hard-coded: Two prediction classes self.PredictionProbabilities[i].meta.shape = self.prediction_shape self.PredictionProbabilities[i].meta.dtype = numpy.float64 self.PredictionProbabilities[ i].meta.axistags = vigra.defaultAxistags('txyzc') # Classify with random data self.opClassifier.Images[i].setValue( vigra.taggedView(numpy.random.random(self.dataShape), 'txyzc')) self.LabelImages[i].meta.shape = self.dataShape self.LabelImages[i].meta.dtype = numpy.float64 self.LabelImages[i].meta.axistags = self.opClassifier.Images[ i].meta.axistags self.Classifier.connect(self.opClassifier.Classifier) def setInSlot(self, slot, subindex, roi, value): key = roi.toSlice() assert slot.name == "LabelInputs" self._data[subindex[0]][key] = value self.LabelImages[subindex[0]].setDirty(key) def execute(self, slot, subindex, roi, result): key = roiToSlice(roi.start, roi.stop) index = subindex[0] if slot.name == "NonzeroLabelBlocks": # Split into 10 chunks blocks = [] slicing = [slice(0, maximum) for maximum in self.dataShape] for i in range(10): slicing[2] = slice(i * 10, (i + 1) * 10) if not (self._data[index][slicing] == 0).all(): blocks.append(list(slicing)) result[0] = blocks if slot.name == "LabelImages": result[...] = self._data[index][key] if slot.name == "PredictionProbabilities": result[...] = self.predictionData[key] def propagateDirty(self, slot, subindex, roi): pass
class OpPredictionPipeline(OpPredictionPipelineNoCache): """ This operator extends the cacheless prediction pipeline above with additional outputs for the GUI. (It uses caches for these outputs, and has an extra input for cached features.) """ FreezePredictions = InputSlot() CachedFeatureImages = InputSlot() PredictionProbabilities = OutputSlot() CachedPredictionProbabilities = OutputSlot() PredictionProbabilitiesUint8 = OutputSlot() PredictionProbabilityChannels = OutputSlot( level=1 ) SegmentationChannels = OutputSlot( level=1 ) UncertaintyEstimate = OutputSlot() def __init__(self, *args, **kwargs): super(OpPredictionPipeline, self).__init__( *args, **kwargs ) # Random forest prediction using CACHED features. self.predict = OpClassifierPredict( parent=self ) self.predict.name = "OpClassifierPredict" self.predict.Classifier.connect(self.Classifier) self.predict.Image.connect(self.CachedFeatureImages) self.predict.PredictionMask.connect(self.PredictionMask) self.predict.LabelsCount.connect( self.NumClasses ) self.PredictionProbabilities.connect( self.predict.PMaps ) # Alternate headless output: uint8 instead of float. # Note that drange is automatically updated. self.opConvertToUint8 = OpPixelOperator( parent=self ) self.opConvertToUint8.Input.connect( self.predict.PMaps ) self.opConvertToUint8.Function.setValue( lambda a: (255*a).astype(numpy.uint8) ) self.PredictionProbabilitiesUint8.connect( self.opConvertToUint8.Output ) # Prediction cache for the GUI self.prediction_cache_gui = OpSlicedBlockedArrayCache( parent=self ) self.prediction_cache_gui.name = "prediction_cache_gui" self.prediction_cache_gui.inputs["fixAtCurrent"].connect( self.FreezePredictions ) self.prediction_cache_gui.inputs["Input"].connect( self.predict.PMaps ) self.CachedPredictionProbabilities.connect(self.prediction_cache_gui.Output ) # Also provide each prediction channel as a separate layer (for the GUI) self.opPredictionSlicer = OpMultiArraySlicer2( parent=self ) self.opPredictionSlicer.name = "opPredictionSlicer" self.opPredictionSlicer.Input.connect( self.prediction_cache_gui.Output ) self.opPredictionSlicer.AxisFlag.setValue('c') self.PredictionProbabilityChannels.connect( self.opPredictionSlicer.Slices ) self.opSegmentor = OpMaxChannelIndicatorOperator( parent=self ) self.opSegmentor.Input.connect( self.prediction_cache_gui.Output ) self.opSegmentationSlicer = OpMultiArraySlicer2( parent=self ) self.opSegmentationSlicer.name = "opSegmentationSlicer" self.opSegmentationSlicer.Input.connect( self.opSegmentor.Output ) self.opSegmentationSlicer.AxisFlag.setValue('c') self.SegmentationChannels.connect( self.opSegmentationSlicer.Slices ) # Create a layer for uncertainty estimate self.opUncertaintyEstimator = OpEnsembleMargin( parent=self ) self.opUncertaintyEstimator.Input.connect( self.prediction_cache_gui.Output ) # Cache the uncertainty so we get zeros for uncomputed points self.opUncertaintyCache = OpSlicedBlockedArrayCache( parent=self ) self.opUncertaintyCache.name = "opUncertaintyCache" self.opUncertaintyCache.Input.connect( self.opUncertaintyEstimator.Output ) self.opUncertaintyCache.fixAtCurrent.connect( self.FreezePredictions ) self.UncertaintyEstimate.connect( self.opUncertaintyCache.Output ) def setupOutputs(self): # Set the blockshapes for each input image separately, depending on which axistags it has. axisOrder = [ tag.key for tag in self.FeatureImages.meta.axistags ] blockDimsX = { 't' : (1,1), 'z' : (256,256), 'y' : (256,256), 'x' : (1,1), 'c' : (100, 100) } blockDimsY = { 't' : (1,1), 'z' : (256,256), 'y' : (1,1), 'x' : (256,256), 'c' : (100,100) } blockDimsZ = { 't' : (1,1), 'z' : (1,1), 'y' : (256,256), 'x' : (256,256), 'c' : (100,100) } innerBlockShapeX = tuple( blockDimsX[k][0] for k in axisOrder ) outerBlockShapeX = tuple( blockDimsX[k][1] for k in axisOrder ) innerBlockShapeY = tuple( blockDimsY[k][0] for k in axisOrder ) outerBlockShapeY = tuple( blockDimsY[k][1] for k in axisOrder ) innerBlockShapeZ = tuple( blockDimsZ[k][0] for k in axisOrder ) outerBlockShapeZ = tuple( blockDimsZ[k][1] for k in axisOrder ) self.prediction_cache_gui.inputs["innerBlockShape"].setValue( (innerBlockShapeX, innerBlockShapeY, innerBlockShapeZ) ) self.prediction_cache_gui.inputs["outerBlockShape"].setValue( (outerBlockShapeX, outerBlockShapeY, outerBlockShapeZ) ) self.opUncertaintyCache.inputs["innerBlockShape"].setValue( (innerBlockShapeX, innerBlockShapeY, innerBlockShapeZ) ) self.opUncertaintyCache.inputs["outerBlockShape"].setValue( (outerBlockShapeX, outerBlockShapeY, outerBlockShapeZ) ) assert self.opConvertToUint8.Output.meta.drange == (0,255)