class OpDataSelectionGroup(Operator): # Inputs ProjectFile = InputSlot(stype='object', optional=True) ProjectDataGroup = InputSlot(stype='string', optional=True) WorkingDirectory = InputSlot(stype='filestring') DatasetRoles = InputSlot(stype='object') DatasetGroup = InputSlot( stype='object', level=1, optional=True ) # Must mark as optional because not all subslots are required. # Outputs ImageGroup = OutputSlot(level=1) # These output slots are provided as a convenience, since otherwise it is tricky to create a lane-wise multislot of level-1 for only a single role. # (It can be done, but requires OpTransposeSlots to invert the level-2 multislot indexes...) Image = OutputSlot() # The first dataset. Equivalent to ImageGroup[0] Image1 = OutputSlot() # The second dataset. Equivalent to ImageGroup[1] Image2 = OutputSlot() # The third dataset. Equivalent to ImageGroup[2] AllowLabels = OutputSlot( stype='bool') # Pulled from the first dataset only. _NonTransposedImageGroup = OutputSlot(level=1) # Must be the LAST slot declared in this class. # When the shell detects that this slot has been resized, # it assumes all the others have already been resized. ImageName = OutputSlot( ) # Name of the first dataset is used. Other names are ignored. def __init__(self, forceAxisOrder=None, *args, **kwargs): super(OpDataSelectionGroup, self).__init__(*args, **kwargs) self._opDatasets = None self._roles = [] self._forceAxisOrder = forceAxisOrder def handleNewRoles(*args): self.DatasetGroup.resize(len(self.DatasetRoles.value)) self.DatasetRoles.notifyReady(handleNewRoles) def setupOutputs(self): # Create internal operators if self.DatasetRoles.value != self._roles: self._roles = self.DatasetRoles.value # Clean up the old operators self.ImageGroup.disconnect() self.Image.disconnect() self.Image1.disconnect() self.Image2.disconnect() self._NonTransposedImageGroup.disconnect() if self._opDatasets is not None: self._opDatasets.cleanUp() self._opDatasets = OperatorWrapper( OpDataSelection, parent=self, operator_kwargs={'forceAxisOrder': self._forceAxisOrder}, broadcastingSlotNames=[ 'ProjectFile', 'ProjectDataGroup', 'WorkingDirectory' ]) self.ImageGroup.connect(self._opDatasets.Image) self._NonTransposedImageGroup.connect( self._opDatasets._NonTransposedImage) self._opDatasets.Dataset.connect(self.DatasetGroup) self._opDatasets.ProjectFile.connect(self.ProjectFile) self._opDatasets.ProjectDataGroup.connect(self.ProjectDataGroup) self._opDatasets.WorkingDirectory.connect(self.WorkingDirectory) if len(self._opDatasets.Image) > 0: self.Image.connect(self._opDatasets.Image[0]) if len(self._opDatasets.Image) >= 2: self.Image1.connect(self._opDatasets.Image[1]) else: self.Image1.disconnect() self.Image1.meta.NOTREADY = True if len(self._opDatasets.Image) >= 3: self.Image2.connect(self._opDatasets.Image[2]) else: self.Image2.disconnect() self.Image2.meta.NOTREADY = True self.ImageName.connect(self._opDatasets.ImageName[0]) self.AllowLabels.connect(self._opDatasets.AllowLabels[0]) else: self.Image.disconnect() self.Image1.disconnect() self.Image2.disconnect() self.ImageName.disconnect() self.AllowLabels.disconnect() self.Image.meta.NOTREADY = True self.Image1.meta.NOTREADY = True self.Image2.meta.NOTREADY = True self.ImageName.meta.NOTREADY = True self.AllowLabels.meta.NOTREADY = True def execute(self, slot, subindex, rroi, result): assert False, "Unknown or unconnected output slot: {}".format( slot.name) def propagateDirty(self, slot, subindex, roi): # Output slots are directly connected to internal operators pass
class OpConservationTracking(OpTrackingBase): DivisionProbabilities = InputSlot(optional=True, stype=Opaque, rtype=List) DetectionProbabilities = InputSlot(stype=Opaque, rtype=List) NumLabels = InputSlot() # compressed cache for merger output MergerInputHdf5 = InputSlot(optional=True) MergerCleanBlocks = OutputSlot() MergerOutputHdf5 = OutputSlot() MergerCachedOutput = OutputSlot() # For the GUI (blockwise access) MergerOutput = OutputSlot() CoordinateMap = OutputSlot() RelabeledInputHdf5 = InputSlot(optional=True) RelabeledCleanBlocks = OutputSlot() RelabeledOutputHdf5 = OutputSlot() RelabeledCachedOutput = OutputSlot() # For the GUI (blockwise access) RelabeledImage = OutputSlot() def __init__(self, parent=None, graph=None): super(OpConservationTracking, self).__init__(parent=parent, graph=graph) self._mergerOpCache = OpCompressedCache(parent=self) self._mergerOpCache.InputHdf5.connect(self.MergerInputHdf5) self._mergerOpCache.Input.connect(self.MergerOutput) self.MergerCleanBlocks.connect(self._mergerOpCache.CleanBlocks) self.MergerOutputHdf5.connect(self._mergerOpCache.OutputHdf5) self.MergerCachedOutput.connect(self._mergerOpCache.Output) self._relabeledOpCache = OpCompressedCache(parent=self) self._relabeledOpCache.InputHdf5.connect(self.RelabeledInputHdf5) self._relabeledOpCache.Input.connect(self.RelabeledImage) self.RelabeledCleanBlocks.connect(self._relabeledOpCache.CleanBlocks) self.RelabeledOutputHdf5.connect(self._relabeledOpCache.OutputHdf5) self.RelabeledCachedOutput.connect(self._relabeledOpCache.Output) self.tracker = None self._ndim = 3 def setupOutputs(self): super(OpConservationTracking, self).setupOutputs() self.MergerOutput.meta.assignFrom(self.LabelImage.meta) self.RelabeledImage.meta.assignFrom(self.LabelImage.meta) self._ndim = 2 if self.LabelImage.meta.shape[3] == 1 else 3 self._mergerOpCache.BlockShape.setValue(self._blockshape) self._relabeledOpCache.BlockShape.setValue(self._blockshape) frame_shape = ( 1, ) + self.LabelImage.meta.shape[1:] # assumes t,x,y,z,c order assert frame_shape[-1] == 1 self.MergerOutput.meta.ideal_blockshape = frame_shape self.RelabeledImage.meta.ideal_blockshape = frame_shape def execute(self, slot, subindex, roi, result): if slot is self.Output: parameters = self.Parameters.value trange = range(roi.start[0], roi.stop[0]) original = np.zeros(result.shape, dtype=slot.meta.dtype) super(OpConservationTracking, self).execute(slot, subindex, roi, original) result[:] = self.LabelImage.get(roi).wait() pixel_offsets = roi.start[ 1:-1] # offset only in pixels, not time and channel for t in trange: if ('time_range' in parameters and t <= parameters['time_range'][-1] and t >= parameters['time_range'][0] and len(self.resolvedto) > t and len(self.resolvedto[t])): result[t - roi.start[0], ..., 0] = self._relabelMergers( result[t - roi.start[0], ..., 0], t, pixel_offsets) else: result[t - roi.start[0], ...][:] = 0 original[result != 0] = result[result != 0] result[:] = original elif slot is self.MergerOutput: parameters = self.Parameters.value trange = range(roi.start[0], roi.stop[0]) result[:] = self.LabelImage.get(roi).wait() pixel_offsets = roi.start[ 1:-1] # offset only in pixels, not time and channel for t in trange: if ('time_range' in parameters and t <= parameters['time_range'][-1] and t >= parameters['time_range'][0] and len(self.mergers) > t and len(self.mergers[t])): if 'withMergerResolution' in parameters.keys( ) and parameters['withMergerResolution']: result[t - roi.start[0], ..., 0] = self._relabelMergers( result[t - roi.start[0], ..., 0], t, pixel_offsets, True) else: result[t - roi.start[0], ..., 0] = highlightMergers( result[t - roi.start[0], ..., 0], self.mergers[t]) else: result[t - roi.start[0], ...][:] = 0 elif slot is self.RelabeledImage: parameters = self.Parameters.value trange = range(roi.start[0], roi.stop[0]) result[:] = self.LabelImage.get(roi).wait() pixel_offsets = roi.start[ 1:-1] # offset only in pixels, not time and channel for t in trange: if ('time_range' in parameters and t <= parameters['time_range'][-1] and t >= parameters['time_range'][0] and len(self.resolvedto) > t and len(self.resolvedto[t]) and 'withMergerResolution' in parameters.keys() and parameters['withMergerResolution']): result[t - roi.start[0], ..., 0] = self._relabelMergers( result[t - roi.start[0], ..., 0], t, pixel_offsets, False, True) else: # default bahaviour super(OpConservationTracking, self).execute(slot, subindex, roi, result) return result def setInSlot(self, slot, subindex, roi, value): assert slot == self.InputHdf5 or slot == self.MergerInputHdf5 or slot == self.RelabeledInputHdf5, "Invalid slot for setInSlot(): {}".format( slot.name) def track(self, time_range, x_range, y_range, z_range, size_range=(0, 100000), x_scale=1.0, y_scale=1.0, z_scale=1.0, maxDist=30, maxObj=2, divThreshold=0.5, avgSize=[0], withTracklets=False, sizeDependent=True, divWeight=10.0, transWeight=10.0, withDivisions=True, withOpticalCorrection=True, withClassifierPrior=False, ndim=3, cplex_timeout=None, withMergerResolution=True, borderAwareWidth=0.0, withArmaCoordinates=True, appearance_cost=500, disappearance_cost=500, motionModelWeight=10.0, force_build_hypotheses_graph=False, max_nearest_neighbors=1, withBatchProcessing=False, solverName="ILP"): if not self.Parameters.ready(): raise Exception("Parameter slot is not ready") # it is assumed that the self.Parameters object is changed only at this # place (ugly assumption). Therefore we can track any changes in the # parameters as done in the following lines: If the same value for the # key is already written in the parameters dictionary, the # paramters_changed dictionary will get a "False" entry for this key, # otherwise it is set to "True" parameters = self.Parameters.value parameters['maxDist'] = maxDist parameters['maxObj'] = maxObj parameters['divThreshold'] = divThreshold parameters['avgSize'] = avgSize parameters['withTracklets'] = withTracklets parameters['sizeDependent'] = sizeDependent parameters['divWeight'] = divWeight parameters['transWeight'] = transWeight parameters['withDivisions'] = withDivisions parameters['withOpticalCorrection'] = withOpticalCorrection parameters['withClassifierPrior'] = withClassifierPrior parameters['withMergerResolution'] = withMergerResolution parameters['borderAwareWidth'] = borderAwareWidth parameters['withArmaCoordinates'] = withArmaCoordinates parameters['appearanceCost'] = appearance_cost parameters['disappearanceCost'] = disappearance_cost do_build_hypotheses_graph = True if cplex_timeout: parameters['cplex_timeout'] = cplex_timeout else: parameters['cplex_timeout'] = '' cplex_timeout = float(1e75) if withClassifierPrior: if not self.DetectionProbabilities.ready() or len( self.DetectionProbabilities([0]).wait()[0]) == 0: raise DatasetConstraintError( 'Tracking', 'Classifier not ready yet. Did you forget to train the Object Count Classifier?' ) if not self.NumLabels.ready() or self.NumLabels.value < (maxObj + 1): raise DatasetConstraintError('Tracking', 'The max. number of objects must be consistent with the number of labels given in Object Count Classification.\n' +\ 'Check whether you have (i) the correct number of label names specified in Object Count Classification, and (ii) provided at least ' +\ 'one training example for each class.') if len(self.DetectionProbabilities( [0]).wait()[0][0]) < (maxObj + 1): raise DatasetConstraintError('Tracking', 'The max. number of objects must be consistent with the number of labels given in Object Count Classification.\n' +\ 'Check whether you have (i) the correct number of label names specified in Object Count Classification, and (ii) provided at least ' +\ 'one training example for each class.') median_obj_size = [0] fs, ts, empty_frame, max_traxel_id_at = self._generate_traxelstore( time_range, x_range, y_range, z_range, size_range, x_scale, y_scale, z_scale, median_object_size=median_obj_size, with_div=withDivisions, with_opt_correction=withOpticalCorrection, with_classifier_prior=withClassifierPrior) if empty_frame: raise DatasetConstraintError( 'Tracking', 'Can not track frames with 0 objects, abort.') if avgSize[0] > 0: median_obj_size = avgSize logger.info('median_obj_size = {}'.format(median_obj_size)) ep_gap = 0.05 transition_parameter = 5 fov = pgmlink.FieldOfView( time_range[0] * 1.0, x_range[0] * x_scale, y_range[0] * y_scale, z_range[0] * z_scale, time_range[-1] * 1.0, (x_range[1] - 1) * x_scale, (y_range[1] - 1) * y_scale, (z_range[1] - 1) * z_scale, ) logger.info('fov = {},{},{},{},{},{},{},{}'.format( time_range[0] * 1.0, x_range[0] * x_scale, y_range[0] * y_scale, z_range[0] * z_scale, time_range[-1] * 1.0, (x_range[1] - 1) * x_scale, (y_range[1] - 1) * y_scale, (z_range[1] - 1) * z_scale, )) if ndim == 2: assert z_range[0] * z_scale == 0 and ( z_range[1] - 1) * z_scale == 0, "fov of z must be (0,0) if ndim==2" if self.tracker is None: do_build_hypotheses_graph = True solverType = self.getPgmlinkSolverType(solverName) if do_build_hypotheses_graph: print '\033[94m' + "make new graph" + '\033[0m' self.tracker = pgmlink.ConsTracking( int(maxObj), bool(sizeDependent), # size_dependent_detection_prob float(median_obj_size[0]), # median_object_size float(maxDist), bool(withDivisions), float(divThreshold), "none", # detection_rf_filename fov, "none", # dump traxelstore, solverType, ndim) g = self.tracker.buildGraph(ts, max_nearest_neighbors) # create dummy uncertainty parameter object with just one iteration, so no perturbations at all (iter=0 -> MAP) sigmas = pgmlink.VectorOfDouble() for i in range(5): sigmas.append(0.0) uncertaintyParams = pgmlink.UncertaintyParameter( 1, pgmlink.DistrId.PerturbAndMAP, sigmas) params = self.tracker.get_conservation_tracking_parameters( 0, # forbidden_cost float(ep_gap), # ep_gap bool(withTracklets), # with tracklets float(10.0), # detection weight float(divWeight), # division weight float(transWeight), # transition weight float(disappearance_cost), # disappearance cost float(appearance_cost), # appearance cost bool(withMergerResolution), # with merger resolution int(ndim), # ndim float(transition_parameter), # transition param float(borderAwareWidth), # border width True, #with_constraints uncertaintyParams, # uncertainty parameters float(cplex_timeout), # cplex timeout None, # transition classifier solverType, False, # training to hard constraints 1 # num threads ) # if motionModelWeight > 0: # logger.info("Registering motion model with weight {}".format(motionModelWeight)) # params.register_motion_model4_func(swirl_motion_func_creator(motionModelWeight), motionModelWeight * 25.0) try: eventsVector = self.tracker.track(params, False) eventsVector = eventsVector[ 0] # we have a vector such that we could get a vector per perturbation # extract the coordinates with the given event vector if withMergerResolution: coordinate_map = pgmlink.TimestepIdCoordinateMap() self._get_merger_coordinates(coordinate_map, time_range, eventsVector) self.CoordinateMap.setValue(coordinate_map) eventsVector = self.tracker.resolve_mergers( eventsVector, params, coordinate_map.get(), float(ep_gap), float(transWeight), bool(withTracklets), ndim, transition_parameter, max_traxel_id_at, True, # with_constraints None) # TransitionClassifier except Exception as e: raise Exception, 'Tracking terminated unsuccessfully: ' + str(e) if len(eventsVector) == 0: raise Exception, 'Tracking terminated unsuccessfully: Events vector has zero length.' events = get_events(eventsVector) self.Parameters.setValue(parameters, check_changed=False) self.EventsVector.setValue(events, check_changed=False) self.RelabeledImage.setDirty() if not withBatchProcessing: merger_layer_idx = self.parent.parent.trackingApplet._gui.currentGui( ).layerstack.findMatchingIndex(lambda x: x.name == "Merger") tracking_layer_idx = self.parent.parent.trackingApplet._gui.currentGui( ).layerstack.findMatchingIndex(lambda x: x.name == "Tracking") if 'withMergerResolution' in parameters.keys( ) and not parameters['withMergerResolution']: self.parent.parent.trackingApplet._gui.currentGui().layerstack[merger_layer_idx].colorTable = \ self.parent.parent.trackingApplet._gui.currentGui().merger_colortable else: self.parent.parent.trackingApplet._gui.currentGui().layerstack[merger_layer_idx].colorTable = \ self.parent.parent.trackingApplet._gui.currentGui().tracking_colortable @staticmethod def getPgmlinkSolverType(solverName): if solverName == "ILP": # select solver type if hasattr(pgmlink.ConsTrackingSolverType, "CplexSolver"): solver = pgmlink.ConsTrackingSolverType.CplexSolver else: raise AssertionError( "Cannot select ILP solver if pgmlink was compiled without ILP support" ) elif solverName == "Magnusson": if hasattr(pgmlink.ConsTrackingSolverType, "DynProgSolver"): solver = pgmlink.ConsTrackingSolverType.DynProgSolver else: raise AssertionError( "Cannot select Magnusson solver if pgmlink was compiled without Magnusson support" ) elif solverName == "Flow": if hasattr(pgmlink.ConsTrackingSolverType, "FlowSolver"): solver = pgmlink.ConsTrackingSolverType.FlowSolver else: raise AssertionError( "Cannot select Flow solver if pgmlink was compiled without Flow support" ) else: raise ValueError("Unknown solver {} selected".format(solverName)) return solver def propagateDirty(self, inputSlot, subindex, roi): super(OpConservationTracking, self).propagateDirty(inputSlot, subindex, roi) if inputSlot == self.NumLabels: if self.parent.parent.trackingApplet._gui \ and self.parent.parent.trackingApplet._gui.currentGui() \ and self.NumLabels.ready() \ and self.NumLabels.value > 1: self.parent.parent.trackingApplet._gui.currentGui( )._drawer.maxObjectsBox.setValue(self.NumLabels.value - 1) def _get_merger_coordinates(self, coordinate_map, time_range, eventsVector): # fetch features feats = self.ObjectFeatures(time_range).wait() # iterate over all timesteps for t in feats.keys(): rc = feats[t][default_features_key]['RegionCenter'] lower = feats[t][default_features_key]['Coord<Minimum>'] upper = feats[t][default_features_key]['Coord<Maximum>'] size = feats[t][default_features_key]['Count'] for event in eventsVector[t]: # check for merger events if event.type == pgmlink.EventType.Merger: idx = event.traxel_ids[0] # generate roi: assume the following order: txyzc n_dim = len(rc[idx]) roi = [0] * 5 roi[0] = slice(int(t), int(t + 1)) roi[1] = slice(int(lower[idx][0]), int(upper[idx][0] + 1)) roi[2] = slice(int(lower[idx][1]), int(upper[idx][1] + 1)) if n_dim == 3: roi[3] = slice(int(lower[idx][2]), int(upper[idx][2] + 1)) else: assert n_dim == 2 image_excerpt = self.LabelImage[roi].wait() if n_dim == 2: image_excerpt = image_excerpt[0, ..., 0, 0] elif n_dim == 3: image_excerpt = image_excerpt[0, ..., 0] else: raise Exception, "n_dim = %s instead of 2 or 3" pgmlink.extract_coord_by_timestep_id( coordinate_map, image_excerpt, lower[idx].astype(np.int64), t, idx, int(size[idx, 0])) def _relabelMergers(self, volume, time, pixel_offsets=[0, 0, 0], onlyMergers=False, noRelabeling=False): if self.CoordinateMap.value.size() == 0: logger.info( "Skipping merger relabeling because coordinate map is empty") if onlyMergers: return np.zeros_like(volume) else: return volume if time >= len(self.resolvedto): if onlyMergers: return np.zeros_like(volume) else: return volume coordinate_map = self.CoordinateMap.value valid_ids = [] for old_id, new_ids in self.resolvedto[time].iteritems(): for new_id in new_ids: # TODO Reliable distinction between 2d and 3d? if self._ndim == 2: # Assume we have 2d data: bind z to zero relabel_volume = volume[..., 0] else: # For 3d data use the whole volume relabel_volume = volume # relabel pgmlink.update_labelimage( coordinate_map, relabel_volume, np.array(pixel_offsets, dtype=np.int64), int(time), int(new_id)) valid_ids.append(new_id) if onlyMergers: # find indices of merger ids, set everything else to zero idx = np.in1d(volume.ravel(), valid_ids).reshape(volume.shape) volume[-idx] = 0 if noRelabeling: return volume else: return relabel(volume, self.label2color[time]) def do_export(self, settings, selected_features, progress_slot, lane_index, filename_suffix=""): """ Implements ExportOperator.do_export(settings, selected_features, progress_slot Most likely called from ExportOperator.export_object_data :param settings: the settings for the exporter, see :param selected_features: :param progress_slot: :param lane_index: Ignored. (This is a single-lane operator. It is the caller's responsibility to make sure he's calling the right lane.) :param filename_suffix: If provided, appended to the filename (before the extension). :return: """ #assert lane_index == 0, "This has only been tested in tracking workflows with a single image." with_divisions = self.Parameters.value[ "withDivisions"] if self.Parameters.ready() else False with_merger_resolution = self.Parameters.value[ "withMergerResolution"] if self.Parameters.ready() else False if with_divisions: object_feature_slot = self.ObjectFeaturesWithDivFeatures else: object_feature_slot = self.ObjectFeatures if with_merger_resolution: label_image = self.RelabeledImage opRelabeledRegionFeatures = self._setupRelabeledFeatureSlot( object_feature_slot) object_feature_slot = opRelabeledRegionFeatures.RegionFeatures else: label_image = self.LabelImage self._do_export_impl(settings, selected_features, progress_slot, object_feature_slot, label_image, lane_index, filename_suffix) if with_merger_resolution: opRelabeledRegionFeatures.cleanUp() def _setupRelabeledFeatureSlot(self, original_feature_slot): from ilastik.applets.trackingFeatureExtraction import config # when exporting after merger resolving, the stored object features are not up to date for the relabeled objects opRelabeledRegionFeatures = OpRelabeledMergerFeatureExtraction( parent=self) opRelabeledRegionFeatures.RawImage.connect(self.RawImage) opRelabeledRegionFeatures.LabelImage.connect(self.LabelImage) opRelabeledRegionFeatures.RelabeledImage.connect(self.RelabeledImage) opRelabeledRegionFeatures.OriginalRegionFeatures.connect( original_feature_slot) opRelabeledRegionFeatures.ResolvedTo.setValue(self.resolvedto) vigra_features = list((set(config.vigra_features)).union( config.selected_features_objectcount[config.features_vigra_name])) feature_names_vigra = {} feature_names_vigra[config.features_vigra_name] = { name: {} for name in vigra_features } opRelabeledRegionFeatures.FeatureNames.setValue(feature_names_vigra) return opRelabeledRegionFeatures
class OpSlicedBlockedArrayCache(OpCache): name = "OpSlicedBlockedArrayCache" description = "" #Inputs Input = InputSlot() innerBlockShape = InputSlot() outerBlockShape = InputSlot() fixAtCurrent = InputSlot(value=False) #Outputs Output = OutputSlot() InnerOutputs = OutputSlot(level=1) loggerName = __name__ + ".OpSlicedBlockedArrayCache" logger = logging.getLogger(loggerName) traceLogger = logging.getLogger("TRACE." + loggerName) def __init__(self, *args, **kwargs): super(OpSlicedBlockedArrayCache, self).__init__(*args, **kwargs) self._innerOps = [] def generateReport(self, report): report.name = self.name report.fractionOfUsedMemoryDirty = self.fractionOfUsedMemoryDirty() report.usedMemory = self.usedMemory() report.lastAccessTime = self.lastAccessTime() report.dtype = self.Output.meta.dtype report.type = type(self) report.id = id(self) sh = self.Output.meta.shape if sh is not None: report.roi = ([0] * len(sh), sh) for i, iOp in enumerate(self._innerOps): n = MemInfoNode() report.children.append(n) iOp.generateReport(n) def usedMemory(self): tot = 0.0 for iOp in self._innerOps: tot += iOp.usedMemory() return tot def setupOutputs(self): self.shape = self.inputs["Input"].meta.shape self._outerShapes = self.inputs["outerBlockShape"].value self._innerShapes = self.inputs["innerBlockShape"].value for blockshape in self._innerShapes + self._outerShapes: if len(blockshape) != len(self.Input.meta.shape): self.Output.meta.NOTREADY = True return # FIXME: This is wrong: Shouldn't it actually compare the new inner block shape with the old one? if len(self._innerShapes) != len(self._innerOps): # Clean up previous inner operators for slot in self.InnerOutputs: slot.disconnect() for o in self._innerOps: o.cleanUp() self._innerOps = [] for i, innershape in enumerate(self._innerShapes): op = OpBlockedArrayCache(parent=self) op.inputs["fixAtCurrent"].connect(self.inputs["fixAtCurrent"]) self._innerOps.append(op) op.inputs["Input"].connect(self.inputs["Input"]) # Forward "value changed" notifications to our own output op.Output.notifyValueChanged(self.Output._sig_value_changed) for i, innershape in enumerate(self._innerShapes): op = self._innerOps[i] op.inputs["innerBlockShape"].setValue(innershape) op.inputs["outerBlockShape"].setValue(self._outerShapes[i]) self.Output.meta.assignFrom(self.Input.meta) # Estimate ram usage ram_per_pixel = 0 if self.Output.meta.dtype == object or self.Output.meta.dtype == numpy.object_: ram_per_pixel = sys.getsizeof(None) elif numpy.issubdtype(self.Output.meta.dtype, numpy.dtype): ram_per_pixel = self.Output.meta.dtype().nbytes tagged_shape = self.Output.meta.getTaggedShape() if 'c' in tagged_shape: ram_per_pixel *= float(tagged_shape['c']) if self.Output.meta.ram_usage_per_requested_pixel is not None: ram_per_pixel = max(ram_per_pixel, self.Output.meta.ram_usage_per_requested_pixel) self.Output.meta.ram_usage_per_requested_pixel = ram_per_pixel # We also provide direct access to each of our inner cache outputs. self.InnerOutputs.resize(len(self._innerOps)) for i, slot in enumerate(self.InnerOutputs): slot.connect(self._innerOps[i].Output) def execute(self, slot, subindex, roi, result): t = time.time() assert slot == self.Output key = roi.toSlice() start, stop = sliceToRoi(key, self.shape) roishape = numpy.array(stop) - numpy.array(start) max_dist_squared = sys.maxint index = 0 for i, blockshape in enumerate(self._innerShapes): blockshape = numpy.array(blockshape) diff = roishape - blockshape diffsquared = diff * diff distance_squared = numpy.sum(diffsquared) if distance_squared < max_dist_squared: index = i max_dist_squared = distance_squared op = self._innerOps[index] op.outputs["Output"][key].writeInto(result).wait() self.logger.debug("read %r took %f msec." % (roi.pprint(), 1000.0 * (time.time() - t))) def propagateDirty(self, slot, subindex, roi): key = roi.toSlice() # We *could* simply forward dirty notifications from our inner operators # to our output (by subscribing to their notifyDirty signals), # but that would result in duplicates of many (not all!) dirty notifications # (since we have more than one inner cache, and each is receiving dirty notifications) # Instead, we simply mark *everything* dirty when we beome unfixed or if the block shape changes. fixed = self.fixAtCurrent.value if not fixed: if slot == self.Input: self.Output.setDirty(key) elif slot == self.outerBlockShape or slot == self.innerBlockShape: #self.Output.setDirty( slice(None) ) pass # Blockshape changes don't trigger dirty notifications # It is considered an error to change the blockshape after the initial configuration. elif slot == self.fixAtCurrent: self.Output.setDirty(slice(None)) else: assert False, "Unknown dirty input slot"
class OpPixelClassification(Operator): """ Top-level operator for pixel classification """ name = "OpPixelClassification" category = "Top-level" # Graph inputs InputImages = InputSlot( level=1) # Original input data. Used for display only. PredictionMasks = InputSlot( level=1, optional=True ) # Routed to OpClassifierPredict.PredictionMask. See there for details. LabelInputs = InputSlot( optional=True, level=1) # Input for providing label data from an external source LabelsAllowedFlags = InputSlot( stype='bool', level=1) # Specifies which images are permitted to be labeled FeatureImages = InputSlot( level=1 ) # Computed feature images (each channel is a different feature) CachedFeatureImages = InputSlot(level=1) # Cached feature data. FreezePredictions = InputSlot(stype='bool') ClassifierFactory = InputSlot( value=ParallelVigraRfLazyflowClassifierFactory(10, 10)) PredictionsFromDisk = InputSlot(optional=True, level=1) PredictionProbabilities = OutputSlot( level=1 ) # Classification predictions (via feature cache for interactive speed) PredictionProbabilityChannels = OutputSlot( level=2) # Classification predictions, enumerated by channel SegmentationChannels = OutputSlot( level=2) # Binary image of the final selections. LabelImages = OutputSlot(level=1) # Labels from the user NonzeroLabelBlocks = OutputSlot( level=1) # A list if slices that contain non-zero label values Classifier = OutputSlot( ) # We provide the classifier as an external output for other applets to use CachedPredictionProbabilities = OutputSlot( level=1 ) # Classification predictions (via feature cache AND prediction cache) HeadlessPredictionProbabilities = OutputSlot( level=1 ) # Classification predictions ( via no image caches (except for the classifier itself ) HeadlessUint8PredictionProbabilities = OutputSlot( level=1) # Same as above, but 0-255 uint8 instead of 0.0-1.0 float32 HeadlessUncertaintyEstimate = OutputSlot( level=1 ) # Same as uncertaintly estimate, but does not rely on cached data. UncertaintyEstimate = OutputSlot(level=1) SimpleSegmentation = OutputSlot(level=1) # For debug, for now # GUI-only (not part of the pipeline, but saved to the project) LabelNames = OutputSlot() LabelColors = OutputSlot() PmapColors = OutputSlot() NumClasses = OutputSlot() def setupOutputs(self): self.LabelNames.meta.dtype = object self.LabelNames.meta.shape = (1, ) self.LabelColors.meta.dtype = object self.LabelColors.meta.shape = (1, ) self.PmapColors.meta.dtype = object self.PmapColors.meta.shape = (1, ) def __init__(self, *args, **kwargs): """ Instantiate all internal operators and connect them together. """ super(OpPixelClassification, self).__init__(*args, **kwargs) # Default values for some input slots self.FreezePredictions.setValue(True) self.LabelNames.setValue([]) self.LabelColors.setValue([]) self.PmapColors.setValue([]) # SPECIAL connection: The LabelInputs slot doesn't get it's data # from the InputImages slot, but it's shape must match. self.LabelInputs.connect(self.InputImages) # Hook up Labeling Pipeline self.opLabelPipeline = OpMultiLaneWrapper( OpLabelPipeline, parent=self, broadcastingSlotNames=['DeleteLabel']) self.opLabelPipeline.RawImage.connect(self.InputImages) self.opLabelPipeline.LabelInput.connect(self.LabelInputs) self.opLabelPipeline.DeleteLabel.setValue(-1) self.LabelImages.connect(self.opLabelPipeline.Output) self.NonzeroLabelBlocks.connect(self.opLabelPipeline.nonzeroBlocks) # Hook up the Training operator self.opTrain = OpTrainClassifierBlocked(parent=self) self.opTrain.ClassifierFactory.connect(self.ClassifierFactory) self.opTrain.Labels.connect(self.opLabelPipeline.Output) self.opTrain.Images.connect(self.CachedFeatureImages) self.opTrain.nonzeroLabelBlocks.connect( self.opLabelPipeline.nonzeroBlocks) # Hook up the Classifier Cache # The classifier is cached here to allow serializers to force in # a pre-calculated classifier (loaded from disk) self.classifier_cache = OpValueCache(parent=self) self.classifier_cache.name = "OpPixelClassification.classifier_cache" self.classifier_cache.inputs["Input"].connect( self.opTrain.outputs['Classifier']) self.classifier_cache.inputs["fixAtCurrent"].connect( self.FreezePredictions) self.Classifier.connect(self.classifier_cache.Output) # Hook up the prediction pipeline inputs self.opPredictionPipeline = OpMultiLaneWrapper(OpPredictionPipeline, parent=self) self.opPredictionPipeline.FeatureImages.connect(self.FeatureImages) self.opPredictionPipeline.CachedFeatureImages.connect( self.CachedFeatureImages) self.opPredictionPipeline.Classifier.connect( self.classifier_cache.Output) self.opPredictionPipeline.FreezePredictions.connect( self.FreezePredictions) self.opPredictionPipeline.PredictionsFromDisk.connect( self.PredictionsFromDisk) self.opPredictionPipeline.PredictionMask.connect(self.PredictionMasks) def _updateNumClasses(*args): """ When the number of labels changes, we MUST make sure that the prediction image changes its shape (the number of channels). Since setupOutputs is not called for mere dirty notifications, but is called in response to setValue(), we use this function to call setValue(). """ numClasses = len(self.LabelNames.value) self.opTrain.MaxLabel.setValue(numClasses) self.opPredictionPipeline.NumClasses.setValue(numClasses) self.NumClasses.setValue(numClasses) self.LabelNames.notifyDirty(_updateNumClasses) # Prediction pipeline outputs -> Top-level outputs self.PredictionProbabilities.connect( self.opPredictionPipeline.PredictionProbabilities) self.CachedPredictionProbabilities.connect( self.opPredictionPipeline.CachedPredictionProbabilities) self.HeadlessPredictionProbabilities.connect( self.opPredictionPipeline.HeadlessPredictionProbabilities) self.HeadlessUint8PredictionProbabilities.connect( self.opPredictionPipeline.HeadlessUint8PredictionProbabilities) self.PredictionProbabilityChannels.connect( self.opPredictionPipeline.PredictionProbabilityChannels) self.SegmentationChannels.connect( self.opPredictionPipeline.SegmentationChannels) self.UncertaintyEstimate.connect( self.opPredictionPipeline.UncertaintyEstimate) self.SimpleSegmentation.connect( self.opPredictionPipeline.SimpleSegmentation) self.HeadlessUncertaintyEstimate.connect( self.opPredictionPipeline.HeadlessUncertaintyEstimate) def inputResizeHandler(slot, oldsize, newsize): if (newsize == 0): self.LabelImages.resize(0) self.NonzeroLabelBlocks.resize(0) self.PredictionProbabilities.resize(0) self.CachedPredictionProbabilities.resize(0) self.InputImages.notifyResized(inputResizeHandler) # Debug assertions: Check to make sure the non-wrapped operators stayed that way. assert self.opTrain.Images.operator == self.opTrain def handleNewInputImage(multislot, index, *args): def handleInputReady(slot): self._checkConstraints(index) self.setupCaches(multislot.index(slot)) multislot[index].notifyReady(handleInputReady) self.InputImages.notifyInserted(handleNewInputImage) def handleNewMaskImage(multislot, index, *args): def handleInputReady(slot): self._checkConstraints(index) multislot[index].notifyReady(handleInputReady) self.PredictionMasks.notifyInserted(handleNewMaskImage) # All input multi-slots should be kept in sync # Output multi-slots will auto-sync via the graph multiInputs = filter(lambda s: s.level >= 1, self.inputs.values()) for s1 in multiInputs: for s2 in multiInputs: if s1 != s2: def insertSlot(a, b, position, finalsize): a.insertSlot(position, finalsize) s1.notifyInserted(partial(insertSlot, s2)) def removeSlot(a, b, position, finalsize): a.removeSlot(position, finalsize) s1.notifyRemoved(partial(removeSlot, s2)) def setupCaches(self, imageIndex): numImages = len(self.InputImages) inputSlot = self.InputImages[imageIndex] # # Can't setup if all inputs haven't been set yet. # if numImages != len(self.FeatureImages) or \ # numImages != len(self.CachedFeatureImages): # return # # self.LabelImages.resize(numImages) self.LabelInputs.resize(numImages) # Special case: We have to set up the shape of our label *input* according to our image input shape shapeList = list(self.InputImages[imageIndex].meta.shape) try: channelIndex = self.InputImages[imageIndex].meta.axistags.index( 'c') shapeList[channelIndex] = 1 except: pass self.LabelInputs[imageIndex].meta.shape = tuple(shapeList) self.LabelInputs[imageIndex].meta.axistags = inputSlot.meta.axistags def _checkConstraints(self, laneIndex): """ Ensure that all input images have the same number of channels. """ if not self.InputImages[laneIndex].ready(): return thisLaneTaggedShape = self.InputImages[laneIndex].meta.getTaggedShape() # Find a different lane and use it for comparison validShape = thisLaneTaggedShape for i, slot in enumerate(self.InputImages): if slot.ready() and i != laneIndex: validShape = slot.meta.getTaggedShape() break if validShape['c'] != thisLaneTaggedShape['c']: raise DatasetConstraintError( "Pixel Classification", "All input images must have the same number of channels. "\ "Your new image has {} channel(s), but your other images have {} channel(s)."\ .format( thisLaneTaggedShape['c'], validShape['c'] ) ) if len(validShape) != len(thisLaneTaggedShape): raise DatasetConstraintError( "Pixel Classification", "All input images must have the same dimensionality. "\ "Your new image has {} dimensions (including channel), but your other images have {} dimensions."\ .format( len(thisLaneTaggedShape), len(validShape) ) ) mask_slot = self.PredictionMasks[laneIndex] input_shape = tuple(thisLaneTaggedShape.values()) if mask_slot.ready() and mask_slot.meta.shape[:-1] != input_shape[:-1]: raise DatasetConstraintError( "Pixel Classification", "If you supply a prediction mask, it must have the same shape as the input image."\ "Your input image has shape {}, but your mask has shape {}."\ .format( input_shape, mask_slot.meta.shape ) ) def setInSlot(self, slot, subindex, roi, value): # Nothing to do here: All inputs that support __setitem__ # are directly connected to internal operators. pass def propagateDirty(self, slot, subindex, roi): # Nothing to do here: All outputs are directly connected to # internal operators that handle their own dirty propagation. pass def addLane(self, laneIndex): numLanes = len(self.InputImages) assert numLanes == laneIndex, "Image lanes must be appended." self.InputImages.resize(numLanes + 1) def removeLane(self, laneIndex, finalLength): self.InputImages.removeSlot(laneIndex, finalLength) def getLane(self, laneIndex): return OperatorSubView(self, laneIndex)
class OpTrackingFeatureExtraction(Operator): name = "Tracking Feature Extraction" TranslationVectors = InputSlot(optional=True) RawImage = InputSlot() BinaryImage = InputSlot() # which features to compute. # nested dictionary with format: # dict[plugin_name][feature_name][parameter_name] = parameter_value # for example {"Standard Object Features": {"Mean in neighborhood":{"margin": (5, 5, 2)}}} FeatureNamesVigra = InputSlot(rtype=List, stype=Opaque, value={}) FeatureNamesDivision = InputSlot(rtype=List, stype=Opaque, value={}) # Bypass cache (for headless mode) BypassModeEnabled = InputSlot(value=False) LabelImage = OutputSlot() ObjectCenterImage = OutputSlot() # the computed features. # nested dictionary with format: # dict[plugin_name][feature_name] = feature_value RegionFeaturesVigra = OutputSlot(stype=Opaque, rtype=List) RegionFeaturesDivision = OutputSlot(stype=Opaque, rtype=List) RegionFeaturesAll = OutputSlot(stype=Opaque, rtype=List) ComputedFeatureNamesAll = OutputSlot(rtype=List, stype=Opaque) ComputedFeatureNamesNoDivisions = OutputSlot(rtype=List, stype=Opaque) BlockwiseRegionFeaturesVigra = OutputSlot( ) # For compatibility with tracking workflow, the RegionFeatures output # has rtype=List, indexed by t. # For other workflows, output has rtype=ArrayLike, indexed by (t) BlockwiseRegionFeaturesDivision = OutputSlot() CleanLabelBlocks = OutputSlot() LabelImageCacheInput = InputSlot(optional=True) RegionFeaturesCacheInputVigra = InputSlot(optional=True) RegionFeaturesCleanBlocksVigra = OutputSlot() RegionFeaturesCacheInputDivision = InputSlot(optional=True) RegionFeaturesCleanBlocksDivision = OutputSlot() def __init__(self, parent): super(OpTrackingFeatureExtraction, self).__init__(parent) self._default_features = None # internal operators self._objectExtraction = OpObjectExtraction(parent=self) self._opDivFeats = OpCachedDivisionFeatures(parent=self) self._opDivFeatsAdaptOutput = OpAdaptTimeListRoi(parent=self) # connect internal operators self._objectExtraction.RawImage.connect(self.RawImage) self._objectExtraction.BinaryImage.connect(self.BinaryImage) self._objectExtraction.BypassModeEnabled.connect( self.BypassModeEnabled) self._objectExtraction.Features.connect(self.FeatureNamesVigra) self._objectExtraction.RegionFeaturesCacheInput.connect( self.RegionFeaturesCacheInputVigra) self._objectExtraction.LabelImageCacheInput.connect( self.LabelImageCacheInput) self.CleanLabelBlocks.connect(self._objectExtraction.CleanLabelBlocks) self.RegionFeaturesCleanBlocksVigra.connect( self._objectExtraction.RegionFeaturesCleanBlocks) self.ObjectCenterImage.connect( self._objectExtraction.ObjectCenterImage) self.LabelImage.connect(self._objectExtraction.LabelImage) self.BlockwiseRegionFeaturesVigra.connect( self._objectExtraction.BlockwiseRegionFeatures) self.RegionFeaturesVigra.connect(self._objectExtraction.RegionFeatures) self._opDivFeats.LabelImage.connect(self.LabelImage) self._opDivFeats.DivisionFeatureNames.connect( self.FeatureNamesDivision) self._opDivFeats.CacheInput.connect( self.RegionFeaturesCacheInputDivision) self._opDivFeats.RegionFeaturesVigra.connect( self._objectExtraction.BlockwiseRegionFeatures) self.RegionFeaturesCleanBlocksDivision.connect( self._opDivFeats.CleanBlocks) self.BlockwiseRegionFeaturesDivision.connect(self._opDivFeats.Output) self._opDivFeatsAdaptOutput.Input.connect(self._opDivFeats.Output) self.RegionFeaturesDivision.connect(self._opDivFeatsAdaptOutput.Output) # As soon as input data is available, check its constraints self.RawImage.notifyReady(self._checkConstraints) self.BinaryImage.notifyReady(self._checkConstraints) # FIXME this shouldn't be done in post-filtering, but in reading the config or around that time self.RawImage.notifyReady(self._filterFeaturesByDim) def setDefaultFeatures(self, feats): self._default_features = feats def setupOutputs(self, *args, **kwargs): self.ComputedFeatureNamesAll.meta.assignFrom( self.FeatureNamesVigra.meta) self.ComputedFeatureNamesNoDivisions.meta.assignFrom( self.FeatureNamesVigra.meta) self.RegionFeaturesAll.meta.assignFrom(self.RegionFeaturesVigra.meta) def execute(self, slot, subindex, roi, result): if slot == self.ComputedFeatureNamesAll: feat_names_vigra = self.FeatureNamesVigra([]).wait() feat_names_div = self.FeatureNamesDivision([]).wait() for plugin_name in list(feat_names_vigra.keys()): assert plugin_name not in feat_names_div, "feature name dictionaries must be mutually exclusive" for plugin_name in list(feat_names_div.keys()): assert plugin_name not in feat_names_vigra, "feature name dictionaries must be mutually exclusive" result = dict( list(feat_names_vigra.items()) + list(feat_names_div.items())) return result elif slot == self.ComputedFeatureNamesNoDivisions: feat_names_vigra = self.FeatureNamesVigra([]).wait() result = dict(list(feat_names_vigra.items())) return result elif slot == self.RegionFeaturesAll: feat_vigra = self.RegionFeaturesVigra(roi).wait() feat_div = self.RegionFeaturesDivision(roi).wait() assert np.all(list(feat_vigra.keys()) == list(feat_div.keys())) result = {} for t in list(feat_vigra.keys()): for plugin_name in list(feat_vigra[t].keys()): assert plugin_name not in feat_div[ t], "feature dictionaries must be mutually exclusive" for plugin_name in list(feat_div[t].keys()): assert plugin_name not in feat_vigra[ t], "feature dictionaries must be mutually exclusive" result[t] = dict( list(feat_div[t].items()) + list(feat_vigra[t].items())) return result else: assert False, "Shouldn't get here." def propagateDirty(self, slot, subindex, roi): if slot == self.BypassModeEnabled: pass elif slot == self.FeatureNamesVigra or slot == self.FeatureNamesDivision: self.ComputedFeatureNamesAll.setDirty(roi) self.ComputedFeatureNamesNoDivisions.setDirty(roi) def setInSlot(self, slot, subindex, roi, value): assert slot == self.RegionFeaturesCacheInputVigra or \ slot == self.RegionFeaturesCacheInputDivision or \ slot == self.LabelImageCacheInput, "Invalid slot for setInSlot(): {}".format(slot.name) def _checkConstraints(self, *args): if self.RawImage.ready(): rawTaggedShape = self.RawImage.meta.getTaggedShape() if 't' not in rawTaggedShape or rawTaggedShape['t'] < 2: msg = "Raw image must have a time dimension with at least 2 images.\n"\ "Your dataset has shape: {}".format(self.RawImage.meta.shape) if self.BinaryImage.ready(): rawTaggedShape = self.BinaryImage.meta.getTaggedShape() if 't' not in rawTaggedShape or rawTaggedShape['t'] < 2: msg = "Binary image must have a time dimension with at least 2 images.\n"\ "Your dataset has shape: {}".format(self.BinaryImage.meta.shape) if self.RawImage.ready() and self.BinaryImage.ready(): rawTaggedShape = self.RawImage.meta.getTaggedShape() binTaggedShape = self.BinaryImage.meta.getTaggedShape() rawTaggedShape['c'] = None binTaggedShape['c'] = None if dict(rawTaggedShape) != dict(binTaggedShape): logger.info("Raw data and other data must have equal dimensions (different channels are okay).\n"\ "Your datasets have shapes: {} and {}".format( self.RawImage.meta.shape, self.BinaryImage.meta.shape )) msg = "Raw data and other data must have equal dimensions (different channels are okay).\n"\ "Your datasets have shapes: {} and {}".format( self.RawImage.meta.shape, self.BinaryImage.meta.shape ) raise DatasetConstraintError("Object Extraction", msg) def _filterFeaturesByDim(self, *args): # Remove 2D-only features from 3D datasets # Features look as follows: # dict[plugin_name][feature_name][parameter_name] = parameter_value # for example {"Standard Object Features": {"Mean in neighborhood":{"margin": (5, 5, 2)}}} # FIXME: this is a hacky solution because we overwrite the INPUT slot FeatureNamesVigra depending on the data shape. # We store the _default_features separately because if the user switches from a 3D to a 2D dataset the value of # FeatureNamesVigra would not be populated with all the features again, but only those that work for 3D. if self.RawImage.ready() and self.FeatureNamesVigra.ready(): rawTaggedShape = self.RawImage.meta.getTaggedShape() filtered_features_dict = {} if rawTaggedShape['z'] > 1: # Filter out the 2D-only features, which helpfully have "2D" in their plugin name current_dict = self._default_features for plugin in list(current_dict.keys()): if not "2D" in plugin: filtered_features_dict[plugin] = current_dict[plugin] self.FeatureNamesVigra.setValue(filtered_features_dict) else: # Filter out the 2D-only features, which helpfully have "2D" in their plugin name current_dict = self._default_features for plugin in list(current_dict.keys()): if not "3D" in plugin: filtered_features_dict[plugin] = current_dict[plugin] self.FeatureNamesVigra.setValue(filtered_features_dict)
class OpInterpMissingData(Operator): name = "OpInterpMissingData" InputVolume = InputSlot() InputSearchDepth = InputSlot(value=3) PatchSize = InputSlot(value=128) HaloSize = InputSlot(value=30) DetectionMethod = InputSlot(value="svm") InterpolationMethod = InputSlot(value="cubic") # be careful when using the following: setting the same thing twice will not trigger # the action you desire, even if something else has changed OverloadDetector = InputSlot(value="") Output = OutputSlot() Missing = OutputSlot() Detector = OutputSlot(stype=Opaque) _requiredMargin = {"cubic": 2, "linear": 1, "constant": 0} _dirty = False def __init__(self, *args, **kwargs): super(OpInterpMissingData, self).__init__(*args, **kwargs) self.detector = OpDetectMissing(parent=self) self.interpolator = OpInterpolate(parent=self) self.detector.InputVolume.connect(self.InputVolume) self.detector.PatchSize.connect(self.PatchSize) self.detector.HaloSize.connect(self.HaloSize) self.detector.DetectionMethod.connect(self.DetectionMethod) self.detector.OverloadDetector.connect(self.OverloadDetector) self.interpolator.InputVolume.connect(self.InputVolume) self.interpolator.Missing.connect(self.detector.Output) self.interpolator.InterpolationMethod.connect(self.InterpolationMethod) self.Missing.connect(self.detector.Output) self.Detector.connect(self.detector.Detector) def isDirty(self): return self._dirty def resetDirty(self): self._dirty = False def setupOutputs(self): # Output has the same shape/axes/dtype/drange as input self.Output.meta.assignFrom(self.InputVolume.meta) self.Detector.meta.shape = (1, ) def execute(self, slot, subindex, roi, result): """ execute """ method = self.InterpolationMethod.value assert method in list(self._requiredMargin.keys( )), "Unknown interpolation method {}".format(method) z_index = self.InputVolume.meta.axistags.index("z") c_index = self.InputVolume.meta.axistags.index("c") t_index = self.InputVolume.meta.axistags.index("t") # nz = self.InputVolume.meta.getTaggedShape()['z'] resultZYXCT = vigra.taggedView( result, self.InputVolume.meta.axistags).withAxes(*"zyxct") # backup ROI oldStart = np.copy(roi.start) oldStop = np.copy(roi.stop) if c_index < len(roi.start): cRange = np.arange(roi.start[c_index], roi.stop[c_index]) else: cRange = np.array([0]) if t_index < len(roi.start): tRange = np.arange(roi.start[t_index], roi.stop[t_index]) else: tRange = np.array([0]) for c in cRange: for t in tRange: # change roi to single block if c_index < len(roi.start): roi.start[c_index] = c roi.stop[c_index] = c + 1 if t_index < len(roi.start): roi.start[t_index] = t roi.stop[t_index] = t + 1 # check if more input is needed, and how many z_offsets = self._extendRoi(roi) # get extended interpolation roi.start[z_index] -= z_offsets[0] roi.stop[z_index] += z_offsets[1] a = self.interpolator.Output.get(roi).wait() # reduce to original roi roi.stop = roi.stop - roi.start roi.start *= 0 roi.start[z_index] += z_offsets[0] roi.stop[z_index] -= z_offsets[1] key = roiToSlice(roi.start, roi.stop) resultZYXCT[..., c, t] = vigra.taggedView( a[key], self.InputVolume.meta.axistags).withAxes(*"zyx") # restore ROI, will be used in other methods!!! roi.start = np.copy(oldStart) roi.stop = np.copy(oldStop) return result def propagateDirty(self, slot, subindex, roi): if slot == self.InputVolume: self.Output.setDirty(roi) if slot == self.OverloadDetector: self._dirty = True if slot == self.PatchSize or slot == self.HaloSize: self._dirty = True def train(self, force=False): return self.detector.train(force=force) def _extendRoi(self, roi): origStart = np.copy(roi.start) origStop = np.copy(roi.stop) offset_top = 0 offset_bot = 0 z_index = self.InputVolume.meta.axistags.index("z") depth = self.InputSearchDepth.value nRequestedSlices = roi.stop[z_index] - roi.start[z_index] nNeededSlices = self._requiredMargin[self.InterpolationMethod.value] missing = vigra.taggedView( self.detector.Output.get(roi).wait(), axistags=self.InputVolume.meta.axistags).withAxes(*"zyx") nGoodSlicesTop = 0 # go inside the roi for k in range(nRequestedSlices): if np.all(missing[k, ...] == 0): # clean slice nGoodSlicesTop += 1 else: break # are we finished yet? if nGoodSlicesTop >= nRequestedSlices: return (0, 0) # looks like we need more slices on top while nGoodSlicesTop < nNeededSlices and offset_top < depth and roi.start[ z_index] > 0: roi.stop[z_index] = roi.start[z_index] roi.start[z_index] -= 1 offset_top += 1 topmissing = self.detector.Output.get(roi).wait() if np.all(topmissing == 0): # clean slice nGoodSlicesTop += 1 else: # need to start again nGoodSlicesTop = 0 nGoodSlicesBot = 0 # go inside the roi for k in range(1, nRequestedSlices + 1): if np.all(missing[-k, ...] == 0): # clean slice nGoodSlicesBot += 1 else: break roi.start = np.copy(origStart) roi.stop = np.copy(origStop) # looks like we need more slices on bottom while (roi.stop[z_index] < self.InputVolume.meta.getTaggedShape()["z"] and nGoodSlicesBot < nNeededSlices and offset_bot < depth): roi.start[z_index] = roi.stop[z_index] roi.stop[z_index] += 1 offset_bot += 1 botmissing = self.detector.Output.get(roi).wait() if np.all(botmissing == 0): # clean slice nGoodSlicesBot += 1 else: # need to start again nGoodSlicesBot = 0 roi.start = np.copy(origStart) roi.stop = np.copy(origStop) return (offset_top, offset_bot)
class OpArrayCache(Operator, ManagedCache): """ Allocates a block of memory as large as Input.meta.shape (==Output.meta.shape) with the same dtype in order to be able to cache results. blockShape: dirty regions are tracked with a granularity of blockShape """ name = "ArrayCache" description = "numpy.ndarray caching class" category = "misc" DefaultBlockSize = 64 #Input Input = InputSlot(allow_mask=True) blockShape = InputSlot(value = DefaultBlockSize) fixAtCurrent = InputSlot(value = False) #Output CleanBlocks = OutputSlot() Output = OutputSlot(allow_mask=True) loggingName = __name__ + ".OpArrayCache" logger = logging.getLogger(loggingName) traceLogger = logging.getLogger("TRACE." + loggingName) # Block states IN_PROCESS = 0 DIRTY = 1 CLEAN = 2 FIXED_DIRTY = 3 def __init__(self, *args, **kwargs): super( OpArrayCache, self ).__init__(*args, **kwargs) self._origBlockShape = self.DefaultBlockSize self._blockShape = None self._dirtyShape = None self._blockState = None self._dirtyState = None self._fixed = False self._cache = None self._lock = Lock() self._lazyAlloc = True self._cacheHits = 0 self._has_fixed_dirty_blocks = False self._running = 0 # Now that we're initialized, it's safe to register with the memory manager self.registerWithMemoryManager() # ========== CACHE API ========== def usedMemory(self): return self._usedMemory(self._cache) def fractionOfUsedMemoryDirty(self): if self.Output.meta.shape is None: return 0 totAll = numpy.prod(self.Output.meta.shape) totDirty = 0 if self._blockState is None: return 0 it = numpy.nditer(self._blockState, flags=['multi_index']) while not it.finished: v = it[0] sh = self._blockShapeForIndex(it.multi_index) it.iternext() if sh is None: continue if v == self.DIRTY or v == self.FIXED_DIRTY: totDirty += numpy.prod(sh) return totDirty/float(totAll) def lastAccessTime(self): return super(OpArrayCache, self).lastAccessTime() def freeMemory(self): return self._freeMemory() def freeDirtyMemory(self): if self.fractionOfUsedMemoryDirty() >= 1.0: # we can only free if all is dirty without touching non-dirty areas return self.freeMemory() else: return 0 def generateReport(self, memInfoNode): super(OpArrayCache, self).generateReport(memInfoNode) if self._cache is not None: if hasattr(self._cache, "dtype"): memInfoNode.dtype = self._cache.dtype else: # cache is no array, so we cannot determine the dtype pass # ========== END CACHE API ========== @staticmethod def _usedMemory(item): s = 0 if isinstance(item, numpy.ndarray): if item.dtype == numpy.object: for x in item.ravel(): s += OpArrayCache._usedMemory(x) else: s = item.nbytes elif isinstance(item, dict): for key in item.keys(): try: obj = item[key] except KeyError: # cleaned up pass else: s += OpArrayCache._usedMemory(obj) return s def _blockShapeForIndex(self, index): if self._cache is None: return None index = numpy.asarray(index) cacheShape = numpy.array(self._cache.shape) blockShape = numpy.array(self._blockShape) start = blockShape * index stop = blockShape * (index + 1) start = numpy.maximum(start, (0,)*len(start)) stop = numpy.minimum(stop, cacheShape) ret = tuple(map(int, stop - start)) return ret def _freeMemory(self, refcheck = True): with self._lock: freed = self.usedMemory() if self._cache is not None and (self._blockState != OpArrayCache.IN_PROCESS).all(): if self._cache.shape == (): return fshape = self._cache.shape try: self._cache.resize((), refcheck = refcheck) except ValueError: freed = 0 self.logger.debug("OpArrayCache (name={}): freeing failed due to view references".format(self.name)) if freed > 0: self.logger.debug("OpArrayCache: freed cache of shape:{}".format(fshape)) self._blockState[:] = OpArrayCache.DIRTY del self._cache self._cache = None return freed def _get_full_blockshape(self, input_blockshape): max_shape = self.Input.meta.shape if not isinstance(input_blockshape, collections.Iterable): # Broadcast as a tuple blockshape = (input_blockshape,)*len(max_shape) else: blockshape = tuple( input_blockshape ) blockshape = numpy.minimum(blockshape, max_shape) return tuple(blockshape) def _allocateManagementStructures(self): shape = self.Output.meta.shape self._blockShape = self._get_full_blockshape(self._origBlockShape) self._dirtyShape = numpy.ceil(1.0 * numpy.array(shape) / numpy.array(self._blockShape)).astype(numpy.int) self.logger.debug("Configured OpArrayCache with shape={}, blockShape={}, dirtyShape={}, origBlockShape={}".format(shape, self._blockShape, self._dirtyShape, self._origBlockShape)) #if a request has been submitted to get a block, the request object #is stored within this array self._blockQuery = numpy.ndarray(self._dirtyShape, dtype=object) #keep track of the dirty state of each block self._blockState = OpArrayCache.DIRTY * numpy.ones(self._dirtyShape, numpy.uint8) self._blockState[:]= OpArrayCache.DIRTY self._dirtyState = OpArrayCache.CLEAN def _allocateCache(self): self._last_access_time = 0 self._cache_priority = 0 self._running = 0 if self._cache is None or (self._cache.shape != self.Output.meta.shape): mem = self.Output.stype.allocateDestination(None) mem[:] = 0 self.logger.debug("OpArrayCache: Allocating cache (size: %dbytes)" % mem.nbytes) if self._blockState is None: self._allocateManagementStructures() self._cache = mem def setupOutputs(self): self.CleanBlocks.meta.shape = (1,) self.CleanBlocks.meta.dtype = object reconfigure = False if self.inputs["fixAtCurrent"].ready(): self._fixed = self.inputs["fixAtCurrent"].value if self.inputs["blockShape"].ready() and self.inputs["Input"].ready(): newBShape = self.inputs["blockShape"].value assert numpy.issubdtype(type(newBShape), numpy.integer) or all( map(lambda x: numpy.issubdtype(type(x), numpy.integer), newBShape) ) if self._origBlockShape != newBShape and self.inputs["Input"].ready(): reconfigure = True self._origBlockShape = newBShape self._blockShape = newBShape inputSlot = self.inputs["Input"] self.Output.meta.assignFrom(inputSlot.meta) if isinstance(self._blockShape, collections.Iterable) and \ len(self._blockShape) != len(self.Input.meta.shape): self.Output.meta.NOTREADY = True self.CleanBlocks.meta.NOTREADY = True return # Estimate ram usage ram_per_pixel = 0 if self.Output.meta.dtype == object or self.Output.meta.dtype == numpy.object_: ram_per_pixel = sys.getsizeof(None) elif numpy.issubdtype(self.Output.meta.dtype, numpy.dtype): ram_per_pixel = self.Output.meta.dtype().nbytes tagged_shape = self.Output.meta.getTaggedShape() if 'c' in tagged_shape: ram_per_pixel *= float(tagged_shape['c']) if self.Output.meta.ram_usage_per_requested_pixel is not None: ram_per_pixel = max( ram_per_pixel, self.Output.meta.ram_usage_per_requested_pixel ) self.Output.meta.ram_usage_per_requested_pixel = ram_per_pixel shape = self.Output.meta.shape if (self._dirtyShape is None or reconfigure) and shape is not None: with self._lock: self._allocateManagementStructures() if not self._lazyAlloc: self._allocateCache() self.Output.meta.ideal_blockshape = self._get_full_blockshape(self._origBlockShape) def propagateDirty(self, slot, subindex, roi): shape = self.Input.meta.shape key = roi.toSlice() if slot == self.inputs["Input"]: start, stop = sliceToRoi(key, shape) with self._lock: if self._blockState is not None: blockStart = numpy.floor(1.0 * start / self._blockShape) blockStop = numpy.ceil(1.0 * stop / self._blockShape) blockKey = roiToSlice(blockStart,blockStop) if self._fixed: # Remember that this block became dirty while we were fixed # so we can notify downstream operators when we become unfixed. self._blockState[blockKey] = OpArrayCache.FIXED_DIRTY self._has_fixed_dirty_blocks = True else: self._blockState[blockKey] = OpArrayCache.DIRTY if not self._fixed: self.outputs["Output"].setDirty(key) if slot == self.inputs["fixAtCurrent"]: if self.inputs["fixAtCurrent"].ready(): self._fixed = self.inputs["fixAtCurrent"].value if not self._fixed and self.Output.meta.shape is not None and self._has_fixed_dirty_blocks: # We've become unfixed, so we need to notify downstream # operators of every block that became dirty while we were fixed. # Convert all FIXED_DIRTY states into DIRTY states with self._lock: cond = (self._blockState[...] == OpArrayCache.FIXED_DIRTY) self._blockState[...] = fastWhere(cond, OpArrayCache.DIRTY, self._blockState, numpy.uint8) self._has_fixed_dirty_blocks = False newDirtyBlocks = numpy.transpose(numpy.nonzero(cond)) # To avoid lots of setDirty notifications, we simply merge all the dirtyblocks into one single superblock. # This should be the best option in most cases, but could be bad in some cases. # TODO: Optimize this by merging the dirty blocks via connected components or something. cacheShape = numpy.array(self.Output.meta.shape) dirtyStart = cacheShape dirtyStop = [0] * len(cacheShape) for index in newDirtyBlocks: blockStart = index * self._blockShape blockStop = numpy.minimum(blockStart + self._blockShape, cacheShape) dirtyStart = numpy.minimum(dirtyStart, blockStart) dirtyStop = numpy.maximum(dirtyStop, blockStop) if len(newDirtyBlocks > 0): self.Output.setDirty( dirtyStart, dirtyStop ) def _updatePriority(self, new_access = None): if self._last_access_time is None: self._last_access_time = new_access or time.time() cur_time = time.time() delta = cur_time - self._last_access_time + 1e-9 self._last_access_time = cur_time new_prio = 0.5 * self._cache_priority + delta self._cache_priority = new_prio def execute(self, slot, subindex, roi, result): if slot == self.Output: return self._executeOutput(slot, subindex, roi, result) elif slot == self.CleanBlocks: return self._executeCleanBlocks(slot, subindex, roi, result) def _executeOutput(self, slot, subindex, roi, result): t = time.time() key = roi.toSlice() shape = self.Output.meta.shape start, stop = sliceToRoi(key, shape) with self._lock: ch = self._cacheHits ch += 1 self._cacheHits = ch self._running += 1 if (self._cache is None or self._cache.shape != self.Output.meta.shape): self._allocateCache() cacheView = self._cache[:] #prevent freeing of cache during running this function blockStart = (1.0 * start / self._blockShape).floor() blockStop = (1.0 * stop / self._blockShape).ceil() blockKey = roiToSlice(blockStart,blockStop) blockSet = self._blockState[blockKey] # this is a little optimization to shortcut # many lines of python code when all data is # is already in the cache: if numpy.logical_or(blockSet == OpArrayCache.CLEAN, blockSet == OpArrayCache.FIXED_DIRTY).all(): cache_result = self._cache[roiToSlice(start, stop)] self.Output.stype.copy_data(result, cache_result) self._running -= 1 self._updatePriority() cacheView = None return extracted = numpy.extract( blockSet == OpArrayCache.IN_PROCESS, self._blockQuery[blockKey]) inProcessQueries = numpy.unique(extracted) cond = (blockSet == OpArrayCache.DIRTY) tileWeights = fastWhere(cond, 1, 128**3, numpy.uint32) trueDirtyIndices = numpy.nonzero(cond) if has_drtile: tileArray = drtile.test_DRTILE(tileWeights, 128**3).swapaxes(0,1) else: tileStartArray = numpy.array(trueDirtyIndices) tileStopArray = 1 + tileStartArray tileArray = numpy.concatenate((tileStartArray, tileStopArray), axis=0) dirtyRois = [] half = tileArray.shape[0]//2 dirtyPool = RequestPool() for i in range(tileArray.shape[1]): drStart3 = tileArray[:half,i] drStop3 = tileArray[half:,i] drStart2 = drStart3 + blockStart drStop2 = drStop3 + blockStart drStart = drStart2*self._blockShape drStop = drStop2*self._blockShape shape = self.Output.meta.shape drStop = numpy.minimum(drStop, shape) drStart = numpy.minimum(drStart, shape) key2 = roiToSlice(drStart2,drStop2) key = roiToSlice(drStart,drStop) if not self._fixed: dirtyRois.append([drStart,drStop]) req = self.inputs["Input"][key].writeInto(self._cache[key]) req.uncancellable = True #FIXME dirtyPool.add(req) self._blockQuery[key2] = weakref.ref(req) #sanity check: if (self._blockState[key2] != OpArrayCache.DIRTY).any(): logger.warning( "original condition" + str(cond) ) logger.warning( "original tilearray {} {}".format( tileArray, tileArray.shape ) ) logger.warning( "original tileWeights {} {}".format( tileWeights, tileWeights.shape ) ) logger.warning( "sub condition {}".format( self._blockState[key2] == OpArrayCache.DIRTY ) ) logger.warning( "START={}, STOP={}".format( drStart2, drStop2 ) ) import h5py with h5py.File("test.h5", "w") as f: f.create_dataset("data",data = tileWeights) logger.warning( "%r \n %r \n %r\n %r\n %r \n%r" % (key2, blockKey,self._blockState[key2], self._blockState[blockKey][trueDirtyIndices],self._blockState[blockKey],tileWeights) ) assert False self._blockState[key2] = OpArrayCache.IN_PROCESS # indicate the inprocessing state, by setting array to 0 (i.e. IN_PROCESS) if not self._fixed: blockSet[:] = fastWhere(cond, OpArrayCache.IN_PROCESS, blockSet, numpy.uint8) else: # Someone asked for some dirty blocks while we were fixed. # Mark these blocks to be signaled as dirty when we become unfixed blockSet[:] = fastWhere(cond, OpArrayCache.FIXED_DIRTY, blockSet, numpy.uint8) self._has_fixed_dirty_blocks = True temp = itertools.count(0) #wait for all requests to finish something_updated = len( dirtyPool ) > 0 dirtyPool.wait() if something_updated: # Signal that something was updated. # Note that we don't need to do this for the 'in process' queries (below) # because they are already in the dirtyPool in some other thread self.Output._sig_value_changed() # indicate the finished inprocess state (i.e. CLEAN) if not self._fixed and temp.next() == 0: with self._lock: blockSet[:] = fastWhere(cond, OpArrayCache.CLEAN, blockSet, numpy.uint8) self._blockQuery[blockKey] = fastWhere(cond, None, self._blockQuery[blockKey], object) # Wait for all in-process queries. # Can't use RequestPool here because these requests have already started. for req in inProcessQueries: req = req() # get original req object from weakref if req is not None: req.wait() # finally, store results in result area with self._lock: if self._cache is not None: cache_result = self._cache[roiToSlice(start, stop)] self.Output.stype.copy_data(result, cache_result) else: self.inputs["Input"][roiToSlice(start, stop)].writeInto(result).wait() self._running -= 1 self._updatePriority() cacheView = None self.logger.debug("read %s took %f sec." % (roi.pprint(), time.time()-t)) def setInSlot(self, slot, subindex, roi, value): assert slot == self.inputs["Input"] ch = self._cacheHits ch += 1 self._cacheHits = ch start, stop = roi.start, roi.stop blockStart = numpy.ceil(1.0 * start / self._blockShape) blockStop = numpy.floor(1.0 * stop / self._blockShape) blockStop = numpy.where(stop == self.Output.meta.shape, self._dirtyShape, blockStop) blockKey = roiToSlice(blockStart,blockStop) if (self._blockState[blockKey] != OpArrayCache.CLEAN).any(): start2 = blockStart * self._blockShape stop2 = blockStop * self._blockShape stop2 = numpy.minimum(stop2, self.Output.meta.shape) key2 = roiToSlice(start2,stop2) with self._lock: if self._cache is None: self._allocateCache() self.Output.stype.copy_data( self._cache[key2], value[roiToSlice(start2-start,stop2-start)] ) self._blockState[blockKey] = self._dirtyState self._blockQuery[blockKey] = None def _executeCleanBlocks(self, slot, subindex, roi, destination): indexCols = numpy.where(self._blockState == OpArrayCache.CLEAN) clean_block_starts = numpy.array(indexCols).transpose() clean_block_starts *= self._blockShape inputShape = self.Input.meta.shape clean_block_rois = map( partial( getBlockBounds, inputShape, self._blockShape ), clean_block_starts ) destination[0] = map( partial(map, TinyVector), clean_block_rois ) return destination
class OpVigraWatershedViewer(Operator): name = "OpWatershedViewer" category = "top-level" RawImage = InputSlot( optional=True) # Displayed in the GUI (not used in pipeline) InputImage = InputSlot() # The image to be sliced and watershedded FreezeCache = InputSlot(value=True) # opWatershedCache InputChannelIndexes = InputSlot(value=[0]) WatershedPadding = InputSlot(value=10) OverrideLabels = InputSlot(value={0: (0, 0, 0, 0)}) SeedThresholdValue = InputSlot(value=0.0) MinSeedSize = InputSlot(value=0) CacheBlockShape = InputSlot( value=(256, 10) ) # opWatershedCache block shapes. Expected: tuple of (width, depth) viewing shape Seeds = OutputSlot() # For batch export WatershedLabels = OutputSlot() # Watershed labeled output SummedInput = OutputSlot() # Watershed input (for gui display) ColoredPixels = OutputSlot() # Colored watershed labels (for gui display) ColoredSeeds = OutputSlot() # Seeds to the watershed (for gui display) SelectedInputChannels = OutputSlot(level=1) def __init__(self, *args, **kwargs): super(OpVigraWatershedViewer, self).__init__(*args, **kwargs) self._seedThreshold = None # Overview Schematic # Example here uses input channels 0,2,5 # InputChannelIndexes=[0,2,5] ---- # \ # InputImage --> opChannelSlicer .Slices[0] ---\ # .Slices[1] ----> opAverage -------------------------------------------------> opWatershed --> opWatershedCache --> opColorizer --> GUI # .Slices[2] ---/ \ / # \ MinSeedSize / # \ \ / # SeedThresholdValue ----------> opThreshold --> opSeedLabeler --> opSeedFilter --> opSeedCache --> opSeedColorizer --> GUI # Create operators self.opChannelSlicer = OpMultiArraySlicer2(parent=self) self.opAverage = OpMultiArrayMerger(parent=self) self.opWatershed = OpVigraWatershed(parent=self) self.opWatershedCache = OpSlicedBlockedArrayCache(parent=self) self.opColorizer = OpColorizeLabels(parent=self) self.opThreshold = OpPixelOperator(parent=self) self.opSeedLabeler = OpVigraLabelVolume(parent=self) self.opSeedFilter = OpFilterLabels(parent=self) self.opSeedCache = OpSlicedBlockedArrayCache(parent=self) self.opSeedColorizer = OpColorizeLabels(parent=self) # Select specific input channels self.opChannelSlicer.Input.connect(self.InputImage) self.opChannelSlicer.SliceIndexes.connect(self.InputChannelIndexes) self.opChannelSlicer.AxisFlag.setValue('c') # Average selected channels def average(arrays): if len(arrays) == 0: return 0 else: return sum(arrays) / float(len(arrays)) self.opAverage.MergingFunction.setValue(average) self.opAverage.Inputs.connect(self.opChannelSlicer.Slices) # Threshold for seeds self.opThreshold.Input.connect(self.opAverage.Output) # Label seeds self.opSeedLabeler.Input.connect(self.opThreshold.Output) # Filter seeds self.opSeedFilter.MinLabelSize.connect(self.MinSeedSize) self.opSeedFilter.Input.connect(self.opSeedLabeler.Output) # Cache seeds self.opSeedCache.fixAtCurrent.connect(self.FreezeCache) self.opSeedCache.Input.connect(self.opSeedFilter.Output) # Color seeds for RBG display self.opSeedColorizer.Input.connect(self.opSeedCache.Output) self.opSeedColorizer.OverrideColors.setValue({0: (0, 0, 0, 0)}) # Compute watershed labels (possibly with seeds, see setupOutputs) self.opWatershed.InputImage.connect(self.opAverage.Output) self.opWatershed.PaddingWidth.connect(self.WatershedPadding) # Cache the watershed output self.opWatershedCache.fixAtCurrent.connect(self.FreezeCache) self.opWatershedCache.Input.connect(self.opWatershed.Output) # Colorize the watershed labels for RGB display self.opColorizer.Input.connect(self.opWatershedCache.Output) self.opColorizer.OverrideColors.connect(self.OverrideLabels) # Connnect external outputs the operators that provide them self.Seeds.connect(self.opSeedCache.Output) self.ColoredPixels.connect(self.opColorizer.Output) self.SelectedInputChannels.connect(self.opChannelSlicer.Slices) self.SummedInput.connect(self.opAverage.Output) self.ColoredSeeds.connect(self.opSeedColorizer.Output) def setupOutputs(self): # User has control over cache block shape # Width and depth are applied to x,y, or z depending on which slicing view is being used. width, depth = self.CacheBlockShape.value ## Cache blocks # Inner and outer block shapes are the same. # We're using this cache for the "sliced" property, not the "blocked" property. blockDimsX = { 't': (1, 1), 'z': (width, width), 'y': (width, width), 'x': (depth, depth), 'c': (1, 1) } blockDimsY = { 't': (1, 1), 'z': (width, width), 'y': (depth, depth), 'x': (width, width), 'c': (1, 1) } blockDimsZ = { 't': (1, 1), 'z': (depth, depth), 'y': (width, width), 'x': (width, width), 'c': (1, 1) } # Set the blockshapes for each input image separately, depending on which axistags it has. axisOrder = [tag.key for tag in self.InputImage.meta.axistags] innerBlockShapeX = tuple(blockDimsX[k][0] for k in axisOrder) outerBlockShapeX = tuple(blockDimsX[k][1] for k in axisOrder) innerBlockShapeY = tuple(blockDimsY[k][0] for k in axisOrder) outerBlockShapeY = tuple(blockDimsY[k][1] for k in axisOrder) innerBlockShapeZ = tuple(blockDimsZ[k][0] for k in axisOrder) outerBlockShapeZ = tuple(blockDimsZ[k][1] for k in axisOrder) self.opWatershedCache.innerBlockShape.setValue( (innerBlockShapeX, innerBlockShapeY, innerBlockShapeZ)) self.opWatershedCache.outerBlockShape.setValue( (outerBlockShapeX, outerBlockShapeY, outerBlockShapeZ)) # Seed cache has same shape as watershed cache self.opSeedCache.innerBlockShape.setValue( (innerBlockShapeX, innerBlockShapeY, innerBlockShapeZ)) self.opSeedCache.outerBlockShape.setValue( (outerBlockShapeX, outerBlockShapeY, outerBlockShapeZ)) # For now watershed labels always come from the X-Y slicing view if len(self.opWatershedCache.InnerOutputs) > 0: self.WatershedLabels.connect(self.opWatershedCache.InnerOutputs[2]) if self.SeedThresholdValue.ready(): seedThreshold = self.SeedThresholdValue.value if not self.opWatershed.SeedImage.connected( ) or seedThreshold != self._seedThreshold: self._seedThreshold = seedThreshold self.opThreshold.Function.setValue( lambda a: (a <= seedThreshold).astype(numpy.uint8)) self.opWatershed.SeedImage.connect(self.opSeedFilter.Output) else: self.opWatershed.SeedImage.disconnect() self.opThreshold.Function.disconnect() def propagateDirty(self, slot, subindex, roi): # All outputs are directly connected to internal operators pass
class OpInputDataReader(Operator): """ This operator can read input data of any supported type. The data format is determined from the file extension. """ name = "OpInputDataReader" category = "Input" videoExts = ["ufmf", "mmf"] h5_n5_Exts = ["h5", "hdf5", "ilp", "n5"] n5Selection = [ "json" ] # n5 stores data in a directory, containing a json-file which we use to select the n5-file klbExts = ["klb"] npyExts = ["npy"] npzExts = ["npz"] rawExts = ["dat", "bin", "raw"] blockwiseExts = ["json"] tiledExts = ["json"] tiffExts = ["tif", "tiff"] vigraImpexExts = vigra.impex.listExtensions().split() SupportedExtensions = (h5_n5_Exts + n5Selection + npyExts + npzExts + rawExts + vigraImpexExts + blockwiseExts + videoExts + klbExts) if _supports_dvid: dvidExts = ["dvidvol"] SupportedExtensions += dvidExts if _supports_h5blockreader: h5blockstoreExts = ["json"] SupportedExtensions += h5blockstoreExts # FilePath is inspected to determine data type. # For hdf5 files, append the internal path to the filepath, # e.g. /mydir/myfile.h5/internal/path/to/dataset # For stacks, provide a globstring, e.g. /mydir/input*.png # Other types are determined via file extension WorkingDirectory = InputSlot(stype="filestring", optional=True) FilePath = InputSlot(stype="filestring") SequenceAxis = InputSlot(optional=True) # FIXME: Document this. SubVolumeRoi = InputSlot(optional=True) # (start, stop) Output = OutputSlot() loggingName = __name__ + ".OpInputDataReader" logger = logging.getLogger(loggingName) class DatasetReadError(Exception): pass def __init__(self, *args, **kwargs): super(OpInputDataReader, self).__init__(*args, **kwargs) self.internalOperators = [] self.internalOutput = None self._file = None def cleanUp(self): super(OpInputDataReader, self).cleanUp() if self._file is not None: self._file.close() self._file = None def setupOutputs(self): """ Inspect the file name and instantiate and connect an internal operator of the appropriate type. TODO: Handle datasets of non-standard (non-5d) dimensions. """ filePath = self.FilePath.value assert isinstance( filePath, (str, unicode )), "Error: filePath is not of type str. It's of type {}".format( type(filePath)) # Does this look like a relative path? useRelativePath = not isUrl(filePath) and not os.path.isabs(filePath) if useRelativePath: # If using a relative path, we need both inputs before proceeding if not self.WorkingDirectory.ready(): return else: # Convert this relative path into an absolute path filePath = os.path.normpath( os.path.join(self.WorkingDirectory.value, filePath)).replace("\\", "/") # Clean up before reconfiguring if self.internalOperators: self.Output.disconnect() self.opInjector.cleanUp() for op in self.internalOperators[::-1]: op.cleanUp() self.internalOperators = [] self.internalOutput = None if self._file is not None: self._file.close() openFuncs = [ self._attemptOpenAsKlb, self._attemptOpenAsUfmf, self._attemptOpenAsMmf, self._attemptOpenAsRESTfulPrecomputedChunkedVolume, self._attemptOpenAsDvidVolume, self._attemptOpenAsH5N5Stack, self._attemptOpenAsTiffStack, self._attemptOpenAsStack, self._attemptOpenAsH5N5, self._attemptOpenAsNpy, self._attemptOpenAsRawBinary, self._attemptOpenAsTiledVolume, self._attemptOpenAsH5BlockStore, self._attemptOpenAsBlockwiseFileset, self._attemptOpenAsRESTfulBlockwiseFileset, self._attemptOpenAsBigTiff, self._attemptOpenAsTiff, self._attemptOpenWithVigraImpex, ] # Try every method of opening the file until one works. iterFunc = openFuncs.__iter__() while not self.internalOperators: try: openFunc = next(iterFunc) except StopIteration: break self.internalOperators, self.internalOutput = openFunc(filePath) if self.internalOutput is None: raise RuntimeError("Can't read " + filePath + " because it has an unrecognized format.") # If we've got a ROI, append a subregion operator. if self.SubVolumeRoi.ready(): self._opSubRegion = OpSubRegion(parent=self) self._opSubRegion.Roi.setValue(self.SubVolumeRoi.value) self._opSubRegion.Input.connect(self.internalOutput) self.internalOutput = self._opSubRegion.Output self.opInjector = OpMetadataInjector(parent=self) self.opInjector.Input.connect(self.internalOutput) # Add metadata for estimated RAM usage if the internal operator didn't already provide it. if self.internalOutput.meta.ram_per_pixelram_usage_per_requested_pixel is None: ram_per_pixel = self.internalOutput.meta.dtype().nbytes if "c" in self.internalOutput.meta.getTaggedShape(): ram_per_pixel *= self.internalOutput.meta.getTaggedShape()["c"] self.opInjector.Metadata.setValue( {"ram_per_pixelram_usage_per_requested_pixel": ram_per_pixel}) else: # Nothing to add self.opInjector.Metadata.setValue({}) # Directly connect our own output to the internal output self.Output.connect(self.opInjector.Output) def _attemptOpenAsKlb(self, filePath): if not os.path.splitext(filePath)[1].lower() == ".klb": return ([], None) opReader = OpKlbReader(parent=self) opReader.FilePath.setValue(filePath) return [opReader, opReader.Output] def _attemptOpenAsMmf(self, filePath): if ".mmf" in filePath: mmfReader = OpStreamingMmfReader(parent=self) mmfReader.FileName.setValue(filePath) return ([mmfReader], mmfReader.Output) """ # Cache the frames we read frameShape = mmfReader.Output.meta.ideal_blockshape mmfCache = OpBlockedArrayCache( parent=self ) mmfCache.fixAtCurrent.setValue( False ) mmfCache.BlockShape.setValue( frameShape ) mmfCache.Input.connect( mmfReader.Output ) return ([mmfReader, mmfCache], mmfCache.Output) """ else: return ([], None) def _attemptOpenAsUfmf(self, filePath): if ".ufmf" in filePath: ufmfReader = OpStreamingUfmfReader(parent=self) ufmfReader.FileName.setValue(filePath) return ([ufmfReader], ufmfReader.Output) # Cache the frames we read """ frameShape = ufmfReader.Output.meta.ideal_blockshape ufmfCache = OpBlockedArrayCache( parent=self ) ufmfCache.fixAtCurrent.setValue( False ) ufmfCache.BlockShape.setValue( frameShape ) ufmfCache.Input.connect( ufmfReader.Output ) return ([ufmfReader, ufmfCache], ufmfCache.Output) """ else: return ([], None) def _attemptOpenAsRESTfulPrecomputedChunkedVolume(self, filePath): if not filePath.lower().startswith("precomputed://"): return ([], None) else: url = filePath.lstrip("precomputed://") reader = OpRESTfulPrecomputedChunkedVolumeReader(parent=self) reader.BaseUrl.setValue(url) return [reader], reader.Output def _attemptOpenAsH5N5Stack(self, filePath): if not ("*" in filePath or os.path.pathsep in filePath): return ([], None) # Now use the .checkGlobString method of the stack readers isSingleFile = True try: OpStreamingH5N5SequenceReaderS.checkGlobString(filePath) except OpStreamingH5N5SequenceReaderS.WrongFileTypeError: return ([], None) except ( OpStreamingH5N5SequenceReaderS.NoInternalPlaceholderError, OpStreamingH5N5SequenceReaderS.NotTheSameFileError, OpStreamingH5N5SequenceReaderS.ExternalPlaceholderError, ): isSingleFile = False isMultiFile = True try: OpStreamingH5N5SequenceReaderM.checkGlobString(filePath) except ( OpStreamingH5N5SequenceReaderM.NoExternalPlaceholderError, OpStreamingH5N5SequenceReaderM.SameFileError, OpStreamingH5N5SequenceReaderM.InternalPlaceholderError, ): isMultiFile = False assert not (isMultiFile and isSingleFile) if isSingleFile is True: opReader = OpStreamingH5N5SequenceReaderS(parent=self) elif isMultiFile is True: opReader = OpStreamingH5N5SequenceReaderM(parent=self) try: opReader.SequenceAxis.connect(self.SequenceAxis) opReader.GlobString.setValue(filePath) return ([opReader], opReader.OutputImage) except (OpStreamingH5N5SequenceReaderM.WrongFileTypeError, OpStreamingH5N5SequenceReaderS.WrongFileTypeError): return ([], None) def _attemptOpenAsTiffStack(self, filePath): if not ("*" in filePath or os.path.pathsep in filePath): return ([], None) try: opReader = OpTiffSequenceReader(parent=self) opReader.SequenceAxis.connect(self.SequenceAxis) opReader.GlobString.setValue(filePath) return ([opReader], opReader.Output) except OpTiffSequenceReader.WrongFileTypeError as ex: return ([], None) def _attemptOpenAsStack(self, filePath): if "*" in filePath or os.path.pathsep in filePath: stackReader = OpStackLoader(parent=self) stackReader.SequenceAxis.connect(self.SequenceAxis) stackReader.globstring.setValue(filePath) return ([stackReader], stackReader.stack) else: return ([], None) def _attemptOpenAsH5N5(self, filePath): # Check for an hdf5 or n5 extension pathComponents = PathComponents(filePath) ext = pathComponents.extension if ext[1:] not in OpInputDataReader.h5_n5_Exts: return [], None externalPath = pathComponents.externalPath internalPath = pathComponents.internalPath if not os.path.exists(externalPath): raise OpInputDataReader.DatasetReadError( "Input file does not exist: " + externalPath) # Open the h5/n5 file in read-only mode try: h5N5File = OpStreamingH5N5Reader.get_h5_n5_file(externalPath, "r") except OpInputDataReader.DatasetReadError: raise except Exception as e: msg = "Unable to open H5/N5 File: {}\n{}".format( externalPath, str(e)) raise OpInputDataReader.DatasetReadError(msg) from e else: if not internalPath: possible_internal_paths = lsH5N5(h5N5File) if len(possible_internal_paths) == 1: internalPath = possible_internal_paths[0]["name"] elif len(possible_internal_paths) == 0: h5N5File.close() msg = "H5/N5 file contains no datasets: {}".format( externalPath) raise OpInputDataReader.DatasetReadError(msg) else: h5N5File.close() msg = ( "When using hdf5/n5, you must append the hdf5 internal path to the " "data set to your filename, e.g. myfile.h5/volume/data " "No internal path provided for dataset in file: {}". format(externalPath)) raise OpInputDataReader.DatasetReadError(msg) try: compression_setting = h5N5File[internalPath].compression except Exception as e: h5N5File.close() msg = "Error reading H5/N5 File: {}\n{}".format( externalPath, e) raise OpInputDataReader.DatasetReadError(msg) from e # If the h5 dataset is compressed, we'll have better performance # with a multi-process hdf5 access object. # (Otherwise, single-process is faster.) allow_multiprocess_hdf5 = ( "LAZYFLOW_MULTIPROCESS_HDF5" in os.environ and os.environ["LAZYFLOW_MULTIPROCESS_HDF5"] != "") if compression_setting is not None and allow_multiprocess_hdf5 and isinstance( h5N5File, h5py.File): h5N5File.close() h5N5File = MultiProcessHdf5File(externalPath, "r") self._file = h5N5File h5N5Reader = OpStreamingH5N5Reader(parent=self) h5N5Reader.H5N5File.setValue(h5N5File) try: h5N5Reader.InternalPath.setValue(internalPath) except OpStreamingH5N5Reader.DatasetReadError as e: msg = "Error reading H5/N5 File: {}\n{}".format( externalPath, e.msg) raise OpInputDataReader.DatasetReadError(msg) from e return ([h5N5Reader], h5N5Reader.OutputImage) def _attemptOpenAsNpy(self, filePath): pathComponents = PathComponents(filePath) ext = pathComponents.extension npyzExts = OpInputDataReader.npyExts + OpInputDataReader.npzExts if ext not in (".%s" % x for x in npyzExts): return ([], None) externalPath = pathComponents.externalPath internalPath = pathComponents.internalPath # FIXME: check whether path is valid?! if not os.path.exists(externalPath): raise OpInputDataReader.DatasetReadError( "Input file does not exist: " + externalPath) try: # Create an internal operator npyReader = OpNpyFileReader(parent=self) if internalPath is not None: internalPath = internalPath.replace("/", "") npyReader.InternalPath.setValue(internalPath) npyReader.FileName.setValue(externalPath) return ([npyReader], npyReader.Output) except OpNpyFileReader.DatasetReadError as e: raise OpInputDataReader.DatasetReadError(*e.args) from e def _attemptOpenAsRawBinary(self, filePath): fileExtension = os.path.splitext(filePath)[1].lower() fileExtension = fileExtension.lstrip(".") # Remove leading dot # Check for numpy extension if fileExtension not in OpInputDataReader.rawExts: return ([], None) else: try: # Create an internal operator opReader = OpRawBinaryFileReader(parent=self) opReader.FilePath.setValue(filePath) return ([opReader], opReader.Output) except OpRawBinaryFileReader.DatasetReadError as e: raise OpInputDataReader.DatasetReadError(*e.args) from e def _attemptOpenAsH5BlockStore(self, filePath): if not os.path.splitext(filePath)[1] == ".json": return ([], None) op = OpH5BlockStoreReader(parent=self) try: # For now, there is no explicit schema validation for the json file, # but H5BlockStore constructor will fail to load the json. op.IndexFilepath.setValue(filePath) return [op], op.Output except: raise # DELME op.cleanUp() return ([], None) def _attemptOpenAsDvidVolume(self, filePath): """ Two ways to specify a dvid volume. 1) via a file that contains the hostname, uuid, and dataset name (1 per line) 2) as a url, e.g. http://localhost:8000/api/node/uuid/dataname """ if os.path.splitext(filePath)[1] == ".dvidvol": with open(filePath) as f: filetext = f.read() hostname, uuid, dataname = filetext.splitlines() opDvidVolume = OpDvidVolume(hostname, uuid, dataname, parent=self) return [opDvidVolume], opDvidVolume.Output if "://" not in filePath: return ([], None) # not a url url_format = "^protocol://hostname/api/node/uuid/dataname(\\?query_string)?" for field in [ "protocol", "hostname", "uuid", "dataname", "query_string" ]: url_format = url_format.replace(field, "(?P<" + field + ">[^?]+)") match = re.match(url_format, filePath) if not match: # DVID is the only url-based format we support right now. # So if it looks like the user gave a URL that isn't a valid DVID node, then error. raise OpInputDataReader.DatasetReadError( "Invalid URL format for DVID: {}".format(filePath)) fields = match.groupdict() try: query_string = fields["query_string"] query_args = {} if query_string: query_args = dict( [s.split("=") for s in query_string.split("&")]) try: opDvidVolume = OpDvidVolume(fields["hostname"], fields["uuid"], fields["dataname"], query_args, parent=self) return [opDvidVolume], opDvidVolume.Output except: # Maybe this is actually a roi opDvidRoi = OpDvidRoi(fields["hostname"], fields["uuid"], fields["dataname"], parent=self) return [opDvidRoi], opDvidRoi.Output except OpDvidVolume.DatasetReadError as e: raise OpInputDataReader.DatasetReadError(*e.args) from e def _attemptOpenAsBlockwiseFileset(self, filePath): fileExtension = os.path.splitext(filePath)[1].lower() fileExtension = fileExtension.lstrip(".") # Remove leading dot if fileExtension in OpInputDataReader.blockwiseExts: opReader = OpBlockwiseFilesetReader(parent=self) try: # This will raise a SchemaError if this is the wrong type of json config. opReader.DescriptionFilePath.setValue(filePath) return ([opReader], opReader.Output) except JsonConfigParser.SchemaError: opReader.cleanUp() except OpBlockwiseFilesetReader.MissingDatasetError as e: raise OpInputDataReader.DatasetReadError(*e.args) from e return ([], None) def _attemptOpenAsRESTfulBlockwiseFileset(self, filePath): fileExtension = os.path.splitext(filePath)[1].lower() fileExtension = fileExtension.lstrip(".") # Remove leading dot if fileExtension in OpInputDataReader.blockwiseExts: opReader = OpRESTfulBlockwiseFilesetReader(parent=self) try: # This will raise a SchemaError if this is the wrong type of json config. opReader.DescriptionFilePath.setValue(filePath) return ([opReader], opReader.Output) except JsonConfigParser.SchemaError: opReader.cleanUp() except OpRESTfulBlockwiseFilesetReader.MissingDatasetError as e: raise OpInputDataReader.DatasetReadError(*e.args) from e return ([], None) def _attemptOpenAsTiledVolume(self, filePath): fileExtension = os.path.splitext(filePath)[1].lower() fileExtension = fileExtension.lstrip(".") # Remove leading dot if fileExtension in OpInputDataReader.tiledExts: opReader = OpCachedTiledVolumeReader(parent=self) try: # This will raise a SchemaError if this is the wrong type of json config. opReader.DescriptionFilePath.setValue(filePath) return ([opReader], opReader.SpecifiedOutput) except JsonConfigParser.SchemaError: opReader.cleanUp() return ([], None) def _attemptOpenAsBigTiff(self, filePath): if not _supports_bigtiff: return ([], None) fileExtension = os.path.splitext(filePath)[1].lower() fileExtension = fileExtension.lstrip(".") # Remove leading dot if fileExtension not in OpInputDataReader.tiffExts: return ([], None) if not os.path.exists(filePath): raise OpInputDataReader.DatasetReadError( "Input file does not exist: " + filePath) opReader = OpBigTiffReader(parent=self) try: opReader.Filepath.setValue(filePath) return ([opReader], opReader.Output) except OpBigTiffReader.NotBigTiffError as ex: opReader.cleanUp() return ([], None) def _attemptOpenAsTiff(self, filePath): fileExtension = os.path.splitext(filePath)[1].lower() fileExtension = fileExtension.lstrip(".") # Remove leading dot if fileExtension not in OpInputDataReader.tiffExts: return ([], None) if not os.path.exists(filePath): raise OpInputDataReader.DatasetReadError( "Input file does not exist: " + filePath) opReader = OpTiffReader(parent=self) opReader.Filepath.setValue(filePath) page_shape = opReader.Output.meta.ideal_blockshape # Cache the pages we read opCache = OpBlockedArrayCache(parent=self) opCache.fixAtCurrent.setValue(False) opCache.BlockShape.setValue(page_shape) opCache.Input.connect(opReader.Output) return ([opReader, opCache], opCache.Output) def _attemptOpenWithVigraImpex(self, filePath): fileExtension = os.path.splitext(filePath)[1].lower() fileExtension = fileExtension.lstrip(".") # Remove leading dot if fileExtension not in OpInputDataReader.vigraImpexExts: return ([], None) if not os.path.exists(filePath): raise OpInputDataReader.DatasetReadError( "Input file does not exist: " + filePath) vigraReader = OpImageReader(parent=self) vigraReader.Filename.setValue(filePath) # Cache the image instead of reading the hard disk for every access. imageCache = OpBlockedArrayCache(parent=self) imageCache.Input.connect(vigraReader.Image) # 2D: Just one block for the whole image cacheBlockShape = vigraReader.Image.meta.shape taggedShape = vigraReader.Image.meta.getTaggedShape() if "z" in list(taggedShape.keys()): # 3D: blocksize is one slice. taggedShape["z"] = 1 cacheBlockShape = tuple(taggedShape.values()) imageCache.fixAtCurrent.setValue(False) imageCache.BlockShape.setValue(cacheBlockShape) assert imageCache.Output.ready() return ([vigraReader, imageCache], imageCache.Output) def execute(self, slot, subindex, roi, result): assert False, "Shouldn't get here because our output is directly connected..." def propagateDirty(self, slot, subindex, roi): # Output slots are directly conncted to internal operators pass
class OpEdgeTraining(Operator): # Shared across lanes DEFAULT_FEATURES = {"Grayscale": ['standard_edge_mean']} FeatureNames = InputSlot(value=DEFAULT_FEATURES) FreezeClassifier = InputSlot(value=True) # Lane-wise EdgeLabelsDict = InputSlot(level=1, value={}) VoxelData = InputSlot(level=1) Superpixels = InputSlot(level=1) GroundtruthSegmentation = InputSlot(level=1, optional=True) RawData = InputSlot(level=1, optional=True) # Used by the GUI for display only Rag = OutputSlot(level=1) EdgeProbabilities = OutputSlot(level=1) EdgeProbabilitiesDict = OutputSlot( level=1) # A dict of id_pair -> probabilities NaiveSegmentation = OutputSlot(level=1) def __init__(self, *args, **kwargs): super(OpEdgeTraining, self).__init__(*args, **kwargs) self.opCreateRag = OpMultiLaneWrapper(OpCreateRag, parent=self) self.opCreateRag.Superpixels.connect(self.Superpixels) self.opRagCache = OpMultiLaneWrapper( OpValueCache, parent=self, broadcastingSlotNames=['fixAtCurrent']) self.opRagCache.Input.connect(self.opCreateRag.Rag) self.opRagCache.name = 'opRagCache' self.opComputeEdgeFeatures = OpMultiLaneWrapper( OpComputeEdgeFeatures, parent=self, broadcastingSlotNames=['FeatureNames']) self.opComputeEdgeFeatures.FeatureNames.connect(self.FeatureNames) self.opComputeEdgeFeatures.VoxelData.connect(self.VoxelData) self.opComputeEdgeFeatures.Rag.connect(self.opRagCache.Output) self.opEdgeFeaturesCache = OpMultiLaneWrapper( OpValueCache, parent=self, broadcastingSlotNames=['fixAtCurrent']) self.opEdgeFeaturesCache.Input.connect( self.opComputeEdgeFeatures.EdgeFeaturesDataFrame) self.opEdgeFeaturesCache.name = 'opEdgeFeaturesCache' self.opTrainEdgeClassifier = OpTrainEdgeClassifier(parent=self) self.opTrainEdgeClassifier.EdgeLabelsDict.connect(self.EdgeLabelsDict) self.opTrainEdgeClassifier.EdgeFeaturesDataFrame.connect( self.opEdgeFeaturesCache.Output) # classifier cache input is set after training. self.opClassifierCache = OpValueCache(parent=self) self.opClassifierCache.Input.connect( self.opTrainEdgeClassifier.EdgeClassifier) self.opClassifierCache.fixAtCurrent.connect(self.FreezeClassifier) self.opClassifierCache.name = 'opClassifierCache' self.opPredictEdgeProbabilities = OpMultiLaneWrapper( OpPredictEdgeProbabilities, parent=self, broadcastingSlotNames=['EdgeClassifier']) self.opPredictEdgeProbabilities.EdgeClassifier.connect( self.opClassifierCache.Output) self.opPredictEdgeProbabilities.EdgeFeaturesDataFrame.connect( self.opEdgeFeaturesCache.Output) self.opEdgeProbabilitiesCache = OpMultiLaneWrapper( OpValueCache, parent=self, broadcastingSlotNames=['fixAtCurrent']) self.opEdgeProbabilitiesCache.Input.connect( self.opPredictEdgeProbabilities.EdgeProbabilities) self.opEdgeProbabilitiesCache.name = 'opEdgeProbabilitiesCache' self.opEdgeProbabilitiesCache.fixAtCurrent.connect( self.FreezeClassifier) self.opEdgeProbabilitiesDict = OpMultiLaneWrapper( OpEdgeProbabilitiesDict, parent=self) self.opEdgeProbabilitiesDict.Rag.connect(self.opRagCache.Output) self.opEdgeProbabilitiesDict.EdgeProbabilities.connect( self.opEdgeProbabilitiesCache.Output) self.opEdgeProbabilitiesDictCache = OpMultiLaneWrapper( OpValueCache, parent=self, broadcastingSlotNames=['fixAtCurrent']) self.opEdgeProbabilitiesDictCache.Input.connect( self.opEdgeProbabilitiesDict.EdgeProbabilitiesDict) self.opEdgeProbabilitiesDictCache.name = 'opEdgeProbabilitiesDictCache' self.opNaiveSegmentation = OpMultiLaneWrapper(OpNaiveSegmentation, parent=self) self.opNaiveSegmentation.Superpixels.connect(self.Superpixels) self.opNaiveSegmentation.Rag.connect(self.opRagCache.Output) self.opNaiveSegmentation.EdgeProbabilities.connect( self.opEdgeProbabilitiesCache.Output) self.opNaiveSegmentationCache = OpMultiLaneWrapper( OpBlockedArrayCache, parent=self, broadcastingSlotNames=[ 'CompressionEnabled', 'fixAtCurrent', 'BypassModeEnabled' ]) self.opNaiveSegmentationCache.CompressionEnabled.setValue(True) self.opNaiveSegmentationCache.Input.connect( self.opNaiveSegmentation.Output) self.opNaiveSegmentationCache.name = 'opNaiveSegmentationCache' self.Rag.connect(self.opRagCache.Output) self.EdgeProbabilities.connect(self.opEdgeProbabilitiesCache.Output) self.EdgeProbabilitiesDict.connect( self.opEdgeProbabilitiesDictCache.Output) self.NaiveSegmentation.connect(self.opNaiveSegmentationCache.Output) # All input multi-slots should be kept in sync # Output multi-slots will auto-sync via the graph multiInputs = filter(lambda s: s.level >= 1, self.inputs.values()) for s1 in multiInputs: for s2 in multiInputs: if s1 != s2: def insertSlot(a, b, position, finalsize): a.insertSlot(position, finalsize) s1.notifyInserted(partial(insertSlot, s2)) def removeSlot(a, b, position, finalsize): a.removeSlot(position, finalsize) s1.notifyRemoved(partial(removeSlot, s2)) # If superpixels change, we have to delete our edge labels. # Since we're dealing with multi-lane slot, setting up dirty handlers is a two-stage process. # (1) React to lane insertion by subscribing to dirty signals for the new lane. # (2) React to each lane's dirty signal by deleting the labels for that lane. def subscribe_to_dirty_sp(slot, position, finalsize): # A new lane was added. Subscribe to it's dirty signal. assert slot is self.Superpixels self.Superpixels[position].notifyDirty( self.handle_dirty_superpixels) self.Superpixels[position].notifyReady( self.handle_dirty_superpixels) self.Superpixels[position].notifyUnready( self.handle_dirty_superpixels) # When a new lane is added, set up the listener for dirtyness. self.Superpixels.notifyInserted(subscribe_to_dirty_sp) def handle_dirty_superpixels(self, subslot, *args): """ Discards the labels for a given lane. NOTE: In addition to callers in this file, this function is also called from multicutWorkflow.py """ # Determine which lane triggered this and delete it's labels lane_index = self.Superpixels.index(subslot) old_labels = self.EdgeLabelsDict[lane_index].value if old_labels: logger.warn( "Superpixels changed. Deleting all labels in lane {}.".format( lane_index)) logger.info("Old labels were: {}".format(old_labels)) self.EdgeLabelsDict[lane_index].setValue({}) def setupOutputs(self): for sp_slot, seg_cache_blockshape_slot in zip( self.Superpixels, self.opNaiveSegmentationCache.outerBlockShape): assert sp_slot.meta.dtype == np.uint32 assert sp_slot.meta.getAxisKeys()[-1] == 'c' seg_cache_blockshape_slot.setValue(sp_slot.meta.shape) def execute(self, slot, subindex, roi, result): assert False, "Shouldn't get here, but requesting slot: {}".format( slot) def propagateDirty(self, slot, subindex, roi): pass def setEdgeLabelsFromGroundtruth(self, lane_index): """ For the given lane, read the ground truth volume and automatically determine edge label values. """ op_view = self.getLane(lane_index) if not op_view.GroundtruthSegmentation.ready(): raise RuntimeError( "There is no Ground Truth data available for lane: {}".format( lane_index)) logger.info("Loading groundtruth for lane {}...".format(lane_index)) gt_vol = op_view.GroundtruthSegmentation[:].wait() gt_vol = vigra.taggedView( gt_vol, op_view.GroundtruthSegmentation.meta.axistags) gt_vol = gt_vol.withAxes(''.join( tag.key for tag in op_view.Superpixels.meta.axistags)) gt_vol = gt_vol.dropChannelAxis() rag = op_view.opRagCache.Output.value logger.info("Computing edge decisions from groundtruth...") decisions = rag.edge_decisions_from_groundtruth(gt_vol, asdict=False) edge_labels = decisions.view(np.uint8) + 1 edge_ids = map(tuple, rag.edge_ids) edge_labels_dict = dict(zip(edge_ids, edge_labels)) op_view.EdgeLabelsDict.setValue(edge_labels_dict) def addLane(self, laneIndex): numLanes = len(self.VoxelData) assert numLanes == laneIndex, "Image lanes must be appended." self.VoxelData.resize(numLanes + 1) def removeLane(self, laneIndex, finalLength): self.VoxelData.removeSlot(laneIndex, finalLength) def getLane(self, laneIndex): return OperatorSubView(self, laneIndex)
class OpTrainEdgeClassifier(Operator): EdgeLabelsDict = InputSlot(level=1) EdgeFeaturesDataFrame = InputSlot(level=1) EdgeClassifier = OutputSlot() def setupOutputs(self): self.EdgeClassifier.meta.shape = (1, ) self.EdgeClassifier.meta.dtype = object def execute(self, slot, subindex, roi, result): all_features_and_labels_df = None for lane_index, (labels_dict_slot, features_slot) in \ enumerate( zip(self.EdgeLabelsDict, self.EdgeFeaturesDataFrame) ): logger.info( "Retrieving features for lane {}...".format(lane_index)) labels_dict = labels_dict_slot.value.copy( ) # Copy now to avoid threading issues. if not labels_dict: continue sp_columns = np.array(labels_dict.keys()) edge_features_df = features_slot.value assert list(edge_features_df.columns[0:2]) == ['sp1', 'sp2'] labels_df = pd.DataFrame(sp_columns, columns=['sp1', 'sp2']) labels_df['label'] = labels_dict.values() # Drop zero labels labels_df = labels_df[labels_df['label'] != 0] # Merge in features features_and_labels_df = pd.merge(edge_features_df, labels_df, how='right', on=['sp1', 'sp2']) if all_features_and_labels_df is not None: all_features_and_labels_df = all_features_and_labels_df.append( features_and_labels_df) else: all_features_and_labels_df = features_and_labels_df if all_features_and_labels_df is None: # No labels yet. result[0] = None return assert list(all_features_and_labels_df.columns[0:2]) == ['sp1', 'sp2'] assert all_features_and_labels_df.columns[-1] == 'label' feature_matrix = all_features_and_labels_df.iloc[:, 2: -1].values # Omit 'sp1', 'sp2', and 'label' labels = all_features_and_labels_df.iloc[:, -1].values logger.info("Training classifier with {} labels...".format( len(labels))) # TODO: Allow factory to be configured via an input slot classifier_factory = ParallelVigraRfLazyflowClassifierFactory() classifier = classifier_factory.create_and_train( feature_matrix, labels, feature_names=all_features_and_labels_df.columns[2:-1].values) assert set(classifier.known_classes).issubset(set([1, 2])) result[0] = classifier def propagateDirty(self, slot, subindex, roi): self.EdgeClassifier.setDirty()
class OpCarving(Operator): name = "Carving" category = "interactive segmentation" # I n p u t s # #MST of preprocessed Graph MST = InputSlot() # These three slots are for display only. # All computation is done with the MST. OverlayData = InputSlot( optional=True ) # Display-only: Available to the GUI in case the input data was preprocessed in some way but you still want to see the 'raw' data. InputData = InputSlot() # The data used by preprocessing (display only) FilteredInputData = InputSlot() # The output of the preprocessing filter #write the seeds that the users draw into this slot WriteSeeds = InputSlot() #trigger an update by writing into this slot Trigger = InputSlot(value=numpy.zeros((1, ), dtype=numpy.uint8)) #number between 0.0 and 1.0 #bias of the background #FIXME: correct name? BackgroundPriority = InputSlot(value=0.95) LabelNames = OutputSlot(stype='list') #a number between 0 and 256 #below the number, no background bias will be applied to the edge weights NoBiasBelow = InputSlot(value=64) # uncertainty type UncertaintyType = InputSlot() # O u t p u t s # #current object + background Segmentation = OutputSlot() Supervoxels = OutputSlot() Uncertainty = OutputSlot() #contains an array with where all objects done so far are labeled the same DoneObjects = OutputSlot() #contains an array with the object labels done so far, one label for each #object DoneSegmentation = OutputSlot() CurrentObjectName = OutputSlot(stype='string') AllObjectNames = OutputSlot(rtype=List, stype=Opaque) #current object has an actual segmentation HasSegmentation = OutputSlot(stype='bool') #Hint Overlay HintOverlay = OutputSlot() #Pmap Overlay PmapOverlay = OutputSlot() MstOut = OutputSlot() def __init__(self, graph=None, hintOverlayFile=None, pmapOverlayFile=None, parent=None): super(OpCarving, self).__init__(graph=graph, parent=parent) self.opLabelArray = OpDenseLabelArray(parent=self) #self.opLabelArray.EraserLabelValue.setValue( 100 ) self.opLabelArray.MetaInput.connect(self.InputData) self._hintOverlayFile = hintOverlayFile self._mst = None self.has_seeds = False # keeps track of whether or not there are seeds currently loaded, either drawn by the user or loaded from a saved object self.LabelNames.setValue(["Background", "Object"]) #supervoxels of finished and saved objects self._done_lut = None self._done_seg_lut = None self._hints = None self._pmap = None if hintOverlayFile is not None: try: f = h5py.File(hintOverlayFile, "r") except Exception as e: logger.info("Could not open hint overlay '%s'" % hintOverlayFile) raise e self._hints = f["/hints"].value[numpy.newaxis, :, :, :, numpy.newaxis] if pmapOverlayFile is not None: try: f = h5py.File(pmapOverlayFile, "r") except Exception as e: raise RuntimeError("Could not open pmap overlay '%s'" % pmapOverlayFile) self._pmap = f["/data"].value[numpy.newaxis, :, :, :, numpy.newaxis] self._setCurrObjectName("<not saved yet>") self.HasSegmentation.setValue(False) # keep track of a set of object names that have changed since # the last serialization of this object to disk self._dirtyObjects = set() self.preprocessingApplet = None self._opMstCache = OpValueCache(parent=self) self.MstOut.connect(self._opMstCache.Output) self.InputData.notifyReady(self._checkConstraints) def _checkConstraints(self, *args): slot = self.InputData numChannels = slot.meta.getTaggedShape()['c'] if numChannels != 1: raise DatasetConstraintError( "Carving", "Input image must have exactly one channel. " + "You attempted to add a dataset with {} channels".format( numChannels)) sh = slot.meta.shape ax = slot.meta.axistags if len(slot.meta.shape) != 5: # Raise a regular exception. This error is for developers, not users. raise RuntimeError("was expecting a 5D dataset, got shape=%r" % (sh, )) if slot.meta.getTaggedShape()['t'] != 1: raise DatasetConstraintError( "Carving", "Input image must not have more than one time slice. " + "You attempted to add a dataset with {} time slices".format( slot.meta.getTaggedShape()['t'])) for i in range(1, 4): if not ax[i].isSpatial(): # This is for developers. Don't need a user-friendly error. raise RuntimeError("%d-th axis %r is not spatial" % (i, ax[i])) def _clearLabels(self): #clear the labels self.opLabelArray.DeleteLabel.setValue(2) self.opLabelArray.DeleteLabel.setValue(1) self.opLabelArray.DeleteLabel.setValue(-1) self.has_seeds = False def _setCurrObjectName(self, n): """ Sets the current object name to n. """ self._currObjectName = n self.CurrentObjectName.setValue(n) def _buildDone(self): """ Builds the done segmentation anew, for example after saving an object or deleting an object. """ if self._mst is None: return with Timer() as timer: self._done_lut = numpy.zeros(self._mst.numNodes + 1, dtype=numpy.int32) self._done_seg_lut = numpy.zeros(self._mst.numNodes + 1, dtype=numpy.int32) logger.info("building 'done' luts") for name, objectSupervoxels in self._mst.object_lut.iteritems(): if name == self._currObjectName: continue self._done_lut[objectSupervoxels] += 1 assert name in self._mst.object_names, "%s not in self._mst.object_names, keys are %r" % ( name, self._mst.object_names.keys()) self._done_seg_lut[objectSupervoxels] = self._mst.object_names[ name] logger.info("building the 'done' luts took {} seconds".format( timer.seconds())) def dataIsStorable(self): if self._mst is None: return False nodeSeeds = self._mst.gridSegmentor.getNodeSeeds() fg_seedNum = len(numpy.where(nodeSeeds == 2)[0]) bg_seedNum = len(numpy.where(nodeSeeds == 1)[0]) if not (fg_seedNum > 0 and bg_seedNum > 0): return False else: return True def setupOutputs(self): self.Segmentation.meta.assignFrom(self.InputData.meta) self.Segmentation.meta.dtype = numpy.int32 self.Supervoxels.meta.assignFrom(self.Segmentation.meta) self.DoneObjects.meta.assignFrom(self.Segmentation.meta) self.DoneSegmentation.meta.assignFrom(self.Segmentation.meta) self.HintOverlay.meta.assignFrom(self.InputData.meta) self.PmapOverlay.meta.assignFrom(self.InputData.meta) self.Uncertainty.meta.assignFrom(self.InputData.meta) self.Uncertainty.meta.dtype = numpy.uint8 self.Trigger.meta.shape = (1, ) self.Trigger.meta.dtype = numpy.uint8 if self._mst is not None: objects = self._mst.object_names.keys() self.AllObjectNames.meta.shape = (len(objects), ) else: self.AllObjectNames.meta.shape = (0, ) self.AllObjectNames.meta.dtype = object def connectToPreprocessingApplet(self, applet): self.PreprocessingApplet = applet # def updatePreprocessing(self): # if self.PreprocessingApplet is None or self._mst is None: # return #FIXME: why were the following lines needed ? # if len(self._mst.object_names)==0: # self.PreprocessingApplet.enableWriteprotect(True) # else: # self.PreprocessingApplet.enableWriteprotect(False) def hasCurrentObject(self): """ Returns current object name. None if it is not set. """ #FIXME: This is misleading. Having a current object and that object having #a name is not the same thing. return self._currObjectName def currentObjectName(self): """ Returns current object name. Return "" if no current object """ assert self._currObjectName is not None, "FIXME: This function should either return '' or None. Why does it sometimes return one and then the other?" return self._currObjectName def hasObjectWithName(self, name): """ Returns True if object with name is existent. False otherwise. """ return name in self._mst.object_lut def doneObjectNamesForPosition(self, position3d): """ Returns a list of names of objects which occupy a specific 3D position. List is empty if there are no objects present. """ assert len(position3d) == 3 #find the supervoxel that was clicked sv = self._mst.supervoxelUint32[position3d] names = [] for name, objectSupervoxels in self._mst.object_lut.iteritems(): if numpy.sum(sv == objectSupervoxels) > 0: names.append(name) logger.info("click on %r, supervoxel=%d: %r" % (position3d, sv, names)) return names @Operator.forbidParallelExecute def attachVoxelLabelsToObject(self, name, fgVoxels, bgVoxels): """ Attaches Voxellabes to an object called name. """ self._mst.object_seeds_fg_voxels[name] = fgVoxels self._mst.object_seeds_bg_voxels[name] = bgVoxels @Operator.forbidParallelExecute def clearCurrentLabeling(self, trigger_recompute=True): """ Clears the current labeling. """ self._clearLabels() self._mst.gridSegmentor.clearSeeds() #lut_segmentation = self._mst.segmentation.lut[:] #lut_segmentation[:] = 0 #lut_seeds = self._mst.seeds.lut[:] #lut_seeds[:] = 0 #self.HasSegmentation.setValue(False) self.Trigger.setDirty(slice(None)) def loadObject_impl(self, name): """ Loads a single object called name to be the currently edited object. Its not part of the done segmentation anymore. """ assert self._mst is not None logger.info("[OpCarving] load object %s (opCarving=%d, mst=%d)" % (name, id(self), id(self._mst))) assert name in self._mst.object_lut assert name in self._mst.object_seeds_fg_voxels assert name in self._mst.object_seeds_bg_voxels assert name in self._mst.bg_priority assert name in self._mst.no_bias_below #lut_segmentation = self._mst.segmentation.lut[:] #lut_objects = self._mst.objects.lut[:] #lut_seeds = self._mst.seeds.lut[:] ## clean seeds #lut_seeds[:] = 0 # set foreground and background seeds fgVoxelsSeedPos = self._mst.object_seeds_fg_voxels[name] bgVoxelsSeedPos = self._mst.object_seeds_bg_voxels[name] fgArraySeedPos = numpy.array(fgVoxelsSeedPos) bgArraySeedPos = numpy.array(bgVoxelsSeedPos) self._mst.setSeeds(fgArraySeedPos, bgArraySeedPos) # load the actual segmentation fgNodes = self._mst.object_lut[name] self._mst.setResulFgObj(fgNodes[0]) #newSegmentation = numpy.ones(len(lut_objects), dtype=numpy.int32) #newSegmentation[ self._mst.object_lut[name] ] = 2 #lut_segmentation[:] = newSegmentation self._setCurrObjectName(name) self.HasSegmentation.setValue(True) #now that 'name' is no longer part of the set of finished objects, rebuild the done overlay self._buildDone() return (fgVoxelsSeedPos, bgVoxelsSeedPos) def loadObject(self, name): logger.info("want to load object with name = %s" % name) if not self.hasObjectWithName(name): logger.info(" --> no such object '%s'" % name) return False if self.hasCurrentObject(): self.saveCurrentObject() self._clearLabels() fgVoxels, bgVoxels = self.loadObject_impl(name) fg_bounding_box_start = numpy.array(map(numpy.min, fgVoxels)) fg_bounding_box_stop = 1 + numpy.array(map(numpy.max, fgVoxels)) bg_bounding_box_start = numpy.array(map(numpy.min, bgVoxels)) bg_bounding_box_stop = 1 + numpy.array(map(numpy.max, bgVoxels)) bounding_box_start = numpy.minimum(fg_bounding_box_start, bg_bounding_box_start) bounding_box_stop = numpy.maximum(fg_bounding_box_stop, bg_bounding_box_stop) bounding_box_slicing = roiToSlice(bounding_box_start, bounding_box_stop) bounding_box_shape = tuple(bounding_box_stop - bounding_box_start) dtype = self.opLabelArray.Output.meta.dtype # Convert coordinates to be relative to bounding box fgVoxels = numpy.array(fgVoxels) fgVoxels = fgVoxels - numpy.array([bounding_box_start]).transpose() fgVoxels = list(fgVoxels) bgVoxels = numpy.array(bgVoxels) bgVoxels = bgVoxels - numpy.array([bounding_box_start]).transpose() bgVoxels = list(bgVoxels) with Timer() as timer: logger.info("Loading seeds....") z = numpy.zeros(bounding_box_shape, dtype=dtype) logger.info("Allocating seed array took {} seconds".format( timer.seconds())) z[fgVoxels] = 2 z[bgVoxels] = 1 self.WriteSeeds[(slice(0, 1), ) + bounding_box_slicing + (slice(0, 1), )] = z[numpy.newaxis, :, :, :, numpy.newaxis] logger.info("Loading seeds took a total of {} seconds".format( timer.seconds())) #restore the correct parameter values mst = self._mst assert name in mst.object_lut assert name in mst.object_seeds_fg_voxels assert name in mst.object_seeds_bg_voxels assert name in mst.bg_priority assert name in mst.no_bias_below assert name in mst.bg_priority assert name in mst.no_bias_below self.BackgroundPriority.setValue(mst.bg_priority[name]) self.NoBiasBelow.setValue(mst.no_bias_below[name]) #self.updatePreprocessing() # The entire segmentation layer needs to be refreshed now. self.Segmentation.setDirty() return True @Operator.forbidParallelExecute def deleteObject_impl(self, name): """ Deletes an object called name. """ #lut_seeds = self._mst.seeds.lut[:] # clean seeds #lut_seeds[:] = 0 del self._mst.object_lut[name] del self._mst.object_seeds_fg_voxels[name] del self._mst.object_seeds_bg_voxels[name] del self._mst.bg_priority[name] del self._mst.no_bias_below[name] #delete it from object_names, as it indicates #whether the object exists if name in self._mst.object_names: del self._mst.object_names[name] self._setCurrObjectName("<not saved yet>") #now that 'name' has been deleted, rebuild the done overlay self._buildDone() #self.updatePreprocessing() def deleteObject(self, name): logger.info("want to delete object with name = %s" % name) if not self.hasObjectWithName(name): logger.info(" --> no such object '%s'" % name) return False self.deleteObject_impl(name) #clear the user labels self._clearLabels() # trigger a re-computation self.Trigger.setDirty(slice(None)) self._dirtyObjects.add(name) objects = self._mst.object_names.keys() logger.info("save: len = {}".format(len(objects))) self.AllObjectNames.meta.shape = (len(objects), ) self.HasSegmentation.setValue(False) return True @Operator.forbidParallelExecute def saveCurrentObject(self): """ Saves the objects which is currently edited. """ if self._currObjectName: name = copy.copy(self._currObjectName) logger.info("saving object %s" % self._currObjectName) self.saveCurrentObjectAs(self._currObjectName) self.HasSegmentation.setValue(False) return name return "" @Operator.forbidParallelExecute def saveCurrentObjectAs(self, name): """ Saves current object as name. """ seed = 2 logger.info(" --> Saving object %r from seed %r" % (name, seed)) if self._mst.object_names.has_key(name): objNr = self._mst.object_names[name] else: # find free objNr if len(self._mst.object_names.values()) > 0: objNr = numpy.max(numpy.array( self._mst.object_names.values())) + 1 else: objNr = 1 sVseg = self._mst.getSuperVoxelSeg() sVseed = self._mst.getSuperVoxelSeeds() self._mst.object_names[name] = objNr self._mst.bg_priority[name] = self.BackgroundPriority.value self._mst.no_bias_below[name] = self.NoBiasBelow.value self._mst.objects[name] = numpy.where(sVseg == 2) self._mst.object_lut[name] = numpy.where(sVseg == 2) self._setCurrObjectName("<not saved yet>") self.HasSegmentation.setValue(False) objects = self._mst.object_names.keys() self.AllObjectNames.meta.shape = (len(objects), ) #now that 'name' is no longer part of the set of finished objects, rebuild the done overlay self._buildDone() #self._clearLabels() #self._mst.clearSegmentation() #self.clearCurrentLabeling() #self._mst.gridSegmentor.clearSeeds() #self.Trigger.setDirty(slice(None)) #self.updatePreprocessing() def get_label_voxels(self): #the voxel coordinates of fg and bg labels if not self.opLabelArray.NonzeroBlocks.ready(): return (None, None) nonzeroSlicings = self.opLabelArray.NonzeroBlocks[:].wait()[0] coors1 = [[], [], []] coors2 = [[], [], []] for sl in nonzeroSlicings: a = self.opLabelArray.Output[sl].wait() w1 = numpy.where(a == 1) w2 = numpy.where(a == 2) w1 = [w1[i] + sl[i].start for i in range(1, 4)] w2 = [w2[i] + sl[i].start for i in range(1, 4)] for i in range(3): coors1[i].append(w1[i]) coors2[i].append(w2[i]) for i in range(3): if len(coors1[i]) > 0: coors1[i] = numpy.concatenate(coors1[i], 0) else: coors1[i] = numpy.ndarray((0, ), numpy.int32) if len(coors2[i]) > 0: coors2[i] = numpy.concatenate(coors2[i], 0) else: coors2[i] = numpy.ndarray((0, ), numpy.int32) return (coors2, coors1) def saveObjectAs(self, name): # first, save the object under "name" self.saveCurrentObjectAs(name) # Sparse label array automatically shifts label values down 1 sVseed = self._mst.getSuperVoxelSeeds() #fgVoxels = numpy.where(sVseed==2) #bgVoxels = numpy.where(sVseed==1) fgVoxels, bgVoxels = self.get_label_voxels() self.attachVoxelLabelsToObject(name, fgVoxels=fgVoxels, bgVoxels=bgVoxels) self._clearLabels() # trigger a re-computation self.Trigger.setDirty(slice(None)) self._dirtyObjects.add(name) self._mst.gridSegmentor.clearSeeds() self._mst.clearSegmentation() self.clearCurrentLabeling() def getMaxUncertaintyPos(self, label): # FIXME: currently working on uncertainties = self._mst.uncertainty.lut segmentation = self._mst.segmentation.lut uncertainty_fg = numpy.where(segmentation == label, uncertainties, 0) index_max_uncert = numpy.argmax(uncertainty_fg, axis=0) pos = self._mst.regionCenter[index_max_uncert, :] return pos def execute(self, slot, subindex, roi, result): self._mst = self.MST.value if slot == self.AllObjectNames: ret = self._mst.object_names.keys() return ret sl = roi.toSlice() if slot == self.Segmentation: #avoid data being copied temp = self._mst.getVoxelSegmentation(roi=roi) temp.shape = (1, ) + temp.shape + (1, ) elif slot == self.Supervoxels: #avoid data being copied temp = self._mst.supervoxelUint32[sl[1:4]] temp.shape = (1, ) + temp.shape + (1, ) elif slot == self.DoneObjects: #avoid data being copied if self._done_lut is None: result[0, :, :, :, 0] = 0 return result else: temp = self._done_lut[self._mst.supervoxelUint32[sl[1:4]]] temp.shape = (1, ) + temp.shape + (1, ) elif slot == self.DoneSegmentation: #avoid data being copied if self._done_seg_lut is None: result[0, :, :, :, 0] = 0 return result else: temp = self._done_seg_lut[self._mst.supervoxelUint32[sl[1:4]]] temp.shape = (1, ) + temp.shape + (1, ) elif slot == self.HintOverlay: if self._hints is None: result[:] = 0 return result else: result[:] = self._hints[roi.toSlice()] return result elif slot == self.PmapOverlay: if self._pmap is None: result[:] = 0 return result else: result[:] = self._pmap[roi.toSlice()] return result elif slot == self.Uncertainty: temp = self._mst.uncertainty[sl[1:4]] temp.shape = (1, ) + temp.shape + (1, ) else: raise RuntimeError("unknown slot") return temp #avoid copying data def setInSlot(self, slot, subindex, roi, value): key = roi.toSlice() if slot == self.WriteSeeds: with Timer() as timer: logger.info("Writing seeds to label array") self.opLabelArray.LabelSinkInput[roi.toSlice()] = value logger.info( "Writing seeds to label array took {} seconds".format( timer.seconds())) assert self._mst is not None # Important: mst.seeds will requires erased values to be 255 (a.k.a -1) #value[:] = numpy.where(value == 100, 255, value) seedVal = value.max() with Timer() as timer: logger.info("Writing seeds to MST") if hasattr(key, '__len__'): self._mst.addSeeds(roi=roi, brushStroke=value.squeeze()) else: raise RuntimeError("when is this part of the code called") self._mst.seeds[key] = value logger.info("Writing seeds to MST took {} seconds".format( timer.seconds())) self.has_seeds = True else: raise RuntimeError("unknown slots") def propagateDirty(self, slot, subindex, roi): if slot == self.Trigger or \ slot == self.BackgroundPriority or \ slot == self.NoBiasBelow or \ slot == self.UncertaintyType: if self._mst is None: return if not self.BackgroundPriority.ready(): return if not self.NoBiasBelow.ready(): return bgPrio = self.BackgroundPriority.value noBiasBelow = self.NoBiasBelow.value logger.info( "compute new carving results with bg priority = %f, no bias below %d" % (bgPrio, noBiasBelow)) t1 = time.time() labelCount = 2 params = dict() params["prios"] = [1.0, bgPrio, 1.0] params["uncertainty"] = self.UncertaintyType.value params["noBiasBelow"] = noBiasBelow unaries = numpy.zeros((self._mst.numNodes + 1, labelCount + 1), dtype=numpy.float32) self._mst.run(unaries, **params) logger.info(" ... carving took %f sec." % (time.time() - t1)) self.Segmentation.setDirty(slice(None)) hasSeg = numpy.any(self._mst.hasSeg) #hasSeg = numpy.any(self._mst.segmentation.lut > 0 ) self.HasSegmentation.setValue(hasSeg) elif slot == self.MST: self._opMstCache.Input.disconnect() self._mst = self.MST.value self._opMstCache.Input.setValue(self._mst) elif slot == self.OverlayData or \ slot == self.InputData or \ slot == self.FilteredInputData or \ slot == self.WriteSeeds: pass else: assert False, "Unknown input slot: {}".format(slot.name)
class OpTrackingBase(Operator, ExportingOperator): name = "Tracking" category = "other" LabelImage = InputSlot() ObjectFeatures = InputSlot(stype=Opaque, rtype=List) ObjectFeaturesWithDivFeatures = InputSlot(optional=True, stype=Opaque, rtype=List) ComputedFeatureNames = InputSlot(rtype=List, stype=Opaque) ComputedFeatureNamesWithDivFeatures = InputSlot(optional=True, rtype=List, stype=Opaque) EventsVector = InputSlot(value={}) FilteredLabels = InputSlot(value={}) RawImage = InputSlot() Parameters = InputSlot(value={}) # for serialization InputHdf5 = InputSlot(optional=True) CleanBlocks = OutputSlot() AllBlocks = OutputSlot() OutputHdf5 = OutputSlot() CachedOutput = OutputSlot() # For the GUI (blockwise-access) Output = OutputSlot() # Use a slot for storing the export settings in the project file. ExportSettings = InputSlot() # Override functions ExportingOperator mixin def configure_table_export_settings(self, settings, selected_features): self.ExportSettings.setValue( (settings, selected_features) ) def get_table_export_settings(self): if self.ExportSettings.ready(): (settings, selected_features) = self.ExportSettings.value return (settings, selected_features) else: return None, None def __init__(self, parent=None, graph=None): super(OpTrackingBase, self).__init__(parent=parent, graph=graph) self.label2color = [] self.mergers = [] self.resolvedto = [] self.track_id = None self.extra_track_ids = None self.divisions = None self._opCache = OpCompressedCache(parent=self) self._opCache.InputHdf5.connect(self.InputHdf5) self._opCache.Input.connect(self.Output) self.CleanBlocks.connect(self._opCache.CleanBlocks) self.OutputHdf5.connect(self._opCache.OutputHdf5) self.CachedOutput.connect(self._opCache.Output) self.zeroProvider = OpZeroDefault(parent=self) self.zeroProvider.MetaInput.connect(self.LabelImage) # As soon as input data is available, check its constraints self.RawImage.notifyReady(self._checkConstraints) self.LabelImage.notifyReady(self._checkConstraints) self.export_progress_dialog = None self.ExportSettings.setValue( (None, None) ) def setupOutputs(self): self.Output.meta.assignFrom(self.LabelImage.meta) # cache our own output, don't propagate from internal operator chunks = list(self.LabelImage.meta.shape) # FIXME: assumes t,x,y,z,c chunks[0] = 1 # 't' self._blockshape = tuple(chunks) self._opCache.BlockShape.setValue(self._blockshape) self.AllBlocks.meta.shape = (1,) self.AllBlocks.meta.dtype = object def _checkConstraints(self, *args): if self.RawImage.ready(): rawTaggedShape = self.RawImage.meta.getTaggedShape() if rawTaggedShape['t'] < 2: raise DatasetConstraintError( "Tracking", "For tracking, the dataset must have a time axis with at least 2 images. " \ "Please load time-series data instead. See user documentation for details.") if self.LabelImage.ready(): segmentationTaggedShape = self.LabelImage.meta.getTaggedShape() if segmentationTaggedShape['t'] < 2: raise DatasetConstraintError( "Tracking", "For tracking, the dataset must have a time axis with at least 2 images. " \ "Please load time-series data instead. See user documentation for details.") if self.RawImage.ready() and self.LabelImage.ready(): rawTaggedShape['c'] = None segmentationTaggedShape['c'] = None if dict(rawTaggedShape) != dict(segmentationTaggedShape): raise DatasetConstraintError("Tracking", "For tracking, the raw data and the prediction maps must contain the same " \ "number of timesteps and the same shape. " \ "Your raw image has a shape of (t, x, y, z, c) = {}, whereas your prediction image has a " \ "shape of (t, x, y, z, c) = {}" \ .format(self.RawImage.meta.shape, self.BinaryImage.meta.shape)) def execute(self, slot, subindex, roi, result): if slot is self.Output: result[:] = self.LabelImage.get(roi).wait() if not self.Parameters.ready(): raise Exception("Parameter slot is not ready") parameters = self.Parameters.value t_start = roi.start[0] t_end = roi.stop[0] for t in range(t_start, t_end): if ('time_range' in parameters and t <= parameters['time_range'][-1] and t >= parameters['time_range'][ 0]) and len(self.label2color) > t: result[t - t_start, ..., 0] = relabel(result[t - t_start, ..., 0], self.label2color[t]) else: result[t - t_start, ...] = 0 return result elif slot == self.AllBlocks: # if nothing was computed, return empty list if len(self.label2color) == 0: result[0] = [] return result all_block_rois = [] shape = self.Output.meta.shape # assumes t,x,y,z,c slicing = [slice(None), ] * 5 for t in range(shape[0]): slicing[0] = slice(t, t + 1) all_block_rois.append(sliceToRoi(slicing, shape)) result[0] = all_block_rois return result def propagateDirty(self, inputSlot, subindex, roi): if inputSlot is self.LabelImage: self.Output.setDirty(roi) elif inputSlot is self.EventsVector: self._setLabel2Color() try: self._setLabel2Color(export_mode=True) except: logger.debug("Warning: some label information might be wrong...") def setInSlot(self, slot, subindex, roi, value): assert slot == self.InputHdf5, "Invalid slot for setInSlot(): {}".format(slot.name) def _setLabel2Color(self, successive_ids=True, export_mode=False): if not self.EventsVector.ready() or not self.Parameters.ready() \ or not self.FilteredLabels.ready(): return if export_mode: assert successive_ids, "Export mode only works for successive ids" events = self.EventsVector.value parameters = self.Parameters.value time_min = 0 time_max = self.RawImage.meta.shape[0] - 1 # Assumes t,x,y,z,c if 'time_range' in parameters: time_min, time_max = parameters['time_range'] time_range = range(time_min, time_max) filtered_labels = self.FilteredLabels.value label2color = [] label2color.append({}) mergers = [] resolvedto = [] maxId = 2 # misdetections have id 1 # handle start time offsets for i in range(time_range[0]): label2color.append({}) mergers.append({}) resolvedto.append({}) extra_track_ids = {} if export_mode: multi_move = {} multi_move_next = {} divisions = [] for i in time_range: app = get_dict_value(events[str(i - time_range[0] + 1)], "app", []) div = get_dict_value(events[str(i - time_range[0] + 1)], "div", []) mov = get_dict_value(events[str(i - time_range[0] + 1)], "mov", []) merger = get_dict_value(events[str(i - time_range[0])], "merger", []) res = get_dict_value(events[str(i - time_range[0])], "res", {}) logger.debug(" {} app at {}".format(len(app), i)) logger.debug(" {} div at {}".format(len(div), i)) logger.debug(" {} mov at {}".format(len(mov), i)) logger.debug(" {} merger at {}".format(len(merger), i)) label2color.append({}) mergers.append({}) moves_at = [] resolvedto.append({}) if export_mode: moves_to = {} for e in app: if successive_ids: label2color[-1][int(e[0])] = maxId # in export mode, the label color is used as track ID maxId += 1 else: label2color[-1][int(e[0])] = np.random.randint(1, 255) for e in mov: if export_mode: if e[1] in moves_to: multi_move.setdefault(i, {}) multi_move[i][e[0]] = e[1] if len(moves_to[e[1]]) == 1: # if we are just setting up this multi move multi_move[i][moves_to[e[1]][0]] = e[1] multi_move_next[(i, e[1])] = 0 moves_to.setdefault(e[1], []) moves_to[e[1]].append(e[0]) # moves_to[target] contains list of incoming object ids # alternative way of appearance if not label2color[-2].has_key(int(e[0])): if successive_ids: label2color[-2][int(e[0])] = maxId maxId += 1 else: label2color[-2][int(e[0])] = np.random.randint(1, 255) # assign color of parent label2color[-1][int(e[1])] = label2color[-2][int(e[0])] moves_at.append(int(e[0])) if export_mode: key = i - 1, e[0] if key in multi_move_next: # captures mergers staying connected over longer time spans multi_move_next[key] = e[1] # redirects output of last merger to target in this frame multi_move_next[(i, e[1])] = 0 # sets current end to zero (might be changed by above line in the future) for e in div: # event(parent, child, child) # if not label2color[-2].has_key(int(e[0])): if not int(e[0]) in label2color[-2]: if successive_ids: label2color[-2][int(e[0])] = maxId maxId += 1 else: label2color[-2][int(e[0])] = np.random.randint(1, 255) ancestor_color = label2color[-2][int(e[0])] if export_mode: label2color[-1][int(e[1])] = maxId label2color[-1][int(e[2])] = maxId + 1 divisions.append((i, int(e[0]), ancestor_color, int(e[1]), maxId, int(e[2]), maxId + 1 )) maxId += 2 else: label2color[-1][int(e[1])] = ancestor_color label2color[-1][int(e[2])] = ancestor_color for e in merger: mergers[-1][int(e[0])] = int(e[1]) for o, r in res.iteritems(): resolvedto[-1][int(o)] = [int(c) for c in r[:-1]] # label the original object with the false detection label mergers[-1][int(o)] = len(r[:-1]) if export_mode: extra_track_ids.setdefault(i, {}) extra_track_ids[i][int(o)] = [int(c) for c in r[:-1]] # last timestep merger = get_dict_value(events[str(time_range[-1] - time_range[0] + 1)], "merger", []) mergers.append({}) for e in merger: mergers[-1][int(e[0])] = int(e[1]) res = get_dict_value(events[str(time_range[-1] - time_range[0] + 1)], "res", {}) resolvedto.append({}) if export_mode: extra_track_ids[time_range[-1] + 1] = {} for o, r in res.iteritems(): resolvedto[-1][int(o)] = [int(c) for c in r[:-1]] mergers[-1][int(o)] = len(r[:-1]) if export_mode: extra_track_ids[time_range[-1] + 1][int(o)] = [int(c) for c in r[:-1]] # mark the filtered objects for i in filtered_labels.keys(): if int(i) + time_range[0] >= len(label2color): continue fl_at = filtered_labels[i] for l in fl_at: assert l not in label2color[int(i) + time_range[0]] label2color[int(i) + time_range[0]][l] = 0 if export_mode: # don't set fields when in export_mode self.track_id = label2color self.divisions = divisions self.extra_track_ids = extra_track_ids return label2color, extra_track_ids, divisions self.track_id = label2color self.extra_track_ids = extra_track_ids self.label2color = label2color self.resolvedto = resolvedto self.mergers = mergers self.Output._value = None self.Output.setDirty(slice(None)) if 'MergerOutput' in self.outputs: self.MergerOutput._value = None self.MergerOutput.setDirty(slice(None)) def export_track_ids(self): return self._setLabel2Color(export_mode=True) def track_children(self, track_id, start=0): if start in self.divisions: for t, _, track, _, child_track1, _, child_track2 in self.divisions[start:]: if track == track_id: children_of = partial(self.track_children, start=t) return [child_track1, child_track2] + \ children_of(child_track1) + children_of(child_track2) return [] def track_parent(self, track_id): if not self.divisions == {}: for t, oid, track, _, child_track1, _, child_track2 in self.divisions[:-1]: if track_id in (child_track1, child_track2): return [track] + self.track_parent(track) return [] def track_family(self, track_id): return self.track_children(track_id), self.track_parent(track_id) def _generate_traxelstore(self, time_range, x_range, y_range, z_range, size_range, x_scale=1.0, y_scale=1.0, z_scale=1.0, with_div=False, with_local_centers=False, median_object_size=None, max_traxel_id_at=None, with_opt_correction=False, with_coordinate_list=False, with_classifier_prior=False): if not self.Parameters.ready(): raise Exception("Parameter slot is not ready") parameters = self.Parameters.value parameters['scales'] = [x_scale, y_scale, z_scale] parameters['time_range'] = [min(time_range), max(time_range)] parameters['x_range'] = x_range parameters['y_range'] = y_range parameters['z_range'] = z_range parameters['size_range'] = size_range logger.info("generating traxels") logger.info("fetching region features and division probabilities") feats = self.ObjectFeatures(time_range).wait() if with_div: if not self.DivisionProbabilities.ready() or len(self.DivisionProbabilities([0]).wait()[0]) == 0: msgStr = "\nDivision classifier has not been trained! " + \ "Uncheck divisible objects if your objects don't divide or " + \ "go back to the Division Detection applet and train it." raise DatasetConstraintError ("Tracking",msgStr) divProbs = self.DivisionProbabilities(time_range).wait() if with_local_centers: localCenters = self.RegionLocalCenters(time_range).wait() if with_classifier_prior: if not self.DetectionProbabilities.ready() or len(self.DetectionProbabilities([0]).wait()[0]) == 0: msgStr = "\nObject count classifier has not been trained! " + \ "Go back to the Object Count Classification applet and train it." raise DatasetConstraintError ("Tracking",msgStr) detProbs = self.DetectionProbabilities(time_range).wait() logger.info("filling traxelstore") ts = pgmlink.TraxelStore() fs = pgmlink.FeatureStore() max_traxel_id_at = pgmlink.VectorOfInt() filtered_labels = {} obj_sizes = [] total_count = 0 empty_frame = False for t in feats.keys(): rc = feats[t][default_features_key]['RegionCenter'] lower = feats[t][default_features_key]['Coord<Minimum>'] upper = feats[t][default_features_key]['Coord<Maximum>'] if rc.size: rc = rc[1:, ...] lower = lower[1:, ...] upper = upper[1:, ...] if with_opt_correction: try: rc_corr = feats[t][config.features_vigra_name]['RegionCenter_corr'] except: raise Exception, 'Can not consider optical correction since it has not been computed before' if rc_corr.size: rc_corr = rc_corr[1:, ...] ct = feats[t][default_features_key]['Count'] if ct.size: ct = ct[1:, ...] logger.debug("at timestep {}, {} traxels found".format(t, rc.shape[0])) count = 0 filtered_labels_at = [] for idx in range(rc.shape[0]): # for 2d data, set z-coordinate to 0: if len(rc[idx]) == 2: x, y = rc[idx] z = 0 elif len(rc[idx]) == 3: x, y, z = rc[idx] else: raise DatasetConstraintError ("Tracking", "The RegionCenter feature must have dimensionality 2 or 3.") size = ct[idx] if (x < x_range[0] or x >= x_range[1] or y < y_range[0] or y >= y_range[1] or z < z_range[0] or z >= z_range[1] or size < size_range[0] or size >= size_range[1]): filtered_labels_at.append(int(idx + 1)) continue else: count += 1 tr = pgmlink.Traxel() tr.set_feature_store(fs) tr.set_x_scale(x_scale) tr.set_y_scale(y_scale) tr.set_z_scale(z_scale) tr.Id = int(idx + 1) tr.Timestep = int(t) # pgmlink expects always 3 coordinates, z=0 for 2d data tr.add_feature_array("com", 3) for i, v in enumerate([x, y, z]): tr.set_feature_value('com', i, float(v)) tr.add_feature_array("CoordMinimum", 3) for i, v in enumerate(lower[idx]): tr.set_feature_value("CoordMinimum", i, float(v)) tr.add_feature_array("CoordMaximum", 3) for i, v in enumerate(upper[idx]): tr.set_feature_value("CoordMaximum", i, float(v)) if with_opt_correction: tr.add_feature_array("com_corrected", 3) for i, v in enumerate(rc_corr[idx]): tr.set_feature_value("com_corrected", i, float(v)) if len(rc_corr[idx]) == 2: tr.set_feature_value("com_corrected", 2, 0.) if with_div: tr.add_feature_array("divProb", 1) # idx+1 because rc and ct start from 1, divProbs starts from 0 tr.set_feature_value("divProb", 0, float(divProbs[t][idx + 1][1])) if with_classifier_prior: tr.add_feature_array("detProb", len(detProbs[t][idx + 1])) for i, v in enumerate(detProbs[t][idx + 1]): val = float(v) if val < 0.0000001: val = 0.0000001 if val > 0.99999999: val = 0.99999999 tr.set_feature_value("detProb", i, float(val)) # FIXME: check whether it is 2d or 3d data! if with_local_centers: tr.add_feature_array("localCentersX", len(localCenters[t][idx + 1])) tr.add_feature_array("localCentersY", len(localCenters[t][idx + 1])) tr.add_feature_array("localCentersZ", len(localCenters[t][idx + 1])) for i, v in enumerate(localCenters[t][idx + 1]): tr.set_feature_value("localCentersX", i, float(v[0])) tr.set_feature_value("localCentersY", i, float(v[1])) tr.set_feature_value("localCentersZ", i, float(v[2])) tr.add_feature_array("count", 1) tr.set_feature_value("count", 0, float(size)) if median_object_size is not None: obj_sizes.append(float(size)) ts.add(fs, tr) if len(filtered_labels_at) > 0: filtered_labels[str(int(t) - time_range[0])] = filtered_labels_at logger.debug("at timestep {}, {} traxels passed filter".format(t, count)) max_traxel_id_at.append(int(rc.shape[0])) if count == 0: empty_frame = True total_count += count if median_object_size is not None: median_object_size[0] = np.median(np.array(obj_sizes), overwrite_input=True) logger.info('median object size = ' + str(median_object_size[0])) self.FilteredLabels.setValue(filtered_labels, check_changed=True) return fs, ts, empty_frame, max_traxel_id_at def save_export_progress_dialog(self, dialog): """ Implements ExportOperator.save_export_progress_dialog Without this the progress dialog would be hidden after the export :param dialog: the ProgressDialog to save """ self.export_progress_dialog = dialog def do_export(self, settings, selected_features, progress_slot, lane_index, filename_suffix=""): """ Implements ExportOperator.do_export(settings, selected_features, progress_slot Most likely called from ExportOperator.export_object_data :param settings: the settings for the exporter, see :param selected_features: :param progress_slot: :param lane_index: Ignored. (This is a single-lane operator. It is the caller's responsibility to make sure he's calling the right lane.) :param filename_suffix: If provided, appended to the filename (before the extension). :return: """ with_divisions = self.Parameters.value["withDivisions"] if self.Parameters.ready() else False if with_divisions: object_feature_slot = self.ObjectFeaturesWithDivFeatures else: object_feature_slot = self.ObjectFeatures self._do_export_impl(settings, selected_features, progress_slot, object_feature_slot, self.LabelImage, lane_index, filename_suffix) def _do_export_impl(self, settings, selected_features, progress_slot, object_feature_slot, label_image_slot, lane_index, filename_suffix=""): from ilastik.utility.exportFile import objects_per_frame, ExportFile, ilastik_ids, Mode, Default, \ flatten_dict, division_flatten_dict selected_features = list(selected_features) with_divisions = self.Parameters.value["withDivisions"] if self.Parameters.ready() else False obj_count = list(objects_per_frame(label_image_slot)) track_ids, extra_track_ids, divisions = self.export_track_ids() self._setLabel2Color() lineage = flatten_dict(self.label2color, obj_count) multi_move_max = self.Parameters.value["maxObj"] if self.Parameters.ready() else 2 t_range = self.Parameters.value["time_range"] if self.Parameters.ready() else (0, 0) ids = ilastik_ids(obj_count) file_path = settings["file path"] if filename_suffix: path, ext = os.path.splitext(file_path) file_path = path + "-" + filename_suffix + ext export_file = ExportFile(file_path) export_file.ExportProgress.subscribe(progress_slot) export_file.InsertionProgress.subscribe(progress_slot) export_file.add_columns("table", range(sum(obj_count)), Mode.List, Default.KnimeId) export_file.add_columns("table", list(ids), Mode.List, Default.IlastikId) export_file.add_columns("table", lineage, Mode.List, Default.Lineage) export_file.add_columns("table", track_ids, Mode.IlastikTrackingTable, {"max": multi_move_max, "counts": obj_count, "extra ids": extra_track_ids, "range": t_range}) export_file.add_columns("table", object_feature_slot, Mode.IlastikFeatureTable, {"selection": selected_features}) if with_divisions: if divisions: div_lineage = division_flatten_dict(divisions, self.label2color) zips = zip(*divisions) divisions = zip(zips[0], div_lineage, *zips[1:]) export_file.add_columns("divisions", divisions, Mode.List, Default.DivisionNames) else: logger.debug("No divisions occurred. Division Table will not be exported!") if settings["file type"] == "h5": export_file.add_rois(Default.LabelRoiPath, label_image_slot, "table", settings["margin"], "labeling") if settings["include raw"]: export_file.add_image(Default.RawPath, self.RawImage) else: export_file.add_rois(Default.RawRoiPath, self.RawImage, "table", settings["margin"]) export_file.write_all(settings["file type"], settings["compression"]) export_file.ExportProgress.unsubscribe(progress_slot) export_file.InsertionProgress.unsubscribe(progress_slot)
class _OpCachedLabelImage(Operator): """ Combines OpLabelImage with OpCompressedCache, and provides a default block shape. """ Input = InputSlot() BackgroundLabels = InputSlot( optional=True) # Optional. See OpLabelImage for details. BlockShape = InputSlot( optional=True ) # If not provided, blockshape is 1 time slice, 1 channel slice, # and the entire volume in xyz. Output = OutputSlot() # Serialization support InputHdf5 = InputSlot(optional=True) CleanBlocks = OutputSlot() OutputHdf5 = OutputSlot() # See OpCachedLabelImage for details # Schematic: # # BackgroundLabels -- BlockShape -- # \ \ # Input ------------> OpLabelImage ---> OpCompressedCache --> Output # \ # --> CleanBlocks def __init__(self, *args, **kwargs): super(_OpCachedLabelImage, self).__init__(*args, **kwargs) # Hook up the labeler self._opLabelImage = OpLabelImage(parent=self) self._opLabelImage.Input.connect(self.Input) self._opLabelImage.BackgroundLabels.connect(self.BackgroundLabels) # Hook up the cache self._opCache = OpCompressedCache(parent=self) self._opCache.Input.connect(self._opLabelImage.Output) self._opCache.InputHdf5.connect(self.InputHdf5) # Hook up our output slots self.Output.connect(self._opCache.Output) self.CleanBlocks.connect(self._opCache.CleanBlocks) self.OutputHdf5.connect(self._opCache.OutputHdf5) def generateReport(self, report): return self._opCache.generateReport(report) def usedMemory(self): return self._opCache.usedMemory() def fractionOfUsedMemoryDirty(self): return self._opCache.fractionOfUsedMemoryDirty() def lastAccessTime(self): return self._opCache.lastAccessTime() def setupOutputs(self): if self.BlockShape.ready(): self._opCache.BlockShape.setValue(self.BlockShape.value) else: # By default, block shape is the same as the entire image shape, # but only 1 time slice and 1 channel slice taggedBlockShape = self.Input.meta.getTaggedShape() taggedBlockShape["t"] = 1 taggedBlockShape["c"] = 1 self._opCache.BlockShape.setValue(tuple(taggedBlockShape.values())) def execute(self, slot, subindex, roi, destination): assert False, "Shouldn't get here." def propagateDirty(self, slot, subindex, roi): pass # Nothing to do... def setInSlot(self, slot, subindex, roi, value): assert slot == self.Input or slot == self.InputHdf5, "Invalid slot for setInSlot(): {}".format( slot.name)
class OpSimpleBlockedArrayCache(OpUnblockedArrayCache): BlockShape = InputSlot( optional=True ) # Must be a tuple. Any 'None' elements will be interpreted as 'max' for that dimension. BypassModeEnabled = InputSlot(value=False) def __init__(self, *args, **kwargs): super(OpSimpleBlockedArrayCache, self).__init__(*args, **kwargs) self._blockshape = None def setupOutputs(self): super(OpSimpleBlockedArrayCache, self).setupOutputs() if self.BlockShape.ready(): self._blockshape = self.BlockShape.value else: self._blockshape = self.Input.meta.shape if len(self._blockshape) != len(self.Input.meta.shape): self.Output.meta.NOTREADY = True return # Replace 'None' (or zero) with default (from Input shape) self._blockshape = tuple( numpy.where(self._blockshape, self._blockshape, self.Input.meta.shape)) self.Output.meta.ideal_blockshape = tuple( numpy.minimum(self._blockshape, self.Input.meta.shape)) # Estimate ram usage per requested pixel ram_per_pixel = get_ram_per_element(self.Input.meta.dtype) # One 'pixel' includes all channels tagged_shape = self.Input.meta.getTaggedShape() if "c" in tagged_shape: ram_per_pixel *= float(tagged_shape["c"]) if self.Input.meta.ram_usage_per_requested_pixel is not None: ram_per_pixel = max(ram_per_pixel, self.Input.meta.ram_usage_per_requested_pixel) self.Output.meta.ram_usage_per_requested_pixel = ram_per_pixel def _execute_Output(self, slot, subindex, roi, result): """ Overridden from OpUnblockedArrayCache """ def copy_block(full_block_roi, clipped_block_roi): full_block_roi = numpy.asarray(full_block_roi) clipped_block_roi = numpy.asarray(clipped_block_roi) output_roi = numpy.asarray(clipped_block_roi) - roi.start block_roi = self._get_containing_block_roi(clipped_block_roi) # Skip cache and copy full block directly if self.BypassModeEnabled.value: full_block_data = self.Output.stype.allocateDestination( SubRegion(self.Output, *full_block_roi)) self.Input(*full_block_roi).writeInto(full_block_data).block() roi_within_block = clipped_block_roi - full_block_roi[0] self.Output.stype.copy_data( result[roiToSlice(*output_roi)], full_block_data[roiToSlice(*roi_within_block)]) # If data data exists already or we can just fetch it without needing extra scratch space, # just call the base class elif block_roi is not None or (full_block_roi == clipped_block_roi).all(): self._execute_Output_impl(clipped_block_roi, result[roiToSlice(*output_roi)]) elif self.Input.meta.dontcache: # Data isn't in the cache, but we don't need it in the cache anyway. self.Input(*clipped_block_roi).writeInto( result[roiToSlice(*output_roi)]).block() else: # Data doesn't exist yet in the cache. # Request the full block, but then discard the parts we don't need. # (We use allocateDestination() here to support MaskedArray types.) # TODO: We should probably just get rid of MaskedArray support altogether... full_block_data = self.Output.stype.allocateDestination( SubRegion(self.Output, *full_block_roi)) self._execute_Output_impl(full_block_roi, full_block_data) roi_within_block = clipped_block_roi - full_block_roi[0] self.Output.stype.copy_data( result[roiToSlice(*output_roi)], full_block_data[roiToSlice(*roi_within_block)]) clipped_block_rois = getIntersectingRois(self.Input.meta.shape, self._blockshape, (roi.start, roi.stop), True) full_block_rois = getIntersectingRois(self.Input.meta.shape, self._blockshape, (roi.start, roi.stop), False) pool = RequestPool() for full_block_roi, clipped_block_roi in zip(full_block_rois, clipped_block_rois): req = Request( partial(copy_block, full_block_roi, clipped_block_roi)) pool.add(req) pool.wait() def propagateDirty(self, slot, subindex, roi): if slot in (self.BypassModeEnabled, self.BlockShape): return super(OpSimpleBlockedArrayCache, self).propagateDirty(slot, subindex, roi)
class OpConservationTracking(Operator): LabelImage = InputSlot() ObjectFeatures = InputSlot(stype=Opaque, rtype=List) ObjectFeaturesWithDivFeatures = InputSlot(optional=True, stype=Opaque, rtype=List) ComputedFeatureNames = InputSlot(rtype=List, stype=Opaque) ComputedFeatureNamesWithDivFeatures = InputSlot(optional=True, rtype=List, stype=Opaque) FilteredLabels = InputSlot(value={}) RawImage = InputSlot() Parameters = InputSlot(value={}) HypothesesGraph = InputSlot(value={}) ResolvedMergers = InputSlot(value={}) # for serialization CleanBlocks = OutputSlot() AllBlocks = OutputSlot() CachedOutput = OutputSlot() # For the GUI (blockwise-access) Output = OutputSlot() # Volume relabelled with lineage IDs # Use a slot for storing the export settings in the project file. # just here so that old projects still load! ExportSettings = InputSlot() DivisionProbabilities = InputSlot(optional=True, stype=Opaque, rtype=List) DetectionProbabilities = InputSlot(stype=Opaque, rtype=List) NumLabels = InputSlot() # compressed cache for merger output MergerCleanBlocks = OutputSlot() MergerCachedOutput = OutputSlot() # For the GUI (blockwise access) MergerOutput = OutputSlot() # Volume showing only merger IDs RelabeledCleanBlocks = OutputSlot() RelabeledCachedOutput = OutputSlot() # For the GUI (blockwise access) RelabeledImage = OutputSlot() # Volume showing object IDs def __init__(self, parent=None, graph=None): super(OpConservationTracking, self).__init__(parent=parent, graph=graph) self._opCache = OpBlockedArrayCache(parent=self) self._opCache.name = "OpConservationTracking._opCache" self._opCache.Input.connect(self.Output) self.CleanBlocks.connect(self._opCache.CleanBlocks) self.CachedOutput.connect(self._opCache.Output) self.zeroProvider = OpZeroDefault(parent=self) self.zeroProvider.MetaInput.connect(self.LabelImage) # As soon as input data is available, check its constraints self.RawImage.notifyReady(self._checkConstraints) self.LabelImage.notifyReady(self._checkConstraints) self.ExportSettings.setValue( (None, None) ) self._mergerOpCache = OpBlockedArrayCache(parent=self) self._mergerOpCache.name = "OpConservationTracking._mergerOpCache" self._mergerOpCache.Input.connect(self.MergerOutput) self.MergerCleanBlocks.connect(self._mergerOpCache.CleanBlocks) self.MergerCachedOutput.connect(self._mergerOpCache.Output) self._relabeledOpCache = OpBlockedArrayCache(parent=self) self._relabeledOpCache.name = "OpConservationTracking._mergerOpCache" self._relabeledOpCache.Input.connect(self.RelabeledImage) self.RelabeledCleanBlocks.connect(self._relabeledOpCache.CleanBlocks) self.RelabeledCachedOutput.connect(self._relabeledOpCache.Output) # Merger resolver plugin manager (contains GMM fit routine) self.pluginPaths = [os.path.join(os.path.dirname(os.path.abspath(hytra.__file__)), 'plugins')] pluginManager = TrackingPluginManager(verbose=False, pluginPaths=self.pluginPaths) self.mergerResolverPlugin = pluginManager.getMergerResolver() self.result = None # gui progress self.progressWindow = None self.progressVisitor=DefaultProgressVisitor() def setupOutputs(self): self.Output.meta.assignFrom(self.LabelImage.meta) # cache our own output, don't propagate from internal operator chunks = list(self.LabelImage.meta.shape) # FIXME: assumes t,x,y,z,c chunks[0] = 1 # 't' self._blockshape = tuple(chunks) self._opCache.BlockShape.setValue(self._blockshape) self.AllBlocks.meta.shape = (1,) self.AllBlocks.meta.dtype = object self.MergerOutput.meta.assignFrom(self.LabelImage.meta) self.RelabeledImage.meta.assignFrom(self.LabelImage.meta) self._mergerOpCache.BlockShape.setValue( self._blockshape ) self._relabeledOpCache.BlockShape.setValue( self._blockshape ) frame_shape = (1,) + self.LabelImage.meta.shape[1:] # assumes t,x,y,z,c order assert frame_shape[-1] == 1 self.MergerOutput.meta.ideal_blockshape = frame_shape self.RelabeledImage.meta.ideal_blockshape = frame_shape def execute(self, slot, subindex, roi, result): # Output showing lineage IDs if slot is self.Output: if not self.Parameters.ready(): raise Exception("Parameter slot is not ready") parameters = self.Parameters.value resolvedMergers = self.ResolvedMergers.value # Assume [t,x,y,z,c] order trange = range(roi.start[0], roi.stop[0]) offset = roi.start[1:-1] result[:] = self.LabelImage.get(roi).wait() for t in trange: if 'time_range' in parameters and t <= parameters['time_range'][-1] and t >= parameters['time_range'][0]: if resolvedMergers: self._labelMergers(result[t-roi.start[0],...,0], t, offset) result[t-roi.start[0],...,0] = self._labelLineageIds(result[t-roi.start[0],...,0], t) else: result[t-roi.start[0],...][:] = 0 # Output showing mergers only elif slot is self.MergerOutput: parameters = self.Parameters.value resolvedMergers = self.ResolvedMergers.value # Assume [t,x,y,z,c] order trange = range(roi.start[0], roi.stop[0]) offset = roi.start[1:-1] result[:] = self.LabelImage.get(roi).wait() for t in trange: if 'time_range' in parameters and t <= parameters['time_range'][-1] and t >= parameters['time_range'][0]: if resolvedMergers: self._labelMergers(result[t-roi.start[0],...,0], t, offset) result[t-roi.start[0],...,0] = self._labelLineageIds(result[t-roi.start[0],...,0], t, onlyMergers=True) else: result[t-roi.start[0],...][:] = 0 # Output showing object Ids (before lineage IDs are assigned) elif slot is self.RelabeledImage: parameters = self.Parameters.value resolvedMergers = self.ResolvedMergers.value # Assume [t,x,y,z,c] order trange = range(roi.start[0], roi.stop[0]) offset = roi.start[1:-1] result[:] = self.LabelImage.get(roi).wait() for t in trange: if resolvedMergers and 'time_range' in parameters and t <= parameters['time_range'][-1] and t >= parameters['time_range'][0]: self._labelMergers(result[t-roi.start[0],...,0], t, offset) # Cache blocks elif slot == self.AllBlocks: # if nothing was computed, return empty list if not self.HypothesesGraph.value: result[0] = [] return result all_block_rois = [] shape = self.Output.meta.shape # assumes t,x,y,z,c slicing = [slice(None), ] * 5 for t in range(shape[0]): slicing[0] = slice(t, t + 1) all_block_rois.append(sliceToRoi(slicing, shape)) result[0] = all_block_rois return result def setInSlot(self, slot, subindex, roi, value): assert slot == self.InputHdf5 or slot == self.MergerInputHdf5 or slot == self.RelabeledInputHdf5, "Invalid slot for setInSlot(): {}".format( slot.name ) def _createHypothesesGraph(self): ''' Construct a hypotheses graph given the current settings in the parameters slot ''' parameters = self.Parameters.value time_range = range(parameters['time_range'][0],parameters['time_range'][1] + 1) x_range = parameters['x_range'] y_range = parameters['y_range'] z_range = parameters['z_range'] size_range = parameters['size_range'] scales = parameters['scales'] withDivisions = parameters['withDivisions'] withClassifierPrior = parameters['withClassifierPrior'] maxDist = parameters['maxDist'] maxObj = parameters['maxObj'] divThreshold = parameters['divThreshold'] max_nearest_neighbors = parameters['max_nearest_neighbors'] borderAwareWidth = parameters['borderAwareWidth'] traxelstore = self._generate_traxelstore(time_range, x_range, y_range, z_range, size_range, scales[0], scales[1], scales[2], with_div=withDivisions, with_classifier_prior=withClassifierPrior) def constructFov(shape, t0, t1, scale=[1, 1, 1]): [xshape, yshape, zshape] = shape [xscale, yscale, zscale] = scale fov = FieldOfView(t0, 0, 0, 0, t1, xscale * (xshape - 1), yscale * (yshape - 1), zscale * (zshape - 1)) return fov fieldOfView = constructFov((x_range[1], y_range[1], z_range[1]), time_range[0], time_range[-1]+1, scales) if WITH_HYTRA: hypothesesGraph = IlastikHypothesesGraph( probabilityGenerator=traxelstore, timeRange=(time_range[0],time_range[-1]+1), maxNumObjects=maxObj, numNearestNeighbors=max_nearest_neighbors, fieldOfView=fieldOfView, withDivisions=withDivisions, maxNeighborDistance=maxDist, divisionThreshold=divThreshold, borderAwareWidth=borderAwareWidth, progressVisitor=self.progressVisitor ) else: hypothesesGraph = IlastikHypothesesGraph( probabilityGenerator=traxelstore, timeRange=(time_range[0],time_range[-1]+1), maxNumObjects=maxObj, numNearestNeighbors=max_nearest_neighbors, fieldOfView=fieldOfView, withDivisions=withDivisions, maxNeighborDistance=maxDist, divisionThreshold=divThreshold, borderAwareWidth=borderAwareWidth ) return hypothesesGraph def _resolveMergers(self, hypothesesGraph, model): ''' run merger resolution on the hypotheses graph which contains the current solution ''' logger.info("Resolving mergers.") parameters = self.Parameters.value withTracklets = parameters['withTracklets'] originalGraph = hypothesesGraph.referenceTraxelGraph if withTracklets else hypothesesGraph resolvedMergersDict = {} # Enable full graph computation for animal tracking workflow withFullGraph = False if 'withAnimalTracking' in parameters and parameters['withAnimalTracking']: # TODO: Setting this parameter outside of the track() function (on AnimalConservationTrackingWorkflow) is not desirable withFullGraph = True logger.info("Computing full graph on merger resolver (Only enabled on animal tracking workflow)") mergerResolver = IlastikMergerResolver(originalGraph, pluginPaths=self.pluginPaths, withFullGraph=withFullGraph) # Check if graph contains mergers, otherwise skip merger resolving if not mergerResolver.mergerNum: logger.info("Graph contains no mergers. Skipping merger resolving.") else: # Fit and refine merger nodes using a GMM # It has to be done per time-step in order to aviod loading the whole video on RAM traxelIdPerTimestepToUniqueIdMap, uuidToTraxelMap = getMappingsBetweenUUIDsAndTraxels(model) timesteps = [int(t) for t in traxelIdPerTimestepToUniqueIdMap.keys()] timesteps.sort() timeIndex = self.LabelImage.meta.axistags.index('t') numTimeStep = len(timesteps) count=0 for timestep in timesteps: count +=1 self.progressVisitor.showProgress(count/float(numTimeStep)) roi = [slice(None) for i in range(len(self.LabelImage.meta.shape))] roi[timeIndex] = slice(timestep, timestep+1) roi = tuple(roi) labelImage = self.LabelImage[roi].wait() # Get coordinates for object IDs in label image. Used by GMM merger fit. objectIds = vigra.analysis.unique(labelImage[0,...,0]) maxObjectId = max(objectIds) coordinatesForIds = {} pool = RequestPool() for objectId in objectIds: pool.add(Request(partial(mergerResolver.getCoordinatesForObjectId, coordinatesForIds, labelImage[0, ..., 0], timestep, objectId))) # Run requests to get object ID coordinates pool.wait() # Fit mergers and store fit info in nodes if coordinatesForIds: mergerResolver.fitAndRefineNodesForTimestep(coordinatesForIds, maxObjectId, timestep) self.parent.parent.trackingApplet.progressSignal.emit(100) # Compute object features, re-run flow solver, update model and result, and get merger dictionary resolvedMergersDict = mergerResolver.run() return resolvedMergersDict def raiseException(self, progressWindow, str): if progressWindow is not None: progressWindow.onTrackDone() raise Exception (str) def raiseDatasetConstraintError(self, progressWindow, titleStr, str): if progressWindow is not None: progressWindow.onTrackDone() raise DatasetConstraintError(titleStr, str) def track(self, time_range, x_range, y_range, z_range, size_range=(0, 100000), x_scale=1.0, y_scale=1.0, z_scale=1.0, maxDist=30, maxObj=2, divThreshold=0.5, avgSize=[0], withTracklets=False, sizeDependent=True, detWeight=10.0, divWeight=10.0, transWeight=10.0, withDivisions=True, withOpticalCorrection=True, withClassifierPrior=False, ndim=3, cplex_timeout=None, withMergerResolution=True, borderAwareWidth = 0.0, withArmaCoordinates = True, appearance_cost = 500, disappearance_cost = 500, motionModelWeight=10.0, force_build_hypotheses_graph = False, max_nearest_neighbors = 1, numFramesPerSplit=0, withBatchProcessing = False, solverName="Flow-based", progressWindow=None, progressVisitor=CommandLineProgressVisitor() ): """ Main conservation tracking function. Runs tracking solver, generates hypotheses graph, and resolves mergers. """ if WITH_HYTRA: self.progressWindow = progressWindow self.progressVisitor=progressVisitor else: self.progressWindow = None self.progressVisitor = DefaultProgressVisitor() if not self.Parameters.ready(): self.raiseException(self.progressWindow, "Parameter slot is not ready") # it is assumed that the self.Parameters object is changed only at this # place (ugly assumption). Therefore we can track any changes in the # parameters as done in the following lines: If the same value for the # key is already written in the parameters dictionary, the # paramters_changed dictionary will get a "False" entry for this key, # otherwise it is set to "True" parameters = self.Parameters.value parameters['maxDist'] = maxDist parameters['maxObj'] = maxObj parameters['divThreshold'] = divThreshold parameters['avgSize'] = avgSize parameters['withTracklets'] = withTracklets parameters['sizeDependent'] = sizeDependent parameters['detWeight'] = detWeight parameters['divWeight'] = divWeight parameters['transWeight'] = transWeight parameters['withDivisions'] = withDivisions parameters['withOpticalCorrection'] = withOpticalCorrection parameters['withClassifierPrior'] = withClassifierPrior parameters['withMergerResolution'] = withMergerResolution parameters['borderAwareWidth'] = borderAwareWidth parameters['withArmaCoordinates'] = withArmaCoordinates parameters['appearanceCost'] = appearance_cost parameters['disappearanceCost'] = disappearance_cost parameters['scales'] = [x_scale, y_scale, z_scale] parameters['time_range'] = [min(time_range), max(time_range)] parameters['x_range'] = x_range parameters['y_range'] = y_range parameters['z_range'] = z_range parameters['max_nearest_neighbors'] = max_nearest_neighbors parameters['numFramesPerSplit'] = numFramesPerSplit parameters['solver'] = str(solverName) # Set a size range with a minimum area equal to the max number of objects (since the GMM throws an error if we try to fit more gaussians than the number of pixels in the object) size_range = (max(maxObj, size_range[0]), size_range[1]) parameters['size_range'] = size_range if cplex_timeout: parameters['cplex_timeout'] = cplex_timeout else: parameters['cplex_timeout'] = '' cplex_timeout = float(1e75) self.Parameters.setValue(parameters, check_changed=False) if withClassifierPrior: if not self.DetectionProbabilities.ready() or len(self.DetectionProbabilities([0]).wait()[0]) == 0: self.raiseDatasetConstraintError(self.progressWindow, 'Tracking', 'Classifier not ready yet. Did you forget to train the Object Count Classifier?') if not self.NumLabels.ready() or self.NumLabels.value < (maxObj + 1): self.raiseDatasetConstraintError(self.progressWindow, 'Tracking', 'The max. number of objects must be consistent with the number of labels given in Object Count Classification.\n' +\ 'Check whether you have (i) the correct number of label names specified in Object Count Classification, and (ii) provided at least ' +\ 'one training example for each class.') if len(self.DetectionProbabilities([0]).wait()[0][0]) < (maxObj + 1): self.raiseDatasetConstraintError(self.progressWindow, 'Tracking', 'The max. number of objects must be consistent with the number of labels given in Object Count Classification.\n' +\ 'Check whether you have (i) the correct number of label names specified in Object Count Classification, and (ii) provided at least ' +\ 'one training example for each class.') hypothesesGraph = self._createHypothesesGraph() hypothesesGraph.allowLengthOneTracks = True if withTracklets: hypothesesGraph = hypothesesGraph.generateTrackletGraph() hypothesesGraph.insertEnergies() trackingGraph = hypothesesGraph.toTrackingGraph() trackingGraph.convexifyCosts() model = trackingGraph.model model['settings']['allowLengthOneTracks'] = True detWeight = 10.0 # FIXME: Should we store this weight in the parameters slot? weights = trackingGraph.weightsListToDict([transWeight, detWeight, divWeight, appearance_cost, disappearance_cost]) stepStr = "Tracking solver" self.progressVisitor.showState(stepStr) self.progressVisitor.showProgress(0) if solverName == 'Flow-based' and dpct: if numFramesPerSplit: # Run solver with frame splits (split, solve, and stitch video to improve running-time) from hytra.core.splittracking import SplitTracking result = SplitTracking.trackFlowBasedWithSplits(model, weights, numFramesPerSplit=numFramesPerSplit) else: result = dpct.trackFlowBased(model, weights) elif solverName == 'ILP' and mht: result = mht.track(model, weights) else: raise ValueError("Invalid tracking solver selected") self.progressVisitor.showProgress(1.0) # Insert the solution into the hypotheses graph and from that deduce the lineages if hypothesesGraph: hypothesesGraph.insertSolution(result) # Merger resolution resolvedMergersDict = {} if withMergerResolution: stepStr = "Merger resolution" self.progressVisitor.showState(stepStr) resolvedMergersDict = self._resolveMergers(hypothesesGraph, model) # Set value of resolved mergers slot (Should be empty if mergers are disabled) self.ResolvedMergers.setValue(resolvedMergersDict, check_changed=False) # Computing tracking lineage IDs from within Hytra hypothesesGraph.computeLineage() if self.progressWindow is not None: self.progressWindow.onTrackDone() # Uncomment to export a hypothese graph diagram #logger.info("Exporting hypotheses graph diagram") #from hytra.util.hypothesesgraphdiagram import HypothesesGraphDiagram #hgv = HypothesesGraphDiagram(hypothesesGraph._graph, timeRange=(0, 10), fileName='HypothesesGraph.png' ) # Set value of hypotheses grap slot (use referenceTraxelGraph if using tracklets) hypothesesGraph = hypothesesGraph.referenceTraxelGraph if withTracklets else hypothesesGraph self.HypothesesGraph.setValue(hypothesesGraph, check_changed=False) # Set all the output slots dirty (See execute() function) self.Output.setDirty() self.MergerOutput.setDirty() self.RelabeledImage.setDirty() return result def propagateDirty(self, inputSlot, subindex, roi): if inputSlot is self.LabelImage: self.Output.setDirty(roi) elif inputSlot is self.HypothesesGraph: pass elif inputSlot is self.ResolvedMergers: pass elif inputSlot == self.NumLabels: if self.parent.parent.trackingApplet._gui \ and self.parent.parent.trackingApplet._gui.currentGui() \ and self.NumLabels.ready() \ and self.NumLabels.value > 1: self.parent.parent.trackingApplet._gui.currentGui()._drawer.maxObjectsBox.setValue(self.NumLabels.value-1) def _labelMergers(self, volume, time, offset): """ Label volume mergers with correspoding IDs, using the plugin GMM fit """ resolvedMergersDict = self.ResolvedMergers.value if time not in resolvedMergersDict: return volume idxs = vigra.analysis.unique(volume) for idx in idxs: if idx in resolvedMergersDict[time]: fits = resolvedMergersDict[time][idx]['fits'] newIds = resolvedMergersDict[time][idx]['newIds'] self.mergerResolverPlugin.updateLabelImage(volume, idx, fits, newIds, offset=offset) return volume def _labelLineageIds(self, volume, time, onlyMergers=False): """ Label the every object in the volume for the given time frame by the lineage ID it belongs to. If onlyMergers is True, then only those segments that were resolved from a merger are shown, everything else set to zero. :return: the relabeled volume, where 0 means background, 1 means false detection, and all higher numbers indicate lineages """ hypothesesGraph = self.HypothesesGraph.value if not hypothesesGraph: return np.zeros_like(volume) resolvedMergersDict = self.ResolvedMergers.value indexMapping = np.zeros(np.amax(volume) + 1, dtype=volume.dtype) idxs = vigra.analysis.unique(volume) # Reduce labels to the ones that contain mergers if onlyMergers: if resolvedMergersDict: if time not in resolvedMergersDict: idxs = [] else: newIds = [newId for _, nodeDict in resolvedMergersDict[time].items() for newId in nodeDict['newIds']] idxs = [id for id in idxs if id in newIds] else: idxs = [idx for idx in idxs if idx > 0 and hypothesesGraph.hasNode((time,idx)) and hypothesesGraph._graph.node[(time,idx)]['value'] > 1] # Map labels to corresponding lineage IDs for idx in idxs: if idx > 0 and hypothesesGraph.hasNode((time,idx)): lineage_id = hypothesesGraph.getLineageId(time, idx) if lineage_id is None: lineage_id = 1 indexMapping[idx] = lineage_id return indexMapping[volume] def _setupRelabeledFeatureSlot(self, original_feature_slot): from ilastik.applets.trackingFeatureExtraction import config # when exporting after merger resolving, the stored object features are not up to date for the relabeled objects opRelabeledRegionFeatures = OpRelabeledMergerFeatureExtraction(parent=self) opRelabeledRegionFeatures.RawImage.connect(self.RawImage) opRelabeledRegionFeatures.LabelImage.connect(self.LabelImage) opRelabeledRegionFeatures.RelabeledImage.connect(self.RelabeledImage) opRelabeledRegionFeatures.OriginalRegionFeatures.connect(original_feature_slot) vigra_features = list((set(config.vigra_features)).union(config.selected_features_objectcount[config.features_vigra_name])) feature_names_vigra = {} feature_names_vigra[config.features_vigra_name] = { name: {} for name in vigra_features } opRelabeledRegionFeatures.FeatureNames.setValue(feature_names_vigra) return opRelabeledRegionFeatures def exportPlugin(self, filename, plugin, checkOverwriteFiles=False): with_divisions = self.Parameters.value["withDivisions"] if self.Parameters.ready() else False with_merger_resolution = self.Parameters.value["withMergerResolution"] if self.Parameters.ready() else False # Create opRegionFeatures to extract features of relabeled volume if with_merger_resolution: parameters = self.Parameters.value # Use simple relabeled merger feature slot configuration instead of opRelabeledMergerFeatureExtraction # This is faster for videos with few mergers and few number of objects per frame if False:#'withAnimalTracking' in parameters and parameters['withAnimalTracking']: logger.info('Setting relabeled merger feature slots for animal tracking') from ilastik.applets.trackingFeatureExtraction import config self._opRegionFeatures = OpRegionFeatures(parent=self) self._opRegionFeatures.RawVolume.connect(self.RawImage) self._opRegionFeatures.LabelVolume.connect(self.RelabeledImage) vigra_features = list((set(config.vigra_features)).union(config.selected_features_objectcount[config.features_vigra_name])) feature_names_vigra = {} feature_names_vigra[config.features_vigra_name] = { name: {} for name in vigra_features } self._opRegionFeatures.Features.setValue(feature_names_vigra) self._opAdaptTimeListRoi = OpAdaptTimeListRoi(parent=self) self._opAdaptTimeListRoi.Input.connect(self._opRegionFeatures.Output) object_feature_slot = self._opAdaptTimeListRoi.Output # Use opRelabeledMergerFeatureExtraction for cell tracking else: opRelabeledRegionFeatures = self._setupRelabeledFeatureSlot(self.ObjectFeatures) object_feature_slot = opRelabeledRegionFeatures.RegionFeatures label_image = self.RelabeledImage # Use ObjectFeaturesWithDivFeatures slot elif with_divisions: object_feature_slot = self.ObjectFeaturesWithDivFeatures label_image = self.LabelImage # Use ObjectFeatures slot only else: object_feature_slot = self.ObjectFeatures label_image = self.LabelImage hypothesesGraph = self.HypothesesGraph.value if checkOverwriteFiles and plugin.checkFilesExist(filename): # do not export if we would otherwise overwrite files return False if not plugin.export(filename, hypothesesGraph, object_feature_slot, label_image, self.RawImage): raise RuntimeError('Exporting tracking solution with plugin failed') else: return True def get_table_export_settings(self): # TODO: remove once tracking is hytra-only return None, None def _checkConstraints(self, *args): if self.RawImage.ready(): rawTaggedShape = self.RawImage.meta.getTaggedShape() if rawTaggedShape['t'] < 2: raise DatasetConstraintError( "Tracking", "For tracking, the dataset must have a time axis with at least 2 images. " \ "Please load time-series data instead. See user documentation for details.") if self.LabelImage.ready(): segmentationTaggedShape = self.LabelImage.meta.getTaggedShape() if segmentationTaggedShape['t'] < 2: raise DatasetConstraintError( "Tracking", "For tracking, the dataset must have a time axis with at least 2 images. " \ "Please load time-series data instead. See user documentation for details.") if self.RawImage.ready() and self.LabelImage.ready(): rawTaggedShape['c'] = None segmentationTaggedShape['c'] = None if dict(rawTaggedShape) != dict(segmentationTaggedShape): raise DatasetConstraintError("Tracking", "For tracking, the raw data and the prediction maps must contain the same " \ "number of timesteps and the same shape. " \ "Your raw image has a shape of (t, x, y, z, c) = {}, whereas your prediction image has a " \ "shape of (t, x, y, z, c) = {}" \ .format(self.RawImage.meta.shape, self.BinaryImage.meta.shape)) def _generate_traxelstore(self, time_range, x_range, y_range, z_range, size_range, x_scale=1.0, y_scale=1.0, z_scale=1.0, with_div=False, with_local_centers=False, with_classifier_prior=False): logger.info("generating traxels") traxelstore = ProbabilityGenerator() logger.info("fetching region features and division probabilities") feats = self.ObjectFeatures(time_range).wait() if with_div: if not self.DivisionProbabilities.ready() or len(self.DivisionProbabilities([0]).wait()[0]) == 0: msgStr = "\nDivision classifier has not been trained! " + \ "Uncheck divisible objects if your objects don't divide or " + \ "go back to the Division Detection applet and train it." raise DatasetConstraintError ("Tracking",msgStr) divProbs = self.DivisionProbabilities(time_range).wait() if with_local_centers: localCenters = self.RegionLocalCenters(time_range).wait() if with_classifier_prior: if not self.DetectionProbabilities.ready() or len(self.DetectionProbabilities([0]).wait()[0]) == 0: msgStr = "\nObject count classifier has not been trained! " + \ "Go back to the Object Count Classification applet and train it." raise DatasetConstraintError ("Tracking",msgStr) detProbs = self.DetectionProbabilities(time_range).wait() logger.info("filling traxelstore") filtered_labels = {} total_count = 0 empty_frame = False numTimeStep = len(feats.keys()) countT = 0 stepStr = "Creating traxel store" self.progressVisitor.showState(stepStr+" ") for t in feats.keys(): countT +=1 self.progressVisitor.showProgress(countT/float(numTimeStep)) rc = feats[t][default_features_key]['RegionCenter'] lower = feats[t][default_features_key]['Coord<Minimum>'] upper = feats[t][default_features_key]['Coord<Maximum>'] if rc.size: rc = rc[1:, ...] lower = lower[1:, ...] upper = upper[1:, ...] ct = feats[t][default_features_key]['Count'] if ct.size: ct = ct[1:, ...] logger.debug("at timestep {}, {} traxels found".format(t, rc.shape[0])) count = 0 filtered_labels_at = [] for idx in range(rc.shape[0]): traxel = Traxel() # for 2d data, set z-coordinate to 0: if len(rc[idx]) == 2: x, y = rc[idx] z = 0 x_lower, y_lower = lower[idx] x_upper, y_upper = upper[idx] z_lower = 0 z_upper = 0 elif len(rc[idx]) == 3: x, y, z = rc[idx] x_lower, y_lower, z_lower = lower[idx] x_upper, y_upper, z_upper = upper[idx] else: raise DatasetConstraintError ("Tracking", "The RegionCenter feature must have dimensionality 2 or 3.") size = ct[idx] if (x_upper < x_range[0] or x_lower >= x_range[1] or y_upper < y_range[0] or y_lower >= y_range[1] or z_upper < z_range[0] or z_lower >= z_range[1] or size < size_range[0] or size >= size_range[1]): filtered_labels_at.append(int(idx + 1)) continue else: count += 1 traxel.Id = int(idx + 1) traxel.Timestep = int(t) traxel.set_x_scale(x_scale) traxel.set_y_scale(y_scale) traxel.set_z_scale(z_scale) # Expects always 3 coordinates, z=0 for 2d data traxel.add_feature_array("com", 3) for i, v in enumerate([x, y, z]): traxel.set_feature_value('com', i, float(v)) traxel.add_feature_array("CoordMinimum", 3) for i, v in enumerate(lower[idx]): traxel.set_feature_value("CoordMinimum", i, float(v)) traxel.add_feature_array("CoordMaximum", 3) for i, v in enumerate(upper[idx]): traxel.set_feature_value("CoordMaximum", i, float(v)) if with_div: traxel.add_feature_array("divProb", 2) # idx+1 because rc and ct start from 1, divProbs starts from 0 prob = float(divProbs[t][idx + 1][1]) prob = float(prob) if prob < 0.0000001: prob = 0.0000001 if prob > 0.99999999: prob = 0.99999999 traxel.set_feature_value("divProb", 0, 1.0 - prob) traxel.set_feature_value("divProb", 1, prob) if with_classifier_prior: traxel.add_feature_array("detProb", len(detProbs[t][idx + 1])) for i, v in enumerate(detProbs[t][idx + 1]): val = float(v) if val < 0.0000001: val = 0.0000001 if val > 0.99999999: val = 0.99999999 traxel.set_feature_value("detProb", i, float(val)) # FIXME: check whether it is 2d or 3d data! if with_local_centers: traxel.add_feature_array("localCentersX", len(localCenters[t][idx + 1])) traxel.add_feature_array("localCentersY", len(localCenters[t][idx + 1])) traxel.add_feature_array("localCentersZ", len(localCenters[t][idx + 1])) for i, v in enumerate(localCenters[t][idx + 1]): traxel.set_feature_value("localCentersX", i, float(v[0])) traxel.set_feature_value("localCentersY", i, float(v[1])) traxel.set_feature_value("localCentersZ", i, float(v[2])) traxel.add_feature_array("count", 1) traxel.set_feature_value("count", 0, float(size)) if (x_upper < x_range[0] or x_lower >= x_range[1] or y_upper < y_range[0] or y_lower >= y_range[1] or z_upper < z_range[0] or z_lower >= z_range[1] or size < size_range[0] or size >= size_range[1]): logger.info("Omitting traxel with ID: {} {}".format(traxel.Id,t)) print "Omitting traxel with ID: {} {}".format(traxel.Id,t) else: logger.info("Adding traxel with ID: {} {}".format(traxel.Id,t)) traxelstore.TraxelsPerFrame.setdefault(int(t), {})[int(idx + 1)] = traxel if len(filtered_labels_at) > 0: filtered_labels[str(int(t) - time_range[0])] = filtered_labels_at logger.debug("at timestep {}, {} traxels passed filter".format(t, count)) if count == 0: empty_frame = True logger.info('Found empty frames for time {}'.format(t)) total_count += count self.parent.parent.trackingApplet.progressSignal.emit(100) self.FilteredLabels.setValue(filtered_labels, check_changed=True) return traxelstore def isTrackingSolutionAvailable(self): """ check whether the hypotheses graph is filled and contains a tracking solution :return: True if there is a tracking solution available, False otherwise """ hypothesesGraph = self.HypothesesGraph.value from hytra.core.hypothesesgraph import HypothesesGraph if isinstance(hypothesesGraph, HypothesesGraph): hypothesesGraph = hypothesesGraph.referenceTraxelGraph if hypothesesGraph.withTracklets else hypothesesGraph if 'value' in hypothesesGraph._graph.nodes(data='True')[0][1]: return True return False
class OpInterpolate(Operator): InputVolume = InputSlot() Missing = InputSlot() InterpolationMethod = InputSlot(value="cubic") Output = OutputSlot() _requiredMargin = {"cubic": 2, "linear": 1, "constant": 0} _maxInterpolationDistance = { "cubic": 1, "linear": np.inf, "constant": np.inf } _fallbacks = {"cubic": "linear", "linear": "constant", "constant": None} def propagateDirty(self, slot, subindex, roi): # TODO self.Output.setDirty(roi) def setupOutputs(self): # Output has the same shape/axes/dtype/drange as input self.Output.meta.assignFrom(self.InputVolume.meta) try: self._iinfo = np.iinfo(self.InputVolume.meta.dtype) except ValueError: # not integer type, no casting needed self._iinfo = None assert ( self.InputVolume.meta.getTaggedShape() == self.Missing.meta.getTaggedShape() ), "InputVolume and Missing must have the same shape " + "({} vs {})".format( self.InputVolume.meta.getTaggedShape(), self.Missing.meta.getTaggedShape()) def execute(self, slot, subindex, roi, result): # prefill result result[:] = self.InputVolume.get(roi).wait() resultZYXCT = vigra.taggedView( result, self.InputVolume.meta.axistags).withAxes(*"zyxct") missingZYXCT = vigra.taggedView( self.Missing.get(roi).wait(), self.Missing.meta.axistags).withAxes(*"zyxct") for t in range(resultZYXCT.shape[4]): for c in range(resultZYXCT.shape[3]): missingLabeled = vigra.analysis.labelVolumeWithBackground( missingZYXCT[..., c, t]) maxLabel = missingLabeled.max() for i in range(1, maxLabel + 1): self._interpolate(resultZYXCT[..., c, t], missingLabeled == i) return result def _cast(self, x): """ casts the array to expected range (i.e. 0..255 for uint8 types, ...) """ if not self._iinfo is None: x = np.where(x > self._iinfo.max, self._iinfo.max, x) x = np.where(x < self._iinfo.min, self._iinfo.min, x) return x def _interpolate(self, volume, missing, method=None): """ interpolates in z direction :param volume: 3d block with axistags 'zyx' :type volume: array-like :param missing: True where data is missing :type missing: bool, 3d block with axistags 'zyx' :param method: 'cubic' or 'linear' or 'constant' :type method: str """ method = self.InterpolationMethod.value if method is None else method # sanity checks assert method in list( self._requiredMargin.keys()), "Unknown method '{}'".format(method) assert (volume.axistags.index("z") == 0 and volume.axistags.index("y") == 1 and volume.axistags.index("x") == 2 and len(volume.shape) == 3), "Data must be 3d with z as first axis." # number and z-location of missing slices (z-axis is at zero) black_z_ind, black_y_ind, black_x_ind = np.where(missing) if len(black_z_ind) == 0: # no need for interpolation return if black_z_ind.max() - black_z_ind.min( ) + 1 > self._maxInterpolationDistance[method]: self._interpolate(volume, missing, self._fallbacks[method]) return # indices with respect to the required margin around the missing values minZ = black_z_ind.min() - self._requiredMargin[method] maxZ = black_z_ind.max() + self._requiredMargin[method] n = maxZ - minZ - 2 * self._requiredMargin[method] + 1 if not (minZ > -1 and maxZ < volume.shape[0]): # this method is not applicable, try another one logger.warning(" ".join(( "Margin not big enough for interpolation ", "(need at least {} pixels for '{}')".format( self._requiredMargin[method], method), ))) if self._fallbacks[method] is not None: logger.warning("Falling back to method '{}'".format( self._fallbacks[method])) self._interpolate(volume, missing, self._fallbacks[method]) return else: assert False, " ".join(( "Margin not big enough for interpolation", "(need at least {} pixels for '{}')".format( self._requiredMargin[method], method), "and no fallback available", )) minY, maxY = (black_y_ind.min(), black_y_ind.max()) minX, maxX = (black_x_ind.min(), black_x_ind.max()) if method == "linear" or method == "cubic" and n > 1: # do a convex combination of the boundary slices xs = np.linspace(0, 1, n + 2) left = volume[minZ, minY:maxY + 1, minX:maxX + 1] right = volume[maxZ, minY:maxY + 1, minX:maxX + 1] for i in range(n): # interpolate every slice volume[minZ + i + 1, minY:maxY + 1, minX:maxX + 1] = self._cast((1 - xs[i + 1]) * left + xs[i + 1] * right) elif method == "cubic": # interpolation coefficients D = np.rollaxis( volume[[minZ, minZ + 1, maxZ - 1, maxZ], minY:maxY + 1, minX:maxX + 1], 0, 3) F = np.tensordot(D, _cubic_mat(n), ([2], [1])) xs = np.linspace(0, 1, n + 2) for i in range(n): # interpolate every slice x = xs[i + 1] volume[minZ + i + 2, minY:maxY + 1, minX:maxX + 1] = self._cast(F[..., 0] + F[..., 1] * x + F[..., 2] * x**2 + F[..., 3] * x**3) else: # constant if minZ > 0: # fill right hand side with last good slice for i in range(maxZ - minZ + 1): volume[minZ + i, minY:maxY + 1, minX:maxX + 1] = volume[minZ - 1, minY:maxY + 1, minX:maxX + 1] elif maxZ < volume.shape[0] - 1: # fill left hand side with last good slice for i in range(maxZ - minZ + 1): volume[minZ + i, minY:maxY + 1, minX:maxX + 1] = volume[maxZ + 1, minY:maxY + 1, minX:maxX + 1] else: # nothing to do for empty block logger.warning( "Not enough data for interpolation leaving slice as is ..." )
class OpStackLoader(Operator): """Imports an image stack. Note: This operator does NOT cache the images, so direct access via the execute() function is very inefficient, especially through the Z-axis. Typically, you'll want to connect this operator to a cache whose block size is large in the X-Y plane. :param globstring: A glob string as defined by the glob module. We also support the following special extension to globstring syntax: A single string can hold a *list* of globstrings. The delimiter that separates the globstrings in the list is OS-specific via os.path.pathsep. For example, on Linux the pathsep is':', so '/a/b/c.txt:/d/e/f.txt:../g/i/h.txt' is parsed as ['/a/b/c.txt', '/d/e/f.txt', '../g/i/h.txt'] """ name = "Image Stack Reader" category = "Input" globstring = InputSlot() SequenceAxis = InputSlot(optional=True) stack = OutputSlot() class FileOpenError(Exception): def __init__(self, filename): self.filename = filename self.msg = f"Unable to open file: {filename}" super().__init__(self.msg) def setupOutputs(self): self.fileNameList = self.expandGlobStrings(self.globstring.value) num_files = len(self.fileNameList) if len(self.fileNameList) == 0: self.stack.meta.NOTREADY = True return try: self.info = vigra.impex.ImageInfo(self.fileNameList[0]) self.slices_per_file = vigra.impex.numberImages( self.fileNameList[0]) except RuntimeError as e: logger.error(str(e)) raise OpStackLoader.FileOpenError(self.fileNameList[0]) from e slice_shape = self.info.getShape() X, Y, C = slice_shape if self.slices_per_file == 1: if self.SequenceAxis.ready(): sequence_axis = str(self.SequenceAxis.value) assert sequence_axis in "tzc" else: sequence_axis = "z" # For stacks of 2D images, we assume xy slices if sequence_axis == "c": shape = (X, Y, C * num_files) axistags = vigra.defaultAxistags("xyc") else: shape = (num_files, Y, X, C) axistags = vigra.defaultAxistags(sequence_axis + "yxc") else: if self.SequenceAxis.ready(): sequence_axis = self.SequenceAxis.value assert sequence_axis in "tzc" else: sequence_axis = "t" if sequence_axis == "z": axistags = vigra.defaultAxistags("ztyxc") elif sequence_axis == "t": axistags = vigra.defaultAxistags("tzyxc") else: axistags = vigra.defaultAxistags("czyx") # For stacks of 3D volumes, we assume xyz blocks stacked along # sequence_axis if sequence_axis == "c": shape = (num_files * C, self.slices_per_file, Y, X) else: shape = (num_files, self.slices_per_file, Y, X, C) self.stack.meta.shape = shape self.stack.meta.axistags = axistags self.stack.meta.dtype = self.info.getDtype() def propagateDirty(self, slot, subindex, roi): assert slot == self.globstring # Any change to the globstring means our entire output is dirty. self.stack.setDirty() def execute(self, slot, subindex, roi, result): if len(self.stack.meta.shape) == 3: return self._execute_3d(roi, result) elif len(self.stack.meta.shape) == 4: return self._execute_4d(roi, result) elif len(self.stack.meta.shape) == 5: return self._execute_5d(roi, result) else: assert False, f"Unexpected output shape: {self.stack.meta.shape}" def _execute_3d(self, roi, result): traceLogger.debug("OpStackLoader: Execute for: " + str(roi)) # roi is in xyc order; stacking over c x_start, y_start, c_start = roi.start x_stop, y_stop, c_stop = roi.stop # get C of slice C = self.info.getShape()[2] # Copy each c-slice one at a time. for i, fileName in enumerate(self.fileNameList[c_start // C:c_stop // C]): traceLogger.debug(f"Reading image: {fileName}") file_shape = vigra.impex.ImageInfo(fileName).getShape() if self.info.getShape() != file_shape: raise RuntimeError("not all files have the same shape") images_per_file = vigra.impex.numberImages(self.fileNameList[0]) if self.slices_per_file != images_per_file: raise RuntimeError("Not all files have the same number of " "slices") result[:, :, i * C:(i + 1) * C] = vigra.impex.readImage(fileName)[ x_start:x_stop, y_start:y_stop, :].withAxes(*"xyc") return result def _execute_4d(self, roi, result): traceLogger.debug("OpStackLoader: Execute for: " + str(roi)) # roi is in zyxc, tyxc or czyx order, depending on SequenceAxis z_start, y_start, x_start, c_start = roi.start z_stop, y_stop, x_stop, c_stop = roi.stop # get C of slice C = self.info.getShape()[2] # Copy each z-slice one at a time. for result_z, fileName in enumerate(self.fileNameList[z_start:z_stop]): traceLogger.debug(f"Reading image: {fileName}") file_shape = vigra.impex.ImageInfo(fileName).getShape() if self.info.getShape() != file_shape: raise RuntimeError("not all files have the same shape") images_per_file = vigra.impex.numberImages(self.fileNameList[0]) if self.slices_per_file != images_per_file: raise RuntimeError("Not all files have the same number of " "slices") if self.stack.meta.axistags.channelIndex == 0: # czyx order -> read slice along z (here y) for result_y, y in enumerate(range(y_start, y_stop)): result[result_z * C:(result_z + 1) * C, result_y, ...] = vigra.impex.readImage( fileName, index=y)[c_start:c_stop, x_start:x_stop].withAxes(*"cyx") else: result[result_z, ...] = vigra.impex.readImage( fileName)[x_start:x_stop, y_start:y_stop, c_start:c_stop].withAxes(*"yxc") return result def _execute_5d(self, roi, result): # Technically, t and z might be switched depending on SequenceAxis. # Beware these variable names for t/z might be misleading. t_start, z_start, y_start, x_start, c_start = roi.start t_stop, z_stop, y_stop, x_stop, c_stop = roi.stop # Use *enumerated* range to get global t coords and result t coords for result_t, t in enumerate(range(t_start, t_stop)): file_name = self.fileNameList[t] for result_z, z in enumerate(range(z_start, z_stop)): img = vigra.readImage(file_name, index=z) result[result_t, result_z, :, :, :] = img[x_start:x_stop, y_start:y_stop, c_start:c_stop].withAxes( *"yxc") return result @staticmethod def expandGlobStrings(globStrings): ret = [] # Parse list into separate globstrings and combine them for globString in globStrings.split(os.path.pathsep): s = globString.strip() ret += sorted(glob.glob(s)) return ret
class OpDetectMissing(Operator): """ Sub-Operator for detection of missing image content """ InputVolume = InputSlot() PatchSize = InputSlot(value=128) HaloSize = InputSlot(value=30) DetectionMethod = InputSlot(value="classic") NHistogramBins = InputSlot(value=_defaultBinSize) OverloadDetector = InputSlot(value="") # histograms: ndarray, shape: nHistograms x (NHistogramBins.value + 1) # the last column holds the label, i.e. {0: negative, 1: positive} TrainingHistograms = InputSlot() Output = OutputSlot() Detector = OutputSlot(stype=Opaque) ### PRIVATE class attributes ### _manager = None ### PRIVATE attributes ### _inputRange = (0, 255) _needsTraining = True _felzenOpts = { "firstSamples": 250, "maxRemovePerStep": 0, "maxAddPerStep": 250, "maxSamples": 1000, "nTrainingSteps": 4, } def __init__(self, *args, **kwargs): super(OpDetectMissing, self).__init__(*args, **kwargs) self.TrainingHistograms.setValue(_defaultTrainingHistograms()) def propagateDirty(self, slot, subindex, roi): if slot == self.InputVolume: self.Output.setDirty(roi) if slot == self.TrainingHistograms: OpDetectMissing._needsTraining = True if slot == self.NHistogramBins: OpDetectMissing._needsTraining = OpDetectMissing._manager.has( self.NHistogramBins.value) if slot == self.PatchSize or slot == self.HaloSize: self.Output.setDirty() if slot == self.OverloadDetector: s = self.OverloadDetector.value self.loads(s) self.Output.setDirty() def setupOutputs(self): self.Output.meta.assignFrom(self.InputVolume.meta) self.Output.meta.dtype = np.uint8 # determine range of input if self.InputVolume.meta.dtype == np.uint8: r = (0, 255) elif self.InputVolume.meta.dtype == np.uint16: r = (0, 65535) else: # FIXME hardcoded range, use np.iinfo r = (0, 255) self._inputRange = r self.Detector.meta.shape = (1, ) def execute(self, slot, subindex, roi, result): if slot == self.Detector: result = self.dumps() return result # sanity check assert self.DetectionMethod.value in [ "svm", "classic" ], "Unknown detection method '{}'".format(self.DetectionMethod.value) # prefill result resultZYXCT = vigra.taggedView( result, self.InputVolume.meta.axistags).withAxes(*"zyxct") # acquire data data = self.InputVolume.get(roi).wait() dataZYXCT = vigra.taggedView( data, self.InputVolume.meta.axistags).withAxes(*"zyxct") # walk over time and channel axes for t in range(dataZYXCT.shape[4]): for c in range(dataZYXCT.shape[3]): resultZYXCT[..., c, t] = self._detectMissing(dataZYXCT[..., c, t]) return result def _detectMissing(self, data): """ detects missing regions and labels each missing region with 1 :param data: 3d data with axistags 'zyx' :type data: array-like """ assert (data.axistags.index("z") == 0 and data.axistags.index("y") == 1 and data.axistags.index("x") == 2 and len(data.shape) == 3), "Data must be 3d with axis 'zyx'." result = np.zeros(data.shape, dtype=np.uint8) patchSize = self.PatchSize.value haloSize = self.HaloSize.value if patchSize is None or not patchSize > 0: raise ValueError("PatchSize must be a positive integer") if haloSize is None or haloSize < 0: raise ValueError("HaloSize must be a non-negative integer") maxZ = data.shape[0] # walk over slices for z in range(maxZ): patches, slices = _patchify(data[z, :, :], patchSize, haloSize) hists = [] # walk over patches for patch in patches: (hist, _) = np.histogram(patch, bins=self.NHistogramBins.value, range=self._inputRange, density=True) hists.append(hist) hists = np.vstack(hists) pred = self.predict(hists, method=self.DetectionMethod.value) for i, p in enumerate(pred): if p > 0: # patch is classified as missing result[z, slices[i][0], slices[i][1]] |= 1 return result def train(self, force=False): """ trains with samples drawn from slot TrainingHistograms (retrains only if bin size is currently untrained or force is True) """ # return early if unneccessary if not force and not OpDetectMissing._needsTraining and OpDetectMissing._manager.has( self.NHistogramBins.value): return # return if we don't have svms if not havesklearn: return logger.debug("Training for {} histogram bins ...".format( self.NHistogramBins.value)) if self.DetectionMethod.value == "classic" or not havesklearn: # no need to train this return histograms = self.TrainingHistograms[:].wait() logger.debug("Finished loading histogram data of shape {}.".format( histograms.shape)) assert ( histograms.shape[1] >= self.NHistogramBins.value + 1 and len(histograms.shape) == 2 ), "Training data has wrong shape (expected: (n,{}), got: {}.".format( self.NHistogramBins.value + 1, histograms.shape) labels = histograms[:, self.NHistogramBins.value] histograms = histograms[:, :self.NHistogramBins.value] neg_inds = np.where(labels == 0)[0] pos_inds = np.setdiff1d(np.arange(len(labels)), neg_inds) pos = histograms[pos_inds] neg = histograms[neg_inds] npos = len(pos) nneg = len(neg) # prepare for 10-fold cross-validation nfolds = 10 cfp = np.zeros((nfolds, )) cfn = np.zeros((nfolds, )) cprec = np.zeros((nfolds, )) crec = np.zeros((nfolds, )) pos_random = np.random.permutation(len(pos)) neg_random = np.random.permutation(len(neg)) logger.debug("Starting training with " + "{} negative patches and {} positive patches...".format( len(neg), len(pos))) self._felzenszwalbTraining(neg, pos) logger.debug("Finished training.") OpDetectMissing._needsTraining = False def _felzenszwalbTraining(self, negative, positive): """ we want to train on a 'hard' subset of the training data, see FELZENSZWALB ET AL.: OBJECT DETECTION WITH DISCRIMINATIVELY TRAINED PART-BASED MODELS (4.4), PAMI 32/9 """ # TODO sanity checks n = (self.PatchSize.value + self.HaloSize.value)**2 method = self.DetectionMethod.value # set options for Felzenszwalb training firstSamples = self._felzenOpts["firstSamples"] maxRemovePerStep = self._felzenOpts["maxRemovePerStep"] maxAddPerStep = self._felzenOpts["maxAddPerStep"] maxSamples = self._felzenOpts["maxSamples"] nTrainingSteps = self._felzenOpts["nTrainingSteps"] # initial choice of training samples (initNegative, choiceNegative, _, _) = _chooseRandomSubset(negative, min(firstSamples, len(negative))) (initPositive, choicePositive, _, _) = _chooseRandomSubset(positive, min(firstSamples, len(positive))) # setup for parallel training samples = [negative, positive] choice = [choiceNegative, choicePositive] S_t = [initNegative, initPositive] finished = [False, False] ### BEGIN SUBROUTINE ### def felzenstep(x, cache, ind): case = ("positive" if ind > 0 else "negative") + " set" pred = self.predict(x, method=method) hard = np.where(pred != ind)[0] easy = np.setdiff1d(list(range(len(x))), hard) logger.debug(" {}: currently {} hard and {} easy samples".format( case, len(hard), len(easy))) # shrink the cache easyInCache = np.intersect1d(easy, cache) if len(easy) > 0 else [] if len(easyInCache) > 0: (removeFromCache, _, _, _) = _chooseRandomSubset( easyInCache, min(len(easyInCache), maxRemovePerStep)) cache = np.setdiff1d(cache, removeFromCache) logger.debug(" {}: shrunk the cache by {} elements".format( case, len(removeFromCache))) # grow the cache temp = len(cache) addToCache = _chooseRandomSubset(hard, min(len(hard), maxAddPerStep))[0] cache = np.union1d(cache, addToCache) addedHard = len(cache) - temp logger.debug(" {}: grown the cache by {} elements".format( case, addedHard)) if len(cache) > maxSamples: logger.debug( " {}: Cache to big, removing elements.".format(case)) cache = _chooseRandomSubset(cache, maxSamples)[0] # apply the cache C = x[cache] return (C, cache, addedHard == 0) ### END SUBROUTINE ### ### BEGIN PARALLELIZATION FUNCTION ### def partFun(i): (C, newChoice, newFinished) = felzenstep(samples[i], choice[i], i) S_t[i] = C choice[i] = newChoice finished[i] = newFinished ### END PARALLELIZATION FUNCTION ### for k in range(nTrainingSteps): logger.debug("Felzenszwalb Training " + "(step {}/{}): {} hard negative samples, {}".format( k + 1, nTrainingSteps, len(S_t[0]), len(S_t[1])) + "hard positive samples.") self.fit(S_t[0], S_t[1], method=method) pool = RequestPool() for i in range(len(S_t)): req = Request(partial(partFun, i)) pool.add(req) pool.wait() pool.clean() if np.all(finished): # already have all hard examples in training set break self.fit(S_t[0], S_t[1], method=method) logger.debug(" Finished Felzenszwalb Training.")
class OpStackWriter(Operator): name = "Stack File Writer" category = "Output" Input = InputSlot( ) # The last two non-singleton axes (except 'c') are the axes of the slices. # Re-order the axes yourself if you want an alternative slicing direction FilepathPattern = InputSlot( ) # A complete filepath including a {slice_index} member and a valid file extension. SliceIndexOffset = InputSlot( value=0) # Added to the {slice_index} in the export filename. def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.progressSignal = OrderedSignal() def run_export(self): """ Request the volume in slices (running in parallel), and write each slice to a separate image. """ # Make the directory first if necessary export_dir = os.path.split(self.FilepathPattern.value)[0] if not os.path.exists(export_dir): os.makedirs(export_dir) # Sliceshape is the same as the input shape, except for the sliced dimension tagged_sliceshape = self.Input.meta.getTaggedShape() tagged_sliceshape[self._volume_axes[0]] = 1 slice_shape = list(tagged_sliceshape.values()) parallel_requests = 4 # If ram usage info is available, make a better guess about how many requests we can launch in parallel ram_usage_per_requested_pixel = self.Input.meta.ram_usage_per_requested_pixel if ram_usage_per_requested_pixel is not None: pixels_per_slice = numpy.prod(slice_shape) if "c" in tagged_sliceshape: pixels_per_slice //= tagged_sliceshape["c"] ram_usage_per_slice = pixels_per_slice * ram_usage_per_requested_pixel # Fudge factor: Reduce RAM usage by a bit available_ram = psutil.virtual_memory().available available_ram *= 0.5 parallel_requests = int(available_ram // ram_usage_per_slice) if parallel_requests < 1: raise MemoryError( "Not enough RAM to export to the selected format. " "Consider exporting to hdf5 (h5).") streamer = BigRequestStreamer(self.Input, roiFromShape(self.Input.meta.shape), slice_shape, parallel_requests) # Write the slices as they come in (possibly out-of-order, but probably not) streamer.resultSignal.subscribe(self._write_slice) streamer.progressSignal.subscribe(self.progressSignal) logger.debug( f"Starting Stack Export with slicing shape: {slice_shape}") streamer.execute() def setupOutputs(self): # If stacking XY images in Z-steps, # then self._volume_axes = 'zxy' self._volume_axes = self.get_nonsingleton_axes() step_axis = self._volume_axes[0] max_slice = self.SliceIndexOffset.value + self.Input.meta.getTaggedShape( )[step_axis] self._max_slice_digits = int(math.ceil(math.log10(max_slice + 1))) # Check for errors assert len(self._volume_axes) == 3 or len( self._volume_axes ) == 4 and "c" in self._volume_axes[1:], ( "Exported stacks must have exactly 3 non-singleton dimensions (other than the channel dimension). " "Your stack dimensions are: {}".format( self.Input.meta.getTaggedShape())) # Test to make sure the filepath pattern includes slice index field filepath_pattern = self.FilepathPattern.value assert "123456789" in filepath_pattern.format( slice_index=123_456_789 ), ("Output filepath pattern must contain the '{{slice_index}}' field for formatting.\n" "Your format was: {}".format(filepath_pattern)) # No output slots... def execute(self, slot, subindex, roi, result): pass def propagateDirty(self, slot, subindex, roi): pass def get_nonsingleton_axes(self): return self.get_nonsingleton_axes_for_tagged_shape( self.Input.meta.getTaggedShape()) @classmethod def get_nonsingleton_axes_for_tagged_shape(self, tagged_shape): # Find the non-singleton axes. # The first non-singleton axis is the step axis. # The last 2 non-channel non-singleton axes will be the axes of the slices. tagged_items = list(tagged_shape.items()) filtered_items = [k_v for k_v in tagged_items if k_v[1] > 1] filtered_axes = list(zip(*filtered_items))[0] return filtered_axes def _write_slice(self, roi, slice_data): """ Write the data from the given roi into a slice image. """ step_axis = self._volume_axes[0] input_axes = self.Input.meta.getAxisKeys() tagged_roi = OrderedDict(list(zip(input_axes, list(zip(*roi))))) # e.g. tagged_roi={ 'x':(0,1), 'y':(3,4), 'z':(10,20) } assert tagged_roi[step_axis][1] - tagged_roi[step_axis][ 0] == 1, "Expected roi to be a single slice." slice_index = tagged_roi[step_axis][0] + self.SliceIndexOffset.value filepattern = self.FilepathPattern.value # If the user didn't provide custom formatting for the slice field, # auto-format to include zero-padding if "{slice_index}" in filepattern: filepattern = filepattern.format( slice_index="{" + "slice_index:0{}".format(self._max_slice_digits) + "}") formatted_path = filepattern.format(slice_index=slice_index) squeezed_data = slice_data.squeeze() squeezed_data = vigra.taggedView( squeezed_data, vigra.defaultAxistags("".join(self._volume_axes[1:]))) assert len(squeezed_data.shape) == len(self._volume_axes) - 1 # logger.debug( "Writing slice image for roi: {}".format( roi ) ) logger.debug("Writing slice: {}".format(formatted_path)) vigra.impex.writeImage(squeezed_data, formatted_path)
class OpLayerViewer(Operator): name = "OpLayerViewer" category = "top-level" RawInput = InputSlot()
class OpStackToH5Writer(Operator): name = "OpStackToH5Writer" category = "IO" GlobString = InputSlot(stype="globstring") hdf5Group = InputSlot(stype="object") hdf5Path = InputSlot(stype="string") # Requesting the output induces the copy from stack to h5 file. WriteImage = OutputSlot(stype="bool") def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.progressSignal = OrderedSignal() self.opStackLoader = OpStackLoader(parent=self) self.opStackLoader.globstring.connect(self.GlobString) def setupOutputs(self): self.WriteImage.meta.shape = (1, ) self.WriteImage.meta.dtype = object def propagateDirty(self, slot, subindex, roi): # Any change to our inputs means we're dirty assert slot == self.GlobString or slot == self.hdf5Group or slot == self.hdf5Path self.WriteImage.setDirty(slice(None)) def execute(self, slot, subindex, roi, result): if not self.opStackLoader.fileNameList: raise Exception( f"Didn't find any files to combine. Is the glob string valid? " f"globstring = {self.GlobString.value}") # Copy the data image-by-image stackTags = self.opStackLoader.stack.meta.axistags zAxis = stackTags.index("z") dataShape = self.opStackLoader.stack.meta.shape numImages = self.opStackLoader.stack.meta.shape[zAxis] axistags = self.opStackLoader.stack.meta.axistags dtype = self.opStackLoader.stack.meta.dtype if isinstance(dtype, numpy.dtype): # Make sure we're dealing with a type (e.g. numpy.float64), # not a numpy.dtype dtype = dtype.type index_ = axistags.index("c") if index_ >= len(dataShape): numChannels = 1 else: numChannels = dataShape[index_] # Set up our chunk shape: Aim for a cube that's roughly 300k in size dtypeBytes = dtype().nbytes cubeDim = math.pow(300_000 // (numChannels * dtypeBytes), (1 / 3.0)) cubeDim = int(cubeDim) chunkDims = {} chunkDims["t"] = 1 chunkDims["x"] = cubeDim chunkDims["y"] = cubeDim chunkDims["z"] = cubeDim chunkDims["c"] = numChannels # h5py guide to chunking says chunks of 300k or less "work best" assert chunkDims["x"] * chunkDims["y"] * chunkDims[ "z"] * numChannels * dtypeBytes <= 300_000 chunkShape = () for i in range(len(dataShape)): axisKey = axistags[i].key # Chunk shape can't be larger than the data shape chunkShape += (min(chunkDims[axisKey], dataShape[i]), ) # Create the dataset internalPath = self.hdf5Path.value internalPath = internalPath.replace("\\", "/") # Windows fix group = self.hdf5Group.value if internalPath in group: del group[internalPath] data = group.create_dataset( internalPath, # compression='gzip', # compression_opts=4, shape=dataShape, dtype=dtype, chunks=chunkShape, ) # Now copy each image self.progressSignal(0) for z in range(numImages): # Ask for an entire z-slice (exactly one whole image from the stack) slicing = [slice(None)] * len(stackTags) slicing[zAxis] = slice(z, z + 1) data[tuple(slicing)] = self.opStackLoader.stack[slicing].wait() self.progressSignal(z * 100 // numImages) data.attrs["axistags"] = axistags.toJSON() # We're done result[...] = True self.progressSignal(100) return result
class OpPredictionPipeline(OpPredictionPipelineNoCache): """ This operator extends the cacheless prediction pipeline above with additional outputs for the GUI. (It uses caches for these outputs, and has an extra input for cached features.) """ FreezePredictions = InputSlot() CachedFeatureImages = InputSlot() PredictionProbabilities = OutputSlot() CachedPredictionProbabilities = OutputSlot() PredictionProbabilityChannels = OutputSlot(level=1) SegmentationChannels = OutputSlot(level=1) UncertaintyEstimate = OutputSlot() def __init__(self, *args, **kwargs): super(OpPredictionPipeline, self).__init__(*args, **kwargs) # Random forest prediction using CACHED features. self.predict = OpClassifierPredict(parent=self) self.predict.name = "OpClassifierPredict" self.predict.Classifier.connect(self.Classifier) self.predict.Image.connect(self.CachedFeatureImages) self.predict.PredictionMask.connect(self.PredictionMask) self.predict.LabelsCount.connect(self.NumClasses) self.PredictionProbabilities.connect(self.predict.PMaps) # Prediction cache for the GUI self.prediction_cache_gui = OpSlicedBlockedArrayCache(parent=self) self.prediction_cache_gui.name = "prediction_cache_gui" self.prediction_cache_gui.inputs["fixAtCurrent"].connect( self.FreezePredictions) self.prediction_cache_gui.inputs["Input"].connect(self.predict.PMaps) self.CachedPredictionProbabilities.connect( self.prediction_cache_gui.Output) # Also provide each prediction channel as a separate layer (for the GUI) self.opPredictionSlicer = OpMultiArraySlicer2(parent=self) self.opPredictionSlicer.name = "opPredictionSlicer" self.opPredictionSlicer.Input.connect(self.prediction_cache_gui.Output) self.opPredictionSlicer.AxisFlag.setValue('c') self.PredictionProbabilityChannels.connect( self.opPredictionSlicer.Slices) self.opSegmentor = OpMaxChannelIndicatorOperator(parent=self) self.opSegmentor.Input.connect(self.prediction_cache_gui.Output) self.opSegmentationSlicer = OpMultiArraySlicer2(parent=self) self.opSegmentationSlicer.name = "opSegmentationSlicer" self.opSegmentationSlicer.Input.connect(self.opSegmentor.Output) self.opSegmentationSlicer.AxisFlag.setValue('c') self.SegmentationChannels.connect(self.opSegmentationSlicer.Slices) # Create a layer for uncertainty estimate self.opUncertaintyEstimator = OpEnsembleMargin(parent=self) self.opUncertaintyEstimator.Input.connect( self.prediction_cache_gui.Output) # Cache the uncertainty so we get zeros for uncomputed points self.opUncertaintyCache = OpSlicedBlockedArrayCache(parent=self) self.opUncertaintyCache.name = "opUncertaintyCache" self.opUncertaintyCache.Input.connect( self.opUncertaintyEstimator.Output) self.opUncertaintyCache.fixAtCurrent.connect(self.FreezePredictions) self.UncertaintyEstimate.connect(self.opUncertaintyCache.Output) def setupOutputs(self): # Set the blockshapes for each input image separately, depending on which axistags it has. axisOrder = [tag.key for tag in self.FeatureImages.meta.axistags] blockDimsX = { 't': (1, 1), 'z': (128, 256), 'y': (128, 256), 'x': (1, 1), 'c': (100, 100) } blockDimsY = { 't': (1, 1), 'z': (128, 256), 'y': (1, 1), 'x': (128, 256), 'c': (100, 100) } blockDimsZ = { 't': (1, 1), 'z': (1, 1), 'y': (128, 256), 'x': (128, 256), 'c': (100, 100) } innerBlockShapeX = tuple(blockDimsX[k][0] for k in axisOrder) outerBlockShapeX = tuple(blockDimsX[k][1] for k in axisOrder) innerBlockShapeY = tuple(blockDimsY[k][0] for k in axisOrder) outerBlockShapeY = tuple(blockDimsY[k][1] for k in axisOrder) innerBlockShapeZ = tuple(blockDimsZ[k][0] for k in axisOrder) outerBlockShapeZ = tuple(blockDimsZ[k][1] for k in axisOrder) self.prediction_cache_gui.inputs["innerBlockShape"].setValue( (innerBlockShapeX, innerBlockShapeY, innerBlockShapeZ)) self.prediction_cache_gui.inputs["outerBlockShape"].setValue( (outerBlockShapeX, outerBlockShapeY, outerBlockShapeZ)) self.opUncertaintyCache.inputs["innerBlockShape"].setValue( (innerBlockShapeX, innerBlockShapeY, innerBlockShapeZ)) self.opUncertaintyCache.inputs["outerBlockShape"].setValue( (outerBlockShapeX, outerBlockShapeY, outerBlockShapeZ))
class OpImageReader(Operator): """ Read an image using vigra.impex.readImage(). Supports 2D images (output as xyc) and also multi-page tiffs (output as zyxc). """ Filename = InputSlot(stype="filestring") Image = OutputSlot() def setupOutputs(self): filename = self.Filename.value info = vigra.impex.ImageInfo(filename) assert [tag.key for tag in info.getAxisTags()] == ["x", "y", "c"] shape_xyc = info.getShape() shape_yxc = (shape_xyc[1], shape_xyc[0], shape_xyc[2]) self.Image.meta.dtype = info.getDtype() self.Image.meta.prefer_2d = True numImages = vigra.impex.numberImages(filename) if numImages == 1: # For 2D, we use order yxc. self.Image.meta.shape = shape_yxc v_tags = info.getAxisTags() self.Image.meta.axistags = vigra.AxisTags( [v_tags[k] for k in "yxc"]) else: # For 3D, we use zyxc # Insert z-axis shape shape_zyxc = (numImages, ) + shape_yxc self.Image.meta.shape = shape_zyxc # Insert z tag z_tag = vigra.defaultAxistags("z")[0] tags_xyc = [tag for tag in info.getAxisTags()] tags_zyxc = [z_tag] + list(reversed(tags_xyc[:-1])) + tags_xyc[-1:] self.Image.meta.axistags = vigra.AxisTags(tags_zyxc) def execute(self, slot, subindex, rroi, result): filename = self.Filename.value if "z" in self.Image.meta.getAxisKeys(): # Copy from each image slice into the corresponding slice of the result. roi_zyxc = numpy.array([rroi.start, rroi.stop]) for z_global, z_result in zip(list(range(*roi_zyxc[:, 0])), list(range(result.shape[0]))): full_slice = vigra.impex.readImage(filename, index=z_global) full_slice = full_slice.transpose(1, 0, 2) # xyc -> yxc assert full_slice.shape == self.Image.meta.shape[1:] result[z_result] = full_slice[roiToSlice(*roi_zyxc[:, 1:])] else: full_slice = vigra.impex.readImage(filename).transpose( 1, 0, 2) # xyc -> yxc assert full_slice.shape == self.Image.meta.shape roi_yxc = numpy.array([rroi.start, rroi.stop]) result[:] = full_slice[roiToSlice(*roi_yxc)] return result def propagateDirty(self, slot, subindex, roi): if slot == self.Filename: self.Image.setDirty() else: assert False, "Unknown dirty input slot."
class OpDivisionFeatures(Operator): """Computes division features on a 5D volume.""" LabelVolume = InputSlot() DivisionFeatureNames = InputSlot(rtype=List, stype=Opaque) RegionFeaturesVigra = InputSlot() BlockwiseDivisionFeatures = OutputSlot() def __init__(self, *args, **kwargs): super(OpDivisionFeatures, self).__init__(*args, **kwargs) def setupOutputs(self): taggedShape = self.LabelVolume.meta.getTaggedShape() if set(taggedShape.keys()) != set('txyzc'): raise Exception("Input volumes must have txyzc axes.") self.BlockwiseDivisionFeatures.meta.shape = tuple([taggedShape['t']]) self.BlockwiseDivisionFeatures.meta.axistags = vigra.defaultAxistags( "t") self.BlockwiseDivisionFeatures.meta.dtype = object ndim = 3 if np.any(list(taggedShape.get(k, 0) == 1 for k in "xyz")): ndim = 2 self.featureManager = FeatureManager( scales=config.image_scale, n_best=config.n_best_successors, com_name_cur=config.com_name_cur, com_name_next=config.com_name_next, size_name=config.size_name, delim=config.delim, template_size=config.template_size, ndim=ndim, size_filter=config.size_filter, squared_distance_default=config.squared_distance_default) def execute(self, slot, subindex, roi, result): assert len(roi.start) == len(roi.stop) == len( self.BlockwiseDivisionFeatures.meta.shape) assert slot == self.BlockwiseDivisionFeatures taggedShape = self.LabelVolume.meta.getTaggedShape() timeIndex = list(taggedShape.keys()).index('t') import time start = time.perf_counter() vroi_start = len(self.LabelVolume.meta.shape) * [ 0, ] vroi_stop = list(self.LabelVolume.meta.shape) assert len(roi.start) == 1 froi_start = roi.start[0] froi_stop = roi.stop[0] vroi_stop[timeIndex] = roi.stop[0] assert timeIndex == 0 vroi_start[timeIndex] = roi.start[0] if roi.stop[0] + 1 < self.LabelVolume.meta.shape[timeIndex]: vroi_stop[timeIndex] = roi.stop[0] + 1 froi_stop = roi.stop[0] + 1 vroi = [ slice(vroi_start[i], vroi_stop[i]) for i in range(len(vroi_start)) ] feats = self.RegionFeaturesVigra[slice(froi_start, froi_stop)].wait() labelVolume = self.LabelVolume[vroi].wait() divisionFeatNames = self.DivisionFeatureNames[( )].wait()[config.features_division_name] for t in range(roi.stop[0] - roi.start[0]): result[t] = {} feats_cur = feats[t][config.features_vigra_name] if t + 1 < froi_stop - froi_start: feats_next = feats[t + 1][config.features_vigra_name] img_next = labelVolume[t + 1, ...] else: feats_next = None img_next = None res = self.featureManager.computeFeatures_at( feats_cur, feats_next, img_next, divisionFeatNames) result[t][config.features_division_name] = res stop = time.perf_counter() logger.debug( "TIMING: computing division features took {:.3f}s".format(stop - start)) return result def propagateDirty(self, slot, subindex, roi): if slot is self.DivisionFeatureNames: self.BlockwiseDivisionFeatures.setDirty(slice(None)) elif slot is self.RegionFeaturesVigra: self.BlockwiseDivisionFeatures.setDirty(roi) else: axes = list(self.LabelVolume.meta.getTaggedShape().keys()) dirtyStart = collections.OrderedDict(list(zip(axes, roi.start))) dirtyStop = collections.OrderedDict(list(zip(axes, roi.stop))) # Remove the spatial and channel dims (keep t, if present) del dirtyStart['x'] del dirtyStart['y'] del dirtyStart['z'] del dirtyStart['c'] del dirtyStop['x'] del dirtyStop['y'] del dirtyStop['z'] del dirtyStop['c'] self.BlockwiseDivisionFeatures.setDirty(list(dirtyStart.values()), list(dirtyStop.values()))
class OpH5N5WriterBigDataset(Operator): name = "H5 and N5 File Writer BigDataset" category = "Output" h5N5File = InputSlot( ) # Must be an already-open hdf5File/n5File (or group) for writing to h5N5Path = InputSlot() Image = InputSlot() # h5py uses single-threaded gzip comression, which really slows down export. CompressionEnabled = InputSlot(value=False) BatchSize = InputSlot(optional=True) WriteImage = OutputSlot() loggingName = __name__ + ".OpH5N5WriterBigDataset" logger = logging.getLogger(loggingName) traceLogger = logging.getLogger("TRACE." + loggingName) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.progressSignal = OrderedSignal() self.d = None self.f = None def cleanUp(self): super().cleanUp() # Discard the reference to the dataset, to ensure that the file can be closed. self.d = None self.f = None self.progressSignal.clean() def setupOutputs(self): self.outputs["WriteImage"].meta.shape = (1, ) self.outputs["WriteImage"].meta.dtype = object self.f = self.inputs["h5N5File"].value h5N5Path = self.inputs["h5N5Path"].value # On windows, there may be backslashes. h5N5Path = h5N5Path.replace("\\", "/") h5N5GroupName, datasetName = os.path.split(h5N5Path) if h5N5GroupName == "": g = self.f else: if h5N5GroupName in self.f: g = self.f[h5N5GroupName] else: g = self.f.create_group(h5N5GroupName) dataShape = self.Image.meta.shape self.logger.info(f"Data shape: {dataShape}") dtype = self.Image.meta.dtype if isinstance(dtype, numpy.dtype): # Make sure we're dealing with a type (e.g. numpy.float64), # not a numpy.dtype dtype = dtype.type # Set up our chunk shape: Aim for a cube that's roughly 512k in size dtypeBytes = dtype().nbytes tagged_maxshape = self.Image.meta.getTaggedShape() if "t" in tagged_maxshape: # Assume that chunks should not span multiple t-slices, # and channels are often handled separately, too. tagged_maxshape["t"] = 1 if "c" in tagged_maxshape: tagged_maxshape["c"] = 1 self.chunkShape = determineBlockShape(list(tagged_maxshape.values()), 512_000.0 / dtypeBytes) if datasetName in list(g.keys()): del g[datasetName] kwargs = { "shape": dataShape, "dtype": dtype, "chunks": self.chunkShape } if self.CompressionEnabled.value: kwargs[ "compression"] = "gzip" # <-- Would be nice to use lzf compression here, but that is h5py-specific. if isinstance(self.f, h5py.File): kwargs[ "compression_opts"] = 1 # <-- Optimize for speed, not disk space. else: # z5py has uses different names here kwargs["level"] = 1 # <-- Optimize for speed, not disk space. else: if isinstance(self.f, z5py.N5File ): # n5 uses gzip level 5 as default compression. kwargs["compression"] = "raw" self.d = g.create_dataset(datasetName, **kwargs) if self.Image.meta.drange is not None: self.d.attrs["drange"] = self.Image.meta.drange if self.Image.meta.display_mode is not None: self.d.attrs["display_mode"] = self.Image.meta.display_mode def execute(self, slot, subindex, rroi, result): self.progressSignal(0) # Save the axistags as a dataset attribute self.d.attrs["axistags"] = self.Image.meta.axistags.toJSON() def handle_block_result(roi, data): slicing = roiToSlice(*roi) if data.flags.c_contiguous: self.d.write_direct(data.view(numpy.ndarray), dest_sel=slicing) else: self.d[slicing] = data batch_size = None if self.BatchSize.ready(): batch_size = self.BatchSize.value requester = BigRequestStreamer(self.Image, roiFromShape(self.Image.meta.shape), batchSize=batch_size) requester.resultSignal.subscribe(handle_block_result) requester.progressSignal.subscribe(self.progressSignal) requester.execute() # Be paranoid: Flush right now. if isinstance(self.f, h5py.File): self.f.file.flush() # not available in z5py # We're finished. result[0] = True self.progressSignal(100) def propagateDirty(self, slot, subindex, roi): # The output from this operator isn't generally connected to other operators. # If someone is using it that way, we'll assume that the user wants to know that # the input image has become dirty and may need to be written to disk again. self.WriteImage.setDirty(slice(None))
class OpStreamingHdf5Reader(Operator): """ The top-level operator for the data selection applet. """ name = "OpStreamingHdf5Reader" category = "Reader" # The project hdf5 File object (already opened) Hdf5File = InputSlot(stype='hdf5File') # The internal path for project-local datasets InternalPath = InputSlot(stype='string') # Output data OutputImage = OutputSlot() class DatasetReadError(Exception): def __init__(self, internalPath): self.internalPath = internalPath self.msg = "Unable to open Hdf5 dataset: {}".format( internalPath ) super(OpStreamingHdf5Reader.DatasetReadError, self).__init__( self.msg ) def __init__(self, *args, **kwargs): super(OpStreamingHdf5Reader, self).__init__(*args, **kwargs) self._hdf5File = None def setupOutputs(self): # Read the dataset meta-info from the HDF5 dataset self._hdf5File = self.Hdf5File.value internalPath = self.InternalPath.value if internalPath not in self._hdf5File: raise OpStreamingHdf5Reader.DatasetReadError(internalPath) dataset = self._hdf5File[internalPath] try: # Read the axistags property without actually importing the data axistagsJson = self._hdf5File[internalPath].attrs['axistags'] # Throws KeyError if 'axistags' can't be found axistags = vigra.AxisTags.fromJSON(axistagsJson) except KeyError: # No axistags found. ndims = len(dataset.shape) assert ndims != 0, "OpStreamingHdf5Reader: Zero-dimensional datasets not supported." assert ndims != 1, "OpStreamingHdf5Reader: Support for 1-D data not yet supported" assert ndims <= 5, "OpStreamingHdf5Reader: No support for data with more than 5 dimensions." axisorders = { 2 : 'xy', 3 : 'xyz', 4 : 'xyzc', 5 : 'txyzc' } axisorder = axisorders[ndims] if ndims == 3 and dataset.shape[2] <= 4: # Special case: If the 3rd dim is small, assume it's 'c', not 'z' axisorder = 'xyc' axistags = vigra.defaultAxistags(axisorder) assert len(axistags) == len( dataset.shape ),\ "Mismatch between shape {} and axisorder {}".format( dataset.shape, axisorder ) # Configure our slot meta-info self.OutputImage.meta.dtype = dataset.dtype.type self.OutputImage.meta.shape = dataset.shape self.OutputImage.meta.axistags = axistags # If the dataset specifies a datarange, add it to the slot metadata if 'drange' in self._hdf5File[internalPath].attrs: self.OutputImage.meta.drange = tuple( self._hdf5File[internalPath].attrs['drange'] ) total_volume = numpy.prod(numpy.array(self._hdf5File[internalPath].shape)) if not self._hdf5File[internalPath].chunks and total_volume > 1e8: self.OutputImage.meta.inefficient_format = True logger.warn("This dataset ({}{}) is NOT chunked. " "Performance for 3D access patterns will be bad!" .format( self._hdf5File.filename, internalPath )) def execute(self, slot, subindex, roi, result): t = time.time() assert self._hdf5File is not None # Read the desired data directly from the hdf5File key = roi.toSlice() hdf5File = self._hdf5File internalPath = self.InternalPath.value timer = None if logger.isEnabledFor(logging.DEBUG): logger.debug("Reading HDF5 block: [{}, {}]".format( roi.start, roi.stop )) timer = Timer() timer.unpause() if result.flags.c_contiguous: hdf5File[internalPath].read_direct( result[...], key ) else: result[...] = hdf5File[internalPath][key] if logger.getEffectiveLevel() >= logging.DEBUG: t = 1000.0*(time.time()-t) logger.debug("took %f msec." % t) if timer: timer.pause() logger.debug("Completed HDF5 read in {} seconds: [{}, {}]".format( timer.seconds(), roi.start, roi.stop )) def propagateDirty(self, slot, subindex, roi): if slot == self.Hdf5File or slot == self.InternalPath: self.OutputImage.setDirty( slice(None) )
class OpBlockedArrayCache(Operator, ManagedBlockedCache): """ A blockwise array cache designed to replace the old OpBlockedArrayCache. Instead of a monolithic implementation, this operator is a small pipeline of three simple operators. The actual caching of data is handled by an unblocked cache, so the "blocked" functionality is implemented via separate "splitting" operator that comes after the cache. Also, the "fixAtCurrent" feature is implemented in a special operator, which comes before the cache. """ fixAtCurrent = InputSlot(value=False) Input = InputSlot(allow_mask=True) # BlockShape = InputSlot() BlockShape = InputSlot( optional=True ) # If 'None' is present, those items will be treated as max for the dimension. # If not provided, will be set to Input.meta.shape BypassModeEnabled = InputSlot(value=False) CompressionEnabled = InputSlot(value=False) Output = OutputSlot(allow_mask=True) CleanBlocks = OutputSlot( ) # A list of slicings indicating which blocks are stored in the cache and clean. def __init__(self, *args, **kwargs): super(OpBlockedArrayCache, self).__init__(*args, **kwargs) # SCHEMATIC WHEN BypassModeEnabled == False: # # Input ---------> opCacheFixer -> opSimpleBlockedArrayCache -> (indirectly via execute) -> Output # / / # fixAtCurrent -- / # / # BlockShape ------------------- # SCHEMATIC WHEN BypassModeEnabled == True: # # Input --> (indirectly via execute) -> Output self._opCacheFixer = OpCacheFixer(parent=self) self._opCacheFixer.Input.connect(self.Input) self._opCacheFixer.fixAtCurrent.connect(self.fixAtCurrent) self._opSimpleBlockedArrayCache = OpSimpleBlockedArrayCache( parent=self) self._opSimpleBlockedArrayCache.Input.connect( self._opCacheFixer.Output) self._opSimpleBlockedArrayCache.CompressionEnabled.connect( self.CompressionEnabled) self._opSimpleBlockedArrayCache.Input.connect( self._opCacheFixer.Output) self._opSimpleBlockedArrayCache.BlockShape.connect(self.BlockShape) self._opSimpleBlockedArrayCache.BypassModeEnabled.connect( self.BypassModeEnabled) self.CleanBlocks.connect(self._opSimpleBlockedArrayCache.CleanBlocks) self.Output.connect(self._opSimpleBlockedArrayCache.Output) # Instead of connecting our Output directly to our internal pipeline, # We manually forward the data via the execute() function, # which allows us to implement a bypass for the internal pipeline if Enabled # self.Output.connect( self._opSimpleBlockedArrayCache.Output ) # Since we didn't directly connect the pipeline to our output, explicitly forward dirty notifications self._opSimpleBlockedArrayCache.Output.notifyDirty( lambda slot, roi: self.Output.setDirty(roi.start, roi.stop)) # This member is used by tests that check RAM usage. self.setup_ram_context = RamMeasurementContext() self.registerWithMemoryManager() def setupOutputs(self): if not self.BlockShape.ready(): self.BlockShape.setValue(self.Input.meta.shape) # Copy metadata from the internal pipeline to the output self.Output.meta.assignFrom( self._opSimpleBlockedArrayCache.Output.meta) def execute(self, slot, subindex, roi, result): assert False, "Shouldn't get here" def propagateDirty(self, slot, subindex, roi): pass def setInSlot(self, slot, subindex, key, value): pass # Nothing to do here: Input is connected to an internal operator # ======= mimic cache interface for wrapping operators ======= def usedMemory(self): return self._opSimpleBlockedArrayCache.usedMemory() def fractionOfUsedMemoryDirty(self): # dirty memory is discarded immediately return self._opSimpleBlockedArrayCache.fractionOfUsedMemoryDirty() def lastAccessTime(self): return self._opSimpleBlockedArrayCache.lastAccessTime() def getBlockAccessTimes(self): return self._opSimpleBlockedArrayCache.getBlockAccessTimes() def freeMemory(self): return self._opSimpleBlockedArrayCache.freeMemory() def freeBlock(self, key): return self._opSimpleBlockedArrayCache.freeBlock(key) def freeDirtyMemory(self): return self._opSimpleBlockedArrayCache.freeDirtyMemory() def generateReport(self, report): self._opSimpleBlockedArrayCache.generateReport(report) child = copy.copy(report) super(OpBlockedArrayCache, self).generateReport(report) report.children.append(child)
class OpAnisotropicGaussianSmoothing(Operator): Input = InputSlot() Sigmas = InputSlot( value={'z':1.0, 'y':1.0, 'x':1.0} ) Output = OutputSlot() def setupOutputs(self): self.Output.meta.assignFrom(self.Input.meta) #if there is a time of dim 1, output won't have that timeIndex = self.Output.meta.axistags.index('t') if timeIndex<len(self.Output.meta.shape): newshape = list(self.Output.meta.shape) newshape.pop(timeIndex) self.Output.meta.shape = tuple(newshape) del self.Output.meta.axistags[timeIndex] self.Output.meta.dtype = numpy.float32 # vigra gaussian only supports float32 self._sigmas = self.Sigmas.value assert isinstance(self.Sigmas.value, dict), "Sigmas slot expects a dict" assert set(self._sigmas.keys()) == set('zyx'), "Sigmas slot expects three key-value pairs for z,y,x" print("Assigning output: {} ====> {}".format(self.Input.meta.getTaggedShape(), self.Output.meta.getTaggedShape())) #self.Output.setDirty( slice(None) ) def execute(self, slot, subindex, roi, result): assert all(roi.stop <= self.Input.meta.shape), "Requested roi {} is too large for this input image of shape {}.".format( roi, self.Input.meta.shape ) # Determine how much input data we'll need, and where the result will be relative to that input roi inputRoi, computeRoi = self._getInputComputeRois(roi) # Obtain the input data with Timer() as resultTimer: data = self.Input( *inputRoi ).wait() logger.debug("Obtaining input data took {} seconds for roi {}".format( resultTimer.seconds(), inputRoi )) zIndex = self.Input.meta.axistags.index('z') if self.Input.meta.axistags.index('z')<len(self.Input.meta.shape) else None xIndex = self.Input.meta.axistags.index('x') yIndex = self.Input.meta.axistags.index('y') cIndex = self.Input.meta.axistags.index('c') if self.Input.meta.axistags.index('c')<len(self.Input.meta.shape) else None # Must be float32 if data.dtype != numpy.float32: data = data.astype(numpy.float32) axiskeys = self.Input.meta.getAxisKeys() spatialkeys = filter( lambda k: k in 'zyx', axiskeys ) # we need to remove a singleton z axis, otherwise we get # 'kernel longer than line' errors reskey = [slice(None, None, None)]*len(self.Input.meta.shape) reskey[cIndex]=0 if zIndex and self.Input.meta.shape[zIndex]==1: removedZ = True data = data.reshape((data.shape[xIndex], data.shape[yIndex])) reskey[zIndex]=0 spatialkeys = filter( lambda k: k in 'yx', axiskeys ) else: removedZ = False sigma = map(self._sigmas.get, spatialkeys) #Check if we need to smooth if any([x < 0.1 for x in sigma]): if removedZ: resultYX = vigra.taggedView(result, axistags="".join(axiskeys)) resultYX = resultYX.withAxes(*'yx') resultYX[:] = data else: result[:] = data return result # Smooth the input data smoothed = vigra.filters.gaussianSmoothing(data, sigma, window_size=2.0, roi=computeRoi, out=result[tuple(reskey)]) # FIXME: Assumes channel is last axis expectedShape = tuple(TinyVector(computeRoi[1]) - TinyVector(computeRoi[0])) assert tuple(smoothed.shape) == expectedShape, "Smoothed data shape {} didn't match expected shape {}".format( smoothed.shape, roi.stop - roi.start ) return result def _getInputComputeRois(self, roi): axiskeys = self.Input.meta.getAxisKeys() spatialkeys = filter( lambda k: k in 'zyx', axiskeys ) sigma = map( self._sigmas.get, spatialkeys ) inputSpatialShape = self.Input.meta.getTaggedShape() spatialRoi = ( TinyVector(roi.start), TinyVector(roi.stop) ) tIndex = None cIndex = None zIndex = None if 'c' in inputSpatialShape: del inputSpatialShape['c'] cIndex = axiskeys.index('c') if 't' in inputSpatialShape.keys(): assert inputSpatialShape['t'] == 1 tIndex = axiskeys.index('t') if 'z' in inputSpatialShape.keys() and inputSpatialShape['z']==1: #2D image, avoid kernel longer than line exception del inputSpatialShape['z'] zIndex = axiskeys.index('z') indices = [tIndex, cIndex, zIndex] indices = sorted(indices, reverse=True) for ind in indices: if ind: spatialRoi[0].pop(ind) spatialRoi[1].pop(ind) inputSpatialRoi = enlargeRoiForHalo(spatialRoi[0], spatialRoi[1], inputSpatialShape.values(), sigma, window=2.0) # Determine the roi within the input data we're going to request inputRoiOffset = spatialRoi[0] - inputSpatialRoi[0] computeRoi = (inputRoiOffset, inputRoiOffset + spatialRoi[1] - spatialRoi[0]) # For some reason, vigra.filters.gaussianSmoothing will raise an exception if this parameter doesn't have the correct integer type. # (for example, if we give it as a numpy.ndarray with dtype=int64, we get an error) computeRoi = ( tuple(map(int, computeRoi[0])), tuple(map(int, computeRoi[1])) ) inputRoi = (list(inputSpatialRoi[0]), list(inputSpatialRoi[1])) for ind in reversed(indices): if ind: inputRoi[0].insert( ind, 0 ) inputRoi[1].insert( ind, 1 ) return inputRoi, computeRoi def propagateDirty(self, slot, subindex, roi): if slot == self.Input: # Halo calculation is bidirectional, so we can re-use the function that computes the halo during execute() inputRoi, _ = self._getInputComputeRois(roi) self.Output.setDirty( inputRoi[0], inputRoi[1] ) elif slot == self.Sigmas: self.Output.setDirty( slice(None) ) else: assert False, "Unknown input slot: {}".format( slot.name )
class OpDataSelection(Operator): """ The top-level operator for the data selection applet, implemented as a single-image operator. The applet uses an OperatorWrapper to make it suitable for use in a workflow. """ name = "OpDataSelection" category = "Top-level" SupportedExtensions = OpInputDataReader.SupportedExtensions # Inputs ProjectFile = InputSlot( stype='object', optional=True) #: The project hdf5 File object (already opened) ProjectDataGroup = InputSlot( stype='string', optional=True ) #: The internal path to the hdf5 group where project-local datasets are stored within the project file WorkingDirectory = InputSlot( stype='filestring' ) #: The filesystem directory where the project file is located Dataset = InputSlot(stype='object') #: A DatasetInfo object # Outputs Image = OutputSlot() #: The output image AllowLabels = OutputSlot( stype='bool' ) #: A bool indicating whether or not this image can be used for training _NonTransposedImage = OutputSlot( ) #: The output slot, in the data's original axis ordering (regardless of forceAxisOrder) ImageName = OutputSlot(stype='string') #: The name of the output image class InvalidDimensionalityError(Exception): """Raised if the user tries to replace the dataset with a new one of differing dimensionality.""" def __init__(self, message): super(OpDataSelection.InvalidDimensionalityError, self).__init__() self.message = message def __str__(self): return self.message def __init__(self, forceAxisOrder=False, *args, **kwargs): super(OpDataSelection, self).__init__(*args, **kwargs) self.forceAxisOrder = forceAxisOrder self._opReaders = [] # If the gui calls disconnect() on an input slot without replacing it with something else, # we still need to clean up the internal operator that was providing our data. self.ProjectFile.notifyUnready(self.internalCleanup) self.ProjectDataGroup.notifyUnready(self.internalCleanup) self.WorkingDirectory.notifyUnready(self.internalCleanup) self.Dataset.notifyUnready(self.internalCleanup) def internalCleanup(self, *args): if len(self._opReaders) > 0: self.Image.disconnect() self._NonTransposedImage.disconnect() for reader in reversed(self._opReaders): reader.cleanUp() self._opReaders = [] def setupOutputs(self): self.internalCleanup() datasetInfo = self.Dataset.value try: # Data only comes from the project file if the user said so AND it exists in the project datasetInProject = ( datasetInfo.location == DatasetInfo.Location.ProjectInternal) datasetInProject &= self.ProjectFile.ready() if datasetInProject: internalPath = self.ProjectDataGroup.value + '/' + datasetInfo.datasetId datasetInProject &= internalPath in self.ProjectFile.value # If we should find the data in the project file, use a dataset reader if datasetInProject: opReader = OpStreamingHdf5Reader(parent=self) opReader.Hdf5File.setValue(self.ProjectFile.value) opReader.InternalPath.setValue(internalPath) providerSlot = opReader.OutputImage elif datasetInfo.location == DatasetInfo.Location.PreloadedArray: preloaded_array = datasetInfo.preloaded_array assert preloaded_array is not None if not hasattr(preloaded_array, 'axistags'): # Guess the axis order, since one was not provided. axisorders = {2: 'yx', 3: 'zyx', 4: 'zyxc', 5: 'tzyxc'} shape = preloaded_array.shape ndim = preloaded_array.ndim assert ndim != 0, "Support for 0-D data not yet supported" assert ndim != 1, "Support for 1-D data not yet supported" assert ndim <= 5, "No support for data with more than 5 dimensions." axisorder = axisorders[ndim] if ndim == 3 and shape[2] <= 4: # Special case: If the 3rd dim is small, assume it's 'c', not 'z' axisorder = 'yxc' preloaded_array = vigra.taggedView(preloaded_array, axisorder) opReader = OpArrayPiper(parent=self) opReader.Input.setValue(preloaded_array) providerSlot = opReader.Output else: # Use a normal (filesystem) reader opReader = OpInputDataReader(parent=self) if datasetInfo.subvolume_roi is not None: opReader.SubVolumeRoi.setValue(datasetInfo.subvolume_roi) opReader.WorkingDirectory.setValue(self.WorkingDirectory.value) opReader.FilePath.setValue(datasetInfo.filePath) providerSlot = opReader.Output self._opReaders.append(opReader) # Inject metadata if the dataset info specified any. # Also, inject if if dtype is uint8, which we can reasonably assume has drange (0,255) if datasetInfo.normalizeDisplay is not None or \ datasetInfo.drange is not None or \ datasetInfo.axistags is not None or \ (providerSlot.meta.drange is None and providerSlot.meta.dtype == numpy.uint8): metadata = {} if datasetInfo.drange is not None: metadata['drange'] = datasetInfo.drange elif providerSlot.meta.dtype == numpy.uint8: # SPECIAL case for uint8 data: Provide a default drange. # The user can always override this herself if she wants. metadata['drange'] = (0, 255) if datasetInfo.normalizeDisplay is not None: metadata['normalizeDisplay'] = datasetInfo.normalizeDisplay if datasetInfo.axistags is not None: if len(datasetInfo.axistags) != len( providerSlot.meta.shape): raise Exception( "Your dataset's provided axistags ({}) do not have the " "correct dimensionality for your dataset, which has {} dimensions." .format( "".join(tag.key for tag in datasetInfo.axistags), len(providerSlot.meta.shape))) metadata['axistags'] = datasetInfo.axistags if datasetInfo.subvolume_roi is not None: metadata['subvolume_roi'] = datasetInfo.subvolume_roi # FIXME: We are overwriting the axistags metadata to intentionally allow # the user to change our interpretation of which axis is which. # That's okay, but technically there's a special corner case if # the user redefines the channel axis index. # Technically, it invalidates the meaning of meta.ram_usage_per_requested_pixel. # For most use-cases, that won't really matter, which is why I'm not worrying about it right now. opMetadataInjector = OpMetadataInjector(parent=self) opMetadataInjector.Input.connect(providerSlot) opMetadataInjector.Metadata.setValue(metadata) providerSlot = opMetadataInjector.Output self._opReaders.append(opMetadataInjector) self._NonTransposedImage.connect(providerSlot) if self.forceAxisOrder: # Before we re-order, make sure no non-singleton # axes would be dropped by the forced order. output_order = "".join(self.forceAxisOrder) provider_order = "".join(providerSlot.meta.getAxisKeys()) tagged_provider_shape = providerSlot.meta.getTaggedShape() dropped_axes = set(provider_order) - set(output_order) if any(tagged_provider_shape[a] > 1 for a in dropped_axes): msg = "The axes of your dataset ({}) are not compatible with the axes used by this workflow ({}). Please fix them."\ .format(provider_order, output_order) raise DatasetConstraintError("DataSelection", msg) op5 = OpReorderAxes(parent=self) op5.AxisOrder.setValue(self.forceAxisOrder) op5.Input.connect(providerSlot) providerSlot = op5.Output self._opReaders.append(op5) # If the channel axis is not last (or is missing), # make sure the axes are re-ordered so that channel is last. if providerSlot.meta.axistags.index('c') != len( providerSlot.meta.axistags) - 1: op5 = OpReorderAxes(parent=self) keys = providerSlot.meta.getTaggedShape().keys() try: # Remove if present. keys.remove('c') except ValueError: pass # Append keys.append('c') op5.AxisOrder.setValue("".join(keys)) op5.Input.connect(providerSlot) providerSlot = op5.Output self._opReaders.append(op5) # Connect our external outputs to the internal operators we chose self.Image.connect(providerSlot) # Set the image name and usage flag self.AllowLabels.setValue(datasetInfo.allowLabels) # If the reading operator provides a nickname, use it. if self.Image.meta.nickname is not None: datasetInfo.nickname = self.Image.meta.nickname imageName = datasetInfo.nickname if imageName == "": imageName = datasetInfo.filePath self.ImageName.setValue(imageName) except: self.internalCleanup() raise def propagateDirty(self, slot, subindex, roi): # Output slots are directly connected to internal operators pass @classmethod def getInternalDatasets(cls, filePath): return OpInputDataReader.getInternalDatasets(filePath)