Ejemplo n.º 1
0
 def plot_all_patterns(self,
                       kind='seq',
                       trim_frac=0,
                       n_min_seqlets=None,
                       ylim=None,
                       no_axis=False,
                       **kwargs):
     """
     Args:
       kind:
     """
     self.stats()  # print stats
     for pattern in self.patterns():
         if n_min_seqlets is not None and self.n_seqlets(
                 *pattern.split("/")) < n_min_seqlets:
             continue
         self.plot_pattern(pattern,
                           kind=kind,
                           trim_frac=trim_frac,
                           **kwargs)
         if ylim is not None:
             plt.ylim(ylim)
         if no_axis:
             strip_axis(plt.gca())
             plt.gca().axison = False
Ejemplo n.º 2
0
def tidy_motif_plot(ax=None):
    if ax is None:
        ax = plt.gca()
    strip_axis(ax)
    ax.set_xlabel(None)
    ax.get_xaxis().set_visible(False)
    ax.set_ylim([0, 2.0])
    ax.set_ylabel("IC")
Ejemplo n.º 3
0
def plot_profiles_single(seqlet,
                         x,
                         tracks,
                         importance_scores={},
                         figsize=(20, 2),
                         legend=True,
                         rotate_y=90,
                         seq_height=1,
                         flip_neg=False,
                         rc_fn=lambda x: x[::-1, ::-1]):
    """
    Plot the sequence profiles
    Args:
      x: one-hot-encoded sequence
      tracks: dictionary of profile tracks
      importance_scores: optional dictionary of importance scores

    """
    import matplotlib.pyplot as plt
    from concise.utils.plot import seqlogo_fig, seqlogo

    # --------------
    # extract signal
    seq = seqlet.extract(x)
    ext_importance_scores = {
        s: seqlet.extract(imp)
        for s, imp in importance_scores.items()
    }

    fig, ax = plt.subplots(1 + len(importance_scores) + len(tracks),
                           1,
                           sharex=True,
                           figsize=figsize,
                           gridspec_kw={
                               'height_ratios': [1] * len(tracks) +
                               [seq_height] * (1 + len(importance_scores))
                           })

    # signal
    for i, (k, signal) in enumerate(tracks.items()):
        plot_stranded_profile(seqlet.extract(signal),
                              ax=ax[i],
                              flip_neg=flip_neg)
        simple_yaxis_format(ax[i])
        strip_axis(ax[i])
        ax[i].set_ylabel(f"{k}", rotation=rotate_y, ha='right', labelpad=5)

        if legend:
            ax[i].legend()

    # -----------
    # importance scores (seqlogo)
    # -----------
    max_scale = max([
        np.maximum(v, 0).sum(axis=-1).max()
        for v in ext_importance_scores.values()
    ])
    min_scale = min([
        np.minimum(v, 0).sum(axis=-1).min()
        for v in ext_importance_scores.values()
    ])
    for k, (imp_score_name, logo) in enumerate(ext_importance_scores.items()):
        ax_id = len(tracks) + k
        # plot
        ax[ax_id].set_ylim([min_scale, max_scale])
        ax[ax_id].axhline(y=0, linewidth=1, linestyle='--', color='grey')
        seqlogo(logo, ax=ax[ax_id])

        # style
        simple_yaxis_format(ax[ax_id])
        strip_axis(ax[ax_id])
        # ax[ax_id].set_ylabel(imp_score_name)
        ax[ax_id].set_ylabel(imp_score_name,
                             rotation=rotate_y,
                             ha='right',
                             labelpad=5)  # va='bottom',

    # -----------
    # information content (seqlogo)
    # -----------
    # plot
    seqlogo(seq, ax=ax[-1])

    # style
    simple_yaxis_format(ax[-1])
    strip_axis(ax[-1])
    ax[-1].set_ylabel("Inf. content",
                      rotation=rotate_y,
                      ha='right',
                      labelpad=5)
    ax[-1].set_xticks(list(range(0, len(seq) + 1, 5)))
    return fig
Ejemplo n.º 4
0
def plot_profiles(
        seqlets_by_pattern,
        x,
        tracks,
        importance_scores={},
        figsize=(20, 2),
        start_vec=None,
        width=20,
        legend=True,
        rotate_y=90,
        seq_height=1,
        ymax=None,  # determine y-max
        n_limit=35,
        n_bootstrap=None,
        flip_neg=False,
        patterns=None,
        fpath_template=None,
        only_idx=None,
        mkdir=False,
        rc_fn=lambda x: x[::-1, ::-1]):
    """
    Plot the sequence profiles
    Args:
      x: one-hot-encoded sequence
      tracks: dictionary of profile tracks
      importance_scores: optional dictionary of importance scores

    """
    import matplotlib.pyplot as plt
    from concise.utils.plot import seqlogo_fig, seqlogo

    # Setup start-vec
    if start_vec is not None:
        if not isinstance(start_vec, list):
            start_vec = [start_vec] * len(patterns)
    else:
        start_vec = [0] * len(patterns)
        width = len(x)

    if patterns is None:
        patterns = list(seqlets_by_pattern)
    # aggregated profiles
    d_signal_patterns = {
        pattern: {
            k: aggregate_profiles(extract_signal(
                y, seqlets_by_pattern[pattern])[:,
                                                start_vec[ip]:(start_vec[ip] +
                                                               width)],
                                  n_bootstrap=n_bootstrap,
                                  only_idx=only_idx)
            for k, y in tracks.items()
        }
        for ip, pattern in enumerate(patterns)
    }
    if ymax is None:
        # infer ymax
        def take_max(x, dx):
            if dx is None:
                return x.max()
            else:
                # HACK - hard-coded 2
                return (x + 2 * dx).max()

        ymax = [
            max([
                take_max(*d_signal_patterns[pattern][k])
                for pattern in patterns
            ]) for k in tracks
        ]  # loop through all the tracks
    if not isinstance(ymax, list):
        ymax = [ymax] * len(tracks)

    figs = []
    for i, pattern in enumerate(tqdm(patterns)):
        j = i
        # --------------
        # extract signal
        seqs = extract_signal(
            x,
            seqlets_by_pattern[pattern])[:,
                                         start_vec[i]:(start_vec[i] + width)]
        ext_importance_scores = {
            s: extract_signal(
                imp, seqlets_by_pattern[pattern])[:,
                                                  start_vec[i]:(start_vec[i] +
                                                                width)]
            for s, imp in importance_scores.items()
        }
        d_signal = d_signal_patterns[pattern]
        # --------------
        if only_idx is None:
            sequence = ic_scale(seqs.mean(axis=0))
        else:
            sequence = seqs[only_idx]

        n = len(seqs)
        if n < n_limit:
            continue
        fig, ax = plt.subplots(1 + len(importance_scores) + len(tracks),
                               1,
                               sharex=True,
                               figsize=figsize,
                               gridspec_kw={
                                   'height_ratios': [1] * len(tracks) +
                                   [seq_height] * (1 + len(importance_scores))
                               })

        # signal
        ax[0].set_title(f"{pattern} ({n})")
        for i, (k, signal) in enumerate(d_signal.items()):
            signal_mean, signal_std = d_signal_patterns[pattern][k]
            plot_stranded_profile(signal_mean,
                                  ax=ax[i],
                                  ymax=ymax[i],
                                  profile_std=signal_std,
                                  flip_neg=flip_neg)
            simple_yaxis_format(ax[i])
            strip_axis(ax[i])
            ax[i].set_ylabel(f"{k}", rotation=rotate_y, ha='right', labelpad=5)

            if legend:
                ax[i].legend()

        # -----------
        # importance scores (seqlogo)
        # -----------
        # average the importance scores
        if only_idx is None:
            norm_importance_scores = {
                k: v.mean(axis=0)
                for k, v in ext_importance_scores.items()
            }
        else:
            norm_importance_scores = {
                k: v[only_idx]
                for k, v in ext_importance_scores.items()
            }

        max_scale = max([
            np.maximum(v, 0).sum(axis=-1).max()
            for v in norm_importance_scores.values()
        ])
        min_scale = min([
            np.minimum(v, 0).sum(axis=-1).min()
            for v in norm_importance_scores.values()
        ])
        for k, (imp_score_name,
                logo) in enumerate(norm_importance_scores.items()):
            ax_id = len(tracks) + k

            # Trim the pattern if necessary
            # plot
            ax[ax_id].set_ylim([min_scale, max_scale])
            ax[ax_id].axhline(y=0, linewidth=1, linestyle='--', color='grey')
            seqlogo(logo, ax=ax[ax_id])

            # style
            simple_yaxis_format(ax[ax_id])
            strip_axis(ax[ax_id])
            # ax[ax_id].set_ylabel(imp_score_name)
            ax[ax_id].set_ylabel(imp_score_name,
                                 rotation=rotate_y,
                                 ha='right',
                                 labelpad=5)  # va='bottom',

        # -----------
        # information content (seqlogo)
        # -----------
        # plot
        seqlogo(sequence, ax=ax[-1])

        # style
        simple_yaxis_format(ax[-1])
        strip_axis(ax[-1])
        ax[-1].set_ylabel("Inf. content",
                          rotation=rotate_y,
                          ha='right',
                          labelpad=5)
        ax[-1].set_xticks(list(range(0, len(sequence) + 1, 5)))

        figs.append(fig)
        # save to file
        if fpath_template is not None:
            pname = pattern.replace("/", ".")
            basepath = fpath_template.format(pname=pname, pattern=pattern)
            if mkdir:
                os.makedirs(os.path.dirname(basepath), exist_ok=True)
            plt.savefig(basepath + '.png', dpi=600)
            plt.savefig(basepath + '.pdf', dpi=600)
            plt.close(fig)  # close the figure
            show_figure(fig)
            plt.show()
    return figs
Ejemplo n.º 5
0
def plot_patterns(patterns,
                  tasks,
                  pattern_trim=(24, 41),
                  fp_slice=slice(10, 190),
                  n_blank=1):

    df_info = get_df_info(patterns, tasks)
    max_vals = {t: df_info.max()[t + "_max"] for t in tasks}

    fig, axes = plt.subplots(len(patterns),
                             3 + len(tasks),
                             figsize=get_figsize(1, aspect=1.2))
    fig.subplots_adjust(hspace=0, wspace=0)

    # Get the ylim for each TF
    contrib_ylim = {
        tf: (min([
            p.contrib[p.attrs['TF']].min() for p in patterns
            if p.attrs['TF'] == tf
        ]),
             max([
                 p.contrib[p.attrs['TF']].max() for p in patterns
                 if p.attrs['TF'] == tf
             ]))
        for tf in tasks
    }

    max_digits = max([len(str(p.attrs["n_seqlets"])) for p in patterns])
    for i, p in enumerate(patterns):
        # Motif logo
        ax = axes[i, 0]
        p = p.trim(*pattern_trim).rc()

        seqlogo(p.contrib[p.attrs['TF']], ax=ax)
        ax.set_ylim(
            contrib_ylim[p.attrs['TF']])  # all plots have the same shape
        strip_axis(ax)
        ax.axison = False
        pos1 = ax.get_position()  # get the original position
        extra_x = pos1.width * 0.2
        pos2 = [
            pos1.x0, pos1.y0 + pos1.height * 0.4, pos1.width + extra_x,
            pos1.height * .5
        ]
        ax.set_position(pos2)  # set a new position
        if i == 0:
            ax.set_title("Importance\nscore")

        # Text columns before
        if "/" in p.name:
            pid = p.name.split("_")[-1]
        else:
            pid = p.name.replace("m0_p", "")

        # Oct4/1 150
        # str_n_seqlets = "%*d" % (max_digits, p.attrs["n_seqlets"])
        ax.text(-9,
                0,
                p.attrs["TF"] + "/" + pid,
                fontsize=8,
                horizontalalignment='right')
        ax.text(-1,
                0,
                str(p.attrs['n_seqlets']),
                fontsize=8,
                horizontalalignment='right')

        ax = axes[i, 1]
        seqlogo(p.get_seq_ic(), ax=ax)
        ax.set_ylim([0, 2])  # all plots have the same shape
        strip_axis(ax)
        ax.axison = False
        pos1 = ax.get_position()  # get the original position
        pos2 = [
            pos1.x0 + extra_x, pos1.y0 + pos1.height * 0.4,
            pos1.width + extra_x, pos1.height * .5
        ]
        ax.set_position(pos2)  # set a new position
        if i == 0:
            ax.set_title("Sequence\ninfo. content")

        # ax.text(22, 1, i_to_motif_names[i], fontsize=8, horizontalalignment='center')

        ax = axes[i, 2]
        blank_ax(ax)
        pos1 = ax.get_position()  # get the original position
        pos2 = [
            pos1.x0 + extra_x * 2, pos1.y0 + pos1.height * 0.4,
            pos1.width - 2 * extra_x, pos1.height * .5
        ]
        ax.set_position(pos2)  # set a new position
        # if i == 0:
        #    ax.set_title("Assumed\nmotif")

        # Profile columns
        for j, task in enumerate(tasks):
            ax = axes[i, 3 + j]
            fp = p.profile[task]

            ax.plot(fp[fp_slice, 0], color=tf_colors[task])
            ax.plot(-fp[fp_slice, 1], color=tf_colors[task],
                    alpha=0.5)  # negative strand

            ax.set_ylim([-max_vals[task],
                         max_vals[task]])  # all plots have the same shape
            strip_axis(ax)
            ax.axison = False

            if i == 0:
                ax.set_title(task)
    return fig
Ejemplo n.º 6
0
def blank_ax(ax):
    strip_axis(ax)
    ax.axison = False