Example #1
0
def main(args):

    print("main")
    ht = hl.read_table(
        f'{temp_dir}/ddd-elgh-ukbb/variant_qc/Sanger_table_for_RF_by_variant_type.ht'
    )

    run_hash = str(uuid.uuid4())[:8]
    rf_runs = get_rf_runs(f'{tmp_dir}/ddd-elgh-ukbb/')
    while run_hash in rf_runs:
        run_hash = str(uuid.uuid4())[:8]
    ht_result, rf_model = train_rf(ht, args)
    print("Writing out ht_training data")
    ht_result = ht_result.checkpoint(
        f'{tmp_dir}/ddd-elgh-ukbb/Sanger_RF_training_data.ht', overwrite=True)
    rf_runs[run_hash] = get_run_data(
        vqsr_training=False,
        transmitted_singletons=True,
        test_intervals=args.test_intervals,
        adj=False,
        features_importance=hl.eval(ht_result.features_importance),
        test_results=hl.eval(ht_result.test_results),
    )

    with hl.hadoop_open(f'{plot_dir}/ddd-elgh-ukbb/variant_qc/rf_runs.json',
                        "w") as f:
        json.dump(rf_runs, f)

    logger.info("Saving RF model")
    save_model(rf_model, f'{tmp_dir}/ddd-elgh-ukbb/rf_model.model')
Example #2
0
def main(args):

    print("main")
    ht = hl.read_table(
        f'{temp_dir}/ddd-elgh-ukbb/variant_qc/Sanger_table_for_RF_by_variant_type.ht'
    )

    if args.train_rf:

        run_hash = str(uuid.uuid4())[:8]
        rf_runs = get_rf_runs(f'{tmp_dir}/ddd-elgh-ukbb/')
        while run_hash in rf_runs:
            run_hash = str(uuid.uuid4())[:8]
        ht_result, rf_model = train_rf(ht, args)
        print("Writing out ht_training data")
        ht_result = ht_result.checkpoint(get_rf(data="training",
                                                run_hash=run_hash).path,
                                         overwrite=True)
        # f'{tmp_dir}/ddd-elgh-ukbb/Sanger_RF_training_data.ht', overwrite=True)
        rf_runs[run_hash] = get_run_data(
            vqsr_training=False,
            transmitted_singletons=True,
            test_intervals=args.test_intervals,
            adj=False,
            features_importance=hl.eval(ht_result.features_importance),
            test_results=hl.eval(ht_result.test_results),
        )

        with hl.hadoop_open(
                f'{plot_dir}/ddd-elgh-ukbb/variant_qc/rf_runs.json', "w") as f:
            json.dump(rf_runs, f)

        logger.info("Saving RF model")
        save_model(rf_model,
                   get_rf(data="model", run_hash=run_hash),
                   overwrite=True)
        # f'{tmp_dir}/ddd-elgh-ukbb/rf_model.model')
    else:
        run_hash = args.run_hash

    if args.apply_rf:

        logger.info(f"Applying RF model {run_hash}...")
        rf_model = load_model(get_rf(data="model", run_hash=run_hash))
        ht = get_rf(data="training", run_hash=run_hash).ht()
        features = hl.eval(ht.features)
        ht = apply_rf_model(ht, rf_model, features, label=LABEL_COL)

        logger.info("Finished applying RF model")
        ht = ht.annotate_globals(rf_hash=run_hash)
        ht = ht.checkpoint(
            get_rf("rf_result", run_hash=run_hash).path,
            overwrite=True,
        )

        ht_summary = ht.group_by("tp", "fp", TRAIN_COL, LABEL_COL,
                                 PREDICTION_COL).aggregate(n=hl.agg.count())
        ht_summary.show(n=20)
Example #3
0
def main(args):
    hl.init(log="/variant_qc_random_forest.log")

    if args.list_rf_runs:
        logger.info(f"RF runs:")
        pretty_print_runs(get_rf_runs(rf_run_path()))

    if args.annotate_for_rf:
        ht = create_rf_ht(
            impute_features=args.impute_features,
            adj=args.adj,
            n_partitions=args.n_partitions,
            checkpoint_path=get_checkpoint_path("rf_annotation"),
        )
        ht.write(
            get_rf_annotations(args.adj).path, overwrite=args.overwrite,
        )
        logger.info(f"Completed annotation wrangling for random forests model training")

    if args.train_rf:
        model_id = f"rf_{str(uuid.uuid4())[:8]}"
        rf_runs = get_rf_runs(rf_run_path())
        while model_id in rf_runs:
            model_id = f"rf_{str(uuid.uuid4())[:8]}"

        ht, rf_model = train_rf(
            get_rf_annotations(args.adj).ht(),
            fp_to_tp=args.fp_to_tp,
            num_trees=args.num_trees,
            max_depth=args.max_depth,
            no_transmitted_singletons=args.no_transmitted_singletons,
            no_inbreeding_coeff=args.no_inbreeding_coeff,
            vqsr_training=args.vqsr_training,
            vqsr_model_id=args.vqsr_model_id,
            filter_centromere_telomere=args.filter_centromere_telomere,
            test_intervals=args.test_intervals,
        )

        ht = ht.checkpoint(
            get_rf_training(model_id=model_id).path, overwrite=args.overwrite,
        )

        logger.info("Adding run to RF run list")
        rf_runs[model_id] = get_run_data(
            input_args={
                "transmitted_singletons": None
                if args.vqsr_training
                else not args.no_transmitted_singletons,
                "adj": args.adj,
                "vqsr_training": args.vqsr_training,
                "filter_centromere_telomere": args.filter_centromere_telomere,
            },
            test_intervals=args.test_intervals,
            features_importance=hl.eval(ht.features_importance),
            test_results=hl.eval(ht.test_results),
        )

        with hl.hadoop_open(rf_run_path(), "w") as f:
            json.dump(rf_runs, f)

        logger.info("Saving RF model")
        save_model(
            rf_model, get_rf_model_path(model_id=model_id), overwrite=args.overwrite,
        )

    else:
        model_id = args.model_id

    if args.apply_rf:
        logger.info(f"Applying RF model {model_id}...")
        rf_model = load_model(get_rf_model_path(model_id=model_id))
        ht = get_rf_training(model_id=model_id).ht()
        features = hl.eval(ht.features)
        ht = apply_rf_model(ht, rf_model, features, label=LABEL_COL)

        logger.info("Finished applying RF model")
        ht = ht.annotate_globals(rf_model_id=model_id)
        ht = ht.checkpoint(
            get_rf_result(model_id=model_id).path, overwrite=args.overwrite,
        )

        ht_summary = ht.group_by(
            "tp", "fp", TRAIN_COL, LABEL_COL, PREDICTION_COL
        ).aggregate(n=hl.agg.count())
        ht_summary.show(n=20)
Example #4
0
def main(args):

    print("importing main table")
    ht = hl.read_table(
        f'{nfs_dir}/hail_data/variant_qc/chd_ukbb.table_for_RF_by_variant_type_all_cols.ht'
    )

    if args.train_rf:
        # ht = hl.read_table(
        #    f'{temp_dir}/ddd-elgh-ukbb/variant_qc/Sanger_table_for_RF_by_variant_type.ht')
        run_hash = str(uuid.uuid4())[:8]
        rf_runs = get_rf_runs(f'{tmp_dir}/rf_runs.json')
        while run_hash in rf_runs:
            run_hash = str(uuid.uuid4())[:8]

        ht_result, rf_model = train_rf(ht, args)
        print("Writing out ht_training data")
        ht_result = ht_result.checkpoint(get_rf(data="training",
                                                run_hash=run_hash).path,
                                         overwrite=True)
        # f'{tmp_dir}/ddd-elgh-ukbb/Sanger_RF_training_data.ht', overwrite=True)
        rf_runs[run_hash] = get_run_data(
            vqsr_training=False,
            transmitted_singletons=True,
            test_intervals=args.test_intervals,
            adj=True,
            features_importance=hl.eval(ht_result.features_importance),
            test_results=hl.eval(ht_result.test_results),
        )

        with hl.hadoop_open(f'{tmp_dir}/rf_runs.json', "w") as f:
            json.dump(rf_runs, f)
        pretty_print_runs(rf_runs)
        logger.info("Saving RF model")
        save_model(rf_model,
                   get_rf(data="model", run_hash=run_hash),
                   overwrite=True)
        # f'{tmp_dir}/ddd-elgh-ukbb/rf_model.model')
    else:
        run_hash = args.run_hash

    if args.apply_rf:

        logger.info(f"Applying RF model {run_hash}...")
        rf_model = load_model(get_rf(data="model", run_hash=run_hash))

        ht = get_rf(data="training", run_hash=run_hash).ht()
        features = hl.eval(ht.features)
        ht = apply_rf_model(ht, rf_model, features, label=LABEL_COL)
        logger.info("Finished applying RF model")
        ht = ht.annotate_globals(rf_hash=run_hash)
        ht = ht.checkpoint(
            get_rf("rf_result_chd_ukbb", run_hash=run_hash).path,
            overwrite=True,
        )

        ht_summary = ht.group_by("tp", "fp", TRAIN_COL, LABEL_COL,
                                 PREDICTION_COL).aggregate(n=hl.agg.count())
        ht_summary.show(n=20)

    if args.finalize:

        # TODO: Adjust this step to run on the CHD-UKBB cohort

        run_hash = args.run_hash
        ht = hl.read_table(
            f'{tmp_dir}/variant_qc/models/{run_hash}/rf_result_ac_added.ht')
        # ht = create_grouped_bin_ht(
        #    model_id=run_hash, overwrite=True)
        freq_ht = hl.read_table(
            f'{tmp_dir}/variant_qc/mt_sampleQC_FILTERED_FREQ_adj.ht')
        freq = freq_ht[ht.key]

        print("created bin ht")

        ht = generate_final_rf_ht(
            ht,
            ac0_filter_expr=freq.freq[0].AC == 0,
            ts_ac_filter_expr=freq.freq[1].AC == 1,
            mono_allelic_fiter_expr=(freq.freq[1].AF == 1) |
            (freq.freq[1].AF == 0),
            snp_cutoff=args.snp_cutoff,
            indel_cutoff=args.indel_cutoff,
            determine_cutoff_from_bin=False,
            aggregated_bin_ht=bin_ht,
            bin_id=bin_ht.bin,
            inbreeding_coeff_cutoff=INBREEDING_COEFF_HARD_CUTOFF,
        )
        # This column is added by the RF module based on a 0.5 threshold which doesn't correspond to what we use
        # ht = ht.drop(ht[PREDICTION_COL])
        ht.write(f'{tmp_dir}/rf_final.ht', overwrite=True)
def main(args):

    print("main")
    ht = hl.read_table(
        f'{temp_dir}/ddd-elgh-ukbb/variant_qc/Sanger_table_for_RF_by_variant_type.ht'
    )

    if args.train_rf:

        run_hash = str(uuid.uuid4())[:8]
        rf_runs = get_rf_runs(f'{tmp_dir}/ddd-elgh-ukbb/')
        while run_hash in rf_runs:
            run_hash = str(uuid.uuid4())[:8]
        ht_result, rf_model = train_rf(ht, args)
        print("Writing out ht_training data")
        ht_result = ht_result.checkpoint(get_rf(data="training",
                                                run_hash=run_hash).path,
                                         overwrite=True)
        # f'{tmp_dir}/ddd-elgh-ukbb/Sanger_RF_training_data.ht', overwrite=True)
        rf_runs[run_hash] = get_run_data(
            vqsr_training=False,
            transmitted_singletons=True,
            test_intervals=args.test_intervals,
            adj=False,
            features_importance=hl.eval(ht_result.features_importance),
            test_results=hl.eval(ht_result.test_results),
        )

        with hl.hadoop_open(
                f'{plot_dir}/ddd-elgh-ukbb/variant_qc/rf_runs.json', "w") as f:
            json.dump(rf_runs, f)
        pretty_print_runs(rf_runs)
        logger.info("Saving RF model")
        save_model(rf_model,
                   get_rf(data="model", run_hash=run_hash),
                   overwrite=True)
        # f'{tmp_dir}/ddd-elgh-ukbb/rf_model.model')
    else:
        run_hash = args.run_hash

    if args.apply_rf:

        logger.info(f"Applying RF model {run_hash}...")
        #rf_model = load_model(get_rf(data="model", run_hash=run_hash))
        run_hash = args.run_hash
        rf_model = hl.read_table(
            f'{temp_dir}/ddd-elgh-ukbb/variant_qc/models/{run_hash}/model.model'
        )
        # ht = hl.read_table(
        #    f'{temp_dir}/ddd-elgh-ukbb/variant_qc/Sanger_cohorts_chr1-20-XY_sampleQC_FILTERED_FREQ_adj_inb.ht')
        ht = hl.read_table(
            f'{temp_dir}/ddd-elgh-ukbb/variant_qc/Sanger_cohorts_for_RF_unfiltered.ht'
        )
        ht = ht.annotate(rf_label=rf_model[ht.key].rf_label)
        #ht = get_rf(data="training", run_hash=run_hash).ht()
        features = hl.eval(rf_model.features)
        ht = apply_rf_model(ht, rf_model, features, label=LABEL_COL)
        logger.info("Finished applying RF model")
        ht = ht.annotate_globals(rf_hash=run_hash)
        ht = ht.checkpoint(
            get_rf("rf_result_sanger_cohorts", run_hash=run_hash).path,
            overwrite=True,
        )

        ht_summary = ht.group_by("tp", "fp", TRAIN_COL, LABEL_COL,
                                 PREDICTION_COL).aggregate(n=hl.agg.count())
        ht_summary.show(n=20)

    if args.finalize:
        run_hash = args.run_hash
        ht = hl.read_table(
            f'{temp_dir}/ddd-elgh-ukbb/variant_qc/models/{run_hash}/rf_result_ac_added.ht'
        )
        # ht = create_grouped_bin_ht(
        #    model_id=run_hash, overwrite=True)
        freq_ht = hl.read_table(
            f'{temp_dir}/ddd-elgh-ukbb/variant_qc/Sanger_cohorts_chr1-20-XY_sampleQC_FILTERED_FREQ_adj.ht'
        )
        freq = freq_ht[ht.key]
        bin_ht = create_quantile_bin_ht(run_hash,
                                        n_bins=100,
                                        vqsr=False,
                                        overwrite=True)
        # if not file_exists(
        #    get_score_quantile_bins(args.run_hash, aggregated=True).path
        # ):
        #    sys.exit(
        #        f"Could not find binned HT for RF  run {args.run_hash} (). Please run create_ranked_scores.py for that hash."
        #    )
        #aggregated_bin_ht = get_score_quantile_bins(ht, aggregated=True)
        print("created bin ht")

        ht = generate_final_rf_ht(
            ht,
            ac0_filter_expr=freq.freq[0].AC == 0,
            ts_ac_filter_expr=freq.freq[1].AC == 1,
            mono_allelic_fiter_expr=(freq.freq[1].AF == 1) |
            (freq.freq[1].AF == 0),
            snp_cutoff=args.snp_cutoff,
            indel_cutoff=args.indel_cutoff,
            determine_cutoff_from_bin=False,
            aggregated_bin_ht=bin_ht,
            bin_id=bin_ht.bin,
            inbreeding_coeff_cutoff=INBREEDING_COEFF_HARD_CUTOFF,
        )
        # This column is added by the RF module based on a 0.5 threshold which doesn't correspond to what we use
        #ht = ht.drop(ht[PREDICTION_COL])
        ht.write(f'{tmp_dir}/rf_final.ht', overwrite=True)
Example #6
0
def train_rf(data_type, args):

    # Get unique hash for run and load previous runs
    run_hash = str(uuid.uuid4())[:8]
    rf_runs = get_rf_runs(data_type)
    while run_hash in rf_runs:
        run_hash = str(uuid.uuid4())[:8]

    ht = hl.read_table(rf_annotated_path(data_type, args.adj))

    ht = ht.repartition(500, shuffle=False)

    if not args.vqsr_training:
        if args.no_transmitted_singletons:
            tp_expr = ht.omni | ht.mills | ht.info_POSITIVE_TRAIN_SITE
        else:
            tp_expr = ht.omni | ht.mills | ht.info_POSITIVE_TRAIN_SITE | ht.transmitted_singleton

        ht = ht.annotate(tp=tp_expr)

    test_intervals_str = [] if not args.test_intervals else [
        args.test_intervals
    ] if isinstance(args.test_intervals, str) else args.test_intervals
    test_intervals_locus = [
        hl.parse_locus_interval(x) for x in test_intervals_str
    ]

    if test_intervals_locus:
        ht = ht.annotate_globals(test_intervals=test_intervals_locus)

    ht = sample_rf_training_examples(
        ht,
        tp_col='info_POSITIVE_TRAIN_SITE' if args.vqsr_training else 'tp',
        fp_col='info_NEGATIVE_TRAIN_SITE'
        if args.vqsr_training else 'fail_hard_filters',
        fp_to_tp=args.fp_to_tp)

    ht = ht.persist()

    rf_ht = ht.filter(ht[TRAIN_COL])
    rf_features = get_features_list(
        True, not (args.vqsr_features or args.median_features),
        args.vqsr_features, args.median_features)

    logger.info(
        "Training RF model:\nfeatures: {}\nnum_tree: {}\nmax_depth:{}\nTest intervals: {}"
        .format(",".join(rf_features), args.num_trees, args.max_depth,
                ",".join(test_intervals_str)))

    rf_model = random_forest.train_rf(rf_ht,
                                      features=rf_features,
                                      label=LABEL_COL,
                                      num_trees=args.num_trees,
                                      max_depth=args.max_depth)

    logger.info("Saving RF model")
    random_forest.save_model(rf_model,
                             rf_path(data_type,
                                     data='model',
                                     run_hash=run_hash),
                             overwrite=args.overwrite)

    test_results = None
    if args.test_intervals:
        logger.info("Testing model {} on intervals {}".format(
            run_hash, ",".join(test_intervals_str)))
        test_ht = hl.filter_intervals(ht, test_intervals_locus, keep=True)
        test_ht = test_ht.checkpoint('gs://gnomad-tmp/test_random_forest.ht',
                                     overwrite=True)
        test_ht = test_ht.filter(hl.is_defined(test_ht[LABEL_COL]))
        test_results = random_forest.test_model(
            test_ht,
            rf_model,
            features=get_features_list(True, not args.vqsr_features,
                                       args.vqsr_features),
            label=LABEL_COL)
        ht = ht.annotate_globals(test_results=test_results)

    logger.info("Writing RF training HT")
    features_importance = random_forest.get_features_importance(rf_model)
    ht = ht.annotate_globals(
        features_importance=features_importance,
        features=get_features_list(True, not args.vqsr_features,
                                   args.vqsr_features),
        vqsr_training=args.vqsr_training,
        no_transmitted_singletons=args.no_transmitted_singletons,
        adj=args.adj)
    ht.write(rf_path(data_type, data='training', run_hash=run_hash),
             overwrite=args.overwrite)

    logger.info("Adding run to RF run list")
    rf_runs[run_hash] = get_run_data(data_type, args.vqsr_training,
                                     args.no_transmitted_singletons, args.adj,
                                     test_intervals_str, features_importance,
                                     test_results)
    with hl.hadoop_open(rf_run_hash_path(data_type), 'w') as f:
        json.dump(rf_runs, f)

    return run_hash