"--num_workers", type=int, default=0, help="Number of parallel workers for loading the dataset") parser.add_argument('-p', '--patience', default=10, type=int, help='Early stopping patience') parser.add_argument( '--gpu', default=-1, type=int, help='Which gpu to use. If -1, determine automatically') args = parser.parse_args() dl_kwargs_train = parse_json_file_str(args.dl_kwargs_train) dl_kwargs_eval = parse_json_file_str(args.dl_kwargs_eval) if args.add_n_hidden == "": hidden = [] else: hidden = [int(x) for x in args.add_n_hidden.split(",")] # ------- odir = Path(args.output) odir.mkdir(parents=True, exist_ok=True) if args.gpu == -1: gpu = GPUtil.getFirstAvailable(attempts=3, includeNan=True)[0] else: gpu = args.gpu create_tf_session(gpu)
def cli_create_mutation_map(command, raw_args): """CLI interface to calculate mutation map data """ assert command == "create_mutation_map" parser = argparse.ArgumentParser( 'kipoi postproc {}'.format(command), description='Predict effect of SNVs using ISM.') add_model(parser) add_dataloader(parser, with_args=True) parser.add_argument( '-r', '--regions_file', help='Region definition as VCF or bed file. Not a required input.') # TODO - rename path to fpath parser.add_argument('--batch_size', type=int, default=32, help='Batch size to use in prediction') parser.add_argument( "-n", "--num_workers", type=int, default=0, help="Number of parallel workers for loading the dataset") parser.add_argument("-i", "--install_req", action='store_true', help="Install required packages from requirements.txt") parser.add_argument( '-o', '--output', required=True, help="Output HDF5 file. To be used as input for plotting.") parser.add_argument( '-s', "--scores", default="diff", nargs="+", help= "Scoring method to be used. Only scoring methods selected in the model yaml file are" "available except for `diff` which is always available. Select scoring function by the" "`name` tag defined in the model yaml file.") parser.add_argument( '-k', "--score_kwargs", default="", nargs="+", help= "JSON definition of the kwargs for the scoring functions selected in --scores. The " "definiton can either be in JSON in the command line or the path of a .json file. The " "individual JSONs are expected to be supplied in the same order as the labels defined in " "--scores. If the defaults or no arguments should be used define '{}' for that respective " "scoring method.") parser.add_argument( '-l', "--seq_length", type=int, default=None, help= "Optional parameter: Model input sequence length - necessary if the model does not have a " "pre-defined input sequence length.") args = parser.parse_args(raw_args) # extract args for kipoi.variant_effects.predict_snvs dataloader_arguments = parse_json_file_str(args.dataloader_args) if args.output is None: raise Exception("Output file `--output` has to be set!") # -------------------------------------------- # install args if args.install_req: kipoi.pipeline.install_model_requirements(args.model, args.source, and_dataloaders=True) # load model & dataloader model = kipoi.get_model(args.model, args.source) regions_file = os.path.realpath(args.regions_file) output = os.path.realpath(args.output) with cd(model.source_dir): if not os.path.exists(regions_file): raise Exception("Regions inputs file does not exist: %s" % args.regions_file) # Check that all the folders exist file_exists(regions_file, logger) dir_exists(os.path.dirname(output), logger) if args.dataloader is not None: Dl = kipoi.get_dataloader_factory(args.dataloader, args.dataloader_source) else: Dl = model.default_dataloader if not isinstance(args.scores, list): args.scores = [args.scores] dts = get_scoring_fns(model, args.scores, args.score_kwargs) # Load effect prediction related model info model_info = kipoi.postprocessing.variant_effects.ModelInfoExtractor( model, Dl) manual_seq_len = args.seq_length # Select the appropriate region generator and vcf or bed file input args.file_format = regions_file.split(".")[-1] bed_region_file = None vcf_region_file = None bed_to_region = None vcf_to_region = None if args.file_format == "vcf" or regions_file.endswith("vcf.gz"): vcf_region_file = regions_file if model_info.requires_region_definition: # Select the SNV-centered region generator vcf_to_region = kipoi.postprocessing.variant_effects.SnvCenteredRg( model_info, seq_length=manual_seq_len) logger.info('Using variant-centered sequence generation.') elif args.file_format == "bed": if model_info.requires_region_definition: # Select the SNV-centered region generator bed_to_region = kipoi.postprocessing.variant_effects.BedOverlappingRg( model_info, seq_length=manual_seq_len) logger.info('Using bed-file based sequence generation.') bed_region_file = regions_file else: raise Exception("") if model_info.use_seq_only_rc: logger.info( 'Model SUPPORTS simple reverse complementation of input DNA sequences.' ) else: logger.info( 'Model DOES NOT support simple reverse complementation of input DNA sequences.' ) from kipoi.postprocessing.variant_effects.mutation_map import _generate_mutation_map mdmm = _generate_mutation_map( model, Dl, vcf_fpath=vcf_region_file, bed_fpath=bed_region_file, batch_size=args.batch_size, num_workers=args.num_workers, dataloader_args=dataloader_arguments, vcf_to_region=vcf_to_region, bed_to_region=bed_to_region, evaluation_function_kwargs={'diff_types': dts}, ) mdmm.save_to_file(output) logger.info('Successfully generated mutation map data')
def cli_score_variants(command, raw_args): """CLI interface to score variants """ # Updated argument names: # - scoring -> scores # - --vcf_path -> --input_vcf, -i # - --out_vcf_fpath -> --output_vcf, -o # - --output -> -e, --extra_output # - remove - -install_req # - scoring_kwargs -> score_kwargs AVAILABLE_FORMATS = ["tsv", "hdf5", "h5"] assert command == "score_variants" parser = argparse.ArgumentParser( 'kipoi postproc {}'.format(command), description='Predict effect of SNVs using ISM.') parser.add_argument('model', help='Model name.', nargs="+") parser.add_argument( '--source', default=["kipoi"], nargs="+", choices=list(kipoi.config.model_sources().keys()), help='Model source to use. Specified in ~/.kipoi/config.yaml' + " under model_sources. " + "'dir' is an additional source referring to the local folder.") parser.add_argument( '--dataloader', nargs="+", default=[], help="Dataloader name. If not specified, the model's default" + "DataLoader will be used") parser.add_argument('--dataloader_source', nargs="+", default=["kipoi"], help="Dataloader source") parser.add_argument('--dataloader_args', nargs="+", default=[], help="Dataloader arguments either as a json string:" + "'{\"arg1\": 1} or as a file path to a json file") parser.add_argument('-i', '--input_vcf', help='Input VCF.') parser.add_argument('-o', '--output_vcf', help='Output annotated VCF file path.', default=None) parser.add_argument('--batch_size', type=int, default=32, help='Batch size to use in prediction') parser.add_argument( "-n", "--num_workers", type=int, default=0, help="Number of parallel workers for loading the dataset") parser.add_argument( '-r', '--restriction_bed', default=None, help="Regions for prediction can only be subsets of this bed file") parser.add_argument( '-e', '--extra_output', required=False, help= "Additional output file. File format is inferred from the file path ending" + ". Available file formats are: {0}".format( ",".join(AVAILABLE_FORMATS))) parser.add_argument( '-s', "--scores", default="diff", nargs="+", help= "Scoring method to be used. Only scoring methods selected in the model yaml file are" "available except for `diff` which is always available. Select scoring function by the" "`name` tag defined in the model yaml file.") parser.add_argument( '-k', "--score_kwargs", default="", nargs="+", help= "JSON definition of the kwargs for the scoring functions selected in --scoring. The " "definiton can either be in JSON in the command line or the path of a .json file. The " "individual JSONs are expected to be supplied in the same order as the labels defined in " "--scoring. If the defaults or no arguments should be used define '{}' for that respective " "scoring method.") parser.add_argument( '-l', "--seq_length", type=int, nargs="+", default=[], help= "Optional parameter: Model input sequence length - necessary if the model does not have a " "pre-defined input sequence length.") parser.add_argument( '--std_var_id', action="store_true", help="If set then variant IDs in the annotated" " VCF will be replaced with a standardised, unique ID.") args = parser.parse_args(raw_args) # Make sure all the multi-model arguments like source, dataloader etc. fit together _prepare_multi_model_args(args) # Check that all the folders exist file_exists(args.input_vcf, logger) dir_exists(os.path.dirname(args.output_vcf), logger) if args.extra_output is not None: dir_exists(os.path.dirname(args.extra_output), logger) # infer the file format args.file_format = args.extra_output.split(".")[-1] if args.file_format not in AVAILABLE_FORMATS: logger.error("File ending: {0} for file {1} not from {2}".format( args.file_format, args.extra_output, AVAILABLE_FORMATS)) sys.exit(1) if args.file_format in ["hdf5", "h5"]: # only if hdf5 output is used import deepdish if not isinstance(args.scores, list): args.scores = [args.scores] score_kwargs = [] if len(args.score_kwargs) > 0: score_kwargs = args.score_kwargs if len(args.scores) >= 1: # Check if all scoring functions should be used: if args.scores == ["all"]: if len(score_kwargs) >= 1: raise ValueError( "`--score_kwargs` cannot be defined in combination will `--scoring all`!" ) else: score_kwargs = [parse_json_file_str(el) for el in score_kwargs] if not len(args.score_kwargs) == len(score_kwargs): raise ValueError( "When defining `--score_kwargs` a JSON representation of arguments (or the " "path of a file containing them) must be given for every " "`--scores` function.") keep_predictions = args.extra_output is not None n_models = len(args.model) res = {} for model_name, model_source, dataloader, dataloader_source, dataloader_args, seq_length in zip( args.model, args.source, args.dataloader, args.dataloader_source, args.dataloader_args, args.seq_length): model_name_safe = model_name.replace("/", "_") output_vcf_model = None if args.output_vcf is not None: output_vcf_model = args.output_vcf # If multiple models are to be analysed then vcfs need renaming. if n_models > 1: if output_vcf_model.endswith(".vcf"): output_vcf_model = output_vcf_model[:-4] output_vcf_model += model_name_safe + ".vcf" dataloader_arguments = parse_json_file_str(dataloader_args) # -------------------------------------------- # load model & dataloader model = kipoi.get_model(model_name, model_source) if dataloader is not None: Dl = kipoi.get_dataloader_factory(dataloader, dataloader_source) else: Dl = model.default_dataloader # Load effect prediction related model info model_info = kipoi.postprocessing.variant_effects.ModelInfoExtractor( model, Dl) if model_info.use_seq_only_rc: logger.info( 'Model SUPPORTS simple reverse complementation of input DNA sequences.' ) else: logger.info( 'Model DOES NOT support simple reverse complementation of input DNA sequences.' ) if output_vcf_model is not None: logger.info('Annotated VCF will be written to %s.' % str(output_vcf_model)) res[model_name_safe] = kipoi.postprocessing.variant_effects.score_variants( model, dataloader_arguments, args.input_vcf, output_vcf_model, scores=args.scores, score_kwargs=score_kwargs, num_workers=args.num_workers, batch_size=args.batch_size, seq_length=seq_length, std_var_id=args.std_var_id, restriction_bed=args.restriction_bed, return_predictions=keep_predictions) # tabular files if keep_predictions: if args.file_format in ["tsv"]: for model_name in res: for i, k in enumerate(res[model_name]): # Remove an old file if it is still there... if i == 0: try: os.unlink(args.extra_output) except Exception: pass with open(args.extra_output, "w") as ofh: ofh.write("KPVEP_%s:%s\n" % (k.upper(), model_name)) res[model_name][k].to_csv(args.extra_output, sep="\t", mode="a") if args.file_format in ["hdf5", "h5"]: deepdish.io.save(args.extra_output, res) logger.info('Successfully predicted samples')
def cli_grad(command, raw_args): """CLI interface to predict """ from .main import prepare_batch from kipoi.model import GradientMixin assert command == "grad" from tqdm import tqdm parser = argparse.ArgumentParser( 'kipoi {}'.format(command), description='Save gradients and inputs to a hdf5 file.') add_model(parser) add_dataloader(parser, with_args=True) parser.add_argument('--batch_size', type=int, default=32, help='Batch size to use in prediction') parser.add_argument( "-n", "--num_workers", type=int, default=0, help="Number of parallel workers for loading the dataset") parser.add_argument("-i", "--install_req", action='store_true', help="Install required packages from requirements.txt") parser.add_argument( "-l", "--layer", default=None, help="Which output layer to use to make the predictions. If specified," + "`model.predict_activation_on_batch` will be invoked instead of `model.predict_on_batch`", required=False) parser.add_argument( "--final_layer", help= "Alternatively to `--layer` this flag can be used to indicate that the last layer should " "be used.", action='store_true') parser.add_argument( "--pre_nonlinearity", help= "Flag indicating that it should checked whether the selected output is post activation " "function. If a non-linear activation function is used attempt to use its input. This " "feature is not available for all models.", action='store_true') parser.add_argument( "-f", "--filter_idx", help= "Filter index that should be inspected with gradients. If not set all filters will " + "be used.", default=None) parser.add_argument( "-a", "--avg_func", help= "Averaging function to be applied across selected filters (`--filter_idx`) in " + "layer `--layer`.", choices=GradientMixin.allowed_functions, default="sum") parser.add_argument( '--selected_fwd_node', help="If the selected layer has multiple inbound connections in " "the graph then those can be selected here with an integer " "index. Not necessarily supported by all models.", default=None, type=int) parser.add_argument( '-o', '--output', required=True, nargs="+", help= "Output files. File format is inferred from the file path ending. Available file formats are: " + ", ".join(["." + k for k in writers.FILE_SUFFIX_MAP])) args = parser.parse_args(raw_args) dataloader_kwargs = parse_json_file_str(args.dataloader_args) # setup the files if not isinstance(args.output, list): args.output = [args.output] for o in args.output: ending = o.split('.')[-1] if ending not in writers.FILE_SUFFIX_MAP: logger.error("File ending: {0} for file {1} not from {2}".format( ending, o, writers.FILE_SUFFIX_MAP)) sys.exit(1) dir_exists(os.path.dirname(o), logger) # -------------------------------------------- # install args if args.install_req: kipoi.pipeline.install_model_requirements(args.model, args.source, and_dataloaders=True) layer = args.layer if layer is None and not args.final_layer: raise Exception( "A layer has to be selected explicitely using `--layer` or implicitely by using the" "`--final_layer` flag.") # Not a good idea # if layer is not None and isint(layer): # logger.warn("Interpreting `--layer` value as integer layer index!") # layer = int(args.layer) # load model & dataloader model = kipoi.get_model(args.model, args.source) if not isinstance(model, GradientMixin): raise Exception("Model does not support gradient calculation.") if args.dataloader is not None: Dl = kipoi.get_dataloader_factory(args.dataloader, args.dataloader_source) else: Dl = model.default_dataloader dataloader_kwargs = kipoi.pipeline.validate_kwargs(Dl, dataloader_kwargs) dl = Dl(**dataloader_kwargs) filter_idx_parsed = None if args.filter_idx is not None: filter_idx_parsed = parse_filter_slice(args.filter_idx) # setup batching it = dl.batch_iter(batch_size=args.batch_size, num_workers=args.num_workers) # Setup the writers use_writers = [] for output in args.output: ending = output.split('.')[-1] W = writers.FILE_SUFFIX_MAP[ending] logger.info("Using {0} for file {1}".format(W.__name__, output)) if ending == "tsv": assert W == writers.TsvBatchWriter use_writers.append( writers.TsvBatchWriter(file_path=output, nested_sep="/")) elif ending == "bed": raise Exception("Please use tsv or hdf5 output format.") elif ending in ["hdf5", "h5"]: assert W == writers.HDF5BatchWriter use_writers.append(writers.HDF5BatchWriter(file_path=output)) else: logger.error("Unknown file format: {0}".format(ending)) sys.exit(1) # Loop through the data, make predictions, save the output for i, batch in enumerate(tqdm(it)): # validate the data schema in the first iteration if i == 0 and not Dl.output_schema.compatible_with_batch(batch): logger.warn( "First batch of data is not compatible with the dataloader schema." ) # make the prediction pred_batch = model.input_grad(batch['inputs'], filter_idx=filter_idx_parsed, avg_func=args.avg_func, layer=layer, final_layer=args.final_layer, selected_fwd_node=args.selected_fwd_node, pre_nonlinearity=args.pre_nonlinearity) # write out the predictions, metadata (, inputs, targets) # always keep the inputs so that input*grad can be generated! # output_batch = prepare_batch(batch, pred_batch, keep_inputs=True) output_batch = batch output_batch["grads"] = pred_batch for writer in use_writers: writer.batch_write(output_batch) for writer in use_writers: writer.close() logger.info('Done! Gradients stored in {0}'.format(",".join(args.output)))
help="Number of workers used to load the data") parser.add_argument("--tf", required=True, help="Transcription factor to benchmark") parser.add_argument("--output", "-o", required=True, help="Transcription factor to benchmark") args = parser.parse_args() model = kipoi.get_model(args.model) print("Obtaining a batch of data, using {} workers".format( args.num_workers)) dl_kwargs = kipoi.pipeline.validate_kwargs( model.default_dataloader, parse_json_file_str(args.dl_kwargs)) print("Used kwargs: {}".format(dl_kwargs)) dl = model.default_dataloader(**dl_kwargs) # batch = numpy_collate([dl[0]]*args.batch_size) it = dl.batch_iter(args.batch_size, num_workers=args.num_workers) batch = next(it) print("Measuring the forward time pass") times = [] for i in range(args.num_runs): start_time = time.time() model.predict_on_batch(batch['inputs']) duration = time.time() - start_time times.append(duration) print("Writing results to a json file")
def cli_score_variants(command, raw_args): """CLI interface to predict """ AVAILABLE_FORMATS = ["tsv", "hdf5", "h5"] import pybedtools assert command == "score_variants" parser = argparse.ArgumentParser( 'kipoi postproc {}'.format(command), description='Predict effect of SNVs using ISM.') add_model(parser) add_dataloader(parser, with_args=True) parser.add_argument('-v', '--vcf_path', help='Input VCF.') # TODO - rename path to fpath parser.add_argument('-a', '--out_vcf_fpath', help='Output annotated VCF file path.', default=None) parser.add_argument('--batch_size', type=int, default=32, help='Batch size to use in prediction') parser.add_argument( "-n", "--num_workers", type=int, default=0, help="Number of parallel workers for loading the dataset") parser.add_argument("-i", "--install_req", action='store_true', help="Install required packages from requirements.txt") parser.add_argument( '-r', '--restriction_bed', default=None, help="Regions for prediction can only be subsets of this bed file") parser.add_argument( '-o', '--output', required=False, help= "Additional output file. File format is inferred from the file path ending" + ". Available file formats are: {0}".format( ",".join(AVAILABLE_FORMATS))) parser.add_argument( '-s', "--scoring", default="diff", nargs="+", help= "Scoring method to be used. Only scoring methods selected in the model yaml file are" "available except for `diff` which is always available. Select scoring function by the" "`name` tag defined in the model yaml file.") parser.add_argument( '-k', "--scoring_kwargs", default="", nargs="+", help= "JSON definition of the kwargs for the scoring functions selected in --scoring. The " "definiton can either be in JSON in the command line or the path of a .json file. The " "individual JSONs are expected to be supplied in the same order as the labels defined in " "--scoring. If the defaults or no arguments should be used define '{}' for that respective " "scoring method.") args = parser.parse_args(raw_args) # extract args for kipoi.variant_effects.predict_snvs vcf_path = args.vcf_path out_vcf_fpath = args.out_vcf_fpath dataloader_arguments = parse_json_file_str(args.dataloader_args) # infer the file format args.file_format = args.output.split(".")[-1] if args.file_format not in AVAILABLE_FORMATS: logger.error("File ending: {0} for file {1} not from {2}".format( args.file_format, args.output, AVAILABLE_FORMATS)) sys.exit(1) if args.file_format in ["hdf5", "h5"]: # only if hdf5 output is used import deepdish # Check that all the folders exist file_exists(args.vcf_path, logger) dir_exists(os.path.dirname(args.out_vcf_fpath), logger) if args.output is not None: dir_exists(os.path.dirname(args.output), logger) # -------------------------------------------- # install args if args.install_req: kipoi.pipeline.install_model_requirements(args.model, args.source, and_dataloaders=True) # load model & dataloader model = kipoi.get_model(args.model, args.source) if args.dataloader is not None: Dl = kipoi.get_dataloader_factory(args.dataloader, args.dataloader_source) else: Dl = model.default_dataloader if not os.path.exists(vcf_path): raise Exception("VCF file does not exist: %s" % vcf_path) if not isinstance(args.scoring, list): args.scoring = [args.scoring] dts = _get_scoring_fns(model, args.scoring, args.scoring_kwargs) # Load effect prediction related model info model_info = kipoi.postprocessing.variant_effects.ModelInfoExtractor( model, Dl) # Select the appropriate region generator if args.restriction_bed is not None: # Select the restricted SNV-centered region generator pbd = pybedtools.BedTool(args.restriction_bed) vcf_to_region = kipoi.postprocessing.variant_effects.SnvPosRestrictedRg( model_info, pbd) logger.info( 'Restriction bed file defined. Only variants in defined regions will be tested.' 'Only defined regions will be tested.') elif model_info.requires_region_definition: # Select the SNV-centered region generator vcf_to_region = kipoi.postprocessing.variant_effects.SnvCenteredRg( model_info) logger.info('Using variant-centered sequence generation.') else: # No regions can be defined for the given model, VCF overlap will be inferred, hence tabixed VCF is necessary vcf_to_region = None # Make sure that the vcf is tabixed vcf_path = kipoi.postprocessing.variant_effects.ensure_tabixed_vcf( vcf_path) logger.info( 'Dataloader does not accept definition of a regions bed-file. Only VCF-variants that lie within' 'produced regions can be predicted') if model_info.use_seq_only_rc: logger.info( 'Model SUPPORTS simple reverse complementation of input DNA sequences.' ) else: logger.info( 'Model DOES NOT support simple reverse complementation of input DNA sequences.' ) # Get a vcf output writer if needed if out_vcf_fpath is not None: logger.info('Annotated VCF will be written to %s.' % str(out_vcf_fpath)) vcf_writer = kipoi.postprocessing.variant_effects.VcfWriter( model, vcf_path, out_vcf_fpath) else: vcf_writer = None keep_predictions = args.output is not None res = kipoi.postprocessing.variant_effects.predict_snvs( model, Dl, vcf_path, batch_size=args.batch_size, num_workers=args.num_workers, dataloader_args=dataloader_arguments, vcf_to_region=vcf_to_region, evaluation_function_kwargs={"diff_types": dts}, sync_pred_writer=vcf_writer, return_predictions=keep_predictions) # tabular files if args.output is not None: if args.file_format in ["tsv"]: for i, k in enumerate(res): # Remove an old file if it is still there... if i == 0: try: os.unlink(args.output) except Exception: pass with open(args.output, "w") as ofh: ofh.write("KPVEP_%s\n" % k.upper()) res[k].to_csv(args.output, sep="\t", mode="a") if args.file_format in ["hdf5", "h5"]: deepdish.io.save(args.output, res) logger.info('Successfully predicted samples')
def _get_scoring_fns(model, sel_scoring_labels, sel_scoring_kwargs): # get the scoring methods according to the model definition avail_scoring_fns, avail_scoring_fn_def_args, avail_scoring_fn_names, \ default_scoring_fns = get_avail_scoring_methods(model) errmsg_scoring_kwargs = "When defining `--scoring_kwargs` a JSON representation of arguments (or the path of a" \ " file containing them) must be given for every `--scoring` function." dts = {} if len(sel_scoring_labels) >= 1: # Check if all scoring functions should be used: if sel_scoring_labels == ["all"]: if len(sel_scoring_kwargs) >= 1: raise ValueError( "`--scoring_kwargs` cannot be defined in combination will `--scoring all`!" ) for arg_iter, k in enumerate(avail_scoring_fn_names): si = avail_scoring_fn_names.index(k) # get the default kwargs kwargs = avail_scoring_fn_def_args[si] if kwargs is None: raise ValueError( "No default kwargs for scoring function: %s" " `--scoring all` cannot be used. " "Please also define `--scoring_kwargs`." % (k)) # instantiate the scoring fn dts[k] = avail_scoring_fns[si](**kwargs) else: # if -k set check that length matches with -s if len(sel_scoring_kwargs) >= 1: if not len(sel_scoring_labels) == len(sel_scoring_kwargs): raise ValueError(errmsg_scoring_kwargs) for arg_iter, k in enumerate(sel_scoring_labels): # if -s set check is available for model if k in avail_scoring_fn_names: si = avail_scoring_fn_names.index(k) # get the default kwargs kwargs = avail_scoring_fn_def_args[si] # if the user has set scoring function kwargs then load them here. if len(sel_scoring_kwargs) >= 1: # all the {}s in -k replace by their defaults, if the default is None # raise exception with the corrsponding scoring function label etc. defined_kwargs = parse_json_file_str( sel_scoring_kwargs[si]) if len(defined_kwargs) != 0: kwargs = defined_kwargs if kwargs is None: raise ValueError( "No kwargs were given for scoring function %s" " with no defaults but required argmuents. " "Please also define `--scoring_kwargs`." % (k)) # instantiate the scoring fn dts[k] = avail_scoring_fns[si](**kwargs) else: logger.warn("Cannot choose scoring function %s. " "Model only supports: %s." % (k, str(avail_scoring_fn_names))) # if -s not set use all defaults elif len(default_scoring_fns) != 0: for arg_iter, k in enumerate(default_scoring_fns): si = avail_scoring_fn_names.index(k) kwargs = avail_scoring_fn_def_args[si] dts[k] = avail_scoring_fns[si](**kwargs) if len(dts) == 0: raise Exception("No scoring method was chosen!") return dts
def cli_feature_importance(command, raw_args): """CLI interface to predict """ # from .main import prepare_batch assert command == "feature_importance" parser = argparse.ArgumentParser('kipoi {}'.format(command), description='Save gradients and inputs to a hdf5 file.') add_model(parser) add_dataloader(parser, with_args=True) parser.add_argument("--imp_score", help="Importance score name", choices=available_importance_scores()) parser.add_argument("--imp_score_kwargs", help="Importance score kwargs") parser.add_argument('--batch_size', type=int, default=32, help='Batch size to use in prediction') parser.add_argument("-n", "--num_workers", type=int, default=0, help="Number of parallel workers for loading the dataset") # TODO - handle the reference-based importance scores... # io parser.add_argument('-o', '--output', required=True, nargs="+", help="Output files. File format is inferred from the file path ending. Available file formats are: " + ", ".join(["." + k for k in writers.FILE_SUFFIX_MAP])) args = parser.parse_args(raw_args) dataloader_kwargs = parse_json_file_str(args.dataloader_args) imp_score_kwargs = parse_json_file_str(args.imp_score_kwargs) # setup the files if not isinstance(args.output, list): args.output = [args.output] for o in args.output: ending = o.split('.')[-1] if ending not in writers.FILE_SUFFIX_MAP: logger.error("File ending: {0} for file {1} not from {2}". format(ending, o, writers.FILE_SUFFIX_MAP)) sys.exit(1) dir_exists(os.path.dirname(o), logger) # -------------------------------------------- # install args if args.install_req: kipoi.pipeline.install_model_requirements(args.model, args.source, and_dataloaders=True) # load model & dataloader model = kipoi.get_model(args.model, args.source, with_dataloader=args.dataloader is None) if args.dataloader is not None: Dl = kipoi.get_dataloader_factory(args.dataloader, args.dataloader_source) else: Dl = model.default_dataloader dataloader_kwargs = kipoi.pipeline.validate_kwargs(Dl, dataloader_kwargs) dl = Dl(**dataloader_kwargs) # get_importance_score ImpScore = get_importance_score(args.imp_score) if not ImpScore.is_compatible(model): raise ValueError("model not compatible with score: {0}".format(args.imp_score)) impscore = ImpScore(model, **imp_score_kwargs) # setup batching it = dl.batch_iter(batch_size=args.batch_size, num_workers=args.num_workers) # Setup the writers use_writers = [] for output in args.output: ending = output.split('.')[-1] W = writers.FILE_SUFFIX_MAP[ending] logger.info("Using {0} for file {1}".format(W.__name__, output)) if ending == "tsv": assert W == writers.TsvBatchWriter use_writers.append(writers.TsvBatchWriter(file_path=output, nested_sep="/")) elif ending == "bed": raise Exception("Please use tsv or hdf5 output format.") elif ending in ["hdf5", "h5"]: assert W == writers.HDF5BatchWriter use_writers.append(writers.HDF5BatchWriter(file_path=output)) else: logger.error("Unknown file format: {0}".format(ending)) sys.exit(1) # Loop through the data, make predictions, save the output for i, batch in enumerate(tqdm(it)): # validate the data schema in the first iteration if i == 0 and not Dl.output_schema.compatible_with_batch(batch): logger.warn("First batch of data is not compatible with the dataloader schema.") # make the prediction # TODO - handle the reference-based importance scores... importance_scores = impscore.score(batch['inputs']) # write out the predictions, metadata (, inputs, targets) # always keep the inputs so that input*grad can be generated! # output_batch = prepare_batch(batch, pred_batch, keep_inputs=True) output_batch = batch output_batch["importance_scores"] = importance_scores for writer in use_writers: writer.batch_write(output_batch) for writer in use_writers: writer.close() logger.info('Done! Importance scores stored in {0}'.format(",".join(args.output)))