예제 #1
0
def add_filters_expr(
    filters: Dict[str, hl.expr.BooleanExpression],
    current_filters: hl.expr.SetExpression = None,
) -> hl.expr.SetExpression:
    """
    Create an expression to create or add filters.

    For each entry in the `filters` dictionary, if the value evaluates to `True`,
    then the key is added as a filter name.

    Current filters are kept if provided using `current_filters`

    :param filters: The filters and their expressions
    :param current_filters: The set of current filters
    :return: An expression that can be used to annotate the filters
    """
    if current_filters is None:
        current_filters = hl.empty_set(hl.tstr)

    return hl.fold(
        lambda x, y: x.union(y),
        current_filters,
        [
            hl.cond(filter_condition, hl.set([filter_name]),
                    hl.empty_set(hl.tstr))
            for filter_name, filter_condition in filters.items()
        ],
    )
예제 #2
0
    def compute_same_hap_log_like(n, p, q, x):
        res = (
            hl.cond(
                q > 0,
                hl.fold(
                    lambda i, j: i + j[0] * j[1], 0.0,
                    hl.zip(gt_counts, [
                        hl.log10(x) * 2,
                        hl.log10(2 * x * e),
                        hl.log10(e) * 2,
                        hl.log10(2 * x * p),
                        hl.log10(2 * (p * e + x * q)),
                        hl.log10(2 * q * e),
                        hl.log10(p) * 2,
                        hl.log10(2 * p * q),
                        hl.log10(q) * 2
                    ])),
                -1e31  # Very large negative value if no q is present
            ))

        # If desired, add distance posterior based on value derived from regression
        if distance is not None:
            res = res + hl.max(-6,
                               hl.log10(0.97 - 0.03 * hl.log(distance + 1)))

        return res
def annotate_with_genotype_num_alt(mt: hl.MatrixTable) -> hl.MatrixTable:
    if 'AD' in set(mt.entry):
        # GATK-consistent VCF
        mt = mt.annotate_rows(genotypes=(hl.agg.collect(
            hl.struct(num_alt=hl.cond(mt.alleles[1] == '<CNV>', 0,
                                      mt.GT.n_alt_alleles()),
                      ab=hl.cond(
                          mt.alleles[1] == '<CNV>', 0.0,
                          hl.float(hl.array(mt.AD)[1]) /
                          hl.float(hl.fold(lambda i, j: i + j, 0, mt.AD))),
                      gq=mt.GQ,
                      sample_id=mt.s,
                      dp=mt.DP))))
    elif 'AO' in set(mt.entry):
        mt = mt.annotate_rows(
            genotypes=hl.agg.collect(
                hl.struct(num_alt=hl.cond(mt.alleles[1] == '<CNV>', 0,
                                          mt.GT.n_alt_alleles()),
                          ab=hl.cond(mt.alleles[1] == '<CNV>' or mt.DP == 0,
                                     0.0,
                                     hl.float(mt.AO[0]) / hl.float(mt.DP)),
                          dp=mt.DP,
                          gq=mt.GQ,
                          sample_id=mt.s))
        )  #hl.cond(mt.GT=="0/0",0,hl.cond(mt.GT=="1/0",1,hl.cond(mt.GT=="0/1",1,hl.cond((mt.GT=="1/1",2,hl.cond(mt.GT=="1/2",2,hl.cond(mt.GT=="2/1",2,hl.cond(mt.GT=="2/2",2,-1))))))))
    else:
        raise ValueError("unrecognized vcf")
    return mt
예제 #4
0
    def transform_entries(old_entry):
        def with_local_a_index(local_a_index):
            new_pl = hl.or_missing(
                hl.is_defined(old_entry.LPL),
                hl.or_missing(
                    hl.is_defined(local_a_index),
                    hl.range(0, 3).map(lambda i: hl.min(
                        hl.range(0, hl.triangle(hl.len(old_entry.LA)))
                            .filter(lambda j: hl.downcode(hl.unphased_diploid_gt_index_call(j), local_a_index) == hl.unphased_diploid_gt_index_call(i))
                            .map(lambda idx: old_entry.LPL[idx])))))
            fields = set(old_entry.keys())

            def with_pl(pl):
                new_exprs = {}
                dropped_fields = ['LA']
                if 'LGT' in fields:
                    new_exprs['GT'] = hl.downcode(old_entry.LGT, hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LGT')
                if 'LPGT' in fields:
                    new_exprs['PGT'] = hl.downcode(old_entry.LPGT, hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LPGT')
                if 'LAD' in fields:
                    new_exprs['AD'] = hl.or_missing(
                        hl.is_defined(old_entry.LAD),
                        [old_entry.LAD[0], hl.or_else(old_entry.LAD[local_a_index], 0)]) # second entry zeroed for lack of non-ref AD
                    dropped_fields.append('LAD')
                if 'LPL' in fields:
                    new_exprs['PL'] = pl
                    if 'GQ' in fields:
                        new_exprs['GQ'] = hl.or_else(hl.gq_from_pl(pl), old_entry.GQ)

                    dropped_fields.append('LPL')

                return hl.cond(hl.len(ds.alleles) == 1,
                                   old_entry.annotate(**{f[1:]: old_entry[f] for f in ['LGT', 'LPGT', 'LAD', 'LPL'] if f in fields}).drop(*dropped_fields),
                                   old_entry.annotate(**new_exprs).drop(*dropped_fields))

            if 'LPL' in fields:
                return hl.bind(with_pl, new_pl)
            else:
                return with_pl(None)

        lai = hl.fold(lambda accum, elt:
                        hl.cond(old_entry.LA[elt] == ds[new_id].a_index,
                                elt, accum),
                        hl.null(hl.tint32),
                        hl.range(0, hl.len(old_entry.LA)))
        return hl.bind(with_local_a_index, lai)
예제 #5
0
    def compute_chet_log_like(n, p, q, x):
        res = (hl.cond((p > 0) & (q > 0),
                       hl.fold(
                           lambda i, j: i + j[0] * j[1], 0,
                           hl.zip(gt_counts, [
                               hl.log10(x) * 2,
                               hl.log10(2 * x * q),
                               hl.log10(q) * 2,
                               hl.log10(2 * x * p),
                               hl.log10(2 * (p * q + x * e)),
                               hl.log10(2 * q * e),
                               hl.log10(p) * 2,
                               hl.log10(2 * p * e),
                               hl.log10(e) * 2
                           ])), -1e-31))
        # If desired, add distance posterior based on value derived from regression
        if distance is not None:
            res = res + hl.max(-6,
                               hl.log10(0.03 + 0.03 * hl.log(distance - 1)))

        return res
예제 #6
0
파일: misc.py 프로젝트: jigold/hail
def locus_windows(locus_expr, radius, coord_expr=None, _localize=True):
    """Returns start and stop indices for window around each locus.

    Examples
    --------

    Windows with 2bp radius for one contig with positions 1, 2, 3, 4, 5:

    >>> starts, stops = hl.linalg.utils.locus_windows(
    ...     hl.balding_nichols_model(1, 5, 5).locus,
    ...     radius=2)
    >>> starts, stops
    (array([0, 0, 0, 1, 2]), array([3, 4, 5, 5, 5]))

    The following examples involve three contigs.

    >>> loci = [{'locus': hl.Locus('1', 1), 'cm': 1.0},
    ...         {'locus': hl.Locus('1', 2), 'cm': 3.0},
    ...         {'locus': hl.Locus('1', 4), 'cm': 4.0},
    ...         {'locus': hl.Locus('2', 1), 'cm': 2.0},
    ...         {'locus': hl.Locus('2', 1), 'cm': 2.0},
    ...         {'locus': hl.Locus('3', 3), 'cm': 5.0}]

    >>> ht = hl.Table.parallelize(
    ...         loci,
    ...         hl.tstruct(locus=hl.tlocus('GRCh37'), cm=hl.tfloat64),
    ...         key=['locus'])

    Windows with 1bp radius:

    >>> hl.linalg.utils.locus_windows(ht.locus, 1)
    (array([0, 0, 2, 3, 3, 5]), array([2, 2, 3, 5, 5, 6]))

    Windows with 1cm radius:

    >>> hl.linalg.utils.locus_windows(ht.locus, 1.0, coord_expr=ht.cm)
    (array([0, 1, 1, 3, 3, 5]), array([1, 3, 3, 5, 5, 6]))

    Notes
    -----
    This function returns two 1-dimensional ndarrays of integers,
    ``starts`` and ``stops``, each of size equal to the number of rows.

    By default, for all indices ``i``, ``[starts[i], stops[i])`` is the maximal
    range of row indices ``j`` such that ``contig[i] == contig[j]`` and
    ``position[i] - radius <= position[j] <= position[i] + radius``.

    If the :meth:`.global_position` on `locus_expr` is not in ascending order,
    this method will fail. Ascending order should hold for a matrix table keyed
    by locus or variant (and the associated row table), or for a table that has
    been ordered by `locus_expr`.

    Set `coord_expr` to use a value other than position to define the windows.
    This row-indexed numeric expression must be non-missing, non-``nan``, on the
    same source as `locus_expr`, and ascending with respect to locus
    position for each contig; otherwise the function will fail.

    The last example above uses centimorgan coordinates, so
    ``[starts[i], stops[i])`` is the maximal range of row indices ``j`` such
    that ``contig[i] == contig[j]`` and
    ``cm[i] - radius <= cm[j] <= cm[i] + radius``.

    Index ranges are start-inclusive and stop-exclusive. This function is
    especially useful in conjunction with
    :meth:`.BlockMatrix.sparsify_row_intervals`.

    Parameters
    ----------
    locus_expr : :class:`.LocusExpression`
        Row-indexed locus expression on a table or matrix table.
    radius: :obj:`int`
        Radius of window for row values.
    coord_expr: :class:`.Float64Expression`, optional
        Row-indexed numeric expression for the row value.
        Must be on the same table or matrix table as `locus_expr`.
        By default, the row value is given by the locus position.

    Returns
    -------
    (:class:`ndarray` of :obj:`int64`, :class:`ndarray` of :obj:`int64`)
        Tuple of start indices array and stop indices array.
    """
    if radius < 0:
        raise ValueError(f"locus_windows: 'radius' must be non-negative, found {radius}")
    check_row_indexed('locus_windows', locus_expr)
    if coord_expr is not None:
        check_row_indexed('locus_windows', coord_expr)

    src = locus_expr._indices.source
    if locus_expr not in src._fields_inverse:
        locus = Env.get_uid()
        annotate_fields = {locus: locus_expr}

        if coord_expr is not None:
            if coord_expr not in src._fields_inverse:
                coords = Env.get_uid()
                annotate_fields[coords] = coord_expr
            else:
                coords = src._fields_inverse[coord_expr]

        if isinstance(src, hl.MatrixTable):
            new_src = src.annotate_rows(**annotate_fields)
        else:
            new_src = src.annotate(**annotate_fields)

        locus_expr = new_src[locus]
        if coord_expr is not None:
            coord_expr = new_src[coords]

    if coord_expr is None:
        coord_expr = locus_expr.position

    rg = locus_expr.dtype.reference_genome
    contig_group_expr = hl.agg.group_by(hl.locus(locus_expr.contig, 1, reference_genome=rg), hl.agg.collect(coord_expr))

    # check loci are in sorted order
    last_pos = hl.fold(lambda a, elt: (hl.case()
                                         .when(a <= elt, elt)
                                         .or_error("locus_windows: 'locus_expr' global position must be in ascending order.")),
                       -1,
                       hl.agg.collect(hl.case()
                                        .when(hl.is_defined(locus_expr), locus_expr.global_position())
                                        .or_error("locus_windows: missing value for 'locus_expr'.")))
    checked_contig_groups = (hl.case()
                               .when(last_pos >= 0, contig_group_expr)
                               .or_error("locus_windows: 'locus_expr' has length 0"))

    contig_groups = locus_expr._aggregation_method()(checked_contig_groups, _localize=False)

    coords = hl.sorted(hl.array(contig_groups)).map(lambda t: t[1])
    starts_and_stops = hl._locus_windows_per_contig(coords, radius)

    if not _localize:
        return starts_and_stops

    starts, stops = hl.eval(starts_and_stops)
    return np.array(starts), np.array(stops)
예제 #7
0
    def transform_entries(old_entry):
        def with_local_a_index(local_a_index):
            fields = set(old_entry.keys())

            def with_pl(pl):
                new_exprs = {}
                dropped_fields = ['LA']
                if 'LGT' in fields:
                    new_exprs['GT'] = hl.downcode(
                        old_entry.LGT,
                        hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LGT')
                if 'LPGT' in fields:
                    new_exprs['PGT'] = hl.downcode(
                        old_entry.LPGT,
                        hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LPGT')
                if 'LAD' in fields:
                    non_ref_ad = hl.or_else(old_entry.LAD[local_a_index],
                                            0)  # zeroed if not in LAD
                    new_exprs['AD'] = hl.or_missing(
                        hl.is_defined(old_entry.LAD),
                        [hl.sum(old_entry.LAD) - non_ref_ad, non_ref_ad])
                    dropped_fields.append('LAD')
                if 'LPL' in fields:
                    new_exprs['PL'] = pl
                    if 'GQ' in fields:
                        new_exprs['GQ'] = hl.or_else(hl.gq_from_pl(pl),
                                                     old_entry.GQ)

                    dropped_fields.append('LPL')

                return (hl.case().when(
                    hl.len(ds.alleles) == 1,
                    old_entry.annotate(
                        **{
                            f[1:]: old_entry[f]
                            for f in ['LGT', 'LPGT', 'LAD', 'LPL']
                            if f in fields
                        }).drop(*dropped_fields)).when(
                            hl.or_else(old_entry.LGT.is_hom_ref(), False),
                            old_entry.annotate(
                                **{
                                    f: old_entry[f'L{f}'] if f in
                                    ['GT', 'PGT'] else e
                                    for f, e in new_exprs.items()
                                }).drop(*dropped_fields)).default(
                                    old_entry.annotate(**new_exprs).drop(
                                        *dropped_fields)))

            if 'LPL' in fields:
                new_pl = hl.or_missing(
                    hl.is_defined(old_entry.LPL),
                    hl.or_missing(
                        hl.is_defined(local_a_index),
                        hl.range(0, 3).map(lambda i: hl.min(
                            hl.range(0, hl.triangle(hl.len(old_entry.LA))).
                            filter(lambda j: hl.downcode(
                                hl.unphased_diploid_gt_index_call(j),
                                local_a_index) == hl.
                                   unphased_diploid_gt_index_call(i)).map(
                                       lambda idx: old_entry.LPL[idx])))))
                return hl.bind(with_pl, new_pl)
            else:
                return with_pl(None)

        lai = hl.fold(
            lambda accum, elt: hl.cond(old_entry.LA[elt] == ds[new_id].a_index,
                                       elt, accum), hl.null(hl.tint32),
            hl.range(0, hl.len(old_entry.LA)))
        return hl.bind(with_local_a_index, lai)
예제 #8
0
    def transform_entries(old_entry):
        def with_local_a_index(local_a_index):
            new_pl = hl.or_missing(
                hl.is_defined(old_entry.LPL),
                hl.or_missing(
                    hl.is_defined(local_a_index),
                    hl.range(0, 3).map(lambda i: hl.min(
                        hl.range(0, hl.triangle(hl.len(old_entry.LA))).
                        filter(lambda j: hl.downcode(
                            hl.unphased_diploid_gt_index_call(j), local_a_index
                        ) == hl.unphased_diploid_gt_index_call(i)).map(
                            lambda idx: old_entry.LPL[idx])))))
            fields = set(old_entry.keys())

            def with_pl(pl):
                new_exprs = {}
                dropped_fields = ['LA']
                if 'LGT' in fields:
                    new_exprs['GT'] = hl.downcode(
                        old_entry.LGT,
                        hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LGT')
                if 'LPGT' in fields:
                    new_exprs['PGT'] = hl.downcode(
                        old_entry.LPGT,
                        hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LPGT')
                if 'LAD' in fields:
                    new_exprs['AD'] = hl.or_missing(
                        hl.is_defined(old_entry.LAD), [
                            old_entry.LAD[0],
                            hl.or_else(old_entry.LAD[local_a_index], 0)
                        ])  # second entry zeroed for lack of non-ref AD
                    dropped_fields.append('LAD')
                if 'LPL' in fields:
                    new_exprs['PL'] = pl
                    if 'GQ' in fields:
                        new_exprs['GQ'] = hl.or_else(hl.gq_from_pl(pl),
                                                     old_entry.GQ)

                    dropped_fields.append('LPL')

                return hl.cond(
                    hl.len(ds.alleles) == 1,
                    old_entry.annotate(
                        **{
                            f[1:]: old_entry[f]
                            for f in ['LGT', 'LPGT', 'LAD', 'LPL']
                            if f in fields
                        }).drop(*dropped_fields),
                    old_entry.annotate(**new_exprs).drop(*dropped_fields))

            if 'LPL' in fields:
                return hl.bind(with_pl, new_pl)
            else:
                return with_pl(None)

        lai = hl.fold(
            lambda accum, elt: hl.cond(old_entry.LA[elt] == ds[new_id].a_index,
                                       elt, accum), hl.null(hl.tint32),
            hl.range(0, hl.len(old_entry.LA)))
        return hl.bind(with_local_a_index, lai)