Exemplo n.º 1
0
def set_female_y_metrics_to_na_expr(
        t: Union[hl.Table, hl.MatrixTable]) -> hl.expr.ArrayExpression:
    """
    Set Y-variant frequency callstats for female-specific metrics to missing structs.

    .. note:: Requires freq, freq_meta, and freq_index_dict annotations to be present in Table or MatrixTable

    :param t: Table or MatrixTable for which to adjust female metrics
    :return: Hail array expression to set female Y-variant metrics to missing values
    """
    female_idx = hl.map(
        lambda x: t.freq_index_dict[x],
        hl.filter(lambda x: x.contains("XX"), t.freq_index_dict.keys()),
    )
    freq_idx_range = hl.range(hl.len(t.freq_meta))

    new_freq_expr = hl.if_else(
        (t.locus.in_y_nonpar() | t.locus.in_y_par()),
        hl.map(
            lambda x: hl.if_else(female_idx.contains(x),
                                 missing_callstats_expr(), t.freq[x]),
            freq_idx_range,
        ),
        t.freq,
    )

    return new_freq_expr
Exemplo n.º 2
0
def get_pheno_id(tb):
    pheno_id = (
        tb.trait_type + '-' + tb.phenocode + '-' + tb.pheno_sex +
        hl.if_else(hl.len(tb.coding) > 0, '-' + tb.coding, '') +
        hl.if_else(hl.len(tb.modifier) > 0, '-' + tb.modifier, '')).replace(
            ' ', '_').replace('/', '_')
    return pheno_id
Exemplo n.º 3
0
    def test_loop(self):
        def triangle_with_ints(n):
            return hl.experimental.loop(
                lambda f, x, c: hl.if_else(x > 0, f(x - 1, c + x), c),
                hl.tint32, n, 0)

        def triangle_with_tuple(n):
            return hl.experimental.loop(
                lambda f, xc: hl.if_else(xc[0] > 0,
                                         f((xc[0] - 1, xc[1] + xc[0])), xc[1]),
                hl.tint32, (n, 0))

        for triangle in [triangle_with_ints, triangle_with_tuple]:
            assert_evals_to(triangle(20), sum(range(21)))
            assert_evals_to(triangle(0), 0)
            assert_evals_to(triangle(-1), 0)

        def fails_typecheck(regex, f):
            with self.assertRaisesRegex(TypeError, regex):
                hl.eval(hl.experimental.loop(f, hl.tint32, 1))

        fails_typecheck("outside of tail position", lambda f, x: x + f(x))
        fails_typecheck("wrong number of arguments", lambda f, x: f(x, x + 1))
        fails_typecheck("bound value", lambda f, x: hl.bind(lambda x: x, f(x)))
        fails_typecheck("branch condition",
                        lambda f, x: hl.if_else(f(x) == 0, x, 1))
        fails_typecheck("Type error",
                        lambda f, x: hl.if_else(x == 0, f("foo"), 1))
Exemplo n.º 4
0
 def allele_type(ref, alt):
     return hl.bind(
         lambda at: hl.if_else(
             at == allele_ints['SNP'],
             hl.if_else(hl.is_transition(ref, alt), allele_ints[
                 'Transition'], allele_ints['Transversion']), at),
         _num_allele_type(ref, alt))
Exemplo n.º 5
0
def impute_sex_aggregator(call,
                          aaf,
                          aaf_threshold=0.0,
                          include_par=False,
                          female_threshold=0.4,
                          male_threshold=0.8) -> hl.Table:
    """:func:`.impute_sex` as an aggregator."""
    mt = call._indices.source
    rg = mt.locus.dtype.reference_genome
    x_contigs = hl.literal(
        hl.eval(
            hl.map(lambda x_contig: hl.parse_locus_interval(x_contig, rg),
                   rg.x_contigs)))
    inbreeding = hl.agg.inbreeding(call, aaf)
    is_female = hl.if_else(
        inbreeding.f_stat < female_threshold, True,
        hl.if_else(inbreeding.f_stat > male_threshold, False,
                   hl.is_missing('tbool')))
    expression = hl.struct(is_female=is_female, **inbreeding)
    if not include_par:
        interval_type = hl.tarray(hl.tinterval(hl.tlocus(rg)))
        par_intervals = hl.literal(rg.par, interval_type)
        expression = hl.agg.filter(
            ~par_intervals.any(
                lambda par_interval: par_interval.contains(mt.locus)),
            expression)
    expression = hl.agg.filter(
        (aaf > aaf_threshold) & (aaf < (1 - aaf_threshold)), expression)
    expression = hl.agg.filter(
        x_contigs.any(lambda contig: contig.contains(mt.locus)), expression)

    return expression
Exemplo n.º 6
0
 def merge_alleles(alleles):
     from hail.expr.functions import _num_allele_type, _allele_ints
     return hl.rbind(
         alleles.map(lambda a: hl.or_else(a[0], ''))
         .fold(lambda s, t: hl.if_else(hl.len(s) > hl.len(t), s, t), ''),
         lambda ref:
         hl.rbind(
             alleles.map(
                 lambda al: hl.rbind(
                     al[0],
                     lambda r:
                     hl.array([ref]).extend(
                         al[1:].map(
                             lambda a:
                             hl.rbind(
                                 _num_allele_type(r, a),
                                 lambda at:
                                 hl.if_else(
                                     (_allele_ints['SNP'] == at)
                                     | (_allele_ints['Insertion'] == at)
                                     | (_allele_ints['Deletion'] == at)
                                     | (_allele_ints['MNP'] == at)
                                     | (_allele_ints['Complex'] == at),
                                     a + ref[hl.len(r):],
                                     a)))))),
             lambda lal:
             hl.struct(
                 globl=hl.array([ref]).extend(hl.array(hl.set(hl.flatten(lal)).remove(ref))),
                 local=lal)))
Exemplo n.º 7
0
def add_sample_annotations(mt: hl.MatrixTable,
                           annotations: str) -> hl.MatrixTable:
    # use annotations file to annotate VCF
    ann = hl.import_table(annotations,
                          impute=False,
                          types={
                              'Sample': hl.tstr,
                              'Sex': hl.tstr,
                              'Pheno': hl.tstr
                          }).key_by('Sample')
    ann_cols = dict(ann.row)

    mt = mt.annotate_cols(annotations=ann[mt.s])

    if 'is_female' not in mt.col:
        if 'Sex' in ann_cols:
            mt = mt.annotate_cols(is_female=hl.if_else((
                (mt.annotations.Sex == 'F') | (mt.annotations.Sex == str(2))
                | (mt.annotations.Sex == 'True')
                | (mt.annotations.Sex == 'Female')), True, False))
        else:
            print(
                'Sex column is missing from annotations file. Please add it and run GWASpy again'
            )
            sys.exit(2)

    if 'is_case' not in mt.col:
        if 'Pheno' in ann_cols:
            mt = mt.annotate_cols(is_case=hl.if_else((
                (mt.annotations.Pheno == str(2))
                | (mt.annotations.Pheno == 'True')
                | (mt.annotations.Pheno == 'Case')), True, False))

    return mt
Exemplo n.º 8
0
def add_global_af(ht: hl.Table, temp: str) -> hl.Table:
    '''
    Adds gnomAD global AF annotation to Table

    :param Table ht: Input Table
    :param str temp: Path to temp bucket (to store intermediary files)
    :return: Table with gnomAD global AF annotation
    :rtype: Table
    '''
    # checkpoint table after completing both gnomAD exomes and gnomAD genomes join
    temp_path = f'{temp}/join.ht'
    ht = ht.checkpoint(temp_path)

    # set gnomAD ACs and ANs to 0 if they are missing after the join
    ht = ht.transmute(
        gnomad_exomes_AC=hl.if_else(hl.is_defined(ht.gnomad_exomes_AC),
                                    ht.gnomad_exomes_AC, 0),
        gnomad_genomes_AC=hl.if_else(hl.is_defined(ht.gnomad_genomes_AC),
                                     ht.gnomad_genomes_AC, 0),
        gnomad_exomes_AN=hl.if_else(hl.is_defined(ht.gnomad_exomes_AN),
                                    ht.gnomad_exomes_AN, 0),
        gnomad_genomes_AN=hl.if_else(hl.is_defined(ht.gnomad_genomes_AN),
                                     ht.gnomad_genomes_AN, 0),
    )

    ht = ht.annotate(gnomad_global_AF=(
        hl.if_else(((ht.gnomad_exomes_AN == 0)
                    & (ht.gnomad_genomes_AN == 0)), 0.0,
                   hl.float((ht.gnomad_exomes_AC + ht.gnomad_genomes_AC) /
                            (ht.gnomad_exomes_AN + ht.gnomad_genomes_AN)))))
    ht.describe()
    return ht
Exemplo n.º 9
0
def conditional_phenotypes(mt: hl.MatrixTable,
                           column_field,
                           entry_field,
                           lists_of_columns,
                           new_col_name='grouping',
                           new_entry_name='new_entry'):
    """
    Create a conditional phenotype by setting phenotype1 to missing for any individual without phenotype2.

    Pheno1 Pheno2 new_pheno
    T      T      T
    T      F      NA
    F      F      NA
    F      T      F

    `lists_of_columns` should be a list of lists (of length 2 for the inner list).
    The first element corresponds to the phenotype to maintain, except for setting to missing when the
    phenotype coded by the second element is False.

    new_entry = Pheno1 conditioned on having Pheno2

    Example:

    mt = hl.balding_nichols_model(1, 3, 10).drop('GT')
    mt = mt.annotate_entries(pheno=hl.rand_bool(0.5))
    lists_of_columns = [[0, 1], [2, 1]]
    entry_field = mt.pheno
    column_field = mt.sample_idx

    :param MatrixTable mt: Input MatrixTable
    :param Expression column_field: Column-indexed Expression to group by
    :param Expression entry_field: Entry-indexed Expression to which to apply `grouping_function`
    :param list of list lists_of_columns: Entry in this list should be the same type as `column_field`
    :param str new_col_name: Name for new column key (default 'grouping')
    :param str new_entry_name: Name for new entry expression (default 'new_entry')
    :return: Re-grouped MatrixTable
    :rtype: MatrixTable
    """
    assert all([len(x) == 2 for x in lists_of_columns])
    lists_of_columns = hl.literal(lists_of_columns)
    mt = mt._annotate_all(col_exprs={'_col_expr': column_field},
                          entry_exprs={'_entry_expr': entry_field})
    mt = mt.annotate_cols(
        _col_expr=lists_of_columns.filter(lambda x: x.contains(
            mt._col_expr)).map(lambda y: (y, y[0] == mt._col_expr)))
    mt = mt.explode_cols('_col_expr')
    # if second element (~mt._col_expr[1]) is false (~mt._entry_expr), then return missing
    # otherwise, get actual element (either true if second element, or actual first element)
    bool_array = hl.agg.collect(
        hl.if_else(~mt._col_expr[1] & ~mt._entry_expr, hl.null(hl.tbool),
                   mt._entry_expr))
    # if any element is missing, return missing. otherwise return first element
    return mt.group_cols_by(**{
        new_col_name: mt._col_expr[0]
    }).aggregate(
        **{
            new_entry_name:
            hl.if_else(hl.any(lambda x: hl.is_missing(x), bool_array),
                       hl.null(hl.tbool), bool_array[0] & bool_array[1])
        })
Exemplo n.º 10
0
 def test_plot_roc_curve(self):
     x = hl.utils.range_table(100).annotate(score1=hl.rand_norm(),
                                            score2=hl.rand_norm())
     x = x.annotate(tp=hl.if_else(x.score1 > 0, hl.rand_bool(0.7), False),
                    score3=x.score1 + hl.rand_norm())
     ht = x.annotate(fp=hl.if_else(~x.tp, hl.rand_bool(0.2), False))
     _, aucs = hl.experimental.plot_roc_curve(
         ht, ['score1', 'score2', 'score3'])
Exemplo n.º 11
0
def export_ma_format(batch_size=256):
    r'''
    Export columns for .ma format (A1, A2, freq, beta, se, N) for select phenotypes
    '''
    meta_mt0 = hl.read_matrix_table(get_meta_analysis_results_path())
    
    highprev = hl.import_table(f'{ldprune_dir}/joined_ukbb_lancet_age_high_prev.tsv', impute=True)
    highprev = highprev.annotate(pheno = highprev.code.replace('_irnt',''))
    pheno_list = highprev.pheno.collect()
    pheno_list = [p for p in pheno_list if p is not None]
    meta_mt0 = meta_mt0.filter_cols(hl.literal(pheno_list).contains(meta_mt0.pheno))

    meta_mt0 = meta_mt0.annotate_cols(pheno_id = (meta_mt0.trait_type+'-'+
                                      meta_mt0.phenocode+'-'+
                                      meta_mt0.pheno_sex+
                                      hl.if_else(hl.len(meta_mt0.coding)>0, '-'+meta_mt0.coding, '')+
                                      hl.if_else(hl.len(meta_mt0.modifier)>0, '-'+meta_mt0.modifier, '')
                                      ).replace(' ','_').replace('/','_'))
    
    meta_mt0 = meta_mt0.annotate_rows(SNP = meta_mt0.locus.contig+':'+hl.str(meta_mt0.locus.position)+':'+meta_mt0.alleles[0]+':'+meta_mt0.alleles[1],
                                      A1 = meta_mt0.alleles[1], # .ma format requires A1 = effect allele, which in this case is A2 for UKB GWAS
                                      A2 = meta_mt0.alleles[0])

    meta_field_rename_dict = {'BETA':'b',
                          'SE':'se',
                          'Pvalue':'p',
                          'AF_Allele2':'freq',
                          'N':'N'}

    for pop in ['AFR','EUR']: #['AFR','AMR','CSA','EAS','EUR','MID']:
        print(f'not_{pop}')

        req_pop_list = [p for p in POPS if p is not pop]

        loo_pop = meta_mt0.annotate_cols(idx = meta_mt0.meta_analysis_data.pop.index(hl.literal(req_pop_list))) # get index of which meta-analysis is the leave-on-out for current pop
        loo_pop = loo_pop.filter_cols(hl.is_defined(loo_pop.idx))
        
        annotate_dict = {meta_field_rename_dict[field]: loo_pop.meta_analysis[field][loo_pop.idx] for field in  ['AF_Allele2','BETA','SE','Pvalue','N']} 
        batch_idx = 1

    export_out = f'{ldprune_dir}/loo/not_{pop}/batch{batch_idx}'
    while hl.hadoop_is_dir(export_out):
        batch_idx += 1
        export_out = f'{ldprune_dir}/loo/not_{pop}/batch{batch_idx}'
    checkpoint_path = f'gs://ukbb-diverse-temp-30day/loo/not_{pop}/batch{batch_idx}.mt'
#        print(f'\nCheckpointing to: {checkpoint_path}\n')
    loo_pop = loo_pop.checkpoint(checkpoint_path,
                                 _read_if_exists=True,
                                 overwrite=True)
    loo_pop = loo_pop.filter_entries(hl.is_defined(loo_pop.b))
    print(f'\nExporting to: {export_out}\n')
    hl.experimental.export_entries_by_col(mt = loo_pop,
                                          path = export_out,
                                          bgzip = True,
                                          batch_size = batch_size,
                                          use_string_key_as_file_name = True,
                                          header_json_in_file = False)
Exemplo n.º 12
0
def parse_as_ranksum(string, has_non_ref):
    typ = hl.ttuple(hl.tfloat64, hl.tint32)
    items = string.split(r'\|')
    items = hl.if_else(has_non_ref, items[:-1], items)
    return items.map(lambda s: hl.if_else(
        (hl.len(s) == 0) | (s == '.'),
        hl.missing(typ),
        hl.rbind(s.split(','), lambda ss: hl.if_else(
            hl.len(ss) != 2,  # bad field, possibly 'NaN', just set it null
            hl.missing(hl.ttuple(hl.tfloat64, hl.tint32)),
            hl.tuple([hl.float64(ss[0]), hl.int32(ss[1])])))))
Exemplo n.º 13
0
def add_rank(
    ht: hl.Table,
    score_expr: hl.expr.NumericExpression,
    subrank_expr: Optional[Dict[str, hl.expr.BooleanExpression]] = None,
) -> hl.Table:
    """
    Add rank based on the `score_expr`. Rank is added for snvs and indels separately.

    If one or more `subrank_expr` are provided, then subrank is added based on all sites for which the boolean expression is true.

    In addition, variant counts (snv, indel separately) is added as a global (`rank_variant_counts`).

    :param ht: input Hail Table containing variants (with QC annotations) to be ranked
    :param score_expr: the Table annotation by which ranking should be scored
    :param subrank_expr: Any subranking to be added in the form name_of_subrank: subrank_filtering_expr
    :return: Table with rankings added
    """
    key = ht.key
    if subrank_expr is None:
        subrank_expr = {}

    temp_expr = {"_score": score_expr}
    temp_expr.update({f"_{name}": expr for name, expr in subrank_expr.items()})
    rank_ht = ht.select(**temp_expr,
                        is_snv=hl.is_snp(ht.alleles[0], ht.alleles[1]))

    rank_ht = rank_ht.key_by("_score").persist()
    scan_expr = {
        "rank":
        hl.if_else(
            rank_ht.is_snv,
            hl.scan.count_where(rank_ht.is_snv),
            hl.scan.count_where(~rank_ht.is_snv),
        )
    }
    scan_expr.update({
        name: hl.or_missing(
            rank_ht[f"_{name}"],
            hl.if_else(
                rank_ht.is_snv,
                hl.scan.count_where(rank_ht.is_snv & rank_ht[f"_{name}"]),
                hl.scan.count_where(~rank_ht.is_snv & rank_ht[f"_{name}"]),
            ),
        )
        for name in subrank_expr
    })
    rank_ht = rank_ht.annotate(**scan_expr)

    rank_ht = rank_ht.key_by(*key).persist()
    rank_ht = rank_ht.select(*scan_expr.keys())

    ht = ht.annotate(**rank_ht[key])
    return ht
Exemplo n.º 14
0
            def with_pl(pl):
                new_exprs = {}
                dropped_fields = ['LA']
                if 'LGT' in fields:
                    new_exprs['GT'] = hl.rbind(
                        old_entry.LGT, lambda lgt: hl.if_else(
                            lgt.is_non_ref(),
                            hl.downcode(
                                lgt,
                                hl.or_else(local_a_index, hl.len(old_entry.LA))
                            ), lgt))
                    dropped_fields.append('LGT')
                if 'LPGT' in fields:
                    new_exprs['PGT'] = hl.rbind(
                        old_entry.LPGT, lambda lpgt: hl.if_else(
                            lpgt.is_non_ref(),
                            hl.downcode(
                                lpgt,
                                hl.or_else(local_a_index, hl.len(old_entry.LA))
                            ), lpgt))
                    dropped_fields.append('LPGT')
                if 'LAD' in fields:
                    non_ref_ad = hl.or_else(old_entry.LAD[local_a_index],
                                            0)  # zeroed if not in LAD
                    new_exprs['AD'] = hl.or_missing(
                        hl.is_defined(old_entry.LAD),
                        [hl.sum(old_entry.LAD) - non_ref_ad, non_ref_ad])
                    dropped_fields.append('LAD')
                if 'LPL' in fields:
                    new_exprs['PL'] = pl
                    if 'GQ' in fields:
                        new_exprs['GQ'] = hl.or_else(hl.gq_from_pl(pl),
                                                     old_entry.GQ)

                    dropped_fields.append('LPL')

                return (hl.case().when(
                    hl.len(ds.alleles) == 1,
                    old_entry.annotate(
                        **{
                            f[1:]: old_entry[f]
                            for f in ['LGT', 'LPGT', 'LAD', 'LPL']
                            if f in fields
                        }).drop(*dropped_fields)).when(
                            hl.or_else(old_entry.LGT.is_hom_ref(), False),
                            old_entry.annotate(
                                **{
                                    f: old_entry[f'L{f}'] if f in
                                    ['GT', 'PGT'] else e
                                    for f, e in new_exprs.items()
                                }).drop(*dropped_fields)).default(
                                    old_entry.annotate(**new_exprs).drop(
                                        *dropped_fields)))
Exemplo n.º 15
0
 def f(base):
     # build cond chain bottom-up
     if default is self._base:
         expr = base
     else:
         expr = default
     for value, then in self._cases[::-1]:
         expr = hl.if_else(base == value, then, expr)
     # needs to be on the outside, because upstream missingness would propagate
     if self._when_missing_case is not None:
         expr = hl.if_else(hl.is_missing(base), self._when_missing_case,
                           expr)
     return expr
Exemplo n.º 16
0
def make_pheno_manifest():
    mt0 = load_final_sumstats_mt(filter_sumstats=False,
                                 filter_variants=False,
                                 separate_columns_by_pop=False,
                                 annotate_with_nearest_gene=False)
    ht = mt0.cols()
    annotate_dict = {}

    annotate_dict.update({
        'pops': hl.delimit(ht.pheno_data.pop),
        'num_pops': hl.len(ht.pheno_data.pop)
    })

    for field in ['n_cases', 'n_controls', 'heritability', 'lambda_gc']:
        for pop in ['AFR', 'AMR', 'CSA', 'EAS', 'EUR', 'MID']:
            new_field = field if field != 'heritability' else 'saige_heritability'  # new field name (only applicable to saige heritability)
            idx = ht.pheno_data.pop.index(pop)
            field_expr = ht.pheno_data[field]
            annotate_dict.update({
                f'{new_field}_{pop}':
                hl.if_else(hl.is_nan(idx), hl.null(field_expr[0].dtype),
                           field_expr[idx])
            })
    annotate_dict.update({
        'filename':
        (ht.trait_type + '-' + ht.phenocode + '-' + ht.pheno_sex +
         hl.if_else(hl.len(ht.coding) > 0, '-' + ht.coding, '') +
         hl.if_else(hl.len(ht.modifier) > 0, '-' + ht.modifier, '')).replace(
             ' ', '_').replace('/', '_') + '.tsv.bgz'
    })
    ht = ht.annotate(**annotate_dict)
    aws_bucket = 'https://pan-ukb-us-east-1.s3.amazonaws.com/sumstats_release'
    ht = ht.annotate(aws_link=aws_bucket + '/' + ht.filename,
                     aws_link_tabix=aws_bucket + '_tabix/' + ht.filename +
                     '.tbi')

    other_fields_ht = hl.import_table(
        f'{ldprune_dir}/release/md5_hex_and_file_size.tsv.bgz',
        force_bgz=True,
        key=PHENO_KEY_FIELDS)
    other_fields = [
        'size_in_bytes', 'size_in_bytes_tabix', 'md5_hex', 'md5_hex_tabix'
    ]

    ht = ht.annotate(wget='wget ' + ht.aws_link,
                     wget_tabix='wget ' + ht.aws_link_tabix,
                     **{f: other_fields_ht[ht.key][f]
                        for f in other_fields})

    ht = ht.drop('pheno_data', 'pheno_indices')
    ht.export(f'{bucket}/combined_results/phenotype_manifest.tsv.bgz')
Exemplo n.º 17
0
def remap_samples(
    original_mt_path: str,
    input_mt: hl.MatrixTable,
    pedigree: hl.Table,
    inferred_sex: str,
) -> Tuple[hl.MatrixTable, hl.Table]:
    """
    Rename `s` col in the MatrixTable and inferred sex ht.

    :param original_mt_path: Path to original MatrixTable location
    :param input_mt: MatrixTable 
    :param pedigree: Pedigree file from seqr loaded as a Hail Table
    :param inferred_sex: Path to text file of inferred sexes
    :return: mt and sex ht with sample names remapped
    """
    base_path = "/".join(
        dirname(original_mt_path).split("/")[:-1]) + ("/base/projects")
    project_list = list(set(pedigree.Project_GUID.collect()))

    # Get the list of hts containing sample remapping information for each project
    remap_hts = []

    logger.info("Found %d projects that need to be remapped.", len(remap_hts))
    sex_ht = hl.import_table(inferred_sex)

    for i in project_list:
        remap = f"{base_path}/{i}/{i}_remap.tsv"
        if hl.hadoop_is_file(remap):
            remap_ht = hl.import_table(remap)
            remap_ht = remap_ht.key_by("s", "seqr_id")
            remap_hts.append(remap_ht)

    if len(remap_hts) > 0:
        ht = remap_hts[0]
        for next_ht in remap_hts[1:]:
            ht = ht.join(next_ht, how="outer")

        # If a sample has a non-missing value for seqr_id, rename it to the sample name for the mt and sex ht
        ht = ht.key_by("s")
        input_mt = input_mt.annotate_cols(seqr_id=ht[input_mt.s].seqr_id)
        input_mt = input_mt.key_cols_by(s=hl.if_else(
            hl.is_missing(input_mt.seqr_id), input_mt.s, input_mt.seqr_id))

        sex_ht = sex_ht.annotate(seqr_id=ht[sex_ht.s].seqr_id).key_by("s")
        sex_ht = sex_ht.key_by(s=hl.if_else(hl.is_missing(sex_ht.seqr_id),
                                            sex_ht.s, sex_ht.seqr_id))
    else:
        sex_ht = sex_ht.key_by("s")

    return input_mt, sex_ht
Exemplo n.º 18
0
def hom_alt_depletion_fix(
    mt: hl.MatrixTable,
    het_non_ref_expr: hl.expr.BooleanExpression,
    af_expr: hl.expr.Float64Expression,
    af_cutoff: float = 0.01,
    ab_cutoff: float = 0.9,
) -> hl.MatrixTable:
    """
    Adjust MT genotypes with temporary fix for the depletion of homozygous alternate genotypes.
    
    More details about the problem can be found on the gnomAD blog:
    https://gnomad.broadinstitute.org/blog/2020-10-gnomad-v3-1-new-content-methods-annotations-and-data-availability/#tweaks-and-updates
    
    :param mt: Input MT that needs hom alt genotype fix
    :param het_non_ref_expr: Expression indicating whether the original genotype (pre split multi) is het non ref
    :param af_expr: Allele frequency expression to determine which variants need the hom alt fix
    :param af_cutoff: Allele frequency cutoff for variants that need the hom alt fix. Default is 0.01
    :param ab_cutoff: Allele balance cutoff to determine which genotypes need the hom alt fix. Default is 0.9
    :return: MatrixTable with genotypes adjusted for the hom alt depletion fix
    """
    return mt.annotate_entries(GT=hl.if_else(
        mt.GT.is_het()
        # Skip adjusting genotypes if sample originally had a het nonref genotype
        & ~het_non_ref_expr
        & (af_expr > af_cutoff)
        & (mt.AD[1] / mt.DP > ab_cutoff),
        hl.call(1, 1),
        mt.GT,
    ))
Exemplo n.º 19
0
def annotate_variants_with_mnvs(variants_path, mnvs_path):
    ds = hl.read_table(mnvs_path)

    ds = ds.select("changes_amino_acids_for_snvs", "constituent_snvs", "constituent_snv_ids", "n_individuals",)

    ds = ds.explode(ds.constituent_snvs, "snv")
    ds = ds.annotate(
        locus=hl.locus(ds.snv.chrom, ds.snv.pos, reference_genome="GRCh37"), alleles=[ds.snv.ref, ds.snv.alt]
    )
    ds = ds.group_by(ds.locus, ds.alleles).aggregate(multi_nucleotide_variants=hl.agg.collect(ds.row.drop("snv")))

    variants = hl.read_table(variants_path)

    variants = variants.annotate(multi_nucleotide_variants=ds[variants.key].multi_nucleotide_variants)
    variants = variants.annotate(
        flags=hl.if_else(
            hl.len(variants.multi_nucleotide_variants) > 0,
            variants.flags.add("mnv"),
            variants.flags,
            missing_false=True,
        ),
        multi_nucleotide_variants=variants.multi_nucleotide_variants.map(
            lambda mnv: mnv.select(
                combined_variant_id=mnv.variant_id,
                changes_amino_acids=mnv.changes_amino_acids_for_snvs.contains(variants.variant_id),
                n_individuals=mnv.n_individuals,
                other_constituent_snvs=mnv.constituent_snv_ids.filter(lambda snv_id: snv_id != variants.variant_id),
            )
        ),
    )

    return variants
Exemplo n.º 20
0
def apply_mito_artifact_filter(
    mt: hl.MatrixTable,
    artifact_prone_sites_path: str,
) -> hl.MatrixTable:
    """Add back in artifact_prone_site filter

    :param hl.MatrixTable mt: MatrixTable to use an input
    :param str artifact_prone_sites_path: path to BED file of artifact_prone_sites to flag in the filters column

    :return: MatrixTable with artifact_prone_sites filter
    :rtype: hl.MatrixTable

    """

    # apply "artifact_prone_site" filter to any SNP or deletion that spans a known problematic site
    mt = mt.annotate_rows(
        position_range=hl.range(mt.locus.position, mt.locus.position +
                                hl.len(mt.alleles[0])))

    artifact_sites = []
    with hl.hadoop_open(artifact_prone_sites_path) as f:
        for line in f:
            pos = line.split()[2]
            artifact_sites.append(int(pos))
    sites = hl.literal(set(artifact_sites))

    mt = mt.annotate_rows(filters=hl.if_else(
        hl.len(hl.set(mt.position_range).intersection(sites)) > 0,
        {"artifact_prone_site"},
        {"PASS"},
    ))

    mt = mt.drop("position_range")

    return mt
Exemplo n.º 21
0
def remove_FT_values(
    mt: hl.MatrixTable,
    filters_to_remove: list = [
        'possible_numt', 'mt_many_low_hets', 'FAIL', 'blacklisted_site'
    ]
) -> hl.MatrixTable:
    """Removes the FT filters specified in filters_to_remove
    
    By default, this function removes the 'possible_numt', 'mt_many_low_hets', and 'FAIL' filters (because these filters were found to have low performance), 
    and the 'blacklisted_site' filter because this filter did not always behave as expected in early GATK versions (can be replaced with apply_mito_artifact_filter function)

    :param hl.MatrixTable mt:  MatrixTable
    :param list filters_to_remove: list of FT filters that should be removed from the entries
    
    :return: MatrixTable with certain FT filters removed
    :rtype: MatrixTable
    """

    filters_to_remove = hl.set(filters_to_remove)
    mt = mt.annotate_entries(
        FT=hl.array((mt.FT).difference(filters_to_remove)))

    # if no filters exists after removing those specified above, set the FT field to PASS
    mt = mt.annotate_entries(
        FT=hl.if_else(hl.len(mt.FT) == 0, ["PASS"], mt.FT))

    return (mt)
Exemplo n.º 22
0
    def _get_most_severe_csq(csq_list: hl.expr.ArrayExpression,
                             protein_coding: bool) -> hl.expr.StructExpression:
        """
        Processes VEP consequences to generate summary annotations.

        :param csq_list: VEP consequences list to be processed.
        :param protein_coding: Whether variant is in a protein-coding transcript.
        :return: Struct containing summary annotations.
        """
        lof = hl.null(hl.tstr)
        no_lof_flags = hl.null(hl.tbool)
        if protein_coding:
            all_lofs = csq_list.map(lambda x: x.lof)
            lof = hl.literal(loftee_labels).find(
                lambda x: all_lofs.contains(x))
            csq_list = hl.if_else(hl.is_defined(lof),
                                  csq_list.filter(lambda x: x.lof == lof),
                                  csq_list)
            no_lof_flags = hl.or_missing(
                hl.is_defined(lof),
                csq_list.any(lambda x:
                             (x.lof == lof) & hl.is_missing(x.lof_flags)),
            )
        all_csq_terms = csq_list.flatmap(lambda x: x.consequence_terms)
        most_severe_csq = hl.literal(csq_order).find(
            lambda x: all_csq_terms.contains(x))
        return hl.struct(
            most_severe_csq=most_severe_csq,
            protein_coding=protein_coding,
            lof=lof,
            no_lof_flags=no_lof_flags,
        )
Exemplo n.º 23
0
def subset_mt(project_guid,
              mt,
              skip_sample_subset=False,
              ignore_missing_samples=False):
    if not skip_sample_subset:
        sample_subset = get_sample_subset(project_guid, WGS_SAMPLE_TYPE)
        found_samples = sample_subset.intersection(
            mt.aggregate_cols(hl.agg.collect_as_set(mt.s)))
        if len(found_samples) != len(sample_subset):
            missed_samples = sample_subset - found_samples
            missing_sample_message = 'Missing the following {} samples:\n{}'.format(
                len(missed_samples), ', '.join(sorted(missed_samples)))
            if ignore_missing_samples:
                logger.info(missing_sample_message)
            else:
                logger.error(missing_sample_message)
                raise Exception(missing_sample_message)

        sample_remap = get_sample_remap(project_guid, WGS_SAMPLE_TYPE)
        message = 'Subsetting to {} samples'.format(len(sample_subset))
        if sample_remap:
            message += ' (remapping {} samples)'.format(len(sample_remap))
            sample_subset = sample_subset - set(sample_remap.keys())
            sample_subset.update(set(sample_remap.values()))
            mt = mt.key_cols_by()
            sample_remap = hl.literal(sample_remap)
            mt = mt.annotate_cols(s=hl.if_else(sample_remap.contains(mt.s),
                                               sample_remap[mt.s], mt.s))
        logger.info(message)

        mt = mt.filter_cols(hl.literal(sample_subset).contains(mt.s))

    return mt.filter_rows(hl.agg.any(mt.GT.is_non_ref()))
Exemplo n.º 24
0
def compute_prs_mt(genotype_mt_path, prs_mt_path):
    scratch_dir = 'gs://ukbb-diverse-temp-30day/nb-scratch'

    clumped = hl.read_table(
        'gs://ukb-diverse-pops/ld_prune/results_high_quality/not_AMR/phecode-250.2-both_sexes/clump_results.ht/'
    )
    sumstats = hl.import_table(
        'gs://ukb-diverse-pops/sumstats_flat_files/phecode-250.2-both_sexes.tsv.bgz',
        impute=True)
    sumstats = sumstats.annotate(locus=hl.locus(sumstats.chr, sumstats.pos),
                                 alleles=hl.array([sumstats.ref,
                                                   sumstats.alt]))
    sumstats = sumstats.key_by('locus', 'alleles')
    sumstats.describe()
    #    mt = hl.read_matrix_table(genotype_mt_path) # read genotype mt subset

    # get full genotype mt
    meta_mt = hl.read_matrix_table(get_meta_analysis_results_path())
    mt = get_filtered_mt_with_x()
    mt = mt.filter_rows(hl.is_defined(meta_mt.rows()[mt.row_key]))
    mt = mt.select_entries('dosage')
    mt = mt.select_rows()
    mt = mt.select_cols()

    mt = mt.annotate_rows(beta=hl.if_else(hl.is_defined(clumped[mt.row_key]),
                                          sumstats[mt.row_key].beta_meta, 0))
    mt = mt.annotate_cols(score=hl.agg.sum(mt.beta * mt.dosage))
    mt_cols = mt.cols()
    mt_cols = mt_cols.repartition(1000)
    mt_cols.write(f'{scratch_dir}/prs_all_samples.ht')
Exemplo n.º 25
0
def _spectral_moments(A,
                      num_moments,
                      p=None,
                      moment_samples=500,
                      block_size=128):
    if not isinstance(A, TallSkinnyMatrix):
        check_entry_indexed('_spectral_moments/entry_expr', A)
        A = _make_tsm_from_call(A, block_size)

    n = A.ncols

    if p is None:
        p = min(num_moments // 2, 10)

    # TODO: When moment_samples > n, we should just do a TSQR on A, and compute
    # the spectrum of R.
    assert moment_samples < n, '_spectral_moments: moment_samples must be smaller than num cols of A'
    G = hl.nd.zeros(
        (n,
         moment_samples)).map(lambda n: hl.if_else(hl.rand_bool(0.5), -1, 1))
    Q1, R1 = hl.nd.qr(G)._persist()
    fact = _krylov_factorization(A, Q1, p, compute_U=False)
    moments_and_stdevs = hl.eval(fact.spectral_moments(num_moments, R1))
    moments = moments_and_stdevs.moments
    stdevs = moments_and_stdevs.stdevs
    return moments, stdevs
def pre_process_subset_freq(subset: str,
                            global_ht: hl.Table,
                            test: bool = False) -> hl.Table:
    """
    Prepare subset frequency Table by filling in missing frequency fields for loci present only in the global cohort.

    .. note::

        The resulting final `freq` array will be as long as the subset `freq_meta` global (i.e., one `freq` entry for each `freq_meta` entry)

    :param subset: subset ID
    :param global_ht: Hail Table containing all variants discovered in the overall release cohort
    :param test: If True, filter to small region on chr20
    :return: Table containing subset frequencies with missing freq structs filled in
    """

    # Read in subset HTs
    subset_ht_path = get_freq(subset=subset).path
    subset_chr20_ht_path = qc_temp_prefix() + f"chr20_test_freq.{subset}.ht"

    if test:
        if file_exists(subset_chr20_ht_path):
            logger.info(
                "Loading chr20 %s subset frequency data for testing: %s",
                subset,
                subset_chr20_ht_path,
            )
            subset_ht = hl.read_table(subset_chr20_ht_path)

        elif file_exists(subset_ht_path):
            logger.info(
                "Loading %s subset frequency data for testing: %s",
                subset,
                subset_ht_path,
            )
            subset_ht = hl.read_table(subset_ht_path)
            subset_ht = hl.filter_intervals(
                subset_ht, [hl.parse_locus_interval("chr20:1-1000000")])

    elif file_exists(subset_ht_path):
        logger.info("Loading %s subset frequency data: %s", subset,
                    subset_ht_path)
        subset_ht = hl.read_table(subset_ht_path)

    else:
        raise DataException(
            f"Hail Table containing {subset} subset frequencies not found. You may need to run the script generate_freq_data.py to generate frequency annotations first."
        )

    # Fill in missing freq structs
    ht = subset_ht.join(global_ht.select().select_globals(), how="right")
    ht = ht.annotate(freq=hl.if_else(
        hl.is_missing(ht.freq),
        hl.map(lambda x: missing_callstats_expr(),
               hl.range(hl.len(ht.freq_meta))),
        ht.freq,
    ))

    return ht
def merge_overlapping_regions(regions):
    return hl.if_else(
        hl.len(regions) > 1,
        hl.rbind(
            hl.sorted(regions, lambda region: region.start),
            lambda sorted_regions: sorted_regions[1:].fold(
                lambda acc, region: hl.if_else(
                    region.start <= acc[-1].stop + 1,
                    acc[:-1].append(acc[-1].annotate(stop=hl.max(
                        region.stop, acc[-1].stop))),
                    acc.append(region),
                ),
                [sorted_regions[0]],
            ),
        ),
        regions,
    )
Exemplo n.º 28
0
def split_position_end(position):
    return hl.or_missing(
        hl.is_defined(position),
        hl.bind(
            lambda start: hl.if_else(start == "?", hl.null(hl.tint),
                                     hl.int(start)),
            position.split("-")[-1]),
    )
Exemplo n.º 29
0
def export_loo(batch_size=256):
    r'''
    For exporting p-values of meta-analysis of leave-one-out population sets
    '''
    meta_mt0 = hl.read_matrix_table(get_meta_analysis_results_path())

    meta_mt0 = meta_mt0.filter_cols(hl.len(meta_mt0.pheno_data.pop) == 6)

    meta_mt0 = meta_mt0.annotate_cols(pheno_id=(
        meta_mt0.trait_type + '-' + meta_mt0.phenocode + '-' +
        meta_mt0.pheno_sex +
        hl.if_else(hl.len(meta_mt0.coding) > 0, '-' + meta_mt0.coding, '') +
        hl.if_else(hl.len(meta_mt0.modifier) > 0, '-' +
                   meta_mt0.modifier, '')).replace(' ', '_').replace('/', '_'))

    meta_mt0 = meta_mt0.annotate_rows(
        SNP=(meta_mt0.locus.contig + ':' + hl.str(meta_mt0.locus.position) +
             ':' + meta_mt0.alleles[0] + ':' + meta_mt0.alleles[1]))

    all_pops = sorted(['AFR', 'AMR', 'CSA', 'EAS', 'EUR', 'MID'])

    annotate_dict = {}
    for pop_idx, pop in enumerate(
            all_pops, 1
    ):  # pop idx corresponds to the alphabetic ordering of the pops (entry with idx=0 is 6-pop meta-analysis)
        annotate_dict.update(
            {f'pval_not_{pop}': meta_mt0.meta_analysis.Pvalue[pop_idx]})
    meta_mt1 = meta_mt0.annotate_entries(**annotate_dict)

    meta_mt1 = meta_mt1.key_cols_by('pheno_id')
    meta_mt1 = meta_mt1.key_rows_by().drop('locus', 'alleles', 'gene',
                                           'annotation', 'meta_analysis')

    print(meta_mt1.describe())

    batch_idx = 1
    get_export_path = lambda batch_idx: f'{ldprune_dir}/loo/sumstats/batch{batch_idx}'
    while hl.hadoop_is_dir(get_export_path(batch_idx)):
        batch_idx += 1
    print(f'\nExporting to: {get_export_path(batch_idx)}\n')
    hl.experimental.export_entries_by_col(mt=meta_mt1,
                                          path=get_export_path(batch_idx),
                                          bgzip=True,
                                          batch_size=batch_size,
                                          use_string_key_as_file_name=True,
                                          header_json_in_file=False)
Exemplo n.º 30
0
def make_pheno_manifest(export=True):
    mt0 = load_final_sumstats_mt(filter_sumstats=False,
                                 filter_variants=False,
                                 separate_columns_by_pop=False,
                                 annotate_with_nearest_gene=False)

    ht = mt0.cols()
    annotate_dict = {}

    annotate_dict.update({
        'pops': hl.delimit(ht.pheno_data.pop),
        'num_pops': hl.len(ht.pheno_data.pop)
    })

    for field in ['n_cases', 'n_controls', 'heritability', 'lambda_gc']:
        for pop in ['AFR', 'AMR', 'CSA', 'EAS', 'EUR', 'MID']:
            new_field = field if field != 'heritability' else 'saige_heritability'  # new field name (only applicable to saige heritability)
            idx = ht.pheno_data.pop.index(pop)
            field_expr = ht.pheno_data[field]
            annotate_dict.update({
                f'{new_field}_{pop}':
                hl.if_else(hl.is_nan(idx), hl.null(field_expr[0].dtype),
                           field_expr[idx])
            })
    annotate_dict.update({'filename': get_pheno_id(tb=ht) + '.tsv.bgz'})
    ht = ht.annotate(**annotate_dict)

    dropbox_manifest = hl.import_table(
        f'{ldprune_dir}/UKBB_Pan_Populations-Manifest_20200615-manifest_info.tsv',
        impute=True,
        key='File')
    dropbox_manifest = dropbox_manifest.filter(
        dropbox_manifest['is_old_file'] != '1')
    bgz = dropbox_manifest.filter(~dropbox_manifest.File.contains('.tbi'))
    bgz = bgz.rename({'File': 'filename'})
    tbi = dropbox_manifest.filter(dropbox_manifest.File.contains('.tbi'))
    tbi = tbi.annotate(
        filename=tbi.File.replace('.tbi', '')).key_by('filename')

    dropbox_annotate_dict = {}

    rename_dict = {
        'dbox link': 'dropbox_link',
        'size (bytes)': 'size_in_bytes'
    }

    dropbox_annotate_dict.update({'filename_tabix': tbi[ht.filename].File})
    for field in ['dbox link', 'wget', 'size (bytes)', 'md5 hex']:
        for tb, suffix in [(bgz, ''), (tbi, '_tabix')]:
            dropbox_annotate_dict.update({
                (rename_dict[field] if field in rename_dict else field.replace(
                     ' ', '_')) + suffix:
                tb[ht.filename][field]
            })
    ht = ht.annotate(**dropbox_annotate_dict)
    ht = ht.drop('pheno_data')
    ht.describe()
    ht.show()