Esempio n. 1
0
        def with_local_a_index(local_a_index):
            new_pl = hl.or_missing(
                hl.is_defined(old_entry.LPL),
                hl.or_missing(
                    hl.is_defined(local_a_index),
                    hl.range(0, 3).map(lambda i: hl.min(
                        hl.range(0, hl.triangle(hl.len(old_entry.LA))).
                        filter(lambda j: hl.downcode(
                            hl.unphased_diploid_gt_index_call(j), local_a_index
                        ) == hl.unphased_diploid_gt_index_call(i)).map(
                            lambda idx: old_entry.LPL[idx])))))
            fields = set(old_entry.keys())

            def with_pl(pl):
                new_exprs = {}
                dropped_fields = ['LA']
                if 'LGT' in fields:
                    new_exprs['GT'] = hl.downcode(
                        old_entry.LGT,
                        hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LGT')
                if 'LPGT' in fields:
                    new_exprs['PGT'] = hl.downcode(
                        old_entry.LPGT,
                        hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LPGT')
                if 'LAD' in fields:
                    new_exprs['AD'] = hl.or_missing(
                        hl.is_defined(old_entry.LAD), [
                            old_entry.LAD[0],
                            hl.or_else(old_entry.LAD[local_a_index], 0)
                        ])  # second entry zeroed for lack of non-ref AD
                    dropped_fields.append('LAD')
                if 'LPL' in fields:
                    new_exprs['PL'] = pl
                    if 'GQ' in fields:
                        new_exprs['GQ'] = hl.or_else(hl.gq_from_pl(pl),
                                                     old_entry.GQ)

                    dropped_fields.append('LPL')

                return hl.cond(
                    hl.len(ds.alleles) == 1,
                    old_entry.annotate(
                        **{
                            f[1:]: old_entry[f]
                            for f in ['LGT', 'LPGT', 'LAD', 'LPL']
                            if f in fields
                        }).drop(*dropped_fields),
                    old_entry.annotate(**new_exprs).drop(*dropped_fields))

            if 'LPL' in fields:
                return hl.bind(with_pl, new_pl)
            else:
                return with_pl(None)
def annotate_sex(mt: hl.MatrixTable,
                 out_internal_mt_prefix: str,
                 male_threshold: float = 0.8,
                 female_threshold: float = 0.5) -> hl.MatrixTable:
    """
    Imputes sex, exports data, and annotates mt with this data
    NOTE: Evaluated in R (plots) and decided on cutoff of F<0.5 for females and F>0.8 for males (default) for genomes

    :param MatrixTable mt: MT containing samples to be ascertained for sex
    :param str out_internal_mt_prefix: file path prefix for tsv containing samples and sex imputation annotations
    :return: MatrixTable with imputed sex annotations stashed in column annotation 'sex_check'
    :rtype: MatrixTable
    """
    mt1 = hl.filter_intervals(mt, [hl.parse_locus_interval('chrX')])
    #mt = mt.filter_rows(mt.locus.in_x_nonpar())
    mtx_unphased = mt1.select_entries(
        GT=hl.unphased_diploid_gt_index_call(mt1.GT.n_alt_alleles()))
    #imputed_sex = hl.impute_sex(mtx_unphased.GT)
    sex_ht = hl.impute_sex(mtx_unphased.GT,
                           aaf_threshold=0.05,
                           female_threshold=female_threshold,
                           male_threshold=male_threshold)
    sex_ht.export(out_internal_mt_prefix + '.sex_check.txt.bgz')
    sex_colnames = ['f_stat', 'is_female']
    sex_ht = sex_ht.select(*sex_colnames)
    mt = mt.annotate_cols(**sex_ht[mt.col_key])
    return mt
Esempio n. 3
0
def test_pc_relate_simple_example():
    gs = hl.literal(
        [[0, 0, 0, 0, 1, 1, 1, 1],
         [0, 0, 1, 1, 0, 0, 1, 1],
         [0, 1, 0, 1, 0, 1, 0, 1],
         [0, 0, 1, 1, 0, 0, 1, 1]])
    scores = hl.literal([[0, 1], [1, 1], [1, 0], [0, 0]])
    mt = hl.utils.range_matrix_table(n_rows=8, n_cols=4)
    mt = mt.annotate_entries(GT=hl.unphased_diploid_gt_index_call(gs[mt.col_idx][mt.row_idx]))
    mt = mt.annotate_cols(scores=scores[mt.col_idx])
    pcr = hl.pc_relate(mt.GT, min_individual_maf=0, scores_expr=mt.scores)

    expected = [
        hl.Struct(i=0, j=1, kin=-0.14570713364640647,
                  ibd0=1.4823511628401964, ibd1=-0.38187379109476693, ibd2=-0.10047737174542953),
        hl.Struct(i=0, j=2, kin=0.16530591922102378,
                  ibd0=0.5234783206257841, ibd1=0.2918196818643366, ibd2=0.18470199750987923),
        hl.Struct(i=0, j=3, kin=-0.14570713364640647,
                  ibd0=1.4823511628401964, ibd1=-0.38187379109476693, ibd2=-0.10047737174542953),
        hl.Struct(i=1, j=2, kin=-0.14570713364640647,
                  ibd0=1.4823511628401964, ibd1=-0.38187379109476693, ibd2=-0.10047737174542953),
        hl.Struct(i=1, j=3, kin=0.14285714285714285,
                  ibd0=0.7027734170591313, ibd1=0.02302459445316596, ibd2=0.2742019884877027),
        hl.Struct(i=2, j=3, kin=-0.14570713364640647,
                  ibd0=1.4823511628401964, ibd1=-0.38187379109476693, ibd2=-0.10047737174542953),
    ]
    ht_expected = hl.Table.parallelize(expected)
    ht_expected = ht_expected.key_by(i=hl.struct(col_idx=ht_expected.i),
                                     j=hl.struct(col_idx=ht_expected.j))
    assert ht_expected._same(pcr)
Esempio n. 4
0
        def with_local_a_index(local_a_index):
            new_pl = hl.or_missing(
                hl.is_defined(old_entry.LPL),
                hl.or_missing(
                    hl.is_defined(local_a_index),
                    hl.range(0, 3).map(lambda i: hl.min(
                        hl.range(0, hl.triangle(hl.len(old_entry.LA)))
                            .filter(lambda j: hl.downcode(hl.unphased_diploid_gt_index_call(j), local_a_index) == hl.unphased_diploid_gt_index_call(i))
                            .map(lambda idx: old_entry.LPL[idx])))))
            fields = set(old_entry.keys())

            def with_pl(pl):
                new_exprs = {}
                dropped_fields = ['LA']
                if 'LGT' in fields:
                    new_exprs['GT'] = hl.downcode(old_entry.LGT, hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LGT')
                if 'LPGT' in fields:
                    new_exprs['PGT'] = hl.downcode(old_entry.LPGT, hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LPGT')
                if 'LAD' in fields:
                    new_exprs['AD'] = hl.or_missing(
                        hl.is_defined(old_entry.LAD),
                        [old_entry.LAD[0], hl.or_else(old_entry.LAD[local_a_index], 0)]) # second entry zeroed for lack of non-ref AD
                    dropped_fields.append('LAD')
                if 'LPL' in fields:
                    new_exprs['PL'] = pl
                    if 'GQ' in fields:
                        new_exprs['GQ'] = hl.or_else(hl.gq_from_pl(pl), old_entry.GQ)

                    dropped_fields.append('LPL')

                return hl.cond(hl.len(ds.alleles) == 1,
                                   old_entry.annotate(**{f[1:]: old_entry[f] for f in ['LGT', 'LPGT', 'LAD', 'LPL'] if f in fields}).drop(*dropped_fields),
                                   old_entry.annotate(**new_exprs).drop(*dropped_fields))

            if 'LPL' in fields:
                return hl.bind(with_pl, new_pl)
            else:
                return with_pl(None)
def main(args):
    mt = hl.read_matrix_table(args.matrixtable)
    # ld pruning
    pruned_ht = hl.ld_prune(mt.GT, r2=0.1)
    pruned_mt = mt.filter_rows(hl.is_defined(pruned_ht[mt.row_key]))
    pruned_mt.write(f"{args.output_dir}/mt_ldpruned.mt", overwrite=True)

    # PC relate
    pruned_mt = pruned_mt.select_entries(
        GT=hl.unphased_diploid_gt_index_call(pruned_mt.GT.n_alt_alleles()))

    eig, scores, _ = hl.hwe_normalized_pca(pruned_mt.GT,
                                           k=10,
                                           compute_loadings=False)
    scores.write(f"{args.output_dir}/mt_pruned.pca_scores.ht", overwrite=True)

    relatedness_ht = hl.pc_relate(pruned_mt.GT,
                                  min_individual_maf=0.05,
                                  scores_expr=scores[pruned_mt.col_key].scores,
                                  block_size=4096,
                                  min_kinship=0.05,
                                  statistics='kin2')
    relatedness_ht.write(f"{args.output_dir}/mt_relatedness.ht",
                         overwrite=True)
    pairs = relatedness_ht.filter(relatedness_ht['kin'] > 0.125)
    related_samples_to_remove = hl.maximal_independent_set(pairs.i,
                                                           pairs.j,
                                                           keep=False)
    related_samples_to_remove.write(
        f"{args.output_dir}/mt_related_samples_to_remove.ht", overwrite=True)

    pca_mt = pruned_mt.filter_cols(hl.is_defined(
        related_samples_to_remove[pruned_mt.col_key]),
                                   keep=False)
    related_mt = pruned_mt.filter_cols(hl.is_defined(
        related_samples_to_remove[pruned_mt.col_key]),
                                       keep=True)

    variants, samples = pca_mt.count()

    print(f"{samples} samples after relatedness step.")

    # Population pca

    plink_mt = pca_mt.annotate_cols(uid=pca_mt.s).key_cols_by('uid')
    hl.export_plink(plink_mt,
                    f"{args.output_dir}/mt_unrelated.plink",
                    fam_id=plink_mt.uid,
                    ind_id=plink_mt.uid)
    pca_evals, pca_scores, pca_loadings = hl.hwe_normalized_pca(
        pca_mt.GT, k=20, compute_loadings=True)
    pca_af_ht = pca_mt.annotate_rows(
        pca_af=hl.agg.mean(pca_mt.GT.n_alt_alleles()) / 2).rows()
    pca_loadings = pca_loadings.annotate(
        pca_af=pca_af_ht[pca_loadings.key].pca_af)
    pca_scores.write(f"{args.output_dir}/mt_pca_scores.ht", overwrite=True)
    pca_loadings.write(f"{args.output_dir}/mt_pca_loadings.ht", overwrite=True)

    pca_mt = pca_mt.annotate_cols(scores=pca_scores[pca_mt.col_key].scores)

    variants, samples = related_mt.count()
    print(
        'Projecting population PCs for {} related samples...'.format(samples))
    #related_scores = pc_project(related_mt, pca_loadings)
    #relateds = related_mt.cols()
    #relateds = relateds.annotate(scores=related_scores[relateds.key].scores)

    pca_mt.write(f"{args.output_dir}/mt_pca.mt", overwrite=True)
    p = hl.plot.scatter(pca_mt.scores[0],
                        pca_mt.scores[1],
                        title='PCA',
                        xlabel='PC1',
                        ylabel='PC2')
    output_file(f"{args.plot_dir}/pca.html")
    save(p)
Esempio n. 6
0
def main(args):
    n_partitions = 500

    # ANNOTATION TABLES:
    truth_data_ht = hl.read_table(args.truthset_table)
    trio_stats_table = hl.read_table(args.trio_stats_table)

    #inbreeding_ht = hl.read_table(f'{temp_dir}/ddd-elgh-ukbb/variant_qc/Sanger_cohorts_inbreeding.ht')
    allele_data_ht = hl.read_table(args.allele_data)
    allele_counts_ht = hl.read_table(args.allele_counts)
    allele_counts_ht = allele_counts_ht.select(
        *['ac_qc_samples_raw', 'ac_qc_samples_adj'])
    inbreeding_ht = hl.read_table(args.inbreeding)
    group = "raw"

    mt = hl.read_matrix_table(
        args.matrixtable)
    mt = mt.key_rows_by('locus').distinct_by_row(
    ).key_rows_by('locus', 'alleles')
    mt = mt.select_entries(
        GT=hl.unphased_diploid_gt_index_call(mt.GT.n_alt_alleles()))
    mt = mt.annotate_rows(InbreedingCoeff=hl.or_missing(
       ~hl.is_nan(mt.info.InbreedingCoeff), mt.info.InbreedingCoeff))
    ht = mt.rows()
    ht = ht.transmute(**ht.info)
    ht = ht.select( "MQ", "InbreedingCoeff", *INFO_FEATURES)

    trio_stats_ht = trio_stats_table.select(
        f"n_transmitted_{group}", f"ac_children_{group}"
    )

    ht = ht.annotate(
        **inbreeding_ht[ht.key],
        **trio_stats_ht[ht.key],
        **truth_data_ht[ht.key],
        **allele_data_ht[ht.key].allele_data,
        **allele_data_ht[ht.key],
        **allele_counts_ht[ht.key],
    )
    # Filter to only variants found in high quality samples or controls with no LowQual filter
    #ht = ht.filter(
    #    (ht[f"ac_children_{group}"] > 0)
    # )  # TODO: change to AS_lowqual for v3.1 or leave as is to be more consistent with v3.0? I will need to add this annotation if so
    ht = ht.annotate(fail_hard_filters=(ht.QD < 2)
                     | (ht.FS > 60) | (ht.MQ < 30))
    ht = ht.annotate(ac_raw=ht.ac_qc_samples_raw)
    ht = ht.annotate(transmitted_singleton=(
        ht[f"n_transmitted_{group}"] == 1) & (ht[f"ac_qc_samples_{group}"] == 2))

    # the following only selects the required RF fields but I commented it out because some of the fields excluded are needed later
    ht = ht.select(
        "a_index",
        "was_split",
        *FEATURES,
        *TRUTH_DATA,
        **{
            "transmitted_singleton": (ht[f"n_transmitted_{group}"] == 1)
            & (ht[f"ac_qc_samples_{group}"] == 2),
            "fail_hard_filters": (ht.QD < 2) | (ht.FS > 60) | (ht.MQ < 30),
        },
        ac_raw=ht.ac_qc_samples_raw

     )
    logger.info("Repartirioning")
    ht = ht.repartition(n_partitions, shuffle=False)
    ht = ht.checkpoint(
        f'{args.output_dir}/variant_qc/MegaWES_for_RF_all_cols.ht', overwrite=True)
    ht = median_impute_features(ht, {"variant_type": ht.variant_type})
    ht = ht.checkpoint(
        f'{args.output_dir}/variant_qc/MegaWES_for_RF_by_variant_type_all_cols.ht', overwrite=True)
Esempio n. 7
0
        def with_local_a_index(local_a_index):
            fields = set(old_entry.keys())

            def with_pl(pl):
                new_exprs = {}
                dropped_fields = ['LA']
                if 'LGT' in fields:
                    new_exprs['GT'] = hl.downcode(
                        old_entry.LGT,
                        hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LGT')
                if 'LPGT' in fields:
                    new_exprs['PGT'] = hl.downcode(
                        old_entry.LPGT,
                        hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LPGT')
                if 'LAD' in fields:
                    non_ref_ad = hl.or_else(old_entry.LAD[local_a_index],
                                            0)  # zeroed if not in LAD
                    new_exprs['AD'] = hl.or_missing(
                        hl.is_defined(old_entry.LAD),
                        [hl.sum(old_entry.LAD) - non_ref_ad, non_ref_ad])
                    dropped_fields.append('LAD')
                if 'LPL' in fields:
                    new_exprs['PL'] = pl
                    if 'GQ' in fields:
                        new_exprs['GQ'] = hl.or_else(hl.gq_from_pl(pl),
                                                     old_entry.GQ)

                    dropped_fields.append('LPL')

                return (hl.case().when(
                    hl.len(ds.alleles) == 1,
                    old_entry.annotate(
                        **{
                            f[1:]: old_entry[f]
                            for f in ['LGT', 'LPGT', 'LAD', 'LPL']
                            if f in fields
                        }).drop(*dropped_fields)).when(
                            hl.or_else(old_entry.LGT.is_hom_ref(), False),
                            old_entry.annotate(
                                **{
                                    f: old_entry[f'L{f}'] if f in
                                    ['GT', 'PGT'] else e
                                    for f, e in new_exprs.items()
                                }).drop(*dropped_fields)).default(
                                    old_entry.annotate(**new_exprs).drop(
                                        *dropped_fields)))

            if 'LPL' in fields:
                new_pl = hl.or_missing(
                    hl.is_defined(old_entry.LPL),
                    hl.or_missing(
                        hl.is_defined(local_a_index),
                        hl.range(0, 3).map(lambda i: hl.min(
                            hl.range(0, hl.triangle(hl.len(old_entry.LA))).
                            filter(lambda j: hl.downcode(
                                hl.unphased_diploid_gt_index_call(j),
                                local_a_index) == hl.
                                   unphased_diploid_gt_index_call(i)).map(
                                       lambda idx: old_entry.LPL[idx])))))
                return hl.bind(with_pl, new_pl)
            else:
                return with_pl(None)
Esempio n. 8
0
#! /usr/bin/python

import sys
import hail as hl

n_samples = int(sys.argv[1])
n_variants = int(sys.argv[2])
path = sys.argv[3]

mt = hl.balding_nichols_model(1, n_samples, n_variants)
mt = mt.key_cols_by(s = hl.str(mt.sample_idx))
mt = mt.annotate_entries(GT = hl.unphased_diploid_gt_index_call(hl.rand_bool(0.5) * 2))
hl.export_vcf(mt, path + ".vcf")
hl.export_plink(mt, path)
Esempio n. 9
0
def split_multi_dynamic(
        t: Union[hl.MatrixTable, hl.Table],
        keep_star: bool = False,
        left_aligned: bool = True) -> Union[hl.MatrixTable, hl.Table]:
    """
    Splits MatrixTable based on entry fields found. Downcodes whatever it can. Supported so far:
    GT, DP, AD, PL, GQ
    PGT, PID
    ADALL

    :param MatrixTable t: Input MatrixTable
    :param bool keep_star: whether to keep star alleles (passed to SplitMulti)
    :param bool left_aligned: whether matrix table is already left_aligned (passed to SplitMulti)
    :return: Split MatrixTable
    :rtype: MatrixTable
    """
    if isinstance(t, hl.Table):
        t = t.annotate(a_index=hl.range(0,
                                        hl.len(t.alleles) -
                                        1)).explode('a_index')
        return t.annotate(alleles=[t.alleles[0], t.alleles[t.a_index]
                                   ])  # Note: does not minrep at the moment
    fields = list(t.entry)
    sm = hl.SplitMulti(t, keep_star=keep_star, left_aligned=left_aligned)
    sm.update_rows(a_index=sm.a_index(), was_split=sm.was_split())
    expression = {}

    # HTS/standard
    if 'GT' in fields:
        expression['GT'] = hl.downcode(t.GT, sm.a_index())
    if 'DP' in fields:
        expression['DP'] = t.DP
    if 'AD' in fields:
        expression['AD'] = hl.or_missing(hl.is_defined(
            t.AD), [hl.sum(t.AD) - t.AD[sm.a_index()], t.AD[sm.a_index()]])
    if 'PL' in fields:
        pl = hl.or_missing(hl.is_defined(
            t.PL), (hl.range(0, 3).map(lambda i: hl.min(
                (hl.range(0, hl.triangle(t.alleles.length())).filter(
                    lambda j: hl.downcode(hl.unphased_diploid_gt_index_call(
                        j), sm.a_index()) == hl.unphased_diploid_gt_index_call(
                            i)).map(lambda j: t.PL[j]))))))
        expression['PL'] = pl
        if 'GQ' in fields:
            expression['GQ'] = hl.gq_from_pl(pl)
    else:
        if 'GQ' in fields:
            expression['GQ'] = t.GQ

    # Phased data
    if 'PGT' in fields:
        expression['PGT'] = hl.downcode(t.PGT, sm.a_index())
    if 'PID' in fields:
        expression['PID'] = t.PID

    # Custom data
    if 'ADALL' in fields:  # found in NA12878
        expression['ADALL'] = hl.or_missing(
            hl.is_defined(t.ADALL),
            [hl.sum(t.ADALL) - t.ADALL[sm.a_index()], t.ADALL[sm.a_index()]])

    sm.update_entries(**expression)
    return sm.result()
    ###################### INPUT DATA  ##############################
    #####################################################################
    CHROMOSOME = "WGS"
    # mt = hl.read_matrix_table(
    #    f"{temp_dir}/ddd-elgh-ukbb/chr1_chr20_XY_sex_annotations.mt")

    # ld pruning
    #pruned_ht = hl.ld_prune(mt.GT, r2=0.1)
    #pruned_mt = mt.filter_rows(hl.is_defined(pruned_ht[mt.row_key]))
    # pruned_mt.write(
    #    f"{tmp_dir}/ddd-elgh-ukbb/chr1_chr20_XY_ldpruned.mt", overwrite=True)
    pruned_mt = hl.read_matrix_table(
        f"{temp_dir}/ddd-elgh-ukbb/chr1_chr20_XY_ldpruned.mt")
    # PC relate
    pruned_mt = pruned_mt.select_entries(
        GT=hl.unphased_diploid_gt_index_call(pruned_mt.GT.n_alt_alleles()))

    eig, scores, _ = hl.hwe_normalized_pca(pruned_mt.GT,
                                           k=10,
                                           compute_loadings=False)
    #  scores.write(
    #      f"{tmp_dir}/ddd-elgh-ukbb/chr1_chr20_XY_pruned.pca_scores.ht", overwrite=True)

    relatedness_ht = hl.pc_relate(pruned_mt.GT,
                                  min_individual_maf=0.05,
                                  scores_expr=scores[pruned_mt.col_key].scores,
                                  block_size=4096,
                                  min_kinship=0.05,
                                  statistics='kin2')
    # relatedness_ht.write(
    #     f"{tmp_dir}/ddd-elgh-ukbb/chr1_chr20_XY_relatedness.ht", overwrite=True)
Esempio n. 11
0
#! /usr/bin/python

import sys
import hail as hl

n_samples = int(sys.argv[1])
n_variants = int(sys.argv[2])
path = sys.argv[3]

mt = hl.balding_nichols_model(1, n_samples, n_variants)
mt = mt.key_cols_by(s=hl.str(mt.sample_idx))
mt = mt.annotate_entries(
    GT=hl.unphased_diploid_gt_index_call(hl.rand_bool(0.5) * 2))

hl.export_vcf(mt, path + ".vcf")
hl.export_plink(mt, path)

chimera0 = mt.filter_rows(mt.locus.position < n_variants / 2)
chimera0 = chimera0.filter_cols(chimera0.s == "0")

chimera1 = mt.filter_rows(mt.locus.position >= n_variants / 2)
chimera1 = chimera1.filter_cols(chimera1.s == "1")
chimera1 = chimera1.key_cols_by(s="0")

mt2 = chimera0.union_rows(chimera1)
hl.export_vcf(mt2, path + "-chimera.vcf")
hl.export_plink(mt2, path + "-chimera")
def main(args):

    # Init Hail
    hl.init(default_reference=args.default_reference)

    if not args.skip_compute_pc_relate:

        if not args.skip_filter_data:
            # Read MatrixTable
            mt = hl.read_matrix_table(args.mt_input_path)

            # filter variants (bi-allelic, high-callrate, common SNPs)
            logger.info(
                f"Filtering to bi-allelic, high-callrate, common SNPs ({args.maf_threshold}) for pc_relate..."
            )

            mt = (mt.filter_rows(
                (hl.len(mt.alleles) == 2)
                & hl.is_snp(mt.alleles[0], mt.alleles[1])
                & (hl.agg.mean(mt.GT.n_alt_alleles()) / 2 > args.maf_threshold)
                & (hl.agg.fraction(hl.is_defined(mt.GT)) > 0.99)
                & ~mt.was_split).repartition(500, shuffle=False))

            # keep only GT entry field and force to evaluate expression
            (mt.select_entries(mt.GT).write(
                f'{nfs_dir}/hail_data/sample_qc/chd_ukbb.filtered_high_confidence_variants.mt',
                overwrite=args.overwrite))

        mt = hl.read_matrix_table(
            f'{nfs_dir}/hail_data/sample_qc/chd_ukbb.filtered_high_confidence_variants.mt'
        )

        if not args.skip_prune_ld:
            # LD pruning
            # Avoid filtering / missingness entries (genotypes) before run LP pruning
            # Zulip Hail support issue -> "BlockMatrix trouble when running pc_relate"
            # mt = mt.unfilter_entries()

            # Prune variants in linkage disequilibrium.
            # Return a table with nearly uncorrelated variants

            logger.info(
                f'Pruning variants in LD from MT with {mt.count_rows()} variants...'
            )

            pruned_variant_table = hl.ld_prune(mt.GT, r2=args.r2)

            # Keep LD-pruned variants
            pruned_mt = (mt.filter_rows(hl.is_defined(
                pruned_variant_table[mt.row_key]),
                                        keep=True))
            pruned_mt.write(
                f'{nfs_dir}/hail_data/sample_qc/chd_ukbb.ld_pruned.mt',
                overwrite=args.overwrite)

        pruned_mt = hl.read_matrix_table(
            f'{nfs_dir}/hail_data/sample_qc/chd_ukbb.ld_pruned.mt')
        v, s = pruned_mt.count()
        logger.info(f'{s} samples, {v} variants found in LD-pruned MT')

        pruned_mt = pruned_mt.select_entries(
            GT=hl.unphased_diploid_gt_index_call(pruned_mt.GT.n_alt_alleles()))

        # run pc_relate method...compute all stats
        logger.info('Running PCA for PC-Relate...')
        eig, scores, _ = hl.hwe_normalized_pca(pruned_mt.GT,
                                               k=10,
                                               compute_loadings=False)
        scores.write(
            f'{nfs_dir}/hail_data/sample_qc/chd_ukbb.pruned.pca_scores_for_pc_relate.ht',
            overwrite=args.overwrite)

        logger.info(f'Running PC-Relate...')
        scores = hl.read_table(
            f'{nfs_dir}/hail_data/sample_qc/chd_ukbb.pruned.pca_scores_for_pc_relate.ht'
        )
        relatedness_ht = hl.pc_relate(
            call_expr=pruned_mt.GT,
            min_individual_maf=args.min_individual_maf,
            scores_expr=scores[pruned_mt.col_key].scores,
            block_size=4096,
            min_kinship=args.min_kinship,
            statistics='all')

        logger.info(f'Writing relatedness table...')
        # Write/export table to file
        relatedness_ht.write(
            output=
            f'{nfs_dir}/hail_data/sample_qc/chd_ukbb.relatedness_kinship.ht',
            overwrite=args.overwrite)

        # Write PCs table to file (if specified)
        # if args.write_to_file:
        #    # Export table to file
        #    relatedness_ht.export(output=f'{args.ht_output_path}.tsv.bgz')

    # retrieve maximal independent set of related samples
    logger.info('Getting optimal set of related samples to prune...')

    relatedness_ht = hl.read_table(
        f'{nfs_dir}/hail_data/sample_qc/chd_ukbb.relatedness_kinship.ht')

    relatedness_ht = (relatedness_ht.flatten().rename({
        'i.s': 'i',
        'j.s': 'j'
    }).repartition(100))

    # import trios info
    fam = import_fam_ht()
    mat_ids = hl.set(fam.mat_id.collect())
    fat_ids = hl.set(fam.pat_id.collect())

    # rank samples by retention priority (e.g. cases over controls)
    tb_rank = make_sample_rank_table(get_sample_meta_data())

    # apply min kinship to consider related pairs
    relatedness_ht = (relatedness_ht.filter(relatedness_ht.kin > MIN_KINSHIP))

    # run maximal_independent_set stratified by groups
    # Note: This method fails when considering all pairs together (e.g. it removes most of the index in trios, we want
    # keep them (index) since they are mostly affected individuals rather than parents).

    # defining pairs group
    # TODO: check groups with updated fam file
    relatedness_ht = (relatedness_ht.annotate(pairs_group=hl.case().when(
        relatedness_ht.kin > 0.40, 'twins_or_dups').when(
            mat_ids.contains(relatedness_ht.i)
            | mat_ids.contains(relatedness_ht.j), 'pairs_child_mat').when(
                fat_ids.contains(relatedness_ht.i)
                | fat_ids.contains(relatedness_ht.j),
                'pairs_child_fat').default('pairs_others')))

    groups = (relatedness_ht.aggregate(
        hl.agg.collect_as_set(relatedness_ht['pairs_group'])))
    tbs = []
    for pair_group in groups:
        pair_ht = relatedness_ht.filter(
            relatedness_ht.pairs_group == pair_group)
        tb = get_related_samples_to_drop(rank_table=tb_rank,
                                         relatedness_ht=pair_ht)
        tbs.append(tb)

    related_samples_to_remove = hl.Table.union(*tbs)

    related_samples_to_remove.describe()

    related_samples_to_remove = related_samples_to_remove.checkpoint(
        f'{nfs_dir}/hail_data/sample_qc/chd_ukbb.related_samples_to_remove.ht',
        overwrite=args.overwrite)

    if args.write_to_file:
        (related_samples_to_remove.flatten().export(
            f'{nfs_dir}/hail_data/sample_qc/chd_ukbb.related_samples_to_remove.tsv'
        ))

    hl.stop()