예제 #1
0
def main(args):
    # Start Hail
    hl.init(default_reference=args.default_ref_genome)

    # Import unfiltered split MT
    mt = get_mt_data(dataset=args.exome_cohort, part='unfiltered')

    # Compute stratified sample_qc (biallelic and multi-allelic sites)
    sample_qc_ht = compute_sample_qc(mt)

    # Write HT with sample QC metrics
    sample_qc_ht = sample_qc_ht.checkpoint(get_sample_qc_ht_path(
        dataset=args.exome_cohort, part='high_conf_autosomes'),
                                           overwrite=args.overwrite,
                                           _read_if_exists=not args.overwrite)

    # annotate sample population and platform qc info
    pop_qc = hl.read_table(get_sample_qc_ht_path(part='population_qc'))
    platform_qc = hl.read_table(get_sample_qc_ht_path(part='platform_pca'))

    ann_expr = {
        'qc_pop': pop_qc[sample_qc_ht.s].predicted_pop,
        'qc_platform': platform_qc[sample_qc_ht.s].qc_platform
    }

    sample_qc_ht = sample_qc_ht.annotate(**ann_expr)

    # Export HT to file
    if args.write_to_file:
        (sample_qc_ht.flatten().export(
            f"{get_sample_qc_ht_path(dataset=args.exome_cohort, part='high_conf_autosomes')}.tsv.bgz"
        ))

    # Apply stratified sample filters based on defined QC metrics
    exome_qc_metrics = [
        'n_snp', 'r_ti_tv', 'r_insertion_deletion', 'n_insertion',
        'n_deletion', 'r_het_hom_var'
    ]

    print('Computing stratified metrics filters...')
    exome_pop_platform_filter_ht = compute_stratified_metrics_filter(
        sample_qc_ht, exome_qc_metrics, ['qc_pop', 'qc_platform'])

    exome_pop_platform_filter_ht = exome_pop_platform_filter_ht.checkpoint(
        get_sample_qc_ht_path(dataset=args.exome_cohort,
                              part='stratified_metrics_filter'),
        overwrite=args.overwrite,
        _read_if_exists=not args.overwrite)

    # Export HT to file
    if args.write_to_file:
        (exome_pop_platform_filter_ht.export(
            f"{get_sample_qc_ht_path(dataset=args.exome_cohort, part='stratified_metrics_filter')}.tsv.bgz"
        ))

    # Stop Hail
    hl.stop()

    print("Finished!")
예제 #2
0
def main(args):

    # nfs_dir = 'file:///home/ubuntu/data'

    hl.init(default_reference=args.default_reference)

    logger.info("Importing data...")

    # import unfiltered MT
    mt = get_mt_data(dataset=args.exome_cohort, part='unfiltered')

    # keep bi-allelic variants
    mt = (mt
          .filter_rows(bi_allelic_expr(mt), keep=True)
          )

    # read intervals for filtering variants (used mainly for exomes)
    def _get_interval_table(interval: str) -> Union[None, hl.Table]:
        return get_capture_interval_ht(name=interval,
                                       reference=args.default_reference) if interval is not None else interval

    ht = compute_mean_coverage(mt=mt,
                               normalization_contig=args.normalization_contig,
                               included_calling_intervals=_get_interval_table(args.interval_to_include),
                               excluded_calling_intervals=_get_interval_table(args.interval_to_exclude),
                               chr_x=args.chr_x,
                               chr_y=args.chr_y)

    logger.info("Exporting data...")

    # write HT
    output_ht_path = get_sample_qc_ht_path(part='sex_chrom_coverage')
    ht.write(output=output_ht_path,
             overwrite=args.overwrite)

    # export to file if true
    if args.write_to_file:
        (ht
         .export(f'{output_ht_path}.tsv.bgz')
         )

    hl.stop()

    print("Done!")
예제 #3
0
def main(args):

    # Start Hail
    hl.init(default_reference=args.default_ref_genome)

    # Import unfiltered split MT
    mt = get_mt_data(dataset=args.exome_cohort, part='unfiltered')

    # Compute stratified sample_qc (biallelic and multi-allelic sites)
    sample_qc_ht = compute_sample_qc(mt)

    # Write HT with sample QC metrics
    output_path = (
        f'{nfs_dir}/hail_data/sample_qc/chd_ukbb.sample_qc.high_conf.autosomes.cds.capture_intervals.rare_common.ht'
    )

    sample_qc_ht = sample_qc_ht.checkpoint(output_path,
                                           overwrite=args.overwrite,
                                           _read_if_exists=not args.overwrite)

    # annotate sample population and platform qc info
    pop_qc = hl.read_table(get_sample_qc_ht_path(part='population_qc'))
    platform_qc = hl.read_table(get_sample_qc_ht_path(part='platform_pca'))

    ann_expr = {
        'qc_pop': pop_qc[sample_qc_ht.s].predicted_pop,
        'qc_platform': platform_qc[sample_qc_ht.s].qc_platform
    }

    sample_qc_ht = sample_qc_ht.annotate(**ann_expr)

    # Export HT to file
    if args.write_to_file:
        (sample_qc_ht.flatten().export(f"{output_path}.tsv.bgz"))

    # Stop Hail
    hl.stop()

    print("Finished!")
예제 #4
0
def main(args):
    # Start Hail
    hl.init(default_reference=args.default_reference)

    if not args.skip_filter_step:
        logger.info("Importing data...")

        # import unfiltered MT
        mt = get_mt_data(dataset=args.exome_cohort, part='unfiltered')

        # Read MT from 1kgenome and keep only locus defined in interval
        mt_1kg = get_1kg_mt(args.default_reference)

        # Joining dataset (inner join). Keep only 'GT' entry field
        mt_joint = (mt.select_entries('GT').union_cols(
            mt_1kg.select_entries('GT'), row_join_type='inner'))

        logger.info(
            "Filtering joint MT to bi-allelic, high-callrate, common SNPs...")
        mt_joint = (mt_joint.filter_rows(
            bi_allelic_expr(mt_joint)
            & hl.is_snp(mt_joint.alleles[0], mt_joint.alleles[1])
            & (hl.agg.mean(mt_joint.GT.n_alt_alleles()) / 2 > 0.001)
            & (hl.agg.fraction(hl.is_defined(mt_joint.GT)) > 0.99)).
                    naive_coalesce(1000))

        logger.info(
            "Checkpoint: writing joint filtered MT before LD pruning...")
        mt_joint = mt_joint.checkpoint(get_mt_checkpoint_path(
            dataset=args.exome_cohort,
            part='joint_1kg_high_callrate_common_snp_biallelic'),
                                       overwrite=True)

        logger.info(
            f"Running ld_prune with r2 = {args.ld_prune_r2} on MT with {mt_joint.count_rows()} variants..."
        )
        # remove correlated variants
        pruned_variant_table = hl.ld_prune(mt_joint.GT,
                                           r2=args.ld_prune_r2,
                                           bp_window_size=500000,
                                           memory_per_core=512)
        mt_joint = (mt_joint.filter_rows(
            hl.is_defined(pruned_variant_table[mt_joint.row_key])))

        logger.info("Writing filtered joint MT with variants in LD pruned...")
        (mt_joint.write(get_qc_mt_path(
            dataset=args.exome_cohort + '_1kg',
            part='joint_high_callrate_common_snp_biallelic',
            split=True,
            ld_pruned=True),
                        overwrite=args.overwrite))

    logger.info("Importing filtered joint MT...")
    mt_joint = hl.read_matrix_table(
        get_qc_mt_path(dataset=args.exome_cohort + '_1kg',
                       part='joint_high_callrate_common_snp_biallelic',
                       split=True,
                       ld_pruned=True))

    logger.info(f"Running PCA with {mt_joint.count_rows()} variants...")
    # run pca on merged dataset
    eigenvalues, pc_scores, _ = hl.hwe_normalized_pca(mt_joint.GT,
                                                      k=args.n_pcs)

    logger.info(f"Eigenvalues: {eigenvalues}")  # TODO: save eigenvalues?

    # Annotate PC array as independent fields.
    pca_table = (pc_scores.annotate(
        **
        {'PC' + str(k + 1): pc_scores.scores[k]
         for k in range(0, args.n_pcs)}).drop('scores'))

    logger.info(f"Writing HT with PCA results...")
    # write as HT
    output_ht_path = get_sample_qc_ht_path(dataset=args.exome_cohort,
                                           part='joint_pca_1kg')
    pca_table.write(output=output_ht_path)

    if args.write_to_file:
        (pca_table.export(f'{output_ht_path}.tsv.bgz'))

    # Stop Hail
    hl.stop()

    print("Done!")
예제 #5
0
def main(args):
    hl.init(default_reference=args.default_ref_genome)

    if args.run_test_mode:
        logger.info('Running pipeline on test data...')
        mt = (get_mt_data(part='raw_chr20').sample_rows(0.1))
    else:
        logger.info(
            'Running pipeline on MatrixTable wih adjusted genotypes...')
        ds = args.exome_cohort
        mt = hl.read_matrix_table(
            get_qc_mt_path(dataset=ds,
                           part='unphase_adj_genotypes',
                           split=True))

    # 1. Sample-QC filtering
    if not args.skip_sample_qc_filtering:
        logger.info('Applying per sample QC filtering...')

        mt = apply_sample_qc_filtering(mt)

        logger.info(
            'Writing sample qc-filtered mt with rare variants (internal maf 0.01) to disk...'
        )
        mt = (mt.write(f'{hdfs_dir}/chd_ukbb.sample_qc_filtered.mt',
                       overwrite=True))

    # 2. Variant-QC filtering
    if not args.skip_variant_qc_filtering:

        logger.info('Applying per variant QC filtering...')

        if hl.hadoop_is_file(
                f'{hdfs_dir}/chd_ukbb.sample_qc_filtered.mt/_SUCCESS'):
            logger.info('Reading pre-existing sample qc-filtered MT...')
            mt = hl.read_matrix_table(
                f'{hdfs_dir}/chd_ukbb.sample_qc_filtered.mt')
        mt = apply_variant_qc_filtering(mt)

        # write hard filtered MT to disk
        logger.info(
            'Writing variant qc-filtered mt with rare variants (internal maf 0.01) to disk...'
        )
        mt = (mt.write(f'{hdfs_dir}/chd_ukbb.variant_qc_filtered.mt',
                       overwrite=True))

    # 3. Annotate AFs

    # allelic frequency cut-off
    maf_cutoff = args.af_max_threshold

    if not args.skip_af_filtering:

        if hl.hadoop_is_file(
                f'{hdfs_dir}/chd_ukbb.variant_qc_filtered.mt/_SUCCESS'):
            logger.info(
                'Reading pre-existing sample/variant qc-filtered MT...')
            mt = hl.read_matrix_table(
                f'{hdfs_dir}/chd_ukbb.variant_qc_filtered.mt')

        # Annotate allelic frequencies from external source,
        # and compute internal AF on samples passing QC
        af_ht = get_af_annotation_ht()

        mt = (mt.annotate_rows(**af_ht[mt.row_key]))

        filter_expressions = [
            af_filter_expr(mt, 'internal_af', af_cutoff=maf_cutoff),
            af_filter_expr(mt, 'gnomad_genomes_af', af_cutoff=maf_cutoff),
            af_filter_expr(mt, 'gnomAD_AF', af_cutoff=maf_cutoff),
            af_filter_expr(mt, 'ger_af', af_cutoff=maf_cutoff),
            af_filter_expr(mt, 'rumc_af', af_cutoff=maf_cutoff),
            af_filter_expr(mt, 'bonn_af', af_cutoff=maf_cutoff)
        ]

        mt = (mt.filter_rows(functools.reduce(operator.iand,
                                              filter_expressions),
                             keep=True))

        logger.info(
            'Writing qc-filtered MT filtered to external maf with to disk...')
        mt = (mt.write(f'{hdfs_dir}/chd_ukbb.qc_final.rare.mt',
                       overwrite=True))

    # 4. ##### Burden Test ######

    logger.info('Running burden test...')

    if hl.hadoop_is_file(f'{hdfs_dir}/chd_ukbb.qc_final.rare.mt/_SUCCESS'):
        logger.info(
            'Reading pre-existing sample/variant qc-filtered MT with rare variants...'
        )
        mt = hl.read_matrix_table(f'{hdfs_dir}/chd_ukbb.qc_final.rare.mt')

    ## Add VEP-annotated fields
    vep_ht = get_vep_annotation_ht()

    mt = (mt.annotate_rows(LoF=vep_ht[mt.row_key].vep.LoF,
                           Consequence=vep_ht[mt.row_key].vep.Consequence,
                           DOMAINS=vep_ht[mt.row_key].vep.DOMAINS,
                           SYMBOL=vep_ht[mt.row_key].vep.SYMBOL))

    ## Filter to bi-allelic variants
    if args.filter_biallelic:
        logger.info('Running burden test on biallelic variants...')
        mt = mt.filter_rows(bi_allelic_expr(mt))

    ## Filter to variants within protein domain(s)
    if args.filter_protein_domain:
        logger.info(
            'Running burden test on variants within protein domain(s)...')
        mt = mt.filter_rows(vep_protein_domain_filter_expr(mt.DOMAINS),
                            keep=True)

    ## Add cases/controls sample annotations
    tb_sample = get_sample_meta_data()
    mt = (mt.annotate_cols(**tb_sample[mt.s]))

    mt = (mt.filter_cols(mt['phe.is_case'] | mt['phe.is_control']))

    ## Annotate pathogenic scores
    ht_scores = get_vep_scores_ht()
    mt = mt.annotate_rows(**ht_scores[mt.row_key])

    ## Classify variant into (major) consequence groups
    score_expr_ann = {
        'hcLOF': mt.LoF == 'HC',
        'syn': mt.Consequence == 'synonymous_variant',
        'miss': mt.Consequence == 'missense_variant'
    }

    # Update dict expr annotations with combinations of variant consequences categories
    score_expr_ann.update({
        'missC': (hl.sum([(mt['vep.MVP_score'] >= MVP_THRESHOLD),
                          (mt['vep.REVEL_score'] >= REVEL_THRESHOLD),
                          (mt['vep.CADD_PHRED'] >= CADD_THRESHOLD)]) >= 2)
        & score_expr_ann.get('miss')
    })

    score_expr_ann.update({
        'hcLOF_missC':
        score_expr_ann.get('hcLOF') | score_expr_ann.get('missC')
    })

    mt = (mt.annotate_rows(csq_group=score_expr_ann))

    # Transmute csq_group and convert dict to set where the group is defined
    # (easier to explode and grouping later)
    mt = (mt.transmute_rows(csq_group=hl.set(
        hl.filter(lambda x: mt.csq_group.get(x), mt.csq_group.keys()))))

    mt = (mt.filter_rows(hl.len(mt.csq_group) > 0))

    # Explode nested csq_group before grouping
    mt = (mt.explode_rows(mt.csq_group))

    # print('Number of samples/variants: ')
    # print(mt.count())

    # Group mt by gene/csq_group.
    mt_grouped = (mt.group_rows_by(mt['SYMBOL'], mt['csq_group']).aggregate(
        hets=hl.agg.any(mt.GT.is_het()),
        homs=hl.agg.any(mt.GT.is_hom_var()),
        chets=hl.agg.count_where(mt.GT.is_het()) >= 2,
        homs_chets=(hl.agg.count_where(mt.GT.is_het()) >= 2) |
        (hl.agg.any(mt.GT.is_hom_var()))).repartition(100).persist())
    mts = []

    if args.homs:
        # select homs genotypes.

        mt_homs = (mt_grouped.select_entries(
            mac=mt_grouped.homs).annotate_rows(agg_genotype='homs'))

        mts.append(mt_homs)

    if args.chets:
        # select compound hets (chets) genotypes.
        mt_chets = (mt_grouped.select_entries(
            mac=mt_grouped.chets).annotate_rows(agg_genotype='chets'))

        mts.append(mt_chets)

    if args.homs_chets:
        # select chets and/or homs genotypes.
        mt_homs_chets = (mt_grouped.select_entries(
            mac=mt_grouped.homs_chets).annotate_rows(
                agg_genotype='homs_chets'))

        mts.append(mt_homs_chets)

    if args.hets:
        # select hets genotypes
        mt_hets = (mt_grouped.select_entries(
            mac=mt_grouped.hets).annotate_rows(agg_genotype='hets'))

        mts.append(mt_hets)

    ## Joint MatrixTables
    mt_grouped = hl.MatrixTable.union_rows(*mts)

    # Generate table of counts
    tb_gene = (mt_grouped.annotate_rows(
        n_cases=hl.agg.filter(mt_grouped['phe.is_case'],
                              hl.agg.sum(mt_grouped.mac)),
        n_syndromic=hl.agg.filter(mt_grouped['phe.is_syndromic'],
                                  hl.agg.sum(mt_grouped.mac)),
        n_nonsyndromic=hl.agg.filter(mt_grouped['phe.is_nonsyndromic'],
                                     hl.agg.sum(mt_grouped.mac)),
        n_controls=hl.agg.filter(mt_grouped['phe.is_control'],
                                 hl.agg.sum(mt_grouped.mac)),
        n_total_cases=hl.agg.filter(mt_grouped['phe.is_case'], hl.agg.count()),
        n_total_syndromic=hl.agg.filter(mt_grouped['phe.is_syndromic'],
                                        hl.agg.count()),
        n_total_nonsyndromic=hl.agg.filter(mt_grouped['phe.is_nonsyndromic'],
                                           hl.agg.count()),
        n_total_controls=hl.agg.filter(mt_grouped['phe.is_control'],
                                       hl.agg.count())).rows())

    # run fet stratified by proband type
    analysis = ['all_cases', 'syndromic', 'nonsyndromic']

    tbs = []
    for proband in analysis:
        logger.info(f'Running test for {proband}...')
        colCases = None
        colTotalCases = None
        colControls = 'n_controls'
        colTotalControls = 'n_total_controls'
        if proband == 'all_cases':
            colCases = 'n_cases'
            colTotalCases = 'n_total_cases'
        if proband == 'syndromic':
            colCases = 'n_syndromic'
            colTotalCases = 'n_total_syndromic'
        if proband == 'nonsyndromic':
            colCases = 'n_nonsyndromic'
            colTotalCases = 'n_total_nonsyndromic'

        tb_fet = compute_fisher_exact(tb=tb_gene,
                                      n_cases_col=colCases,
                                      n_control_col=colControls,
                                      total_cases_col=colTotalCases,
                                      total_controls_col=colTotalControls,
                                      correct_total_counts=True,
                                      root_col_name='fet',
                                      extra_fields={
                                          'analysis': proband,
                                          'maf': maf_cutoff
                                      })

        # filter out zero-count genes
        tb_fet = (tb_fet.filter(
            hl.sum([tb_fet[colCases], tb_fet[colControls]]) > 0, keep=True))

        tbs.append(tb_fet)

    tb_final = hl.Table.union(*tbs)

    tb_final.describe()

    # export results
    date = current_date()
    run_hash = str(uuid.uuid4())[:6]
    output_path = f'{args.output_dir}/{date}/{args.exome_cohort}.fet_burden.{run_hash}.ht'

    tb_final = (tb_final.checkpoint(output=output_path))

    if args.write_to_file:
        # write table to disk as TSV file
        (tb_final.export(f'{output_path}.tsv'))

    hl.stop()
예제 #6
0
def main(args):
    # Start Hail
    hl.init(default_reference=args.default_ref_genome)

    # Import raw split MT
    mt = (get_mt_data(dataset=args.exome_cohort, part='raw',
                      split=True).select_cols())

    ht = (mt.cols().key_by('s'))

    # Annotate samples filters
    sample_qc_filters = {}

    # 1. Add sample hard filters annotation expr
    sample_qc_hard_filters_ht = hl.read_table(
        get_sample_qc_ht_path(dataset=args.exome_cohort, part='hard_filters'))

    sample_qc_filters.update(
        {'hard_filters': sample_qc_hard_filters_ht[ht.s]['hard_filters']})

    # 2. Add population qc filters annotation expr
    sample_qc_pop_ht = hl.read_table(
        get_sample_qc_ht_path(dataset=args.exome_cohort, part='population_qc'))

    sample_qc_filters.update(
        {'predicted_pop': sample_qc_pop_ht[ht.s]['predicted_pop']})

    # 3. Add relatedness filters annotation expr
    related_samples_to_drop = get_related_samples_to_drop()
    related_samples = hl.set(
        related_samples_to_drop.aggregate(
            hl.agg.collect_as_set(related_samples_to_drop.node.id)))

    sample_qc_filters.update({'is_related': related_samples.contains(ht.s)})

    # 4. Add stratified sample qc (population/platform) annotation expr
    sample_qc_pop_platform_filters_ht = hl.read_table(
        get_sample_qc_ht_path(dataset=args.exome_cohort,
                              part='stratified_metrics_filter'))

    sample_qc_filters.update({
        'pop_platform_filters':
        sample_qc_pop_platform_filters_ht[ht.s]['pop_platform_filters']
    })

    ht = (ht.annotate(**sample_qc_filters))

    # Final sample qc filter joint expression
    final_sample_qc_ann_expr = {
        'pass_filters':
        hl.cond((hl.len(ht.hard_filters) == 0) &
                (hl.len(ht.pop_platform_filters) == 0) &
                (ht.predicted_pop == 'EUR') & ~ht.is_related, True, False)
    }
    ht = (ht.annotate(**final_sample_qc_ann_expr))

    logger.info('Writing final sample qc HT to disk...')
    output_path_ht = get_sample_qc_ht_path(dataset=args.exome_cohort,
                                           part='final_qc')

    ht = ht.checkpoint(output_path_ht, overwrite=args.overwrite)

    # Export final sample QC annotations to file
    if args.write_to_file:
        (ht.export(f'{output_path_ht}.tsv.bgz'))

    ## Release final unphase MT with adjusted genotypes filtered
    mt = unphase_mt(mt)
    mt = annotate_adj(mt)
    mt = mt.filter_entries(mt.adj).select_entries('GT', 'DP', 'GQ', 'adj')

    logger.info('Writing unphase MT with adjusted genotypes to disk...')
    # write MT
    mt.write(get_qc_mt_path(dataset=args.exome_cohort,
                            part='unphase_adj_genotypes',
                            split=True),
             overwrite=args.overwrite)

    # Stop Hail
    hl.stop()

    print("Finished!")
예제 #7
0
def main(args):
    hl.init(default_reference=args.default_ref_genome)

    if args.run_test_mode:
        logger.info('Running pipeline on test data...')
        mt = (get_mt_data(part='raw_chr20').sample_rows(0.1))
    else:
        logger.info(
            'Running pipeline on MatrixTable wih adjusted genotypes...')
        ds = args.exome_cohort
        mt = hl.read_matrix_table(
            get_qc_mt_path(dataset=ds,
                           part='unphase_adj_genotypes',
                           split=True))

    # 1. Sample-QC filtering
    if not args.skip_sample_qc_filtering:
        logger.info('Applying per sample QC filtering...')

        mt = apply_sample_qc_filtering(mt)

        logger.info(
            'Writing sample qc-filtered mt with rare variants (internal maf 0.01) to disk...'
        )
        mt = (mt.write(f'{hdfs_dir}/chd_ukbb.sample_qc_filtered.mt',
                       overwrite=True))

    # 2. Variant-QC filtering
    if not args.skip_variant_qc_filtering:

        logger.info('Applying per variant QC filtering...')

        if hl.hadoop_is_file(
                f'{hdfs_dir}/chd_ukbb.sample_qc_filtered.mt/_SUCCESS'):
            logger.info('Reading pre-existing sample qc-filtered MT...')
            mt = hl.read_matrix_table(
                f'{hdfs_dir}/chd_ukbb.sample_qc_filtered.mt')
        mt = apply_variant_qc_filtering(mt)

        # write hard filtered MT to disk
        logger.info(
            'Writing variant qc-filtered mt with rare variants (internal maf 0.01) to disk...'
        )
        mt = (mt.write(f'{hdfs_dir}/chd_ukbb.variant_qc_filtered.mt',
                       overwrite=True))

    # 3. Annotate AFs

    # allelic frequency cut-off
    maf_cutoff = args.af_max_threshold

    if not args.skip_af_filtering:

        if hl.hadoop_is_file(
                f'{hdfs_dir}/chd_ukbb.variant_qc_filtered.mt/_SUCCESS'):
            logger.info(
                'Reading pre-existing sample/variant qc-filtered MT...')
            mt = hl.read_matrix_table(
                f'{hdfs_dir}/chd_ukbb.variant_qc_filtered.mt')

        # Annotate allelic frequencies from external source,
        # and compute internal AF on samples passing QC
        af_ht = get_af_annotation_ht()

        mt = (mt.annotate_rows(**af_ht[mt.row_key]))

        filter_expressions = [
            af_filter_expr(mt, 'internal_af', af_cutoff=maf_cutoff),
            af_filter_expr(mt, 'gnomad_genomes_af', af_cutoff=maf_cutoff),
            af_filter_expr(mt, 'gnomAD_AF', af_cutoff=maf_cutoff),
            af_filter_expr(mt, 'ger_af', af_cutoff=maf_cutoff),
            af_filter_expr(mt, 'rumc_af', af_cutoff=maf_cutoff),
            af_filter_expr(mt, 'bonn_af', af_cutoff=maf_cutoff)
        ]

        mt = (mt.filter_rows(functools.reduce(operator.iand,
                                              filter_expressions),
                             keep=True))

        logger.info(
            f'Writing sample/variant QCed MT with rare variants at maf: {args.af_max_threshold}.'
        )
        mt = (mt.write(f'{hdfs_dir}/chd_ukbb.qc_final.rare.mt',
                       overwrite=True))

    # 4. ##### Run gene-set burden logistic regression ######

    logger.info('Running gene-set burden logistic regression test...')

    if hl.hadoop_is_file(f'{hdfs_dir}/chd_ukbb.qc_final.rare.mt/_SUCCESS'):
        logger.info(
            'Reading pre-existing sample/variant qc-filtered MT with rare variants...'
        )
        mt = hl.read_matrix_table(f'{hdfs_dir}/chd_ukbb.qc_final.rare.mt')

    ## Add VEP-annotated fields
    vep_ht = get_vep_annotation_ht()

    mt = (mt.annotate_rows(LoF=vep_ht[mt.row_key].vep.LoF,
                           Consequence=vep_ht[mt.row_key].vep.Consequence,
                           DOMAINS=vep_ht[mt.row_key].vep.DOMAINS,
                           SYMBOL=vep_ht[mt.row_key].vep.SYMBOL))

    ## Filter to bi-allelic variants
    if args.filter_biallelic:
        logger.info('Running burden test on biallelic variants...')
        mt = mt.filter_rows(bi_allelic_expr(mt))

    ## Filter to variants within protein domain(s)
    if args.filter_protein_domain:
        logger.info(
            'Running burden test on variants within protein domain(s)...')
        mt = mt.filter_rows(vep_protein_domain_filter_expr(mt.DOMAINS),
                            keep=True)

    ## Annotate pathogenic scores
    ht_scores = get_vep_scores_ht()
    mt = mt.annotate_rows(**ht_scores[mt.row_key])

    ## Classify variant into (major) consequence groups
    score_expr_ann = {
        'hcLOF': mt.LoF == 'HC',
        'syn': mt.Consequence == 'synonymous_variant',
        'miss': mt.Consequence == 'missense_variant'
    }

    # Update dict expr annotations with combinations of variant consequences categories
    score_expr_ann.update({
        'missC': (hl.sum([(mt['vep.MVP_score'] >= MVP_THRESHOLD),
                          (mt['vep.REVEL_score'] >= REVEL_THRESHOLD),
                          (mt['vep.CADD_PHRED'] >= CADD_THRESHOLD)]) >= 2)
        & score_expr_ann.get('miss')
    })

    score_expr_ann.update({
        'hcLOF_missC':
        score_expr_ann.get('hcLOF') | score_expr_ann.get('missC')
    })

    mt = (mt.annotate_rows(csq_group=score_expr_ann))

    # Transmute csq_group and convert dict to set where the group is defined
    # (easier to explode and grouping later)
    mt = (mt.transmute_rows(csq_group=hl.set(
        hl.filter(lambda x: mt.csq_group.get(x), mt.csq_group.keys()))))

    mt = (mt.filter_rows(hl.len(mt.csq_group) > 0))

    # Explode nested csq_group and gene clusters before grouping
    mt = (mt.explode_rows(mt.csq_group))

    # First-step aggregation:
    # Generate a sample per gene/variant_type (binary) matrix aggregating genotypes as follow:
    #
    #   a) entry: hets
    #   b) entry: homs
    #   c) entry: chets (compound hets)

    mt_grouped = (mt.group_rows_by(mt['SYMBOL'], mt['csq_group']).aggregate(
        hets=hl.agg.any(mt.GT.is_het()),
        homs=hl.agg.any(mt.GT.is_hom_var()),
        chets=hl.agg.count_where(
            mt.GT.is_het()) >= 2).repartition(100).persist())

    # Import/generate gene clusters
    clusters = hl.import_table(args.set_file,
                               no_header=True,
                               delimiter="\t",
                               min_partitions=50,
                               impute=False)
    clusters = generate_clusters_map(clusters)

    # Annotate gene-set info
    mt_grouped = (mt_grouped.annotate_rows(**clusters[mt_grouped.SYMBOL]))

    # Explode nested csq_group before grouping
    mt_grouped = (mt_grouped.explode_rows(mt_grouped.cluster_id))

    # filter rows with defined consequence and gene-set name
    mt_grouped = (mt_grouped.filter_rows(
        hl.is_defined(mt_grouped.csq_group)
        & hl.is_defined(mt_grouped.cluster_id)))

    # 2. Second-step aggregation
    # Generate a sample per gene-sets/variant type matrix aggregating genotypes as follow:
    # if dominant -> sum hets (default)
    # if recessive -> sum (homs)
    # if recessive (a) -> sum (chets)
    # if recessive (b) -> sum (chets and/or homs)

    mts = []

    if args.homs:
        # Group mt by gene-sets/csq_group aggregating homs genotypes.
        mt_homs = (mt_grouped.group_rows_by(
            mt_grouped.csq_group, mt_grouped.cluster_id).aggregate(
                mac=hl.int(hl.agg.sum(mt_grouped.homs))).repartition(
                    100).persist().annotate_rows(agg_genotype='homs'))

        mts.append(mt_homs)

    if args.chets:
        # Group mt by gene-sets/csq_group aggregating compound hets (chets) genotypes.
        mt_chets = (mt_grouped.group_rows_by(
            mt_grouped.csq_group, mt_grouped.cluster_id).aggregate(
                mac=hl.int(hl.agg.sum(mt_grouped.chets))).repartition(
                    100).persist().annotate_rows(agg_genotype='chets'))

        mts.append(mt_chets)

    if args.homs_chets:
        # Group mt by gene-sets/csq_group aggregating chets and/or homs genotypes.
        mt_homs_chets = (mt_grouped.group_rows_by(
            mt_grouped.csq_group, mt_grouped.cluster_id).aggregate(mac=hl.int(
                hl.agg.count_where(mt_grouped.chets
                                   | mt_grouped.homs))).repartition(100).
                         persist().annotate_rows(agg_genotype='homs_chets'))

        mts.append(mt_homs_chets)

    if args.hets:
        # Group mt by gene-sets/csq_group aggregating hets genotypes (default)
        mt_hets = (mt_grouped.group_rows_by(
            mt_grouped.csq_group, mt_grouped.cluster_id).aggregate(
                mac=hl.int(hl.agg.sum(mt_grouped.hets))).repartition(
                    100).persist().annotate_rows(agg_genotype='hets'))

        mts.append(mt_hets)

    ## Joint MatrixTables
    mt_joint = hl.MatrixTable.union_rows(*mts)

    ## Add samples annotations
    # annotate sample covs
    covariates = hl.read_table(
        f'{nfs_dir}/hail_data/sample_qc/chd_ukbb.sample_covariates.ht')
    mt_joint = (mt_joint.annotate_cols(**covariates[mt_joint.s]))

    # annotate case/control phenotype info
    tb_sample = get_sample_meta_data()
    mt_joint = (mt_joint.annotate_cols(**tb_sample[mt_joint.s]))

    mt_joint = (mt_joint.filter_cols(mt_joint['phe.is_case']
                                     | mt_joint['phe.is_control']))

    ## Run logistic regression stratified by proband type
    analysis = ['all_cases', 'syndromic', 'nonsyndromic']

    tbs = []

    covs = ['sex', 'PC1', 'PC2', 'PC3', 'PC4', 'PC5']

    for proband in analysis:
        logger.info(f'Running burden test for {proband}...')

        mt_tmp = hl.MatrixTable

        if proband == 'all_cases':
            mt_tmp = mt_joint
        if proband == 'syndromic':
            mt_tmp = mt_joint.filter_cols(~mt_joint['phe.is_nonsyndromic'])
        if proband == 'nonsyndromic':
            mt_tmp = mt_joint.filter_cols(~mt_joint['phe.is_syndromic'])

        tb_logreg = logistic_regression(mt=mt_tmp,
                                        x_expr='mac',
                                        response='phe.is_case',
                                        covs=covs,
                                        pass_through=['agg_genotype'],
                                        extra_fields={
                                            'analysis': proband,
                                            'maf': maf_cutoff,
                                            'covs': '|'.join(covs)
                                        })

        tbs.append(tb_logreg)

    tb_final = hl.Table.union(*tbs)

    # export results
    date = current_date()
    run_hash = str(uuid.uuid4())[:6]
    output_path = f'{args.output_dir}/{date}/{args.exome_cohort}.logreg_burden.{run_hash}.ht'

    tb_final = (tb_final.checkpoint(output=output_path))

    if args.write_to_file:
        # write table to disk as TSV file
        (tb_final.export(f'{output_path}.tsv'))

    hl.stop()