Example #1
0
def vdom_pssm(pssm, letter_width=0.2, letter_height=0.8, **kwargs):
    """Nicely plot the pssm
    """
    import matplotlib.pyplot as plt
    from concise.utils.plot import seqlogo_fig, seqlogo
    fig, ax = plt.subplots(figsize=(letter_width * len(pssm), letter_height))
    ax.axison = False
    seqlogo(pssm, ax=ax)
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
    return fig2vdom(fig, **kwargs)
Example #2
0
def seqlogo_clean(seq, letter_width=0.2, height=0.8, title=None):
    import matplotlib.pyplot as plt
    from concise.utils.plot import seqlogo
    fig, ax = plt.subplots(figsize=(letter_width * len(seq), height))
    ax.axison = False
    seqlogo(seq, ax=ax)
    if title is not None:
        ax.set_title(title)
    else:
        plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
    return fig
Example #3
0
def plot_filters(W, ncol=2, figsize=(10,10)):
    """Plot convolutional filters as motifs
    
    Args:
      weights: weights returned by `model.layers[0].get_weights()[0]`
      ncol: number of columns in the plot
      figsize: Matplotlib figure size (width, height)
    """
    N = W.shape[2]
    nrow = int(np.ceil(N/ncol))
    fig, ax = plt.subplots(nrow, ncol, figsize=figsize)
    for i in range(N):
        ax = fig.axes[i]
        seqlogo(W[:,:,i], ax=ax);
        ax.set_title(f"Filter: {i}")
    plt.tight_layout()
Example #4
0
    def plot_profiles(self,
                      x,
                      tracks,
                      contribution_scores={},
                      figsize=(20, 2),
                      rc_vec=None,
                      start_vec=None,
                      width=20,
                      legend=True,
                      seq_height=2,
                      ylim=[0, 3],
                      n_limit=35,
                      n_bootstrap=None,
                      pattern_names=None,
                      fpath_template=None,
                      rc_fn=lambda x: x[::-1, ::-1]):
        """
        Plot the sequence profiles
        Args:
          x: one-hot-encoded sequence
          tracks: dictionary of profile tracks
          contribution_scores: optional dictionary of contribution scores

        TODO - add the reverse complementation option to it
        """
        import matplotlib.pyplot as plt
        from concise.utils.plot import seqlogo_fig, seqlogo

        seqs_all = self.extract_signal(x)
        ext_contribution_scores = {
            s: self.extract_signal(contrib)
            for s, contrib in contribution_scores.items()
        }
        # TODO assert correct shape in contrib

        if pattern_names is None:
            pattern_names = self.pattern_names()
        for i, pattern in enumerate(pattern_names):
            j = i
            seqs = seqs_all[pattern]
            sequence = ic_scale(seqs.mean(axis=0))
            if rc_vec is not None and rc_vec[i]:
                rc_seq = True
                sequence = rc_fn(sequence)
            else:
                rc_seq = False
            if start_vec is not None:
                start = start_vec[i]
                sequence = sequence[start:(start + width)]
            n = len(seqs)
            if n < n_limit:
                continue
            fig, ax = plt.subplots(1 + len(contribution_scores) + len(tracks),
                                   1,
                                   sharex=True,
                                   figsize=figsize,
                                   gridspec_kw={
                                       'height_ratios':
                                       [1] * len(tracks) + [seq_height] *
                                       (1 + len(contribution_scores))
                                   })
            ax[0].set_title(f"{pattern} ({n})")
            for i, (k, y) in enumerate(tracks.items()):
                signal = self.extract_signal(y, rc_fn)[pattern]

                if start_vec is not None:
                    start = start_vec[i]
                    signal = signal[:, start:(start + width)]

                if n_bootstrap is None:
                    signal_mean = signal.mean(axis=0)
                    signal_std = signal.std(axis=0)
                else:
                    signal_mean, signal_std = bootstrap_mean(signal,
                                                             n=n_bootstrap)
                if rc_vec is not None and rc_vec[i]:
                    signal_mean = rc_fn(signal_mean)
                    signal_std = rc_fn(signal_std)

                ax[i].plot(np.arange(1,
                                     len(signal_mean) + 1),
                           signal_mean[:, 0],
                           label='pos')
                if n_bootstrap is not None:
                    ax[i].fill_between(
                        np.arange(1,
                                  len(signal_mean) + 1),
                        signal_mean[:, 0] - 2 * signal_std[:, 0],
                        signal_mean[:, 0] + 2 * signal_std[:, 0],
                        alpha=0.1)
                #                   label='pos')
                # plot also the other strand
                if signal_mean.shape[1] == 2:
                    ax[i].plot(np.arange(1,
                                         len(signal_mean) + 1),
                               signal_mean[:, 1],
                               label='neg')
                    if n_bootstrap is not None:
                        ax[i].fill_between(
                            np.arange(1,
                                      len(signal_mean) + 1),
                            signal_mean[:, 1] - 2 * signal_std[:, 1],
                            signal_mean[:, 1] + 2 * signal_std[:, 1],
                            alpha=0.1)
                #                   label='pos')
                ax[i].set_ylabel(f"{k}")
                ax[i].yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
                ax[i].spines['top'].set_visible(False)
                ax[i].spines['right'].set_visible(False)
                ax[i].spines['bottom'].set_visible(False)
                ax[i].xaxis.set_ticks_position('none')
                if isinstance(ylim[0], list):
                    ax[i].set_ylim(ylim[i])
                if legend:
                    ax[i].legend()

            for k, (contrib_score_name,
                    values) in enumerate(ext_contribution_scores.items()):
                ax_id = len(tracks) + k
                logo = values[pattern].mean(axis=0)
                # Trim the pattern if necessary
                if rc_seq:
                    logo = rc_fn(logo)
                if start_vec is not None:
                    start = start_vec[j]
                    logo = logo[start:(start + width)]
                seqlogo(logo, ax=ax[ax_id])
                ax[ax_id].yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
                ax[ax_id].set_ylabel(contrib_score_name)
                ax[ax_id].spines['top'].set_visible(False)
                ax[ax_id].spines['right'].set_visible(False)
                ax[ax_id].spines['bottom'].set_visible(False)
                ax[ax_id].xaxis.set_ticks_position('none')

            seqlogo(sequence, ax=ax[-1])
            ax[-1].yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
            ax[-1].set_ylabel("Inf. content")
            ax[-1].spines['top'].set_visible(False)
            ax[-1].spines['right'].set_visible(False)
            ax[-1].spines['bottom'].set_visible(False)
            ax[-1].set_xticks(list(range(0, len(sequence) + 1, 5)))

            if fpath_template is not None:
                pname = pattern.replace("/", ".")
                plt.savefig(fpath_template.format(pname) + '.png', dpi=600)
                plt.savefig(fpath_template.format(pname) + '.pdf', dpi=600)
                plt.close(fig)  # close the figure
                show_figure(fig)
                plt.show()
Example #5
0
def plot_profiles_single(seqlet,
                         x,
                         tracks,
                         importance_scores={},
                         figsize=(20, 2),
                         legend=True,
                         rotate_y=90,
                         seq_height=1,
                         flip_neg=False,
                         rc_fn=lambda x: x[::-1, ::-1]):
    """
    Plot the sequence profiles
    Args:
      x: one-hot-encoded sequence
      tracks: dictionary of profile tracks
      importance_scores: optional dictionary of importance scores

    """
    import matplotlib.pyplot as plt
    from concise.utils.plot import seqlogo_fig, seqlogo

    # --------------
    # extract signal
    seq = seqlet.extract(x)
    ext_importance_scores = {
        s: seqlet.extract(imp)
        for s, imp in importance_scores.items()
    }

    fig, ax = plt.subplots(1 + len(importance_scores) + len(tracks),
                           1,
                           sharex=True,
                           figsize=figsize,
                           gridspec_kw={
                               'height_ratios': [1] * len(tracks) +
                               [seq_height] * (1 + len(importance_scores))
                           })

    # signal
    for i, (k, signal) in enumerate(tracks.items()):
        plot_stranded_profile(seqlet.extract(signal),
                              ax=ax[i],
                              flip_neg=flip_neg)
        simple_yaxis_format(ax[i])
        strip_axis(ax[i])
        ax[i].set_ylabel(f"{k}", rotation=rotate_y, ha='right', labelpad=5)

        if legend:
            ax[i].legend()

    # -----------
    # importance scores (seqlogo)
    # -----------
    max_scale = max([
        np.maximum(v, 0).sum(axis=-1).max()
        for v in ext_importance_scores.values()
    ])
    min_scale = min([
        np.minimum(v, 0).sum(axis=-1).min()
        for v in ext_importance_scores.values()
    ])
    for k, (imp_score_name, logo) in enumerate(ext_importance_scores.items()):
        ax_id = len(tracks) + k
        # plot
        ax[ax_id].set_ylim([min_scale, max_scale])
        ax[ax_id].axhline(y=0, linewidth=1, linestyle='--', color='grey')
        seqlogo(logo, ax=ax[ax_id])

        # style
        simple_yaxis_format(ax[ax_id])
        strip_axis(ax[ax_id])
        # ax[ax_id].set_ylabel(imp_score_name)
        ax[ax_id].set_ylabel(imp_score_name,
                             rotation=rotate_y,
                             ha='right',
                             labelpad=5)  # va='bottom',

    # -----------
    # information content (seqlogo)
    # -----------
    # plot
    seqlogo(seq, ax=ax[-1])

    # style
    simple_yaxis_format(ax[-1])
    strip_axis(ax[-1])
    ax[-1].set_ylabel("Inf. content",
                      rotation=rotate_y,
                      ha='right',
                      labelpad=5)
    ax[-1].set_xticks(list(range(0, len(seq) + 1, 5)))
    return fig
Example #6
0
def plot_profiles(
        seqlets_by_pattern,
        x,
        tracks,
        importance_scores={},
        figsize=(20, 2),
        start_vec=None,
        width=20,
        legend=True,
        rotate_y=90,
        seq_height=1,
        ymax=None,  # determine y-max
        n_limit=35,
        n_bootstrap=None,
        flip_neg=False,
        patterns=None,
        fpath_template=None,
        only_idx=None,
        mkdir=False,
        rc_fn=lambda x: x[::-1, ::-1]):
    """
    Plot the sequence profiles
    Args:
      x: one-hot-encoded sequence
      tracks: dictionary of profile tracks
      importance_scores: optional dictionary of importance scores

    """
    import matplotlib.pyplot as plt
    from concise.utils.plot import seqlogo_fig, seqlogo

    # Setup start-vec
    if start_vec is not None:
        if not isinstance(start_vec, list):
            start_vec = [start_vec] * len(patterns)
    else:
        start_vec = [0] * len(patterns)
        width = len(x)

    if patterns is None:
        patterns = list(seqlets_by_pattern)
    # aggregated profiles
    d_signal_patterns = {
        pattern: {
            k: aggregate_profiles(extract_signal(
                y, seqlets_by_pattern[pattern])[:,
                                                start_vec[ip]:(start_vec[ip] +
                                                               width)],
                                  n_bootstrap=n_bootstrap,
                                  only_idx=only_idx)
            for k, y in tracks.items()
        }
        for ip, pattern in enumerate(patterns)
    }
    if ymax is None:
        # infer ymax
        def take_max(x, dx):
            if dx is None:
                return x.max()
            else:
                # HACK - hard-coded 2
                return (x + 2 * dx).max()

        ymax = [
            max([
                take_max(*d_signal_patterns[pattern][k])
                for pattern in patterns
            ]) for k in tracks
        ]  # loop through all the tracks
    if not isinstance(ymax, list):
        ymax = [ymax] * len(tracks)

    figs = []
    for i, pattern in enumerate(tqdm(patterns)):
        j = i
        # --------------
        # extract signal
        seqs = extract_signal(
            x,
            seqlets_by_pattern[pattern])[:,
                                         start_vec[i]:(start_vec[i] + width)]
        ext_importance_scores = {
            s: extract_signal(
                imp, seqlets_by_pattern[pattern])[:,
                                                  start_vec[i]:(start_vec[i] +
                                                                width)]
            for s, imp in importance_scores.items()
        }
        d_signal = d_signal_patterns[pattern]
        # --------------
        if only_idx is None:
            sequence = ic_scale(seqs.mean(axis=0))
        else:
            sequence = seqs[only_idx]

        n = len(seqs)
        if n < n_limit:
            continue
        fig, ax = plt.subplots(1 + len(importance_scores) + len(tracks),
                               1,
                               sharex=True,
                               figsize=figsize,
                               gridspec_kw={
                                   'height_ratios': [1] * len(tracks) +
                                   [seq_height] * (1 + len(importance_scores))
                               })

        # signal
        ax[0].set_title(f"{pattern} ({n})")
        for i, (k, signal) in enumerate(d_signal.items()):
            signal_mean, signal_std = d_signal_patterns[pattern][k]
            plot_stranded_profile(signal_mean,
                                  ax=ax[i],
                                  ymax=ymax[i],
                                  profile_std=signal_std,
                                  flip_neg=flip_neg)
            simple_yaxis_format(ax[i])
            strip_axis(ax[i])
            ax[i].set_ylabel(f"{k}", rotation=rotate_y, ha='right', labelpad=5)

            if legend:
                ax[i].legend()

        # -----------
        # importance scores (seqlogo)
        # -----------
        # average the importance scores
        if only_idx is None:
            norm_importance_scores = {
                k: v.mean(axis=0)
                for k, v in ext_importance_scores.items()
            }
        else:
            norm_importance_scores = {
                k: v[only_idx]
                for k, v in ext_importance_scores.items()
            }

        max_scale = max([
            np.maximum(v, 0).sum(axis=-1).max()
            for v in norm_importance_scores.values()
        ])
        min_scale = min([
            np.minimum(v, 0).sum(axis=-1).min()
            for v in norm_importance_scores.values()
        ])
        for k, (imp_score_name,
                logo) in enumerate(norm_importance_scores.items()):
            ax_id = len(tracks) + k

            # Trim the pattern if necessary
            # plot
            ax[ax_id].set_ylim([min_scale, max_scale])
            ax[ax_id].axhline(y=0, linewidth=1, linestyle='--', color='grey')
            seqlogo(logo, ax=ax[ax_id])

            # style
            simple_yaxis_format(ax[ax_id])
            strip_axis(ax[ax_id])
            # ax[ax_id].set_ylabel(imp_score_name)
            ax[ax_id].set_ylabel(imp_score_name,
                                 rotation=rotate_y,
                                 ha='right',
                                 labelpad=5)  # va='bottom',

        # -----------
        # information content (seqlogo)
        # -----------
        # plot
        seqlogo(sequence, ax=ax[-1])

        # style
        simple_yaxis_format(ax[-1])
        strip_axis(ax[-1])
        ax[-1].set_ylabel("Inf. content",
                          rotation=rotate_y,
                          ha='right',
                          labelpad=5)
        ax[-1].set_xticks(list(range(0, len(sequence) + 1, 5)))

        figs.append(fig)
        # save to file
        if fpath_template is not None:
            pname = pattern.replace("/", ".")
            basepath = fpath_template.format(pname=pname, pattern=pattern)
            if mkdir:
                os.makedirs(os.path.dirname(basepath), exist_ok=True)
            plt.savefig(basepath + '.png', dpi=600)
            plt.savefig(basepath + '.pdf', dpi=600)
            plt.close(fig)  # close the figure
            show_figure(fig)
            plt.show()
    return figs
Example #7
0
def plot_patterns(patterns,
                  tasks,
                  pattern_trim=(24, 41),
                  fp_slice=slice(10, 190),
                  n_blank=1):

    df_info = get_df_info(patterns, tasks)
    max_vals = {t: df_info.max()[t + "_max"] for t in tasks}

    fig, axes = plt.subplots(len(patterns),
                             3 + len(tasks),
                             figsize=get_figsize(1, aspect=1.2))
    fig.subplots_adjust(hspace=0, wspace=0)

    # Get the ylim for each TF
    contrib_ylim = {
        tf: (min([
            p.contrib[p.attrs['TF']].min() for p in patterns
            if p.attrs['TF'] == tf
        ]),
             max([
                 p.contrib[p.attrs['TF']].max() for p in patterns
                 if p.attrs['TF'] == tf
             ]))
        for tf in tasks
    }

    max_digits = max([len(str(p.attrs["n_seqlets"])) for p in patterns])
    for i, p in enumerate(patterns):
        # Motif logo
        ax = axes[i, 0]
        p = p.trim(*pattern_trim).rc()

        seqlogo(p.contrib[p.attrs['TF']], ax=ax)
        ax.set_ylim(
            contrib_ylim[p.attrs['TF']])  # all plots have the same shape
        strip_axis(ax)
        ax.axison = False
        pos1 = ax.get_position()  # get the original position
        extra_x = pos1.width * 0.2
        pos2 = [
            pos1.x0, pos1.y0 + pos1.height * 0.4, pos1.width + extra_x,
            pos1.height * .5
        ]
        ax.set_position(pos2)  # set a new position
        if i == 0:
            ax.set_title("Importance\nscore")

        # Text columns before
        if "/" in p.name:
            pid = p.name.split("_")[-1]
        else:
            pid = p.name.replace("m0_p", "")

        # Oct4/1 150
        # str_n_seqlets = "%*d" % (max_digits, p.attrs["n_seqlets"])
        ax.text(-9,
                0,
                p.attrs["TF"] + "/" + pid,
                fontsize=8,
                horizontalalignment='right')
        ax.text(-1,
                0,
                str(p.attrs['n_seqlets']),
                fontsize=8,
                horizontalalignment='right')

        ax = axes[i, 1]
        seqlogo(p.get_seq_ic(), ax=ax)
        ax.set_ylim([0, 2])  # all plots have the same shape
        strip_axis(ax)
        ax.axison = False
        pos1 = ax.get_position()  # get the original position
        pos2 = [
            pos1.x0 + extra_x, pos1.y0 + pos1.height * 0.4,
            pos1.width + extra_x, pos1.height * .5
        ]
        ax.set_position(pos2)  # set a new position
        if i == 0:
            ax.set_title("Sequence\ninfo. content")

        # ax.text(22, 1, i_to_motif_names[i], fontsize=8, horizontalalignment='center')

        ax = axes[i, 2]
        blank_ax(ax)
        pos1 = ax.get_position()  # get the original position
        pos2 = [
            pos1.x0 + extra_x * 2, pos1.y0 + pos1.height * 0.4,
            pos1.width - 2 * extra_x, pos1.height * .5
        ]
        ax.set_position(pos2)  # set a new position
        # if i == 0:
        #    ax.set_title("Assumed\nmotif")

        # Profile columns
        for j, task in enumerate(tasks):
            ax = axes[i, 3 + j]
            fp = p.profile[task]

            ax.plot(fp[fp_slice, 0], color=tf_colors[task])
            ax.plot(-fp[fp_slice, 1], color=tf_colors[task],
                    alpha=0.5)  # negative strand

            ax.set_ylim([-max_vals[task],
                         max_vals[task]])  # all plots have the same shape
            strip_axis(ax)
            ax.axison = False

            if i == 0:
                ax.set_title(task)
    return fig
Example #8
0
        print("unsuccessfull parsing. letter: " + let)

# A, D, O, P


dragonn.plot.standardize_polygons_str(O)

ax = plt.gca()
add_letter_to_axis(ax, "P", 0, 0, 1)

plt.show()

all_polygons = {l: globals()[l] for l in string.ascii_uppercase}


seqlogo(pwm, "DNA")
plt.show()

seqlogo(pwm, "RNA")
plt.show()

seqlogo(pwm, "AA")
plt.show()


aa_matrix = np.random.uniform(0, 1, (10, len(VOCAB_colors["AA"])))
seqlogo(aa_matrix, "AA")
plt.show()

aa_matrix = np.random.uniform(-1, 1, (10, len(VOCAB_colors["AA"])))
seqlogo(aa_matrix, "AA")