Beispiel #1
0
 def _promote_scalar(self, typ):
     if typ == tint32:
         return hail.int32(self)
     elif typ == tint64:
         return hail.int64(self)
     elif typ == tfloat32:
         return hail.float32(self)
     else:
         assert typ == tfloat64
         return hail.float64(self)
Beispiel #2
0
    def test_ibd(self):
        dataset = self.get_dataset()

        def plinkify(ds, min=None, max=None):
            vcf = utils.new_temp_file(prefix="plink", suffix="vcf")
            plinkpath = utils.new_temp_file(prefix="plink")
            hl.export_vcf(ds, vcf)
            threshold_string = "{} {}".format("--min {}".format(min) if min else "",
                                              "--max {}".format(max) if max else "")

            plink_command = "plink --double-id --allow-extra-chr --vcf {} --genome full --out {} {}" \
                .format(utils.uri_path(vcf),
                        utils.uri_path(plinkpath),
                        threshold_string)
            result_file = utils.uri_path(plinkpath + ".genome")

            syscall(plink_command, shell=True, stdout=DEVNULL, stderr=DEVNULL)

            ### format of .genome file is:
            # _, fid1, iid1, fid2, iid2, rt, ez, z0, z1, z2, pihat, phe,
            # dst, ppc, ratio, ibs0, ibs1, ibs2, homhom, hethet (+ separated)

            ### format of ibd is:
            # i (iid1), j (iid2), ibd: {Z0, Z1, Z2, PI_HAT}, ibs0, ibs1, ibs2
            results = {}
            with open(result_file) as f:
                f.readline()
                for line in f:
                    row = line.strip().split()
                    results[(row[1], row[3])] = (list(map(float, row[6:10])),
                                                 list(map(int, row[14:17])))
            return results

        def compare(ds, min=None, max=None):
            plink_results = plinkify(ds, min, max)
            hail_results = hl.identity_by_descent(ds, min=min, max=max).collect()

            for row in hail_results:
                key = (row.i, row.j)
                self.assertAlmostEqual(plink_results[key][0][0], row.ibd.Z0, places=4)
                self.assertAlmostEqual(plink_results[key][0][1], row.ibd.Z1, places=4)
                self.assertAlmostEqual(plink_results[key][0][2], row.ibd.Z2, places=4)
                self.assertAlmostEqual(plink_results[key][0][3], row.ibd.PI_HAT, places=4)
                self.assertEqual(plink_results[key][1][0], row.ibs0)
                self.assertEqual(plink_results[key][1][1], row.ibs1)
                self.assertEqual(plink_results[key][1][2], row.ibs2)

        compare(dataset)
        compare(dataset, min=0.0, max=1.0)
        dataset = dataset.annotate_rows(dummy_maf=0.01)
        hl.identity_by_descent(dataset, dataset['dummy_maf'], min=0.0, max=1.0)
        hl.identity_by_descent(dataset, hl.float32(dataset['dummy_maf']), min=0.0, max=1.0)
Beispiel #3
0
    def test_maximal_independent_set(self):
        # prefer to remove nodes with higher index
        t = hl.utils.range_table(10)
        graph = t.select(i=hl.int64(t.idx), j=hl.int64(t.idx + 10), bad_type=hl.float32(t.idx))

        mis_table = hl.maximal_independent_set(graph.i, graph.j, True, lambda l, r: l - r)
        mis = [row['node'] for row in mis_table.collect()]
        self.assertEqual(sorted(mis), list(range(0, 10)))
        self.assertEqual(mis_table.row.dtype, hl.tstruct(node=hl.tint64))

        self.assertRaises(ValueError, lambda: hl.maximal_independent_set(graph.i, graph.bad_type, True))
        self.assertRaises(ValueError, lambda: hl.maximal_independent_set(graph.i, hl.utils.range_table(10).idx, True))
        self.assertRaises(ValueError, lambda: hl.maximal_independent_set(hl.literal(1), hl.literal(2), True))
Beispiel #4
0
    def test_maximal_independent_set(self):
        # prefer to remove nodes with higher index
        t = hl.utils.range_table(10)
        graph = t.select(i=hl.int64(t.idx), j=hl.int64(t.idx + 10), bad_type=hl.float32(t.idx))

        mis_table = hl.maximal_independent_set(graph.i, graph.j, True, lambda l, r: l - r)
        mis = [row['node'] for row in mis_table.collect()]
        self.assertEqual(sorted(mis), list(range(0, 10)))
        self.assertEqual(mis_table.row.dtype, hl.tstruct(node=hl.tint64))
        self.assertEqual(mis_table.key.dtype, hl.tstruct(node=hl.tint64))

        self.assertRaises(ValueError, lambda: hl.maximal_independent_set(graph.i, graph.bad_type, True))
        self.assertRaises(ValueError, lambda: hl.maximal_independent_set(graph.i, hl.utils.range_table(10).idx, True))
        self.assertRaises(ValueError, lambda: hl.maximal_independent_set(hl.literal(1), hl.literal(2), True))
def vcf_to_mt(splice_ai_snvs_path, splice_ai_indels_path, genome_version):
    """
    Loads the snv path and indels source path to a matrix table and returns the table.

    :param splice_ai_snvs_path: source location
    :param splice_ai_indels_path: source location
    :return: matrix table
    """

    logger.info("==> reading in splice_ai vcfs: %s, %s" %
                (splice_ai_snvs_path, splice_ai_indels_path))

    # for 37, extract to MT, for 38, MT not included.
    interval = "1-MT" if genome_version == "37" else "chr1-chrY"
    contig_dict = None
    if genome_version == "38":
        contig_dict = NO_CHR_TO_CHR_CONTIG_RECODING

    mt = hl.import_vcf(
        [splice_ai_snvs_path, splice_ai_indels_path],
        reference_genome=f"GRCh{genome_version}",
        contig_recoding=contig_dict,
        force_bgz=True,
        min_partitions=10000,
        skip_invalid_loci=True,
    )
    interval = [
        hl.parse_locus_interval(interval,
                                reference_genome=f"GRCh{genome_version}")
    ]
    mt = hl.filter_intervals(mt, interval)

    # Split SpliceAI field by | delimiter. Capture delta score entries and map to floats
    delta_scores = mt.info.SpliceAI[0].split(delim="\\|")[2:6]
    splice_split = mt.info.annotate(
        SpliceAI=hl.map(lambda x: hl.float32(x), delta_scores))
    mt = mt.annotate_rows(info=splice_split)

    # Annotate info.max_DS with the max of DS_AG, DS_AL, DS_DG, DS_DL in info.
    # delta_score array is |DS_AG|DS_AL|DS_DG|DS_DL
    consequences = hl.literal(
        ["Acceptor gain", "Acceptor loss", "Donor gain", "Donor loss"])
    mt = mt.annotate_rows(info=mt.info.annotate(
        max_DS=hl.max(mt.info.SpliceAI)))
    mt = mt.annotate_rows(info=mt.info.annotate(splice_consequence=hl.if_else(
        mt.info.max_DS > 0,
        consequences[mt.info.SpliceAI.index(mt.info.max_DS)],
        "No consequence",
    )))
    return mt
Beispiel #6
0
def create_all_values():
    return hl.struct(
        f32=hl.float32(3.14),
        i64=hl.int64(-9),
        m=hl.null(hl.tfloat64),
        astruct=hl.struct(a=hl.null(hl.tint32), b=5.5),
        mstruct=hl.null(hl.tstruct(x=hl.tint32, y=hl.tstr)),
        aset=hl.set(['foo', 'bar', 'baz']),
        mset=hl.null(hl.tset(hl.tfloat64)),
        d=hl.dict({hl.array(['a', 'b']): 0.5, hl.array(['x', hl.null(hl.tstr), 'z']): 0.3}),
        md=hl.null(hl.tdict(hl.tint32, hl.tstr)),
        h38=hl.locus('chr22', 33878978, 'GRCh38'),
        ml=hl.null(hl.tlocus('GRCh37')),
        i=hl.interval(
            hl.locus('1', 999),
            hl.locus('1', 1001)),
        c=hl.call(0, 1),
        mc=hl.null(hl.tcall),
        t=hl.tuple([hl.call(1, 2, phased=True), 'foo', hl.null(hl.tstr)]),
        mt=hl.null(hl.ttuple(hl.tlocus('GRCh37'), hl.tbool))
    )
Beispiel #7
0
def create_all_values():
    return hl.struct(
        f32=hl.float32(3.14),
        i64=hl.int64(-9),
        m=hl.null(hl.tfloat64),
        astruct=hl.struct(a=hl.null(hl.tint32), b=5.5),
        mstruct=hl.null(hl.tstruct(x=hl.tint32, y=hl.tstr)),
        aset=hl.set(['foo', 'bar', 'baz']),
        mset=hl.null(hl.tset(hl.tfloat64)),
        d=hl.dict({hl.array(['a', 'b']): 0.5, hl.array(['x', hl.null(hl.tstr), 'z']): 0.3}),
        md=hl.null(hl.tdict(hl.tint32, hl.tstr)),
        h38=hl.locus('chr22', 33878978, 'GRCh38'),
        ml=hl.null(hl.tlocus('GRCh37')),
        i=hl.interval(
            hl.locus('1', 999),
            hl.locus('1', 1001)),
        c=hl.call(0, 1),
        mc=hl.null(hl.tcall),
        t=hl.tuple([hl.call(1, 2, phased=True), 'foo', hl.null(hl.tstr)]),
        mt=hl.null(hl.ttuple(hl.tlocus('GRCh37'), hl.tbool))
    )
Beispiel #8
0
def create_all_values_datasets():
    all_values = hl.struct(
        f32=hl.float32(3.14),
        i64=hl.int64(-9),
        m=hl.null(hl.tfloat64),
        astruct=hl.struct(a=hl.null(hl.tint32), b=5.5),
        mstruct=hl.null(hl.tstruct(x=hl.tint32, y=hl.tstr)),
        aset=hl.set(['foo', 'bar', 'baz']),
        mset=hl.null(hl.tset(hl.tfloat64)),
        d=hl.dict({hl.array(['a', 'b']): 0.5, hl.array(['x', hl.null(hl.tstr), 'z']): 0.3}),
        md=hl.null(hl.tdict(hl.tint32, hl.tstr)),
        h38=hl.locus('chr22', 33878978, 'GRCh38'),
        ml=hl.null(hl.tlocus('GRCh37')),
        i=hl.interval(
            hl.locus('1', 999),
            hl.locus('1', 1001)),
        c=hl.call(0, 1),
        mc=hl.null(hl.tcall),
        t=hl.tuple([hl.call(1, 2, phased=True), 'foo', hl.null(hl.tstr)]),
        mt=hl.null(hl.ttuple(hl.tlocus('GRCh37'), hl.tbool))
    )

    def prefix(s, p):
        return hl.struct(**{p + k: s[k] for k in s})

    all_values_table = (hl.utils.range_table(5, n_partitions=3)
                        .annotate_globals(**prefix(all_values, 'global_'))
                        .annotate(**all_values)
                        .cache())

    all_values_matrix_table = (hl.utils.range_matrix_table(3, 2, n_partitions=2)
                               .annotate_globals(**prefix(all_values, 'global_'))
                               .annotate_rows(**prefix(all_values, 'row_'))
                               .annotate_cols(**prefix(all_values, 'col_'))
                               .annotate_entries(**prefix(all_values, 'entry_'))
                               .cache())

    return all_values_table, all_values_matrix_table
Beispiel #9
0
def create_all_values_datasets():
    all_values = hl.struct(
        f32=hl.float32(3.14),
        i64=hl.int64(-9),
        m=hl.null(hl.tfloat64),
        astruct=hl.struct(a=hl.null(hl.tint32), b=5.5),
        mstruct=hl.null(hl.tstruct(x=hl.tint32, y=hl.tstr)),
        aset=hl.set(['foo', 'bar', 'baz']),
        mset=hl.null(hl.tset(hl.tfloat64)),
        d=hl.dict({
            hl.array(['a', 'b']): 0.5,
            hl.array(['x', hl.null(hl.tstr), 'z']): 0.3
        }),
        md=hl.null(hl.tdict(hl.tint32, hl.tstr)),
        h38=hl.locus('chr22', 33878978, 'GRCh38'),
        ml=hl.null(hl.tlocus('GRCh37')),
        i=hl.interval(hl.locus('1', 999), hl.locus('1', 1001)),
        c=hl.call(0, 1),
        mc=hl.null(hl.tcall),
        t=hl.tuple([hl.call(1, 2, phased=True), 'foo',
                    hl.null(hl.tstr)]),
        mt=hl.null(hl.ttuple(hl.tlocus('GRCh37'), hl.tbool)))

    def prefix(s, p):
        return hl.struct(**{p + k: s[k] for k in s})

    all_values_table = (hl.utils.range_table(
        5, n_partitions=3).annotate_globals(
            **prefix(all_values, 'global_')).annotate(**all_values).cache())

    all_values_matrix_table = (hl.utils.range_matrix_table(
        3, 2, n_partitions=2).annotate_globals(
            **prefix(all_values, 'global_')).annotate_rows(
                **prefix(all_values, 'row_')).annotate_cols(
                    **prefix(all_values, 'col_')).annotate_entries(
                        **prefix(all_values, 'entry_')).cache())

    return all_values_table, all_values_matrix_table
Beispiel #10
0
def plot_roc_curve(ht,
                   scores,
                   tp_label='tp',
                   fp_label='fp',
                   colors=None,
                   title='ROC Curve',
                   hover_mode='mouse'):
    """Create ROC curve from Hail Table.

    One or more `score` fields must be provided, which are assessed against `tp_label` and `fp_label` as truth data.

    High scores should correspond to true positives.

    Parameters
    ----------
    ht : :class:`.Table`
        Table with required data
    scores : :obj:`str` or :obj:`list` of :obj:`.str`
        Top-level location of scores in ht against which to generate PR curves.
    tp_label : :obj:`str`
        Top-level location of true positives in ht.
    fp_label : :obj:`str`
        Top-level location of false positives in ht.
    colors : :obj:`dict` of :obj:`str`
        Optional colors to use (score -> desired color).
    title : :obj:`str`
        Title of plot.
    hover_mode : :obj:`str`
        Hover mode; one of 'mouse' (default), 'vline' or 'hline'

    Returns
    -------
    :obj:`tuple` of :class:`.Figure` and :obj:`list` of :obj:`str`
        Figure, and list of AUCs corresponding to scores.
    """
    if colors is None:
        # Get a palette automatically
        from bokeh.palettes import d3
        palette = d3['Category10'][max(3, len(scores))]
        colors = {score: palette[i] for i, score in enumerate(scores)}

    if isinstance(scores, str):
        scores = [scores]
    total_tp, total_fp = ht.aggregate(
        (hl.agg.count_where(ht[tp_label]), hl.agg.count_where(ht[fp_label])))

    p = figure(title=title,
               x_axis_label='FPR',
               y_axis_label='TPR',
               tools="hover,save,pan,box_zoom,reset,wheel_zoom")
    p.add_layout(Title(text=f'Based on {total_tp} TPs and {total_fp} FPs'),
                 'above')

    aucs = []
    for score in scores:
        ordered_ht = ht.key_by(_score=-ht[score])
        ordered_ht = ordered_ht.select(
            score_name=score,
            score=ordered_ht[score],
            tpr=hl.scan.count_where(ordered_ht[tp_label]) / total_tp,
            fpr=hl.scan.count_where(ordered_ht[fp_label]) / total_fp,
        ).key_by().drop('_score')
        last_row = hl.utils.range_table(1).key_by().select(score_name=score,
                                                           score=hl.float64(
                                                               float('-inf')),
                                                           tpr=hl.float32(1.0),
                                                           fpr=hl.float32(1.0))
        ordered_ht = ordered_ht.union(last_row)
        ordered_ht = ordered_ht.annotate(
            auc_contrib=hl.or_else((ordered_ht.fpr -
                                    hl.scan.max(ordered_ht.fpr)) *
                                   ordered_ht.tpr, 0.0))
        auc = ordered_ht.aggregate(hl.agg.sum(ordered_ht.auc_contrib))
        aucs.append(auc)
        df = ordered_ht.annotate(score_name=ordered_ht.score_name +
                                 f' (AUC = {auc:.4f})').to_pandas()
        p.line(x='fpr',
               y='tpr',
               legend='score_name',
               source=ColumnDataSource(df),
               color=colors[score],
               line_width=3)

    p.legend.location = 'bottom_right'
    p.legend.click_policy = 'hide'
    p.select_one(HoverTool).tooltips = [
        (x, f"@{x}") for x in ('score_name', 'score', 'tpr', 'fpr')
    ]
    p.select_one(HoverTool).mode = hover_mode
    return p, aucs
Beispiel #11
0
def _to_expr(e, dtype):
    if e is None:
        return None
    elif isinstance(e, Expression):
        if e.dtype != dtype:
            assert is_numeric(dtype), 'expected {}, got {}'.format(
                dtype, e.dtype)
            if dtype == tfloat64:
                return hl.float64(e)
            elif dtype == tfloat32:
                return hl.float32(e)
            elif dtype == tint64:
                return hl.int64(e)
            else:
                assert dtype == tint32
                return hl.int32(e)
        return e
    elif not is_compound(dtype):
        # these are not container types and cannot contain expressions if we got here
        return e
    elif isinstance(dtype, tstruct):
        new_fields = []
        found_expr = False
        for f, t in dtype.items():
            value = _to_expr(e[f], t)
            found_expr = found_expr or isinstance(value, Expression)
            new_fields.append(value)

        if not found_expr:
            return e
        else:
            exprs = [
                new_fields[i] if isinstance(new_fields[i], Expression) else
                hl.literal(new_fields[i], dtype[i])
                for i in range(len(new_fields))
            ]
            fields = {name: expr for name, expr in zip(dtype.keys(), exprs)}
            from .typed_expressions import StructExpression
            return StructExpression._from_fields(fields)

    elif isinstance(dtype, tarray):
        elements = []
        found_expr = False
        for element in e:
            value = _to_expr(element, dtype.element_type)
            found_expr = found_expr or isinstance(value, Expression)
            elements.append(value)
        if not found_expr:
            return e
        else:
            assert len(elements) > 0
            exprs = [
                element if isinstance(element, Expression) else hl.literal(
                    element, dtype.element_type) for element in elements
            ]
            indices, aggregations = unify_all(*exprs)
        x = ir.MakeArray([e._ir for e in exprs], None)
        return expressions.construct_expr(x, dtype, indices, aggregations)
    elif isinstance(dtype, tset):
        elements = []
        found_expr = False
        for element in e:
            value = _to_expr(element, dtype.element_type)
            found_expr = found_expr or isinstance(value, Expression)
            elements.append(value)
        if not found_expr:
            return e
        else:
            assert len(elements) > 0
            exprs = [
                element if isinstance(element, Expression) else hl.literal(
                    element, dtype.element_type) for element in elements
            ]
            indices, aggregations = unify_all(*exprs)
            x = ir.ToSet(
                ir.ToStream(ir.MakeArray([e._ir for e in exprs], None)))
            return expressions.construct_expr(x, dtype, indices, aggregations)
    elif isinstance(dtype, ttuple):
        elements = []
        found_expr = False
        assert len(e) == len(dtype.types)
        for i in range(len(e)):
            value = _to_expr(e[i], dtype.types[i])
            found_expr = found_expr or isinstance(value, Expression)
            elements.append(value)
        if not found_expr:
            return e
        else:
            exprs = [
                elements[i] if isinstance(elements[i], Expression) else
                hl.literal(elements[i], dtype.types[i])
                for i in range(len(elements))
            ]
            indices, aggregations = unify_all(*exprs)
            x = ir.MakeTuple([expr._ir for expr in exprs])
            return expressions.construct_expr(x, dtype, indices, aggregations)
    elif isinstance(dtype, tdict):
        keys = []
        values = []
        found_expr = False
        for k, v in e.items():
            k_ = _to_expr(k, dtype.key_type)
            v_ = _to_expr(v, dtype.value_type)
            found_expr = found_expr or isinstance(k_, Expression)
            found_expr = found_expr or isinstance(v_, Expression)
            keys.append(k_)
            values.append(v_)
        if not found_expr:
            return e
        else:
            assert len(keys) > 0
            # Here I use `to_expr` to call `lit` the keys and values separately.
            # I anticipate a common mode is statically-known keys and Expression
            # values.
            key_array = to_expr(keys, tarray(dtype.key_type))
            value_array = to_expr(values, tarray(dtype.value_type))
            return hl.dict(hl.zip(key_array, value_array))
    elif isinstance(dtype, hl.tndarray):
        return hl.nd.array(e)
    else:
        raise NotImplementedError(dtype)