def main(args): print("main") ht = hl.read_table( f'{temp_dir}/ddd-elgh-ukbb/variant_qc/Sanger_table_for_RF_by_variant_type.ht' ) run_hash = str(uuid.uuid4())[:8] rf_runs = get_rf_runs(f'{tmp_dir}/ddd-elgh-ukbb/') while run_hash in rf_runs: run_hash = str(uuid.uuid4())[:8] ht_result, rf_model = train_rf(ht, args) print("Writing out ht_training data") ht_result = ht_result.checkpoint( f'{tmp_dir}/ddd-elgh-ukbb/Sanger_RF_training_data.ht', overwrite=True) rf_runs[run_hash] = get_run_data( vqsr_training=False, transmitted_singletons=True, test_intervals=args.test_intervals, adj=False, features_importance=hl.eval(ht_result.features_importance), test_results=hl.eval(ht_result.test_results), ) with hl.hadoop_open(f'{plot_dir}/ddd-elgh-ukbb/variant_qc/rf_runs.json', "w") as f: json.dump(rf_runs, f) logger.info("Saving RF model") save_model(rf_model, f'{tmp_dir}/ddd-elgh-ukbb/rf_model.model')
def main(args): print("main") ht = hl.read_table( f'{temp_dir}/ddd-elgh-ukbb/variant_qc/Sanger_table_for_RF_by_variant_type.ht' ) if args.train_rf: run_hash = str(uuid.uuid4())[:8] rf_runs = get_rf_runs(f'{tmp_dir}/ddd-elgh-ukbb/') while run_hash in rf_runs: run_hash = str(uuid.uuid4())[:8] ht_result, rf_model = train_rf(ht, args) print("Writing out ht_training data") ht_result = ht_result.checkpoint(get_rf(data="training", run_hash=run_hash).path, overwrite=True) # f'{tmp_dir}/ddd-elgh-ukbb/Sanger_RF_training_data.ht', overwrite=True) rf_runs[run_hash] = get_run_data( vqsr_training=False, transmitted_singletons=True, test_intervals=args.test_intervals, adj=False, features_importance=hl.eval(ht_result.features_importance), test_results=hl.eval(ht_result.test_results), ) with hl.hadoop_open( f'{plot_dir}/ddd-elgh-ukbb/variant_qc/rf_runs.json', "w") as f: json.dump(rf_runs, f) logger.info("Saving RF model") save_model(rf_model, get_rf(data="model", run_hash=run_hash), overwrite=True) # f'{tmp_dir}/ddd-elgh-ukbb/rf_model.model') else: run_hash = args.run_hash if args.apply_rf: logger.info(f"Applying RF model {run_hash}...") rf_model = load_model(get_rf(data="model", run_hash=run_hash)) ht = get_rf(data="training", run_hash=run_hash).ht() features = hl.eval(ht.features) ht = apply_rf_model(ht, rf_model, features, label=LABEL_COL) logger.info("Finished applying RF model") ht = ht.annotate_globals(rf_hash=run_hash) ht = ht.checkpoint( get_rf("rf_result", run_hash=run_hash).path, overwrite=True, ) ht_summary = ht.group_by("tp", "fp", TRAIN_COL, LABEL_COL, PREDICTION_COL).aggregate(n=hl.agg.count()) ht_summary.show(n=20)
def main(args): hl.init(log="/variant_qc_random_forest.log") if args.list_rf_runs: logger.info(f"RF runs:") pretty_print_runs(get_rf_runs(rf_run_path())) if args.annotate_for_rf: ht = create_rf_ht( impute_features=args.impute_features, adj=args.adj, n_partitions=args.n_partitions, checkpoint_path=get_checkpoint_path("rf_annotation"), ) ht.write( get_rf_annotations(args.adj).path, overwrite=args.overwrite, ) logger.info(f"Completed annotation wrangling for random forests model training") if args.train_rf: model_id = f"rf_{str(uuid.uuid4())[:8]}" rf_runs = get_rf_runs(rf_run_path()) while model_id in rf_runs: model_id = f"rf_{str(uuid.uuid4())[:8]}" ht, rf_model = train_rf( get_rf_annotations(args.adj).ht(), fp_to_tp=args.fp_to_tp, num_trees=args.num_trees, max_depth=args.max_depth, no_transmitted_singletons=args.no_transmitted_singletons, no_inbreeding_coeff=args.no_inbreeding_coeff, vqsr_training=args.vqsr_training, vqsr_model_id=args.vqsr_model_id, filter_centromere_telomere=args.filter_centromere_telomere, test_intervals=args.test_intervals, ) ht = ht.checkpoint( get_rf_training(model_id=model_id).path, overwrite=args.overwrite, ) logger.info("Adding run to RF run list") rf_runs[model_id] = get_run_data( input_args={ "transmitted_singletons": None if args.vqsr_training else not args.no_transmitted_singletons, "adj": args.adj, "vqsr_training": args.vqsr_training, "filter_centromere_telomere": args.filter_centromere_telomere, }, test_intervals=args.test_intervals, features_importance=hl.eval(ht.features_importance), test_results=hl.eval(ht.test_results), ) with hl.hadoop_open(rf_run_path(), "w") as f: json.dump(rf_runs, f) logger.info("Saving RF model") save_model( rf_model, get_rf_model_path(model_id=model_id), overwrite=args.overwrite, ) else: model_id = args.model_id if args.apply_rf: logger.info(f"Applying RF model {model_id}...") rf_model = load_model(get_rf_model_path(model_id=model_id)) ht = get_rf_training(model_id=model_id).ht() features = hl.eval(ht.features) ht = apply_rf_model(ht, rf_model, features, label=LABEL_COL) logger.info("Finished applying RF model") ht = ht.annotate_globals(rf_model_id=model_id) ht = ht.checkpoint( get_rf_result(model_id=model_id).path, overwrite=args.overwrite, ) ht_summary = ht.group_by( "tp", "fp", TRAIN_COL, LABEL_COL, PREDICTION_COL ).aggregate(n=hl.agg.count()) ht_summary.show(n=20)
def main(args): print("importing main table") ht = hl.read_table( f'{nfs_dir}/hail_data/variant_qc/chd_ukbb.table_for_RF_by_variant_type_all_cols.ht' ) if args.train_rf: # ht = hl.read_table( # f'{temp_dir}/ddd-elgh-ukbb/variant_qc/Sanger_table_for_RF_by_variant_type.ht') run_hash = str(uuid.uuid4())[:8] rf_runs = get_rf_runs(f'{tmp_dir}/rf_runs.json') while run_hash in rf_runs: run_hash = str(uuid.uuid4())[:8] ht_result, rf_model = train_rf(ht, args) print("Writing out ht_training data") ht_result = ht_result.checkpoint(get_rf(data="training", run_hash=run_hash).path, overwrite=True) # f'{tmp_dir}/ddd-elgh-ukbb/Sanger_RF_training_data.ht', overwrite=True) rf_runs[run_hash] = get_run_data( vqsr_training=False, transmitted_singletons=True, test_intervals=args.test_intervals, adj=True, features_importance=hl.eval(ht_result.features_importance), test_results=hl.eval(ht_result.test_results), ) with hl.hadoop_open(f'{tmp_dir}/rf_runs.json', "w") as f: json.dump(rf_runs, f) pretty_print_runs(rf_runs) logger.info("Saving RF model") save_model(rf_model, get_rf(data="model", run_hash=run_hash), overwrite=True) # f'{tmp_dir}/ddd-elgh-ukbb/rf_model.model') else: run_hash = args.run_hash if args.apply_rf: logger.info(f"Applying RF model {run_hash}...") rf_model = load_model(get_rf(data="model", run_hash=run_hash)) ht = get_rf(data="training", run_hash=run_hash).ht() features = hl.eval(ht.features) ht = apply_rf_model(ht, rf_model, features, label=LABEL_COL) logger.info("Finished applying RF model") ht = ht.annotate_globals(rf_hash=run_hash) ht = ht.checkpoint( get_rf("rf_result_chd_ukbb", run_hash=run_hash).path, overwrite=True, ) ht_summary = ht.group_by("tp", "fp", TRAIN_COL, LABEL_COL, PREDICTION_COL).aggregate(n=hl.agg.count()) ht_summary.show(n=20) if args.finalize: # TODO: Adjust this step to run on the CHD-UKBB cohort run_hash = args.run_hash ht = hl.read_table( f'{tmp_dir}/variant_qc/models/{run_hash}/rf_result_ac_added.ht') # ht = create_grouped_bin_ht( # model_id=run_hash, overwrite=True) freq_ht = hl.read_table( f'{tmp_dir}/variant_qc/mt_sampleQC_FILTERED_FREQ_adj.ht') freq = freq_ht[ht.key] print("created bin ht") ht = generate_final_rf_ht( ht, ac0_filter_expr=freq.freq[0].AC == 0, ts_ac_filter_expr=freq.freq[1].AC == 1, mono_allelic_fiter_expr=(freq.freq[1].AF == 1) | (freq.freq[1].AF == 0), snp_cutoff=args.snp_cutoff, indel_cutoff=args.indel_cutoff, determine_cutoff_from_bin=False, aggregated_bin_ht=bin_ht, bin_id=bin_ht.bin, inbreeding_coeff_cutoff=INBREEDING_COEFF_HARD_CUTOFF, ) # This column is added by the RF module based on a 0.5 threshold which doesn't correspond to what we use # ht = ht.drop(ht[PREDICTION_COL]) ht.write(f'{tmp_dir}/rf_final.ht', overwrite=True)
def main(args): print("main") ht = hl.read_table( f'{temp_dir}/ddd-elgh-ukbb/variant_qc/Sanger_table_for_RF_by_variant_type.ht' ) if args.train_rf: run_hash = str(uuid.uuid4())[:8] rf_runs = get_rf_runs(f'{tmp_dir}/ddd-elgh-ukbb/') while run_hash in rf_runs: run_hash = str(uuid.uuid4())[:8] ht_result, rf_model = train_rf(ht, args) print("Writing out ht_training data") ht_result = ht_result.checkpoint(get_rf(data="training", run_hash=run_hash).path, overwrite=True) # f'{tmp_dir}/ddd-elgh-ukbb/Sanger_RF_training_data.ht', overwrite=True) rf_runs[run_hash] = get_run_data( vqsr_training=False, transmitted_singletons=True, test_intervals=args.test_intervals, adj=False, features_importance=hl.eval(ht_result.features_importance), test_results=hl.eval(ht_result.test_results), ) with hl.hadoop_open( f'{plot_dir}/ddd-elgh-ukbb/variant_qc/rf_runs.json', "w") as f: json.dump(rf_runs, f) pretty_print_runs(rf_runs) logger.info("Saving RF model") save_model(rf_model, get_rf(data="model", run_hash=run_hash), overwrite=True) # f'{tmp_dir}/ddd-elgh-ukbb/rf_model.model') else: run_hash = args.run_hash if args.apply_rf: logger.info(f"Applying RF model {run_hash}...") #rf_model = load_model(get_rf(data="model", run_hash=run_hash)) run_hash = args.run_hash rf_model = hl.read_table( f'{temp_dir}/ddd-elgh-ukbb/variant_qc/models/{run_hash}/model.model' ) # ht = hl.read_table( # f'{temp_dir}/ddd-elgh-ukbb/variant_qc/Sanger_cohorts_chr1-20-XY_sampleQC_FILTERED_FREQ_adj_inb.ht') ht = hl.read_table( f'{temp_dir}/ddd-elgh-ukbb/variant_qc/Sanger_cohorts_for_RF_unfiltered.ht' ) ht = ht.annotate(rf_label=rf_model[ht.key].rf_label) #ht = get_rf(data="training", run_hash=run_hash).ht() features = hl.eval(rf_model.features) ht = apply_rf_model(ht, rf_model, features, label=LABEL_COL) logger.info("Finished applying RF model") ht = ht.annotate_globals(rf_hash=run_hash) ht = ht.checkpoint( get_rf("rf_result_sanger_cohorts", run_hash=run_hash).path, overwrite=True, ) ht_summary = ht.group_by("tp", "fp", TRAIN_COL, LABEL_COL, PREDICTION_COL).aggregate(n=hl.agg.count()) ht_summary.show(n=20) if args.finalize: run_hash = args.run_hash ht = hl.read_table( f'{temp_dir}/ddd-elgh-ukbb/variant_qc/models/{run_hash}/rf_result_ac_added.ht' ) # ht = create_grouped_bin_ht( # model_id=run_hash, overwrite=True) freq_ht = hl.read_table( f'{temp_dir}/ddd-elgh-ukbb/variant_qc/Sanger_cohorts_chr1-20-XY_sampleQC_FILTERED_FREQ_adj.ht' ) freq = freq_ht[ht.key] bin_ht = create_quantile_bin_ht(run_hash, n_bins=100, vqsr=False, overwrite=True) # if not file_exists( # get_score_quantile_bins(args.run_hash, aggregated=True).path # ): # sys.exit( # f"Could not find binned HT for RF run {args.run_hash} (). Please run create_ranked_scores.py for that hash." # ) #aggregated_bin_ht = get_score_quantile_bins(ht, aggregated=True) print("created bin ht") ht = generate_final_rf_ht( ht, ac0_filter_expr=freq.freq[0].AC == 0, ts_ac_filter_expr=freq.freq[1].AC == 1, mono_allelic_fiter_expr=(freq.freq[1].AF == 1) | (freq.freq[1].AF == 0), snp_cutoff=args.snp_cutoff, indel_cutoff=args.indel_cutoff, determine_cutoff_from_bin=False, aggregated_bin_ht=bin_ht, bin_id=bin_ht.bin, inbreeding_coeff_cutoff=INBREEDING_COEFF_HARD_CUTOFF, ) # This column is added by the RF module based on a 0.5 threshold which doesn't correspond to what we use #ht = ht.drop(ht[PREDICTION_COL]) ht.write(f'{tmp_dir}/rf_final.ht', overwrite=True)
def train_rf(data_type, args): # Get unique hash for run and load previous runs run_hash = str(uuid.uuid4())[:8] rf_runs = get_rf_runs(data_type) while run_hash in rf_runs: run_hash = str(uuid.uuid4())[:8] ht = hl.read_table(rf_annotated_path(data_type, args.adj)) ht = ht.repartition(500, shuffle=False) if not args.vqsr_training: if args.no_transmitted_singletons: tp_expr = ht.omni | ht.mills | ht.info_POSITIVE_TRAIN_SITE else: tp_expr = ht.omni | ht.mills | ht.info_POSITIVE_TRAIN_SITE | ht.transmitted_singleton ht = ht.annotate(tp=tp_expr) test_intervals_str = [] if not args.test_intervals else [ args.test_intervals ] if isinstance(args.test_intervals, str) else args.test_intervals test_intervals_locus = [ hl.parse_locus_interval(x) for x in test_intervals_str ] if test_intervals_locus: ht = ht.annotate_globals(test_intervals=test_intervals_locus) ht = sample_rf_training_examples( ht, tp_col='info_POSITIVE_TRAIN_SITE' if args.vqsr_training else 'tp', fp_col='info_NEGATIVE_TRAIN_SITE' if args.vqsr_training else 'fail_hard_filters', fp_to_tp=args.fp_to_tp) ht = ht.persist() rf_ht = ht.filter(ht[TRAIN_COL]) rf_features = get_features_list( True, not (args.vqsr_features or args.median_features), args.vqsr_features, args.median_features) logger.info( "Training RF model:\nfeatures: {}\nnum_tree: {}\nmax_depth:{}\nTest intervals: {}" .format(",".join(rf_features), args.num_trees, args.max_depth, ",".join(test_intervals_str))) rf_model = random_forest.train_rf(rf_ht, features=rf_features, label=LABEL_COL, num_trees=args.num_trees, max_depth=args.max_depth) logger.info("Saving RF model") random_forest.save_model(rf_model, rf_path(data_type, data='model', run_hash=run_hash), overwrite=args.overwrite) test_results = None if args.test_intervals: logger.info("Testing model {} on intervals {}".format( run_hash, ",".join(test_intervals_str))) test_ht = hl.filter_intervals(ht, test_intervals_locus, keep=True) test_ht = test_ht.checkpoint('gs://gnomad-tmp/test_random_forest.ht', overwrite=True) test_ht = test_ht.filter(hl.is_defined(test_ht[LABEL_COL])) test_results = random_forest.test_model( test_ht, rf_model, features=get_features_list(True, not args.vqsr_features, args.vqsr_features), label=LABEL_COL) ht = ht.annotate_globals(test_results=test_results) logger.info("Writing RF training HT") features_importance = random_forest.get_features_importance(rf_model) ht = ht.annotate_globals( features_importance=features_importance, features=get_features_list(True, not args.vqsr_features, args.vqsr_features), vqsr_training=args.vqsr_training, no_transmitted_singletons=args.no_transmitted_singletons, adj=args.adj) ht.write(rf_path(data_type, data='training', run_hash=run_hash), overwrite=args.overwrite) logger.info("Adding run to RF run list") rf_runs[run_hash] = get_run_data(data_type, args.vqsr_training, args.no_transmitted_singletons, args.adj, test_intervals_str, features_importance, test_results) with hl.hadoop_open(rf_run_hash_path(data_type), 'w') as f: json.dump(rf_runs, f) return run_hash