Ejemplo n.º 1
0
    def test_concordance(self):
        dataset = get_dataset()
        glob_conc, cols_conc, rows_conc = hl.concordance(dataset, dataset)

        self.assertEqual(sum([sum(glob_conc[i]) for i in range(5)]), dataset.count_rows() * dataset.count_cols())

        counts = dataset.aggregate_entries(hl.Struct(n_het=agg.filter(dataset.GT.is_het(), agg.count()),
                                                     n_hom_ref=agg.filter(dataset.GT.is_hom_ref(),
                                                                          agg.count()),
                                                     n_hom_var=agg.filter(dataset.GT.is_hom_var(),
                                                                          agg.count()),
                                                     nNoCall=agg.filter(hl.is_missing(dataset.GT),
                                                                        agg.count())))

        self.assertEqual(glob_conc[0][0], 0)
        self.assertEqual(glob_conc[1][1], counts.nNoCall)
        self.assertEqual(glob_conc[2][2], counts.n_hom_ref)
        self.assertEqual(glob_conc[3][3], counts.n_het)
        self.assertEqual(glob_conc[4][4], counts.n_hom_var)
        [self.assertEqual(glob_conc[i][j], 0) for i in range(5) for j in range(5) if i != j]

        self.assertTrue(cols_conc.all(hl.sum(hl.flatten(cols_conc.concordance)) == dataset.count_rows()))
        self.assertTrue(rows_conc.all(hl.sum(hl.flatten(rows_conc.concordance)) == dataset.count_cols()))

        cols_conc.write('/tmp/foo.kt', overwrite=True)
        rows_conc.write('/tmp/foo.kt', overwrite=True)
Ejemplo n.º 2
0
    def test_aggregate_ir(self):
        ds = (hl.utils.range_matrix_table(5, 5)
              .annotate_globals(g1=5)
              .annotate_entries(e1=3))

        x = [("col_idx", lambda e: ds.aggregate_cols(e)),
             ("row_idx", lambda e: ds.aggregate_rows(e))]

        for name, f in x:
            r = f(hl.struct(x=agg.sum(ds[name]) + ds.g1,
                            y=agg.filter(ds[name] % 2 != 0, agg.sum(ds[name] + 2)) + ds.g1,
                            z=agg.sum(ds.g1 + ds[name]) + ds.g1,
                            mean=agg.mean(ds[name])))
            self.assertEqual(convert_struct_to_dict(r), {u'x': 15, u'y': 13, u'z': 40, u'mean': 2.0})

            r = f(5)
            self.assertEqual(r, 5)

            r = f(hl.null(hl.tint32))
            self.assertEqual(r, None)

            r = f(agg.filter(ds[name] % 2 != 0, agg.sum(ds[name] + 2)) + ds.g1)
            self.assertEqual(r, 13)

        r = ds.aggregate_entries(agg.filter((ds.row_idx % 2 != 0) & (ds.col_idx % 2 != 0),
                                            agg.sum(ds.e1 + ds.g1 + ds.row_idx + ds.col_idx)) + ds.g1)
        self.assertTrue(r, 48)
Ejemplo n.º 3
0
    def test_concordance(self):
        dataset = get_dataset()
        glob_conc, cols_conc, rows_conc = hl.concordance(dataset, dataset)

        self.assertEqual(sum([sum(glob_conc[i]) for i in range(5)]), dataset.count_rows() * dataset.count_cols())

        counts = dataset.aggregate_entries(hl.Struct(n_het=agg.filter(dataset.GT.is_het(), agg.count()),
                                                     n_hom_ref=agg.filter(dataset.GT.is_hom_ref(),
                                                                          agg.count()),
                                                     n_hom_var=agg.filter(dataset.GT.is_hom_var(),
                                                                          agg.count()),
                                                     nNoCall=agg.filter(hl.is_missing(dataset.GT),
                                                                        agg.count())))

        self.assertEqual(glob_conc[0][0], 0)
        self.assertEqual(glob_conc[1][1], counts.nNoCall)
        self.assertEqual(glob_conc[2][2], counts.n_hom_ref)
        self.assertEqual(glob_conc[3][3], counts.n_het)
        self.assertEqual(glob_conc[4][4], counts.n_hom_var)
        [self.assertEqual(glob_conc[i][j], 0) for i in range(5) for j in range(5) if i != j]

        self.assertTrue(cols_conc.all(hl.sum(hl.flatten(cols_conc.concordance)) == dataset.count_rows()))
        self.assertTrue(rows_conc.all(hl.sum(hl.flatten(rows_conc.concordance)) == dataset.count_cols()))

        cols_conc.write('/tmp/foo.kt', overwrite=True)
        rows_conc.write('/tmp/foo.kt', overwrite=True)
Ejemplo n.º 4
0
    def test_aggregate(self):
        vds = self.get_vds()

        vds = vds.annotate_globals(foo=5)
        vds = vds.annotate_rows(x1=agg.count())
        vds = vds.annotate_cols(y1=agg.count())
        vds = vds.annotate_entries(z1=vds.DP)

        qv = vds.aggregate_rows(agg.count())
        qs = vds.aggregate_cols(agg.count())
        qg = vds.aggregate_entries(agg.count())

        self.assertIsNotNone(vds.aggregate_entries(hl.agg.take(vds.s, 1)[0]))

        self.assertEqual(qv, 346)
        self.assertEqual(qs, 100)
        self.assertEqual(qg, qv * qs)

        qvs = vds.aggregate_rows(
            hl.Struct(x=agg.collect(vds.locus.contig), y=agg.collect(vds.x1)))

        qss = vds.aggregate_cols(
            hl.Struct(x=agg.collect(vds.s), y=agg.collect(vds.y1)))

        qgs = vds.aggregate_entries(
            hl.Struct(x=agg.collect(agg.filter(False, vds.y1)),
                      y=agg.collect(agg.filter(hl.rand_bool(0.1), vds.GT))))
Ejemplo n.º 5
0
    def test_aggregate_ir(self):
        ds = (hl.utils.range_matrix_table(5, 5)
              .annotate_globals(g1=5)
              .annotate_entries(e1=3))

        x = [("col_idx", lambda e: ds.aggregate_cols(e)),
             ("row_idx", lambda e: ds.aggregate_rows(e))]

        for name, f in x:
            r = f(hl.struct(x=agg.sum(ds[name]) + ds.g1,
                            y=agg.filter(ds[name] % 2 != 0, agg.sum(ds[name] + 2)) + ds.g1,
                            z=agg.sum(ds.g1 + ds[name]) + ds.g1,
                            mean=agg.mean(ds[name])))
            self.assertEqual(convert_struct_to_dict(r), {u'x': 15, u'y': 13, u'z': 40, u'mean': 2.0})

            r = f(5)
            self.assertEqual(r, 5)

            r = f(hl.null(hl.tint32))
            self.assertEqual(r, None)

            r = f(agg.filter(ds[name] % 2 != 0, agg.sum(ds[name] + 2)) + ds.g1)
            self.assertEqual(r, 13)

        r = ds.aggregate_entries(agg.filter((ds.row_idx % 2 != 0) & (ds.col_idx % 2 != 0),
                                            agg.sum(ds.e1 + ds.g1 + ds.row_idx + ds.col_idx)) + ds.g1)
        self.assertTrue(r, 48)
Ejemplo n.º 6
0
    def test_aggregate(self):
        vds = self.get_vds()

        vds = vds.annotate_globals(foo=5)
        vds = vds.annotate_rows(x1=agg.count())
        vds = vds.annotate_cols(y1=agg.count())
        vds = vds.annotate_entries(z1=vds.DP)

        qv = vds.aggregate_rows(agg.count())
        qs = vds.aggregate_cols(agg.count())
        qg = vds.aggregate_entries(agg.count())

        self.assertIsNotNone(vds.aggregate_entries(hl.agg.take(vds.s, 1)[0]))

        self.assertEqual(qv, 346)
        self.assertEqual(qs, 100)
        self.assertEqual(qg, qv * qs)

        qvs = vds.aggregate_rows(hl.Struct(x=agg.collect(vds.locus.contig),
                                           y=agg.collect(vds.x1)))

        qss = vds.aggregate_cols(hl.Struct(x=agg.collect(vds.s),
                                           y=agg.collect(vds.y1)))

        qgs = vds.aggregate_entries(hl.Struct(x=agg.filter(False, agg.collect(vds.y1)),
                                              y=agg.filter(hl.rand_bool(0.1), agg.collect(vds.GT))))
Ejemplo n.º 7
0
    def test_aggregate_ir(self):
        kt = hl.utils.range_table(10).annotate_globals(g1=5)
        r = kt.aggregate(hl.struct(x=agg.sum(kt.idx) + kt.g1,
                                   y=agg.filter(kt.idx % 2 != 0, agg.sum(kt.idx + 2)) + kt.g1,
                                   z=agg.sum(kt.g1 + kt.idx) + kt.g1))
        self.assertEqual(convert_struct_to_dict(r), {u'x': 50, u'y': 40, u'z': 100})

        r = kt.aggregate(5)
        self.assertEqual(r, 5)

        r = kt.aggregate(hl.null(hl.tint32))
        self.assertEqual(r, None)

        r = kt.aggregate(agg.filter(kt.idx % 2 != 0, agg.sum(kt.idx + 2)) + kt.g1)
        self.assertEqual(r, 40)
Ejemplo n.º 8
0
    def test_aggregate_ir(self):
        kt = hl.utils.range_table(10).annotate_globals(g1=5)
        r = kt.aggregate(hl.struct(x=agg.sum(kt.idx) + kt.g1,
                                   y=agg.filter(kt.idx % 2 != 0, agg.sum(kt.idx + 2)) + kt.g1,
                                   z=agg.sum(kt.g1 + kt.idx) + kt.g1))
        self.assertEqual(convert_struct_to_dict(r), {u'x': 50, u'y': 40, u'z': 100})

        r = kt.aggregate(5)
        self.assertEqual(r, 5)

        r = kt.aggregate(hl.null(hl.tint32))
        self.assertEqual(r, None)

        r = kt.aggregate(agg.filter(kt.idx % 2 != 0, agg.sum(kt.idx + 2)) + kt.g1)
        self.assertEqual(r, 40)
Ejemplo n.º 9
0
 def test_agg_cols_filter(self):
     t = hl.utils.range_matrix_table(1, 10)
     tests = [(agg.filter(t.col_idx > 7,
                          agg.collect(t.col_idx + 1).append(0)),
               [9, 10, 0]),
              (agg.filter(t.col_idx > 7,
                          agg.explode(lambda elt: agg.collect(elt + 1).append(0),
                                      [t.col_idx, t.col_idx + 1])),
               [9, 10, 10, 11, 0]),
              (agg.filter(t.col_idx > 7,
                          agg.group_by(t.col_idx % 3,
                                       hl.array(agg.collect_as_set(t.col_idx + 1)).append(0))),
               {0: [10, 0], 2: [9, 0]})
              ]
     for aggregation, expected in tests:
         self.assertEqual(t.select_rows(result = aggregation).result.collect()[0], expected)
Ejemplo n.º 10
0
    def test_agg_cols_explode(self):
        t = hl.utils.range_matrix_table(1, 10)

        tests = [(agg.explode(
            lambda elt: agg.collect(elt + 1).append(0),
            hl.cond(t.col_idx > 7, [t.col_idx, t.col_idx + 1],
                    hl.empty_array(hl.tint32))), [9, 10, 10, 11, 0]),
                 (agg.explode(
                     lambda elt: agg.explode(
                         lambda elt2: agg.collect(elt2 + 1).append(0),
                         [elt, elt + 1]),
                     hl.cond(t.col_idx > 7, [t.col_idx, t.col_idx + 1],
                             hl.empty_array(hl.tint32))),
                  [9, 10, 10, 11, 10, 11, 11, 12, 0]),
                 (agg.explode(
                     lambda elt: agg.filter(elt > 8,
                                            agg.collect(elt + 1).append(0)),
                     hl.cond(t.col_idx > 7, [t.col_idx, t.col_idx + 1],
                             hl.empty_array(hl.tint32))), [10, 10, 11, 0]),
                 (agg.explode(
                     lambda elt: agg.group_by(elt % 3,
                                              agg.collect(elt + 1).append(0)),
                     hl.cond(t.col_idx > 7, [t.col_idx, t.col_idx + 1],
                             hl.empty_array(hl.tint32))), {
                                 0: [10, 10, 0],
                                 1: [11, 0],
                                 2: [9, 0]
                             })]
        for aggregation, expected in tests:
            self.assertEqual(
                t.select_rows(result=aggregation).result.collect()[0],
                expected)
Ejemplo n.º 11
0
 def test_agg_cols_group_by(self):
     t = hl.utils.range_matrix_table(1, 10)
     tests = [
         (agg.group_by(
             t.col_idx % 2,
             hl.array(agg.collect_as_set(t.col_idx + 1)).append(0)), {
                 0: [1, 3, 5, 7, 9, 0],
                 1: [2, 4, 6, 8, 10, 0]
             }),
         (agg.group_by(
             t.col_idx % 3,
             agg.filter(
                 t.col_idx > 7,
                 hl.array(agg.collect_as_set(t.col_idx + 1)).append(0))), {
                     0: [10, 0],
                     1: [0],
                     2: [9, 0]
                 }),
         (agg.group_by(
             t.col_idx % 3,
             agg.explode(
                 lambda elt: agg.collect(elt + 1).append(0),
                 hl.cond(t.col_idx > 7, [t.col_idx, t.col_idx + 1],
                         hl.empty_array(hl.tint32)))), {
                             0: [10, 11, 0],
                             1: [0],
                             2: [9, 10, 0]
                         }),
     ]
     for aggregation, expected in tests:
         self.assertEqual(
             t.select_rows(result=aggregation).result.collect()[0],
             expected)
Ejemplo n.º 12
0
 def test_agg_cols_filter(self):
     t = hl.utils.range_matrix_table(1, 10)
     tests = [(agg.filter(t.col_idx > 7,
                          agg.collect(t.col_idx + 1).append(0)),
               [9, 10, 0]),
              (agg.filter(t.col_idx > 7,
                          agg.explode(lambda elt: agg.collect(elt + 1).append(0),
                                      [t.col_idx, t.col_idx + 1])),
               [9, 10, 10, 11, 0]),
              (agg.filter(t.col_idx > 7,
                          agg.group_by(t.col_idx % 3,
                                       hl.array(agg.collect_as_set(t.col_idx + 1)).append(0))),
               {0: [10, 0], 2: [9, 0]})
              ]
     for aggregation, expected in tests:
         self.assertEqual(t.select_rows(result = aggregation).result.collect()[0], expected)
Ejemplo n.º 13
0
def summary_ccr(ht_ccr: hl.Table,
                file_output: str,
                ccr_pct_start: int = 0,
                ccr_pct_end: int = 100,
                ccr_pct_bins: int = 10,
                cumulative_histogram: bool = False,
                ccr_pct_cutoffs=None) -> None:
    """
    Summarize Coding Constrain Region information (as histogram) per gene.

    :param ht_ccr: CCR Hail table
    :param file_output: File output path
    :param ccr_pct_start: Start of histogram range.
    :param ccr_pct_end: End of histogram range
    :param ccr_pct_bins: Number of bins
    :param cumulative_histogram: Generate a cumulative histogram (rather than to use bins)
    :param ccr_pct_cutoffs: Cut-offs used to generate the cumulative histogram
    :return: None
    """

    if ccr_pct_cutoffs is None:
        ccr_pct_cutoffs = [90, 95, 99]

    if cumulative_histogram:
        # generate cumulative counts histogram
        summary_tb = (ht_ccr
                      .group_by('gene')
                      .aggregate(**{'ccr_above_' + str(ccr_pct_cutoffs[k]): agg.filter(ht_ccr.ccr_pct >=
                                                                                       ccr_pct_cutoffs[k], agg.count())
                                    for k in range(0, len(ccr_pct_cutoffs))})
                      )
    else:
        summary_tb = (ht_ccr
                      .group_by('gene')
                      .aggregate(ccr_bins=agg.hist(ht_ccr.ccr_pct, ccr_pct_start, ccr_pct_end, ccr_pct_bins))
                      )

        # get bin edges as list (expected n_bins + 1)
        bin_edges = summary_tb.aggregate(agg.take(summary_tb.ccr_bins.bin_edges, 1))[0]

        # unpack array structure and annotate as individual fields
        summary_tb = (summary_tb
                      .annotate(**{'ccr_bin_' + str(bin_edges[k]) + '_' + str(bin_edges[k + 1]):
                                       summary_tb.ccr_bins.bin_freq[k] for k in range(0, len(bin_edges) - 1)})
                      .flatten()
                      )

        # drop fields
        fields_to_drop = ['ccr_bins.bin_edges', 'ccr_bins.bin_freq']
        summary_tb = (summary_tb
                      .drop(*fields_to_drop)
                      )

    # Export summarized table
    (summary_tb
     .export(output=file_output)
     )
Ejemplo n.º 14
0
    def test_query(self):
        schema = hl.tstruct(a=hl.tint32, b=hl.tint32, c=hl.tint32, d=hl.tint32, e=hl.tstr, f=hl.tarray(hl.tint32))

        rows = [{'a': 4, 'b': 1, 'c': 3, 'd': 5, 'e': "hello", 'f': [1, 2, 3]},
                {'a': 0, 'b': 5, 'c': 13, 'd': -1, 'e': "cat", 'f': []},
                {'a': 4, 'b': 2, 'c': 20, 'd': 3, 'e': "dog", 'f': [5, 6, 7]}]

        kt = hl.Table.parallelize(rows, schema)
        results = kt.aggregate(hl.Struct(q1=agg.sum(kt.b),
                                         q2=agg.count(),
                                         q3=agg.collect(kt.e),
                                         q4=agg.collect(agg.filter((kt.d >= 5) | (kt.a == 0), kt.e))))

        self.assertEqual(results.q1, 8)
        self.assertEqual(results.q2, 3)
        self.assertEqual(set(results.q3), {"hello", "cat", "dog"})
        self.assertEqual(set(results.q4), {"hello", "cat"})
Ejemplo n.º 15
0
 def test_agg_cols_group_by(self):
     t = hl.utils.range_matrix_table(1, 10)
     tests = [(agg.group_by(t.col_idx % 2,
                            hl.array(agg.collect_as_set(t.col_idx + 1)).append(0)),
               {0: [1, 3, 5, 7, 9, 0], 1: [2, 4, 6, 8, 10, 0]}),
              (agg.group_by(t.col_idx % 3,
                            agg.filter(t.col_idx > 7,
                                       hl.array(agg.collect_as_set(t.col_idx + 1)).append(0))),
               {0: [10, 0], 1: [0], 2: [9, 0]}),
              (agg.group_by(t.col_idx % 3,
                            agg.explode(lambda elt: agg.collect(elt + 1).append(0),
                                        hl.cond(t.col_idx > 7,
                                                [t.col_idx, t.col_idx + 1],
                                                hl.empty_array(hl.tint32)))),
               {0: [10, 11, 0], 1: [0], 2:[9, 10, 0]}),
              ]
     for aggregation, expected in tests:
         self.assertEqual(t.select_rows(result = aggregation).result.collect()[0], expected)
Ejemplo n.º 16
0
    def test_aggregate1(self):
        schema = hl.tstruct(a=hl.tint32, b=hl.tint32, c=hl.tint32, d=hl.tint32, e=hl.tstr, f=hl.tarray(hl.tint32))

        rows = [{'a': 4, 'b': 1, 'c': 3, 'd': 5, 'e': "hello", 'f': [1, 2, 3]},
                {'a': 0, 'b': 5, 'c': 13, 'd': -1, 'e': "cat", 'f': []},
                {'a': 4, 'b': 2, 'c': 20, 'd': 3, 'e': "dog", 'f': [5, 6, 7]}]

        kt = hl.Table.parallelize(rows, schema)
        results = kt.aggregate(hl.Struct(q1=agg.sum(kt.b),
                                         q2=agg.count(),
                                         q3=agg.collect(kt.e),
                                         q4=agg.filter((kt.d >= 5) | (kt.a == 0), agg.collect(kt.e)),
                                         q5=agg.explode(lambda elt: agg.mean(elt), kt.f)))

        self.assertEqual(results.q1, 8)
        self.assertEqual(results.q2, 3)
        self.assertEqual(set(results.q3), {"hello", "cat", "dog"})
        self.assertEqual(set(results.q4), {"hello", "cat"})
        self.assertAlmostEqual(results.q5, 4)
Ejemplo n.º 17
0
    def test_agg_cols_explode(self):
        t = hl.utils.range_matrix_table(1, 10)

        tests = [(agg.explode(lambda elt: agg.collect(elt + 1).append(0),
                              hl.cond(t.col_idx > 7, [t.col_idx, t.col_idx + 1], hl.empty_array(hl.tint32))),
                  [9, 10, 10, 11, 0]),
                 (agg.explode(lambda elt: agg.explode(lambda elt2: agg.collect(elt2 + 1).append(0),
                                                      [elt, elt + 1]),
                              hl.cond(t.col_idx > 7, [t.col_idx, t.col_idx + 1], hl.empty_array(hl.tint32))),
                  [9, 10, 10, 11, 10, 11, 11, 12, 0]),
                 (agg.explode(lambda elt: agg.filter(elt > 8,
                                                     agg.collect(elt + 1).append(0)),
                              hl.cond(t.col_idx > 7, [t.col_idx, t.col_idx + 1], hl.empty_array(hl.tint32))),
                  [10, 10, 11, 0]),
                 (agg.explode(lambda elt: agg.group_by(elt % 3,
                                                       agg.collect(elt + 1).append(0)),
                                           hl.cond(t.col_idx > 7,
                                                   [t.col_idx, t.col_idx + 1],
                                                   hl.empty_array(hl.tint32))),
                  {0: [10, 10, 0], 1: [11, 0], 2:[9, 0]})
                 ]
        for aggregation, expected in tests:
            self.assertEqual(t.select_rows(result = aggregation).result.collect()[0], expected)
Ejemplo n.º 18
0
def transmission_disequilibrium_test(dataset, pedigree) -> Table:
    r"""Performs the transmission disequilibrium test on trios.

    .. include:: ../_templates/req_tstring.rst

    .. include:: ../_templates/req_tvariant.rst

    .. include:: ../_templates/req_biallelic.rst

    Examples
    --------
    Compute TDT association statistics and show the first two results:
    
    >>> pedigree = hl.Pedigree.read('data/tdt_trios.fam')
    >>> tdt_table = hl.transmission_disequilibrium_test(tdt_dataset, pedigree)
    >>> tdt_table.show(2)  # doctest: +NOTEST
    +---------------+------------+-------+-------+----------+----------+
    | locus         | alleles    |     t |     u |   chi_sq |  p_value |
    +---------------+------------+-------+-------+----------+----------+
    | locus<GRCh37> | array<str> | int64 | int64 |  float64 |  float64 |
    +---------------+------------+-------+-------+----------+----------+
    | 1:246714629   | ["C","A"]  |     0 |     4 | 4.00e+00 | 4.55e-02 |
    | 2:167262169   | ["T","C"]  |    NA |    NA |       NA |       NA |
    +---------------+------------+-------+-------+----------+----------+

    Export variants with p-values below 0.001:

    >>> tdt_table = tdt_table.filter(tdt_table.p_value < 0.001)
    >>> tdt_table.export("output/tdt_results.tsv")

    Notes
    -----
    The
    `transmission disequilibrium test <https://en.wikipedia.org/wiki/Transmission_disequilibrium_test#The_case_of_trios:_one_affected_child_per_family>`__
    compares the number of times the alternate allele is transmitted (t) versus
    not transmitted (u) from a heterozgyous parent to an affected child. The null
    hypothesis holds that each case is equally likely. The TDT statistic is given by

    .. math::

        (t - u)^2 \over (t + u)

    and asymptotically follows a chi-squared distribution with one degree of
    freedom under the null hypothesis.

    :func:`transmission_disequilibrium_test` only considers complete trios (two
    parents and a proband with defined sex) and only returns results for the
    autosome, as defined by :meth:`~hail.genetics.Locus.in_autosome`, and
    chromosome X. Transmissions and non-transmissions are counted only for the
    configurations of genotypes and copy state in the table below, in order to
    filter out Mendel errors and configurations where transmission is
    guaranteed. The copy state of a locus with respect to a trio is defined as
    follows:

    - Auto -- in autosome or in PAR of X or female child
    - HemiX -- in non-PAR of X and male child

    Here PAR is the `pseudoautosomal region
    <https://en.wikipedia.org/wiki/Pseudoautosomal_region>`__
    of X and Y defined by :class:`.ReferenceGenome`, which many variant callers
    map to chromosome X.

    +--------+--------+--------+------------+---+---+
    |  Kid   | Dad    | Mom    | Copy State | t | u |
    +========+========+========+============+===+===+
    | HomRef | Het    | Het    | Auto       | 0 | 2 |
    +--------+--------+--------+------------+---+---+
    | HomRef | HomRef | Het    | Auto       | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | HomRef | Het    | HomRef | Auto       | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | Het    | Het    | Het    | Auto       | 1 | 1 |
    +--------+--------+--------+------------+---+---+
    | Het    | HomRef | Het    | Auto       | 1 | 0 |
    +--------+--------+--------+------------+---+---+
    | Het    | Het    | HomRef | Auto       | 1 | 0 |
    +--------+--------+--------+------------+---+---+
    | Het    | HomVar | Het    | Auto       | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | Het    | Het    | HomVar | Auto       | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | HomVar | Het    | Het    | Auto       | 2 | 0 |
    +--------+--------+--------+------------+---+---+
    | HomVar | Het    | HomVar | Auto       | 1 | 0 |
    +--------+--------+--------+------------+---+---+
    | HomVar | HomVar | Het    | Auto       | 1 | 0 |
    +--------+--------+--------+------------+---+---+
    | HomRef | HomRef | Het    | HemiX      | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | HomRef | HomVar | Het    | HemiX      | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | HomVar | HomRef | Het    | HemiX      | 1 | 0 |
    +--------+--------+--------+------------+---+---+
    | HomVar | HomVar | Het    | HemiX      | 1 | 0 |
    +--------+--------+--------+------------+---+---+

    :func:`tdt` produces a table with the following columns:

     - `locus` (:class:`.tlocus`) -- Locus.
     - `alleles` (:class:`.tarray` of :py:data:`.tstr`) -- Alleles.
     - `t` (:py:data:`.tint32`) -- Number of transmitted alternate alleles.
     - `u` (:py:data:`.tint32`) -- Number of untransmitted alternate alleles.
     - `chi_sq` (:py:data:`.tfloat64`) -- TDT statistic.
     - `p_value` (:py:data:`.tfloat64`) -- p-value.

    Parameters
    ----------
    dataset : :class:`.MatrixTable`
        Dataset.
    pedigree : :class:`~hail.genetics.Pedigree`
        Sample pedigree.

    Returns
    -------
    :class:`.Table`
        Table of TDT results.
    """

    dataset = require_biallelic(dataset, 'transmission_disequilibrium_test')
    dataset = dataset.annotate_rows(auto_or_x_par=dataset.locus.in_autosome()
                                    | dataset.locus.in_x_par())
    dataset = dataset.filter_rows(dataset.auto_or_x_par
                                  | dataset.locus.in_x_nonpar())

    hom_ref = 0
    het = 1
    hom_var = 2

    auto = 2
    hemi_x = 1

    #                     kid,     dad,     mom,   copy, t, u
    config_counts = [(hom_ref, het, het, auto, 0, 2),
                     (hom_ref, hom_ref, het, auto, 0, 1),
                     (hom_ref, het, hom_ref, auto, 0, 1),
                     (het, het, het, auto, 1, 1),
                     (het, hom_ref, het, auto, 1, 0),
                     (het, het, hom_ref, auto, 1, 0),
                     (het, hom_var, het, auto, 0, 1),
                     (het, het, hom_var, auto, 0, 1),
                     (hom_var, het, het, auto, 2, 0),
                     (hom_var, het, hom_var, auto, 1, 0),
                     (hom_var, hom_var, het, auto, 1, 0),
                     (hom_ref, hom_ref, het, hemi_x, 0, 1),
                     (hom_ref, hom_var, het, hemi_x, 0, 1),
                     (hom_var, hom_ref, het, hemi_x, 1, 0),
                     (hom_var, hom_var, het, hemi_x, 1, 0)]

    count_map = hl.literal({(c[0], c[1], c[2], c[3]): [c[4], c[5]]
                            for c in config_counts})

    tri = trio_matrix(dataset, pedigree, complete_trios=True)

    # this filter removes mendel error of het father in x_nonpar. It also avoids
    #   building and looking up config in common case that neither parent is het
    father_is_het = tri.father_entry.GT.is_het()
    parent_is_valid_het = ((father_is_het & tri.auto_or_x_par) |
                           (tri.mother_entry.GT.is_het() & ~father_is_het))

    copy_state = hl.cond(tri.auto_or_x_par | tri.is_female, 2, 1)

    config = (tri.proband_entry.GT.n_alt_alleles(),
              tri.father_entry.GT.n_alt_alleles(),
              tri.mother_entry.GT.n_alt_alleles(), copy_state)

    tri = tri.annotate_rows(counts=agg.filter(
        parent_is_valid_het, agg.array_sum(count_map.get(config))))

    tab = tri.rows().select('counts')
    tab = tab.transmute(t=tab.counts[0], u=tab.counts[1])
    tab = tab.annotate(chi_sq=((tab.t - tab.u)**2) / (tab.t + tab.u))
    tab = tab.annotate(p_value=hl.pchisqtail(tab.chi_sq, 1.0))

    return tab.cache()
Ejemplo n.º 19
0
def transmission_disequilibrium_test(dataset, pedigree) -> Table:
    r"""Performs the transmission disequilibrium test on trios.

    .. include:: ../_templates/req_tstring.rst

    .. include:: ../_templates/req_tvariant.rst

    .. include:: ../_templates/req_biallelic.rst

    Examples
    --------
    Compute TDT association statistics and show the first two results:
    
    >>> pedigree = hl.Pedigree.read('data/tdt_trios.fam')
    >>> tdt_table = hl.transmission_disequilibrium_test(tdt_dataset, pedigree)
    >>> tdt_table.show(2)  # doctest: +NOTEST
    +---------------+------------+-------+-------+----------+----------+
    | locus         | alleles    |     t |     u |   chi_sq |  p_value |
    +---------------+------------+-------+-------+----------+----------+
    | locus<GRCh37> | array<str> | int64 | int64 |  float64 |  float64 |
    +---------------+------------+-------+-------+----------+----------+
    | 1:246714629   | ["C","A"]  |     0 |     4 | 4.00e+00 | 4.55e-02 |
    | 2:167262169   | ["T","C"]  |    NA |    NA |       NA |       NA |
    +---------------+------------+-------+-------+----------+----------+

    Export variants with p-values below 0.001:

    >>> tdt_table = tdt_table.filter(tdt_table.p_value < 0.001)
    >>> tdt_table.export("output/tdt_results.tsv")

    Notes
    -----
    The
    `transmission disequilibrium test <https://en.wikipedia.org/wiki/Transmission_disequilibrium_test#The_case_of_trios:_one_affected_child_per_family>`__
    compares the number of times the alternate allele is transmitted (t) versus
    not transmitted (u) from a heterozgyous parent to an affected child. The null
    hypothesis holds that each case is equally likely. The TDT statistic is given by

    .. math::

        (t - u)^2 \over (t + u)

    and asymptotically follows a chi-squared distribution with one degree of
    freedom under the null hypothesis.

    :func:`transmission_disequilibrium_test` only considers complete trios (two
    parents and a proband with defined sex) and only returns results for the
    autosome, as defined by :meth:`~hail.genetics.Locus.in_autosome`, and
    chromosome X. Transmissions and non-transmissions are counted only for the
    configurations of genotypes and copy state in the table below, in order to
    filter out Mendel errors and configurations where transmission is
    guaranteed. The copy state of a locus with respect to a trio is defined as
    follows:

    - Auto -- in autosome or in PAR of X or female child
    - HemiX -- in non-PAR of X and male child

    Here PAR is the `pseudoautosomal region
    <https://en.wikipedia.org/wiki/Pseudoautosomal_region>`__
    of X and Y defined by :class:`.ReferenceGenome`, which many variant callers
    map to chromosome X.

    +--------+--------+--------+------------+---+---+
    |  Kid   | Dad    | Mom    | Copy State | t | u |
    +========+========+========+============+===+===+
    | HomRef | Het    | Het    | Auto       | 0 | 2 |
    +--------+--------+--------+------------+---+---+
    | HomRef | HomRef | Het    | Auto       | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | HomRef | Het    | HomRef | Auto       | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | Het    | Het    | Het    | Auto       | 1 | 1 |
    +--------+--------+--------+------------+---+---+
    | Het    | HomRef | Het    | Auto       | 1 | 0 |
    +--------+--------+--------+------------+---+---+
    | Het    | Het    | HomRef | Auto       | 1 | 0 |
    +--------+--------+--------+------------+---+---+
    | Het    | HomVar | Het    | Auto       | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | Het    | Het    | HomVar | Auto       | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | HomVar | Het    | Het    | Auto       | 2 | 0 |
    +--------+--------+--------+------------+---+---+
    | HomVar | Het    | HomVar | Auto       | 1 | 0 |
    +--------+--------+--------+------------+---+---+
    | HomVar | HomVar | Het    | Auto       | 1 | 0 |
    +--------+--------+--------+------------+---+---+
    | HomRef | HomRef | Het    | HemiX      | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | HomRef | HomVar | Het    | HemiX      | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | HomVar | HomRef | Het    | HemiX      | 1 | 0 |
    +--------+--------+--------+------------+---+---+
    | HomVar | HomVar | Het    | HemiX      | 1 | 0 |
    +--------+--------+--------+------------+---+---+

    :func:`tdt` produces a table with the following columns:

     - `locus` (:class:`.tlocus`) -- Locus.
     - `alleles` (:class:`.tarray` of :py:data:`.tstr`) -- Alleles.
     - `t` (:py:data:`.tint32`) -- Number of transmitted alternate alleles.
     - `u` (:py:data:`.tint32`) -- Number of untransmitted alternate alleles.
     - `chi_sq` (:py:data:`.tfloat64`) -- TDT statistic.
     - `p_value` (:py:data:`.tfloat64`) -- p-value.

    Parameters
    ----------
    dataset : :class:`.MatrixTable`
        Dataset.
    pedigree : :class:`~hail.genetics.Pedigree`
        Sample pedigree.

    Returns
    -------
    :class:`.Table`
        Table of TDT results.
    """

    dataset = require_biallelic(dataset, 'transmission_disequilibrium_test')
    dataset = dataset.annotate_rows(auto_or_x_par = dataset.locus.in_autosome() | dataset.locus.in_x_par())
    dataset = dataset.filter_rows(dataset.auto_or_x_par | dataset.locus.in_x_nonpar())

    hom_ref = 0
    het = 1
    hom_var = 2

    auto = 2
    hemi_x = 1

    #                     kid,     dad,     mom,   copy, t, u
    config_counts = [(hom_ref,     het,     het,   auto, 0, 2),
                     (hom_ref, hom_ref,     het,   auto, 0, 1),
                     (hom_ref,     het, hom_ref,   auto, 0, 1),
                     (    het,     het,     het,   auto, 1, 1),
                     (    het, hom_ref,     het,   auto, 1, 0),
                     (    het,     het, hom_ref,   auto, 1, 0),
                     (    het, hom_var,     het,   auto, 0, 1),
                     (    het,     het, hom_var,   auto, 0, 1),
                     (hom_var,     het,     het,   auto, 2, 0),
                     (hom_var,     het, hom_var,   auto, 1, 0),
                     (hom_var, hom_var,     het,   auto, 1, 0),
                     (hom_ref, hom_ref,     het, hemi_x, 0, 1),
                     (hom_ref, hom_var,     het, hemi_x, 0, 1),
                     (hom_var, hom_ref,     het, hemi_x, 1, 0),
                     (hom_var, hom_var,     het, hemi_x, 1, 0)]

    count_map = hl.literal({(c[0], c[1], c[2], c[3]): [c[4], c[5]] for c in config_counts})

    tri = trio_matrix(dataset, pedigree, complete_trios=True)

    # this filter removes mendel error of het father in x_nonpar. It also avoids
    #   building and looking up config in common case that neither parent is het
    father_is_het = tri.father_entry.GT.is_het()
    parent_is_valid_het = ((father_is_het & tri.auto_or_x_par) |
                           (tri.mother_entry.GT.is_het() & ~father_is_het))

    copy_state = hl.cond(tri.auto_or_x_par | tri.is_female, 2, 1)

    config = (tri.proband_entry.GT.n_alt_alleles(),
              tri.father_entry.GT.n_alt_alleles(),
              tri.mother_entry.GT.n_alt_alleles(),
              copy_state)

    tri = tri.annotate_rows(counts = agg.filter(parent_is_valid_het, agg.array_sum(count_map.get(config))))

    tab = tri.rows().select('counts')
    tab = tab.transmute(t = tab.counts[0], u = tab.counts[1])
    tab = tab.annotate(chi_sq = ((tab.t - tab.u) ** 2) / (tab.t + tab.u))
    tab = tab.annotate(p_value = hl.pchisqtail(tab.chi_sq, 1.0))

    return tab.cache()