def get_population(ids_el): cell_indices = df.index.get_indexer_for(ids_el) cell_indices = cell_indices[cell_indices > -1] if len(cell_indices) is 0: return None p = np.zeros(len(df), dtype=np.float64) p[cell_indices] = 1.0 return Population(day, p / np.sum(p))
def pull_back(self, *populations, to_time=None, normalize=True, as_list=False): """ Pulls the population back through the computed transport maps Parameters ---------- *populations : wot.Population Measure over the cells at a given timepoint to be pushed forward. to_time : int or float, optional Destination timepoint to pull back to. normalize : bool, optional, default: True Wether to normalize to a probability distribution or keep growth. as_list : bool, optional, default: False Wether to return a listof length 1 when a single element is passed, or a Population Returns ------- result : wot.Population The pull back of the input population through the proper transport map. Array of populations if several populations were given as input. Raises ------ ValueError If there is no previous timepoint to pull the population back. ValueError If several populations are given as input but dot live in the same timepoint. Examples -------- >>> self.pull_back(pop, to_time = 0) # -> wot.Population Pushing several populations at once >>> self.pull_back(pop1, pop2, pop3) # -> list of wot.Population Pulling back after pushing forward >>> self.pull_back(self.push_forward(pop)) Same, but several populations at once >>> self.pull_back(* self.push_forward(pop1, pop2, pop3)) """ i = self.timepoints.index(wot.tmap.unique_timepoint(*populations)) j = i - 1 if to_time is None else self.timepoints.index(to_time) if i == -1: raise ValueError("Timepoint not found") if i == 0: raise ValueError("No previous timepoints. Unable to pull back") if j == -1: raise ValueError("Destination timepoint not found") if i < j: raise ValueError( "Destination timepoint is after source. Unable to pull back") p = np.vstack([pop.p for pop in populations]) while i > j: t1 = self.timepoints[i] t0 = self.timepoints[i - 1] tmap = self.get_coupling(t0, t1) p = (tmap.X @ p.T).T if normalize: p = (p.T / np.sum(p, axis=1)).T i -= 1 result = [ Population(self.timepoints[i], p[k, :]) for k in range(p.shape[0]) ] if len(result) == 1 and not as_list: return result[0] else: return result