def fates(self, populations): """ Computes fates for each population Parameters ---------- self : wot.TransportMapModel The TransportMapModel used to find fates populations : list of wot.Population The target populations such as ones from self.population_from_cell_sets. The populations must be from the same time. Returns ------- fates : anndata.AnnData Rows : all cells, Columns : populations index. At point (i, j) : the probability that cell i belongs to population j """ start_day = wot.tmap.unique_timepoint( *populations) # check for unique timepoint populations = Population.copy(*populations, normalize=False, add_missing=True) pop_names = [pop.name for pop in populations] results = [] results.insert(0, np.array([pop.p for pop in populations]).T) while self.can_pull_back(*populations): populations = self.pull_back(*populations, as_list=True, normalize=False) results.insert(0, np.array([pop.p for pop in populations]).T) X = np.concatenate(results) X /= X.sum(axis=1, keepdims=1) obs = self.meta.copy() obs = obs[obs['day'] <= start_day] return anndata.AnnData(X=X, obs=obs, var=pd.DataFrame(index=pop_names))
def get_population(ids_el): cell_indices = df.index.get_indexer_for(ids_el) cell_indices = cell_indices[cell_indices > -1] if len(cell_indices) is 0: return None p = np.zeros(len(df), dtype=np.float64) p[cell_indices] = 1.0 return Population(day, p / np.sum(p))
def transition_table(self, start_populations, end_populations): """ Computes a transition table from the starting populations to the ending populations Parameters ---------- self : wot.TransportMapModel The TransportMapModel start_populations : list of wot.Population The target populations such as ones from self.population_from_cell_sets. THe populations must be from the same time. Returns ------- transition table : anndata.AnnData Rows : starting populations, Columns : ending populations. """ # add "other" population if any cells are missing across all populations start_time = wot.tmap.unique_timepoint(*start_populations) start_populations = Population.copy(*start_populations, normalize=False, add_missing=True) end_populations = Population.copy(*end_populations, normalize=False, add_missing=True) wot.tmap.unique_timepoint( *end_populations) # check for unique timepoint populations = end_populations results = [] results.insert(0, np.array([pop.p for pop in populations]).T) while self.can_pull_back(*populations) and wot.tmap.unique_timepoint( *populations) > start_time: populations = self.pull_back(*populations, as_list=True, normalize=False) end_p = np.vstack([pop.p for pop in populations]) start_p = np.vstack([pop.p for pop in start_populations]) p = (start_p @ end_p.T) p = p / p.sum() return anndata.AnnData( X=p, obs=pd.DataFrame(index=[p.name for p in start_populations]), var=pd.DataFrame(index=[p.name for p in end_populations]))
def trajectories(self, populations): """ Computes a trajectory for each population Parameters ---------- self : wot.TransportMapModel The TransportMapModel used to find ancestors and descendants of the population populations : list of wot.Population The target populations such as ones from self.population_from_cell_sets. THe populations must be from the same time. Returns ------- trajectories : anndata.AnnData Rows : all cells, Columns : populations index. At point (i, j) : the probability that cell i is an ancestor/descendant of population j """ wot.tmap.unique_timepoint(*populations) # check for unique timepoint trajectories = [] populations = Population.copy(*populations, normalize=True, add_missing=False) population_names = [p.name for p in populations] initial_populations = populations def update(head, populations_to_update): idx = 0 if head else len(trajectories) trajectories.insert( idx, np.array([pop.p for pop in populations_to_update]).T) update(True, populations) while self.can_pull_back(*populations): populations = self.pull_back(*populations, as_list=True) update(True, populations) populations = initial_populations while self.can_push_forward(*populations): populations = self.push_forward(*populations, as_list=True) update(False, populations) return anndata.AnnData(X=np.concatenate(trajectories), obs=self.meta.copy(), var=pd.DataFrame(index=population_names))
def pull_back(self, *populations, to_time=None, normalize=True, as_list=False): """ Pulls the population back through the computed transport maps Parameters ---------- *populations : wot.Population Measure over the cells at a given timepoint to be pushed forward. to_time : int or float, optional Destination timepoint to pull back to. normalize : bool, optional, default: True Wether to normalize to a probability distribution or keep growth. as_list : bool, optional, default: False Wether to return a listof length 1 when a single element is passed, or a Population Returns ------- result : wot.Population The pull back of the input population through the proper transport map. Array of populations if several populations were given as input. Raises ------ ValueError If there is no previous timepoint to pull the population back. ValueError If several populations are given as input but dot live in the same timepoint. Examples -------- >>> self.pull_back(pop, to_time = 0) # -> wot.Population Pushing several populations at once >>> self.pull_back(pop1, pop2, pop3) # -> list of wot.Population Pulling back after pushing forward >>> self.pull_back(self.push_forward(pop)) Same, but several populations at once >>> self.pull_back(* self.push_forward(pop1, pop2, pop3)) """ i = self.timepoints.index(wot.tmap.unique_timepoint(*populations)) j = i - 1 if to_time is None else self.timepoints.index(to_time) if i == -1: raise ValueError("Timepoint not found") if i == 0: raise ValueError("No previous timepoints. Unable to pull back") if j == -1: raise ValueError("Destination timepoint not found") if i < j: raise ValueError( "Destination timepoint is after source. Unable to pull back") p = np.vstack([pop.p for pop in populations]) while i > j: t1 = self.timepoints[i] t0 = self.timepoints[i - 1] tmap = self.get_coupling(t0, t1) p = (tmap.X @ p.T).T if normalize: p = (p.T / np.sum(p, axis=1)).T i -= 1 result = [ Population(self.timepoints[i], p[k, :]) for k in range(p.shape[0]) ] if len(result) == 1 and not as_list: return result[0] else: return result