Esempio n. 1
0
 def test_maximal_independent_set_types(self):
     ht = hl.utils.range_table(10)
     ht = ht.annotate(i=hl.struct(a='1', b=hl.rand_norm(0, 1)),
                      j=hl.struct(a='2', b=hl.rand_norm(0, 1)))
     ht = ht.annotate(ii=hl.struct(id=ht.i, rank=hl.rand_norm(0, 1)),
                      jj=hl.struct(id=ht.j, rank=hl.rand_norm(0, 1)))
     hl.maximal_independent_set(ht.ii, ht.jj).count()
Esempio n. 2
0
 def test_maximal_independent_set_types(self):
     ht = hl.utils.range_table(10)
     ht = ht.annotate(i=hl.struct(a='1', b=hl.rand_norm(0, 1)),
                      j=hl.struct(a='2', b=hl.rand_norm(0, 1)))
     ht = ht.annotate(ii=hl.struct(id=ht.i, rank=hl.rand_norm(0, 1)),
                      jj=hl.struct(id=ht.j, rank=hl.rand_norm(0, 1)))
     hl.maximal_independent_set(ht.ii, ht.jj).count()
Esempio n. 3
0
    def test_maximal_independent_set(self):
        # prefer to remove nodes with higher index
        t = hl.utils.range_table(10)
        graph = t.select(i=hl.int64(t.idx), j=hl.int64(t.idx + 10), bad_type=hl.float32(t.idx))

        mis_table = hl.maximal_independent_set(graph.i, graph.j, True, lambda l, r: l - r)
        mis = [row['node'] for row in mis_table.collect()]
        self.assertEqual(sorted(mis), list(range(0, 10)))
        self.assertEqual(mis_table.row.dtype, hl.tstruct(node=hl.tint64))

        self.assertRaises(ValueError, lambda: hl.maximal_independent_set(graph.i, graph.bad_type, True))
        self.assertRaises(ValueError, lambda: hl.maximal_independent_set(graph.i, hl.utils.range_table(10).idx, True))
        self.assertRaises(ValueError, lambda: hl.maximal_independent_set(hl.literal(1), hl.literal(2), True))
Esempio n. 4
0
    def test_maximal_independent_set(self):
        # prefer to remove nodes with higher index
        t = hl.utils.range_table(10)
        graph = t.select(i=hl.int64(t.idx), j=hl.int64(t.idx + 10), bad_type=hl.float32(t.idx))

        mis_table = hl.maximal_independent_set(graph.i, graph.j, True, lambda l, r: l - r)
        mis = [row['node'] for row in mis_table.collect()]
        self.assertEqual(sorted(mis), list(range(0, 10)))
        self.assertEqual(mis_table.row.dtype, hl.tstruct(node=hl.tint64))
        self.assertEqual(mis_table.key.dtype, hl.tstruct(node=hl.tint64))

        self.assertRaises(ValueError, lambda: hl.maximal_independent_set(graph.i, graph.bad_type, True))
        self.assertRaises(ValueError, lambda: hl.maximal_independent_set(graph.i, hl.utils.range_table(10).idx, True))
        self.assertRaises(ValueError, lambda: hl.maximal_independent_set(hl.literal(1), hl.literal(2), True))
Esempio n. 5
0
    def test_maximal_independent_set3(self):
        is_case = {"A", "C", "E", "G", "H"}
        edges = [("A", "B"), ("C", "D"), ("E", "F"), ("G", "H")]
        edges = [{
            "i": {
                "id": l,
                "is_case": l in is_case
            },
            "j": {
                "id": r,
                "is_case": r in is_case
            }
        } for l, r in edges]

        t = hl.Table.parallelize(
            edges,
            hl.tstruct(i=hl.tstruct(id=hl.tstr, is_case=hl.tbool),
                       j=hl.tstruct(id=hl.tstr, is_case=hl.tbool)))

        tiebreaker = lambda l, r: (hl.case().when(l.is_case & (
            ~r.is_case), -1).when(~(l.is_case) & r.is_case, 1).default(0))

        mis = hl.maximal_independent_set(t.i, t.j, tie_breaker=tiebreaker)

        expected_sets = [{"A", "C", "E", "G"}, {"A", "C", "E", "H"}]

        self.assertTrue(mis.all(mis.node.is_case))
        self.assertTrue(
            set([row.id for row in mis.select(mis.node.id).collect()]) in
            expected_sets)
def query():
    """Query script entry point."""

    hl.init(default_reference='GRCh38')

    mt = hl.read_matrix_table(HGDP1KG_TOBWGS)

    # Perform kinship test with pc_relate
    pc_rel_path = output_path('pc_relate_kinship_estimate.ht')
    pc_rel = hl.pc_relate(mt.GT, 0.01, k=10, statistics='kin')
    pc_rel.write(pc_rel_path, overwrite=True)
    pairs = pc_rel.filter(pc_rel['kin'] >= 0.125)
    related_samples_to_remove = hl.maximal_independent_set(
        pairs.i, pairs.j, False)
    n_related_samples = related_samples_to_remove.count()
    print(f'related_samples_to_remove.count() = {n_related_samples}')

    # save as html
    html = pd.DataFrame({
        'removed_individual':
        related_samples_to_remove.node.s.collect()
    }).to_html()
    plot_filename_html = output_path(f'removed_samples.html', 'web')
    with hl.hadoop_open(plot_filename_html, 'w') as f:
        f.write(html)
Esempio n. 7
0
def query():
    """Query script entry point."""

    hl.init(default_reference='GRCh38')

    mt = hl.read_matrix_table(HGDP1KG_TOBWGS)
    mt = mt.filter_cols(
        (mt.hgdp_1kg_metadata.population_inference.pop == 'nfe')
        | (mt.s.contains('TOB'))
    )
    # Remove related samples (at the 2nd degree or closer)
    king = hl.king(mt.GT)
    king_path = output_path('king_kinship_estimate_NFE.ht')
    king.write(king_path)
    ht = king.entries()
    related_samples = ht.filter((ht.s_1 != ht.s) & (ht.phi > 0.125), keep=True)
    struct = hl.struct(i=related_samples.s_1, j=related_samples.s)
    struct = struct.annotate(phi=related_samples.phi)
    related_samples_to_remove = hl.maximal_independent_set(
        struct.i, struct.j, False  # pylint: disable=E1101
    )
    n_related_samples = related_samples_to_remove.count()
    print(f'related_samples_to_remove.count() = {n_related_samples}')
    # save as html
    html = pd.DataFrame(
        {'related_individual': related_samples_to_remove.node.collect()}
    ).to_html()
    plot_filename_html = output_path(f'related_samples.html', 'web')
    with hl.hadoop_open(plot_filename_html, 'w') as f:
        f.write(html)
Esempio n. 8
0
def run_pc_relate(mt: hl.MatrixTable,
                  pca_prefix: str,
                  overwrite: bool = False):
    """
    Runs PC-relate to identify relatives in a matrix table
    :param mt: Matrix table to run PC-relate on
    :param pca_prefix: Prefix to path to output relatedness information
    :param overwrite: if True, overwrites existing data
    :return:
    """
    relatedness_ht = hl.pc_relate(mt.GT,
                                  min_individual_maf=0.05,
                                  min_kinship=0.05,
                                  statistics='kin',
                                  k=20).key_by()
    relatedness_ht.write(pca_prefix + 'relatedness.ht', args.overwrite)
    relatedness_ht = hl.read_table(pca_prefix + 'relatedness.ht')

    # identify individuals in pairs to remove
    related_samples_to_remove = hl.maximal_independent_set(
        relatedness_ht.i, relatedness_ht.j, False)
    mt_unrel = mt.filter_cols(hl.is_defined(
        related_samples_to_remove[mt.col_key]),
                              keep=False)
    mt_rel = mt.filter_cols(hl.is_defined(
        related_samples_to_remove[mt.col_key]),
                            keep=True)

    mt_unrel.write(pca_prefix + 'unrel.mt', args.overwrite)
    mt_rel.write(pca_prefix + 'rel.mt', args.overwrite)
Esempio n. 9
0
 def test_maximal_independent_set_on_floats(self):
     t = hl.utils.range_table(1).annotate(l=hl.struct(s="a", x=3.0),
                                          r=hl.struct(s="b", x=2.82))
     expected = [hl.Struct(node=hl.Struct(s="a", x=3.0))]
     actual = hl.maximal_independent_set(
         t.l, t.r, keep=False,
         tie_breaker=lambda l, r: l.x - r.x).collect()
     assert actual == expected
Esempio n. 10
0
    def test_maximal_independent_set2(self):
        edges = [(0, 4), (0, 1), (0, 2), (1, 5), (1, 3), (2, 3), (2, 6),
                 (3, 7), (4, 5), (4, 6), (5, 7), (6, 7)]
        edges = [{"i": l, "j": r} for l, r in edges]

        t = hl.Table.parallelize(edges, hl.tstruct(i=hl.tint64, j=hl.tint64))
        mis_t = hl.maximal_independent_set(t.i, t.j)
        self.assertTrue(mis_t.row.dtype == hl.tstruct(node=hl.tint64) and
                        mis_t.globals.dtype == hl.tstruct())

        mis = set([row.node for row in mis_t.collect()])
        maximal_indep_sets = [{0, 6, 5, 3}, {1, 4, 7, 2}]
        non_maximal_indep_sets = [{0, 7}, {6, 1}]
        self.assertTrue(mis in non_maximal_indep_sets or mis in maximal_indep_sets)
Esempio n. 11
0
    def test_maximal_independent_set2(self):
        edges = [(0, 4), (0, 1), (0, 2), (1, 5), (1, 3), (2, 3), (2, 6),
                 (3, 7), (4, 5), (4, 6), (5, 7), (6, 7)]
        edges = [{"i": l, "j": r} for l, r in edges]

        t = hl.Table.parallelize(edges, hl.tstruct(i=hl.tint64, j=hl.tint64))
        mis_t = hl.maximal_independent_set(t.i, t.j)
        self.assertTrue(mis_t.row.dtype == hl.tstruct(node=hl.tint64)
                        and mis_t.globals.dtype == hl.tstruct())

        mis = set([row.node for row in mis_t.collect()])
        maximal_indep_sets = [{0, 6, 5, 3}, {1, 4, 7, 2}]
        non_maximal_indep_sets = [{0, 7}, {6, 1}]
        self.assertTrue(mis in non_maximal_indep_sets
                        or mis in maximal_indep_sets)
def get_related_samples_to_drop(rank_table: hl.Table,
                                relatedness_ht: hl.Table) -> hl.Table:
    """
    Use the maximal independence function in Hail to intelligently prune clusters of related individuals, removing
    less desirable samples while maximizing the number of unrelated individuals kept in the sample set

    :param Table rank_table: Table with ranking annotations across exomes and genomes, computed via make_rank_file()
    :param Table relatedness_ht: Table with kinship coefficient annotations computed via pc_relate()
    :return: Table containing sample IDs ('s') to be pruned from the combined exome and genome sample set
    :rtype: Table
    """
    # Define maximal independent set, using rank list
    related_pairs = relatedness_ht.filter(
        relatedness_ht.kin > 0.08838835).select('i', 'j')
    n_related_samples = hl.eval(
        hl.len(
            related_pairs.aggregate(hl.agg.explode(
                lambda x: hl.agg.collect_as_set(x),
                [related_pairs.i, related_pairs.j]),
                                    _localize=False)))
    logger.info(
        '{} samples with at least 2nd-degree relatedness found in callset'.
        format(n_related_samples))
    max_rank = rank_table.count()
    related_pairs = related_pairs.annotate(
        id1_rank=hl.struct(id=related_pairs.i,
                           rank=rank_table[related_pairs.i].rank),
        id2_rank=hl.struct(id=related_pairs.j,
                           rank=rank_table[related_pairs.j].rank)).select(
                               'id1_rank', 'id2_rank')

    def tie_breaker(l, r):
        return hl.or_else(l.rank, max_rank + 1) - hl.or_else(
            r.rank, max_rank + 1)

    related_samples_to_drop_ranked = hl.maximal_independent_set(
        related_pairs.id1_rank,
        related_pairs.id2_rank,
        keep=False,
        tie_breaker=tie_breaker)
    return related_samples_to_drop_ranked.select(
        **related_samples_to_drop_ranked.node.id).key_by('data_type', 's')
Esempio n. 13
0
    def test_maximal_independent_set3(self):
        is_case = {"A", "C", "E", "G", "H"}
        edges = [("A", "B"), ("C", "D"), ("E", "F"), ("G", "H")]
        edges = [{"i": {"id": l, "is_case": l in is_case},
                  "j": {"id": r, "is_case": r in is_case}} for l, r in edges]

        t = hl.Table.parallelize(edges, hl.tstruct(i=hl.tstruct(id=hl.tstr, is_case=hl.tbool),
                                                   j=hl.tstruct(id=hl.tstr, is_case=hl.tbool)))

        tiebreaker = lambda l, r: (hl.case()
                                   .when(l.is_case & (~r.is_case), -1)
                                   .when(~(l.is_case) & r.is_case, 1)
                                   .default(0))

        mis = hl.maximal_independent_set(t.i, t.j, tie_breaker=tiebreaker)

        expected_sets = [{"A", "C", "E", "G"}, {"A", "C", "E", "H"}]

        self.assertTrue(mis.all(mis.node.is_case))
        self.assertTrue(set([row.id for row in mis.select(mis.node.id).collect()]) in expected_sets)
def main(args):
    mt = hl.read_matrix_table(args.matrixtable)
    # ld pruning
    pruned_ht = hl.ld_prune(mt.GT, r2=0.1)
    pruned_mt = mt.filter_rows(hl.is_defined(pruned_ht[mt.row_key]))
    pruned_mt.write(f"{args.output_dir}/mt_ldpruned.mt", overwrite=True)

    # PC relate
    pruned_mt = pruned_mt.select_entries(
        GT=hl.unphased_diploid_gt_index_call(pruned_mt.GT.n_alt_alleles()))

    eig, scores, _ = hl.hwe_normalized_pca(pruned_mt.GT,
                                           k=10,
                                           compute_loadings=False)
    scores.write(f"{args.output_dir}/mt_pruned.pca_scores.ht", overwrite=True)

    relatedness_ht = hl.pc_relate(pruned_mt.GT,
                                  min_individual_maf=0.05,
                                  scores_expr=scores[pruned_mt.col_key].scores,
                                  block_size=4096,
                                  min_kinship=0.05,
                                  statistics='kin2')
    relatedness_ht.write(f"{args.output_dir}/mt_relatedness.ht",
                         overwrite=True)
    pairs = relatedness_ht.filter(relatedness_ht['kin'] > 0.125)
    related_samples_to_remove = hl.maximal_independent_set(pairs.i,
                                                           pairs.j,
                                                           keep=False)
    related_samples_to_remove.write(
        f"{args.output_dir}/mt_related_samples_to_remove.ht", overwrite=True)

    pca_mt = pruned_mt.filter_cols(hl.is_defined(
        related_samples_to_remove[pruned_mt.col_key]),
                                   keep=False)
    related_mt = pruned_mt.filter_cols(hl.is_defined(
        related_samples_to_remove[pruned_mt.col_key]),
                                       keep=True)

    variants, samples = pca_mt.count()

    print(f"{samples} samples after relatedness step.")

    # Population pca

    plink_mt = pca_mt.annotate_cols(uid=pca_mt.s).key_cols_by('uid')
    hl.export_plink(plink_mt,
                    f"{args.output_dir}/mt_unrelated.plink",
                    fam_id=plink_mt.uid,
                    ind_id=plink_mt.uid)
    pca_evals, pca_scores, pca_loadings = hl.hwe_normalized_pca(
        pca_mt.GT, k=20, compute_loadings=True)
    pca_af_ht = pca_mt.annotate_rows(
        pca_af=hl.agg.mean(pca_mt.GT.n_alt_alleles()) / 2).rows()
    pca_loadings = pca_loadings.annotate(
        pca_af=pca_af_ht[pca_loadings.key].pca_af)
    pca_scores.write(f"{args.output_dir}/mt_pca_scores.ht", overwrite=True)
    pca_loadings.write(f"{args.output_dir}/mt_pca_loadings.ht", overwrite=True)

    pca_mt = pca_mt.annotate_cols(scores=pca_scores[pca_mt.col_key].scores)

    variants, samples = related_mt.count()
    print(
        'Projecting population PCs for {} related samples...'.format(samples))
    #related_scores = pc_project(related_mt, pca_loadings)
    #relateds = related_mt.cols()
    #relateds = relateds.annotate(scores=related_scores[relateds.key].scores)

    pca_mt.write(f"{args.output_dir}/mt_pca.mt", overwrite=True)
    p = hl.plot.scatter(pca_mt.scores[0],
                        pca_mt.scores[1],
                        title='PCA',
                        xlabel='PC1',
                        ylabel='PC2')
    output_file(f"{args.plot_dir}/pca.html")
    save(p)
                                           k=10,
                                           compute_loadings=False)
    #  scores.write(
    #      f"{tmp_dir}/ddd-elgh-ukbb/chr1_chr20_XY_pruned.pca_scores.ht", overwrite=True)

    relatedness_ht = hl.pc_relate(pruned_mt.GT,
                                  min_individual_maf=0.05,
                                  scores_expr=scores[pruned_mt.col_key].scores,
                                  block_size=4096,
                                  min_kinship=0.05,
                                  statistics='kin2')
    # relatedness_ht.write(
    #     f"{tmp_dir}/ddd-elgh-ukbb/chr1_chr20_XY_relatedness.ht", overwrite=True)
    pairs = relatedness_ht.filter(relatedness_ht['kin'] > 0.125)
    related_samples_to_remove = hl.maximal_independent_set(pairs.i,
                                                           pairs.j,
                                                           keep=False)
    # related_samples_to_remove.write(
    #     f"{tmp_dir}/ddd-elgh-ukbb/chr1_chr20_XY_related_samples_to_remove.ht", overwrite=True)

    pca_mt = pruned_mt.filter_cols(hl.is_defined(
        related_samples_to_remove[pruned_mt.col_key]),
                                   keep=False)
    related_mt = pruned_mt.filter_cols(hl.is_defined(
        related_samples_to_remove[pruned_mt.col_key]),
                                       keep=True)

    variants, samples = pca_mt.count()

    print(f"{samples} samples after relatedness step.")
Esempio n. 16
0
def main(args):
    if args.load_ref:
        load_ref(args.dirname, args.basename)

    if args.load_ukbb:
        samples = hl.read_table(
            'gs://ukb-diverse-pops/pigmentation_phenos_covs_pops.ht')
        ukbb = hl.read_matrix_table('gs://ukb31063/ukb31063.genotype.mt')
        ukbb = ukbb.annotate_cols(**samples[ukbb.s])

    if args.intersect_ref:
        intersect_ref(args.dirname, args.basename, ukbb)

    if args.pca_project:
        """
        Compute PCA in global reference panel, project UKBB individuals into PCA space
        """
        ref_in_ukbb = hl.read_matrix_table(args.dirname + 'intersect_' +
                                           args.basename + 'ukbb.mt')
        print('Computing reference PCs')
        run_pca(ref_in_ukbb, args.out_prefix + args.basename + '_ukbb_')

        # project ukbb
        pca_loadings = hl.read_table(
            f'{args.out_prefix}{args.basename}_ukbb_loadings.ht')
        project_mt = hl.read_matrix_table(args.dirname + 'intersect_ukbb_' +
                                          args.basename + '.mt')
        ht = project_individuals(pca_loadings, project_mt)
        ht.export(args.out_prefix + 'ukbb_' + args.basename +
                  '_scores.txt.bgz')

    # if args.continental_pca:
    #     """
    #     Compute PCA within reference panel super pops, project UKBB individuals into PCA space
    #     1. Filter UKBB to individuals in continental population
    #     2. Run PCA on continental ref
    #     3. Project UKBB inds
    #     """
    #     pass

    if args.ukbb_pop_pca:
        """
        Compute PCA in each UKBB population (unrelateds), project reference individuals and relateds into PCA space
        1. Filter UKBB to individuals in continental population
        2. Run PC-relate on these individuals
        # New
        2.5 Filter to pruned set of individuals
        #
        3. Filter UKBB population to unrelated individuals
        4. Run PCA on UKBB unrelateds within population
        5. Project relateds
        """

        for pop in POPS:
            mt = hl.read_matrix_table(get_ukb_grm_mt_path(pop))
            pruned_ht = hl.read_table(get_ukb_grm_pruned_ht_path(pop))
            mt = mt.filter_rows(hl.is_defined(pruned_ht[mt.row_key]))

            # run PC-relate
            if args.overwrite or not hl.hadoop_exists(
                    get_relatedness_path(pop,
                                         extension='all_scores.ht/_SUCCESS')):
                _, scores, _ = hl.hwe_normalized_pca(mt.GT,
                                                     k=10,
                                                     compute_loadings=False)
                scores.write(
                    get_relatedness_path(pop, extension='all_scores.ht'),
                    args.overwrite)
            scores = hl.read_table(
                get_relatedness_path(pop, extension='all_scores.ht'))
            mt = mt.annotate_cols(scores=scores[mt.col_key].scores)
            # For EUR, required highmem machines with SSDs (Needed ~6T of hdfs space, so 20 workers + 100 pre-emptibles ran in ~7 hours)
            relatedness_ht = hl.pc_relate(
                mt.GT,
                min_individual_maf=0.05,
                scores_expr=mt.scores,
                min_kinship=0.05,
                statistics='kin',
                block_size=4096 if pop == 'EUR' else 512).key_by()
            relatedness_ht.write(get_relatedness_path(pop, extension='ht'),
                                 args.overwrite)
            relatedness_ht = hl.read_table(
                get_relatedness_path(pop, extension='ht'))

            # identify individuals in pairs to remove
            related_samples_to_remove = hl.maximal_independent_set(
                relatedness_ht.i, relatedness_ht.j, False)
            mt_unrel = mt.filter_cols(hl.is_defined(
                related_samples_to_remove[mt.col_key]),
                                      keep=False)
            mt_rel = mt.filter_cols(hl.is_defined(
                related_samples_to_remove[mt.col_key]),
                                    keep=True)

            mt_unrel.write(get_relatedness_path(pop, True, 'mt'),
                           args.overwrite)
            mt_rel.write(get_relatedness_path(pop, extension='mt'),
                         args.overwrite)

    if args.ukb_prune_pca_project:
        for pop in POPS:
            mt_unrel = hl.read_matrix_table(
                get_relatedness_path(pop, True, 'mt'))
            mt_rel = hl.read_matrix_table(
                get_relatedness_path(pop, extension='mt'))

            # Removing individuals
            pruned_inds = hl.import_table(get_pruned_tsv_path(), key='s')
            mt_rel = mt_rel.filter_cols(
                hl.is_defined(pruned_inds[mt_rel.col_key]))
            mt_unrel = mt_unrel.filter_cols(
                hl.is_defined(pruned_inds[mt_unrel.col_key]))

            # Removing sites
            window = '1e6' if pop != 'EUR' else '1e7'
            pruned_ht = hl.read_table(get_ukb_grm_pruned_ht_path(pop, window))
            mt_unrel = mt_unrel.filter_rows(
                hl.is_defined(pruned_ht[mt_unrel.row_key]))

            mt_unrel = mt_unrel.repartition(500).checkpoint(
                hl.utils.new_temp_file())

            pop = pop if window == '1e6' else f'{pop}_{window}'
            run_pca(
                mt_unrel,
                get_relatedness_path(pop, unrelated=True, extension='') + '.',
                args.overwrite)
            pca_loadings = hl.read_table(
                get_relatedness_path(pop,
                                     unrelated=True,
                                     extension='loadings.ht'))
            ht = project_individuals(pca_loadings, mt_rel)
            ht.write(
                get_relatedness_path(pop, extension='scores_projected.ht'),
                args.overwrite)
            hl.read_table(
                get_relatedness_path(
                    pop, extension='scores_projected.ht')).export(
                        get_relatedness_path(
                            pop, extension='scores_projected.txt.bgz'))

    if args.generate_covariates:
        hts = []
        for pop in POPS:
            pop_path = pop if pop != 'EUR' else f'EUR_1e7'
            ht = hl.read_table(
                get_relatedness_path(pop_path,
                                     extension='scores_projected.ht'))
            hts.append(ht.annotate(pop=pop, related=True))
            ht = hl.read_table(
                get_relatedness_path(pop_path, True, extension='scores.ht'))
            ht = ht.transmute(
                **{f'PC{i}': ht.scores[i - 1]
                   for i in range(1, 21)})
            hts.append(ht.annotate(pop=pop, related=False))

        ht = hts[0].union(*hts[1:])
        cov_ht = hl.import_table(get_age_sex_tsv_path(),
                                 impute=True,
                                 force=True,
                                 quote='"',
                                 key='userId').select('age', 'sex')
        cov_ht = cov_ht.annotate(age_sex=cov_ht.age * cov_ht.sex,
                                 age2=hl.int32(cov_ht.age**2),
                                 age2_sex=hl.int32(cov_ht.age**2) * cov_ht.sex)
        ht = ht.annotate(**cov_ht.key_by(userId=hl.str(cov_ht.userId))[ht.key])
        ht.write(get_covariates_ht_path(), args.overwrite)

    get_filtered_mt(imputed=False).cols().export(get_final_sample_set())
Esempio n. 17
0
def main(args):
    global output_prefix
    output_prefix = args.output_dir.rstrip("/") + "/" + splitext(
        basename(args.input_mt))[0]

    if args.compute_qc_mt:
        qc_mt = get_qc_mt(hl.read_matrix_table(args.input_mt))
        qc_mt = qc_mt.repartition(n_partitions=200)
        qc_mt.write(path('qc.mt'), overwrite=args.overwrite)

    if args.compute_qc_metrics:
        logger.info("Computing sample QC")
        mt = filter_to_autosomes(hl.read_matrix_table(args.input_mt))
        strats = {
            'bi_allelic': bi_allelic_expr(mt),
            'multi_allelic': ~bi_allelic_expr(mt)
        }
        for strat, filter_expr in strats.items():
            strat_sample_qc_ht = hl.sample_qc(
                mt.filter_rows(filter_expr)).cols()
            strat_sample_qc_ht.write(path(f'{strat}_sample_qc.ht'),
                                     overwrite=args.overwrite)
        strat_hts = [
            hl.read_table(path(f'{strat}_sample_qc.ht')) for strat in strats
        ]
        sample_qc_ht = strat_hts.pop()
        sample_qc_ht = sample_qc_ht.select(
            sample_qc=merge_sample_qc_expr([sample_qc_ht.sample_qc] + [
                strat_hts[i][sample_qc_ht.key].sample_qc
                for i in range(0, len(strat_hts))
            ]))
        sample_qc_ht.write(path('sample_qc.ht'), overwrite=args.overwrite)

    if args.compute_callrate_mt:
        callrate_mt = compute_callrate_mt(
            hl.read_matrix_table(args.input_mt),
            hl.import_locus_intervals(exome_calling_intervals_path))
        callrate_mt.write(path('callrate.mt'), args.overwrite)

    if args.run_platform_pca:
        eigenvalues, scores_ht, loadings_ht = run_platform_pca(
            hl.read_matrix_table(path('callrate.mt')))
        scores_ht.write(path('platform_pca_scores.ht'),
                        overwrite=args.overwrite)
        loadings_ht.write(path('platform_pca_loadings.ht'),
                          overwrite=args.overwrite)

    if args.assign_platforms:
        platform_ht = assign_platform_from_pcs(
            hl.read_table(path('platform_pca_scores.ht')),
            hdbscan_min_cluster_size=args.hdbscan_min_cluster_size,
            hdbscan_min_samples=args.hdbscan_min_samples)
        platform_ht.write(f'{output_prefix}.platform_pca_results.ht',
                          overwrite=args.overwrite)

    if args.impute_sex:
        sex_ht = infer_sex(hl.read_matrix_table(path('qc.mt')),
                           hl.read_matrix_table(args.input_mt),
                           hl.read_table(path('platform_pca_results.ht')),
                           args.male_threshold, args.female_threshold,
                           args.min_male_y_sites_called,
                           args.max_y_female_call_rate,
                           args.min_y_male_call_rate)
        sex_ht.write(path('sex.ht'), overwrite=args.overwrite)

    if args.run_pc_relate:
        logger.info('Running PCA for PC-Relate')
        qc_mt = hl.read_matrix_table(path('qc.mt')).unfilter_entries()
        eig, scores, _ = hl.hwe_normalized_pca(qc_mt.GT,
                                               k=10,
                                               compute_loadings=False)
        scores.write(path('pruned.pca_scores.ht'), args.overwrite)

        logger.info('Running PC-Relate')
        logger.warn(
            "PC-relate requires SSDs and doesn't work with preemptible workers!"
        )
        scores = hl.read_table(path('pruned.pca_scores.ht'))
        relatedness_ht = hl.pc_relate(qc_mt.GT,
                                      min_individual_maf=0.05,
                                      scores_expr=scores[qc_mt.col_key].scores,
                                      block_size=4096,
                                      min_kinship=args.min_emission_kinship,
                                      statistics='all')
        relatedness_ht.write(path('relatedness.ht'), args.overwrite)

    if args.filter_dups:
        logger.info("Filtering duplicate samples")
        sample_qc_ht = hl.read_table(path('sample_qc.ht'))
        samples_rankings_ht = sample_qc_ht.select(
            rank=-1 * sample_qc_ht.sample_qc.dp_stats.mean)
        dups_ht = filter_duplicate_samples(
            hl.read_table(path('relatedness.ht')), samples_rankings_ht)
        dups_ht.write(path('duplicates.ht'), overwrite=args.overwrite)

    if args.infer_families:
        logger.info("Inferring families")
        duplicates_ht = hl.read_table(path('duplicates.ht'))
        dups_to_remove = duplicates_ht.aggregate(
            hl.agg.explode(lambda x: hl.agg.collect_as_set(x.s),
                           duplicates_ht.filtered))
        ped = infer_families(hl.read_table(path('relatedness.ht')),
                             hl.read_table(path('sex.ht')), dups_to_remove)
        ped.write(path('pedigree.ped'))

    if args.filter_related_samples:
        logger.info("Filtering related samples")
        related_pairs_ht, related_pairs_tie_breaker = rank_related_samples(
            hl.read_table(path('relatedness.ht')), hl.read_table(args.meta),
            hl.read_table(path('sample_qc.ht')),
            hl.import_fam(path('pedigree.ped'), delimiter="\t"))

        related_samples_to_drop_ht = hl.maximal_independent_set(
            related_pairs_ht.i,
            related_pairs_ht.j,
            keep=False,
            tie_breaker=related_pairs_tie_breaker)
        related_samples_to_drop_ht = related_samples_to_drop_ht.key_by()
        related_samples_to_drop_ht = related_samples_to_drop_ht.select(
            **related_samples_to_drop_ht.node)
        related_samples_to_drop_ht = related_samples_to_drop_ht.key_by('s')
        related_samples_to_drop_ht.write(path('related_samples_to_drop.ht'),
                                         overwrite=args.overwrite)

    if args.run_pca:
        logger.info("Running population PCA")
        pca_evals, pop_pca_scores_ht, pop_pca_loadings_ht = run_pca_with_relateds(
            hl.read_matrix_table(path('qc.mt')),
            hl.read_table(path('related_samples_to_drop.ht')), args.n_pcs)
        pop_pca_loadings_ht.write(path('pop_pca_loadings.ht'), args.overwrite)
        pop_pca_scores_ht.write(path('pop_pca_scores.ht'), args.overwrite)

    if args.assign_pops:
        logger.info("Assigning global population labels")
        pop_pca_scores_ht = hl.read_table(path("pop_pca_scores.ht"))
        gnomad_meta_ht = get_gnomad_meta('exomes').select("pop")[
            pop_pca_scores_ht.key]
        pop_pca_scores_ht = pop_pca_scores_ht.annotate(known_pop=hl.or_missing(
            gnomad_meta_ht.pop != "oth", gnomad_meta_ht.pop))
        pop_ht, pops_rf_model = assign_population_pcs(
            pop_pca_scores_ht,
            pc_cols=pop_pca_scores_ht.scores[:args.n_pcs],
            known_col='known_pop',
            min_prob=args.min_pop_prob)

        pop_ht.write(path('pop.ht'), args.overwrite)
        with hl.hadoop_open(path('pop_rf_model.pkl'), 'wb') as out:
            pickle.dump(pops_rf_model, out)

    if args.assign_subpops:
        qc_mt = hl.read_matrix_table(path('qc.mt'))
        pop_ht = hl.read_table(path('pop.ht'))
        meta_ht = hl.read_table(args.meta)[qc_mt.col_key]
        qc_mt = qc_mt.annotate_cols(pop=pop_ht[qc_mt.col_key].pop,
                                    is_case=meta_ht.is_case,
                                    country=meta_ht.country)

        platform_specific_intervals = get_platform_specific_intervals(
            hl.read_table(path('platform_pca_loadings.ht')), threshold=0.01)
        logger.info(
            f'Excluding {len(platform_specific_intervals)} platform-specific intervals for subpop PCA.'
        )
        qc_mt = hl.filter_intervals(qc_mt,
                                    platform_specific_intervals,
                                    keep=False)

        assign_and_write_subpops(
            qc_mt,
            hl.read_table(path('related_samples_to_drop.ht')),
            min_samples_for_subpop=args.min_samples_for_subpop,
            n_pcs=args.n_pcs,
            min_pop_prob=args.min_pop_prob,
            overwrite=args.overwrite,
            pop_ann='pop',
            subpop_ann='country',
            include_in_pop_count=qc_mt.is_case)

    if args.run_kgp_pca:
        logger.info("Joining data with 1000 Genomes")
        qc_mt = hl.read_matrix_table(
            path('qc.mt')).select_rows().select_entries("GT")
        qc_mt = qc_mt.select_cols(known_pop=hl.null(hl.tstr),
                                  known_subpop=hl.null(hl.tstr))
        qc_mt = qc_mt.key_cols_by(_kgp=False, *qc_mt.col_key)

        kgp_mt = hl.read_matrix_table(
            kgp_phase3_genotypes_mt_path()).select_rows()
        kgp_mt = kgp_mt.select_cols(known_pop=kgp_mt.super_pops.get(
            kgp_mt.population, "oth").lower(),
                                    known_subpop=kgp_mt.population.lower())
        kgp_mt = kgp_mt.filter_rows(hl.is_defined(
            qc_mt.rows()[kgp_mt.row_key]))
        kgp_mt = filter_rows_for_qc(kgp_mt)
        kgp_mt = kgp_mt.key_cols_by(_kgp=True, *kgp_mt.col_key)

        union_kgp_qc_mt = qc_mt.union_cols(kgp_mt)
        union_kgp_qc_mt.write(path('union_kgp_qc.mt'),
                              overwrite=args.overwrite)

        logger.info("Computing PCA on data with 1000 Genomes")
        union_kgp_qc_mt = hl.read_matrix_table(path('union_kgp_qc.mt'))
        related_samples_to_drop_ht = hl.read_table(
            path('related_samples_to_drop.ht'))
        related_samples_to_drop_ht = related_samples_to_drop_ht.key_by(
            _kgp=False, *related_samples_to_drop_ht.key)
        pca_evals, union_kgp_pca_scores_ht, union_kgp_pca_loadings_ht = run_pca_with_relateds(
            union_kgp_qc_mt, related_samples_to_drop_ht, args.n_kgp_pcs)
        union_kgp_pca_loadings_ht.write(path('union_kgp_pca_loadings.ht'),
                                        args.overwrite)
        union_kgp_pca_scores_ht.write(path('union_kgp_pca_scores.ht'),
                                      args.overwrite)

    if args.assign_pops_kgp:
        logger.info("Assigning populations based on 1000 Genomes labels")
        union_kgp_qc_mt = hl.read_matrix_table(path('union_kgp_qc.mt'))
        union_kgp_pca_scores_ht = hl.read_table(
            path('union_kgp_pca_scores.ht'))
        union_kgp_pca_scores_ht = union_kgp_pca_scores_ht.annotate(
            known_pop=union_kgp_qc_mt[union_kgp_pca_scores_ht.key].known_pop)
        union_kgp_pop_ht, union_kgp_pop_rf_model = assign_population_pcs(
            union_kgp_pca_scores_ht,
            pc_cols=union_kgp_pca_scores_ht.scores[:args.n_kgp_pcs],
            known_col='known_pop',
            min_prob=args.min_kgp_pop_prob)

        union_kgp_pop_ht.write(path('union_kgp_pop.ht'), args.overwrite)

        with hl.hadoop_open(path('union_kgp_pop_rf_model.pkl'), 'wb') as out:
            pickle.dump(union_kgp_pop_rf_model, out)

    if args.assign_subpops_kgp:
        union_kgp_qc_mt = hl.read_matrix_table(path('union_kgp_qc.mt'))
        meta_ht = hl.read_table(args.meta)
        union_kgp_pop_ht = hl.read_table(path('union_kgp_pop.ht'))
        union_kgp_qc_mt = union_kgp_qc_mt.annotate_cols(
            is_case=meta_ht[union_kgp_qc_mt.col_key].is_case,
            pop=union_kgp_pop_ht[union_kgp_qc_mt.col_key].pop)

        platform_specific_intervals = get_platform_specific_intervals(
            hl.read_table(path('platform_pca_loadings.ht')))
        logger.info(
            f'Excluding {len(platform_specific_intervals)} platform-specific intervals for subpop PCA.'
        )
        union_kgp_qc_mt = hl.filter_intervals(union_kgp_qc_mt,
                                              platform_specific_intervals,
                                              keep=False)

        related_samples_to_drop_ht = hl.read_table(
            path('related_samples_to_drop.ht'))
        related_samples_to_drop_ht = related_samples_to_drop_ht.key_by(
            _kgp=False, *related_samples_to_drop_ht.key)

        assign_and_write_subpops(
            union_kgp_qc_mt,
            related_samples_to_drop_ht,
            min_samples_for_subpop=args.min_samples_for_subpop,
            n_pcs=args.n_kgp_pcs,
            min_pop_prob=args.min_kgp_pop_prob,
            overwrite=args.overwrite,
            pop_ann='pop',
            subpop_ann='known_subpop',
            include_in_pop_count=union_kgp_qc_mt.is_case,
            files_prefix='union_kgp_')

    if args.apply_stratified_filters:
        logger.info("Computing stratified QC")
        for variant_class_prefix in ['', 'bi_allelic_', 'multi_allelic_']:
            sample_qc_ht = hl.read_table(
                path(f'{variant_class_prefix}sample_qc.ht'))
            pop_ht = hl.read_table(path('pops.ht'))
            platform_ht = hl.read_table(path('platform_pca_results.ht'))
            sample_qc_ht = sample_qc_ht.annotate(
                qc_pop=pop_ht[sample_qc_ht.key].pop,
                qc_platform=platform_ht[sample_qc_ht.key].qc_platform)
            stratified_metrics_ht = compute_stratified_metrics_filter(
                sample_qc_ht, args.filtering_qc_metrics.split(","),
                ['qc_pop', 'qc_platform'])
            stratified_metrics_ht.write(
                path(f'{variant_class_prefix}stratified_metrics_filters.ht'),
                overwrite=args.overwrite)

    if args.write_full_meta:
        logger.info("Writing metadata table")

        # List all tables to join with the base meta
        meta_annotation_hts = [
            hl.read_table(path('platform_pca_results.ht')).rename(
                {'scores': 'platform_pc_scores'}),
            hl.read_table(path('sex.ht')),
            flatten_duplicate_samples_ht(hl.read_table(path('duplicates.ht'))),
            hl.read_table(path('related_samples_to_drop.ht')).select(
                related_filtered=True),
            hl.read_table(path('pca_scores.ht')).rename(
                {'scores': 'pop_pc_scores'}),
            hl.read_table(path('pops.ht')).select('pop'),
            hl.read_table(path('nfe.pca_scores.ht')).rename(
                {'scores': 'nfe_pc_scores'}),
            hl.read_table(path('subpops.nfe.ht')).select('subpop')
        ]

        # union_kgp_pops_ht = hl.read_table(path('union_kgp_pops.ht'))
        # union_kgp_pops_ht = union_kgp_pops_ht.filter(~union_kgp_pops_ht._kgp).key_by('s')
        # union_kgp_pops_ht = union_kgp_pops_ht.select(kgp_pop=union_kgp_pops_ht.pop)
        # meta_annotation_hts.append(union_kgp_pops_ht)
        #
        # union_kgp_pca_scores_ht = hl.read_table(path('union_kgp_pca_scores.ht')).rename({'scores': 'kgp_pop_pc_scores'})
        # union_kgp_pca_scores_ht = union_kgp_pca_scores_ht.filter(~union_kgp_pca_scores_ht._kgp).key_by('s')
        # meta_annotation_hts.append(union_kgp_pca_scores_ht)

        gnomad_meta_ht = get_gnomad_meta('exomes')
        gnomad_meta_ht = gnomad_meta_ht.select(
            gnomad_pop=gnomad_meta_ht.pop, gnomad_subpop=gnomad_meta_ht.subpop)
        meta_annotation_hts.append(gnomad_meta_ht)

        for variant_class_prefix in ['', 'bi_allelic_', 'multi_allelic_']:
            sample_qc_ht = hl.read_table(
                path(f'{variant_class_prefix}sample_qc.ht'))
            stratified_metrics_filters_ht = hl.read_table(
                path(f'{variant_class_prefix}stratified_metrics_filters.ht'))
            if variant_class_prefix:
                sample_qc_ht = sample_qc_ht.rename(
                    {'sample_qc': f'{variant_class_prefix}sample_qc'})
                stratified_metrics_filters_ht = stratified_metrics_filters_ht.rename(
                    {
                        f: f'{variant_class_prefix}{f}'
                        for f in list(stratified_metrics_filters_ht.globals) +
                        list(stratified_metrics_filters_ht.row_value)
                    })
            meta_annotation_hts.extend(
                [sample_qc_ht, stratified_metrics_filters_ht])

        meta_ht = hl.read_table(args.meta)
        meta_ht = meta_ht.annotate_globals(
            **{
                name: expr
                for ann_ht in meta_annotation_hts
                for name, expr in ann_ht.index_globals().items()
            })

        meta_ht = meta_ht.annotate(
            **{
                name: expr
                for ann_ht in meta_annotation_hts
                for name, expr in ann_ht[meta_ht.key].items()
            })

        filtering_col_prefix = '' if args.filtering_variant_class == 'all' else args.filtering_variant_class + "_"
        meta_ht = meta_ht.annotate_globals(
            filtering_variant_class=args.filtering_variant_class)
        meta_ht = meta_ht.annotate(sample_filters=add_filters_expr(
            filters={
                "ambiguous sex": hl.is_missing(meta_ht.is_female),
                'call_rate': meta_ht.sample_qc.call_rate < args.min_call_rate,
                'duplicate': hl.is_defined(meta_ht.dup_filtered)
                & meta_ht.dup_filtered,
                'related': meta_ht.related_filtered
            },
            current_filters=meta_ht[
                f'{filtering_col_prefix}pop_platform_filters']))

        meta_ht.write(path('full_meta.ht'), overwrite=args.overwrite)
Esempio n. 18
0
def compute_related_samples_to_drop(
    relatedness_ht: hl.Table,
    rank_ht: hl.Table,
    kin_threshold: float,
    filtered_samples: Optional[hl.expr.SetExpression] = None,
    min_related_hard_filter: Optional[int] = None,
) -> hl.Table:
    """
    Computes a Table with the list of samples to drop (and their global rank) to get the maximal independent set of unrelated samples.

    .. note::

        - `relatedness_ht` should be keyed by exactly two fields of the same type, identifying the pair of samples for each row.
        - `rank_ht` should be keyed by a single key of the same type as a single sample identifier in `relatedness_ht`.

    :param relatedness_ht: relatedness HT, as produced by e.g. pc-relate
    :param kin_threshold: Kinship threshold to consider two samples as related
    :param rank_ht: Table with a global rank for each sample (smaller is preferred)
    :param filtered_samples: An optional set of samples to exclude (e.g. these samples were hard-filtered)  These samples will then appear in the resulting samples to drop.
    :param min_related_hard_filter: If provided, any sample that is related to more samples than this parameter will be filtered prior to computing the maximal independent set and appear in the results.
    :return: A Table with the list of the samples to drop along with their rank.
    """

    # Make sure that the key types are valid
    assert len(list(relatedness_ht.key)) == 2
    assert relatedness_ht.key[0].dtype == relatedness_ht.key[1].dtype
    assert len(list(rank_ht.key)) == 1
    assert relatedness_ht.key[0].dtype == rank_ht.key[0].dtype

    logger.info(
        f"Filtering related samples using a kin threshold of {kin_threshold}")
    relatedness_ht = relatedness_ht.filter(relatedness_ht.kin > kin_threshold)

    filtered_samples_rel = set()
    if min_related_hard_filter is not None:
        logger.info(
            f"Computing samples related to too many individuals (>{min_related_hard_filter}) for exclusion"
        )
        gbi = relatedness_ht.annotate(s=list(relatedness_ht.key))
        gbi = gbi.explode(gbi.s)
        gbi = gbi.group_by(gbi.s).aggregate(n=hl.agg.count())
        filtered_samples_rel = gbi.aggregate(
            hl.agg.filter(gbi.n > min_related_hard_filter,
                          hl.agg.collect_as_set(gbi.s)))
        logger.info(
            f"Found {len(filtered_samples_rel)} samples with too many 1st/2nd degree relatives. These samples will be excluded."
        )

    if filtered_samples is not None:
        filtered_samples_rel = filtered_samples_rel.union(
            relatedness_ht.aggregate(
                hl.agg.explode(
                    lambda s: hl.agg.collect_as_set(s),
                    hl.array(list(relatedness_ht.key)).filter(
                        lambda s: filtered_samples.contains(s)),
                )))

    if len(filtered_samples_rel) > 0:
        filtered_samples_lit = hl.literal(filtered_samples_rel)
        relatedness_ht = relatedness_ht.filter(
            filtered_samples_lit.contains(relatedness_ht.key[0])
            | filtered_samples_lit.contains(relatedness_ht.key[1]),
            keep=False,
        )

    logger.info("Annotating related sample pairs with rank.")
    i, j = list(relatedness_ht.key)
    relatedness_ht = relatedness_ht.key_by(s=relatedness_ht[i])
    relatedness_ht = relatedness_ht.annotate(**{
        i:
        hl.struct(s=relatedness_ht.s, rank=rank_ht[relatedness_ht.key].rank)
    })
    relatedness_ht = relatedness_ht.key_by(s=relatedness_ht[j])
    relatedness_ht = relatedness_ht.annotate(**{
        j:
        hl.struct(s=relatedness_ht.s, rank=rank_ht[relatedness_ht.key].rank)
    })
    relatedness_ht = relatedness_ht.key_by(i, j)
    relatedness_ht = relatedness_ht.drop("s")
    relatedness_ht = relatedness_ht.persist()

    related_samples_to_drop_ht = hl.maximal_independent_set(
        relatedness_ht[i],
        relatedness_ht[j],
        keep=False,
        tie_breaker=lambda l, r: l.rank - r.rank,
    )
    related_samples_to_drop_ht = related_samples_to_drop_ht.key_by()
    related_samples_to_drop_ht = related_samples_to_drop_ht.select(
        **related_samples_to_drop_ht.node)
    related_samples_to_drop_ht = related_samples_to_drop_ht.key_by("s")

    if len(filtered_samples_rel) > 0:
        related_samples_to_drop_ht = related_samples_to_drop_ht.union(
            hl.Table.parallelize(
                [
                    hl.struct(s=s, rank=hl.null(hl.tint64))
                    for s in filtered_samples_rel
                ],
                key="s",
            ))

    return related_samples_to_drop_ht
Esempio n. 19
0
def main(args):
    data_type = "exomes" if args.exomes else "genomes"
    hl.init(log=f"/ccdg_sample_qc_{data_type}.log")
    # gcloud compute scp wlu-m:/hard_filter_genomes.log .
    if args.sample_qc:
        compute_sample_qc(data_type).write(
            get_ccdg_results_path(data_type=data_type, result="sample_qc_all"),
            overwrite=args.overwrite,
        )

    if args.impute_sex:
        compute_sex(data_type).write(
            get_ccdg_results_path(data_type=data_type, result="sex"),
            overwrite=args.overwrite,
        )
    # elif args.reannotate_sex:
    #     reannotate_sex(
    #         args.min_cov,
    #         (args.upper_x, (args.lower_xx, args.upper_xx), args.lower_xxx),
    #         ((args.lower_y, args.upper_y), args.lower_yy),
    #     ).write(
    #         get_ccdg_results_path(data_type=data_type, result="sex"),
    #         overwrite=args.overwrite,
    #     )
    ##### Wait for more information
    # if args.compute_hard_filters:
    #     compute_hard_filters(args.min_cov).write(
    #         hard_filtered_samples.path, overwrite=args.overwrite
    #     )

    if args.run_pc_relate or args.reannotate_relatedness:
        if args.run_pc_relate:
            logger.warning(
                "PC-relate requires SSDs and doesn't work with preemptible workers!"
            )
            relatedness_ht = compute_relatedness(
                data_type,
                overwrite=args.overwrite,
            )
        else:
            relatedness_ht = hl.read_table(
                get_ccdg_results_path(data_type=data_type, result="relatedness")
            ).checkpoint(
                "gs://ccdg/tmp/relatedness_ht_checkpoint.ht", overwrite=True
            )  # Copy HT to temp location to overwrite annotation
        relatedness_ht = annotate_relatedness(
            relatedness_ht,
            first_degree_kin_thresholds=tuple(args.first_degree_kin_thresholds),
            second_degree_min_kin=args.second_degree_kin_cutoff,
            ibd0_0_max=args.ibd0_0_max,
        )
        relatedness_ht.write(
            get_ccdg_results_path(data_type=data_type, result="relatedness"),
            overwrite=args.overwrite,
        )

    if args.compute_related_samples_to_drop:
        relatedness_ht = hl.read_table(
            get_ccdg_results_path(data_type=data_type, result="relatedness")
        )
        related_samples_to_remove = hl.maximal_independent_set(
            relartedness_ht.i, pairs.j, False
        ).checkpoint(
            get_ccdg_results_path(data_type=data_type, result="related_samples"),
            overwrite=args.overwrite,
        )

    if args.update_variant_filtered_pca_mt:
        pca_var_ht = hl.read_table(get_pca_variants_path())
        mt = hl.vds.to_dense_mt(get_qc_vds(data_type, split=True))
        mt = mt.filter_rows(hl.is_defined(pca_var_ht[mt.row_key])).checkpoint(
            get_pca_variants_path(ld_pruned=True, data=f"ccdg_{data_type}", mt=True),
            overwrite=args.overwrite,
            _read_if_exists=(not args.overwrite),
        )

    if args.run_pc_project:
        ## TODO: Rank samples and hard filter samples
        mt = hl.read_matrix_table(
            get_pca_variants_path(ld_pruned=True, data=f"ccdg_{data_type}", mt=True)
        )

        pca_loadings = hl.read_table(path_to_gnomad_loadings)

        pca_ht = hl.experimental.pc_project(
            mt.GT,
            pca_loadings.loadings,
            pca_loadings.pca_af,
        )

        pca_ht.checkpoint(
            get_ccdg_results_path(
                data_type=data_type, result="gnomad_pc_project_scores"
            ),
            overwrite=args.overwrite,
        )

        # related_ht = hl.read_table(
        #     get_ccdg_results_path(data_type=data_type, result="related_samples")
        # )
        #
        # related_mt = mt.filter_cols(hl.is_defined(related_mt[mt.col_key]), keep=True)
        # pca_mt = mt.filter_cols(hl.is_defined(related_mt[mt.col_key]), keep=False)

        # pca_ht = hl.experimental.pc_project(
        #     pca_mt.GT, pca_loadings.loadings, pca_loadings.pca_af
        # )
        # pca_mt = pca_mt.annotate_cols(scores=pca_ht[pca_mt.col_key].scores)
        #
        # related_ht = hl.experimental.pc_project(
        #     related_mt.GT, pca_loadings.loadings, pca_loadings.pca_af
        # )
        # related_mt = related_mt.annotate_cols(
        #     scores=related_ht[related_mt.col_key].scores
        # )

    if args.assign_pops:
        with hl.hadoop_open(
            path_to_gnomad_rf,
            "rb",
        ) as f:
            fit = pickle.load(f)

        # Reduce the scores to only those used in the RF model, this was 6 for v2 and 16 for v3.1
        n_pcs = fit.n_features_
        pca_ht = hl.read_table(
            get_ccdg_results_path(
                data_type=data_type, result="gnomad_pc_project_scores"
            )
        )
        pca_ht = pca_ht.annotate(scores=pca_ht.scores[:n_pcs])
        pop_ht, rf_model = assign_population_pcs(
            pca_ht,
            pc_cols=pca_ht.scores,
            fit=fit,
        )

        pop_ht = pop_ht.checkpoint(
            get_ccdg_results_path(data_type=data_type, result="pop_assignment"),
            overwrite=args.overwrite,
            _read_if_exists=not args.overwrite,
        )
        pop_ht.transmute(
            **{f"PC{i + 1}": pop_ht.pca_scores[i] for i in range(n_pcs)}
        ).export(
            get_ccdg_results_path(data_type=data_type, result="pop_assignment")[:-2]
            + "tsv"
        )

        with hl.hadoop_open(
            get_ccdg_results_path(data_type=data_type, result="pop_RF_fit")[:-2]
            + "pickle",
            "wb",
        ) as out:
            pickle.dump(rf_model, out)

    if args.calculate_inbreeding:
        qc_mt = hl.read_matrix_table(
            get_pca_variants_path(ld_pruned=True, data=f"ccdg_{data_type}", mt=True)
        )
        pop_ht = hl.read_table(
            get_ccdg_results_path(data_type=data_type, result="pop_assignment"),
        )
        qc_mt = qc_mt.annotate_cols(pop=pop_ht[qc_mt.col_key].pop)
        qc_mt = qc_mt.annotate_rows(
            call_stats_by_pop=hl.agg.group_by(
                qc_mt.pop, hl.agg.call_stats(qc_mt.GT, qc_mt.alleles)
            )
        )
        inbreeding_ht = (
            qc_mt.annotate_cols(
                inbreeding=hl.agg.inbreeding(
                    qc_mt.GT, qc_mt.call_stats_by_pop[qc_mt.pop].AF[1]
                )
            )
            .cols()
            .select("inbreeding")
        )
        inbreeding_ht.write(
            get_ccdg_results_path(data_type=data_type, result="inbreeding"),
            overwrite=args.overwrite,
        )

    if args.apply_stratified_filters or args.apply_regressed_filters:
        filtering_qc_metrics = args.filtering_qc_metrics.split(",")
        sample_qc_ht = hl.read_table(
            get_ccdg_results_path(data_type=data_type, result="sample_qc_bi_allelic")
        )
        pc_scores = hl.read_table(
            get_ccdg_results_path(data_type=data_type, result="pc_scores")
        )
        sample_qc_ht = sample_qc_ht.select(
            scores=pc_scores[sample_qc_ht.key]["scores"],
        )
        pop_ht = hl.read_table(
            get_ccdg_results_path(data_type=data_type, result="pop_assignment"),
        )

        if "inbreeding" in filtering_qc_metrics:
            inbreeding_ht = hl.read_table(
                get_ccdg_results_path(data_type=data_type, result="inbreeding")
            )[sample_qc_ht.key]
            sample_qc_ht = sample_qc_ht.annotate(
                inbreeding=inbreeding_ht.inbreeding.f_stat
            )

        if args.apply_regressed_filters:
            n_pcs = args.regress_n_pcs
            residuals_ht = compute_qc_metrics_residuals(
                ht=sample_qc_ht,
                pc_scores=sample_qc_ht.scores[:n_pcs],
                qc_metrics={
                    metric: sample_qc_ht[metric] for metric in filtering_qc_metrics
                },
            )
            residuals_ht = residuals_ht.filter(
                hl.is_missing(hard_filtered_samples.ht()[residuals_ht.key])
            )
            stratified_metrics_ht = compute_stratified_metrics_filter(
                ht=residuals_ht,
                qc_metrics=dict(residuals_ht.row_value),
                metric_threshold={
                    "n_singleton_residual": (math.inf, 8.0),
                    "r_het_hom_var_residual": (math.inf, 4.0),
                },
            )

            residuals_ht = residuals_ht.annotate(
                **stratified_metrics_ht[residuals_ht.key]
            )
            residuals_ht = residuals_ht.annotate_globals(
                **stratified_metrics_ht.index_globals(),
                n_pcs=n_pcs,
            )
        else:
            logger.info(
                "Computing stratified QC metrics filters using metrics: "
                + ", ".join(filtering_qc_metrics)
            )
            sample_qc_ht = sample_qc_ht.annotate(qc_pop=pop_ht[sample_qc_ht.key].pop)
            # TODO: compute hard-filtered samples
            sample_qc_ht = sample_qc_ht.filter(
                hl.is_missing(hard_filtered_samples.ht()[sample_qc_ht.key])
            )
            stratified_metrics_ht = compute_stratified_metrics_filter(
                sample_qc_ht,
                qc_metrics={
                    metric: sample_qc_ht[metric] for metric in filtering_qc_metrics
                },
                strata={"qc_pop": sample_qc_ht.qc_pop},
                metric_threshold={"n_singleton": (4.0, 8.0)},
            )
def query():
    """Query script entry point."""

    hl.init(default_reference='GRCh38')

    # save relatedness estimates for pc_relate global populations
    ht = hl.read_table(PC_RELATE_ESTIMATE_GLOBAL)
    related_samples = ht.filter(ht.kin > 0.1)
    pc_relate_global = pd.DataFrame({
        'i_s': related_samples.i.s.collect(),
        'j_s': related_samples.j.s.collect(),
        'kin': related_samples.kin.collect(),
    })
    filename = output_path(f'pc_relate_global_matrix.csv', 'analysis')
    pc_relate_global.to_csv(filename, index=False)

    # get maximal independent set
    pairs = ht.filter(ht['kin'] >= 0.125)
    related_samples_to_remove = hl.maximal_independent_set(
        pairs.i, pairs.j, False)

    related_samples = pd.DataFrame(
        {'removed_individual': related_samples_to_remove.node.s.collect()})
    filename = output_path(f'pc_relate_global_maximal_independent_set.csv',
                           'analysis')
    related_samples.to_csv(filename, index=False)

    # save relatedness estimates for pc_relate NFE samples
    ht = hl.read_table(PC_RELATE_ESTIMATE_NFE)
    related_samples = ht.filter(ht.kin > 0.1)
    pc_relate_nfe = pd.DataFrame({
        'i_s': related_samples.i.s.collect(),
        'j_s': related_samples.j.s.collect(),
        'kin': related_samples.kin.collect(),
    })
    filename = output_path(f'pc_relate_nfe_matrix.csv', 'analysis')
    pc_relate_nfe.to_csv(filename, index=False)
    # get maximal independent set
    pairs = ht.filter(ht['kin'] >= 0.125)
    related_samples_to_remove = hl.maximal_independent_set(
        pairs.i, pairs.j, False)
    related_samples = pd.DataFrame(
        {'removed_individual': related_samples_to_remove.node.s.collect()})
    filename = output_path(f'pc_relate_nfe_maximal_independent_set.csv',
                           'analysis')
    related_samples.to_csv(filename, index=False)

    # save relatedness estimates for KING NFE samples
    mt = hl.read_matrix_table(KING_ESTIMATE_NFE)
    ht = mt.entries()
    # remove entries where samples are identical
    related_samples = ht.filter(ht.s_1 != ht.s)
    related_samples = ht.filter(ht.phi > 0.1)
    king_nfe = pd.DataFrame({
        'i_s': related_samples.s_1.collect(),
        'j_s': related_samples.s.collect(),
        'kin': related_samples.phi.collect(),
    })
    filename = output_path(f'king_nfe_matrix_90k.csv', 'analysis')
    king_nfe.to_csv(filename, index=False)
    # save KING NFE maximal independent set
    second_degree_related_samples = ht.filter(
        (ht.s_1 != ht.s) & (ht.phi > 0.125), keep=True)
    struct = hl.struct(i=second_degree_related_samples.s_1,
                       j=second_degree_related_samples.s)
    struct = struct.annotate(phi=second_degree_related_samples.phi)
    related_samples_to_remove = hl.maximal_independent_set(
        struct.i,
        struct.j,
        False  # pylint: disable=E1101
    )
    related_samples = pd.DataFrame(
        {'related_individual': related_samples_to_remove.node.collect()})
    filename = output_path(
        f'king_90k_related_samples_maximal_independent_set.csv', 'analysis')
    related_samples.to_csv(filename, index=False)