Exemple #1
0
def generate_sim(bpnet, central_motif, side_motif, side_distances,
                 center_coords=[450, 550], repeat=128, contribution=['count', 'profile'], correct=False):
    outl = []
    tasks = bpnet.tasks
    seqlen = bpnet.input_seqlen()
    # ref_preds = sim_pred(model, central_motif)
    ref_preds = unflatten(bpnet.sim_pred(central_motif,
                                         repeat=repeat,
                                         contribution=contribution), "/")
    none_preds = unflatten(bpnet.sim_pred('', '', [],
                                          repeat=repeat, contribution=contribution), "/")

    alt_profiles = []
    for dist in tqdm(side_distances):
        # alt_preds = sim_pred(model, central_motif, side_motif, [dist])
        alt_preds = unflatten(bpnet.sim_pred(central_motif, side_motif, [dist],
                                             repeat=repeat, contribution=contribution), "/")
        if correct:
            # Correct for the 'shoulder' effect
            #
            # this performs: AB - (B - 0)
            # Where:
            # - AB: contains both, central and side_motif
            # - B : contains only side_motif
            # - 0 : doesn't contain any motif
            edge_only_preds = unflatten(bpnet.sim_pred('', side_motif, [dist],
                                                       repeat=repeat, contribution=contribution), "/")

            alt_preds_f = flatten(alt_preds, '/')
            # ref_preds_f = flatten(ref_preds, '/')
            none_preds_f = flatten(none_preds, "/")
            # substract the other counts
            alt_preds = unflatten({k: alt_preds_f[k] - v + none_preds_f[k]
                                   for k, v in flatten(edge_only_preds, "/").items()}, "/")
            # ref_preds = unflatten({k: ref_preds_f[k] - v  for k,v in flatten(none_preds, "/").items()}, "/")
        alt_profiles.append((dist, alt_preds))

        # This normalizes the score by `A` finally yielding:
        # (AB - B + 0) / A
        scores = get_scores(ref_preds, alt_preds, tasks, central_motif, seqlen, center_coords)

        # compute the distance metrics
        for task in bpnet.tasks:
            d = scores[task]

            # book-keeping
            d['task'] = task
            d['central_motif'] = central_motif
            d['side_motif'] = side_motif
            d['position'] = dist
            d['distance'] = dist - seqlen // 2

            outl.append(d)

    return pd.DataFrame(outl), alt_profiles
Exemple #2
0
    def evaluate(self,
                 dataset,
                 eval_metric=None,
                 num_workers=8,
                 batch_size=256):
        lpreds = []
        llabels = []
        for inputs, targets in tqdm(dataset.batch_train_iter(
                cycle=False, num_workers=num_workers, batch_size=batch_size),
                                    total=len(dataset) // batch_size):
            assert isinstance(targets, dict)
            target_keys = list(targets)
            llabels.append(deepcopy(targets))
            bpreds = {
                k: v
                for k, v in self.predict(inputs, batch_size=None).items()
                if k in target_keys
            }  # keep only the target key predictions
            lpreds.append(bpreds)
            del inputs
            del targets
        preds = numpy_collate_concat(lpreds)
        labels = numpy_collate_concat(llabels)
        del lpreds
        del llabels

        if eval_metric is not None:
            return eval_metric(labels, preds)
        else:
            task_avg_tape = defaultdict(list)
            out = {}
            for task, heads in self.all_heads.items():
                for head_i, head in enumerate(heads):
                    target_name = head.get_target(task)
                    if target_name not in labels:
                        print(
                            f"Target {target_name} not found. Skipping evaluation"
                        )
                        continue
                    res = head.metric(labels[target_name], preds[target_name])
                    out[target_name] = res
                    metrics_dict = flatten(res, separator='/')
                    for k, v in metrics_dict.items():
                        task_avg_tape[
                            head.target_name.replace("{task}", "avg") + "/" +
                            k].append(v)
            for k, v in task_avg_tape.items():
                # get the average
                out[k] = mean(v)

        # flatten everything
        out = flatten(out, separator='/')
        return out
Exemple #3
0
def get_scores(ref_pred, alt_pred, tasks, motif, seqlen, center_coords):
    d = {}
    cstart, cend = center_coords
    for task in tasks:
        # profile - use the filtered tracks
        d[task] = flatten({"profile": profile_sim_metrics(ref_pred['profile'][task][cstart:cend],
                                                          alt_pred['profile'][task][cstart:cend])}, "/")

        # contribution scores - use the central motif region
        if 'contrib' in ref_pred:
            for contrib_score in ref_pred["contrib"][task]:
                contrib, contrib_frac = contrib_sim_metrics(ref_pred["contrib"][task][contrib_score],
                                                            alt_pred["contrib"][task][contrib_score], motif, seqlen)
                d[task] = {f"contrib/{contrib_score}": contrib, f"contrib/{contrib_score}_frac": contrib_frac, **d[task]}
    return d
Exemple #4
0
 def flatten(self):
     return super().__init__(flatten(self.data), attrs=deepcopy(self.attrs))
Exemple #5
0
 def get_lens(self):
     return list(flatten(self.dapply(len)).values())
Exemple #6
0
def modisco_plot(
        modisco_dir,
        output_dir,
        # filter_npy=None,
        # ignore_dist_filter=False,
        heatmap_width=200,
        figsize=(10, 10),
        contribsf=None):
    """Plot the results of a modisco run

    Args:
      modisco_dir: modisco directory
      output_dir: Output directory for writing the results
      figsize: Output figure size
      contribsf: [optional] modisco contribution score file (ContribFile)
    """
    plt.switch_backend('agg')
    add_file_logging(output_dir, logger, 'modisco-plot')
    from bpnet.plot.vdom import write_heatmap_pngs
    from bpnet.plot.profiles import plot_profiles
    from bpnet.utils import flatten

    output_dir = Path(output_dir)
    output_dir.parent.mkdir(parents=True, exist_ok=True)

    # load modisco
    mf = ModiscoFile(f"{modisco_dir}/modisco.h5")

    if contribsf is not None:
        d = contribsf
    else:
        d = ContribFile.from_modisco_dir(modisco_dir)
        logger.info("Loading the contribution scores")
        d.cache()  # load all

    thr_one_hot = d.get_seq()
    # thr_hypothetical_contribs
    tracks = d.get_profiles()
    thr_hypothetical_contribs = dict()
    thr_contrib_scores = dict()
    # TODO - generalize this
    thr_hypothetical_contribs['profile'] = d.get_hyp_contrib()
    thr_contrib_scores['profile'] = d.get_contrib()

    tasks = d.get_tasks()

    # Count contribution (if it exists)
    if d.contains_contrib_score("counts/pre-act"):
        count_contrib_score = "counts/pre-act"
        thr_hypothetical_contribs['count'] = d.get_hyp_contrib(
            contrib_score=count_contrib_score)
        thr_contrib_scores['count'] = d.get_contrib(
            contrib_score=count_contrib_score)
    elif d.contains_contrib_score("count"):
        count_contrib_score = "count"
        thr_hypothetical_contribs['count'] = d.get_hyp_contrib(
            contrib_score=count_contrib_score)
        thr_contrib_scores['count'] = d.get_contrib(
            contrib_score=count_contrib_score)
    else:
        # Don't do anything
        pass

    thr_hypothetical_contribs = OrderedDict(
        flatten(thr_hypothetical_contribs, separator='/'))
    thr_contrib_scores = OrderedDict(flatten(thr_contrib_scores,
                                             separator='/'))
    # -------------------------------------------------

    all_seqlets = mf.seqlets()
    all_patterns = mf.pattern_names()
    if len(all_patterns) == 0:
        print("No patterns found")
        return

    # 1. Plots with tracks and contrib scores
    print("Writing results for contribution scores")
    plot_profiles(all_seqlets,
                  thr_one_hot,
                  tracks=tracks,
                  contribution_scores=thr_contrib_scores,
                  legend=False,
                  flip_neg=True,
                  rotate_y=0,
                  seq_height=.5,
                  patterns=all_patterns,
                  n_bootstrap=100,
                  fpath_template=str(output_dir /
                                     "{pattern}/agg_profile_contribcores"),
                  mkdir=True,
                  figsize=figsize)

    # 2. Plots only with hypothetical contrib scores
    print("Writing results for hypothetical contribution scores")
    plot_profiles(all_seqlets,
                  thr_one_hot,
                  tracks={},
                  contribution_scores=thr_hypothetical_contribs,
                  legend=False,
                  flip_neg=True,
                  rotate_y=0,
                  seq_height=1,
                  patterns=all_patterns,
                  n_bootstrap=100,
                  fpath_template=str(output_dir /
                                     "{pattern}/agg_profile_hypcontribscores"),
                  figsize=figsize)

    print("Plotting heatmaps")
    for pattern in tqdm(all_patterns):
        write_heatmap_pngs(all_seqlets[pattern],
                           d,
                           tasks,
                           pattern,
                           output_dir=str(output_dir / pattern),
                           resize_width=heatmap_width)

    mf.close()