Пример #1
0
 def get_pssm(self, metacluster, pattern, rc=False, trim_frac=None):
     pattern_grp = self.get_pattern_grp(metacluster, pattern)
     pssm = ic_scale(pattern_grp["sequence"]['fwd'][:])
     if trim_frac is not None:
         pssm = trim_pssm(pssm, trim_frac)
     if rc:
         pssm = pssm[::-1, ::-1]
     return pssm
Пример #2
0
def plot_profiles(
        seqlets_by_pattern,
        x,
        tracks,
        importance_scores={},
        figsize=(20, 2),
        start_vec=None,
        width=20,
        legend=True,
        rotate_y=90,
        seq_height=1,
        ymax=None,  # determine y-max
        n_limit=35,
        n_bootstrap=None,
        flip_neg=False,
        patterns=None,
        fpath_template=None,
        only_idx=None,
        mkdir=False,
        rc_fn=lambda x: x[::-1, ::-1]):
    """
    Plot the sequence profiles
    Args:
      x: one-hot-encoded sequence
      tracks: dictionary of profile tracks
      importance_scores: optional dictionary of importance scores

    """
    import matplotlib.pyplot as plt
    from concise.utils.plot import seqlogo_fig, seqlogo

    # Setup start-vec
    if start_vec is not None:
        if not isinstance(start_vec, list):
            start_vec = [start_vec] * len(patterns)
    else:
        start_vec = [0] * len(patterns)
        width = len(x)

    if patterns is None:
        patterns = list(seqlets_by_pattern)
    # aggregated profiles
    d_signal_patterns = {
        pattern: {
            k: aggregate_profiles(extract_signal(
                y, seqlets_by_pattern[pattern])[:,
                                                start_vec[ip]:(start_vec[ip] +
                                                               width)],
                                  n_bootstrap=n_bootstrap,
                                  only_idx=only_idx)
            for k, y in tracks.items()
        }
        for ip, pattern in enumerate(patterns)
    }
    if ymax is None:
        # infer ymax
        def take_max(x, dx):
            if dx is None:
                return x.max()
            else:
                # HACK - hard-coded 2
                return (x + 2 * dx).max()

        ymax = [
            max([
                take_max(*d_signal_patterns[pattern][k])
                for pattern in patterns
            ]) for k in tracks
        ]  # loop through all the tracks
    if not isinstance(ymax, list):
        ymax = [ymax] * len(tracks)

    figs = []
    for i, pattern in enumerate(tqdm(patterns)):
        j = i
        # --------------
        # extract signal
        seqs = extract_signal(
            x,
            seqlets_by_pattern[pattern])[:,
                                         start_vec[i]:(start_vec[i] + width)]
        ext_importance_scores = {
            s: extract_signal(
                imp, seqlets_by_pattern[pattern])[:,
                                                  start_vec[i]:(start_vec[i] +
                                                                width)]
            for s, imp in importance_scores.items()
        }
        d_signal = d_signal_patterns[pattern]
        # --------------
        if only_idx is None:
            sequence = ic_scale(seqs.mean(axis=0))
        else:
            sequence = seqs[only_idx]

        n = len(seqs)
        if n < n_limit:
            continue
        fig, ax = plt.subplots(1 + len(importance_scores) + len(tracks),
                               1,
                               sharex=True,
                               figsize=figsize,
                               gridspec_kw={
                                   'height_ratios': [1] * len(tracks) +
                                   [seq_height] * (1 + len(importance_scores))
                               })

        # signal
        ax[0].set_title(f"{pattern} ({n})")
        for i, (k, signal) in enumerate(d_signal.items()):
            signal_mean, signal_std = d_signal_patterns[pattern][k]
            plot_stranded_profile(signal_mean,
                                  ax=ax[i],
                                  ymax=ymax[i],
                                  profile_std=signal_std,
                                  flip_neg=flip_neg)
            simple_yaxis_format(ax[i])
            strip_axis(ax[i])
            ax[i].set_ylabel(f"{k}", rotation=rotate_y, ha='right', labelpad=5)

            if legend:
                ax[i].legend()

        # -----------
        # importance scores (seqlogo)
        # -----------
        # average the importance scores
        if only_idx is None:
            norm_importance_scores = {
                k: v.mean(axis=0)
                for k, v in ext_importance_scores.items()
            }
        else:
            norm_importance_scores = {
                k: v[only_idx]
                for k, v in ext_importance_scores.items()
            }

        max_scale = max([
            np.maximum(v, 0).sum(axis=-1).max()
            for v in norm_importance_scores.values()
        ])
        min_scale = min([
            np.minimum(v, 0).sum(axis=-1).min()
            for v in norm_importance_scores.values()
        ])
        for k, (imp_score_name,
                logo) in enumerate(norm_importance_scores.items()):
            ax_id = len(tracks) + k

            # Trim the pattern if necessary
            # plot
            ax[ax_id].set_ylim([min_scale, max_scale])
            ax[ax_id].axhline(y=0, linewidth=1, linestyle='--', color='grey')
            seqlogo(logo, ax=ax[ax_id])

            # style
            simple_yaxis_format(ax[ax_id])
            strip_axis(ax[ax_id])
            # ax[ax_id].set_ylabel(imp_score_name)
            ax[ax_id].set_ylabel(imp_score_name,
                                 rotation=rotate_y,
                                 ha='right',
                                 labelpad=5)  # va='bottom',

        # -----------
        # information content (seqlogo)
        # -----------
        # plot
        seqlogo(sequence, ax=ax[-1])

        # style
        simple_yaxis_format(ax[-1])
        strip_axis(ax[-1])
        ax[-1].set_ylabel("Inf. content",
                          rotation=rotate_y,
                          ha='right',
                          labelpad=5)
        ax[-1].set_xticks(list(range(0, len(sequence) + 1, 5)))

        figs.append(fig)
        # save to file
        if fpath_template is not None:
            pname = pattern.replace("/", ".")
            basepath = fpath_template.format(pname=pname, pattern=pattern)
            if mkdir:
                os.makedirs(os.path.dirname(basepath), exist_ok=True)
            plt.savefig(basepath + '.png', dpi=600)
            plt.savefig(basepath + '.pdf', dpi=600)
            plt.close(fig)  # close the figure
            show_figure(fig)
            plt.show()
    return figs
Пример #3
0
def get_modisco_rank_scorer(loaded_tfmodisco_results, seqlet_size_to_score_with=25, n_cores=1, trim_pattern=False):
    """

    Args:
      loaded_tfmodisco_results: tf-modisco result containing the track_seq
      seqlet_size_to_score_with: width of the scored seqlet
      n_cores: number of cores to use for nearest-neighbour computation

    Returns:
      cross_metacluster_scorer, all_pattern_names, metacluster_idx_to_scorer
    """
    import modisco
    from modisco import affinitymat
    from modisco import hit_scoring
    from modisco import aggregator
    task_names = loaded_tfmodisco_results.task_names
    metacluster_idx_to_scorer = OrderedDict()
    all_pattern_scorers = []
    all_pattern_names = []

    # loop through the metaclusters
    for metacluster_name in sorted(loaded_tfmodisco_results.metacluster_idx_to_submetacluster_results.keys()):

        submetacluster_results = (loaded_tfmodisco_results.metacluster_idx_to_submetacluster_results[metacluster_name])

        activity_pattern = submetacluster_results.activity_pattern

        relevant_task_names = [task_name for (task_name, x) in
                               zip(task_names, activity_pattern) if np.abs(x) != 0]

        if trim_pattern:
            trim_sizes = {}
            trimmed_patterns = []
            for pattern_idx, pattern in\
                enumerate(submetacluster_results.
                          seqlets_to_patterns_result.patterns):
                pssm = ic_scale(pattern["sequence"].fwd)
                t1, t2 = trim_pssm_idx(pssm)
                trim_sizes[pattern_idx] = t2 - t1
                trimmer = aggregator.TrimToBestWindow(
                    window_size=trim_sizes[pattern_idx],
                    track_names=([x + "_contrib_scores" for x in relevant_task_names]
                                 + [x + "_hypothetical_contribs" for x in relevant_task_names]))
                trimmed_patterns.extend(trimmer([pattern]))

            submetacluster_results.seqlets_to_patterns_result.patterns = trimmed_patterns

        pattern_comparison_settings = affinitymat.core.PatternComparisonSettings(
            track_names=([x + "_contrib_scores" for x in relevant_task_names] +
                         [x + "_hypothetical_contribs" for x in relevant_task_names]),  # only compare across relevant tasks
            track_transformer=affinitymat.L1Normalizer(),
            min_overlap=0.7)

        pattern_to_seqlets_sim_computer = hit_scoring.PatternsToSeqletsSimComputer(
            pattern_comparison_settings=pattern_comparison_settings,
            cross_metric_computer=affinitymat.core.ParallelCpuCrossMetricOnNNpairs(
                n_cores=n_cores,
                cross_metric_single_region=affinitymat.core.CrossContinJaccardSingleRegionWithArgmax(),
                verbose=False),
            seqlet_trimmer=modisco.hit_scoring.SeqletTrimToBestWindow(
                window_size=seqlet_size_to_score_with,
                track_names=[x + "_contrib_scores" for x
                             in relevant_task_names])
        )

        # Get a list of scorers for all the patterns in the metacluster
        metacluster_pattern_scorers = []
        if submetacluster_results.seqlets_to_patterns_result.patterns is None or \
           len(submetacluster_results.seqlets_to_patterns_result.patterns) == 0:
            # metacluster has no patterns
            # don't append anything
            continue

        for pattern_idx, pattern in\
            enumerate(submetacluster_results.
                      seqlets_to_patterns_result.patterns):
            metacluster_idx = int(metacluster_name.split("_")[1])
            all_pattern_names.append("metacluster_" + str(metacluster_idx)
                                     + ",pattern_" + str(pattern_idx))
            if trim_pattern:
                pattern_to_seqlets_sim_computer = hit_scoring.PatternsToSeqletsSimComputer(
                    pattern_comparison_settings=pattern_comparison_settings,
                    cross_metric_computer=affinitymat.core.ParallelCpuCrossMetricOnNNpairs(
                        n_cores=n_cores,
                        cross_metric_single_region=affinitymat.core.CrossContinJaccardSingleRegionWithArgmax(),
                        verbose=False),
                    seqlet_trimmer=modisco.hit_scoring.SeqletTrimToBestWindow(
                        window_size=min(seqlet_size_to_score_with, trim_sizes[pattern_idx]),
                        track_names=[x + "_contrib_scores" for x
                                     in relevant_task_names])
                )
            pattern_scorer = hit_scoring.RankBasedPatternScorer(
                aggseqlets=pattern,
                patterns_to_seqlets_sim_computer=pattern_to_seqlets_sim_computer)
            metacluster_pattern_scorers.append(pattern_scorer)
            all_pattern_scorers.append(pattern_scorer)
        # This is the final scorer for the metacluster;
        # it takes the maximum score produced by all the
        # individual scorers
        max_rank_based_pattern_scorer = hit_scoring.MaxRankBasedPatternScorer(pattern_scorers=metacluster_pattern_scorers)

        metacluster_idx_to_scorer[metacluster_idx] = max_rank_based_pattern_scorer

    cross_metacluster_scorer = hit_scoring.MaxRankBasedPatternScorer(pattern_scorers=all_pattern_scorers)

    return cross_metacluster_scorer, all_pattern_names, metacluster_idx_to_scorer
Пример #4
0
    def plot_profiles(self,
                      x,
                      tracks,
                      importance_scores={},
                      figsize=(20, 2),
                      rc_vec=None,
                      start_vec=None,
                      width=20,
                      legend=True,
                      seq_height=2,
                      ylim=[0, 3],
                      n_limit=35,
                      n_bootstrap=None,
                      patterns=None,
                      fpath_template=None,
                      rc_fn=lambda x: x[::-1, ::-1]):
        """
        Plot the sequence profiles
        Args:
          x: one-hot-encoded sequence
          tracks: dictionary of profile tracks
          importance_scores: optional dictionary of importance scores

        TODO - add the reverse complementation option to it
        """
        import matplotlib.pyplot as plt
        from concise.utils.plot import seqlogo_fig, seqlogo

        seqs_all = self.extract_signal(x)
        ext_importance_scores = {
            s: self.extract_signal(imp)
            for s, imp in importance_scores.items()
        }
        # TODO assert correct shape in imp

        if patterns is None:
            patterns = self.patterns()
        for i, pattern in enumerate(patterns):
            j = i
            seqs = seqs_all[pattern]
            sequence = ic_scale(seqs.mean(axis=0))
            if rc_vec is not None and rc_vec[i]:
                rc_seq = True
                sequence = rc_fn(sequence)
            else:
                rc_seq = False
            if start_vec is not None:
                start = start_vec[i]
                sequence = sequence[start:(start + width)]
            n = len(seqs)
            if n < n_limit:
                continue
            fig, ax = plt.subplots(1 + len(importance_scores) + len(tracks),
                                   1,
                                   sharex=True,
                                   figsize=figsize,
                                   gridspec_kw={
                                       'height_ratios':
                                       [1] * len(tracks) + [seq_height] *
                                       (1 + len(importance_scores))
                                   })
            ax[0].set_title(f"{pattern} ({n})")
            for i, (k, y) in enumerate(tracks.items()):
                signal = self.extract_signal(y, rc_fn)[pattern]

                if start_vec is not None:
                    start = start_vec[i]
                    signal = signal[:, start:(start + width)]

                if n_bootstrap is None:
                    signal_mean = signal.mean(axis=0)
                    signal_std = signal.std(axis=0)
                else:
                    signal_mean, signal_std = bootstrap_mean(signal,
                                                             n=n_bootstrap)
                if rc_vec is not None and rc_vec[i]:
                    signal_mean = rc_fn(signal_mean)
                    signal_std = rc_fn(signal_std)

                ax[i].plot(np.arange(1,
                                     len(signal_mean) + 1),
                           signal_mean[:, 0],
                           label='pos')
                if n_bootstrap is not None:
                    ax[i].fill_between(
                        np.arange(1,
                                  len(signal_mean) + 1),
                        signal_mean[:, 0] - 2 * signal_std[:, 0],
                        signal_mean[:, 0] + 2 * signal_std[:, 0],
                        alpha=0.1)
                #                   label='pos')
                # plot also the other strand
                if signal_mean.shape[1] == 2:
                    ax[i].plot(np.arange(1,
                                         len(signal_mean) + 1),
                               signal_mean[:, 1],
                               label='neg')
                    if n_bootstrap is not None:
                        ax[i].fill_between(
                            np.arange(1,
                                      len(signal_mean) + 1),
                            signal_mean[:, 1] - 2 * signal_std[:, 1],
                            signal_mean[:, 1] + 2 * signal_std[:, 1],
                            alpha=0.1)
                #                   label='pos')
                ax[i].set_ylabel(f"{k}")
                ax[i].yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
                ax[i].spines['top'].set_visible(False)
                ax[i].spines['right'].set_visible(False)
                ax[i].spines['bottom'].set_visible(False)
                ax[i].xaxis.set_ticks_position('none')
                if isinstance(ylim[0], list):
                    ax[i].set_ylim(ylim[i])
                if legend:
                    ax[i].legend()

            for k, (imp_score_name,
                    values) in enumerate(ext_importance_scores.items()):
                ax_id = len(tracks) + k
                logo = values[pattern].mean(axis=0)
                # Trim the pattern if necessary
                if rc_seq:
                    logo = rc_fn(logo)
                if start_vec is not None:
                    start = start_vec[j]
                    logo = logo[start:(start + width)]
                seqlogo(logo, ax=ax[ax_id])
                ax[ax_id].yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
                ax[ax_id].set_ylabel(imp_score_name)
                ax[ax_id].spines['top'].set_visible(False)
                ax[ax_id].spines['right'].set_visible(False)
                ax[ax_id].spines['bottom'].set_visible(False)
                ax[ax_id].xaxis.set_ticks_position('none')

            seqlogo(sequence, ax=ax[-1])
            ax[-1].yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
            ax[-1].set_ylabel("Inf. content")
            ax[-1].spines['top'].set_visible(False)
            ax[-1].spines['right'].set_visible(False)
            ax[-1].spines['bottom'].set_visible(False)
            ax[-1].set_xticks(list(range(0, len(sequence) + 1, 5)))

            if fpath_template is not None:
                pname = pattern.replace("/", ".")
                plt.savefig(fpath_template.format(pname) + '.png', dpi=600)
                plt.savefig(fpath_template.format(pname) + '.pdf', dpi=600)
                plt.close(fig)  # close the figure
                show_figure(fig)
                plt.show()
Пример #5
0
 def get_seq_ic(self):
     """Get the sequence on the information-content scale
     """
     return ic_scale(self.seq)