コード例 #1
0
ファイル: test_011_remotes.py プロジェクト: yynst2/kipoi
def test_load_models_kipoi():
    k = kipoi.config.get_source("kipoi")

    ls = k.list_models()  # all the available models

    assert "HAL" in list(ls.model)
    model = "HAL"
    k.pull_model(model)

    # load the model
    kipoi.get_model(os.path.join(k.local_path, "HAL"), source="dir")

    kipoi.get_model(model, source="kipoi")
    kipoi.get_dataloader_factory(model)
コード例 #2
0
def test_load_models_kipoi():
    k = kipoi.config.get_source("kipoi")

    l = k.list_models()  # all the available models

    assert "HAL" in list(l.model)
    model = "HAL"
    mpath = k.pull_model(model)
    m_dir = os.path.dirname(mpath)

    # load the model
    kipoi.get_model(m_dir, source="dir")

    kipoi.get_model(model, source="kipoi")
    kipoi.get_dataloader_factory(model)
コード例 #3
0
def cli_info(command, raw_args):
    """CLI interface to predict
    """
    assert command == "info"
    parser = argparse.ArgumentParser('kipoi {}'.format(command),
                                     description="Prints dataloader" +
                                                 " keyword arguments.")
    parser.add_argument("-i", "--install_req", action='store_true',
                        help="Install required packages from requirements.txt")
    add_model(parser)
    add_dataloader(parser, with_args=False)
    args = parser.parse_args(raw_args)

    # --------------------------------------------
    # install args
    if args.install_req:
        kipoi.pipeline.install_model_requirements(args.model,
                                                  args.source,
                                                  and_dataloaders=True)
    # load model & dataloader
    model = kipoi.get_model(args.model, args.source)

    if args.dataloader is not None:
        dl_info = "dataloader '{0}' from source '{1}'".format(str(args.dataloader), str(args.dataloader_source))
        Dl = kipoi.get_dataloader_factory(args.dataloader, args.dataloader_source)
    else:
        dl_info = "default dataloader for model '{0}' from source '{1}'".format(str(model.name), str(args.source))
        Dl = model.default_dataloader

    print("-" * 80)
    print("Displaying keyword arguments for {0}".format(dl_info))
    print(kipoi.print_dl_kwargs(Dl))
    print("-" * 80)
コード例 #4
0
def get_example_data(example, layer, writer=None):
    example_dir = "examples/{0}".format(example)
    if INSTALL_REQ:
        install_model_requirements(example_dir, "dir", and_dataloaders=True)

    model = kipoi.get_model(example_dir, source="dir")
    # The preprocessor
    Dataloader = kipoi.get_dataloader_factory(example_dir, source="dir")
    #
    with open(example_dir + "/example_files/test.json", "r") as ifh:
        dataloader_arguments = json.load(ifh)

    for k in dataloader_arguments:
        dataloader_arguments[k] = "example_files/" + dataloader_arguments[k]

    outputs = []
    with cd(model.source_dir):
        dl = Dataloader(**dataloader_arguments)
        it = dl.batch_iter(batch_size=32, num_workers=0)

        # Loop through the data, make predictions, save the output
        for i, batch in enumerate(tqdm(it)):

            # make the prediction
            pred_batch = model.input_grad(batch['inputs'], avg_func="sum", layer=layer,
                                          final_layer=False)
            # write out the predictions, metadata (, inputs, targets)
            # always keep the inputs so that input*grad can be generated!
            output_batch = prepare_batch(batch, pred_batch, keep_inputs=True)
            if writer is not None:
                writer.batch_write(output_batch)
            outputs.append(output_batch)
        if writer is not None:
            writer.close()
    return numpy_collate(outputs)
コード例 #5
0
def test_extractor_model(example):
    """Test extractor
    """
    if example == "rbp" and sys.version_info[0] == 2:
        pytest.skip("rbp example not supported on python 2 ")
    #
    example_dir = "examples/{0}".format(example)
    # install the dependencies
    # - TODO maybe put it implicitly in load_dataloader?
    if INSTALL_REQ:
        install_model_requirements(example_dir, "dir", and_dataloaders=True)
    #
    Dl = kipoi.get_dataloader_factory(example_dir, source="dir")
    #
    test_kwargs = get_test_kwargs(example_dir)
    #
    # install the dependencies
    # - TODO maybe put it implicitly in load_extractor?
    if INSTALL_REQ:
        install_model_requirements(example_dir, source="dir")
    #
    # get model
    model = kipoi.get_model(example_dir, source="dir")
    #
    with cd(example_dir + "/example_files"):
        # initialize the dataloader
        dataloader = Dl(**test_kwargs)
        #
        # sample a batch of data
        it = dataloader.batch_iter()
        batch = next(it)
        # predict with a model
        model.predict_on_batch(batch["inputs"])
        model.pred_grad(batch["inputs"], Slice_conv()[:, 0])
コード例 #6
0
def test_dataloader_model(example):
    """Test dataloader
    """
    if example in {"rbp", "iris_model_template"} and sys.version_info[0] == 2:
        pytest.skip("example not supported on python 2 ")

    example_dir = "example/models/{0}".format(example)

    # install the dependencies
    if INSTALL_REQ:
        install_model_requirements(example_dir, "dir", and_dataloaders=True)

    Dl = kipoi.get_dataloader_factory(example_dir, source="dir")

    test_kwargs = Dl.example_kwargs

    # get dataloader

    # get model
    model = kipoi.get_model(example_dir, source="dir")

    with kipoi_utils.utils.cd(example_dir):
        # initialize the dataloader
        dataloader = Dl(**test_kwargs)

        # sample a batch of data
        it = dataloader.batch_iter()
        batch = next(it)
        # predict with a model
        model.predict_on_batch(batch["inputs"])
コード例 #7
0
ファイル: main.py プロジェクト: jeffmylife/kipoi
def cli_preproc(command, raw_args):
    """Preprocess:
    - Run the dataloader and store the results to a (hdf5) file
    """
    assert command == "preproc"
    parser = argparse.ArgumentParser(
        'kipoi {}'.format(command),
        description='Run the dataloader and save the output to an hdf5 file.')
    add_dataloader_main(parser, with_args=True)
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='Batch size to use in data loading')
    parser.add_argument("-i",
                        "--install_req",
                        action='store_true',
                        help="Install required packages from requirements.txt")
    parser.add_argument(
        "-n",
        "--num_workers",
        type=int,
        default=0,
        help="Number of parallel workers for loading the dataset")
    parser.add_argument("-o",
                        "--output",
                        required=True,
                        help="Output hdf5 file")
    args = parser.parse_args(raw_args)

    dataloader_kwargs = parse_json_file_str(args.dataloader_args)

    dir_exists(os.path.dirname(args.output), logger)
    # --------------------------------------------
    # install args
    if args.install_req:
        kipoi.pipeline.install_dataloader_requirements(args.dataloader,
                                                       args.source)
    Dataloader = kipoi.get_dataloader_factory(args.dataloader, args.source)

    dataloader_kwargs = kipoi.pipeline.validate_kwargs(Dataloader,
                                                       dataloader_kwargs)
    dataloader = Dataloader(**dataloader_kwargs)

    it = dataloader.batch_iter(batch_size=args.batch_size,
                               num_workers=args.num_workers)

    logger.info("Writing to the hdf5 file: {0}".format(args.output))
    writer = writers.HDF5BatchWriter(file_path=args.output)

    for i, batch in enumerate(tqdm(it)):
        # check that the first batch was indeed correct
        if i == 0 and not Dataloader.output_schema.compatible_with_batch(
                batch):
            logger.warn(
                "First batch of data is not compatible with the dataloader schema."
            )
        writer.batch_write(batch)

    writer.close()
    logger.info("Done!")
コード例 #8
0
def cli_info(command, raw_args):
    """CLI interface to predict
    """
    assert command == "info"
    parser = argparse.ArgumentParser('kipoi {}'.format(command),
                                     description="Prints dataloader" +
                                                 " keyword arguments.")
    add_model(parser)
    add_dataloader(parser, with_args=False)
    args = parser.parse_args(raw_args)

    # --------------------------------------------
    # load model & dataloader
    model = kipoi.get_model(args.model, args.source)

    if args.dataloader is not None:
        dl_info = "dataloader '{0}' from source '{1}'".format(str(args.dataloader), str(args.dataloader_source))
        Dl = kipoi.get_dataloader_factory(args.dataloader, args.dataloader_source)
    else:
        dl_info = "default dataloader for model '{0}' from source '{1}'".format(str(model.name), str(args.source))
        Dl = model.default_dataloader

    print("-" * 80)
    print("Displaying keyword arguments for {0}".format(dl_info))
    print(Dl.print_args())
    print("-" * 80)
コード例 #9
0
def test_deeplift():
    # return True
    example = "tal1_model"
    layer = predict_activation_layers[example]
    example_dir = "tests/models/{0}".format(example)
    if INSTALL_REQ:
        install_model_requirements(example_dir, "dir", and_dataloaders=True)

    model = kipoi.get_model(example_dir, source="dir")
    # The preprocessor
    Dataloader = kipoi.get_dataloader_factory(example_dir, source="dir")
    #
    with open(example_dir + "/example_files/test.json", "r") as ifh:
        dataloader_arguments = json.load(ifh)

    for k in dataloader_arguments:
        dataloader_arguments[k] = "example_files/" + dataloader_arguments[k]

    d = DeepLift(model,
                 output_layer=-2,
                 task_idx=0,
                 preact=None,
                 mxts_mode='grad_times_inp')

    new_ofname = model.source_dir + "/example_files/deeplift_grads_pred.hdf5"
    if os.path.exists(new_ofname):
        os.unlink(new_ofname)

    writer = writers.HDF5BatchWriter(file_path=new_ofname)

    with kipoi.utils.cd(model.source_dir):
        dl = Dataloader(**dataloader_arguments)
        it = dl.batch_iter(batch_size=32, num_workers=0)
        # Loop through the data, make predictions, save the output
        for i, batch in enumerate(tqdm(it)):
            # make the prediction
            pred_batch = d.score(batch['inputs'], None)

            # Using Avanti's recommendation to check whether the model conversion has worked.
            pred_batch_fwd = d.predict_on_batch(batch['inputs'])
            orig_pred_batch_fwd = model.predict_on_batch(batch['inputs'])
            assert np.all(pred_batch_fwd == orig_pred_batch_fwd)

        output_batch = batch
        output_batch["input_grad"] = pred_batch
        writer.batch_write(output_batch)
    writer.close()

    new_res = readers.HDF5Reader.load(new_ofname)
    ref_res = readers.HDF5Reader.load(model.source_dir +
                                      "/example_files/grads.hdf5")
    assert np.all(
        np.isclose(new_res['input_grad'],
                   (ref_res['inputs'] * ref_res['grads'])))

    if os.path.exists(new_ofname):
        os.unlink(new_ofname)
コード例 #10
0
def test_mutation_map():
    if sys.version_info[0] == 2:
        pytest.skip("rbp example not supported on python 2 ")

    # Take the rbp model
    model_dir = "examples/rbp/"
    if INSTALL_REQ:
        install_model_requirements(model_dir, "dir", and_dataloaders=True)

    model = kipoi.get_model(model_dir, source="dir")
    # The preprocessor
    Dataloader = kipoi.get_dataloader_factory(model_dir, source="dir")
    #
    dataloader_arguments = {
        "fasta_file": "example_files/hg38_chr22.fa",
        "preproc_transformer": "dataloader_files/encodeSplines.pkl",
        "gtf_file": "example_files/gencode_v25_chr22.gtf.pkl.gz",
    }
    dataloader_arguments = {
        k: model_dir + v
        for k, v in dataloader_arguments.items()
    }
    #
    # Run the actual predictions
    vcf_path = model_dir + "example_files/first_variant.vcf"
    #
    model_info = kipoi.postprocessing.variant_effects.ModelInfoExtractor(
        model, Dataloader)
    vcf_to_region = kipoi.postprocessing.variant_effects.SnvCenteredRg(
        model_info)
    mdmm = mm._generate_mutation_map(
        model,
        Dataloader,
        vcf_path,
        dataloader_args=dataloader_arguments,
        evaluation_function=analyse_model_preds,
        batch_size=32,
        vcf_to_region=vcf_to_region,
        evaluation_function_kwargs={'diff_types': {
            'diff': Diff("mean")
        }})
    with cd(model.source_dir):
        mdmm.save_to_file("example_files/first_variant_mm_totest.hdf5")
        from kipoi.postprocessing.variant_effects.utils.generic import read_hdf5
        reference = read_hdf5("example_files/first_variant_mm.hdf5")
        obs = read_hdf5("example_files/first_variant_mm.hdf5")
        compare_rec(reference[0], obs[0])
        import matplotlib
        matplotlib.pyplot.switch_backend('agg')
        mdmm.plot_mutmap(0, "seq", "diff", "rbp_prb")
        os.unlink("example_files/first_variant_mm_totest.hdf5")
コード例 #11
0
def test_get_model(source):
    # model correctly instentiated
    assert kipoi.get_dataloader_factory("pyt", source).info.doc
    assert kipoi.get_model("pyt", source).info.doc

    assert kipoi.get_model("multiple_models/model1", source).dummy_add == 1
    assert kipoi.get_model("multiple_models/submodel/model2", source).dummy_add == 2

    # model examples correctly performed
    m = kipoi.get_model("multiple_models/model1", source)
    assert np.all(m.pipeline.predict_example() == 1)

    m = kipoi.get_model("multiple_models/submodel/model2", source)
    assert np.all(m.pipeline.predict_example() == 2)
コード例 #12
0
def test_var_eff_pred_varseq():
    if sys.version_info[0] == 2:
        pytest.skip("rbp example not supported on python 2 ")
    model_dir = "examples/var_seqlen_model/"
    if INSTALL_REQ:
        install_model_requirements(model_dir, "dir", and_dataloaders=True)
    #
    model = kipoi.get_model(model_dir, source="dir")
    # The preprocessor
    Dataloader = kipoi.get_dataloader_factory(model_dir, source="dir")
    #
    dataloader_arguments = {
        "fasta_file": "example_files/hg38_chr22.fa",
        "preproc_transformer": "dataloader_files/encodeSplines.pkl",
        "gtf_file": "example_files/gencode_v25_chr22.gtf.pkl.gz",
        "intervals_file": "example_files/variant_centered_intervals.tsv"
    }
    vcf_path = "example_files/variants.vcf"
    out_vcf_fpath = "example_files/variants_generated.vcf"
    ref_out_vcf_fpath = "example_files/variants_ref_out.vcf"
    #
    with cd(model.source_dir):
        vcf_path = kipoi.postprocessing.variant_effects.ensure_tabixed_vcf(
            vcf_path)
        model_info = kipoi.postprocessing.variant_effects.ModelInfoExtractor(
            model, Dataloader)
        writer = kipoi.postprocessing.variant_effects.VcfWriter(
            model, vcf_path, out_vcf_fpath)
        vcf_to_region = None
        with pytest.raises(Exception):
            # This has to raise an exception as the sequence length is None.
            vcf_to_region = kipoi.postprocessing.variant_effects.SnvCenteredRg(
                model_info)
        res = sp.predict_snvs(
            model,
            Dataloader,
            vcf_path,
            dataloader_args=dataloader_arguments,
            evaluation_function=analyse_model_preds,
            batch_size=32,
            vcf_to_region=vcf_to_region,
            evaluation_function_kwargs={'diff_types': {
                'diff': Diff("mean")
            }},
            sync_pred_writer=writer)
        writer.close()
        # pass
        # assert filecmp.cmp(out_vcf_fpath, ref_out_vcf_fpath)
        compare_vcfs(out_vcf_fpath, ref_out_vcf_fpath)
        os.unlink(out_vcf_fpath)
コード例 #13
0
def test_var_eff_pred2():
    if sys.version_info[0] == 2:
        pytest.skip("rbp example not supported on python 2 ")
    # Take the rbp model
    model_dir = "examples/rbp/"
    if INSTALL_REQ:
        install_model_requirements(model_dir, "dir", and_dataloaders=True)
    #
    model = kipoi.get_model(model_dir, source="dir")
    # The preprocessor
    Dataloader = kipoi.get_dataloader_factory(model_dir, source="dir")
    #
    dataloader_arguments = {
        "fasta_file": "example_files/hg38_chr22.fa",
        "preproc_transformer": "dataloader_files/encodeSplines.pkl",
        "gtf_file": "example_files/gencode_v25_chr22.gtf.pkl.gz",
    }
    #
    # Run the actual predictions
    vcf_path = "example_files/variants.vcf"
    out_vcf_fpath = "example_files/variants_generated2.vcf"
    ref_out_vcf_fpath = "example_files/variants_ref_out2.vcf"
    restricted_regions_fpath = "example_files/restricted_regions.bed"
    #
    with cd(model.source_dir):
        pbd = pb.BedTool(restricted_regions_fpath)
        model_info = kipoi.postprocessing.variant_effects.ModelInfoExtractor(
            model, Dataloader)
        vcf_to_region = kipoi.postprocessing.variant_effects.SnvPosRestrictedRg(
            model_info, pbd)
        writer = kipoi.postprocessing.variant_effects.utils.io.VcfWriter(
            model, vcf_path, out_vcf_fpath)
        res = sp.predict_snvs(
            model,
            Dataloader,
            vcf_path,
            dataloader_args=dataloader_arguments,
            evaluation_function=analyse_model_preds,
            batch_size=32,
            vcf_to_region=vcf_to_region,
            evaluation_function_kwargs={'diff_types': {
                'diff': Diff("mean")
            }},
            sync_pred_writer=writer)
        writer.close()
        # pass
        #assert filecmp.cmp(out_vcf_fpath, ref_out_vcf_fpath)
        compare_vcfs(out_vcf_fpath, ref_out_vcf_fpath)
        os.unlink(out_vcf_fpath)
コード例 #14
0
def test_ref_seq():
    ### Get pure fasta predictions
    model_dir = model_root+"./"
    model = kipoi.get_model(model_dir, source="dir")
    # The preprocessor
    Dataloader = kipoi.get_dataloader_factory(model_dir, source="dir")
    dataloader_arguments = {
        "fasta_file": "/nfs/research1/stegle/users/rkreuzhu/opt/manuscript_code/data/raw/dataloader_files/shared/hg19.fa",
        "intervals_file": "test_files/encode_roadmap.bed"
    }
    # predict using results
    preds = model.pipeline.predict(dataloader_arguments)
    #
    res_orig = pd.read_csv("/nfs/research1/stegle/users/rkreuzhu/deeplearning/Basset/data/encode_roadmap_short_pred.txt", "\t", header=None)
    assert np.isclose(preds, res_orig.values, atol=1e-3).all()
コード例 #15
0
def test_gradient_function_model(example):
    """Test extractor
    """
    if example == "rbp" and sys.version_info[0] == 2:
        pytest.skip("rbp example not supported on python 2 ")

    import keras
    backend = keras.backend._BACKEND
    if backend == 'theano' and example == "rbp":
        pytest.skip("extended_coda example not with theano ")
    #
    example_dir = "examples/{0}".format(example)
    # install the dependencies
    # - TODO maybe put it implicitly in load_dataloader?
    if INSTALL_REQ:
        install_model_requirements(example_dir, "dir", and_dataloaders=True)
    #
    Dl = kipoi.get_dataloader_factory(example_dir, source="dir")
    #
    test_kwargs = get_test_kwargs(example_dir)
    #
    # install the dependencies
    # - TODO maybe put it implicitly in load_extractor?
    if INSTALL_REQ:
        install_model_requirements(example_dir, source="dir")
    #
    # get model
    model = kipoi.get_model(example_dir, source="dir")
    #
    with cd(example_dir + "/example_files"):
        # initialize the dataloader
        dataloader = Dl(**test_kwargs)
        #
        # sample a batch of data
        it = dataloader.batch_iter()
        batch = next(it)
        # predict with a model
        model.predict_on_batch(batch["inputs"])
        if backend != 'theano':
            model.input_grad(batch["inputs"],
                             Slice_conv()[:, 0],
                             pre_nonlinearity=True)
        model.input_grad(batch["inputs"],
                         Slice_conv()[:, 0],
                         pre_nonlinearity=False)
        model.input_grad(batch["inputs"], 0,
                         pre_nonlinearity=False)  # same as Slice_conv()[:, 0]
        model.input_grad(batch["inputs"], avg_func="sum")
コード例 #16
0
def test_activation_function_model(example):
    """Test extractor
    """
    if example == "rbp" and sys.version_info[0] == 2:
        pytest.skip("rbp example not supported on python 2 ")
    #
    import keras
    backend = keras.backend._BACKEND
    if backend == 'theano' and example == "rbp":
        pytest.skip("extended_coda example not with theano ")
    #
    example_dir = "examples/{0}".format(example)
    # install the dependencies
    # - TODO maybe put it implicitly in load_dataloader?
    if INSTALL_REQ:
        install_model_requirements(example_dir, "dir", and_dataloaders=True)
    #
    Dl = kipoi.get_dataloader_factory(example_dir, source="dir")
    #
    test_kwargs = get_test_kwargs(example_dir)
    #
    # install the dependencies
    # - TODO maybe put it implicitly in load_extractor?
    if INSTALL_REQ:
        install_model_requirements(example_dir, source="dir")
    #
    # get model
    model = kipoi.get_model(example_dir, source="dir")
    #
    with cd(example_dir + "/example_files"):
        # initialize the dataloader
        dataloader = Dl(**test_kwargs)
        #
        # sample a batch of data
        it = dataloader.batch_iter()
        batch = next(it)
        # predict with a model
        model.predict_on_batch(batch["inputs"])
        model.predict_activation_on_batch(batch["inputs"],
                                          layer=len(model.model.layers) - 2)
        if example == "rbp":
            model.predict_activation_on_batch(batch["inputs"],
                                              layer="flatten_6")
コード例 #17
0
def test_score():
    example = "tal1_model"
    layer = predict_activation_layers[example]
    example_dir = "example/models/{0}".format(example)
    if INSTALL_REQ:
        install_model_requirements(example_dir, "dir", and_dataloaders=True)

    model = kipoi.get_model(example_dir, source="dir")
    # The preprocessor
    Dataloader = kipoi.get_dataloader_factory(example_dir, source="dir")
    #
    with open(example_dir + "/example_files/test.json", "r") as ifh:
        dataloader_arguments = json.load(ifh)

    for k in dataloader_arguments:
        dataloader_arguments[k] = "example_files/" + dataloader_arguments[k]

    g = Gradient(model, None, layer=layer, avg_func="sum")

    if os.path.exists(model.source_dir + "/example_files/grads_pred.hdf5"):
        os.unlink(model.source_dir + "/example_files/grads_pred.hdf5")

    writer = writers.HDF5BatchWriter(file_path=model.source_dir + "/example_files/grads_pred.hdf5")

    with kipoi_utils.utils.cd(model.source_dir):
        dl = Dataloader(**dataloader_arguments)
        it = dl.batch_iter(batch_size=32, num_workers=0)
        # Loop through the data, make predictions, save the output
        for i, batch in enumerate(tqdm(it)):
            # make the prediction
            pred_batch = g.score(batch['inputs'])
            output_batch = batch
            output_batch["grads"] = pred_batch
            writer.batch_write(output_batch)
        writer.close()

    obj1 = readers.HDF5Reader.load(model.source_dir + "/example_files/grads_pred.hdf5")
    obj2 = readers.HDF5Reader.load(model.source_dir + "/example_files/grads.hdf5")
    kipoi_utils.utils.compare_numpy_dict(obj1, obj2)

    if os.path.exists(model.source_dir + "/example_files/grads_pred.hdf5"):
        os.unlink(model.source_dir + "/example_files/grads_pred.hdf5")
コード例 #18
0
def test_load_model(example):
    example_dir = "examples/{0}".format(example)

    if example in {"rbp", "iris_model_template"} and sys.version_info[0] == 2:
        pytest.skip("example not supported on python 2 ")

    if INSTALL_REQ:
        install_dataloader_requirements(example_dir, "dir")
    Dl = kipoi.get_dataloader_factory(example_dir, source="dir")

    Dl.type
    Dl.defined_as
    Dl.args
    Dl.info
    Dl.output_schema
    Dl.source
    # datalaoder
    Dl.batch_iter
    Dl.load_all

    Dl.print_args()
コード例 #19
0
ファイル: cli.py プロジェクト: kipoi/kipoi-interpret
def cli_deeplift(command, raw_args):
    """CLI interface to predict
    """
    # TODO: find a way to define the "reference" for a scored sequence.
    # from .main import prepare_batch
    assert command == "deeplift"
    from tqdm import tqdm
    from .referencebased import DeepLift
    from .referencebased import get_mxts_modes
    parser = argparse.ArgumentParser('kipoi interpret {}'.format(command),
                                     description='Calculate DeepLIFT scores.')
    add_model(parser)
    add_dataloader(parser, with_args=True)
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='Batch size to use in prediction')
    parser.add_argument(
        "-n",
        "--num_workers",
        type=int,
        default=0,
        help="Number of parallel workers for loading the dataset")
    parser.add_argument(
        "-l",
        "--layer",
        type=int,
        default=None,
        help="With respect to which layer the scores should be calculated.",
        required=True)
    parser.add_argument(
        "--pre_nonlinearity",
        help=
        "Flag indicating that it should checked whether the selected output is post activation "
        "function. If a non-linear activation function is used attempt to use its input. This "
        "feature is not available for all models.",
        action='store_true')
    parser.add_argument(
        "-f",
        "--filter_idx",
        help="Filter index that should be inspected with gradients",
        default=None,
        required=True,
        type=int)
    parser.add_argument("-m",
                        "--mxts_mode",
                        help="Deeplift score, allowed values are: %s" %
                        str(list(get_mxts_modes().keys())),
                        default='rescale_conv_revealcancel_fc')
    parser.add_argument(
        '-o',
        '--output',
        required=True,
        nargs="+",
        help=
        "Output files. File format is inferred from the file path ending. Available file formats are: "
        + ", ".join(["." + k for k in writers.FILE_SUFFIX_MAP]))
    args = parser.parse_args(raw_args)

    dataloader_kwargs = parse_json_file_str(args.dataloader_args)

    # setup the files
    if not isinstance(args.output, list):
        args.output = [args.output]
    for o in args.output:
        ending = o.split('.')[-1]
        if ending not in writers.FILE_SUFFIX_MAP:
            logger.error("File ending: {0} for file {1} not from {2}".format(
                ending, o, writers.FILE_SUFFIX_MAP))
            sys.exit(1)
        dir_exists(os.path.dirname(o), logger)
    # --------------------------------------------

    layer = args.layer
    if layer is None and not args.final_layer:
        raise Exception(
            "A layer has to be selected explicitely using `--layer` or implicitely by using the"
            "`--final_layer` flag.")

    # Not a good idea
    # if layer is not None and isint(layer):
    #    logger.warn("Interpreting `--layer` value as integer layer index!")
    #    layer = int(args.layer)

    # load model & dataloader
    model = kipoi.get_model(args.model, args.source)

    if args.dataloader is not None:
        Dl = kipoi.get_dataloader_factory(args.dataloader,
                                          args.dataloader_source)
    else:
        Dl = model.default_dataloader

    dataloader_kwargs = kipoi.pipeline.validate_kwargs(Dl, dataloader_kwargs)
    dl = Dl(**dataloader_kwargs)

    # setup batching
    it = dl.batch_iter(batch_size=args.batch_size,
                       num_workers=args.num_workers)

    # Setup the writers
    use_writers = []
    for output in args.output:
        ending = output.split('.')[-1]
        W = writers.FILE_SUFFIX_MAP[ending]
        logger.info("Using {0} for file {1}".format(W.__name__, output))
        if ending == "tsv":
            assert W == writers.TsvBatchWriter
            use_writers.append(
                writers.TsvBatchWriter(file_path=output, nested_sep="/"))
        elif ending == "bed":
            raise Exception("Please use tsv or hdf5 output format.")
        elif ending in ["hdf5", "h5"]:
            assert W == writers.HDF5BatchWriter
            use_writers.append(writers.HDF5BatchWriter(file_path=output))
        else:
            logger.error("Unknown file format: {0}".format(ending))
            sys.exit(1)

    d = DeepLift(model,
                 output_layer=args.layer,
                 task_idx=args.filter_idx,
                 preact=args.pre_nonlinearity,
                 mxts_mode=args.mxts_mode,
                 batch_size=args.batch_size)

    # Loop through the data, make predictions, save the output
    for i, batch in enumerate(tqdm(it)):
        # validate the data schema in the first iteration
        if i == 0 and not Dl.output_schema.compatible_with_batch(batch):
            logger.warn(
                "First batch of data is not compatible with the dataloader schema."
            )

        # calculate scores without reference for the moment.
        pred_batch = d.score(batch['inputs'], None)

        # write out the predictions, metadata (, inputs, targets)
        # always keep the inputs so that input*grad can be generated!
        # output_batch = prepare_batch(batch, pred_batch, keep_inputs=True)
        output_batch = batch
        output_batch["scores"] = pred_batch
        for writer in use_writers:
            writer.batch_write(output_batch)

    for writer in use_writers:
        writer.close()
    logger.info('Done! Gradients stored in {0}'.format(",".join(args.output)))
コード例 #20
0
def cli_predict(command, raw_args):
    """CLI interface to predict
    """
    assert command == "predict"
    parser = argparse.ArgumentParser('kipoi {}'.format(command),
                                     description='Run the model prediction.')
    add_model(parser)
    add_dataloader(parser, with_args=True)
    parser.add_argument('--batch_size', type=int, default=32,
                        help='Batch size to use in prediction')
    parser.add_argument("-n", "--num_workers", type=int, default=0,
                        help="Number of parallel workers for loading the dataset")
    parser.add_argument("-i", "--install_req", action='store_true',
                        help="Install required packages from requirements.txt")
    parser.add_argument("-k", "--keep_inputs", action='store_true',
                        help="Keep the inputs in the output file. ")
    parser.add_argument("-l", "--layer",
                        help="Which output layer to use to make the predictions. If specified," +
                        "`model.predict_activation_on_batch` will be invoked instead of `model.predict_on_batch`")
    parser.add_argument('-o', '--output', required=True, nargs="+",
                        help="Output files. File format is inferred from the file path ending. Available file formats are: " +
                        ", ".join(["." + k for k in writers.FILE_SUFFIX_MAP]))
    args = parser.parse_args(raw_args)

    dataloader_kwargs = parse_json_file_str(args.dataloader_args)

    # setup the files
    if not isinstance(args.output, list):
        args.output = [args.output]
    for o in args.output:
        ending = o.split('.')[-1]
        if ending not in writers.FILE_SUFFIX_MAP:
            logger.error("File ending: {0} for file {1} not from {2}".
                         format(ending, o, writers.FILE_SUFFIX_MAP))
            sys.exit(1)
        dir_exists(os.path.dirname(o), logger)
    # --------------------------------------------
    # install args
    if args.install_req:
        kipoi.pipeline.install_model_requirements(args.model,
                                                  args.source,
                                                  and_dataloaders=True)
    # load model & dataloader
    model = kipoi.get_model(args.model, args.source)

    if args.dataloader is not None:
        Dl = kipoi.get_dataloader_factory(args.dataloader, args.dataloader_source)
    else:
        Dl = model.default_dataloader

    dataloader_kwargs = kipoi.pipeline.validate_kwargs(Dl, dataloader_kwargs)
    dl = Dl(**dataloader_kwargs)

    # setup batching
    it = dl.batch_iter(batch_size=args.batch_size,
                       num_workers=args.num_workers)

    # Setup the writers
    use_writers = []
    for output in args.output:
        ending = output.split('.')[-1]
        W = writers.FILE_SUFFIX_MAP[ending]
        logger.info("Using {0} for file {1}".format(W.__name__, output))
        if ending == "tsv":
            assert W == writers.TsvBatchWriter
            use_writers.append(writers.TsvBatchWriter(file_path=output, nested_sep="/"))
        elif ending == "bed":
            assert W == writers.BedBatchWriter
            use_writers.append(writers.BedBatchWriter(file_path=output,
                                                      dataloader_schema=dl.output_schema.metadata,
                                                      header=True))
        elif ending in ["hdf5", "h5"]:
            assert W == writers.HDF5BatchWriter
            use_writers.append(writers.HDF5BatchWriter(file_path=output))
        else:
            logger.error("Unknown file format: {0}".format(ending))
            sys.exit(1)

    # Loop through the data, make predictions, save the output
    for i, batch in enumerate(tqdm(it)):
        # validate the data schema in the first iteration
        if i == 0 and not Dl.output_schema.compatible_with_batch(batch):
            logger.warn("First batch of data is not compatible with the dataloader schema.")

        # make the prediction
        if args.layer is None:
            pred_batch = model.predict_on_batch(batch['inputs'])
        else:
            pred_batch = model.predict_activation_on_batch(batch['inputs'], layer=args.layer)

        # write out the predictions, metadata (, inputs, targets)
        output_batch = prepare_batch(batch, pred_batch, keep_inputs=args.keep_inputs)
        for writer in use_writers:
            writer.batch_write(output_batch)

    for writer in use_writers:
        writer.close()
    logger.info('Done! Predictions stored in {0}'.format(",".join(args.output)))
コード例 #21
0
ファイル: cli.py プロジェクト: kipoi/kipoi-interpret
def cli_ism(command, raw_args):
    # TODO: find a way to define the model output selection
    """CLI interface to predict
    """
    # from .main import prepare_batch
    assert command == "ism"
    from tqdm import tqdm
    from .ism import Mutation

    parser = argparse.ArgumentParser('kipoi interpret {}'.format(command),
                                     description='Calculate DeepLIFT scores.')
    add_model(parser)
    add_dataloader(parser, with_args=True)
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='Batch size to use in prediction')
    parser.add_argument(
        "-n",
        "--num_workers",
        type=int,
        default=0,
        help="Number of parallel workers for loading the dataset")
    parser.add_argument("--model_input",
                        help="Name of the model input that should be scored.",
                        required=True)
    parser.add_argument(
        '-s',
        "--scores",
        default="diff",
        nargs="+",
        help=
        "Scoring method to be used. Only scoring methods selected in the model yaml file are"
        "available except for `diff` which is always available. Select scoring function by the"
        "`name` tag defined in the model yaml file.")
    parser.add_argument(
        '-k',
        "--score_kwargs",
        default=None,
        nargs="+",
        help=
        "JSON definition of the kwargs for the scoring functions selected in --scores. The "
        "definiton can either be in JSON in the command line or the path of a .json file. The "
        "individual JSONs are expected to be supplied in the same order as the labels defined in "
        "--scores. If the defaults or no arguments should be used define '{}' for that respective "
        "scoring method.")
    parser.add_argument(
        "-c",
        "--category_axis",
        help="Using the selected model input with `--model_input`: Which "
        "dimension of that array contains the one-hot encoded categories?"
        " e.g. for a one-hot encoded DNA-sequence"
        "array with input shape (1000, 4) for a single sample, "
        "`category_dim` is 1, for (4, 1000) `category_dim`"
        "is 0.",
        default=1,
        type=int,
        required=False)
    parser.add_argument(
        "-f",
        "--output_sel_fn",
        help="Define an output selection function in order to return effects"
        "on the output of the function. example definitoin: "
        "`--output_sel_fn my_file.py::my_sel_fn`",
        default=None,
        required=False)
    parser.add_argument(
        '-o',
        '--output',
        required=True,
        nargs="+",
        help="Output files. File format is inferred from the file path ending. "
        "Available file formats are: " +
        ", ".join(["." + k for k in writers.FILE_SUFFIX_MAP]))
    args = parser.parse_args(raw_args)

    dataloader_kwargs = parse_json_file_str(args.dataloader_args)

    # setup the files
    if not isinstance(args.output, list):
        args.output = [args.output]
    for o in args.output:
        ending = o.split('.')[-1]
        if ending not in writers.FILE_SUFFIX_MAP:
            logger.error("File ending: {0} for file {1} not from {2}".format(
                ending, o, writers.FILE_SUFFIX_MAP))
            sys.exit(1)
        dir_exists(os.path.dirname(o), logger)
    # --------------------------------------------
    if not isinstance(args.scores, list):
        args.scores = [args.scores]

    # load model & dataloader
    model = kipoi.get_model(args.model, args.source)

    if args.dataloader is not None:
        Dl = kipoi.get_dataloader_factory(args.dataloader,
                                          args.dataloader_source)
    else:
        Dl = model.default_dataloader

    dataloader_kwargs = kipoi.pipeline.validate_kwargs(Dl, dataloader_kwargs)
    dl = Dl(**dataloader_kwargs)

    # setup batching
    it = dl.batch_iter(batch_size=args.batch_size,
                       num_workers=args.num_workers)

    # Setup the writers
    use_writers = []
    for output in args.output:
        ending = output.split('.')[-1]
        W = writers.FILE_SUFFIX_MAP[ending]
        logger.info("Using {0} for file {1}".format(W.__name__, output))
        if ending == "tsv":
            assert W == writers.TsvBatchWriter
            use_writers.append(
                writers.TsvBatchWriter(file_path=output, nested_sep="/"))
        elif ending == "bed":
            raise Exception("Please use tsv or hdf5 output format.")
        elif ending in ["hdf5", "h5"]:
            assert W == writers.HDF5BatchWriter
            use_writers.append(writers.HDF5BatchWriter(file_path=output))
        else:
            logger.error("Unknown file format: {0}".format(ending))
            sys.exit(1)

    output_sel_fn = None
    if args.output_sel_fn is not None:
        file_path, obj_name = tuple(args.output_sel_fn.split("::"))
        output_sel_fn = getattr(load_module(file_path), obj_name)

    m = Mutation(model,
                 args.model_input,
                 scores=args.scores,
                 score_kwargs=args.score_kwargs,
                 batch_size=args.batch_size,
                 output_sel_fn=output_sel_fn,
                 category_axis=args.category_axis,
                 test_ref_ref=True)

    out_batches = {}

    # Loop through the data, make predictions, save the output..
    # TODO: batch writer fails because it tries to concatenate on highest dimension rather than the lowest!
    for i, batch in enumerate(tqdm(it)):
        # validate the data schema in the first iteration
        if i == 0 and not Dl.output_schema.compatible_with_batch(batch):
            logger.warn(
                "First batch of data is not compatible with the dataloader schema."
            )

        # calculate scores without reference for the moment.
        pred_batch = m.score(batch['inputs'])

        # with the current writers it's not possible to store the scores and the model inputs in the same file
        output_batch = {}
        output_batch["scores"] = pred_batch

        for k in output_batch:
            if k not in out_batches:
                out_batches[k] = []
            out_batches[k].append(output_batch[k])

    # concatenate batches:
    full_output = {
        k: np.concatenate([np.array(el) for el in v])
        for k, v in out_batches.items()
    }
    logger.info('Full output shape: {0}'.format(
        str(full_output["scores"].shape)))

    for writer in use_writers:
        writer.batch_write(full_output)

    for writer in use_writers:
        writer.close()
    logger.info('Done! ISM stored in {0}'.format(",".join(args.output)))
コード例 #22
0
def cli_predict(command, raw_args):
    """CLI interface to predict
    """
    assert command == "predict"
    parser = argparse.ArgumentParser('kipoi {}'.format(command),
                                     description='Run the model prediction.')
    add_model(parser)
    add_dataloader(parser, with_args=True)
    parser.add_argument('--batch_size', type=int, default=32,
                        help='Batch size to use in prediction')
    parser.add_argument("-n", "--num_workers", type=int, default=0,
                        help="Number of parallel workers for loading the dataset")
    parser.add_argument("-k", "--keep_inputs", action='store_true',
                        help="Keep the inputs in the output file. ")
    parser.add_argument("-l", "--layer",
                        help="Which output layer to use to make the predictions. If specified," +
                        "`model.predict_activation_on_batch` will be invoked instead of `model.predict_on_batch`")
    parser.add_argument("--singularity", action='store_true',
                        help="Run `kipoi predict` in the appropriate singularity container. "
                        "Containters will get downloaded to ~/.kipoi/envs/ or to "
                        "$SINGULARITY_CACHEDIR if set")
    parser.add_argument('-o', '--output', required=True, nargs="+",
                        help="Output files. File format is inferred from the file path ending. Available file formats are: " +
                        ", ".join(["." + k for k in writers.FILE_SUFFIX_MAP]))
    args = parser.parse_args(raw_args)

    dataloader_kwargs = parse_json_file_str_or_arglist(args.dataloader_args, parser)

    # setup the files
    if not isinstance(args.output, list):
        args.output = [args.output]
    for o in args.output:
        ending = o.split('.')[-1]
        if ending not in writers.FILE_SUFFIX_MAP:
            logger.error("File ending: {0} for file {1} not from {2}".
                         format(ending, o, writers.FILE_SUFFIX_MAP))
            sys.exit(1)
        dir_exists(os.path.dirname(o), logger)

    # singularity_command
    if args.singularity:
        from kipoi.cli.singularity import singularity_command
        logger.info("Running kipoi predict in the singularity container")
        # Drop the singularity flag
        raw_args = [x for x in raw_args if x != '--singularity']
        singularity_command(['kipoi', command] + raw_args,
                            args.model,
                            dataloader_kwargs,
                            output_files=args.output,
                            source=args.source,
                            dry_run=False)
        return None
    # --------------------------------------------
    # load model & dataloader
    model = kipoi.get_model(args.model, args.source)

    if args.dataloader is not None:
        Dl = kipoi.get_dataloader_factory(args.dataloader, args.dataloader_source)
    else:
        Dl = model.default_dataloader

    dataloader_kwargs = kipoi.pipeline.validate_kwargs(Dl, dataloader_kwargs)
    dl = Dl(**dataloader_kwargs)

    # setup batching
    it = dl.batch_iter(batch_size=args.batch_size,
                       num_workers=args.num_workers)

    # Setup the writers
    use_writers = []
    for output in args.output:
        writer = writers.get_writer(output, metadata_schema=dl.get_output_schema().metadata)
        if writer is None:
            logger.error("Unknown file format: {0}".format(ending))
            sys.exit()
        else:
            use_writers.append(writer)
    output_writers = writers.MultipleBatchWriter(use_writers)

    # Loop through the data, make predictions, save the output
    for i, batch in enumerate(tqdm(it)):
        # validate the data schema in the first iteration
        if i == 0 and not Dl.get_output_schema().compatible_with_batch(batch):
            logger.warning("First batch of data is not compatible with the dataloader schema.")

        # make the prediction
        if args.layer is None:
            pred_batch = model.predict_on_batch(batch['inputs'])
        else:
            pred_batch = model.predict_activation_on_batch(batch['inputs'], layer=args.layer)

        # write out the predictions, metadata (, inputs, targets)
        output_batch = prepare_batch(batch, pred_batch, keep_inputs=args.keep_inputs)
        output_writers.batch_write(output_batch)

    output_writers.close()
    logger.info('Done! Predictions stored in {0}'.format(",".join(args.output)))
コード例 #23
0
def get_model(model, source="kipoi", with_dataloader=True):
    """Load the `model` from `source`, as well as the
    default dataloder to model.default_dataloder.

    Args:
      model, str:  model name
      source, str:  source name
      with_dataloader, bool: if True, the default dataloader is
        loaded to model.default_dataloadera and the pipeline at model.pipeline enabled.
    """
    # TODO - model can be a yaml file or a directory
    source_name = source

    source = kipoi.config.get_source(source)

    # pull the model & get the model directory
    yaml_path = source.pull_model(model)
    source_dir = os.path.dirname(yaml_path)

    # Setup model description
    with cd(source_dir):
        md = ModelDescription.load(os.path.basename(yaml_path))
    # TODO - is there a way to prevent code duplication here?
    # TODO - possible to inherit from both classes and call the corresponding inits?
    # --------------------------------------------
    # TODO - load it into memory?

    # TODO - validate md.default_dataloader <-> model

    # attach the default dataloader already to the model
    if ":" in md.default_dataloader:
        dl_source, dl_path = md.default_dataloader.split(":")
    else:
        dl_source = source_name
        dl_path = md.default_dataloader

    if with_dataloader:
        # allow to use relative and absolute paths for referring to the dataloader
        default_dataloader_path = os.path.join("/" + model, dl_path)[1:]
        default_dataloader = kipoi.get_dataloader_factory(
            default_dataloader_path, dl_source)
    else:
        default_dataloader = None

    # Read the Model - append methods, attributes to self
    with cd(source_dir):  # move to the model directory temporarily
        if md.type == 'custom':
            Mod = load_model_custom(**md.args)
            assert issubclass(Mod, BaseModel)  # it should inherit from Model
            mod = Mod()
        elif md.type in AVAILABLE_MODELS:
            # TODO - this doesn't seem to work
            mod = AVAILABLE_MODELS[md.type](**md.args)
        else:
            raise ValueError("Unsupported model type: {0}. " +
                             "Model type needs to be one of: {1}".format(
                                 md.type, ['custom'] +
                                 list(AVAILABLE_MODELS.keys())))

    # populate the returned class
    mod.type = md.type
    mod.args = md.args
    mod.info = md.info
    mod.schema = md.schema
    mod.dependencies = md.dependencies
    mod.default_dataloader = default_dataloader
    mod.name = model
    mod.source = source
    mod.source_name = source_name
    mod.source_dir = source_dir
    # parse the postprocessing module
    mod.postprocessing = md.postprocessing
    if with_dataloader:
        mod.pipeline = Pipeline(model=mod, dataloader_cls=default_dataloader)
    else:
        mod.pipeline = None
    return mod
コード例 #24
0
ファイル: tlearn.py プロジェクト: kipoi/manuscript
    else:
        hidden = [int(x) for x in args.add_n_hidden.split(",")]
    # -------
    odir = Path(args.output)
    odir.mkdir(parents=True, exist_ok=True)

    if args.gpu == -1:
        gpu = GPUtil.getFirstAvailable(attempts=3, includeNan=True)[0]
    else:
        gpu = args.gpu
    create_tf_session(gpu)

    # Get the model and the dataloader
    model = kipoi.get_model(args.model, args.source)
    if args.dataloader is not None:
        Dl = kipoi.get_dataloader_factory(args.dataloader,
                                          args.dataloader_source)
    else:
        Dl = model.default_dataloader

    if not model.type == "keras":
        raise ValueError("Only keras models are supported")

    dl_train = Dl(**dl_kwargs_train)
    dl_eval = Dl(**dl_kwargs_eval)

    # ---------------
    # Setup a new model

    # Transferred part
    tmodel = Model(model.model.inputs,
                   model.model.get_layer(args.transfer_to).output)
コード例 #25
0
def cli_grad(command, raw_args):
    """CLI interface to predict
    """
    from .main import prepare_batch
    from kipoi.model import GradientMixin
    assert command == "grad"
    from tqdm import tqdm
    parser = argparse.ArgumentParser(
        'kipoi {}'.format(command),
        description='Save gradients and inputs to a hdf5 file.')
    add_model(parser)
    add_dataloader(parser, with_args=True)
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='Batch size to use in prediction')
    parser.add_argument(
        "-n",
        "--num_workers",
        type=int,
        default=0,
        help="Number of parallel workers for loading the dataset")
    parser.add_argument("-i",
                        "--install_req",
                        action='store_true',
                        help="Install required packages from requirements.txt")
    parser.add_argument(
        "-l",
        "--layer",
        default=None,
        help="Which output layer to use to make the predictions. If specified,"
        +
        "`model.predict_activation_on_batch` will be invoked instead of `model.predict_on_batch`",
        required=False)
    parser.add_argument(
        "--final_layer",
        help=
        "Alternatively to `--layer` this flag can be used to indicate that the last layer should "
        "be used.",
        action='store_true')
    parser.add_argument(
        "--pre_nonlinearity",
        help=
        "Flag indicating that it should checked whether the selected output is post activation "
        "function. If a non-linear activation function is used attempt to use its input. This "
        "feature is not available for all models.",
        action='store_true')
    parser.add_argument(
        "-f",
        "--filter_idx",
        help=
        "Filter index that should be inspected with gradients. If not set all filters will "
        + "be used.",
        default=None)
    parser.add_argument(
        "-a",
        "--avg_func",
        help=
        "Averaging function to be applied across selected filters (`--filter_idx`) in "
        + "layer `--layer`.",
        choices=GradientMixin.allowed_functions,
        default="sum")
    parser.add_argument(
        '--selected_fwd_node',
        help="If the selected layer has multiple inbound connections in "
        "the graph then those can be selected here with an integer "
        "index. Not necessarily supported by all models.",
        default=None,
        type=int)
    parser.add_argument(
        '-o',
        '--output',
        required=True,
        nargs="+",
        help=
        "Output files. File format is inferred from the file path ending. Available file formats are: "
        + ", ".join(["." + k for k in writers.FILE_SUFFIX_MAP]))
    args = parser.parse_args(raw_args)

    dataloader_kwargs = parse_json_file_str(args.dataloader_args)

    # setup the files
    if not isinstance(args.output, list):
        args.output = [args.output]
    for o in args.output:
        ending = o.split('.')[-1]
        if ending not in writers.FILE_SUFFIX_MAP:
            logger.error("File ending: {0} for file {1} not from {2}".format(
                ending, o, writers.FILE_SUFFIX_MAP))
            sys.exit(1)
        dir_exists(os.path.dirname(o), logger)
    # --------------------------------------------
    # install args
    if args.install_req:
        kipoi.pipeline.install_model_requirements(args.model,
                                                  args.source,
                                                  and_dataloaders=True)

    layer = args.layer
    if layer is None and not args.final_layer:
        raise Exception(
            "A layer has to be selected explicitely using `--layer` or implicitely by using the"
            "`--final_layer` flag.")

    # Not a good idea
    # if layer is not None and isint(layer):
    #    logger.warn("Interpreting `--layer` value as integer layer index!")
    #    layer = int(args.layer)

    # load model & dataloader
    model = kipoi.get_model(args.model, args.source)

    if not isinstance(model, GradientMixin):
        raise Exception("Model does not support gradient calculation.")

    if args.dataloader is not None:
        Dl = kipoi.get_dataloader_factory(args.dataloader,
                                          args.dataloader_source)
    else:
        Dl = model.default_dataloader

    dataloader_kwargs = kipoi.pipeline.validate_kwargs(Dl, dataloader_kwargs)
    dl = Dl(**dataloader_kwargs)

    filter_idx_parsed = None
    if args.filter_idx is not None:
        filter_idx_parsed = parse_filter_slice(args.filter_idx)

    # setup batching
    it = dl.batch_iter(batch_size=args.batch_size,
                       num_workers=args.num_workers)

    # Setup the writers
    use_writers = []
    for output in args.output:
        ending = output.split('.')[-1]
        W = writers.FILE_SUFFIX_MAP[ending]
        logger.info("Using {0} for file {1}".format(W.__name__, output))
        if ending == "tsv":
            assert W == writers.TsvBatchWriter
            use_writers.append(
                writers.TsvBatchWriter(file_path=output, nested_sep="/"))
        elif ending == "bed":
            raise Exception("Please use tsv or hdf5 output format.")
        elif ending in ["hdf5", "h5"]:
            assert W == writers.HDF5BatchWriter
            use_writers.append(writers.HDF5BatchWriter(file_path=output))
        else:
            logger.error("Unknown file format: {0}".format(ending))
            sys.exit(1)

    # Loop through the data, make predictions, save the output
    for i, batch in enumerate(tqdm(it)):
        # validate the data schema in the first iteration
        if i == 0 and not Dl.output_schema.compatible_with_batch(batch):
            logger.warn(
                "First batch of data is not compatible with the dataloader schema."
            )

        # make the prediction
        pred_batch = model.input_grad(batch['inputs'],
                                      filter_idx=filter_idx_parsed,
                                      avg_func=args.avg_func,
                                      layer=layer,
                                      final_layer=args.final_layer,
                                      selected_fwd_node=args.selected_fwd_node,
                                      pre_nonlinearity=args.pre_nonlinearity)

        # write out the predictions, metadata (, inputs, targets)
        # always keep the inputs so that input*grad can be generated!
        # output_batch = prepare_batch(batch, pred_batch, keep_inputs=True)
        output_batch = batch
        output_batch["grads"] = pred_batch
        for writer in use_writers:
            writer.batch_write(output_batch)

    for writer in use_writers:
        writer.close()
    logger.info('Done! Gradients stored in {0}'.format(",".join(args.output)))
コード例 #26
0
def cli_create_mutation_map(command, raw_args):
    """CLI interface to calculate mutation map data 
    """
    assert command == "create_mutation_map"
    parser = argparse.ArgumentParser(
        'kipoi postproc {}'.format(command),
        description='Predict effect of SNVs using ISM.')
    add_model(parser)
    add_dataloader(parser, with_args=True)
    parser.add_argument(
        '-r',
        '--regions_file',
        help='Region definition as VCF or bed file. Not a required input.')
    # TODO - rename path to fpath
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='Batch size to use in prediction')
    parser.add_argument(
        "-n",
        "--num_workers",
        type=int,
        default=0,
        help="Number of parallel workers for loading the dataset")
    parser.add_argument("-i",
                        "--install_req",
                        action='store_true',
                        help="Install required packages from requirements.txt")
    parser.add_argument(
        '-o',
        '--output',
        required=True,
        help="Output HDF5 file. To be used as input for plotting.")
    parser.add_argument(
        '-s',
        "--scores",
        default="diff",
        nargs="+",
        help=
        "Scoring method to be used. Only scoring methods selected in the model yaml file are"
        "available except for `diff` which is always available. Select scoring function by the"
        "`name` tag defined in the model yaml file.")
    parser.add_argument(
        '-k',
        "--score_kwargs",
        default="",
        nargs="+",
        help=
        "JSON definition of the kwargs for the scoring functions selected in --scores. The "
        "definiton can either be in JSON in the command line or the path of a .json file. The "
        "individual JSONs are expected to be supplied in the same order as the labels defined in "
        "--scores. If the defaults or no arguments should be used define '{}' for that respective "
        "scoring method.")
    parser.add_argument(
        '-l',
        "--seq_length",
        type=int,
        default=None,
        help=
        "Optional parameter: Model input sequence length - necessary if the model does not have a "
        "pre-defined input sequence length.")

    args = parser.parse_args(raw_args)

    # extract args for kipoi.variant_effects.predict_snvs

    dataloader_arguments = parse_json_file_str(args.dataloader_args)

    if args.output is None:
        raise Exception("Output file `--output` has to be set!")

    # --------------------------------------------
    # install args
    if args.install_req:
        kipoi.pipeline.install_model_requirements(args.model,
                                                  args.source,
                                                  and_dataloaders=True)
    # load model & dataloader
    model = kipoi.get_model(args.model, args.source)

    regions_file = os.path.realpath(args.regions_file)
    output = os.path.realpath(args.output)
    with cd(model.source_dir):
        if not os.path.exists(regions_file):
            raise Exception("Regions inputs file does not exist: %s" %
                            args.regions_file)

        # Check that all the folders exist
        file_exists(regions_file, logger)
        dir_exists(os.path.dirname(output), logger)

        if args.dataloader is not None:
            Dl = kipoi.get_dataloader_factory(args.dataloader,
                                              args.dataloader_source)
        else:
            Dl = model.default_dataloader

    if not isinstance(args.scores, list):
        args.scores = [args.scores]

    dts = get_scoring_fns(model, args.scores, args.score_kwargs)

    # Load effect prediction related model info
    model_info = kipoi.postprocessing.variant_effects.ModelInfoExtractor(
        model, Dl)
    manual_seq_len = args.seq_length

    # Select the appropriate region generator and vcf or bed file input
    args.file_format = regions_file.split(".")[-1]
    bed_region_file = None
    vcf_region_file = None
    bed_to_region = None
    vcf_to_region = None
    if args.file_format == "vcf" or regions_file.endswith("vcf.gz"):
        vcf_region_file = regions_file
        if model_info.requires_region_definition:
            # Select the SNV-centered region generator
            vcf_to_region = kipoi.postprocessing.variant_effects.SnvCenteredRg(
                model_info, seq_length=manual_seq_len)
            logger.info('Using variant-centered sequence generation.')
    elif args.file_format == "bed":
        if model_info.requires_region_definition:
            # Select the SNV-centered region generator
            bed_to_region = kipoi.postprocessing.variant_effects.BedOverlappingRg(
                model_info, seq_length=manual_seq_len)
            logger.info('Using bed-file based sequence generation.')
        bed_region_file = regions_file
    else:
        raise Exception("")

    if model_info.use_seq_only_rc:
        logger.info(
            'Model SUPPORTS simple reverse complementation of input DNA sequences.'
        )
    else:
        logger.info(
            'Model DOES NOT support simple reverse complementation of input DNA sequences.'
        )

    from kipoi.postprocessing.variant_effects.mutation_map import _generate_mutation_map
    mdmm = _generate_mutation_map(
        model,
        Dl,
        vcf_fpath=vcf_region_file,
        bed_fpath=bed_region_file,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        dataloader_args=dataloader_arguments,
        vcf_to_region=vcf_to_region,
        bed_to_region=bed_to_region,
        evaluation_function_kwargs={'diff_types': dts},
    )
    mdmm.save_to_file(output)

    logger.info('Successfully generated mutation map data')
コード例 #27
0
def cli_score_variants(command, raw_args):
    """CLI interface to score variants
    """
    # Updated argument names:
    # - scoring -> scores
    # - --vcf_path -> --input_vcf, -i
    # - --out_vcf_fpath -> --output_vcf, -o
    # - --output -> -e, --extra_output
    # - remove - -install_req
    # - scoring_kwargs -> score_kwargs
    AVAILABLE_FORMATS = ["tsv", "hdf5", "h5"]
    assert command == "score_variants"
    parser = argparse.ArgumentParser(
        'kipoi postproc {}'.format(command),
        description='Predict effect of SNVs using ISM.')
    parser.add_argument('model', help='Model name.', nargs="+")
    parser.add_argument(
        '--source',
        default=["kipoi"],
        nargs="+",
        choices=list(kipoi.config.model_sources().keys()),
        help='Model source to use. Specified in ~/.kipoi/config.yaml' +
        " under model_sources. " +
        "'dir' is an additional source referring to the local folder.")
    parser.add_argument(
        '--dataloader',
        nargs="+",
        default=[],
        help="Dataloader name. If not specified, the model's default" +
        "DataLoader will be used")
    parser.add_argument('--dataloader_source',
                        nargs="+",
                        default=["kipoi"],
                        help="Dataloader source")
    parser.add_argument('--dataloader_args',
                        nargs="+",
                        default=[],
                        help="Dataloader arguments either as a json string:" +
                        "'{\"arg1\": 1} or as a file path to a json file")
    parser.add_argument('-i', '--input_vcf', help='Input VCF.')
    parser.add_argument('-o',
                        '--output_vcf',
                        help='Output annotated VCF file path.',
                        default=None)
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='Batch size to use in prediction')
    parser.add_argument(
        "-n",
        "--num_workers",
        type=int,
        default=0,
        help="Number of parallel workers for loading the dataset")
    parser.add_argument(
        '-r',
        '--restriction_bed',
        default=None,
        help="Regions for prediction can only be subsets of this bed file")
    parser.add_argument(
        '-e',
        '--extra_output',
        required=False,
        help=
        "Additional output file. File format is inferred from the file path ending"
        + ". Available file formats are: {0}".format(
            ",".join(AVAILABLE_FORMATS)))
    parser.add_argument(
        '-s',
        "--scores",
        default="diff",
        nargs="+",
        help=
        "Scoring method to be used. Only scoring methods selected in the model yaml file are"
        "available except for `diff` which is always available. Select scoring function by the"
        "`name` tag defined in the model yaml file.")
    parser.add_argument(
        '-k',
        "--score_kwargs",
        default="",
        nargs="+",
        help=
        "JSON definition of the kwargs for the scoring functions selected in --scoring. The "
        "definiton can either be in JSON in the command line or the path of a .json file. The "
        "individual JSONs are expected to be supplied in the same order as the labels defined in "
        "--scoring. If the defaults or no arguments should be used define '{}' for that respective "
        "scoring method.")
    parser.add_argument(
        '-l',
        "--seq_length",
        type=int,
        nargs="+",
        default=[],
        help=
        "Optional parameter: Model input sequence length - necessary if the model does not have a "
        "pre-defined input sequence length.")
    parser.add_argument(
        '--std_var_id',
        action="store_true",
        help="If set then variant IDs in the annotated"
        " VCF will be replaced with a standardised, unique ID.")

    args = parser.parse_args(raw_args)
    # Make sure all the multi-model arguments like source, dataloader etc. fit together
    _prepare_multi_model_args(args)

    # Check that all the folders exist
    file_exists(args.input_vcf, logger)
    dir_exists(os.path.dirname(args.output_vcf), logger)
    if args.extra_output is not None:
        dir_exists(os.path.dirname(args.extra_output), logger)

        # infer the file format
        args.file_format = args.extra_output.split(".")[-1]
        if args.file_format not in AVAILABLE_FORMATS:
            logger.error("File ending: {0} for file {1} not from {2}".format(
                args.file_format, args.extra_output, AVAILABLE_FORMATS))
            sys.exit(1)

        if args.file_format in ["hdf5", "h5"]:
            # only if hdf5 output is used
            import deepdish

    if not isinstance(args.scores, list):
        args.scores = [args.scores]

    score_kwargs = []
    if len(args.score_kwargs) > 0:
        score_kwargs = args.score_kwargs
        if len(args.scores) >= 1:
            # Check if all scoring functions should be used:
            if args.scores == ["all"]:
                if len(score_kwargs) >= 1:
                    raise ValueError(
                        "`--score_kwargs` cannot be defined in combination will `--scoring all`!"
                    )
            else:
                score_kwargs = [parse_json_file_str(el) for el in score_kwargs]
                if not len(args.score_kwargs) == len(score_kwargs):
                    raise ValueError(
                        "When defining `--score_kwargs` a JSON representation of arguments (or the "
                        "path of a file containing them) must be given for every "
                        "`--scores` function.")

    keep_predictions = args.extra_output is not None

    n_models = len(args.model)

    res = {}
    for model_name, model_source, dataloader, dataloader_source, dataloader_args, seq_length in zip(
            args.model, args.source, args.dataloader, args.dataloader_source,
            args.dataloader_args, args.seq_length):
        model_name_safe = model_name.replace("/", "_")
        output_vcf_model = None
        if args.output_vcf is not None:
            output_vcf_model = args.output_vcf
            # If multiple models are to be analysed then vcfs need renaming.
            if n_models > 1:
                if output_vcf_model.endswith(".vcf"):
                    output_vcf_model = output_vcf_model[:-4]
                output_vcf_model += model_name_safe + ".vcf"

        dataloader_arguments = parse_json_file_str(dataloader_args)

        # --------------------------------------------
        # load model & dataloader
        model = kipoi.get_model(model_name, model_source)

        if dataloader is not None:
            Dl = kipoi.get_dataloader_factory(dataloader, dataloader_source)
        else:
            Dl = model.default_dataloader

        # Load effect prediction related model info
        model_info = kipoi.postprocessing.variant_effects.ModelInfoExtractor(
            model, Dl)

        if model_info.use_seq_only_rc:
            logger.info(
                'Model SUPPORTS simple reverse complementation of input DNA sequences.'
            )
        else:
            logger.info(
                'Model DOES NOT support simple reverse complementation of input DNA sequences.'
            )

        if output_vcf_model is not None:
            logger.info('Annotated VCF will be written to %s.' %
                        str(output_vcf_model))

        res[model_name_safe] = kipoi.postprocessing.variant_effects.score_variants(
            model,
            dataloader_arguments,
            args.input_vcf,
            output_vcf_model,
            scores=args.scores,
            score_kwargs=score_kwargs,
            num_workers=args.num_workers,
            batch_size=args.batch_size,
            seq_length=seq_length,
            std_var_id=args.std_var_id,
            restriction_bed=args.restriction_bed,
            return_predictions=keep_predictions)

    # tabular files
    if keep_predictions:
        if args.file_format in ["tsv"]:
            for model_name in res:
                for i, k in enumerate(res[model_name]):
                    # Remove an old file if it is still there...
                    if i == 0:
                        try:
                            os.unlink(args.extra_output)
                        except Exception:
                            pass
                    with open(args.extra_output, "w") as ofh:
                        ofh.write("KPVEP_%s:%s\n" % (k.upper(), model_name))
                        res[model_name][k].to_csv(args.extra_output,
                                                  sep="\t",
                                                  mode="a")

        if args.file_format in ["hdf5", "h5"]:
            deepdish.io.save(args.extra_output, res)

    logger.info('Successfully predicted samples')
コード例 #28
0
ファイル: test_011_remotes.py プロジェクト: yynst2/kipoi
def test_load_models_local():
    model = "example/models/iris_model_template"
    kipoi.get_model(model, source="dir")
    kipoi.get_dataloader_factory(model, source="dir")
コード例 #29
0
def cli_score_variants(command, raw_args):
    """CLI interface to score variants
    """
    # Updated argument names:
    # - scoring -> scores
    # - --vcf_path -> --input_vcf, -i
    # - --out_vcf_fpath -> --output_vcf, -o
    # - --output -> -e, --extra_output
    # - remove - -install_req
    # - scoring_kwargs -> score_kwargs
    AVAILABLE_FORMATS = [k for k in writers.FILE_SUFFIX_MAP if k != 'bed']
    assert command == "score_variants"
    parser = argparse.ArgumentParser(
        'kipoi veff {}'.format(command),
        description='Predict effect of SNVs using ISM.')
    parser.add_argument('model', help='Model name.')
    parser.add_argument(
        '--source',
        default="kipoi",
        choices=list(kipoi.config.model_sources().keys()),
        help='Model source to use. Specified in ~/.kipoi/config.yaml' +
        " under model_sources. " +
        "'dir' is an additional source referring to the local folder.")

    add_dataloader(parser=parser, with_args=True)

    parser.add_argument('-i', '--input_vcf', required=True, help='Input VCF.')
    parser.add_argument('-o',
                        '--output_vcf',
                        help='Output annotated VCF file path.',
                        default=None)
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='Batch size to use in prediction')
    parser.add_argument(
        "-n",
        "--num_workers",
        type=int,
        default=0,
        help="Number of parallel workers for loading the dataset")
    parser.add_argument(
        '-r',
        '--restriction_bed',
        default=None,
        help="Regions for prediction can only be subsets of this bed file")
    parser.add_argument(
        '-e',
        '--extra_output',
        type=str,
        default=None,
        required=False,
        help=
        "Additional output files in other (non-vcf) formats. File format is inferred from the file path ending"
        + ". Available file formats are: {0}".format(", ".join(
            ["." + k for k in AVAILABLE_FORMATS])))
    parser.add_argument(
        '-s',
        "--scores",
        default="diff",
        nargs="+",
        help=
        "Scoring method to be used. Only scoring methods selected in the model yaml file are"
        "available except for `diff` which is always available. Select scoring function by the"
        "`name` tag defined in the model yaml file.")
    parser.add_argument(
        '-k',
        "--score_kwargs",
        default="",
        nargs="+",
        help=
        "JSON definition of the kwargs for the scoring functions selected in --scoring. The "
        "definiton can either be in JSON in the command line or the path of a .json file. The "
        "individual JSONs are expected to be supplied in the same order as the labels defined in "
        "--scoring. If the defaults or no arguments should be used define '{}' for that respective "
        "scoring method.")
    parser.add_argument(
        '-l',
        "--seq_length",
        type=int,
        default=None,
        help=
        "Optional parameter: Model input sequence length - necessary if the model does not have a "
        "pre-defined input sequence length.")
    parser.add_argument(
        '--std_var_id',
        action="store_true",
        help="If set then variant IDs in the annotated"
        " VCF will be replaced with a standardised, unique ID.")

    parser.add_argument(
        "--model_outputs",
        type=str,
        default=None,
        nargs="+",
        help=
        "Optional parameter: Only return predictions for the selected model outputs. Naming"
        "according to the definition in model.yaml > schema > targets > column_labels"
    )

    parser.add_argument(
        "--model_outputs_i",
        type=int,
        default=None,
        nargs="+",
        help=
        "Optional parameter: Only return predictions for the selected model outputs. Give integer"
        "indices of the selected model output(s).")

    parser.add_argument(
        "--singularity",
        action='store_true',
        help="Run `kipoi predict` in the appropriate singularity container. "
        "Containters will get downloaded to ~/.kipoi/envs/ or to "
        "$SINGULARITY_CACHEDIR if set")

    args = parser.parse_args(raw_args)

    # OBSOLETE
    # Make sure all the multi-model arguments like source, dataloader etc. fit together
    #_prepare_multi_model_args(args)

    # Check that all the folders exist
    file_exists(args.input_vcf, logger)

    if args.output_vcf is None and args.extra_output is None:
        logger.error(
            "One of the two needs to be specified: --output_vcf or --extra_output"
        )
        sys.exit(1)

    if args.extra_output is not None:
        dir_exists(os.path.dirname(args.extra_output), logger)
        ending = args.extra_output.split('.')[-1]
        if ending not in AVAILABLE_FORMATS:
            logger.error("File ending: {0} for file {1} not from {2}".format(
                ending, args.extra_output, AVAILABLE_FORMATS))
            sys.exit(1)

    # singularity_command
    if args.singularity:
        from kipoi.cli.singularity import singularity_command
        logger.info(
            "Running kipoi veff {} in the singularity container".format(
                command))

        # Drop the singularity flag
        raw_args = [x for x in raw_args if x != '--singularity']

        dataloader_kwargs = parse_json_file_str_or_arglist(
            args.dataloader_args)

        # create output files
        output_files = []
        if args.output_vcf is not None:
            output_files.append(args.output_vcf)
        if args.extra_output is not None:
            output_files.append(args.extra_output)

        singularity_command(['kipoi', 'veff', command] + raw_args,
                            model=args.model,
                            dataloader_kwargs=dataloader_kwargs,
                            output_files=output_files,
                            source=args.source,
                            dry_run=False)
        return None

    if not isinstance(args.scores, list):
        args.scores = [args.scores]

    score_kwargs = []
    if len(args.score_kwargs) > 0:
        score_kwargs = args.score_kwargs
        if len(args.scores) >= 1:
            # Check if all scoring functions should be used:
            if args.scores == ["all"]:
                if len(score_kwargs) >= 1:
                    raise ValueError(
                        "`--score_kwargs` cannot be defined in combination will `--scoring all`!"
                    )
            else:
                score_kwargs = [parse_json_file_str(el) for el in score_kwargs]
                if not len(args.score_kwargs) == len(score_kwargs):
                    raise ValueError(
                        "When defining `--score_kwargs` a JSON representation of arguments (or the "
                        "path of a file containing them) must be given for every "
                        "`--scores` function.")

    # VCF writer
    output_vcf_model = None
    if args.output_vcf is not None:
        dir_exists(os.path.dirname(args.output_vcf), logger)
        output_vcf_model = args.output_vcf

    # Other writers
    if args.extra_output is not None:
        dir_exists(os.path.dirname(args.extra_output), logger)
        extra_output = args.extra_output
        writer = writers.get_writer(extra_output, metadata_schema=None)
        assert writer is not None
        extra_writers = [SyncBatchWriter(writer)]
    else:
        extra_writers = []

    dataloader_arguments = parse_json_file_str_or_arglist(args.dataloader_args)

    # --------------------------------------------
    # load model & dataloader
    model = kipoi.get_model(args.model, args.source)

    if args.dataloader is not None:
        Dl = kipoi.get_dataloader_factory(args.dataloader,
                                          args.dataloader_source)
    else:
        Dl = model.default_dataloader

    # Load effect prediction related model info
    model_info = kipoi_veff.ModelInfoExtractor(model, Dl)

    if model_info.use_seq_only_rc:
        logger.info(
            'Model SUPPORTS simple reverse complementation of input DNA sequences.'
        )
    else:
        logger.info(
            'Model DOES NOT support simple reverse complementation of input DNA sequences.'
        )

    if output_vcf_model is not None:
        logger.info('Annotated VCF will be written to %s.' %
                    str(output_vcf_model))

    model_outputs = None
    if args.model_outputs is not None:
        model_outputs = args.model_outputs

    elif args.model_outputs_i is not None:
        model_outputs = args.model_outputs_i

    kipoi_veff.score_variants(model,
                              dataloader_arguments,
                              args.input_vcf,
                              output_vcf=output_vcf_model,
                              output_writers=extra_writers,
                              scores=args.scores,
                              score_kwargs=score_kwargs,
                              num_workers=args.num_workers,
                              batch_size=args.batch_size,
                              seq_length=args.seq_length,
                              std_var_id=args.std_var_id,
                              restriction_bed=args.restriction_bed,
                              return_predictions=False,
                              model_outputs=model_outputs)

    logger.info('Successfully predicted samples')
コード例 #30
0
ファイル: cli.py プロジェクト: kipoi/kipoi-interpret
def cli_feature_importance(command, raw_args):
    """CLI interface to predict
    """
    # from .main import prepare_batch
    assert command == "feature_importance"
    parser = argparse.ArgumentParser(
        'kipoi {}'.format(command),
        description='Save gradients and inputs to a hdf5 file.')
    add_model(parser)
    add_dataloader(parser, with_args=True)
    parser.add_argument("--imp_score",
                        help="Importance score name",
                        choices=available_importance_scores())
    parser.add_argument("--imp_score_kwargs", help="Importance score kwargs")
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='Batch size to use in prediction')
    parser.add_argument(
        "-n",
        "--num_workers",
        type=int,
        default=0,
        help="Number of parallel workers for loading the dataset")
    # TODO - handle the reference-based importance scores...

    # io
    parser.add_argument(
        '-o',
        '--output',
        required=True,
        nargs="+",
        help=
        "Output files. File format is inferred from the file path ending. Available file formats are: "
        + ", ".join(["." + k for k in writers.FILE_SUFFIX_MAP]))
    args = parser.parse_args(raw_args)

    dataloader_kwargs = parse_json_file_str(args.dataloader_args)
    imp_score_kwargs = parse_json_file_str(args.imp_score_kwargs)

    # setup the files
    if not isinstance(args.output, list):
        args.output = [args.output]
    for o in args.output:
        ending = o.split('.')[-1]
        if ending not in writers.FILE_SUFFIX_MAP:
            logger.error("File ending: {0} for file {1} not from {2}".format(
                ending, o, writers.FILE_SUFFIX_MAP))
            sys.exit(1)
        dir_exists(os.path.dirname(o), logger)
    # --------------------------------------------
    # install args
    if args.install_req:
        kipoi.pipeline.install_model_requirements(args.model,
                                                  args.source,
                                                  and_dataloaders=True)

    # load model & dataloader
    model = kipoi.get_model(args.model,
                            args.source,
                            with_dataloader=args.dataloader is None)

    if args.dataloader is not None:
        Dl = kipoi.get_dataloader_factory(args.dataloader,
                                          args.dataloader_source)
    else:
        Dl = model.default_dataloader

    dataloader_kwargs = kipoi.pipeline.validate_kwargs(Dl, dataloader_kwargs)
    dl = Dl(**dataloader_kwargs)

    # get_importance_score
    ImpScore = get_importance_score(args.imp_score)
    if not ImpScore.is_compatible(model):
        raise ValueError("model not compatible with score: {0}".format(
            args.imp_score))
    impscore = ImpScore(model, **imp_score_kwargs)

    # setup batching
    it = dl.batch_iter(batch_size=args.batch_size,
                       num_workers=args.num_workers)

    # Setup the writers
    use_writers = []
    for output in args.output:
        ending = output.split('.')[-1]
        W = writers.FILE_SUFFIX_MAP[ending]
        logger.info("Using {0} for file {1}".format(W.__name__, output))
        if ending == "tsv":
            assert W == writers.TsvBatchWriter
            use_writers.append(
                writers.TsvBatchWriter(file_path=output, nested_sep="/"))
        elif ending == "bed":
            raise Exception("Please use tsv or hdf5 output format.")
        elif ending in ["hdf5", "h5"]:
            assert W == writers.HDF5BatchWriter
            use_writers.append(writers.HDF5BatchWriter(file_path=output))
        else:
            logger.error("Unknown file format: {0}".format(ending))
            sys.exit(1)

    # Loop through the data, make predictions, save the output
    for i, batch in enumerate(tqdm(it)):
        # validate the data schema in the first iteration
        if i == 0 and not Dl.output_schema.compatible_with_batch(batch):
            logger.warn(
                "First batch of data is not compatible with the dataloader schema."
            )

        # make the prediction
        # TODO - handle the reference-based importance scores...
        importance_scores = impscore.score(batch['inputs'])

        # write out the predictions, metadata (, inputs, targets)
        # always keep the inputs so that input*grad can be generated!
        # output_batch = prepare_batch(batch, pred_batch, keep_inputs=True)
        output_batch = batch
        output_batch["importance_scores"] = importance_scores
        for writer in use_writers:
            writer.batch_write(output_batch)

    for writer in use_writers:
        writer.close()
    logger.info('Done! Importance scores stored in {0}'.format(",".join(
        args.output)))