Beispiel #1
0
def default_compute_info(mt: hl.MatrixTable,
                         site_annotations: bool = False,
                         n_partitions: int = 5000) -> hl.Table:
    """
    Computes a HT with the typical GATK allele-specific (AS) info fields 
    as well as ACs and lowqual fields.
    Note that this table doesn't split multi-allelic sites.

    :param mt: Input MatrixTable. Note that this table should be filtered to nonref sites.
    :param site_annotations: Whether to also generate site level info fields. Default is False.
    :param n_partitions: Number of desired partitions for output Table. Default is 5000.
    :return: Table with info fields
    :rtype: Table
    """
    # Move gvcf info entries out from nested struct
    mt = mt.transmute_entries(**mt.gvcf_info)

    # Compute AS info expr
    info_expr = get_as_info_expr(mt)

    if site_annotations:
        info_expr = info_expr.annotate(**get_site_info_expr(mt))

    # Add AC and AC_raw:
    # First compute ACs for each non-ref allele, grouped by adj
    grp_ac_expr = hl.agg.array_agg(
        lambda ai: hl.agg.filter(
            mt.LA.contains(ai),
            hl.agg.group_by(
                get_adj_expr(mt.LGT, mt.GQ, mt.DP, mt.LAD),
                hl.agg.sum(
                    mt.LGT.one_hot_alleles(mt.LA.map(lambda x: hl.str(x)))[
                        mt.LA.index(ai)]),
            ),
        ),
        hl.range(1, hl.len(mt.alleles)),
    )

    # Then, for each non-ref allele, compute
    # AC as the adj group
    # AC_raw as the sum of adj and non-adj groups
    info_expr = info_expr.annotate(
        AC_raw=grp_ac_expr.map(
            lambda i: hl.int32(i.get(True, 0) + i.get(False, 0))),
        AC=grp_ac_expr.map(lambda i: hl.int32(i.get(True, 0))),
    )

    info_ht = mt.select_rows(info=info_expr).rows()

    # Add AS lowqual flag
    info_ht = info_ht.annotate(AS_lowqual=get_lowqual_expr(
        info_ht.alleles, info_ht.info.AS_QUALapprox))

    if site_annotations:
        # Add lowqual flag
        info_ht = info_ht.annotate(
            lowqual=get_lowqual_expr(info_ht.alleles, info_ht.info.QUALapprox))

    return info_ht.naive_coalesce(n_partitions)
Beispiel #2
0
def main(args):
    hl.init(log='/frequency_data_generation.log', default_reference='GRCh38')

    logger.info("Reading sparse MT and metadata table...")
    mt = get_gnomad_v3_mt(key_by_locus_and_alleles=True)
    meta_ht = meta.ht().select('pop', 'sex', 'project_id', 'release', 'sample_filters')

    if args.test:
        logger.info("Filtering to chr20:1-1000000")
        mt = hl.filter_intervals(mt, [hl.parse_locus_interval('chr20:1-1000000')])

    mt = hl.experimental.sparse_split_multi(mt, filter_changed_loci=True)

    logger.info("Annotating sparse MT with metadata...")
    mt = mt.annotate_cols(meta=meta_ht[mt.s])
    mt = mt.filter_cols(mt.meta.release)
    samples = mt.count_cols()
    logger.info(f"Running frequency table prep and generation pipeline on {samples} samples")

    logger.info("Computing adj and sex adjusted genotypes.")
    mt = mt.annotate_entries(
        GT=adjusted_sex_ploidy_expr(mt.locus, mt.GT, mt.meta.sex),
        adj=get_adj_expr(mt.GT, mt.GQ, mt.DP, mt.AD)
    )

    logger.info("Densify-ing...")
    mt = hl.experimental.densify(mt)
    mt = mt.filter_rows(hl.len(mt.alleles) > 1)

    logger.info("Generating frequency data...")
    mt = annotate_freq(
        mt,
        sex_expr=mt.meta.sex,
        pop_expr=mt.meta.pop
    )

    # Select freq, FAF and popmax
    faf, faf_meta = faf_expr(mt.freq, mt.freq_meta, mt.locus, POPS_TO_REMOVE_FOR_POPMAX)
    mt = mt.select_rows(
        'freq',
        faf=faf,
        popmax=pop_max_expr(mt.freq, mt.freq_meta, POPS_TO_REMOVE_FOR_POPMAX)
    )
    mt = mt.annotate_globals(faf_meta=faf_meta)

    # Annotate quality metrics histograms, as these also require densifying
    mt = mt.annotate_rows(
        **qual_hist_expr(mt.GT, mt.GQ, mt.DP, mt.AD)
    )

    logger.info("Writing out frequency data...")
    if args.test:
        mt.rows().write("gs://gnomad-tmp/gnomad_freq/chr20_1_1000000_freq.ht", overwrite=True)
    else:
        mt.rows().write(freq.path, overwrite=args.overwrite)
def compute_stats(stats_path: str):
    mt = get_gnomad_v3_mt()
    mt = mt.filter_entries(hl.is_defined(mt.END))
    ref_block_stats = mt.aggregate_entries(
        hl.struct(ref_block_stats=hl.struct(
            stats=hl.agg.stats(mt.END - mt.locus.position),
            hist=hl.agg.hist(mt.END - mt.locus.position, 0, 9999, 10000),
            hist_log=hl.agg.hist(hl.log10(1 + mt.END - mt.locus.position), 0,
                                 5, 100)),
                  adj_ref_block_stats=hl.agg.filter(
                      get_adj_expr(mt.LGT, mt.GQ, mt.DP, mt.LAD),
                      hl.struct(stats=hl.agg.stats(mt.END - mt.locus.position),
                                hist=hl.agg.hist(mt.END - mt.locus.position, 0,
                                                 9999, 10000),
                                hist_log=hl.agg.hist(
                                    hl.log10(1 + mt.END - mt.locus.position),
                                    0, 5, 100)))))

    with hl.hadoop_open(stats_path, 'wb') as f:
        pickle.dump(ref_block_stats, f)
Beispiel #4
0
def compute_qc_mt() -> hl.MatrixTable:
    # Load v2 and p5k sites for QC
    v2_qc_sites = get_liftover_v2_qc_mt('joint',
                                        ld_pruned=True).rows().key_by('locus')
    qc_sites = v2_qc_sites.union(purcell_5k_intervals.ht(), unify=True)

    qc_sites = qc_sites.filter(hl.is_missing(lcr_intervals.ht()[qc_sites.key]))

    mt = get_gnomad_v3_mt(key_by_locus_and_alleles=True)
    mt = mt.select_entries('END',
                           GT=mt.LGT,
                           adj=get_adj_expr(mt.LGT, mt.GQ, mt.DP, mt.LAD))
    mt = densify_sites(mt, qc_sites, hl.read_table(last_END_position.path))

    mt = mt.filter_rows((hl.len(mt.alleles) == 2)
                        & hl.is_snp(mt.alleles[0], mt.alleles[1])
                        & (qc_sites[mt.row_key].alleles == mt.alleles))
    mt = mt.checkpoint('gs://gnomad-tmp/gnomad_v3_qc_mt_v2_sites_dense.mt',
                       overwrite=True)
    mt = mt.naive_coalesce(5000)
    mt = mt.checkpoint(
        'gs://gnomad-tmp/gnomad_v3_qc_mt_v2_sites_dense_repartitioned.mt',
        overwrite=True)
    info_ht = get_info(split=False).ht()
    info_ht = info_ht.annotate(info=info_ht.info.select(
        # No need for AS_annotations since it's bi-allelic sites only
        **
        {x: info_ht.info[x]
         for x in info_ht.info if not x.startswith('AS_')}))
    mt = mt.annotate_rows(info=info_ht[mt.row_key].info)
    qc_mt = get_qc_mt(mt,
                      min_af=0.0,
                      min_inbreeding_coeff_threshold=-0.025,
                      min_hardy_weinberg_threshold=None,
                      ld_r2=None,
                      filter_lcr=False,
                      filter_decoy=False,
                      filter_segdup=False)
    return qc_mt
Beispiel #5
0
def main(args):
    hl.init(log='/frequency_data_generation.log', default_reference='GRCh38')

    logger.info("Reading sparse MT and metadata table...")
    mt = get_gnomad_v3_mt(key_by_locus_and_alleles=True)
    meta_ht = meta.ht().select('pop', 'sex', 'project_id', 'release', 'sample_filters')

    if args.test:
        logger.info("Filtering to chr20:1-1000000")
        mt = hl.filter_intervals(mt, [hl.parse_locus_interval('chr20:1-1000000')])

    mt = hl.experimental.sparse_split_multi(mt, filter_changed_loci=True)

    logger.info("Annotating sparse MT with metadata...")
    mt = mt.annotate_cols(meta=meta_ht[mt.s])
    mt = mt.filter_cols(mt.meta.release)
    samples = mt.count_cols()
    logger.info(f"Running frequency table prep and generation pipeline on {samples} samples")

    logger.info("Computing adj and sex adjusted genotypes.")
    mt = mt.annotate_entries(
        GT=adjusted_sex_ploidy_expr(mt.locus, mt.GT, mt.meta.sex),
        adj=get_adj_expr(mt.GT, mt.GQ, mt.DP, mt.AD)
    )

    logger.info("Densify-ing...")
    mt = hl.experimental.densify(mt)
    mt = mt.filter_rows(hl.len(mt.alleles) > 1)

    logger.info("Setting het genotypes at sites with >1% AF (using v3.0 frequencies) and > 0.9 AB to homalt...")
    # hotfix for depletion of homozygous alternate genotypes
    # Using v3.0 AF to avoid an extra frequency calculation
    # TODO: Using previous callset AF works for small incremental changes to a callset, but we need to revisit for large increments
    freq_ht = freq.versions["3"].ht()
    freq_ht = freq_ht.select(AF=freq_ht.freq[0].AF)

    mt = mt.annotate_entries(
        GT=hl.cond(
            (freq_ht[mt.row_key].AF > 0.01)
            & mt.GT.is_het()
            & (mt.AD[1] / mt.DP > 0.9),
            hl.call(1, 1),
            mt.GT,
        )
    )

    logger.info("Calculating InbreedingCoefficient...")
    # NOTE: This is not the ideal location to calculate this, but added here to avoid another densify
    mt = mt.annotate_rows(InbreedingCoeff=bi_allelic_site_inbreeding_expr(mt.GT))

    logger.info("Generating frequency data...")
    mt = annotate_freq(
        mt,
        sex_expr=mt.meta.sex,
        pop_expr=mt.meta.pop
    )

    # Select freq, FAF and popmax
    faf, faf_meta = faf_expr(mt.freq, mt.freq_meta, mt.locus, POPS_TO_REMOVE_FOR_POPMAX)
    mt = mt.select_rows(
        'InbreedingCoeff',
        'freq',
        faf=faf,
        popmax=pop_max_expr(mt.freq, mt.freq_meta, POPS_TO_REMOVE_FOR_POPMAX)
    )
    mt = mt.annotate_globals(faf_meta=faf_meta)

    # Annotate quality metrics histograms, as these also require densifying
    mt = mt.annotate_rows(
        **qual_hist_expr(mt.GT, mt.GQ, mt.DP, mt.AD)
    )

    logger.info("Writing out frequency data...")
    if args.test:
        mt.rows().write("gs://gnomad-tmp/gnomad_freq/chr20_1_1000000_freq.ht", overwrite=True)
    else:
        mt.rows().write(freq.path, overwrite=args.overwrite)
Beispiel #6
0
def main(args):
    subsets = args.subsets
    hl.init(
        log=
        f"/generate_frequency_data{'.' + '_'.join(subsets) if subsets else ''}.log",
        default_reference="GRCh38",
    )

    invalid_subsets = []
    n_subsets_use_subpops = 0
    for s in subsets:
        if s not in SUBSETS:
            invalid_subsets.append(s)
        if s in COHORTS_WITH_POP_STORED_AS_SUBPOP:
            n_subsets_use_subpops += 1

    if invalid_subsets:
        raise ValueError(
            f"{', '.join(invalid_subsets)} subset(s) are not one of the following official subsets: {SUBSETS}"
        )
    if n_subsets_use_subpops & (n_subsets_use_subpops != len(subsets)):
        raise ValueError(
            f"All or none of the supplied subset(s) should be in the list of cohorts that need to use subpops instead "
            f"of pops in frequency calculations: {COHORTS_WITH_POP_STORED_AS_SUBPOP}"
        )

    try:
        logger.info("Reading full sparse MT and metadata table...")
        mt = get_gnomad_v3_mt(
            key_by_locus_and_alleles=True,
            release_only=not args.include_non_release,
            samples_meta=True,
        )

        if args.test:
            logger.info("Filtering to two partitions on chr20")
            mt = hl.filter_intervals(
                mt, [hl.parse_locus_interval("chr20:1-1000000")])
            mt = mt._filter_partitions(range(2))

        mt = hl.experimental.sparse_split_multi(mt, filter_changed_loci=True)

        if args.include_non_release:
            logger.info("Filtering MT columns to high quality samples")
            total_sample_count = mt.count_cols()
            mt = mt.filter_cols(mt.meta.high_quality)
            high_quality_sample_count = mt.count_cols()
            logger.info(
                f"Filtered {total_sample_count - high_quality_sample_count} from the full set of {total_sample_count} "
                f"samples...")

        if subsets:
            mt = mt.filter_cols(hl.any([mt.meta.subsets[s] for s in subsets]))
            logger.info(
                f"Running frequency generation pipeline on {mt.count_cols()} samples in {', '.join(subsets)} subset(s)..."
            )
        else:
            logger.info(
                f"Running frequency generation pipeline on {mt.count_cols()} samples..."
            )

        logger.info("Computing adj and sex adjusted genotypes...")
        mt = mt.annotate_entries(
            GT=adjusted_sex_ploidy_expr(mt.locus, mt.GT,
                                        mt.meta.sex_imputation.sex_karyotype),
            adj=get_adj_expr(mt.GT, mt.GQ, mt.DP, mt.AD),
        )

        logger.info("Densify-ing...")
        mt = hl.experimental.densify(mt)
        mt = mt.filter_rows(hl.len(mt.alleles) > 1)

        # Temporary hotfix for depletion of homozygous alternate genotypes
        logger.info(
            "Setting het genotypes at sites with >1% AF (using v3.0 frequencies) and > 0.9 AB to homalt..."
        )
        # Load v3.0 allele frequencies to avoid an extra frequency calculation
        # NOTE: Using previous callset AF works for small incremental changes to a callset, but we will need to revisit for large increments
        freq_ht = get_freq(version="3").ht()
        freq_ht = freq_ht.select(AF=freq_ht.freq[0].AF)

        mt = mt.annotate_entries(GT=hl.cond(
            (freq_ht[mt.row_key].AF > 0.01)
            & mt.GT.is_het()
            & (mt.AD[1] / mt.DP > 0.9),
            hl.call(1, 1),
            mt.GT,
        ))

        logger.info("Generating frequency data...")
        if subsets:
            mt = annotate_freq(
                mt,
                sex_expr=mt.meta.sex_imputation.sex_karyotype,
                pop_expr=mt.meta.population_inference.pop
                if not n_subsets_use_subpops else
                mt.meta.project_meta.project_subpop,
                # NOTE: TGP and HGDP labeled populations are highly specific and are stored in the project_subpop meta field
            )

            # NOTE: no FAFs or popmax needed for subsets
            mt = mt.select_rows("freq")

            logger.info(
                f"Writing out frequency data for {', '.join(subsets)} subset(s)..."
            )
            if args.test:
                mt.rows().write(
                    get_checkpoint_path(
                        f"chr20_test_freq.{'_'.join(subsets)}"),
                    overwrite=True,
                )
            else:
                mt.rows().write(get_freq(subset="_".join(subsets)).path,
                                overwrite=args.overwrite)

        else:
            logger.info("Computing age histograms for each variant...")
            mt = mt.annotate_cols(age=hl.if_else(
                hl.is_defined(mt.meta.project_meta.age),
                mt.meta.project_meta.age,
                mt.meta.project_meta.age_alt,
                # NOTE: most age data is stored as integers in 'age' annotation, but for a select number of samples, age is stored as a bin range and 'age_alt' corresponds to an integer in the middle of the bin
            ))
            mt = mt.annotate_rows(**age_hists_expr(mt.adj, mt.GT, mt.age))

            # Compute callset-wide age histogram global
            mt = mt.annotate_globals(age_distribution=mt.aggregate_cols(
                hl.agg.hist(mt.age, 30, 80, 10)))

            mt = annotate_freq(
                mt,
                sex_expr=mt.meta.sex_imputation.sex_karyotype,
                pop_expr=mt.meta.population_inference.pop,
                downsamplings=DOWNSAMPLINGS,
            )
            # Remove all loci with raw AC=0
            mt = mt.filter_rows(mt.freq[1].AC > 0)

            logger.info("Calculating InbreedingCoeff...")
            # NOTE: This is not the ideal location to calculate this, but added here to avoid another densify
            mt = mt.annotate_rows(
                InbreedingCoeff=bi_allelic_site_inbreeding_expr(mt.GT))

            logger.info("Computing filtering allele frequencies and popmax...")
            faf, faf_meta = faf_expr(mt.freq, mt.freq_meta, mt.locus,
                                     POPS_TO_REMOVE_FOR_POPMAX)
            mt = mt.select_rows(
                "InbreedingCoeff",
                "freq",
                faf=faf,
                popmax=pop_max_expr(mt.freq, mt.freq_meta,
                                    POPS_TO_REMOVE_FOR_POPMAX),
            )
            mt = mt.annotate_globals(
                faf_meta=faf_meta,
                faf_index_dict=make_faf_index_dict(faf_meta))
            mt = mt.annotate_rows(popmax=mt.popmax.annotate(
                faf95=mt.faf[mt.faf_meta.index(
                    lambda x: x.values() == ["adj", mt.popmax.pop])].faf95))

            logger.info("Annotating quality metrics histograms...")
            # NOTE: these are performed here as the quality metrics histograms also require densifying
            mt = mt.annotate_rows(
                qual_hists=qual_hist_expr(mt.GT, mt.GQ, mt.DP, mt.AD, mt.adj))
            ht = mt.rows()
            ht = ht.annotate(
                qual_hists=hl.Struct(
                    **{
                        i.replace("_adj", ""): ht.qual_hists[i]
                        for i in ht.qual_hists if "_adj" in i
                    }),
                raw_qual_hists=hl.Struct(**{
                    i: ht.qual_hists[i]
                    for i in ht.qual_hists if "_adj" not in i
                }),
            )

            logger.info("Writing out frequency data...")
            if args.test:
                ht.write(get_checkpoint_path("chr20_test_freq"),
                         overwrite=True)
            else:
                ht.write(get_freq().path, overwrite=args.overwrite)

    finally:
        logger.info("Copying hail log to logging bucket...")
        hl.copy_log(f"{qc_temp_prefix()}logs/")
Beispiel #7
0
def compute_info() -> hl.Table:
    """
    Computes a HT with the typical GATK AS and site-level info fields as well as ACs and lowqual fields.

    Note that this table doesn't split multi-allelic sites.

    :return: Table with info fields
    :rtype: Table
    """
    mt = get_gnomad_v3_mt(key_by_locus_and_alleles=True,
                          remove_hard_filtered_samples=False)

    mt = mt.filter_rows((hl.len(mt.alleles) > 1))
    mt = mt.transmute_entries(**mt.gvcf_info)
    mt = mt.annotate_rows(
        alt_alleles_range_array=hl.range(1, hl.len(mt.alleles)))

    # Compute AS and site level info expr
    # Note that production defaults have changed:
    # For new releases, the `RAWMQ_andDP` field replaces the `RAW_MQ` and `MQ_DP` fields
    info_expr = get_site_info_expr(
        mt,
        sum_agg_fields=INFO_SUM_AGG_FIELDS + ["RAW_MQ"],
        int32_sum_agg_fields=INFO_INT32_SUM_AGG_FIELDS + ["MQ_DP"],
        array_sum_agg_fields=["SB"],
    )
    info_expr = info_expr.annotate(**get_as_info_expr(
        mt,
        sum_agg_fields=INFO_SUM_AGG_FIELDS + ["RAW_MQ"],
        int32_sum_agg_fields=INFO_INT32_SUM_AGG_FIELDS + ["MQ_DP"],
        array_sum_agg_fields=["SB"],
    ))

    # Add AC and AC_raw:
    # First compute ACs for each non-ref allele, grouped by adj
    grp_ac_expr = hl.agg.array_agg(
        lambda ai: hl.agg.filter(
            mt.LA.contains(ai),
            hl.agg.group_by(
                get_adj_expr(mt.LGT, mt.GQ, mt.DP, mt.LAD),
                hl.agg.sum(
                    mt.LGT.one_hot_alleles(mt.LA.map(lambda x: hl.str(x)))[
                        mt.LA.index(ai)]),
            ),
        ),
        mt.alt_alleles_range_array,
    )

    # Then, for each non-ref allele, compute
    # AC as the adj group
    # AC_raw as the sum of adj and non-adj groups
    info_expr = info_expr.annotate(
        AC_raw=grp_ac_expr.map(
            lambda i: hl.int32(i.get(True, 0) + i.get(False, 0))),
        AC=grp_ac_expr.map(lambda i: hl.int32(i.get(True, 0))),
    )

    # Annotating raw MT with pab max
    info_expr = info_expr.annotate(AS_pab_max=hl.agg.array_agg(
        lambda ai: hl.agg.filter(
            mt.LA.contains(ai) & mt.LGT.is_het(),
            hl.agg.max(
                hl.binom_test(mt.LAD[mt.LA.index(ai)], hl.sum(mt.LAD), 0.5,
                              "two-sided")),
        ),
        mt.alt_alleles_range_array,
    ))

    info_ht = mt.select_rows(info=info_expr).rows()

    # Add lowqual flag
    info_ht = info_ht.annotate(
        lowqual=get_lowqual_expr(
            info_ht.alleles,
            info_ht.info.QUALapprox,
            # The indel het prior used for gnomad v3 was 1/10k bases (phred=40).
            # This value is usually 1/8k bases (phred=39).
            indel_phred_het_prior=40,
        ),
        AS_lowqual=get_lowqual_expr(info_ht.alleles,
                                    info_ht.info.AS_QUALapprox,
                                    indel_phred_het_prior=40),
    )

    return info_ht.naive_coalesce(7500)
Beispiel #8
0
def determine_pca_variants(
    autosomes_only: bool = True,
    snv_only: bool = True,
    bi_allelic_only: bool = False,
    adj_only: bool = True,
    min_gnomad_v3_ac: Optional[int] = None,
    high_qual_ccdg_exome_interval_only: bool = False,
    high_qual_ukbb_exome_interval_only: bool = False,
    pct_samples_ukbb_exome_interval: float = 0.8,
    min_joint_af: float = 0.0001,  # TODO: Konrad mentioned that he might want to lower this
    min_joint_callrate: float = 0.95,
    min_inbreeding_coeff_threshold: Optional[float] = -0.8,
    min_hardy_weinberg_threshold: Optional[float] = 1e-8,
    min_ccdg_exome_callrate: float = 0.99,  # TODO: What parameter should this start with?
    min_ukbb_exome_callrate: float = 0.99,  # TODO: What parameter should this start with?
    filter_lcr: bool = True,
    filter_segdup: bool = True,
    ld_pruning: bool = True,
    ld_pruning_dataset: str = "ccdg_genomes",
    ld_r2: float = 0.1,
    read_per_dataset_checkpoint_if_exists: bool = False,
    read_pre_ld_prune_ht_checkpoint_if_exists: bool = False,
    read_pre_ld_prune_mt_checkpoint_if_exists: bool = False,
    overwrite: bool = True,
    filter_washu: bool = False,
) -> None:
    """
    Determine a diverse set of variants for relatedness/ancestry PCA using CCDG, gnomAD v3, and UK Biobank.

    :param autosomes_only: Whether to filter to variants in autosomes
    :param snv_only: Whether to filter to SNVs
    :param bi_allelic_only: Whether to filter to variants that are bi-allelic in either CCDG and gnomAD v3
    :param adj_only: If set, only ADJ genotypes (QD >= 2, FS <= 60 and MQ >= 30) are kept. This filter is applied before the call rate and AF calculation
    :param min_gnomad_v3_ac: Optional lower bound of AC for variants in gnomAD v3 genomes
    :param high_qual_ccdg_exome_interval_only: Whether to filter to high quality intervals in CCDG exomes
    :param float pct_samples_ukbb_exome_interval: Percent of samples with over 80% of bases having coverage of over 20x per interval
    :param high_qual_ukbb_exome_interval_only: Whether to filter to high quality intervals in UKBB 455K exomes
    :param float pct_samples_ukbb: Percent of samples with coverage greater than 20x over the interval for filtering
    :param min_joint_af: Lower bound for combined MAF computed from CCDG and gnomAD v3 genomes
    :param min_joint_callrate: Lower bound for combined callrate computed from CCDG and gnomAD v3 genomes
    :param min_inbreeding_coeff_threshold: Minimum site inbreeding coefficient to keep. Not applied if set to `None`
    :param min_hardy_weinberg_threshold: Minimum site HW test p-value to keep. Not applied if set to `None`
    :param min_ccdg_exome_callrate: Lower bound for CCDG exomes callrate
    :param min_ukbb_exome_callrate: Lower bound for UKBB exomes callrate
    :param filter_lcr: Whether to filter LCR regions
    :param filter_segdup: Whether to filter Segdup regions
    :param ld_pruning: Whether to conduct LD pruning
    :param ld_pruning_dataset: Which dataset is used for LD pruning, 'ccdg_genomes' or 'gnomAD_genomes'
    :param ld_r2: LD pruning cutoff
    :param read_per_dataset_checkpoint_if_exists: Whether to read the CCDG exome/genome pre filtered HT if it exists.
        Each dataset possible filtered to: autosomes only, SNVs only, gnomAD v3.1.2 AC filter, CCDG high quality exome
        intervals, and UK Biobank high quality exome intervals
    :param read_pre_ld_prune_ht_checkpoint_if_exists: Whether to read in the PCA variant HT with no LD-pruning if it exists
    :param read_pre_ld_prune_mt_checkpoint_if_exists: Whether to read in the checkpointed MT filtered to variants in the
        PCA variant HT with no LD-pruning if it exists
    :param overwrite: Whether to overwrite the final variant HT
    :param filter_washu: Whether to filter out washU samples
    :return: Table with desired variants for PCA
    """
    if not read_pre_ld_prune_ht_checkpoint_if_exists:
        logger.info(
            "Loading gnomAD v3.1.2 release HT and UK Biobank 455K release HT ..."
        )
        flag = "_without_washu" if filter_washu else ""
        gnomad_ht = gnomad_public_release("genomes").ht()
        gnomad_ht = gnomad_ht.select(
            gnomad_was_split=gnomad_ht.was_split,
            gnomad_AC=gnomad_ht.freq[0].AC,
            gnomad_AN=gnomad_ht.freq[0].AN,
            gnomad_genomes_site_inbreeding_coeff=gnomad_ht.info.InbreedingCoeff,
            gnomad_genomes_homozygote_count=gnomad_ht.freq[0].homozygote_count,
        )
        if min_hardy_weinberg_threshold is not None:
            gnomad_ht = gnomad_ht.annotate(
                gnomad_genomes_hwe=hl.hardy_weinberg_test(
                    hl.int32(
                        (gnomad_ht.gnomad_AN / 2)
                        - gnomad_ht.gnomad_genomes_homozygote_count
                        - (
                            gnomad_ht.gnomad_AC
                            - (gnomad_ht.gnomad_genomes_homozygote_count * 2)
                        )
                    ),  # Num hom ref genotypes
                    hl.int32(
                        (
                            gnomad_ht.gnomad_AC
                            - (gnomad_ht.gnomad_genomes_homozygote_count * 2)
                        )
                    ),  # Num het genotypes
                    gnomad_ht.gnomad_genomes_homozygote_count,  # Num hom alt genotypes
                ),
            )

        ukbb_ht = hl.read_table(ukbb_release_ht_path("broad", 7))
        ukbb_ht = ukbb_ht.select(
            ukbb_AC=ukbb_ht.freq[0].AC,
            ukbb_AN=ukbb_ht.freq[0].AN,
        )
        ukbb_meta_ht = hl.read_table(ukbb_meta_ht_path("broad", 7))

        # Only count samples used in the UK Biobank exome frequency calculations
        ukbb_exome_count = ukbb_meta_ht.filter(
            ukbb_meta_ht.sample_filters.high_quality
            & hl.is_defined(ukbb_meta_ht.ukbb_meta.batch)
            & ~ukbb_meta_ht.sample_filters.related
        ).count()

        logger.info("Getting CCDG genome and exome sample counts...")
        ccdg_genome_count = get_ccdg_vds(
            "genomes", filter_washu=filter_washu
        ).variant_data.count_cols()
        logger.info(f"Number of CCDG genome samples: {ccdg_genome_count}...")
        ccdg_exome_count = get_ccdg_vds("exomes").variant_data.count_cols()
        logger.info(f"Number of CCDG exome samples: {ccdg_exome_count} ...")

        def _initial_filter(data_type):
            """
            Get Table of CCDG variants passing desired filters.

            Possible filters are:
                - Autosomes only
                - SNVs only
                - gnomAD v3.1.2 AC filter
                - CCDG high quality exome intervals
                - UK Biobank high quality exome intervals

            After densification of the VDS, rows are annotated with:
                - ccdg_{data_type}_was_split
                - ccdg_{data_type}_AC
                - ccdg_{data_type}_AN

            The filtered and annotated rows are returned as a Table and are also checkpointed
            :param data_type: Whether data is from genomes or exomes

            :return: Table of CCDG filtered variants
            """
            logger.info(
                "Loading CCDG %s VDS and splitting multi-allelics for initial filtering steps...",
                data_type,
            )
            vds = get_ccdg_vds(data_type, filter_washu=filter_washu)
            logger.info(
                f"{vds.variant_data.count_cols()} CCDG {data_type} samples loaded..."
            )
            vds = hl.vds.split_multi(vds)

            if autosomes_only:
                logger.info("Filtering CCDG %s VDS to autosomes...", data_type)
                vds = hl.vds.filter_chromosomes(vds, keep_autosomes=True)

            ht = vds.variant_data.rows()
            variant_filter_expr = True
            if snv_only:
                logger.info("Filtering CCDG %s VDS to SNVs...", data_type)
                variant_filter_expr &= hl.is_snp(ht.alleles[0], ht.alleles[1])

            if min_gnomad_v3_ac:
                logger.info(
                    "Filtering CCDG %s VDS to gnomAD v3.1.2 variants with adj-filtered AC > %d...",
                    data_type,
                    min_gnomad_v3_ac,
                )
                variant_filter_expr &= gnomad_ht[ht.key].gnomad_AC > min_gnomad_v3_ac

            vds = hl.vds.filter_variants(vds, ht.filter(variant_filter_expr), keep=True)

            if high_qual_ccdg_exome_interval_only:
                logger.info(
                    f"Filtering CCDG %s VDS to high quality (>80%% of samples with %dX coverage) CCDG exome intervals...",
                    data_type,
                    INTERVAL_DP,
                )
                interval_qc_ht = hl.read_table(
                    get_ccdg_results_path(
                        data_type="exomes", result=f"intervals_{INTERVAL_DP}x"
                    )
                )
                interval_qc_ht = interval_qc_ht.filter(interval_qc_ht.to_keep)
                vds = hl.vds.filter_intervals(
                    vds, intervals=interval_qc_ht.interval.collect(), keep=True
                )

            if high_qual_ukbb_exome_interval_only:
                if not autosomes_only:
                    raise ValueError(
                        "UK Biobank interval QC filtering is only available for autosomes!"
                    )

                logger.info(
                    "Filtering CCDG %s VDS to high quality (>80%% of samples with 20X coverage) UK Biobank exome intervals...",
                    data_type,
                )
                interval_qc_ht = hl.read_table(
                    ukbb_interval_qc_path("broad", 7, "autosomes")
                )  # Note: freeze 7 is all included in gnomAD v4
                interval_qc_ht = interval_qc_ht.filter(
                    interval_qc_ht["pct_samples_20x"] > pct_samples_ukbb_exome_interval
                )
                vds = hl.vds.filter_intervals(
                    vds, intervals=interval_qc_ht.interval.collect(), keep=True
                )

            logger.info("Densifying filtered CCDG %s VDS...", data_type)
            mt = hl.vds.to_dense_mt(vds)
            if adj_only:
                mt = filter_to_adj(mt)

            annotation_expr = {
                f"ccdg_{data_type}_was_split": mt.was_split,
                f"ccdg_{data_type}_AC": hl.agg.sum(mt.GT.n_alt_alleles()),
                f"ccdg_{data_type}_AN": hl.agg.count_where(hl.is_defined(mt.GT)) * 2,
            }

            if min_inbreeding_coeff_threshold is not None:
                annotation_expr[
                    f"ccdg_{data_type}_site_inbreeding_coeff"
                ] = bi_allelic_site_inbreeding_expr(mt.GT)
            if min_hardy_weinberg_threshold is not None:
                annotation_expr[f"ccdg_{data_type}_hwe"] = hl.agg.hardy_weinberg_test(
                    mt.GT
                )

            mt = mt.annotate_rows(**annotation_expr)
            ht = mt.rows().checkpoint(
                get_ccdg_results_path(
                    data_type=data_type,
                    mt=False,
                    result=f"pre_filtered_variants_interval{INTERVAL_DP}x{flag}",
                ),
                overwrite=(not read_per_dataset_checkpoint_if_exists),
                _read_if_exists=read_per_dataset_checkpoint_if_exists,
            )

            return ht

        logger.info(
            "Creating Table with joint gnomAD v3.1.2 and CCDG genome allele frequencies and callrate...",
        )
        ccdg_genomes_ht = _initial_filter("genomes")
        ccdg_exomes_ht = _initial_filter("exomes")
        ht = ccdg_exomes_ht.join(ccdg_genomes_ht, how="inner")
        ht = ht.annotate(**gnomad_ht[ht.key], **ukbb_ht[ht.key])
        ht = ht.annotate(
            joint_biallelic=(~ht.ccdg_genomes_was_split) | (~ht.gnomad_was_split),
            joint_AC=ht.ccdg_genomes_AC + ht.gnomad_AC,
            joint_AN=ht.ccdg_genomes_AN + ht.gnomad_AN,
        )
        total_genome_an = hl.eval(
            (gnomad_ht.freq_sample_count[0] + ccdg_genome_count) * 2
        )
        ht = ht.annotate(
            joint_AF=ht.joint_AC / ht.joint_AN,
            joint_callrate=ht.joint_AN / total_genome_an,
        )
        ht = ht.checkpoint(
            f"{get_joint_pca_variants_ht_path(filter_washu=filter_washu)}",
            overwrite=(not read_pre_ld_prune_ht_checkpoint_if_exists),
            _read_if_exists=read_pre_ld_prune_ht_checkpoint_if_exists,
        )

        logger.info(
            "Filtering variants to combined gnomAD v3.1.2 and CCDG genome AF of %.3f and callrate of %.2f, CCDG exome callrate "
            "of %.2f, and UK Biobank exome callrate of %.2f....",
            min_joint_af,
            min_joint_callrate,
            min_ccdg_exome_callrate,
            min_ukbb_exome_callrate,
        )

        variant_filter_expr = True
        if bi_allelic_only:
            variant_filter_expr &= ht.joint_biallelic
        if min_inbreeding_coeff_threshold is not None:
            variant_filter_expr &= (
                ht.ccdg_genomes_site_inbreeding_coeff > min_inbreeding_coeff_threshold
            ) & (
                ht.gnomad_genomes_site_inbreeding_coeff > min_inbreeding_coeff_threshold
            )
        if min_hardy_weinberg_threshold is not None:
            variant_filter_expr &= (
                ht.ccdg_genomes_hwe.p_value > min_hardy_weinberg_threshold
            ) & (ht.gnomad_genomes_hwe.p_value > min_hardy_weinberg_threshold)

        variant_filter_expr &= (
            (ht.joint_AF > min_joint_af)
            & (ht.joint_callrate > min_joint_callrate)
            & (ht.ccdg_exomes_AN / (ccdg_exome_count * 2) > min_ccdg_exome_callrate)
            & (ht.ukbb_AN / (ukbb_exome_count * 2) > min_ukbb_exome_callrate)
        )

        ht = ht.filter(variant_filter_expr)

        ht = ht.annotate_globals(
            autosomes_only=autosomes_only,
            snv_only=snv_only,
            adj_only=adj_only,
            bi_allelic_only=bi_allelic_only,
            min_gnomad_v3_ac=min_gnomad_v3_ac,
            high_qual_ccdg_exome_interval_only=high_qual_ccdg_exome_interval_only,
            high_qual_ukbb_exome_interval_only=high_qual_ukbb_exome_interval_only,
            filter_lcr=filter_lcr,
            filter_segdup=filter_segdup,
            min_af=min_joint_af,
            min_callrate=min_joint_callrate,
            min_ccdg_exome_callrate=min_ccdg_exome_callrate,
            min_ukbb_exome_callrate=min_ukbb_exome_callrate,
            min_inbreeding_coeff_threshold=min_inbreeding_coeff_threshold,
            min_hardy_weinberg_threshold=min_hardy_weinberg_threshold,
        )

        ht = filter_low_conf_regions(
            ht,
            filter_lcr=filter_lcr,
            filter_decoy=False,  # No decoy for GRCh38
            filter_segdup=filter_segdup,
        )

        ht = ht.checkpoint(
            get_pca_variants_path(ld_pruned=False, filter_washu=filter_washu),
            overwrite=True,
        )
    else:
        ht = hl.read_table(
            get_pca_variants_path(
                ld_pruned=False, data=ld_pruning_dataset, filter_washu=filter_washu
            )
        )

    if ld_pruning:
        # Whether this is still required?
        logger.warning(
            "The LD-prune step of this function requires non-preemptible workers only!"
        )
        logger.info("Creating Table after LD pruning of %s...", ld_pruning_dataset)
        if ld_pruning_dataset == "ccdg_genomes":
            vds = get_ccdg_vds("genomes")
            vds = hl.vds.split_multi(vds, filter_changed_loci=True)
            vds = hl.vds.filter_variants(vds, ht, keep=True)
            mt = hl.vds.to_dense_mt(vds)
        elif ld_pruning_dataset == "gnomad_genomes":
            mt = get_gnomad_v3_mt(key_by_locus_and_alleles=True)
            logger.info("Converting gnomAD v3.1 MatrixTable to VDS...")
            mt = mt.select_entries(
                "END", "LA", "LGT", adj=get_adj_expr(mt.LGT, mt.GQ, mt.DP, mt.LAD)
            )
            vds = hl.vds.VariantDataset.from_merged_representation(mt)

            logger.info("Performing split-multi and filtering variants...")
            vds = hl.vds.split_multi(vds, filter_changed_loci=True)
            vds = hl.vds.filter_variants(vds, ht)

            logger.info("Densifying data...")
            mt = hl.vds.to_dense_mt(vds)
        else:
            ValueError(
                "Only options for LD pruning are `ccdg_genomes` and `gnomad_genomes`"
            )

        hl._set_flags(no_whole_stage_codegen="1")
        mt = mt.checkpoint(
            get_pca_variants_path(ld_pruned=False, data=ld_pruning_dataset, mt=True),
            overwrite=(not read_pre_ld_prune_mt_checkpoint_if_exists),
            _read_if_exists=read_pre_ld_prune_mt_checkpoint_if_exists,
        )
        hl._set_flags(no_whole_stage_codegen=None)
        ht = hl.ld_prune(mt.GT, r2=ld_r2)
        ht = ht.annotate_globals(ld_r2=ld_r2, ld_pruning_dataset=ld_pruning_dataset)
        ht = ht.checkpoint(
            get_pca_variants_path(ld_pruned=True, data=ld_pruning_dataset),
            overwrite=overwrite,
            _read_if_exists=(not overwrite),
        )
        mt = mt.filter_rows(hl.is_defined(ht[mt.row_key]))
        mt.naive_coalesce(1000).write(
            get_pca_variants_path(ld_pruned=True, data=ld_pruning_dataset, mt=True),
            overwrite=overwrite,
        )
def main(args):
    if args.create_gene_sample_mt:
        mt = hl.read_matrix_table(
            'gs://gnomad/projects/compound_hets/myoseq/MacArthur_LGMD_Callset_Jan2019.mt'
        )
        meta = hl.read_table(
            'gs://gnomad/projects/compound_hets/myoseq/sample_qc/MacArthur_LGMD_Callset_Jan2019.full_meta.ht'
        )
        pop_distance = hl.read_table(
            'gs://gnomad-lfran/compound_hets/myoseq/sample_qc/myoseq_pop_distance_to_max_kde.ht'
        )
        variant_annotations_ht = hl.read_table(
            'gs://gnomad/projects/compound_hets/myoseq/MacArthur_LGMD_Callset_Jan2019.annotations.ht'
        )
        variant_annotations_ht.drop('was_split', 'a_index')
        mt = mt.annotate_cols(
            **meta[mt.col_key],
            **pop_distance[mt.col_key],
        )
        mt = mt.annotate_rows(**variant_annotations_ht[mt.row_key])

        # Filter samples failing QC
        mt = mt.filter_cols((hl.len(mt.sample_filters) == 0) & (
            mt.distance < args.pop_distance
        )  # NFE pop-distance away from densest point in KDE in pc-space (selects only NFEs)
                            )
        counts = mt.aggregate_cols(hl.agg.counter(mt.is_case))
        print(
            f'Found {counts[True]} cases and {counts[False]} controls for gene aggregation.'
        )

        # Filter sites failing QC, without any tx_annotation (i.e. without a protein-coding variant) or too common
        mt = mt.filter_rows(
            (hl.len(mt.filters) == 0) & hl.is_defined(mt.tx_annotation)
            & (hl.or_else(mt.gnomad_exomes_popmax.AF,
                          hl.or_else(mt.gnomad_genomes_popmax.AF, 0.0)) <
               args.max_gnomad_af))

        # Keep non-ref entries only
        entries_filter_expr = mt.GT.is_non_ref()
        if not args.raw:
            entries_filter_expr = mt.GT.is_non_ref() & get_adj_expr(
                mt.GT, mt.GQ, mt.DP, mt.AD, haploid_adj_dp=5)
        mt = mt.filter_entries(entries_filter_expr)

        # Annotate genes and
        mt = mt.annotate_rows(gene=hl.set(
            mt.tx_annotation.map(
                lambda x: hl.struct(gene_symbol=x.symbol, gene_id=x.ensg))))

        # Aggregate by gene
        mt = mt.explode_rows(mt.gene)
        mt = mt.annotate_rows(tx_annotation=mt.tx_annotation.filter(lambda x: (
            x.symbol == mt.gene.gene_symbol) & (x.ensg == mt.gene.gene_id)))
        # mt.write('gs://gnomad/projects/compound_hets/myoseq/MacArthur_LGMD_Callset_Jan2019_filtered_gene_exploded.mt', overwrite=True)

        # TODO: Add pext to missense counts

        # mt = hl.read_matrix_table('gs://gnomad/projects/compound_hets/myoseq/MacArthur_LGMD_Callset_Jan2019_filtered_gene_exploded.mt')
        mt = mt.group_rows_by(**mt.gene).aggregate(
            locus_interval=hl.locus_interval(hl.agg.take(mt.locus,
                                                         1)[0].contig,
                                             hl.agg.min(mt.locus.position),
                                             hl.agg.max(mt.locus.position),
                                             includes_end=True),
            n_het_lof=hl.agg.count_where(
                mt.GT.is_het()
                & mt.tx_annotation.any(lambda x: x.lof == 'HC')),
            n_hom_lof=hl.agg.count_where(
                mt.GT.is_hom_var()
                & mt.tx_annotation.any(lambda x: x.lof == 'HC')),
            n_het_lof_pext=hl.agg.count_where(mt.GT.is_het(
            ) & mt.tx_annotation.any(lambda x: (x.lof == 'HC') &
                                     (x.Muscle_Skeletal >= args.pext_cutoff))),
            n_hom_lof_pext=hl.agg.count_where(mt.GT.is_hom_var(
            ) & mt.tx_annotation.any(lambda x: (x.lof == 'HC') &
                                     (x.Muscle_Skeletal >= args.pext_cutoff))),
            n_het_missense=hl.agg.count_where(
                mt.GT.is_het()
                & mt.tx_annotation.any(lambda x: x.csq == 'missense_variant')),
            n_hom_missense=hl.agg.count_where(
                mt.GT.is_hom_var()
                & mt.tx_annotation.any(lambda x: x.csq == 'missense_variant')),
            n_het_damaging_missense=hl.agg.count_where(
                mt.GT.is_het() & mt.tx_annotation.any(
                    lambda x: (x.polyphen_prediction == 'probably damaging') |
                    (x.sift_prediction == 'deleterious'))),
            n_hom_damaging_missense=hl.agg.count_where(
                mt.GT.is_hom_var() & mt.tx_annotation.any(
                    lambda x: (x.polyphen_prediction == 'probably damaging') |
                    (x.sift_prediction == 'deleterious'))),
            n_het_synonymous=hl.agg.count_where(mt.GT.is_het(
            ) & mt.tx_annotation.any(lambda x: x.csq == 'synonymous_variant')),
            n_hom_synonymous=hl.agg.count_where(mt.GT.is_hom_var(
            ) & mt.tx_annotation.any(lambda x: x.csq == 'synonymous_variant'))
        ).write(
            'gs://gnomad/projects/compound_hets/myoseq/MacArthur_LGMD_Callset_Jan2019_gene_burden.mt',
            overwrite=args.overwrite)

    if args.run_burden_tests:
        mt = hl.read_matrix_table(
            'gs://gnomad/projects/compound_hets/myoseq/MacArthur_LGMD_Callset_Jan2019_gene_burden.mt'
        )

        def fet_expr(het_count_exp: hl.expr.Int64Expression,
                     hom_count_expr: hl.expr.Int64Expression):
            return hl.bind(
                lambda x: hl.struct(
                    counts=x,
                    dominant=hl.fisher_exact_test(x[0][0], x[0][1] + x[0][2],
                                                  x[1][0], x[1][1] + x[1][2]),
                    recessive=hl.fisher_exact_test(x[0][0] + x[0][1], x[0][
                        2], x[1][0] + x[1][1], x[1][2])),
                hl.bind(
                    lambda x: [
                        [
                            hl.int32(
                                hl.cond(x.contains(False), x[False].get(0, 0),
                                        0)),
                            hl.int32(
                                hl.cond(x.contains(False), x[False].get(1, 0),
                                        0)),
                            hl.int32(
                                hl.cond(x.contains(False), x[False].get(2, 0),
                                        0))
                        ],
                        [
                            hl.int32(
                                hl.cond(x.contains(True), x[True].get(0, 0), 0)
                            ),
                            hl.int32(
                                hl.cond(x.contains(True), x[True].get(1, 0), 0)
                            ),
                            hl.int32(
                                hl.cond(x.contains(True), x[True].get(2, 0), 0)
                            )
                        ],
                    ],
                    hl.agg.group_by(
                        mt.is_case,
                        hl.agg.counter(
                            hl.min(2, het_count_exp + 2 * hom_count_expr)))))

        mt = mt.annotate_rows(
            **{
                'lof':
                fet_expr(mt.n_het_lof, mt.n_hom_lof),
                'lof_pext':
                fet_expr(mt.n_het_lof_pext, mt.n_hom_lof_pext),
                'lof_missense':
                fet_expr(mt.n_het_lof + mt.n_het_missense, mt.n_het_lof +
                         mt.n_hom_missense),
                'lof_damaging_missense':
                fet_expr(mt.n_het_lof +
                         mt.n_het_damaging_missense, mt.n_het_lof +
                         mt.n_hom_damaging_missense),
                'synonymous':
                fet_expr(mt.n_het_synonymous, mt.n_hom_synonymous)
            })

        mt.write(
            'gs://gnomad/projects/compound_hets/myoseq/MacArthur_LGMD_Callset_Jan2019_gene_burden_tests.mt',
            overwrite=args.overwrite)