Ejemplo n.º 1
0
    def maybe_create_lineage(
        direction: Union[str, Direction], pretty_name: Optional[str] = None
    ):
        if isinstance(direction, Direction):
            lin_key = str(
                AbsProbKey.FORWARD
                if direction == Direction.FORWARD
                else AbsProbKey.BACKWARD
            )
        else:
            lin_key = direction

        pretty_name = "" if pretty_name is None else (pretty_name + " ")
        names_key, colors_key = _lin_names(lin_key), _colors(lin_key)

        if lin_key in adata.obsm.keys():
            n_cells, n_lineages = adata.obsm[lin_key].shape
            logg.info(f"Creating {pretty_name}`Lineage` from `adata.obsm[{lin_key!r}]`")

            if names_key not in adata.uns.keys():
                logg.warning(
                    f"    Lineage names not found in `adata.uns[{names_key!r}]`, creating new names"
                )
                names = [f"Lineage {i}" for i in range(n_lineages)]
            elif len(adata.uns[names_key]) != n_lineages:
                logg.warning(
                    f"    Lineage names are don't have the required length ({n_lineages}), creating new names"
                )
                names = [f"Lineage {i}" for i in range(n_lineages)]
            else:
                logg.info("    Successfully loaded names")
                names = adata.uns[names_key]

            if colors_key not in adata.uns.keys():
                logg.warning(
                    f"    Lineage colors not found in `adata.uns[{colors_key!r}]`, creating new colors"
                )
                colors = _create_categorical_colors(n_lineages)
            elif len(adata.uns[colors_key]) != n_lineages or not all(
                map(lambda c: is_color_like(c), adata.uns[colors_key])
            ):
                logg.warning(
                    f"    Lineage colors don't have the required length ({n_lineages}) "
                    f"or are not color-like, creating new colors"
                )
                colors = _create_categorical_colors(n_lineages)
            else:
                logg.info("    Successfully loaded colors")
                colors = adata.uns[colors_key]

            adata.obsm[lin_key] = Lineage(
                adata.obsm[lin_key], names=names, colors=colors
            )
            adata.uns[colors_key] = colors
            adata.uns[names_key] = names
        else:
            logg.debug(
                f"Unable to load {pretty_name}`Lineage` from `adata.obsm[{lin_key!r}]`"
            )
Ejemplo n.º 2
0
    def _reconstruct_lineage(self, attr: PrettyEnum, obsm_key: str):
        self._set_or_debug(obsm_key, self.adata.obsm, attr)
        names = self._set_or_debug(_lin_names(self._term_key), self.adata.uns)
        colors = self._set_or_debug(_colors(self._term_key), self.adata.uns)

        probs = self._get(attr)

        if probs is not None:
            if len(names) != probs.shape[1]:
                if isinstance(probs, Lineage):
                    names = probs.names
                else:
                    logg.warning(
                        f"Expected lineage names to be of length `{probs.shape[1]}`, found `{len(names)}`. "
                        f"Creating new names"
                    )
                    names = [f"Lineage {i}" for i in range(probs.shape[1])]
            if len(colors) != probs.shape[1] or not all(
                map(lambda c: isinstance(c, str) and is_color_like(c), colors)
            ):
                if isinstance(probs, Lineage):
                    colors = probs.colors
                else:
                    logg.warning(
                        f"Expected lineage colors to be of length `{probs.shape[1]}`, found `{len(names)}`. "
                        f"Creating new colors"
                    )
                    colors = _create_categorical_colors(probs.shape[1])
            self._set(attr, Lineage(probs, names=names, colors=colors))

            self.adata.obsm[obsm_key] = self._get(attr)
            self.adata.uns[_lin_names(self._term_key)] = names
            self.adata.uns[_colors(self._term_key)] = colors
Ejemplo n.º 3
0
    def test_automatic_color_assignment(self):
        x = np.random.random((10, 3))
        l = Lineage(x, names=["foo", "bar", "baz"])

        gt_colors = [colors.to_hex(c) for c in _create_categorical_colors(3)]

        np.testing.assert_array_equal(l.colors, gt_colors)
Ejemplo n.º 4
0
    def _reconstruct_lineage(self, attr: PrettyEnum, obsm_key: str):

        self._set_or_debug(obsm_key, self.adata.obsm, attr)
        names = self._set_or_debug(_lin_names(self._term_key), self.adata.uns)
        colors = self._set_or_debug(_colors(self._term_key), self.adata.uns)

        # choosing this instead of property because GPCCA doesn't have property for FIN_ABS_PROBS
        probs = self._get(attr)

        if probs is not None:
            if len(names) != probs.shape[1]:
                logg.debug(
                    f"Expected lineage names to be of length `{probs.shape[1]}`, found `{len(names)}`. "
                    f"Creating new names")
                names = [f"Lineage {i}" for i in range(probs.shape[1])]
            if len(colors) != probs.shape[1] or not all(
                    map(lambda c: isinstance(c, str) and is_color_like(c),
                        colors)):
                logg.debug(
                    f"Expected lineage colors to be of length `{probs.shape[1]}`, found `{len(names)}`. "
                    f"Creating new colors")
                colors = _create_categorical_colors(probs.shape[1])
            self._set(attr, Lineage(probs, names=names, colors=colors))

            self.adata.obsm[obsm_key] = self._get(attr)
            self.adata.uns[_lin_names(self._term_key)] = names
            self.adata.uns[_colors(self._term_key)] = colors
Ejemplo n.º 5
0
    def create_col_categorical_color(cluster_key: str,
                                     rng: np.ndarray) -> np.ndarray:
        if not is_categorical_dtype(adata.obs[cluster_key]):
            raise TypeError(
                f"Expected `adata.obs[{cluster_key!r}]` to be categorical, "
                f"found `{adata.obs[cluster_key].dtype.name!r}`.")

        color_key = f"{cluster_key}_colors"
        if color_key not in adata.uns:
            logg.warning(
                f"Color key `{color_key!r}` not found in `adata.uns`. Creating new colors"
            )
            colors = _create_categorical_colors(
                len(adata.obs[cluster_key].cat.categories))
            adata.uns[color_key] = colors
        else:
            colors = adata.uns[color_key]

        time_series = adata.obs[time_key]
        ixs = find_indices(time_series, rng)

        mapper = dict(zip(adata.obs[cluster_key].cat.categories, colors))

        return np.array([
            mcolors.to_hex(mapper[v])
            for v in adata[ixs, :].obs[cluster_key].values
        ])
Ejemplo n.º 6
0
    def test_no_colors(self, adata: AnnData, path: Path, lin_key: str, n_lins: int):
        colors_key = _colors(lin_key)
        del adata.uns[colors_key]

        sc.write(path, adata)
        adata_new = cr.read(path)
        lins = adata_new.obsm[lin_key]

        assert isinstance(lins, Lineage)
        np.testing.assert_array_equal(lins.colors, _create_categorical_colors(n_lins))
        np.testing.assert_array_equal(lins.colors, adata_new.uns[colors_key])
Ejemplo n.º 7
0
 def cmap(self) -> Mapping[str, Any]:
     """Colormap for :attr:`clusters`."""
     return dict(
         zip(
             self.clusters.cat.categories,
             self._adata.uns.get(
                 f"{self._ckey}_colors",
                 _create_categorical_colors(
                     len(self.clusters.cat.categories)),
             ),
         ))
Ejemplo n.º 8
0
    def colors(self, value: Optional[Iterable[ColorLike]]):
        if value is None:
            value = _create_categorical_colors(self._n_lineages)
        elif not isinstance(value, Iterable):
            raise TypeError(_ERROR_NOT_ITERABLE.format("colors", type(value).__name__))

        value = self._check_shape(value, _ERROR_WRONG_SIZE.format("colors"))
        self._colors = self._prepare_annotation(
            value,
            checker=c.is_color_like,
            transformer=c.to_hex,
            checker_msg="Value `{}` is not a valid color.",
        )
Ejemplo n.º 9
0
    def test_normal_run(self, adata: AnnData, path: Path, lin_key: str, n_lins: int):
        colors = _create_categorical_colors(10)[-n_lins:]
        names = [f"foo {i}" for i in range(n_lins)]

        adata.uns[_colors(lin_key)] = colors
        adata.uns[_lin_names(lin_key)] = names

        sc.write(path, adata)
        adata_new = cr.read(path)
        lins_new = adata_new.obsm[lin_key]

        np.testing.assert_array_equal(lins_new.colors, colors)
        np.testing.assert_array_equal(lins_new.names, names)
Ejemplo n.º 10
0
def _get_categorical_colors(
    adata: AnnData, cluster_key: str
) -> Tuple[np.ndarray, Mapping[str, str]]:
    if cluster_key not in adata.obs:
        raise KeyError(f"Unable to find data in `adata.obs[{cluster_key!r}].`")
    if not is_categorical_dtype(adata.obs[cluster_key]):
        raise TypeError(
            f"Expected `adata.obs[{cluster_key!r}]` to be categorical, "
            f"found `{infer_dtype(adata.obs[cluster_key])}`."
        )

    color_key = f"{cluster_key}_colors"
    try:
        colors = adata.uns[color_key]
    except KeyError:
        adata.uns[color_key] = colors = _create_categorical_colors(
            len(adata.obs[cluster_key].cat.categories)
        )
    mapper = dict(zip(adata.obs[cluster_key].cat.categories, colors))
    mapper[np.nan] = "grey"

    return colors, mapper
Ejemplo n.º 11
0
 def test_create_categorical_colors_too_many_colors(self):
     with pytest.raises(ValueError):
         _create_categorical_colors(1000)
Ejemplo n.º 12
0
def _create_root_final_annotations(
    adata: AnnData,
    final_key: str = "terminal_states",
    root_key: str = "initial_states",
    final_pref: Optional[str] = "terminal",
    root_pref: Optional[str] = "initial",
    key_added: Optional[str] = "initial_terminal",
) -> None:
    """
    Create categorical annotations of both root and final states.
    This is a utility function for creating a categorical Series object which combines the information about root
    and final states. The Series is written directly to the AnnData object.  This can for example be used to create a
    scatter plot in scvelo.

    Parameters
    ----------
    adata
        AnnData object to write to (`.obs[key_added]`).
    final_key
        Key from `.obs` where final states have been saved.
    root_key
        Key from `.obs` where root states have been saved.
    final_pref, root_pref
        DirPrefix used in the annotations.
    key_added
        Key added to `adata.obs`.
    Returns
    -------
    Nothing, just writes to AnnData.
    """
    from cellrank.tl._utils import _merge_categorical_series
    from cellrank.tl._colors import _create_categorical_colors

    if f"{final_key}_colors" not in adata.uns:
        adata.uns[f"{final_key}_colors"] = _create_categorical_colors(
            len(adata.obs[final_key].cot.categories))

    if f"{root_key}_colors" not in adata.uns:
        adata.uns[f"{root_key}_colors"] = _create_categorical_colors(
            30)[::-1][len(adata.obs[root_key].cat.categories):]

    # get both Series objects
    cats_final, colors_final = adata.obs[final_key], adata.uns[
        f"{final_key}_colors"]
    cats_root, colors_root = adata.obs[root_key], adata.uns[
        f"{root_key}_colors"]

    # merge
    cats_merged, colors_merged = _merge_categorical_series(
        cats_final, cats_root, list(colors_final), list(colors_root))

    # adjust the names
    final_names = cats_final.cat.categories
    final_labels = [
        f"{final_pref if key in final_names else root_pref}: {key}"
        for key in cats_merged.cat.categories
    ]
    cats_merged.cat.rename_categories(final_labels, inplace=True)

    # write to AnnData
    adata.obs[key_added] = cats_merged
    adata.uns[f"{key_added}_colors"] = colors_merged
Ejemplo n.º 13
0
    def plot_macrostate_composition(
        self,
        key: str,
        width: float = 0.8,
        title: Optional[str] = None,
        labelrot: float = 45,
        legend_loc: Optional[str] = "upper right out",
        figsize: Optional[Tuple[float, float]] = None,
        dpi: Optional[int] = None,
        save: Optional[Union[str, Path]] = None,
        show: bool = True,
    ) -> Optional[Axes]:
        """
        Plot stacked histogram of macrostates over categorical annotations.

        Parameters
        ----------
        %(adata)s
        key
            Key from :attr:`adata` ``.obs`` containing categorical annotations.
        width
            Bar width in `[0, 1]`.
        title
            Title of the figure. If `None`, create one automatically.
        labelrot
            Rotation of labels on x-axis.
        legend_loc
            Position of the legend. If `None`, don't show legend.
        %(plotting)s
        show
            If `False`, return :class:`matplotlib.pyplot.Axes`.

        Returns
        -------
        :class:`matplotlib.pyplot.Axes`
            The axis object if ``show=False``.
        %(just_plots)s
        """
        from cellrank.pl._utils import _position_legend

        macrostates = self._get(P.MACRO)
        if macrostates is None:
            raise RuntimeError(
                "Compute macrostates first as `.compute_macrostates()`.")
        if key not in self.adata.obs:
            raise KeyError(f"Key `{key}` not found in `adata.obs`.")
        if not is_categorical_dtype(self.adata.obs[key]):
            raise TypeError(
                f"Expected `adata.obs[{key!r}]` to be `categorical`, "
                f"found `{infer_dtype(self.adata.obs[key])}`.")

        mask = ~macrostates.isnull()
        df = (pd.DataFrame({
            "macrostates": macrostates,
            key: self.adata.obs[key]
        })[mask].groupby([key, "macrostates"]).size())
        try:
            cats_colors = self.adata.uns[f"{key}_colors"]
        except KeyError:
            cats_colors = _create_categorical_colors(
                len(self.adata.obs[key].cat.categories))
        cat_color_mapper = dict(
            zip(self.adata.obs[key].cat.categories, cats_colors))
        x_indices = np.arange(len(macrostates.cat.categories))
        bottom = np.zeros_like(x_indices, dtype=np.float32)

        width = min(1, max(0, width))
        fig, ax = plt.subplots(figsize=figsize, dpi=dpi, tight_layout=True)
        for cat, color in cat_color_mapper.items():
            frequencies = df.loc[cat]
            # do not add to legend if category is missing
            if np.sum(frequencies) > 0:
                ax.bar(
                    x_indices,
                    frequencies,
                    width,
                    label=cat,
                    color=color,
                    bottom=bottom,
                    ec="black",
                    lw=0.5,
                )
                bottom += np.array(frequencies)

        ax.set_xticks(x_indices)
        ax.set_xticklabels(
            # assuming at least 1 category
            frequencies.index,
            rotation=labelrot,
            ha="center" if labelrot in (0, 90) else "right",
        )
        y_max = bottom.max()
        ax.set_ylim([0, y_max + 0.05 * y_max])
        ax.set_yticks(np.linspace(0, y_max, 5))
        ax.margins(0.05)

        ax.set_xlabel("macrostate")
        ax.set_ylabel("frequency")
        if title is None:
            title = f"distribution over {key}"
        ax.set_title(title)
        if legend_loc not in (None, "none"):
            _position_legend(ax, legend_loc=legend_loc)

        if save is not None:
            save_fig(fig, save)

        if not show:
            return ax
Ejemplo n.º 14
0
    def _plot_discrete(
        self,
        data: pd.Series,
        prop: str,
        lineages: Optional[Union[str, Sequence[str]]] = None,
        cluster_key: Optional[str] = None,
        same_plot: bool = True,
        title: Optional[Union[str, List[str]]] = None,
        **kwargs,
    ) -> None:
        """
        Plot the states for each uncovered lineage.

        Parameters
        ----------
        lineages
            Plot only these lineages. If `None`, plot all lineages.
        cluster_key
            Key from :attr:`adata` ``.obs`` for plotting categorical observations.
        same_plot
            Whether to plot the lineages on the same plot or separately.
        title
            The title of the plot.
        %(basis)s
        kwargs
            Keyword arguments for :func:`scvelo.pl.scatter`.

        Returns
        -------
        %(just_plots)s
        """

        if data is None:
            raise RuntimeError(
                f"Compute `.{prop}` first as `.{F.COMPUTE.fmt(prop)}()`.")
        if not is_categorical_dtype(data):
            raise TypeError(
                f"Expected property `.{prop}` to be categorical, found `{type(data).__name__!r}`."
            )
        if prop in (P.ABS_PROBS.s, P.TERM.s):
            colors = getattr(self, A.TERM_COLORS.v, None)
        elif prop == P.MACRO.v:
            colors = getattr(self, A.MACRO_COLORS.v, None)
        else:
            logg.debug("No colors found. Creating new ones")
            colors = _create_categorical_colors(len(data.cat.categories))
        colors = dict(zip(data.cat.categories, colors))

        if (
                lineages is not None
        ):  # these are states per-se, but I want to keep the arg names for dispatch the same
            if isinstance(lineages, str):
                lineages = [lineages]
            for state in lineages:
                if state not in data.cat.categories:
                    raise ValueError(
                        f"Invalid state `{state!r}`. Valid options are `{list(data.cat.categories)}`."
                    )
            data = data.copy()
            to_remove = list(set(data.cat.categories) - set(lineages))

            if len(to_remove) == len(data.cat.categories):
                raise RuntimeError(
                    "Nothing to plot because empty subset has been selected.")

            for state in to_remove:
                data[data == state] = np.nan
            data = data.cat.remove_categories(to_remove)

        if cluster_key is None:
            cluster_key = []
        elif isinstance(cluster_key, str):
            cluster_key = [cluster_key]
        if not isinstance(cluster_key, list):
            cluster_key = list(cluster_key)

        same_plot = same_plot or len(data.cat.categories) == 1
        kwargs["legend_loc"] = kwargs.get("legend_loc", "on data")

        with RandomKeys(self.adata,
                        None if same_plot else len(data.cat.categories),
                        where="obs") as keys:
            if same_plot:
                key = keys[0]
                self.adata.obs[key] = data
                self.adata.uns[f"{key}_colors"] = [
                    colors[c] for c in data.cat.categories
                ]

                if title is None:
                    title = (
                        f"{prop.replace('_', ' ')} "
                        f"({Direction.BACKWARD if self.kernel.backward else Direction.FORWARD})"
                    )
                if isinstance(title, str):
                    title = [title]

                scv.pl.scatter(
                    self.adata,
                    title=cluster_key + title,
                    color=cluster_key + keys,
                    **_filter_kwargs(scv.pl.scatter, **kwargs),
                )
            else:
                for key, cat in zip(keys, data.cat.categories):
                    d = data.copy()
                    d[data != cat] = None
                    d = d.cat.set_categories([cat])

                    self.adata.obs[key] = d
                    self.adata.uns[f"{key}_colors"] = [colors[cat]]

                scv.pl.scatter(
                    self.adata,
                    color=cluster_key + keys,
                    title=(cluster_key + [
                        f"{_initial if self.kernel.backward else _terminal} state {c}"
                        for c in data.cat.categories
                    ]) if title is None else title,
                    **_filter_kwargs(scv.pl.scatter, **kwargs),
                )
Ejemplo n.º 15
0
def _trends_helper(
    models: Dict[str, Dict[str, Any]],
    gene: str,
    transpose: bool = False,
    lineage_names: Optional[Sequence[str]] = None,
    same_plot: bool = False,
    sharey: Union[str, bool] = False,
    show_ylabel: bool = True,
    show_lineage: Union[bool, np.ndarray] = True,
    show_xticks_and_label: Union[bool, np.ndarray] = True,
    lineage_cmap: Optional[Union[mpl.colors.ListedColormap, Sequence]] = None,
    lineage_probability_color: Optional[str] = None,
    abs_prob_cmap=cm.viridis,
    gene_as_title: bool = False,
    legend_loc: Optional[str] = "best",
    fig: mpl.figure.Figure = None,
    axes: Union[mpl.axes.Axes, Sequence[mpl.axes.Axes]] = None,
    **kwargs,
) -> None:
    """
    Plot an expression gene for some lineages.

    Parameters
    ----------
    %(adata)s
    %(model)s
    gene
        Name of the gene in `adata.var_names``.
    ln_key
        Key in ``adata.obsm`` where to find the lineages.
    lineage_names
        Names of lineages to plot.
    same_plot
        Whether to plot all lineages in the same plot or separately.
    sharey
        Whether the y-axis is being shared.
    show_ylabel
        Whether to show y-label on the y-axis. Usually, only the first column will contain the y-label.
    show_lineage
        Whether to show the lineage as the title. Usually, only first row will contain the lineage names.
    show_xticks_and_label
        Whether to show x-ticks and x-label. Usually, only the last row will show this.
    lineage_cmap
        Colormap to use when coloring the the lineage. When ``transpose``, this corresponds to the color of genes.
    lineage_probability_color
        Actual color of 1 ``lineage``. Only used when ``same_plot=True`` and ``transpose=True`` and
        ``lineage_probability=True``.
    abs_prob_cmap:
        Colormap to use when coloring in the absorption probabilities, if they are being plotted.
    gene_as_title
        Whether to use the gene names as titles (with lineage names as well) or on the y-axis.
    legend_loc
        Location of the legend. If `None`, don't show any legend.
    fig
        Figure to use.
    ax
        Ax to use.
    **kwargs
        Keyword arguments for :meth:`cellrank.ul.models.BaseModel.plot`.

    Returns
    -------
    %(just_plots)s
    """

    n_lineages = len(lineage_names)
    if same_plot:
        axes = [axes] * len(lineage_names)

    fig.tight_layout()
    axes = np.ravel(axes)

    percs = kwargs.pop("perc", None)
    if percs is None or not isinstance(percs[0], (tuple, list)):
        percs = [percs]

    same_perc = False  # we need to show colorbar always if percs differ
    if len(percs) != n_lineages or n_lineages == 1:
        if len(percs) != 1:
            raise ValueError(
                f"Percentile must be a collection of size `1` or `{n_lineages}`, got `{len(percs)}`."
            )
        same_perc = True
        percs = percs * n_lineages

    hide_cells = kwargs.pop("hide_cells", False)
    show_cbar = kwargs.pop("cbar", True)
    show_prob = kwargs.pop("lineage_probability", False)

    if same_plot:
        if not transpose:
            lineage_colors = (lineage_cmap.colors if lineage_cmap is not None
                              and hasattr(lineage_cmap, "colors") else
                              lineage_cmap)
        else:
            # this should be fine w.r.t. to the missing genes, since they are in the same order AND
            # we're also passing the failed models (this is important)
            # these are actually gene colors, bu w/e
            if lineage_cmap is not None:
                lineage_colors = (
                    lineage_cmap.colors if hasattr(lineage_cmap, "colors") else
                    [c for _, c in zip(lineage_names, lineage_cmap)])
            else:
                lineage_colors = _create_categorical_colors(n_lineages)
    else:
        lineage_colors = (("black" if not mcolors.is_color_like(lineage_cmap)
                           else lineage_cmap), ) * n_lineages

    if n_lineages > len(lineage_colors):
        raise ValueError(
            f"Expected at least `{n_lineages}` colors, found `{len(lineage_colors)}`."
        )

    lineage_color_mapper = {
        ln: lineage_colors[i]
        for i, ln in enumerate(lineage_names)
    }

    successful_models = {
        ln: models[gene][ln]
        for ln in lineage_names if models[gene][ln]
    }

    if show_prob and same_plot:
        minns, maxxs = zip(*[
            models[gene][n]._return_min_max(show_conf_int=kwargs.get(
                "conf_int", False), ) for n in lineage_names
        ])
        minn, maxx = min(minns), max(maxxs)
        kwargs["loc"] = legend_loc
        kwargs["scaler"] = lambda x: (x - minn) / (maxx - minn)
    else:
        kwargs["loc"] = None

    if isinstance(show_xticks_and_label, bool):
        show_xticks_and_label = [show_xticks_and_label] * len(lineage_names)
    elif len(show_xticks_and_label) != len(lineage_names):
        raise ValueError(
            f"Expected `show_xticks_label` to be the same length as `lineage_names`, "
            f"found `{len(show_xticks_and_label)}` != `{len(lineage_names)}`.")

    if isinstance(show_lineage, bool):
        show_lineage = [show_lineage] * len(lineage_names)
    elif len(show_lineage) != len(lineage_names):
        raise ValueError(
            f"Expected `show_lineage` to be the same length as `lineage_names`, "
            f"found `{len(show_lineage)}` != `{len(lineage_names)}`.")

    last_ax = None
    ylabel_shown = False
    cells_shown = False

    for i, (name, ax, perc) in enumerate(zip(lineage_names, axes, percs)):
        model = models[gene][name]
        if isinstance(model, FailedModel):
            if not same_plot:
                ax.remove()
            continue

        if same_plot:
            if gene_as_title:
                title = gene
                ylabel = "expression" if show_ylabel else None
            else:
                title = ""
                ylabel = gene
        else:
            if gene_as_title:
                title = None
                ylabel = "expression" if not ylabel_shown else None
            else:
                title = ((name if name is not None else "no lineage")
                         if show_lineage[i] else "")
                ylabel = gene if not ylabel_shown else None

        model.plot(
            ax=ax,
            fig=fig,
            perc=perc,
            cbar=False,
            title=title,
            hide_cells=cells_shown if not hide_cells else True,
            same_plot=same_plot,
            lineage_color=lineage_color_mapper[name],
            lineage_probability_color=lineage_probability_color,
            abs_prob_cmap=abs_prob_cmap,
            lineage_probability=show_prob,
            ylabel=ylabel,
            **kwargs,
        )
        if sharey in ("row", "all", True) and not ylabel_shown:
            plt.setp(ax.get_yticklabels(), visible=True)

        if show_xticks_and_label[i]:
            plt.setp(ax.get_xticklabels(), visible=True)
        else:
            ax.set_xlabel(None)

        last_ax = ax
        ylabel_shown = True
        cells_shown = True

    if not same_plot and same_perc and show_cbar and not hide_cells:
        vmin = np.min([model.w_all for model in successful_models.values()])
        vmax = np.max([model.w_all for model in successful_models.values()])
        norm = mcolors.Normalize(vmin=vmin, vmax=vmax)

        for ax in axes:
            children = [
                c for c in ax.get_children()
                if isinstance(c, mpl.collections.PathCollection)
            ]
            if len(children):
                children[0].set_norm(norm)

        divider = make_axes_locatable(last_ax)
        cax = divider.append_axes("right", size="2%", pad=0.1)
        _ = mpl.colorbar.ColorbarBase(
            cax,
            norm=norm,
            cmap=abs_prob_cmap,
            label="absorption probability",
            ticks=np.linspace(norm.vmin, norm.vmax, 5),
        )

    if same_plot and lineage_names != [None] and legend_loc is not None:
        handles = [
            mpl.lines.Line2D([], [], color=lineage_color_mapper[ln], label=ln)
            for ln in successful_models.keys()
        ]
        last_ax.legend(handles=handles, loc=legend_loc)
Ejemplo n.º 16
0
    def test_create_categorical_colors_normal_run(self):
        colors = _create_categorical_colors(62)

        assert len(colors) == 62
        assert all(map(lambda c: isinstance(c, str), colors))
        assert all(map(lambda c: is_color_like(c), colors))
Ejemplo n.º 17
0
 def test_create_categorical_colors_neg_categories(self):
     with pytest.raises(RuntimeError):
         _create_categorical_colors(-1)
Ejemplo n.º 18
0
    def test_create_categorical_colors_no_categories(self):
        c = _create_categorical_colors(0)

        assert c == []
Ejemplo n.º 19
0
    def _set_categorical_labels(
        self,
        attr_key: str,
        color_key: str,
        pretty_attr_key: str,
        categories: Union[Series, Dict[Any, Any]],
        add_to_existing_error_msg: Optional[str] = None,
        cluster_key: Optional[str] = None,
        en_cutoff: Optional[float] = None,
        p_thresh: Optional[float] = None,
        add_to_existing: bool = False,
    ) -> None:
        if isinstance(categories, dict):
            categories = _convert_to_categorical_series(
                categories, list(self.adata.obs_names)
            )
        if not is_categorical_dtype(categories):
            raise TypeError(
                f"Object must be `categorical`, found `{infer_dtype(categories)}`."
            )

        if add_to_existing:
            if getattr(self, attr_key) is None:
                raise RuntimeError(add_to_existing_error_msg)
            categories = _merge_categorical_series(
                getattr(self, attr_key), categories, inplace=False
            )

        if cluster_key is not None:
            logg.debug(f"Creating colors based on `{cluster_key}`")

            # check that we can load the reference series from adata
            if cluster_key not in self.adata.obs:
                raise KeyError(
                    f"Cluster key `{cluster_key!r}` not found in `adata.obs`."
                )
            series_query, series_reference = categories, self.adata.obs[cluster_key]

            # load the reference colors if they exist
            if _colors(cluster_key) in self.adata.uns.keys():
                colors_reference = _convert_to_hex_colors(
                    self.adata.uns[_colors(cluster_key)]
                )
            else:
                colors_reference = _create_categorical_colors(
                    len(series_reference.cat.categories)
                )

            approx_rcs_names, colors = _map_names_and_colors(
                series_reference=series_reference,
                series_query=series_query,
                colors_reference=colors_reference,
                en_cutoff=en_cutoff,
            )
            setattr(self, color_key, colors)
            # if approx_rcs_names is categorical, the info is take from .cat.categories
            categories.cat.categories = approx_rcs_names
        else:
            setattr(
                self,
                color_key,
                _create_categorical_colors(len(categories.cat.categories)),
            )

        if p_thresh is not None:
            self._detect_cc_stages(categories, p_thresh=p_thresh)

        # write to class and adata
        if getattr(self, attr_key) is not None:
            logg.debug(f"Overwriting `.{pretty_attr_key}`")

        setattr(self, attr_key, categories)