def computeJaccardScoresOnCloud(frame, labelImageFilenames, labelImagePaths, labelImageFrameIdToGlobalId, groundTruthFilename, groundTruthPath, groundTruthMinJaccardScore, pluginPaths=['hytra/plugins'], imageProviderPluginName='LocalImageLoader'): """ Compute jaccard scores of all objects in the different segmentations with the ground truth for that frame. Returns a dictionary of overlapping GT labels and the score per globalId in that frame, as well as a dictionary specifying the matching globalId and score for every GT label (as a list ordered by score, best match last). Meant to be run in its own process using `concurrent.futures.ProcessPoolExecutor` """ # set up plugin manager from hytra.pluginsystem.plugin_manager import TrackingPluginManager pluginManager = TrackingPluginManager(pluginPaths=pluginPaths, verbose=False) pluginManager.setImageProvider(imageProviderPluginName) scores = {} gtToGlobalIdMap = {} groundTruthLabelImage = pluginManager.getImageProvider().getLabelImageForFrame(groundTruthFilename, groundTruthPath, frame) for labelImageIndexA in range(len(labelImageFilenames)): labelImageA = pluginManager.getImageProvider().getLabelImageForFrame(labelImageFilenames[labelImageIndexA], labelImagePaths[labelImageIndexA], frame) # check for overlaps - even a 1-pixel overlap is enough to be mutually exclusive! for objectIdA in np.unique(labelImageA): if objectIdA == 0: continue globalIdA = labelImageFrameIdToGlobalId[(labelImageFilenames[labelImageIndexA], frame, objectIdA)] overlap = groundTruthLabelImage[labelImageA == objectIdA] overlappingGtElements = set(np.unique(overlap)) - set([0]) for gtLabel in overlappingGtElements: # compute Jaccard scores intersectingPixels = np.sum(overlap == gtLabel) unionPixels = np.sum(np.logical_or(groundTruthLabelImage == gtLabel, labelImageA == objectIdA)) jaccardScore = float(intersectingPixels) / float(unionPixels) # append to object's score list scores.setdefault(globalIdA, []).append( (gtLabel, jaccardScore) ) # store this as GT mapping if there was no better object for this GT label yet if jaccardScore > groundTruthMinJaccardScore and \ ((frame, gtLabel) not in gtToGlobalIdMap or gtToGlobalIdMap[(frame, gtLabel)][-1][1] < jaccardScore): gtToGlobalIdMap.setdefault((frame, gtLabel), []).append((globalIdA, jaccardScore)) # sort all gt mappings by ascending jaccard score for _, v in gtToGlobalIdMap.iteritems(): v.sort(key=lambda x: x[1]) return frame, scores, gtToGlobalIdMap
def findConflictingHypothesesInSeparateProcess( frame, labelImageFilenames, labelImagePaths, labelImageFrameIdToGlobalId, pluginPaths=['hytra/plugins'], imageProviderPluginName='LocalImageLoader'): """ Look which objects between different segmentation hypotheses (given as different labelImages) overlap, and return a dictionary of those overlapping situations. Meant to be run in its own process using `concurrent.futures.ProcessPoolExecutor` """ # set up plugin manager from hytra.pluginsystem.plugin_manager import TrackingPluginManager pluginManager = TrackingPluginManager(pluginPaths=pluginPaths, verbose=False) pluginManager.setImageProvider(imageProviderPluginName) overlaps = {} # overlap dict: key=globalId, value=[list of globalIds] for labelImageIndexA in range(len(labelImageFilenames)): labelImageA = pluginManager.getImageProvider().getLabelImageForFrame( labelImageFilenames[labelImageIndexA], labelImagePaths[labelImageIndexA], frame) for labelImageIndexB in range(labelImageIndexA + 1, len(labelImageFilenames)): labelImageB = pluginManager.getImageProvider( ).getLabelImageForFrame(labelImageFilenames[labelImageIndexB], labelImagePaths[labelImageIndexB], frame) # check for overlaps - even a 1-pixel overlap is enough to be mutually exclusive! for objectIdA in np.unique(labelImageA): if objectIdA == 0: continue overlapping = set( np.unique(labelImageB[labelImageA == objectIdA])) - set( [0]) overlappingGlobalIds = [ labelImageFrameIdToGlobalId[( labelImageFilenames[labelImageIndexB], frame, o)] for o in overlapping ] globalIdA = labelImageFrameIdToGlobalId[( labelImageFilenames[labelImageIndexA], frame, objectIdA)] overlaps.setdefault(globalIdA, []).extend(overlappingGlobalIds) for globalIdB in overlappingGlobalIds: overlaps.setdefault(globalIdB, []).append(globalIdA) return frame, overlaps
def findConflictingHypothesesInSeparateProcess(frame, labelImageFilenames, labelImagePaths, labelImageFrameIdToGlobalId, pluginPaths=['hytra/plugins'], imageProviderPluginName='LocalImageLoader'): """ Look which objects between different segmentation hypotheses (given as different labelImages) overlap, and return a dictionary of those overlapping situations. Meant to be run in its own process using `concurrent.futures.ProcessPoolExecutor` """ # set up plugin manager from hytra.pluginsystem.plugin_manager import TrackingPluginManager pluginManager = TrackingPluginManager(pluginPaths=pluginPaths, verbose=False) pluginManager.setImageProvider(imageProviderPluginName) overlaps = {} # overlap dict: key=globalId, value=[list of globalIds] for labelImageIndexA in range(len(labelImageFilenames)): labelImageA = pluginManager.getImageProvider().getLabelImageForFrame(labelImageFilenames[labelImageIndexA], labelImagePaths[labelImageIndexA], frame) for labelImageIndexB in range(labelImageIndexA + 1, len(labelImageFilenames)): labelImageB = pluginManager.getImageProvider().getLabelImageForFrame(labelImageFilenames[labelImageIndexB], labelImagePaths[labelImageIndexB], frame) # check for overlaps - even a 1-pixel overlap is enough to be mutually exclusive! for objectIdA in np.unique(labelImageA): if objectIdA == 0: continue overlapping = set(np.unique(labelImageB[labelImageA == objectIdA])) - set([0]) overlappingGlobalIds = [labelImageFrameIdToGlobalId[(labelImageFilenames[labelImageIndexB], frame, o)] for o in overlapping] globalIdA = labelImageFrameIdToGlobalId[(labelImageFilenames[labelImageIndexA], frame, objectIdA)] overlaps.setdefault(globalIdA, []).extend(overlappingGlobalIds) for globalIdB in overlappingGlobalIds: overlaps.setdefault(globalIdB, []).append(globalIdA) return frame, overlaps
def writeEvents(timestep, activeLinks, activeDivisions, mergers, detections, fn, labelImagePath, ilpFilename, verbose, pluginPaths): dis = [] app = [] div = [] mov = [] mer = [] mul = [] pluginManager = TrackingPluginManager(verbose=verbose, pluginPaths=pluginPaths) logging.getLogger('json_result_to_events.py').debug("-- Writing results to {}".format(fn)) try: # convert to ndarray for better indexing dis = np.asarray(dis) app = np.asarray(app) div = np.asarray([[k, v[0], v[1]] for k,v in activeDivisions.iteritems()]) mov = np.asarray(activeLinks) mer = np.asarray([[k,v] for k,v in mergers.iteritems()]) mul = np.asarray(mul) shape = pluginManager.getImageProvider().getImageShape(ilpFilename, labelImagePath) label_img = pluginManager.getImageProvider().getLabelImageForFrame(ilpFilename, labelImagePath, timestep) with h5py.File(fn, 'w') as dest_file: # write meta fields and copy segmentation from project seg = dest_file.create_group('segmentation') seg.create_dataset("labels", data=label_img, compression='gzip') meta = dest_file.create_group('objects/meta') ids = np.unique(label_img) ids = ids[ids > 0] valid = np.ones(ids.shape) meta.create_dataset("id", data=ids, dtype=np.uint32) meta.create_dataset("valid", data=valid, dtype=np.uint32) tg = dest_file.create_group("tracking") # write associations if app is not None and len(app) > 0: ds = tg.create_dataset("Appearances", data=app, dtype=np.int32) ds.attrs["Format"] = "cell label appeared in current file" if dis is not None and len(dis) > 0: ds = tg.create_dataset("Disappearances", data=dis, dtype=np.int32) ds.attrs["Format"] = "cell label disappeared in current file" if mov is not None and len(mov) > 0: ds = tg.create_dataset("Moves", data=mov, dtype=np.int32) ds.attrs["Format"] = "from (previous file), to (current file)" if div is not None and len(div) > 0: ds = tg.create_dataset("Splits", data=div, dtype=np.int32) ds.attrs["Format"] = "ancestor (previous file), descendant (current file), descendant (current file)" if mer is not None and len(mer) > 0: ds = tg.create_dataset("Mergers", data=mer, dtype=np.int32) ds.attrs["Format"] = "descendant (current file), number of objects" if mul is not None and len(mul) > 0: ds = tg.create_dataset("MultiFrameMoves", data=mul, dtype=np.int32) ds.attrs["Format"] = "from (given by timestep), to (current file), timestep" logging.getLogger('json_result_to_events.py').debug("-> results successfully written") except Exception as e: logging.getLogger('json_result_to_events.py').warning("ERROR while writing events: {}".format(str(e)))
start = time.time() feature_path = options.feats_path with_div = not bool(options.without_divisions) with_merger_prior = True # get selected time range time_range = [options.mints, options.maxts] if options.maxts == -1 and options.mints == 0: time_range = None try: # find shape of dataset pluginManager = TrackingPluginManager(verbose=options.verbose, pluginPaths=options.pluginPaths) shape = pluginManager.getImageProvider().getImageShape( ilp_fn, options.label_img_path) data_time_range = pluginManager.getImageProvider().getTimeRange( ilp_fn, options.label_img_path) if time_range is not None and time_range[1] < 0: time_range[1] += data_time_range[1] except: logging.warning("Could not read shape and time range from images") shape = None # set average object size if chosen obj_size = [0] if options.avg_obj_size != 0: obj_size[0] = options.avg_obj_size else:
def computeJaccardScoresOnCloud(frame, labelImageFilenames, labelImagePaths, labelImageFrameIdToGlobalId, groundTruthFilename, groundTruthPath, groundTruthMinJaccardScore, pluginPaths=['hytra/plugins'], imageProviderPluginName='LocalImageLoader'): """ Compute jaccard scores of all objects in the different segmentations with the ground truth for that frame. Returns a dictionary of overlapping GT labels and the score per globalId in that frame, as well as a dictionary specifying the matching globalId and score for every GT label (as a list ordered by score, best match last). Meant to be run in its own process using `concurrent.futures.ProcessPoolExecutor` """ # set up plugin manager from hytra.pluginsystem.plugin_manager import TrackingPluginManager pluginManager = TrackingPluginManager(pluginPaths=pluginPaths, verbose=False) pluginManager.setImageProvider(imageProviderPluginName) scores = {} gtToGlobalIdMap = {} groundTruthLabelImage = pluginManager.getImageProvider( ).getLabelImageForFrame(groundTruthFilename, groundTruthPath, frame) for labelImageIndexA in range(len(labelImageFilenames)): labelImageA = pluginManager.getImageProvider().getLabelImageForFrame( labelImageFilenames[labelImageIndexA], labelImagePaths[labelImageIndexA], frame) # check for overlaps - even a 1-pixel overlap is enough to be mutually exclusive! for objectIdA in np.unique(labelImageA): if objectIdA == 0: continue globalIdA = labelImageFrameIdToGlobalId[( labelImageFilenames[labelImageIndexA], frame, objectIdA)] overlap = groundTruthLabelImage[labelImageA == objectIdA] overlappingGtElements = set(np.unique(overlap)) - set([0]) for gtLabel in overlappingGtElements: # compute Jaccard scores intersectingPixels = np.sum(overlap == gtLabel) unionPixels = np.sum( np.logical_or(groundTruthLabelImage == gtLabel, labelImageA == objectIdA)) jaccardScore = float(intersectingPixels) / float(unionPixels) # append to object's score list scores.setdefault(globalIdA, []).append( (gtLabel, jaccardScore)) # store this as GT mapping if there was no better object for this GT label yet if jaccardScore > groundTruthMinJaccardScore and \ ((frame, gtLabel) not in gtToGlobalIdMap or gtToGlobalIdMap[(frame, gtLabel)][-1][1] < jaccardScore): gtToGlobalIdMap.setdefault((frame, gtLabel), []).append( (globalIdA, jaccardScore)) # sort all gt mappings by ascending jaccard score for _, v in gtToGlobalIdMap.iteritems(): v.sort(key=lambda x: x[1]) return frame, scores, gtToGlobalIdMap
# jumping over time frames, so creating if trackId in gapTrackParents.keys(): if gapTrackParents[trackId] != trackId: parent = gapTrackParents[trackId] getLogger().info( "Jumping over one time frame in this link: trackid: {}, parent: {}, time: {}" .format(trackId, parent, min(timestepList))) trackDict[trackId] = [parent, min(timestepList), max(timestepList)] save_tracks(trackDict, args) # load images, relabel, and export relabeled result getLogger().debug("Saving relabeled images") pluginManager = TrackingPluginManager(verbose=args.verbose, pluginPaths=args.pluginPaths) pluginManager.setImageProvider('LocalImageLoader') imageProvider = pluginManager.getImageProvider() timeRange = imageProvider.getTimeRange(args.label_image_filename, args.label_image_path) for timeframe in range(timeRange[0], timeRange[1]): label_image = imageProvider.getLabelImageForFrame( args.label_image_filename, args.label_image_path, timeframe) # check if frame is empty if timeframe in mappings.keys(): remapped_label_image = remap_label_image(label_image, mappings[timeframe]) save_frame_to_tif(timeframe, remapped_label_image, args) else: save_frame_to_tif(timeframe, label_image, args)
def writeEvents(timestep, activeLinks, activeDivisions, mergers, detections, fn, labelImagePath, ilpFilename, verbose, pluginPaths): dis = [] app = [] div = [] mov = [] mer = [] mul = [] pluginManager = TrackingPluginManager(verbose=verbose, pluginPaths=pluginPaths) logging.getLogger('json_result_to_events.py').debug( "-- Writing results to {}".format(fn)) try: # convert to ndarray for better indexing dis = np.asarray(dis) app = np.asarray(app) div = np.asarray([[k, v[0], v[1]] for k, v in activeDivisions.items()]) mov = np.asarray(activeLinks) mer = np.asarray([[k, v] for k, v in mergers.items()]) mul = np.asarray(mul) shape = pluginManager.getImageProvider().getImageShape( ilpFilename, labelImagePath) label_img = pluginManager.getImageProvider().getLabelImageForFrame( ilpFilename, labelImagePath, timestep) with h5py.File(fn, 'w') as dest_file: # write meta fields and copy segmentation from project seg = dest_file.create_group('segmentation') seg.create_dataset("labels", data=label_img, compression='gzip') meta = dest_file.create_group('objects/meta') ids = np.unique(label_img) ids = ids[ids > 0] valid = np.ones(ids.shape) meta.create_dataset("id", data=ids, dtype=np.uint32) meta.create_dataset("valid", data=valid, dtype=np.uint32) tg = dest_file.create_group("tracking") # write associations if app is not None and len(app) > 0: ds = tg.create_dataset("Appearances", data=app, dtype=np.int32) ds.attrs["Format"] = "cell label appeared in current file" if dis is not None and len(dis) > 0: ds = tg.create_dataset("Disappearances", data=dis, dtype=np.int32) ds.attrs["Format"] = "cell label disappeared in current file" if mov is not None and len(mov) > 0: ds = tg.create_dataset("Moves", data=mov, dtype=np.int32) ds.attrs["Format"] = "from (previous file), to (current file)" if div is not None and len(div) > 0: ds = tg.create_dataset("Splits", data=div, dtype=np.int32) ds.attrs[ "Format"] = "ancestor (previous file), descendant (current file), descendant (current file)" if mer is not None and len(mer) > 0: ds = tg.create_dataset("Mergers", data=mer, dtype=np.int32) ds.attrs[ "Format"] = "descendant (current file), number of objects" if mul is not None and len(mul) > 0: ds = tg.create_dataset("MultiFrameMoves", data=mul, dtype=np.int32) ds.attrs[ "Format"] = "from (given by timestep), to (current file), timestep" logging.getLogger('json_result_to_events.py').debug( "-> results successfully written") except Exception as e: logging.getLogger('json_result_to_events.py').warning( "ERROR while writing events: {}".format(str(e)))
def run_pipeline(options): """ Run the complete tracking pipeline with competing segmentation hypotheses """ # set up probabilitygenerator (aka traxelstore) and hypothesesgraph or load them from a dump if options.load_graph_filename is not None: getLogger().info("Loading state from file: " + options.load_graph_filename) with gzip.open(options.load_graph_filename, 'r') as graphDump: ilpOptions = pickle.load(graphDump) probGenerator = pickle.load(graphDump) fieldOfView = pickle.load(graphDump) hypotheses_graph = pickle.load(graphDump) trackingGraph = pickle.load(graphDump) getLogger().info("Done loading state from file") else: fieldOfView, hypotheses_graph, ilpOptions, probGenerator, trackingGraph = setupGraph(options) if options.dump_graph_filename is not None: getLogger().info("Saving state to file: " + options.dump_graph_filename) with gzip.open(options.dump_graph_filename, 'w') as graphDump: pickle.dump(ilpOptions, graphDump) pickle.dump(probGenerator, graphDump) pickle.dump(fieldOfView, graphDump) pickle.dump(hypotheses_graph, graphDump) pickle.dump(trackingGraph, graphDump) getLogger().info("Done saving state to file") # map groundtruth to hypothesesgraph if all required variables are specified weights = None if options.gt_label_image_file is not None and options.gt_label_image_path is not None \ and options.gt_text_file is not None and options.gt_jaccard_threshold is not None: weights = mapGroundTruth(options, hypotheses_graph, trackingGraph, probGenerator) # track result = runTracking(options, trackingGraph, weights) # insert the solution into the hypotheses graph and from that deduce the lineages getLogger().info("Inserting solution into graph") hypotheses_graph.insertSolution(result) hypotheses_graph.computeLineage() mappings = {} # dictionary over timeframes, containing another dict objectId -> trackId per frame tracks = {} # stores a list of timeframes per track, so that we can find from<->to per track trackParents = {} # store the parent trackID of a track if known for n in hypotheses_graph.nodeIterator(): frameMapping = mappings.setdefault(n[0], {}) if 'trackId' not in hypotheses_graph._graph.node[n]: raise ValueError("You need to compute the Lineage of every node before accessing the trackId!") trackId = hypotheses_graph._graph.node[n]['trackId'] traxel = hypotheses_graph._graph.node[n]['traxel'] if trackId is not None: frameMapping[(traxel.idInSegmentation, traxel.segmentationFilename)] = trackId if trackId in tracks: tracks[trackId].append(n[0]) else: tracks[trackId] = [n[0]] if 'parent' in hypotheses_graph._graph.node[n]: assert(trackId not in trackParents) trackParents[trackId] = hypotheses_graph._graph.node[hypotheses_graph._graph.node[n]['parent']]['trackId'] # write res_track.txt getLogger().info("Writing track text file") trackDict = {} for trackId, timestepList in tracks.items(): timestepList.sort() try: parent = trackParents[trackId] except KeyError: parent = 0 trackDict[trackId] = [parent, min(timestepList), max(timestepList)] save_tracks(trackDict, options) # export results getLogger().info("Saving relabeled images") pluginManager = TrackingPluginManager(verbose=options.verbose, pluginPaths=options.pluginPaths) pluginManager.setImageProvider('LocalImageLoader') imageProvider = pluginManager.getImageProvider() timeRange = probGenerator.timeRange for timeframe in range(timeRange[0], timeRange[1]): label_images = {} for f, p in zip(options.label_image_files, options.label_image_paths): label_images[f] = imageProvider.getLabelImageForFrame(f, p, timeframe) remapped_label_image = remap_label_image(label_images, mappings[timeframe]) save_frame_to_tif(timeframe, remapped_label_image, options)
def computeRegionFeaturesOnCloud( frame, rawImageFilename, rawImagePath, rawImageAxes, labelImageFilename, labelImagePath, turnOffFeatures, pluginPaths=['hytra/plugins'], featuresPerFrame=None, imageProviderPluginName='LocalImageLoader', featureSerializerPluginName='LocalFeatureSerializer'): ''' Allow to use dispy to schedule feature computation to nodes running a dispynode, or to use multiprocessing. **Parameters** * `frame`: the frame number * `rawImageFilename`: the base filename of the raw image volume, or a dvid server address * `rawImagePath`: path inside the raw image HDF5 file, or DVID dataset UUID * `rawImageAxes`: axes configuration of the raw data * `labelImageFilename`: the base filename of the label image volume, or a dvid server address * `labelImagePath`: path inside the label image HDF5 file, or DVID dataset UUID * `pluginPaths`: where all yapsy plugins are stored (should be absolute for DVID) **returns** the feature dictionary for this frame if `featureSerializerPluginName == 'LocalFeatureSerializer'` and `featuresPerFrame == None`. ''' # set up plugin manager from hytra.pluginsystem.plugin_manager import TrackingPluginManager pluginManager = TrackingPluginManager(pluginPaths=pluginPaths, turnOffFeatures=turnOffFeatures, verbose=False) pluginManager.setImageProvider(imageProviderPluginName) pluginManager.setFeatureSerializer(featureSerializerPluginName) # load raw and label image (depending on chosen plugin this works via DVID or locally) rawImage = pluginManager.getImageProvider().getImageDataAtTimeFrame( rawImageFilename, rawImagePath, rawImageAxes, frame) labelImage = pluginManager.getImageProvider().getLabelImageForFrame( labelImageFilename, labelImagePath, frame) # untwist axes, if just x and y are messed up if rawImage.shape[0] == labelImage.shape[1] and rawImage.shape[ 1] == labelImage.shape[0]: labelImage = np.transpose(labelImage, axes=[1, 0]) # compute features moreFeats, ignoreNames = pluginManager.applyObjectFeatureComputationPlugins( len(labelImage.shape), rawImage, labelImage, frame, rawImageFilename) # combine into one dictionary # WARNING: if there are multiple features with the same name, they will be overwritten! frameFeatureItems = [] for f in moreFeats: frameFeatureItems = frameFeatureItems + f.items() frameFeatures = dict(frameFeatureItems) # delete all ignored features for k in ignoreNames: if k in frameFeatures.keys(): del frameFeatures[k] # return or save features if featuresPerFrame is None and featureSerializerPluginName is 'LocalFeatureSerializer': # simply return resulting dict return frame, frameFeatures else: # set up feature serializer (local or DVID for now) featureSerializer = pluginManager.getFeatureSerializer() # server address and uuid are only used by the DVID serializer featureSerializer.server_address = labelImageFilename featureSerializer.uuid = labelImagePath # feature dictionary used by local serializer featureSerializer.features_per_frame = featuresPerFrame # store featureSerializer.storeFeaturesForFrame(frameFeatures, frame)
class IlpProbabilityGenerator(ProbabilityGenerator): """ The IlpProbabilityGenerator is a python wrapper around pgmlink's C++ traxelstore, but with the functionality to compute all region features and evaluate the division/count/transition classifiers. """ def __init__(self, ilpOptions, turnOffFeatures=[], useMultiprocessing=True, pluginPaths=['hytra/plugins'], verbose=False): self._useMultiprocessing = useMultiprocessing self._options = ilpOptions self._pluginPaths = pluginPaths self._pluginManager = TrackingPluginManager( turnOffFeatures=turnOffFeatures, verbose=verbose, pluginPaths=pluginPaths) self._pluginManager.setImageProvider(ilpOptions.imageProviderName) self._pluginManager.setFeatureSerializer( ilpOptions.featureSerializerName) self._countClassifier = None self._divisionClassifier = None self._transitionClassifier = None self._loadClassifiers() self.shape, self.timeRange = self._getShapeAndTimeRange() # set default division feature names self._divisionFeatureNames = [ 'ParentChildrenRatio_Count', 'ParentChildrenRatio_Mean', 'ChildrenRatio_Count', 'ChildrenRatio_Mean', 'ParentChildrenAngle_RegionCenter', 'ChildrenRatio_SquaredDistances' ] # other parameters that one might want to set self.x_scale = 1.0 self.y_scale = 1.0 self.z_scale = 1.0 self.divisionProbabilityFeatureName = 'divProb' self.detectionProbabilityFeatureName = 'detProb' self.TraxelsPerFrame = {} ''' this public variable contains all traxels if we're not using pgmlink ''' def _loadClassifiers(self): if self._options.objectCountClassifierPath != None and self._options.objectCountClassifierFilename != None: self._countClassifier = RandomForestClassifier( self._options.objectCountClassifierPath, self._options.objectCountClassifierFilename, self._options) if self._options.divisionClassifierPath != None and self._options.divisionClassifierFilename != None: self._divisionClassifier = RandomForestClassifier( self._options.divisionClassifierPath, self._options.divisionClassifierFilename, self._options) if self._options.transitionClassifierPath != None and self._options.transitionClassifierFilename != None: self._transitionClassifier = RandomForestClassifier( self._options.transitionClassifierPath, self._options.transitionClassifierFilename, self._options) def __getstate__(self): ''' We define __getstate__ and __setstate__ to exclude the random forests from being pickled, as that is not allowed. See https://docs.python.org/3/library/pickle.html#pickle-state for more details. ''' # Copy the object's state from self.__dict__ which contains # all our instance attributes. Always use the dict.copy() # method to avoid modifying the original state. state = self.__dict__.copy() # Remove the unpicklable entries. del state['_countClassifier'] del state['_divisionClassifier'] del state['_transitionClassifier'] return state def __setstate__(self, state): # Restore instance attributes self.__dict__.update(state) # Restore the random forests by reading them from scratch self._loadClassifiers() def computeRegionFeatures(self, rawImage, labelImage, frameNumber): """ Computes all region features for all objects in the given image """ assert (labelImage.dtype == np.uint32) moreFeats, ignoreNames = self._pluginManager.applyObjectFeatureComputationPlugins( len(labelImage.shape), rawImage, labelImage, frameNumber, self._options.rawImageFilename) frameFeatureItems = [] for f in moreFeats: frameFeatureItems = frameFeatureItems + f.items() frameFeatures = dict(frameFeatureItems) # delete the "Global<Min/Max>" features as they are not nice when iterating over everything for k in ignoreNames: if k in frameFeatures.keys(): del frameFeatures[k] return frameFeatures def computeDivisionFeatures(self, featuresAtT, featuresAtTPlus1, labelImageAtTPlus1): """ Computes the division features for all objects in the images """ fm = hytra.core.divisionfeatures.FeatureManager( ndim=self.getNumDimensions()) return fm.computeFeatures_at(featuresAtT, featuresAtTPlus1, labelImageAtTPlus1, self._divisionFeatureNames) def setDivisionFeatures(self, divisionFeatures): """ Set which features should be computed explicitly for divisions by giving a list of strings. Each string could be a combination of <operation>_<feature>, where Operation is one of: * ParentIdentity * SquaredDistances * ChildrenRatio * ParentChildrenAngle * ParentChildrenRatio And <feature> is any region feature plus "SquaredDistances" """ # TODO: check that the strings are valid? self._divisionFeatureNames = divisionFeatures def getNumDimensions(self): """ Compute the number of dimensions which is the number of axis with more than 1 element """ return np.count_nonzero(np.array(self.shape) != 1) def _getShapeAndTimeRange(self): """ extract the shape from the labelimage """ shape = self._pluginManager.getImageProvider().getImageShape( self._options.labelImageFilename, self._options.labelImagePath) timerange = self._pluginManager.getImageProvider().getTimeRange( self._options.labelImageFilename, self._options.labelImagePath) return shape, timerange def getLabelImageForFrame(self, timeframe): """ Get the label image(volume) of one time frame """ rawImage = self._pluginManager.getImageProvider( ).getLabelImageForFrame(self._options.labelImageFilename, self._options.labelImagePath, timeframe) return rawImage def getRawImageForFrame(self, timeframe): """ Get the raw image(volume) of one time frame """ rawImage = self._pluginManager.getImageProvider( ).getImageDataAtTimeFrame(self._options.rawImageFilename, self._options.rawImagePath, timeframe) return rawImage def _extractFeaturesForFrame(self, timeframe): """ extract the features of one frame, return a dictionary of features, where each feature vector contains N entries per object (where N is the dimensionality of the feature) """ rawImage = self.getRawImageForFrame(timeframe) labelImage = self.getLabelImageForFrame(timeframe) return timeframe, self.computeRegionFeatures(rawImage, labelImage, timeframe) def _extractDivisionFeaturesForFrame(self, timeframe, featuresPerFrame): """ extract Division Features for one frame, and store them in the given featuresPerFrame dict """ feats = {} if timeframe + 1 < self.timeRange[1]: labelImageAtTPlus1 = self.getLabelImageForFrame(timeframe + 1) feats = self.computeDivisionFeatures( featuresPerFrame[timeframe], featuresPerFrame[timeframe + 1], labelImageAtTPlus1) return timeframe, feats def _extractAllFeatures(self, dispyNodeIps=[], turnOffFeatures=[]): """ Extract the features of all frames. If a list of IP addresses is given e.g. as `dispyNodeIps = ["104.197.178.206","104.196.46.138"]`, then the computation will be distributed across these nodes. Otherwise, multiprocessing will be used if `self._useMultiprocessing=True`, which it is by default. If `dispyNodeIps` is an empty list, then the feature extraction will be parallelized via multiprocessing. **TODO:** fix division feature computation for distributed mode """ import logging # configure progress bar numSteps = self.timeRange[1] - self.timeRange[0] if self._divisionClassifier is not None: numSteps *= 2 t0 = time.time() if (len(dispyNodeIps) == 0): # no dispy node IDs given, parallelize object feature computation via processes if self._useMultiprocessing: # use ProcessPoolExecutor, which instanciates as many processes as there CPU cores by default ExecutorType = concurrent.futures.ProcessPoolExecutor logging.getLogger('Traxelstore').info( 'Parallelizing feature extraction via multiprocessing on all cores!' ) else: ExecutorType = DummyExecutor logging.getLogger('Traxelstore').info( 'Running feature extraction on single core!') featuresPerFrame = {} progressBar = ProgressBar(stop=numSteps) progressBar.show(increase=0) with ExecutorType() as executor: # 1st pass for region features jobs = [] for frame in range(self.timeRange[0], self.timeRange[1]): jobs.append( executor.submit(computeRegionFeaturesOnCloud, frame, self._options.rawImageFilename, self._options.rawImagePath, self._options.rawImageAxes, self._options.labelImageFilename, self._options.labelImagePath, turnOffFeatures, self._pluginPaths)) for job in concurrent.futures.as_completed(jobs): progressBar.show() frame, feats = job.result() featuresPerFrame[frame] = feats # 2nd pass for division features if self._divisionClassifier is not None: jobs = [] for frame in range(self.timeRange[0], self.timeRange[1] - 1): jobs.append( executor.submit( computeDivisionFeaturesOnCloud, frame, featuresPerFrame[frame], featuresPerFrame[frame + 1], self._pluginManager.getImageProvider(), self._options.labelImageFilename, self._options.labelImagePath, self.getNumDimensions(), self._divisionFeatureNames)) for job in concurrent.futures.as_completed(jobs): progressBar.show() frame, feats = job.result() featuresPerFrame[frame].update(feats) # # serialize features?? # for frame in range(self.timeRange[0], self.timeRange[1]): # featureSerializer.storeFeaturesForFrame(featuresPerFrame[frame], frame) else: import logging logging.getLogger('Traxelstore').warning( 'Parallelization with dispy is WORK IN PROGRESS!') import random import dispy cluster = dispy.JobCluster(computeRegionFeaturesOnCloud, nodes=dispyNodeIps, loglevel=logging.DEBUG, depends=[self._pluginManager], secret="teamtracking") jobs = [] for frame in range(self.timeRange[0], self.timeRange[1]): job = cluster.submit( frame, self._options.rawImageFilename, self._options.rawImagePath, self._options.rawImageAxes, self._options.labelImageFilename, self._options.labelImagePath, turnOffFeatures, pluginPaths=['/home/carstenhaubold/embryonic/plugins']) job.id = frame jobs.append(job) for job in jobs: job() # wait for job to finish print job.exception print job.stdout print job.stderr print job.id logging.getLogger('Traxelstore').warning( 'Using dispy we cannot compute division features yet!') # # 2nd pass for division features # if self._divisionClassifier is not None: # for frame in range(self.timeRange[0], self.timeRange[1]): # progressBar.show() # featuresPerFrame[frame].update(self._extractDivisionFeaturesForFrame(frame, featuresPerFrame)[1]) t1 = time.time() getLogger().info("Feature computation took {} secs".format(t1 - t0)) return featuresPerFrame def _setTraxelFeatureArray(self, traxel, featureArray, name): ''' store the specified `featureArray` in a `traxel`'s feature dictionary under the specified key=`name` ''' if isinstance(featureArray, np.ndarray): featureArray = featureArray.flatten() traxel.add_feature_array(name, len(featureArray)) for i, v in enumerate(featureArray): traxel.set_feature_value(name, i, float(v)) def fillTraxels(self, usePgmlink=True, ts=None, fs=None, dispyNodeIps=[], turnOffFeatures=[]): """ Compute all the features and predict object count as well as division probabilities. Store the resulting information (and all other features) in the given pgmlink::TraxelStore, or create a new one if ts=None. usePgmlink: boolean whether pgmlink should be used and a pgmlink.TraxelStore and pgmlink.FeatureStore returned ts: an initial pgmlink.TraxelStore (only used if usePgmlink=True) fs: an initial pgmlink.FeatureStore (only used if usePgmlink=True) returns (ts, fs) but only if usePgmlink=True, otherwise it fills self.TraxelsPerFrame """ if usePgmlink: import pgmlink if ts is None: ts = pgmlink.TraxelStore() fs = pgmlink.FeatureStore() else: assert (fs is not None) getLogger().info("Extracting features...") self._featuresPerFrame = self._extractAllFeatures( dispyNodeIps=dispyNodeIps, turnOffFeatures=turnOffFeatures) getLogger().info("Creating traxels...") progressBar = ProgressBar(stop=len(self._featuresPerFrame)) progressBar.show(increase=0) for frame, features in self._featuresPerFrame.iteritems(): # predict random forests if self._countClassifier is not None: objectCountProbabilities = self._countClassifier.predictProbabilities( features=None, featureDict=features) if self._divisionClassifier is not None and frame + 1 < self.timeRange[ 1]: divisionProbabilities = self._divisionClassifier.predictProbabilities( features=None, featureDict=features) # create traxels for all objects for objectId in range(1, features.values()[0].shape[0]): # print("Frame {} Object {}".format(frame, objectId)) pixelSize = features['Count'][objectId] if pixelSize == 0 or (self._options.sizeFilter is not None \ and (pixelSize < self._options.sizeFilter[0] \ or pixelSize > self._options.sizeFilter[1])): continue # create traxel if usePgmlink: traxel = pgmlink.Traxel() else: traxel = Traxel() traxel.Id = objectId traxel.Timestep = frame # add raw features for key, val in features.iteritems(): if key == 'id': traxel.idInSegmentation = val[objectId] elif key == 'filename': traxel.segmentationFilename = val[objectId] else: try: if isinstance( val, list): # polygon feature returns a list! featureValues = val[objectId] else: featureValues = val[objectId, ...] except: getLogger().error( "Could not get feature values of {} for key {} from matrix with shape {}" .format(objectId, key, val.shape)) raise AssertionError() try: self._setTraxelFeatureArray( traxel, featureValues, key) if key == 'RegionCenter': self._setTraxelFeatureArray( traxel, featureValues, 'com') except: getLogger().error( "Could not add feature array {} for {}".format( featureValues, key)) raise AssertionError() # add random forest predictions if self._countClassifier is not None: self._setTraxelFeatureArray( traxel, objectCountProbabilities[objectId, :], self.detectionProbabilityFeatureName) if self._divisionClassifier is not None and frame + 1 < self.timeRange[ 1]: self._setTraxelFeatureArray( traxel, divisionProbabilities[objectId, :], self.divisionProbabilityFeatureName) # set other parameters traxel.set_x_scale(self.x_scale) traxel.set_y_scale(self.y_scale) traxel.set_z_scale(self.z_scale) if usePgmlink: # add to pgmlink's traxelstore ts.add(fs, traxel) else: self.TraxelsPerFrame.setdefault(frame, {})[objectId] = traxel progressBar.show() if usePgmlink: return ts, fs def getTraxelFeatureDict(self, frame, objectId): """ Getter method for features per traxel """ assert self._featuresPerFrame != None traxelFeatureDict = {} for k, v in self._featuresPerFrame[frame].iteritems(): if 'Polygon' in k: traxelFeatureDict[k] = v[objectId] else: traxelFeatureDict[k] = v[objectId, ...] return traxelFeatureDict def getTransitionFeatureVector(self, featureDictObjectA, featureDictObjectB, selectedFeatures): """ Return component wise difference and product of the selected features as input for the TransitionClassifier """ features = np.array( self._pluginManager. applyTransitionFeatureVectorConstructionPlugins( featureDictObjectA, featureDictObjectB, selectedFeatures)) features = np.expand_dims(features, axis=0) return features
def run_pipeline(options): """ Run the complete tracking pipeline with competing segmentation hypotheses """ # set up probabilitygenerator (aka traxelstore) and hypothesesgraph or load them from a dump if options.load_graph_filename is not None: getLogger().info("Loading state from file: " + options.load_graph_filename) with gzip.open(options.load_graph_filename, 'r') as graphDump: ilpOptions = pickle.load(graphDump) probGenerator = pickle.load(graphDump) fieldOfView = pickle.load(graphDump) hypotheses_graph = pickle.load(graphDump) trackingGraph = pickle.load(graphDump) getLogger().info("Done loading state from file") else: fieldOfView, hypotheses_graph, ilpOptions, probGenerator, trackingGraph = setupGraph(options) if options.dump_graph_filename is not None: getLogger().info("Saving state to file: " + options.dump_graph_filename) with gzip.open(options.dump_graph_filename, 'w') as graphDump: pickle.dump(ilpOptions, graphDump) pickle.dump(probGenerator, graphDump) pickle.dump(fieldOfView, graphDump) pickle.dump(hypotheses_graph, graphDump) pickle.dump(trackingGraph, graphDump) getLogger().info("Done saving state to file") # map groundtruth to hypothesesgraph if all required variables are specified weights = None if options.gt_label_image_file is not None and options.gt_label_image_path is not None \ and options.gt_text_file is not None and options.gt_jaccard_threshold is not None: weights = mapGroundTruth(options, hypotheses_graph, trackingGraph, probGenerator) # track result = runTracking(options, trackingGraph, weights) # insert the solution into the hypotheses graph and from that deduce the lineages getLogger().info("Inserting solution into graph") hypotheses_graph.insertSolution(result) hypotheses_graph.computeLineage() mappings = {} # dictionary over timeframes, containing another dict objectId -> trackId per frame tracks = {} # stores a list of timeframes per track, so that we can find from<->to per track trackParents = {} # store the parent trackID of a track if known for n in hypotheses_graph.nodeIterator(): frameMapping = mappings.setdefault(n[0], {}) if 'trackId' not in hypotheses_graph._graph.node[n]: raise ValueError("You need to compute the Lineage of every node before accessing the trackId!") trackId = hypotheses_graph._graph.node[n]['trackId'] traxel = hypotheses_graph._graph.node[n]['traxel'] if trackId is not None: frameMapping[(traxel.idInSegmentation, traxel.segmentationFilename)] = trackId if trackId in tracks: tracks[trackId].append(n[0]) else: tracks[trackId] = [n[0]] if 'parent' in hypotheses_graph._graph.node[n]: assert(trackId not in trackParents) trackParents[trackId] = hypotheses_graph._graph.node[hypotheses_graph._graph.node[n]['parent']]['trackId'] # write res_track.txt getLogger().info("Writing track text file") trackDict = {} for trackId, timestepList in tracks.iteritems(): timestepList.sort() try: parent = trackParents[trackId] except KeyError: parent = 0 trackDict[trackId] = [parent, min(timestepList), max(timestepList)] save_tracks(trackDict, options) # export results getLogger().info("Saving relabeled images") pluginManager = TrackingPluginManager(verbose=options.verbose, pluginPaths=options.pluginPaths) pluginManager.setImageProvider('LocalImageLoader') imageProvider = pluginManager.getImageProvider() timeRange = probGenerator.timeRange for timeframe in range(timeRange[0], timeRange[1]): label_images = {} for f, p in zip(options.label_image_files, options.label_image_paths): label_images[f] = imageProvider.getLabelImageForFrame(f, p, timeframe) remapped_label_image = remap_label_image(label_images, mappings[timeframe]) save_frame_to_tif(timeframe, remapped_label_image, options)
# Do the tracking start = time.time() feature_path = options.feats_path with_div = not bool(options.without_divisions) with_merger_prior = True # get selected time range time_range = [options.mints, options.maxts] if options.maxts == -1 and options.mints == 0: time_range = None try: # find shape of dataset pluginManager = TrackingPluginManager(verbose=options.verbose, pluginPaths=options.pluginPaths) shape = pluginManager.getImageProvider().getImageShape(ilp_fn, options.label_img_path) data_time_range = pluginManager.getImageProvider().getTimeRange(ilp_fn, options.label_img_path) if time_range is not None and time_range[1] < 0: time_range[1] += data_time_range[1] except: logging.warning("Could not read shape and time range from images") # set average object size if chosen obj_size = [0] if options.avg_obj_size != 0: obj_size[0] = options.avg_obj_size else: options.avg_obj_size = obj_size
for trackId, timestepList in tracks.iteritems(): timestepList.sort() if trackId in trackParents.keys(): parent = trackParents[trackId] else: parent = 0 # jumping over time frames, so creating if trackId in gapTrackParents.keys(): if gapTrackParents[trackId] != trackId: parent = gapTrackParents[trackId] getLogger().info("Jumping over one time frame in this link: trackid: {}, parent: {}, time: {}".format(trackId, parent, min(timestepList))) trackDict[trackId] = [parent, min(timestepList), max(timestepList)] save_tracks(trackDict, args) # load images, relabel, and export relabeled result getLogger().debug("Saving relabeled images") pluginManager = TrackingPluginManager(verbose=args.verbose, pluginPaths=args.pluginPaths) pluginManager.setImageProvider('LocalImageLoader') imageProvider = pluginManager.getImageProvider() timeRange = imageProvider.getTimeRange(args.label_image_filename, args.label_image_path) for timeframe in range(timeRange[0], timeRange[1]): label_image = imageProvider.getLabelImageForFrame(args.label_image_filename, args.label_image_path, timeframe) # check if frame is empty if timeframe in mappings.keys(): remapped_label_image = remap_label_image(label_image, mappings[timeframe]) save_frame_to_tif(timeframe, remapped_label_image, args) else: save_frame_to_tif(timeframe, label_image, args)