def findConflictingHypothesesInSeparateProcess(
        frame,
        labelImageFilenames,
        labelImagePaths,
        labelImageFrameIdToGlobalId,
        pluginPaths=['hytra/plugins'],
        imageProviderPluginName='LocalImageLoader'):
    """
    Look which objects between different segmentation hypotheses (given as different labelImages)
    overlap, and return a dictionary of those overlapping situations.

    Meant to be run in its own process using `concurrent.futures.ProcessPoolExecutor`
    """

    # set up plugin manager
    from hytra.pluginsystem.plugin_manager import TrackingPluginManager
    pluginManager = TrackingPluginManager(pluginPaths=pluginPaths,
                                          verbose=False)
    pluginManager.setImageProvider(imageProviderPluginName)

    overlaps = {}  # overlap dict: key=globalId, value=[list of globalIds]

    for labelImageIndexA in range(len(labelImageFilenames)):
        labelImageA = pluginManager.getImageProvider().getLabelImageForFrame(
            labelImageFilenames[labelImageIndexA],
            labelImagePaths[labelImageIndexA], frame)
        for labelImageIndexB in range(labelImageIndexA + 1,
                                      len(labelImageFilenames)):
            labelImageB = pluginManager.getImageProvider(
            ).getLabelImageForFrame(labelImageFilenames[labelImageIndexB],
                                    labelImagePaths[labelImageIndexB], frame)
            # check for overlaps - even a 1-pixel overlap is enough to be mutually exclusive!
            for objectIdA in np.unique(labelImageA):
                if objectIdA == 0:
                    continue
                overlapping = set(
                    np.unique(labelImageB[labelImageA == objectIdA])) - set(
                        [0])
                overlappingGlobalIds = [
                    labelImageFrameIdToGlobalId[(
                        labelImageFilenames[labelImageIndexB], frame, o)]
                    for o in overlapping
                ]
                globalIdA = labelImageFrameIdToGlobalId[(
                    labelImageFilenames[labelImageIndexA], frame, objectIdA)]
                overlaps.setdefault(globalIdA, []).extend(overlappingGlobalIds)
                for globalIdB in overlappingGlobalIds:
                    overlaps.setdefault(globalIdB, []).append(globalIdA)

    return frame, overlaps
示例#2
0
    def __init__(self,
                 pluginPaths=[os.path.abspath('../hytra/plugins')],
                 verbose=False):
        self.unresolvedGraph = None
        self.resolvedGraph = None
        self.mergersPerTimestep = None
        self.detectionsPerTimestep = None
        self.pluginManager = TrackingPluginManager(verbose=verbose,
                                                   pluginPaths=pluginPaths)
        self.mergerResolverPlugin = self.pluginManager.getMergerResolver()

        # should be filled by constructors of derived classes!
        self.model = None
        self.result = None
    def __init__(self, parent=None, graph=None):
        super(OpConservationTracking, self).__init__(parent=parent,
                                                     graph=graph)

        self._opCache = OpBlockedArrayCache(parent=self)
        self._opCache.name = "OpConservationTracking._opCache"
        self._opCache.Input.connect(self.Output)
        self.CleanBlocks.connect(self._opCache.CleanBlocks)
        self.CachedOutput.connect(self._opCache.Output)

        self.zeroProvider = OpZeroDefault(parent=self)
        self.zeroProvider.MetaInput.connect(self.LabelImage)

        # As soon as input data is available, check its constraints
        self.RawImage.notifyReady(self._checkConstraints)
        self.LabelImage.notifyReady(self._checkConstraints)

        self.ExportSettings.setValue((None, None))

        self._mergerOpCache = OpBlockedArrayCache(parent=self)
        self._mergerOpCache.name = "OpConservationTracking._mergerOpCache"
        self._mergerOpCache.Input.connect(self.MergerOutput)
        self.MergerCleanBlocks.connect(self._mergerOpCache.CleanBlocks)
        self.MergerCachedOutput.connect(self._mergerOpCache.Output)

        self._relabeledOpCache = OpBlockedArrayCache(parent=self)
        self._relabeledOpCache.name = "OpConservationTracking._mergerOpCache"
        self._relabeledOpCache.Input.connect(self.RelabeledImage)
        self.RelabeledCleanBlocks.connect(self._relabeledOpCache.CleanBlocks)
        self.RelabeledCachedOutput.connect(self._relabeledOpCache.Output)

        # Merger resolver plugin manager (contains GMM fit routine)
        self.pluginPaths = [
            os.path.join(os.path.dirname(os.path.abspath(hytra.__file__)),
                         'plugins')
        ]
        pluginManager = TrackingPluginManager(verbose=False,
                                              pluginPaths=self.pluginPaths)
        self.mergerResolverPlugin = pluginManager.getMergerResolver()

        self.result = None

        # progress bar
        self.progressWindow = None
        self.progressVisitor = DefaultProgressVisitor()
示例#4
0
    def __init__(self,
                 ilpOptions,
                 turnOffFeatures=[],
                 useMultiprocessing=True,
                 pluginPaths=['hytra/plugins'],
                 verbose=False):
        self._useMultiprocessing = useMultiprocessing
        self._options = ilpOptions
        self._pluginPaths = pluginPaths
        self._pluginManager = TrackingPluginManager(
            turnOffFeatures=turnOffFeatures,
            verbose=verbose,
            pluginPaths=pluginPaths)
        self._pluginManager.setImageProvider(ilpOptions.imageProviderName)
        self._pluginManager.setFeatureSerializer(
            ilpOptions.featureSerializerName)

        self._countClassifier = None
        self._divisionClassifier = None
        self._transitionClassifier = None

        self._loadClassifiers()

        self.shape, self.timeRange = self._getShapeAndTimeRange()

        # set default division feature names
        self._divisionFeatureNames = [
            'ParentChildrenRatio_Count', 'ParentChildrenRatio_Mean',
            'ChildrenRatio_Count', 'ChildrenRatio_Mean',
            'ParentChildrenAngle_RegionCenter',
            'ChildrenRatio_SquaredDistances'
        ]

        # other parameters that one might want to set
        self.x_scale = 1.0
        self.y_scale = 1.0
        self.z_scale = 1.0
        self.divisionProbabilityFeatureName = 'divProb'
        self.detectionProbabilityFeatureName = 'detProb'

        self.TraxelsPerFrame = {}
        ''' this public variable contains all traxels if we're not using pgmlink '''
    # Do the tracking
    start = time.time()

    feature_path = options.feats_path
    with_div = not bool(options.without_divisions)
    with_merger_prior = True

    # get selected time range
    time_range = [options.mints, options.maxts]
    if options.maxts == -1 and options.mints == 0:
        time_range = None

    try:
        # find shape of dataset
        pluginManager = TrackingPluginManager(verbose=options.verbose,
                                              pluginPaths=options.pluginPaths)
        shape = pluginManager.getImageProvider().getImageShape(
            ilp_fn, options.label_img_path)
        data_time_range = pluginManager.getImageProvider().getTimeRange(
            ilp_fn, options.label_img_path)

        if time_range is not None and time_range[1] < 0:
            time_range[1] += data_time_range[1]

    except:
        logging.warning("Could not read shape and time range from images")
        shape = None

    # set average object size if chosen
    obj_size = [0]
    if options.avg_obj_size != 0:
        rawimage = hytra.util.axesconversion.adjustOrder(rawimage, args.rawimage_axes[dataset], 'xyztc')
        logger.info('Done loading raw data from dataset {} of shape {}'.format(dataset, rawimage.shape))

        # find ground truth files
        # filepath is now a list of filepaths'
        filepath = args.filepath[dataset]
        # filepattern is now a list of filepatterns
        files = glob.glob(os.path.join(filepath, args.filepattern[dataset]))
        files.sort()
        initFrame = args.initFrame
        endFrame = args.endFrame
        if endFrame < 0:
            endFrame += len(files)
    
        # compute features
        trackingPluginManager = TrackingPluginManager(verbose=args.verbose, 
                                                      pluginPaths=args.pluginPaths)
        features = compute_features(rawimage,
                                    read_in_images(initFrame, endFrame, files, args.groundtruth_axes[dataset]),
                                    initFrame,
                                    endFrame,
                                    trackingPluginManager,
                                    rawimage_filename)
        logger.info('Done computing features from dataset {}'.format(dataset))

        selectedFeatures = find_features_without_NaNs(features)
        pos_labels = read_positiveLabels(initFrame, endFrame, files)
        neg_labels = negativeLabels(features, pos_labels)
        numSamples += 2 * sum([len(l) for l in pos_labels]) + sum([len(l) for l in neg_labels])
        logger.info('Done extracting {} samples'.format(numSamples))
        TC = TransitionClassifier(selectedFeatures, numSamples)
        if dataset > 0:
def computeJaccardScoresOnCloud(frame,
                                labelImageFilenames,
                                labelImagePaths,
                                labelImageFrameIdToGlobalId,
                                groundTruthFilename,
                                groundTruthPath,
                                groundTruthMinJaccardScore,
                                pluginPaths=['hytra/plugins'],
                                imageProviderPluginName='LocalImageLoader'):
    """
    Compute jaccard scores of all objects in the different segmentations with the ground truth for that frame.
    Returns a dictionary of overlapping GT labels and the score per globalId in that frame, as well as 
    a dictionary specifying the matching globalId and score for every GT label (as a list ordered by score, best match last).

    Meant to be run in its own process using `concurrent.futures.ProcessPoolExecutor`
    """

    # set up plugin manager
    from hytra.pluginsystem.plugin_manager import TrackingPluginManager
    pluginManager = TrackingPluginManager(pluginPaths=pluginPaths,
                                          verbose=False)
    pluginManager.setImageProvider(imageProviderPluginName)

    scores = {}
    gtToGlobalIdMap = {}

    groundTruthLabelImage = pluginManager.getImageProvider(
    ).getLabelImageForFrame(groundTruthFilename, groundTruthPath, frame)

    for labelImageIndexA in range(len(labelImageFilenames)):
        labelImageA = pluginManager.getImageProvider().getLabelImageForFrame(
            labelImageFilenames[labelImageIndexA],
            labelImagePaths[labelImageIndexA], frame)
        # check for overlaps - even a 1-pixel overlap is enough to be mutually exclusive!
        for objectIdA in np.unique(labelImageA):
            if objectIdA == 0:
                continue
            globalIdA = labelImageFrameIdToGlobalId[(
                labelImageFilenames[labelImageIndexA], frame, objectIdA)]
            overlap = groundTruthLabelImage[labelImageA == objectIdA]
            overlappingGtElements = set(np.unique(overlap)) - set([0])

            for gtLabel in overlappingGtElements:
                # compute Jaccard scores
                intersectingPixels = np.sum(overlap == gtLabel)
                unionPixels = np.sum(
                    np.logical_or(groundTruthLabelImage == gtLabel,
                                  labelImageA == objectIdA))
                jaccardScore = float(intersectingPixels) / float(unionPixels)

                # append to object's score list
                scores.setdefault(globalIdA, []).append(
                    (gtLabel, jaccardScore))

                # store this as GT mapping if there was no better object for this GT label yet
                if jaccardScore > groundTruthMinJaccardScore and \
                    ((frame, gtLabel) not in gtToGlobalIdMap or gtToGlobalIdMap[(frame, gtLabel)][-1][1] < jaccardScore):
                    gtToGlobalIdMap.setdefault((frame, gtLabel), []).append(
                        (globalIdA, jaccardScore))

    # sort all gt mappings by ascending jaccard score
    for _, v in gtToGlobalIdMap.iteritems():
        v.sort(key=lambda x: x[1])

    return frame, scores, gtToGlobalIdMap
示例#8
0
def writeEvents(timestep, activeLinks, activeDivisions, mergers, detections,
                fn, labelImagePath, ilpFilename, verbose, pluginPaths):
    dis = []
    app = []
    div = []
    mov = []
    mer = []
    mul = []

    pluginManager = TrackingPluginManager(verbose=verbose,
                                          pluginPaths=pluginPaths)

    logging.getLogger('json_result_to_events.py').debug(
        "-- Writing results to {}".format(fn))
    try:
        # convert to ndarray for better indexing
        dis = np.asarray(dis)
        app = np.asarray(app)
        div = np.asarray([[k, v[0], v[1]] for k, v in activeDivisions.items()])
        mov = np.asarray(activeLinks)
        mer = np.asarray([[k, v] for k, v in mergers.items()])
        mul = np.asarray(mul)

        shape = pluginManager.getImageProvider().getImageShape(
            ilpFilename, labelImagePath)
        label_img = pluginManager.getImageProvider().getLabelImageForFrame(
            ilpFilename, labelImagePath, timestep)

        with h5py.File(fn, 'w') as dest_file:
            # write meta fields and copy segmentation from project
            seg = dest_file.create_group('segmentation')
            seg.create_dataset("labels", data=label_img, compression='gzip')
            meta = dest_file.create_group('objects/meta')
            ids = np.unique(label_img)
            ids = ids[ids > 0]
            valid = np.ones(ids.shape)
            meta.create_dataset("id", data=ids, dtype=np.uint32)
            meta.create_dataset("valid", data=valid, dtype=np.uint32)

            tg = dest_file.create_group("tracking")

            # write associations
            if app is not None and len(app) > 0:
                ds = tg.create_dataset("Appearances", data=app, dtype=np.int32)
                ds.attrs["Format"] = "cell label appeared in current file"

            if dis is not None and len(dis) > 0:
                ds = tg.create_dataset("Disappearances",
                                       data=dis,
                                       dtype=np.int32)
                ds.attrs["Format"] = "cell label disappeared in current file"

            if mov is not None and len(mov) > 0:
                ds = tg.create_dataset("Moves", data=mov, dtype=np.int32)
                ds.attrs["Format"] = "from (previous file), to (current file)"

            if div is not None and len(div) > 0:
                ds = tg.create_dataset("Splits", data=div, dtype=np.int32)
                ds.attrs[
                    "Format"] = "ancestor (previous file), descendant (current file), descendant (current file)"

            if mer is not None and len(mer) > 0:
                ds = tg.create_dataset("Mergers", data=mer, dtype=np.int32)
                ds.attrs[
                    "Format"] = "descendant (current file), number of objects"

            if mul is not None and len(mul) > 0:
                ds = tg.create_dataset("MultiFrameMoves",
                                       data=mul,
                                       dtype=np.int32)
                ds.attrs[
                    "Format"] = "from (given by timestep), to (current file), timestep"

        logging.getLogger('json_result_to_events.py').debug(
            "-> results successfully written")
    except Exception as e:
        logging.getLogger('json_result_to_events.py').warning(
            "ERROR while writing events: {}".format(str(e)))
示例#9
0
def run_pipeline(options):
    """
    Run the complete tracking pipeline with competing segmentation hypotheses
    """

    # set up probabilitygenerator (aka traxelstore) and hypothesesgraph or load them from a dump
    if options.load_graph_filename is not None:
        getLogger().info("Loading state from file: " + options.load_graph_filename)
        with gzip.open(options.load_graph_filename, 'r') as graphDump:
            ilpOptions = pickle.load(graphDump)
            probGenerator = pickle.load(graphDump)
            fieldOfView = pickle.load(graphDump)
            hypotheses_graph = pickle.load(graphDump)
            trackingGraph = pickle.load(graphDump)
        getLogger().info("Done loading state from file")
    else:
        fieldOfView, hypotheses_graph, ilpOptions, probGenerator, trackingGraph = setupGraph(options)
        
        if options.dump_graph_filename is not None:
            getLogger().info("Saving state to file: " + options.dump_graph_filename)
            with gzip.open(options.dump_graph_filename, 'w') as graphDump:
                pickle.dump(ilpOptions, graphDump)
                pickle.dump(probGenerator, graphDump)
                pickle.dump(fieldOfView, graphDump)
                pickle.dump(hypotheses_graph, graphDump)
                pickle.dump(trackingGraph, graphDump)
            getLogger().info("Done saving state to file")

    # map groundtruth to hypothesesgraph if all required variables are specified
    weights = None
    if options.gt_label_image_file is not None and options.gt_label_image_path is not None \
        and options.gt_text_file is not None and options.gt_jaccard_threshold is not None:
        weights = mapGroundTruth(options, hypotheses_graph, trackingGraph, probGenerator)

    # track
    result = runTracking(options, trackingGraph, weights)

    # insert the solution into the hypotheses graph and from that deduce the lineages
    getLogger().info("Inserting solution into graph")
    hypotheses_graph.insertSolution(result)
    hypotheses_graph.computeLineage()

    mappings = {} # dictionary over timeframes, containing another dict objectId -> trackId per frame
    tracks = {} # stores a list of timeframes per track, so that we can find from<->to per track
    trackParents = {} # store the parent trackID of a track if known

    for n in hypotheses_graph.nodeIterator():
        frameMapping = mappings.setdefault(n[0], {})
        if 'trackId' not in hypotheses_graph._graph.node[n]:
            raise ValueError("You need to compute the Lineage of every node before accessing the trackId!")
        trackId = hypotheses_graph._graph.node[n]['trackId']
        traxel = hypotheses_graph._graph.node[n]['traxel']
        if trackId is not None:
            frameMapping[(traxel.idInSegmentation, traxel.segmentationFilename)] = trackId
        if trackId in tracks:
            tracks[trackId].append(n[0])
        else:
            tracks[trackId] = [n[0]]
        if 'parent' in hypotheses_graph._graph.node[n]:
            assert(trackId not in trackParents)
            trackParents[trackId] = hypotheses_graph._graph.node[hypotheses_graph._graph.node[n]['parent']]['trackId']

    # write res_track.txt
    getLogger().info("Writing track text file")
    trackDict = {}
    for trackId, timestepList in tracks.items():
        timestepList.sort()
        try:
            parent = trackParents[trackId]
        except KeyError:
            parent = 0
        trackDict[trackId] = [parent, min(timestepList), max(timestepList)]
    save_tracks(trackDict, options) 

    # export results
    getLogger().info("Saving relabeled images")
    pluginManager = TrackingPluginManager(verbose=options.verbose, pluginPaths=options.pluginPaths)
    pluginManager.setImageProvider('LocalImageLoader')
    imageProvider = pluginManager.getImageProvider()
    timeRange = probGenerator.timeRange

    for timeframe in range(timeRange[0], timeRange[1]):
        label_images = {}
        for f, p in zip(options.label_image_files, options.label_image_paths):
            label_images[f] = imageProvider.getLabelImageForFrame(f, p, timeframe)
        remapped_label_image = remap_label_image(label_images, mappings[timeframe])
        save_frame_to_tif(timeframe, remapped_label_image, options)
示例#10
0
def computeRegionFeaturesOnCloud(
        frame,
        rawImageFilename,
        rawImagePath,
        rawImageAxes,
        labelImageFilename,
        labelImagePath,
        turnOffFeatures,
        pluginPaths=['hytra/plugins'],
        featuresPerFrame=None,
        imageProviderPluginName='LocalImageLoader',
        featureSerializerPluginName='LocalFeatureSerializer'):
    '''
    Allow to use dispy to schedule feature computation to nodes running a dispynode,
    or to use multiprocessing.

    **Parameters**

    * `frame`: the frame number
    * `rawImageFilename`: the base filename of the raw image volume, or a dvid server address
    * `rawImagePath`: path inside the raw image HDF5 file, or DVID dataset UUID
    * `rawImageAxes`: axes configuration of the raw data
    * `labelImageFilename`: the base filename of the label image volume, or a dvid server address
    * `labelImagePath`: path inside the label image HDF5 file, or DVID dataset UUID
    * `pluginPaths`: where all yapsy plugins are stored (should be absolute for DVID)

    **returns** the feature dictionary for this frame if `featureSerializerPluginName == 'LocalFeatureSerializer'`
    and `featuresPerFrame == None`.
    '''

    # set up plugin manager
    from hytra.pluginsystem.plugin_manager import TrackingPluginManager
    pluginManager = TrackingPluginManager(pluginPaths=pluginPaths,
                                          turnOffFeatures=turnOffFeatures,
                                          verbose=False)
    pluginManager.setImageProvider(imageProviderPluginName)
    pluginManager.setFeatureSerializer(featureSerializerPluginName)

    # load raw and label image (depending on chosen plugin this works via DVID or locally)
    rawImage = pluginManager.getImageProvider().getImageDataAtTimeFrame(
        rawImageFilename, rawImagePath, rawImageAxes, frame)
    labelImage = pluginManager.getImageProvider().getLabelImageForFrame(
        labelImageFilename, labelImagePath, frame)

    # untwist axes, if just x and y are messed up
    if rawImage.shape[0] == labelImage.shape[1] and rawImage.shape[
            1] == labelImage.shape[0]:
        labelImage = np.transpose(labelImage, axes=[1, 0])

    # compute features
    moreFeats, ignoreNames = pluginManager.applyObjectFeatureComputationPlugins(
        len(labelImage.shape), rawImage, labelImage, frame, rawImageFilename)

    # combine into one dictionary
    # WARNING: if there are multiple features with the same name, they will be overwritten!
    frameFeatureItems = []
    for f in moreFeats:
        frameFeatureItems = frameFeatureItems + f.items()
    frameFeatures = dict(frameFeatureItems)

    # delete all ignored features
    for k in ignoreNames:
        if k in frameFeatures.keys():
            del frameFeatures[k]

    # return or save features
    if featuresPerFrame is None and featureSerializerPluginName is 'LocalFeatureSerializer':
        # simply return resulting dict
        return frame, frameFeatures
    else:
        # set up feature serializer (local or DVID for now)
        featureSerializer = pluginManager.getFeatureSerializer()

        # server address and uuid are only used by the DVID serializer
        featureSerializer.server_address = labelImageFilename
        featureSerializer.uuid = labelImagePath

        # feature dictionary used by local serializer
        featureSerializer.features_per_frame = featuresPerFrame

        # store
        featureSerializer.storeFeaturesForFrame(frameFeatures, frame)