def computeJaccardScoresOnCloud(frame,
                                labelImageFilenames,
                                labelImagePaths,
                                labelImageFrameIdToGlobalId,
                                groundTruthFilename,
                                groundTruthPath,
                                groundTruthMinJaccardScore,
                                pluginPaths=['hytra/plugins'],
                                imageProviderPluginName='LocalImageLoader'):
    """
    Compute jaccard scores of all objects in the different segmentations with the ground truth for that frame.
    Returns a dictionary of overlapping GT labels and the score per globalId in that frame, as well as 
    a dictionary specifying the matching globalId and score for every GT label (as a list ordered by score, best match last).

    Meant to be run in its own process using `concurrent.futures.ProcessPoolExecutor`
    """

    # set up plugin manager
    from hytra.pluginsystem.plugin_manager import TrackingPluginManager
    pluginManager = TrackingPluginManager(pluginPaths=pluginPaths, verbose=False)
    pluginManager.setImageProvider(imageProviderPluginName)

    scores = {}
    gtToGlobalIdMap = {}

    groundTruthLabelImage = pluginManager.getImageProvider().getLabelImageForFrame(groundTruthFilename, groundTruthPath, frame)

    for labelImageIndexA in range(len(labelImageFilenames)):
        labelImageA = pluginManager.getImageProvider().getLabelImageForFrame(labelImageFilenames[labelImageIndexA],
                                                                                    labelImagePaths[labelImageIndexA],
                                                                                    frame)
        # check for overlaps - even a 1-pixel overlap is enough to be mutually exclusive!
        for objectIdA in np.unique(labelImageA):
            if objectIdA == 0:
                continue
            globalIdA = labelImageFrameIdToGlobalId[(labelImageFilenames[labelImageIndexA], frame, objectIdA)]
            overlap = groundTruthLabelImage[labelImageA == objectIdA]
            overlappingGtElements = set(np.unique(overlap)) - set([0])
            
            for gtLabel in overlappingGtElements:
                # compute Jaccard scores
                intersectingPixels = np.sum(overlap == gtLabel)
                unionPixels = np.sum(np.logical_or(groundTruthLabelImage == gtLabel, labelImageA == objectIdA))
                jaccardScore = float(intersectingPixels) / float(unionPixels) 

                # append to object's score list
                scores.setdefault(globalIdA, []).append( (gtLabel, jaccardScore) )

                # store this as GT mapping if there was no better object for this GT label yet
                if jaccardScore > groundTruthMinJaccardScore and \
                    ((frame, gtLabel) not in gtToGlobalIdMap or gtToGlobalIdMap[(frame, gtLabel)][-1][1] < jaccardScore):
                    gtToGlobalIdMap.setdefault((frame, gtLabel), []).append((globalIdA, jaccardScore))

    # sort all gt mappings by ascending jaccard score
    for _, v in gtToGlobalIdMap.iteritems():
        v.sort(key=lambda x: x[1]) 

    return frame, scores, gtToGlobalIdMap
Example #2
0
    def __init__(self,
                 pluginPaths=[os.path.abspath('../hytra/plugins')],
                 verbose=False):
        self.unresolvedGraph = None
        self.resolvedGraph = None
        self.mergersPerTimestep = None
        self.detectionsPerTimestep = None
        self.pluginManager = TrackingPluginManager(verbose=verbose,
                                                   pluginPaths=pluginPaths)
        self.mergerResolverPlugin = self.pluginManager.getMergerResolver()

        # should be filled by constructors of derived classes!
        self.model = None
        self.result = None
    def __init__(self, parent=None, graph=None):
        super(OpConservationTracking, self).__init__(parent=parent,
                                                     graph=graph)

        self._opCache = OpBlockedArrayCache(parent=self)
        self._opCache.name = "OpConservationTracking._opCache"
        self._opCache.Input.connect(self.Output)
        self.CleanBlocks.connect(self._opCache.CleanBlocks)
        self.CachedOutput.connect(self._opCache.Output)

        self.zeroProvider = OpZeroDefault(parent=self)
        self.zeroProvider.MetaInput.connect(self.LabelImage)

        # As soon as input data is available, check its constraints
        self.RawImage.notifyReady(self._checkConstraints)
        self.LabelImage.notifyReady(self._checkConstraints)

        self.ExportSettings.setValue((None, None))

        self._mergerOpCache = OpBlockedArrayCache(parent=self)
        self._mergerOpCache.name = "OpConservationTracking._mergerOpCache"
        self._mergerOpCache.Input.connect(self.MergerOutput)
        self.MergerCleanBlocks.connect(self._mergerOpCache.CleanBlocks)
        self.MergerCachedOutput.connect(self._mergerOpCache.Output)

        self._relabeledOpCache = OpBlockedArrayCache(parent=self)
        self._relabeledOpCache.name = "OpConservationTracking._mergerOpCache"
        self._relabeledOpCache.Input.connect(self.RelabeledImage)
        self.RelabeledCleanBlocks.connect(self._relabeledOpCache.CleanBlocks)
        self.RelabeledCachedOutput.connect(self._relabeledOpCache.Output)

        # Merger resolver plugin manager (contains GMM fit routine)
        self.pluginPaths = [
            os.path.join(os.path.dirname(os.path.abspath(hytra.__file__)),
                         'plugins')
        ]
        pluginManager = TrackingPluginManager(verbose=False,
                                              pluginPaths=self.pluginPaths)
        self.mergerResolverPlugin = pluginManager.getMergerResolver()

        self.result = None

        # progress bar
        self.progressWindow = None
        self.progressVisitor = DefaultProgressVisitor()
    def __init__(self,
                 ilpOptions,
                 turnOffFeatures=[],
                 useMultiprocessing=True,
                 pluginPaths=['hytra/plugins'],
                 verbose=False):
        self._useMultiprocessing = useMultiprocessing
        self._options = ilpOptions
        self._pluginPaths = pluginPaths
        self._pluginManager = TrackingPluginManager(
            turnOffFeatures=turnOffFeatures,
            verbose=verbose,
            pluginPaths=pluginPaths)
        self._pluginManager.setImageProvider(ilpOptions.imageProviderName)
        self._pluginManager.setFeatureSerializer(
            ilpOptions.featureSerializerName)

        self._countClassifier = None
        self._divisionClassifier = None
        self._transitionClassifier = None

        self._loadClassifiers()

        self.shape, self.timeRange = self._getShapeAndTimeRange()

        # set default division feature names
        self._divisionFeatureNames = [
            'ParentChildrenRatio_Count', 'ParentChildrenRatio_Mean',
            'ChildrenRatio_Count', 'ChildrenRatio_Mean',
            'ParentChildrenAngle_RegionCenter',
            'ChildrenRatio_SquaredDistances'
        ]

        # other parameters that one might want to set
        self.x_scale = 1.0
        self.y_scale = 1.0
        self.z_scale = 1.0
        self.divisionProbabilityFeatureName = 'divProb'
        self.detectionProbabilityFeatureName = 'detProb'

        self.TraxelsPerFrame = {}
        ''' this public variable contains all traxels if we're not using pgmlink '''
    def __init__(self, pluginPaths=[os.path.abspath('../hytra/plugins')], verbose=False):
        self.unresolvedGraph = None
        self.resolvedGraph = None
        self.mergersPerTimestep = None
        self.detectionsPerTimestep = None
        self.pluginManager = TrackingPluginManager(
            verbose=verbose, pluginPaths=pluginPaths)
        self.mergerResolverPlugin = self.pluginManager.getMergerResolver()

        # should be filled by constructors of derived classes!
        self.model = None
        self.result = None
    def __init__(self, parent=None, graph=None):
        super(OpConservationTracking, self).__init__(parent=parent, graph=graph)

        self._opCache = OpBlockedArrayCache(parent=self)
        self._opCache.name = "OpConservationTracking._opCache"
        self._opCache.Input.connect(self.Output)
        self.CleanBlocks.connect(self._opCache.CleanBlocks)
        self.CachedOutput.connect(self._opCache.Output)

        self.zeroProvider = OpZeroDefault(parent=self)
        self.zeroProvider.MetaInput.connect(self.LabelImage)

        # As soon as input data is available, check its constraints
        self.RawImage.notifyReady(self._checkConstraints)
        self.LabelImage.notifyReady(self._checkConstraints)

        self.ExportSettings.setValue( (None, None) )

        self._mergerOpCache = OpBlockedArrayCache(parent=self)
        self._mergerOpCache.name = "OpConservationTracking._mergerOpCache"
        self._mergerOpCache.Input.connect(self.MergerOutput)
        self.MergerCleanBlocks.connect(self._mergerOpCache.CleanBlocks)
        self.MergerCachedOutput.connect(self._mergerOpCache.Output)

        self._relabeledOpCache = OpBlockedArrayCache(parent=self)
        self._relabeledOpCache.name = "OpConservationTracking._mergerOpCache"
        self._relabeledOpCache.Input.connect(self.RelabeledImage)
        self.RelabeledCleanBlocks.connect(self._relabeledOpCache.CleanBlocks)
        self.RelabeledCachedOutput.connect(self._relabeledOpCache.Output)
        
        # Merger resolver plugin manager (contains GMM fit routine)
        self.pluginPaths = [os.path.join(os.path.dirname(os.path.abspath(hytra.__file__)), 'plugins')]
        pluginManager = TrackingPluginManager(verbose=False, pluginPaths=self.pluginPaths)
        self.mergerResolverPlugin = pluginManager.getMergerResolver()

        self.result = None

        # progress bar
        self.progressWindow = None
        self.progressVisitor=DefaultProgressVisitor()
def findConflictingHypothesesInSeparateProcess(frame,
                                               labelImageFilenames,
                                               labelImagePaths,
                                               labelImageFrameIdToGlobalId,
                                               pluginPaths=['hytra/plugins'],
                                               imageProviderPluginName='LocalImageLoader'):
    """
    Look which objects between different segmentation hypotheses (given as different labelImages)
    overlap, and return a dictionary of those overlapping situations.

    Meant to be run in its own process using `concurrent.futures.ProcessPoolExecutor`
    """

    # set up plugin manager
    from hytra.pluginsystem.plugin_manager import TrackingPluginManager
    pluginManager = TrackingPluginManager(pluginPaths=pluginPaths, verbose=False)
    pluginManager.setImageProvider(imageProviderPluginName)

    overlaps = {} # overlap dict: key=globalId, value=[list of globalIds]

    for labelImageIndexA in range(len(labelImageFilenames)):
        labelImageA = pluginManager.getImageProvider().getLabelImageForFrame(labelImageFilenames[labelImageIndexA],
                                                                                    labelImagePaths[labelImageIndexA],
                                                                                    frame)
        for labelImageIndexB in range(labelImageIndexA + 1, len(labelImageFilenames)):
            labelImageB = pluginManager.getImageProvider().getLabelImageForFrame(labelImageFilenames[labelImageIndexB],
                                                                                        labelImagePaths[labelImageIndexB],
                                                                                        frame)
            # check for overlaps - even a 1-pixel overlap is enough to be mutually exclusive!
            for objectIdA in np.unique(labelImageA):
                if objectIdA == 0:
                    continue
                overlapping = set(np.unique(labelImageB[labelImageA == objectIdA])) - set([0])
                overlappingGlobalIds = [labelImageFrameIdToGlobalId[(labelImageFilenames[labelImageIndexB], frame, o)] for o in overlapping]
                globalIdA = labelImageFrameIdToGlobalId[(labelImageFilenames[labelImageIndexA], frame, objectIdA)]
                overlaps.setdefault(globalIdA, []).extend(overlappingGlobalIds)
                for globalIdB in overlappingGlobalIds:
                    overlaps.setdefault(globalIdB, []).append(globalIdA)

    return frame, overlaps
def findConflictingHypothesesInSeparateProcess(
        frame,
        labelImageFilenames,
        labelImagePaths,
        labelImageFrameIdToGlobalId,
        pluginPaths=['hytra/plugins'],
        imageProviderPluginName='LocalImageLoader'):
    """
    Look which objects between different segmentation hypotheses (given as different labelImages)
    overlap, and return a dictionary of those overlapping situations.

    Meant to be run in its own process using `concurrent.futures.ProcessPoolExecutor`
    """

    # set up plugin manager
    from hytra.pluginsystem.plugin_manager import TrackingPluginManager
    pluginManager = TrackingPluginManager(pluginPaths=pluginPaths,
                                          verbose=False)
    pluginManager.setImageProvider(imageProviderPluginName)

    overlaps = {}  # overlap dict: key=globalId, value=[list of globalIds]

    for labelImageIndexA in range(len(labelImageFilenames)):
        labelImageA = pluginManager.getImageProvider().getLabelImageForFrame(
            labelImageFilenames[labelImageIndexA],
            labelImagePaths[labelImageIndexA], frame)
        for labelImageIndexB in range(labelImageIndexA + 1,
                                      len(labelImageFilenames)):
            labelImageB = pluginManager.getImageProvider(
            ).getLabelImageForFrame(labelImageFilenames[labelImageIndexB],
                                    labelImagePaths[labelImageIndexB], frame)
            # check for overlaps - even a 1-pixel overlap is enough to be mutually exclusive!
            for objectIdA in np.unique(labelImageA):
                if objectIdA == 0:
                    continue
                overlapping = set(
                    np.unique(labelImageB[labelImageA == objectIdA])) - set(
                        [0])
                overlappingGlobalIds = [
                    labelImageFrameIdToGlobalId[(
                        labelImageFilenames[labelImageIndexB], frame, o)]
                    for o in overlapping
                ]
                globalIdA = labelImageFrameIdToGlobalId[(
                    labelImageFilenames[labelImageIndexA], frame, objectIdA)]
                overlaps.setdefault(globalIdA, []).extend(overlappingGlobalIds)
                for globalIdB in overlappingGlobalIds:
                    overlaps.setdefault(globalIdB, []).append(globalIdA)

    return frame, overlaps
    # Do the tracking
    start = time.time()

    feature_path = options.feats_path
    with_div = not bool(options.without_divisions)
    with_merger_prior = True

    # get selected time range
    time_range = [options.mints, options.maxts]
    if options.maxts == -1 and options.mints == 0:
        time_range = None

    try:
        # find shape of dataset
        pluginManager = TrackingPluginManager(verbose=options.verbose,
                                              pluginPaths=options.pluginPaths)
        shape = pluginManager.getImageProvider().getImageShape(
            ilp_fn, options.label_img_path)
        data_time_range = pluginManager.getImageProvider().getTimeRange(
            ilp_fn, options.label_img_path)

        if time_range is not None and time_range[1] < 0:
            time_range[1] += data_time_range[1]

    except:
        logging.warning("Could not read shape and time range from images")
        shape = None

    # set average object size if chosen
    obj_size = [0]
    if options.avg_obj_size != 0:
        rawimage = hytra.util.axesconversion.adjustOrder(rawimage, args.rawimage_axes[dataset], 'xyztc')
        logger.info('Done loading raw data from dataset {} of shape {}'.format(dataset, rawimage.shape))

        # find ground truth files
        # filepath is now a list of filepaths'
        filepath = args.filepath[dataset]
        # filepattern is now a list of filepatterns
        files = glob.glob(os.path.join(filepath, args.filepattern[dataset]))
        files.sort()
        initFrame = args.initFrame
        endFrame = args.endFrame
        if endFrame < 0:
            endFrame += len(files)
    
        # compute features
        trackingPluginManager = TrackingPluginManager(verbose=args.verbose, 
                                                      pluginPaths=args.pluginPaths)
        features = compute_features(rawimage,
                                    read_in_images(initFrame, endFrame, files, args.groundtruth_axes[dataset]),
                                    initFrame,
                                    endFrame,
                                    trackingPluginManager,
                                    rawimage_filename)
        logger.info('Done computing features from dataset {}'.format(dataset))

        selectedFeatures = find_features_without_NaNs(features)
        pos_labels = read_positiveLabels(initFrame, endFrame, files)
        neg_labels = negativeLabels(features, pos_labels)
        numSamples += 2 * sum([len(l) for l in pos_labels]) + sum([len(l) for l in neg_labels])
        logger.info('Done extracting {} samples'.format(numSamples))
        TC = TransitionClassifier(selectedFeatures, numSamples)
        if dataset > 0:
def computeJaccardScoresOnCloud(frame,
                                labelImageFilenames,
                                labelImagePaths,
                                labelImageFrameIdToGlobalId,
                                groundTruthFilename,
                                groundTruthPath,
                                groundTruthMinJaccardScore,
                                pluginPaths=['hytra/plugins'],
                                imageProviderPluginName='LocalImageLoader'):
    """
    Compute jaccard scores of all objects in the different segmentations with the ground truth for that frame.
    Returns a dictionary of overlapping GT labels and the score per globalId in that frame, as well as 
    a dictionary specifying the matching globalId and score for every GT label (as a list ordered by score, best match last).

    Meant to be run in its own process using `concurrent.futures.ProcessPoolExecutor`
    """

    # set up plugin manager
    from hytra.pluginsystem.plugin_manager import TrackingPluginManager
    pluginManager = TrackingPluginManager(pluginPaths=pluginPaths,
                                          verbose=False)
    pluginManager.setImageProvider(imageProviderPluginName)

    scores = {}
    gtToGlobalIdMap = {}

    groundTruthLabelImage = pluginManager.getImageProvider(
    ).getLabelImageForFrame(groundTruthFilename, groundTruthPath, frame)

    for labelImageIndexA in range(len(labelImageFilenames)):
        labelImageA = pluginManager.getImageProvider().getLabelImageForFrame(
            labelImageFilenames[labelImageIndexA],
            labelImagePaths[labelImageIndexA], frame)
        # check for overlaps - even a 1-pixel overlap is enough to be mutually exclusive!
        for objectIdA in np.unique(labelImageA):
            if objectIdA == 0:
                continue
            globalIdA = labelImageFrameIdToGlobalId[(
                labelImageFilenames[labelImageIndexA], frame, objectIdA)]
            overlap = groundTruthLabelImage[labelImageA == objectIdA]
            overlappingGtElements = set(np.unique(overlap)) - set([0])

            for gtLabel in overlappingGtElements:
                # compute Jaccard scores
                intersectingPixels = np.sum(overlap == gtLabel)
                unionPixels = np.sum(
                    np.logical_or(groundTruthLabelImage == gtLabel,
                                  labelImageA == objectIdA))
                jaccardScore = float(intersectingPixels) / float(unionPixels)

                # append to object's score list
                scores.setdefault(globalIdA, []).append(
                    (gtLabel, jaccardScore))

                # store this as GT mapping if there was no better object for this GT label yet
                if jaccardScore > groundTruthMinJaccardScore and \
                    ((frame, gtLabel) not in gtToGlobalIdMap or gtToGlobalIdMap[(frame, gtLabel)][-1][1] < jaccardScore):
                    gtToGlobalIdMap.setdefault((frame, gtLabel), []).append(
                        (globalIdA, jaccardScore))

    # sort all gt mappings by ascending jaccard score
    for _, v in gtToGlobalIdMap.iteritems():
        v.sort(key=lambda x: x[1])

    return frame, scores, gtToGlobalIdMap
Example #12
0
            parent = trackParents[trackId]
        else:
            parent = 0
        # jumping over time frames, so creating
        if trackId in gapTrackParents.keys():
            if gapTrackParents[trackId] != trackId:
                parent = gapTrackParents[trackId]
                getLogger().info(
                    "Jumping over one time frame in this link: trackid: {}, parent: {}, time: {}"
                    .format(trackId, parent, min(timestepList)))
        trackDict[trackId] = [parent, min(timestepList), max(timestepList)]
    save_tracks(trackDict, args)

    # load images, relabel, and export relabeled result
    getLogger().debug("Saving relabeled images")
    pluginManager = TrackingPluginManager(verbose=args.verbose,
                                          pluginPaths=args.pluginPaths)
    pluginManager.setImageProvider('LocalImageLoader')
    imageProvider = pluginManager.getImageProvider()
    timeRange = imageProvider.getTimeRange(args.label_image_filename,
                                           args.label_image_path)

    for timeframe in range(timeRange[0], timeRange[1]):
        label_image = imageProvider.getLabelImageForFrame(
            args.label_image_filename, args.label_image_path, timeframe)

        # check if frame is empty
        if timeframe in mappings.keys():
            remapped_label_image = remap_label_image(label_image,
                                                     mappings[timeframe])
            save_frame_to_tif(timeframe, remapped_label_image, args)
        else:
Example #13
0
def writeEvents(timestep, activeLinks, activeDivisions, mergers, detections,
                fn, labelImagePath, ilpFilename, verbose, pluginPaths):
    dis = []
    app = []
    div = []
    mov = []
    mer = []
    mul = []

    pluginManager = TrackingPluginManager(verbose=verbose,
                                          pluginPaths=pluginPaths)

    logging.getLogger('json_result_to_events.py').debug(
        "-- Writing results to {}".format(fn))
    try:
        # convert to ndarray for better indexing
        dis = np.asarray(dis)
        app = np.asarray(app)
        div = np.asarray([[k, v[0], v[1]] for k, v in activeDivisions.items()])
        mov = np.asarray(activeLinks)
        mer = np.asarray([[k, v] for k, v in mergers.items()])
        mul = np.asarray(mul)

        shape = pluginManager.getImageProvider().getImageShape(
            ilpFilename, labelImagePath)
        label_img = pluginManager.getImageProvider().getLabelImageForFrame(
            ilpFilename, labelImagePath, timestep)

        with h5py.File(fn, 'w') as dest_file:
            # write meta fields and copy segmentation from project
            seg = dest_file.create_group('segmentation')
            seg.create_dataset("labels", data=label_img, compression='gzip')
            meta = dest_file.create_group('objects/meta')
            ids = np.unique(label_img)
            ids = ids[ids > 0]
            valid = np.ones(ids.shape)
            meta.create_dataset("id", data=ids, dtype=np.uint32)
            meta.create_dataset("valid", data=valid, dtype=np.uint32)

            tg = dest_file.create_group("tracking")

            # write associations
            if app is not None and len(app) > 0:
                ds = tg.create_dataset("Appearances", data=app, dtype=np.int32)
                ds.attrs["Format"] = "cell label appeared in current file"

            if dis is not None and len(dis) > 0:
                ds = tg.create_dataset("Disappearances",
                                       data=dis,
                                       dtype=np.int32)
                ds.attrs["Format"] = "cell label disappeared in current file"

            if mov is not None and len(mov) > 0:
                ds = tg.create_dataset("Moves", data=mov, dtype=np.int32)
                ds.attrs["Format"] = "from (previous file), to (current file)"

            if div is not None and len(div) > 0:
                ds = tg.create_dataset("Splits", data=div, dtype=np.int32)
                ds.attrs[
                    "Format"] = "ancestor (previous file), descendant (current file), descendant (current file)"

            if mer is not None and len(mer) > 0:
                ds = tg.create_dataset("Mergers", data=mer, dtype=np.int32)
                ds.attrs[
                    "Format"] = "descendant (current file), number of objects"

            if mul is not None and len(mul) > 0:
                ds = tg.create_dataset("MultiFrameMoves",
                                       data=mul,
                                       dtype=np.int32)
                ds.attrs[
                    "Format"] = "from (given by timestep), to (current file), timestep"

        logging.getLogger('json_result_to_events.py').debug(
            "-> results successfully written")
    except Exception as e:
        logging.getLogger('json_result_to_events.py').warning(
            "ERROR while writing events: {}".format(str(e)))
Example #14
0
def run_pipeline(options):
    """
    Run the complete tracking pipeline with competing segmentation hypotheses
    """

    # set up probabilitygenerator (aka traxelstore) and hypothesesgraph or load them from a dump
    if options.load_graph_filename is not None:
        getLogger().info("Loading state from file: " + options.load_graph_filename)
        with gzip.open(options.load_graph_filename, 'r') as graphDump:
            ilpOptions = pickle.load(graphDump)
            probGenerator = pickle.load(graphDump)
            fieldOfView = pickle.load(graphDump)
            hypotheses_graph = pickle.load(graphDump)
            trackingGraph = pickle.load(graphDump)
        getLogger().info("Done loading state from file")
    else:
        fieldOfView, hypotheses_graph, ilpOptions, probGenerator, trackingGraph = setupGraph(options)
        
        if options.dump_graph_filename is not None:
            getLogger().info("Saving state to file: " + options.dump_graph_filename)
            with gzip.open(options.dump_graph_filename, 'w') as graphDump:
                pickle.dump(ilpOptions, graphDump)
                pickle.dump(probGenerator, graphDump)
                pickle.dump(fieldOfView, graphDump)
                pickle.dump(hypotheses_graph, graphDump)
                pickle.dump(trackingGraph, graphDump)
            getLogger().info("Done saving state to file")

    # map groundtruth to hypothesesgraph if all required variables are specified
    weights = None
    if options.gt_label_image_file is not None and options.gt_label_image_path is not None \
        and options.gt_text_file is not None and options.gt_jaccard_threshold is not None:
        weights = mapGroundTruth(options, hypotheses_graph, trackingGraph, probGenerator)

    # track
    result = runTracking(options, trackingGraph, weights)

    # insert the solution into the hypotheses graph and from that deduce the lineages
    getLogger().info("Inserting solution into graph")
    hypotheses_graph.insertSolution(result)
    hypotheses_graph.computeLineage()

    mappings = {} # dictionary over timeframes, containing another dict objectId -> trackId per frame
    tracks = {} # stores a list of timeframes per track, so that we can find from<->to per track
    trackParents = {} # store the parent trackID of a track if known

    for n in hypotheses_graph.nodeIterator():
        frameMapping = mappings.setdefault(n[0], {})
        if 'trackId' not in hypotheses_graph._graph.node[n]:
            raise ValueError("You need to compute the Lineage of every node before accessing the trackId!")
        trackId = hypotheses_graph._graph.node[n]['trackId']
        traxel = hypotheses_graph._graph.node[n]['traxel']
        if trackId is not None:
            frameMapping[(traxel.idInSegmentation, traxel.segmentationFilename)] = trackId
        if trackId in tracks:
            tracks[trackId].append(n[0])
        else:
            tracks[trackId] = [n[0]]
        if 'parent' in hypotheses_graph._graph.node[n]:
            assert(trackId not in trackParents)
            trackParents[trackId] = hypotheses_graph._graph.node[hypotheses_graph._graph.node[n]['parent']]['trackId']

    # write res_track.txt
    getLogger().info("Writing track text file")
    trackDict = {}
    for trackId, timestepList in tracks.items():
        timestepList.sort()
        try:
            parent = trackParents[trackId]
        except KeyError:
            parent = 0
        trackDict[trackId] = [parent, min(timestepList), max(timestepList)]
    save_tracks(trackDict, options) 

    # export results
    getLogger().info("Saving relabeled images")
    pluginManager = TrackingPluginManager(verbose=options.verbose, pluginPaths=options.pluginPaths)
    pluginManager.setImageProvider('LocalImageLoader')
    imageProvider = pluginManager.getImageProvider()
    timeRange = probGenerator.timeRange

    for timeframe in range(timeRange[0], timeRange[1]):
        label_images = {}
        for f, p in zip(options.label_image_files, options.label_image_paths):
            label_images[f] = imageProvider.getLabelImageForFrame(f, p, timeframe)
        remapped_label_image = remap_label_image(label_images, mappings[timeframe])
        save_frame_to_tif(timeframe, remapped_label_image, options)
    for trackId, timestepList in tracks.iteritems():
        timestepList.sort()
        if trackId in trackParents.keys():
            parent = trackParents[trackId]
        else:
            parent = 0
        # jumping over time frames, so creating 
        if trackId in gapTrackParents.keys():
            if gapTrackParents[trackId] != trackId:
                parent = gapTrackParents[trackId]
                getLogger().info("Jumping over one time frame in this link: trackid: {}, parent: {}, time: {}".format(trackId, parent, min(timestepList)))
        trackDict[trackId] = [parent, min(timestepList), max(timestepList)]
    save_tracks(trackDict, args) 

    # load images, relabel, and export relabeled result
    getLogger().debug("Saving relabeled images")
    pluginManager = TrackingPluginManager(verbose=args.verbose, pluginPaths=args.pluginPaths)
    pluginManager.setImageProvider('LocalImageLoader')
    imageProvider = pluginManager.getImageProvider()
    timeRange = imageProvider.getTimeRange(args.label_image_filename, args.label_image_path)

    for timeframe in range(timeRange[0], timeRange[1]):
        label_image = imageProvider.getLabelImageForFrame(args.label_image_filename, args.label_image_path, timeframe)

        # check if frame is empty
        if timeframe in mappings.keys():
            remapped_label_image = remap_label_image(label_image, mappings[timeframe])
            save_frame_to_tif(timeframe, remapped_label_image, args)
        else:
            save_frame_to_tif(timeframe, label_image, args)
class MergerResolver(object):
    """
    Base class for all merger resolving implementations. Use one of the derived classes
    that handle reading/writing data to the respective sources.
    """

    def __init__(self, pluginPaths=[os.path.abspath('../hytra/plugins')], verbose=False):
        self.unresolvedGraph = None
        self.resolvedGraph = None
        self.mergersPerTimestep = None
        self.detectionsPerTimestep = None
        self.pluginManager = TrackingPluginManager(
            verbose=verbose, pluginPaths=pluginPaths)
        self.mergerResolverPlugin = self.pluginManager.getMergerResolver()

        # should be filled by constructors of derived classes!
        self.model = None
        self.result = None

    def _createUnresolvedGraph(self, divisionsPerTimestep, mergersPerTimestep, mergerLinks, withFullGraph=False):
        """
        Set up a networkx graph consisting of mergers that need to be resolved (not resolved yet!)
        and their direct neighbors.

        ** returns ** the `unresolvedGraph`
        """
        
        self.unresolvedGraph = nx.DiGraph()
        def source(timestep, link):
            return int(timestep) - 1, link[0]

        def target(timestep, link):
            return int(timestep), link[1]
        
        # Recompute full graph
        if withFullGraph:
            self.unresolvedGraph = self.hypothesesGraph._graph.copy()
            
            # Add division parameter to nodes
            # TODO: Add the division parameter only to nodes that contain divisions (we're already doing these with 'count')
            lastframe = max(divisionsPerTimestep.keys(), key=int)
            for node in self.unresolvedGraph.nodes_iter(): 
                timestep, idx = node

                if divisionsPerTimestep is not None and int(timestep) < int(lastframe):
                    division = idx in divisionsPerTimestep[str(timestep + 1)] # +1 screams for lastframe condition.
                else:
                    division = False  
                    
                self.unresolvedGraph.node[node]['division'] = division          
            
            # Add count parameter to nodes 
            for t, link in mergerLinks:
                for node in [source(t, link), target(t, link)]:
                    timestep, idx = node
                    if idx in mergersPerTimestep[str(timestep)]:
                        count = mergersPerTimestep[str(timestep)][idx]
                        self.unresolvedGraph.node[node]['count'] = count
        
        # Recompute graph only with merger nodes and neighbors                
        else:      
            def addNode(node):
                ''' add a node to the unresolved graph and fill in the properties `division` and `count` '''
                intT, idx = node
    
                lastframe = max(divisionsPerTimestep.keys(), key=int)
                if divisionsPerTimestep is not None and int(intT) < int(lastframe):
                    division = idx in divisionsPerTimestep[str(intT + 1)] # +1 screams for lastframe condition.
                else:
                    division = False
                count = 1
                if idx in mergersPerTimestep[str(intT)]:
                    assert(not division)
                    count = mergersPerTimestep[str(intT)][idx]
                self.unresolvedGraph.add_node(node, division=division, count=count)
    
            # add nodes
            for t, link in mergerLinks:
                for n in [source(t, link), target(t, link)]:
                    if not self.unresolvedGraph.has_node(n):
                        addNode(n)
                self.unresolvedGraph.add_edge(source(t, link), target(t, link))

        return self.unresolvedGraph

    def _prepareResolvedGraph(self):
        """
        
        ** returns ** the `resolvedGraph`
        """
        self.resolvedGraph = self.unresolvedGraph.copy()

        return self.resolvedGraph

    def _readLabelImage(self, timeframe):
        '''
        Should return the labelimage for the given timeframe
        '''
        raise NotImplementedError()

    def _fitAndRefineNodes(self,
                            detectionsPerTimestep,
                            mergersPerTimestep,
                            timesteps):
        '''
        Update segmentation of mergers (nodes in unresolvedGraph) from first timeframe to last
        and create new nodes in `resolvedGraph`. Links to merger nodes are duplicated to all new nodes.

        Uses the mergerResolver plugin to update the segmentations in the labelImages.
        '''

        intTimesteps = [int(t) for t in timesteps]
        intTimesteps.sort()

        for intT in intTimesteps:
            t = str(intT)
            # use image provider plugin to load labelimage
            labelImage = self._readLabelImage(int(t))
            nextObjectId = labelImage.max() + 1

            for idx in detectionsPerTimestep[t]:
                node = (intT, idx)
                if node not in self.resolvedGraph:
                    continue

                count = 1
                if idx in mergersPerTimestep[t]:
                    count = mergersPerTimestep[t][idx]
                getLogger().debug("Looking at node {} in timestep {} with count {}".format(idx, t, count))
                
                # collect initializations from incoming
                initializations = []
                for predecessor, _ in self.unresolvedGraph.in_edges(node):
                    initializations.extend(self.unresolvedGraph.node[predecessor]['fits'])
                # TODO: what shall we do if e.g. a 2-merger and a single object merge to 2 + 1,
                # so there are 3 initializations for the 2-merger, and two initializations for the 1 merger?
                # What does pgmlink do in that case?

                # use merger resolving plugin to fit `count` objects, also updates labelimage!
                fittedObjects = self.mergerResolverPlugin.resolveMerger(labelImage, idx, nextObjectId, count, initializations)
                assert(len(fittedObjects) == count)

                # split up node if count > 1, duplicate incoming and outgoing arcs
                if count > 1:
                    for idx in range(nextObjectId, nextObjectId + count):
                        newNode = (intT, idx)
                        self.resolvedGraph.add_node(newNode, division=False, count=1, origin=node)

                        for e in self.unresolvedGraph.out_edges(node):
                            self.resolvedGraph.add_edge(newNode, e[1])
                        for e in self.unresolvedGraph.in_edges(node):
                            if 'newIds' in self.unresolvedGraph.node[e[0]]:
                                for newId in self.unresolvedGraph.node[e[0]]['newIds']:
                                    self.resolvedGraph.add_edge((e[0][0], newId), newNode)
                            else:
                                self.resolvedGraph.add_edge(e[0], newNode)

                    self.resolvedGraph.remove_node(node)
                    self.unresolvedGraph.node[node]['newIds'] = range(nextObjectId, nextObjectId + count)
                    nextObjectId += count

                # each unresolved node stores its fitted shape(s) to be used
                # as initialization in the next frame, this way division duplicates
                # and de-merged nodes in the resolved graph do not need to store a fit as well
                self.unresolvedGraph.node[node]['fits'] = fittedObjects

        # import matplotlib.pyplot as plt
        # nx.draw_networkx(resolvedGraph)
        # plt.savefig("/Users/chaubold/test.pdf")

    def _minCostMaxFlowMergerResolving(self, objectFeatures, transitionClassifier=None, transitionParameter=5.0):
        """
        Find the optimal assignments within the `resolvedGraph` by running min-cost max-flow from the
        `dpct` module.

        Converts the `resolvedGraph` to our JSON model structure, predicts the transition probabilities
        either using the given transitionClassifier, or using distance-based probabilities.

        **returns** a `nodeFlowMap` and `arcFlowMap` holding information on the usage of the respective nodes and links

        **Note:** cannot use `networkx` flow methods because they don't work with floating point weights.
        """

        trackingGraph = JsonTrackingGraph()
        for node in self.resolvedGraph.nodes_iter():
            additionalFeatures = {}

            # nodes with no in/out
            numStates = 2
            if len(self.resolvedGraph.in_edges(node)) == 0:
                # division nodes with no incoming arcs offer 2 units of flow without the need to de-merge
                if node in self.unresolvedGraph.nodes() and self.unresolvedGraph.node[node]['division'] and len(self.unresolvedGraph.out_edges(node)) == 2:
                    numStates = 3
                additionalFeatures['appearanceFeatures'] = [[i**2 * 0.01] for i in range(numStates)]
            if len(self.resolvedGraph.out_edges(node)) == 0:
                assert(numStates == 2) # division nodes with no incoming should have outgoing, or they shouldn't show up in resolved graph
                additionalFeatures['disappearanceFeatures'] = [[i**2 * 0.01] for i in range(numStates)]

            features = [[i**2] for i in range(numStates)]
            uuid = trackingGraph.addDetectionHypotheses(features, **additionalFeatures)
            self.resolvedGraph.node[node]['id'] = uuid

        for edge in self.resolvedGraph.edges_iter():
            src = self.resolvedGraph.node[edge[0]]['id']
            dest = self.resolvedGraph.node[edge[1]]['id']

            featuresAtSrc = objectFeatures[edge[0]]
            featuresAtDest = objectFeatures[edge[1]]

            if transitionClassifier is not None:
                try:
                    featVec = self.pluginManager.applyTransitionFeatureVectorConstructionPlugins(
                        featuresAtSrc, featuresAtDest, transitionClassifier.selectedFeatures)
                except:
                    getLogger().error("Could not compute transition features of link {}->{}:".format(src, dest))
                    getLogger().error(featuresAtSrc)
                    getLogger().error(featuresAtDest)
                    raise
                featVec = np.expand_dims(np.array(featVec), axis=0)
                probs = transitionClassifier.predictProbabilities(featVec)[0]
            else:
                dist = np.linalg.norm(featuresAtDest['RegionCenter'] - featuresAtSrc['RegionCenter'])
                prob = np.exp(-dist / transitionParameter)
                probs = [1.0 - prob, prob]

            trackingGraph.addLinkingHypotheses(src, dest, listify(negLog(probs)))

        # track
        import dpct
        weights = {"weights": [1, 1, 1, 1]}
        mergerResult = dpct.trackMaxFlow(trackingGraph.model, weights)

        # transform results to dictionaries that can be indexed by id or (src,dest)
        nodeFlowMap = dict([(int(d['id']), int(d['value'])) for d in mergerResult['detectionResults']])
        arcFlowMap = dict([((int(l['src']), int(l['dest'])), int(l['value'])) for l in mergerResult['linkingResults']])

        return nodeFlowMap, arcFlowMap

    def _refineModel(self,
                     uuidToTraxelMap,
                     traxelIdPerTimestepToUniqueIdMap,
                     mergerNodeFilter,
                     mergerLinkFilter):
        """
        Take the `self.model` (JSON format) with mergers, remove the merger nodes, but add new
        de-merged nodes and links. Also updates `traxelIdPerTimestepToUniqueIdMap` locally and in the resulting file,
        such that the traxel IDs match the new connected component IDs in the refined images.

        `mergerNodeFilter` and `mergerLinkFilter` are methods that can filter merger detections
        and links from the respective lists in the `model` dict.

        **Returns** the updated `model` dictionary, which is the same as the input `model` (works in-place)
        """

        # remove merger detections
        self.model['segmentationHypotheses'] = [seg for seg in self.model['segmentationHypotheses'] if mergerNodeFilter(seg)]

        # remove merger links
        self.model['linkingHypotheses'] = [link for link in self.model['linkingHypotheses'] if mergerLinkFilter(link)]

        # insert new nodes and update UUID to traxel map
        nextUuid = max(uuidToTraxelMap.keys()) + 1
        for node in self.unresolvedGraph.nodes_iter():
            if 'count' in self.unresolvedGraph.node[node] and self.unresolvedGraph.node[node]['count'] > 1:
                newIds = self.unresolvedGraph.node[node]['newIds']
                del traxelIdPerTimestepToUniqueIdMap[str(node[0])][str(node[1])]
                for newId in newIds:
                    newDetection = {}
                    newDetection['id'] = nextUuid
                    newDetection['timestep'] = [node[0], node[0]]
                    self.model['segmentationHypotheses'].append(newDetection)
                    traxelIdPerTimestepToUniqueIdMap[str(node[0])][str(newId)] = nextUuid
                    nextUuid += 1

        # insert new links
        for edge in self.resolvedGraph.edges_iter():
            newLink = {}
            newLink['src'] = traxelIdPerTimestepToUniqueIdMap[str(edge[0][0])][str(edge[0][1])]
            newLink['dest'] = traxelIdPerTimestepToUniqueIdMap[str(edge[1][0])][str(edge[1][1])]
            self.model['linkingHypotheses'].append(newLink)

        # save
        return self.model

    def _refineResult(self,
                      nodeFlowMap,
                      arcFlowMap,
                      traxelIdPerTimestepToUniqueIdMap,
                      mergerNodeFilter,
                      mergerLinkFilter):
        """
        Update the `self.result` dict by removing the mergers and adding the refined nodes and links.

        Operates on a `result` dictionary in our JSON result style with mergers,
        the resolved and unresolved graph as well as
        the `nodeFlowMap` and `arcFlowMap` obtained by running tracking on the `resolvedGraph`.

        Updates the `result` dictionary so that all merger nodes are removed but the new nodes
        are contained with the appropriate links and values.

        `mergerNodeFilter` and `mergerLinkFilter` are methods that can filter merger detections
        and links from the respective lists in the `result` dict.

        **Returns** the updated `result` dict, which is the same as the input `result` (works in-place)
        """

        # filter merger edges
        self.result['detectionResults'] = [r for r in self.result['detectionResults'] if mergerNodeFilter(r)]
        self.result['linkingResults'] = [r for r in self.result['linkingResults'] if mergerLinkFilter(r)]

        # add new nodes
        for node in self.unresolvedGraph.nodes_iter():
            if 'count' in self.unresolvedGraph.node[node] and self.unresolvedGraph.node[node]['count'] > 1:
                newIds = self.unresolvedGraph.node[node]['newIds']
                for newId in newIds:
                    uuid = traxelIdPerTimestepToUniqueIdMap[str(node[0])][str(newId)]
                    resolvedNode = (node[0], newId)
                    resolvedResultId = self.resolvedGraph.node[resolvedNode]['id']
                    newDetection = {'id': uuid, 'value': nodeFlowMap[resolvedResultId]}
                    self.result['detectionResults'].append(newDetection)

        # add new links
        for edge in self.resolvedGraph.edges_iter():
            newLink = {}
            newLink['src'] = traxelIdPerTimestepToUniqueIdMap[str(edge[0][0])][str(edge[0][1])]
            newLink['dest'] = traxelIdPerTimestepToUniqueIdMap[str(edge[1][0])][str(edge[1][1])]
            srcId = self.resolvedGraph.node[edge[0]]['id']
            destId = self.resolvedGraph.node[edge[1]]['id']
            newLink['value'] = arcFlowMap[(srcId, destId)]
            self.result['linkingResults'].append(newLink)

        return self.result

    def _exportRefinedSegmentation(self, timesteps):
        """
        Store the resulting label images, if needed.

        `labelImages` is a dictionary with str(timestep) as keys.
        """
        pass

    def _computeObjectFeatures(self, timesteps):
        '''
        Return the features per object as nested dictionaries:
        { (int(Timestep), int(Id)):{ "FeatureName" : np.array(value), "NextFeature": ...} }
        '''
        pass

    # ------------------------------------------------------------
    def run(self, transition_classifier_filename=None, transition_classifier_path=None):
        """
        Run merger resolving

        1. find mergers in the given model and result
        2. build graph of the unresolved (merger) nodes and their direct neighbors
        3. use a mergerResolving plugin to refine the merger nodes and their segmentation
        4. run min-cost max-flow tracking to find the fate of all the de-merged objects
        5. export refined segmentation, update member variables `model` and `result`

        **Returns** a nested dictionary, indexed first by time, then object Id, containing a list of new segmentIDs per merger
        """

        traxelIdPerTimestepToUniqueIdMap, uuidToTraxelMap = hytra.core.jsongraph.getMappingsBetweenUUIDsAndTraxels(self.model)
        # timesteps = [t for t in traxelIdPerTimestepToUniqueIdMap.keys()]
        # there might be empty frames. We want them as output too.
        timesteps = [str(t).decode("utf-8") for t in range(int(min(traxelIdPerTimestepToUniqueIdMap.keys())) , int(max(traxelIdPerTimestepToUniqueIdMap.keys()))+1 )]

        mergers, detections, links, divisions = hytra.core.jsongraph.getMergersDetectionsLinksDivisions(self.result, uuidToTraxelMap)


        # ------------------------------------------------------------

        # it may be, that there are no mergers, so do basically nothing, just copy all the ingoing data
        if len(mergers) == 0:
            getLogger().info("The maximum number of objects is 1, so nothing to be done. Writing the output...")
            self._exportRefinedSegmentation(timesteps)

        else:
            self.mergersPerTimestep = hytra.core.jsongraph.getMergersPerTimestep(mergers, timesteps)
            self.detectionsPerTimestep = hytra.core.jsongraph.getDetectionsPerTimestep(detections, timesteps)
            
            linksPerTimestep = hytra.core.jsongraph.getLinksPerTimestep(links, timesteps)
            divisionsPerTimestep = hytra.core.jsongraph.getDivisionsPerTimestep(divisions, linksPerTimestep, timesteps)
            mergerLinks = hytra.core.jsongraph.getMergerLinks(linksPerTimestep, self.mergersPerTimestep, timesteps)

            # set up unresolved graph and then refine the nodes to get the resolved graph
            self._createUnresolvedGraph(divisionsPerTimestep, self.mergersPerTimestep, mergerLinks)
            self._prepareResolvedGraph()
            self._fitAndRefineNodes(self.detectionsPerTimestep,
                                    self.mergersPerTimestep,
                                    timesteps)

            # ------------------------------------------------------------
            # compute new object features
            objectFeatures = self._computeObjectFeatures(timesteps)

            # ------------------------------------------------------------
            # load transition classifier if any
            if transition_classifier_filename is not None:
                getLogger().info("\tLoading transition classifier")
                transitionClassifier = probabilitygenerator.RandomForestClassifier(
                    transition_classifier_path, transition_classifier_filename)
            else:
                getLogger().info("\tUsing distance based transition energies")
                transitionClassifier = None

            # ------------------------------------------------------------
            # run min-cost max-flow to find merger assignments
            getLogger().info("Running min-cost max-flow to find resolved merger assignments")

            nodeFlowMap, arcFlowMap = self._minCostMaxFlowMergerResolving(objectFeatures, transitionClassifier)

            # ------------------------------------------------------------
            # fuse results into a new solution

            # 1.) replace merger nodes in JSON graph by their replacements -> new JSON graph
            #     update UUID to traxel map.
            #     a) how do we deal with the smaller number of states?
            #        Does it matter as we're done with tracking anyway..?

            def mergerNodeFilter(jsonNode):
                uuid = int(jsonNode['id'])
                traxels = uuidToTraxelMap[uuid]
                return not any(t[1] in self.mergersPerTimestep[str(t[0])] for t in traxels)

            def mergerLinkFilter(jsonLink):
                srcUuid = int(jsonLink['src'])
                destUuid = int(jsonLink['dest'])
                srcTraxels = uuidToTraxelMap[srcUuid]
                destTraxels = uuidToTraxelMap[destUuid]

                # return True if there was no traxel in either source or target node that was a merger.
                return not (any(t[1] in self.mergersPerTimestep[str(t[0])] for t in srcTraxels) or any(t[1] in self.mergersPerTimestep[str(t[0])] for t in destTraxels))

            self.model = self._refineModel(uuidToTraxelMap,
                                           traxelIdPerTimestepToUniqueIdMap,
                                           mergerNodeFilter,
                                           mergerLinkFilter)

            # 2.) new result = union(old result, resolved mergers) - old mergers

            self.result = self._refineResult(nodeFlowMap,
                                             arcFlowMap,
                                             traxelIdPerTimestepToUniqueIdMap,
                                             mergerNodeFilter,
                                             mergerLinkFilter)

            # 3.) export refined segmentation
            self._exportRefinedSegmentation(timesteps)

            # return a dictionary telling about which mergers were resolved into what
            mergerDict = {}
            for n in self.unresolvedGraph.nodes_iter():
                # skip non-mergers
                if not 'newIds' in self.unresolvedGraph.node[n] or len(self.unresolvedGraph.node[n]['newIds']) < 2:
                    continue
                mergerDict.setdefault(n[0], {})[n[1]] = self.unresolvedGraph.node[n]['newIds']

            return mergerDict
    
    def relabelMergers(self, labelImage, time):
        """
        Calls the merger resolving plugin to relabel the mergers based on a previously found fit,
        which is stored in the hypotheses graph node
        """
        t = str(time)
        
        if self.detectionsPerTimestep is not None and t in self.detectionsPerTimestep:
            for idx in self.detectionsPerTimestep[t]:
                node = (time, idx)

                if idx not in self.mergersPerTimestep[t]:
                    continue
                
                # use fits stored in graph
                fits = self.unresolvedGraph.node[node]['fits']
                newIds = self.unresolvedGraph.node[node]['newIds']
                
                # use merger resolving plugin to update labelImage with merger IDs
                self.mergerResolverPlugin.updateLabelImage(labelImage, idx, fits, newIds)
          
        return labelImage
Example #17
0
class IlpProbabilityGenerator(ProbabilityGenerator):
    """
    The IlpProbabilityGenerator is a python wrapper around pgmlink's C++ traxelstore,
    but with the functionality to compute all region features 
    and evaluate the division/count/transition classifiers.
    """
    def __init__(self,
                 ilpOptions,
                 turnOffFeatures=[],
                 useMultiprocessing=True,
                 pluginPaths=['hytra/plugins'],
                 verbose=False):
        self._useMultiprocessing = useMultiprocessing
        self._options = ilpOptions
        self._pluginPaths = pluginPaths
        self._pluginManager = TrackingPluginManager(
            turnOffFeatures=turnOffFeatures,
            verbose=verbose,
            pluginPaths=pluginPaths)
        self._pluginManager.setImageProvider(ilpOptions.imageProviderName)
        self._pluginManager.setFeatureSerializer(
            ilpOptions.featureSerializerName)

        self._countClassifier = None
        self._divisionClassifier = None
        self._transitionClassifier = None

        self._loadClassifiers()

        self.shape, self.timeRange = self._getShapeAndTimeRange()

        # set default division feature names
        self._divisionFeatureNames = [
            'ParentChildrenRatio_Count', 'ParentChildrenRatio_Mean',
            'ChildrenRatio_Count', 'ChildrenRatio_Mean',
            'ParentChildrenAngle_RegionCenter',
            'ChildrenRatio_SquaredDistances'
        ]

        # other parameters that one might want to set
        self.x_scale = 1.0
        self.y_scale = 1.0
        self.z_scale = 1.0
        self.divisionProbabilityFeatureName = 'divProb'
        self.detectionProbabilityFeatureName = 'detProb'

        self.TraxelsPerFrame = {}
        ''' this public variable contains all traxels if we're not using pgmlink '''

    def _loadClassifiers(self):
        if self._options.objectCountClassifierPath != None and self._options.objectCountClassifierFilename != None:
            self._countClassifier = RandomForestClassifier(
                self._options.objectCountClassifierPath,
                self._options.objectCountClassifierFilename, self._options)
        if self._options.divisionClassifierPath != None and self._options.divisionClassifierFilename != None:
            self._divisionClassifier = RandomForestClassifier(
                self._options.divisionClassifierPath,
                self._options.divisionClassifierFilename, self._options)
        if self._options.transitionClassifierPath != None and self._options.transitionClassifierFilename != None:
            self._transitionClassifier = RandomForestClassifier(
                self._options.transitionClassifierPath,
                self._options.transitionClassifierFilename, self._options)

    def __getstate__(self):
        '''
        We define __getstate__ and __setstate__ to exclude the random forests from being pickled,
        as that is not allowed.

        See https://docs.python.org/3/library/pickle.html#pickle-state for more details.
        '''
        # Copy the object's state from self.__dict__ which contains
        # all our instance attributes. Always use the dict.copy()
        # method to avoid modifying the original state.
        state = self.__dict__.copy()
        # Remove the unpicklable entries.
        del state['_countClassifier']
        del state['_divisionClassifier']
        del state['_transitionClassifier']
        return state

    def __setstate__(self, state):
        # Restore instance attributes
        self.__dict__.update(state)
        # Restore the random forests by reading them from scratch
        self._loadClassifiers()

    def computeRegionFeatures(self, rawImage, labelImage, frameNumber):
        """
        Computes all region features for all objects in the given image
        """
        assert (labelImage.dtype == np.uint32)

        moreFeats, ignoreNames = self._pluginManager.applyObjectFeatureComputationPlugins(
            len(labelImage.shape), rawImage, labelImage, frameNumber,
            self._options.rawImageFilename)
        frameFeatureItems = []
        for f in moreFeats:
            frameFeatureItems = frameFeatureItems + f.items()
        frameFeatures = dict(frameFeatureItems)

        # delete the "Global<Min/Max>" features as they are not nice when iterating over everything
        for k in ignoreNames:
            if k in frameFeatures.keys():
                del frameFeatures[k]

        return frameFeatures

    def computeDivisionFeatures(self, featuresAtT, featuresAtTPlus1,
                                labelImageAtTPlus1):
        """
        Computes the division features for all objects in the images
        """
        fm = hytra.core.divisionfeatures.FeatureManager(
            ndim=self.getNumDimensions())
        return fm.computeFeatures_at(featuresAtT, featuresAtTPlus1,
                                     labelImageAtTPlus1,
                                     self._divisionFeatureNames)

    def setDivisionFeatures(self, divisionFeatures):
        """
        Set which features should be computed explicitly for divisions by giving a list of strings.
        Each string could be a combination of <operation>_<feature>, where Operation is one of:
            * ParentIdentity
            * SquaredDistances
            * ChildrenRatio
            * ParentChildrenAngle
            * ParentChildrenRatio

        And <feature> is any region feature plus "SquaredDistances"
        """
        # TODO: check that the strings are valid?
        self._divisionFeatureNames = divisionFeatures

    def getNumDimensions(self):
        """
        Compute the number of dimensions which is the number of axis with more than 1 element
        """
        return np.count_nonzero(np.array(self.shape) != 1)

    def _getShapeAndTimeRange(self):
        """
        extract the shape from the labelimage
        """
        shape = self._pluginManager.getImageProvider().getImageShape(
            self._options.labelImageFilename, self._options.labelImagePath)
        timerange = self._pluginManager.getImageProvider().getTimeRange(
            self._options.labelImageFilename, self._options.labelImagePath)
        return shape, timerange

    def getLabelImageForFrame(self, timeframe):
        """
        Get the label image(volume) of one time frame
        """
        rawImage = self._pluginManager.getImageProvider(
        ).getLabelImageForFrame(self._options.labelImageFilename,
                                self._options.labelImagePath, timeframe)
        return rawImage

    def getRawImageForFrame(self, timeframe):
        """
        Get the raw image(volume) of one time frame
        """
        rawImage = self._pluginManager.getImageProvider(
        ).getImageDataAtTimeFrame(self._options.rawImageFilename,
                                  self._options.rawImagePath, timeframe)
        return rawImage

    def _extractFeaturesForFrame(self, timeframe):
        """
        extract the features of one frame, return a dictionary of features,
        where each feature vector contains N entries per object 
        (where N is the dimensionality of the feature)
        """
        rawImage = self.getRawImageForFrame(timeframe)
        labelImage = self.getLabelImageForFrame(timeframe)

        return timeframe, self.computeRegionFeatures(rawImage, labelImage,
                                                     timeframe)

    def _extractDivisionFeaturesForFrame(self, timeframe, featuresPerFrame):
        """
        extract Division Features for one frame, and store them in the given featuresPerFrame dict
        """
        feats = {}
        if timeframe + 1 < self.timeRange[1]:
            labelImageAtTPlus1 = self.getLabelImageForFrame(timeframe + 1)
            feats = self.computeDivisionFeatures(
                featuresPerFrame[timeframe], featuresPerFrame[timeframe + 1],
                labelImageAtTPlus1)

        return timeframe, feats

    def _extractAllFeatures(self, dispyNodeIps=[], turnOffFeatures=[]):
        """
        Extract the features of all frames. 

        If a list of IP addresses is given e.g. as `dispyNodeIps = ["104.197.178.206","104.196.46.138"]`, 
        then the computation will be distributed across these nodes. Otherwise, multiprocessing will
        be used if `self._useMultiprocessing=True`, which it is by default.

        If `dispyNodeIps` is an empty list, then the feature extraction will be parallelized via
        multiprocessing.

        **TODO:** fix division feature computation for distributed mode
        """
        import logging
        # configure progress bar
        numSteps = self.timeRange[1] - self.timeRange[0]
        if self._divisionClassifier is not None:
            numSteps *= 2

        t0 = time.time()

        if (len(dispyNodeIps) == 0):
            # no dispy node IDs given, parallelize object feature computation via processes

            if self._useMultiprocessing:
                # use ProcessPoolExecutor, which instanciates as many processes as there CPU cores by default
                ExecutorType = concurrent.futures.ProcessPoolExecutor
                logging.getLogger('Traxelstore').info(
                    'Parallelizing feature extraction via multiprocessing on all cores!'
                )
            else:
                ExecutorType = DummyExecutor
                logging.getLogger('Traxelstore').info(
                    'Running feature extraction on single core!')

            featuresPerFrame = {}
            progressBar = ProgressBar(stop=numSteps)
            progressBar.show(increase=0)

            with ExecutorType() as executor:
                # 1st pass for region features
                jobs = []
                for frame in range(self.timeRange[0], self.timeRange[1]):
                    jobs.append(
                        executor.submit(computeRegionFeaturesOnCloud, frame,
                                        self._options.rawImageFilename,
                                        self._options.rawImagePath,
                                        self._options.rawImageAxes,
                                        self._options.labelImageFilename,
                                        self._options.labelImagePath,
                                        turnOffFeatures, self._pluginPaths))
                for job in concurrent.futures.as_completed(jobs):
                    progressBar.show()
                    frame, feats = job.result()
                    featuresPerFrame[frame] = feats

                # 2nd pass for division features
                if self._divisionClassifier is not None:
                    jobs = []
                    for frame in range(self.timeRange[0],
                                       self.timeRange[1] - 1):
                        jobs.append(
                            executor.submit(
                                computeDivisionFeaturesOnCloud, frame,
                                featuresPerFrame[frame],
                                featuresPerFrame[frame + 1],
                                self._pluginManager.getImageProvider(),
                                self._options.labelImageFilename,
                                self._options.labelImagePath,
                                self.getNumDimensions(),
                                self._divisionFeatureNames))

                    for job in concurrent.futures.as_completed(jobs):
                        progressBar.show()
                        frame, feats = job.result()
                        featuresPerFrame[frame].update(feats)

            # # serialize features??
            # for frame in range(self.timeRange[0], self.timeRange[1]):
            #     featureSerializer.storeFeaturesForFrame(featuresPerFrame[frame], frame)
        else:

            import logging
            logging.getLogger('Traxelstore').warning(
                'Parallelization with dispy is WORK IN PROGRESS!')
            import random
            import dispy
            cluster = dispy.JobCluster(computeRegionFeaturesOnCloud,
                                       nodes=dispyNodeIps,
                                       loglevel=logging.DEBUG,
                                       depends=[self._pluginManager],
                                       secret="teamtracking")

            jobs = []
            for frame in range(self.timeRange[0], self.timeRange[1]):
                job = cluster.submit(
                    frame,
                    self._options.rawImageFilename,
                    self._options.rawImagePath,
                    self._options.rawImageAxes,
                    self._options.labelImageFilename,
                    self._options.labelImagePath,
                    turnOffFeatures,
                    pluginPaths=['/home/carstenhaubold/embryonic/plugins'])
                job.id = frame
                jobs.append(job)

            for job in jobs:
                job()  # wait for job to finish
                print job.exception
                print job.stdout
                print job.stderr
                print job.id

            logging.getLogger('Traxelstore').warning(
                'Using dispy we cannot compute division features yet!')
            # # 2nd pass for division features
            # if self._divisionClassifier is not None:
            #     for frame in range(self.timeRange[0], self.timeRange[1]):
            #         progressBar.show()
            #         featuresPerFrame[frame].update(self._extractDivisionFeaturesForFrame(frame, featuresPerFrame)[1])

        t1 = time.time()
        getLogger().info("Feature computation took {} secs".format(t1 - t0))

        return featuresPerFrame

    def _setTraxelFeatureArray(self, traxel, featureArray, name):
        ''' store the specified `featureArray` in a `traxel`'s feature dictionary under the specified key=`name` '''
        if isinstance(featureArray, np.ndarray):
            featureArray = featureArray.flatten()
        traxel.add_feature_array(name, len(featureArray))
        for i, v in enumerate(featureArray):
            traxel.set_feature_value(name, i, float(v))

    def fillTraxels(self,
                    usePgmlink=True,
                    ts=None,
                    fs=None,
                    dispyNodeIps=[],
                    turnOffFeatures=[]):
        """
        Compute all the features and predict object count as well as division probabilities.
        Store the resulting information (and all other features) in the given pgmlink::TraxelStore,
        or create a new one if ts=None.

        usePgmlink: boolean whether pgmlink should be used and a pgmlink.TraxelStore and pgmlink.FeatureStore returned
        ts: an initial pgmlink.TraxelStore (only used if usePgmlink=True)
        fs: an initial pgmlink.FeatureStore (only used if usePgmlink=True)

        returns (ts, fs) but only if usePgmlink=True, otherwise it fills self.TraxelsPerFrame
        """
        if usePgmlink:
            import pgmlink
            if ts is None:
                ts = pgmlink.TraxelStore()
                fs = pgmlink.FeatureStore()
            else:
                assert (fs is not None)

        getLogger().info("Extracting features...")
        self._featuresPerFrame = self._extractAllFeatures(
            dispyNodeIps=dispyNodeIps, turnOffFeatures=turnOffFeatures)

        getLogger().info("Creating traxels...")
        progressBar = ProgressBar(stop=len(self._featuresPerFrame))
        progressBar.show(increase=0)

        for frame, features in self._featuresPerFrame.iteritems():
            # predict random forests
            if self._countClassifier is not None:
                objectCountProbabilities = self._countClassifier.predictProbabilities(
                    features=None, featureDict=features)

            if self._divisionClassifier is not None and frame + 1 < self.timeRange[
                    1]:
                divisionProbabilities = self._divisionClassifier.predictProbabilities(
                    features=None, featureDict=features)

            # create traxels for all objects
            for objectId in range(1, features.values()[0].shape[0]):
                # print("Frame {} Object {}".format(frame, objectId))
                pixelSize = features['Count'][objectId]
                if pixelSize == 0 or (self._options.sizeFilter is not None \
                        and (pixelSize < self._options.sizeFilter[0] \
                                     or pixelSize > self._options.sizeFilter[1])):
                    continue

                # create traxel
                if usePgmlink:
                    traxel = pgmlink.Traxel()
                else:
                    traxel = Traxel()
                traxel.Id = objectId
                traxel.Timestep = frame

                # add raw features
                for key, val in features.iteritems():
                    if key == 'id':
                        traxel.idInSegmentation = val[objectId]
                    elif key == 'filename':
                        traxel.segmentationFilename = val[objectId]
                    else:
                        try:
                            if isinstance(
                                    val,
                                    list):  # polygon feature returns a list!
                                featureValues = val[objectId]
                            else:
                                featureValues = val[objectId, ...]
                        except:
                            getLogger().error(
                                "Could not get feature values of {} for key {} from matrix with shape {}"
                                .format(objectId, key, val.shape))
                            raise AssertionError()
                        try:
                            self._setTraxelFeatureArray(
                                traxel, featureValues, key)
                            if key == 'RegionCenter':
                                self._setTraxelFeatureArray(
                                    traxel, featureValues, 'com')
                        except:
                            getLogger().error(
                                "Could not add feature array {} for {}".format(
                                    featureValues, key))
                            raise AssertionError()

                # add random forest predictions
                if self._countClassifier is not None:
                    self._setTraxelFeatureArray(
                        traxel, objectCountProbabilities[objectId, :],
                        self.detectionProbabilityFeatureName)

                if self._divisionClassifier is not None and frame + 1 < self.timeRange[
                        1]:
                    self._setTraxelFeatureArray(
                        traxel, divisionProbabilities[objectId, :],
                        self.divisionProbabilityFeatureName)

                # set other parameters
                traxel.set_x_scale(self.x_scale)
                traxel.set_y_scale(self.y_scale)
                traxel.set_z_scale(self.z_scale)

                if usePgmlink:
                    # add to pgmlink's traxelstore
                    ts.add(fs, traxel)
                else:
                    self.TraxelsPerFrame.setdefault(frame,
                                                    {})[objectId] = traxel
            progressBar.show()

        if usePgmlink:
            return ts, fs

    def getTraxelFeatureDict(self, frame, objectId):
        """
        Getter method for features per traxel
        """
        assert self._featuresPerFrame != None
        traxelFeatureDict = {}
        for k, v in self._featuresPerFrame[frame].iteritems():
            if 'Polygon' in k:
                traxelFeatureDict[k] = v[objectId]
            else:
                traxelFeatureDict[k] = v[objectId, ...]
        return traxelFeatureDict

    def getTransitionFeatureVector(self, featureDictObjectA,
                                   featureDictObjectB, selectedFeatures):
        """
        Return component wise difference and product of the selected features as input for the TransitionClassifier
        """
        features = np.array(
            self._pluginManager.
            applyTransitionFeatureVectorConstructionPlugins(
                featureDictObjectA, featureDictObjectB, selectedFeatures))
        features = np.expand_dims(features, axis=0)
        return features
def run_pipeline(options):
    """
    Run the complete tracking pipeline with competing segmentation hypotheses
    """

    # set up probabilitygenerator (aka traxelstore) and hypothesesgraph or load them from a dump
    if options.load_graph_filename is not None:
        getLogger().info("Loading state from file: " + options.load_graph_filename)
        with gzip.open(options.load_graph_filename, 'r') as graphDump:
            ilpOptions = pickle.load(graphDump)
            probGenerator = pickle.load(graphDump)
            fieldOfView = pickle.load(graphDump)
            hypotheses_graph = pickle.load(graphDump)
            trackingGraph = pickle.load(graphDump)
        getLogger().info("Done loading state from file")
    else:
        fieldOfView, hypotheses_graph, ilpOptions, probGenerator, trackingGraph = setupGraph(options)
        
        if options.dump_graph_filename is not None:
            getLogger().info("Saving state to file: " + options.dump_graph_filename)
            with gzip.open(options.dump_graph_filename, 'w') as graphDump:
                pickle.dump(ilpOptions, graphDump)
                pickle.dump(probGenerator, graphDump)
                pickle.dump(fieldOfView, graphDump)
                pickle.dump(hypotheses_graph, graphDump)
                pickle.dump(trackingGraph, graphDump)
            getLogger().info("Done saving state to file")

    # map groundtruth to hypothesesgraph if all required variables are specified
    weights = None
    if options.gt_label_image_file is not None and options.gt_label_image_path is not None \
        and options.gt_text_file is not None and options.gt_jaccard_threshold is not None:
        weights = mapGroundTruth(options, hypotheses_graph, trackingGraph, probGenerator)

    # track
    result = runTracking(options, trackingGraph, weights)

    # insert the solution into the hypotheses graph and from that deduce the lineages
    getLogger().info("Inserting solution into graph")
    hypotheses_graph.insertSolution(result)
    hypotheses_graph.computeLineage()

    mappings = {} # dictionary over timeframes, containing another dict objectId -> trackId per frame
    tracks = {} # stores a list of timeframes per track, so that we can find from<->to per track
    trackParents = {} # store the parent trackID of a track if known

    for n in hypotheses_graph.nodeIterator():
        frameMapping = mappings.setdefault(n[0], {})
        if 'trackId' not in hypotheses_graph._graph.node[n]:
            raise ValueError("You need to compute the Lineage of every node before accessing the trackId!")
        trackId = hypotheses_graph._graph.node[n]['trackId']
        traxel = hypotheses_graph._graph.node[n]['traxel']
        if trackId is not None:
            frameMapping[(traxel.idInSegmentation, traxel.segmentationFilename)] = trackId
        if trackId in tracks:
            tracks[trackId].append(n[0])
        else:
            tracks[trackId] = [n[0]]
        if 'parent' in hypotheses_graph._graph.node[n]:
            assert(trackId not in trackParents)
            trackParents[trackId] = hypotheses_graph._graph.node[hypotheses_graph._graph.node[n]['parent']]['trackId']

    # write res_track.txt
    getLogger().info("Writing track text file")
    trackDict = {}
    for trackId, timestepList in tracks.iteritems():
        timestepList.sort()
        try:
            parent = trackParents[trackId]
        except KeyError:
            parent = 0
        trackDict[trackId] = [parent, min(timestepList), max(timestepList)]
    save_tracks(trackDict, options) 

    # export results
    getLogger().info("Saving relabeled images")
    pluginManager = TrackingPluginManager(verbose=options.verbose, pluginPaths=options.pluginPaths)
    pluginManager.setImageProvider('LocalImageLoader')
    imageProvider = pluginManager.getImageProvider()
    timeRange = probGenerator.timeRange

    for timeframe in range(timeRange[0], timeRange[1]):
        label_images = {}
        for f, p in zip(options.label_image_files, options.label_image_paths):
            label_images[f] = imageProvider.getLabelImageForFrame(f, p, timeframe)
        remapped_label_image = remap_label_image(label_images, mappings[timeframe])
        save_frame_to_tif(timeframe, remapped_label_image, options)
    # Do the tracking
    start = time.time()

    feature_path = options.feats_path
    with_div = not bool(options.without_divisions)
    with_merger_prior = True

    # get selected time range
    time_range = [options.mints, options.maxts]
    if options.maxts == -1 and options.mints == 0:
        time_range = None
    
    try:
        # find shape of dataset
        pluginManager = TrackingPluginManager(verbose=options.verbose, pluginPaths=options.pluginPaths)
        shape = pluginManager.getImageProvider().getImageShape(ilp_fn, options.label_img_path)
        data_time_range = pluginManager.getImageProvider().getTimeRange(ilp_fn, options.label_img_path)
        
        if time_range is not None and time_range[1] < 0:
            time_range[1] += data_time_range[1]

    except:
        logging.warning("Could not read shape and time range from images")

    # set average object size if chosen
    obj_size = [0]
    if options.avg_obj_size != 0:
        obj_size[0] = options.avg_obj_size
    else:
        options.avg_obj_size = obj_size
Example #20
0
class MergerResolver(object):
    """
    Base class for all merger resolving implementations. Use one of the derived classes
    that handle reading/writing data to the respective sources.
    """
    def __init__(self,
                 pluginPaths=[os.path.abspath('../hytra/plugins')],
                 numSplits=None,
                 verbose=False,
                 progressVisitor=DefaultProgressVisitor()):
        self.unresolvedGraph = None
        self.resolvedGraph = None
        self.mergersPerTimestep = None
        self.detectionsPerTimestep = None
        self.pluginManager = TrackingPluginManager(verbose=verbose,
                                                   pluginPaths=pluginPaths)
        self.mergerResolverPlugin = self.pluginManager.getMergerResolver()
        self.numSplits = numSplits

        # should be filled by constructors of derived classes!
        self.model = None
        self.result = None
        self.progressVisitor = progressVisitor

    def _createUnresolvedGraph(self,
                               divisionsPerTimestep,
                               mergersPerTimestep,
                               mergerLinks,
                               withFullGraph=False):
        """
        Set up a networkx graph consisting of mergers that need to be resolved (not resolved yet!)
        and their direct neighbors.

        ** returns ** the `unresolvedGraph`
        """

        self.unresolvedGraph = nx.DiGraph()

        def source(timestep, link):
            return int(timestep) - 1, link[0]

        def target(timestep, link):
            return int(timestep), link[1]

        # Recompute full graph
        if withFullGraph:
            self.unresolvedGraph = self.hypothesesGraph._graph.copy()

            # Add division parameter to nodes
            # TODO: Add the division parameter only to nodes that contain divisions (we're already doing these with 'count')
            lastframe = max(divisionsPerTimestep.keys(), key=int)
            for node in self.unresolvedGraph.nodes_iter():
                timestep, idx = node

                if divisionsPerTimestep is not None and int(timestep) < int(
                        lastframe):
                    division = idx in divisionsPerTimestep[str(
                        timestep + 1)]  # +1 screams for lastframe condition.
                else:
                    division = False

                self.unresolvedGraph.node[node]['division'] = division

            # Add count parameter to nodes
            for t, link in mergerLinks:
                for node in [source(t, link), target(t, link)]:
                    timestep, idx = node
                    if idx in mergersPerTimestep[str(timestep)]:
                        count = mergersPerTimestep[str(timestep)][idx]
                        self.unresolvedGraph.node[node]['count'] = count

        # Recompute graph only with merger nodes and neighbors
        else:

            def addNode(node):
                ''' add a node to the unresolved graph and fill in the properties `division` and `count` '''
                intT, idx = node

                lastframe = max(divisionsPerTimestep.keys(), key=int)
                if divisionsPerTimestep is not None and int(intT) < int(
                        lastframe):
                    division = idx in divisionsPerTimestep[str(
                        intT + 1)]  # +1 screams for lastframe condition.
                else:
                    division = False
                count = 1
                if idx in mergersPerTimestep[str(intT)]:
                    assert (not division)
                    count = mergersPerTimestep[str(intT)][idx]
                self.unresolvedGraph.add_node(node,
                                              division=division,
                                              count=count)

            # add nodes
            for t, link in mergerLinks:
                for n in [source(t, link), target(t, link)]:
                    if not self.unresolvedGraph.has_node(n):
                        addNode(n)
                self.unresolvedGraph.add_edge(source(t, link), target(t, link))

        return self.unresolvedGraph

    def _prepareResolvedGraph(self):
        """
        
        ** returns ** the `resolvedGraph`
        """
        self.resolvedGraph = self.unresolvedGraph.copy()

        return self.resolvedGraph

    def _readLabelImage(self, timeframe):
        '''
        Should return the labelimage for the given timeframe
        '''
        raise NotImplementedError()

    def _fitAndRefineNodes(self, detectionsPerTimestep, mergersPerTimestep,
                           timesteps):
        '''
        Update segmentation of mergers (nodes in unresolvedGraph) from first timeframe to last
        and create new nodes in `resolvedGraph`. Links to merger nodes are duplicated to all new nodes.

        Uses the mergerResolver plugin to update the segmentations in the labelImages.
        '''

        intTimesteps = [int(t) for t in timesteps]
        intTimesteps.sort()

        for intT in intTimesteps:
            t = str(intT)
            # use image provider plugin to load labelimage
            labelImage = self._readLabelImage(int(t))
            nextObjectId = labelImage.max() + 1

            for idx in detectionsPerTimestep[t]:
                node = (intT, idx)
                if node not in self.resolvedGraph:
                    continue

                count = 1
                if idx in mergersPerTimestep[t]:
                    count = mergersPerTimestep[t][idx]
                getLogger().debug(
                    "Looking at node {} in timestep {} with count {}".format(
                        idx, t, count))

                # collect initializations from incoming
                initializations = []
                for predecessor, _ in self.unresolvedGraph.in_edges(node):
                    initializations.extend(
                        self.unresolvedGraph.node[predecessor]['fits'])
                # TODO: what shall we do if e.g. a 2-merger and a single object merge to 2 + 1,
                # so there are 3 initializations for the 2-merger, and two initializations for the 1 merger?
                # What does pgmlink do in that case?

                # use merger resolving plugin to fit `count` objects, also updates labelimage!
                fittedObjects = list(
                    self.mergerResolverPlugin.resolveMerger(
                        labelImage, idx, nextObjectId, count, initializations))
                assert (len(fittedObjects) == count)

                # split up node if count > 1, duplicate incoming and outgoing arcs
                if count > 1:
                    for idx in range(nextObjectId, nextObjectId + count):
                        newNode = (intT, idx)
                        self.resolvedGraph.add_node(newNode,
                                                    division=False,
                                                    count=1,
                                                    origin=node)

                        for e in self.unresolvedGraph.out_edges(node):
                            self.resolvedGraph.add_edge(newNode, e[1])
                        for e in self.unresolvedGraph.in_edges(node):
                            if 'newIds' in self.unresolvedGraph.node[e[0]]:
                                for newId in self.unresolvedGraph.node[
                                        e[0]]['newIds']:
                                    self.resolvedGraph.add_edge(
                                        (e[0][0], newId), newNode)
                            else:
                                self.resolvedGraph.add_edge(e[0], newNode)

                    self.resolvedGraph.remove_node(node)
                    self.unresolvedGraph.node[node]['newIds'] = range(
                        nextObjectId, nextObjectId + count)
                    nextObjectId += count

                # each unresolved node stores its fitted shape(s) to be used
                # as initialization in the next frame, this way division duplicates
                # and de-merged nodes in the resolved graph do not need to store a fit as well
                self.unresolvedGraph.node[node]['fits'] = fittedObjects

        # import matplotlib.pyplot as plt
        # nx.draw_networkx(resolvedGraph)
        # plt.savefig("/Users/chaubold/test.pdf")

    def _minCostMaxFlowMergerResolving(self,
                                       objectFeatures,
                                       transitionClassifier=None,
                                       transitionParameter=5.0):
        """
        Find the optimal assignments within the `resolvedGraph` by running min-cost max-flow from the
        `dpct` module.

        Converts the `resolvedGraph` to our JSON model structure, predicts the transition probabilities
        either using the given transitionClassifier, or using distance-based probabilities.

        **returns** a `nodeFlowMap` and `arcFlowMap` holding information on the usage of the respective nodes and links

        **Note:** cannot use `networkx` flow methods because they don't work with floating point weights.
        """

        trackingGraph = JsonTrackingGraph(progressVisitor=self.progressVisitor)
        for node in self.resolvedGraph.nodes_iter():
            additionalFeatures = {}
            additionalFeatures['nid'] = node

            # nodes with no in/out
            numStates = 2

            if len(self.resolvedGraph.in_edges(node)) == 0:
                # division nodes with no incoming arcs offer 2 units of flow without the need to de-merge
                if node in self.unresolvedGraph.nodes(
                ) and self.unresolvedGraph.node[node]['division'] and len(
                        self.unresolvedGraph.out_edges(node)) == 2:
                    numStates = 3
                additionalFeatures['appearanceFeatures'] = [
                    [i**2 * 0.01] for i in range(numStates)
                ]
            if len(self.resolvedGraph.out_edges(node)) == 0:
                assert (
                    numStates == 2
                )  # division nodes with no incoming should have outgoing, or they shouldn't show up in resolved graph
                additionalFeatures['disappearanceFeatures'] = [
                    [i**2 * 0.01] for i in range(numStates)
                ]

            features = [[i**2] for i in range(numStates)]
            uuid = trackingGraph.addDetectionHypotheses(
                features, **additionalFeatures)
            self.resolvedGraph.node[node]['id'] = uuid

        for edge in self.resolvedGraph.edges_iter():
            src = self.resolvedGraph.node[edge[0]]['id']
            dest = self.resolvedGraph.node[edge[1]]['id']

            featuresAtSrc = objectFeatures[edge[0]]
            featuresAtDest = objectFeatures[edge[1]]

            if transitionClassifier is not None:
                try:
                    featVec = self.pluginManager.applyTransitionFeatureVectorConstructionPlugins(
                        featuresAtSrc, featuresAtDest,
                        transitionClassifier.selectedFeatures)
                except:
                    getLogger().error(
                        "Could not compute transition features of link {}->{}:"
                        .format(src, dest))
                    getLogger().error(featuresAtSrc)
                    getLogger().error(featuresAtDest)
                    raise
                featVec = np.expand_dims(np.array(featVec), axis=0)
                probs = transitionClassifier.predictProbabilities(featVec)[0]
            else:
                dist = np.linalg.norm(featuresAtDest['RegionCenter'] -
                                      featuresAtSrc['RegionCenter'])
                prob = np.exp(-dist / transitionParameter)
                probs = [1.0 - prob, prob]

            trackingGraph.addLinkingHypotheses(src, dest,
                                               listify(negLog(probs)))

        # Set TraxelToUniqueId on resolvedGraph's json graph
        uuidToTraxelMap = {}
        traxelIdPerTimestepToUniqueIdMap = {}

        for node in self.resolvedGraph.nodes_iter():
            uuid = self.resolvedGraph.node[node]['id']
            uuidToTraxelMap[uuid] = [node]

            for t in uuidToTraxelMap[uuid]:
                traxelIdPerTimestepToUniqueIdMap.setdefault(str(t[0]), {})[str(
                    t[1])] = uuid

        trackingGraph.setTraxelToUniqueId(traxelIdPerTimestepToUniqueIdMap)

        # track
        import dpct

        weights = {"weights": [1, 1, 1, 1]}

        if not self.numSplits:
            mergerResult = dpct.trackMaxFlow(trackingGraph.model, weights)
        else:
            getLogger().info("Running split tracking with {} splits.".format(
                self.numSplits))
            mergerResult = SplitTracking.trackFlowBasedWithSplits(
                trackingGraph.model,
                weights,
                numSplits=self.numSplits,
                withMergerResolver=True)

        # transform results to dictionaries that can be indexed by id or (src,dest)
        nodeFlowMap = dict([(int(d['id']), int(d['value']))
                            for d in mergerResult['detectionResults']])
        arcFlowMap = dict([((int(l['src']), int(l['dest'])), int(l['value']))
                           for l in mergerResult['linkingResults']])

        return nodeFlowMap, arcFlowMap

    def _refineModel(self, uuidToTraxelMap, traxelIdPerTimestepToUniqueIdMap,
                     mergerNodeFilter, mergerLinkFilter):
        """
        Take the `self.model` (JSON format) with mergers, remove the merger nodes, but add new
        de-merged nodes and links. Also updates `traxelIdPerTimestepToUniqueIdMap` locally and in the resulting file,
        such that the traxel IDs match the new connected component IDs in the refined images.

        `mergerNodeFilter` and `mergerLinkFilter` are methods that can filter merger detections
        and links from the respective lists in the `model` dict.

        **Returns** the updated `model` dictionary, which is the same as the input `model` (works in-place)
        """

        # remove merger detections
        self.model['segmentationHypotheses'] = [
            seg for seg in self.model['segmentationHypotheses']
            if mergerNodeFilter(seg)
        ]

        # remove merger links
        self.model['linkingHypotheses'] = [
            link for link in self.model['linkingHypotheses']
            if mergerLinkFilter(link)
        ]

        # insert new nodes and update UUID to traxel map
        nextUuid = max(uuidToTraxelMap.keys()) + 1
        for node in self.unresolvedGraph.nodes_iter():
            if 'count' in self.unresolvedGraph.node[
                    node] and self.unresolvedGraph.node[node]['count'] > 1:
                newIds = self.unresolvedGraph.node[node]['newIds']
                del traxelIdPerTimestepToUniqueIdMap[str(node[0])][str(
                    node[1])]
                for newId in newIds:
                    newDetection = {}
                    newDetection['id'] = nextUuid
                    newDetection['timestep'] = [node[0], node[0]]
                    self.model['segmentationHypotheses'].append(newDetection)
                    traxelIdPerTimestepToUniqueIdMap[str(
                        node[0])][str(newId)] = nextUuid
                    nextUuid += 1

        # insert new links
        for edge in self.resolvedGraph.edges_iter():
            newLink = {}
            newLink['src'] = traxelIdPerTimestepToUniqueIdMap[str(
                edge[0][0])][str(edge[0][1])]
            newLink['dest'] = traxelIdPerTimestepToUniqueIdMap[str(
                edge[1][0])][str(edge[1][1])]
            self.model['linkingHypotheses'].append(newLink)

        # save
        return self.model

    def _refineResult(self, nodeFlowMap, arcFlowMap,
                      traxelIdPerTimestepToUniqueIdMap, mergerNodeFilter,
                      mergerLinkFilter):
        """
        Update the `self.result` dict by removing the mergers and adding the refined nodes and links.

        Operates on a `result` dictionary in our JSON result style with mergers,
        the resolved and unresolved graph as well as
        the `nodeFlowMap` and `arcFlowMap` obtained by running tracking on the `resolvedGraph`.

        Updates the `result` dictionary so that all merger nodes are removed but the new nodes
        are contained with the appropriate links and values.

        `mergerNodeFilter` and `mergerLinkFilter` are methods that can filter merger detections
        and links from the respective lists in the `result` dict.

        **Returns** the updated `result` dict, which is the same as the input `result` (works in-place)
        """

        # filter merger edges
        self.result['detectionResults'] = [
            r for r in self.result['detectionResults'] if mergerNodeFilter(r)
        ]
        self.result['linkingResults'] = [
            r for r in self.result['linkingResults'] if mergerLinkFilter(r)
        ]

        # add new nodes
        for node in self.unresolvedGraph.nodes_iter():
            if 'count' in self.unresolvedGraph.node[
                    node] and self.unresolvedGraph.node[node]['count'] > 1:
                newIds = self.unresolvedGraph.node[node]['newIds']
                for newId in newIds:
                    uuid = traxelIdPerTimestepToUniqueIdMap[str(
                        node[0])][str(newId)]
                    resolvedNode = (node[0], newId)
                    resolvedResultId = self.resolvedGraph.node[resolvedNode][
                        'id']
                    newDetection = {
                        'id': uuid,
                        'value': nodeFlowMap[resolvedResultId]
                    }
                    self.result['detectionResults'].append(newDetection)

        # add new links
        for edge in self.resolvedGraph.edges_iter():
            newLink = {}
            newLink['src'] = traxelIdPerTimestepToUniqueIdMap[str(
                edge[0][0])][str(edge[0][1])]
            newLink['dest'] = traxelIdPerTimestepToUniqueIdMap[str(
                edge[1][0])][str(edge[1][1])]
            srcId = self.resolvedGraph.node[edge[0]]['id']
            destId = self.resolvedGraph.node[edge[1]]['id']
            newLink['value'] = arcFlowMap[(srcId, destId)]
            self.result['linkingResults'].append(newLink)

        return self.result

    def _exportRefinedSegmentation(self, timesteps):
        """
        Store the resulting label images, if needed.

        `labelImages` is a dictionary with str(timestep) as keys.
        """
        pass

    def _computeObjectFeatures(self, timesteps):
        '''
        Return the features per object as nested dictionaries:
        { (int(Timestep), int(Id)):{ "FeatureName" : np.array(value), "NextFeature": ...} }
        '''
        pass

    # ------------------------------------------------------------
    def run(self,
            transition_classifier_filename=None,
            transition_classifier_path=None):
        """
        Run merger resolving

        1. find mergers in the given model and result
        2. build graph of the unresolved (merger) nodes and their direct neighbors
        3. use a mergerResolving plugin to refine the merger nodes and their segmentation
        4. run min-cost max-flow tracking to find the fate of all the de-merged objects
        5. export refined segmentation, update member variables `model` and `result`

        **Returns** a nested dictionary, indexed first by time, then object Id, containing a list of new segmentIDs per merger
        """

        traxelIdPerTimestepToUniqueIdMap, uuidToTraxelMap = hytra.core.jsongraph.getMappingsBetweenUUIDsAndTraxels(
            self.model)
        # timesteps = [t for t in traxelIdPerTimestepToUniqueIdMap.keys()]
        # there might be empty frames. We want them as output too.
        timesteps = [
            str(t) for t in range(
                int(min(traxelIdPerTimestepToUniqueIdMap.keys())),
                max([
                    int(idx)
                    for idx in traxelIdPerTimestepToUniqueIdMap.keys()
                ]) + 1)
        ]

        mergers, detections, links, divisions = hytra.core.jsongraph.getMergersDetectionsLinksDivisions(
            self.result, uuidToTraxelMap)

        # ------------------------------------------------------------

        # it may be, that there are no mergers, so do basically nothing, just copy all the ingoing data
        if len(mergers) == 0:
            getLogger().info(
                "The maximum number of objects is 1, so nothing to be done. Writing the output..."
            )
            self._exportRefinedSegmentation(timesteps)

        else:
            self.mergersPerTimestep = hytra.core.jsongraph.getMergersPerTimestep(
                mergers, timesteps)
            self.detectionsPerTimestep = hytra.core.jsongraph.getDetectionsPerTimestep(
                detections, timesteps)

            linksPerTimestep = hytra.core.jsongraph.getLinksPerTimestep(
                links, timesteps)
            divisionsPerTimestep = hytra.core.jsongraph.getDivisionsPerTimestep(
                divisions, linksPerTimestep, timesteps)
            mergerLinks = hytra.core.jsongraph.getMergerLinks(
                linksPerTimestep, self.mergersPerTimestep, timesteps)

            # set up unresolved graph and then refine the nodes to get the resolved graph
            self._createUnresolvedGraph(divisionsPerTimestep,
                                        self.mergersPerTimestep, mergerLinks)
            self._prepareResolvedGraph()
            self._fitAndRefineNodes(self.detectionsPerTimestep,
                                    self.mergersPerTimestep, timesteps)

            # ------------------------------------------------------------
            # compute new object features
            objectFeatures = self._computeObjectFeatures(timesteps)

            # ------------------------------------------------------------
            # load transition classifier if any
            if transition_classifier_filename is not None:
                getLogger().info("\tLoading transition classifier")
                transitionClassifier = probabilitygenerator.RandomForestClassifier(
                    transition_classifier_path, transition_classifier_filename)
            else:
                getLogger().info("\tUsing distance based transition energies")
                transitionClassifier = None

            # ------------------------------------------------------------
            # run min-cost max-flow to find merger assignments
            getLogger().info(
                "Running min-cost max-flow to find resolved merger assignments"
            )

            nodeFlowMap, arcFlowMap = self._minCostMaxFlowMergerResolving(
                objectFeatures, transitionClassifier)

            # ------------------------------------------------------------
            # fuse results into a new solution

            # 1.) replace merger nodes in JSON graph by their replacements -> new JSON graph
            #     update UUID to traxel map.
            #     a) how do we deal with the smaller number of states?
            #        Does it matter as we're done with tracking anyway..?

            def mergerNodeFilter(jsonNode):
                uuid = int(jsonNode['id'])
                traxels = uuidToTraxelMap[uuid]
                return not any(t[1] in self.mergersPerTimestep[str(t[0])]
                               for t in traxels)

            def mergerLinkFilter(jsonLink):
                srcUuid = int(jsonLink['src'])
                destUuid = int(jsonLink['dest'])
                srcTraxels = uuidToTraxelMap[srcUuid]
                destTraxels = uuidToTraxelMap[destUuid]

                # return True if there was no traxel in either source or target node that was a merger.
                return not (any(t[1] in self.mergersPerTimestep[str(t[0])]
                                for t in srcTraxels)
                            or any(t[1] in self.mergersPerTimestep[str(t[0])]
                                   for t in destTraxels))

            self.model = self._refineModel(uuidToTraxelMap,
                                           traxelIdPerTimestepToUniqueIdMap,
                                           mergerNodeFilter, mergerLinkFilter)

            # 2.) new result = union(old result, resolved mergers) - old mergers

            self.result = self._refineResult(nodeFlowMap, arcFlowMap,
                                             traxelIdPerTimestepToUniqueIdMap,
                                             mergerNodeFilter,
                                             mergerLinkFilter)

            # 3.) export refined segmentation
            self._exportRefinedSegmentation(timesteps)

            # return a dictionary telling about which mergers were resolved into what
            mergerDict = {}
            for n in self.unresolvedGraph.nodes_iter():
                # skip non-mergers
                if not 'newIds' in self.unresolvedGraph.node[n] or len(
                        self.unresolvedGraph.node[n]['newIds']) < 2:
                    continue
                mergerDict.setdefault(
                    n[0], {})[n[1]] = self.unresolvedGraph.node[n]['newIds']

            return mergerDict

    def relabelMergers(self, labelImage, time):
        """
        Calls the merger resolving plugin to relabel the mergers based on a previously found fit,
        which is stored in the hypotheses graph node
        """
        t = str(time)

        if self.detectionsPerTimestep is not None and t in self.detectionsPerTimestep:
            for idx in self.detectionsPerTimestep[t]:
                node = (time, idx)

                if idx not in self.mergersPerTimestep[t]:
                    continue

                # use fits stored in graph
                fits = self.unresolvedGraph.node[node]['fits']
                newIds = self.unresolvedGraph.node[node]['newIds']

                # use merger resolving plugin to update labelImage with merger IDs
                self.mergerResolverPlugin.updateLabelImage(
                    labelImage, idx, fits, newIds)

        return labelImage
def writeEvents(timestep, activeLinks, activeDivisions, mergers, detections, fn, labelImagePath, ilpFilename, verbose, pluginPaths):
    dis = []
    app = []
    div = []
    mov = []
    mer = []
    mul = []

    pluginManager = TrackingPluginManager(verbose=verbose, pluginPaths=pluginPaths)
    
    logging.getLogger('json_result_to_events.py').debug("-- Writing results to {}".format(fn))
    try:
        # convert to ndarray for better indexing
        dis = np.asarray(dis)
        app = np.asarray(app)
        div = np.asarray([[k, v[0], v[1]] for k,v in activeDivisions.iteritems()])
        mov = np.asarray(activeLinks)
        mer = np.asarray([[k,v] for k,v in mergers.iteritems()])
        mul = np.asarray(mul)

        shape = pluginManager.getImageProvider().getImageShape(ilpFilename, labelImagePath)
        label_img = pluginManager.getImageProvider().getLabelImageForFrame(ilpFilename, labelImagePath, timestep)
        
        with h5py.File(fn, 'w') as dest_file:
            # write meta fields and copy segmentation from project
            seg = dest_file.create_group('segmentation')
            seg.create_dataset("labels", data=label_img, compression='gzip')
            meta = dest_file.create_group('objects/meta')
            ids = np.unique(label_img)
            ids = ids[ids > 0]
            valid = np.ones(ids.shape)
            meta.create_dataset("id", data=ids, dtype=np.uint32)
            meta.create_dataset("valid", data=valid, dtype=np.uint32)

            tg = dest_file.create_group("tracking")

            # write associations
            if app is not None and len(app) > 0:
                ds = tg.create_dataset("Appearances", data=app, dtype=np.int32)
                ds.attrs["Format"] = "cell label appeared in current file"

            if dis is not None and len(dis) > 0:
                ds = tg.create_dataset("Disappearances", data=dis, dtype=np.int32)
                ds.attrs["Format"] = "cell label disappeared in current file"

            if mov is not None and len(mov) > 0:
                ds = tg.create_dataset("Moves", data=mov, dtype=np.int32)
                ds.attrs["Format"] = "from (previous file), to (current file)"

            if div is not None and len(div) > 0:
                ds = tg.create_dataset("Splits", data=div, dtype=np.int32)
                ds.attrs["Format"] = "ancestor (previous file), descendant (current file), descendant (current file)"

            if mer is not None and len(mer) > 0:
                ds = tg.create_dataset("Mergers", data=mer, dtype=np.int32)
                ds.attrs["Format"] = "descendant (current file), number of objects"

            if mul is not None and len(mul) > 0:
                ds = tg.create_dataset("MultiFrameMoves", data=mul, dtype=np.int32)
                ds.attrs["Format"] = "from (given by timestep), to (current file), timestep"

        logging.getLogger('json_result_to_events.py').debug("-> results successfully written")
    except Exception as e:
        logging.getLogger('json_result_to_events.py').warning("ERROR while writing events: {}".format(str(e)))
Example #22
0
def computeRegionFeaturesOnCloud(
        frame,
        rawImageFilename,
        rawImagePath,
        rawImageAxes,
        labelImageFilename,
        labelImagePath,
        turnOffFeatures,
        pluginPaths=['hytra/plugins'],
        featuresPerFrame=None,
        imageProviderPluginName='LocalImageLoader',
        featureSerializerPluginName='LocalFeatureSerializer'):
    '''
    Allow to use dispy to schedule feature computation to nodes running a dispynode,
    or to use multiprocessing.

    **Parameters**

    * `frame`: the frame number
    * `rawImageFilename`: the base filename of the raw image volume, or a dvid server address
    * `rawImagePath`: path inside the raw image HDF5 file, or DVID dataset UUID
    * `rawImageAxes`: axes configuration of the raw data
    * `labelImageFilename`: the base filename of the label image volume, or a dvid server address
    * `labelImagePath`: path inside the label image HDF5 file, or DVID dataset UUID
    * `pluginPaths`: where all yapsy plugins are stored (should be absolute for DVID)

    **returns** the feature dictionary for this frame if `featureSerializerPluginName == 'LocalFeatureSerializer'`
    and `featuresPerFrame == None`.
    '''

    # set up plugin manager
    from hytra.pluginsystem.plugin_manager import TrackingPluginManager
    pluginManager = TrackingPluginManager(pluginPaths=pluginPaths,
                                          turnOffFeatures=turnOffFeatures,
                                          verbose=False)
    pluginManager.setImageProvider(imageProviderPluginName)
    pluginManager.setFeatureSerializer(featureSerializerPluginName)

    # load raw and label image (depending on chosen plugin this works via DVID or locally)
    rawImage = pluginManager.getImageProvider().getImageDataAtTimeFrame(
        rawImageFilename, rawImagePath, rawImageAxes, frame)
    labelImage = pluginManager.getImageProvider().getLabelImageForFrame(
        labelImageFilename, labelImagePath, frame)

    # untwist axes, if just x and y are messed up
    if rawImage.shape[0] == labelImage.shape[1] and rawImage.shape[
            1] == labelImage.shape[0]:
        labelImage = np.transpose(labelImage, axes=[1, 0])

    # compute features
    moreFeats, ignoreNames = pluginManager.applyObjectFeatureComputationPlugins(
        len(labelImage.shape), rawImage, labelImage, frame, rawImageFilename)

    # combine into one dictionary
    # WARNING: if there are multiple features with the same name, they will be overwritten!
    frameFeatureItems = []
    for f in moreFeats:
        frameFeatureItems = frameFeatureItems + f.items()
    frameFeatures = dict(frameFeatureItems)

    # delete all ignored features
    for k in ignoreNames:
        if k in frameFeatures.keys():
            del frameFeatures[k]

    # return or save features
    if featuresPerFrame is None and featureSerializerPluginName is 'LocalFeatureSerializer':
        # simply return resulting dict
        return frame, frameFeatures
    else:
        # set up feature serializer (local or DVID for now)
        featureSerializer = pluginManager.getFeatureSerializer()

        # server address and uuid are only used by the DVID serializer
        featureSerializer.server_address = labelImageFilename
        featureSerializer.uuid = labelImagePath

        # feature dictionary used by local serializer
        featureSerializer.features_per_frame = featuresPerFrame

        # store
        featureSerializer.storeFeaturesForFrame(frameFeatures, frame)