def computeJaccardScoresOnCloud(frame, labelImageFilenames, labelImagePaths, labelImageFrameIdToGlobalId, groundTruthFilename, groundTruthPath, groundTruthMinJaccardScore, pluginPaths=['hytra/plugins'], imageProviderPluginName='LocalImageLoader'): """ Compute jaccard scores of all objects in the different segmentations with the ground truth for that frame. Returns a dictionary of overlapping GT labels and the score per globalId in that frame, as well as a dictionary specifying the matching globalId and score for every GT label (as a list ordered by score, best match last). Meant to be run in its own process using `concurrent.futures.ProcessPoolExecutor` """ # set up plugin manager from hytra.pluginsystem.plugin_manager import TrackingPluginManager pluginManager = TrackingPluginManager(pluginPaths=pluginPaths, verbose=False) pluginManager.setImageProvider(imageProviderPluginName) scores = {} gtToGlobalIdMap = {} groundTruthLabelImage = pluginManager.getImageProvider().getLabelImageForFrame(groundTruthFilename, groundTruthPath, frame) for labelImageIndexA in range(len(labelImageFilenames)): labelImageA = pluginManager.getImageProvider().getLabelImageForFrame(labelImageFilenames[labelImageIndexA], labelImagePaths[labelImageIndexA], frame) # check for overlaps - even a 1-pixel overlap is enough to be mutually exclusive! for objectIdA in np.unique(labelImageA): if objectIdA == 0: continue globalIdA = labelImageFrameIdToGlobalId[(labelImageFilenames[labelImageIndexA], frame, objectIdA)] overlap = groundTruthLabelImage[labelImageA == objectIdA] overlappingGtElements = set(np.unique(overlap)) - set([0]) for gtLabel in overlappingGtElements: # compute Jaccard scores intersectingPixels = np.sum(overlap == gtLabel) unionPixels = np.sum(np.logical_or(groundTruthLabelImage == gtLabel, labelImageA == objectIdA)) jaccardScore = float(intersectingPixels) / float(unionPixels) # append to object's score list scores.setdefault(globalIdA, []).append( (gtLabel, jaccardScore) ) # store this as GT mapping if there was no better object for this GT label yet if jaccardScore > groundTruthMinJaccardScore and \ ((frame, gtLabel) not in gtToGlobalIdMap or gtToGlobalIdMap[(frame, gtLabel)][-1][1] < jaccardScore): gtToGlobalIdMap.setdefault((frame, gtLabel), []).append((globalIdA, jaccardScore)) # sort all gt mappings by ascending jaccard score for _, v in gtToGlobalIdMap.iteritems(): v.sort(key=lambda x: x[1]) return frame, scores, gtToGlobalIdMap
def __init__(self, pluginPaths=[os.path.abspath('../hytra/plugins')], verbose=False): self.unresolvedGraph = None self.resolvedGraph = None self.mergersPerTimestep = None self.detectionsPerTimestep = None self.pluginManager = TrackingPluginManager(verbose=verbose, pluginPaths=pluginPaths) self.mergerResolverPlugin = self.pluginManager.getMergerResolver() # should be filled by constructors of derived classes! self.model = None self.result = None
def __init__(self, parent=None, graph=None): super(OpConservationTracking, self).__init__(parent=parent, graph=graph) self._opCache = OpBlockedArrayCache(parent=self) self._opCache.name = "OpConservationTracking._opCache" self._opCache.Input.connect(self.Output) self.CleanBlocks.connect(self._opCache.CleanBlocks) self.CachedOutput.connect(self._opCache.Output) self.zeroProvider = OpZeroDefault(parent=self) self.zeroProvider.MetaInput.connect(self.LabelImage) # As soon as input data is available, check its constraints self.RawImage.notifyReady(self._checkConstraints) self.LabelImage.notifyReady(self._checkConstraints) self.ExportSettings.setValue((None, None)) self._mergerOpCache = OpBlockedArrayCache(parent=self) self._mergerOpCache.name = "OpConservationTracking._mergerOpCache" self._mergerOpCache.Input.connect(self.MergerOutput) self.MergerCleanBlocks.connect(self._mergerOpCache.CleanBlocks) self.MergerCachedOutput.connect(self._mergerOpCache.Output) self._relabeledOpCache = OpBlockedArrayCache(parent=self) self._relabeledOpCache.name = "OpConservationTracking._mergerOpCache" self._relabeledOpCache.Input.connect(self.RelabeledImage) self.RelabeledCleanBlocks.connect(self._relabeledOpCache.CleanBlocks) self.RelabeledCachedOutput.connect(self._relabeledOpCache.Output) # Merger resolver plugin manager (contains GMM fit routine) self.pluginPaths = [ os.path.join(os.path.dirname(os.path.abspath(hytra.__file__)), 'plugins') ] pluginManager = TrackingPluginManager(verbose=False, pluginPaths=self.pluginPaths) self.mergerResolverPlugin = pluginManager.getMergerResolver() self.result = None # progress bar self.progressWindow = None self.progressVisitor = DefaultProgressVisitor()
def __init__(self, ilpOptions, turnOffFeatures=[], useMultiprocessing=True, pluginPaths=['hytra/plugins'], verbose=False): self._useMultiprocessing = useMultiprocessing self._options = ilpOptions self._pluginPaths = pluginPaths self._pluginManager = TrackingPluginManager( turnOffFeatures=turnOffFeatures, verbose=verbose, pluginPaths=pluginPaths) self._pluginManager.setImageProvider(ilpOptions.imageProviderName) self._pluginManager.setFeatureSerializer( ilpOptions.featureSerializerName) self._countClassifier = None self._divisionClassifier = None self._transitionClassifier = None self._loadClassifiers() self.shape, self.timeRange = self._getShapeAndTimeRange() # set default division feature names self._divisionFeatureNames = [ 'ParentChildrenRatio_Count', 'ParentChildrenRatio_Mean', 'ChildrenRatio_Count', 'ChildrenRatio_Mean', 'ParentChildrenAngle_RegionCenter', 'ChildrenRatio_SquaredDistances' ] # other parameters that one might want to set self.x_scale = 1.0 self.y_scale = 1.0 self.z_scale = 1.0 self.divisionProbabilityFeatureName = 'divProb' self.detectionProbabilityFeatureName = 'detProb' self.TraxelsPerFrame = {} ''' this public variable contains all traxels if we're not using pgmlink '''
def __init__(self, pluginPaths=[os.path.abspath('../hytra/plugins')], verbose=False): self.unresolvedGraph = None self.resolvedGraph = None self.mergersPerTimestep = None self.detectionsPerTimestep = None self.pluginManager = TrackingPluginManager( verbose=verbose, pluginPaths=pluginPaths) self.mergerResolverPlugin = self.pluginManager.getMergerResolver() # should be filled by constructors of derived classes! self.model = None self.result = None
def __init__(self, parent=None, graph=None): super(OpConservationTracking, self).__init__(parent=parent, graph=graph) self._opCache = OpBlockedArrayCache(parent=self) self._opCache.name = "OpConservationTracking._opCache" self._opCache.Input.connect(self.Output) self.CleanBlocks.connect(self._opCache.CleanBlocks) self.CachedOutput.connect(self._opCache.Output) self.zeroProvider = OpZeroDefault(parent=self) self.zeroProvider.MetaInput.connect(self.LabelImage) # As soon as input data is available, check its constraints self.RawImage.notifyReady(self._checkConstraints) self.LabelImage.notifyReady(self._checkConstraints) self.ExportSettings.setValue( (None, None) ) self._mergerOpCache = OpBlockedArrayCache(parent=self) self._mergerOpCache.name = "OpConservationTracking._mergerOpCache" self._mergerOpCache.Input.connect(self.MergerOutput) self.MergerCleanBlocks.connect(self._mergerOpCache.CleanBlocks) self.MergerCachedOutput.connect(self._mergerOpCache.Output) self._relabeledOpCache = OpBlockedArrayCache(parent=self) self._relabeledOpCache.name = "OpConservationTracking._mergerOpCache" self._relabeledOpCache.Input.connect(self.RelabeledImage) self.RelabeledCleanBlocks.connect(self._relabeledOpCache.CleanBlocks) self.RelabeledCachedOutput.connect(self._relabeledOpCache.Output) # Merger resolver plugin manager (contains GMM fit routine) self.pluginPaths = [os.path.join(os.path.dirname(os.path.abspath(hytra.__file__)), 'plugins')] pluginManager = TrackingPluginManager(verbose=False, pluginPaths=self.pluginPaths) self.mergerResolverPlugin = pluginManager.getMergerResolver() self.result = None # progress bar self.progressWindow = None self.progressVisitor=DefaultProgressVisitor()
def findConflictingHypothesesInSeparateProcess(frame, labelImageFilenames, labelImagePaths, labelImageFrameIdToGlobalId, pluginPaths=['hytra/plugins'], imageProviderPluginName='LocalImageLoader'): """ Look which objects between different segmentation hypotheses (given as different labelImages) overlap, and return a dictionary of those overlapping situations. Meant to be run in its own process using `concurrent.futures.ProcessPoolExecutor` """ # set up plugin manager from hytra.pluginsystem.plugin_manager import TrackingPluginManager pluginManager = TrackingPluginManager(pluginPaths=pluginPaths, verbose=False) pluginManager.setImageProvider(imageProviderPluginName) overlaps = {} # overlap dict: key=globalId, value=[list of globalIds] for labelImageIndexA in range(len(labelImageFilenames)): labelImageA = pluginManager.getImageProvider().getLabelImageForFrame(labelImageFilenames[labelImageIndexA], labelImagePaths[labelImageIndexA], frame) for labelImageIndexB in range(labelImageIndexA + 1, len(labelImageFilenames)): labelImageB = pluginManager.getImageProvider().getLabelImageForFrame(labelImageFilenames[labelImageIndexB], labelImagePaths[labelImageIndexB], frame) # check for overlaps - even a 1-pixel overlap is enough to be mutually exclusive! for objectIdA in np.unique(labelImageA): if objectIdA == 0: continue overlapping = set(np.unique(labelImageB[labelImageA == objectIdA])) - set([0]) overlappingGlobalIds = [labelImageFrameIdToGlobalId[(labelImageFilenames[labelImageIndexB], frame, o)] for o in overlapping] globalIdA = labelImageFrameIdToGlobalId[(labelImageFilenames[labelImageIndexA], frame, objectIdA)] overlaps.setdefault(globalIdA, []).extend(overlappingGlobalIds) for globalIdB in overlappingGlobalIds: overlaps.setdefault(globalIdB, []).append(globalIdA) return frame, overlaps
def findConflictingHypothesesInSeparateProcess( frame, labelImageFilenames, labelImagePaths, labelImageFrameIdToGlobalId, pluginPaths=['hytra/plugins'], imageProviderPluginName='LocalImageLoader'): """ Look which objects between different segmentation hypotheses (given as different labelImages) overlap, and return a dictionary of those overlapping situations. Meant to be run in its own process using `concurrent.futures.ProcessPoolExecutor` """ # set up plugin manager from hytra.pluginsystem.plugin_manager import TrackingPluginManager pluginManager = TrackingPluginManager(pluginPaths=pluginPaths, verbose=False) pluginManager.setImageProvider(imageProviderPluginName) overlaps = {} # overlap dict: key=globalId, value=[list of globalIds] for labelImageIndexA in range(len(labelImageFilenames)): labelImageA = pluginManager.getImageProvider().getLabelImageForFrame( labelImageFilenames[labelImageIndexA], labelImagePaths[labelImageIndexA], frame) for labelImageIndexB in range(labelImageIndexA + 1, len(labelImageFilenames)): labelImageB = pluginManager.getImageProvider( ).getLabelImageForFrame(labelImageFilenames[labelImageIndexB], labelImagePaths[labelImageIndexB], frame) # check for overlaps - even a 1-pixel overlap is enough to be mutually exclusive! for objectIdA in np.unique(labelImageA): if objectIdA == 0: continue overlapping = set( np.unique(labelImageB[labelImageA == objectIdA])) - set( [0]) overlappingGlobalIds = [ labelImageFrameIdToGlobalId[( labelImageFilenames[labelImageIndexB], frame, o)] for o in overlapping ] globalIdA = labelImageFrameIdToGlobalId[( labelImageFilenames[labelImageIndexA], frame, objectIdA)] overlaps.setdefault(globalIdA, []).extend(overlappingGlobalIds) for globalIdB in overlappingGlobalIds: overlaps.setdefault(globalIdB, []).append(globalIdA) return frame, overlaps
# Do the tracking start = time.time() feature_path = options.feats_path with_div = not bool(options.without_divisions) with_merger_prior = True # get selected time range time_range = [options.mints, options.maxts] if options.maxts == -1 and options.mints == 0: time_range = None try: # find shape of dataset pluginManager = TrackingPluginManager(verbose=options.verbose, pluginPaths=options.pluginPaths) shape = pluginManager.getImageProvider().getImageShape( ilp_fn, options.label_img_path) data_time_range = pluginManager.getImageProvider().getTimeRange( ilp_fn, options.label_img_path) if time_range is not None and time_range[1] < 0: time_range[1] += data_time_range[1] except: logging.warning("Could not read shape and time range from images") shape = None # set average object size if chosen obj_size = [0] if options.avg_obj_size != 0:
rawimage = hytra.util.axesconversion.adjustOrder(rawimage, args.rawimage_axes[dataset], 'xyztc') logger.info('Done loading raw data from dataset {} of shape {}'.format(dataset, rawimage.shape)) # find ground truth files # filepath is now a list of filepaths' filepath = args.filepath[dataset] # filepattern is now a list of filepatterns files = glob.glob(os.path.join(filepath, args.filepattern[dataset])) files.sort() initFrame = args.initFrame endFrame = args.endFrame if endFrame < 0: endFrame += len(files) # compute features trackingPluginManager = TrackingPluginManager(verbose=args.verbose, pluginPaths=args.pluginPaths) features = compute_features(rawimage, read_in_images(initFrame, endFrame, files, args.groundtruth_axes[dataset]), initFrame, endFrame, trackingPluginManager, rawimage_filename) logger.info('Done computing features from dataset {}'.format(dataset)) selectedFeatures = find_features_without_NaNs(features) pos_labels = read_positiveLabels(initFrame, endFrame, files) neg_labels = negativeLabels(features, pos_labels) numSamples += 2 * sum([len(l) for l in pos_labels]) + sum([len(l) for l in neg_labels]) logger.info('Done extracting {} samples'.format(numSamples)) TC = TransitionClassifier(selectedFeatures, numSamples) if dataset > 0:
def computeJaccardScoresOnCloud(frame, labelImageFilenames, labelImagePaths, labelImageFrameIdToGlobalId, groundTruthFilename, groundTruthPath, groundTruthMinJaccardScore, pluginPaths=['hytra/plugins'], imageProviderPluginName='LocalImageLoader'): """ Compute jaccard scores of all objects in the different segmentations with the ground truth for that frame. Returns a dictionary of overlapping GT labels and the score per globalId in that frame, as well as a dictionary specifying the matching globalId and score for every GT label (as a list ordered by score, best match last). Meant to be run in its own process using `concurrent.futures.ProcessPoolExecutor` """ # set up plugin manager from hytra.pluginsystem.plugin_manager import TrackingPluginManager pluginManager = TrackingPluginManager(pluginPaths=pluginPaths, verbose=False) pluginManager.setImageProvider(imageProviderPluginName) scores = {} gtToGlobalIdMap = {} groundTruthLabelImage = pluginManager.getImageProvider( ).getLabelImageForFrame(groundTruthFilename, groundTruthPath, frame) for labelImageIndexA in range(len(labelImageFilenames)): labelImageA = pluginManager.getImageProvider().getLabelImageForFrame( labelImageFilenames[labelImageIndexA], labelImagePaths[labelImageIndexA], frame) # check for overlaps - even a 1-pixel overlap is enough to be mutually exclusive! for objectIdA in np.unique(labelImageA): if objectIdA == 0: continue globalIdA = labelImageFrameIdToGlobalId[( labelImageFilenames[labelImageIndexA], frame, objectIdA)] overlap = groundTruthLabelImage[labelImageA == objectIdA] overlappingGtElements = set(np.unique(overlap)) - set([0]) for gtLabel in overlappingGtElements: # compute Jaccard scores intersectingPixels = np.sum(overlap == gtLabel) unionPixels = np.sum( np.logical_or(groundTruthLabelImage == gtLabel, labelImageA == objectIdA)) jaccardScore = float(intersectingPixels) / float(unionPixels) # append to object's score list scores.setdefault(globalIdA, []).append( (gtLabel, jaccardScore)) # store this as GT mapping if there was no better object for this GT label yet if jaccardScore > groundTruthMinJaccardScore and \ ((frame, gtLabel) not in gtToGlobalIdMap or gtToGlobalIdMap[(frame, gtLabel)][-1][1] < jaccardScore): gtToGlobalIdMap.setdefault((frame, gtLabel), []).append( (globalIdA, jaccardScore)) # sort all gt mappings by ascending jaccard score for _, v in gtToGlobalIdMap.iteritems(): v.sort(key=lambda x: x[1]) return frame, scores, gtToGlobalIdMap
parent = trackParents[trackId] else: parent = 0 # jumping over time frames, so creating if trackId in gapTrackParents.keys(): if gapTrackParents[trackId] != trackId: parent = gapTrackParents[trackId] getLogger().info( "Jumping over one time frame in this link: trackid: {}, parent: {}, time: {}" .format(trackId, parent, min(timestepList))) trackDict[trackId] = [parent, min(timestepList), max(timestepList)] save_tracks(trackDict, args) # load images, relabel, and export relabeled result getLogger().debug("Saving relabeled images") pluginManager = TrackingPluginManager(verbose=args.verbose, pluginPaths=args.pluginPaths) pluginManager.setImageProvider('LocalImageLoader') imageProvider = pluginManager.getImageProvider() timeRange = imageProvider.getTimeRange(args.label_image_filename, args.label_image_path) for timeframe in range(timeRange[0], timeRange[1]): label_image = imageProvider.getLabelImageForFrame( args.label_image_filename, args.label_image_path, timeframe) # check if frame is empty if timeframe in mappings.keys(): remapped_label_image = remap_label_image(label_image, mappings[timeframe]) save_frame_to_tif(timeframe, remapped_label_image, args) else:
def writeEvents(timestep, activeLinks, activeDivisions, mergers, detections, fn, labelImagePath, ilpFilename, verbose, pluginPaths): dis = [] app = [] div = [] mov = [] mer = [] mul = [] pluginManager = TrackingPluginManager(verbose=verbose, pluginPaths=pluginPaths) logging.getLogger('json_result_to_events.py').debug( "-- Writing results to {}".format(fn)) try: # convert to ndarray for better indexing dis = np.asarray(dis) app = np.asarray(app) div = np.asarray([[k, v[0], v[1]] for k, v in activeDivisions.items()]) mov = np.asarray(activeLinks) mer = np.asarray([[k, v] for k, v in mergers.items()]) mul = np.asarray(mul) shape = pluginManager.getImageProvider().getImageShape( ilpFilename, labelImagePath) label_img = pluginManager.getImageProvider().getLabelImageForFrame( ilpFilename, labelImagePath, timestep) with h5py.File(fn, 'w') as dest_file: # write meta fields and copy segmentation from project seg = dest_file.create_group('segmentation') seg.create_dataset("labels", data=label_img, compression='gzip') meta = dest_file.create_group('objects/meta') ids = np.unique(label_img) ids = ids[ids > 0] valid = np.ones(ids.shape) meta.create_dataset("id", data=ids, dtype=np.uint32) meta.create_dataset("valid", data=valid, dtype=np.uint32) tg = dest_file.create_group("tracking") # write associations if app is not None and len(app) > 0: ds = tg.create_dataset("Appearances", data=app, dtype=np.int32) ds.attrs["Format"] = "cell label appeared in current file" if dis is not None and len(dis) > 0: ds = tg.create_dataset("Disappearances", data=dis, dtype=np.int32) ds.attrs["Format"] = "cell label disappeared in current file" if mov is not None and len(mov) > 0: ds = tg.create_dataset("Moves", data=mov, dtype=np.int32) ds.attrs["Format"] = "from (previous file), to (current file)" if div is not None and len(div) > 0: ds = tg.create_dataset("Splits", data=div, dtype=np.int32) ds.attrs[ "Format"] = "ancestor (previous file), descendant (current file), descendant (current file)" if mer is not None and len(mer) > 0: ds = tg.create_dataset("Mergers", data=mer, dtype=np.int32) ds.attrs[ "Format"] = "descendant (current file), number of objects" if mul is not None and len(mul) > 0: ds = tg.create_dataset("MultiFrameMoves", data=mul, dtype=np.int32) ds.attrs[ "Format"] = "from (given by timestep), to (current file), timestep" logging.getLogger('json_result_to_events.py').debug( "-> results successfully written") except Exception as e: logging.getLogger('json_result_to_events.py').warning( "ERROR while writing events: {}".format(str(e)))
def run_pipeline(options): """ Run the complete tracking pipeline with competing segmentation hypotheses """ # set up probabilitygenerator (aka traxelstore) and hypothesesgraph or load them from a dump if options.load_graph_filename is not None: getLogger().info("Loading state from file: " + options.load_graph_filename) with gzip.open(options.load_graph_filename, 'r') as graphDump: ilpOptions = pickle.load(graphDump) probGenerator = pickle.load(graphDump) fieldOfView = pickle.load(graphDump) hypotheses_graph = pickle.load(graphDump) trackingGraph = pickle.load(graphDump) getLogger().info("Done loading state from file") else: fieldOfView, hypotheses_graph, ilpOptions, probGenerator, trackingGraph = setupGraph(options) if options.dump_graph_filename is not None: getLogger().info("Saving state to file: " + options.dump_graph_filename) with gzip.open(options.dump_graph_filename, 'w') as graphDump: pickle.dump(ilpOptions, graphDump) pickle.dump(probGenerator, graphDump) pickle.dump(fieldOfView, graphDump) pickle.dump(hypotheses_graph, graphDump) pickle.dump(trackingGraph, graphDump) getLogger().info("Done saving state to file") # map groundtruth to hypothesesgraph if all required variables are specified weights = None if options.gt_label_image_file is not None and options.gt_label_image_path is not None \ and options.gt_text_file is not None and options.gt_jaccard_threshold is not None: weights = mapGroundTruth(options, hypotheses_graph, trackingGraph, probGenerator) # track result = runTracking(options, trackingGraph, weights) # insert the solution into the hypotheses graph and from that deduce the lineages getLogger().info("Inserting solution into graph") hypotheses_graph.insertSolution(result) hypotheses_graph.computeLineage() mappings = {} # dictionary over timeframes, containing another dict objectId -> trackId per frame tracks = {} # stores a list of timeframes per track, so that we can find from<->to per track trackParents = {} # store the parent trackID of a track if known for n in hypotheses_graph.nodeIterator(): frameMapping = mappings.setdefault(n[0], {}) if 'trackId' not in hypotheses_graph._graph.node[n]: raise ValueError("You need to compute the Lineage of every node before accessing the trackId!") trackId = hypotheses_graph._graph.node[n]['trackId'] traxel = hypotheses_graph._graph.node[n]['traxel'] if trackId is not None: frameMapping[(traxel.idInSegmentation, traxel.segmentationFilename)] = trackId if trackId in tracks: tracks[trackId].append(n[0]) else: tracks[trackId] = [n[0]] if 'parent' in hypotheses_graph._graph.node[n]: assert(trackId not in trackParents) trackParents[trackId] = hypotheses_graph._graph.node[hypotheses_graph._graph.node[n]['parent']]['trackId'] # write res_track.txt getLogger().info("Writing track text file") trackDict = {} for trackId, timestepList in tracks.items(): timestepList.sort() try: parent = trackParents[trackId] except KeyError: parent = 0 trackDict[trackId] = [parent, min(timestepList), max(timestepList)] save_tracks(trackDict, options) # export results getLogger().info("Saving relabeled images") pluginManager = TrackingPluginManager(verbose=options.verbose, pluginPaths=options.pluginPaths) pluginManager.setImageProvider('LocalImageLoader') imageProvider = pluginManager.getImageProvider() timeRange = probGenerator.timeRange for timeframe in range(timeRange[0], timeRange[1]): label_images = {} for f, p in zip(options.label_image_files, options.label_image_paths): label_images[f] = imageProvider.getLabelImageForFrame(f, p, timeframe) remapped_label_image = remap_label_image(label_images, mappings[timeframe]) save_frame_to_tif(timeframe, remapped_label_image, options)
for trackId, timestepList in tracks.iteritems(): timestepList.sort() if trackId in trackParents.keys(): parent = trackParents[trackId] else: parent = 0 # jumping over time frames, so creating if trackId in gapTrackParents.keys(): if gapTrackParents[trackId] != trackId: parent = gapTrackParents[trackId] getLogger().info("Jumping over one time frame in this link: trackid: {}, parent: {}, time: {}".format(trackId, parent, min(timestepList))) trackDict[trackId] = [parent, min(timestepList), max(timestepList)] save_tracks(trackDict, args) # load images, relabel, and export relabeled result getLogger().debug("Saving relabeled images") pluginManager = TrackingPluginManager(verbose=args.verbose, pluginPaths=args.pluginPaths) pluginManager.setImageProvider('LocalImageLoader') imageProvider = pluginManager.getImageProvider() timeRange = imageProvider.getTimeRange(args.label_image_filename, args.label_image_path) for timeframe in range(timeRange[0], timeRange[1]): label_image = imageProvider.getLabelImageForFrame(args.label_image_filename, args.label_image_path, timeframe) # check if frame is empty if timeframe in mappings.keys(): remapped_label_image = remap_label_image(label_image, mappings[timeframe]) save_frame_to_tif(timeframe, remapped_label_image, args) else: save_frame_to_tif(timeframe, label_image, args)
class MergerResolver(object): """ Base class for all merger resolving implementations. Use one of the derived classes that handle reading/writing data to the respective sources. """ def __init__(self, pluginPaths=[os.path.abspath('../hytra/plugins')], verbose=False): self.unresolvedGraph = None self.resolvedGraph = None self.mergersPerTimestep = None self.detectionsPerTimestep = None self.pluginManager = TrackingPluginManager( verbose=verbose, pluginPaths=pluginPaths) self.mergerResolverPlugin = self.pluginManager.getMergerResolver() # should be filled by constructors of derived classes! self.model = None self.result = None def _createUnresolvedGraph(self, divisionsPerTimestep, mergersPerTimestep, mergerLinks, withFullGraph=False): """ Set up a networkx graph consisting of mergers that need to be resolved (not resolved yet!) and their direct neighbors. ** returns ** the `unresolvedGraph` """ self.unresolvedGraph = nx.DiGraph() def source(timestep, link): return int(timestep) - 1, link[0] def target(timestep, link): return int(timestep), link[1] # Recompute full graph if withFullGraph: self.unresolvedGraph = self.hypothesesGraph._graph.copy() # Add division parameter to nodes # TODO: Add the division parameter only to nodes that contain divisions (we're already doing these with 'count') lastframe = max(divisionsPerTimestep.keys(), key=int) for node in self.unresolvedGraph.nodes_iter(): timestep, idx = node if divisionsPerTimestep is not None and int(timestep) < int(lastframe): division = idx in divisionsPerTimestep[str(timestep + 1)] # +1 screams for lastframe condition. else: division = False self.unresolvedGraph.node[node]['division'] = division # Add count parameter to nodes for t, link in mergerLinks: for node in [source(t, link), target(t, link)]: timestep, idx = node if idx in mergersPerTimestep[str(timestep)]: count = mergersPerTimestep[str(timestep)][idx] self.unresolvedGraph.node[node]['count'] = count # Recompute graph only with merger nodes and neighbors else: def addNode(node): ''' add a node to the unresolved graph and fill in the properties `division` and `count` ''' intT, idx = node lastframe = max(divisionsPerTimestep.keys(), key=int) if divisionsPerTimestep is not None and int(intT) < int(lastframe): division = idx in divisionsPerTimestep[str(intT + 1)] # +1 screams for lastframe condition. else: division = False count = 1 if idx in mergersPerTimestep[str(intT)]: assert(not division) count = mergersPerTimestep[str(intT)][idx] self.unresolvedGraph.add_node(node, division=division, count=count) # add nodes for t, link in mergerLinks: for n in [source(t, link), target(t, link)]: if not self.unresolvedGraph.has_node(n): addNode(n) self.unresolvedGraph.add_edge(source(t, link), target(t, link)) return self.unresolvedGraph def _prepareResolvedGraph(self): """ ** returns ** the `resolvedGraph` """ self.resolvedGraph = self.unresolvedGraph.copy() return self.resolvedGraph def _readLabelImage(self, timeframe): ''' Should return the labelimage for the given timeframe ''' raise NotImplementedError() def _fitAndRefineNodes(self, detectionsPerTimestep, mergersPerTimestep, timesteps): ''' Update segmentation of mergers (nodes in unresolvedGraph) from first timeframe to last and create new nodes in `resolvedGraph`. Links to merger nodes are duplicated to all new nodes. Uses the mergerResolver plugin to update the segmentations in the labelImages. ''' intTimesteps = [int(t) for t in timesteps] intTimesteps.sort() for intT in intTimesteps: t = str(intT) # use image provider plugin to load labelimage labelImage = self._readLabelImage(int(t)) nextObjectId = labelImage.max() + 1 for idx in detectionsPerTimestep[t]: node = (intT, idx) if node not in self.resolvedGraph: continue count = 1 if idx in mergersPerTimestep[t]: count = mergersPerTimestep[t][idx] getLogger().debug("Looking at node {} in timestep {} with count {}".format(idx, t, count)) # collect initializations from incoming initializations = [] for predecessor, _ in self.unresolvedGraph.in_edges(node): initializations.extend(self.unresolvedGraph.node[predecessor]['fits']) # TODO: what shall we do if e.g. a 2-merger and a single object merge to 2 + 1, # so there are 3 initializations for the 2-merger, and two initializations for the 1 merger? # What does pgmlink do in that case? # use merger resolving plugin to fit `count` objects, also updates labelimage! fittedObjects = self.mergerResolverPlugin.resolveMerger(labelImage, idx, nextObjectId, count, initializations) assert(len(fittedObjects) == count) # split up node if count > 1, duplicate incoming and outgoing arcs if count > 1: for idx in range(nextObjectId, nextObjectId + count): newNode = (intT, idx) self.resolvedGraph.add_node(newNode, division=False, count=1, origin=node) for e in self.unresolvedGraph.out_edges(node): self.resolvedGraph.add_edge(newNode, e[1]) for e in self.unresolvedGraph.in_edges(node): if 'newIds' in self.unresolvedGraph.node[e[0]]: for newId in self.unresolvedGraph.node[e[0]]['newIds']: self.resolvedGraph.add_edge((e[0][0], newId), newNode) else: self.resolvedGraph.add_edge(e[0], newNode) self.resolvedGraph.remove_node(node) self.unresolvedGraph.node[node]['newIds'] = range(nextObjectId, nextObjectId + count) nextObjectId += count # each unresolved node stores its fitted shape(s) to be used # as initialization in the next frame, this way division duplicates # and de-merged nodes in the resolved graph do not need to store a fit as well self.unresolvedGraph.node[node]['fits'] = fittedObjects # import matplotlib.pyplot as plt # nx.draw_networkx(resolvedGraph) # plt.savefig("/Users/chaubold/test.pdf") def _minCostMaxFlowMergerResolving(self, objectFeatures, transitionClassifier=None, transitionParameter=5.0): """ Find the optimal assignments within the `resolvedGraph` by running min-cost max-flow from the `dpct` module. Converts the `resolvedGraph` to our JSON model structure, predicts the transition probabilities either using the given transitionClassifier, or using distance-based probabilities. **returns** a `nodeFlowMap` and `arcFlowMap` holding information on the usage of the respective nodes and links **Note:** cannot use `networkx` flow methods because they don't work with floating point weights. """ trackingGraph = JsonTrackingGraph() for node in self.resolvedGraph.nodes_iter(): additionalFeatures = {} # nodes with no in/out numStates = 2 if len(self.resolvedGraph.in_edges(node)) == 0: # division nodes with no incoming arcs offer 2 units of flow without the need to de-merge if node in self.unresolvedGraph.nodes() and self.unresolvedGraph.node[node]['division'] and len(self.unresolvedGraph.out_edges(node)) == 2: numStates = 3 additionalFeatures['appearanceFeatures'] = [[i**2 * 0.01] for i in range(numStates)] if len(self.resolvedGraph.out_edges(node)) == 0: assert(numStates == 2) # division nodes with no incoming should have outgoing, or they shouldn't show up in resolved graph additionalFeatures['disappearanceFeatures'] = [[i**2 * 0.01] for i in range(numStates)] features = [[i**2] for i in range(numStates)] uuid = trackingGraph.addDetectionHypotheses(features, **additionalFeatures) self.resolvedGraph.node[node]['id'] = uuid for edge in self.resolvedGraph.edges_iter(): src = self.resolvedGraph.node[edge[0]]['id'] dest = self.resolvedGraph.node[edge[1]]['id'] featuresAtSrc = objectFeatures[edge[0]] featuresAtDest = objectFeatures[edge[1]] if transitionClassifier is not None: try: featVec = self.pluginManager.applyTransitionFeatureVectorConstructionPlugins( featuresAtSrc, featuresAtDest, transitionClassifier.selectedFeatures) except: getLogger().error("Could not compute transition features of link {}->{}:".format(src, dest)) getLogger().error(featuresAtSrc) getLogger().error(featuresAtDest) raise featVec = np.expand_dims(np.array(featVec), axis=0) probs = transitionClassifier.predictProbabilities(featVec)[0] else: dist = np.linalg.norm(featuresAtDest['RegionCenter'] - featuresAtSrc['RegionCenter']) prob = np.exp(-dist / transitionParameter) probs = [1.0 - prob, prob] trackingGraph.addLinkingHypotheses(src, dest, listify(negLog(probs))) # track import dpct weights = {"weights": [1, 1, 1, 1]} mergerResult = dpct.trackMaxFlow(trackingGraph.model, weights) # transform results to dictionaries that can be indexed by id or (src,dest) nodeFlowMap = dict([(int(d['id']), int(d['value'])) for d in mergerResult['detectionResults']]) arcFlowMap = dict([((int(l['src']), int(l['dest'])), int(l['value'])) for l in mergerResult['linkingResults']]) return nodeFlowMap, arcFlowMap def _refineModel(self, uuidToTraxelMap, traxelIdPerTimestepToUniqueIdMap, mergerNodeFilter, mergerLinkFilter): """ Take the `self.model` (JSON format) with mergers, remove the merger nodes, but add new de-merged nodes and links. Also updates `traxelIdPerTimestepToUniqueIdMap` locally and in the resulting file, such that the traxel IDs match the new connected component IDs in the refined images. `mergerNodeFilter` and `mergerLinkFilter` are methods that can filter merger detections and links from the respective lists in the `model` dict. **Returns** the updated `model` dictionary, which is the same as the input `model` (works in-place) """ # remove merger detections self.model['segmentationHypotheses'] = [seg for seg in self.model['segmentationHypotheses'] if mergerNodeFilter(seg)] # remove merger links self.model['linkingHypotheses'] = [link for link in self.model['linkingHypotheses'] if mergerLinkFilter(link)] # insert new nodes and update UUID to traxel map nextUuid = max(uuidToTraxelMap.keys()) + 1 for node in self.unresolvedGraph.nodes_iter(): if 'count' in self.unresolvedGraph.node[node] and self.unresolvedGraph.node[node]['count'] > 1: newIds = self.unresolvedGraph.node[node]['newIds'] del traxelIdPerTimestepToUniqueIdMap[str(node[0])][str(node[1])] for newId in newIds: newDetection = {} newDetection['id'] = nextUuid newDetection['timestep'] = [node[0], node[0]] self.model['segmentationHypotheses'].append(newDetection) traxelIdPerTimestepToUniqueIdMap[str(node[0])][str(newId)] = nextUuid nextUuid += 1 # insert new links for edge in self.resolvedGraph.edges_iter(): newLink = {} newLink['src'] = traxelIdPerTimestepToUniqueIdMap[str(edge[0][0])][str(edge[0][1])] newLink['dest'] = traxelIdPerTimestepToUniqueIdMap[str(edge[1][0])][str(edge[1][1])] self.model['linkingHypotheses'].append(newLink) # save return self.model def _refineResult(self, nodeFlowMap, arcFlowMap, traxelIdPerTimestepToUniqueIdMap, mergerNodeFilter, mergerLinkFilter): """ Update the `self.result` dict by removing the mergers and adding the refined nodes and links. Operates on a `result` dictionary in our JSON result style with mergers, the resolved and unresolved graph as well as the `nodeFlowMap` and `arcFlowMap` obtained by running tracking on the `resolvedGraph`. Updates the `result` dictionary so that all merger nodes are removed but the new nodes are contained with the appropriate links and values. `mergerNodeFilter` and `mergerLinkFilter` are methods that can filter merger detections and links from the respective lists in the `result` dict. **Returns** the updated `result` dict, which is the same as the input `result` (works in-place) """ # filter merger edges self.result['detectionResults'] = [r for r in self.result['detectionResults'] if mergerNodeFilter(r)] self.result['linkingResults'] = [r for r in self.result['linkingResults'] if mergerLinkFilter(r)] # add new nodes for node in self.unresolvedGraph.nodes_iter(): if 'count' in self.unresolvedGraph.node[node] and self.unresolvedGraph.node[node]['count'] > 1: newIds = self.unresolvedGraph.node[node]['newIds'] for newId in newIds: uuid = traxelIdPerTimestepToUniqueIdMap[str(node[0])][str(newId)] resolvedNode = (node[0], newId) resolvedResultId = self.resolvedGraph.node[resolvedNode]['id'] newDetection = {'id': uuid, 'value': nodeFlowMap[resolvedResultId]} self.result['detectionResults'].append(newDetection) # add new links for edge in self.resolvedGraph.edges_iter(): newLink = {} newLink['src'] = traxelIdPerTimestepToUniqueIdMap[str(edge[0][0])][str(edge[0][1])] newLink['dest'] = traxelIdPerTimestepToUniqueIdMap[str(edge[1][0])][str(edge[1][1])] srcId = self.resolvedGraph.node[edge[0]]['id'] destId = self.resolvedGraph.node[edge[1]]['id'] newLink['value'] = arcFlowMap[(srcId, destId)] self.result['linkingResults'].append(newLink) return self.result def _exportRefinedSegmentation(self, timesteps): """ Store the resulting label images, if needed. `labelImages` is a dictionary with str(timestep) as keys. """ pass def _computeObjectFeatures(self, timesteps): ''' Return the features per object as nested dictionaries: { (int(Timestep), int(Id)):{ "FeatureName" : np.array(value), "NextFeature": ...} } ''' pass # ------------------------------------------------------------ def run(self, transition_classifier_filename=None, transition_classifier_path=None): """ Run merger resolving 1. find mergers in the given model and result 2. build graph of the unresolved (merger) nodes and their direct neighbors 3. use a mergerResolving plugin to refine the merger nodes and their segmentation 4. run min-cost max-flow tracking to find the fate of all the de-merged objects 5. export refined segmentation, update member variables `model` and `result` **Returns** a nested dictionary, indexed first by time, then object Id, containing a list of new segmentIDs per merger """ traxelIdPerTimestepToUniqueIdMap, uuidToTraxelMap = hytra.core.jsongraph.getMappingsBetweenUUIDsAndTraxels(self.model) # timesteps = [t for t in traxelIdPerTimestepToUniqueIdMap.keys()] # there might be empty frames. We want them as output too. timesteps = [str(t).decode("utf-8") for t in range(int(min(traxelIdPerTimestepToUniqueIdMap.keys())) , int(max(traxelIdPerTimestepToUniqueIdMap.keys()))+1 )] mergers, detections, links, divisions = hytra.core.jsongraph.getMergersDetectionsLinksDivisions(self.result, uuidToTraxelMap) # ------------------------------------------------------------ # it may be, that there are no mergers, so do basically nothing, just copy all the ingoing data if len(mergers) == 0: getLogger().info("The maximum number of objects is 1, so nothing to be done. Writing the output...") self._exportRefinedSegmentation(timesteps) else: self.mergersPerTimestep = hytra.core.jsongraph.getMergersPerTimestep(mergers, timesteps) self.detectionsPerTimestep = hytra.core.jsongraph.getDetectionsPerTimestep(detections, timesteps) linksPerTimestep = hytra.core.jsongraph.getLinksPerTimestep(links, timesteps) divisionsPerTimestep = hytra.core.jsongraph.getDivisionsPerTimestep(divisions, linksPerTimestep, timesteps) mergerLinks = hytra.core.jsongraph.getMergerLinks(linksPerTimestep, self.mergersPerTimestep, timesteps) # set up unresolved graph and then refine the nodes to get the resolved graph self._createUnresolvedGraph(divisionsPerTimestep, self.mergersPerTimestep, mergerLinks) self._prepareResolvedGraph() self._fitAndRefineNodes(self.detectionsPerTimestep, self.mergersPerTimestep, timesteps) # ------------------------------------------------------------ # compute new object features objectFeatures = self._computeObjectFeatures(timesteps) # ------------------------------------------------------------ # load transition classifier if any if transition_classifier_filename is not None: getLogger().info("\tLoading transition classifier") transitionClassifier = probabilitygenerator.RandomForestClassifier( transition_classifier_path, transition_classifier_filename) else: getLogger().info("\tUsing distance based transition energies") transitionClassifier = None # ------------------------------------------------------------ # run min-cost max-flow to find merger assignments getLogger().info("Running min-cost max-flow to find resolved merger assignments") nodeFlowMap, arcFlowMap = self._minCostMaxFlowMergerResolving(objectFeatures, transitionClassifier) # ------------------------------------------------------------ # fuse results into a new solution # 1.) replace merger nodes in JSON graph by their replacements -> new JSON graph # update UUID to traxel map. # a) how do we deal with the smaller number of states? # Does it matter as we're done with tracking anyway..? def mergerNodeFilter(jsonNode): uuid = int(jsonNode['id']) traxels = uuidToTraxelMap[uuid] return not any(t[1] in self.mergersPerTimestep[str(t[0])] for t in traxels) def mergerLinkFilter(jsonLink): srcUuid = int(jsonLink['src']) destUuid = int(jsonLink['dest']) srcTraxels = uuidToTraxelMap[srcUuid] destTraxels = uuidToTraxelMap[destUuid] # return True if there was no traxel in either source or target node that was a merger. return not (any(t[1] in self.mergersPerTimestep[str(t[0])] for t in srcTraxels) or any(t[1] in self.mergersPerTimestep[str(t[0])] for t in destTraxels)) self.model = self._refineModel(uuidToTraxelMap, traxelIdPerTimestepToUniqueIdMap, mergerNodeFilter, mergerLinkFilter) # 2.) new result = union(old result, resolved mergers) - old mergers self.result = self._refineResult(nodeFlowMap, arcFlowMap, traxelIdPerTimestepToUniqueIdMap, mergerNodeFilter, mergerLinkFilter) # 3.) export refined segmentation self._exportRefinedSegmentation(timesteps) # return a dictionary telling about which mergers were resolved into what mergerDict = {} for n in self.unresolvedGraph.nodes_iter(): # skip non-mergers if not 'newIds' in self.unresolvedGraph.node[n] or len(self.unresolvedGraph.node[n]['newIds']) < 2: continue mergerDict.setdefault(n[0], {})[n[1]] = self.unresolvedGraph.node[n]['newIds'] return mergerDict def relabelMergers(self, labelImage, time): """ Calls the merger resolving plugin to relabel the mergers based on a previously found fit, which is stored in the hypotheses graph node """ t = str(time) if self.detectionsPerTimestep is not None and t in self.detectionsPerTimestep: for idx in self.detectionsPerTimestep[t]: node = (time, idx) if idx not in self.mergersPerTimestep[t]: continue # use fits stored in graph fits = self.unresolvedGraph.node[node]['fits'] newIds = self.unresolvedGraph.node[node]['newIds'] # use merger resolving plugin to update labelImage with merger IDs self.mergerResolverPlugin.updateLabelImage(labelImage, idx, fits, newIds) return labelImage
class IlpProbabilityGenerator(ProbabilityGenerator): """ The IlpProbabilityGenerator is a python wrapper around pgmlink's C++ traxelstore, but with the functionality to compute all region features and evaluate the division/count/transition classifiers. """ def __init__(self, ilpOptions, turnOffFeatures=[], useMultiprocessing=True, pluginPaths=['hytra/plugins'], verbose=False): self._useMultiprocessing = useMultiprocessing self._options = ilpOptions self._pluginPaths = pluginPaths self._pluginManager = TrackingPluginManager( turnOffFeatures=turnOffFeatures, verbose=verbose, pluginPaths=pluginPaths) self._pluginManager.setImageProvider(ilpOptions.imageProviderName) self._pluginManager.setFeatureSerializer( ilpOptions.featureSerializerName) self._countClassifier = None self._divisionClassifier = None self._transitionClassifier = None self._loadClassifiers() self.shape, self.timeRange = self._getShapeAndTimeRange() # set default division feature names self._divisionFeatureNames = [ 'ParentChildrenRatio_Count', 'ParentChildrenRatio_Mean', 'ChildrenRatio_Count', 'ChildrenRatio_Mean', 'ParentChildrenAngle_RegionCenter', 'ChildrenRatio_SquaredDistances' ] # other parameters that one might want to set self.x_scale = 1.0 self.y_scale = 1.0 self.z_scale = 1.0 self.divisionProbabilityFeatureName = 'divProb' self.detectionProbabilityFeatureName = 'detProb' self.TraxelsPerFrame = {} ''' this public variable contains all traxels if we're not using pgmlink ''' def _loadClassifiers(self): if self._options.objectCountClassifierPath != None and self._options.objectCountClassifierFilename != None: self._countClassifier = RandomForestClassifier( self._options.objectCountClassifierPath, self._options.objectCountClassifierFilename, self._options) if self._options.divisionClassifierPath != None and self._options.divisionClassifierFilename != None: self._divisionClassifier = RandomForestClassifier( self._options.divisionClassifierPath, self._options.divisionClassifierFilename, self._options) if self._options.transitionClassifierPath != None and self._options.transitionClassifierFilename != None: self._transitionClassifier = RandomForestClassifier( self._options.transitionClassifierPath, self._options.transitionClassifierFilename, self._options) def __getstate__(self): ''' We define __getstate__ and __setstate__ to exclude the random forests from being pickled, as that is not allowed. See https://docs.python.org/3/library/pickle.html#pickle-state for more details. ''' # Copy the object's state from self.__dict__ which contains # all our instance attributes. Always use the dict.copy() # method to avoid modifying the original state. state = self.__dict__.copy() # Remove the unpicklable entries. del state['_countClassifier'] del state['_divisionClassifier'] del state['_transitionClassifier'] return state def __setstate__(self, state): # Restore instance attributes self.__dict__.update(state) # Restore the random forests by reading them from scratch self._loadClassifiers() def computeRegionFeatures(self, rawImage, labelImage, frameNumber): """ Computes all region features for all objects in the given image """ assert (labelImage.dtype == np.uint32) moreFeats, ignoreNames = self._pluginManager.applyObjectFeatureComputationPlugins( len(labelImage.shape), rawImage, labelImage, frameNumber, self._options.rawImageFilename) frameFeatureItems = [] for f in moreFeats: frameFeatureItems = frameFeatureItems + f.items() frameFeatures = dict(frameFeatureItems) # delete the "Global<Min/Max>" features as they are not nice when iterating over everything for k in ignoreNames: if k in frameFeatures.keys(): del frameFeatures[k] return frameFeatures def computeDivisionFeatures(self, featuresAtT, featuresAtTPlus1, labelImageAtTPlus1): """ Computes the division features for all objects in the images """ fm = hytra.core.divisionfeatures.FeatureManager( ndim=self.getNumDimensions()) return fm.computeFeatures_at(featuresAtT, featuresAtTPlus1, labelImageAtTPlus1, self._divisionFeatureNames) def setDivisionFeatures(self, divisionFeatures): """ Set which features should be computed explicitly for divisions by giving a list of strings. Each string could be a combination of <operation>_<feature>, where Operation is one of: * ParentIdentity * SquaredDistances * ChildrenRatio * ParentChildrenAngle * ParentChildrenRatio And <feature> is any region feature plus "SquaredDistances" """ # TODO: check that the strings are valid? self._divisionFeatureNames = divisionFeatures def getNumDimensions(self): """ Compute the number of dimensions which is the number of axis with more than 1 element """ return np.count_nonzero(np.array(self.shape) != 1) def _getShapeAndTimeRange(self): """ extract the shape from the labelimage """ shape = self._pluginManager.getImageProvider().getImageShape( self._options.labelImageFilename, self._options.labelImagePath) timerange = self._pluginManager.getImageProvider().getTimeRange( self._options.labelImageFilename, self._options.labelImagePath) return shape, timerange def getLabelImageForFrame(self, timeframe): """ Get the label image(volume) of one time frame """ rawImage = self._pluginManager.getImageProvider( ).getLabelImageForFrame(self._options.labelImageFilename, self._options.labelImagePath, timeframe) return rawImage def getRawImageForFrame(self, timeframe): """ Get the raw image(volume) of one time frame """ rawImage = self._pluginManager.getImageProvider( ).getImageDataAtTimeFrame(self._options.rawImageFilename, self._options.rawImagePath, timeframe) return rawImage def _extractFeaturesForFrame(self, timeframe): """ extract the features of one frame, return a dictionary of features, where each feature vector contains N entries per object (where N is the dimensionality of the feature) """ rawImage = self.getRawImageForFrame(timeframe) labelImage = self.getLabelImageForFrame(timeframe) return timeframe, self.computeRegionFeatures(rawImage, labelImage, timeframe) def _extractDivisionFeaturesForFrame(self, timeframe, featuresPerFrame): """ extract Division Features for one frame, and store them in the given featuresPerFrame dict """ feats = {} if timeframe + 1 < self.timeRange[1]: labelImageAtTPlus1 = self.getLabelImageForFrame(timeframe + 1) feats = self.computeDivisionFeatures( featuresPerFrame[timeframe], featuresPerFrame[timeframe + 1], labelImageAtTPlus1) return timeframe, feats def _extractAllFeatures(self, dispyNodeIps=[], turnOffFeatures=[]): """ Extract the features of all frames. If a list of IP addresses is given e.g. as `dispyNodeIps = ["104.197.178.206","104.196.46.138"]`, then the computation will be distributed across these nodes. Otherwise, multiprocessing will be used if `self._useMultiprocessing=True`, which it is by default. If `dispyNodeIps` is an empty list, then the feature extraction will be parallelized via multiprocessing. **TODO:** fix division feature computation for distributed mode """ import logging # configure progress bar numSteps = self.timeRange[1] - self.timeRange[0] if self._divisionClassifier is not None: numSteps *= 2 t0 = time.time() if (len(dispyNodeIps) == 0): # no dispy node IDs given, parallelize object feature computation via processes if self._useMultiprocessing: # use ProcessPoolExecutor, which instanciates as many processes as there CPU cores by default ExecutorType = concurrent.futures.ProcessPoolExecutor logging.getLogger('Traxelstore').info( 'Parallelizing feature extraction via multiprocessing on all cores!' ) else: ExecutorType = DummyExecutor logging.getLogger('Traxelstore').info( 'Running feature extraction on single core!') featuresPerFrame = {} progressBar = ProgressBar(stop=numSteps) progressBar.show(increase=0) with ExecutorType() as executor: # 1st pass for region features jobs = [] for frame in range(self.timeRange[0], self.timeRange[1]): jobs.append( executor.submit(computeRegionFeaturesOnCloud, frame, self._options.rawImageFilename, self._options.rawImagePath, self._options.rawImageAxes, self._options.labelImageFilename, self._options.labelImagePath, turnOffFeatures, self._pluginPaths)) for job in concurrent.futures.as_completed(jobs): progressBar.show() frame, feats = job.result() featuresPerFrame[frame] = feats # 2nd pass for division features if self._divisionClassifier is not None: jobs = [] for frame in range(self.timeRange[0], self.timeRange[1] - 1): jobs.append( executor.submit( computeDivisionFeaturesOnCloud, frame, featuresPerFrame[frame], featuresPerFrame[frame + 1], self._pluginManager.getImageProvider(), self._options.labelImageFilename, self._options.labelImagePath, self.getNumDimensions(), self._divisionFeatureNames)) for job in concurrent.futures.as_completed(jobs): progressBar.show() frame, feats = job.result() featuresPerFrame[frame].update(feats) # # serialize features?? # for frame in range(self.timeRange[0], self.timeRange[1]): # featureSerializer.storeFeaturesForFrame(featuresPerFrame[frame], frame) else: import logging logging.getLogger('Traxelstore').warning( 'Parallelization with dispy is WORK IN PROGRESS!') import random import dispy cluster = dispy.JobCluster(computeRegionFeaturesOnCloud, nodes=dispyNodeIps, loglevel=logging.DEBUG, depends=[self._pluginManager], secret="teamtracking") jobs = [] for frame in range(self.timeRange[0], self.timeRange[1]): job = cluster.submit( frame, self._options.rawImageFilename, self._options.rawImagePath, self._options.rawImageAxes, self._options.labelImageFilename, self._options.labelImagePath, turnOffFeatures, pluginPaths=['/home/carstenhaubold/embryonic/plugins']) job.id = frame jobs.append(job) for job in jobs: job() # wait for job to finish print job.exception print job.stdout print job.stderr print job.id logging.getLogger('Traxelstore').warning( 'Using dispy we cannot compute division features yet!') # # 2nd pass for division features # if self._divisionClassifier is not None: # for frame in range(self.timeRange[0], self.timeRange[1]): # progressBar.show() # featuresPerFrame[frame].update(self._extractDivisionFeaturesForFrame(frame, featuresPerFrame)[1]) t1 = time.time() getLogger().info("Feature computation took {} secs".format(t1 - t0)) return featuresPerFrame def _setTraxelFeatureArray(self, traxel, featureArray, name): ''' store the specified `featureArray` in a `traxel`'s feature dictionary under the specified key=`name` ''' if isinstance(featureArray, np.ndarray): featureArray = featureArray.flatten() traxel.add_feature_array(name, len(featureArray)) for i, v in enumerate(featureArray): traxel.set_feature_value(name, i, float(v)) def fillTraxels(self, usePgmlink=True, ts=None, fs=None, dispyNodeIps=[], turnOffFeatures=[]): """ Compute all the features and predict object count as well as division probabilities. Store the resulting information (and all other features) in the given pgmlink::TraxelStore, or create a new one if ts=None. usePgmlink: boolean whether pgmlink should be used and a pgmlink.TraxelStore and pgmlink.FeatureStore returned ts: an initial pgmlink.TraxelStore (only used if usePgmlink=True) fs: an initial pgmlink.FeatureStore (only used if usePgmlink=True) returns (ts, fs) but only if usePgmlink=True, otherwise it fills self.TraxelsPerFrame """ if usePgmlink: import pgmlink if ts is None: ts = pgmlink.TraxelStore() fs = pgmlink.FeatureStore() else: assert (fs is not None) getLogger().info("Extracting features...") self._featuresPerFrame = self._extractAllFeatures( dispyNodeIps=dispyNodeIps, turnOffFeatures=turnOffFeatures) getLogger().info("Creating traxels...") progressBar = ProgressBar(stop=len(self._featuresPerFrame)) progressBar.show(increase=0) for frame, features in self._featuresPerFrame.iteritems(): # predict random forests if self._countClassifier is not None: objectCountProbabilities = self._countClassifier.predictProbabilities( features=None, featureDict=features) if self._divisionClassifier is not None and frame + 1 < self.timeRange[ 1]: divisionProbabilities = self._divisionClassifier.predictProbabilities( features=None, featureDict=features) # create traxels for all objects for objectId in range(1, features.values()[0].shape[0]): # print("Frame {} Object {}".format(frame, objectId)) pixelSize = features['Count'][objectId] if pixelSize == 0 or (self._options.sizeFilter is not None \ and (pixelSize < self._options.sizeFilter[0] \ or pixelSize > self._options.sizeFilter[1])): continue # create traxel if usePgmlink: traxel = pgmlink.Traxel() else: traxel = Traxel() traxel.Id = objectId traxel.Timestep = frame # add raw features for key, val in features.iteritems(): if key == 'id': traxel.idInSegmentation = val[objectId] elif key == 'filename': traxel.segmentationFilename = val[objectId] else: try: if isinstance( val, list): # polygon feature returns a list! featureValues = val[objectId] else: featureValues = val[objectId, ...] except: getLogger().error( "Could not get feature values of {} for key {} from matrix with shape {}" .format(objectId, key, val.shape)) raise AssertionError() try: self._setTraxelFeatureArray( traxel, featureValues, key) if key == 'RegionCenter': self._setTraxelFeatureArray( traxel, featureValues, 'com') except: getLogger().error( "Could not add feature array {} for {}".format( featureValues, key)) raise AssertionError() # add random forest predictions if self._countClassifier is not None: self._setTraxelFeatureArray( traxel, objectCountProbabilities[objectId, :], self.detectionProbabilityFeatureName) if self._divisionClassifier is not None and frame + 1 < self.timeRange[ 1]: self._setTraxelFeatureArray( traxel, divisionProbabilities[objectId, :], self.divisionProbabilityFeatureName) # set other parameters traxel.set_x_scale(self.x_scale) traxel.set_y_scale(self.y_scale) traxel.set_z_scale(self.z_scale) if usePgmlink: # add to pgmlink's traxelstore ts.add(fs, traxel) else: self.TraxelsPerFrame.setdefault(frame, {})[objectId] = traxel progressBar.show() if usePgmlink: return ts, fs def getTraxelFeatureDict(self, frame, objectId): """ Getter method for features per traxel """ assert self._featuresPerFrame != None traxelFeatureDict = {} for k, v in self._featuresPerFrame[frame].iteritems(): if 'Polygon' in k: traxelFeatureDict[k] = v[objectId] else: traxelFeatureDict[k] = v[objectId, ...] return traxelFeatureDict def getTransitionFeatureVector(self, featureDictObjectA, featureDictObjectB, selectedFeatures): """ Return component wise difference and product of the selected features as input for the TransitionClassifier """ features = np.array( self._pluginManager. applyTransitionFeatureVectorConstructionPlugins( featureDictObjectA, featureDictObjectB, selectedFeatures)) features = np.expand_dims(features, axis=0) return features
def run_pipeline(options): """ Run the complete tracking pipeline with competing segmentation hypotheses """ # set up probabilitygenerator (aka traxelstore) and hypothesesgraph or load them from a dump if options.load_graph_filename is not None: getLogger().info("Loading state from file: " + options.load_graph_filename) with gzip.open(options.load_graph_filename, 'r') as graphDump: ilpOptions = pickle.load(graphDump) probGenerator = pickle.load(graphDump) fieldOfView = pickle.load(graphDump) hypotheses_graph = pickle.load(graphDump) trackingGraph = pickle.load(graphDump) getLogger().info("Done loading state from file") else: fieldOfView, hypotheses_graph, ilpOptions, probGenerator, trackingGraph = setupGraph(options) if options.dump_graph_filename is not None: getLogger().info("Saving state to file: " + options.dump_graph_filename) with gzip.open(options.dump_graph_filename, 'w') as graphDump: pickle.dump(ilpOptions, graphDump) pickle.dump(probGenerator, graphDump) pickle.dump(fieldOfView, graphDump) pickle.dump(hypotheses_graph, graphDump) pickle.dump(trackingGraph, graphDump) getLogger().info("Done saving state to file") # map groundtruth to hypothesesgraph if all required variables are specified weights = None if options.gt_label_image_file is not None and options.gt_label_image_path is not None \ and options.gt_text_file is not None and options.gt_jaccard_threshold is not None: weights = mapGroundTruth(options, hypotheses_graph, trackingGraph, probGenerator) # track result = runTracking(options, trackingGraph, weights) # insert the solution into the hypotheses graph and from that deduce the lineages getLogger().info("Inserting solution into graph") hypotheses_graph.insertSolution(result) hypotheses_graph.computeLineage() mappings = {} # dictionary over timeframes, containing another dict objectId -> trackId per frame tracks = {} # stores a list of timeframes per track, so that we can find from<->to per track trackParents = {} # store the parent trackID of a track if known for n in hypotheses_graph.nodeIterator(): frameMapping = mappings.setdefault(n[0], {}) if 'trackId' not in hypotheses_graph._graph.node[n]: raise ValueError("You need to compute the Lineage of every node before accessing the trackId!") trackId = hypotheses_graph._graph.node[n]['trackId'] traxel = hypotheses_graph._graph.node[n]['traxel'] if trackId is not None: frameMapping[(traxel.idInSegmentation, traxel.segmentationFilename)] = trackId if trackId in tracks: tracks[trackId].append(n[0]) else: tracks[trackId] = [n[0]] if 'parent' in hypotheses_graph._graph.node[n]: assert(trackId not in trackParents) trackParents[trackId] = hypotheses_graph._graph.node[hypotheses_graph._graph.node[n]['parent']]['trackId'] # write res_track.txt getLogger().info("Writing track text file") trackDict = {} for trackId, timestepList in tracks.iteritems(): timestepList.sort() try: parent = trackParents[trackId] except KeyError: parent = 0 trackDict[trackId] = [parent, min(timestepList), max(timestepList)] save_tracks(trackDict, options) # export results getLogger().info("Saving relabeled images") pluginManager = TrackingPluginManager(verbose=options.verbose, pluginPaths=options.pluginPaths) pluginManager.setImageProvider('LocalImageLoader') imageProvider = pluginManager.getImageProvider() timeRange = probGenerator.timeRange for timeframe in range(timeRange[0], timeRange[1]): label_images = {} for f, p in zip(options.label_image_files, options.label_image_paths): label_images[f] = imageProvider.getLabelImageForFrame(f, p, timeframe) remapped_label_image = remap_label_image(label_images, mappings[timeframe]) save_frame_to_tif(timeframe, remapped_label_image, options)
# Do the tracking start = time.time() feature_path = options.feats_path with_div = not bool(options.without_divisions) with_merger_prior = True # get selected time range time_range = [options.mints, options.maxts] if options.maxts == -1 and options.mints == 0: time_range = None try: # find shape of dataset pluginManager = TrackingPluginManager(verbose=options.verbose, pluginPaths=options.pluginPaths) shape = pluginManager.getImageProvider().getImageShape(ilp_fn, options.label_img_path) data_time_range = pluginManager.getImageProvider().getTimeRange(ilp_fn, options.label_img_path) if time_range is not None and time_range[1] < 0: time_range[1] += data_time_range[1] except: logging.warning("Could not read shape and time range from images") # set average object size if chosen obj_size = [0] if options.avg_obj_size != 0: obj_size[0] = options.avg_obj_size else: options.avg_obj_size = obj_size
class MergerResolver(object): """ Base class for all merger resolving implementations. Use one of the derived classes that handle reading/writing data to the respective sources. """ def __init__(self, pluginPaths=[os.path.abspath('../hytra/plugins')], numSplits=None, verbose=False, progressVisitor=DefaultProgressVisitor()): self.unresolvedGraph = None self.resolvedGraph = None self.mergersPerTimestep = None self.detectionsPerTimestep = None self.pluginManager = TrackingPluginManager(verbose=verbose, pluginPaths=pluginPaths) self.mergerResolverPlugin = self.pluginManager.getMergerResolver() self.numSplits = numSplits # should be filled by constructors of derived classes! self.model = None self.result = None self.progressVisitor = progressVisitor def _createUnresolvedGraph(self, divisionsPerTimestep, mergersPerTimestep, mergerLinks, withFullGraph=False): """ Set up a networkx graph consisting of mergers that need to be resolved (not resolved yet!) and their direct neighbors. ** returns ** the `unresolvedGraph` """ self.unresolvedGraph = nx.DiGraph() def source(timestep, link): return int(timestep) - 1, link[0] def target(timestep, link): return int(timestep), link[1] # Recompute full graph if withFullGraph: self.unresolvedGraph = self.hypothesesGraph._graph.copy() # Add division parameter to nodes # TODO: Add the division parameter only to nodes that contain divisions (we're already doing these with 'count') lastframe = max(divisionsPerTimestep.keys(), key=int) for node in self.unresolvedGraph.nodes_iter(): timestep, idx = node if divisionsPerTimestep is not None and int(timestep) < int( lastframe): division = idx in divisionsPerTimestep[str( timestep + 1)] # +1 screams for lastframe condition. else: division = False self.unresolvedGraph.node[node]['division'] = division # Add count parameter to nodes for t, link in mergerLinks: for node in [source(t, link), target(t, link)]: timestep, idx = node if idx in mergersPerTimestep[str(timestep)]: count = mergersPerTimestep[str(timestep)][idx] self.unresolvedGraph.node[node]['count'] = count # Recompute graph only with merger nodes and neighbors else: def addNode(node): ''' add a node to the unresolved graph and fill in the properties `division` and `count` ''' intT, idx = node lastframe = max(divisionsPerTimestep.keys(), key=int) if divisionsPerTimestep is not None and int(intT) < int( lastframe): division = idx in divisionsPerTimestep[str( intT + 1)] # +1 screams for lastframe condition. else: division = False count = 1 if idx in mergersPerTimestep[str(intT)]: assert (not division) count = mergersPerTimestep[str(intT)][idx] self.unresolvedGraph.add_node(node, division=division, count=count) # add nodes for t, link in mergerLinks: for n in [source(t, link), target(t, link)]: if not self.unresolvedGraph.has_node(n): addNode(n) self.unresolvedGraph.add_edge(source(t, link), target(t, link)) return self.unresolvedGraph def _prepareResolvedGraph(self): """ ** returns ** the `resolvedGraph` """ self.resolvedGraph = self.unresolvedGraph.copy() return self.resolvedGraph def _readLabelImage(self, timeframe): ''' Should return the labelimage for the given timeframe ''' raise NotImplementedError() def _fitAndRefineNodes(self, detectionsPerTimestep, mergersPerTimestep, timesteps): ''' Update segmentation of mergers (nodes in unresolvedGraph) from first timeframe to last and create new nodes in `resolvedGraph`. Links to merger nodes are duplicated to all new nodes. Uses the mergerResolver plugin to update the segmentations in the labelImages. ''' intTimesteps = [int(t) for t in timesteps] intTimesteps.sort() for intT in intTimesteps: t = str(intT) # use image provider plugin to load labelimage labelImage = self._readLabelImage(int(t)) nextObjectId = labelImage.max() + 1 for idx in detectionsPerTimestep[t]: node = (intT, idx) if node not in self.resolvedGraph: continue count = 1 if idx in mergersPerTimestep[t]: count = mergersPerTimestep[t][idx] getLogger().debug( "Looking at node {} in timestep {} with count {}".format( idx, t, count)) # collect initializations from incoming initializations = [] for predecessor, _ in self.unresolvedGraph.in_edges(node): initializations.extend( self.unresolvedGraph.node[predecessor]['fits']) # TODO: what shall we do if e.g. a 2-merger and a single object merge to 2 + 1, # so there are 3 initializations for the 2-merger, and two initializations for the 1 merger? # What does pgmlink do in that case? # use merger resolving plugin to fit `count` objects, also updates labelimage! fittedObjects = list( self.mergerResolverPlugin.resolveMerger( labelImage, idx, nextObjectId, count, initializations)) assert (len(fittedObjects) == count) # split up node if count > 1, duplicate incoming and outgoing arcs if count > 1: for idx in range(nextObjectId, nextObjectId + count): newNode = (intT, idx) self.resolvedGraph.add_node(newNode, division=False, count=1, origin=node) for e in self.unresolvedGraph.out_edges(node): self.resolvedGraph.add_edge(newNode, e[1]) for e in self.unresolvedGraph.in_edges(node): if 'newIds' in self.unresolvedGraph.node[e[0]]: for newId in self.unresolvedGraph.node[ e[0]]['newIds']: self.resolvedGraph.add_edge( (e[0][0], newId), newNode) else: self.resolvedGraph.add_edge(e[0], newNode) self.resolvedGraph.remove_node(node) self.unresolvedGraph.node[node]['newIds'] = range( nextObjectId, nextObjectId + count) nextObjectId += count # each unresolved node stores its fitted shape(s) to be used # as initialization in the next frame, this way division duplicates # and de-merged nodes in the resolved graph do not need to store a fit as well self.unresolvedGraph.node[node]['fits'] = fittedObjects # import matplotlib.pyplot as plt # nx.draw_networkx(resolvedGraph) # plt.savefig("/Users/chaubold/test.pdf") def _minCostMaxFlowMergerResolving(self, objectFeatures, transitionClassifier=None, transitionParameter=5.0): """ Find the optimal assignments within the `resolvedGraph` by running min-cost max-flow from the `dpct` module. Converts the `resolvedGraph` to our JSON model structure, predicts the transition probabilities either using the given transitionClassifier, or using distance-based probabilities. **returns** a `nodeFlowMap` and `arcFlowMap` holding information on the usage of the respective nodes and links **Note:** cannot use `networkx` flow methods because they don't work with floating point weights. """ trackingGraph = JsonTrackingGraph(progressVisitor=self.progressVisitor) for node in self.resolvedGraph.nodes_iter(): additionalFeatures = {} additionalFeatures['nid'] = node # nodes with no in/out numStates = 2 if len(self.resolvedGraph.in_edges(node)) == 0: # division nodes with no incoming arcs offer 2 units of flow without the need to de-merge if node in self.unresolvedGraph.nodes( ) and self.unresolvedGraph.node[node]['division'] and len( self.unresolvedGraph.out_edges(node)) == 2: numStates = 3 additionalFeatures['appearanceFeatures'] = [ [i**2 * 0.01] for i in range(numStates) ] if len(self.resolvedGraph.out_edges(node)) == 0: assert ( numStates == 2 ) # division nodes with no incoming should have outgoing, or they shouldn't show up in resolved graph additionalFeatures['disappearanceFeatures'] = [ [i**2 * 0.01] for i in range(numStates) ] features = [[i**2] for i in range(numStates)] uuid = trackingGraph.addDetectionHypotheses( features, **additionalFeatures) self.resolvedGraph.node[node]['id'] = uuid for edge in self.resolvedGraph.edges_iter(): src = self.resolvedGraph.node[edge[0]]['id'] dest = self.resolvedGraph.node[edge[1]]['id'] featuresAtSrc = objectFeatures[edge[0]] featuresAtDest = objectFeatures[edge[1]] if transitionClassifier is not None: try: featVec = self.pluginManager.applyTransitionFeatureVectorConstructionPlugins( featuresAtSrc, featuresAtDest, transitionClassifier.selectedFeatures) except: getLogger().error( "Could not compute transition features of link {}->{}:" .format(src, dest)) getLogger().error(featuresAtSrc) getLogger().error(featuresAtDest) raise featVec = np.expand_dims(np.array(featVec), axis=0) probs = transitionClassifier.predictProbabilities(featVec)[0] else: dist = np.linalg.norm(featuresAtDest['RegionCenter'] - featuresAtSrc['RegionCenter']) prob = np.exp(-dist / transitionParameter) probs = [1.0 - prob, prob] trackingGraph.addLinkingHypotheses(src, dest, listify(negLog(probs))) # Set TraxelToUniqueId on resolvedGraph's json graph uuidToTraxelMap = {} traxelIdPerTimestepToUniqueIdMap = {} for node in self.resolvedGraph.nodes_iter(): uuid = self.resolvedGraph.node[node]['id'] uuidToTraxelMap[uuid] = [node] for t in uuidToTraxelMap[uuid]: traxelIdPerTimestepToUniqueIdMap.setdefault(str(t[0]), {})[str( t[1])] = uuid trackingGraph.setTraxelToUniqueId(traxelIdPerTimestepToUniqueIdMap) # track import dpct weights = {"weights": [1, 1, 1, 1]} if not self.numSplits: mergerResult = dpct.trackMaxFlow(trackingGraph.model, weights) else: getLogger().info("Running split tracking with {} splits.".format( self.numSplits)) mergerResult = SplitTracking.trackFlowBasedWithSplits( trackingGraph.model, weights, numSplits=self.numSplits, withMergerResolver=True) # transform results to dictionaries that can be indexed by id or (src,dest) nodeFlowMap = dict([(int(d['id']), int(d['value'])) for d in mergerResult['detectionResults']]) arcFlowMap = dict([((int(l['src']), int(l['dest'])), int(l['value'])) for l in mergerResult['linkingResults']]) return nodeFlowMap, arcFlowMap def _refineModel(self, uuidToTraxelMap, traxelIdPerTimestepToUniqueIdMap, mergerNodeFilter, mergerLinkFilter): """ Take the `self.model` (JSON format) with mergers, remove the merger nodes, but add new de-merged nodes and links. Also updates `traxelIdPerTimestepToUniqueIdMap` locally and in the resulting file, such that the traxel IDs match the new connected component IDs in the refined images. `mergerNodeFilter` and `mergerLinkFilter` are methods that can filter merger detections and links from the respective lists in the `model` dict. **Returns** the updated `model` dictionary, which is the same as the input `model` (works in-place) """ # remove merger detections self.model['segmentationHypotheses'] = [ seg for seg in self.model['segmentationHypotheses'] if mergerNodeFilter(seg) ] # remove merger links self.model['linkingHypotheses'] = [ link for link in self.model['linkingHypotheses'] if mergerLinkFilter(link) ] # insert new nodes and update UUID to traxel map nextUuid = max(uuidToTraxelMap.keys()) + 1 for node in self.unresolvedGraph.nodes_iter(): if 'count' in self.unresolvedGraph.node[ node] and self.unresolvedGraph.node[node]['count'] > 1: newIds = self.unresolvedGraph.node[node]['newIds'] del traxelIdPerTimestepToUniqueIdMap[str(node[0])][str( node[1])] for newId in newIds: newDetection = {} newDetection['id'] = nextUuid newDetection['timestep'] = [node[0], node[0]] self.model['segmentationHypotheses'].append(newDetection) traxelIdPerTimestepToUniqueIdMap[str( node[0])][str(newId)] = nextUuid nextUuid += 1 # insert new links for edge in self.resolvedGraph.edges_iter(): newLink = {} newLink['src'] = traxelIdPerTimestepToUniqueIdMap[str( edge[0][0])][str(edge[0][1])] newLink['dest'] = traxelIdPerTimestepToUniqueIdMap[str( edge[1][0])][str(edge[1][1])] self.model['linkingHypotheses'].append(newLink) # save return self.model def _refineResult(self, nodeFlowMap, arcFlowMap, traxelIdPerTimestepToUniqueIdMap, mergerNodeFilter, mergerLinkFilter): """ Update the `self.result` dict by removing the mergers and adding the refined nodes and links. Operates on a `result` dictionary in our JSON result style with mergers, the resolved and unresolved graph as well as the `nodeFlowMap` and `arcFlowMap` obtained by running tracking on the `resolvedGraph`. Updates the `result` dictionary so that all merger nodes are removed but the new nodes are contained with the appropriate links and values. `mergerNodeFilter` and `mergerLinkFilter` are methods that can filter merger detections and links from the respective lists in the `result` dict. **Returns** the updated `result` dict, which is the same as the input `result` (works in-place) """ # filter merger edges self.result['detectionResults'] = [ r for r in self.result['detectionResults'] if mergerNodeFilter(r) ] self.result['linkingResults'] = [ r for r in self.result['linkingResults'] if mergerLinkFilter(r) ] # add new nodes for node in self.unresolvedGraph.nodes_iter(): if 'count' in self.unresolvedGraph.node[ node] and self.unresolvedGraph.node[node]['count'] > 1: newIds = self.unresolvedGraph.node[node]['newIds'] for newId in newIds: uuid = traxelIdPerTimestepToUniqueIdMap[str( node[0])][str(newId)] resolvedNode = (node[0], newId) resolvedResultId = self.resolvedGraph.node[resolvedNode][ 'id'] newDetection = { 'id': uuid, 'value': nodeFlowMap[resolvedResultId] } self.result['detectionResults'].append(newDetection) # add new links for edge in self.resolvedGraph.edges_iter(): newLink = {} newLink['src'] = traxelIdPerTimestepToUniqueIdMap[str( edge[0][0])][str(edge[0][1])] newLink['dest'] = traxelIdPerTimestepToUniqueIdMap[str( edge[1][0])][str(edge[1][1])] srcId = self.resolvedGraph.node[edge[0]]['id'] destId = self.resolvedGraph.node[edge[1]]['id'] newLink['value'] = arcFlowMap[(srcId, destId)] self.result['linkingResults'].append(newLink) return self.result def _exportRefinedSegmentation(self, timesteps): """ Store the resulting label images, if needed. `labelImages` is a dictionary with str(timestep) as keys. """ pass def _computeObjectFeatures(self, timesteps): ''' Return the features per object as nested dictionaries: { (int(Timestep), int(Id)):{ "FeatureName" : np.array(value), "NextFeature": ...} } ''' pass # ------------------------------------------------------------ def run(self, transition_classifier_filename=None, transition_classifier_path=None): """ Run merger resolving 1. find mergers in the given model and result 2. build graph of the unresolved (merger) nodes and their direct neighbors 3. use a mergerResolving plugin to refine the merger nodes and their segmentation 4. run min-cost max-flow tracking to find the fate of all the de-merged objects 5. export refined segmentation, update member variables `model` and `result` **Returns** a nested dictionary, indexed first by time, then object Id, containing a list of new segmentIDs per merger """ traxelIdPerTimestepToUniqueIdMap, uuidToTraxelMap = hytra.core.jsongraph.getMappingsBetweenUUIDsAndTraxels( self.model) # timesteps = [t for t in traxelIdPerTimestepToUniqueIdMap.keys()] # there might be empty frames. We want them as output too. timesteps = [ str(t) for t in range( int(min(traxelIdPerTimestepToUniqueIdMap.keys())), max([ int(idx) for idx in traxelIdPerTimestepToUniqueIdMap.keys() ]) + 1) ] mergers, detections, links, divisions = hytra.core.jsongraph.getMergersDetectionsLinksDivisions( self.result, uuidToTraxelMap) # ------------------------------------------------------------ # it may be, that there are no mergers, so do basically nothing, just copy all the ingoing data if len(mergers) == 0: getLogger().info( "The maximum number of objects is 1, so nothing to be done. Writing the output..." ) self._exportRefinedSegmentation(timesteps) else: self.mergersPerTimestep = hytra.core.jsongraph.getMergersPerTimestep( mergers, timesteps) self.detectionsPerTimestep = hytra.core.jsongraph.getDetectionsPerTimestep( detections, timesteps) linksPerTimestep = hytra.core.jsongraph.getLinksPerTimestep( links, timesteps) divisionsPerTimestep = hytra.core.jsongraph.getDivisionsPerTimestep( divisions, linksPerTimestep, timesteps) mergerLinks = hytra.core.jsongraph.getMergerLinks( linksPerTimestep, self.mergersPerTimestep, timesteps) # set up unresolved graph and then refine the nodes to get the resolved graph self._createUnresolvedGraph(divisionsPerTimestep, self.mergersPerTimestep, mergerLinks) self._prepareResolvedGraph() self._fitAndRefineNodes(self.detectionsPerTimestep, self.mergersPerTimestep, timesteps) # ------------------------------------------------------------ # compute new object features objectFeatures = self._computeObjectFeatures(timesteps) # ------------------------------------------------------------ # load transition classifier if any if transition_classifier_filename is not None: getLogger().info("\tLoading transition classifier") transitionClassifier = probabilitygenerator.RandomForestClassifier( transition_classifier_path, transition_classifier_filename) else: getLogger().info("\tUsing distance based transition energies") transitionClassifier = None # ------------------------------------------------------------ # run min-cost max-flow to find merger assignments getLogger().info( "Running min-cost max-flow to find resolved merger assignments" ) nodeFlowMap, arcFlowMap = self._minCostMaxFlowMergerResolving( objectFeatures, transitionClassifier) # ------------------------------------------------------------ # fuse results into a new solution # 1.) replace merger nodes in JSON graph by their replacements -> new JSON graph # update UUID to traxel map. # a) how do we deal with the smaller number of states? # Does it matter as we're done with tracking anyway..? def mergerNodeFilter(jsonNode): uuid = int(jsonNode['id']) traxels = uuidToTraxelMap[uuid] return not any(t[1] in self.mergersPerTimestep[str(t[0])] for t in traxels) def mergerLinkFilter(jsonLink): srcUuid = int(jsonLink['src']) destUuid = int(jsonLink['dest']) srcTraxels = uuidToTraxelMap[srcUuid] destTraxels = uuidToTraxelMap[destUuid] # return True if there was no traxel in either source or target node that was a merger. return not (any(t[1] in self.mergersPerTimestep[str(t[0])] for t in srcTraxels) or any(t[1] in self.mergersPerTimestep[str(t[0])] for t in destTraxels)) self.model = self._refineModel(uuidToTraxelMap, traxelIdPerTimestepToUniqueIdMap, mergerNodeFilter, mergerLinkFilter) # 2.) new result = union(old result, resolved mergers) - old mergers self.result = self._refineResult(nodeFlowMap, arcFlowMap, traxelIdPerTimestepToUniqueIdMap, mergerNodeFilter, mergerLinkFilter) # 3.) export refined segmentation self._exportRefinedSegmentation(timesteps) # return a dictionary telling about which mergers were resolved into what mergerDict = {} for n in self.unresolvedGraph.nodes_iter(): # skip non-mergers if not 'newIds' in self.unresolvedGraph.node[n] or len( self.unresolvedGraph.node[n]['newIds']) < 2: continue mergerDict.setdefault( n[0], {})[n[1]] = self.unresolvedGraph.node[n]['newIds'] return mergerDict def relabelMergers(self, labelImage, time): """ Calls the merger resolving plugin to relabel the mergers based on a previously found fit, which is stored in the hypotheses graph node """ t = str(time) if self.detectionsPerTimestep is not None and t in self.detectionsPerTimestep: for idx in self.detectionsPerTimestep[t]: node = (time, idx) if idx not in self.mergersPerTimestep[t]: continue # use fits stored in graph fits = self.unresolvedGraph.node[node]['fits'] newIds = self.unresolvedGraph.node[node]['newIds'] # use merger resolving plugin to update labelImage with merger IDs self.mergerResolverPlugin.updateLabelImage( labelImage, idx, fits, newIds) return labelImage
def writeEvents(timestep, activeLinks, activeDivisions, mergers, detections, fn, labelImagePath, ilpFilename, verbose, pluginPaths): dis = [] app = [] div = [] mov = [] mer = [] mul = [] pluginManager = TrackingPluginManager(verbose=verbose, pluginPaths=pluginPaths) logging.getLogger('json_result_to_events.py').debug("-- Writing results to {}".format(fn)) try: # convert to ndarray for better indexing dis = np.asarray(dis) app = np.asarray(app) div = np.asarray([[k, v[0], v[1]] for k,v in activeDivisions.iteritems()]) mov = np.asarray(activeLinks) mer = np.asarray([[k,v] for k,v in mergers.iteritems()]) mul = np.asarray(mul) shape = pluginManager.getImageProvider().getImageShape(ilpFilename, labelImagePath) label_img = pluginManager.getImageProvider().getLabelImageForFrame(ilpFilename, labelImagePath, timestep) with h5py.File(fn, 'w') as dest_file: # write meta fields and copy segmentation from project seg = dest_file.create_group('segmentation') seg.create_dataset("labels", data=label_img, compression='gzip') meta = dest_file.create_group('objects/meta') ids = np.unique(label_img) ids = ids[ids > 0] valid = np.ones(ids.shape) meta.create_dataset("id", data=ids, dtype=np.uint32) meta.create_dataset("valid", data=valid, dtype=np.uint32) tg = dest_file.create_group("tracking") # write associations if app is not None and len(app) > 0: ds = tg.create_dataset("Appearances", data=app, dtype=np.int32) ds.attrs["Format"] = "cell label appeared in current file" if dis is not None and len(dis) > 0: ds = tg.create_dataset("Disappearances", data=dis, dtype=np.int32) ds.attrs["Format"] = "cell label disappeared in current file" if mov is not None and len(mov) > 0: ds = tg.create_dataset("Moves", data=mov, dtype=np.int32) ds.attrs["Format"] = "from (previous file), to (current file)" if div is not None and len(div) > 0: ds = tg.create_dataset("Splits", data=div, dtype=np.int32) ds.attrs["Format"] = "ancestor (previous file), descendant (current file), descendant (current file)" if mer is not None and len(mer) > 0: ds = tg.create_dataset("Mergers", data=mer, dtype=np.int32) ds.attrs["Format"] = "descendant (current file), number of objects" if mul is not None and len(mul) > 0: ds = tg.create_dataset("MultiFrameMoves", data=mul, dtype=np.int32) ds.attrs["Format"] = "from (given by timestep), to (current file), timestep" logging.getLogger('json_result_to_events.py').debug("-> results successfully written") except Exception as e: logging.getLogger('json_result_to_events.py').warning("ERROR while writing events: {}".format(str(e)))
def computeRegionFeaturesOnCloud( frame, rawImageFilename, rawImagePath, rawImageAxes, labelImageFilename, labelImagePath, turnOffFeatures, pluginPaths=['hytra/plugins'], featuresPerFrame=None, imageProviderPluginName='LocalImageLoader', featureSerializerPluginName='LocalFeatureSerializer'): ''' Allow to use dispy to schedule feature computation to nodes running a dispynode, or to use multiprocessing. **Parameters** * `frame`: the frame number * `rawImageFilename`: the base filename of the raw image volume, or a dvid server address * `rawImagePath`: path inside the raw image HDF5 file, or DVID dataset UUID * `rawImageAxes`: axes configuration of the raw data * `labelImageFilename`: the base filename of the label image volume, or a dvid server address * `labelImagePath`: path inside the label image HDF5 file, or DVID dataset UUID * `pluginPaths`: where all yapsy plugins are stored (should be absolute for DVID) **returns** the feature dictionary for this frame if `featureSerializerPluginName == 'LocalFeatureSerializer'` and `featuresPerFrame == None`. ''' # set up plugin manager from hytra.pluginsystem.plugin_manager import TrackingPluginManager pluginManager = TrackingPluginManager(pluginPaths=pluginPaths, turnOffFeatures=turnOffFeatures, verbose=False) pluginManager.setImageProvider(imageProviderPluginName) pluginManager.setFeatureSerializer(featureSerializerPluginName) # load raw and label image (depending on chosen plugin this works via DVID or locally) rawImage = pluginManager.getImageProvider().getImageDataAtTimeFrame( rawImageFilename, rawImagePath, rawImageAxes, frame) labelImage = pluginManager.getImageProvider().getLabelImageForFrame( labelImageFilename, labelImagePath, frame) # untwist axes, if just x and y are messed up if rawImage.shape[0] == labelImage.shape[1] and rawImage.shape[ 1] == labelImage.shape[0]: labelImage = np.transpose(labelImage, axes=[1, 0]) # compute features moreFeats, ignoreNames = pluginManager.applyObjectFeatureComputationPlugins( len(labelImage.shape), rawImage, labelImage, frame, rawImageFilename) # combine into one dictionary # WARNING: if there are multiple features with the same name, they will be overwritten! frameFeatureItems = [] for f in moreFeats: frameFeatureItems = frameFeatureItems + f.items() frameFeatures = dict(frameFeatureItems) # delete all ignored features for k in ignoreNames: if k in frameFeatures.keys(): del frameFeatures[k] # return or save features if featuresPerFrame is None and featureSerializerPluginName is 'LocalFeatureSerializer': # simply return resulting dict return frame, frameFeatures else: # set up feature serializer (local or DVID for now) featureSerializer = pluginManager.getFeatureSerializer() # server address and uuid are only used by the DVID serializer featureSerializer.server_address = labelImageFilename featureSerializer.uuid = labelImagePath # feature dictionary used by local serializer featureSerializer.features_per_frame = featuresPerFrame # store featureSerializer.storeFeaturesForFrame(frameFeatures, frame)