Ejemplo n.º 1
0
    def __init__(self, schema, paths, key, intervals):
        assert (key is None) == (intervals is None)
        self.schema = schema
        self.paths = paths
        self.key = key

        if intervals is not None:
            t = hl.expr.impute_type(intervals)
            if not isinstance(t, hl.tarray) and not isinstance(
                    t.element_type, hl.tinterval):
                raise TypeError("'intervals' must be an array of tintervals")
            pt = t.element_type.point_type
            if isinstance(pt, hl.tstruct):
                self._interval_type = t
            else:
                self._interval_type = hl.tarray(
                    hl.tinterval(hl.tstruct(__point=pt)))

        if intervals is not None and t != self._interval_type:
            self.intervals = [
                hl.Interval(hl.Struct(__point=i.start),
                            hl.Struct(__point=i.end), i.includes_start,
                            i.includes_end) for i in intervals
            ]
        else:
            self.intervals = intervals
Ejemplo n.º 2
0
def generate_5_sample_vds():
    paths = [
        os.path.join(resource('gvcfs'), '1kg_chr22', path) for path in [
            'HG00187.hg38.g.vcf.gz', 'HG00190.hg38.g.vcf.gz',
            'HG00308.hg38.g.vcf.gz', 'HG00313.hg38.g.vcf.gz',
            'HG00320.hg38.g.vcf.gz'
        ]
    ]
    parts = [
        hl.Interval(start=hl.Struct(
            locus=hl.Locus('chr22', 1, reference_genome='GRCh38')),
                    end=hl.Struct(locus=hl.Locus(
                        'chr22',
                        hl.get_reference('GRCh38').contig_length('chr22') - 1,
                        reference_genome='GRCh38')),
                    includes_end=True)
    ]
    vcfs = hl.import_gvcfs(paths,
                           parts,
                           reference_genome='GRCh38',
                           array_elements_required=False)
    to_keep = defined_entry_fields(
        vcfs[0].filter_rows(hl.is_defined(vcfs[0].info.END)), 100_000)
    vds = hl.vds.combiner.combine_variant_datasets(
        [hl.vds.combiner.transform_gvcf(mt, to_keep) for mt in vcfs])
    vds.variant_data = vds.variant_data._key_rows_by_assert_sorted(
        'locus', 'alleles')
    vds.write(os.path.join(resource('vds'), '1kg_chr22_5_samples.vds'),
              overwrite=True)
Ejemplo n.º 3
0
    def test_aggregate(self):
        vds = self.get_vds()

        vds = vds.annotate_globals(foo=5)
        vds = vds.annotate_rows(x1=agg.count())
        vds = vds.annotate_cols(y1=agg.count())
        vds = vds.annotate_entries(z1=vds.DP)

        qv = vds.aggregate_rows(agg.count())
        qs = vds.aggregate_cols(agg.count())
        qg = vds.aggregate_entries(agg.count())

        self.assertIsNotNone(vds.aggregate_entries(hl.agg.take(vds.s, 1)[0]))

        self.assertEqual(qv, 346)
        self.assertEqual(qs, 100)
        self.assertEqual(qg, qv * qs)

        qvs = vds.aggregate_rows(
            hl.Struct(x=agg.collect(vds.locus.contig), y=agg.collect(vds.x1)))

        qss = vds.aggregate_cols(
            hl.Struct(x=agg.collect(vds.s), y=agg.collect(vds.y1)))

        qgs = vds.aggregate_entries(
            hl.Struct(x=agg.collect(agg.filter(False, vds.y1)),
                      y=agg.collect(agg.filter(hl.rand_bool(0.1), vds.GT))))
Ejemplo n.º 4
0
    def test_annotate_globals(self):
        mt = hl.utils.range_matrix_table(1, 1)
        ht = hl.utils.range_table(1, 1)
        data = [(5, hl.tint, operator.eq),
                (float('nan'), hl.tfloat32, lambda x, y: str(x) == str(y)),
                (float('inf'), hl.tfloat64, lambda x, y: str(x) == str(y)),
                (float('-inf'), hl.tfloat64, lambda x, y: str(x) == str(y)),
                (1.111, hl.tfloat64, operator.eq),
                ([
                    hl.Struct(**{
                        'a': None,
                        'b': 5
                    }),
                    hl.Struct(**{
                        'a': 'hello',
                        'b': 10
                    })
                ], hl.tarray(hl.tstruct(a=hl.tstr, b=hl.tint)), operator.eq)]

        for x, t, f in data:
            self.assertTrue(
                f(mt.annotate_globals(foo=hl.literal(x, t)).foo.value, x),
                f"{x}, {t}")
            self.assertTrue(
                f(ht.annotate_globals(foo=hl.literal(x, t)).foo.value, x),
                f"{x}, {t}")
Ejemplo n.º 5
0
        def collect_mappings_and_precomputed(selected):
            mapping_per_geom = []
            precomputes = {}
            for geom_idx, geom in enumerate(self.geoms):
                geom_label = make_geom_label(geom_idx)

                combined_mapping = selected["figure_mapping"].annotate(
                    **selected[geom_label])

                for key in combined_mapping:
                    if key in self.scales:
                        combined_mapping = combined_mapping.annotate(
                            **{
                                key:
                                self.scales[key].transform_data(
                                    combined_mapping[key])
                            })
                mapping_per_geom.append(combined_mapping)
                precomputes[geom_label] = geom.get_stat().get_precomputes(
                    combined_mapping)

            # Is there anything to precompute?
            should_precompute = any(
                [len(precompute) > 0 for precompute in precomputes.values()])

            if should_precompute:
                precomputed = selected.aggregate(hl.struct(**precomputes))
            else:
                precomputed = hl.Struct(
                    **{key: hl.Struct()
                       for key in precomputes.keys()})

            return mapping_per_geom, precomputed
Ejemplo n.º 6
0
 def test_maximal_independent_set_on_floats(self):
     t = hl.utils.range_table(1).annotate(l=hl.struct(s="a", x=3.0),
                                          r=hl.struct(s="b", x=2.82))
     expected = [hl.Struct(node=hl.Struct(s="a", x=3.0))]
     actual = hl.maximal_independent_set(
         t.l, t.r, keep=False,
         tie_breaker=lambda l, r: l.x - r.x).collect()
     assert actual == expected
Ejemplo n.º 7
0
    def test_filter_intervals_compound_partition_key(self):
        ds = hl.import_vcf(resource('sample.vcf'), min_partitions=20)
        ds = (ds.annotate_rows(variant=hl.struct(locus=ds.locus, alleles=ds.alleles))
              .key_rows_by('locus', 'alleles'))

        intervals = [hl.Interval(hl.Struct(locus=hl.Locus('20', 10639222), alleles=['A', 'T']),
                                 hl.Struct(locus=hl.Locus('20', 10644700), alleles=['A', 'T']))]
        self.assertEqual(hl.filter_intervals(ds, intervals).count_rows(), 3)
Ejemplo n.º 8
0
def read(fname: str) -> 'DNDArray':
    # read without good partitioning, just to get the globals
    a = DNDArray(hl.read_table(fname))
    t = hl.read_table(fname, _intervals=[
        hl.Interval(hl.Struct(r=i, c=j),
                    hl.Struct(r=i, c=j + 1))
        for i in range(a.n_block_rows)
        for j in range(a.n_block_cols)])
    return DNDArray(t)
Ejemplo n.º 9
0
 def test_aggregate_by_key_partitioning(self):
     ht1 = hl.Table.parallelize([
         {'k': 'foo', 'b': 1},
         {'k': 'bar', 'b': 2},
         {'k': 'bar', 'b': 2}],
         hl.tstruct(k=hl.tstr, b=hl.tint32),
         key='k')
     self.assertEqual(
         set(ht1.group_by('k').aggregate(mean_b = hl.agg.mean(ht1.b)).collect()),
         {hl.Struct(k='foo', mean_b=1.0), hl.Struct(k='bar', mean_b=2.0)})
Ejemplo n.º 10
0
    def test_row_joins_into_table(self):
        rt = hl.utils.range_matrix_table(9, 13, 3)
        mt1 = rt.key_rows_by(idx=rt.row_idx)
        mt1 = mt1.select_rows(v=mt1.idx + 2)
        mt2 = rt.key_rows_by(idx=rt.row_idx, idx2=rt.row_idx + 1)
        mt2 = mt2.select_rows(v=mt2.idx + 2)

        t1 = hl.utils.range_table(10, 3)
        t2 = t1.key_by(t1.idx, idx2=t1.idx + 1)
        t1 = t1.select(v=t1.idx + 2)
        t2 = t2.select(v=t2.idx + 2)

        tinterval1 = t1.key_by(k=hl.interval(t1.idx, t1.idx, True, True))
        tinterval1 = tinterval1.select(v=tinterval1.idx + 2)
        tinterval2 = t2.key_by(k=hl.interval(t2.key, t2.key, True, True))
        tinterval2 = tinterval2.select(v=tinterval2.idx + 2)

        values = [hl.Struct(v=i + 2) for i in range(9)]
        # join on mt row key
        self.assertEqual(t1.index(mt1.row_key).collect(), values)
        self.assertEqual(t2.index(mt2.row_key).collect(), values)
        self.assertEqual(t1.index(mt1.idx).collect(), values)
        self.assertEqual(t2.index(mt2.idx, mt2.idx2).collect(), values)
        self.assertEqual(t1.index(mt2.idx).collect(), values)
        with self.assertRaises(hl.expr.ExpressionException):
            t2.index(mt2.idx).collect()
        with self.assertRaises(hl.expr.ExpressionException):
            t2.index(mt1.row_key).collect()

        # join on not mt row key
        self.assertEqual(
            t1.index(mt1.v).collect(),
            [hl.Struct(v=i + 2) for i in range(2, 10)] + [None])
        self.assertEqual(
            t2.index(mt2.idx2, mt2.v).collect(),
            [hl.Struct(v=i + 2) for i in range(1, 10)])
        with self.assertRaises(hl.expr.ExpressionException):
            t2.index(mt2.v).collect()

        # join on interval of first field of mt row key
        self.assertEqual(tinterval1.index(mt1.idx).collect(), values)
        self.assertEqual(tinterval1.index(mt1.row_key).collect(), values)
        self.assertEqual(tinterval1.index(mt2.idx).collect(), values)

        with self.assertRaises(hl.expr.ExpressionException):
            tinterval1.index(mt2.row_key).collect()
        with self.assertRaises(hl.expr.ExpressionException):
            tinterval2.index(mt2.idx).collect()
        with self.assertRaises(hl.expr.ExpressionException):
            tinterval2.index(mt2.row_key).collect()
        with self.assertRaises(hl.expr.ExpressionException):
            tinterval2.index(mt2.idx, mt2.idx2).collect()
Ejemplo n.º 11
0
    def test(self):
        schema = hl.tstruct(a=hl.tint32, b=hl.tint32, c=hl.tint32, d=hl.tint32, e=hl.tstr,
                            f=hl.tarray(hl.tint32),
                            g=hl.tarray(
                                hl.tstruct(x=hl.tint32, y=hl.tint32, z=hl.tstr)),
                            h=hl.tstruct(a=hl.tint32, b=hl.tint32, c=hl.tstr),
                            i=hl.tbool,
                            j=hl.tstruct(x=hl.tint32, y=hl.tint32, z=hl.tstr))

        rows = [{'a': 4, 'b': 1, 'c': 3, 'd': 5,
                 'e': "hello", 'f': [1, 2, 3],
                 'g': [hl.Struct(x=1, y=5, z='banana')],
                 'h': hl.Struct(a=5, b=3, c='winter'),
                 'i': True,
                 'j': hl.Struct(x=3, y=2, z='summer')}]

        kt = hl.Table.parallelize(rows, schema)

        result = convert_struct_to_dict(kt.annotate(
            chisq=hl.chisq(kt.a, kt.b, kt.c, kt.d),
            ctt=hl.ctt(kt.a, kt.b, kt.c, kt.d, 5),
            dict=hl.dict(hl.zip([kt.a, kt.b], [kt.c, kt.d])),
            dpois=hl.dpois(4, kt.a),
            drop=kt.h.drop('b', 'c'),
            exp=hl.exp(kt.c),
            fet=hl.fisher_exact_test(kt.a, kt.b, kt.c, kt.d),
            hwe=hl.hardy_weinberg_p(1, 2, 1),
            index=hl.index(kt.g, 'z'),
            is_defined=hl.is_defined(kt.i),
            is_missing=hl.is_missing(kt.i),
            is_nan=hl.is_nan(hl.float64(kt.a)),
            json=hl.json(kt.g),
            log=hl.log(kt.a, kt.b),
            log10=hl.log10(kt.c),
            or_else=hl.or_else(kt.a, 5),
            or_missing=hl.or_missing(kt.i, kt.j),
            pchisqtail=hl.pchisqtail(kt.a, kt.b),
            pcoin=hl.rand_bool(0.5),
            pnorm=hl.pnorm(0.2),
            pow=2.0 ** kt.b,
            ppois=hl.ppois(kt.a, kt.b),
            qchisqtail=hl.qchisqtail(kt.a, kt.b),
            range=hl.range(0, 5, kt.b),
            rnorm=hl.rand_norm(0.0, kt.b),
            rpois=hl.rand_pois(kt.a),
            runif=hl.rand_unif(kt.b, kt.a),
            select=kt.h.select('c', 'b'),
            sqrt=hl.sqrt(kt.a),
            to_str=[hl.str(5), hl.str(kt.a), hl.str(kt.g)],
            where=hl.cond(kt.i, 5, 10)
        ).take(1)[0])
Ejemplo n.º 12
0
 def get_groupable_matrix():
     rt = hl.utils.range_matrix_table(n_rows=100, n_cols=20)
     rt = rt.annotate_globals(foo="foo")
     rt = rt.annotate_rows(group1=rt['row_idx'] % 6,
                           group2=hl.Struct(a=rt['row_idx'] % 6, b="foo"))
     rt = rt.annotate_cols(group3=rt['col_idx'] % 6,
                           group4=hl.Struct(a=rt['col_idx'] % 6, b="foo"))
     return rt.annotate_entries(c=rt['row_idx'],
                                d=rt['col_idx'],
                                e="foo",
                                f=rt['group1'],
                                g=rt['group2']['a'],
                                h=rt['group3'],
                                i=rt['group4']['a'])
Ejemplo n.º 13
0
def test_sampleqc_gq_dp():
    vds = hl.vds.read_vds(
        os.path.join(resource('vds'), '1kg_chr22_5_samples.vds'))
    sqc = hl.vds.sample_qc(vds)

    assert hl.eval(sqc.index_globals()) == hl.Struct(gq_bins=(0, 20, 60),
                                                     dp_bins=(0, 1, 10, 20,
                                                              30))

    hg00320 = sqc.filter(sqc.s == 'HG00320').select(
        'bases_over_gq_threshold', 'bases_over_dp_threshold').collect()[0]
    assert hg00320 == hl.Struct(s='HG00320',
                                bases_over_gq_threshold=(334822, 515, 82),
                                bases_over_dp_threshold=(334822, 10484, 388,
                                                         111, 52))
Ejemplo n.º 14
0
    def test_concordance(self):
        dataset = get_dataset()
        glob_conc, cols_conc, rows_conc = hl.concordance(dataset, dataset)

        self.assertEqual(sum([sum(glob_conc[i]) for i in range(5)]), dataset.count_rows() * dataset.count_cols())

        counts = dataset.aggregate_entries(hl.Struct(n_het=agg.filter(dataset.GT.is_het(), agg.count()),
                                                     n_hom_ref=agg.filter(dataset.GT.is_hom_ref(),
                                                                          agg.count()),
                                                     n_hom_var=agg.filter(dataset.GT.is_hom_var(),
                                                                          agg.count()),
                                                     nNoCall=agg.filter(hl.is_missing(dataset.GT),
                                                                        agg.count())))

        self.assertEqual(glob_conc[0][0], 0)
        self.assertEqual(glob_conc[1][1], counts.nNoCall)
        self.assertEqual(glob_conc[2][2], counts.n_hom_ref)
        self.assertEqual(glob_conc[3][3], counts.n_het)
        self.assertEqual(glob_conc[4][4], counts.n_hom_var)
        [self.assertEqual(glob_conc[i][j], 0) for i in range(5) for j in range(5) if i != j]

        self.assertTrue(cols_conc.all(hl.sum(hl.flatten(cols_conc.concordance)) == dataset.count_rows()))
        self.assertTrue(rows_conc.all(hl.sum(hl.flatten(rows_conc.concordance)) == dataset.count_cols()))

        cols_conc.write('/tmp/foo.kt', overwrite=True)
        rows_conc.write('/tmp/foo.kt', overwrite=True)
Ejemplo n.º 15
0
def annotate_gene_transcripts_with_refseq_id(table_path,
                                             mane_select_transcripts_path):
    mane_select_transcripts = hl.read_table(mane_select_transcripts_path)

    ensembl_to_refseq_map = {}
    for transcript in mane_select_transcripts.collect():
        ensembl_to_refseq_map[transcript.ensembl_id] = {
            transcript.ensembl_version:
            hl.Struct(refseq_id=transcript.refseq_id,
                      refseq_version=transcript.refseq_version)
        }

    ensembl_to_refseq_map = hl.literal(ensembl_to_refseq_map)

    genes = hl.read_table(table_path)

    genes = genes.annotate(transcripts=genes.transcripts.map(
        lambda transcript: transcript.annotate(**ensembl_to_refseq_map.get(
            transcript.transcript_id,
            hl.empty_dict(
                hl.tstr, hl.tstruct(refseq_id=hl.tstr, refseq_version=hl.tstr)
            ),
        ).get(
            transcript.transcript_version,
            hl.struct(refseq_id=hl.null(hl.tstr),
                      refseq_version=hl.null(hl.tstr)),
        ))))

    return genes
Ejemplo n.º 16
0
def process_joins(obj, exprs, broadcast_f):
    all_uids = []
    left = obj
    used_joins = set()
    broadcasts = []

    for e in exprs:
        joins = e._ast.search(lambda a: isinstance(a, hail.expr.expr_ast.Join))
        for j in sorted(joins, key=lambda j: j.idx): # Make sure joins happen in order
            if j not in used_joins:
                left = j.join_func(left)
                all_uids.extend(j.temp_vars)
                used_joins.add(j)
        broadcasts.extend(e._ast.search(lambda a: isinstance(a, hail.expr.expr_ast.Broadcast)))

    if broadcasts:
        t = hail.tstruct(**{b.uid: b.dtype for b in broadcasts})
        all_uids.extend(list(t))
        data = hail.Struct(**{b.uid: b.value for b in broadcasts})
        data_json = t._to_json(data)
        left = broadcast_f(left, data_json, t._jtype)

    def cleanup(table):
        remaining_uids = [uid for uid in all_uids if uid in table._fields]
        return table.drop(*remaining_uids)

    return left, cleanup
Ejemplo n.º 17
0
    def test_aggregate2(self):
        schema = hl.tstruct(status=hl.tint32, GT=hl.tcall, qPheno=hl.tint32)

        rows = [{'status': 0, 'GT': hl.Call([0, 0]), 'qPheno': 3},
                {'status': 0, 'GT': hl.Call([0, 1]), 'qPheno': 13}]

        kt = hl.Table.parallelize(rows, schema)

        result = convert_struct_to_dict(
            kt.group_by(status=kt.status)
                .aggregate(
                x1=agg.collect(kt.qPheno * 2),
                x2=agg.explode(lambda elt: agg.collect(elt), [kt.qPheno, kt.qPheno + 1]),
                x3=agg.min(kt.qPheno),
                x4=agg.max(kt.qPheno),
                x5=agg.sum(kt.qPheno),
                x6=agg.product(hl.int64(kt.qPheno)),
                x7=agg.count(),
                x8=agg.count_where(kt.qPheno == 3),
                x9=agg.fraction(kt.qPheno == 1),
                x10=agg.stats(hl.float64(kt.qPheno)),
                x11=agg.hardy_weinberg_test(kt.GT),
                x13=agg.inbreeding(kt.GT, 0.1),
                x14=agg.call_stats(kt.GT, ["A", "T"]),
                x15=agg.collect(hl.Struct(a=5, b="foo", c=hl.Struct(banana='apple')))[0],
                x16=agg.collect(hl.Struct(a=5, b="foo", c=hl.Struct(banana='apple')).c.banana)[0],
                x17=agg.explode(lambda elt: agg.collect(elt), hl.null(hl.tarray(hl.tint32))),
                x18=agg.explode(lambda elt: agg.collect(elt), hl.null(hl.tset(hl.tint32))),
                x19=agg.take(kt.GT, 1, ordering=-kt.qPheno)
            ).take(1)[0])

        expected = {u'status': 0,
                    u'x13': {u'n_called': 2, u'expected_homs': 1.64, u'f_stat': -1.777777777777777,
                             u'observed_homs': 1},
                    u'x14': {u'AC': [3, 1], u'AF': [0.75, 0.25], u'AN': 4, u'homozygote_count': [1, 0]},
                    u'x15': {u'a': 5, u'c': {u'banana': u'apple'}, u'b': u'foo'},
                    u'x10': {u'min': 3.0, u'max': 13.0, u'sum': 16.0, u'stdev': 5.0, u'n': 2, u'mean': 8.0},
                    u'x8': 1, u'x9': 0.0, u'x16': u'apple',
                    u'x11': {u'het_freq_hwe': 0.5, u'p_value': 0.5},
                    u'x2': [3, 4, 13, 14], u'x3': 3, u'x1': [6, 26], u'x6': 39, u'x7': 2, u'x4': 13, u'x5': 16,
                    u'x17': [],
                    u'x18': [],
                    u'x19': [hl.Call([0, 1])]}

        self.maxDiff = None

        self.assertDictEqual(result, expected)
Ejemplo n.º 18
0
    def test_DB(self):
        mt = hl.balding_nichols_model(n_populations=3,
                                      n_samples=50,
                                      n_variants=10010)

        db = hl.experimental.DB()
        mt = db.annotate_rows_db(mt, "DANN")

        actual = mt.filter_rows(hl.is_defined(mt.DANN)).DANN.collect()
        expected = [
            hl.Struct(score=0.3618202027281013),
            hl.Struct(score=0.36516159615040267),
            hl.Struct(score=0.3678246364006052),
            hl.Struct(score=0.3697632743148331)
        ]
        for i in range(len(array1)):
            self.assertAlmostEqual(actual[i], expected[i])
Ejemplo n.º 19
0
 def test_value_same_after_parsing(self):
     for t, v in self.values():
         row_v = ir.Literal(t, v)
         map_globals_ir = ir.TableMapGlobals(
             ir.TableRange(1, 1),
             ir.InsertFields(ir.Ref("global"), [("foo", row_v)], None))
         new_globals = hl.eval(hl.Table(map_globals_ir).index_globals())
         self.assertEqual(new_globals, hl.Struct(foo=v))
Ejemplo n.º 20
0
 def test_value_same_after_parsing(self):
     for t, v in self.values():
         row_v = ir.Literal(t, v)
         map_globals_ir = ir.TableMapGlobals(
             ir.TableRange(1, 1),
             ir.InsertFields(ir.Ref("global", hl.tstruct()),
                             [("foo", row_v)]))
         new_globals = hl.Table._from_ir(map_globals_ir).globals.value
         self.assertEquals(new_globals, hl.Struct(foo=v))
Ejemplo n.º 21
0
    def __init__(self, path, intervals, filter_intervals):
        if intervals is not None:
            t = hl.expr.impute_type(intervals)
            if not isinstance(t, hl.tarray) and not isinstance(t.element_type, hl.tinterval):
                raise TypeError("'intervals' must be an array of tintervals")
            pt = t.element_type.point_type
            if isinstance(pt, hl.tstruct):
                self._interval_type = t
            else:
                self._interval_type = hl.tarray(hl.tinterval(hl.tstruct(__point=pt)))

        self.path = path
        self.filter_intervals = filter_intervals
        if intervals is not None and t != self._interval_type:
            self.intervals = [hl.Interval(hl.Struct(__point=i.start),
                                          hl.Struct(__point=i.end),
                                          i.includes_start,
                                          i.includes_end) for i in intervals]
        else:
            self.intervals = intervals
Ejemplo n.º 22
0
    def test_multiple_files_variant_filtering(self):
        bgen_file = [
            resource('random-b.bgen'),
            resource('random-c.bgen'),
            resource('random-a.bgen')
        ]
        hl.index_bgen(bgen_file)

        alleles = ['A', 'G']

        desired_variants = [
            hl.Struct(locus=hl.Locus('20', 11), alleles=alleles),
            hl.Struct(locus=hl.Locus('20', 13), alleles=alleles),
            hl.Struct(locus=hl.Locus('20', 29), alleles=alleles),
            hl.Struct(locus=hl.Locus('20', 28), alleles=alleles),
            hl.Struct(locus=hl.Locus('20', 1), alleles=alleles),
            hl.Struct(locus=hl.Locus('20', 12), alleles=alleles),
        ]

        actual = hl.import_bgen(bgen_file, ['GT'],
                                n_partitions=10,
                                variants=desired_variants)
        self.assertEqual(actual.count_rows(), 6)

        everything = hl.import_bgen(bgen_file, ['GT'])
        self.assertEqual(everything.count(), (30, 10))

        expected = everything.filter_rows(
            hl.set(desired_variants).contains(everything.row_key))

        self.assertTrue(expected._same(actual))
Ejemplo n.º 23
0
def test_pc_relate_simple_example():
    gs = hl.literal(
        [[0, 0, 0, 0, 1, 1, 1, 1],
         [0, 0, 1, 1, 0, 0, 1, 1],
         [0, 1, 0, 1, 0, 1, 0, 1],
         [0, 0, 1, 1, 0, 0, 1, 1]])
    scores = hl.literal([[0, 1], [1, 1], [1, 0], [0, 0]])
    mt = hl.utils.range_matrix_table(n_rows=8, n_cols=4)
    mt = mt.annotate_entries(GT=hl.unphased_diploid_gt_index_call(gs[mt.col_idx][mt.row_idx]))
    mt = mt.annotate_cols(scores=scores[mt.col_idx])
    pcr = hl.pc_relate(mt.GT, min_individual_maf=0, scores_expr=mt.scores)

    expected = [
        hl.Struct(i=0, j=1, kin=-0.14570713364640647,
                  ibd0=1.4823511628401964, ibd1=-0.38187379109476693, ibd2=-0.10047737174542953),
        hl.Struct(i=0, j=2, kin=0.16530591922102378,
                  ibd0=0.5234783206257841, ibd1=0.2918196818643366, ibd2=0.18470199750987923),
        hl.Struct(i=0, j=3, kin=-0.14570713364640647,
                  ibd0=1.4823511628401964, ibd1=-0.38187379109476693, ibd2=-0.10047737174542953),
        hl.Struct(i=1, j=2, kin=-0.14570713364640647,
                  ibd0=1.4823511628401964, ibd1=-0.38187379109476693, ibd2=-0.10047737174542953),
        hl.Struct(i=1, j=3, kin=0.14285714285714285,
                  ibd0=0.7027734170591313, ibd1=0.02302459445316596, ibd2=0.2742019884877027),
        hl.Struct(i=2, j=3, kin=-0.14570713364640647,
                  ibd0=1.4823511628401964, ibd1=-0.38187379109476693, ibd2=-0.10047737174542953),
    ]
    ht_expected = hl.Table.parallelize(expected)
    ht_expected = ht_expected.key_by(i=hl.struct(col_idx=ht_expected.i),
                                     j=hl.struct(col_idx=ht_expected.j))
    assert ht_expected._same(pcr)
Ejemplo n.º 24
0
def test_combiner_works():
    _paths = ['gvcfs/HG00096.g.vcf.gz', 'gvcfs/HG00268.g.vcf.gz']
    paths = [resource(p) for p in _paths]
    parts = [
        hl.Interval(start=hl.Struct(locus=hl.Locus('chr20', 17821257, reference_genome='GRCh38')),
                    end=hl.Struct(locus=hl.Locus('chr20', 18708366, reference_genome='GRCh38')),
                    includes_end=True),
        hl.Interval(start=hl.Struct(locus=hl.Locus('chr20', 18708367, reference_genome='GRCh38')),
                    end=hl.Struct(locus=hl.Locus('chr20', 19776611, reference_genome='GRCh38')),
                    includes_end=True),
        hl.Interval(start=hl.Struct(locus=hl.Locus('chr20', 19776612, reference_genome='GRCh38')),
                    end=hl.Struct(locus=hl.Locus('chr20', 21144633, reference_genome='GRCh38')),
                    includes_end=True)
    ]
    vcfs = hl.import_gvcfs(paths, parts, reference_genome='GRCh38', array_elements_required=False)
    entry_to_keep = defined_entry_fields(vcfs[0].filter_rows(hl.is_defined(vcfs[0].info.END)), 100_000) - {'GT', 'PGT', 'PL'}
    vcfs = [transform_gvcf(mt.annotate_rows(info=mt.info.annotate(
        MQ_DP=hl.missing(hl.tint32),
        VarDP=hl.missing(hl.tint32),
        QUALapprox=hl.missing(hl.tint32))),
                           reference_entry_fields_to_keep=entry_to_keep)
            for mt in vcfs]
    comb = combine_variant_datasets(vcfs)
    assert len(parts) == comb.variant_data.n_partitions()
    comb.variant_data._force_count_rows()
    comb.reference_data._force_count_rows()
Ejemplo n.º 25
0
def test_vcf_vds_combiner_equivalence():
    import hail.experimental.vcf_combiner.vcf_combiner as vcf
    import hail.vds.combiner as vds
    _paths = ['gvcfs/HG00096.g.vcf.gz', 'gvcfs/HG00268.g.vcf.gz']
    paths = [resource(p) for p in _paths]
    parts = [
        hl.Interval(start=hl.Struct(locus=hl.Locus('chr20', 17821257, reference_genome='GRCh38')),
                    end=hl.Struct(locus=hl.Locus('chr20', 18708366, reference_genome='GRCh38')),
                    includes_end=True),
        hl.Interval(start=hl.Struct(locus=hl.Locus('chr20', 18708367, reference_genome='GRCh38')),
                    end=hl.Struct(locus=hl.Locus('chr20', 19776611, reference_genome='GRCh38')),
                    includes_end=True),
        hl.Interval(start=hl.Struct(locus=hl.Locus('chr20', 19776612, reference_genome='GRCh38')),
                    end=hl.Struct(locus=hl.Locus('chr20', 21144633, reference_genome='GRCh38')),
                    includes_end=True)
    ]
    vcfs = [mt.annotate_rows(info=mt.info.annotate(
        MQ_DP=hl.missing(hl.tint32),
        VarDP=hl.missing(hl.tint32),
        QUALapprox=hl.missing(hl.tint32)))
            for mt in hl.import_gvcfs(paths, parts, reference_genome='GRCh38',
                                      array_elements_required=False)]
    entry_to_keep = defined_entry_fields(vcfs[0].filter_rows(hl.is_defined(vcfs[0].info.END)), 100_000) - {'GT', 'PGT', 'PL'}
    vds = vds.combine_variant_datasets([vds.transform_gvcf(mt, reference_entry_fields_to_keep=entry_to_keep) for mt in vcfs])
    smt = vcf.combine_gvcfs([vcf.transform_gvcf(mt) for mt in vcfs])
    smt_from_vds = hl.vds.to_merged_sparse_mt(vds).drop('RGQ')
    smt = smt.select_entries(*smt_from_vds.entry)  # harmonize fields and order
    smt = smt.key_rows_by('locus', 'alleles')
    assert smt._same(smt_from_vds)
Ejemplo n.º 26
0
    def test_loop_with_struct_of_strings(self):
        def loop_func(recur_f, my_struct):
            return hl.if_else(
                hl.len(my_struct.s1) > hl.len(my_struct.s2), my_struct,
                recur_f(
                    hl.struct(s1=my_struct.s1 + my_struct.s2[-1],
                              s2=my_struct.s2[:-1])))

        initial_struct = hl.struct(s1="a", s2="gfedcb")
        assert hl.eval(
            hl.experimental.loop(loop_func, hl.tstruct(s1=hl.tstr, s2=hl.tstr),
                                 initial_struct)) == hl.Struct(s1="abcd",
                                                               s2="gfe")
Ejemplo n.º 27
0
 def test_summarize_variants(self):
     mt = hl.utils.range_matrix_table(3, 3)
     variants = hl.literal({
         0:
         hl.Struct(locus=hl.Locus('1', 1), alleles=['A', 'T', 'C']),
         1:
         hl.Struct(locus=hl.Locus('2', 1), alleles=['A', 'AT', '@']),
         2:
         hl.Struct(locus=hl.Locus('2', 1), alleles=['AC', 'GT'])
     })
     mt = mt.annotate_rows(**variants[mt.row_idx]).key_rows_by(
         'locus', 'alleles')
     r = hl.summarize_variants(mt, show=False)
     self.assertEqual(r.n_variants, 3)
     self.assertEqual(r.contigs, {'1': 1, '2': 2})
     self.assertEqual(r.allele_types, {
         'SNP': 2,
         'MNP': 1,
         'Unknown': 1,
         'Insertion': 1
     })
     self.assertEqual(r.allele_counts, {2: 1, 3: 2})
Ejemplo n.º 28
0
    def test_transmute(self):
        schema = hl.tstruct(a=hl.tint32, b=hl.tint32, c=hl.tint32, d=hl.tint32, e=hl.tstr, f=hl.tarray(hl.tint32),
                            g=hl.tstruct(x=hl.tbool, y=hl.tint32))

        rows = [{'a': 4, 'b': 1, 'c': 3, 'd': 5, 'e': "hello", 'f': [1, 2, 3], 'g': {'x': True, 'y': 2}},
                {'a': 0, 'b': 5, 'c': 13, 'd': -1, 'e': "cat", 'f': [], 'g': {'x': True, 'y': 2}},
                {'a': 4, 'b': 2, 'c': 20, 'd': 3, 'e': "dog", 'f': [5, 6, 7], 'g': None}]
        df = hl.Table.parallelize(rows, schema)

        df = df.transmute(h=df.a + df.b + df.c + df.g.y)
        r = df.select('h').collect()

        self.assertEqual(list(df.row), ['d', 'e', 'f', 'h'])
        self.assertEqual(r, [hl.Struct(h=x) for x in [10, 20, None]])
Ejemplo n.º 29
0
 def values(self):
     values = [(hl.tbool, True), (hl.tint32, 0), (hl.tint64, 0),
               (hl.tfloat32, 0.5), (hl.tfloat64, 0.5), (hl.tstr, "foo"),
               (hl.tstruct(x=hl.tint32), hl.Struct(x=0)),
               (hl.tarray(hl.tint32), [0, 1, 4]),
               (hl.tset(hl.tint32), {0, 1, 4}),
               (hl.tdict(hl.tstr, hl.tint32), {
                   "a": 0,
                   "b": 1,
                   "c": 4
               }), (hl.tinterval(hl.tint32), hl.Interval(0, 1, True,
                                                         False)),
               (hl.tlocus(hl.default_reference()), hl.Locus("1", 1)),
               (hl.tcall, hl.Call([0, 1]))]
     return values
Ejemplo n.º 30
0
    def test_table_head_returns_right_number(self):
        rt = hl.utils.range_table(10, 11)
        par = hl.Table.parallelize([hl.Struct(x=x) for x in range(10)], schema='struct{x: int32}', n_partitions=11)

        # test TableRange and TableParallelize rewrite rules
        tables = [rt, par, rt.cache()]
        for table in tables:
            self.assertEqual(table.head(10).count(), 10)
            self.assertEqual(table.head(10)._force_count(), 10)
            self.assertEqual(table.head(9).count(), 9)
            self.assertEqual(table.head(9)._force_count(), 9)
            self.assertEqual(table.head(11).count(), 10)
            self.assertEqual(table.head(11)._force_count(), 10)
            self.assertEqual(table.head(0).count(), 0)
            self.assertEqual(table.head(0)._force_count(), 0)