def test_unsuccessful_classification(ocx, api_data): classifications = [ ocx.Classifications.get('45a573fb7833449a'), ocx.Classifications.get('593601a797914cbf') ] classifications[0]._resource.success = False with pytest.warns(UserWarning): normed_classifications, _ = normalize_classifications(classifications) assert len(normed_classifications) == 1 normed_classifications, _ = normalize_classifications(classifications, skip_missing=False) assert len(normed_classifications) == 2
def test_normalize_analyses_and_samples(ocx, api_data): analyses = [ ocx.Analyses.get('45a573fb7833449a'), ocx.Analyses.get('593601a797914cbf') ] normed_classifications, metadata = normalize_classifications(analyses) for n in normed_classifications: assert isinstance(n, Classifications) normed_classifications, metadata = normalize_classifications( [a.sample for a in analyses]) for n in normed_classifications: assert isinstance(n, Classifications)
def test_normalize_classifications_labeling(ocx, api_data): cs = [ ocx.Classifications.get('45a573fb7833449a'), ocx.Classifications.get('593601a797914cbf') ] _, metadata = normalize_classifications(cs, label='created_at') assert (metadata['_display_name'] == metadata['created_at']).all() with pytest.raises(OneCodexException): _, metadata = normalize_classifications(cs, label='nonexisting') _, metadata = normalize_classifications(cs, label=lambda x: x.sample.filename) assert (metadata['_display_name'] == [c.sample.filename for c in cs]).all()
def plot_distance(analyses, metric='braycurtis', title=None, label=None, xlabel=None, ylabel=None, field='readcount_w_children', rank='species', **kwargs): """Plot beta diversity distance matrix. Additional **kwargs are passed to Seaborn's `sns.clustermap`. """ # if taxonomy trees are inconsistent, unifrac will not work if metric in ['braycurtis', 'bray-curtis', 'bray curtis']: f = braycurtis elif metric in ['manhattan', 'cityblock']: f = cityblock elif metric == 'jaccard': f = jaccard elif metric == 'unifrac': f = unifrac else: raise OneCodexException("'metric' must be one of " "braycurtis, manhattan, jaccard, or unifrac") normed_classifications, metadata = normalize_classifications(analyses, label=label) if len(normed_classifications) < 2: raise OneCodexException('`plot_distance` requires 2 or more valid classification results.') sns.set(style=kwargs.pop('style', 'darkgrid')) # there is no uniqueness constraint on metadata names # so plot by uuid, then replace the labels in the dataframe with their names uuids = {} sample_names = {} for idx, analysis in enumerate(normed_classifications): uuids[analysis.id] = analysis.id sample_names[analysis.id] = metadata.loc[idx, '_display_name'] distances = f(normed_classifications, field=field, rank=rank) ids = distances.ids distance_matrix = distances.data dists = {} for idx1, id1 in enumerate(ids): dists[uuids[id1]] = {} for idx2, id2 in enumerate(ids): dists[uuids[id1]][uuids[id2]] = distance_matrix[idx1][idx2] dists = pd.DataFrame(dists).rename(index=sample_names, columns=sample_names) # Plot cluster map; ignore new SciPy cluster warnings with warnings.catch_warnings(): warnings.simplefilter('ignore', scipy.cluster.hierarchy.ClusterWarning) g = sns.clustermap(dists, **kwargs) plt.setp(g.ax_heatmap.yaxis.get_majorticklabels(), rotation=0) # Labels if xlabel is not None: plt.gca().set_xlabel(xlabel) if ylabel is not None: plt.gca().set_ylabel(ylabel) if title: g.fig.suptitle(title) plt.show()
def beta_counts(classifications, field='readcount_w_children', rank='species'): normed_classifications, _ = normalize_classifications(classifications) df, tax_info = collate_classification_results(normed_classifications, field=field, rank=rank) tax_ids = df.columns.values.tolist() vectors = df.values.tolist() vectors = [[int(i) for i in row] for row in vectors] ids = df.index.values return (vectors, tax_ids, ids)
def test_normalize_classifications(ocx, api_data): classifications = [ ocx.Classifications.get('45a573fb7833449a'), ocx.Classifications.get('593601a797914cbf') ] normed_classifications, metadata = normalize_classifications( classifications) assert len(normed_classifications) == 2 assert len(metadata) == 2 assert isinstance(metadata, pd.DataFrame) assert (metadata['_display_name'] == metadata['name']).all() for n in normed_classifications: assert isinstance(n, Classifications)
def plot_metadata(analyses, metadata='created_at', statistic=None, tax_id=None, title=None, label=None, xlabel=None, ylabel=None, field='readcount_w_children', rank='species', normalize=False, **kwargs): """Plot by arbitary metadata. Note that `rank` only applies if you're calculating a `statistic`. Additional **kwargs are passed to Seaborn or Matplotlib plot functions as appropriate. """ if not tax_id and not statistic: raise OneCodexException('Please pass a `tax_id` or a `statistic`.') elif tax_id and statistic: raise OneCodexException( 'Please pass only a `tax_id` or a `statistic`.') sns.set(style=kwargs.pop('style', 'darkgrid')) normed_classifications, md = normalize_classifications(analyses, label=label) if metadata not in md: raise OneCodexException( 'Selected metadata field `{}` not available'.format(metadata)) stat = [] if statistic: for analysis in normed_classifications: try: stat.append( alpha_diversity(analysis, statistic, field=field, rank=rank)) except ValueError: raise NotImplementedError( '{} statistic are currently supported.'.format(statistic)) elif tax_id: df, tax_info = collate_classification_results(normed_classifications, field=field, rank=None, normalize=normalize) stat = df.loc[:, str(tax_id)].values if stat.shape[0] == 0: raise OneCodexException( 'No values found for `tax_id` {} and `field` {}.'.format( tax_id, field)) md['_data'] = stat # Try to plot any column with `date` in it as a datetime, or if it's already a datetime # Plot numeric types as lmplots # Plot categorical, boolean, and objects as boxplot if 'date' in metadata.split('_') or pd.api.types.is_datetime64_any_dtype( md[metadata]): if not pd.api.types.is_datetime64_any_dtype(md[metadata]): md.loc[:, metadata] = md.loc[:, metadata].apply(pd.to_datetime, utc=True) fig, ax = plt.subplots() ax.xaxis.set_major_locator(mdates.AutoDateLocator()) ax.xaxis.set_major_formatter(mdates.DateFormatter('%m/%d/%Y')) ax.plot_date(md[metadata].values, md['_data'].values, **kwargs) fig.autofmt_xdate() elif pd.api.types.is_numeric_dtype(md[metadata]): sns.lmplot(x=metadata, y='_data', truncate=True, data=md, **kwargs) elif pd.api.types.is_bool_dtype(md[metadata]) or \ pd.api.types.is_categorical_dtype(md[metadata]) or \ pd.api.types.is_object_dtype(md[metadata]): na = {field: 'N/A' for field in md.columns} md.fillna(value=na, inplace=True) sns.boxplot(x=metadata, y='_data', data=md, palette='pastel', **kwargs) else: raise OneCodexException( 'Unplottable column type for metadata {}'.format(metadata)) # Labels if xlabel is not None: plt.gca().set_xlabel(xlabel) if ylabel is None: ylabel = statistic if statistic else field plt.gca().set_ylabel(ylabel) _, labels = plt.xticks() plt.setp(labels, rotation=90) if title: plt.suptitle(title) plt.show()
def plot_heatmap(analyses, top_n=20, threshold=None, title=None, label=None, xlabel=None, ylabel=None, field='readcount_w_children', rank=None, normalize=False, **kwargs): """Plot a heatmap of classification results by sample. Additional **kwargs are passed to Seaborn's `sns.clustermap`. """ if not (threshold or top_n): raise OneCodexException('Please set either "threshold" or "top_n"') normed_classifications, metadata = normalize_classifications(analyses, label=label) df, tax_info = collate_classification_results(normed_classifications, field=field, rank=rank, normalize=normalize) if len(df) < 2: raise OneCodexException( '`plot_heatmap` requires 2 or more valid classification results.') df.columns = [ '{} ({})'.format(tax_info[tax_id]['name'], tax_id) for tax_id in df.columns.values ] df.index = metadata.loc[:, '_display_name'] df.index.name = '' sns.set(style=kwargs.pop('style', 'darkgrid')) if threshold and top_n: idx = df.sum(axis=0).sort_values(ascending=False).head(top_n).index df = df.loc[:, idx] # can't use idx with multiple conditions g = sns.clustermap(df.loc[:, df.max() >= threshold], **kwargs) elif threshold: g = sns.clustermap(df.loc[:, df.max() >= threshold], **kwargs) elif top_n: idx = df.sum(axis=0).sort_values(ascending=False).head(top_n).index g = sns.clustermap(df.loc[:, idx], **kwargs) # Rotate the margin labels plt.setp(g.ax_heatmap.yaxis.get_majorticklabels(), rotation=0) # Labels if xlabel is not None: plt.gca().set_xlabel(xlabel) if ylabel is not None: plt.gca().set_ylabel(ylabel) if title: g.fig.suptitle(title) plt.show()
def plot_pca(analyses, threshold=None, title=None, hue=None, xlabel=None, ylabel=None, org_vectors=0, org_vectors_scale=None, field='readcount_w_children', rank=None, normalize=False, **kwargs): """Perform Principal Components Analysis to visualize the similarity of samples. Additional **kwargs are passed to Seaborn's sns.lmplot. """ # hue: piece of metadata to color by # org_vectors: boolean; whether to plot the most highly contributing organisms # org_vectors_scale_factor: scale factor to modify the length of the organism vectors normed_classifications, metadata = normalize_classifications(analyses) df, tax_info = collate_classification_results(normed_classifications, field=field, rank=rank, normalize=normalize) if len(df) < 2: raise OneCodexException( '`plot_pca` requires 2 or more valid classification results.') # normalize the magnitude of the data df = (df.T / df.sum(axis=1)).T pca = PCA() pca_vals = pca.fit(df.values).transform(df.values) pca_vals = pd.DataFrame(pca_vals, index=df.index) pca_vals.rename(columns=lambda x: "PCA{}".format(x + 1), inplace=True) metadata.index = df.index na = {field: 'N/A' for field in metadata.columns} plot_data = pd.concat([pca_vals, metadata.fillna(na)], axis=1) # Scatter plot of PCA sns.set(style=kwargs.pop('style', 'darkgrid')) g = sns.lmplot('PCA1', 'PCA2', data=plot_data, fit_reg=False, hue=hue, **kwargs) # Plot the organism eigenvectors that contribute the most if org_vectors > 0: # Plot the most highly contributing taxa magnitudes = np.sqrt(pca.components_[0]**2 + pca.components_[1]**2) magnitudes.sort() cutoff = magnitudes[-1 * org_vectors] # we can't use the "levels" method here b/c https://stackoverflow.com/questions/28772494 tax_ids = [x for x in df.columns.values] tax_id_map = {x: tax_info[x]['name'] for x in df.columns} if org_vectors_scale is None: org_vectors_scale = 0.8 * np.max(pca_vals.abs().values) for taxid, var1, var2 in \ zip(tax_ids, pca.components_[0, :], pca.components_[1, :]): if np.sqrt(var1**2 + var2**2) >= cutoff: g.axes.flat[0].annotate( "{} ({})".format(tax_id_map[taxid], taxid), xytext=(var1 * float(org_vectors_scale), var2 * float(org_vectors_scale)), xy=(0, 0), size=8, arrowprops={ 'facecolor': 'black', 'width': 1, 'headwidth': 5 }) # Labels if xlabel is None: xlabel = 'PCA1 ({}%)'.format( round(pca.explained_variance_ratio_[0] * 100, 2)) if ylabel is None: ylabel = 'PCA2 ({}%)'.format( round(pca.explained_variance_ratio_[1] * 100, 2)) plt.gca().set_xlabel(xlabel) plt.gca().set_ylabel(ylabel) # Title if title: plt.title(title) plt.show()