def plot_pca( self, rank=Rank.Auto, normalize="auto", org_vectors=0, org_vectors_scale=None, title=None, xlabel=None, ylabel=None, color=None, size=None, tooltip=None, return_chart=False, label=None, mark_size=100, width=None, height=None, ): """Perform principal component analysis and plot first two axes. Parameters ---------- rank : {'auto', 'kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'}, optional Analysis will be restricted to abundances of taxa at the specified level. normalize : 'auto' or `bool`, optional Convert read counts to relative abundances such that each sample sums to 1.0. Setting 'auto' will choose automatically based on the data. org_vectors : `int`, optional Plot this many of the top-contributing eigenvectors from the PCA results. org_vectors_scale : `float`, optional Multiply the length of the lines representing the eigenvectors by this constant. title : `string`, optional Text label at the top of the plot. xlabel : `string`, optional Text label along the horizontal axis. ylabel : `string`, optional Text label along the vertical axis. size : `string` or `tuple`, optional A string or a tuple containing strings representing metadata fields. The size of points in the resulting plot will change based on the metadata associated with each sample. color : `string` or `tuple`, optional A string or a tuple containing strings representing metadata fields. The color of points in the resulting plot will change based on the metadata associated with each sample. tooltip : `string` or `list`, optional A string or list containing strings representing metadata fields. When a point in the plot is hovered over, the value of the metadata associated with that sample will be displayed in a modal. label : `string` or `callable`, optional A metadata field (or function) used to label each analysis. If passing a function, a dict containing the metadata for each analysis is passed as the first and only positional argument. The callable function must return a string. mark_size: `int`, optional The size of the points in the scatter plot. Examples -------- Perform PCA on relative abundances at the species-level and color the resulting points by 'geo_loc_name', a metadata field representing the geographical origin of each sample. >>> plot_pca(rank='species', normalize=True, color='geo_loc_name') Change the size of each point in the plot based on the abundance of Bacteroides. >>> plot_pca(size='Bacteroides') Display the abundances of Bacteroides, Prevotella, and Bifidobacterium in each sample when hovering over points in the plot. >>> plot_pca(tooltip=['Bacteroides', 'Prevotella', 'Bifidobacterium']) """ # Deferred imports import altair as alt import numpy as np import pandas as pd from sklearn.decomposition import PCA if rank is None: raise OneCodexException("Please specify a rank or 'auto' to choose automatically") if len(self._results) < 3: raise PlottingException( "There are too few samples for PCA after filtering. Please select 3 or more " "samples to plot." ) df = self.to_df(rank=rank, normalize=normalize) if len(df.columns) < 2: raise PlottingException( "There are too few taxa for PCA after filtering. Please select a rank that " "includes at least 2 taxa." ) if tooltip: if not isinstance(tooltip, list): tooltip = [tooltip] else: tooltip = [] tooltip.insert(0, "Label") if color and color not in tooltip: tooltip.insert(1, color) if size and size not in tooltip: tooltip.insert(2, size) magic_metadata, magic_fields = self._metadata_fetch(tooltip, label=label) pca = PCA() pca_vals = pca.fit(df.values).transform(df.values) pca_vals = pd.DataFrame(pca_vals, index=df.index) pca_vals.rename(columns=lambda x: "PC{}".format(x + 1), inplace=True) # label the axes if xlabel is None: xlabel = "PC1 ({}%)".format(round(pca.explained_variance_ratio_[0] * 100, 2)) if ylabel is None: ylabel = "PC2 ({}%)".format(round(pca.explained_variance_ratio_[1] * 100, 2)) # don't send all the data to vega, just what we're plotting plot_data = pd.concat( [pca_vals.loc[:, ("PC1", "PC2")], magic_metadata], axis=1 ).reset_index() alt_kwargs = dict( x=alt.X("PC1", axis=alt.Axis(title=xlabel)), y=alt.Y("PC2", axis=alt.Axis(title=ylabel)), tooltip=[magic_fields[t] for t in tooltip], href="url:N", url=get_base_classification_url() + alt.datum.classification_id, ) # only add these parameters if they are in use if color: color_kwargs = { "legend": alt.Legend(title=magic_fields[color]), } if not is_continuous(plot_data[color]) or has_missing_values(plot_data[color]): plot_data[color] = plot_data[color].fillna("N/A").astype(str) domain = plot_data[color].values color_range = interleave_palette(domain) color_kwargs["scale"] = alt.Scale(domain=domain, range=color_range) alt_kwargs["color"] = alt.Color(magic_fields[color], **color_kwargs) if size: alt_kwargs["size"] = magic_fields[size] chart = ( alt.Chart(plot_data) .transform_calculate(url=alt_kwargs.pop("url")) .mark_circle(size=mark_size) ) vector_chart = None # plot the organism eigenvectors that contribute the most if org_vectors > 0: plot_data = { "x": [], "y": [], "o": [], # order these points should be connected in "Eigenvectors": [], } magnitudes = np.sqrt(pca.components_[0] ** 2 + pca.components_[1] ** 2) magnitudes.sort() cutoff = magnitudes[-1 * org_vectors] if org_vectors_scale is None: org_vectors_scale = 0.8 * np.max(pca_vals.abs().values) for tax_id, var1, var2 in zip( df.columns.values, pca.components_[0, :], pca.components_[1, :] ): if np.sqrt(var1 ** 2 + var2 ** 2) >= cutoff: plot_data["x"].extend([0, var1 * float(org_vectors_scale)]) plot_data["y"].extend([0, var2 * float(org_vectors_scale)]) plot_data["o"].extend([0, 1]) plot_data["Eigenvectors"].extend([self.taxonomy["name"][tax_id]] * 2) org_vectors -= 1 if org_vectors == 0: break plot_data = pd.DataFrame(plot_data) vector_chart = ( alt.Chart(plot_data) .mark_line(point=False) .encode( x=alt.X("x", axis=None), y=alt.Y("y", axis=None), order="o", color="Eigenvectors", ) ) chart = chart.encode(**alt_kwargs) if vector_chart: chart = alt.layer(chart, vector_chart).resolve_scale(color="independent") chart = chart.properties(**prepare_props(title=title, height=height, width=width)) if return_chart: return chart else: chart.interactive().display()
def plot_distance( self, rank=Rank.Auto, metric=BetaDiversityMetric.BrayCurtis, title=None, xlabel=None, ylabel=None, tooltip=None, return_chart=False, linkage=Linkage.Average, label=None, width=None, height=None, ): """Plot beta diversity distance matrix as a heatmap and dendrogram. Parameters ---------- rank : {'auto', 'kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'}, optional Analysis will be restricted to abundances of taxa at the specified level. metric : {'braycurtis', 'cityblock', 'manhattan', 'jaccard', 'unifrac', 'unweighted_unifrac', 'aitchison'}, optional Function to use when calculating the distance between two samples. Note that 'cityblock' and 'manhattan' are equivalent metrics. linkage : {'average', 'single', 'complete', 'weighted', 'centroid', 'median'} The type of linkage to use when clustering axes. title : `string`, optional Text label at the top of the plot. xlabel : `string`, optional Text label along the horizontal axis. ylabel : `string`, optional Text label along the vertical axis. tooltip : `string` or `list`, optional A string or list containing strings representing metadata fields. When a point in the plot is hovered over, the value of the metadata associated with that sample will be displayed in a modal. label : `string` or `callable`, optional A metadata field (or function) used to label each analysis. If passing a function, a dict containing the metadata for each analysis is passed as the first and only positional argument. The callable function must return a string. Examples -------- Plot the weighted UniFrac distance between all our samples, using counts at the genus level. >>> plot_distance(rank='genus', metric='unifrac') """ import altair as alt import numpy as np import pandas as pd from onecodex.viz import dendrogram if len(self._results) < 2: raise PlottingException( "There are too few samples for distance matrix plots after filtering. Please " "select 2 or more samples to plot.") # this will be passed to the heatmap chart as a dataframe eventually plot_data = { "1) Label": [], "2) Label": [], "Distance": [], "classification_id": [] } # here we figure out what to put in the tooltips and get the appropriate data if tooltip: if not isinstance(tooltip, list): tooltip = [tooltip] else: tooltip = [] tooltip.insert(0, "Label") magic_metadata, magic_fields = self._metadata_fetch(tooltip, label=label) formatted_fields = [] for _, magic_field in magic_fields.items(): field_group = [] for i in (1, 2): field = "{}) {}".format(i, magic_field) plot_data[field] = [] field_group.append(field) formatted_fields.append(field_group) clust = self._cluster_by_sample(rank=rank, metric=metric, linkage=linkage) # must convert to long format for heatmap plotting for idx1, id1 in enumerate(clust["dist_matrix"].index): for idx2, id2 in enumerate(clust["dist_matrix"].index): if idx1 == idx2: plot_data["Distance"].append(np.nan) else: plot_data["Distance"].append( clust["dist_matrix"].iloc[idx1, idx2]) plot_data["classification_id"].append(id1) for field_group, magic_field in zip(formatted_fields, magic_fields.values()): plot_data[field_group[0]].append( magic_metadata[magic_field][id1]) plot_data[field_group[1]].append( magic_metadata[magic_field][id2]) plot_data = pd.DataFrame(data=plot_data) labels_in_order = magic_metadata["Label"][ clust["ids_in_order"]].tolist() # it's important to tell altair to order the cells in the heatmap according to the clustering # obtained from scipy alt_kwargs = dict( x=alt.X("1) Label:N", axis=alt.Axis(title=xlabel), sort=labels_in_order), y=alt.Y("2) Label:N", axis=alt.Axis(title=ylabel, orient="right"), sort=labels_in_order), color=alt.Color("Distance:Q", legend=alt.Legend(title="Distance")), tooltip=list(chain.from_iterable(formatted_fields)) + ["Distance:Q"], href="url:N", url=get_base_classification_url() + alt.datum.classification_id, ) chart = (alt.Chart( plot_data, width=15 * len(clust["dist_matrix"].index), height=15 * len(clust["dist_matrix"].index), ).transform_calculate(url=alt_kwargs.pop("url")).mark_rect().encode( **alt_kwargs)) chart = chart.properties(**prepare_props(height=height, width=width)) dendro_chart = dendrogram(clust["scipy_tree"]) if height: cell_height = height / len(clust["dist_matrix"].index) dendro_chart = dendro_chart.properties(height=height - cell_height / 2) title_kwargs = prepare_props(title=title) concat_chart = alt.hconcat( dendro_chart, chart, spacing=0, **title_kwargs).configure_view(strokeWidth=0) if return_chart: return concat_chart else: concat_chart.display()
def plot_mds( self, rank=Rank.Auto, metric=BetaDiversityMetric.BrayCurtis, method=OrdinationMethod.Pcoa, title=None, xlabel=None, ylabel=None, color=None, size=None, tooltip=None, return_chart=False, label=None, mark_size=100, width=None, height=None, ): """Plot beta diversity distance matrix using multidimensional scaling (MDS). Parameters ---------- rank : {'auto', 'kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'}, optional Analysis will be restricted to abundances of taxa at the specified level. metric : {'braycurtis', 'cityblock', 'manhattan', 'jaccard', 'unifrac', 'unweighted_unifrac', 'aitchison'}, optional Function to use when calculating the distance between two samples. Note that 'cityblock' and 'manhattan' are equivalent metrics. method : {'pcoa', 'smacof'} Algorithm to use for ordination. PCoA uses eigenvalue decomposition and is not well suited to non-euclidean distance functions. SMACOF is an iterative optimization strategy that can be used as an alternative. title : `string`, optional Text label at the top of the plot. xlabel : `string`, optional Text label along the horizontal axis. ylabel : `string`, optional Text label along the vertical axis. size : `string` or `tuple`, optional A string or a tuple containing strings representing metadata fields. The size of points in the resulting plot will change based on the metadata associated with each sample. color : `string` or `tuple`, optional A string or a tuple containing strings representing metadata fields. The color of points in the resulting plot will change based on the metadata associated with each sample. tooltip : `string` or `list`, optional A string or list containing strings representing metadata fields. When a point in the plot is hovered over, the value of the metadata associated with that sample will be displayed in a modal. label : `string` or `callable`, optional A metadata field (or function) used to label each analysis. If passing a function, a dict containing the metadata for each analysis is passed as the first and only positional argument. The callable function must return a string. Examples -------- Scatter plot of weighted UniFrac distance between all our samples, using counts at the genus level. >>> plot_mds(rank='genus', metric='unifrac') Notes ----- **For `smacof`**: The values reported on the axis labels are Pearson's correlations between the distances between points on each axis alone, and the corresponding distances in the distance matrix calculated using the user-specified metric. These values are related to the effectiveness of the MDS algorithm in placing points on the scatter plot in such a way that they truly represent the calculated distances. They do not reflect how well the distance metric captures similarities between the underlying data (in this case, an OTU table). """ import altair as alt import numpy as np import pandas as pd from scipy.spatial.distance import squareform from scipy.stats import pearsonr from skbio.stats import ordination from sklearn import manifold from sklearn.metrics.pairwise import euclidean_distances if len(self._results) < 3: raise PlottingException( "There are too few samples for MDS/PCoA after filtering. Please select 3 or more " "samples to plot.") dists = self._compute_distance(rank, metric).to_data_frame() # here we figure out what to put in the tooltips and get the appropriate data if tooltip: if not isinstance(tooltip, list): tooltip = [tooltip] else: tooltip = [] tooltip.insert(0, "Label") if color and color not in tooltip: tooltip.insert(1, color) if size and size not in tooltip: tooltip.insert(2, size) magic_metadata, magic_fields = self._metadata_fetch(tooltip, label=label) if method == OrdinationMethod.Smacof: # adapted from https://scikit-learn.org/stable/auto_examples/manifold/plot_mds.html x_field = "MDS1" y_field = "MDS2" seed = np.random.RandomState(seed=3) mds = manifold.MDS(max_iter=3000, eps=1e-12, random_state=seed, dissimilarity="precomputed", n_jobs=1) pos = mds.fit(dists).embedding_ plot_data = pd.DataFrame(pos, columns=[x_field, y_field], index=dists.index) plot_data = plot_data.div(plot_data.abs().max(axis=0), axis=1) # normalize to [0,1] # determine how much of the original distance is captured by each of the axes after MDS. # this implementation of MDS does not use eigen decomposition and so there's no simple # way of returning a 'percent of variance explained' value r_squared = [] for axis in [0, 1]: mds_dist = pos.copy() mds_dist[::, axis] = 0 mds_dist = squareform(euclidean_distances(mds_dist).round(6)) r_squared.append(pearsonr(mds_dist, squareform(dists))[0]) # label the axes x_extra_label = "r² = %.02f" % (r_squared[0], ) y_extra_label = "r² = %.02f" % (r_squared[1], ) elif method == OrdinationMethod.Pcoa: # suppress eigenvalue warning from skbio--not because it's an invalid warning, but # because lots of folks in the field run pcoa on these distances functions, even if # statistically inappropriate. perhaps this will change if we ever become more # opinionated about the analyses that we allow our users to do (roo) with warnings.catch_warnings(): warnings.simplefilter("ignore") ord_result = ordination.pcoa( dists.round(6)) # round to avoid float precision errors plot_data = ord_result.samples.iloc[:, [0, 1 ]] # get first two components plot_data = plot_data.div(plot_data.abs().max(axis=0), axis=1) # normalize to [0,1] plot_data.index = dists.index x_field, y_field = plot_data.columns.tolist( ) # name of first two components x_extra_label = "%0.02f%%" % (ord_result.proportion_explained[0] * 100, ) y_extra_label = "%0.02f%%" % (ord_result.proportion_explained[1] * 100, ) else: raise OneCodexException("MDS method must be one of: {}".format( ", ".join(OrdinationMethod.values))) # label the axes if xlabel is None: xlabel = "{} ({})".format(x_field, x_extra_label) if ylabel is None: ylabel = "{} ({})".format(y_field, y_extra_label) plot_data = pd.concat([plot_data, magic_metadata], axis=1).reset_index() alt_kwargs = dict( x=alt.X(x_field, axis=alt.Axis(title=xlabel)), y=alt.Y(y_field, axis=alt.Axis(title=ylabel)), tooltip=[magic_fields[t] for t in tooltip], href="url:N", url=get_base_classification_url() + alt.datum.classification_id, ) # only add these parameters if they are in use if color: color_kwargs = { "legend": alt.Legend(title=magic_fields[color]), } if not is_continuous(plot_data[color]) or has_missing_values( plot_data[color]): plot_data[color] = plot_data[color].fillna("N/A").astype(str) domain = plot_data[color].values color_range = interleave_palette(domain) color_kwargs["scale"] = alt.Scale(domain=domain, range=color_range) alt_kwargs["color"] = alt.Color(magic_fields[color], **color_kwargs) if size: alt_kwargs["size"] = magic_fields[size] chart = (alt.Chart(plot_data).transform_calculate( url=alt_kwargs.pop("url")).mark_circle(size=mark_size).encode( **alt_kwargs)) chart = chart.properties( **prepare_props(title=title, height=height, width=width)) if return_chart: return chart else: chart.interactive().display()
def plot_metadata( self, rank=Rank.Auto, haxis="Label", vaxis=AlphaDiversityMetric.Shannon, title=None, xlabel=None, ylabel=None, return_chart=False, plot_type=PlotType.Auto, label=None, sort_x=None, width=200, height=400, ): """Plot an arbitrary metadata field versus an arbitrary quantity as a boxplot or scatter plot. Parameters ---------- rank : {'auto', 'kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'}, optional Analysis will be restricted to abundances of taxa at the specified level. haxis : `string`, optional The metadata field (or tuple containing multiple categorical fields) to be plotted on the horizontal axis. vaxis : `string`, optional Data to be plotted on the vertical axis. Can be any one of the following: - A metadata field: the name of a metadata field containing numerical data - {'simpson', 'observed_taxa', 'shannon'}: an alpha diversity statistic to calculate for each sample - A taxon name: the name of a taxon in the analysis - A taxon ID: the ID of a taxon in the analysis title : `string`, optional Text label at the top of the plot. xlabel : `string`, optional Text label along the horizontal axis. ylabel : `string`, optional Text label along the vertical axis. plot_type : {'auto', 'boxplot', 'scatter'} By default, will determine plot type automatically based on the data. Otherwise, specify one of 'boxplot' or 'scatter' to set the type of plot manually. label : `string` or `callable`, optional A metadata field (or function) used to label each analysis. If passing a function, a dict containing the metadata for each analysis is passed as the first and only positional argument. The callable function must return a string. sort_x : `list` or `callable`, optional Either a list of sorted labels or a function that will be called with a list of x-axis labels as the only argument, and must return the same list in a user-specified order. Examples -------- Generate a boxplot of the abundance of Bacteroides (genus) of samples grouped by whether the individuals are allergy to dogs, cats, both, or neither. >>> plot_metadata(haxis=('allergy_dogs', 'allergy_cats'), vaxis='Bacteroides') """ # Deferred imports import altair as alt import pandas as pd if rank is None: raise OneCodexException( "Please specify a rank or 'auto' to choose automatically") if not PlotType.has_value(plot_type): raise OneCodexException( "Plot type must be one of: auto, boxplot, scatter") if len(self._results) < 1: raise PlottingException( "There are too few samples for metadata plots after filtering. Please select 1 or " "more samples to plot.") # alpha diversity is only allowed on vertical axis--horizontal can be magically mapped df, magic_fields = self._metadata_fetch([haxis, "Label"], label=label) if AlphaDiversityMetric.has_value(vaxis): df.loc[:, vaxis] = self.alpha_diversity(vaxis, rank=rank) magic_fields[vaxis] = vaxis df.dropna(subset=[magic_fields[vaxis]], inplace=True) else: # if it's not alpha diversity, vertical axis can also be magically mapped vert_df, vert_magic_fields = self._metadata_fetch([vaxis]) # we require the vertical axis to be numerical otherwise plots get weird if (pd.api.types.is_bool_dtype(vert_df[vert_magic_fields[vaxis]]) or pd.api.types.is_categorical_dtype( vert_df[vert_magic_fields[vaxis]]) or pd.api.types.is_object_dtype( vert_df[vert_magic_fields[vaxis]]) or not pd.api.types.is_numeric_dtype( vert_df[vert_magic_fields[vaxis]])): # noqa raise OneCodexException( "Metadata field on vertical axis must be numerical") df = pd.concat([df, vert_df], axis=1).dropna(subset=[vert_magic_fields[vaxis]]) magic_fields.update(vert_magic_fields) # plots can look different depending on what the horizontal axis contains if pd.api.types.is_datetime64_any_dtype(df[magic_fields[haxis]]): if plot_type == PlotType.Auto: plot_type = PlotType.BoxPlot elif "date" in magic_fields[haxis].split("_"): df.loc[:, magic_fields[haxis]] = df.loc[:, magic_fields[haxis]].apply( pd.to_datetime, utc=True) if plot_type == PlotType.Auto: plot_type = PlotType.BoxPlot elif (pd.api.types.is_bool_dtype(df[magic_fields[haxis]]) or pd.api.types.is_categorical_dtype(df[magic_fields[haxis]]) or pd.api.types.is_object_dtype(df[magic_fields[haxis]])): # noqa df = df.fillna({field: "N/A" for field in df.columns}) if plot_type == PlotType.Auto: # if data is categorical but there is only one value per sample, scatter plot instead if len(df[magic_fields[haxis]].unique()) == len( df[magic_fields[haxis]]): plot_type = PlotType.Scatter else: plot_type = PlotType.BoxPlot elif pd.api.types.is_numeric_dtype(df[magic_fields[haxis]]): df = df.dropna(subset=[magic_fields[vaxis]]) if plot_type == PlotType.Auto: plot_type = PlotType.Scatter else: raise OneCodexException( "Unplottable column type for horizontal axis ({})".format( haxis)) if xlabel is None: xlabel = magic_fields[haxis] if ylabel is None: ylabel = magic_fields[vaxis] if plot_type == "scatter": df = df.reset_index() sort_order = sort_helper(sort_x, df[magic_fields[haxis]].tolist()) alt_kwargs = dict( x=alt.X(magic_fields[haxis], axis=alt.Axis(title=xlabel), sort=sort_order), y=alt.Y(magic_fields[vaxis], axis=alt.Axis(title=ylabel)), tooltip=["Label", "{}:Q".format(magic_fields[vaxis])], href="url:N", url=get_base_classification_url() + alt.datum.classification_id, ) chart = (alt.Chart(df).transform_calculate( url=alt_kwargs.pop("url")).mark_circle().encode(**alt_kwargs)) elif plot_type == PlotType.BoxPlot: if sort_x: raise OneCodexException( "Must not specify sort_x when plot_type is boxplot") # See the following issue in case this gets fixed in altair: # https://github.com/altair-viz/altair/issues/2144 if (df.groupby(magic_fields[haxis]).size() < 2).any(): warnings.warn( "There is at least one sample group consisting of only a single sample. Groups " "of size 1 may not have their boxes displayed in the plot.", PlottingWarning, ) box_size = 45 increment = 5 n_boxes = len(df[magic_fields[haxis]].unique()) if width and width != "container" and ( n_boxes * (box_size + increment)) > width: box_size = ( (width / n_boxes) // increment) * increment - increment chart = (alt.Chart(df).mark_boxplot(size=box_size).encode( x=alt.X(magic_fields[haxis], axis=alt.Axis(title=xlabel)), y=alt.Y(magic_fields[vaxis], axis=alt.Axis(title=ylabel)), )) chart = chart.properties( **prepare_props(title=title, height=height, width=width)) if return_chart: return chart else: chart.interactive().display()
def plot_bargraph( self, rank=Rank.Auto, normalize="auto", top_n="auto", threshold="auto", title=None, xlabel=None, ylabel=None, tooltip=None, return_chart=False, haxis=None, legend="auto", label=None, sort_x=None, include_taxa_missing_rank=None, include_other=True, width=None, height=None, ): """Plot a bargraph of relative abundance of taxa for multiple samples. Parameters ---------- rank : {'auto', 'kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'}, optional Analysis will be restricted to abundances of taxa at the specified level. normalize : 'auto' or `bool`, optional Convert read counts to relative abundances such that each sample sums to 1.0. Setting 'auto' will choose automatically based on the data. return_chart : `bool`, optional When True, return an `altair.Chart` object instead of displaying the resulting plot in the current notebook. top_n : `int`, optional Display the top N most abundant taxa in the entire cohort of samples. threshold : `float` Display only taxa that are more abundant that this threshold in one or more samples. title : `string`, optional Text label at the top of the plot. xlabel : `string`, optional Text label along the horizontal axis. ylabel : `string`, optional Text label along the vertical axis. tooltip : `string` or `list`, optional A string or list containing strings representing metadata fields. When a point in the plot is hovered over, the value of the metadata associated with that sample will be displayed in a modal. haxis : `string`, optional The metadata field (or tuple containing multiple categorical fields) used to group samples together. legend: `string`, optional Title for color scale. Defaults to the metric used to generate the plot, e.g. readcount_w_children or abundance. label : `string` or `callable`, optional A metadata field (or function) used to label each analysis. If passing a function, a dict containing the metadata for each analysis is passed as the first and only positional argument. The callable function must return a string. sort_x : `list` or `callable`, optional Either a list of sorted labels or a function that will be called with a list of x-axis labels as the only argument, and must return the same list in a user-specified order. include_no_level : `bool`, optional Whether or not a row should be plotted for taxa that do not have a designated parent at `rank`. Examples -------- Plot a bargraph of the top 10 most abundant genera >>> plot_bargraph(rank='genus', top_n=10) """ # Deferred imports import altair as alt if rank is None: raise OneCodexException( "Please specify a rank or 'auto' to choose automatically") if not (threshold or top_n): raise OneCodexException( "Please specify at least one of: threshold, top_n") if len(self._results) < 1: raise PlottingException( "There are too few samples for bargraph plots after filtering. Please select 1 or " "more samples to plot.") if top_n == "auto" and threshold == "auto": top_n = 10 threshold = None elif top_n == "auto" and threshold != "auto": top_n = None elif top_n != "auto" and threshold == "auto": threshold = None df = self.to_df(rank=rank, normalize=normalize, threshold=threshold) if AbundanceMetric.has_value( self._metric) and include_taxa_missing_rank is None: include_taxa_missing_rank = True if include_taxa_missing_rank: if self._metric != Metric.AbundanceWChildren: raise OneCodexException( "No-level data can only be imputed on abundances w/ children" ) name = "No {}".format(rank) df[name] = 1 - df.sum(axis=1) top_n = df.mean().sort_values(ascending=False).iloc[:top_n].index df = df[top_n] if include_other and normalize: df["Other"] = 1 - df.sum(axis=1) if legend == "auto": legend = self.metric if tooltip: if not isinstance(tooltip, list): tooltip = [tooltip] else: tooltip = [] if haxis: tooltip.append(haxis) tooltip.insert(0, "Label") # takes metadata columns and returns a dataframe with just those columns # renames columns in the case where columns are taxids magic_metadata, magic_fields = self._metadata_fetch(tooltip, label=label) df = df.join(magic_metadata) df = df.reset_index().melt( id_vars=["classification_id"] + magic_metadata.columns.tolist(), var_name="tax_id", value_name=self.metric, ) # add taxa names df["tax_name"] = df["tax_id"].apply(lambda t: "{} ({})".format( self.taxonomy["name"][t], t) if t in self.taxonomy["name"] else t) # # TODO: how to sort bars in bargraph # - abundance (mean across all samples) # - parent taxon (this will require that we make a few assumptions # about taxonomic ranks but as all taxonomic data will be coming from # OCX this should be okay) # ylabel = ylabel or self.metric xlabel = xlabel or "" # should ultimately be Label, tax_name, readcount_w_children, then custom fields tooltip_for_altair = [magic_fields[f] for f in tooltip] tooltip_for_altair.insert(1, "tax_name") tooltip_for_altair.insert(2, "{}:Q".format(self.metric)) kwargs = {} if haxis: kwargs["column"] = alt.Column(haxis, header=alt.Header( titleOrient="bottom", labelOrient="bottom")) domain = sorted(df["tax_name"].unique()) no_level_name = "No {}".format(rank) color_range = interleave_palette( set(domain) - {"Other", no_level_name}) other_color = ["#DCE0E5"] no_level_color = ["#eeefe1"] if include_taxa_missing_rank and no_level_name in domain: domain.remove(no_level_name) domain = [no_level_name] + domain color_range = no_level_color + color_range if include_other and "Other" in domain: domain.remove("Other") domain = ["Other"] + domain color_range = other_color + color_range sort_order = sort_helper(sort_x, df["Label"].tolist()) df["order"] = df["tax_name"].apply(domain.index) y_scale_kwargs = {"zero": True, "nice": False} if normalize: y_scale_kwargs["domain"] = [0, 1] chart = (alt.Chart(df).transform_calculate( url=get_base_classification_url() + alt.datum.classification_id).mark_bar().encode( x=alt.X("Label", axis=alt.Axis(title=xlabel), sort=sort_order), y=alt.Y(self.metric, axis=alt.Axis(title=ylabel), scale=alt.Scale(**y_scale_kwargs)), color=alt.Color( "tax_name", legend=alt.Legend(title=legend), sort=domain, scale=alt.Scale(domain=domain, range=color_range), ), tooltip=tooltip_for_altair, href="url:N", order=alt.Order("order", sort="descending"), **kwargs)) if haxis: chart = chart.resolve_scale(x="independent") chart = chart.properties( **prepare_props(title=title, width=width, height=height)) return chart if return_chart else chart.display()
def plot_heatmap( self, rank=Rank.Auto, normalize="auto", top_n="auto", threshold="auto", title=None, xlabel=None, ylabel=None, tooltip=None, return_chart=False, linkage=Linkage.Average, haxis=None, metric="euclidean", legend="auto", label=None, sort_x=None, sort_y=None, width=None, height=None, ): """Plot heatmap of taxa abundance/count data for several samples. Parameters ---------- rank : {'auto', 'kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'}, optional Analysis will be restricted to abundances of taxa at the specified level. normalize : 'auto' or `bool`, optional Convert read counts to relative abundances such that each sample sums to 1.0. Setting 'auto' will choose automatically based on the data. return_chart : `bool`, optional When True, return an `altair.Chart` object instead of displaying the resulting plot in the current notebook. haxis : `string`, optional The metadata field (or tuple containing multiple categorical fields) used to group samples together. Each group of samples will be clustered independently. metric : {'euclidean', 'braycurtis', 'cityblock', 'manhattan', 'jaccard', 'unifrac', 'unweighted_unifrac', 'aitchison'}, optional Function to use when calculating the distance between two samples. Note that 'cityblock' and 'manhattan' are equivalent metrics. linkage : {'average', 'single', 'complete', 'weighted', 'centroid', 'median'} The type of linkage to use when clustering axes. top_n : `int`, optional Display the top N most abundant taxa in the entire cohort of samples. threshold : `float` Display only taxa that are more abundant that this threshold in one or more samples. title : `string`, optional Text label at the top of the plot. xlabel : `string`, optional Text label along the horizontal axis. ylabel : `string`, optional Text label along the vertical axis. tooltip : `string` or `list`, optional A string or list containing strings representing metadata fields. When a point in the plot is hovered over, the value of the metadata associated with that sample will be displayed in a modal. legend: `string`, optional Title for color scale. Defaults to the field used to generate the plot, e.g. readcount_w_children or abundance. label : `string` or `callable`, optional A metadata field (or function) used to label each analysis. If passing a function, a dict containing the metadata for each analysis is passed as the first and only positional argument. The callable function must return a string. sort_x : `list` or `callable`, optional Either a list of sorted labels or a function that will be called with a list of x-axis labels as the only argument, and must return the same list in a user-specified order. sort_y : `list` or `callable`, optional Either a list of sorted labels or a function that will be called with a list of y-axis labels as the only argument, and must return the same list in a user-specified order. Examples -------- Plot a heatmap of the relative abundances of the top 10 most abundant families. >>> plot_heatmap(rank='family', top_n=10) """ # Deferred imports import altair as alt import pandas as pd if rank is None: raise OneCodexException( "Please specify a rank or 'auto' to choose automatically") if not (threshold or top_n): raise OneCodexException( "Please specify at least one of: threshold, top_n") if len(self._results) < 2: raise PlottingException( "There are too few samples for heatmap plots after filtering. Please select 2 or " "more samples to plot.") if top_n == "auto" and threshold == "auto": top_n = 10 threshold = None elif top_n == "auto" and threshold != "auto": top_n = None elif top_n != "auto" and threshold == "auto": threshold = None df = self.to_df(rank=rank, normalize=normalize, top_n=top_n, threshold=threshold, table_format="long") if len(df["tax_id"].unique()) < 2: raise PlottingException( "There are too few taxa for heatmap clustering after filtering. Please select a " "rank or threshold that includes at least 2 taxa.") if legend == "auto": legend = df.ocx.metric if tooltip: if not isinstance(tooltip, list): tooltip = [tooltip] else: tooltip = [] if haxis: tooltip.append(haxis) tooltip.insert(0, "Label") magic_metadata, magic_fields = self._metadata_fetch(tooltip, label=label) # add columns for prettier display df["Label"] = magic_metadata["Label"][df["classification_id"]].tolist() df["tax_name"] = [ "{} ({})".format(self.taxonomy["name"][t], t) for t in df["tax_id"] ] # and for metadata for f in tooltip: df[magic_fields[f]] = magic_metadata[magic_fields[f]][ df["classification_id"]].tolist() # if we've already been normalized, we must cluster samples by euclidean distance. beta # diversity measures won't work with normalized distances. if self._guess_normalized(): if metric != "euclidean": raise OneCodexException( "Results are normalized. Please re-run with metric=euclidean" ) df_sample_cluster = self.to_df(rank=rank, normalize=normalize, top_n=top_n, threshold=threshold) df_taxa_cluster = df_sample_cluster else: df_sample_cluster = self.to_df(rank=rank, normalize=False, top_n=top_n, threshold=threshold) df_taxa_cluster = self.to_df(rank=rank, normalize=normalize, top_n=top_n, threshold=threshold) # applying clustering to determine order of taxa, or use custom sorting function if given if sort_y is None: taxa_cluster = df_taxa_cluster.ocx._cluster_by_taxa( linkage=linkage) taxa_cluster = taxa_cluster["labels_in_order"] else: taxa_cluster = sort_helper(sort_y, df["tax_name"]) if sort_x is None: if haxis is None: # cluster samples only once sample_cluster = df_sample_cluster.ocx._cluster_by_sample( rank=rank, metric=metric, linkage=linkage) labels_in_order = magic_metadata["Label"][ sample_cluster["ids_in_order"]].tolist() else: if not (pd.api.types.is_bool_dtype(df[magic_fields[haxis]]) or pd.api.types.is_categorical_dtype( df[magic_fields[haxis]]) # noqa or pd.api.types.is_object_dtype( df[magic_fields[haxis]]) # noqa ): # noqa raise OneCodexException( "Metadata field on horizontal axis can not be numerical" ) labels_in_order = [] df_sample_cluster[haxis] = self.metadata[haxis] for group, group_df in df_sample_cluster.groupby(haxis): if group_df.shape[0] <= 3: # we can't cluster labels_in_order.extend( sorted(magic_metadata["Label"][ group_df.index].tolist())) continue sample_cluster = group_df.drop( columns=[haxis]).ocx._cluster_by_sample( rank=rank, metric=metric, linkage=linkage) labels_in_order.extend(magic_metadata["Label"][ sample_cluster["ids_in_order"]].tolist()) else: labels_in_order = sort_helper(sort_x, magic_metadata["Label"].tolist()) # should ultimately be Label, tax_name, readcount_w_children, then custom fields tooltip_for_altair = [magic_fields[f] for f in tooltip] tooltip_for_altair.insert(1, "tax_name") tooltip_for_altair.insert(2, "{}:Q".format(df.ocx.metric)) alt_kwargs = dict( x=alt.X("Label:N", axis=alt.Axis(title=xlabel), sort=labels_in_order), y=alt.Y("tax_name:N", axis=alt.Axis(title=ylabel), sort=taxa_cluster), color=alt.Color("{}:Q".format(df.ocx.metric), legend=alt.Legend(title=legend)), tooltip=tooltip_for_altair, href="url:N", url=get_base_classification_url() + alt.datum.classification_id, ) if haxis: alt_kwargs["column"] = alt.Column(haxis, header=alt.Header( titleOrient="bottom", labelOrient="bottom")) chart = (alt.Chart(df).transform_calculate( url=alt_kwargs.pop("url")).mark_rect().encode(**alt_kwargs)) col_count = len(labels_in_order) row_count = len(taxa_cluster) chart = chart.properties( **prepare_props(title=title, height=(height or 15 * row_count), width=(width or 15 * col_count))) if haxis: chart = chart.resolve_scale(x="independent") if return_chart: return chart else: chart.interactive().display()