예제 #1
0
 def plot_all_patterns(self,
                       kind='seq',
                       trim_frac=0,
                       n_min_seqlets=None,
                       ylim=None,
                       no_axis=False,
                       **kwargs):
     """
     Args:
       kind:
     """
     self.stats()  # print stats
     for pattern in self.pattern_names():
         if n_min_seqlets is not None and self.n_seqlets(
                 pattern) < n_min_seqlets:
             continue
         self.plot_pattern(pattern,
                           kind=kind,
                           trim_frac=trim_frac,
                           **kwargs)
         if ylim is not None:
             plt.ylim(ylim)
         if no_axis:
             strip_axis(plt.gca())
             plt.gca().axison = False
예제 #2
0
def tidy_motif_plot(ax=None):
    if ax is None:
        ax = plt.gca()
    strip_axis(ax)
    ax.set_xlabel(None)
    ax.get_xaxis().set_visible(False)
    ax.set_ylim([0, 2.0])
    ax.set_ylabel("IC")
예제 #3
0
def blank_ax(ax):
    strip_axis(ax)
    ax.axison = False
예제 #4
0
def plot_profiles_single(seqlet,
                         x,
                         tracks,
                         contribution_scores={},
                         figsize=(20, 2),
                         legend=True,
                         rotate_y=90,
                         seq_height=1,
                         flip_neg=False,
                         rc_fn=lambda x: x[::-1, ::-1]):
    """
    Plot the sequence profiles
    Args:
      x: one-hot-encoded sequence
      tracks: dictionary of profile tracks
      contribution_scores: optional dictionary of contribution scores

    """
    import matplotlib.pyplot as plt
    from concise.utils.plot import seqlogo_fig, seqlogo

    # --------------
    # extract signal
    seq = seqlet.extract(x)
    ext_contribution_scores = {
        s: seqlet.extract(contrib)
        for s, contrib in contribution_scores.items()
    }

    fig, ax = plt.subplots(1 + len(contribution_scores) + len(tracks),
                           1,
                           sharex=True,
                           figsize=figsize,
                           gridspec_kw={
                               'height_ratios': [1] * len(tracks) +
                               [seq_height] * (1 + len(contribution_scores))
                           })

    # signal
    for i, (k, signal) in enumerate(tracks.items()):
        plot_stranded_profile(seqlet.extract(signal),
                              ax=ax[i],
                              flip_neg=flip_neg)
        simple_yaxis_format(ax[i])
        strip_axis(ax[i])
        ax[i].set_ylabel(f"{k}", rotation=rotate_y, ha='right', labelpad=5)

        if legend:
            ax[i].legend()

    # -----------
    # contribution scores (seqlogo)
    # -----------
    max_scale = max([
        np.maximum(v, 0).sum(axis=-1).max()
        for v in ext_contribution_scores.values()
    ])
    min_scale = min([
        np.minimum(v, 0).sum(axis=-1).min()
        for v in ext_contribution_scores.values()
    ])
    for k, (contrib_score_name,
            logo) in enumerate(ext_contribution_scores.items()):
        ax_id = len(tracks) + k
        # plot
        ax[ax_id].set_ylim([min_scale, max_scale])
        ax[ax_id].axhline(y=0, linewidth=1, linestyle='--', color='grey')
        seqlogo(logo, ax=ax[ax_id])

        # style
        simple_yaxis_format(ax[ax_id])
        strip_axis(ax[ax_id])
        # ax[ax_id].set_ylabel(contrib_score_name)
        ax[ax_id].set_ylabel(contrib_score_name,
                             rotation=rotate_y,
                             ha='right',
                             labelpad=5)  # va='bottom',

    # -----------
    # information content (seqlogo)
    # -----------
    # plot
    seqlogo(seq, ax=ax[-1])

    # style
    simple_yaxis_format(ax[-1])
    strip_axis(ax[-1])
    ax[-1].set_ylabel("Inf. content",
                      rotation=rotate_y,
                      ha='right',
                      labelpad=5)
    ax[-1].set_xticks(list(range(0, len(seq) + 1, 5)))
    return fig
예제 #5
0
def plot_profiles(
        seqlets_by_pattern,
        x,
        tracks,
        contribution_scores={},
        figsize=(20, 2),
        start_vec=None,
        width=20,
        legend=True,
        rotate_y=90,
        seq_height=1,
        ymax=None,  # determine y-max
        n_limit=35,
        n_bootstrap=None,
        flip_neg=False,
        patterns=None,
        fpath_template=None,
        only_idx=None,
        mkdir=False,
        rc_fn=lambda x: x[::-1, ::-1]):
    """
    Plot the sequence profiles
    Args:
      x: one-hot-encoded sequence
      tracks: dictionary of profile tracks
      contribution_scores: optional dictionary of contribution scores

    """
    import matplotlib.pyplot as plt
    from concise.utils.plot import seqlogo_fig, seqlogo

    # Setup start-vec
    if start_vec is not None:
        if not isinstance(start_vec, list):
            start_vec = [start_vec] * len(patterns)
    else:
        start_vec = [0] * len(patterns)
        width = len(x)

    if patterns is None:
        patterns = list(seqlets_by_pattern)
    # aggregated profiles
    d_signal_patterns = {
        pattern: {
            k: aggregate_profiles(extract_signal(
                y, seqlets_by_pattern[pattern])[:,
                                                start_vec[ip]:(start_vec[ip] +
                                                               width)],
                                  n_bootstrap=n_bootstrap,
                                  only_idx=only_idx)
            for k, y in tracks.items()
        }
        for ip, pattern in enumerate(patterns)
    }
    if ymax is None:
        # infer ymax
        def take_max(x, dx):
            if dx is None:
                return x.max()
            else:
                # HACK - hard-coded 2
                return (x + 2 * dx).max()

        ymax = [
            max([
                take_max(*d_signal_patterns[pattern][k])
                for pattern in patterns
            ]) for k in tracks
        ]  # loop through all the tracks
    if not isinstance(ymax, list):
        ymax = [ymax] * len(tracks)

    figs = []
    for i, pattern in enumerate(tqdm(patterns)):
        j = i
        # --------------
        # extract signal
        seqs = extract_signal(
            x,
            seqlets_by_pattern[pattern])[:,
                                         start_vec[i]:(start_vec[i] + width)]
        ext_contribution_scores = {
            s: extract_signal(
                contrib,
                seqlets_by_pattern[pattern])[:, start_vec[i]:(start_vec[i] +
                                                              width)]
            for s, contrib in contribution_scores.items()
        }
        d_signal = d_signal_patterns[pattern]
        # --------------
        if only_idx is None:
            sequence = ic_scale(seqs.mean(axis=0))
        else:
            sequence = seqs[only_idx]

        n = len(seqs)
        if n < n_limit:
            continue
        fig, ax = plt.subplots(1 + len(contribution_scores) + len(tracks),
                               1,
                               sharex=True,
                               figsize=figsize,
                               gridspec_kw={
                                   'height_ratios':
                                   [1] * len(tracks) + [seq_height] *
                                   (1 + len(contribution_scores))
                               })

        # signal
        ax[0].set_title(f"{pattern} ({n})")
        for i, (k, signal) in enumerate(d_signal.items()):
            signal_mean, signal_std = d_signal_patterns[pattern][k]
            plot_stranded_profile(signal_mean,
                                  ax=ax[i],
                                  ymax=ymax[i],
                                  profile_std=signal_std,
                                  flip_neg=flip_neg)
            simple_yaxis_format(ax[i])
            strip_axis(ax[i])
            ax[i].set_ylabel(f"{k}", rotation=rotate_y, ha='right', labelpad=5)

            if legend:
                ax[i].legend()

        # -----------
        # contribution scores (seqlogo)
        # -----------
        # average the contribution scores
        if only_idx is None:
            norm_contribution_scores = {
                k: v.mean(axis=0)
                for k, v in ext_contribution_scores.items()
            }
        else:
            norm_contribution_scores = {
                k: v[only_idx]
                for k, v in ext_contribution_scores.items()
            }

        max_scale = max([
            np.maximum(v, 0).sum(axis=-1).max()
            for v in norm_contribution_scores.values()
        ])
        min_scale = min([
            np.minimum(v, 0).sum(axis=-1).min()
            for v in norm_contribution_scores.values()
        ])
        for k, (contrib_score_name,
                logo) in enumerate(norm_contribution_scores.items()):
            ax_id = len(tracks) + k

            # Trim the pattern if necessary
            # plot
            ax[ax_id].set_ylim([min_scale, max_scale])
            ax[ax_id].axhline(y=0, linewidth=1, linestyle='--', color='grey')
            seqlogo(logo, ax=ax[ax_id])

            # style
            simple_yaxis_format(ax[ax_id])
            strip_axis(ax[ax_id])
            # ax[ax_id].set_ylabel(contrib_score_name)
            ax[ax_id].set_ylabel(contrib_score_name,
                                 rotation=rotate_y,
                                 ha='right',
                                 labelpad=5)  # va='bottom',

        # -----------
        # information content (seqlogo)
        # -----------
        # plot
        seqlogo(sequence, ax=ax[-1])

        # style
        simple_yaxis_format(ax[-1])
        strip_axis(ax[-1])
        ax[-1].set_ylabel("Inf. content",
                          rotation=rotate_y,
                          ha='right',
                          labelpad=5)
        ax[-1].set_xticks(list(range(0, len(sequence) + 1, 5)))

        figs.append(fig)
        # save to file
        if fpath_template is not None:
            pname = pattern.replace("/", ".")
            basepath = fpath_template.format(pname=pname, pattern=pattern)
            if mkdir:
                os.makedirs(os.path.dirname(basepath), exist_ok=True)
            plt.savefig(basepath + '.png', dpi=600)
            plt.savefig(basepath + '.pdf', dpi=600)
            plt.close(fig)  # close the figure
            show_figure(fig)
            plt.show()
    return figs