Exemplo n.º 1
0
def main(args):

    print("main")
    ht = hl.read_table(
        f'{temp_dir}/ddd-elgh-ukbb/variant_qc/Sanger_table_for_RF_by_variant_type.ht'
    )

    if args.train_rf:

        run_hash = str(uuid.uuid4())[:8]
        rf_runs = get_rf_runs(f'{tmp_dir}/ddd-elgh-ukbb/')
        while run_hash in rf_runs:
            run_hash = str(uuid.uuid4())[:8]
        ht_result, rf_model = train_rf(ht, args)
        print("Writing out ht_training data")
        ht_result = ht_result.checkpoint(get_rf(data="training",
                                                run_hash=run_hash).path,
                                         overwrite=True)
        # f'{tmp_dir}/ddd-elgh-ukbb/Sanger_RF_training_data.ht', overwrite=True)
        rf_runs[run_hash] = get_run_data(
            vqsr_training=False,
            transmitted_singletons=True,
            test_intervals=args.test_intervals,
            adj=False,
            features_importance=hl.eval(ht_result.features_importance),
            test_results=hl.eval(ht_result.test_results),
        )

        with hl.hadoop_open(
                f'{plot_dir}/ddd-elgh-ukbb/variant_qc/rf_runs.json', "w") as f:
            json.dump(rf_runs, f)

        logger.info("Saving RF model")
        save_model(rf_model,
                   get_rf(data="model", run_hash=run_hash),
                   overwrite=True)
        # f'{tmp_dir}/ddd-elgh-ukbb/rf_model.model')
    else:
        run_hash = args.run_hash

    if args.apply_rf:

        logger.info(f"Applying RF model {run_hash}...")
        rf_model = load_model(get_rf(data="model", run_hash=run_hash))
        ht = get_rf(data="training", run_hash=run_hash).ht()
        features = hl.eval(ht.features)
        ht = apply_rf_model(ht, rf_model, features, label=LABEL_COL)

        logger.info("Finished applying RF model")
        ht = ht.annotate_globals(rf_hash=run_hash)
        ht = ht.checkpoint(
            get_rf("rf_result", run_hash=run_hash).path,
            overwrite=True,
        )

        ht_summary = ht.group_by("tp", "fp", TRAIN_COL, LABEL_COL,
                                 PREDICTION_COL).aggregate(n=hl.agg.count())
        ht_summary.show(n=20)
Exemplo n.º 2
0
def main(args):
    hl.init(log="/variant_qc_random_forest.log")

    if args.list_rf_runs:
        logger.info(f"RF runs:")
        pretty_print_runs(get_rf_runs(rf_run_path()))

    if args.annotate_for_rf:
        ht = create_rf_ht(
            impute_features=args.impute_features,
            adj=args.adj,
            n_partitions=args.n_partitions,
            checkpoint_path=get_checkpoint_path("rf_annotation"),
        )
        ht.write(
            get_rf_annotations(args.adj).path, overwrite=args.overwrite,
        )
        logger.info(f"Completed annotation wrangling for random forests model training")

    if args.train_rf:
        model_id = f"rf_{str(uuid.uuid4())[:8]}"
        rf_runs = get_rf_runs(rf_run_path())
        while model_id in rf_runs:
            model_id = f"rf_{str(uuid.uuid4())[:8]}"

        ht, rf_model = train_rf(
            get_rf_annotations(args.adj).ht(),
            fp_to_tp=args.fp_to_tp,
            num_trees=args.num_trees,
            max_depth=args.max_depth,
            no_transmitted_singletons=args.no_transmitted_singletons,
            no_inbreeding_coeff=args.no_inbreeding_coeff,
            vqsr_training=args.vqsr_training,
            vqsr_model_id=args.vqsr_model_id,
            filter_centromere_telomere=args.filter_centromere_telomere,
            test_intervals=args.test_intervals,
        )

        ht = ht.checkpoint(
            get_rf_training(model_id=model_id).path, overwrite=args.overwrite,
        )

        logger.info("Adding run to RF run list")
        rf_runs[model_id] = get_run_data(
            input_args={
                "transmitted_singletons": None
                if args.vqsr_training
                else not args.no_transmitted_singletons,
                "adj": args.adj,
                "vqsr_training": args.vqsr_training,
                "filter_centromere_telomere": args.filter_centromere_telomere,
            },
            test_intervals=args.test_intervals,
            features_importance=hl.eval(ht.features_importance),
            test_results=hl.eval(ht.test_results),
        )

        with hl.hadoop_open(rf_run_path(), "w") as f:
            json.dump(rf_runs, f)

        logger.info("Saving RF model")
        save_model(
            rf_model, get_rf_model_path(model_id=model_id), overwrite=args.overwrite,
        )

    else:
        model_id = args.model_id

    if args.apply_rf:
        logger.info(f"Applying RF model {model_id}...")
        rf_model = load_model(get_rf_model_path(model_id=model_id))
        ht = get_rf_training(model_id=model_id).ht()
        features = hl.eval(ht.features)
        ht = apply_rf_model(ht, rf_model, features, label=LABEL_COL)

        logger.info("Finished applying RF model")
        ht = ht.annotate_globals(rf_model_id=model_id)
        ht = ht.checkpoint(
            get_rf_result(model_id=model_id).path, overwrite=args.overwrite,
        )

        ht_summary = ht.group_by(
            "tp", "fp", TRAIN_COL, LABEL_COL, PREDICTION_COL
        ).aggregate(n=hl.agg.count())
        ht_summary.show(n=20)
Exemplo n.º 3
0
def main(args):

    print("importing main table")
    ht = hl.read_table(
        f'{nfs_dir}/hail_data/variant_qc/chd_ukbb.table_for_RF_by_variant_type_all_cols.ht'
    )

    if args.train_rf:
        # ht = hl.read_table(
        #    f'{temp_dir}/ddd-elgh-ukbb/variant_qc/Sanger_table_for_RF_by_variant_type.ht')
        run_hash = str(uuid.uuid4())[:8]
        rf_runs = get_rf_runs(f'{tmp_dir}/rf_runs.json')
        while run_hash in rf_runs:
            run_hash = str(uuid.uuid4())[:8]

        ht_result, rf_model = train_rf(ht, args)
        print("Writing out ht_training data")
        ht_result = ht_result.checkpoint(get_rf(data="training",
                                                run_hash=run_hash).path,
                                         overwrite=True)
        # f'{tmp_dir}/ddd-elgh-ukbb/Sanger_RF_training_data.ht', overwrite=True)
        rf_runs[run_hash] = get_run_data(
            vqsr_training=False,
            transmitted_singletons=True,
            test_intervals=args.test_intervals,
            adj=True,
            features_importance=hl.eval(ht_result.features_importance),
            test_results=hl.eval(ht_result.test_results),
        )

        with hl.hadoop_open(f'{tmp_dir}/rf_runs.json', "w") as f:
            json.dump(rf_runs, f)
        pretty_print_runs(rf_runs)
        logger.info("Saving RF model")
        save_model(rf_model,
                   get_rf(data="model", run_hash=run_hash),
                   overwrite=True)
        # f'{tmp_dir}/ddd-elgh-ukbb/rf_model.model')
    else:
        run_hash = args.run_hash

    if args.apply_rf:

        logger.info(f"Applying RF model {run_hash}...")
        rf_model = load_model(get_rf(data="model", run_hash=run_hash))

        ht = get_rf(data="training", run_hash=run_hash).ht()
        features = hl.eval(ht.features)
        ht = apply_rf_model(ht, rf_model, features, label=LABEL_COL)
        logger.info("Finished applying RF model")
        ht = ht.annotate_globals(rf_hash=run_hash)
        ht = ht.checkpoint(
            get_rf("rf_result_chd_ukbb", run_hash=run_hash).path,
            overwrite=True,
        )

        ht_summary = ht.group_by("tp", "fp", TRAIN_COL, LABEL_COL,
                                 PREDICTION_COL).aggregate(n=hl.agg.count())
        ht_summary.show(n=20)

    if args.finalize:

        # TODO: Adjust this step to run on the CHD-UKBB cohort

        run_hash = args.run_hash
        ht = hl.read_table(
            f'{tmp_dir}/variant_qc/models/{run_hash}/rf_result_ac_added.ht')
        # ht = create_grouped_bin_ht(
        #    model_id=run_hash, overwrite=True)
        freq_ht = hl.read_table(
            f'{tmp_dir}/variant_qc/mt_sampleQC_FILTERED_FREQ_adj.ht')
        freq = freq_ht[ht.key]

        print("created bin ht")

        ht = generate_final_rf_ht(
            ht,
            ac0_filter_expr=freq.freq[0].AC == 0,
            ts_ac_filter_expr=freq.freq[1].AC == 1,
            mono_allelic_fiter_expr=(freq.freq[1].AF == 1) |
            (freq.freq[1].AF == 0),
            snp_cutoff=args.snp_cutoff,
            indel_cutoff=args.indel_cutoff,
            determine_cutoff_from_bin=False,
            aggregated_bin_ht=bin_ht,
            bin_id=bin_ht.bin,
            inbreeding_coeff_cutoff=INBREEDING_COEFF_HARD_CUTOFF,
        )
        # This column is added by the RF module based on a 0.5 threshold which doesn't correspond to what we use
        # ht = ht.drop(ht[PREDICTION_COL])
        ht.write(f'{tmp_dir}/rf_final.ht', overwrite=True)
Exemplo n.º 4
0
def main(args):
    hl.init(log='/variantqc.log')

    data_type = 'exomes' if args.exomes else 'genomes'

    if args.debug:
        logger.setLevel(logging.DEBUG)

    if args.list_rf_runs:
        logger.info(f"RF runs for {data_type}:")
        random_forest.pretty_print_runs(get_rf_runs(data_type))

    if args.annotate_for_rf:
        ht = create_rf_ht(data_type,
                          n_variants_median=args.n_variants_median,
                          impute_features_by_variant_type=not args.
                          impute_features_no_variant_type,
                          group='adj' if args.adj else 'raw')
        ht.write(rf_annotated_path(data_type, args.adj),
                 overwrite=args.overwrite)

    run_hash = train_rf(data_type, args) if args.train_rf else args.run_hash

    if args.apply_rf:
        logger.info(f"Applying RF model {run_hash} to {data_type}.")

        rf_model = random_forest.load_model(
            rf_path(data_type, data='model', run_hash=run_hash))
        ht = hl.read_table(
            rf_path(data_type, data='training', run_hash=run_hash))

        ht = random_forest.apply_rf_model(ht,
                                          rf_model,
                                          get_features_list(
                                              True, not args.vqsr_features,
                                              args.vqsr_features),
                                          label=LABEL_COL)

        if 'singleton' in ht.row and 'was_split' in ht.row:  # Needed for backwards compatibility for RF runs that happened prior to updating annotations
            ht = add_rank(ht,
                          score_expr=ht.rf_probability['FP'],
                          subrank_expr={
                              'singleton_rank':
                              ht.singleton,
                              'biallelic_rank':
                              ~ht.was_split,
                              'biallelic_singleton_rank':
                              ~ht.was_split & ht.singleton
                          })
        else:
            logger.warn(
                "Ranking was not added  because of missing annotations -- please run 'create_ranked_scores.py' to add rank."
            )

        ht.write(rf_path(data_type, 'rf_result', run_hash=run_hash),
                 overwrite=args.overwrite)

    if args.finalize:
        ht = prepare_final_ht(data_type, args.run_hash, args.snp_bin_cutoff,
                              args.indel_bin_cutoff)
        ht.write(annotations_ht_path(data_type, 'rf'), args.overwrite)