示例#1
0
    def phase_haploid_proband_x_nonpar(
            proband_call: hl.expr.CallExpression,
            father_call: hl.expr.CallExpression,
            mother_call: hl.expr.CallExpression) -> hl.expr.ArrayExpression:
        """
        Returns phased genotype calls in the case of a haploid proband in the non-PAR region of X

        :param CallExpression proband_call: Input proband genotype call
        :param CallExpression father_call: Input father genotype call
        :param CallExpression mother_call: Input mother genotype call
        :return: Array containing: phased proband call, phased father call, phased mother call
        :rtype: ArrayExpression
        """

        transmitted_allele = hl.enumerate(
            hl.array([mother_call[0],
                      mother_call[1]])).find(lambda m: m[1] == proband_call[0])
        return hl.or_missing(
            hl.is_defined(transmitted_allele),
            hl.array([
                hl.call(proband_call[0], phased=True),
                hl.or_missing(father_call.is_haploid(),
                              hl.call(father_call[0], phased=True)),
                phase_parent_call(mother_call, transmitted_allele[0])
            ]))
示例#2
0
def filter_samples(vds: 'VariantDataset', samples_table: 'Table', *,
                   keep: bool = True,
                   remove_dead_alleles: bool = False) -> 'VariantDataset':
    """Filter samples in a :class:`.VariantDataset`.

    Parameters
    ----------
    vds : :class:`.VariantDataset`
        Dataset in VariantDataset representation.
    samples_table : :class:`.Table`
        Samples to filter on.
    keep : :obj:`bool`
        Whether to keep (default), or filter out the samples from `samples_table`.
    remove_dead_alleles : :obj:`bool`
        If true, remove alleles observed in no samples. Alleles with AC == 0 will be
        removed, and LA values recalculated.

    Returns
    -------
    :class:`.VariantDataset`
    """
    if not list(samples_table[x].dtype for x in samples_table.key) == [hl.tstr]:
        raise TypeError(f'invalid key: {samples_table.key.dtype}')
    samples_to_keep = samples_table.aggregate(hl.agg.collect_as_set(samples_table.key[0]), _localize=False)._persist()
    reference_data = vds.reference_data.filter_cols(samples_to_keep.contains(vds.reference_data.col_key[0]), keep=keep)
    reference_data = reference_data.filter_rows(hl.agg.count() > 0)
    variant_data = vds.variant_data.filter_cols(samples_to_keep.contains(vds.variant_data.col_key[0]), keep=keep)

    if remove_dead_alleles:
        vd = variant_data
        vd = vd.annotate_rows(__allele_counts=hl.agg.explode(lambda x: hl.agg.counter(x), vd.LA), __n=hl.agg.count())
        vd = vd.filter_rows(vd.__n > 0)

        vd = vd.annotate_rows(__kept_indices=hl.dict(
            hl.enumerate(
                hl.range(hl.len(vd.alleles)).filter(lambda idx: (idx == 0) | (vd.__allele_counts.get(idx, 0) > 0)),
                index_first=False)))

        vd = vd.annotate_rows(
            __old_to_new_LA=hl.range(hl.len(vd.alleles)).map(lambda idx: vd.__kept_indices.get(idx, -1)))

        def new_la_index(old_idx):
            raw_idx = vd.__old_to_new_LA[old_idx]
            return hl.case().when(raw_idx >= 0, raw_idx) \
                .or_error("'filter_samples': unexpected local allele: old index=" + hl.str(old_idx))

        vd = vd.annotate_entries(LA=vd.LA.map(lambda la: new_la_index(la)))
        vd = vd.key_rows_by('locus')
        vd = vd.annotate_rows(alleles=vd.__kept_indices.keys().map(lambda i: vd.alleles[i]))
        vd = vd._key_rows_by_assert_sorted('locus', 'alleles')
        vd = vd.drop('__allele_counts', '__kept_indices', '__old_to_new_LA')
        return VariantDataset(reference_data, vd)

    variant_data = variant_data.filter_rows(hl.agg.count() > 0)
    return VariantDataset(reference_data, variant_data)
示例#3
0
    def phase_diploid_proband(
            locus: hl.expr.LocusExpression, alleles: hl.expr.ArrayExpression,
            proband_call: hl.expr.CallExpression,
            father_call: hl.expr.CallExpression,
            mother_call: hl.expr.CallExpression) -> hl.expr.ArrayExpression:
        """
        Returns phased genotype calls in the case of a diploid proband
        (autosomes, PAR regions of sex chromosomes or non-PAR regions of a female proband)

        :param LocusExpression locus: Locus in the trio MatrixTable
        :param ArrayExpression alleles: Alleles in the trio MatrixTable
        :param CallExpression proband_call: Input proband genotype call
        :param CallExpression father_call: Input father genotype call
        :param CallExpression mother_call: Input mother genotype call
        :return: Array containing: phased proband call, phased father call, phased mother call
        :rtype: ArrayExpression
        """

        proband_v = proband_call.one_hot_alleles(alleles)
        father_v = hl.if_else(
            locus.in_x_nonpar() | locus.in_y_nonpar(),
            hl.or_missing(father_call.is_haploid(),
                          hl.array([father_call.one_hot_alleles(alleles)])),
            call_to_one_hot_alleles_array(father_call, alleles))
        mother_v = call_to_one_hot_alleles_array(mother_call, alleles)

        combinations = hl.flatmap(
            lambda f: hl.enumerate(mother_v).filter(lambda m: m[1] + f[
                1] == proband_v).map(lambda m: hl.struct(m=m[0], f=f[0])),
            hl.enumerate(father_v))

        return (hl.or_missing(
            hl.is_defined(combinations) & (hl.len(combinations) == 1),
            hl.array([
                hl.call(father_call[combinations[0].f],
                        mother_call[combinations[0].m],
                        phased=True),
                hl.if_else(father_call.is_haploid(),
                           hl.call(father_call[0], phased=True),
                           phase_parent_call(father_call, combinations[0].f)),
                phase_parent_call(mother_call, combinations[0].m)
            ])))
示例#4
0
def full_outer_join_mt(left: hl.MatrixTable, right: hl.MatrixTable) -> hl.MatrixTable:
    """Performs a full outer join on `left` and `right`.

    Replaces row, column, and entry fields with the following:

     - `left_row` / `right_row`: structs of row fields from left and right.
     - `left_col` / `right_col`: structs of column fields from left and right.
     - `left_entry` / `right_entry`: structs of entry fields from left and right.

    Examples
    --------

    The following creates and joins two random datasets with disjoint sample ids
    but non-disjoint variant sets. We use :func:`.or_else` to attempt to find a
    non-missing genotype. If neither genotype is non-missing, then the genotype
    is set to missing. In particular, note that Samples `2` and `3` have missing
    genotypes for loci 1:1 and 1:2 because those loci are not present in `mt2`
    and these samples are not present in `mt1`

    >>> hl.set_global_seed(0)
    >>> mt1 = hl.balding_nichols_model(1, 2, 3)
    >>> mt2 = hl.balding_nichols_model(1, 2, 3)
    >>> mt2 = mt2.key_rows_by(locus=hl.locus(mt2.locus.contig,
    ...                                      mt2.locus.position+2),
    ...                       alleles=mt2.alleles)
    >>> mt2 = mt2.key_cols_by(sample_idx=mt2.sample_idx+2)
    >>> mt1.show()
    +---------------+------------+------+------+
    | locus         | alleles    | 0.GT | 1.GT |
    +---------------+------------+------+------+
    | locus<GRCh37> | array<str> | call | call |
    +---------------+------------+------+------+
    | 1:1           | ["A","C"]  | 0/1  | 0/1  |
    | 1:2           | ["A","C"]  | 1/1  | 1/1  |
    | 1:3           | ["A","C"]  | 0/0  | 0/0  |
    +---------------+------------+------+------+
    <BLANKLINE>
    >>> mt2.show()  # doctest: +SKIP_OUTPUT_CHECK
    +---------------+------------+------+------+
    | locus         | alleles    | 0.GT | 1.GT |
    +---------------+------------+------+------+
    | locus<GRCh37> | array<str> | call | call |
    +---------------+------------+------+------+
    | 1:3           | ["A","C"]  | 0/1  | 1/1  |
    | 1:4           | ["A","C"]  | 0/1  | 0/1  |
    | 1:5           | ["A","C"]  | 1/1  | 0/0  |
    +---------------+------------+------+------+
    <BLANKLINE>
    >>> mt3 = hl.experimental.full_outer_join_mt(mt1, mt2)
    >>> mt3 = mt3.select_entries(GT=hl.or_else(mt3.left_entry.GT, mt3.right_entry.GT))
    >>> mt3.show()
    +---------------+------------+------+------+------+------+
    | locus         | alleles    | 0.GT | 1.GT | 2.GT | 3.GT |
    +---------------+------------+------+------+------+------+
    | locus<GRCh37> | array<str> | call | call | call | call |
    +---------------+------------+------+------+------+------+
    | 1:1           | ["A","C"]  | 0/1  | 0/1  | NA   | NA   |
    | 1:2           | ["A","C"]  | 1/1  | 1/1  | NA   | NA   |
    | 1:3           | ["A","C"]  | 0/0  | 0/0  | 0/1  | 1/1  |
    | 1:4           | ["A","C"]  | NA   | NA   | 0/1  | 0/1  |
    | 1:5           | ["A","C"]  | NA   | NA   | 1/1  | 0/0  |
    +---------------+------------+------+------+------+------+
    <BLANKLINE>

    Parameters
    ----------
    left : :class:`.MatrixTable`
    right : :class:`.MatrixTable`

    Returns
    -------
    :class:`.MatrixTable`
    """

    if [x.dtype for x in left.row_key.values()] != [x.dtype for x in right.row_key.values()]:
        raise ValueError(f"row key types do not match:\n"
                         f"  left:  {list(left.row_key.values())}\n"
                         f"  right: {list(right.row_key.values())}")

    if [x.dtype for x in left.col_key.values()] != [x.dtype for x in right.col_key.values()]:
        raise ValueError(f"column key types do not match:\n"
                         f"  left:  {list(left.col_key.values())}\n"
                         f"  right: {list(right.col_key.values())}")

    left = left.select_rows(left_row=left.row)
    left_t = left.localize_entries('left_entries', 'left_cols')
    right = right.select_rows(right_row=right.row)
    right_t = right.localize_entries('right_entries', 'right_cols')

    ht = left_t.join(right_t, how='outer')
    ht = ht.annotate_globals(
        left_keys=hl.group_by(
            lambda t: t[0],
            hl.enumerate(
                ht.left_cols.map(lambda x: hl.tuple([x[f] for f in left.col_key])), index_first=False)).map_values(
            lambda elts: elts.map(lambda t: t[1])),
        right_keys=hl.group_by(
            lambda t: t[0],
            hl.enumerate(
                ht.right_cols.map(lambda x: hl.tuple([x[f] for f in right.col_key])), index_first=False)).map_values(
            lambda elts: elts.map(lambda t: t[1])))
    ht = ht.annotate_globals(
        key_indices=hl.array(ht.left_keys.key_set().union(ht.right_keys.key_set()))
        .map(lambda k: hl.struct(k=k, left_indices=ht.left_keys.get(k), right_indices=ht.right_keys.get(k)))
        .flatmap(lambda s: hl.case()
                 .when(hl.is_defined(s.left_indices) & hl.is_defined(s.right_indices),
                       hl.range(0, s.left_indices.length()).flatmap(
                           lambda i: hl.range(0, s.right_indices.length()).map(
                               lambda j: hl.struct(k=s.k, left_index=s.left_indices[i],
                                                   right_index=s.right_indices[j]))))
                 .when(hl.is_defined(s.left_indices),
                       s.left_indices.map(
                           lambda elt: hl.struct(k=s.k, left_index=elt, right_index=hl.missing('int32'))))
                 .when(hl.is_defined(s.right_indices),
                       s.right_indices.map(
                           lambda elt: hl.struct(k=s.k, left_index=hl.missing('int32'), right_index=elt)))
                 .or_error('assertion error')))
    ht = ht.annotate(__entries=ht.key_indices.map(lambda s: hl.struct(left_entry=ht.left_entries[s.left_index],
                                                                      right_entry=ht.right_entries[s.right_index])))
    ht = ht.annotate_globals(__cols=ht.key_indices.map(
        lambda s: hl.struct(**{f: s.k[i] for i, f in enumerate(left.col_key)},
                            left_col=ht.left_cols[s.left_index],
                            right_col=ht.right_cols[s.right_index])))
    ht = ht.drop('left_entries', 'left_cols', 'left_keys', 'right_entries', 'right_cols', 'right_keys', 'key_indices')
    return ht._unlocalize_entries('__entries', '__cols', list(left.col_key))
示例#5
0
def explode_trio_matrix(tm: hl.MatrixTable,
                        col_keys: List[str] = ['s'],
                        keep_trio_cols: bool = True,
                        keep_trio_entries: bool = False) -> hl.MatrixTable:
    """Splits a trio MatrixTable back into a sample MatrixTable.

    Example
    -------
    >>> # Create a trio matrix from a sample matrix
    >>> pedigree = hl.Pedigree.read('data/case_control_study.fam')
    >>> trio_dataset = hl.trio_matrix(dataset, pedigree, complete_trios=True)

    >>> # Explode trio matrix back into a sample matrix
    >>> exploded_trio_dataset = explode_trio_matrix(trio_dataset)

    Notes
    -----
    The resulting MatrixTable column schema is the same as the proband/father/mother schema,
    and the resulting entry schema is the same as the proband_entry/father_entry/mother_entry schema.
    If the `keep_trio_cols` option is set, then an additional `source_trio` column is added with the trio column data.
    If the `keep_trio_entries` option is set, then an additional `source_trio_entry` column is added with the trio entry data.

    Note
    ----
    This assumes that the input MatrixTable is a trio MatrixTable (similar to
    the result of :func:`~.trio_matrix`) Its entry schema has to contain
    'proband_entry`, `father_entry` and `mother_entry` all with the same type.
    Its column schema has to contain 'proband`, `father` and `mother` all with
    the same type.

    Parameters
    ----------
    tm : :class:`.MatrixTable`
        Trio MatrixTable (entries have to be a Struct with `proband_entry`, `mother_entry` and `father_entry` present)
    col_keys : :obj:`list` of str
        Column key(s) for the resulting sample MatrixTable
    keep_trio_cols: bool
        Whether to add a `source_trio` column with the trio column data (default `True`)
    keep_trio_entries: bool
        Whether to add a `source_trio_entries` column with the trio entry data (default `False`)

    Returns
    -------
    :class:`.MatrixTable`
        Sample MatrixTable
    """

    select_entries_expr = {
        '__trio_entries':
        hl.array([tm.proband_entry, tm.father_entry, tm.mother_entry])
    }
    if keep_trio_entries:
        select_entries_expr['source_trio_entry'] = hl.struct(**tm.entry)
    tm = tm.select_entries(**select_entries_expr)

    tm = tm.key_cols_by()
    select_cols_expr = {
        '__trio_members':
        hl.enumerate(hl.array([tm.proband, tm.father, tm.mother]))
    }
    if keep_trio_cols:
        select_cols_expr['source_trio'] = hl.struct(**tm.col)
    tm = tm.select_cols(**select_cols_expr)

    mt = tm.explode_cols(tm.__trio_members)

    mt = mt.transmute_entries(**mt.__trio_entries[mt.__trio_members[0]])

    mt = mt.key_cols_by()
    mt = mt.transmute_cols(**mt.__trio_members[1])

    if col_keys:
        mt = mt.key_cols_by(*col_keys)

    return mt
示例#6
0
def segment_reference_blocks(ref: 'MatrixTable',
                             intervals: 'Table') -> 'MatrixTable':
    """Returns a matrix table of reference blocks segmented according to intervals.

    Loci outside the given intervals are discarded. Reference blocks that start before
    but span an interval will appear at the interval start locus.

    Note
    ----
        Assumes disjoint intervals which do not span contigs.

        Requires start-inclusive intervals.

    Parameters
    ----------
    ref : :class:`.MatrixTable`
        MatrixTable of reference blocks.
    intervals : :class:`.Table`
        Table of intervals at which to segment reference blocks.

    Returns
    -------
    :class:`.MatrixTable`
    """
    interval_field = list(intervals.key)[0]
    if not intervals[interval_field].dtype == hl.tinterval(ref.locus.dtype):
        raise ValueError(
            f"expect intervals to be keyed by intervals of loci matching the VariantDataset:"
            f" found {intervals[interval_field].dtype} / {ref.locus.dtype}")
    intervals = intervals.select(_interval_dup=intervals[interval_field])

    if not intervals.aggregate(
            hl.agg.all(intervals[interval_field].includes_start
                       & (intervals[interval_field].start.contig ==
                          intervals[interval_field].end.contig))):
        raise ValueError("expect intervals to be start-inclusive")

    starts = intervals.key_by(_start_locus=intervals[interval_field].start)
    starts = starts.annotate(_include_locus=True)
    refl = ref.localize_entries('_ref_entries', '_ref_cols')
    joined = refl.join(starts, how='outer')
    rg = ref.locus.dtype.reference_genome
    contigs = rg.contigs
    contig_idx_map = hl.literal({contigs[i]: i
                                 for i in range(len(contigs))},
                                'dict<str, int32>')
    joined = joined.annotate(__contig_idx=contig_idx_map[joined.locus.contig])
    joined = joined.annotate(_ref_entries=joined._ref_entries.map(
        lambda e: e.annotate(__contig_idx=joined.__contig_idx)))
    dense = joined.annotate(dense_ref=hl.or_missing(
        joined._include_locus,
        hl.rbind(
            joined.locus.position, lambda pos: hl.enumerate(
                hl.scan._densify(hl.len(joined._ref_cols), joined._ref_entries)
            ).map(lambda idx_and_e: hl.rbind(
                idx_and_e[0], idx_and_e[1], lambda idx, e: hl.coalesce(
                    joined._ref_entries[idx],
                    hl.or_missing((e.__contig_idx == joined.__contig_idx) &
                                  (e.END >= pos), e))).drop('__contig_idx')))))
    dense = dense.filter(dense._include_locus).drop('_interval_dup',
                                                    '_include_locus',
                                                    '__contig_idx')

    # at this point, 'dense' is a table with dense rows of reference blocks, keyed by locus

    refl_filtered = refl.annotate(
        **{interval_field: intervals[refl.locus]._interval_dup})

    # remove rows that are not contained in an interval, and rows that are the start of an
    # interval (interval starts come from the 'dense' table)
    refl_filtered = refl_filtered.filter(
        hl.is_defined(refl_filtered[interval_field])
        & (refl_filtered.locus != refl_filtered[interval_field].start))

    # union dense interval starts with filtered table
    refl_filtered = refl_filtered.union(
        dense.transmute(_ref_entries=dense.dense_ref))

    # rewrite reference blocks to end at the first of (interval end, reference block end)
    refl_filtered = refl_filtered.annotate(
        interval_end=refl_filtered[interval_field].end.position -
        ~refl_filtered[interval_field].includes_end)
    refl_filtered = refl_filtered.annotate(
        _ref_entries=refl_filtered._ref_entries.map(
            lambda entry: entry.annotate(END=hl.min(entry.END, refl_filtered.
                                                    interval_end))))

    return refl_filtered._unlocalize_entries('_ref_entries', '_ref_cols',
                                             list(ref.col_key))