Ejemplo n.º 1
0
def test_unsuccessful_classification(ocx, api_data):
    classifications = [
        ocx.Classifications.get('45a573fb7833449a'),
        ocx.Classifications.get('593601a797914cbf')
    ]
    classifications[0]._resource.success = False
    with pytest.warns(UserWarning):
        normed_classifications, _ = normalize_classifications(classifications)
        assert len(normed_classifications) == 1

    normed_classifications, _ = normalize_classifications(classifications,
                                                          skip_missing=False)
    assert len(normed_classifications) == 2
Ejemplo n.º 2
0
def test_normalize_analyses_and_samples(ocx, api_data):
    analyses = [
        ocx.Analyses.get('45a573fb7833449a'),
        ocx.Analyses.get('593601a797914cbf')
    ]
    normed_classifications, metadata = normalize_classifications(analyses)
    for n in normed_classifications:
        assert isinstance(n, Classifications)

    normed_classifications, metadata = normalize_classifications(
        [a.sample for a in analyses])
    for n in normed_classifications:
        assert isinstance(n, Classifications)
Ejemplo n.º 3
0
def test_normalize_classifications_labeling(ocx, api_data):
    cs = [
        ocx.Classifications.get('45a573fb7833449a'),
        ocx.Classifications.get('593601a797914cbf')
    ]
    _, metadata = normalize_classifications(cs, label='created_at')
    assert (metadata['_display_name'] == metadata['created_at']).all()

    with pytest.raises(OneCodexException):
        _, metadata = normalize_classifications(cs, label='nonexisting')

    _, metadata = normalize_classifications(cs,
                                            label=lambda x: x.sample.filename)
    assert (metadata['_display_name'] == [c.sample.filename for c in cs]).all()
Ejemplo n.º 4
0
def plot_distance(analyses, metric='braycurtis',
                  title=None, label=None, xlabel=None, ylabel=None,
                  field='readcount_w_children', rank='species', **kwargs):
    """Plot beta diversity distance matrix.

    Additional **kwargs are passed to Seaborn's `sns.clustermap`.
    """
    # if taxonomy trees are inconsistent, unifrac will not work
    if metric in ['braycurtis', 'bray-curtis', 'bray curtis']:
        f = braycurtis
    elif metric in ['manhattan', 'cityblock']:
        f = cityblock
    elif metric == 'jaccard':
        f = jaccard
    elif metric == 'unifrac':
        f = unifrac
    else:
        raise OneCodexException("'metric' must be one of "
                                "braycurtis, manhattan, jaccard, or unifrac")

    normed_classifications, metadata = normalize_classifications(analyses, label=label)
    if len(normed_classifications) < 2:
        raise OneCodexException('`plot_distance` requires 2 or more valid classification results.')

    sns.set(style=kwargs.pop('style', 'darkgrid'))

    # there is no uniqueness constraint on metadata names
    # so plot by uuid, then replace the labels in the dataframe with their names
    uuids = {}
    sample_names = {}
    for idx, analysis in enumerate(normed_classifications):
        uuids[analysis.id] = analysis.id
        sample_names[analysis.id] = metadata.loc[idx, '_display_name']

    distances = f(normed_classifications, field=field, rank=rank)
    ids = distances.ids
    distance_matrix = distances.data
    dists = {}
    for idx1, id1 in enumerate(ids):
        dists[uuids[id1]] = {}
        for idx2, id2 in enumerate(ids):
            dists[uuids[id1]][uuids[id2]] = distance_matrix[idx1][idx2]
    dists = pd.DataFrame(dists).rename(index=sample_names, columns=sample_names)

    # Plot cluster map; ignore new SciPy cluster warnings
    with warnings.catch_warnings():
        warnings.simplefilter('ignore', scipy.cluster.hierarchy.ClusterWarning)
        g = sns.clustermap(dists, **kwargs)

    plt.setp(g.ax_heatmap.yaxis.get_majorticklabels(), rotation=0)

    # Labels
    if xlabel is not None:
        plt.gca().set_xlabel(xlabel)
    if ylabel is not None:
        plt.gca().set_ylabel(ylabel)

    if title:
        g.fig.suptitle(title)
    plt.show()
Ejemplo n.º 5
0
def beta_counts(classifications, field='readcount_w_children', rank='species'):
    normed_classifications, _ = normalize_classifications(classifications)
    df, tax_info = collate_classification_results(normed_classifications, field=field, rank=rank)

    tax_ids = df.columns.values.tolist()
    vectors = df.values.tolist()
    vectors = [[int(i) for i in row] for row in vectors]
    ids = df.index.values
    return (vectors, tax_ids, ids)
Ejemplo n.º 6
0
def test_normalize_classifications(ocx, api_data):
    classifications = [
        ocx.Classifications.get('45a573fb7833449a'),
        ocx.Classifications.get('593601a797914cbf')
    ]
    normed_classifications, metadata = normalize_classifications(
        classifications)
    assert len(normed_classifications) == 2
    assert len(metadata) == 2
    assert isinstance(metadata, pd.DataFrame)
    assert (metadata['_display_name'] == metadata['name']).all()

    for n in normed_classifications:
        assert isinstance(n, Classifications)
Ejemplo n.º 7
0
def plot_metadata(analyses,
                  metadata='created_at',
                  statistic=None,
                  tax_id=None,
                  title=None,
                  label=None,
                  xlabel=None,
                  ylabel=None,
                  field='readcount_w_children',
                  rank='species',
                  normalize=False,
                  **kwargs):
    """Plot by arbitary metadata.

    Note that `rank` only applies if you're calculating a `statistic`.
    Additional **kwargs are passed to Seaborn or Matplotlib plot functions as appropriate.
    """
    if not tax_id and not statistic:
        raise OneCodexException('Please pass a `tax_id` or a `statistic`.')
    elif tax_id and statistic:
        raise OneCodexException(
            'Please pass only a `tax_id` or a `statistic`.')

    sns.set(style=kwargs.pop('style', 'darkgrid'))

    normed_classifications, md = normalize_classifications(analyses,
                                                           label=label)

    if metadata not in md:
        raise OneCodexException(
            'Selected metadata field `{}` not available'.format(metadata))

    stat = []
    if statistic:
        for analysis in normed_classifications:
            try:
                stat.append(
                    alpha_diversity(analysis,
                                    statistic,
                                    field=field,
                                    rank=rank))
            except ValueError:
                raise NotImplementedError(
                    '{} statistic are currently supported.'.format(statistic))
    elif tax_id:
        df, tax_info = collate_classification_results(normed_classifications,
                                                      field=field,
                                                      rank=None,
                                                      normalize=normalize)
        stat = df.loc[:, str(tax_id)].values
        if stat.shape[0] == 0:
            raise OneCodexException(
                'No values found for `tax_id` {} and `field` {}.'.format(
                    tax_id, field))

    md['_data'] = stat

    # Try to plot any column with `date` in it as a datetime, or if it's already a datetime
    # Plot numeric types as lmplots
    # Plot categorical, boolean, and objects as boxplot
    if 'date' in metadata.split('_') or pd.api.types.is_datetime64_any_dtype(
            md[metadata]):
        if not pd.api.types.is_datetime64_any_dtype(md[metadata]):
            md.loc[:, metadata] = md.loc[:, metadata].apply(pd.to_datetime,
                                                            utc=True)
        fig, ax = plt.subplots()
        ax.xaxis.set_major_locator(mdates.AutoDateLocator())
        ax.xaxis.set_major_formatter(mdates.DateFormatter('%m/%d/%Y'))
        ax.plot_date(md[metadata].values, md['_data'].values, **kwargs)
        fig.autofmt_xdate()
    elif pd.api.types.is_numeric_dtype(md[metadata]):
        sns.lmplot(x=metadata, y='_data', truncate=True, data=md, **kwargs)
    elif pd.api.types.is_bool_dtype(md[metadata]) or \
            pd.api.types.is_categorical_dtype(md[metadata]) or \
            pd.api.types.is_object_dtype(md[metadata]):
        na = {field: 'N/A' for field in md.columns}
        md.fillna(value=na, inplace=True)
        sns.boxplot(x=metadata, y='_data', data=md, palette='pastel', **kwargs)
    else:
        raise OneCodexException(
            'Unplottable column type for metadata {}'.format(metadata))

    # Labels
    if xlabel is not None:
        plt.gca().set_xlabel(xlabel)
    if ylabel is None:
        ylabel = statistic if statistic else field
    plt.gca().set_ylabel(ylabel)

    _, labels = plt.xticks()
    plt.setp(labels, rotation=90)

    if title:
        plt.suptitle(title)

    plt.show()
Ejemplo n.º 8
0
def plot_heatmap(analyses,
                 top_n=20,
                 threshold=None,
                 title=None,
                 label=None,
                 xlabel=None,
                 ylabel=None,
                 field='readcount_w_children',
                 rank=None,
                 normalize=False,
                 **kwargs):
    """Plot a heatmap of classification results by sample.

    Additional **kwargs are passed to Seaborn's `sns.clustermap`.
    """
    if not (threshold or top_n):
        raise OneCodexException('Please set either "threshold" or "top_n"')

    normed_classifications, metadata = normalize_classifications(analyses,
                                                                 label=label)
    df, tax_info = collate_classification_results(normed_classifications,
                                                  field=field,
                                                  rank=rank,
                                                  normalize=normalize)

    if len(df) < 2:
        raise OneCodexException(
            '`plot_heatmap` requires 2 or more valid classification results.')

    df.columns = [
        '{} ({})'.format(tax_info[tax_id]['name'], tax_id)
        for tax_id in df.columns.values
    ]
    df.index = metadata.loc[:, '_display_name']
    df.index.name = ''

    sns.set(style=kwargs.pop('style', 'darkgrid'))

    if threshold and top_n:
        idx = df.sum(axis=0).sort_values(ascending=False).head(top_n).index
        df = df.loc[:, idx]  # can't use idx with multiple conditions
        g = sns.clustermap(df.loc[:, df.max() >= threshold], **kwargs)
    elif threshold:
        g = sns.clustermap(df.loc[:, df.max() >= threshold], **kwargs)
    elif top_n:
        idx = df.sum(axis=0).sort_values(ascending=False).head(top_n).index
        g = sns.clustermap(df.loc[:, idx], **kwargs)

    # Rotate the margin labels
    plt.setp(g.ax_heatmap.yaxis.get_majorticklabels(), rotation=0)

    # Labels
    if xlabel is not None:
        plt.gca().set_xlabel(xlabel)
    if ylabel is not None:
        plt.gca().set_ylabel(ylabel)

    if title:
        g.fig.suptitle(title)

    plt.show()
Ejemplo n.º 9
0
def plot_pca(analyses,
             threshold=None,
             title=None,
             hue=None,
             xlabel=None,
             ylabel=None,
             org_vectors=0,
             org_vectors_scale=None,
             field='readcount_w_children',
             rank=None,
             normalize=False,
             **kwargs):
    """Perform Principal Components Analysis to visualize the similarity of samples.

    Additional **kwargs are passed to Seaborn's sns.lmplot.
    """
    # hue: piece of metadata to color by
    # org_vectors: boolean; whether to plot the most highly contributing organisms
    # org_vectors_scale_factor: scale factor to modify the length of the organism vectors
    normed_classifications, metadata = normalize_classifications(analyses)
    df, tax_info = collate_classification_results(normed_classifications,
                                                  field=field,
                                                  rank=rank,
                                                  normalize=normalize)

    if len(df) < 2:
        raise OneCodexException(
            '`plot_pca` requires 2 or more valid classification results.')

    # normalize the magnitude of the data
    df = (df.T / df.sum(axis=1)).T

    pca = PCA()
    pca_vals = pca.fit(df.values).transform(df.values)
    pca_vals = pd.DataFrame(pca_vals, index=df.index)
    pca_vals.rename(columns=lambda x: "PCA{}".format(x + 1), inplace=True)
    metadata.index = df.index
    na = {field: 'N/A' for field in metadata.columns}
    plot_data = pd.concat([pca_vals, metadata.fillna(na)], axis=1)

    # Scatter plot of PCA
    sns.set(style=kwargs.pop('style', 'darkgrid'))
    g = sns.lmplot('PCA1',
                   'PCA2',
                   data=plot_data,
                   fit_reg=False,
                   hue=hue,
                   **kwargs)

    # Plot the organism eigenvectors that contribute the most
    if org_vectors > 0:
        # Plot the most highly contributing taxa
        magnitudes = np.sqrt(pca.components_[0]**2 + pca.components_[1]**2)
        magnitudes.sort()
        cutoff = magnitudes[-1 * org_vectors]
        # we can't use the "levels" method here b/c https://stackoverflow.com/questions/28772494
        tax_ids = [x for x in df.columns.values]
        tax_id_map = {x: tax_info[x]['name'] for x in df.columns}
        if org_vectors_scale is None:
            org_vectors_scale = 0.8 * np.max(pca_vals.abs().values)
        for taxid, var1, var2 in \
                zip(tax_ids, pca.components_[0, :], pca.components_[1, :]):
            if np.sqrt(var1**2 + var2**2) >= cutoff:
                g.axes.flat[0].annotate(
                    "{} ({})".format(tax_id_map[taxid], taxid),
                    xytext=(var1 * float(org_vectors_scale),
                            var2 * float(org_vectors_scale)),
                    xy=(0, 0),
                    size=8,
                    arrowprops={
                        'facecolor': 'black',
                        'width': 1,
                        'headwidth': 5
                    })
    # Labels
    if xlabel is None:
        xlabel = 'PCA1 ({}%)'.format(
            round(pca.explained_variance_ratio_[0] * 100, 2))
    if ylabel is None:
        ylabel = 'PCA2 ({}%)'.format(
            round(pca.explained_variance_ratio_[1] * 100, 2))
    plt.gca().set_xlabel(xlabel)
    plt.gca().set_ylabel(ylabel)

    # Title
    if title:
        plt.title(title)

    plt.show()