Exemplo n.º 1
0
 def merge_alleles(alleles):
     from hail.expr.functions import _num_allele_type, _allele_ints
     return hl.rbind(
         alleles.map(lambda a: hl.or_else(a[0], ''))
                .fold(lambda s, t: hl.cond(hl.len(s) > hl.len(t), s, t), ''),
         lambda ref:
         hl.rbind(
             alleles.map(
                 lambda al: hl.rbind(
                     al[0],
                     lambda r:
                     hl.array([ref]).extend(
                         al[1:].map(
                             lambda a:
                             hl.rbind(
                                 _num_allele_type(r, a),
                                 lambda at:
                                 hl.cond(
                                     (_allele_ints['SNP'] == at) |
                                     (_allele_ints['Insertion'] == at) |
                                     (_allele_ints['Deletion'] == at) |
                                     (_allele_ints['MNP'] == at) |
                                     (_allele_ints['Complex'] == at),
                                     a + ref[hl.len(r):],
                                     a)))))),
             lambda lal:
             hl.struct(
                 globl=hl.array([ref]).extend(hl.array(hl.set(hl.flatten(lal)).remove(ref))),
                 local=lal)))
Exemplo n.º 2
0
Arquivo: qc.py Projeto: tpoterba/hail
 def allele_type(ref, alt):
     return hl.bind(lambda at: hl.cond(at == allele_ints['SNP'],
                                       hl.cond(hl.is_transition(ref, alt),
                                               allele_ints['Transition'],
                                               allele_ints['Transversion']),
                                       at),
                    _num_allele_type(ref, alt))
Exemplo n.º 3
0
 def test_refs_with_process_joins(self):
     mt = hl.utils.range_matrix_table(10, 10)
     mt = mt.annotate_entries(
         a_literal=hl.literal(['a']),
         a_col_join=hl.is_defined(mt.cols()[mt.col_key]),
         a_row_join=hl.is_defined(mt.rows()[mt.row_key]),
         an_entry_join=hl.is_defined(mt[mt.row_key, mt.col_key]),
         the_global_failure=hl.cond(True, mt.globals, hl.null(mt.globals.dtype)),
         the_row_failure=hl.cond(True, mt.row, hl.null(mt.row.dtype)),
         the_col_failure=hl.cond(True, mt.col, hl.null(mt.col.dtype)),
         the_entry_failure=hl.cond(True, mt.entry, hl.null(mt.entry.dtype)),
     )
     mt.count()
Exemplo n.º 4
0
def combine(ts):
    # pylint: disable=protected-access
    tmp = ts.annotate(
        alleles=merge_alleles(ts.data.map(lambda d: d.alleles)),
        rsid=hl.find(hl.is_defined, ts.data.map(lambda d: d.rsid)),
        filters=hl.set(hl.flatten(ts.data.map(lambda d: hl.array(d.filters)))),
        info=hl.struct(
            DP=hl.sum(ts.data.map(lambda d: d.info.DP)),
            MQ_DP=hl.sum(ts.data.map(lambda d: d.info.MQ_DP)),
            QUALapprox=hl.sum(ts.data.map(lambda d: d.info.QUALapprox)),
            RAW_MQ=hl.sum(ts.data.map(lambda d: d.info.RAW_MQ)),
            VarDP=hl.sum(ts.data.map(lambda d: d.info.VarDP)),
            SB=hl.array([
                hl.sum(ts.data.map(lambda d: d.info.SB[0])),
                hl.sum(ts.data.map(lambda d: d.info.SB[1])),
                hl.sum(ts.data.map(lambda d: d.info.SB[2])),
                hl.sum(ts.data.map(lambda d: d.info.SB[3]))
            ])))
    tmp = tmp.annotate(
        __entries=hl.bind(
            lambda combined_allele_index:
            hl.range(0, hl.len(tmp.data)).flatmap(
                lambda i:
                hl.cond(hl.is_missing(tmp.data[i].__entries),
                        hl.range(0, hl.len(tmp.g[i].__cols))
                          .map(lambda _: hl.null(tmp.data[i].__entries.dtype.element_type)),
                        hl.bind(
                            lambda old_to_new: tmp.data[i].__entries.map(lambda e: renumber_entry(e, old_to_new)),
                            hl.range(0, hl.len(tmp.data[i].alleles)).map(
                                lambda j: combined_allele_index[tmp.data[i].alleles[j]])))),
            hl.dict(hl.range(0, hl.len(tmp.alleles)).map(
                lambda j: hl.tuple([tmp.alleles[j], j])))))
    tmp = tmp.annotate_globals(__cols=hl.flatten(tmp.g.map(lambda g: g.__cols)))

    return tmp.drop('data', 'g')
Exemplo n.º 5
0
    def phase_diploid_proband(
            locus: hl.expr.LocusExpression,
            alleles: hl.expr.ArrayExpression,
            proband_call: hl.expr.CallExpression,
            father_call: hl.expr.CallExpression,
            mother_call: hl.expr.CallExpression
    ) -> hl.expr.ArrayExpression:
        """
        Returns phased genotype calls in the case of a diploid proband
        (autosomes, PAR regions of sex chromosomes or non-PAR regions of a female proband)

        :param LocusExpression locus: Locus in the trio MatrixTable
        :param ArrayExpression alleles: Alleles in the trio MatrixTable
        :param CallExpression proband_call: Input proband genotype call
        :param CallExpression father_call: Input father genotype call
        :param CallExpression mother_call: Input mother genotype call
        :return: Array containing: phased proband call, phased father call, phased mother call
        :rtype: ArrayExpression
        """

        proband_v = proband_call.one_hot_alleles(alleles)
        father_v = hl.cond(
            locus.in_x_nonpar() | locus.in_y_nonpar(),
            hl.or_missing(father_call.is_haploid(), hl.array([father_call.one_hot_alleles(alleles)])),
            call_to_one_hot_alleles_array(father_call, alleles)
        )
        mother_v = call_to_one_hot_alleles_array(mother_call, alleles)

        combinations = hl.flatmap(
            lambda f:
            hl.zip_with_index(mother_v)
                .filter(lambda m: m[1] + f[1] == proband_v)
                .map(lambda m: hl.struct(m=m[0], f=f[0])),
            hl.zip_with_index(father_v)
        )

        return (
            hl.or_missing(
                hl.is_defined(combinations) & (hl.len(combinations) == 1),
                hl.array([
                    hl.call(father_call[combinations[0].f], mother_call[combinations[0].m], phased=True),
                    hl.cond(father_call.is_haploid(), hl.call(father_call[0], phased=True), phase_parent_call(father_call, combinations[0].f)),
                    phase_parent_call(mother_call, combinations[0].m)
                ])
            )
        )
Exemplo n.º 6
0
    def test_reference_genome_liftover(self):
        grch37 = hl.get_reference('GRCh37')
        grch38 = hl.get_reference('GRCh38')

        self.assertTrue(not grch37.has_liftover('GRCh38') and not grch38.has_liftover('GRCh37'))
        grch37.add_liftover(resource('grch37_to_grch38_chr20.over.chain.gz'), 'GRCh38')
        grch38.add_liftover(resource('grch38_to_grch37_chr20.over.chain.gz'), 'GRCh37')
        self.assertTrue(grch37.has_liftover('GRCh38') and grch38.has_liftover('GRCh37'))

        ds = hl.import_vcf(resource('sample.vcf'))
        t = ds.annotate_rows(liftover=hl.liftover(hl.liftover(ds.locus, 'GRCh38'), 'GRCh37')).rows()
        self.assertTrue(t.all(t.locus == t.liftover))

        null_locus = hl.null(hl.tlocus('GRCh38'))

        rows = [
            {'l37': hl.locus('20', 1, 'GRCh37'), 'l38': null_locus},
            {'l37': hl.locus('20', 60000, 'GRCh37'), 'l38': null_locus},
            {'l37': hl.locus('20', 60001, 'GRCh37'), 'l38': hl.locus('chr20', 79360, 'GRCh38')},
            {'l37': hl.locus('20', 278686, 'GRCh37'), 'l38': hl.locus('chr20', 298045, 'GRCh38')},
            {'l37': hl.locus('20', 278687, 'GRCh37'), 'l38': hl.locus('chr20', 298046, 'GRCh38')},
            {'l37': hl.locus('20', 278688, 'GRCh37'), 'l38': null_locus},
            {'l37': hl.locus('20', 278689, 'GRCh37'), 'l38': null_locus},
            {'l37': hl.locus('20', 278690, 'GRCh37'), 'l38': null_locus},
            {'l37': hl.locus('20', 278691, 'GRCh37'), 'l38': hl.locus('chr20', 298047, 'GRCh38')},
            {'l37': hl.locus('20', 37007586, 'GRCh37'), 'l38': hl.locus('chr12', 32563117, 'GRCh38')},
            {'l37': hl.locus('20', 62965520, 'GRCh37'), 'l38': hl.locus('chr20', 64334167, 'GRCh38')},
            {'l37': hl.locus('20', 62965521, 'GRCh37'), 'l38': null_locus}
        ]
        schema = hl.tstruct(l37=hl.tlocus(grch37), l38=hl.tlocus(grch38))
        t = hl.Table.parallelize(rows, schema)
        self.assertTrue(t.all(hl.cond(hl.is_defined(t.l38),
                                      hl.liftover(t.l37, 'GRCh38') == t.l38,
                                      hl.is_missing(hl.liftover(t.l37, 'GRCh38')))))

        t = t.filter(hl.is_defined(t.l38))
        self.assertTrue(t.count() == 6)

        t = t.key_by('l38')
        t.count()
        self.assertTrue(list(t.key) == ['l38'])

        null_locus_interval = hl.null(hl.tinterval(hl.tlocus('GRCh38')))
        rows = [
            {'i37': hl.locus_interval('20', 1, 60000, True, False, 'GRCh37'), 'i38': null_locus_interval},
            {'i37': hl.locus_interval('20', 60001, 82456, True, True, 'GRCh37'),
             'i38': hl.locus_interval('chr20', 79360, 101815, True, True, 'GRCh38')}
        ]
        schema = hl.tstruct(i37=hl.tinterval(hl.tlocus(grch37)), i38=hl.tinterval(hl.tlocus(grch38)))
        t = hl.Table.parallelize(rows, schema)
        self.assertTrue(t.all(hl.liftover(t.i37, 'GRCh38') == t.i38))

        grch37.remove_liftover("GRCh38")
        grch38.remove_liftover("GRCh37")
Exemplo n.º 7
0
    def test_null_joins_2(self):
        tr = hl.utils.range_table(7, 1)
        table1 = tr.key_by(new_key=hl.cond((tr.idx == 3) | (tr.idx == 5),
                                           hl.null(hl.tint32), tr.idx),
                           key2=tr.idx)
        table1 = table1.select(idx1=table1.idx)
        table2 = tr.key_by(new_key=hl.cond((tr.idx == 4) | (tr.idx == 6),
                                           hl.null(hl.tint32), tr.idx),
                           key2=tr.idx)
        table2 = table2.select(idx2=table2.idx)

        left_join = table1.join(table2, 'left')
        right_join = table1.join(table2, 'right')
        inner_join = table1.join(table2, 'inner')
        outer_join = table1.join(table2, 'outer')

        def row(new_key, key2, idx1, idx2):
            return hl.Struct(new_key=new_key, key2=key2, idx1=idx1, idx2=idx2)

        left_join_expected = [row(0, 0, 0, 0), row(1, 1, 1, 1), row(2, 2, 2, 2),
                              row(4, 4, 4, None), row(6, 6, 6, None),
                              row(None, 3, 3, None), row(None, 5, 5, None)]

        right_join_expected = [row(0, 0, 0, 0), row(1, 1, 1, 1), row(2, 2, 2, 2),
                               row(3, 3, None, 3), row(5, 5, None, 5),
                               row(None, 4, None, 4), row(None, 6, None, 6)]

        inner_join_expected = [row(0, 0, 0, 0), row(1, 1, 1, 1), row(2, 2, 2, 2)]

        outer_join_expected = [row(0, 0, 0, 0), row(1, 1, 1, 1), row(2, 2, 2, 2),
                               row(3, 3, None, 3), row(4, 4, 4, None),
                               row(5, 5, None, 5), row(6, 6, 6, None),
                               row(None, 3, 3, None), row(None, 4, None, 4),
                               row(None, 5, 5, None), row(None, 6, None, 6)]

        self.assertEqual(left_join.collect(), left_join_expected)
        self.assertEqual(right_join.collect(), right_join_expected)
        self.assertEqual(inner_join.collect(), inner_join_expected)
        self.assertEqual(outer_join.collect(), outer_join_expected)
Exemplo n.º 8
0
    def test_agg_cols_explode(self):
        t = hl.utils.range_matrix_table(1, 10)

        tests = [(agg.explode(lambda elt: agg.collect(elt + 1).append(0),
                              hl.cond(t.col_idx > 7, [t.col_idx, t.col_idx + 1], hl.empty_array(hl.tint32))),
                  [9, 10, 10, 11, 0]),
                 (agg.explode(lambda elt: agg.explode(lambda elt2: agg.collect(elt2 + 1).append(0),
                                                      [elt, elt + 1]),
                              hl.cond(t.col_idx > 7, [t.col_idx, t.col_idx + 1], hl.empty_array(hl.tint32))),
                  [9, 10, 10, 11, 10, 11, 11, 12, 0]),
                 (agg.explode(lambda elt: agg.filter(elt > 8,
                                                     agg.collect(elt + 1).append(0)),
                              hl.cond(t.col_idx > 7, [t.col_idx, t.col_idx + 1], hl.empty_array(hl.tint32))),
                  [10, 10, 11, 0]),
                 (agg.explode(lambda elt: agg.group_by(elt % 3,
                                                       agg.collect(elt + 1).append(0)),
                                           hl.cond(t.col_idx > 7,
                                                   [t.col_idx, t.col_idx + 1],
                                                   hl.empty_array(hl.tint32))),
                  {0: [10, 10, 0], 1: [11, 0], 2:[9, 0]})
                 ]
        for aggregation, expected in tests:
            self.assertEqual(t.select_rows(result = aggregation).result.collect()[0], expected)
Exemplo n.º 9
0
    def test_nulls_in_distinct_joins(self):

        # MatrixAnnotateRowsTable uses left distinct join
        mr = hl.utils.range_matrix_table(7, 3, 4)
        matrix1 = mr.key_rows_by(new_key=hl.cond((mr.row_idx == 3) | (mr.row_idx == 5),
                                                hl.null(hl.tint32), mr.row_idx))
        matrix2 = mr.key_rows_by(new_key=hl.cond((mr.row_idx == 4) | (mr.row_idx == 6),
                                                hl.null(hl.tint32), mr.row_idx))

        joined = matrix1.select_rows(idx1=matrix1.row_idx,
                                     idx2=matrix2.rows()[matrix1.new_key].row_idx)

        def row(new_key, idx1, idx2):
            return hl.Struct(new_key=new_key, idx1=idx1, idx2=idx2)

        expected = [row(0, 0, 0),
                    row(1, 1, 1),
                    row(2, 2, 2),
                    row(4, 4, None),
                    row(6, 6, None),
                    row(None, 3, None),
                    row(None, 5, None)]
        self.assertEqual(joined.rows().collect(), expected)

        # union_cols uses inner distinct join
        matrix1 = matrix1.annotate_entries(ridx=matrix1.row_idx,
                                           cidx=matrix1.col_idx)
        matrix2 = matrix2.annotate_entries(ridx=matrix2.row_idx,
                                           cidx=matrix2.col_idx)
        matrix2 = matrix2.key_cols_by(col_idx=matrix2.col_idx + 3)

        expected = hl.utils.range_matrix_table(3, 6, 1)
        expected = expected.key_rows_by(new_key=expected.row_idx)
        expected = expected.annotate_entries(ridx=expected.row_idx,
                                             cidx=expected.col_idx % 3)

        self.assertTrue(matrix1.union_cols(matrix2)._same(expected))
Exemplo n.º 10
0
    def transform_entries(old_entry):
        def with_local_a_index(local_a_index):
            new_pl = hl.or_missing(
                hl.is_defined(old_entry.LPL),
                hl.or_missing(
                    hl.is_defined(local_a_index),
                    hl.range(0, 3).map(lambda i: hl.min(
                        hl.range(0, hl.triangle(hl.len(old_entry.LA)))
                            .filter(lambda j: hl.downcode(hl.unphased_diploid_gt_index_call(j), local_a_index) == hl.unphased_diploid_gt_index_call(i))
                            .map(lambda idx: old_entry.LPL[idx])))))
            fields = set(old_entry.keys())

            def with_pl(pl):
                new_exprs = {}
                dropped_fields = ['LA']
                if 'LGT' in fields:
                    new_exprs['GT'] = hl.downcode(old_entry.LGT, hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LGT')
                if 'LPGT' in fields:
                    new_exprs['PGT'] = hl.downcode(old_entry.LPGT, hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LPGT')
                if 'LAD' in fields:
                    new_exprs['AD'] = hl.or_missing(
                        hl.is_defined(old_entry.LAD),
                        [old_entry.LAD[0], hl.or_else(old_entry.LAD[local_a_index], 0)]) # second entry zeroed for lack of non-ref AD
                    dropped_fields.append('LAD')
                if 'LPL' in fields:
                    new_exprs['PL'] = pl
                    if 'GQ' in fields:
                        new_exprs['GQ'] = hl.or_else(hl.gq_from_pl(pl), old_entry.GQ)

                    dropped_fields.append('LPL')

                return hl.cond(hl.len(ds.alleles) == 1,
                                   old_entry.annotate(**{f[1:]: old_entry[f] for f in ['LGT', 'LPGT', 'LAD', 'LPL'] if f in fields}).drop(*dropped_fields),
                                   old_entry.annotate(**new_exprs).drop(*dropped_fields))

            if 'LPL' in fields:
                return hl.bind(with_pl, new_pl)
            else:
                return with_pl(None)

        lai = hl.fold(lambda accum, elt:
                        hl.cond(old_entry.LA[elt] == ds[new_id].a_index,
                                elt, accum),
                        hl.null(hl.tint32),
                        hl.range(0, hl.len(old_entry.LA)))
        return hl.bind(with_local_a_index, lai)
Exemplo n.º 11
0
def densify(sparse_mt):
    """Convert sparse MatrixTable to a dense one.

    Parameters
    ----------
    sparse_mt : :class:`.MatrixTable`
        Sparse MatrixTable to densify.  The first row key field must
        be named ``locus`` and have type ``locus``.  Must have an
        ``END`` entry field of type ``int32``.

    Returns
    -------
    :class:`.MatrixTable`
        The densified MatrixTable.  The ``END`` entry field is dropped.

    """
    if list(sparse_mt.row_key)[0] != 'locus' or not isinstance(sparse_mt.locus.dtype, hl.tlocus):
        raise ValueError("first row key field must be named 'locus' and have type 'locus'")
    if 'END' not in sparse_mt.entry or sparse_mt.END.dtype != hl.tint32:
        raise ValueError("'densify' requires 'END' entry field of type 'int32'")
    col_key_fields = list(sparse_mt.col_key)

    mt = sparse_mt
    mt = sparse_mt.annotate_entries(__contig = mt.locus.contig)
    t = mt._localize_entries('__entries', '__cols')
    t = t.annotate(
        __entries = hl.rbind(
            hl.scan.array_agg(
                lambda entry: hl.scan._prev_nonnull(hl.or_missing(hl.is_defined(entry.END), entry)),
                t.__entries),
            lambda prev_entries: hl.map(
                lambda i:
                hl.rbind(
                    prev_entries[i], t.__entries[i],
                    lambda prev_entry, entry:
                    hl.cond(
                        (~hl.is_defined(entry) &
                         (prev_entry.END >= t.locus.position) &
                         (prev_entry.__contig == t.locus.contig)),
                        prev_entry,
                        entry)),
                hl.range(0, hl.len(t.__entries)))))
    mt = t._unlocalize_entries('__entries', '__cols', col_key_fields)
    mt = mt.drop('__contig', 'END')
    return mt
Exemplo n.º 12
0
 def test_computed_key_join_3(self):
     # duplicate row keys
     ds = self.get_vds()
     kt = hl.Table.parallelize(
         [{'culprit': 'InbreedingCoeff', 'foo': 'bar', 'value': 'IB'}],
         hl.tstruct(culprit=hl.tstr, foo=hl.tstr, value=hl.tstr),
         key=['culprit', 'foo'])
     ds = ds.annotate_rows(
         dsfoo='bar',
         info=ds.info.annotate(culprit=[ds.info.culprit, "foo"]))
     ds = ds.explode_rows(ds.info.culprit)
     ds = ds.annotate_rows(value=kt[ds.info.culprit, ds.dsfoo]['value'])
     rt = ds.rows()
     self.assertTrue(
         rt.all(hl.cond(
             rt.info.culprit == "InbreedingCoeff",
             rt['value'] == "IB",
             hl.is_missing(rt['value']))))
Exemplo n.º 13
0
 def test_agg_cols_group_by(self):
     t = hl.utils.range_matrix_table(1, 10)
     tests = [(agg.group_by(t.col_idx % 2,
                            hl.array(agg.collect_as_set(t.col_idx + 1)).append(0)),
               {0: [1, 3, 5, 7, 9, 0], 1: [2, 4, 6, 8, 10, 0]}),
              (agg.group_by(t.col_idx % 3,
                            agg.filter(t.col_idx > 7,
                                       hl.array(agg.collect_as_set(t.col_idx + 1)).append(0))),
               {0: [10, 0], 1: [0], 2: [9, 0]}),
              (agg.group_by(t.col_idx % 3,
                            agg.explode(lambda elt: agg.collect(elt + 1).append(0),
                                        hl.cond(t.col_idx > 7,
                                                [t.col_idx, t.col_idx + 1],
                                                hl.empty_array(hl.tint32)))),
               {0: [10, 11, 0], 1: [0], 2:[9, 10, 0]}),
              ]
     for aggregation, expected in tests:
         self.assertEqual(t.select_rows(result = aggregation).result.collect()[0], expected)
Exemplo n.º 14
0
    def call_to_one_hot_alleles_array(call: hl.expr.CallExpression, alleles: hl.expr.ArrayExpression) -> hl.expr.ArrayExpression:
        """
        Get the set of all different one-hot-encoded allele-vectors in a genotype call.
        It is returned as an ordered array where the first vector corresponds to the first allele,
        and the second vector (only present if het) the second allele.

        :param CallExpression call: genotype
        :param ArrayExpression alleles: Alleles at the site
        :return: Array of one-hot-encoded alleles
        :rtype: ArrayExpression
        """
        return hl.cond(
            call.is_het(),
            hl.array([
                hl.call(call[0]).one_hot_alleles(alleles),
                hl.call(call[1]).one_hot_alleles(alleles),
            ]),
            hl.array([hl.call(call[0]).one_hot_alleles(alleles)])
        )
Exemplo n.º 15
0
def transform_one(mt: MatrixTable) -> MatrixTable:
    """transforms a gvcf into a form suitable for combining"""
    mt = mt.annotate_entries(
        # local (alt) allele index into global (alt) alleles
        LA=hl.range(0, hl.len(mt.alleles) - 1),
        END=mt.info.END,
        PL=mt['PL'][0:],
        BaseQRankSum=mt.info['BaseQRankSum'],
        ClippingRankSum=mt.info['ClippingRankSum'],
        MQ=mt.info['MQ'],
        MQRankSum=mt.info['MQRankSum'],
        ReadPosRankSum=mt.info['ReadPosRankSum'],
    )
    # This collects all fields with median combiners into arrays so we can calculate medians
    # when needed
    mt = mt.annotate_rows(
        # now minrep'ed (ref, alt) allele pairs
        alleles=hl.bind(lambda ref: mt.alleles[1:].map(lambda alt:
                                                       # minrep <NON_REF>
                                                       hl.struct(ref=hl.cond(alt == "<NON_REF>",
                                                                             ref[0:1],
                                                                             ref),
                                                                 alt=alt)),
                        mt.alleles[0]),
        info=mt.info.annotate(
            SB=hl.agg.array_sum(mt.entry.SB)
        ).select(
            "DP",
            "MQ_DP",
            "QUALapprox",
            "RAW_MQ",
            "VarDP",
            "SB",
        ))
    mt = mt.drop('SB', 'qual')

    return mt
Exemplo n.º 16
0
            def with_pl(pl):
                new_exprs = {}
                dropped_fields = ['LA']
                if 'LGT' in fields:
                    new_exprs['GT'] = hl.downcode(old_entry.LGT, hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LGT')
                if 'LPGT' in fields:
                    new_exprs['PGT'] = hl.downcode(old_entry.LPGT, hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LPGT')
                if 'LAD' in fields:
                    new_exprs['AD'] = hl.or_missing(
                        hl.is_defined(old_entry.LAD),
                        [old_entry.LAD[0], hl.or_else(old_entry.LAD[local_a_index], 0)]) # second entry zeroed for lack of non-ref AD
                    dropped_fields.append('LAD')
                if 'LPL' in fields:
                    new_exprs['PL'] = pl
                    if 'GQ' in fields:
                        new_exprs['GQ'] = hl.or_else(hl.gq_from_pl(pl), old_entry.GQ)

                    dropped_fields.append('LPL')

                return hl.cond(hl.len(ds.alleles) == 1,
                                   old_entry.annotate(**{f[1:]: old_entry[f] for f in ['LGT', 'LPGT', 'LAD', 'LPL'] if f in fields}).drop(*dropped_fields),
                                   old_entry.annotate(**new_exprs).drop(*dropped_fields))
Exemplo n.º 17
0
Arquivo: qc.py Projeto: tpoterba/hail
def variant_qc(mt, name='variant_qc') -> MatrixTable:
    """Compute common variant statistics (quality control metrics).

    .. include:: ../_templates/req_tvariant.rst

    Examples
    --------

    >>> dataset_result = hl.variant_qc(dataset)

    Notes
    -----
    This method computes variant statistics from the genotype data, returning
    a new struct field `name` with the following metrics based on the fields
    present in the entry schema.

    If `mt` contains an entry field `DP` of type :py:data:`.tint32`, then the
    field `dp_stats` is computed. If `mt` contains an entry field `GQ` of type
    :py:data:`.tint32`, then the field `gq_stats` is computed. Both `dp_stats`
    and `gq_stats` are structs with with four fields:

    - `mean` (``float64``) -- Mean value.
    - `stdev` (``float64``) -- Standard deviation (zero degrees of freedom).
    - `min` (``int32``) -- Minimum value.
    - `max` (``int32``) -- Maximum value.

    If the dataset does not contain an entry field `GT` of type
    :py:data:`.tcall`, then an error is raised. The following fields are always
    computed from `GT`:

    - `AF` (``array<float64>``) -- Calculated allele frequency, one element
      per allele, including the reference. Sums to one. Equivalent to
      `AC` / `AN`.
    - `AC` (``array<int32>``) -- Calculated allele count, one element per
      allele, including the reference. Sums to `AN`.
    - `AN` (``int32``) -- Total number of called alleles.
    - `homozygote_count` (``array<int32>``) -- Number of homozygotes per
      allele. One element per allele, including the reference.
    - `n_called` (``int64``) -- Number of samples with a defined `GT`.
    - `n_not_called` (``int64``) -- Number of samples with a missing `GT`.
    - `call_rate` (``float32``) -- Fraction of samples with a defined `GT`.
      Equivalent to `n_called` / :meth:`.count_cols`.
    - `n_het` (``int64``) -- Number of heterozygous samples.
    - `n_non_ref` (``int64``) -- Number of samples with at least one called
      non-reference allele.
    - `het_freq_hwe` (``float64``) -- Expected frequency of heterozygous
      samples under Hardy-Weinberg equilibrium. See
      :func:`.functions.hardy_weinberg_test` for details.
    - `p_value_hwe` (``float64``) -- p-value from test of Hardy-Weinberg equilibrium.
      See :func:`.functions.hardy_weinberg_test` for details.

    Warning
    -------
    `het_freq_hwe` and `p_value_hwe` are calculated as in
    :func:`.functions.hardy_weinberg_test`, with non-diploid calls
    (``ploidy != 2``) ignored in the counts. As this test is only
    statistically rigorous in the biallelic setting, :func:`.variant_qc`
    sets both fields to missing for multiallelic variants. Consider using
    :func:`~hail.methods.split_multi` to split multi-allelic variants beforehand.

    Parameters
    ----------
    mt : :class:`.MatrixTable`
        Dataset.
    name : :obj:`str`
        Name for resulting field.

    Returns
    -------
    :class:`.MatrixTable`
    """
    require_row_key_variant(mt, 'variant_qc')

    exprs = {}
    struct_exprs = []

    def has_field_of_type(name, dtype):
        return name in mt.entry and mt[name].dtype == dtype

    n_samples = mt.count_cols()

    if has_field_of_type('DP', hl.tint32):
        exprs['dp_stats'] = hl.agg.stats(mt.DP).select('mean', 'stdev', 'min', 'max')

    if has_field_of_type('GQ', hl.tint32):
        exprs['gq_stats'] = hl.agg.stats(mt.GQ).select('mean', 'stdev', 'min', 'max')

    if not has_field_of_type('GT',  hl.tcall):
        raise ValueError(f"'variant_qc': expect an entry field 'GT' of type 'call'")
    exprs['n_called'] = hl.agg.count_where(hl.is_defined(mt['GT']))
    struct_exprs.append(hl.agg.call_stats(mt.GT, mt.alleles))


    # the structure of this function makes it easy to add new nested computations
    def flatten_struct(*struct_exprs):
        flat = {}
        for struct in struct_exprs:
            for k, v in struct.items():
                flat[k] = v
        return hl.struct(
            **flat,
            **exprs,
        )

    mt = mt.annotate_rows(**{name: hl.bind(flatten_struct, *struct_exprs)})

    hwe = hl.hardy_weinberg_test(mt[name].homozygote_count[0],
                                 mt[name].AC[1] - 2 * mt[name].homozygote_count[1],
                                 mt[name].homozygote_count[1])
    hwe = hwe.select(het_freq_hwe=hwe.het_freq_hwe, p_value_hwe=hwe.p_value)
    mt = mt.annotate_rows(**{name: mt[name].annotate(n_not_called=n_samples - mt[name].n_called,
                                                     call_rate=mt[name].n_called / n_samples,
                                                     n_het=mt[name].n_called - hl.sum(mt[name].homozygote_count),
                                                     n_non_ref=mt[name].n_called - mt[name].homozygote_count[0],
                                                     **hl.cond(hl.len(mt.alleles) == 2,
                                                               hwe,
                                                               hl.null(hwe.dtype)))})
    return mt
Exemplo n.º 18
0
def transmission_disequilibrium_test(dataset, pedigree) -> Table:
    r"""Performs the transmission disequilibrium test on trios.

    .. include:: ../_templates/req_tstring.rst

    .. include:: ../_templates/req_tvariant.rst

    .. include:: ../_templates/req_biallelic.rst

    Examples
    --------
    Compute TDT association statistics and show the first two results:
    
    >>> pedigree = hl.Pedigree.read('data/tdt_trios.fam')
    >>> tdt_table = hl.transmission_disequilibrium_test(tdt_dataset, pedigree)
    >>> tdt_table.show(2)  # doctest: +NOTEST
    +---------------+------------+-------+-------+----------+----------+
    | locus         | alleles    |     t |     u |   chi_sq |  p_value |
    +---------------+------------+-------+-------+----------+----------+
    | locus<GRCh37> | array<str> | int64 | int64 |  float64 |  float64 |
    +---------------+------------+-------+-------+----------+----------+
    | 1:246714629   | ["C","A"]  |     0 |     4 | 4.00e+00 | 4.55e-02 |
    | 2:167262169   | ["T","C"]  |    NA |    NA |       NA |       NA |
    +---------------+------------+-------+-------+----------+----------+

    Export variants with p-values below 0.001:

    >>> tdt_table = tdt_table.filter(tdt_table.p_value < 0.001)
    >>> tdt_table.export("output/tdt_results.tsv")

    Notes
    -----
    The
    `transmission disequilibrium test <https://en.wikipedia.org/wiki/Transmission_disequilibrium_test#The_case_of_trios:_one_affected_child_per_family>`__
    compares the number of times the alternate allele is transmitted (t) versus
    not transmitted (u) from a heterozgyous parent to an affected child. The null
    hypothesis holds that each case is equally likely. The TDT statistic is given by

    .. math::

        (t - u)^2 \over (t + u)

    and asymptotically follows a chi-squared distribution with one degree of
    freedom under the null hypothesis.

    :func:`transmission_disequilibrium_test` only considers complete trios (two
    parents and a proband with defined sex) and only returns results for the
    autosome, as defined by :meth:`~hail.genetics.Locus.in_autosome`, and
    chromosome X. Transmissions and non-transmissions are counted only for the
    configurations of genotypes and copy state in the table below, in order to
    filter out Mendel errors and configurations where transmission is
    guaranteed. The copy state of a locus with respect to a trio is defined as
    follows:

    - Auto -- in autosome or in PAR of X or female child
    - HemiX -- in non-PAR of X and male child

    Here PAR is the `pseudoautosomal region
    <https://en.wikipedia.org/wiki/Pseudoautosomal_region>`__
    of X and Y defined by :class:`.ReferenceGenome`, which many variant callers
    map to chromosome X.

    +--------+--------+--------+------------+---+---+
    |  Kid   | Dad    | Mom    | Copy State | t | u |
    +========+========+========+============+===+===+
    | HomRef | Het    | Het    | Auto       | 0 | 2 |
    +--------+--------+--------+------------+---+---+
    | HomRef | HomRef | Het    | Auto       | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | HomRef | Het    | HomRef | Auto       | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | Het    | Het    | Het    | Auto       | 1 | 1 |
    +--------+--------+--------+------------+---+---+
    | Het    | HomRef | Het    | Auto       | 1 | 0 |
    +--------+--------+--------+------------+---+---+
    | Het    | Het    | HomRef | Auto       | 1 | 0 |
    +--------+--------+--------+------------+---+---+
    | Het    | HomVar | Het    | Auto       | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | Het    | Het    | HomVar | Auto       | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | HomVar | Het    | Het    | Auto       | 2 | 0 |
    +--------+--------+--------+------------+---+---+
    | HomVar | Het    | HomVar | Auto       | 1 | 0 |
    +--------+--------+--------+------------+---+---+
    | HomVar | HomVar | Het    | Auto       | 1 | 0 |
    +--------+--------+--------+------------+---+---+
    | HomRef | HomRef | Het    | HemiX      | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | HomRef | HomVar | Het    | HemiX      | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | HomVar | HomRef | Het    | HemiX      | 1 | 0 |
    +--------+--------+--------+------------+---+---+
    | HomVar | HomVar | Het    | HemiX      | 1 | 0 |
    +--------+--------+--------+------------+---+---+

    :func:`tdt` produces a table with the following columns:

     - `locus` (:class:`.tlocus`) -- Locus.
     - `alleles` (:class:`.tarray` of :py:data:`.tstr`) -- Alleles.
     - `t` (:py:data:`.tint32`) -- Number of transmitted alternate alleles.
     - `u` (:py:data:`.tint32`) -- Number of untransmitted alternate alleles.
     - `chi_sq` (:py:data:`.tfloat64`) -- TDT statistic.
     - `p_value` (:py:data:`.tfloat64`) -- p-value.

    Parameters
    ----------
    dataset : :class:`.MatrixTable`
        Dataset.
    pedigree : :class:`~hail.genetics.Pedigree`
        Sample pedigree.

    Returns
    -------
    :class:`.Table`
        Table of TDT results.
    """

    dataset = require_biallelic(dataset, 'transmission_disequilibrium_test')
    dataset = dataset.annotate_rows(auto_or_x_par=dataset.locus.in_autosome()
                                    | dataset.locus.in_x_par())
    dataset = dataset.filter_rows(dataset.auto_or_x_par
                                  | dataset.locus.in_x_nonpar())

    hom_ref = 0
    het = 1
    hom_var = 2

    auto = 2
    hemi_x = 1

    #                     kid,     dad,     mom,   copy, t, u
    config_counts = [(hom_ref, het, het, auto, 0, 2),
                     (hom_ref, hom_ref, het, auto, 0, 1),
                     (hom_ref, het, hom_ref, auto, 0, 1),
                     (het, het, het, auto, 1, 1),
                     (het, hom_ref, het, auto, 1, 0),
                     (het, het, hom_ref, auto, 1, 0),
                     (het, hom_var, het, auto, 0, 1),
                     (het, het, hom_var, auto, 0, 1),
                     (hom_var, het, het, auto, 2, 0),
                     (hom_var, het, hom_var, auto, 1, 0),
                     (hom_var, hom_var, het, auto, 1, 0),
                     (hom_ref, hom_ref, het, hemi_x, 0, 1),
                     (hom_ref, hom_var, het, hemi_x, 0, 1),
                     (hom_var, hom_ref, het, hemi_x, 1, 0),
                     (hom_var, hom_var, het, hemi_x, 1, 0)]

    count_map = hl.literal({(c[0], c[1], c[2], c[3]): [c[4], c[5]]
                            for c in config_counts})

    tri = trio_matrix(dataset, pedigree, complete_trios=True)

    # this filter removes mendel error of het father in x_nonpar. It also avoids
    #   building and looking up config in common case that neither parent is het
    father_is_het = tri.father_entry.GT.is_het()
    parent_is_valid_het = ((father_is_het & tri.auto_or_x_par) |
                           (tri.mother_entry.GT.is_het() & ~father_is_het))

    copy_state = hl.cond(tri.auto_or_x_par | tri.is_female, 2, 1)

    config = (tri.proband_entry.GT.n_alt_alleles(),
              tri.father_entry.GT.n_alt_alleles(),
              tri.mother_entry.GT.n_alt_alleles(), copy_state)

    tri = tri.annotate_rows(counts=agg.filter(
        parent_is_valid_het, agg.array_sum(count_map.get(config))))

    tab = tri.rows().select('counts')
    tab = tab.transmute(t=tab.counts[0], u=tab.counts[1])
    tab = tab.annotate(chi_sq=((tab.t - tab.u)**2) / (tab.t + tab.u))
    tab = tab.annotate(p_value=hl.pchisqtail(tab.chi_sq, 1.0))

    return tab.cache()
def create_binned_data_initial(ht: hl.Table, data: str, data_type: str,
                               n_bins: int) -> hl.Table:
    # Count variants for ranking
    count_expr = {
        x: hl.agg.filter(
            hl.is_defined(ht[x]),
            hl.agg.counter(
                hl.cond(hl.is_snp(ht.alleles[0], ht.alleles[1]), 'snv',
                        'indel')))
        for x in ht.row if x.endswith('rank')
    }
    rank_variant_counts = ht.aggregate(hl.Struct(**count_expr))
    logger.info(
        f"Found the following variant counts:\n {pformat(rank_variant_counts)}"
    )
    ht_truth_data = hl.read_table(f"{lustre_dir}/variant_qc/truthset.ht")
    ht = ht.annotate_globals(rank_variant_counts=rank_variant_counts)
    ht = ht.annotate(
        **ht_truth_data[ht.key],
        # **fam_ht[ht.key],
        # **gnomad_ht[ht.key],
        # **denovo_ht[ht.key],
        # clinvar=hl.is_defined(clinvar_ht[ht.key]),
        indel_length=hl.abs(ht.alleles[0].length() - ht.alleles[1].length()),
        rank_bins=hl.array([
            hl.Struct(
                rank_id=rank_name,
                bin=hl.int(
                    hl.ceil(
                        hl.float(ht[rank_name] + 1) / hl.floor(
                            ht.globals.rank_variant_counts[rank_name][hl.cond(
                                hl.is_snp(ht.alleles[0], ht.alleles[1]), 'snv',
                                'indel')] / n_bins))))
            for rank_name in rank_variant_counts
        ]),
        # lcr=hl.is_defined(lcr_intervals[ht.locus])
    )

    ht = ht.explode(ht.rank_bins)
    ht = ht.transmute(rank_id=ht.rank_bins.rank_id, bin=ht.rank_bins.bin)
    ht = ht.filter(hl.is_defined(ht.bin))

    ht = ht.checkpoint(f'{tmp_dir}/gnomad_score_binning_tmp.ht',
                       overwrite=True)

    # Create binned data
    return (ht.group_by(
        rank_id=ht.rank_id,
        contig=ht.locus.contig,
        snv=hl.is_snp(ht.alleles[0], ht.alleles[1]),
        #bi_allelic=hl.is_defined(ht.biallelic_rank),
        singleton=ht.transmitted_singleton,
        trans_singletons=hl.is_defined(ht.singleton_rank),
        de_novo_high_quality=ht.de_novo_high_quality_rank,
        de_novo_medium_quality=hl.is_defined(ht.de_novo_medium_quality_rank),
        de_novo_synonymous=hl.is_defined(ht.de_novo_synonymous_rank),
        # release_adj=ht.ac > 0,
        bin=ht.bin
    )._set_buffer_size(20000).aggregate(
        min_score=hl.agg.min(ht.score),
        max_score=hl.agg.max(ht.score),
        n=hl.agg.count(),
        n_ins=hl.agg.count_where(hl.is_insertion(ht.alleles[0],
                                                 ht.alleles[1])),
        n_del=hl.agg.count_where(hl.is_deletion(ht.alleles[0], ht.alleles[1])),
        n_ti=hl.agg.count_where(hl.is_transition(ht.alleles[0],
                                                 ht.alleles[1])),
        n_tv=hl.agg.count_where(
            hl.is_transversion(ht.alleles[0], ht.alleles[1])),
        n_1bp_indel=hl.agg.count_where(ht.indel_length == 1),
        n_mod3bp_indel=hl.agg.count_where((ht.indel_length % 3) == 0),
        # n_clinvar=hl.agg.count_where(ht.clinvar),
        n_singleton=hl.agg.count_where(ht.transmitted_singleton),
        n_high_quality_de_novos=hl.agg.count_where(
            ht.de_novo_data.p_de_novo[0] > 0.99),
        n_medium_quality_de_novos=hl.agg.count_where(
            ht.de_novo_data.p_de_novo[0] > 0.5),
        n_high_confidence_de_novos=hl.agg.count_where(
            ht.de_novo_data.confidence[0] == 'HIGH'),
        n_de_novo=hl.agg.filter(
            ht.family_stats.unrelated_qc_callstats.AC[0][1] == 0,
            hl.agg.sum(ht.family_stats.mendel[0].errors)),
        n_high_quality_de_novos_synonymous=hl.agg.count_where(
            (ht.de_novo_data.p_de_novo[0] > 0.99)
            & (ht.consequence == "synonymous_variant")),
        # n_de_novo_no_lcr=hl.agg.filter(~ht.lcr & (
        #    ht.family_stats.unrelated_qc_callstats.AC[1] == 0), hl.agg.sum(ht.family_stats.mendel.errors)),
        n_de_novo_sites=hl.agg.filter(
            ht.family_stats.unrelated_qc_callstats.AC[0][1] == 0,
            hl.agg.count_where(ht.family_stats.mendel[0].errors > 0)),
        # n_de_novo_sites_no_lcr=hl.agg.filter(~ht.lcr & (
        #    ht.family_stats.unrelated_qc_callstats.AC[1] == 0), hl.agg.count_where(ht.family_stats.mendel.errors > 0)),
        n_trans_singletons_old=hl.agg.filter(
            (ht.ac_raw < 3) &
            (ht.family_stats.unrelated_qc_callstats.AC[0][1] == 1),
            hl.agg.sum(ht.family_stats.tdt[0].t)),
        n_trans_singletons=hl.agg.filter(ht.ac_raw == 2,
                                         hl.agg.sum(ht.fam.n_transmitted_raw)),
        n_untrans_singletons=hl.agg.filter(
            (ht.ac_raw < 3) & (ht.ac_qc_samples_raw == 1),
            hl.agg.sum(ht.fam.n_untransmitted_raw),
        ),
        n_trans_singletons_synonymous=hl.agg.count_where(
            ht.variant_transmitted_singletons > 0),
        n_trans_singletons_synonymous_original=hl.agg.filter(
            (ht.ac_raw < 3) & (ht.consequence == "synonymous_variant") &
            (ht.family_stats.unrelated_qc_callstats.AC[0][1] == 1),
            hl.agg.sum(ht.family_stats.tdt[0].t)),
        n_untrans_singletons_old=hl.agg.filter(
            (ht.ac_raw < 3) &
            (ht.family_stats.unrelated_qc_callstats.AC[0][1] == 1),
            hl.agg.sum(ht.family_stats.tdt[0].u)),
        n_untrans_singletons_synonymous=hl.agg.count_where(
            ht.variant_untransmitted_singletons > 0),
        n_untrans_singletons_synonymous_original=hl.agg.filter(
            (ht.ac_raw < 3) & (ht.consequence == "synonymous_variant") &
            (ht.family_stats.unrelated_qc_callstats.AC[0][1] == 1),
            hl.agg.sum(ht.family_stats.tdt[0].u)),
        n_train_trans_singletons=hl.agg.count_where(
            (ht.family_stats.unrelated_qc_callstats.AC[0][1] == 1)
            & (ht.family_stats.tdt[0].t == 1)),
        n_omni=hl.agg.count_where(ht.omni),
        n_mills=hl.agg.count_where(ht.mills),
        n_hapmap=hl.agg.count_where(ht.hapmap),
        n_kgp_high_conf_snvs=hl.agg.count_where(ht.kgp_phase1_hc),
        fail_hard_filters=hl.agg.count_where(ht.fail_hard_filters),
        # n_vqsr_pos_train=hl.agg.count_where(ht.vqsr_positive_train_site),
        # n_vqsr_neg_train=hl.agg.count_where(ht.vqsr_negative_train_site)
    ))
Exemplo n.º 20
0
    def test_annotate(self):
        schema = hl.tstruct(a=hl.tint32, b=hl.tint32, c=hl.tint32, d=hl.tint32, e=hl.tstr, f=hl.tarray(hl.tint32))

        rows = [{'a': 4, 'b': 1, 'c': 3, 'd': 5, 'e': "hello", 'f': [1, 2, 3]},
                {'a': 0, 'b': 5, 'c': 13, 'd': -1, 'e': "cat", 'f': []},
                {'a': 4, 'b': 2, 'c': 20, 'd': 3, 'e': "dog", 'f': [5, 6, 7]}]

        kt = hl.Table.parallelize(rows, schema)

        self.assertTrue(kt.annotate()._same(kt))

        result1 = convert_struct_to_dict(kt.annotate(foo=kt.a + 1,
                                                     foo2=kt.a).take(1)[0])

        self.assertDictEqual(result1, {'a': 4,
                                       'b': 1,
                                       'c': 3,
                                       'd': 5,
                                       'e': "hello",
                                       'f': [1, 2, 3],
                                       'foo': 5,
                                       'foo2': 4})

        result3 = convert_struct_to_dict(kt.annotate(
            x1=kt.f.map(lambda x: x * 2),
            x2=kt.f.map(lambda x: [x, x + 1]).flatmap(lambda x: x),
            x3=hl.min(kt.f),
            x4=hl.max(kt.f),
            x5=hl.sum(kt.f),
            x6=hl.product(kt.f),
            x7=kt.f.length(),
            x8=kt.f.filter(lambda x: x == 3),
            x9=kt.f[1:],
            x10=kt.f[:],
            x11=kt.f[1:2],
            x12=kt.f.map(lambda x: [x, x + 1]),
            x13=kt.f.map(lambda x: [[x, x + 1], [x + 2]]).flatmap(lambda x: x),
            x14=hl.cond(kt.a < kt.b, kt.c, hl.null(hl.tint32)),
            x15={1, 2, 3}
        ).take(1)[0])

        self.assertDictEqual(result3, {'a': 4,
                                       'b': 1,
                                       'c': 3,
                                       'd': 5,
                                       'e': "hello",
                                       'f': [1, 2, 3],
                                       'x1': [2, 4, 6], 'x2': [1, 2, 2, 3, 3, 4],
                                       'x3': 1, 'x4': 3, 'x5': 6, 'x6': 6, 'x7': 3, 'x8': [3],
                                       'x9': [2, 3], 'x10': [1, 2, 3], 'x11': [2],
                                       'x12': [[1, 2], [2, 3], [3, 4]],
                                       'x13': [[1, 2], [3], [2, 3], [4], [3, 4], [5]],
                                       'x14': None, 'x15': set([1, 2, 3])})
        kt.annotate(
            x1=kt.a + 5,
            x2=5 + kt.a,
            x3=kt.a + kt.b,
            x4=kt.a - 5,
            x5=5 - kt.a,
            x6=kt.a - kt.b,
            x7=kt.a * 5,
            x8=5 * kt.a,
            x9=kt.a * kt.b,
            x10=kt.a / 5,
            x11=5 / kt.a,
            x12=kt.a / kt.b,
            x13=-kt.a,
            x14=+kt.a,
            x15=kt.a == kt.b,
            x16=kt.a == 5,
            x17=5 == kt.a,
            x18=kt.a != kt.b,
            x19=kt.a != 5,
            x20=5 != kt.a,
            x21=kt.a > kt.b,
            x22=kt.a > 5,
            x23=5 > kt.a,
            x24=kt.a >= kt.b,
            x25=kt.a >= 5,
            x26=5 >= kt.a,
            x27=kt.a < kt.b,
            x28=kt.a < 5,
            x29=5 < kt.a,
            x30=kt.a <= kt.b,
            x31=kt.a <= 5,
            x32=5 <= kt.a,
            x33=(kt.a == 0) & (kt.b == 5),
            x34=(kt.a == 0) | (kt.b == 5),
            x35=False,
            x36=True
        )
Exemplo n.º 21
0
def dmg_case_control_filter(mt,chrm='0',start=0,end=0,out_dir=os.getcwd(),name=""):
	"""Filters matrix table that has been annotated with annovar for case control test.""" 
	#convert chrm number to string if input is integer
	chrm = str(chrm)
	
	if name == "":
		name = "case_control_filtering_{}_{}-{}".format(chrm,start,end)
	
	child_dir = os.path.join(out_dir,"case_control_filtering",name)
	os.makedirs(child_dir)
	
	out_f = os.path.join(child_dir,"case_control_filtering_{}_{}-{}.out".format(chrm,start,end))
	with open(out_f, "w") as f:
		#Import matrix table already annotated with bravo and MetaSVM (In this case, using annovar)
		mt = hl.read_matrix_table(mt)
		#Make annotations easier for filtering. If no bravo frequency for variants, set bravo freq to 0
		mt = mt.annotate_rows(bravo=hl.cond(hl.is_defined(mt.info["bravo"][0]), hl.float64(mt.info["bravo"][0]), 0.0))
		mt = mt.annotate_rows(meta_svm_pred=mt.info.MetaSVM_pred[0])

		db_ht = hl.read_table('/gpfs/ycga/project/kahle/sp2349/datasets/dbNSFP/dbNSFp4.1a/dbNSFP4.1a_chr_all_GRCh37.ht')

		#Annotate mt with CADD16 scores. Original vcf was annotated with CADD13
		mt = mt.annotate_rows(cadd16=hl.cond(hl.is_defined(db_ht[mt.row_key]), db_ht[mt.row_key].CADD_phred_hg19, 0.0))

		#Subset probands for case_control

		# Step 1: Create a text file with the sample IDs you want to keep and import that text file as a hail table. 
		table = (hl.import_table('/gpfs/ycga/project/kahle/sp2349/moyamoya/case_control/cc_output/final_probands_for_caseControl.txt', impute=True).key_by('Sample'))

		#The IDs_keep.txt file has the following format (including headers)
		# Sample        should_retain
		# 1-00005       yes
		# 1-00187       yes
		# 1-00252       yes
		# 1-00386       yes
		# 1-00668       yes
		# etc etc

		# Annotate columns of matrix table with Sample-IDs you want to keep
		mt = mt.annotate_cols(is_retain = table[mt.s])
		mt = mt.annotate_cols(should_retain = table[mt.s].should_retain)


		# Filter matrix table columns 
		mt = mt.filter_cols(mt.col.is_retain.should_retain == 'yes', keep=True)
		mt = mt.filter_cols(mt.should_retain == 'yes', keep=True)
		
		sample_count = mt.cols().count()
		print("Sample count: {}".format(sample_count),file=f)
		print("Total allele count: {}".format(sample_count*2),file=f)

		#Filter on Bravo frequency
		mt_filtered = mt.filter_rows(mt.bravo <= 0.0005)

		#Filter on DIAPH1 coordinates
		mt_filtered = mt_filtered.filter_rows((mt_filtered.locus >= hl.locus(chrm,start)) & (mt_filtered.locus <= hl.locus(chrm,end)))

		print("Unique variants post bravo, CADD, and MetaSVM: {}".format(mt_filtered.rows().count()),file=f)

		#Filter for exonic and splice-site variants only
		mt_filtered = mt_filtered.filter_rows((mt_filtered.vep.most_severe_consequence == "stop_gained") | 
											  (mt_filtered.vep.most_severe_consequence == "splice_acceptor_variant") | 
											  (mt_filtered.vep.most_severe_consequence =="splice_donor_variant") | 
											  (mt_filtered.vep.most_severe_consequence == "frameshift_variant") | 
											  (mt_filtered.vep.most_severe_consequence =="stop_lost") | 
											  (mt_filtered.vep.most_severe_consequence =="start_lost") | 
											  ((mt_filtered.vep.most_severe_consequence =='missense_variant') & (mt_filtered.cadd16 >=20)) |
											  ((mt_filtered.vep.most_severe_consequence =='missense_variant') & (mt_filtered.meta_svm_pred == "D")) |
											  ((mt_filtered.vep.most_severe_consequence =='protein_altering_variant') & (mt_filtered.cadd16 >=20)) |
											  ((mt_filtered.vep.most_severe_consequence =='protein_altering_variant') & (mt_filtered.meta_svm_pred == "D")))
		print("Variants: {} kept".format(mt_filtered.count()))

		mt_filtered.count()

		#Convert matrix table to table for easier dropping of homozygous reference samples
		mt_filtered = mt_filtered.key_cols_by()
		mt_filtered_table = mt_filtered.entries()

		#Filter samples with homozygous reference calls (WT)
		mt_filtered_table = mt_filtered_table.filter(mt_filtered_table.GT.is_hom_ref() == True, keep=False)
		#mt_filtered_table.show()

		#Filter samples on GQ >= 20 and DP >= 8
		mt_filtered_table = mt_filtered_table.filter(mt_filtered_table.DP > 9)
		mt_filtered_table = mt_filtered_table.filter(mt_filtered_table.GQ > 19)
		
		#Write to matrix table
		mt_filtered_table.write(os.path.join(out_dir,'Damaging_Cases_chr{}_{}-{}.ht'.format(chrm,start,end)),overwrite=True)
		
		print("Total Variants post filtering: {}".format(mt_filtered_table.count()),file=f)
		
		print("Total Cases (Alleles): {}".format(mt_filtered_table.aggregate(hl.agg.sum(mt_filtered_table.GT.n_alt_alleles()))),file=f)
		print("Total Homozygous Cases: {}".format(mt_filtered_table.aggregate(hl.agg.count_where(mt_filtered_table.GT.is_hom_var() == True))),file=f)
		print("Samples with variants: {}".format(mt_filtered_table.aggregate(hl.agg.counter(mt_filtered_table.s))),file=f)
	
		#Write to text file
		df = mt_filtered_table.to_pandas()
		df.to_csv(os.path.join(out_dir,'Damaging_Cases_chr{}_{}-{}_table.txt'.format(chrm,start,end)),sep="\t")
Exemplo n.º 22
0
def sparse_split_multi(sparse_mt):
    """Splits multiallelic variants on a sparse MatrixTable.

    Takes a dataset formatted like the output of :func:`.vcf_combiner`. The
    splitting will add `was_split` and `a_index` fields, as :func:`.split_multi`
    does. This function drops the `LA` (local alleles) field, as it re-computes
    entry fields based on the new, split globals alleles.

    Variants are split thus:

    - A row with only one (reference) or two (reference and alternate) alleles.

    - A row with multiple alternate alleles  will be split, with one row for
      each alternate allele, and each row will contain two alleles: ref and alt.
      The reference and alternate allele will be minrepped using
      :func:`.min_rep`.

    The split multi logic handles the following entry fields:

        .. code-block:: text

          struct {
            LGT: call
            LAD: array<int32>
            DP: int32
            GQ: int32
            LPL: array<int32>
            RGQ: int32
            LPGT: call
            LA: array<int32>
            END: int32
          }

    All fields except for `LA` are optional, and only handled if they exist.

    - `LA` is used to find the corresponding local allele index for the desired
      global `a_index`, and then dropped from the resulting dataset. If `LA`
      does not contain the global `a_index`, the index for the `<NON_REF>`
      allele is used to process the entry fields.

    - `LGT` and `LPGT` are downcoded using the corresponding local `a_index`.
      They are renamed to `GT` and `PGT` respectively, as the resulting call is
      no longer local.

    - `LAD` is used to create an `AD` field consisting of the allele depths
      corresponding to the reference, global `a_index` allele, and `<NON_REF>`
      allele.

    - `DP` is preserved unchanged.

    - `GQ` is recalculated from the updated `PL`, if it exists, but otherwise
      preserved unchanged.

    - `PL` array elements are calculated from the minimum `LPL` value for all
      allele pairs that downcode to the desired one. (This logic is identical to
      the `PL` logic in :func:`.split_mult_hts`; if a row has an alternate
      allele but it is not present in `LA`, the `PL` field is set to missing.
      The `PL` for `ref/<NON_REF>` in that case can be drawn from `RGQ`.

    - `RGQ` (the ref genotype quality) is preserved unchanged.

    - `END` is untouched.

    Notes
    -----
    This version of split-multi doesn't deal with either duplicate loci (in
    which case the explode could possibly result in out-of-order rows, although
    the actual split_multi function also doesn't handle that case).

    It also checks that min-repping will not change the locus and will error if
    it does. (I believe the VCF combiner checks that this holds true,
    currently.)

    Parameters
    ----------
    sparse_mt : :class:`.MatrixTable`
        Sparse MatrixTable to split.

    Returns
    -------
    :class:`.MatrixTable`
        The split MatrixTable in sparse format.

    """

    hl.methods.misc.require_row_key_variant(sparse_mt, "sparse_split_multi")

    entries = hl.utils.java.Env.get_uid()
    cols = hl.utils.java.Env.get_uid()
    ds = sparse_mt.localize_entries(entries, cols)
    new_id = hl.utils.java.Env.get_uid()

    def struct_from_min_rep(i):
        return hl.bind(lambda mr:
                       (hl.case()
                        .when(ds.locus == mr.locus,
                              hl.struct(
                                  locus=ds.locus,
                                  alleles=[mr.alleles[0], mr.alleles[1]],
                                  a_index=i,
                                  was_split=True))
                        .or_error("Found non-left-aligned variant in sparse_split_multi")),
                       hl.min_rep(ds.locus, [ds.alleles[0], ds.alleles[i]]))

    explode_structs = hl.cond(hl.len(ds.alleles) < 3,
                              [hl.struct(
                                  locus=ds.locus,
                                  alleles=ds.alleles,
                                  a_index=1,
                                  was_split=False)],
                              hl._sort_by(
                                  hl.range(1, hl.len(ds.alleles))
                                      .map(struct_from_min_rep),
                                  lambda l, r: hl._compare(l.alleles, r.alleles) < 0
                              ))

    ds = ds.annotate(**{new_id: explode_structs}).explode(new_id)

    def transform_entries(old_entry):
        def with_local_a_index(local_a_index):
            new_pl = hl.or_missing(
                hl.is_defined(old_entry.LPL),
                hl.or_missing(
                    hl.is_defined(local_a_index),
                    hl.range(0, 3).map(lambda i: hl.min(
                        hl.range(0, hl.triangle(hl.len(old_entry.LA)))
                            .filter(lambda j: hl.downcode(hl.unphased_diploid_gt_index_call(j), local_a_index) == hl.unphased_diploid_gt_index_call(i))
                            .map(lambda idx: old_entry.LPL[idx])))))
            fields = set(old_entry.keys())

            def with_pl(pl):
                new_exprs = {}
                dropped_fields = ['LA']
                if 'LGT' in fields:
                    new_exprs['GT'] = hl.downcode(old_entry.LGT, hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LGT')
                if 'LPGT' in fields:
                    new_exprs['PGT'] = hl.downcode(old_entry.LPGT, hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LPGT')
                if 'LAD' in fields:
                    new_exprs['AD'] = hl.or_missing(
                        hl.is_defined(old_entry.LAD),
                        [old_entry.LAD[0], hl.or_else(old_entry.LAD[local_a_index], 0)]) # second entry zeroed for lack of non-ref AD
                    dropped_fields.append('LAD')
                if 'LPL' in fields:
                    new_exprs['PL'] = pl
                    if 'GQ' in fields:
                        new_exprs['GQ'] = hl.or_else(hl.gq_from_pl(pl), old_entry.GQ)

                    dropped_fields.append('LPL')

                return hl.cond(hl.len(ds.alleles) == 1,
                                   old_entry.annotate(**{f[1:]: old_entry[f] for f in ['LGT', 'LPGT', 'LAD', 'LPL'] if f in fields}).drop(*dropped_fields),
                                   old_entry.annotate(**new_exprs).drop(*dropped_fields))

            if 'LPL' in fields:
                return hl.bind(with_pl, new_pl)
            else:
                return with_pl(None)

        lai = hl.fold(lambda accum, elt:
                        hl.cond(old_entry.LA[elt] == ds[new_id].a_index,
                                elt, accum),
                        hl.null(hl.tint32),
                        hl.range(0, hl.len(old_entry.LA)))
        return hl.bind(with_local_a_index, lai)

    new_row = ds.row.annotate(**{
        'locus': ds[new_id].locus,
        'alleles': ds[new_id].alleles,
        'a_index': ds[new_id].a_index,
        'was_split': ds[new_id].was_split,
        entries: ds[entries].map(transform_entries)
    }).drop(new_id)

    ds = hl.Table(
        hl.ir.TableKeyBy(
            hl.ir.TableMapRows(
                hl.ir.TableKeyBy(ds._tir, ['locus']),
                new_row._ir),
            ['locus', 'alleles'],
            is_sorted=True))
    return ds._unlocalize_entries(entries, cols, list(sparse_mt.col_key.keys()))
Exemplo n.º 23
0
 def get_allele_type(allele_idx):
     return hl.cond(allele_idx > 0, mt[variant_atypes][allele_idx - 1],
                    hl.null(hl.tint32))
Exemplo n.º 24
0
def ld_score_regression(weight_expr,
                        ld_score_expr,
                        chi_sq_exprs,
                        n_samples_exprs,
                        n_blocks=200,
                        two_step_threshold=30,
                        n_reference_panel_variants=None) -> Table:
    r"""Estimate SNP-heritability and level of confounding biases from
    GWAS summary statistics.

    Given a set or multiple sets of genome-wide association study (GWAS)
    summary statistics, :func:`.ld_score_regression` estimates the heritability
    of a trait or set of traits and the level of confounding biases present in
    the underlying studies by regressing chi-squared statistics on LD scores,
    leveraging the model:

    .. math::

        \mathrm{E}[\chi_j^2] = 1 + Na + \frac{Nh_g^2}{M}l_j

    *  :math:`\mathrm{E}[\chi_j^2]` is the expected chi-squared statistic
       for variant :math:`j` resulting from a test of association between
       variant :math:`j` and a trait.
    *  :math:`l_j = \sum_{k} r_{jk}^2` is the LD score of variant
       :math:`j`, calculated as the sum of squared correlation coefficients
       between variant :math:`j` and nearby variants. See :func:`ld_score`
       for further details.
    *  :math:`a` captures the contribution of confounding biases, such as
       cryptic relatedness and uncontrolled population structure, to the
       association test statistic.
    *  :math:`h_g^2` is the SNP-heritability, or the proportion of variation
       in the trait explained by the effects of variants included in the
       regression model above.
    *  :math:`M` is the number of variants used to estimate :math:`h_g^2`.
    *  :math:`N` is the number of samples in the underlying association study.

    For more details on the method implemented in this function, see:

    * `LD Score regression distinguishes confounding from polygenicity in genome-wide association studies (Bulik-Sullivan et al, 2015) <https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4495769/>`__

    Examples
    --------

    Run the method on a matrix table of summary statistics, where the rows
    are variants and the columns are different phenotypes:

    >>> mt_gwas = hl.read_matrix_table('data/ld_score_regression.sumstats.mt')
    >>> ht_results = hl.experimental.ld_score_regression(
    ...     weight_expr=mt_gwas['ld_score'],
    ...     ld_score_expr=mt_gwas['ld_score'],
    ...     chi_sq_exprs=mt_gwas['chi_squared'],
    ...     n_samples_exprs=mt_gwas['n'])


    Run the method on a table with summary statistics for a single
    phenotype:

    >>> ht_gwas = hl.read_table('data/ld_score_regression.sumstats.ht')
    >>> ht_results = hl.experimental.ld_score_regression(
    ...     weight_expr=ht_gwas['ld_score'],
    ...     ld_score_expr=ht_gwas['ld_score'],
    ...     chi_sq_exprs=ht_gwas['chi_squared_50_irnt'],
    ...     n_samples_exprs=ht_gwas['n_50_irnt'])

    Run the method on a table with summary statistics for multiple
    phenotypes:

    >>> ht_gwas = hl.read_table('data/ld_score_regression.sumstats.ht')
    >>> ht_results = hl.experimental.ld_score_regression(
    ...     weight_expr=ht_gwas['ld_score'],
    ...     ld_score_expr=ht_gwas['ld_score'],
    ...     chi_sq_exprs=[ht_gwas['chi_squared_50_irnt'],
    ...                        ht_gwas['chi_squared_20160']],
    ...     n_samples_exprs=[ht_gwas['n_50_irnt'],
    ...                      ht_gwas['n_20160']])

    Notes
    -----
    The ``exprs`` provided as arguments to :func:`.ld_score_regression`
    must all be from the same object, either a :class:`Table` or a
    :class:`MatrixTable`.

    **If the arguments originate from a table:**

    *  The table must be keyed by fields ``locus`` of type
       :class:`.tlocus` and ``alleles``, a :py:data:`.tarray` of
       :py:data:`.tstr` elements.
    *  ``weight_expr``, ``ld_score_expr``, ``chi_sq_exprs``, and
       ``n_samples_exprs`` are must be row-indexed fields.
    *  The number of expressions passed to ``n_samples_exprs`` must be
       equal to one or the number of expressions passed to
       ``chi_sq_exprs``. If just one expression is passed to
       ``n_samples_exprs``, that sample size expression is assumed to
       apply to all sets of statistics passed to ``chi_sq_exprs``.
       Otherwise, the expressions passed to ``chi_sq_exprs`` and
       ``n_samples_exprs`` are matched by index.
    *  The ``phenotype`` field that keys the table returned by
       :func:`.ld_score_regression` will have generic :obj:`int` values
       ``0``, ``1``, etc. corresponding to the ``0th``, ``1st``, etc.
       expressions passed to the ``chi_sq_exprs`` argument.

    **If the arguments originate from a matrix table:**

    *  The dimensions of the matrix table must be variants
       (rows) by phenotypes (columns).
    *  The rows of the matrix table must be keyed by fields
       ``locus`` of type :class:`.tlocus` and ``alleles``,
       a :py:data:`.tarray` of :py:data:`.tstr` elements.
    *  The columns of the matrix table must be keyed by a field
       of type :py:data:`.tstr` that uniquely identifies phenotypes
       represented in the matrix table. The column key must be a single
       expression; compound keys are not accepted.
    *  ``weight_expr`` and ``ld_score_expr`` must be row-indexed
       fields.
    *  ``chi_sq_exprs`` must be a single entry-indexed field
       (not a list of fields).
    *  ``n_samples_exprs`` must be a single entry-indexed field
       (not a list of fields).
    *  The ``phenotype`` field that keys the table returned by
       :func:`.ld_score_regression` will have values corresponding to the
       column keys of the input matrix table.

    This function returns a :class:`Table` with one row per set of summary
    statistics passed to the ``chi_sq_exprs`` argument. The following
    row-indexed fields are included in the table:

    *  **phenotype** (:py:data:`.tstr`) -- The name of the phenotype. The
       returned table is keyed by this field. See the notes below for
       details on the possible values of this field.
    *  **mean_chi_sq** (:py:data:`.tfloat64`) -- The mean chi-squared
       test statistic for the given phenotype.
    *  **intercept** (`Struct`) -- Contains fields:

       -  **estimate** (:py:data:`.tfloat64`) -- A point estimate of the
          intercept :math:`1 + Na`.
       -  **standard_error**  (:py:data:`.tfloat64`) -- An estimate of
          the standard error of this point estimate.

    *  **snp_heritability** (`Struct`) -- Contains fields:

       -  **estimate** (:py:data:`.tfloat64`) -- A point estimate of the
          SNP-heritability :math:`h_g^2`.
       -  **standard_error** (:py:data:`.tfloat64`) -- An estimate of
          the standard error of this point estimate.

    Warning
    -------
    :func:`.ld_score_regression` considers only the rows for which both row
    fields ``weight_expr`` and ``ld_score_expr`` are defined. Rows with missing
    values in either field are removed prior to fitting the LD score
    regression model.

    Parameters
    ----------
    weight_expr : :class:`.Float64Expression`
                  Row-indexed expression for the LD scores used to derive
                  variant weights in the model.
    ld_score_expr : :class:`.Float64Expression`
                    Row-indexed expression for the LD scores used as covariates
                    in the model.
    chi_sq_exprs : :class:`.Float64Expression` or :obj:`list` of
                        :class:`.Float64Expression`
                        One or more row-indexed (if table) or entry-indexed
                        (if matrix table) expressions for chi-squared
                        statistics resulting from genome-wide association
                        studies.
    n_samples_exprs: :class:`.NumericExpression` or :obj:`list` of
                     :class:`.NumericExpression`
                     One or more row-indexed (if table) or entry-indexed
                     (if matrix table) expressions indicating the number of
                     samples used in the studies that generated the test
                     statistics supplied to ``chi_sq_exprs``.
    n_blocks : :obj:`int`
               The number of blocks used in the jackknife approach to
               estimating standard errors.
    two_step_threshold : :obj:`int`
                         Variants with chi-squared statistics greater than this
                         value are excluded in the first step of the two-step
                         procedure used to fit the model.
    n_reference_panel_variants : :obj:`int`, optional
                                 Number of variants used to estimate the
                                 SNP-heritability :math:`h_g^2`.

    Returns
    -------
    :class:`.Table`
        Table keyed by ``phenotype`` with intercept and heritability estimates
        for each phenotype passed to the function."""

    chi_sq_exprs = wrap_to_list(chi_sq_exprs)
    n_samples_exprs = wrap_to_list(n_samples_exprs)

    assert ((len(chi_sq_exprs) == len(n_samples_exprs)) or
            (len(n_samples_exprs) == 1))
    __k = 2  # number of covariates, including intercept

    ds = chi_sq_exprs[0]._indices.source

    analyze('ld_score_regression/weight_expr',
            weight_expr,
            ds._row_indices)
    analyze('ld_score_regression/ld_score_expr',
            ld_score_expr,
            ds._row_indices)

    # format input dataset
    if isinstance(ds, MatrixTable):
        if len(chi_sq_exprs) != 1:
            raise ValueError("""Only one chi_sq_expr allowed if originating
                from a matrix table.""")
        if len(n_samples_exprs) != 1:
            raise ValueError("""Only one n_samples_expr allowed if
                originating from a matrix table.""")

        col_key = list(ds.col_key)
        if len(col_key) != 1:
            raise ValueError("""Matrix table must be keyed by a single
                phenotype field.""")

        analyze('ld_score_regression/chi_squared_expr',
                chi_sq_exprs[0],
                ds._entry_indices)
        analyze('ld_score_regression/n_samples_expr',
                n_samples_exprs[0],
                ds._entry_indices)

        ds = ds._select_all(row_exprs={'__locus': ds.locus,
                                       '__alleles': ds.alleles,
                                       '__w_initial': weight_expr,
                                       '__w_initial_floor': hl.max(weight_expr,
                                                                   1.0),
                                       '__x': ld_score_expr,
                                       '__x_floor': hl.max(ld_score_expr,
                                                           1.0)},
                            row_key=['__locus', '__alleles'],
                            col_exprs={'__y_name': ds[col_key[0]]},
                            col_key=['__y_name'],
                            entry_exprs={'__y': chi_sq_exprs[0],
                                         '__n': n_samples_exprs[0]})
        ds = ds.annotate_entries(**{'__w': ds.__w_initial})

        ds = ds.filter_rows(hl.is_defined(ds.__locus) &
                            hl.is_defined(ds.__alleles) &
                            hl.is_defined(ds.__w_initial) &
                            hl.is_defined(ds.__x))

    else:
        assert isinstance(ds, Table)
        for y in chi_sq_exprs:
            analyze('ld_score_regression/chi_squared_expr', y, ds._row_indices)
        for n in n_samples_exprs:
            analyze('ld_score_regression/n_samples_expr', n, ds._row_indices)

        ys = ['__y{:}'.format(i) for i, _ in enumerate(chi_sq_exprs)]
        ws = ['__w{:}'.format(i) for i, _ in enumerate(chi_sq_exprs)]
        ns = ['__n{:}'.format(i) for i, _ in enumerate(n_samples_exprs)]

        ds = ds.select(**dict(**{'__locus': ds.locus,
                                 '__alleles': ds.alleles,
                                 '__w_initial': weight_expr,
                                 '__x': ld_score_expr},
                              **{y: chi_sq_exprs[i]
                                 for i, y in enumerate(ys)},
                              **{w: weight_expr for w in ws},
                              **{n: n_samples_exprs[i]
                                 for i, n in enumerate(ns)}))
        ds = ds.key_by(ds.__locus, ds.__alleles)

        table_tmp_file = new_temp_file()
        ds.write(table_tmp_file)
        ds = hl.read_table(table_tmp_file)

        hts = [ds.select(**{'__w_initial': ds.__w_initial,
                            '__w_initial_floor': hl.max(ds.__w_initial,
                                                        1.0),
                            '__x': ds.__x,
                            '__x_floor': hl.max(ds.__x, 1.0),
                            '__y_name': i,
                            '__y': ds[ys[i]],
                            '__w': ds[ws[i]],
                            '__n': hl.int(ds[ns[i]])})
               for i, y in enumerate(ys)]

        mts = [ht.to_matrix_table(row_key=['__locus',
                                           '__alleles'],
                                  col_key=['__y_name'],
                                  row_fields=['__w_initial',
                                              '__w_initial_floor',
                                              '__x',
                                              '__x_floor'])
               for ht in hts]

        ds = mts[0]
        for i in range(1, len(ys)):
            ds = ds.union_cols(mts[i])

        ds = ds.filter_rows(hl.is_defined(ds.__locus) &
                            hl.is_defined(ds.__alleles) &
                            hl.is_defined(ds.__w_initial) &
                            hl.is_defined(ds.__x))

    mt_tmp_file1 = new_temp_file()
    ds.write(mt_tmp_file1)
    mt = hl.read_matrix_table(mt_tmp_file1)

    if not n_reference_panel_variants:
        M = mt.count_rows()
    else:
        M = n_reference_panel_variants

    # block variants for each phenotype
    n_phenotypes = mt.count_cols()

    mt = mt.annotate_entries(__in_step1=(hl.is_defined(mt.__y) &
                                         (mt.__y < two_step_threshold)),
                             __in_step2=hl.is_defined(mt.__y))

    mt = mt.annotate_cols(__col_idx=hl.int(hl.scan.count()),
                          __m_step1=hl.agg.count_where(mt.__in_step1),
                          __m_step2=hl.agg.count_where(mt.__in_step2))

    col_keys = list(mt.col_key)

    ht = mt.localize_entries(entries_array_field_name='__entries',
                             columns_array_field_name='__cols')

    ht = ht.annotate(__entries=hl.rbind(
        hl.scan.array_agg(
            lambda entry: hl.scan.count_where(entry.__in_step1),
            ht.__entries),
        lambda step1_indices: hl.map(
            lambda i: hl.rbind(
                hl.int(hl.or_else(step1_indices[i], 0)),
                ht.__cols[i].__m_step1,
                ht.__entries[i],
                lambda step1_idx, m_step1, entry: hl.rbind(
                    hl.map(
                        lambda j: hl.int(hl.floor(j * (m_step1 / n_blocks))),
                        hl.range(0, n_blocks + 1)),
                    lambda step1_separators: hl.rbind(
                        hl.set(step1_separators).contains(step1_idx),
                        hl.sum(
                            hl.map(
                                lambda s1: step1_idx >= s1,
                                step1_separators)) - 1,
                        lambda is_separator, step1_block: entry.annotate(
                            __step1_block=step1_block,
                            __step2_block=hl.cond(~entry.__in_step1 & is_separator,
                                                  step1_block - 1,
                                                  step1_block))))),
            hl.range(0, hl.len(ht.__entries)))))

    mt = ht._unlocalize_entries('__entries', '__cols', col_keys)

    mt_tmp_file2 = new_temp_file()
    mt.write(mt_tmp_file2)
    mt = hl.read_matrix_table(mt_tmp_file2)
    
    # initial coefficient estimates
    mt = mt.annotate_cols(__initial_betas=[
        1.0, (hl.agg.mean(mt.__y) - 1.0) / hl.agg.mean(mt.__x)])
    mt = mt.annotate_cols(__step1_betas=mt.__initial_betas,
                          __step2_betas=mt.__initial_betas)

    # step 1 iteratively reweighted least squares
    for i in range(3):
        mt = mt.annotate_entries(__w=hl.cond(
            mt.__in_step1,
            1.0/(mt.__w_initial_floor * 2.0 * (mt.__step1_betas[0] +
                                               mt.__step1_betas[1] *
                                               mt.__x_floor)**2),
            0.0))
        mt = mt.annotate_cols(__step1_betas=hl.agg.filter(
            mt.__in_step1,
            hl.agg.linreg(y=mt.__y,
                          x=[1.0, mt.__x],
                          weight=mt.__w).beta))
        mt = mt.annotate_cols(__step1_h2=hl.max(hl.min(
            mt.__step1_betas[1] * M / hl.agg.mean(mt.__n), 1.0), 0.0))
        mt = mt.annotate_cols(__step1_betas=[
            mt.__step1_betas[0],
            mt.__step1_h2 * hl.agg.mean(mt.__n) / M])

    # step 1 block jackknife
    mt = mt.annotate_cols(__step1_block_betas=[
        hl.agg.filter((mt.__step1_block != i) & mt.__in_step1,
                      hl.agg.linreg(y=mt.__y,
                                    x=[1.0, mt.__x],
                                    weight=mt.__w).beta)
        for i in range(n_blocks)])

    mt = mt.annotate_cols(__step1_block_betas_bias_corrected=hl.map(
        lambda x: n_blocks * mt.__step1_betas - (n_blocks - 1) * x,
        mt.__step1_block_betas))

    mt = mt.annotate_cols(
        __step1_jackknife_mean=hl.map(
            lambda i: hl.mean(
                hl.map(lambda x: x[i],
                       mt.__step1_block_betas_bias_corrected)),
            hl.range(0, __k)),
        __step1_jackknife_variance=hl.map(
            lambda i: (hl.sum(
                hl.map(lambda x: x[i]**2,
                       mt.__step1_block_betas_bias_corrected)) -
                       hl.sum(
                hl.map(lambda x: x[i],
                       mt.__step1_block_betas_bias_corrected))**2 /
                       n_blocks) /
            (n_blocks - 1) / n_blocks,
            hl.range(0, __k)))

    # step 2 iteratively reweighted least squares
    for i in range(3):
        mt = mt.annotate_entries(__w=hl.cond(
            mt.__in_step2,
            1.0/(mt.__w_initial_floor *
                 2.0 * (mt.__step2_betas[0] +
                        mt.__step2_betas[1] *
                        mt.__x_floor)**2),
            0.0))
        mt = mt.annotate_cols(__step2_betas=[
            mt.__step1_betas[0],
            hl.agg.filter(mt.__in_step2,
                          hl.agg.linreg(y=mt.__y - mt.__step1_betas[0],
                                        x=[mt.__x],
                                        weight=mt.__w).beta[0])])
        mt = mt.annotate_cols(__step2_h2=hl.max(hl.min(
            mt.__step2_betas[1] * M/hl.agg.mean(mt.__n), 1.0), 0.0))
        mt = mt.annotate_cols(__step2_betas=[
            mt.__step1_betas[0],
            mt.__step2_h2 * hl.agg.mean(mt.__n)/M])

    # step 2 block jackknife
    mt = mt.annotate_cols(__step2_block_betas=[
        hl.agg.filter((mt.__step2_block != i) & mt.__in_step2,
                      hl.agg.linreg(y=mt.__y - mt.__step1_betas[0],
                                    x=[mt.__x],
                                    weight=mt.__w).beta[0])
        for i in range(n_blocks)])

    mt = mt.annotate_cols(__step2_block_betas_bias_corrected=hl.map(
        lambda x: n_blocks * mt.__step2_betas[1] - (n_blocks - 1) * x,
        mt.__step2_block_betas))

    mt = mt.annotate_cols(
        __step2_jackknife_mean=hl.mean(
            mt.__step2_block_betas_bias_corrected),
        __step2_jackknife_variance=(
            hl.sum(mt.__step2_block_betas_bias_corrected**2) -
            hl.sum(mt.__step2_block_betas_bias_corrected)**2 /
            n_blocks) / (n_blocks - 1) / n_blocks)

    # combine step 1 and step 2 block jackknifes
    mt = mt.annotate_entries(
        __step2_initial_w=1.0/(mt.__w_initial_floor *
                               2.0 * (mt.__initial_betas[0] +
                                      mt.__initial_betas[1] *
                                      mt.__x_floor)**2))

    mt = mt.annotate_cols(
        __final_betas=[
            mt.__step1_betas[0],
            mt.__step2_betas[1]],
        __c=(hl.agg.sum(mt.__step2_initial_w * mt.__x) /
             hl.agg.sum(mt.__step2_initial_w * mt.__x**2)))

    mt = mt.annotate_cols(__final_block_betas=hl.map(
        lambda i: (mt.__step2_block_betas[i] - mt.__c *
                   (mt.__step1_block_betas[i][0] - mt.__final_betas[0])),
        hl.range(0, n_blocks)))

    mt = mt.annotate_cols(
        __final_block_betas_bias_corrected=(n_blocks * mt.__final_betas[1] -
                                            (n_blocks - 1) *
                                            mt.__final_block_betas))

    mt = mt.annotate_cols(
        __final_jackknife_mean=[
            mt.__step1_jackknife_mean[0],
            hl.mean(mt.__final_block_betas_bias_corrected)],
        __final_jackknife_variance=[
            mt.__step1_jackknife_variance[0],
            (hl.sum(mt.__final_block_betas_bias_corrected**2) -
             hl.sum(mt.__final_block_betas_bias_corrected)**2 /
             n_blocks) / (n_blocks - 1) / n_blocks])

    # convert coefficient to heritability estimate
    mt = mt.annotate_cols(
        phenotype=mt.__y_name,
        mean_chi_sq=hl.agg.mean(mt.__y),
        intercept=hl.struct(
            estimate=mt.__final_betas[0],
            standard_error=hl.sqrt(mt.__final_jackknife_variance[0])),
        snp_heritability=hl.struct(
            estimate=(M/hl.agg.mean(mt.__n)) * mt.__final_betas[1],
            standard_error=hl.sqrt((M/hl.agg.mean(mt.__n))**2 *
                                   mt.__final_jackknife_variance[1])))

    # format and return results
    ht = mt.cols()
    ht = ht.key_by(ht.phenotype)
    ht = ht.select(ht.mean_chi_sq,
                   ht.intercept,
                   ht.snp_heritability)

    ht_tmp_file = new_temp_file()
    ht.write(ht_tmp_file)
    ht = hl.read_table(ht_tmp_file)
    
    return ht
Exemplo n.º 25
0
    def test_annotate(self):
        schema = hl.tstruct(a=hl.tint32,
                            b=hl.tint32,
                            c=hl.tint32,
                            d=hl.tint32,
                            e=hl.tstr,
                            f=hl.tarray(hl.tint32))

        rows = [{
            'a': 4,
            'b': 1,
            'c': 3,
            'd': 5,
            'e': "hello",
            'f': [1, 2, 3]
        }, {
            'a': 0,
            'b': 5,
            'c': 13,
            'd': -1,
            'e': "cat",
            'f': []
        }, {
            'a': 4,
            'b': 2,
            'c': 20,
            'd': 3,
            'e': "dog",
            'f': [5, 6, 7]
        }]

        kt = hl.Table.parallelize(rows, schema)

        self.assertTrue(kt.annotate()._same(kt))

        result1 = convert_struct_to_dict(
            kt.annotate(foo=kt.a + 1, foo2=kt.a).take(1)[0])

        self.assertDictEqual(
            result1, {
                'a': 4,
                'b': 1,
                'c': 3,
                'd': 5,
                'e': "hello",
                'f': [1, 2, 3],
                'foo': 5,
                'foo2': 4
            })

        result3 = convert_struct_to_dict(
            kt.annotate(x1=kt.f.map(lambda x: x * 2),
                        x2=kt.f.map(lambda x: [x, x + 1]).flatmap(lambda x: x),
                        x3=hl.min(kt.f),
                        x4=hl.max(kt.f),
                        x5=hl.sum(kt.f),
                        x6=hl.product(kt.f),
                        x7=kt.f.length(),
                        x8=kt.f.filter(lambda x: x == 3),
                        x9=kt.f[1:],
                        x10=kt.f[:],
                        x11=kt.f[1:2],
                        x12=kt.f.map(lambda x: [x, x + 1]),
                        x13=kt.f.map(lambda x: [[x, x + 1], [x + 2]]).flatmap(
                            lambda x: x),
                        x14=hl.cond(kt.a < kt.b, kt.c, hl.null(hl.tint32)),
                        x15={1, 2, 3}).take(1)[0])

        self.assertDictEqual(
            result3, {
                'a': 4,
                'b': 1,
                'c': 3,
                'd': 5,
                'e': "hello",
                'f': [1, 2, 3],
                'x1': [2, 4, 6],
                'x2': [1, 2, 2, 3, 3, 4],
                'x3': 1,
                'x4': 3,
                'x5': 6,
                'x6': 6,
                'x7': 3,
                'x8': [3],
                'x9': [2, 3],
                'x10': [1, 2, 3],
                'x11': [2],
                'x12': [[1, 2], [2, 3], [3, 4]],
                'x13': [[1, 2], [3], [2, 3], [4], [3, 4], [5]],
                'x14': None,
                'x15': set([1, 2, 3])
            })
        kt.annotate(x1=kt.a + 5,
                    x2=5 + kt.a,
                    x3=kt.a + kt.b,
                    x4=kt.a - 5,
                    x5=5 - kt.a,
                    x6=kt.a - kt.b,
                    x7=kt.a * 5,
                    x8=5 * kt.a,
                    x9=kt.a * kt.b,
                    x10=kt.a / 5,
                    x11=5 / kt.a,
                    x12=kt.a / kt.b,
                    x13=-kt.a,
                    x14=+kt.a,
                    x15=kt.a == kt.b,
                    x16=kt.a == 5,
                    x17=5 == kt.a,
                    x18=kt.a != kt.b,
                    x19=kt.a != 5,
                    x20=5 != kt.a,
                    x21=kt.a > kt.b,
                    x22=kt.a > 5,
                    x23=5 > kt.a,
                    x24=kt.a >= kt.b,
                    x25=kt.a >= 5,
                    x26=5 >= kt.a,
                    x27=kt.a < kt.b,
                    x28=kt.a < 5,
                    x29=5 < kt.a,
                    x30=kt.a <= kt.b,
                    x31=kt.a <= 5,
                    x32=5 <= kt.a,
                    x33=(kt.a == 0) & (kt.b == 5),
                    x34=(kt.a == 0) | (kt.b == 5),
                    x35=False,
                    x36=True)
Exemplo n.º 26
0
def main():
    args = parse_args()

    tables = []
    for i, path in enumerate(args.paths):

        ht = import_SJ_out_tab(path)
        ht = ht.key_by("chrom", "start_1based", "end_1based")

        if args.normalize_read_counts:
            ht = ht.annotate_globals(
                unique_reads_in_sample=ht.aggregate(hl.agg.sum(
                    ht.unique_reads)),
                multi_mapped_reads_in_sample=ht.aggregate(
                    hl.agg.sum(ht.multi_mapped_reads)),
            )

        # add 'interval' column
        #ht = ht.annotate(interval=hl.interval(
        #    hl.locus(ht.chrom, ht.start_1based, reference_genome=reference_genome),
        #    hl.locus(ht.chrom, ht.end_1based, reference_genome=reference_genome),))

        tables.append(ht)

    ## use zip-join
    combined_ht2 = hl.Table.multi_way_zip_join(tables, "data", "globals")
    #combined_ht2 = combined_ht2.annotate(
    #    strand=hl.agg.array_agg(lambda elem: hl.agg.max(elem.strand), combined_ht2.data))
    print(combined_ht2.describe())
    combined_ht2.export("SJ.out.combined2.tab", header=True)

    return

    # compute mean
    if args.normalize_read_counts:
        mean_unique_reads_in_sample = sum(
            [hl.eval(ht.unique_reads_in_sample)
             for ht in tables]) / float(len(tables))
        mean_multi_mapped_reads_in_sample = sum(
            [hl.eval(ht.multi_mapped_reads_in_sample)
             for ht in tables]) / float(len(tables))
        print(
            f"mean_unique_reads_in_sample: {mean_unique_reads_in_sample:01f}, mean_multi_mapped_reads_in_sample: {mean_multi_mapped_reads_in_sample:01f}"
        )

    combined_ht = None
    for i, ht in enumerate(tables):
        print(f"Processing table #{i} out of {len(tables)}")

        if args.normalize_read_counts:
            unique_reads_multiplier = mean_unique_reads_in_sample / float(
                hl.eval(ht.unique_reads_in_sample))
            multi_mapped_reads_multiplier = mean_multi_mapped_reads_in_sample / float(
                hl.eval(ht.multi_mapped_reads_in_sample))
            print(
                f"unique_reads_multiplier: {unique_reads_multiplier:01f}, multi_mapped_reads_multiplier: {multi_mapped_reads_multiplier:01f}"
            )

        ht = ht.annotate(
            strand_counter=hl.or_else(
                hl.switch(ht.strand).when(1, 1).when(2, -1).or_missing(), 0),
            num_samples_with_this_junction=1,
        )

        if args.normalize_read_counts:
            ht = ht.annotate(
                unique_reads=hl.int32(ht.unique_reads *
                                      unique_reads_multiplier),
                multi_mapped_reads=hl.int32(ht.multi_mapped_reads *
                                            multi_mapped_reads_multiplier),
            )

        if combined_ht is None:
            combined_ht = ht
            continue

        print("----")
        print_stats(path, ht)

        combined_ht = combined_ht.join(ht, how="outer")
        combined_ht = combined_ht.transmute(
            strand=hl.or_else(
                combined_ht.strand, combined_ht.strand_1
            ),  ## in rare cases, the strand for the same junction may differ across samples, so use a 2-step process that assigns strand based on majority of samples
            strand_counter=hl.sum([
                combined_ht.strand_counter, combined_ht.strand_counter_1
            ]),  # samples vote on whether strand = 1 (eg. '+') or 2 (eg. '-')
            intron_motif=hl.or_else(combined_ht.intron_motif,
                                    combined_ht.intron_motif_1
                                    ),  ## double-check that left == right?
            known_splice_junction=hl.or_else(
                hl.cond((combined_ht.known_splice_junction == 1) |
                        (combined_ht.known_splice_junction_1 == 1), 1, 0),
                0),  ## double-check that left == right?
            unique_reads=hl.sum(
                [combined_ht.unique_reads, combined_ht.unique_reads_1]),
            multi_mapped_reads=hl.sum([
                combined_ht.multi_mapped_reads,
                combined_ht.multi_mapped_reads_1
            ]),
            maximum_overhang=hl.max(
                [combined_ht.maximum_overhang,
                 combined_ht.maximum_overhang_1]),
            num_samples_with_this_junction=hl.sum([
                combined_ht.num_samples_with_this_junction,
                combined_ht.num_samples_with_this_junction_1
            ]),
        )

        combined_ht = combined_ht.checkpoint(
            f"checkpoint{i % 2}.ht", overwrite=True)  #, _read_if_exists=True)

    total_junctions_count = combined_ht.count()
    strand_conflicts_count = combined_ht.filter(
        hl.abs(combined_ht.strand_counter) /
        hl.float(combined_ht.num_samples_with_this_junction) < 0.1,
        keep=True).count()

    # set final strand value to 1 (eg. '+') or 2 (eg. '-') or 0 (eg. uknown) based on the setting in the majority of samples
    combined_ht = combined_ht.annotate(
        strand=hl.case().when(combined_ht.strand_counter > 0, 1).when(
            combined_ht.strand_counter < 0, 2).default(0))

    combined_ht = combined_ht.annotate_globals(combined_tables=args.paths,
                                               n_combined_tables=len(
                                                   args.paths))

    if strand_conflicts_count:
        print(
            f"WARNING: Found {strand_conflicts_count} strand_conflicts out of {total_junctions_count} total_junctions"
        )

    # write as HT
    combined_ht = combined_ht.checkpoint(
        f"combined.SJ.out.ht", overwrite=True)  #, _read_if_exists=True)

    ## write as tsv
    output_prefix = f"combined.{len(tables)}_samples{'.normalized_counts' if args.normalize_read_counts else ''}"
    combined_ht = combined_ht.key_by()
    combined_ht.export(f"{output_prefix}.with_header.combined.SJ.out.tab",
                       header=True)
    combined_ht = combined_ht.select(
        "chrom",
        "start_1based",
        "end_1based",
        "strand",
        "intron_motif",
        "known_splice_junction",
        "unique_reads",
        "multi_mapped_reads",
        "maximum_overhang",
    )
    combined_ht.export(f"{output_prefix}.SJ.out.tab", header=False)

    print(
        f"unique_reads_in combined table: {combined_ht.aggregate(hl.agg.sum(combined_ht.unique_reads))}"
    )
Exemplo n.º 27
0
 def field_to_array(ds, field):
     return hl.cond(ds[field] != 0, hl.array([field]), hl.empty_array(hl.tstr))
Exemplo n.º 28
0
def transmission_disequilibrium_test(dataset, pedigree) -> Table:
    r"""Performs the transmission disequilibrium test on trios.

    .. include:: ../_templates/req_tstring.rst

    .. include:: ../_templates/req_tvariant.rst

    .. include:: ../_templates/req_biallelic.rst

    Examples
    --------
    Compute TDT association statistics and show the first two results:
    
    >>> pedigree = hl.Pedigree.read('data/tdt_trios.fam')
    >>> tdt_table = hl.transmission_disequilibrium_test(tdt_dataset, pedigree)
    >>> tdt_table.show(2)  # doctest: +NOTEST
    +---------------+------------+-------+-------+----------+----------+
    | locus         | alleles    |     t |     u |   chi_sq |  p_value |
    +---------------+------------+-------+-------+----------+----------+
    | locus<GRCh37> | array<str> | int64 | int64 |  float64 |  float64 |
    +---------------+------------+-------+-------+----------+----------+
    | 1:246714629   | ["C","A"]  |     0 |     4 | 4.00e+00 | 4.55e-02 |
    | 2:167262169   | ["T","C"]  |    NA |    NA |       NA |       NA |
    +---------------+------------+-------+-------+----------+----------+

    Export variants with p-values below 0.001:

    >>> tdt_table = tdt_table.filter(tdt_table.p_value < 0.001)
    >>> tdt_table.export("output/tdt_results.tsv")

    Notes
    -----
    The
    `transmission disequilibrium test <https://en.wikipedia.org/wiki/Transmission_disequilibrium_test#The_case_of_trios:_one_affected_child_per_family>`__
    compares the number of times the alternate allele is transmitted (t) versus
    not transmitted (u) from a heterozgyous parent to an affected child. The null
    hypothesis holds that each case is equally likely. The TDT statistic is given by

    .. math::

        (t - u)^2 \over (t + u)

    and asymptotically follows a chi-squared distribution with one degree of
    freedom under the null hypothesis.

    :func:`transmission_disequilibrium_test` only considers complete trios (two
    parents and a proband with defined sex) and only returns results for the
    autosome, as defined by :meth:`~hail.genetics.Locus.in_autosome`, and
    chromosome X. Transmissions and non-transmissions are counted only for the
    configurations of genotypes and copy state in the table below, in order to
    filter out Mendel errors and configurations where transmission is
    guaranteed. The copy state of a locus with respect to a trio is defined as
    follows:

    - Auto -- in autosome or in PAR of X or female child
    - HemiX -- in non-PAR of X and male child

    Here PAR is the `pseudoautosomal region
    <https://en.wikipedia.org/wiki/Pseudoautosomal_region>`__
    of X and Y defined by :class:`.ReferenceGenome`, which many variant callers
    map to chromosome X.

    +--------+--------+--------+------------+---+---+
    |  Kid   | Dad    | Mom    | Copy State | t | u |
    +========+========+========+============+===+===+
    | HomRef | Het    | Het    | Auto       | 0 | 2 |
    +--------+--------+--------+------------+---+---+
    | HomRef | HomRef | Het    | Auto       | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | HomRef | Het    | HomRef | Auto       | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | Het    | Het    | Het    | Auto       | 1 | 1 |
    +--------+--------+--------+------------+---+---+
    | Het    | HomRef | Het    | Auto       | 1 | 0 |
    +--------+--------+--------+------------+---+---+
    | Het    | Het    | HomRef | Auto       | 1 | 0 |
    +--------+--------+--------+------------+---+---+
    | Het    | HomVar | Het    | Auto       | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | Het    | Het    | HomVar | Auto       | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | HomVar | Het    | Het    | Auto       | 2 | 0 |
    +--------+--------+--------+------------+---+---+
    | HomVar | Het    | HomVar | Auto       | 1 | 0 |
    +--------+--------+--------+------------+---+---+
    | HomVar | HomVar | Het    | Auto       | 1 | 0 |
    +--------+--------+--------+------------+---+---+
    | HomRef | HomRef | Het    | HemiX      | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | HomRef | HomVar | Het    | HemiX      | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | HomVar | HomRef | Het    | HemiX      | 1 | 0 |
    +--------+--------+--------+------------+---+---+
    | HomVar | HomVar | Het    | HemiX      | 1 | 0 |
    +--------+--------+--------+------------+---+---+

    :func:`tdt` produces a table with the following columns:

     - `locus` (:class:`.tlocus`) -- Locus.
     - `alleles` (:class:`.tarray` of :py:data:`.tstr`) -- Alleles.
     - `t` (:py:data:`.tint32`) -- Number of transmitted alternate alleles.
     - `u` (:py:data:`.tint32`) -- Number of untransmitted alternate alleles.
     - `chi_sq` (:py:data:`.tfloat64`) -- TDT statistic.
     - `p_value` (:py:data:`.tfloat64`) -- p-value.

    Parameters
    ----------
    dataset : :class:`.MatrixTable`
        Dataset.
    pedigree : :class:`~hail.genetics.Pedigree`
        Sample pedigree.

    Returns
    -------
    :class:`.Table`
        Table of TDT results.
    """

    dataset = require_biallelic(dataset, 'transmission_disequilibrium_test')
    dataset = dataset.annotate_rows(auto_or_x_par = dataset.locus.in_autosome() | dataset.locus.in_x_par())
    dataset = dataset.filter_rows(dataset.auto_or_x_par | dataset.locus.in_x_nonpar())

    hom_ref = 0
    het = 1
    hom_var = 2

    auto = 2
    hemi_x = 1

    #                     kid,     dad,     mom,   copy, t, u
    config_counts = [(hom_ref,     het,     het,   auto, 0, 2),
                     (hom_ref, hom_ref,     het,   auto, 0, 1),
                     (hom_ref,     het, hom_ref,   auto, 0, 1),
                     (    het,     het,     het,   auto, 1, 1),
                     (    het, hom_ref,     het,   auto, 1, 0),
                     (    het,     het, hom_ref,   auto, 1, 0),
                     (    het, hom_var,     het,   auto, 0, 1),
                     (    het,     het, hom_var,   auto, 0, 1),
                     (hom_var,     het,     het,   auto, 2, 0),
                     (hom_var,     het, hom_var,   auto, 1, 0),
                     (hom_var, hom_var,     het,   auto, 1, 0),
                     (hom_ref, hom_ref,     het, hemi_x, 0, 1),
                     (hom_ref, hom_var,     het, hemi_x, 0, 1),
                     (hom_var, hom_ref,     het, hemi_x, 1, 0),
                     (hom_var, hom_var,     het, hemi_x, 1, 0)]

    count_map = hl.literal({(c[0], c[1], c[2], c[3]): [c[4], c[5]] for c in config_counts})

    tri = trio_matrix(dataset, pedigree, complete_trios=True)

    # this filter removes mendel error of het father in x_nonpar. It also avoids
    #   building and looking up config in common case that neither parent is het
    father_is_het = tri.father_entry.GT.is_het()
    parent_is_valid_het = ((father_is_het & tri.auto_or_x_par) |
                           (tri.mother_entry.GT.is_het() & ~father_is_het))

    copy_state = hl.cond(tri.auto_or_x_par | tri.is_female, 2, 1)

    config = (tri.proband_entry.GT.n_alt_alleles(),
              tri.father_entry.GT.n_alt_alleles(),
              tri.mother_entry.GT.n_alt_alleles(),
              copy_state)

    tri = tri.annotate_rows(counts = agg.filter(parent_is_valid_het, agg.array_sum(count_map.get(config))))

    tab = tri.rows().select('counts')
    tab = tab.transmute(t = tab.counts[0], u = tab.counts[1])
    tab = tab.annotate(chi_sq = ((tab.t - tab.u) ** 2) / (tab.t + tab.u))
    tab = tab.annotate(p_value = hl.pchisqtail(tab.chi_sq, 1.0))

    return tab.cache()
Exemplo n.º 29
0
    def test_null_joins_2(self):
        tr = hl.utils.range_table(7, 1)
        table1 = tr.key_by(new_key=hl.cond((tr.idx == 3) | (tr.idx == 5),
                                           hl.null(hl.tint32), tr.idx),
                           key2=tr.idx)
        table1 = table1.select(idx1=table1.idx)
        table2 = tr.key_by(new_key=hl.cond((tr.idx == 4) | (tr.idx == 6),
                                           hl.null(hl.tint32), tr.idx),
                           key2=tr.idx)
        table2 = table2.select(idx2=table2.idx)

        left_join = table1.join(table2, 'left')
        right_join = table1.join(table2, 'right')
        inner_join = table1.join(table2, 'inner')
        outer_join = table1.join(table2, 'outer')

        def row(new_key, key2, idx1, idx2):
            return hl.Struct(new_key=new_key, key2=key2, idx1=idx1, idx2=idx2)

        left_join_expected = [
            row(0, 0, 0, 0),
            row(1, 1, 1, 1),
            row(2, 2, 2, 2),
            row(4, 4, 4, None),
            row(6, 6, 6, None),
            row(None, 3, 3, None),
            row(None, 5, 5, None)
        ]

        right_join_expected = [
            row(0, 0, 0, 0),
            row(1, 1, 1, 1),
            row(2, 2, 2, 2),
            row(3, 3, None, 3),
            row(5, 5, None, 5),
            row(None, 4, None, 4),
            row(None, 6, None, 6)
        ]

        inner_join_expected = [
            row(0, 0, 0, 0), row(1, 1, 1, 1),
            row(2, 2, 2, 2)
        ]

        outer_join_expected = [
            row(0, 0, 0, 0),
            row(1, 1, 1, 1),
            row(2, 2, 2, 2),
            row(3, 3, None, 3),
            row(4, 4, 4, None),
            row(5, 5, None, 5),
            row(6, 6, 6, None),
            row(None, 3, 3, None),
            row(None, 4, None, 4),
            row(None, 5, 5, None),
            row(None, 6, None, 6)
        ]

        self.assertEqual(left_join.collect(), left_join_expected)
        self.assertEqual(right_join.collect(), right_join_expected)
        self.assertEqual(inner_join.collect(), inner_join_expected)
        self.assertEqual(outer_join.collect(), outer_join_expected)
Exemplo n.º 30
0
    def get_csq_from_struct(element: hl.expr.StructExpression,
                            feature_type: str) -> hl.expr.StringExpression:
        # Most fields are 1-1, just lowercase
        fields = dict(element)

        # Add general exceptions
        fields.update({
            "allele":
            element.variant_allele,
            "consequence":
            hl.delimit(element.consequence_terms, delimiter="&"),
            "feature_type":
            feature_type,
            "feature":
            (element.transcript_id if "transcript_id" in element else
             element.regulatory_feature_id if "regulatory_feature_id"
             in element else element.motif_feature_id
             if "motif_feature_id" in element else ""),
            "variant_class":
            vep_expr.variant_class,
        })

        # Add exception for transcripts
        if feature_type == "Transcript":
            fields.update({
                "canonical":
                hl.cond(element.canonical == 1, "YES", ""),
                "ensp":
                element.protein_id,
                "gene":
                element.gene_id,
                "symbol":
                element.gene_symbol,
                "symbol_source":
                element.gene_symbol_source,
                "cdna_position":
                hl.str(element.cdna_start) + hl.cond(
                    element.cdna_start == element.cdna_end,
                    "",
                    "-" + hl.str(element.cdna_end),
                ),
                "cds_position":
                hl.str(element.cds_start) + hl.cond(
                    element.cds_start == element.cds_end,
                    "",
                    "-" + hl.str(element.cds_end),
                ),
                "protein_position":
                hl.str(element.protein_start) + hl.cond(
                    element.protein_start == element.protein_end,
                    "",
                    "-" + hl.str(element.protein_end),
                ),
                "sift":
                element.sift_prediction + "(" +
                hl.format("%.3f", element.sift_score) + ")",
                "polyphen":
                element.polyphen_prediction + "(" +
                hl.format("%.3f", element.polyphen_score) + ")",
                "domains":
                hl.delimit(element.domains.map(lambda d: d.db + ":" + d.name),
                           "&"),
            })
        elif feature_type == "MotifFeature":
            fields["motif_score_change"] = hl.format(
                "%.3f", element.motif_score_change)

        return hl.delimit(
            [hl.or_else(hl.str(fields.get(f, "")), "") for f in _csq_fields],
            "|")
Exemplo n.º 31
0
 def test_literals_rebuild(self):
     mt = hl.utils.range_matrix_table(1, 1)
     mt = mt.annotate_rows(x = hl.cond(hl.len(hl.literal([1,2,3])) < hl.rand_unif(10, 11), mt.globals, hl.struct()))
     mt._force_count_rows()
Exemplo n.º 32
0
def combine(ts):
    def merge_alleles(alleles):
        from hail.expr.functions import _num_allele_type, _allele_ints
        return hl.rbind(
            alleles.map(lambda a: hl.or_else(a[0], ''))
                   .fold(lambda s, t: hl.cond(hl.len(s) > hl.len(t), s, t), ''),
            lambda ref:
            hl.rbind(
                alleles.map(
                    lambda al: hl.rbind(
                        al[0],
                        lambda r:
                        hl.array([ref]).extend(
                            al[1:].map(
                                lambda a:
                                hl.rbind(
                                    _num_allele_type(r, a),
                                    lambda at:
                                    hl.cond(
                                        (_allele_ints['SNP'] == at) |
                                        (_allele_ints['Insertion'] == at) |
                                        (_allele_ints['Deletion'] == at) |
                                        (_allele_ints['MNP'] == at) |
                                        (_allele_ints['Complex'] == at),
                                        a + ref[hl.len(r):],
                                        a)))))),
                lambda lal:
                hl.struct(
                    globl=hl.array([ref]).extend(hl.array(hl.set(hl.flatten(lal)).remove(ref))),
                    local=lal)))

    def renumber_entry(entry, old_to_new) -> StructExpression:
        # global index of alternate (non-ref) alleles
        return entry.annotate(LA=entry.LA.map(lambda lak: old_to_new[lak]))

    if (ts.row.dtype, ts.globals.dtype) not in _merge_function_map:
        f = hl.experimental.define_function(
            lambda row, gbl:
            hl.rbind(
                merge_alleles(row.data.map(lambda d: d.alleles)),
                lambda alleles:
                hl.struct(
                    locus=row.locus,
                    alleles=alleles.globl,
                    rsid=hl.find(hl.is_defined, row.data.map(lambda d: d.rsid)),
                    __entries=hl.bind(
                        lambda combined_allele_index:
                        hl.range(0, hl.len(row.data)).flatmap(
                            lambda i:
                            hl.cond(hl.is_missing(row.data[i].__entries),
                                    hl.range(0, hl.len(gbl.g[i].__cols))
                                      .map(lambda _: hl.null(row.data[i].__entries.dtype.element_type)),
                                    hl.bind(
                                        lambda old_to_new: row.data[i].__entries.map(
                                            lambda e: renumber_entry(e, old_to_new)),
                                        hl.range(0, hl.len(alleles.local[i])).map(
                                            lambda j: combined_allele_index[alleles.local[i][j]])))),
                        hl.dict(hl.range(0, hl.len(alleles.globl)).map(
                            lambda j: hl.tuple([alleles.globl[j], j])))))),
            ts.row.dtype, ts.globals.dtype)
        _merge_function_map[(ts.row.dtype, ts.globals.dtype)] = f
    merge_function = _merge_function_map[(ts.row.dtype, ts.globals.dtype)]
    ts = Table(TableMapRows(ts._tir, Apply(merge_function._name,
                                           TopLevelReference('row'),
                                           TopLevelReference('global'))))
    return ts.transmute_globals(__cols=hl.flatten(ts.g.map(lambda g: g.__cols)))
Exemplo n.º 33
0
def main(args):
    # Start Hail
    hl.init(default_reference=args.default_ref_genome)

    # Import raw split MT
    mt = (get_mt_data(dataset=args.exome_cohort, part='raw',
                      split=True).select_cols())

    ht = (mt.cols().key_by('s'))

    # Annotate samples filters
    sample_qc_filters = {}

    # 1. Add sample hard filters annotation expr
    sample_qc_hard_filters_ht = hl.read_table(
        get_sample_qc_ht_path(dataset=args.exome_cohort, part='hard_filters'))

    sample_qc_filters.update(
        {'hard_filters': sample_qc_hard_filters_ht[ht.s]['hard_filters']})

    # 2. Add population qc filters annotation expr
    sample_qc_pop_ht = hl.read_table(
        get_sample_qc_ht_path(dataset=args.exome_cohort, part='population_qc'))

    sample_qc_filters.update(
        {'predicted_pop': sample_qc_pop_ht[ht.s]['predicted_pop']})

    # 3. Add relatedness filters annotation expr
    related_samples_to_drop = get_related_samples_to_drop()
    related_samples = hl.set(
        related_samples_to_drop.aggregate(
            hl.agg.collect_as_set(related_samples_to_drop.node.id)))

    sample_qc_filters.update({'is_related': related_samples.contains(ht.s)})

    # 4. Add stratified sample qc (population/platform) annotation expr
    sample_qc_pop_platform_filters_ht = hl.read_table(
        get_sample_qc_ht_path(dataset=args.exome_cohort,
                              part='stratified_metrics_filter'))

    sample_qc_filters.update({
        'pop_platform_filters':
        sample_qc_pop_platform_filters_ht[ht.s]['pop_platform_filters']
    })

    ht = (ht.annotate(**sample_qc_filters))

    # Final sample qc filter joint expression
    final_sample_qc_ann_expr = {
        'pass_filters':
        hl.cond((hl.len(ht.hard_filters) == 0) &
                (hl.len(ht.pop_platform_filters) == 0) &
                (ht.predicted_pop == 'EUR') & ~ht.is_related, True, False)
    }
    ht = (ht.annotate(**final_sample_qc_ann_expr))

    logger.info('Writing final sample qc HT to disk...')
    output_path_ht = get_sample_qc_ht_path(dataset=args.exome_cohort,
                                           part='final_qc')

    ht = ht.checkpoint(output_path_ht, overwrite=args.overwrite)

    # Export final sample QC annotations to file
    if args.write_to_file:
        (ht.export(f'{output_path_ht}.tsv.bgz'))

    ## Release final unphase MT with adjusted genotypes filtered
    mt = unphase_mt(mt)
    mt = annotate_adj(mt)
    mt = mt.filter_entries(mt.adj).select_entries('GT', 'DP', 'GQ', 'adj')

    logger.info('Writing unphase MT with adjusted genotypes to disk...')
    # write MT
    mt.write(get_qc_mt_path(dataset=args.exome_cohort,
                            part='unphase_adj_genotypes',
                            split=True),
             overwrite=args.overwrite)

    # Stop Hail
    hl.stop()

    print("Finished!")
Exemplo n.º 34
0
# These fields contain float values but are stored as strings
CONVERT_TO_FLOAT_FIELDS = [
    "ESP_AF_POPMAX", "ESP_AF_GLOBAL", "KG_AF_POPMAX", "KG_AF_GLOBAL"
]

# Convert "NA" and empty strings into null values
# Convert fields in chunks to avoid "Method code too large" errors
for i in range(0, len(SELECT_INFO_FIELDS), 10):
    mt = mt.annotate_rows(info=mt.info.annotate(
        **{
            field: hl.or_missing(
                hl.is_defined(mt.info[field]),
                hl.bind(
                    lambda value: hl.cond(
                        (value == "") | (value == "NA"),
                        hl.null(mt.info[field].dtype), mt.info[field]),
                    hl.str(mt.info[field]),
                ),
            )
            for field in SELECT_INFO_FIELDS[i:i + 10]
        }))

# Convert field types
mt = mt.annotate_rows(info=mt.info.annotate(
    **{
        field: hl.cond(mt.info[field] == "", hl.null(hl.tint),
                       hl.int(mt.info[field]))
        for field in CONVERT_TO_INT_FIELDS
    }))
mt = mt.annotate_rows(info=mt.info.annotate(
Exemplo n.º 35
0
Arquivo: qc.py Projeto: tpoterba/hail
 def get_allele_type(allele_idx):
     return hl.cond(allele_idx > 0, mt[variant_atypes][allele_idx - 1], hl.null(hl.tint32))
Exemplo n.º 36
0
 def test_plot_roc_curve(self):
     x = hl.utils.range_table(100).annotate(score1=hl.rand_norm(), score2=hl.rand_norm())
     x = x.annotate(tp=hl.cond(x.score1 > 0, hl.rand_bool(0.7), False), score3=x.score1 + hl.rand_norm())
     ht = x.annotate(fp=hl.cond(~x.tp, hl.rand_bool(0.2), False))
     _, aucs = hl.experimental.plot_roc_curve(ht, ['score1', 'score2', 'score3'])
Exemplo n.º 37
0
def transform_one(mt, vardp_outlier=100_000) -> Table:
    """transforms a gvcf into a form suitable for combining

    The input to this should be some result of either :func:`.import_vcf` or
    :func:`.import_vcfs` with `array_elements_required=False`.

    There is a strong assumption that this function will be called on a matrix
    table with one column.
    """
    mt = localize(mt)
    if mt.row.dtype not in _transform_rows_function_map:
        f = hl.experimental.define_function(
            lambda row: hl.rbind(
                hl.len(row.alleles), '<NON_REF>' == row.alleles[-1],
                lambda alleles_len, has_non_ref: hl.struct(
                    locus=row.locus,
                    alleles=hl.cond(has_non_ref, row.alleles[:-1], row.alleles),
                    rsid=row.rsid,
                    __entries=row.__entries.map(
                        lambda e:
                        hl.struct(
                            DP=e.DP,
                            END=row.info.END,
                            GQ=e.GQ,
                            LA=hl.range(0, alleles_len - hl.cond(has_non_ref, 1, 0)),
                            LAD=hl.cond(has_non_ref, e.AD[:-1], e.AD),
                            LGT=e.GT,
                            LPGT=e.PGT,
                            LPL=hl.cond(has_non_ref,
                                        hl.cond(alleles_len > 2,
                                                e.PL[:-alleles_len],
                                                hl.null(e.PL.dtype)),
                                        hl.cond(alleles_len > 1,
                                                e.PL,
                                                hl.null(e.PL.dtype))),
                            MIN_DP=e.MIN_DP,
                            PID=e.PID,
                            RGQ=hl.cond(
                                has_non_ref,
                                e.PL[hl.call(0, alleles_len - 1).unphased_diploid_gt_index()],
                                hl.null(e.PL.dtype.element_type)),
                            SB=e.SB,
                            gvcf_info=hl.case()
                                .when(hl.is_missing(row.info.END),
                                      hl.struct(
                                          ClippingRankSum=row.info.ClippingRankSum,
                                          BaseQRankSum=row.info.BaseQRankSum,
                                          MQ=row.info.MQ,
                                          MQRankSum=row.info.MQRankSum,
                                          MQ_DP=row.info.MQ_DP,
                                          QUALapprox=row.info.QUALapprox,
                                          RAW_MQ=row.info.RAW_MQ,
                                          ReadPosRankSum=row.info.ReadPosRankSum,
                                          VarDP=hl.cond(row.info.VarDP > vardp_outlier,
                                                        row.info.DP, row.info.VarDP)))
                                .or_missing()
                        ))),
            ),
            mt.row.dtype)
        _transform_rows_function_map[mt.row.dtype] = f
    transform_row = _transform_rows_function_map[mt.row.dtype]
    return Table(TableMapRows(mt._tir, Apply(transform_row._name, TopLevelReference('row'))))
Exemplo n.º 38
0
def prepare_pext_data(base_level_pext_path):
    tmp_dir = os.path.expanduser("~")

    #
    # Step 1: rename fields, extract chrom/pos from locus, convert missing values to 0, export to TSV
    #
    ds = hl.read_table(base_level_pext_path)

    ds = ds.select(
        gene_id=ds.ensg,
        chrom=ds.locus.contig,
        pos=ds.locus.position,
        # Replace NaNs and missing values with 0s
        mean=hl.cond(
            hl.is_missing(ds.mean_proportion) | hl.is_nan(ds.mean_proportion),
            hl.float(0), ds.mean_proportion),
        **{
            renamed: hl.cond(
                hl.is_missing(ds[original]) | hl.is_nan(ds[original]),
                hl.float(0), ds[original])
            for original, renamed in TISSUE_NAME_MAP.items()
        })

    ds = ds.order_by(ds.gene_id, hl.asc(ds.pos)).drop("locus")
    ds.export("file://" + os.path.join(tmp_dir, "bases.tsv"))

    #
    # Step 2: Collect base-level data into regions
    #
    with open(os.path.join(tmp_dir, "regions.tsv"), "w") as output_file:
        writer = csv.writer(output_file, delimiter="\t")
        writer.writerow(["gene_id", "chrom", "start", "stop", "mean"] +
                        TISSUE_FIELDS)

        def output_region(region):
            writer.writerow([
                region.gene, region.chrom, region.start, region.stop,
                region.tissues["mean"]
            ] + [region.tissues[t] for t in TISSUE_FIELDS])

        rows = read_bases_tsv(os.path.join(tmp_dir, "bases.tsv"))
        first_row = next(rows)
        current_region = Region(gene=first_row.gene,
                                chrom=first_row.chrom,
                                start=first_row.pos,
                                stop=None,
                                tissues=first_row.tissues)
        last_pos = first_row.pos

        for row in rows:
            if (row.gene != current_region.gene
                    or row.chrom != current_region.chrom or row.pos >
                (last_pos + 1)
                    or any(row.tissues[t] != current_region.tissues[t]
                           for t in row.tissues)):
                output_region(current_region._replace(stop=last_pos))
                current_region = Region(gene=row.gene,
                                        chrom=row.chrom,
                                        start=row.pos,
                                        stop=None,
                                        tissues=row.tissues)

            last_pos = row.pos

        output_region(current_region._replace(stop=last_pos))

    # Copy regions file to HDFS
    subprocess.run(
        [
            "hdfs", "dfs", "-cp",
            "file://" + os.path.join(tmp_dir, "regions.tsv"),
            os.path.join(tmp_dir, "regions.tsv")
        ],
        check=True,
    )

    #
    # Step 3: Convert regions to a Hail table.
    #
    types = {t: hl.tfloat for t in TISSUE_FIELDS}
    types["gene_id"] = hl.tstr
    types["chrom"] = hl.tstr
    types["start"] = hl.tint
    types["stop"] = hl.tint
    types["mean"] = hl.tfloat

    ds = hl.import_table(os.path.join(tmp_dir, "regions.tsv"),
                         min_partitions=100,
                         missing="",
                         types=types)

    ds = ds.select("gene_id",
                   "chrom",
                   "start",
                   "stop",
                   "mean",
                   tissues=hl.struct(**{t: ds[t]
                                        for t in TISSUE_FIELDS}))

    ds = ds.group_by("gene_id").aggregate(
        regions=hl.agg.collect(ds.row_value.drop("gene_id")))

    return ds