def calc_time_censored_tree_frequencies(self): print("fitting time censored tree frequencies") # this doesn't interfere with the previous freq estimates via difference in region: global_censored vs global region = "global_censored" freq_cutoff = 25.0 pivots_fit = 6 freq_window = 1.0 for node in self.nodes: node.fit_frequencies = {} node.freq_slope = {} for time in self.timepoints: time_interval = [time - freq_window, time] pivots = make_pivots(time_interval[0], time_interval[1], 1 / self.pivot_spacing) node_filter_func = lambda node: node.numdate >= time_interval[ 0] and node.numdate < time_interval[1] # Recalculate tree frequencies for the given time interval and its # corresponding pivots. tree_freqs = tree_frequencies(self.tree, pivots, node_filter=node_filter_func) tree_freqs.estimate_clade_frequencies() self.frequencies[time] = tree_freqs.frequencies # Annotate censored frequencies on nodes. # TODO: replace node-based annotation with dicts indexed by node name. for node in self.nodes: node.freq = {region: self.frequencies[time][node.clade]} node.logit_freq = { region: logit_transform(self.frequencies[time][node.clade], 1e-4) } for node in self.nodes: if node.logit_freq[region] is not None: node.fit_frequencies[time] = np.minimum( freq_cutoff, np.maximum(-freq_cutoff, node.logit_freq[region])) else: node.fit_frequencies[time] = self.node_parents[ node].fit_frequencies[time] try: slope, intercept, rval, pval, stderr = linregress( pivots[pivots_fit:-1], node.fit_frequencies[time][pivots_fit:-1]) node.freq_slope[time] = slope except: import ipdb ipdb.set_trace() # reset pivots in tree to global pivots self.rootnode.pivots = self.pivots
def calc_node_frequencies(self): ''' goes over all nodes and calculates frequencies at timepoints based on previously calculated frequency trajectories ''' region = "global" # Calculate global tree/clade frequencies if they have not been calculated already. if region not in self.frequencies or self.rootnode.clade not in self.frequencies[ "global"]: print("calculating global node frequencies") tree_freqs = tree_frequencies(self.tree, self.pivots, method="SLSQP", verbose=1) tree_freqs.estimate_clade_frequencies() self.frequencies[region] = tree_freqs.frequencies else: print("found existing global node frequencies") # Annotate frequencies on nodes. # TODO: replace node-based annotation with dicts indexed by node name. for node in self.nodes: node.freq = {region: self.frequencies[region][node.clade]} node.logit_freq = { region: logit_transform(self.frequencies[region][node.clade], 1e-4) } for node in self.nodes: interpolation = interp1d(self.rootnode.pivots, node.freq[region], kind='linear', bounds_error=True) node.timepoint_freqs = defaultdict(float) node.delta_freqs = defaultdict(float) for time in self.timepoints: node.timepoint_freqs[time] = np.asscalar(interpolation(time)) for time in self.timepoints[:-1]: node.delta_freqs[time] = np.asscalar( interpolation(time + self.delta_time)) # freq_arrays list *all* tips for each initial timepoint self.freq_arrays = {} for time in self.timepoints: tmp_freqs = [] for tip in self.tips: tmp_freqs.append(tip.timepoint_freqs[time]) self.freq_arrays[time] = np.array(tmp_freqs)