示例#1
0
def test_contrib_bias_model(tmp_path, method, trained_model_w_bias):
    """Test whether we can compute differnet contribution scores
    """
    K.clear_session()
    fpath = tmp_path / 'imp-score.h5'
    bpnet_contrib(str(trained_model_w_bias), str(fpath), method=method)

    cf = ContribFile(fpath)
    assert cf.get_contrib()['Task1'].shape[-1] == 4
示例#2
0
def load_ranges(modisco_dir):
    modisco_dir = Path(modisco_dir)
    included_samples = load_included_samples(modisco_dir)

    kwargs = read_json(modisco_dir / "modisco-run.kwargs.json")
    d = ContribFile(kwargs["contrib_file"], included_samples)
    df = d.get_ranges()
    d.close()
    return df
示例#3
0
def cwm_scan_seqlets(modisco_dir,
                     output_file,
                     trim_frac=0.08,
                     num_workers=1,
                     contribsf=None,
                     verbose=False):
    """Compute the cwm scanning scores of the original modisco seqlets
    """
    from bpnet.modisco.table import ModiscoData
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    add_file_logging(os.path.dirname(output_file), logger, 'cwm_scan_seqlets')

    # figure out contrib_wildcard
    mf = ModiscoFile(modisco_dir / "modisco.h5")

    if contribsf is None:
        contrib = ContribFile.from_modisco_dir(modisco_dir)
    else:
        contrib = contribsf

    tasks = mf.tasks()
    # HACK prune the tasks of contribution (in case it's present)
    tasks = [t.split("/")[0] for t in tasks]

    dfi_list = []

    for pattern_name in tqdm(mf.pattern_names()):
        pattern = mf.get_pattern(pattern_name).trim_seq_ic(trim_frac)
        seqlets = mf._get_seqlets(pattern_name, trim_frac=trim_frac)

        # scan only the existing locations of the seqlets instead of the full sequences
        # to obtain the distribution
        stacked_seqlets = contrib.extract(seqlets)

        match, contribution = pattern.scan_contribution(
            stacked_seqlets.contrib,
            hyp_contrib=None,
            tasks=tasks,
            n_jobs=num_workers,
            verbose=False,
            pad_mode=None)
        seq_match = pattern.scan_seq(stacked_seqlets.seq,
                                     n_jobs=num_workers,
                                     verbose=False,
                                     pad_mode=None)

        dfm = pattern.get_instances(tasks,
                                    match,
                                    contribution,
                                    seq_match,
                                    fdr=1,
                                    verbose=verbose,
                                    plot=verbose)
        dfm = dfm[dfm.seq_match > 0]

        dfi_list.append(dfm)
    df = pd.concat(dfi_list)
    df.to_csv(output_file)
示例#4
0
def get_signal(seqlets, d: ContribFile, tasks, resize_width=200):
    thr_one_hot = d.get_seq()

    if resize_width is None:
        # width = first seqlets
        resize_width = seqlets[0].end - seqlets[0].start

    # get valid seqlets
    start_pad = np.ceil(resize_width / 2)
    end_pad = thr_one_hot.shape[1] - start_pad
    valid_seqlets = [
        s.resize(resize_width) for s in seqlets
        if (s.center() > start_pad) and (s.center() < end_pad)
    ]

    # prepare data
    ex_signal = {
        task: extract_signal(d.get_profiles()[task], valid_seqlets)
        for task in tasks
    }

    ex_contrib_profile = {
        task: extract_signal(d.get_contrib()[task], valid_seqlets).sum(axis=-1)
        for task in tasks
    }

    if d.contains_contrib_score('count'):
        ex_contrib_counts = {
            task: extract_signal(d.get_contrib("count")[task],
                                 valid_seqlets).sum(axis=-1)
            for task in tasks
        }
    elif d.contains_contrib_score('counts/pre-act'):
        ex_contrib_counts = {
            task: extract_signal(
                d.get_contrib("counts/pre-act")[task],
                valid_seqlets).sum(axis=-1)
            for task in tasks
        }
    else:
        ex_contrib_counts = None

    ex_seq = extract_signal(thr_one_hot, valid_seqlets)

    seq, contrib, hyp_contrib, profile, ranges = d.get_all()

    total_counts = sum(
        [x.sum(axis=-1).sum(axis=-1) for x in ex_signal.values()])
    sort_idx = np.argsort(-total_counts)
    return ex_signal, ex_contrib_profile, ex_contrib_counts, ex_seq, sort_idx
示例#5
0
def contrib2bw(contrib_file, output_prefix):
    """Convert the contribution file to bigwigs
    """
    from kipoi.writers import BigWigWriter
    from bpnet.cli.contrib import ContribFile
    from bpnet.cli.modisco import get_nonredundant_example_idx
    output_dir = os.path.dirname(output_prefix)
    add_file_logging(output_dir, logger, 'contrib2bw')
    os.makedirs(output_dir, exist_ok=True)

    cf = ContribFile(contrib_file)

    # remove overlapping intervals
    ranges = cf.get_ranges()
    keep_idx = get_nonredundant_example_idx(ranges, width=None)
    cf.include_samples = keep_idx
    discarded = len(ranges) - len(keep_idx)
    logger.info(
        f"{discarded}/{len(ranges)} of ranges will be discarded due to overlapping intervals"
    )

    contrib_scores = cf.available_contrib_scores(
    )  # TODO - implement contrib_wildcard to filter them
    chrom_sizes = [(k, v) for k, v in cf.get_chrom_sizes().items()]
    ranges = cf.ranges()

    assert len(ranges) == len(keep_idx)

    delim = "." if not output_prefix.endswith("/") else ""

    for contrib_score in contrib_scores:
        contrib_dict = cf.get_contrib(contrib_score=contrib_score)
        contrib_score_name = contrib_score.replace("/", "_")

        for task, contrib in contrib_dict.items():
            output_file = output_prefix + f'{delim}.contrib.{contrib_score_name}.{task}.bw'
            logger.info(f"Genrating {output_file}")
            contrib_writer = BigWigWriter(output_file,
                                          chrom_sizes=chrom_sizes,
                                          is_sorted=False)

            for idx in range(len(ranges)):
                contrib_writer.region_write(region={
                    "chr": ranges['chrom'].iloc[idx],
                    "start": ranges['start'].iloc[idx],
                    "end": ranges['end'].iloc[idx]
                },
                                            data=contrib[idx])
            contrib_writer.close()
    logger.info("Done!")
示例#6
0
def modisco_export_patterns(modisco_dir, output_file, contribsf=None):
    """Export patterns to a pkl file. Don't cluster them

    Adds `stacked_seqlet_contrib` and `n_seqlets` to pattern `attrs`

    Args:
      modisco_dir: modisco directory containing
      output_file: output file path for patterns.pkl
    """
    from bpnet.cli.contrib import ContribFile

    logger.info("Loading patterns")
    modisco_dir = Path(modisco_dir)

    mf = ModiscoFile(modisco_dir / 'modisco.h5')
    patterns = [mf.get_pattern(pname) for pname in mf.pattern_names()]

    if contribsf is None:
        contrib_file = ContribFile.from_modisco_dir(modisco_dir)
        logger.info("Loading ContribFile into memory")
        contrib_file.cache()
    else:
        logger.info("Using the provided ContribFile")
        contrib_file = contribsf

    logger.info("Extracting profile and contribution scores")
    extended_patterns = []
    for p in tqdm(patterns):
        p = p.copy()

        # get seqlets
        valid_seqlets = mf._get_seqlets(p.name)

        # extract the contribution scores
        sti = contrib_file.extract(valid_seqlets, profile_width=None)
        sti.dfi = mf.get_seqlet_intervals(p.name, as_df=True)
        p.attrs['stacked_seqlet_contrib'] = sti
        p.attrs['n_seqlets'] = mf.n_seqlets(p.name)
        extended_patterns.append(p)

    write_pkl(extended_patterns, output_file)
示例#7
0
run_id = "2020-10-03_07-55-17_1e61b98d-8bb0-4220-b501-3def6877fd00"

# contribution scores
os.system(f"bpnet contrib {model_dir} --method=deeplift --memfrac-gpu=1 --contrib-wildcard='*/profile/wn' {contrib_file}")




import seaborn as sns
from bpnet.cli.contrib import ContribFile
from bpnet.plot.tracks import plot_tracks, to_neg
import seaborn as sns
import matplotlib.pyplot as plt

cf = ContribFile(contrib_file)

profiles = cf.get_profiles()
contrib_scores = cf.get_contrib()

examples = list({v.max(axis=-2).mean(axis=-1).argmax() for k,v in profiles.items()})
examples

tasks = ['Oct4', 'Sox2', 'Nanog']



fig = plt.figure(figsize=[8,8])

xrange = slice(50, 150)
for idx in examples:
示例#8
0
def chip_nexus_analysis(modisco_dir,
                        trim_frac=0.08,
                        num_workers=20,
                        run_cwm_scan=False,
                        force=False,
                        footprint_width=200):
    """Compute all the results for modisco specific for ChIP-nexus/exo data. Runs:
    - modisco_plot
    - modisco_report
    - modisco_table
    - modisco_export_patterns
    - cwm_scan
    - modisco_export_seqlets

    Note:
      All the sub-commands are only executed if they have not been ran before. Use --force override this.
      Whether the commands have been run before is deterimined by checking if the following file exists:
        `{modisco_dir}/.modisco_report_all/{command}.done`.
    """
    plt.switch_backend('agg')
    from bpnet.utils import ConditionalRun

    modisco_dir = Path(modisco_dir)
    # figure out the contribution scores used
    kwargs = read_json(modisco_dir / "modisco-run.kwargs.json")
    contrib_scores = kwargs["contrib_file"]

    mf = ModiscoFile(f"{modisco_dir}/modisco.h5")
    all_patterns = mf.pattern_names()
    mf.close()
    if len(all_patterns) == 0:
        print("No patterns found.")
        # Touch modisco-chip.html for snakemake
        open(modisco_dir / 'modisco-chip.html', 'a').close()
        open(modisco_dir / 'seqlets/scored_regions.bed', 'a').close()
        return

    # class determining whether to run the command or not (poor-man's snakemake)
    cr = ConditionalRun("modisco_report_all", None, modisco_dir, force=force)

    sync = []
    # --------------------------------------------
    if (not cr.set_cmd('modisco_plot').done()
            or not cr.set_cmd('modisco_enrich_patterns').done()):
        # load ContribFile and pass it to all the functions
        logger.info("Loading ContribFile")
        contribsf = ContribFile.from_modisco_dir(modisco_dir)
        contribsf.cache()
    else:
        contribsf = None
    # --------------------------------------------
    # Basic reports
    if not cr.set_cmd('modisco_plot').done():
        modisco_plot(modisco_dir,
                     modisco_dir / 'plots',
                     heatmap_width=footprint_width,
                     figsize=(10, 10),
                     contribsf=contribsf)
        cr.write()
    sync.append("plots")

    if not cr.set_cmd('modisco_report').done():
        modisco_report(str(modisco_dir), str(modisco_dir))
        cr.write()
    sync.append("modisco-chip.html")

    if not cr.set_cmd('modisco_table').done():
        modisco_table(modisco_dir,
                      contrib_scores,
                      modisco_dir,
                      report_url=None,
                      contribsf=contribsf,
                      footprint_width=footprint_width)
        cr.write()
    sync.append("footprints.pkl")
    sync.append("pattern_table.*")

    if not cr.set_cmd('modisco_export_patterns').done():
        modisco_export_patterns(modisco_dir,
                                output_file=modisco_dir / 'patterns.pkl',
                                contribsf=contribsf)
        cr.write()
    sync.append("patterns.pkl")

    # --------------------------------------------
    # Finding new instances
    if run_cwm_scan:
        if not cr.set_cmd('cwm_scan').done():
            cwm_scan(modisco_dir,
                     modisco_dir / 'instances.bed.gz',
                     trim_frac=trim_frac,
                     contrib_file=None,
                     num_workers=num_workers)
            cr.write()

    # --------------------------------------------
    # Export bed-files and bigwigs

    # Seqlets
    if not cr.set_cmd('modisco_export_seqlets').done():
        modisco_export_seqlets(str(modisco_dir),
                               str(modisco_dir / 'seqlets'),
                               trim_frac=trim_frac)
        cr.write()
    sync.append("seqlets")

    # print the rsync command to run in order to sync the output
    # directories to the webserver
    logger.info("Run the following command to sync files to the webserver")
    dirs = " ".join(sync)
    print(f"rsync -av --progress {dirs} <output_dir>/")
示例#9
0
def cwm_scan(modisco_dir,
             output_file,
             trim_frac=0.08,
             patterns='all',
             filters='match_weighted_p>=.2,contrib_weighted_p>=.01',
             contrib_file=None,
             add_profile_features=False,
             num_workers=10):
    """Get motif instances via CWM scanning.
    """
    from bpnet.modisco.utils import longer_pattern, shorten_pattern
    from bpnet.modisco.pattern_instances import annotate_profile_single
    add_file_logging(os.path.dirname(output_file), logger, 'cwm-scan')
    modisco_dir = Path(modisco_dir)

    valid_suffixes = [
        '.csv',
        '.csv.gz',
        '.tsv',
        '.tsv.gz',
        '.parq',
        '.bed',
        '.bed.gz',
    ]
    if not any([output_file.endswith(suffix) for suffix in valid_suffixes]):
        raise ValueError(
            f"output_file doesn't have a valid file suffix. Valid file suffixes are: {valid_suffixes}"
        )

    # Centroid matches path
    cm_path = modisco_dir / f'cwm-scan-seqlets.trim-frac={trim_frac:.2f}.csv.gz'

    # save the hyper-parameters
    kwargs_json_file = os.path.join(os.path.dirname(output_file),
                                    'cwm-scan.kwargs.json')
    write_json(
        dict(modisco_dir=os.path.abspath(str(contrib_file)),
             output_file=str(output_file),
             cwm_scan_seqlets_path=str(cm_path),
             trim_frac=trim_frac,
             patterns=patterns,
             filters=filters,
             contrib_file=contrib_file,
             add_profile_features=add_profile_features,
             num_workers=num_workers), str(kwargs_json_file))

    # figure out contrib_wildcard
    modisco_kwargs = read_json(
        os.path.join(modisco_dir, "modisco-run.kwargs.json"))
    contrib_type = load_contrib_type(modisco_kwargs)

    mf = ModiscoFile(modisco_dir / "modisco.h5")
    tasks = mf.tasks()
    # HACK prune the tasks of contribution (in case it's present)
    tasks = [t.split("/")[0] for t in tasks]

    logger.info(f"Using tasks: {tasks}")

    if contrib_file is None:
        cf = ContribFile.from_modisco_dir(modisco_dir)
        cf.cache(
        )  # cache it since it can be re-used in `modisco_centroid_seqlet_matches`
    else:
        logger.info(f"Loading the contribution scores from: {contrib_file}")
        cf = ContribFile(contrib_file, default_contrib_score=contrib_type)

    if not cm_path.exists():
        logger.info(f"Generating centroid matches to {cm_path.resolve()}")
        cwm_scan_seqlets(modisco_dir,
                         output_file=cm_path,
                         trim_frac=trim_frac,
                         contribsf=cf if contrib_file is None else None,
                         num_workers=num_workers,
                         verbose=False)
    else:
        logger.info("Centroid matches already exist.")
    logger.info(f"Loading centroid matches from {cm_path.resolve()}")
    dfm_norm = pd.read_csv(cm_path)

    # get the raw data
    seq, contrib, ranges = cf.get_seq(), cf.get_contrib(), cf.get_ranges()

    logger.info("Scanning for patterns")
    dfl = []

    # patterns to scan. `longer_pattern` makes sure the patterns are in the long format
    scan_patterns = patterns.split(
        ",") if patterns is not 'all' else mf.pattern_names()
    scan_patterns = [longer_pattern(pn) for pn in scan_patterns]

    if add_profile_features:
        profile = cf.get_profiles()
        logger.info("Profile features will also be added to dfi")

    for pattern_name in tqdm(mf.pattern_names()):
        if pattern_name not in scan_patterns:
            # skip scanning that patterns
            continue
        pattern = mf.get_pattern(pattern_name).trim_seq_ic(trim_frac)
        match, contribution = pattern.scan_contribution(contrib,
                                                        hyp_contrib=None,
                                                        tasks=tasks,
                                                        n_jobs=num_workers,
                                                        verbose=False)
        seq_match = pattern.scan_seq(seq, n_jobs=num_workers, verbose=False)
        dfm = pattern.get_instances(
            tasks,
            match,
            contribution,
            seq_match,
            norm_df=dfm_norm[dfm_norm.pattern == pattern_name],
            verbose=False,
            plot=False)
        for filt in filters.split(","):
            if len(filt) > 0:
                dfm = dfm.query(filt)

        if add_profile_features:
            dfm = annotate_profile_single(dfm,
                                          pattern_name,
                                          mf,
                                          profile,
                                          profile_width=70,
                                          trim_frac=trim_frac)
        dfm['pattern_short'] = shorten_pattern(pattern_name)

        # TODO - is it possible to write out the results incrementally?
        dfl.append(dfm)

    logger.info("Merging")
    # merge and write the results
    dfp = pd.concat(dfl)

    # append the ranges
    logger.info("Append ranges")
    ranges.columns = ["example_" + v for v in ranges.columns]
    dfp = dfp.merge(ranges, on="example_idx", how='left')

    # add the absolute coordinates
    dfp['pattern_start_abs'] = dfp['example_start'] + dfp['pattern_start']
    dfp['pattern_end_abs'] = dfp['example_start'] + dfp['pattern_end']

    logger.info("Table info")
    dfp.info()
    logger.info(
        f"Writing the resuling pd.DataFrame of shape {dfp.shape} to {output_file}"
    )

    # set the first 7 columns to comply to bed6 format (chrom, start, end, name, score, strand, ...)
    bed_columns = [
        'example_chrom', 'pattern_start_abs', 'pattern_end_abs', 'pattern',
        'contrib_weighted_p', 'strand', 'match_weighted_p'
    ]
    dfp = pd_first_cols(dfp, bed_columns)

    # write to a parquet file
    if output_file.endswith(".parq"):
        logger.info("Writing a parquet file")
        dfp.to_parquet(output_file,
                       partition_on=['pattern_short'],
                       engine='fastparquet')
    elif output_file.endswith(".csv.gz") or output_file.endswith(".csv"):
        logger.info("Writing a csv file")
        dfp.to_csv(output_file, compression='infer', index=False)
    elif output_file.endswith(".tsv.gz") or output_file.endswith(".tsv"):
        logger.info("Writing a tsv file")
        dfp.to_csv(output_file, sep='\t', compression='infer', index=False)
    elif output_file.endswith(".bed.gz") or output_file.endswith(".bed"):
        logger.info("Writing a BED file")
        # write only the first (and main) 7 columns
        dfp[bed_columns].to_csv(output_file,
                                sep='\t',
                                compression='infer',
                                index=False,
                                header=False)
    else:
        logger.warn("File suffix not recognized. Using .csv.gz file format")
        dfp.to_csv(output_file, compression='gzip', index=False)
    logger.info("Done!")
示例#10
0
def modisco_plot(
        modisco_dir,
        output_dir,
        # filter_npy=None,
        # ignore_dist_filter=False,
        heatmap_width=200,
        figsize=(10, 10),
        contribsf=None):
    """Plot the results of a modisco run

    Args:
      modisco_dir: modisco directory
      output_dir: Output directory for writing the results
      figsize: Output figure size
      contribsf: [optional] modisco contribution score file (ContribFile)
    """
    plt.switch_backend('agg')
    add_file_logging(output_dir, logger, 'modisco-plot')
    from bpnet.plot.vdom import write_heatmap_pngs
    from bpnet.plot.profiles import plot_profiles
    from bpnet.utils import flatten

    output_dir = Path(output_dir)
    output_dir.parent.mkdir(parents=True, exist_ok=True)

    # load modisco
    mf = ModiscoFile(f"{modisco_dir}/modisco.h5")

    if contribsf is not None:
        d = contribsf
    else:
        d = ContribFile.from_modisco_dir(modisco_dir)
        logger.info("Loading the contribution scores")
        d.cache()  # load all

    thr_one_hot = d.get_seq()
    # thr_hypothetical_contribs
    tracks = d.get_profiles()
    thr_hypothetical_contribs = dict()
    thr_contrib_scores = dict()
    # TODO - generalize this
    thr_hypothetical_contribs['profile'] = d.get_hyp_contrib()
    thr_contrib_scores['profile'] = d.get_contrib()

    tasks = d.get_tasks()

    # Count contribution (if it exists)
    if d.contains_contrib_score("counts/pre-act"):
        count_contrib_score = "counts/pre-act"
        thr_hypothetical_contribs['count'] = d.get_hyp_contrib(
            contrib_score=count_contrib_score)
        thr_contrib_scores['count'] = d.get_contrib(
            contrib_score=count_contrib_score)
    elif d.contains_contrib_score("count"):
        count_contrib_score = "count"
        thr_hypothetical_contribs['count'] = d.get_hyp_contrib(
            contrib_score=count_contrib_score)
        thr_contrib_scores['count'] = d.get_contrib(
            contrib_score=count_contrib_score)
    else:
        # Don't do anything
        pass

    thr_hypothetical_contribs = OrderedDict(
        flatten(thr_hypothetical_contribs, separator='/'))
    thr_contrib_scores = OrderedDict(flatten(thr_contrib_scores,
                                             separator='/'))
    # -------------------------------------------------

    all_seqlets = mf.seqlets()
    all_patterns = mf.pattern_names()
    if len(all_patterns) == 0:
        print("No patterns found")
        return

    # 1. Plots with tracks and contrib scores
    print("Writing results for contribution scores")
    plot_profiles(all_seqlets,
                  thr_one_hot,
                  tracks=tracks,
                  contribution_scores=thr_contrib_scores,
                  legend=False,
                  flip_neg=True,
                  rotate_y=0,
                  seq_height=.5,
                  patterns=all_patterns,
                  n_bootstrap=100,
                  fpath_template=str(output_dir /
                                     "{pattern}/agg_profile_contribcores"),
                  mkdir=True,
                  figsize=figsize)

    # 2. Plots only with hypothetical contrib scores
    print("Writing results for hypothetical contribution scores")
    plot_profiles(all_seqlets,
                  thr_one_hot,
                  tracks={},
                  contribution_scores=thr_hypothetical_contribs,
                  legend=False,
                  flip_neg=True,
                  rotate_y=0,
                  seq_height=1,
                  patterns=all_patterns,
                  n_bootstrap=100,
                  fpath_template=str(output_dir /
                                     "{pattern}/agg_profile_hypcontribscores"),
                  figsize=figsize)

    print("Plotting heatmaps")
    for pattern in tqdm(all_patterns):
        write_heatmap_pngs(all_seqlets[pattern],
                           d,
                           tasks,
                           pattern,
                           output_dir=str(output_dir / pattern),
                           resize_width=heatmap_width)

    mf.close()
示例#11
0
def bpnet_modisco_run(
    contrib_file,
    output_dir,
    null_contrib_file=None,
    premade='modisco-50k',
    config=None,
    override='',
    contrib_wildcard="*/profile/wn",  # on which contribution scores to run modisco
    only_task_regions=False,
    filter_npy=None,
    exclude_chr="",
    num_workers=10,
    gpu=None,  # no need to use a gpu by default
    memfrac_gpu=0.45,
    overwrite=False,
):
    """Run TF-MoDISco on the contribution scores stored in the contribution score file
    generated by `bpnet contrib`.
    """
    add_file_logging(output_dir, logger, 'modisco-run')
    if gpu is not None:
        logger.info(f"Using gpu: {gpu}, memory fraction: {memfrac_gpu}")
        create_tf_session(gpu, per_process_gpu_memory_fraction=memfrac_gpu)
    else:
        # Don't use any GPU's
        os.environ['CUDA_VISIBLE_DEVICES'] = ''
        os.environ['MKL_THREADING_LAYER'] = 'GNU'

    import modisco
    assert '/' in contrib_wildcard

    if filter_npy is not None:
        filter_npy = os.path.abspath(str(filter_npy))
    if config is not None:
        config = os.path.abspath(str(config))

    # setup output file paths
    output_path = os.path.abspath(os.path.join(output_dir, "modisco.h5"))
    remove_exists(output_path, overwrite=overwrite)
    output_filter_npy = os.path.abspath(
        os.path.join(output_dir, 'modisco-run.subset-contrib-file.npy'))
    remove_exists(output_filter_npy, overwrite=overwrite)
    kwargs_json_file = os.path.join(output_dir, "modisco-run.kwargs.json")
    remove_exists(kwargs_json_file, overwrite=overwrite)
    if config is not None:
        config_output_file = os.path.join(output_dir,
                                          'modisco-run.input-config.gin')
        remove_exists(config_output_file, overwrite=overwrite)
        shutil.copyfile(config, config_output_file)

    # save the hyper-parameters
    write_json(
        dict(contrib_file=os.path.abspath(contrib_file),
             output_dir=str(output_dir),
             null_contrib_file=null_contrib_file,
             config=str(config),
             override=override,
             contrib_wildcard=contrib_wildcard,
             only_task_regions=only_task_regions,
             filter_npy=str(filter_npy),
             exclude_chr=exclude_chr,
             num_workers=num_workers,
             overwrite=overwrite,
             output_filter_npy=output_filter_npy,
             gpu=gpu,
             memfrac_gpu=memfrac_gpu), kwargs_json_file)

    # setup the gin config using premade, config and override
    cli_bindings = [f'num_workers={num_workers}']
    gin.parse_config_files_and_bindings(
        _get_gin_files(premade, config),
        bindings=cli_bindings + override.split(";"),
        # NOTE: custom files were inserted right after
        # ther user's config file and before the `override`
        # parameters specified at the command-line
        skip_unknown=False)
    log_gin_config(output_dir, prefix='modisco-run.')
    # --------------------------------------------

    # load the contribution file
    logger.info(f"Loading the contribution file: {contrib_file}")
    cf = ContribFile(contrib_file)
    tasks = cf.get_tasks()

    # figure out subset_tasks
    subset_tasks = set()
    for w in contrib_wildcard.split(","):
        task, head, head_summary = w.split("/")
        if task == '*':
            subset_tasks = None
        else:
            if task not in tasks:
                raise ValueError(f"task {task} not found in tasks: {tasks}")
            subset_tasks.add(task)
    if subset_tasks is not None:
        subset_tasks = list(subset_tasks)

    # --------------------------------------------
    # subset the intervals
    logger.info(f"Loading ranges")
    ranges = cf.get_ranges()
    # include all samples at the beginning
    include_samples = np.ones(len(cf)).astype(bool)

    # --only-task-regions
    if only_task_regions:
        if subset_tasks is None:
            logger.warn(
                "contrib_wildcard contains all tasks (specified by */<head>/<summary>). Not using --only-task-regions"
            )
        elif np.all(ranges['interval_from_task'] == ''):
            raise ValueError(
                "Contribution file wasn't created from multiple set of peaks. "
                "E.g. interval_from_task='' for all ranges. Please disable --only-task-regions"
            )
        else:
            logger.info(f"Subsetting ranges according to `interval_from_task`")
            include_samples = include_samples & ranges[
                'interval_from_task'].isin(subset_tasks).values
            logger.info(
                f"Using {include_samples.sum()} / {len(include_samples)} regions after --only-task-regions subset"
            )

    # --exclude-chr
    if exclude_chr:
        logger.info(f"Excluding chromosomes: {exclude_chr}")
        chromosomes = ranges['chr']
        include_samples = include_samples & (
            ~pd.Series(chromosomes).isin(exclude_chr)).values
        logger.info(
            f"Using {include_samples.sum()} / {len(include_samples)} regions after --exclude-chr subset"
        )

    # -- filter-npy
    if filter_npy is not None:
        print(f"Loading a filter file from {filter_npy}")
        include_samples = include_samples & np.load(filter_npy)
        logger.info(
            f"Using {include_samples.sum()} / {len(include_samples)} regions after --filter-npy subset"
        )

    # store the subset-contrib-file.npy
    logger.info(
        f"Saving the included samples from ContribFile to {output_filter_npy}")
    np.save(output_filter_npy, include_samples)
    # --------------------------------------------

    # convert to indices
    idx = np.arange(len(include_samples))[include_samples]
    seqs = cf.get_seq(idx=idx)

    # fetch the contribution scores from the importance score file
    # expand * to use all possible values
    # TODO - allow this to be done also for all the heads?
    hyp_contrib = {}
    task_names = []
    for w in contrib_wildcard.split(","):
        wc_task, head, head_summary = w.split("/")
        if task == '*':
            use_tasks = tasks
        else:
            use_tasks = [wc_task]
        for task in use_tasks:
            key = f"{task}/{head}/{head_summary}"
            task_names.append(key)
            hyp_contrib[key] = cf._subset(cf.data[f'/hyp_contrib/{key}'],
                                          idx=idx)
    contrib = {k: v * seqs for k, v in hyp_contrib.items()}

    if null_contrib_file is not None:
        logger.info(f"Using null-contrib-file: {null_contrib_file}")
        null_cf = ContribFile(null_contrib_file)
        null_seqs = null_cf.get_seq()
        null_per_pos_scores = {
            key: null_seqs * null_cf.data[f'/hyp_contrib/{key}'][:]
            for key in task_names
        }
    else:
        # default Null distribution. Requires modisco 5.0
        logger.info(f"Using default null_contrib_scores")
        null_per_pos_scores = modisco.coordproducers.LaplaceNullDist(
            num_to_samp=10000)

    # run modisco.
    # NOTE: `workflow` and `report` parameters are provided by gin config files
    modisco_run(task_names=task_names,
                output_path=output_path,
                contrib_scores=contrib,
                hypothetical_contribs=hyp_contrib,
                one_hot=seqs,
                null_per_pos_scores=null_per_pos_scores)

    logger.info(
        f"bpnet modisco-run finished. modisco.h5 and other files can be found in: {output_dir}"
    )