def pre_process_subset_freq(subset: str,
                            global_ht: hl.Table,
                            test: bool = False) -> hl.Table:
    """
    Prepare subset frequency Table by filling in missing frequency fields for loci present only in the global cohort.

    .. note::

        The resulting final `freq` array will be as long as the subset `freq_meta` global (i.e., one `freq` entry for each `freq_meta` entry)

    :param subset: subset ID
    :param global_ht: Hail Table containing all variants discovered in the overall release cohort
    :param test: If True, filter to small region on chr20
    :return: Table containing subset frequencies with missing freq structs filled in
    """

    # Read in subset HTs
    subset_ht_path = get_freq(subset=subset).path
    subset_chr20_ht_path = qc_temp_prefix() + f"chr20_test_freq.{subset}.ht"

    if test:
        if file_exists(subset_chr20_ht_path):
            logger.info(
                "Loading chr20 %s subset frequency data for testing: %s",
                subset,
                subset_chr20_ht_path,
            )
            subset_ht = hl.read_table(subset_chr20_ht_path)

        elif file_exists(subset_ht_path):
            logger.info(
                "Loading %s subset frequency data for testing: %s",
                subset,
                subset_ht_path,
            )
            subset_ht = hl.read_table(subset_ht_path)
            subset_ht = hl.filter_intervals(
                subset_ht, [hl.parse_locus_interval("chr20:1-1000000")])

    elif file_exists(subset_ht_path):
        logger.info("Loading %s subset frequency data: %s", subset,
                    subset_ht_path)
        subset_ht = hl.read_table(subset_ht_path)

    else:
        raise DataException(
            f"Hail Table containing {subset} subset frequencies not found. You may need to run the script generate_freq_data.py to generate frequency annotations first."
        )

    # Fill in missing freq structs
    ht = subset_ht.join(global_ht.select().select_globals(), how="right")
    ht = ht.annotate(freq=hl.if_else(
        hl.is_missing(ht.freq),
        hl.map(lambda x: missing_callstats_expr(),
               hl.range(hl.len(ht.freq_meta))),
        ht.freq,
    ))

    return ht
Пример #2
0
def create_rf_ht(
    impute_features: bool = True,
    adj: bool = False,
    n_partitions: int = 5000,
    checkpoint_path: Optional[str] = None,
) -> hl.Table:
    """
    Creates a Table with all necessary annotations for the random forest model.

    Annotations that are included:

        Features for RF:
            - InbreedingCoeff
            - variant_type
            - allele_type
            - n_alt_alleles
            - has_star
            - AS_QD
            - AS_pab_max
            - AS_MQRankSum
            - AS_SOR
            - AS_ReadPosRankSum

        Training sites (bool):
            - transmitted_singleton
            - fail_hard_filters - (ht.QD < 2) | (ht.FS > 60) | (ht.MQ < 30)

    :param bool impute_features: Whether to impute features using feature medians (this is done by variant type)
    :param str adj: Whether to use adj genotypes
    :param int n_partitions: Number of partitions to use for final annotated table
    :param str checkpoint_path: Optional checkpoint path for the Table before median imputation and/or aggregate summary
    :return: Hail Table ready for RF
    :rtype: Table
    """

    group = "adj" if adj else "raw"

    ht = get_info(split=True).ht()
    ht = ht.transmute(**ht.info)
    ht = ht.select("lowqual", "AS_lowqual", "FS", "MQ", "QD", *INFO_FEATURES)

    inbreeding_ht = get_freq().ht()
    inbreeding_ht = inbreeding_ht.select(
        InbreedingCoeff=hl.if_else(
            hl.is_nan(inbreeding_ht.InbreedingCoeff),
            hl.null(hl.tfloat32),
            inbreeding_ht.InbreedingCoeff,
        )
    )
    trio_stats_ht = fam_stats.ht()
    trio_stats_ht = trio_stats_ht.select(
        f"n_transmitted_{group}", f"ac_children_{group}"
    )

    truth_data_ht = get_truth_ht()
    allele_data_ht = allele_data.ht()
    allele_counts_ht = qc_ac.ht()

    logger.info("Annotating Table with all columns from multiple annotation Tables")
    ht = ht.annotate(
        **inbreeding_ht[ht.key],
        **trio_stats_ht[ht.key],
        **truth_data_ht[ht.key],
        **allele_data_ht[ht.key].allele_data,
        **allele_counts_ht[ht.key],
    )
    # Filter to only variants found in high quality samples and are not lowqual
    ht = ht.filter((ht[f"ac_qc_samples_{group}"] > 0) & ~ht.AS_lowqual)
    ht = ht.select(
        "a_index",
        "was_split",
        *FEATURES,
        *TRUTH_DATA,
        **{
            "transmitted_singleton": (ht[f"n_transmitted_{group}"] == 1)
            & (ht[f"ac_qc_samples_{group}"] == 2),
            "fail_hard_filters": (ht.QD < 2) | (ht.FS > 60) | (ht.MQ < 30),
        },
        singleton=ht.ac_release_samples_raw == 1,
        ac_raw=ht.ac_qc_samples_raw,
        ac=ht.ac_release_samples_adj,
        ac_qc_samples_unrelated_raw=ht.ac_qc_samples_unrelated_raw,
    )

    ht = ht.repartition(n_partitions, shuffle=False)
    if checkpoint_path:
        ht = ht.checkpoint(checkpoint_path, overwrite=True)

    if impute_features:
        ht = median_impute_features(ht, {"variant_type": ht.variant_type})

    summary = ht.group_by("omni", "mills", "transmitted_singleton",).aggregate(
        n=hl.agg.count()
    )
    logger.info("Summary of truth data annotations:")
    summary.show(20)

    return ht
Пример #3
0
def main(args):
    subsets = args.subsets
    hl.init(
        log=
        f"/generate_frequency_data{'.' + '_'.join(subsets) if subsets else ''}.log",
        default_reference="GRCh38",
    )

    invalid_subsets = []
    n_subsets_use_subpops = 0
    for s in subsets:
        if s not in SUBSETS:
            invalid_subsets.append(s)
        if s in COHORTS_WITH_POP_STORED_AS_SUBPOP:
            n_subsets_use_subpops += 1

    if invalid_subsets:
        raise ValueError(
            f"{', '.join(invalid_subsets)} subset(s) are not one of the following official subsets: {SUBSETS}"
        )
    if n_subsets_use_subpops & (n_subsets_use_subpops != len(subsets)):
        raise ValueError(
            f"All or none of the supplied subset(s) should be in the list of cohorts that need to use subpops instead "
            f"of pops in frequency calculations: {COHORTS_WITH_POP_STORED_AS_SUBPOP}"
        )

    try:
        logger.info("Reading full sparse MT and metadata table...")
        mt = get_gnomad_v3_mt(
            key_by_locus_and_alleles=True,
            release_only=not args.include_non_release,
            samples_meta=True,
        )

        if args.test:
            logger.info("Filtering to two partitions on chr20")
            mt = hl.filter_intervals(
                mt, [hl.parse_locus_interval("chr20:1-1000000")])
            mt = mt._filter_partitions(range(2))

        mt = hl.experimental.sparse_split_multi(mt, filter_changed_loci=True)

        if args.include_non_release:
            logger.info("Filtering MT columns to high quality samples")
            total_sample_count = mt.count_cols()
            mt = mt.filter_cols(mt.meta.high_quality)
            high_quality_sample_count = mt.count_cols()
            logger.info(
                f"Filtered {total_sample_count - high_quality_sample_count} from the full set of {total_sample_count} "
                f"samples...")

        if subsets:
            mt = mt.filter_cols(hl.any([mt.meta.subsets[s] for s in subsets]))
            logger.info(
                f"Running frequency generation pipeline on {mt.count_cols()} samples in {', '.join(subsets)} subset(s)..."
            )
        else:
            logger.info(
                f"Running frequency generation pipeline on {mt.count_cols()} samples..."
            )

        logger.info("Computing adj and sex adjusted genotypes...")
        mt = mt.annotate_entries(
            GT=adjusted_sex_ploidy_expr(mt.locus, mt.GT,
                                        mt.meta.sex_imputation.sex_karyotype),
            adj=get_adj_expr(mt.GT, mt.GQ, mt.DP, mt.AD),
        )

        logger.info("Densify-ing...")
        mt = hl.experimental.densify(mt)
        mt = mt.filter_rows(hl.len(mt.alleles) > 1)

        # Temporary hotfix for depletion of homozygous alternate genotypes
        logger.info(
            "Setting het genotypes at sites with >1% AF (using v3.0 frequencies) and > 0.9 AB to homalt..."
        )
        # Load v3.0 allele frequencies to avoid an extra frequency calculation
        # NOTE: Using previous callset AF works for small incremental changes to a callset, but we will need to revisit for large increments
        freq_ht = get_freq(version="3").ht()
        freq_ht = freq_ht.select(AF=freq_ht.freq[0].AF)

        mt = mt.annotate_entries(GT=hl.cond(
            (freq_ht[mt.row_key].AF > 0.01)
            & mt.GT.is_het()
            & (mt.AD[1] / mt.DP > 0.9),
            hl.call(1, 1),
            mt.GT,
        ))

        logger.info("Generating frequency data...")
        if subsets:
            mt = annotate_freq(
                mt,
                sex_expr=mt.meta.sex_imputation.sex_karyotype,
                pop_expr=mt.meta.population_inference.pop
                if not n_subsets_use_subpops else
                mt.meta.project_meta.project_subpop,
                # NOTE: TGP and HGDP labeled populations are highly specific and are stored in the project_subpop meta field
            )

            # NOTE: no FAFs or popmax needed for subsets
            mt = mt.select_rows("freq")

            logger.info(
                f"Writing out frequency data for {', '.join(subsets)} subset(s)..."
            )
            if args.test:
                mt.rows().write(
                    get_checkpoint_path(
                        f"chr20_test_freq.{'_'.join(subsets)}"),
                    overwrite=True,
                )
            else:
                mt.rows().write(get_freq(subset="_".join(subsets)).path,
                                overwrite=args.overwrite)

        else:
            logger.info("Computing age histograms for each variant...")
            mt = mt.annotate_cols(age=hl.if_else(
                hl.is_defined(mt.meta.project_meta.age),
                mt.meta.project_meta.age,
                mt.meta.project_meta.age_alt,
                # NOTE: most age data is stored as integers in 'age' annotation, but for a select number of samples, age is stored as a bin range and 'age_alt' corresponds to an integer in the middle of the bin
            ))
            mt = mt.annotate_rows(**age_hists_expr(mt.adj, mt.GT, mt.age))

            # Compute callset-wide age histogram global
            mt = mt.annotate_globals(age_distribution=mt.aggregate_cols(
                hl.agg.hist(mt.age, 30, 80, 10)))

            mt = annotate_freq(
                mt,
                sex_expr=mt.meta.sex_imputation.sex_karyotype,
                pop_expr=mt.meta.population_inference.pop,
                downsamplings=DOWNSAMPLINGS,
            )
            # Remove all loci with raw AC=0
            mt = mt.filter_rows(mt.freq[1].AC > 0)

            logger.info("Calculating InbreedingCoeff...")
            # NOTE: This is not the ideal location to calculate this, but added here to avoid another densify
            mt = mt.annotate_rows(
                InbreedingCoeff=bi_allelic_site_inbreeding_expr(mt.GT))

            logger.info("Computing filtering allele frequencies and popmax...")
            faf, faf_meta = faf_expr(mt.freq, mt.freq_meta, mt.locus,
                                     POPS_TO_REMOVE_FOR_POPMAX)
            mt = mt.select_rows(
                "InbreedingCoeff",
                "freq",
                faf=faf,
                popmax=pop_max_expr(mt.freq, mt.freq_meta,
                                    POPS_TO_REMOVE_FOR_POPMAX),
            )
            mt = mt.annotate_globals(
                faf_meta=faf_meta,
                faf_index_dict=make_faf_index_dict(faf_meta))
            mt = mt.annotate_rows(popmax=mt.popmax.annotate(
                faf95=mt.faf[mt.faf_meta.index(
                    lambda x: x.values() == ["adj", mt.popmax.pop])].faf95))

            logger.info("Annotating quality metrics histograms...")
            # NOTE: these are performed here as the quality metrics histograms also require densifying
            mt = mt.annotate_rows(
                qual_hists=qual_hist_expr(mt.GT, mt.GQ, mt.DP, mt.AD, mt.adj))
            ht = mt.rows()
            ht = ht.annotate(
                qual_hists=hl.Struct(
                    **{
                        i.replace("_adj", ""): ht.qual_hists[i]
                        for i in ht.qual_hists if "_adj" in i
                    }),
                raw_qual_hists=hl.Struct(**{
                    i: ht.qual_hists[i]
                    for i in ht.qual_hists if "_adj" not in i
                }),
            )

            logger.info("Writing out frequency data...")
            if args.test:
                ht.write(get_checkpoint_path("chr20_test_freq"),
                         overwrite=True)
            else:
                ht.write(get_freq().path, overwrite=args.overwrite)

    finally:
        logger.info("Copying hail log to logging bucket...")
        hl.copy_log(f"{qc_temp_prefix()}logs/")
def main(args):
    hl.init(log="/create_release_ht.log", default_reference="GRCh38")

    # The concatenated HT contains all subset frequency annotations, plus the overall cohort frequency annotations,
    # concatenated together in a single freq annotation ('freq')

    # Load global frequency Table
    if args.test:
        global_freq_chr20_ht_path = "gs://gnomad-tmp/gnomad_freq/chr20_test_freq.ht"

        if file_exists(global_freq_chr20_ht_path):
            logger.info(
                "Loading chr20 global frequency data for testing: %s",
                global_freq_chr20_ht_path,
            )
            global_freq_ht = (hl.read_table(global_freq_chr20_ht_path).select(
                "freq").select_globals("freq_meta"))

        elif file_exists(get_freq().path):
            logger.info("Loading global frequency data for testing: %s",
                        get_freq().path)
            global_freq_ht = (hl.read_table(
                get_freq().path).select("freq").select_globals("freq_meta"))
            global_freq_ht = hl.filter_intervals(
                global_freq_ht, [hl.parse_locus_interval("chr20:1-1000000")])

    elif file_exists(get_freq().path):
        logger.info("Loading global frequency data: %s", get_freq().path)
        global_freq_ht = (hl.read_table(
            get_freq().path).select("freq").select_globals("freq_meta"))

    else:
        raise DataException(
            "Hail Table containing global callset frequencies not found. You may need to run the script to generate frequency annotations first."
        )

    # Load subset frequency Table(s)
    if args.test:
        test_subsets = args.test_subsets
        subset_freq_hts = [
            pre_process_subset_freq(subset, global_freq_ht, test=True)
            for subset in test_subsets
        ]

    else:
        subset_freq_hts = [
            pre_process_subset_freq(subset, global_freq_ht)
            for subset in SUBSETS
        ]

    logger.info("Concatenating subset frequencies...")
    freq_ht = hl.Table.multi_way_zip_join(
        [global_freq_ht] + subset_freq_hts,
        data_field_name="freq",
        global_field_name="freq_meta",
    )
    freq_ht = freq_ht.transmute(freq=freq_ht.freq.flatmap(lambda x: x.freq))
    freq_ht = freq_ht.transmute_globals(
        freq_meta=freq_ht.freq_meta.flatmap(lambda x: x.freq_meta))

    # Create frequency index dictionary on concatenated array (i.e., including all subsets)
    # NOTE: non-standard downsampling values are created in the frequency script corresponding to population totals, so
    # callset-specific DOWNSAMPLINGS must be used instead of the generic DOWNSAMPLING values
    global_freq_ht = hl.read_table(get_freq().path)
    freq_ht = freq_ht.annotate_globals(freq_index_dict=make_freq_index_dict(
        freq_meta=hl.eval(freq_ht.freq_meta),
        pops=POPS,
        downsamplings=hl.eval(global_freq_ht.downsamplings),
    ))

    # Add back in all global frequency annotations not present in concatenated frequencies HT
    row_fields = global_freq_ht.row_value.keys() - freq_ht.row_value.keys()
    logger.info(
        "Adding back the following row annotations onto concatenated frequencies: %s",
        row_fields)
    freq_ht = freq_ht.annotate(**global_freq_ht[freq_ht.key].select(
        *row_fields))

    global_fields = global_freq_ht.globals.keys() - freq_ht.globals.keys()
    global_fields.remove("downsamplings")
    logger.info(
        "Adding back the following global annotations onto concatenated frequencies: %s",
        global_fields)
    freq_ht = freq_ht.annotate_globals(**global_freq_ht.index_globals().select(
        *global_fields))

    logger.info("Preparing release Table annotations...")
    ht = add_release_annotations(freq_ht)

    logger.info("Removing chrM and sites without filter...")
    ht = hl.filter_intervals(ht, [hl.parse_locus_interval("chrM")], keep=False)
    ht = ht.filter(hl.is_defined(ht.filters))

    ht = ht.checkpoint(
        qc_temp_prefix() + "release/gnomad.genomes.v3.1.sites.chr20.ht"
        if args.test else release_sites().path,
        args.overwrite,
    )
    logger.info("Final variant count: %d", ht.count())
    ht.describe()
    ht.show()
    ht.summarize()
Пример #5
0
def main(args):
    hl.init(log="/variant_qc_finalize.log")

    ht = get_score_bins(args.model_id, aggregated=False).ht()
    if args.filter_centromere_telomere:
        ht = ht.filter(
            ~hl.is_defined(telomeres_and_centromeres.ht()[ht.locus]))

    info_ht = get_info(split=True).ht()
    ht = ht.filter(~info_ht[ht.key].AS_lowqual)

    if args.model_id.startswith("vqsr_"):
        ht = ht.drop("info")

    freq_ht = get_freq().ht()
    ht = ht.annotate(InbreedingCoeff=freq_ht[ht.key].InbreedingCoeff)
    freq_idx = freq_ht[ht.key]
    aggregated_bin_path = get_score_bins(args.model_id, aggregated=True).path
    if not file_exists(aggregated_bin_path):
        sys.exit(
            f"Could not find binned HT for model: {args.model_id} ({aggregated_bin_path}). Please run create_ranked_scores.py for that hash."
        )
    aggregated_bin_ht = get_score_bins(args.model_id, aggregated=True).ht()

    ht = generate_final_filter_ht(
        ht,
        args.model_name,
        args.score_name,
        ac0_filter_expr=freq_idx.freq[0].AC == 0,
        ts_ac_filter_expr=freq_idx.freq[1].AC == 1,
        mono_allelic_flag_expr=(freq_idx.freq[1].AF == 1) |
        (freq_idx.freq[1].AF == 0),
        snp_bin_cutoff=args.snp_bin_cutoff,
        indel_bin_cutoff=args.indel_bin_cutoff,
        snp_score_cutoff=args.snp_score_cutoff,
        indel_score_cutoff=args.indel_score_cutoff,
        inbreeding_coeff_cutoff=args.inbreeding_coeff_threshold,
        aggregated_bin_ht=aggregated_bin_ht,
        bin_id="bin",
        vqsr_ht=get_vqsr_filters(args.vqsr_model_id, split=True).ht()
        if args.vqsr_model_id else None,
    )
    ht = ht.annotate_globals(
        filtering_model=ht.filtering_model.annotate(model_id=args.model_id, ))
    if args.model_id.startswith("vqsr_"):
        ht = ht.annotate_globals(filtering_model=ht.filtering_model.annotate(
            snv_training_variables=[
                "AS_QD",
                "AS_MQRankSum",
                "AS_ReadPosRankSum",
                "AS_FS",
                "AS_SOR",
                "AS_MQ",
            ],
            indel_training_variables=[
                "AS_QD",
                "AS_MQRankSum",
                "AS_ReadPosRankSum",
                "AS_FS",
                "AS_SOR",
            ],
        ))
    else:
        ht = ht.annotate_globals(filtering_model=ht.filtering_model.annotate(
            snv_training_variables=ht.features,
            indel_training_variables=ht.features,
        ))

    ht.write(final_filter.path, args.overwrite)

    final_filter_ht = final_filter.ht()
    final_filter_ht.summarize()