示例#1
0
文件: _pca.py 项目: fossabot/onecodex
    def plot_pca(
        self,
        rank=Rank.Auto,
        normalize="auto",
        org_vectors=0,
        org_vectors_scale=None,
        title=None,
        xlabel=None,
        ylabel=None,
        color=None,
        size=None,
        tooltip=None,
        return_chart=False,
        label=None,
        mark_size=100,
        width=None,
        height=None,
    ):
        """Perform principal component analysis and plot first two axes.

        Parameters
        ----------
        rank : {'auto', 'kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'}, optional
            Analysis will be restricted to abundances of taxa at the specified level.
        normalize : 'auto' or `bool`, optional
            Convert read counts to relative abundances such that each sample sums to 1.0. Setting
            'auto' will choose automatically based on the data.
        org_vectors : `int`, optional
            Plot this many of the top-contributing eigenvectors from the PCA results.
        org_vectors_scale : `float`, optional
            Multiply the length of the lines representing the eigenvectors by this constant.
        title : `string`, optional
            Text label at the top of the plot.
        xlabel : `string`, optional
            Text label along the horizontal axis.
        ylabel : `string`, optional
            Text label along the vertical axis.
        size : `string` or `tuple`, optional
            A string or a tuple containing strings representing metadata fields. The size of points
            in the resulting plot will change based on the metadata associated with each sample.
        color : `string` or `tuple`, optional
            A string or a tuple containing strings representing metadata fields. The color of points
            in the resulting plot will change based on the metadata associated with each sample.
        tooltip : `string` or `list`, optional
            A string or list containing strings representing metadata fields. When a point in the
            plot is hovered over, the value of the metadata associated with that sample will be
            displayed in a modal.
        label : `string` or `callable`, optional
            A metadata field (or function) used to label each analysis. If passing a function, a
            dict containing the metadata for each analysis is passed as the first and only
            positional argument. The callable function must return a string.
        mark_size: `int`, optional
            The size of the points in the scatter plot.

        Examples
        --------
        Perform PCA on relative abundances at the species-level and color the resulting points by
        'geo_loc_name', a metadata field representing the geographical origin of each sample.

        >>> plot_pca(rank='species', normalize=True, color='geo_loc_name')

        Change the size of each point in the plot based on the abundance of Bacteroides.

        >>> plot_pca(size='Bacteroides')

        Display the abundances of Bacteroides, Prevotella, and Bifidobacterium in each sample when
        hovering over points in the plot.

        >>> plot_pca(tooltip=['Bacteroides', 'Prevotella', 'Bifidobacterium'])
        """
        # Deferred imports
        import altair as alt
        import numpy as np
        import pandas as pd
        from sklearn.decomposition import PCA

        if rank is None:
            raise OneCodexException("Please specify a rank or 'auto' to choose automatically")

        if len(self._results) < 3:
            raise PlottingException(
                "There are too few samples for PCA after filtering. Please select 3 or more "
                "samples to plot."
            )

        df = self.to_df(rank=rank, normalize=normalize)

        if len(df.columns) < 2:
            raise PlottingException(
                "There are too few taxa for PCA after filtering. Please select a rank that "
                "includes at least 2 taxa."
            )

        if tooltip:
            if not isinstance(tooltip, list):
                tooltip = [tooltip]
        else:
            tooltip = []

        tooltip.insert(0, "Label")

        if color and color not in tooltip:
            tooltip.insert(1, color)

        if size and size not in tooltip:
            tooltip.insert(2, size)

        magic_metadata, magic_fields = self._metadata_fetch(tooltip, label=label)

        pca = PCA()
        pca_vals = pca.fit(df.values).transform(df.values)
        pca_vals = pd.DataFrame(pca_vals, index=df.index)
        pca_vals.rename(columns=lambda x: "PC{}".format(x + 1), inplace=True)

        # label the axes
        if xlabel is None:
            xlabel = "PC1 ({}%)".format(round(pca.explained_variance_ratio_[0] * 100, 2))
        if ylabel is None:
            ylabel = "PC2 ({}%)".format(round(pca.explained_variance_ratio_[1] * 100, 2))

        # don't send all the data to vega, just what we're plotting
        plot_data = pd.concat(
            [pca_vals.loc[:, ("PC1", "PC2")], magic_metadata], axis=1
        ).reset_index()

        alt_kwargs = dict(
            x=alt.X("PC1", axis=alt.Axis(title=xlabel)),
            y=alt.Y("PC2", axis=alt.Axis(title=ylabel)),
            tooltip=[magic_fields[t] for t in tooltip],
            href="url:N",
            url=get_base_classification_url() + alt.datum.classification_id,
        )

        # only add these parameters if they are in use
        if color:
            color_kwargs = {
                "legend": alt.Legend(title=magic_fields[color]),
            }
            if not is_continuous(plot_data[color]) or has_missing_values(plot_data[color]):
                plot_data[color] = plot_data[color].fillna("N/A").astype(str)
                domain = plot_data[color].values
                color_range = interleave_palette(domain)
                color_kwargs["scale"] = alt.Scale(domain=domain, range=color_range)

            alt_kwargs["color"] = alt.Color(magic_fields[color], **color_kwargs)
        if size:
            alt_kwargs["size"] = magic_fields[size]

        chart = (
            alt.Chart(plot_data)
            .transform_calculate(url=alt_kwargs.pop("url"))
            .mark_circle(size=mark_size)
        )

        vector_chart = None
        # plot the organism eigenvectors that contribute the most
        if org_vectors > 0:
            plot_data = {
                "x": [],
                "y": [],
                "o": [],  # order these points should be connected in
                "Eigenvectors": [],
            }

            magnitudes = np.sqrt(pca.components_[0] ** 2 + pca.components_[1] ** 2)
            magnitudes.sort()
            cutoff = magnitudes[-1 * org_vectors]

            if org_vectors_scale is None:
                org_vectors_scale = 0.8 * np.max(pca_vals.abs().values)

            for tax_id, var1, var2 in zip(
                df.columns.values, pca.components_[0, :], pca.components_[1, :]
            ):
                if np.sqrt(var1 ** 2 + var2 ** 2) >= cutoff:
                    plot_data["x"].extend([0, var1 * float(org_vectors_scale)])
                    plot_data["y"].extend([0, var2 * float(org_vectors_scale)])
                    plot_data["o"].extend([0, 1])
                    plot_data["Eigenvectors"].extend([self.taxonomy["name"][tax_id]] * 2)

                    org_vectors -= 1

                    if org_vectors == 0:
                        break

            plot_data = pd.DataFrame(plot_data)

            vector_chart = (
                alt.Chart(plot_data)
                .mark_line(point=False)
                .encode(
                    x=alt.X("x", axis=None),
                    y=alt.Y("y", axis=None),
                    order="o",
                    color="Eigenvectors",
                )
            )

        chart = chart.encode(**alt_kwargs)

        if vector_chart:
            chart = alt.layer(chart, vector_chart).resolve_scale(color="independent")

        chart = chart.properties(**prepare_props(title=title, height=height, width=width))

        if return_chart:
            return chart
        else:
            chart.interactive().display()
示例#2
0
    def plot_distance(
        self,
        rank=Rank.Auto,
        metric=BetaDiversityMetric.BrayCurtis,
        title=None,
        xlabel=None,
        ylabel=None,
        tooltip=None,
        return_chart=False,
        linkage=Linkage.Average,
        label=None,
        width=None,
        height=None,
    ):
        """Plot beta diversity distance matrix as a heatmap and dendrogram.

        Parameters
        ----------
        rank : {'auto', 'kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'}, optional
            Analysis will be restricted to abundances of taxa at the specified level.
        metric : {'braycurtis', 'cityblock', 'manhattan', 'jaccard', 'unifrac', 'unweighted_unifrac', 'aitchison'}, optional
            Function to use when calculating the distance between two samples.
            Note that 'cityblock' and 'manhattan' are equivalent metrics.
        linkage : {'average', 'single', 'complete', 'weighted', 'centroid', 'median'}
            The type of linkage to use when clustering axes.
        title : `string`, optional
            Text label at the top of the plot.
        xlabel : `string`, optional
            Text label along the horizontal axis.
        ylabel : `string`, optional
            Text label along the vertical axis.
        tooltip : `string` or `list`, optional
            A string or list containing strings representing metadata fields. When a point in the
            plot is hovered over, the value of the metadata associated with that sample will be
            displayed in a modal.
        label : `string` or `callable`, optional
            A metadata field (or function) used to label each analysis. If passing a function, a
            dict containing the metadata for each analysis is passed as the first and only
            positional argument. The callable function must return a string.

        Examples
        --------
        Plot the weighted UniFrac distance between all our samples, using counts at the genus level.

        >>> plot_distance(rank='genus', metric='unifrac')
        """
        import altair as alt
        import numpy as np
        import pandas as pd
        from onecodex.viz import dendrogram

        if len(self._results) < 2:
            raise PlottingException(
                "There are too few samples for distance matrix plots after filtering. Please "
                "select 2 or more samples to plot.")

        # this will be passed to the heatmap chart as a dataframe eventually
        plot_data = {
            "1) Label": [],
            "2) Label": [],
            "Distance": [],
            "classification_id": []
        }

        # here we figure out what to put in the tooltips and get the appropriate data
        if tooltip:
            if not isinstance(tooltip, list):
                tooltip = [tooltip]
        else:
            tooltip = []

        tooltip.insert(0, "Label")

        magic_metadata, magic_fields = self._metadata_fetch(tooltip,
                                                            label=label)
        formatted_fields = []

        for _, magic_field in magic_fields.items():
            field_group = []

            for i in (1, 2):
                field = "{}) {}".format(i, magic_field)
                plot_data[field] = []
                field_group.append(field)

            formatted_fields.append(field_group)

        clust = self._cluster_by_sample(rank=rank,
                                        metric=metric,
                                        linkage=linkage)

        # must convert to long format for heatmap plotting
        for idx1, id1 in enumerate(clust["dist_matrix"].index):
            for idx2, id2 in enumerate(clust["dist_matrix"].index):
                if idx1 == idx2:
                    plot_data["Distance"].append(np.nan)
                else:
                    plot_data["Distance"].append(
                        clust["dist_matrix"].iloc[idx1, idx2])

                plot_data["classification_id"].append(id1)

                for field_group, magic_field in zip(formatted_fields,
                                                    magic_fields.values()):
                    plot_data[field_group[0]].append(
                        magic_metadata[magic_field][id1])
                    plot_data[field_group[1]].append(
                        magic_metadata[magic_field][id2])

        plot_data = pd.DataFrame(data=plot_data)

        labels_in_order = magic_metadata["Label"][
            clust["ids_in_order"]].tolist()

        # it's important to tell altair to order the cells in the heatmap according to the clustering
        # obtained from scipy
        alt_kwargs = dict(
            x=alt.X("1) Label:N",
                    axis=alt.Axis(title=xlabel),
                    sort=labels_in_order),
            y=alt.Y("2) Label:N",
                    axis=alt.Axis(title=ylabel, orient="right"),
                    sort=labels_in_order),
            color=alt.Color("Distance:Q", legend=alt.Legend(title="Distance")),
            tooltip=list(chain.from_iterable(formatted_fields)) +
            ["Distance:Q"],
            href="url:N",
            url=get_base_classification_url() + alt.datum.classification_id,
        )

        chart = (alt.Chart(
            plot_data,
            width=15 * len(clust["dist_matrix"].index),
            height=15 * len(clust["dist_matrix"].index),
        ).transform_calculate(url=alt_kwargs.pop("url")).mark_rect().encode(
            **alt_kwargs))

        chart = chart.properties(**prepare_props(height=height, width=width))

        dendro_chart = dendrogram(clust["scipy_tree"])

        if height:
            cell_height = height / len(clust["dist_matrix"].index)
            dendro_chart = dendro_chart.properties(height=height -
                                                   cell_height / 2)

        title_kwargs = prepare_props(title=title)
        concat_chart = alt.hconcat(
            dendro_chart, chart, spacing=0,
            **title_kwargs).configure_view(strokeWidth=0)
        if return_chart:
            return concat_chart
        else:
            concat_chart.display()
示例#3
0
    def plot_mds(
        self,
        rank=Rank.Auto,
        metric=BetaDiversityMetric.BrayCurtis,
        method=OrdinationMethod.Pcoa,
        title=None,
        xlabel=None,
        ylabel=None,
        color=None,
        size=None,
        tooltip=None,
        return_chart=False,
        label=None,
        mark_size=100,
        width=None,
        height=None,
    ):
        """Plot beta diversity distance matrix using multidimensional scaling (MDS).

        Parameters
        ----------
        rank : {'auto', 'kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'}, optional
            Analysis will be restricted to abundances of taxa at the specified level.
        metric : {'braycurtis', 'cityblock', 'manhattan', 'jaccard', 'unifrac', 'unweighted_unifrac', 'aitchison'}, optional
            Function to use when calculating the distance between two samples.
            Note that 'cityblock' and 'manhattan' are equivalent metrics.
        method : {'pcoa', 'smacof'}
            Algorithm to use for ordination. PCoA uses eigenvalue decomposition and is not well
            suited to non-euclidean distance functions. SMACOF is an iterative optimization strategy
            that can be used as an alternative.
        title : `string`, optional
            Text label at the top of the plot.
        xlabel : `string`, optional
            Text label along the horizontal axis.
        ylabel : `string`, optional
            Text label along the vertical axis.
        size : `string` or `tuple`, optional
            A string or a tuple containing strings representing metadata fields. The size of points
            in the resulting plot will change based on the metadata associated with each sample.
        color : `string` or `tuple`, optional
            A string or a tuple containing strings representing metadata fields. The color of points
            in the resulting plot will change based on the metadata associated with each sample.
        tooltip : `string` or `list`, optional
            A string or list containing strings representing metadata fields. When a point in the
            plot is hovered over, the value of the metadata associated with that sample will be
            displayed in a modal.
        label : `string` or `callable`, optional
            A metadata field (or function) used to label each analysis. If passing a function, a
            dict containing the metadata for each analysis is passed as the first and only
            positional argument. The callable function must return a string.

        Examples
        --------
        Scatter plot of weighted UniFrac distance between all our samples, using counts at the genus
        level.

        >>> plot_mds(rank='genus', metric='unifrac')

        Notes
        -----
        **For `smacof`**: The values reported on the axis labels are Pearson's correlations between
        the distances between points on each axis alone, and the corresponding distances in the
        distance matrix calculated using the user-specified metric. These values are related to the
        effectiveness of the MDS algorithm in placing points on the scatter plot in such a way that
        they truly represent the calculated distances. They do not reflect how well the distance
        metric captures similarities between the underlying data (in this case, an OTU table).
        """
        import altair as alt
        import numpy as np
        import pandas as pd
        from scipy.spatial.distance import squareform
        from scipy.stats import pearsonr
        from skbio.stats import ordination
        from sklearn import manifold
        from sklearn.metrics.pairwise import euclidean_distances

        if len(self._results) < 3:
            raise PlottingException(
                "There are too few samples for MDS/PCoA after filtering. Please select 3 or more "
                "samples to plot.")

        dists = self._compute_distance(rank, metric).to_data_frame()

        # here we figure out what to put in the tooltips and get the appropriate data
        if tooltip:
            if not isinstance(tooltip, list):
                tooltip = [tooltip]
        else:
            tooltip = []

        tooltip.insert(0, "Label")

        if color and color not in tooltip:
            tooltip.insert(1, color)

        if size and size not in tooltip:
            tooltip.insert(2, size)

        magic_metadata, magic_fields = self._metadata_fetch(tooltip,
                                                            label=label)

        if method == OrdinationMethod.Smacof:
            # adapted from https://scikit-learn.org/stable/auto_examples/manifold/plot_mds.html
            x_field = "MDS1"
            y_field = "MDS2"

            seed = np.random.RandomState(seed=3)
            mds = manifold.MDS(max_iter=3000,
                               eps=1e-12,
                               random_state=seed,
                               dissimilarity="precomputed",
                               n_jobs=1)
            pos = mds.fit(dists).embedding_
            plot_data = pd.DataFrame(pos,
                                     columns=[x_field, y_field],
                                     index=dists.index)
            plot_data = plot_data.div(plot_data.abs().max(axis=0),
                                      axis=1)  # normalize to [0,1]

            # determine how much of the original distance is captured by each of the axes after MDS.
            # this implementation of MDS does not use eigen decomposition and so there's no simple
            # way of returning a 'percent of variance explained' value
            r_squared = []

            for axis in [0, 1]:
                mds_dist = pos.copy()
                mds_dist[::, axis] = 0
                mds_dist = squareform(euclidean_distances(mds_dist).round(6))
                r_squared.append(pearsonr(mds_dist, squareform(dists))[0])

            # label the axes
            x_extra_label = "r² = %.02f" % (r_squared[0], )
            y_extra_label = "r² = %.02f" % (r_squared[1], )
        elif method == OrdinationMethod.Pcoa:
            # suppress eigenvalue warning from skbio--not because it's an invalid warning, but
            # because lots of folks in the field run pcoa on these distances functions, even if
            # statistically inappropriate. perhaps this will change if we ever become more
            # opinionated about the analyses that we allow our users to do (roo)
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                ord_result = ordination.pcoa(
                    dists.round(6))  # round to avoid float precision errors

            plot_data = ord_result.samples.iloc[:,
                                                [0, 1
                                                 ]]  # get first two components
            plot_data = plot_data.div(plot_data.abs().max(axis=0),
                                      axis=1)  # normalize to [0,1]
            plot_data.index = dists.index
            x_field, y_field = plot_data.columns.tolist(
            )  # name of first two components

            x_extra_label = "%0.02f%%" % (ord_result.proportion_explained[0] *
                                          100, )
            y_extra_label = "%0.02f%%" % (ord_result.proportion_explained[1] *
                                          100, )
        else:
            raise OneCodexException("MDS method must be one of: {}".format(
                ", ".join(OrdinationMethod.values)))

        # label the axes
        if xlabel is None:
            xlabel = "{} ({})".format(x_field, x_extra_label)
        if ylabel is None:
            ylabel = "{} ({})".format(y_field, y_extra_label)

        plot_data = pd.concat([plot_data, magic_metadata],
                              axis=1).reset_index()

        alt_kwargs = dict(
            x=alt.X(x_field, axis=alt.Axis(title=xlabel)),
            y=alt.Y(y_field, axis=alt.Axis(title=ylabel)),
            tooltip=[magic_fields[t] for t in tooltip],
            href="url:N",
            url=get_base_classification_url() + alt.datum.classification_id,
        )

        # only add these parameters if they are in use
        if color:
            color_kwargs = {
                "legend": alt.Legend(title=magic_fields[color]),
            }
            if not is_continuous(plot_data[color]) or has_missing_values(
                    plot_data[color]):
                plot_data[color] = plot_data[color].fillna("N/A").astype(str)
                domain = plot_data[color].values
                color_range = interleave_palette(domain)
                color_kwargs["scale"] = alt.Scale(domain=domain,
                                                  range=color_range)

            alt_kwargs["color"] = alt.Color(magic_fields[color],
                                            **color_kwargs)
        if size:
            alt_kwargs["size"] = magic_fields[size]

        chart = (alt.Chart(plot_data).transform_calculate(
            url=alt_kwargs.pop("url")).mark_circle(size=mark_size).encode(
                **alt_kwargs))

        chart = chart.properties(
            **prepare_props(title=title, height=height, width=width))

        if return_chart:
            return chart
        else:
            chart.interactive().display()
示例#4
0
    def plot_metadata(
        self,
        rank=Rank.Auto,
        haxis="Label",
        vaxis=AlphaDiversityMetric.Shannon,
        title=None,
        xlabel=None,
        ylabel=None,
        return_chart=False,
        plot_type=PlotType.Auto,
        label=None,
        sort_x=None,
        width=200,
        height=400,
    ):
        """Plot an arbitrary metadata field versus an arbitrary quantity as a boxplot or scatter plot.

        Parameters
        ----------
        rank : {'auto', 'kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'}, optional
            Analysis will be restricted to abundances of taxa at the specified level.

        haxis : `string`, optional
            The metadata field (or tuple containing multiple categorical fields) to be plotted on
            the horizontal axis.

        vaxis : `string`, optional
            Data to be plotted on the vertical axis. Can be any one of the following:

            - A metadata field: the name of a metadata field containing numerical data
            - {'simpson', 'observed_taxa', 'shannon'}: an alpha diversity statistic to calculate for each sample
            - A taxon name: the name of a taxon in the analysis
            - A taxon ID: the ID of a taxon in the analysis

        title : `string`, optional
            Text label at the top of the plot.

        xlabel : `string`, optional
            Text label along the horizontal axis.

        ylabel : `string`, optional
            Text label along the vertical axis.

        plot_type : {'auto', 'boxplot', 'scatter'}
            By default, will determine plot type automatically based on the data. Otherwise, specify
            one of 'boxplot' or 'scatter' to set the type of plot manually.

        label : `string` or `callable`, optional
            A metadata field (or function) used to label each analysis. If passing a function, a
            dict containing the metadata for each analysis is passed as the first and only
            positional argument. The callable function must return a string.

        sort_x : `list` or `callable`, optional
            Either a list of sorted labels or a function that will be called with a list of x-axis labels
            as the only argument, and must return the same list in a user-specified order.

        Examples
        --------
        Generate a boxplot of the abundance of Bacteroides (genus) of samples grouped by whether the
        individuals are allergy to dogs, cats, both, or neither.

        >>> plot_metadata(haxis=('allergy_dogs', 'allergy_cats'), vaxis='Bacteroides')
        """
        # Deferred imports
        import altair as alt
        import pandas as pd

        if rank is None:
            raise OneCodexException(
                "Please specify a rank or 'auto' to choose automatically")

        if not PlotType.has_value(plot_type):
            raise OneCodexException(
                "Plot type must be one of: auto, boxplot, scatter")

        if len(self._results) < 1:
            raise PlottingException(
                "There are too few samples for metadata plots after filtering. Please select 1 or "
                "more samples to plot.")

        # alpha diversity is only allowed on vertical axis--horizontal can be magically mapped
        df, magic_fields = self._metadata_fetch([haxis, "Label"], label=label)

        if AlphaDiversityMetric.has_value(vaxis):
            df.loc[:, vaxis] = self.alpha_diversity(vaxis, rank=rank)
            magic_fields[vaxis] = vaxis
            df.dropna(subset=[magic_fields[vaxis]], inplace=True)
        else:
            # if it's not alpha diversity, vertical axis can also be magically mapped
            vert_df, vert_magic_fields = self._metadata_fetch([vaxis])

            # we require the vertical axis to be numerical otherwise plots get weird
            if (pd.api.types.is_bool_dtype(vert_df[vert_magic_fields[vaxis]])
                    or pd.api.types.is_categorical_dtype(
                        vert_df[vert_magic_fields[vaxis]])
                    or pd.api.types.is_object_dtype(
                        vert_df[vert_magic_fields[vaxis]])
                    or not pd.api.types.is_numeric_dtype(
                        vert_df[vert_magic_fields[vaxis]])):  # noqa
                raise OneCodexException(
                    "Metadata field on vertical axis must be numerical")

            df = pd.concat([df, vert_df],
                           axis=1).dropna(subset=[vert_magic_fields[vaxis]])
            magic_fields.update(vert_magic_fields)

        # plots can look different depending on what the horizontal axis contains
        if pd.api.types.is_datetime64_any_dtype(df[magic_fields[haxis]]):
            if plot_type == PlotType.Auto:
                plot_type = PlotType.BoxPlot
        elif "date" in magic_fields[haxis].split("_"):
            df.loc[:, magic_fields[haxis]] = df.loc[:,
                                                    magic_fields[haxis]].apply(
                                                        pd.to_datetime,
                                                        utc=True)

            if plot_type == PlotType.Auto:
                plot_type = PlotType.BoxPlot
        elif (pd.api.types.is_bool_dtype(df[magic_fields[haxis]])
              or pd.api.types.is_categorical_dtype(df[magic_fields[haxis]]) or
              pd.api.types.is_object_dtype(df[magic_fields[haxis]])):  # noqa
            df = df.fillna({field: "N/A" for field in df.columns})

            if plot_type == PlotType.Auto:
                # if data is categorical but there is only one value per sample, scatter plot instead
                if len(df[magic_fields[haxis]].unique()) == len(
                        df[magic_fields[haxis]]):
                    plot_type = PlotType.Scatter
                else:
                    plot_type = PlotType.BoxPlot
        elif pd.api.types.is_numeric_dtype(df[magic_fields[haxis]]):
            df = df.dropna(subset=[magic_fields[vaxis]])

            if plot_type == PlotType.Auto:
                plot_type = PlotType.Scatter
        else:
            raise OneCodexException(
                "Unplottable column type for horizontal axis ({})".format(
                    haxis))

        if xlabel is None:
            xlabel = magic_fields[haxis]

        if ylabel is None:
            ylabel = magic_fields[vaxis]

        if plot_type == "scatter":
            df = df.reset_index()

            sort_order = sort_helper(sort_x, df[magic_fields[haxis]].tolist())

            alt_kwargs = dict(
                x=alt.X(magic_fields[haxis],
                        axis=alt.Axis(title=xlabel),
                        sort=sort_order),
                y=alt.Y(magic_fields[vaxis], axis=alt.Axis(title=ylabel)),
                tooltip=["Label", "{}:Q".format(magic_fields[vaxis])],
                href="url:N",
                url=get_base_classification_url() +
                alt.datum.classification_id,
            )

            chart = (alt.Chart(df).transform_calculate(
                url=alt_kwargs.pop("url")).mark_circle().encode(**alt_kwargs))

        elif plot_type == PlotType.BoxPlot:
            if sort_x:
                raise OneCodexException(
                    "Must not specify sort_x when plot_type is boxplot")

            # See the following issue in case this gets fixed in altair:
            # https://github.com/altair-viz/altair/issues/2144
            if (df.groupby(magic_fields[haxis]).size() < 2).any():
                warnings.warn(
                    "There is at least one sample group consisting of only a single sample. Groups "
                    "of size 1 may not have their boxes displayed in the plot.",
                    PlottingWarning,
                )

            box_size = 45
            increment = 5

            n_boxes = len(df[magic_fields[haxis]].unique())

            if width and width != "container" and (
                    n_boxes * (box_size + increment)) > width:
                box_size = (
                    (width / n_boxes) // increment) * increment - increment

            chart = (alt.Chart(df).mark_boxplot(size=box_size).encode(
                x=alt.X(magic_fields[haxis], axis=alt.Axis(title=xlabel)),
                y=alt.Y(magic_fields[vaxis], axis=alt.Axis(title=ylabel)),
            ))

        chart = chart.properties(
            **prepare_props(title=title, height=height, width=width))

        if return_chart:
            return chart
        else:
            chart.interactive().display()
示例#5
0
    def plot_bargraph(
        self,
        rank=Rank.Auto,
        normalize="auto",
        top_n="auto",
        threshold="auto",
        title=None,
        xlabel=None,
        ylabel=None,
        tooltip=None,
        return_chart=False,
        haxis=None,
        legend="auto",
        label=None,
        sort_x=None,
        include_taxa_missing_rank=None,
        include_other=True,
        width=None,
        height=None,
    ):
        """Plot a bargraph of relative abundance of taxa for multiple samples.

        Parameters
        ----------
        rank : {'auto', 'kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'}, optional
            Analysis will be restricted to abundances of taxa at the specified level.
        normalize : 'auto' or `bool`, optional
            Convert read counts to relative abundances such that each sample sums to 1.0. Setting
            'auto' will choose automatically based on the data.
        return_chart : `bool`, optional
            When True, return an `altair.Chart` object instead of displaying the resulting plot in
            the current notebook.
        top_n : `int`, optional
            Display the top N most abundant taxa in the entire cohort of samples.
        threshold : `float`
            Display only taxa that are more abundant that this threshold in one or more samples.
        title : `string`, optional
            Text label at the top of the plot.
        xlabel : `string`, optional
            Text label along the horizontal axis.
        ylabel : `string`, optional
            Text label along the vertical axis.
        tooltip : `string` or `list`, optional
            A string or list containing strings representing metadata fields. When a point in the
            plot is hovered over, the value of the metadata associated with that sample will be
            displayed in a modal.
        haxis : `string`, optional
            The metadata field (or tuple containing multiple categorical fields) used to group
            samples together.
        legend: `string`, optional
            Title for color scale. Defaults to the metric used to generate the plot, e.g.
            readcount_w_children or abundance.
        label : `string` or `callable`, optional
            A metadata field (or function) used to label each analysis. If passing a function, a
            dict containing the metadata for each analysis is passed as the first and only
            positional argument. The callable function must return a string.
        sort_x : `list` or `callable`, optional
            Either a list of sorted labels or a function that will be called with a list of x-axis labels
            as the only argument, and must return the same list in a user-specified order.
        include_no_level : `bool`, optional
            Whether or not a row should be plotted for taxa that do not have a designated parent at `rank`.

        Examples
        --------
        Plot a bargraph of the top 10 most abundant genera

        >>> plot_bargraph(rank='genus', top_n=10)
        """
        # Deferred imports
        import altair as alt

        if rank is None:
            raise OneCodexException(
                "Please specify a rank or 'auto' to choose automatically")

        if not (threshold or top_n):
            raise OneCodexException(
                "Please specify at least one of: threshold, top_n")

        if len(self._results) < 1:
            raise PlottingException(
                "There are too few samples for bargraph plots after filtering. Please select 1 or "
                "more samples to plot.")

        if top_n == "auto" and threshold == "auto":
            top_n = 10
            threshold = None
        elif top_n == "auto" and threshold != "auto":
            top_n = None
        elif top_n != "auto" and threshold == "auto":
            threshold = None

        df = self.to_df(rank=rank, normalize=normalize, threshold=threshold)

        if AbundanceMetric.has_value(
                self._metric) and include_taxa_missing_rank is None:
            include_taxa_missing_rank = True

        if include_taxa_missing_rank:
            if self._metric != Metric.AbundanceWChildren:
                raise OneCodexException(
                    "No-level data can only be imputed on abundances w/ children"
                )

            name = "No {}".format(rank)

            df[name] = 1 - df.sum(axis=1)

        top_n = df.mean().sort_values(ascending=False).iloc[:top_n].index

        df = df[top_n]

        if include_other and normalize:
            df["Other"] = 1 - df.sum(axis=1)

        if legend == "auto":
            legend = self.metric

        if tooltip:
            if not isinstance(tooltip, list):
                tooltip = [tooltip]
        else:
            tooltip = []

        if haxis:
            tooltip.append(haxis)

        tooltip.insert(0, "Label")

        # takes metadata columns and returns a dataframe with just those columns
        # renames columns in the case where columns are taxids
        magic_metadata, magic_fields = self._metadata_fetch(tooltip,
                                                            label=label)

        df = df.join(magic_metadata)

        df = df.reset_index().melt(
            id_vars=["classification_id"] + magic_metadata.columns.tolist(),
            var_name="tax_id",
            value_name=self.metric,
        )

        # add taxa names
        df["tax_name"] = df["tax_id"].apply(lambda t: "{} ({})".format(
            self.taxonomy["name"][t], t) if t in self.taxonomy["name"] else t)

        #
        # TODO: how to sort bars in bargraph
        # - abundance (mean across all samples)
        # - parent taxon (this will require that we make a few assumptions
        # about taxonomic ranks but as all taxonomic data will be coming from
        # OCX this should be okay)
        #

        ylabel = ylabel or self.metric
        xlabel = xlabel or ""

        # should ultimately be Label, tax_name, readcount_w_children, then custom fields
        tooltip_for_altair = [magic_fields[f] for f in tooltip]
        tooltip_for_altair.insert(1, "tax_name")
        tooltip_for_altair.insert(2, "{}:Q".format(self.metric))

        kwargs = {}

        if haxis:
            kwargs["column"] = alt.Column(haxis,
                                          header=alt.Header(
                                              titleOrient="bottom",
                                              labelOrient="bottom"))

        domain = sorted(df["tax_name"].unique())

        no_level_name = "No {}".format(rank)

        color_range = interleave_palette(
            set(domain) - {"Other", no_level_name})

        other_color = ["#DCE0E5"]
        no_level_color = ["#eeefe1"]

        if include_taxa_missing_rank and no_level_name in domain:
            domain.remove(no_level_name)
            domain = [no_level_name] + domain
            color_range = no_level_color + color_range

        if include_other and "Other" in domain:
            domain.remove("Other")
            domain = ["Other"] + domain
            color_range = other_color + color_range

        sort_order = sort_helper(sort_x, df["Label"].tolist())

        df["order"] = df["tax_name"].apply(domain.index)

        y_scale_kwargs = {"zero": True, "nice": False}
        if normalize:
            y_scale_kwargs["domain"] = [0, 1]

        chart = (alt.Chart(df).transform_calculate(
            url=get_base_classification_url() +
            alt.datum.classification_id).mark_bar().encode(
                x=alt.X("Label", axis=alt.Axis(title=xlabel), sort=sort_order),
                y=alt.Y(self.metric,
                        axis=alt.Axis(title=ylabel),
                        scale=alt.Scale(**y_scale_kwargs)),
                color=alt.Color(
                    "tax_name",
                    legend=alt.Legend(title=legend),
                    sort=domain,
                    scale=alt.Scale(domain=domain, range=color_range),
                ),
                tooltip=tooltip_for_altair,
                href="url:N",
                order=alt.Order("order", sort="descending"),
                **kwargs))

        if haxis:
            chart = chart.resolve_scale(x="independent")

        chart = chart.properties(
            **prepare_props(title=title, width=width, height=height))

        return chart if return_chart else chart.display()
示例#6
0
    def plot_heatmap(
        self,
        rank=Rank.Auto,
        normalize="auto",
        top_n="auto",
        threshold="auto",
        title=None,
        xlabel=None,
        ylabel=None,
        tooltip=None,
        return_chart=False,
        linkage=Linkage.Average,
        haxis=None,
        metric="euclidean",
        legend="auto",
        label=None,
        sort_x=None,
        sort_y=None,
        width=None,
        height=None,
    ):
        """Plot heatmap of taxa abundance/count data for several samples.

        Parameters
        ----------
        rank : {'auto', 'kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'}, optional
            Analysis will be restricted to abundances of taxa at the specified level.
        normalize : 'auto' or `bool`, optional
            Convert read counts to relative abundances such that each sample sums to 1.0. Setting
            'auto' will choose automatically based on the data.
        return_chart : `bool`, optional
            When True, return an `altair.Chart` object instead of displaying the resulting plot in
            the current notebook.
        haxis : `string`, optional
            The metadata field (or tuple containing multiple categorical fields) used to group
            samples together. Each group of samples will be clustered independently.
        metric : {'euclidean', 'braycurtis', 'cityblock', 'manhattan', 'jaccard', 'unifrac', 'unweighted_unifrac', 'aitchison'}, optional
            Function to use when calculating the distance between two samples.
            Note that 'cityblock' and 'manhattan' are equivalent metrics.
        linkage : {'average', 'single', 'complete', 'weighted', 'centroid', 'median'}
            The type of linkage to use when clustering axes.
        top_n : `int`, optional
            Display the top N most abundant taxa in the entire cohort of samples.
        threshold : `float`
            Display only taxa that are more abundant that this threshold in one or more samples.
        title : `string`, optional
            Text label at the top of the plot.
        xlabel : `string`, optional
            Text label along the horizontal axis.
        ylabel : `string`, optional
            Text label along the vertical axis.
        tooltip : `string` or `list`, optional
            A string or list containing strings representing metadata fields. When a point in the
            plot is hovered over, the value of the metadata associated with that sample will be
            displayed in a modal.
        legend: `string`, optional
            Title for color scale. Defaults to the field used to generate the plot, e.g.
            readcount_w_children or abundance.
        label : `string` or `callable`, optional
            A metadata field (or function) used to label each analysis. If passing a function, a
            dict containing the metadata for each analysis is passed as the first and only
            positional argument. The callable function must return a string.
        sort_x : `list` or `callable`, optional
            Either a list of sorted labels or a function that will be called with a list of x-axis labels
            as the only argument, and must return the same list in a user-specified order.
        sort_y : `list` or `callable`, optional
            Either a list of sorted labels or a function that will be called with a list of y-axis labels
            as the only argument, and must return the same list in a user-specified order.

        Examples
        --------
        Plot a heatmap of the relative abundances of the top 10 most abundant families.

        >>> plot_heatmap(rank='family', top_n=10)
        """
        # Deferred imports
        import altair as alt
        import pandas as pd

        if rank is None:
            raise OneCodexException(
                "Please specify a rank or 'auto' to choose automatically")

        if not (threshold or top_n):
            raise OneCodexException(
                "Please specify at least one of: threshold, top_n")

        if len(self._results) < 2:
            raise PlottingException(
                "There are too few samples for heatmap plots after filtering. Please select 2 or "
                "more samples to plot.")

        if top_n == "auto" and threshold == "auto":
            top_n = 10
            threshold = None
        elif top_n == "auto" and threshold != "auto":
            top_n = None
        elif top_n != "auto" and threshold == "auto":
            threshold = None

        df = self.to_df(rank=rank,
                        normalize=normalize,
                        top_n=top_n,
                        threshold=threshold,
                        table_format="long")

        if len(df["tax_id"].unique()) < 2:
            raise PlottingException(
                "There are too few taxa for heatmap clustering after filtering. Please select a "
                "rank or threshold that includes at least 2 taxa.")

        if legend == "auto":
            legend = df.ocx.metric

        if tooltip:
            if not isinstance(tooltip, list):
                tooltip = [tooltip]
        else:
            tooltip = []

        if haxis:
            tooltip.append(haxis)

        tooltip.insert(0, "Label")

        magic_metadata, magic_fields = self._metadata_fetch(tooltip,
                                                            label=label)

        # add columns for prettier display
        df["Label"] = magic_metadata["Label"][df["classification_id"]].tolist()
        df["tax_name"] = [
            "{} ({})".format(self.taxonomy["name"][t], t) for t in df["tax_id"]
        ]

        # and for metadata
        for f in tooltip:
            df[magic_fields[f]] = magic_metadata[magic_fields[f]][
                df["classification_id"]].tolist()

        # if we've already been normalized, we must cluster samples by euclidean distance. beta
        # diversity measures won't work with normalized distances.
        if self._guess_normalized():
            if metric != "euclidean":
                raise OneCodexException(
                    "Results are normalized. Please re-run with metric=euclidean"
                )

            df_sample_cluster = self.to_df(rank=rank,
                                           normalize=normalize,
                                           top_n=top_n,
                                           threshold=threshold)
            df_taxa_cluster = df_sample_cluster
        else:
            df_sample_cluster = self.to_df(rank=rank,
                                           normalize=False,
                                           top_n=top_n,
                                           threshold=threshold)

            df_taxa_cluster = self.to_df(rank=rank,
                                         normalize=normalize,
                                         top_n=top_n,
                                         threshold=threshold)

        # applying clustering to determine order of taxa, or use custom sorting function if given
        if sort_y is None:
            taxa_cluster = df_taxa_cluster.ocx._cluster_by_taxa(
                linkage=linkage)
            taxa_cluster = taxa_cluster["labels_in_order"]
        else:
            taxa_cluster = sort_helper(sort_y, df["tax_name"])

        if sort_x is None:
            if haxis is None:
                # cluster samples only once
                sample_cluster = df_sample_cluster.ocx._cluster_by_sample(
                    rank=rank, metric=metric, linkage=linkage)
                labels_in_order = magic_metadata["Label"][
                    sample_cluster["ids_in_order"]].tolist()
            else:
                if not (pd.api.types.is_bool_dtype(df[magic_fields[haxis]])
                        or pd.api.types.is_categorical_dtype(
                            df[magic_fields[haxis]])  # noqa
                        or pd.api.types.is_object_dtype(
                            df[magic_fields[haxis]])  # noqa
                        ):  # noqa
                    raise OneCodexException(
                        "Metadata field on horizontal axis can not be numerical"
                    )

                labels_in_order = []
                df_sample_cluster[haxis] = self.metadata[haxis]

                for group, group_df in df_sample_cluster.groupby(haxis):

                    if group_df.shape[0] <= 3:
                        # we can't cluster
                        labels_in_order.extend(
                            sorted(magic_metadata["Label"][
                                group_df.index].tolist()))
                        continue

                    sample_cluster = group_df.drop(
                        columns=[haxis]).ocx._cluster_by_sample(
                            rank=rank, metric=metric, linkage=linkage)
                    labels_in_order.extend(magic_metadata["Label"][
                        sample_cluster["ids_in_order"]].tolist())
        else:
            labels_in_order = sort_helper(sort_x,
                                          magic_metadata["Label"].tolist())

        # should ultimately be Label, tax_name, readcount_w_children, then custom fields
        tooltip_for_altair = [magic_fields[f] for f in tooltip]
        tooltip_for_altair.insert(1, "tax_name")
        tooltip_for_altair.insert(2, "{}:Q".format(df.ocx.metric))

        alt_kwargs = dict(
            x=alt.X("Label:N",
                    axis=alt.Axis(title=xlabel),
                    sort=labels_in_order),
            y=alt.Y("tax_name:N",
                    axis=alt.Axis(title=ylabel),
                    sort=taxa_cluster),
            color=alt.Color("{}:Q".format(df.ocx.metric),
                            legend=alt.Legend(title=legend)),
            tooltip=tooltip_for_altair,
            href="url:N",
            url=get_base_classification_url() + alt.datum.classification_id,
        )

        if haxis:
            alt_kwargs["column"] = alt.Column(haxis,
                                              header=alt.Header(
                                                  titleOrient="bottom",
                                                  labelOrient="bottom"))

        chart = (alt.Chart(df).transform_calculate(
            url=alt_kwargs.pop("url")).mark_rect().encode(**alt_kwargs))

        col_count = len(labels_in_order)
        row_count = len(taxa_cluster)

        chart = chart.properties(
            **prepare_props(title=title,
                            height=(height or 15 * row_count),
                            width=(width or 15 * col_count)))

        if haxis:
            chart = chart.resolve_scale(x="independent")

        if return_chart:
            return chart
        else:
            chart.interactive().display()