Example #1
0
    def fates(self, populations):
        """
        Computes fates for each population

        Parameters
        ----------
        self : wot.TransportMapModel
            The TransportMapModel used to find fates
        populations : list of wot.Population
            The target populations such as ones from self.population_from_cell_sets. The populations must be from the same time.
        Returns
        -------
        fates : anndata.AnnData
            Rows : all cells, Columns : populations index. At point (i, j) : the probability that cell i belongs to population j
        """
        start_day = wot.tmap.unique_timepoint(
            *populations)  # check for unique timepoint
        populations = Population.copy(*populations,
                                      normalize=False,
                                      add_missing=True)
        pop_names = [pop.name for pop in populations]
        results = []

        results.insert(0, np.array([pop.p for pop in populations]).T)
        while self.can_pull_back(*populations):
            populations = self.pull_back(*populations,
                                         as_list=True,
                                         normalize=False)
            results.insert(0, np.array([pop.p for pop in populations]).T)

        X = np.concatenate(results)
        X /= X.sum(axis=1, keepdims=1)
        obs = self.meta.copy()
        obs = obs[obs['day'] <= start_day]
        return anndata.AnnData(X=X, obs=obs, var=pd.DataFrame(index=pop_names))
Example #2
0
 def get_population(ids_el):
     cell_indices = df.index.get_indexer_for(ids_el)
     cell_indices = cell_indices[cell_indices > -1]
     if len(cell_indices) is 0:
         return None
     p = np.zeros(len(df), dtype=np.float64)
     p[cell_indices] = 1.0
     return Population(day, p / np.sum(p))
Example #3
0
    def transition_table(self, start_populations, end_populations):
        """
       Computes a transition table from the starting populations to the ending populations

       Parameters
       ----------
       self : wot.TransportMapModel
           The TransportMapModel
       start_populations : list of wot.Population
           The target populations such as ones from self.population_from_cell_sets. THe populations must be from the same time.

       Returns
       -------
       transition table : anndata.AnnData
           Rows : starting populations, Columns : ending populations.
       """
        # add "other" population if any cells are missing across all populations
        start_time = wot.tmap.unique_timepoint(*start_populations)
        start_populations = Population.copy(*start_populations,
                                            normalize=False,
                                            add_missing=True)
        end_populations = Population.copy(*end_populations,
                                          normalize=False,
                                          add_missing=True)
        wot.tmap.unique_timepoint(
            *end_populations)  # check for unique timepoint
        populations = end_populations
        results = []
        results.insert(0, np.array([pop.p for pop in populations]).T)
        while self.can_pull_back(*populations) and wot.tmap.unique_timepoint(
                *populations) > start_time:
            populations = self.pull_back(*populations,
                                         as_list=True,
                                         normalize=False)

        end_p = np.vstack([pop.p for pop in populations])
        start_p = np.vstack([pop.p for pop in start_populations])
        p = (start_p @ end_p.T)
        p = p / p.sum()
        return anndata.AnnData(
            X=p,
            obs=pd.DataFrame(index=[p.name for p in start_populations]),
            var=pd.DataFrame(index=[p.name for p in end_populations]))
Example #4
0
    def trajectories(self, populations):
        """
        Computes a trajectory for each population

        Parameters
        ----------
        self : wot.TransportMapModel
            The TransportMapModel used to find ancestors and descendants of the population
        populations : list of wot.Population
            The target populations such as ones from self.population_from_cell_sets. THe populations must be from the same time.

        Returns
        -------
        trajectories : anndata.AnnData
            Rows : all cells, Columns : populations index. At point (i, j) : the probability that cell i is an
            ancestor/descendant of population j
        """
        wot.tmap.unique_timepoint(*populations)  # check for unique timepoint
        trajectories = []

        populations = Population.copy(*populations,
                                      normalize=True,
                                      add_missing=False)
        population_names = [p.name for p in populations]
        initial_populations = populations

        def update(head, populations_to_update):
            idx = 0 if head else len(trajectories)
            trajectories.insert(
                idx,
                np.array([pop.p for pop in populations_to_update]).T)

        update(True, populations)
        while self.can_pull_back(*populations):
            populations = self.pull_back(*populations, as_list=True)
            update(True, populations)
        populations = initial_populations
        while self.can_push_forward(*populations):
            populations = self.push_forward(*populations, as_list=True)
            update(False, populations)

        return anndata.AnnData(X=np.concatenate(trajectories),
                               obs=self.meta.copy(),
                               var=pd.DataFrame(index=population_names))
Example #5
0
    def pull_back(self,
                  *populations,
                  to_time=None,
                  normalize=True,
                  as_list=False):
        """
        Pulls the population back through the computed transport maps

        Parameters
        ----------
        *populations : wot.Population
            Measure over the cells at a given timepoint to be pushed forward.
        to_time : int or float, optional
            Destination timepoint to pull back to.
        normalize : bool, optional, default: True
            Wether to normalize to a probability distribution or keep growth.
        as_list : bool, optional, default: False
            Wether to return a listof length 1 when a single element is passed, or a Population

        Returns
        -------
        result : wot.Population
            The pull back of the input population through the proper transport map.
            Array of populations if several populations were given as input.

        Raises
        ------
        ValueError
            If there is no previous timepoint to pull the population back.
        ValueError
            If several populations are given as input but dot live in the same timepoint.

        Examples
        --------
        >>> self.pull_back(pop, to_time = 0) # -> wot.Population
        Pushing several populations at once
        >>> self.pull_back(pop1, pop2, pop3) # -> list of wot.Population
        Pulling back after pushing forward
        >>> self.pull_back(self.push_forward(pop))
        Same, but several populations at once
        >>> self.pull_back(* self.push_forward(pop1, pop2, pop3))
        """
        i = self.timepoints.index(wot.tmap.unique_timepoint(*populations))
        j = i - 1 if to_time is None else self.timepoints.index(to_time)

        if i == -1:
            raise ValueError("Timepoint not found")
        if i == 0:
            raise ValueError("No previous timepoints. Unable to pull back")
        if j == -1:
            raise ValueError("Destination timepoint not found")
        if i < j:
            raise ValueError(
                "Destination timepoint is after source. Unable to pull back")

        p = np.vstack([pop.p for pop in populations])
        while i > j:
            t1 = self.timepoints[i]
            t0 = self.timepoints[i - 1]
            tmap = self.get_coupling(t0, t1)
            p = (tmap.X @ p.T).T
            if normalize:
                p = (p.T / np.sum(p, axis=1)).T
            i -= 1

        result = [
            Population(self.timepoints[i], p[k, :]) for k in range(p.shape[0])
        ]
        if len(result) == 1 and not as_list:
            return result[0]
        else:
            return result