Пример #1
0
    def calc_time_censored_tree_frequencies(self):
        print("fitting time censored tree frequencies")
        # this doesn't interfere with the previous freq estimates via difference in region: global_censored vs global
        region = "global_censored"
        freq_cutoff = 25.0
        pivots_fit = 6
        freq_window = 1.0
        for node in self.nodes:
            node.fit_frequencies = {}
            node.freq_slope = {}
        for time in self.timepoints:
            time_interval = [time - freq_window, time]
            pivots = make_pivots(time_interval[0], time_interval[1],
                                 1 / self.pivot_spacing)
            node_filter_func = lambda node: node.numdate >= time_interval[
                0] and node.numdate < time_interval[1]

            # Recalculate tree frequencies for the given time interval and its
            # corresponding pivots.
            tree_freqs = tree_frequencies(self.tree,
                                          pivots,
                                          node_filter=node_filter_func)
            tree_freqs.estimate_clade_frequencies()
            self.frequencies[time] = tree_freqs.frequencies

            # Annotate censored frequencies on nodes.
            # TODO: replace node-based annotation with dicts indexed by node name.
            for node in self.nodes:
                node.freq = {region: self.frequencies[time][node.clade]}
                node.logit_freq = {
                    region:
                    logit_transform(self.frequencies[time][node.clade], 1e-4)
                }

            for node in self.nodes:
                if node.logit_freq[region] is not None:
                    node.fit_frequencies[time] = np.minimum(
                        freq_cutoff,
                        np.maximum(-freq_cutoff, node.logit_freq[region]))
                else:
                    node.fit_frequencies[time] = self.node_parents[
                        node].fit_frequencies[time]
                try:
                    slope, intercept, rval, pval, stderr = linregress(
                        pivots[pivots_fit:-1],
                        node.fit_frequencies[time][pivots_fit:-1])
                    node.freq_slope[time] = slope
                except:
                    import ipdb
                    ipdb.set_trace()
        # reset pivots in tree to global pivots
        self.rootnode.pivots = self.pivots
Пример #2
0
    def calc_node_frequencies(self):
        '''
        goes over all nodes and calculates frequencies at timepoints based on previously calculated frequency trajectories
        '''
        region = "global"

        # Calculate global tree/clade frequencies if they have not been calculated already.
        if region not in self.frequencies or self.rootnode.clade not in self.frequencies[
                "global"]:
            print("calculating global node frequencies")
            tree_freqs = tree_frequencies(self.tree,
                                          self.pivots,
                                          method="SLSQP",
                                          verbose=1)
            tree_freqs.estimate_clade_frequencies()
            self.frequencies[region] = tree_freqs.frequencies
        else:
            print("found existing global node frequencies")

        # Annotate frequencies on nodes.
        # TODO: replace node-based annotation with dicts indexed by node name.
        for node in self.nodes:
            node.freq = {region: self.frequencies[region][node.clade]}
            node.logit_freq = {
                region: logit_transform(self.frequencies[region][node.clade],
                                        1e-4)
            }

        for node in self.nodes:
            interpolation = interp1d(self.rootnode.pivots,
                                     node.freq[region],
                                     kind='linear',
                                     bounds_error=True)
            node.timepoint_freqs = defaultdict(float)
            node.delta_freqs = defaultdict(float)
            for time in self.timepoints:
                node.timepoint_freqs[time] = np.asscalar(interpolation(time))
            for time in self.timepoints[:-1]:
                node.delta_freqs[time] = np.asscalar(
                    interpolation(time + self.delta_time))

        # freq_arrays list *all* tips for each initial timepoint
        self.freq_arrays = {}
        for time in self.timepoints:
            tmp_freqs = []
            for tip in self.tips:
                tmp_freqs.append(tip.timepoint_freqs[time])
            self.freq_arrays[time] = np.array(tmp_freqs)