Beispiel #1
0
        def fn(dist):
            position = dist + 500
            sstart, send = motif_coords(side_motif, position)
            seqlets = [Seqlet(None, cstart, cend, "center", ""),
                       Seqlet(None, sstart, send, "side", "")]
            # TODO - add also contribution scores
            du = {"p": p[position], "contrib": contrib[position]}

            # TODO - order them correctly
            d = OrderedDict([(f"{prefix}/{task}", du[prefix][task])
                             for task in p[position]
                             for prefix in ['p', 'contrib']])

            ylims = []
            for k in d:
                if k.startswith("p"):
                    ylims.append((0, ymax))
                else:
                    ylims.append((0, ymax_contrib))
            plot_tracks(d,
                        seqlets,
                        title=dist, ylim=ylims)
Beispiel #2
0
    def plot_regions(self, regions, ds=None, variants=None,
                     seqlets=[],
                     pred_summary='profile/wn',
                     contrib_method='grad',
                     batch_size=128,
                     # ylim=None,
                     xlim=None,
                     # seq_height=1,
                     rotate_y=0,
                     add_title=True,
                     fig_height_per_track=2,
                     same_ylim=False,
                     fig_width=20):
        """Plot predictions

        Args:
          regions: list of pybedtools.Interval
          variant: a single instance or a list of bpnet.extractors.Variant
          ds: DataSpec. If provided, the ground truth will be added to the plot
          pred_summary: 'mean' or 'max', summary function name for the profile gradients
        """
        out = self.predict_regions(regions,
                                   variants=variants,
                                   contrib_method=contrib_method,
                                   # pred_summary=pred_summary,
                                   batch_size=batch_size)
        figs = []
        if xlim is None:
            xmin = 0
        else:
            xmin = xlim[0]
        shifted_seqlets = [s.shift(-xmin) for s in seqlets]

        for i in range(len(out)):
            pred = out[i]
            interval = out[i]['interval']

            if ds is not None:
                obs = {task: ds.task_specs[task].load_counts([interval])[0] for task in self.tasks}
            else:
                obs = None

            title = "{i.chrom}:{i.start}-{i.end}, {i.name} {v}".format(i=interval, v=pred.get('variant', ''))

            # handle the DNase case
            if isinstance(pred['seq'], dict):
                seq = pred['seq']['seq']
            else:
                seq = pred['seq']

            if obs is None:
                # TODO - simplify?
                viz_dict = OrderedDict(flatten_list([[
                    (f"{task} Pred", pred['pred'][task]),
                    (f"{task} Contrib profile", pred['contrib_score'][f"{task}/{pred_summary}"] * seq),
                    # (f"{task} Contrib counts", sum(pred['grads'][task_idx]['counts'].values()) / 2 * seq),
                ] for task_idx, task in enumerate(self.tasks)]))
            else:
                viz_dict = OrderedDict(flatten_list([[
                    (f"{task} Pred", pred['pred'][task]),
                    (f"{task} Obs", obs[task]),
                    (f"{task} Contrib profile", pred['contrib_score'][f"{task}/{pred_summary}"] * seq),
                    # (f"{task} Contrib counts", sum(pred['grads'][task_idx]['counts'].values()) / 2 * seq),
                ] for task_idx, task in enumerate(self.tasks)]))

            if add_title:
                title = "{i.chrom}:{i.start}-{i.end}, {i.name} {v}".format(i=interval, v=pred.get('variant', '')),
            else:
                title = None

            if same_ylim:
                fmax = {feature: max([np.abs(viz_dict[f"{task} {feature}"]).max() for task in self.tasks])
                        for feature in ['Pred', 'Contrib profile', 'Obs']}

                ylim = []
                for k in viz_dict:
                    f = k.split(" ", 1)[1]
                    if "Contrib" in f:
                        ylim.append((-fmax[f], fmax[f]))
                    else:
                        ylim.append((0, fmax[f]))
            else:
                ylim = None
            fig = plot_tracks(filter_tracks(viz_dict, xlim),
                              seqlets=shifted_seqlets,
                              title=title,
                              fig_height_per_track=fig_height_per_track,
                              rotate_y=rotate_y,
                              fig_width=fig_width,
                              ylim=ylim,
                              legend=True)
            figs.append(fig)
        return figs
Beispiel #3
0
from bpnet.plot.tracks import plot_tracks, to_neg
import seaborn as sns
import matplotlib.pyplot as plt

cf = ContribFile(contrib_file)

profiles = cf.get_profiles()
contrib_scores = cf.get_contrib()

examples = list({v.max(axis=-2).mean(axis=-1).argmax() for k,v in profiles.items()})
examples

tasks = ['Oct4', 'Sox2', 'Nanog']



fig = plt.figure(figsize=[8,8])

xrange = slice(50, 150)
for idx in examples:
    plot_tracks({**{'profile/' + k: to_neg(v[idx,xrange]) for k,v in profiles.items()},
                **{'contrib/' + k:v[idx,xrange] for k,v in contrib_scores.items()}},
                title=idx,
                rotate_y=0,
                fig_width=10,
                fig_height_per_track=1);
    sns.despine(top=True, right=True, bottom=True)


fig.savefig(f"/groups/lackgrp/ll_members/berkay/SYNTAX/exampleRun/results/train/bpnet/examples/chip-nexus/output/contributions.pdf")
Beispiel #4
0
## convert dfi to Seqlet objects
xrange = slice(50, 150)
seqlets = [s.shift(-xrange.start)
        for s in dfi2seqlets(dfi[dfi.example_idx == idx], short_name=True)]


## Visualize the locus with motif instances highlighted

## get the contribution scores and profile score for that example idx
xrange = slice(300, 700)
cf = ContribFile(contrib_file)
profiles = cf.get_profiles(idx=idx)
contrib_scores = cf.get_contrib(idx=idx)

## Let's focus only on the best match per track
dfi_best = dfi[dfi.example_idx == idx].sort_values("match_weighted_p", ascending=False).groupby('tf').first()

seqlets = [s.shift(-xrange.start) 
        for s in dfi2seqlets(dfi_best, short_name=True)]

fig = plot_tracks({**{'profile/' + k: to_neg(v[xrange]) for k,v in profiles.items()},
            **{'contrib/' + k:v[xrange] for k,v in contrib_scores.items()}},
            title=idx,
            rotate_y=0,
            fig_width=20,
            seqlets=[s.set_seqname('contrib/' + s.name.split("/")[0]) for s in seqlets], # plot seqlets to the 'contrib/Nanog' track
            fig_height_per_track=1);
sns.despine(top=True, right=True, bottom=True)
fig.savefig(output_file, bbox_inches = "tight")