def sparse_split_multi(sparse_mt): """Splits multiallelic variants on a sparse MatrixTable. Takes a dataset formatted like the output of :func:`.vcf_combiner`. The splitting will add `was_split` and `a_index` fields, as :func:`.split_multi` does. This function drops the `LA` (local alleles) field, as it re-computes entry fields based on the new, split globals alleles. Variants are split thus: - A row with only one (reference) or two (reference and alternate) alleles. - A row with multiple alternate alleles will be split, with one row for each alternate allele, and each row will contain two alleles: ref and alt. The reference and alternate allele will be minrepped using :func:`.min_rep`. The split multi logic handles the following entry fields: .. code-block:: text struct { LGT: call LAD: array<int32> DP: int32 GQ: int32 LPL: array<int32> RGQ: int32 LPGT: call LA: array<int32> END: int32 } All fields except for `LA` are optional, and only handled if they exist. - `LA` is used to find the corresponding local allele index for the desired global `a_index`, and then dropped from the resulting dataset. If `LA` does not contain the global `a_index`, the index for the `<NON_REF>` allele is used to process the entry fields. - `LGT` and `LPGT` are downcoded using the corresponding local `a_index`. They are renamed to `GT` and `PGT` respectively, as the resulting call is no longer local. - `LAD` is used to create an `AD` field consisting of the allele depths corresponding to the reference, global `a_index` allele, and `<NON_REF>` allele. - `DP` is preserved unchanged. - `GQ` is recalculated from the updated `PL`, if it exists, but otherwise preserved unchanged. - `PL` array elements are calculated from the minimum `LPL` value for all allele pairs that downcode to the desired one. (This logic is identical to the `PL` logic in :func:`.split_mult_hts`; if a row has an alternate allele but it is not present in `LA`, the `PL` field is set to missing. The `PL` for `ref/<NON_REF>` in that case can be drawn from `RGQ`. - `RGQ` (the ref genotype quality) is preserved unchanged. - `END` is untouched. Notes ----- This version of split-multi doesn't deal with either duplicate loci (in which case the explode could possibly result in out-of-order rows, although the actual split_multi function also doesn't handle that case). It also checks that min-repping will not change the locus and will error if it does. (I believe the VCF combiner checks that this holds true, currently.) Parameters ---------- sparse_mt : :class:`.MatrixTable` Sparse MatrixTable to split. Returns ------- :class:`.MatrixTable` The split MatrixTable in sparse format. """ hl.methods.misc.require_row_key_variant(sparse_mt, "sparse_split_multi") entries = hl.utils.java.Env.get_uid() cols = hl.utils.java.Env.get_uid() ds = sparse_mt.localize_entries(entries, cols) new_id = hl.utils.java.Env.get_uid() def struct_from_min_rep(i): return hl.bind(lambda mr: (hl.case() .when(ds.locus == mr.locus, hl.struct( locus=ds.locus, alleles=[mr.alleles[0], mr.alleles[1]], a_index=i, was_split=True)) .or_error("Found non-left-aligned variant in sparse_split_multi")), hl.min_rep(ds.locus, [ds.alleles[0], ds.alleles[i]])) explode_structs = hl.cond(hl.len(ds.alleles) < 3, [hl.struct( locus=ds.locus, alleles=ds.alleles, a_index=1, was_split=False)], hl._sort_by( hl.range(1, hl.len(ds.alleles)) .map(struct_from_min_rep), lambda l, r: hl._compare(l.alleles, r.alleles) < 0 )) ds = ds.annotate(**{new_id: explode_structs}).explode(new_id) def transform_entries(old_entry): def with_local_a_index(local_a_index): new_pl = hl.or_missing( hl.is_defined(old_entry.LPL), hl.or_missing( hl.is_defined(local_a_index), hl.range(0, 3).map(lambda i: hl.min( hl.range(0, hl.triangle(hl.len(old_entry.LA))) .filter(lambda j: hl.downcode(hl.unphased_diploid_gt_index_call(j), local_a_index) == hl.unphased_diploid_gt_index_call(i)) .map(lambda idx: old_entry.LPL[idx]))))) fields = set(old_entry.keys()) def with_pl(pl): new_exprs = {} dropped_fields = ['LA'] if 'LGT' in fields: new_exprs['GT'] = hl.downcode(old_entry.LGT, hl.or_else(local_a_index, hl.len(old_entry.LA))) dropped_fields.append('LGT') if 'LPGT' in fields: new_exprs['PGT'] = hl.downcode(old_entry.LPGT, hl.or_else(local_a_index, hl.len(old_entry.LA))) dropped_fields.append('LPGT') if 'LAD' in fields: new_exprs['AD'] = hl.or_missing( hl.is_defined(old_entry.LAD), [old_entry.LAD[0], hl.or_else(old_entry.LAD[local_a_index], 0)]) # second entry zeroed for lack of non-ref AD dropped_fields.append('LAD') if 'LPL' in fields: new_exprs['PL'] = pl if 'GQ' in fields: new_exprs['GQ'] = hl.or_else(hl.gq_from_pl(pl), old_entry.GQ) dropped_fields.append('LPL') return hl.cond(hl.len(ds.alleles) == 1, old_entry.annotate(**{f[1:]: old_entry[f] for f in ['LGT', 'LPGT', 'LAD', 'LPL'] if f in fields}).drop(*dropped_fields), old_entry.annotate(**new_exprs).drop(*dropped_fields)) if 'LPL' in fields: return hl.bind(with_pl, new_pl) else: return with_pl(None) lai = hl.fold(lambda accum, elt: hl.cond(old_entry.LA[elt] == ds[new_id].a_index, elt, accum), hl.null(hl.tint32), hl.range(0, hl.len(old_entry.LA))) return hl.bind(with_local_a_index, lai) new_row = ds.row.annotate(**{ 'locus': ds[new_id].locus, 'alleles': ds[new_id].alleles, 'a_index': ds[new_id].a_index, 'was_split': ds[new_id].was_split, entries: ds[entries].map(transform_entries) }).drop(new_id) ds = hl.Table( hl.ir.TableKeyBy( hl.ir.TableMapRows( hl.ir.TableKeyBy(ds._tir, ['locus']), new_row._ir), ['locus', 'alleles'], is_sorted=True)) return ds._unlocalize_entries(entries, cols, list(sparse_mt.col_key.keys()))
def sparse_split_multi(sparse_mt, *, filter_changed_loci=False): """Splits multiallelic variants on a sparse MatrixTable. Takes a dataset formatted like the output of :func:`.vcf_combiner`. The splitting will add `was_split` and `a_index` fields, as :func:`.split_multi` does. This function drops the `LA` (local alleles) field, as it re-computes entry fields based on the new, split globals alleles. Variants are split thus: - A row with only one (reference) or two (reference and alternate) alleles is unchanged, as local and global alleles are the same. - A row with multiple alternate alleles will be split, with one row for each alternate allele, and each row will contain two alleles: ref and alt. The reference and alternate allele will be minrepped using :func:`.min_rep`. The split multi logic handles the following entry fields: .. code-block:: text struct { LGT: call LAD: array<int32> DP: int32 GQ: int32 LPL: array<int32> RGQ: int32 LPGT: call LA: array<int32> END: int32 } All fields except for `LA` are optional, and only handled if they exist. - `LA` is used to find the corresponding local allele index for the desired global `a_index`, and then dropped from the resulting dataset. If `LA` does not contain the global `a_index`, calls will be downcoded to hom ref and `PL` will be set to missing. - `LGT` and `LPGT` are downcoded using the corresponding local `a_index`. They are renamed to `GT` and `PGT` respectively, as the resulting call is no longer local. - `LAD` is used to create an `AD` field consisting of the allele depths corresponding to the reference and global `a_index` alleles. - `DP` is preserved unchanged. - `GQ` is recalculated from the updated `PL`, if it exists, but otherwise preserved unchanged. - `PL` array elements are calculated from the minimum `LPL` value for all allele pairs that downcode to the desired one. (This logic is identical to the `PL` logic in :func:`.split_mult_hts`.) If a row has an alternate allele but it is not present in `LA`, the `PL` field is set to missing. The `PL` for `ref/<NON_REF>` in that case can be drawn from `RGQ`. - `RGQ` (the reference genotype quality) is preserved unchanged. - `END` is untouched. Notes ----- This version of split-multi doesn't deal with either duplicate loci (in which case the explode could possibly result in out-of-order rows, although the actual split_multi function also doesn't handle that case). It also checks that min-repping will not change the locus and will error if it does. (I believe the VCF combiner checks that this holds true, currently.) Parameters ---------- sparse_mt : :class:`.MatrixTable` Sparse MatrixTable to split. filter_changed_loci : :obj:`.bool` Rather than erroring if any REF/ALT pair changes locus under :func:`.min_rep` filter that variant instead. Returns ------- :class:`.MatrixTable` The split MatrixTable in sparse format. """ hl.methods.misc.require_row_key_variant(sparse_mt, "sparse_split_multi") entries = hl.utils.java.Env.get_uid() cols = hl.utils.java.Env.get_uid() ds = sparse_mt.localize_entries(entries, cols) new_id = hl.utils.java.Env.get_uid() def struct_from_min_rep(i): return hl.bind( lambda mr: (hl.case(). when( ds.locus == mr.locus, hl.struct(locus=ds.locus, alleles=[mr.alleles[0], mr.alleles[1]], a_index=i, was_split=True)).when( filter_changed_loci, hl.null( hl.tstruct(locus=ds.locus.dtype, alleles=hl.tarray(hl.tstr), a_index=hl.tint, was_split=hl.tbool))). or_error("Found non-left-aligned variant in sparse_split_multi\n" + "old locus: " + hl.str(ds.locus) + "\n" + "old ref : " + ds.alleles[0] + "\n" + "old alt : " + ds.alleles[ i] + "\n" + "mr locus : " + hl.str( mr.locus) + "\n" + "mr ref : " + mr.alleles[ 0] + "\n" + "mr alt : " + mr.alleles[1])), hl.min_rep(ds.locus, [ds.alleles[0], ds.alleles[i]])) explode_structs = hl.cond( hl.len(ds.alleles) < 3, [ hl.struct( locus=ds.locus, alleles=ds.alleles, a_index=1, was_split=False) ], hl._sort_by( hl.cond( filter_changed_loci, hl.range(1, hl.len(ds.alleles)).map(struct_from_min_rep).filter( hl.is_defined), hl.range(1, hl.len(ds.alleles)).map(struct_from_min_rep)), lambda l, r: hl._compare(l.alleles, r.alleles) < 0)) ds = ds.annotate(**{new_id: explode_structs}).explode(new_id) def transform_entries(old_entry): def with_local_a_index(local_a_index): fields = set(old_entry.keys()) def with_pl(pl): new_exprs = {} dropped_fields = ['LA'] if 'LGT' in fields: new_exprs['GT'] = hl.downcode( old_entry.LGT, hl.or_else(local_a_index, hl.len(old_entry.LA))) dropped_fields.append('LGT') if 'LPGT' in fields: new_exprs['PGT'] = hl.downcode( old_entry.LPGT, hl.or_else(local_a_index, hl.len(old_entry.LA))) dropped_fields.append('LPGT') if 'LAD' in fields: non_ref_ad = hl.or_else(old_entry.LAD[local_a_index], 0) # zeroed if not in LAD new_exprs['AD'] = hl.or_missing( hl.is_defined(old_entry.LAD), [hl.sum(old_entry.LAD) - non_ref_ad, non_ref_ad]) dropped_fields.append('LAD') if 'LPL' in fields: new_exprs['PL'] = pl if 'GQ' in fields: new_exprs['GQ'] = hl.or_else(hl.gq_from_pl(pl), old_entry.GQ) dropped_fields.append('LPL') return (hl.case().when( hl.len(ds.alleles) == 1, old_entry.annotate( **{ f[1:]: old_entry[f] for f in ['LGT', 'LPGT', 'LAD', 'LPL'] if f in fields }).drop(*dropped_fields)).when( hl.or_else(old_entry.LGT.is_hom_ref(), False), old_entry.annotate( **{ f: old_entry[f'L{f}'] if f in ['GT', 'PGT'] else e for f, e in new_exprs.items() }).drop(*dropped_fields)).default( old_entry.annotate(**new_exprs).drop( *dropped_fields))) if 'LPL' in fields: new_pl = hl.or_missing( hl.is_defined(old_entry.LPL), hl.or_missing( hl.is_defined(local_a_index), hl.range(0, 3).map(lambda i: hl.min( hl.range(0, hl.triangle(hl.len(old_entry.LA))). filter(lambda j: hl.downcode( hl.unphased_diploid_gt_index_call(j), local_a_index) == hl. unphased_diploid_gt_index_call(i)).map( lambda idx: old_entry.LPL[idx]))))) return hl.bind(with_pl, new_pl) else: return with_pl(None) lai = hl.fold( lambda accum, elt: hl.cond(old_entry.LA[elt] == ds[new_id].a_index, elt, accum), hl.null(hl.tint32), hl.range(0, hl.len(old_entry.LA))) return hl.bind(with_local_a_index, lai) new_row = ds.row.annotate( **{ 'locus': ds[new_id].locus, 'alleles': ds[new_id].alleles, 'a_index': ds[new_id].a_index, 'was_split': ds[new_id].was_split, entries: ds[entries].map(transform_entries) }).drop(new_id) ds = hl.Table( hl.ir.TableKeyBy(hl.ir.TableMapRows( hl.ir.TableKeyBy(ds._tir, ['locus']), new_row._ir), ['locus', 'alleles'], is_sorted=True)) return ds._unlocalize_entries(entries, cols, list(sparse_mt.col_key.keys()))
def segment_intervals(ht, points): """Segment the interval keys of `ht` at a given set of points. Parameters ---------- ht : :class:`.Table` Table with interval keys. points : :class:`.Table` or :class:`.ArrayExpression` Points at which to segment the intervals, a table or an array. Returns ------- :class:`.Table` """ if len(ht.key) != 1 or not isinstance(ht.key[0].dtype, hl.tinterval): raise ValueError( "'segment_intervals' expects a table with interval keys") point_type = ht.key[0].dtype.point_type if isinstance(points, Table): if len(points.key) != 1 or points.key[0].dtype != point_type: raise ValueError( "'segment_intervals' expects points to be a table with a single" " key of the same type as the intervals in 'ht', or an array of those points:" f"\n expect {point_type}, found {list(points.key.dtype.values())}" ) points = hl.array(hl.set(points.collect(_localize=False))) if points.dtype.element_type != point_type: raise ValueError( f"'segment_intervals' expects points to be a table with a single" f" key of the same type as the intervals in 'ht', or an array of those points:" f"\n expect {point_type}, found {points.dtype.element_type}") points = hl._sort_by(points, lambda l, r: hl._compare(l, r) < 0) ht = ht.annotate_globals(__points=points) interval = ht.key[0] points = ht.__points lower = hl.expr.functions._lower_bound(points, interval.start) higher = hl.expr.functions._lower_bound(points, interval.end) n_points = hl.len(points) lower = hl.if_else((lower < n_points) & (points[lower] == interval.start), lower + 1, lower) higher = hl.if_else((higher < n_points) & (points[higher] == interval.end), higher - 1, higher) interval_results = hl.rbind( lower, higher, lambda lower, higher: hl.cond( lower >= higher, [interval], hl.flatten([ [ hl.interval(interval.start, points[lower], includes_start=interval.includes_start, includes_end=False) ], hl.range(lower, higher - 1).map(lambda x: hl.interval( points[x], points[x + 1], includes_start=True, includes_end=False)), [ hl.interval(points[higher - 1], interval.end, includes_start=True, includes_end=interval.includes_end) ], ]))) ht = ht.annotate(__new_intervals=interval_results, lower=lower, higher=higher).explode('__new_intervals') return ht.key_by(**{ list(ht.key)[0]: ht.__new_intervals }).drop('__new_intervals')