예제 #1
0
파일: test_qc.py 프로젝트: vedasha/hail
    def test_concordance(self):
        dataset = get_dataset()
        glob_conc, cols_conc, rows_conc = hl.concordance(dataset, dataset)

        self.assertEqual(sum([sum(glob_conc[i]) for i in range(5)]), dataset.count_rows() * dataset.count_cols())

        counts = dataset.aggregate_entries(hl.Struct(n_het=agg.filter(dataset.GT.is_het(), agg.count()),
                                                     n_hom_ref=agg.filter(dataset.GT.is_hom_ref(),
                                                                          agg.count()),
                                                     n_hom_var=agg.filter(dataset.GT.is_hom_var(),
                                                                          agg.count()),
                                                     nNoCall=agg.filter(hl.is_missing(dataset.GT),
                                                                        agg.count())))

        self.assertEqual(glob_conc[0][0], 0)
        self.assertEqual(glob_conc[1][1], counts.nNoCall)
        self.assertEqual(glob_conc[2][2], counts.n_hom_ref)
        self.assertEqual(glob_conc[3][3], counts.n_het)
        self.assertEqual(glob_conc[4][4], counts.n_hom_var)
        [self.assertEqual(glob_conc[i][j], 0) for i in range(5) for j in range(5) if i != j]

        self.assertTrue(cols_conc.all(hl.sum(hl.flatten(cols_conc.concordance)) == dataset.count_rows()))
        self.assertTrue(rows_conc.all(hl.sum(hl.flatten(rows_conc.concordance)) == dataset.count_cols()))

        cols_conc.write('/tmp/foo.kt', overwrite=True)
        rows_conc.write('/tmp/foo.kt', overwrite=True)
예제 #2
0
파일: test_qc.py 프로젝트: bcajes/hail
    def test_concordance(self):
        dataset = get_dataset()
        glob_conc, cols_conc, rows_conc = hl.concordance(dataset, dataset)

        self.assertEqual(sum([sum(glob_conc[i]) for i in range(5)]), dataset.count_rows() * dataset.count_cols())

        counts = dataset.aggregate_entries(hl.Struct(n_het=agg.filter(dataset.GT.is_het(), agg.count()),
                                                     n_hom_ref=agg.filter(dataset.GT.is_hom_ref(),
                                                                          agg.count()),
                                                     n_hom_var=agg.filter(dataset.GT.is_hom_var(),
                                                                          agg.count()),
                                                     nNoCall=agg.filter(hl.is_missing(dataset.GT),
                                                                        agg.count())))

        self.assertEqual(glob_conc[0][0], 0)
        self.assertEqual(glob_conc[1][1], counts.nNoCall)
        self.assertEqual(glob_conc[2][2], counts.n_hom_ref)
        self.assertEqual(glob_conc[3][3], counts.n_het)
        self.assertEqual(glob_conc[4][4], counts.n_hom_var)
        [self.assertEqual(glob_conc[i][j], 0) for i in range(5) for j in range(5) if i != j]

        self.assertTrue(cols_conc.all(hl.sum(hl.flatten(cols_conc.concordance)) == dataset.count_rows()))
        self.assertTrue(rows_conc.all(hl.sum(hl.flatten(rows_conc.concordance)) == dataset.count_cols()))

        cols_conc.write('/tmp/foo.kt', overwrite=True)
        rows_conc.write('/tmp/foo.kt', overwrite=True)
예제 #3
0
파일: vcf_combiner.py 프로젝트: bcajes/hail
def combine(ts):
    # pylint: disable=protected-access
    tmp = ts.annotate(
        alleles=merge_alleles(ts.data.map(lambda d: d.alleles)),
        rsid=hl.find(hl.is_defined, ts.data.map(lambda d: d.rsid)),
        filters=hl.set(hl.flatten(ts.data.map(lambda d: hl.array(d.filters)))),
        info=hl.struct(
            DP=hl.sum(ts.data.map(lambda d: d.info.DP)),
            MQ_DP=hl.sum(ts.data.map(lambda d: d.info.MQ_DP)),
            QUALapprox=hl.sum(ts.data.map(lambda d: d.info.QUALapprox)),
            RAW_MQ=hl.sum(ts.data.map(lambda d: d.info.RAW_MQ)),
            VarDP=hl.sum(ts.data.map(lambda d: d.info.VarDP)),
            SB=hl.array([
                hl.sum(ts.data.map(lambda d: d.info.SB[0])),
                hl.sum(ts.data.map(lambda d: d.info.SB[1])),
                hl.sum(ts.data.map(lambda d: d.info.SB[2])),
                hl.sum(ts.data.map(lambda d: d.info.SB[3]))
            ])))
    tmp = tmp.annotate(
        __entries=hl.bind(
            lambda combined_allele_index:
            hl.range(0, hl.len(tmp.data)).flatmap(
                lambda i:
                hl.cond(hl.is_missing(tmp.data[i].__entries),
                        hl.range(0, hl.len(tmp.g[i].__cols))
                          .map(lambda _: hl.null(tmp.data[i].__entries.dtype.element_type)),
                        hl.bind(
                            lambda old_to_new: tmp.data[i].__entries.map(lambda e: renumber_entry(e, old_to_new)),
                            hl.range(0, hl.len(tmp.data[i].alleles)).map(
                                lambda j: combined_allele_index[tmp.data[i].alleles[j]])))),
            hl.dict(hl.range(0, hl.len(tmp.alleles)).map(
                lambda j: hl.tuple([tmp.alleles[j], j])))))
    tmp = tmp.annotate_globals(__cols=hl.flatten(tmp.g.map(lambda g: g.__cols)))

    return tmp.drop('data', 'g')
예제 #4
0
파일: vcf_combiner.py 프로젝트: jigold/hail
 def merge_alleles(alleles):
     from hail.expr.functions import _num_allele_type, _allele_ints
     return hl.rbind(
         alleles.map(lambda a: hl.or_else(a[0], ''))
                .fold(lambda s, t: hl.cond(hl.len(s) > hl.len(t), s, t), ''),
         lambda ref:
         hl.rbind(
             alleles.map(
                 lambda al: hl.rbind(
                     al[0],
                     lambda r:
                     hl.array([ref]).extend(
                         al[1:].map(
                             lambda a:
                             hl.rbind(
                                 _num_allele_type(r, a),
                                 lambda at:
                                 hl.cond(
                                     (_allele_ints['SNP'] == at) |
                                     (_allele_ints['Insertion'] == at) |
                                     (_allele_ints['Deletion'] == at) |
                                     (_allele_ints['MNP'] == at) |
                                     (_allele_ints['Complex'] == at),
                                     a + ref[hl.len(r):],
                                     a)))))),
             lambda lal:
             hl.struct(
                 globl=hl.array([ref]).extend(hl.array(hl.set(hl.flatten(lal)).remove(ref))),
                 local=lal)))
예제 #5
0
def sor_from_sb(
    sb: Union[hl.expr.ArrayNumericExpression, hl.expr.ArrayExpression]
) -> hl.expr.Float64Expression:
    """
    Computes `SOR` (Symmetric Odds Ratio test) annotation from  the `SB` (strand balance table) field.

    .. note::

        This function can either take
        - an array of length four containing the forward and reverse strands' counts of ref and alt alleles: [ref fwd, ref rev, alt fwd, alt rev]
        - a two dimensional array with arrays of length two, containing the counts: [[ref fwd, ref rev], [alt fwd, alt rev]]

    GATK code here: https://github.com/broadinstitute/gatk/blob/master/src/main/java/org/broadinstitute/hellbender/tools/walkers/annotator/StrandOddsRatio.java

    :param sb: Count of ref/alt reads on each strand
    :return: SOR value
    """

    if not isinstance(sb, hl.expr.ArrayNumericExpression):
        sb = hl.bind(lambda x: hl.flatten(x), sb)

    sb = sb.map(lambda x: hl.float64(x) + 1)

    ref_fw = sb[0]
    ref_rv = sb[1]
    alt_fw = sb[2]
    alt_rv = sb[3]
    symmetrical_ratio = ((ref_fw * alt_rv) / (alt_fw * ref_rv)) + (
        (alt_fw * ref_rv) / (ref_fw * alt_rv)
    )
    ref_ratio = hl.min(ref_rv, ref_fw) / hl.max(ref_rv, ref_fw)
    alt_ratio = hl.min(alt_fw, alt_rv) / hl.max(alt_fw, alt_rv)
    sor = hl.log(symmetrical_ratio) + hl.log(ref_ratio) - hl.log(alt_ratio)

    return sor
예제 #6
0
 def merge_alleles(alleles):
     from hail.expr.functions import _num_allele_type, _allele_ints
     return hl.rbind(
         alleles.map(lambda a: hl.or_else(a[0], ''))
                .fold(lambda s, t: hl.cond(hl.len(s) > hl.len(t), s, t), ''),
         lambda ref:
         hl.rbind(
             alleles.map(
                 lambda al: hl.rbind(
                     al[0],
                     lambda r:
                     hl.array([ref]).extend(
                         al[1:].map(
                             lambda a:
                             hl.rbind(
                                 _num_allele_type(r, a),
                                 lambda at:
                                 hl.cond(
                                     (_allele_ints['SNP'] == at) |
                                     (_allele_ints['Insertion'] == at) |
                                     (_allele_ints['Deletion'] == at) |
                                     (_allele_ints['MNP'] == at) |
                                     (_allele_ints['Complex'] == at),
                                     a + ref[hl.len(r):],
                                     a)))))),
             lambda lal:
             hl.struct(
                 globl=hl.array([ref]).extend(hl.array(hl.set(hl.flatten(lal)).remove(ref))),
                 local=lal)))
예제 #7
0
def combine(ts):
    # pylint: disable=protected-access
    tmp = ts.annotate(
        alleles=merge_alleles(ts.data.map(lambda d: d.alleles)),
        rsid=hl.find(hl.is_defined, ts.data.map(lambda d: d.rsid)),
        info=hl.struct(
            MQ_DP=hl.sum(ts.data.map(lambda d: d.info.MQ_DP)),
            QUALapprox=hl.sum(ts.data.map(lambda d: d.info.QUALapprox)),
            RAW_MQ=hl.sum(ts.data.map(lambda d: d.info.RAW_MQ)),
            VarDP=hl.sum(ts.data.map(lambda d: d.info.VarDP)),
            SB_TABLE=hl.array([
                hl.sum(ts.data.map(lambda d: d.info.SB_TABLE[0])),
                hl.sum(ts.data.map(lambda d: d.info.SB_TABLE[1])),
                hl.sum(ts.data.map(lambda d: d.info.SB_TABLE[2])),
                hl.sum(ts.data.map(lambda d: d.info.SB_TABLE[3]))
            ])))
    tmp = tmp.annotate(
        __entries=hl.bind(
            lambda combined_allele_index:
            hl.range(0, hl.len(tmp.data)).flatmap(
                lambda i:
                hl.cond(hl.is_missing(tmp.data[i].__entries),
                        hl.range(0, hl.len(tmp.g[i].__cols))
                          .map(lambda _: hl.null(tmp.data[i].__entries.dtype.element_type)),
                        hl.bind(
                            lambda old_to_new: tmp.data[i].__entries.map(lambda e: renumber_entry(e, old_to_new)),
                            hl.array([0]).extend(
                                hl.range(0, hl.len(tmp.data[i].alleles)).map(
                                    lambda j: combined_allele_index[tmp.data[i].alleles[j]]))))),
            hl.dict(hl.range(1, hl.len(tmp.alleles) + 1).map(
                lambda j: hl.tuple([tmp.alleles[j - 1], j])))))
    tmp = tmp.annotate_globals(__cols=hl.flatten(tmp.g.map(lambda g: g.__cols)))

    return tmp.drop('data', 'g')
예제 #8
0
def combine(ts):
    def merge_alleles(alleles):
        from hail.expr.functions import _num_allele_type, _allele_ints
        return hl.rbind(
            alleles.map(lambda a: hl.or_else(a[0], '')).fold(
                lambda s, t: hl.cond(hl.len(s) > hl.len(t), s, t), ''),
            lambda ref: hl.rbind(
                alleles.map(lambda al: hl.rbind(
                    al[0], lambda r: hl.array([ref]).
                    extend(al[1:].map(lambda a: hl.rbind(
                        _num_allele_type(r, a), lambda at: hl.cond(
                            (_allele_ints['SNP'] == at)
                            | (_allele_ints['Insertion'] == at)
                            | (_allele_ints['Deletion'] == at)
                            | (_allele_ints['MNP'] == at)
                            | (_allele_ints['Complex'] == at), a + ref[hl.len(
                                r):], a)))))), lambda lal: hl.
                struct(globl=hl.array([ref]).extend(
                    hl.array(hl.set(hl.flatten(lal)).remove(ref))),
                       local=lal)))

    def renumber_entry(entry, old_to_new) -> StructExpression:
        # global index of alternate (non-ref) alleles
        return entry.annotate(LA=entry.LA.map(lambda lak: old_to_new[lak]))

    if (ts.row.dtype, ts.globals.dtype) not in _merge_function_map:
        f = hl.experimental.define_function(
            lambda row, gbl: hl.rbind(
                merge_alleles(row.data.map(lambda d: d.alleles)), lambda
                alleles: hl.struct(
                    locus=row.locus,
                    alleles=alleles.globl,
                    rsid=hl.find(hl.is_defined, row.data.map(lambda d: d.rsid)
                                 ),
                    __entries=hl.bind(
                        lambda combined_allele_index: hl.
                        range(0, hl.len(row.data)).flatmap(lambda i: hl.cond(
                            hl.is_missing(row.data[i].__entries),
                            hl.range(0, hl.len(gbl.g[i].__cols)).map(
                                lambda _: hl.null(row.data[i].__entries.dtype.
                                                  element_type)),
                            hl.bind(
                                lambda old_to_new: row.data[i].__entries.map(
                                    lambda e: renumber_entry(e, old_to_new)),
                                hl.range(0, hl.len(alleles.local[i])).map(
                                    lambda j: combined_allele_index[
                                        alleles.local[i][j]])))),
                        hl.dict(
                            hl.range(0, hl.len(alleles.globl)).map(
                                lambda j: hl.tuple([alleles.globl[j], j])))))),
            ts.row.dtype, ts.globals.dtype)
        _merge_function_map[(ts.row.dtype, ts.globals.dtype)] = f
    merge_function = _merge_function_map[(ts.row.dtype, ts.globals.dtype)]
    ts = Table(
        TableMapRows(
            ts._tir,
            Apply(merge_function._name, merge_function._ret_type,
                  TopLevelReference('row'), TopLevelReference('global'))))
    return ts.transmute_globals(
        __cols=hl.flatten(ts.g.map(lambda g: g.__cols)))
예제 #9
0
def multi_way_union_mts(mts: list, tmp_dir: str,
                        chunk_size: int) -> hl.MatrixTable:
    """Joins MatrixTables in the provided list

    :param list mts: list of MatrixTables to join together
    :param str tmp_dir: path to temporary directory for intermediate results
    :param int chunk_size: number of MatrixTables to join per chunk

    :return: joined MatrixTable
    :rtype: MatrixTable
    """

    staging = [mt.localize_entries("__entries", "__cols") for mt in mts]
    stage = 0
    while len(staging) > 1:
        n_jobs = int(math.ceil(len(staging) / chunk_size))
        info(f"multi_way_union_mts: stage {stage}: {n_jobs} total jobs")
        next_stage = []
        for i in range(n_jobs):
            to_merge = staging[chunk_size * i:chunk_size * (i + 1)]
            info(
                f"multi_way_union_mts: stage {stage} / job {i}: merging {len(to_merge)} inputs"
            )

            merged = hl.Table.multi_way_zip_join(to_merge, "__entries",
                                                 "__cols")
            merged = merged.annotate(__entries=hl.flatten(
                hl.range(hl.len(merged.__entries)).map(lambda i: hl.coalesce(
                    merged.__entries[i].__entries,
                    hl.range(hl.len(merged.__cols[i].__cols)).map(
                        lambda j: hl.null(merged.__entries.__entries.dtype.
                                          element_type.element_type)),
                ))))
            merged = merged.annotate_globals(
                __cols=hl.flatten(merged.__cols.map(lambda x: x.__cols)))

            next_stage.append(
                merged.checkpoint(os.path.join(tmp_dir,
                                               f"stage_{stage}_job_{i}.ht"),
                                  overwrite=True))
        info(f"done stage {stage}")
        stage += 1
        staging.clear()
        staging.extend(next_stage)

    return (staging[0]._unlocalize_entries(
        "__entries", "__cols", list(mts[0].col_key)).unfilter_entries())
예제 #10
0
def combine_r(ts):
    if (ts.row.dtype, ts.globals.dtype) not in _merge_function_map:
        f = hl.experimental.define_function(
            lambda row, gbl:
            hl.struct(
                locus=row.locus,
                ref_allele=hl.find(hl.is_defined, row.data.map(lambda d: d.ref_allele)),
                __entries=hl.range(0, hl.len(row.data)).flatmap(
                    lambda i:
                    hl.if_else(hl.is_missing(row.data[i]),
                               hl.range(0, hl.len(gbl.g[i].__cols))
                               .map(lambda _: hl.missing(row.data[i].__entries.dtype.element_type)),
                               row.data[i].__entries))),
            ts.row.dtype, ts.globals.dtype)
        _merge_function_map[(ts.row.dtype, ts.globals.dtype)] = f
    merge_function = _merge_function_map[(ts.row.dtype, ts.globals.dtype)]
    ts = Table(TableMapRows(ts._tir, Apply(merge_function._name,
                                           merge_function._ret_type,
                                           TopLevelReference('row'),
                                           TopLevelReference('global'))))
    return ts.transmute_globals(__cols=hl.flatten(ts.g.map(lambda g: g.__cols)))
예제 #11
0
파일: vcf_combiner.py 프로젝트: jigold/hail
def combine(ts):
    def merge_alleles(alleles):
        from hail.expr.functions import _num_allele_type, _allele_ints
        return hl.rbind(
            alleles.map(lambda a: hl.or_else(a[0], ''))
                   .fold(lambda s, t: hl.cond(hl.len(s) > hl.len(t), s, t), ''),
            lambda ref:
            hl.rbind(
                alleles.map(
                    lambda al: hl.rbind(
                        al[0],
                        lambda r:
                        hl.array([ref]).extend(
                            al[1:].map(
                                lambda a:
                                hl.rbind(
                                    _num_allele_type(r, a),
                                    lambda at:
                                    hl.cond(
                                        (_allele_ints['SNP'] == at) |
                                        (_allele_ints['Insertion'] == at) |
                                        (_allele_ints['Deletion'] == at) |
                                        (_allele_ints['MNP'] == at) |
                                        (_allele_ints['Complex'] == at),
                                        a + ref[hl.len(r):],
                                        a)))))),
                lambda lal:
                hl.struct(
                    globl=hl.array([ref]).extend(hl.array(hl.set(hl.flatten(lal)).remove(ref))),
                    local=lal)))

    def renumber_entry(entry, old_to_new) -> StructExpression:
        # global index of alternate (non-ref) alleles
        return entry.annotate(LA=entry.LA.map(lambda lak: old_to_new[lak]))

    if (ts.row.dtype, ts.globals.dtype) not in _merge_function_map:
        f = hl.experimental.define_function(
            lambda row, gbl:
            hl.rbind(
                merge_alleles(row.data.map(lambda d: d.alleles)),
                lambda alleles:
                hl.struct(
                    locus=row.locus,
                    alleles=alleles.globl,
                    rsid=hl.find(hl.is_defined, row.data.map(lambda d: d.rsid)),
                    __entries=hl.bind(
                        lambda combined_allele_index:
                        hl.range(0, hl.len(row.data)).flatmap(
                            lambda i:
                            hl.cond(hl.is_missing(row.data[i].__entries),
                                    hl.range(0, hl.len(gbl.g[i].__cols))
                                      .map(lambda _: hl.null(row.data[i].__entries.dtype.element_type)),
                                    hl.bind(
                                        lambda old_to_new: row.data[i].__entries.map(
                                            lambda e: renumber_entry(e, old_to_new)),
                                        hl.range(0, hl.len(alleles.local[i])).map(
                                            lambda j: combined_allele_index[alleles.local[i][j]])))),
                        hl.dict(hl.range(0, hl.len(alleles.globl)).map(
                            lambda j: hl.tuple([alleles.globl[j], j])))))),
            ts.row.dtype, ts.globals.dtype)
        _merge_function_map[(ts.row.dtype, ts.globals.dtype)] = f
    merge_function = _merge_function_map[(ts.row.dtype, ts.globals.dtype)]
    ts = Table(TableMapRows(ts._tir, Apply(merge_function._name,
                                           TopLevelReference('row'),
                                           TopLevelReference('global'))))
    return ts.transmute_globals(__cols=hl.flatten(ts.g.map(lambda g: g.__cols)))
    'f7': 'phase',
    'f8': 'attributes'
})

ht_transcripts = ht_transcripts.filter(
    ht_transcripts.feature_type == 'transcript')
ht_transcripts = ht_transcripts.annotate(interval=hl.interval(
    hl.locus(ht_transcripts.contig, ht_transcripts.start, 'GRCh37'),
    hl.locus(ht_transcripts.contig, ht_transcripts.end + 1, 'GRCh37')))
ht_transcripts = ht_transcripts.annotate(attributes=hl.dict(
    hl.map(
        lambda x: (x.split(' ')[0], x.split(' ')[1].replace('"', '').replace(
            ';$', '')), ht_transcripts.attributes.split('; '))))
attribute_cols = list(
    ht_transcripts.aggregate(
        hl.set(hl.flatten(hl.agg.collect(ht_transcripts.attributes.keys())))))
ht_transcripts = ht_transcripts.annotate(
    **{
        x: hl.or_missing(ht_transcripts.attributes.contains(x),
                         ht_transcripts.attributes[x])
        for x in attribute_cols
    })
ht_transcripts = ht_transcripts.select(*([
    'transcript_id', 'transcript_name', 'transcript_type', 'strand',
    'transcript_status', 'havana_transcript', 'ccdsid', 'ont', 'gene_name',
    'interval', 'gene_type', 'annotation_source', 'havana_gene', 'gene_status',
    'tag'
]))
ht_transcripts = ht_transcripts.rename({
    'havana_transcript': 'havana_transcript_id',
    'havana_gene': 'havana_gene_id'
예제 #13
0
def merge_alleles(alleles) -> ArrayExpression:
    # alleles is tarray(tarray(tstruct(ref=tstr, alt=tstr)))
    return hl.rbind(hl.array(hl.set(hl.flatten(alleles))),
                    lambda arr:
                    hl.filter(lambda a: a.alt != '<NON_REF>', arr)
                      .extend(hl.filter(lambda a: a.alt == '<NON_REF>', arr)))
예제 #14
0
def compute_from_vp_mt(chr20: bool, overwrite: bool):
    meta = get_gnomad_meta('exomes')
    vp_mt = hl.read_matrix_table(full_mt_path('exomes'))
    vp_mt = vp_mt.filter_cols(meta[vp_mt.col_key].release)
    ann_ht = hl.read_table(vp_ann_ht_path('exomes'))
    phase_ht = hl.read_table(phased_vp_count_ht_path('exomes'))

    if chr20:
        vp_mt, ann_ht, phase_ht = filter_to_chr20([vp_mt, ann_ht, phase_ht])

    vep1_expr = get_worst_gene_csq_code_expr(ann_ht.vep1)
    vep2_expr = get_worst_gene_csq_code_expr(ann_ht.vep2)
    ann_ht = ann_ht.select(
        'snv1',
        'snv2',
        is_singleton_vp=(ann_ht.freq1['all'].AC < 2) & (ann_ht.freq2['all'].AC < 2),
        pop_af=hl.dict(
            ann_ht.freq1.key_set().intersection(ann_ht.freq2.key_set())
                .map(
                lambda pop: hl.tuple([pop, hl.max(ann_ht.freq1[pop].AF, ann_ht.freq2[pop].AF)])
            )
        ),
        popmax_af=hl.max(ann_ht.popmax1.AF, ann_ht.popmax2.AF, filter_missing=False),
        filtered=(hl.len(ann_ht.filters1) > 0) | (hl.len(ann_ht.filters2) > 0),
        vep=vep1_expr.keys().filter(
            lambda k: vep2_expr.contains(k)
        ).map(
            lambda k: vep1_expr[k].annotate(
                csq=hl.max(vep1_expr[k].csq, vep2_expr[k].csq)
            )
        )
    )

    vp_mt = vp_mt.annotate_cols(
        pop=meta[vp_mt.col_key].pop
    )
    vp_mt = vp_mt.annotate_rows(
        **ann_ht[vp_mt.row_key],
        phase_info=phase_ht[vp_mt.row_key].phase_info
    )

    vp_mt = vp_mt.filter_rows(
        ~vp_mt.filtered
    )

    vp_mt = vp_mt.filter_entries(
        vp_mt.GT1.is_het() & vp_mt.GT2.is_het() & vp_mt.adj1 & vp_mt.adj2
    )

    vp_mt = vp_mt.select_entries(
        x=True
    )

    vp_mt = vp_mt.annotate_cols(
        pop=['all', vp_mt.pop]
    )
    vp_mt = vp_mt.explode_cols('pop')

    vp_mt = vp_mt.explode_rows('vep')
    vp_mt = vp_mt.transmute_rows(
        **vp_mt.vep
    )

    def get_grouped_phase_agg():
        return hl.agg.group_by(
            hl.case()
                .when(~vp_mt.is_singleton_vp & (vp_mt.phase_info[vp_mt.pop].em.adj.p_chet > CHET_THRESHOLD), 1)
                .when(~vp_mt.is_singleton_vp & (vp_mt.phase_info[vp_mt.pop].em.adj.p_chet < SAME_HAP_THRESHOLD), 2)
                .default(3)
            ,
            hl.agg.min(vp_mt.csq)
        )

    vp_mt = vp_mt.group_rows_by(
        'gene_id',
        'gene_symbol'
    ).aggregate(
        all=hl.agg.filter(
            vp_mt.x &
            hl.if_else(
                vp_mt.pop == 'all',
                hl.is_defined(vp_mt.popmax_af) &
                (vp_mt.popmax_af <= MAX_FREQ),
                vp_mt.pop_af[vp_mt.pop] <= MAX_FREQ
            ),
            get_grouped_phase_agg()
        ),
        af_le_0_001=hl.agg.filter(
            hl.if_else(
                vp_mt.pop == 'all',
                hl.is_defined(vp_mt.popmax_af) &
                (vp_mt.popmax_af <= 0.001),
                vp_mt.pop_af[vp_mt.pop] <= 0.001
            )
            & vp_mt.x,
            get_grouped_phase_agg()
        )
    )

    vp_mt = vp_mt.checkpoint('gs://gnomad-tmp/compound_hets/chet_per_gene{}.2.mt'.format(
        '.chr20' if chr20 else ''
    ), overwrite=True)

    gene_ht = vp_mt.annotate_rows(
        row_counts=hl.flatten([
            hl.array(
                hl.agg.group_by(
                    vp_mt.pop,
                    hl.struct(
                        csq=csq,
                        af=af,
                        # TODO: Review this
                        # These will only kept the worst csq -- now maybe it'd be better to keep either
                        # - the worst csq for chet or
                        # - the worst csq for both chet and same_hap
                        n_worst_chet=hl.agg.count_where(vp_mt[af].get(1) == csq_i),
                        n_chet=hl.agg.count_where((vp_mt[af].get(1) == csq_i) & (vp_mt[af].get(2, 9) >= csq_i) & (vp_mt[af].get(3, 9) >= csq_i)),
                        n_same_hap=hl.agg.count_where((vp_mt[af].get(2) == csq_i) & (vp_mt[af].get(1, 9) > csq_i) & (vp_mt[af].get(3, 9) >= csq_i)),
                        n_unphased=hl.agg.count_where((vp_mt[af].get(3) == csq_i) & (vp_mt[af].get(1, 9) > csq_i) & (vp_mt[af].get(2, 9) > csq_i))
                    )
                )
            ).filter(
                lambda x: (x[1].n_chet > 0) | (x[1].n_same_hap > 0) | (x[1].n_unphased > 0)
            ).map(
                lambda x: x[1].annotate(
                    pop=x[0]
                )
            )
            for csq_i, csq in enumerate(CSQ_CODES)
            for af in ['all', 'af_le_0_001']
        ])
    ).rows()

    gene_ht = gene_ht.explode('row_counts')
    gene_ht = gene_ht.select(
        **gene_ht.row_counts
    )

    gene_ht.describe()
    gene_ht = gene_ht.checkpoint(
        'gs://gnomad-lfran/compound_hets/chet_per_gene{}.ht'.format(
            '.chr20' if chr20 else ''
        ),
        overwrite=overwrite
    )

    gene_ht.flatten().export(
        'gs://gnomad-lfran/compound_hets/chet_per_gene{}.tsv.gz'.format(
            '.chr20' if chr20 else ''
        )
    )
예제 #15
0
def compute_from_full_mt(chr20: bool, overwrite: bool):
    mt = get_gnomad_data('exomes', adj=True, release_samples=True)
    freq_ht = hl.read_table(annotations_ht_path('exomes', 'frequencies'))
    vep_ht = hl.read_table(annotations_ht_path('exomes', 'vep'))
    rf_ht = hl.read_table(annotations_ht_path('exomes', 'rf'))

    if chr20:
        mt, freq_ht, vep_ht, rf_ht = filter_to_chr20([mt, freq_ht, vep_ht, rf_ht])

    vep_ht = vep_ht.annotate(
        vep=get_worst_gene_csq_code_expr(vep_ht.vep).values()
    )

    freq_ht = freq_ht.select(
        freq=freq_ht.freq[:10],
        popmax=freq_ht.popmax
    )
    freq_meta = hl.eval(freq_ht.globals.freq_meta)
    freq_dict = {f['pop']: i for i, f in enumerate(freq_meta[:10]) if 'pop' in f}
    freq_dict['all'] = 0
    freq_dict = hl.literal(freq_dict)
    mt = mt.annotate_rows(
        **freq_ht[mt.row_key],
        vep=vep_ht[mt.row_key].vep,
        filters=rf_ht[mt.row_key].filters
    )
    mt = mt.filter_rows(
        (mt.freq[0].AF <= MAX_FREQ) &
        (hl.len(mt.vep) > 0) &
        (hl.len(mt.filters) == 0)
    )

    mt = mt.filter_entries(mt.GT.is_non_ref())
    mt = mt.select_entries(
        is_het=mt.GT.is_het()
    )

    mt = mt.explode_rows(mt.vep)
    mt = mt.transmute_rows(**mt.vep)

    mt = mt.annotate_cols(
        pop=['all', mt.meta.pop]
    )
    mt = mt.explode_cols(mt.pop)

    mt = mt.group_rows_by(
        'gene_id'
    ).aggregate_rows(
        gene_symbol=hl.agg.take(mt.gene_symbol, 1)[0]
    ).aggregate(
        counts=hl.agg.filter(
            hl.if_else(
                mt.pop == 'all',
                hl.is_defined(mt.popmax) & (mt.popmax.AF <= MAX_FREQ),
                mt.freq[freq_dict[mt.pop]].AF <= MAX_FREQ
            ),
            hl.agg.group_by(
                hl.if_else(
                    mt.pop == 'all',
                    mt.popmax.AF > 0.001,
                    mt.freq[freq_dict[mt.pop]].AF > 0.001
                ),
                hl.struct(
                    hom_csq=hl.agg.filter(~mt.is_het, hl.agg.min(mt.csq)),
                    het_csq=hl.agg.filter(mt.is_het, hl.agg.min(mt.csq)),
                    het_het_csq=hl.sorted(
                        hl.array(
                            hl.agg.filter(mt.is_het, hl.agg.counter(mt.csq))
                        ),
                        key=lambda x: x[0]
                    ).scan(
                        lambda i, j: (j[0], i[1] + j[1]),
                        (0, 0)
                    ).find(
                        lambda x: x[1] > 1
                    )[0]
                )
            )
        )
    )

    mt = mt.annotate_entries(
        counts=hl.struct(
            all=hl.struct(
                hom_csq=hl.min(mt.counts.get(True).hom_csq, mt.counts.get(False).hom_csq),
                het_csq=hl.min(mt.counts.get(True).het_csq, mt.counts.get(False).het_csq),
                het_het_csq=hl.min(
                    mt.counts.get(True).het_het_csq,
                    mt.counts.get(False).het_het_csq,
                    hl.or_missing(
                        hl.is_defined(mt.counts.get(True).het_csq) & hl.is_defined(mt.counts.get(False).het_csq),
                        hl.max(mt.counts.get(True).het_csq, mt.counts.get(False).het_csq)
                    )
                ),
            ),
            af_le_0_001=mt.counts.get(False)
        )
    )

    mt = mt.checkpoint('gs://gnomad-tmp/compound_hets/het_and_hom_per_gene{}.1.mt'.format(
        '.chr20' if chr20 else ''
    ), overwrite=True)

    gene_ht = mt.annotate_rows(
        row_counts=hl.flatten([
            hl.array(
                hl.agg.group_by(
                    mt.pop,
                    hl.struct(
                        csq=csq,
                        af=af,
                        n_hom=hl.agg.count_where(mt.counts[af].hom_csq == csq_i),
                        n_het=hl.agg.count_where(mt.counts[af].het_csq == csq_i),
                        n_het_het=hl.agg.count_where(mt.counts[af].het_het_csq == csq_i)
                    )
                )
            ).filter(
                lambda x: (x[1].n_het > 0) | (x[1].n_hom > 0) | (x[1].n_het_het > 0)
            ).map(
                lambda x: x[1].annotate(
                    pop=x[0]
                )
            )
            for csq_i, csq in enumerate(CSQ_CODES)
            for af in ['all', 'af_le_0_001']
        ])
    ).rows()

    gene_ht = gene_ht.explode('row_counts')
    gene_ht = gene_ht.select(
        'gene_symbol',
        **gene_ht.row_counts
    )

    gene_ht.describe()

    gene_ht = gene_ht.checkpoint(
        'gs://gnomad-lfran/compound_hets/het_and_hom_per_gene{}.ht'.format(
            '.chr20' if chr20 else ''
        ),
        overwrite=overwrite
    )

    gene_ht.flatten().export('gs://gnomad-lfran/compound_hets/het_and_hom_per_gene{}.tsv.gz'.format(
        '.chr20' if chr20 else ''
    ))
예제 #16
0
파일: misc.py 프로젝트: chrisvittal/hail
def segment_intervals(ht, points):
    """Segment the interval keys of `ht` at a given set of points.

    Parameters
    ----------
    ht : :class:`.Table`
        Table with interval keys.
    points : :class:`.Table` or :class:`.ArrayExpression`
        Points at which to segment the intervals, a table or an array.

    Returns
    -------
    :class:`.Table`
    """
    if len(ht.key) != 1 or not isinstance(ht.key[0].dtype, hl.tinterval):
        raise ValueError(
            "'segment_intervals' expects a table with interval keys")
    point_type = ht.key[0].dtype.point_type
    if isinstance(points, Table):
        if len(points.key) != 1 or points.key[0].dtype != point_type:
            raise ValueError(
                "'segment_intervals' expects points to be a table with a single"
                " key of the same type as the intervals in 'ht', or an array of those points:"
                f"\n  expect {point_type}, found {list(points.key.dtype.values())}"
            )
        points = hl.array(hl.set(points.collect(_localize=False)))
    if points.dtype.element_type != point_type:
        raise ValueError(
            f"'segment_intervals' expects points to be a table with a single"
            f" key of the same type as the intervals in 'ht', or an array of those points:"
            f"\n  expect {point_type}, found {points.dtype.element_type}")

    points = hl._sort_by(points, lambda l, r: hl._compare(l, r) < 0)

    ht = ht.annotate_globals(__points=points)

    interval = ht.key[0]
    points = ht.__points
    lower = hl.expr.functions._lower_bound(points, interval.start)
    higher = hl.expr.functions._lower_bound(points, interval.end)
    n_points = hl.len(points)
    lower = hl.if_else((lower < n_points) & (points[lower] == interval.start),
                       lower + 1, lower)
    higher = hl.if_else((higher < n_points) & (points[higher] == interval.end),
                        higher - 1, higher)
    interval_results = hl.rbind(
        lower, higher, lambda lower, higher: hl.cond(
            lower >= higher, [interval],
            hl.flatten([
                [
                    hl.interval(interval.start,
                                points[lower],
                                includes_start=interval.includes_start,
                                includes_end=False)
                ],
                hl.range(lower, higher - 1).map(lambda x: hl.interval(
                    points[x],
                    points[x + 1],
                    includes_start=True,
                    includes_end=False)),
                [
                    hl.interval(points[higher - 1],
                                interval.end,
                                includes_start=True,
                                includes_end=interval.includes_end)
                ],
            ])))
    ht = ht.annotate(__new_intervals=interval_results,
                     lower=lower,
                     higher=higher).explode('__new_intervals')
    return ht.key_by(**{
        list(ht.key)[0]: ht.__new_intervals
    }).drop('__new_intervals')
예제 #17
0
파일: load.GTEx.py 프로젝트: saponas/hail
                'SMRRNANM',
                'SMVQCFL',
                'SMTRSCPT',
                'SMMPPDPR',
                'SMCGLGTH',
                'SMUNPDRD',
                'SMMPPDUN',
                'SME2ANTI',
                'SMALTALG',
                'SME2SNSE',
                'SMMFLGTH',
                'SMSPLTRD',
                'SME1ANTI',
                'SME1SNSE',
                'SMNUM5CD']

    ht_samples = ht_samples.annotate(**{x: hl.float(ht_samples[x]) for x in float_cols})
    ht_samples = ht_samples.annotate(**{x: hl.int(ht_samples[x].replace('.0$', '')) for x in int_cols})

    ht = ht.filter(ht.feature_type == 'gene')
    ht = ht.annotate(interval=hl.interval(hl.locus(ht['contig'], ht['start'], 'GRCh37'), hl.locus(ht['contig'], ht['end'] + 1, 'GRCh37')))
    ht = ht.annotate(attributes=hl.dict(hl.map(lambda x: (x.split(' ')[0], x.split(' ')[1].replace('"', '').replace(';$', '')), ht['attributes'].split('; '))))
    attribute_cols = list(ht.aggregate(hl.set(hl.flatten(hl.agg.collect(ht.attributes.keys())))))
    ht = ht.annotate(**{x: hl.or_missing(ht_genes.attributes.contains(x), ht_genes.attributes[x]) for x in attribute_cols})
    ht = ht.select(*(['gene_id', 'interval', 'gene_type', 'strand', 'annotation_source', 'havana_gene', 'gene_status', 'tag']))
    ht = ht.rename({'havana_gene': 'havana_gene_id'})
    ht = ht.key_by(ht_genes.gene_id)

"""

예제 #18
0
파일: load.GTEx.py 프로젝트: saponas/hail
def import_gtf(path, key=None):
    """Import a GTF file.

       The GTF file format is identical to the GFF version 2 file format,
       and so this function can be used to import GFF version 2 files as
       well.

       See https://www.ensembl.org/info/website/upload/gff.html for more
       details on the GTF/GFF2 file format.

       The :class:`.Table` returned by this function will include the following
       row fields:

       .. code-block:: text

           'seqname': str
           'source': str
           'feature': str
           'start': int32
           'end': int32
           'score': float64
           'strand': str
           'frame': int32

       There will also be corresponding fields for every tag found in the
       attribute field of the GTF file.

       .. note::

           The "end" field in the table will be incremented by 1 in
           comparison to the value found in the GTF file, as the end
           coordinate in a GTF file is inclusive while the end
           coordinate in Hail is exclusive.

       Example
       -------

       >>> ht = hl.experimental.import_gtf('data/test.gtf', key='gene_id')
       >>> ht.describe()

       .. code-block:: text

           ----------------------------------------
           Global fields:
           None
           ----------------------------------------
           Row fields:
               'seqname': str
               'source': str
               'feature': str
               'start': int32
               'end': int32
               'score': float64
               'strand': str
               'frame': int32
               'havana_gene': str
               'exon_id': str
               'havana_transcript': str
               'transcript_name': str
               'gene_type': str
               'tag': str
               'transcript_status': str
               'exon_number': str
               'level': str
               'transcript_id': str
               'transcript_type': str
               'gene_id': str
               'gene_name': str
               'gene_status': str
           ----------------------------------------
           Key: ['gene_id']
           ----------------------------------------

       Parameters
       ----------

       path : :obj:`str`
           File to import.
       key : :obj:`str` or :obj:`list` of :obj:`str`
           Key field(s). Can be tag name(s) found in the attribute field
           of the GTF file.

       Returns
       -------
       :class:`.Table`
       """

    ht = hl.import_table(path,
                         comment='#',
                         no_header=True,
                         types={'f3': hl.tint,
                                'f4': hl.tint,
                                'f5': hl.tfloat,
                                'f7': hl.tint},
                         missing='.',
                         delimiter='\t')

    ht = ht.rename({'f0': 'seqname',
                    'f1': 'source',
                    'f2': 'feature',
                    'f3': 'start',
                    'f4': 'end',
                    'f5': 'score',
                    'f6': 'strand',
                    'f7': 'frame',
                    'f8': 'attribute'})

    ht = ht.annotate(end=ht['end'] + 1)
    ht = ht.annotate(attribute=hl.dict(
        hl.map(lambda x: (x.split(' ')[0],
                          x.split(' ')[1].replace('"', '').replace(';$', '')),
               ht['attribute'].split('; '))))

    attributes = list(ht.aggregate(
        hl.set(hl.flatten(hl.agg.collect(ht['attribute'].keys())))))

    ht = ht.annotate(**{x: hl.or_missing(ht['attribute'].contains(x),
                                         ht['attribute'][x])
                        for x in attributes})

    ht = ht.drop(ht['attribute'])

    if key:
        key = wrap_to_list(key)
        ht = ht.key_by(*key)

    return ht
예제 #19
0
파일: vcf_combiner.py 프로젝트: bcajes/hail
def merge_alleles(alleles) -> ArrayExpression:
    return hl.array(hl.set(hl.flatten(alleles)))
예제 #20
0
def merge_alleles(alleles) -> ArrayExpression:
    return hl.array(hl.set(hl.flatten(alleles)))
예제 #21
0
def fs_from_sb(
    sb: Union[hl.expr.ArrayNumericExpression, hl.expr.ArrayExpression],
    normalize: bool = True,
    min_cell_count: int = 200,
    min_count: int = 4,
    min_p_value: float = 1e-320,
) -> hl.expr.Int64Expression:
    """
    Computes `FS` (Fisher strand balance) annotation from  the `SB` (strand balance table) field.
    `FS` is the phred-scaled value of the double-sided Fisher exact test on strand balance.

    Using default values will have the same behavior as the GATK implementation, that is:
    - If sum(counts) > 2*`min_cell_count` (default to GATK value of 200), they are normalized
    - If sum(counts) < `min_count` (default to GATK value of 4), returns missing
    - Any p-value < `min_p_value` (default to GATK value of 1e-320) is truncated to that value

    In addition to the default GATK behavior, setting `normalize` to `False` will perform a chi-squared test
    for large counts (> `min_cell_count`) instead of normalizing the cell values.

    .. note::

        This function can either take
        - an array of length four containing the forward and reverse strands' counts of ref and alt alleles: [ref fwd, ref rev, alt fwd, alt rev]
        - a two dimensional array with arrays of length two, containing the counts: [[ref fwd, ref rev], [alt fwd, alt rev]]

    GATK code here: https://github.com/broadinstitute/gatk/blob/master/src/main/java/org/broadinstitute/hellbender/tools/walkers/annotator/FisherStrand.java

    :param sb: Count of ref/alt reads on each strand
    :param normalize: Whether to normalize counts is sum(counts) > min_cell_count (normalize=True), or use a chi sq instead of FET (normalize=False)
    :param min_cell_count: Maximum count for performing a FET
    :param min_count: Minimum total count to output FS (otherwise null it output)
    :return: FS value
    """
    if not isinstance(sb, hl.expr.ArrayNumericExpression):
        sb = hl.bind(lambda x: hl.flatten(x), sb)

    sb_sum = hl.bind(lambda x: hl.sum(x), sb)

    # Normalize table if counts get too large
    if normalize:
        fs_expr = hl.bind(
            lambda sb, sb_sum: hl.cond(
                sb_sum <= 2 * min_cell_count,
                sb,
                sb.map(lambda x: hl.int(x / (sb_sum / min_cell_count))),
            ),
            sb,
            sb_sum,
        )

        # FET
        fs_expr = to_phred(
            hl.max(
                hl.fisher_exact_test(
                    fs_expr[0], fs_expr[1], fs_expr[2], fs_expr[3]
                ).p_value,
                min_p_value,
            )
        )
    else:
        fs_expr = to_phred(
            hl.max(
                hl.contingency_table_test(
                    sb[0], sb[1], sb[2], sb[3], min_cell_count=min_cell_count
                ).p_value,
                min_p_value,
            )
        )

    # Return null if counts <= `min_count`
    return hl.or_missing(
        sb_sum > min_count, hl.max(0, fs_expr)  # Needed to avoid -0.0 values
    )
예제 #22
0
파일: tutorial.py 프로젝트: lfrancioli/hail
 def fields_to_array(ds, fields):
     return hl.flatten(hl.array([field_to_array(ds, f) for f in fields]))
예제 #23
0
 def fields_to_array(ds, fields):
     return hl.flatten(hl.array([field_to_array(ds, f) for f in fields]))
예제 #24
0
, "TFBS_amplification"
, "TF_binding_site_variant"
, "regulatory_region_ablation"
, "regulatory_region_amplification"
, "feature_elongation"
, "regulatory_region_variant"
, "feature_truncation"
, "intergenic_variant"
]



mt4 = kt2.annotate(transcript_canonicals =kt2.vep.transcript_consequences.filter(lambda tc: tc.canonical == 1))

mt4 = mt4.annotate(
   all_transcript_terms = hl.set(hl.flatten(mt4.transcript_canonicals.map(lambda x: x.consequence_terms)))
)

mt4 = mt4.annotate(
    Consequence = hl.coalesce(  # coalesce means take first non-missing
        hl.literal(consequence_in_severity_order).filter(lambda cnsq:
            mt4.all_transcript_terms.contains(cnsq)
        ).head(),
        mt4.vep.most_severe_consequence
    )
)

mt4 = mt4.annotate(
	Gene = hl.if_else(mt4.transcript_canonicals.any(lambda tc: hl.set(tc.consequence_terms).contains(mt4.Consequence)), \
				mt4.transcript_canonicals.find(lambda tc: (tc.canonical == 1) & (hl.set(tc.consequence_terms).contains(mt4.Consequence))).gene_symbol, \
			 mt4.vep.transcript_consequences.find(lambda tc: hl.set(tc.consequence_terms).contains(mt4.Consequence)).gene_symbol), \