Esempio n. 1
0
def combine(ts):
    def merge_alleles(alleles):
        from hail.expr.functions import _num_allele_type, _allele_ints
        return hl.rbind(
            alleles.map(lambda a: hl.or_else(a[0], '')).fold(
                lambda s, t: hl.cond(hl.len(s) > hl.len(t), s, t), ''),
            lambda ref: hl.rbind(
                alleles.map(lambda al: hl.rbind(
                    al[0], lambda r: hl.array([ref]).
                    extend(al[1:].map(lambda a: hl.rbind(
                        _num_allele_type(r, a), lambda at: hl.cond(
                            (_allele_ints['SNP'] == at)
                            | (_allele_ints['Insertion'] == at)
                            | (_allele_ints['Deletion'] == at)
                            | (_allele_ints['MNP'] == at)
                            | (_allele_ints['Complex'] == at), a + ref[hl.len(
                                r):], a)))))), lambda lal: hl.
                struct(globl=hl.array([ref]).extend(
                    hl.array(hl.set(hl.flatten(lal)).remove(ref))),
                       local=lal)))

    def renumber_entry(entry, old_to_new) -> StructExpression:
        # global index of alternate (non-ref) alleles
        return entry.annotate(LA=entry.LA.map(lambda lak: old_to_new[lak]))

    if (ts.row.dtype, ts.globals.dtype) not in _merge_function_map:
        f = hl.experimental.define_function(
            lambda row, gbl: hl.rbind(
                merge_alleles(row.data.map(lambda d: d.alleles)), lambda
                alleles: hl.struct(
                    locus=row.locus,
                    alleles=alleles.globl,
                    rsid=hl.find(hl.is_defined, row.data.map(lambda d: d.rsid)
                                 ),
                    __entries=hl.bind(
                        lambda combined_allele_index: hl.
                        range(0, hl.len(row.data)).flatmap(lambda i: hl.cond(
                            hl.is_missing(row.data[i].__entries),
                            hl.range(0, hl.len(gbl.g[i].__cols)).map(
                                lambda _: hl.null(row.data[i].__entries.dtype.
                                                  element_type)),
                            hl.bind(
                                lambda old_to_new: row.data[i].__entries.map(
                                    lambda e: renumber_entry(e, old_to_new)),
                                hl.range(0, hl.len(alleles.local[i])).map(
                                    lambda j: combined_allele_index[
                                        alleles.local[i][j]])))),
                        hl.dict(
                            hl.range(0, hl.len(alleles.globl)).map(
                                lambda j: hl.tuple([alleles.globl[j], j])))))),
            ts.row.dtype, ts.globals.dtype)
        _merge_function_map[(ts.row.dtype, ts.globals.dtype)] = f
    merge_function = _merge_function_map[(ts.row.dtype, ts.globals.dtype)]
    ts = Table(
        TableMapRows(
            ts._tir,
            Apply(merge_function._name, merge_function._ret_type,
                  TopLevelReference('row'), TopLevelReference('global'))))
    return ts.transmute_globals(
        __cols=hl.flatten(ts.g.map(lambda g: g.__cols)))
Esempio n. 2
0
def solve_triangular(nd_coef, nd_dep, lower=False):
    """Solve a triangular linear system.

    Parameters
    ----------
    nd_coef : :class:`.NDArrayNumericExpression`, (N, N)
        Triangular coefficient matrix.
    nd_dep : :class:`.NDArrayNumericExpression`, (N,) or (N, K)
        Dependent variables.
    lower : `bool`:
        If true, nd_coef is interpreted as a lower triangular matrix
        If false, nd_coef is interpreted as a upper triangular matrix

    Returns
    -------
    :class:`.NDArrayNumericExpression`, (N,) or (N, K)
        Solution to the triangular system Ax = B. Shape is same as shape of B.

    """
    nd_dep_ndim_orig = nd_dep.ndim
    nd_coef, nd_dep = solve_helper(nd_coef, nd_dep, nd_dep_ndim_orig)
    return_type = hl.tndarray(hl.tfloat64, 2)
    ir = Apply("linear_triangular_solve", return_type, nd_coef._ir, nd_dep._ir,
               lower._ir)
    result = construct_expr(ir, return_type, nd_coef._indices,
                            nd_coef._aggregations)
    if nd_dep_ndim_orig == 1:
        result = result.reshape((-1))
    return result
Esempio n. 3
0
def transform_one(mt, info_to_keep=[]) -> Table:
    """transforms a gvcf into a form suitable for combining

    The input to this should be some result of either :func:`.import_vcf` or
    :func:`.import_vcfs` with `array_elements_required=False`.

    There is a strong assumption that this function will be called on a matrix
    table with one column.
    """
    if not info_to_keep:
        info_to_keep = [name for name in mt.info if name not in ['END', 'DP']]
    mt = localize(mt)

    if mt.row.dtype not in _transform_rows_function_map:
        f = hl.experimental.define_function(
            lambda row: hl.rbind(
                hl.len(row.alleles), '<NON_REF>' == row.alleles[-1],
                lambda alleles_len, has_non_ref: hl.struct(
                    locus=row.locus,
                    alleles=hl.cond(has_non_ref, row.alleles[:-1], row.alleles),
                    rsid=row.rsid,
                    __entries=row.__entries.map(
                        lambda e:
                        hl.struct(
                            DP=e.DP,
                            END=row.info.END,
                            GQ=e.GQ,
                            LA=hl.range(0, alleles_len - hl.cond(has_non_ref, 1, 0)),
                            LAD=hl.cond(has_non_ref, e.AD[:-1], e.AD),
                            LGT=e.GT,
                            LPGT=e.PGT,
                            LPL=hl.cond(has_non_ref,
                                        hl.cond(alleles_len > 2,
                                                e.PL[:-alleles_len],
                                                hl.null(e.PL.dtype)),
                                        hl.cond(alleles_len > 1,
                                                e.PL,
                                                hl.null(e.PL.dtype))),
                            MIN_DP=e.MIN_DP,
                            PID=e.PID,
                            RGQ=hl.cond(
                                has_non_ref,
                                e.PL[hl.call(0, alleles_len - 1).unphased_diploid_gt_index()],
                                hl.null(e.PL.dtype.element_type)),
                            SB=e.SB,
                            gvcf_info=hl.case()
                                .when(hl.is_missing(row.info.END),
                                      hl.struct(**(row.info.select(*info_to_keep))))
                                .or_missing()
                        ))),
            ),
            mt.row.dtype)
        _transform_rows_function_map[mt.row.dtype] = f
    transform_row = _transform_rows_function_map[mt.row.dtype]
    return Table(TableMapRows(mt._tir, Apply(transform_row._name, transform_row._ret_type, TopLevelReference('row'))))
Esempio n. 4
0
def transform_one(mt, info_to_keep=[]) -> Table:
    if not info_to_keep:
        info_to_keep = [name for name in mt.info if name not in ['END', 'DP']]
    mt = localize(mt)

    if mt.row.dtype not in _transform_rows_function_map:
        f = hl.experimental.define_function(
            lambda row: hl.rbind(
                hl.len(row.alleles), '<NON_REF>' == row.alleles[-1],
                lambda alleles_len, has_non_ref: hl.struct(
                    locus=row.locus,
                    alleles=hl.cond(has_non_ref, row.alleles[:-1], row.alleles),
                    rsid=row.rsid,
                    __entries=row.__entries.map(
                        lambda e:
                        hl.struct(
                            DP=e.DP,
                            END=row.info.END,
                            GQ=e.GQ,
                            LA=hl.range(0, alleles_len - hl.cond(has_non_ref, 1, 0)),
                            LAD=hl.cond(has_non_ref, e.AD[:-1], e.AD),
                            LGT=e.GT,
                            LPGT=e.PGT,
                            LPL=hl.cond(has_non_ref,
                                        hl.cond(alleles_len > 2,
                                                e.PL[:-alleles_len],
                                                hl.null(e.PL.dtype)),
                                        hl.cond(alleles_len > 1,
                                                e.PL,
                                                hl.null(e.PL.dtype))),
                            MIN_DP=e.MIN_DP,
                            PID=e.PID,
                            RGQ=hl.cond(
                                has_non_ref,
                                e.PL[hl.call(0, alleles_len - 1).unphased_diploid_gt_index()],
                                hl.null(e.PL.dtype.element_type)),
                            SB=e.SB,
                            gvcf_info=hl.case()
                                .when(hl.is_missing(row.info.END),
                                      hl.struct(**(
                                          parse_as_fields(
                                              row.info.select(*info_to_keep),
                                              has_non_ref)
                                      )))
                                .or_missing()
                        ))),
            ),
            mt.row.dtype)
        _transform_rows_function_map[mt.row.dtype] = f
    transform_row = _transform_rows_function_map[mt.row.dtype]
    return Table(TableMapRows(mt._tir, Apply(transform_row._name, transform_row._ret_type, TopLevelReference('row'))))
Esempio n. 5
0
def combine_r(ts):
    if (ts.row.dtype, ts.globals.dtype) not in _merge_function_map:
        f = hl.experimental.define_function(
            lambda row, gbl:
            hl.struct(
                locus=row.locus,
                ref_allele=hl.find(hl.is_defined, row.data.map(lambda d: d.ref_allele)),
                __entries=hl.range(0, hl.len(row.data)).flatmap(
                    lambda i:
                    hl.if_else(hl.is_missing(row.data[i]),
                               hl.range(0, hl.len(gbl.g[i].__cols))
                               .map(lambda _: hl.missing(row.data[i].__entries.dtype.element_type)),
                               row.data[i].__entries))),
            ts.row.dtype, ts.globals.dtype)
        _merge_function_map[(ts.row.dtype, ts.globals.dtype)] = f
    merge_function = _merge_function_map[(ts.row.dtype, ts.globals.dtype)]
    ts = Table(TableMapRows(ts._tir, Apply(merge_function._name,
                                           merge_function._ret_type,
                                           TopLevelReference('row'),
                                           TopLevelReference('global'))))
    return ts.transmute_globals(__cols=hl.flatten(ts.g.map(lambda g: g.__cols)))
Esempio n. 6
0
def make_reference_matrix_table(mt: MatrixTable,
                                entry_to_keep: Collection[str]
                                ) -> MatrixTable:
    mt = mt.filter_rows(hl.is_defined(mt.info.END))
    entry_key = tuple(sorted(entry_to_keep))  # hashable stable value

    def make_entry_struct(e, row):
        handled_fields = dict()
        # we drop PL by default, but if `entry_to_keep` has it then PL needs to be
        # turned into LPL
        handled_names = {'AD', 'PL'}

        if 'AD' in entry_to_keep:
            handled_fields['LAD'] = e['AD'][:1]
        if 'PL' in entry_to_keep:
            handled_fields['LPL'] = e['PL'][:1]

        reference_fields = {k: v for k, v in e.items()
                            if k in entry_to_keep and k not in handled_names}
        return (hl.case()
                  .when(e.GT.is_hom_ref(),
                        hl.struct(END=row.info.END, **reference_fields, **handled_fields))
                  .or_error('found END with non reference-genotype at' + hl.str(row.locus)))

    mt = localize(mt)
    if (mt.row.dtype, entry_key) not in _transform_reference_fuction_map:
        f = hl.experimental.define_function(
            lambda row: hl.struct(
                locus=row.locus,
                ref_allele=row.alleles[0][0],
                __entries=row.__entries.map(
                    lambda e: make_entry_struct(e, row))),
            mt.row.dtype)
        _transform_reference_fuction_map[mt.row.dtype, entry_key] = f

    transform_row = _transform_reference_fuction_map[mt.row.dtype, entry_key]
    return unlocalize(Table(TableMapRows(mt._tir, Apply(transform_row._name, transform_row._ret_type, TopLevelReference('row')))))
Esempio n. 7
0
File: nd.py Progetto: saponas/hail
def solve(a, b):
    """Solve a linear system.

    Parameters
    ----------
    a : :class:`.NDArrayNumericExpression`, (N, N)
        Coefficient matrix.
    b : :class:`.NDArrayNumericExpression`, (N,) or (N, K)
        Dependent variables.

    Returns
    -------
    :class:`.NDArrayNumericExpression`, (N,) or (N, K)
        Solution to the system Ax = B. Shape is same as shape of B.

    """
    assert a.ndim == 2
    assert b.ndim == 1 or b.ndim == 2

    b_ndim_orig = b.ndim

    if b_ndim_orig == 1:
        b = b.reshape((-1, 1))

    if a.dtype.element_type != hl.tfloat64:
        a = a.map(lambda e: hl.float64(e))
    if b.dtype.element_type != hl.tfloat64:
        b = b.map(lambda e: hl.float64(e))

    ir = Apply("linear_solve", hl.tndarray(hl.tfloat64, 2), a._ir, b._ir)
    result = construct_expr(ir, hl.tndarray(hl.tfloat64, 2), a._indices,
                            a._aggregations)

    if b_ndim_orig == 1:
        result = result.reshape((-1))
    return result
Esempio n. 8
0
def solve(a, b, no_crash=False):
    """Solve a linear system.

    Parameters
    ----------
    a : :class:`.NDArrayNumericExpression`, (N, N)
        Coefficient matrix.
    b : :class:`.NDArrayNumericExpression`, (N,) or (N, K)
        Dependent variables.

    Returns
    -------
    :class:`.NDArrayNumericExpression`, (N,) or (N, K)
        Solution to the system Ax = B. Shape is same as shape of B.

    """
    b_ndim_orig = b.ndim
    a, b = solve_helper(a, b, b_ndim_orig)
    if no_crash:
        name = "linear_solve_no_crash"
        return_type = hl.tstruct(solution=hl.tndarray(hl.tfloat64, 2),
                                 failed=hl.tbool)
    else:
        name = "linear_solve"
        return_type = hl.tndarray(hl.tfloat64, 2)

    ir = Apply(name, return_type, a._ir, b._ir)
    result = construct_expr(ir, return_type, a._indices, a._aggregations)

    if b_ndim_orig == 1:
        if no_crash:
            result = hl.struct(solution=result.solution.reshape((-1)),
                               failed=result.failed)
        else:
            result = result.reshape((-1))
    return result
Esempio n. 9
0
 def f(*args):
     indices, aggregations = unify_all(*args)
     return construct_expr(Apply(mname, ret_type, *(a._ir for a in args)),
                           ret_type, indices, aggregations)
Esempio n. 10
0
def transform_gvcf(mt, info_to_keep=[]) -> Table:
    """Transforms a gvcf into a sparse matrix table

    The input to this should be some result of either :func:`.import_vcf` or
    :func:`.import_gvcfs` with ``array_elements_required=False``.

    There is an assumption that this function will be called on a matrix table
    with one column (or a localized table version of the same).

    Parameters
    ----------
    mt : :obj:`Union[Table, MatrixTable]`
        The gvcf being transformed, if it is a table, then it must be a localized matrix table with
        the entries array named ``__entries``
    info_to_keep : :obj:`List[str]`
        Any ``INFO`` fields in the gvcf that are to be kept and put in the ``gvcf_info`` entry
        field. By default, all ``INFO`` fields except ``END`` and ``DP`` are kept.

    Returns
    -------
    :obj:`.Table`
        A localized matrix table that can be used as part of the input to :func:`.combine_gvcfs`

    Notes
    -----
    This function will parse the following allele specific annotations from
    pipe delimited strings into proper values. ::

        AS_QUALapprox
        AS_RAW_MQ
        AS_RAW_MQRankSum
        AS_RAW_ReadPosRankSum
        AS_SB_TABLE
        AS_VarDP

    """
    if not info_to_keep:
        info_to_keep = [name for name in mt.info if name not in ['END', 'DP']]
    mt = localize(mt)

    if mt.row.dtype not in _transform_rows_function_map:

        def get_lgt(e, n_alleles, has_non_ref, row):
            index = e.GT.unphased_diploid_gt_index()
            n_no_nonref = n_alleles - hl.int(has_non_ref)
            triangle_without_nonref = hl.triangle(n_no_nonref)
            return (hl.case().when(index < triangle_without_nonref, e.GT).when(
                index < hl.triangle(n_alleles),
                hl.null('call')).or_error('invalid GT ' + hl.str(e.GT) +
                                          ' at site ' + hl.str(row.locus)))

        def make_entry_struct(e, alleles_len, has_non_ref, row):
            handled_fields = dict()
            handled_names = {
                'LA', 'gvcf_info', 'END', 'LAD', 'AD', 'LGT', 'GT', 'LPL',
                'PL', 'LPGT', 'PGT'
            }

            if 'END' not in row.info:
                raise hl.utils.FatalError(
                    "the Hail GVCF combiner expects GVCFs to have an 'END' field in INFO."
                )
            if 'GT' not in e:
                raise hl.utils.FatalError(
                    "the Hail GVCF combiner expects GVCFs to have a 'GT' field in FORMAT."
                )

            handled_fields['LA'] = hl.range(
                0, alleles_len - hl.cond(has_non_ref, 1, 0))
            handled_fields['LGT'] = get_lgt(e, alleles_len, has_non_ref, row)
            if 'AD' in e:
                handled_fields['LAD'] = hl.cond(has_non_ref, e.AD[:-1], e.AD)
            if 'PGT' in e:
                handled_fields['LPGT'] = e.PGT
            if 'PL' in e:
                handled_fields['LPL'] = hl.cond(
                    has_non_ref,
                    hl.cond(alleles_len > 2, e.PL[:-alleles_len],
                            hl.null(e.PL.dtype)),
                    hl.cond(alleles_len > 1, e.PL, hl.null(e.PL.dtype)))
                handled_fields['RGQ'] = hl.cond(
                    has_non_ref,
                    e.PL[hl.call(0,
                                 alleles_len - 1).unphased_diploid_gt_index()],
                    hl.null(e.PL.dtype.element_type))

            handled_fields['END'] = row.info.END
            handled_fields['gvcf_info'] = (hl.case().when(
                hl.is_missing(row.info.END),
                hl.struct(**(parse_as_fields(row.info.select(
                    *info_to_keep), has_non_ref)))).or_missing())

            pass_through_fields = {
                k: v
                for k, v in e.items() if k not in handled_names
            }
            return hl.struct(**handled_fields, **pass_through_fields)

        f = hl.experimental.define_function(
            lambda row: hl.rbind(
                hl.len(row.alleles), '<NON_REF>' == row.alleles[-1], lambda
                alleles_len, has_non_ref: hl.struct(
                    locus=row.locus,
                    alleles=hl.cond(has_non_ref, row.alleles[:-1], row.alleles
                                    ),
                    rsid=row.rsid,
                    __entries=row.__entries.map(lambda e: make_entry_struct(
                        e, alleles_len, has_non_ref, row)))), mt.row.dtype)
        _transform_rows_function_map[mt.row.dtype] = f
    transform_row = _transform_rows_function_map[mt.row.dtype]
    return Table(
        TableMapRows(
            mt._tir,
            Apply(transform_row._name, transform_row._ret_type,
                  TopLevelReference('row'))))
Esempio n. 11
0
def transform_gvcf(mt, info_to_keep=[]) -> Table:
    """Transforms a gvcf into a sparse matrix table

    The input to this should be some result of either :func:`.import_vcf` or
    :func:`.import_gvcfs` with ``array_elements_required=False``.

    There is an assumption that this function will be called on a matrix table
    with one column (or a localized table version of the same).

    Parameters
    ----------
    mt : :obj:`Union[Table, MatrixTable]`
        The gvcf being transformed, if it is a table, then it must be a localized matrix table with
        the entries array named ``__entries``
    info_to_keep : :obj:`List[str]`
        Any ``INFO`` fields in the gvcf that are to be kept and put in the ``gvcf_info`` entry
        field. By default, all ``INFO`` fields except ``END`` and ``DP`` are kept.

    Returns
    -------
    :obj:`.Table`
        A localized matrix table that can be used as part of the input to :func:`.combine_gvcfs`

    Notes
    -----
    This function will parse the following allele specific annotations from
    pipe delimited strings into proper values. ::

        AS_QUALapprox
        AS_RAW_MQ
        AS_RAW_MQRankSum
        AS_RAW_ReadPosRankSum
        AS_SB_TABLE
        AS_VarDP

    """
    if not info_to_keep:
        info_to_keep = [name for name in mt.info if name not in ['END', 'DP']]
    mt = localize(mt)

    if mt.row.dtype not in _transform_rows_function_map:
        f = hl.experimental.define_function(
            lambda row: hl.rbind(
                hl.len(row.alleles),
                '<NON_REF>' == row.alleles[-1],
                lambda alleles_len, has_non_ref: hl.
                struct(locus=row.locus,
                       alleles=hl.cond(has_non_ref, row.alleles[:-1], row.
                                       alleles),
                       rsid=row.rsid,
                       __entries=row.__entries.map(lambda e: hl.struct(
                           DP=e.DP,
                           END=row.info.END,
                           GQ=e.GQ,
                           LA=hl.range(
                               0, alleles_len - hl.cond(has_non_ref, 1, 0)),
                           LAD=hl.cond(has_non_ref, e.AD[:-1], e.AD),
                           LGT=e.GT,
                           LPGT=e.PGT,
                           LPL=hl.cond(
                               has_non_ref,
                               hl.cond(alleles_len > 2, e.PL[:-alleles_len],
                                       hl.null(e.PL.dtype)),
                               hl.cond(alleles_len > 1, e.PL,
                                       hl.null(e.PL.dtype))),
                           MIN_DP=e.MIN_DP,
                           PID=e.PID,
                           RGQ=hl.cond(
                               has_non_ref, e.PL[hl.call(0, alleles_len - 1).
                                                 unphased_diploid_gt_index()],
                               hl.null(e.PL.dtype.element_type)),
                           SB=e.SB,
                           gvcf_info=hl.case().when(
                               hl.is_missing(row.info.END),
                               hl.struct(**(parse_as_fields(
                                   row.info.select(*info_to_keep), has_non_ref)
                                            ))).or_missing()))),
            ), mt.row.dtype)
        _transform_rows_function_map[mt.row.dtype] = f
    transform_row = _transform_rows_function_map[mt.row.dtype]
    return Table(
        TableMapRows(
            mt._tir,
            Apply(transform_row._name, transform_row._ret_type,
                  TopLevelReference('row'))))
Esempio n. 12
0
def transform_one(mt) -> Table:
    """transforms a gvcf into a form suitable for combining

    The input to this should be some result of either :func:`.import_vcf` or
    :func:`.import_vcfs` with `array_elements_required=False`.
    """
    mt = localize(mt)
    if mt.row.dtype not in _transform_rows_function_map:
        f = hl.experimental.define_function(
            lambda row: hl.rbind(
                hl.len(row.alleles),
                '<NON_REF>' == row.alleles[-1],
                lambda alleles_len, has_non_ref: hl.
                struct(locus=row.locus,
                       alleles=hl.cond(has_non_ref, row.alleles[:-1], row.
                                       alleles),
                       rsid=row.rsid,
                       info=row.info.annotate(SB_TABLE=hl.array([
                           hl.sum(row.__entries.map(lambda d: d.SB[0])),
                           hl.sum(row.__entries.map(lambda d: d.SB[1])),
                           hl.sum(row.__entries.map(lambda d: d.SB[2])),
                           hl.sum(row.__entries.map(lambda d: d.SB[3])),
                       ])).select(
                           "MQ_DP",
                           "QUALapprox",
                           "RAW_MQ",
                           "SB_TABLE",
                           "VarDP",
                       ),
                       __entries=row.__entries.map(lambda e: hl.struct(
                           BaseQRankSum=row.info['BaseQRankSum'],
                           ClippingRankSum=row.info['ClippingRankSum'],
                           DP=e.DP,
                           END=row.info.END,
                           GQ=e.GQ,
                           LA=hl.range(
                               0, alleles_len - hl.cond(has_non_ref, 1, 0)),
                           LAD=hl.cond(has_non_ref, e.AD[:-1], e.AD),
                           LGT=e.GT,
                           LPGT=e.PGT,
                           LPL=hl.cond(
                               has_non_ref,
                               hl.cond(alleles_len > 2, e.PL[:-alleles_len],
                                       hl.null(e.PL.dtype)),
                               hl.cond(alleles_len > 1, e.PL,
                                       hl.null(e.PL.dtype))),
                           MIN_DP=e.MIN_DP,
                           MQ=row.info['MQ'],
                           MQRankSum=row.info['MQRankSum'],
                           PID=e.PID,
                           RGQ=hl.cond(
                               has_non_ref, e.PL[hl.call(0, alleles_len - 1).
                                                 unphased_diploid_gt_index()],
                               hl.null(e.PL.dtype.element_type)),
                           ReadPosRankSum=row.info['ReadPosRankSum'],
                       ))),
            ), mt.row.dtype)
        _transform_rows_function_map[mt.row.dtype] = f
    transform_row = _transform_rows_function_map[mt.row.dtype]
    return Table(
        TableMapRows(mt._tir,
                     Apply(transform_row._name, TopLevelReference('row'))))
Esempio n. 13
0
def transform_one(mt, vardp_outlier=100_000) -> Table:
    """transforms a gvcf into a form suitable for combining

    The input to this should be some result of either :func:`.import_vcf` or
    :func:`.import_vcfs` with `array_elements_required=False`.

    There is a strong assumption that this function will be called on a matrix
    table with one column.
    """
    mt = localize(mt)
    if mt.row.dtype not in _transform_rows_function_map:
        f = hl.experimental.define_function(
            lambda row: hl.rbind(
                hl.len(row.alleles),
                '<NON_REF>' == row.alleles[-1],
                lambda alleles_len, has_non_ref: hl.
                struct(locus=row.locus,
                       alleles=hl.cond(has_non_ref, row.alleles[:-1], row.
                                       alleles),
                       rsid=row.rsid,
                       __entries=row.__entries.map(lambda e: hl.struct(
                           DP=e.DP,
                           END=row.info.END,
                           GQ=e.GQ,
                           LA=hl.range(
                               0, alleles_len - hl.cond(has_non_ref, 1, 0)),
                           LAD=hl.cond(has_non_ref, e.AD[:-1], e.AD),
                           LGT=e.GT,
                           LPGT=e.PGT,
                           LPL=hl.cond(
                               has_non_ref,
                               hl.cond(alleles_len > 2, e.PL[:-alleles_len],
                                       hl.null(e.PL.dtype)),
                               hl.cond(alleles_len > 1, e.PL,
                                       hl.null(e.PL.dtype))),
                           MIN_DP=e.MIN_DP,
                           PID=e.PID,
                           RGQ=hl.cond(
                               has_non_ref, e.PL[hl.call(0, alleles_len - 1).
                                                 unphased_diploid_gt_index()],
                               hl.null(e.PL.dtype.element_type)),
                           SB=e.SB,
                           gvcf_info=hl.case().when(
                               hl.is_missing(row.info.END),
                               hl.struct(
                                   ClippingRankSum=row.info.ClippingRankSum,
                                   BaseQRankSum=row.info.BaseQRankSum,
                                   MQ=row.info.MQ,
                                   MQRankSum=row.info.MQRankSum,
                                   MQ_DP=row.info.MQ_DP,
                                   QUALapprox=row.info.QUALapprox,
                                   RAW_MQ=row.info.RAW_MQ,
                                   ReadPosRankSum=row.info.ReadPosRankSum,
                                   VarDP=hl.cond(
                                       row.info.VarDP > vardp_outlier, row.info
                                       .DP, row.info.VarDP))).or_missing()))),
            ), mt.row.dtype)
        _transform_rows_function_map[mt.row.dtype] = f
    transform_row = _transform_rows_function_map[mt.row.dtype]
    return Table(
        TableMapRows(mt._tir,
                     Apply(transform_row._name, TopLevelReference('row'))))
Esempio n. 14
0
def make_variants_matrix_table(mt: MatrixTable,
                               info_to_keep: Optional[Collection[str]] = None
                               ) -> MatrixTable:
    if info_to_keep is None:
        info_to_keep = []
    if not info_to_keep:
        info_to_keep = [name for name in mt.info if name not in ['END', 'DP']]
    info_key = tuple(sorted(info_to_keep))  # hashable stable value
    mt = localize(mt)
    mt = mt.filter(hl.is_missing(mt.info.END))

    if (mt.row.dtype, info_key) not in _transform_variant_function_map:
        def get_lgt(e, n_alleles, has_non_ref, row):
            index = e.GT.unphased_diploid_gt_index()
            n_no_nonref = n_alleles - hl.int(has_non_ref)
            triangle_without_nonref = hl.triangle(n_no_nonref)
            return (hl.case()
                    .when(e.GT.is_haploid(),
                          hl.or_missing(e.GT[0] < n_no_nonref, e.GT))
                    .when(index < triangle_without_nonref, e.GT)
                    .when(index < hl.triangle(n_alleles), hl.missing('call'))
                    .or_error('invalid GT ' + hl.str(e.GT) + ' at site ' + hl.str(row.locus)))

        def make_entry_struct(e, alleles_len, has_non_ref, row):
            handled_fields = dict()
            handled_names = {'LA', 'gvcf_info',
                             'LAD', 'AD',
                             'LGT', 'GT',
                             'LPL', 'PL',
                             'LPGT', 'PGT'}

            if 'GT' not in e:
                raise hl.utils.FatalError("the Hail GVCF combiner expects GVCFs to have a 'GT' field in FORMAT.")

            handled_fields['LA'] = hl.range(0, alleles_len - hl.if_else(has_non_ref, 1, 0))
            handled_fields['LGT'] = get_lgt(e, alleles_len, has_non_ref, row)
            if 'AD' in e:
                handled_fields['LAD'] = hl.if_else(has_non_ref, e.AD[:-1], e.AD)
            if 'PGT' in e:
                handled_fields['LPGT'] = e.PGT
            if 'PL' in e:
                handled_fields['LPL'] = hl.if_else(has_non_ref,
                                                   hl.if_else(alleles_len > 2,
                                                              e.PL[:-alleles_len],
                                                              hl.missing(e.PL.dtype)),
                                                   hl.if_else(alleles_len > 1,
                                                              e.PL,
                                                              hl.missing(e.PL.dtype)))
                handled_fields['RGQ'] = hl.if_else(
                    has_non_ref,
                    hl.if_else(e.GT.is_haploid(),
                               e.PL[alleles_len - 1],
                               e.PL[hl.call(0, alleles_len - 1).unphased_diploid_gt_index()]),
                    hl.missing(e.PL.dtype.element_type))

            handled_fields['gvcf_info'] = (hl.case()
                                           .when(hl.is_missing(row.info.END),
                                                 hl.struct(**(
                                                     parse_as_fields(
                                                         row.info.select(*info_to_keep),
                                                         has_non_ref)
                                                 )))
                                           .or_missing())

            pass_through_fields = {k: v for k, v in e.items() if k not in handled_names}
            return hl.struct(**handled_fields, **pass_through_fields)

        f = hl.experimental.define_function(
            lambda row: hl.rbind(
                hl.len(row.alleles), '<NON_REF>' == row.alleles[-1],
                lambda alleles_len, has_non_ref: hl.struct(
                    locus=row.locus,
                    alleles=hl.if_else(has_non_ref, row.alleles[:-1], row.alleles),
                    rsid=row.rsid,
                    __entries=row.__entries.map(
                        lambda e: make_entry_struct(e, alleles_len, has_non_ref, row)))),
            mt.row.dtype)
        _transform_variant_function_map[mt.row.dtype, info_key] = f
    transform_row = _transform_variant_function_map[mt.row.dtype, info_key]
    return unlocalize(Table(TableMapRows(mt._tir, Apply(transform_row._name, transform_row._ret_type, TopLevelReference('row')))))