def sampled_heatmap(meth_stats, sampled_cluster_ids: co.Linkage,
                        cluster_id_name, sampled_index, sampled_linkage, name,
                        linkage_obj, pop_order, out_pattern):

        zscores_df = meth_stats.stats['beta-value_zscores'].loc[sampled_index,
                                                                pop_order]
        betas_df = meth_stats.stats['beta-value'].loc[sampled_index, pop_order]
        deltas_df = meth_stats.stats['beta-value_delta_hsc'].loc[sampled_index,
                                                                 pop_order]

        cdg_cols_clustered = co.ClusteredDataGrid(
            main_df=zscores_df, row_linkage=sampled_linkage.matrix)
        cdg_cols_clustered = cdg_cols_clustered.cluster_cols(
            method='complete', metric='cityblock')

        cdg_cols_ordered = co.ClusteredDataGrid(
            main_df=zscores_df, row_linkage=sampled_linkage.matrix)

        for cols_clustered, cdg in [(True, cdg_cols_clustered),
                                    (False, cdg_cols_ordered)]:
            print(cols_clustered)

            if name == 'equal_sampling':
                cluster_sizes = [
                    co.ClusterSizePlot(cluster_ids=linkage_obj.cluster_ids.
                                       df[cluster_id_name],
                                       bar_height=0.7,
                                       xlabel='Cluster size')
                ]
            else:
                cluster_sizes = []
            # TODO: printing gm before create_or_update_figure is called:
            # -> AttributeError: 'GridManager' object has no attribute 'fig'
            gm = cdg.plot_grid(
                grid=[[
                    co.Heatmap(df=zscores_df, cmap='RdBu_r', rasterized=True),
                    co.Heatmap(df=deltas_df, cmap='RdBu_r', rasterized=True),
                    co.Heatmap(df=betas_df, cmap='YlOrBr', rasterized=True),
                ] + cluster_sizes],
                figsize=(30 / 2.54, 15 / 2.54),
                height_ratios=[(1, 'rel')],
                row_annotation=sampled_cluster_ids.df,
                row_anno_heatmap_args={
                    'colors': [(1, 1, 1), (.8, .8, .8)],
                    'show_values': True
                },
                row_anno_col_width=2.5 / 2.54,
                col_dendrogram=cols_clustered,
            )
            gm.create_or_update_figure()
            out_pdf = out_pattern.format(name=name,
                                         clustercols=str(cols_clustered),
                                         poporder=','.join(pop_order))
            gm.fig.savefig(out_pdf)
            gm.fig.savefig(out_pdf.replace('.pdf', '.png'))
def masked_plot(
        overlap_counts: Union[str, pd.DataFrame], output_dir, param_name,
        quantile, min_abs_log_odds,
        linewidth, col_width_cm, row_height_cm, cbar_size_cm,
        row_dendrogram=False, row_labels_show=True):

    print(param_name)

    if isinstance(overlap_counts, str):
        with open(overlap_counts, 'rb') as fin:
            overlap_counts = pickle.load(fin)
    assert isinstance(overlap_counts, rsp.ClusterOverlapStats)

    plot_data: pd.DataFrame = overlap_counts.log_odds_ratio
    # Drop NA because we want to cluster the features
    plot_data = plot_data.dropna(how='any', axis=1)

    # Set all log odds ratios below the quantile to 0
    # Possible improvement: separately for depletion and enrichment
    q = np.quantile(plot_data.abs(), quantile)
    plot_data = plot_data.where(plot_data.abs().gt(q), 0)
    plot_data = plot_data.loc[:, plot_data.abs().gt(min_abs_log_odds).any()]

    vmin = np.quantile(plot_data, 0.02)
    vmax = np.quantile(plot_data, 0.98)


    plot_data = plot_data.T
    width = plot_data.shape[1] * col_width_cm / 2.54
    height = plot_data.shape[0] * row_height_cm / 2.54
    print('width (cm)', width * 2.54, 'height (cm)', height * 2.54)
    shrink = cbar_size_cm / height

    with mpl.rc_context(paper_context):
        print('Clustered plot')
        cdg = (co.ClusteredDataGrid(main_df=plot_data)
               .cluster_rows(method='average', metric='cityblock'))
        gm = cdg.plot_grid(grid=[
            [
                co.Heatmap(df=plot_data,
                           cmap='RdBu_r',
                           vmin = vmin,
                           vmax = vmax,
                           norm = MidpointNormalize(vmin=vmin, vmax=vmax, midpoint=0),
                           row_labels_show=row_labels_show,
                           rasterized=False,
                           edgecolor='white',
                           linewidth=linewidth,
                           cbar_args = dict(shrink=shrink, aspect=20)
                           ),
            ]
        ],
                figsize=(width, height),
                height_ratios=[(1, 'rel')],
                row_dendrogram=row_dendrogram,
        )
        gm.create_or_update_figure()
        gm.fig.savefig(output_dir / f'all-significant_clustered_log-odds_masked-{param_name}.png')
        gm.fig.savefig(output_dir / f'all-significant_clustered_log-odds_masked-{param_name}.pdf')
Esempio n. 3
0
def barcode_heatmap(
    cluster_overlap_stats: rsp.ClusterOverlapStats,
    plot_stat="p-value",
    vmin=None,
    vmax=None,
    vmin_quantile=0.02,
    vlim=None,
    cluster_features=True,
    spaced_heatmap_kws=None,
    col_width_cm=2,
    row_height_cm=0.2,
    row_labels_show=True,
    divergent_cmap="RdBu_r",
    sequential_cmap="YlOrBr",
    linewidth=1,
    rasterized=False,
    cbar_args=None,
    robust=True,
    clusters_as_rows=False,
    force_row_height=False,
    metric="cityblock",
    method="average",
    **kwargs,
) -> Figure:
    """Barcode heatmap

    automatically normalizes to vcenter=0 if divergent stat, dont pass norm

    Args:
        filter_on_per_feature_pvalue: if False, filter based on aggregated
            feature-info from per_cluster_per_feature pvalues (not implemented yet)
        plot_stat: 'p-value' or 'log-odds' (could also add both at the same time...)
        cbar_args: defaults to dict(shrink=0.4, aspect=20, extend='both')
        kwargs: passed to co.Heatmap
        vmin, vmax, robust: if vmin or vmax are not set, the 0.02 and 0.98
            quantiles are used (robust=True), or the min and max of all values
            are used otherwise
        force_row_height: if the specified row height is smaller than the estimated label height, it is set to the estimated label height, unless this flag is set to True
    """
    # print("new barcode heatmap")

    colorbar_height_in = 2 / 2.54
    colorbar_width_in = 0.7 / 2.54

    if cbar_args is None:
        # TODO extend='both' fails
        # cbar_args = dict(shrink=0.4, aspect=20, extend='both')
        cbar_args = dict(shrink=0.4, aspect=20)

    # Get plot stat
    # --------------------------------------------------------------------------
    if plot_stat == "p-value":
        # To visualize the p-values, we give log10(p-values) associated with
        # positive log-odds ratios a positive sign, while p-values associated
        # with depletion retain the negative sign
        log10_pvalues = np.log10(
            cluster_overlap_stats.cluster_pvalues + 1e-100
        )  # add small float to avoid inf values
        plot_stat = log10_pvalues * -np.sign(cluster_overlap_stats.log_odds_ratio)

    elif plot_stat == "log-odds":
        plot_stat = cluster_overlap_stats.log_odds_ratio

    # Discard features with NA if we are clustering the features
    if cluster_features:
        plot_stat = plot_stat.dropna(how="any", axis=1)

    # Create heatmap
    # --------------------------------------------------------------------------
    plot_stat_is_divergent = plot_stat.lt(0).any(axis=None)
    cmap = divergent_cmap if plot_stat_is_divergent else sequential_cmap

    # Transpose plot stat for plotting and final processing
    if not clusters_as_rows:
        plot_stat = plot_stat.T

    # Get plot dimensions
    curr_font_size = mpl.rcParams["font.size"]
    row_label_width, row_label_height = get_text_width_height(
        plot_stat.index.astype(str), curr_font_size
    )
    col_label_width, col_label_height = get_text_width_height(
        plot_stat.columns.astype(str), curr_font_size, target_axis="x"
    )
    if force_row_height:
        height = plot_stat.shape[0] * row_height_cm + col_label_height
    else:
        height = (
            plot_stat.shape[0] * max(row_height_cm, row_label_height) + col_label_height
        )
    width = (
        row_label_width + (plot_stat.shape[1] * col_width_cm / 2.54) + colorbar_width_in
    )
    colorbar_height_in = min(colorbar_height_in, height)

    if vmin is None:
        if robust:
            vmin = np.quantile(plot_stat, vmin_quantile)
        else:
            vmin = plot_stat.min().min()
    if vmax is None:
        if robust:
            vmax = np.quantile(plot_stat, 1 - vmin_quantile)
        else:
            vmax = plot_stat.max().max()
    if vlim is not None:
        vmin = min(vmin, vlim[0])
        vmax = max(vmax, vlim[1])

    if plot_stat_is_divergent:
        norm = MidpointNormalize(vmin=vmin, vmax=vmax, vcenter=0.0)
    else:
        # note: this block may be wrong and untested
        norm = None
        # does not seem to be necessary?
        # cbar_args.update({'clim': (vmin, vmax)})

    # print("Clustered plot")
    cdg = co.ClusteredDataGrid(main_df=plot_stat)
    if cluster_features:
        if clusters_as_rows:
            cdg.cluster_cols(method=method, metric=metric)
        else:
            cdg.cluster_rows(method=method, metric=metric)

    # doesn't work well with these formulas
    # shrink = colorbar_height_in / height
    # aspect = colorbar_height_in / colorbar_width_in
    # other_cbar_args = dict(shrink=shrink, aspect=aspect)

    if spaced_heatmap_kws is None:
        heatmap = co.Heatmap(
            df=plot_stat,
            cmap=cmap,
            row_labels_show=row_labels_show,
            norm=norm,
            rasterized=rasterized,
            linewidth=linewidth,
            cbar_args=cbar_args,
            # cbar_args=other_cbar_args,
            edgecolor="white",
            **kwargs,
        )
    else:
        heatmap = co.SpacedHeatmap(
            df=plot_stat,
            pcolormesh_args=dict(
                cmap=cmap,
                norm=norm,
                rasterized=rasterized,
                linewidth=linewidth,
                edgecolor="white",
            ),
            show_row_labels=row_labels_show,
            show_col_labels=True,
            add_colorbar=True,
            cbar_args=cbar_args,
            **spaced_heatmap_kws,
            # cbar_args=other_cbar_args,
        )

    gm = cdg.plot_grid(
        grid=[[heatmap]],
        figsize=(width, height),
        height_ratios=[(1, "rel")],
        row_dendrogram=False,
    )
    gm.create_or_update_figure()
    return gm.fig
Esempio n. 4
0
def plot_all_enriched_features(col_width_cm, output_dir, overlap_counts,
                               row_height_cm, sign_threshold, statistics,
                               test_statistics_fp):

    print('Reading input data')

    if isinstance(overlap_counts, str):
        with open(overlap_counts, 'rb') as fin:
            overlap_counts = pickle.load(fin)
    assert isinstance(overlap_counts, rsp.ClusterOverlapStats)

    # test_statistics = pd.read_pickle(test_statistics_fp)

    for statistic in statistics:
        print(statistic)
        print('Prepare plot')
        if statistic == 'normalized_ratio':
            plot_data = overlap_counts.normalized_ratio
        elif statistic == 'log_odds_ratio':
            plot_data = overlap_counts.log_odds_ratio
        else:
            raise ValueError

        # is_significant = test_statistics.qvalues < sign_threshold
        # plot_data = plot_data.loc[:, is_significant].sort_index(axis=1)
        plot_data = plot_data.sort_index(axis=1)
        plot_data = plot_data.dropna(how='any', axis=1)
        if plot_data.lt(0).any(axis=None):
            cmap = 'RdYlGn_r'
        else:
            cmap = 'YlGnBu_r'
        width = plot_data.shape[1] * col_width_cm / 2.54
        height = plot_data.shape[0] * row_height_cm

        print('Clustered plot')
        cdg = (co.ClusteredDataGrid(main_df=plot_data).cluster_cols(
            method='average', metric='cityblock'))
        gm = cdg.plot_grid(
            grid=[[
                co.Heatmap(
                    df=plot_data,
                    cmap=cmap,
                    row_labels_show=True,
                    rasterized=True,
                ),
            ]],
            figsize=(width, height),
            height_ratios=[(1, 'rel')],
            col_dendrogram=True,
        )
        gm.create_or_update_figure()
        gm.fig.savefig(output_dir /
                       f'all-significant_clustered_{statistic}.png')
        gm.fig.savefig(output_dir /
                       f'all-significant_clustered_{statistic}.pdf')

        print('Alphabetic plot')
        cdg = (co.ClusteredDataGrid(main_df=plot_data))
        gm = cdg.plot_grid(
            grid=[[
                co.Heatmap(
                    df=plot_data,
                    cmap=cmap,
                    row_labels_show=True,
                    rasterized=True,
                ),
            ]],
            figsize=(width, height),
            height_ratios=[(1, 'rel')],
        )
        gm.create_or_update_figure()
        gm.fig.savefig(output_dir /
                       f'all-significant_alphabetic_{statistic}.png')
        gm.fig.savefig(output_dir /
                       f'all-significant_alphabetic_{statistic}.pdf')
Esempio n. 5
0
def norm_plot(overlap_counts: Union[str, pd.DataFrame],
              output_dir,
              param_name,
              quantile,
              min_abs_log_odds,
              norm_plateau_height,
              linewidth,
              col_width_cm,
              row_height_cm,
              cbar_size_cm,
              row_dendrogram=False,
              row_labels_show=True):

    print(param_name)

    if isinstance(overlap_counts, str):
        with open(overlap_counts, 'rb') as fin:
            overlap_counts = pickle.load(fin)
    assert isinstance(overlap_counts, rsp.ClusterOverlapStats)

    class RangeNorm(colors.Normalize):
        def __init__(self, neg_q, pos_q, vmin, vmax, height):
            self.vmin = vmin
            self.vmax = vmax
            self.height = height
            self.neg_q = neg_q
            self.pos_q = pos_q
            colors.Normalize.__init__(self, vmin, vmax, clip=True)

        def __call__(self, value, clip=None):
            # xp = [self.vmin, self.neg_q, 0, self.pos_q, self.vmax]
            # yp = [0, 0.5 - self.height/2, 0.5, 0.5 + self.height/2, 1]
            # return np.ma.masked_array(np.interp(value, xp, yp))
            xp = [self.vmin, self.vmax]
            yp = [0, 1]
            res = np.interp(value, xp, yp)
            midpoint = self.neg_q + (self.pos_q - self.neg_q) / 2
            # # res[(value > self.neg_q) & (value < self.pos_q)] = 0.5
            view1 = value[(value > self.neg_q) & (value < midpoint)]
            if view1.size != 0:
                cubic = ((view1 - midpoint) / np.abs(view1.min()))**4
                res[(value > self.neg_q)
                    & (value < midpoint)] = 0.5 + -cubic * (
                        0.5 - np.interp(self.neg_q, xp, yp))
            view2 = value[(value < self.pos_q) & (value > midpoint)]
            if view2.size != 0:
                cubic = ((view2 + midpoint) / np.abs(view2.max()))**4
                res[(value < self.pos_q)
                    & (value > midpoint)] = 0.5 + -cubic * (
                        0.5 - np.interp(self.pos_q, xp, yp))
            return np.ma.masked_array(res)

    plot_data: pd.DataFrame = overlap_counts.log_odds_ratio
    plot_data = plot_data.dropna(how='any', axis=1)

    # log_odds_flat = np.ravel(plot_data)
    # vmin, neg_sat, neg_quant = np.quantile(log_odds_flat[log_odds_flat < 0], [0.001, 0.2, 0.3])
    # pos_quant, pos_sat, vmax = np.quantile(log_odds_flat[log_odds_flat > 0], [0.7, 0.8, 0.999])
    # neg_quant = -0.7
    # pos_quant = +0.7
    vmin = np.quantile(plot_data, 0.01)
    vmax = np.quantile(plot_data, 0.99)
    q = np.quantile(plot_data.abs(), quantile)
    neg_quant = -q
    pos_quant = q

    plot_data = plot_data.loc[:, plot_data.abs().gt(min_abs_log_odds).any()]

    # plot_data = plot_data.where(plot_data.abs().gt(1), 0)

    norm = RangeNorm(neg_q=neg_quant,
                     pos_q=pos_quant,
                     vmin=vmin,
                     vmax=vmax,
                     height=norm_plateau_height)

    # Plot the norm function
    x = np.linspace(-6, 6, 300)
    y = norm(x)
    fig, ax = plt.subplots(1, 1)
    ax.plot(x, y)
    x = vmin, neg_quant, 0, pos_quant, vmax
    y = [
        0, 0.5 - norm_plateau_height / 2, 0.5, 0.5 + norm_plateau_height / 2, 1
    ]
    ax.scatter(x, y)
    fig.savefig(output_dir / 'test.png')

    plot_data = plot_data.T

    width = plot_data.shape[1] * col_width_cm / 2.54
    height = plot_data.shape[0] * row_height_cm / 2.54
    shrink = cbar_size_cm / height

    # with mpl.rc_context(paper_context):
    cdg = (co.ClusteredDataGrid(main_df=plot_data).cluster_rows(
        method='average', metric='cityblock'))
    gm = cdg.plot_grid(
        grid=[[
            co.Heatmap(df=plot_data,
                       cmap='RdBu_r',
                       norm=norm,
                       row_labels_show=row_labels_show,
                       rasterized=False,
                       edgecolor='white',
                       linewidth=linewidth,
                       cbar_args=dict(
                           shrink=shrink,
                           aspect=20,
                       )),
        ]],
        figsize=(width, height),
        height_ratios=[(1, 'rel')],
        row_dendrogram=row_dendrogram,
    )
    gm.create_or_update_figure()
    gm.fig.savefig(output_dir /
                   f'all-significant_clustered_log-odds_norm-{param_name}.png')
    gm.fig.savefig(output_dir /
                   f'all-significant_clustered_log-odds_norm-{param_name}.pdf')