def _loadClassifiers(self):
     if self._options.objectCountClassifierPath != None and self._options.objectCountClassifierFilename != None:
         self._countClassifier = RandomForestClassifier(
             self._options.objectCountClassifierPath,
             self._options.objectCountClassifierFilename, self._options)
     if self._options.divisionClassifierPath != None and self._options.divisionClassifierFilename != None:
         self._divisionClassifier = RandomForestClassifier(
             self._options.divisionClassifierPath,
             self._options.divisionClassifierFilename, self._options)
     if self._options.transitionClassifierPath != None and self._options.transitionClassifierFilename != None:
         self._transitionClassifier = RandomForestClassifier(
             self._options.transitionClassifierPath,
             self._options.transitionClassifierFilename, self._options)
def trainDetectionClassifier(hypothesesGraph,
                             gtFrameIdToGlobalIdsWithScoresMap,
                             numSamples=100,
                             selectedFeatures=None):
    """
    Finds the given number of training examples, half as positive and half as negative examples, from the
    given graph and mapping.

    Positive examples are those with the highest jaccard score, while negative examples can either 
    just not be the best match for a GT label, or also be not matched at all.

    **Returns**: a trained random forest
    """
    # create a list of all elements, sort them by their jaccard score, then pick from both ends?
    getLogger().debug("Extracting candidates")

    # create helper class for candidates, and store a list of these
    @attr.s
    class Candidate(object):
        ''' Helper class to combine a hytpotheses graph `node` and its `score` to find the proper samples for classifier training '''
        node = attr.ib()
        score = attr.ib(validator=attr.validators.instance_of(float))

    candidates = []

    nodeTraxelMap = hypothesesGraph.getNodeTraxelMap()
    for node in hypothesesGraph.nodeIterator():
        if 'JaccardScores' in nodeTraxelMap[node].Features and len(
                nodeTraxelMap[node].Features['JaccardScores']) > 0:
            globalIdsAndScores = nodeTraxelMap[node].Features['JaccardScores']
            globalIdsAndScores = sorted(globalIdsAndScores, key=lambda x: x[1])
            bestScore = globalIdsAndScores[-1][1]
            candidates.append(Candidate(node, bestScore))

    assert (len(candidates) >= numSamples)
    candidates.sort(key=lambda x: x.score)

    # pick the first and last numSamples/2, and extract their features?
    # use RandomForestClassifier's method "extractFeatureVector"
    selectedSamples = candidates[0:numSamples //
                                 2] + candidates[-numSamples // 2 - 1:-1]
    labels = np.hstack([np.zeros(numSamples // 2), np.ones(numSamples // 2)])
    getLogger().info("Using {} of {} available training examples".format(
        numSamples, len(candidates)))

    # TODO: make sure that the positive examples were all selected in the GT mapping

    getLogger().debug("construct feature matrix")
    node = selectedSamples[0].node
    if selectedFeatures is None:
        selectedFeatures = nodeTraxelMap[node].Features.keys()
        forbidden = [
            'JaccardScores', 'id', 'filename', 'Polygon', 'detProb', 'divProb',
            'com'
        ]
        forbidden += [f for f in selectedFeatures if f.count('_') > 0]
        for f in forbidden:
            if f in selectedFeatures:
                selectedFeatures.remove(f)
        getLogger().info(
            "No list of selected features was specified, using {}".format(
                selectedFeatures))

    rf = RandomForestClassifier(selectedFeatures=selectedFeatures)
    features = rf.extractFeatureVector(nodeTraxelMap[node].Features,
                                       singleObject=True)
    featureMatrix = np.zeros([len(selectedSamples), features.shape[1]])
    featureMatrix[0, :] = features
    for idx, candidate in enumerate(selectedSamples[1:]):
        features = rf.extractFeatureVector(
            nodeTraxelMap[candidate.node].Features, singleObject=True)
        featureMatrix[idx + 1, :] = features

    rf.train(featureMatrix, labels)

    return rf
class IlpProbabilityGenerator(ProbabilityGenerator):
    """
    The IlpProbabilityGenerator is a python wrapper around pgmlink's C++ traxelstore,
    but with the functionality to compute all region features 
    and evaluate the division/count/transition classifiers.
    """
    def __init__(self,
                 ilpOptions,
                 turnOffFeatures=[],
                 useMultiprocessing=True,
                 pluginPaths=['hytra/plugins'],
                 verbose=False):
        self._useMultiprocessing = useMultiprocessing
        self._options = ilpOptions
        self._pluginPaths = pluginPaths
        self._pluginManager = TrackingPluginManager(
            turnOffFeatures=turnOffFeatures,
            verbose=verbose,
            pluginPaths=pluginPaths)
        self._pluginManager.setImageProvider(ilpOptions.imageProviderName)
        self._pluginManager.setFeatureSerializer(
            ilpOptions.featureSerializerName)

        self._countClassifier = None
        self._divisionClassifier = None
        self._transitionClassifier = None

        self._loadClassifiers()

        self.shape, self.timeRange = self._getShapeAndTimeRange()

        # set default division feature names
        self._divisionFeatureNames = [
            'ParentChildrenRatio_Count', 'ParentChildrenRatio_Mean',
            'ChildrenRatio_Count', 'ChildrenRatio_Mean',
            'ParentChildrenAngle_RegionCenter',
            'ChildrenRatio_SquaredDistances'
        ]

        # other parameters that one might want to set
        self.x_scale = 1.0
        self.y_scale = 1.0
        self.z_scale = 1.0
        self.divisionProbabilityFeatureName = 'divProb'
        self.detectionProbabilityFeatureName = 'detProb'

        self.TraxelsPerFrame = {}
        ''' this public variable contains all traxels if we're not using pgmlink '''

    def _loadClassifiers(self):
        if self._options.objectCountClassifierPath != None and self._options.objectCountClassifierFilename != None:
            self._countClassifier = RandomForestClassifier(
                self._options.objectCountClassifierPath,
                self._options.objectCountClassifierFilename, self._options)
        if self._options.divisionClassifierPath != None and self._options.divisionClassifierFilename != None:
            self._divisionClassifier = RandomForestClassifier(
                self._options.divisionClassifierPath,
                self._options.divisionClassifierFilename, self._options)
        if self._options.transitionClassifierPath != None and self._options.transitionClassifierFilename != None:
            self._transitionClassifier = RandomForestClassifier(
                self._options.transitionClassifierPath,
                self._options.transitionClassifierFilename, self._options)

    def __getstate__(self):
        '''
        We define __getstate__ and __setstate__ to exclude the random forests from being pickled,
        as that is not allowed.

        See https://docs.python.org/3/library/pickle.html#pickle-state for more details.
        '''
        # Copy the object's state from self.__dict__ which contains
        # all our instance attributes. Always use the dict.copy()
        # method to avoid modifying the original state.
        state = self.__dict__.copy()
        # Remove the unpicklable entries.
        del state['_countClassifier']
        del state['_divisionClassifier']
        del state['_transitionClassifier']
        return state

    def __setstate__(self, state):
        # Restore instance attributes
        self.__dict__.update(state)
        # Restore the random forests by reading them from scratch
        self._loadClassifiers()

    def computeRegionFeatures(self, rawImage, labelImage, frameNumber):
        """
        Computes all region features for all objects in the given image
        """
        assert (labelImage.dtype == np.uint32)

        moreFeats, ignoreNames = self._pluginManager.applyObjectFeatureComputationPlugins(
            len(labelImage.shape), rawImage, labelImage, frameNumber,
            self._options.rawImageFilename)
        frameFeatureItems = []
        for f in moreFeats:
            frameFeatureItems = frameFeatureItems + f.items()
        frameFeatures = dict(frameFeatureItems)

        # delete the "Global<Min/Max>" features as they are not nice when iterating over everything
        for k in ignoreNames:
            if k in frameFeatures.keys():
                del frameFeatures[k]

        return frameFeatures

    def computeDivisionFeatures(self, featuresAtT, featuresAtTPlus1,
                                labelImageAtTPlus1):
        """
        Computes the division features for all objects in the images
        """
        fm = hytra.core.divisionfeatures.FeatureManager(
            ndim=self.getNumDimensions())
        return fm.computeFeatures_at(featuresAtT, featuresAtTPlus1,
                                     labelImageAtTPlus1,
                                     self._divisionFeatureNames)

    def setDivisionFeatures(self, divisionFeatures):
        """
        Set which features should be computed explicitly for divisions by giving a list of strings.
        Each string could be a combination of <operation>_<feature>, where Operation is one of:
            * ParentIdentity
            * SquaredDistances
            * ChildrenRatio
            * ParentChildrenAngle
            * ParentChildrenRatio

        And <feature> is any region feature plus "SquaredDistances"
        """
        # TODO: check that the strings are valid?
        self._divisionFeatureNames = divisionFeatures

    def getNumDimensions(self):
        """
        Compute the number of dimensions which is the number of axis with more than 1 element
        """
        return np.count_nonzero(np.array(self.shape) != 1)

    def _getShapeAndTimeRange(self):
        """
        extract the shape from the labelimage
        """
        shape = self._pluginManager.getImageProvider().getImageShape(
            self._options.labelImageFilename, self._options.labelImagePath)
        timerange = self._pluginManager.getImageProvider().getTimeRange(
            self._options.labelImageFilename, self._options.labelImagePath)
        return shape, timerange

    def getLabelImageForFrame(self, timeframe):
        """
        Get the label image(volume) of one time frame
        """
        rawImage = self._pluginManager.getImageProvider(
        ).getLabelImageForFrame(self._options.labelImageFilename,
                                self._options.labelImagePath, timeframe)
        return rawImage

    def getRawImageForFrame(self, timeframe):
        """
        Get the raw image(volume) of one time frame
        """
        rawImage = self._pluginManager.getImageProvider(
        ).getImageDataAtTimeFrame(self._options.rawImageFilename,
                                  self._options.rawImagePath, timeframe)
        return rawImage

    def _extractFeaturesForFrame(self, timeframe):
        """
        extract the features of one frame, return a dictionary of features,
        where each feature vector contains N entries per object 
        (where N is the dimensionality of the feature)
        """
        rawImage = self.getRawImageForFrame(timeframe)
        labelImage = self.getLabelImageForFrame(timeframe)

        return timeframe, self.computeRegionFeatures(rawImage, labelImage,
                                                     timeframe)

    def _extractDivisionFeaturesForFrame(self, timeframe, featuresPerFrame):
        """
        extract Division Features for one frame, and store them in the given featuresPerFrame dict
        """
        feats = {}
        if timeframe + 1 < self.timeRange[1]:
            labelImageAtTPlus1 = self.getLabelImageForFrame(timeframe + 1)
            feats = self.computeDivisionFeatures(
                featuresPerFrame[timeframe], featuresPerFrame[timeframe + 1],
                labelImageAtTPlus1)

        return timeframe, feats

    def _extractAllFeatures(self, dispyNodeIps=[], turnOffFeatures=[]):
        """
        Extract the features of all frames. 

        If a list of IP addresses is given e.g. as `dispyNodeIps = ["104.197.178.206","104.196.46.138"]`, 
        then the computation will be distributed across these nodes. Otherwise, multiprocessing will
        be used if `self._useMultiprocessing=True`, which it is by default.

        If `dispyNodeIps` is an empty list, then the feature extraction will be parallelized via
        multiprocessing.

        **TODO:** fix division feature computation for distributed mode
        """
        import logging
        # configure progress bar
        numSteps = self.timeRange[1] - self.timeRange[0]
        if self._divisionClassifier is not None:
            numSteps *= 2

        t0 = time.time()

        if (len(dispyNodeIps) == 0):
            # no dispy node IDs given, parallelize object feature computation via processes

            if self._useMultiprocessing:
                # use ProcessPoolExecutor, which instanciates as many processes as there CPU cores by default
                ExecutorType = concurrent.futures.ProcessPoolExecutor
                logging.getLogger('Traxelstore').info(
                    'Parallelizing feature extraction via multiprocessing on all cores!'
                )
            else:
                ExecutorType = DummyExecutor
                logging.getLogger('Traxelstore').info(
                    'Running feature extraction on single core!')

            featuresPerFrame = {}
            progressBar = ProgressBar(stop=numSteps)
            progressBar.show(increase=0)

            with ExecutorType() as executor:
                # 1st pass for region features
                jobs = []
                for frame in range(self.timeRange[0], self.timeRange[1]):
                    jobs.append(
                        executor.submit(computeRegionFeaturesOnCloud, frame,
                                        self._options.rawImageFilename,
                                        self._options.rawImagePath,
                                        self._options.rawImageAxes,
                                        self._options.labelImageFilename,
                                        self._options.labelImagePath,
                                        turnOffFeatures, self._pluginPaths))
                for job in concurrent.futures.as_completed(jobs):
                    progressBar.show()
                    frame, feats = job.result()
                    featuresPerFrame[frame] = feats

                # 2nd pass for division features
                if self._divisionClassifier is not None:
                    jobs = []
                    for frame in range(self.timeRange[0],
                                       self.timeRange[1] - 1):
                        jobs.append(
                            executor.submit(
                                computeDivisionFeaturesOnCloud, frame,
                                featuresPerFrame[frame],
                                featuresPerFrame[frame + 1],
                                self._pluginManager.getImageProvider(),
                                self._options.labelImageFilename,
                                self._options.labelImagePath,
                                self.getNumDimensions(),
                                self._divisionFeatureNames))

                    for job in concurrent.futures.as_completed(jobs):
                        progressBar.show()
                        frame, feats = job.result()
                        featuresPerFrame[frame].update(feats)

            # # serialize features??
            # for frame in range(self.timeRange[0], self.timeRange[1]):
            #     featureSerializer.storeFeaturesForFrame(featuresPerFrame[frame], frame)
        else:

            import logging
            logging.getLogger('Traxelstore').warning(
                'Parallelization with dispy is WORK IN PROGRESS!')
            import random
            import dispy
            cluster = dispy.JobCluster(computeRegionFeaturesOnCloud,
                                       nodes=dispyNodeIps,
                                       loglevel=logging.DEBUG,
                                       depends=[self._pluginManager],
                                       secret="teamtracking")

            jobs = []
            for frame in range(self.timeRange[0], self.timeRange[1]):
                job = cluster.submit(
                    frame,
                    self._options.rawImageFilename,
                    self._options.rawImagePath,
                    self._options.rawImageAxes,
                    self._options.labelImageFilename,
                    self._options.labelImagePath,
                    turnOffFeatures,
                    pluginPaths=['/home/carstenhaubold/embryonic/plugins'])
                job.id = frame
                jobs.append(job)

            for job in jobs:
                job()  # wait for job to finish
                print job.exception
                print job.stdout
                print job.stderr
                print job.id

            logging.getLogger('Traxelstore').warning(
                'Using dispy we cannot compute division features yet!')
            # # 2nd pass for division features
            # if self._divisionClassifier is not None:
            #     for frame in range(self.timeRange[0], self.timeRange[1]):
            #         progressBar.show()
            #         featuresPerFrame[frame].update(self._extractDivisionFeaturesForFrame(frame, featuresPerFrame)[1])

        t1 = time.time()
        getLogger().info("Feature computation took {} secs".format(t1 - t0))

        return featuresPerFrame

    def _setTraxelFeatureArray(self, traxel, featureArray, name):
        ''' store the specified `featureArray` in a `traxel`'s feature dictionary under the specified key=`name` '''
        if isinstance(featureArray, np.ndarray):
            featureArray = featureArray.flatten()
        traxel.add_feature_array(name, len(featureArray))
        for i, v in enumerate(featureArray):
            traxel.set_feature_value(name, i, float(v))

    def fillTraxels(self,
                    usePgmlink=True,
                    ts=None,
                    fs=None,
                    dispyNodeIps=[],
                    turnOffFeatures=[]):
        """
        Compute all the features and predict object count as well as division probabilities.
        Store the resulting information (and all other features) in the given pgmlink::TraxelStore,
        or create a new one if ts=None.

        usePgmlink: boolean whether pgmlink should be used and a pgmlink.TraxelStore and pgmlink.FeatureStore returned
        ts: an initial pgmlink.TraxelStore (only used if usePgmlink=True)
        fs: an initial pgmlink.FeatureStore (only used if usePgmlink=True)

        returns (ts, fs) but only if usePgmlink=True, otherwise it fills self.TraxelsPerFrame
        """
        if usePgmlink:
            import pgmlink
            if ts is None:
                ts = pgmlink.TraxelStore()
                fs = pgmlink.FeatureStore()
            else:
                assert (fs is not None)

        getLogger().info("Extracting features...")
        self._featuresPerFrame = self._extractAllFeatures(
            dispyNodeIps=dispyNodeIps, turnOffFeatures=turnOffFeatures)

        getLogger().info("Creating traxels...")
        progressBar = ProgressBar(stop=len(self._featuresPerFrame))
        progressBar.show(increase=0)

        for frame, features in self._featuresPerFrame.iteritems():
            # predict random forests
            if self._countClassifier is not None:
                objectCountProbabilities = self._countClassifier.predictProbabilities(
                    features=None, featureDict=features)

            if self._divisionClassifier is not None and frame + 1 < self.timeRange[
                    1]:
                divisionProbabilities = self._divisionClassifier.predictProbabilities(
                    features=None, featureDict=features)

            # create traxels for all objects
            for objectId in range(1, features.values()[0].shape[0]):
                # print("Frame {} Object {}".format(frame, objectId))
                pixelSize = features['Count'][objectId]
                if pixelSize == 0 or (self._options.sizeFilter is not None \
                        and (pixelSize < self._options.sizeFilter[0] \
                                     or pixelSize > self._options.sizeFilter[1])):
                    continue

                # create traxel
                if usePgmlink:
                    traxel = pgmlink.Traxel()
                else:
                    traxel = Traxel()
                traxel.Id = objectId
                traxel.Timestep = frame

                # add raw features
                for key, val in features.iteritems():
                    if key == 'id':
                        traxel.idInSegmentation = val[objectId]
                    elif key == 'filename':
                        traxel.segmentationFilename = val[objectId]
                    else:
                        try:
                            if isinstance(
                                    val,
                                    list):  # polygon feature returns a list!
                                featureValues = val[objectId]
                            else:
                                featureValues = val[objectId, ...]
                        except:
                            getLogger().error(
                                "Could not get feature values of {} for key {} from matrix with shape {}"
                                .format(objectId, key, val.shape))
                            raise AssertionError()
                        try:
                            self._setTraxelFeatureArray(
                                traxel, featureValues, key)
                            if key == 'RegionCenter':
                                self._setTraxelFeatureArray(
                                    traxel, featureValues, 'com')
                        except:
                            getLogger().error(
                                "Could not add feature array {} for {}".format(
                                    featureValues, key))
                            raise AssertionError()

                # add random forest predictions
                if self._countClassifier is not None:
                    self._setTraxelFeatureArray(
                        traxel, objectCountProbabilities[objectId, :],
                        self.detectionProbabilityFeatureName)

                if self._divisionClassifier is not None and frame + 1 < self.timeRange[
                        1]:
                    self._setTraxelFeatureArray(
                        traxel, divisionProbabilities[objectId, :],
                        self.divisionProbabilityFeatureName)

                # set other parameters
                traxel.set_x_scale(self.x_scale)
                traxel.set_y_scale(self.y_scale)
                traxel.set_z_scale(self.z_scale)

                if usePgmlink:
                    # add to pgmlink's traxelstore
                    ts.add(fs, traxel)
                else:
                    self.TraxelsPerFrame.setdefault(frame,
                                                    {})[objectId] = traxel
            progressBar.show()

        if usePgmlink:
            return ts, fs

    def getTraxelFeatureDict(self, frame, objectId):
        """
        Getter method for features per traxel
        """
        assert self._featuresPerFrame != None
        traxelFeatureDict = {}
        for k, v in self._featuresPerFrame[frame].iteritems():
            if 'Polygon' in k:
                traxelFeatureDict[k] = v[objectId]
            else:
                traxelFeatureDict[k] = v[objectId, ...]
        return traxelFeatureDict

    def getTransitionFeatureVector(self, featureDictObjectA,
                                   featureDictObjectB, selectedFeatures):
        """
        Return component wise difference and product of the selected features as input for the TransitionClassifier
        """
        features = np.array(
            self._pluginManager.
            applyTransitionFeatureVectorConstructionPlugins(
                featureDictObjectA, featureDictObjectB, selectedFeatures))
        features = np.expand_dims(features, axis=0)
        return features
def trainDetectionClassifier(hypothesesGraph, gtFrameIdToGlobalIdsWithScoresMap, numSamples=100, selectedFeatures=None):
    """
    Finds the given number of training examples, half as positive and half as negative examples, from the
    given graph and mapping.

    Positive examples are those with the highest jaccard score, while negative examples can either 
    just not be the best match for a GT label, or also be not matched at all.

    **Returns**: a trained random forest
    """
    # create a list of all elements, sort them by their jaccard score, then pick from both ends?
    getLogger().debug("Extracting candidates")

    # create helper class for candidates, and store a list of these
    @attr.s
    class Candidate(object):
        ''' Helper class to combine a hytpotheses graph `node` and its `score` to find the proper samples for classifier training '''
        node = attr.ib()
        score = attr.ib(validator=attr.validators.instance_of(float))
    candidates = []

    nodeTraxelMap = hypothesesGraph.getNodeTraxelMap()
    for node in hypothesesGraph.nodeIterator():
        if 'JaccardScores' in nodeTraxelMap[node].Features and len(nodeTraxelMap[node].Features['JaccardScores']) > 0:
            globalIdsAndScores = nodeTraxelMap[node].Features['JaccardScores']
            globalIdsAndScores = sorted(globalIdsAndScores, key=lambda x: x[1])
            bestScore = globalIdsAndScores[-1][1]
            candidates.append(Candidate(node, bestScore))

    
    assert(len(candidates) >= numSamples)
    candidates.sort(key=lambda x: x.score)

    # pick the first and last numSamples/2, and extract their features?
    # use RandomForestClassifier's method "extractFeatureVector"
    selectedSamples = candidates[0:numSamples//2] + candidates[-numSamples//2-1:-1]
    labels = np.hstack([np.zeros(numSamples//2), np.ones(numSamples//2)])
    getLogger().info("Using {} of {} available training examples".format(numSamples, len(candidates)))

    # TODO: make sure that the positive examples were all selected in the GT mapping

    getLogger().debug("construct feature matrix")
    node = selectedSamples[0].node
    if selectedFeatures is None:
        selectedFeatures = nodeTraxelMap[node].Features.keys()
        forbidden = ['JaccardScores', 'id', 'filename', 'Polygon', 'detProb', 'divProb', 'com']
        forbidden += [f for f in selectedFeatures if f.count('_') > 0]
        for f in forbidden:
            if f in selectedFeatures:
                selectedFeatures.remove(f)
        getLogger().info("No list of selected features was specified, using {}".format(selectedFeatures))

    rf = RandomForestClassifier(selectedFeatures=selectedFeatures)
    features = rf.extractFeatureVector(nodeTraxelMap[node].Features, singleObject=True)
    featureMatrix = np.zeros([len(selectedSamples), features.shape[1]])
    featureMatrix[0, :] = features
    for idx, candidate in enumerate(selectedSamples[1:]):
        features = rf.extractFeatureVector(nodeTraxelMap[candidate.node].Features, singleObject=True)
        featureMatrix[idx + 1, :] = features

    rf.train(featureMatrix, labels)

    return rf