Пример #1
0
 def test_import_vcf_flags_are_defined(self):
     # issue 3277
     t = hl.import_vcf(resource('sample.vcf')).rows()
     self.assertTrue(t.all(hl.is_defined(t.info.NEGATIVE_TRAIN_SITE) &
                           hl.is_defined(t.info.POSITIVE_TRAIN_SITE) &
                           hl.is_defined(t.info.DB) &
                           hl.is_defined(t.info.DS)))
Пример #2
0
    def recur_expr(expr, path):
        d = {}
        missingness = append_agg(hl.agg.count_where(hl.is_missing(expr)))
        d['type'] = lambda _: str(expr.dtype)
        d['missing'] = lambda \
                results: f'{results[missingness]} values ({pct(results[missingness] / results[count])})'

        t = expr.dtype

        if t in (hl.tint32, hl.tint64, hl.tfloat32, hl.tfloat64):
            stats = append_agg(hl.agg.stats(expr))
            if t in (hl.tint32, hl.tint64):
                d['minimum'] = lambda results: format(map_int(results[stats]['min']))
                d['maximum'] = lambda results: format(map_int(results[stats]['max']))
                d['sum'] = lambda results: format(map_int(results[stats]['sum']))
            else:
                d['minimum'] = lambda results: format(results[stats]['min'])
                d['maximum'] = lambda results: format(results[stats]['max'])
                d['sum'] = lambda results: format(results[stats]['sum'])
            d['mean'] = lambda results: format(results[stats]['mean'])
            d['stdev'] = lambda results: format(results[stats]['stdev'])
        elif t == hl.tbool:
            counter = append_agg(hl.agg.filter(hl.is_defined(expr), hl.agg.counter(expr)))
            d['counts'] = lambda results: format(results[counter])
        elif t == hl.tstr:
            size = append_agg(hl.agg.stats(hl.len(expr)))
            take = append_agg(hl.agg.filter(hl.is_defined(expr), hl.agg.take(expr, 5)))
            d['minimum size'] = lambda results: format(map_int(results[size]['min']))
            d['maximum size'] = lambda results: format(map_int(results[size]['max']))
            d['mean size'] = lambda results: format(results[size]['mean'])
            d['sample values'] = lambda results: format(results[take])
        elif t == hl.tcall:
            ploidy_counts = append_agg(hl.agg.filter(hl.is_defined(expr), hl.agg.counter(expr.ploidy)))
            phased_counts = append_agg(hl.agg.filter(hl.is_defined(expr), hl.agg.counter(expr.phased)))
            n_hom_ref = append_agg(hl.agg.count_where(expr.is_hom_ref()))
            n_hom_var = append_agg(hl.agg.count_where(expr.is_hom_var()))
            n_het = append_agg(hl.agg.count_where(expr.is_het()))
            d['homozygous reference'] = lambda results: format(results[n_hom_ref])
            d['heterozygous'] = lambda results: format(results[n_het])
            d['homozygous variant'] = lambda results: format(results[n_hom_var])
            d['ploidy'] = lambda results: format(results[ploidy_counts])
            d['phased'] = lambda results: format(results[phased_counts])
        elif isinstance(t, hl.tlocus):
            contig_counts = append_agg(hl.agg.filter(hl.is_defined(expr), hl.agg.counter(expr.contig)))
            d['contig counts'] = lambda results: format(results[contig_counts])
        elif isinstance(t, (hl.tset, hl.tdict, hl.tarray)):
            size = append_agg(hl.agg.stats(hl.len(expr)))
            d['minimum size'] = lambda results: format(map_int(results[size]['min']))
            d['maximum size'] = lambda results: format(map_int(results[size]['max']))
            d['mean size'] = lambda results: format(results[size]['mean'])
        to_print.append((path, d))
        if isinstance(t, hl.ttuple):
            for i in range(len(expr)):
                recur_expr(expr[i], f'{path} / {i}')
        if isinstance(t, hl.tstruct):
            for k, v in expr.items():
                recur_expr(v, f'{path} / {repr(k)[1:-1]}')
Пример #3
0
    def test_reference_genome_liftover(self):
        grch37 = hl.get_reference('GRCh37')
        grch38 = hl.get_reference('GRCh38')

        self.assertTrue(not grch37.has_liftover('GRCh38') and not grch38.has_liftover('GRCh37'))
        grch37.add_liftover(resource('grch37_to_grch38_chr20.over.chain.gz'), 'GRCh38')
        grch38.add_liftover(resource('grch38_to_grch37_chr20.over.chain.gz'), 'GRCh37')
        self.assertTrue(grch37.has_liftover('GRCh38') and grch38.has_liftover('GRCh37'))

        ds = hl.import_vcf(resource('sample.vcf'))
        t = ds.annotate_rows(liftover=hl.liftover(hl.liftover(ds.locus, 'GRCh38'), 'GRCh37')).rows()
        self.assertTrue(t.all(t.locus == t.liftover))

        null_locus = hl.null(hl.tlocus('GRCh38'))

        rows = [
            {'l37': hl.locus('20', 1, 'GRCh37'), 'l38': null_locus},
            {'l37': hl.locus('20', 60000, 'GRCh37'), 'l38': null_locus},
            {'l37': hl.locus('20', 60001, 'GRCh37'), 'l38': hl.locus('chr20', 79360, 'GRCh38')},
            {'l37': hl.locus('20', 278686, 'GRCh37'), 'l38': hl.locus('chr20', 298045, 'GRCh38')},
            {'l37': hl.locus('20', 278687, 'GRCh37'), 'l38': hl.locus('chr20', 298046, 'GRCh38')},
            {'l37': hl.locus('20', 278688, 'GRCh37'), 'l38': null_locus},
            {'l37': hl.locus('20', 278689, 'GRCh37'), 'l38': null_locus},
            {'l37': hl.locus('20', 278690, 'GRCh37'), 'l38': null_locus},
            {'l37': hl.locus('20', 278691, 'GRCh37'), 'l38': hl.locus('chr20', 298047, 'GRCh38')},
            {'l37': hl.locus('20', 37007586, 'GRCh37'), 'l38': hl.locus('chr12', 32563117, 'GRCh38')},
            {'l37': hl.locus('20', 62965520, 'GRCh37'), 'l38': hl.locus('chr20', 64334167, 'GRCh38')},
            {'l37': hl.locus('20', 62965521, 'GRCh37'), 'l38': null_locus}
        ]
        schema = hl.tstruct(l37=hl.tlocus(grch37), l38=hl.tlocus(grch38))
        t = hl.Table.parallelize(rows, schema)
        self.assertTrue(t.all(hl.cond(hl.is_defined(t.l38),
                                      hl.liftover(t.l37, 'GRCh38') == t.l38,
                                      hl.is_missing(hl.liftover(t.l37, 'GRCh38')))))

        t = t.filter(hl.is_defined(t.l38))
        self.assertTrue(t.count() == 6)

        t = t.key_by('l38')
        t.count()
        self.assertTrue(list(t.key) == ['l38'])

        null_locus_interval = hl.null(hl.tinterval(hl.tlocus('GRCh38')))
        rows = [
            {'i37': hl.locus_interval('20', 1, 60000, True, False, 'GRCh37'), 'i38': null_locus_interval},
            {'i37': hl.locus_interval('20', 60001, 82456, True, True, 'GRCh37'),
             'i38': hl.locus_interval('chr20', 79360, 101815, True, True, 'GRCh38')}
        ]
        schema = hl.tstruct(i37=hl.tinterval(hl.tlocus(grch37)), i38=hl.tinterval(hl.tlocus(grch38)))
        t = hl.Table.parallelize(rows, schema)
        self.assertTrue(t.all(hl.liftover(t.i37, 'GRCh38') == t.i38))

        grch37.remove_liftover("GRCh38")
        grch38.remove_liftover("GRCh37")
Пример #4
0
 def test_refs_with_process_joins(self):
     mt = hl.utils.range_matrix_table(10, 10)
     mt = mt.annotate_entries(
         a_literal=hl.literal(['a']),
         a_col_join=hl.is_defined(mt.cols()[mt.col_key]),
         a_row_join=hl.is_defined(mt.rows()[mt.row_key]),
         an_entry_join=hl.is_defined(mt[mt.row_key, mt.col_key]),
         the_global_failure=hl.cond(True, mt.globals, hl.null(mt.globals.dtype)),
         the_row_failure=hl.cond(True, mt.row, hl.null(mt.row.dtype)),
         the_col_failure=hl.cond(True, mt.col, hl.null(mt.col.dtype)),
         the_entry_failure=hl.cond(True, mt.entry, hl.null(mt.entry.dtype)),
     )
     mt.count()
Пример #5
0
def mean(expr) -> Float64Expression:
    """Compute the mean value of records of `expr`.

    Examples
    --------
    Compute the mean of field `HT`:

    >>> table1.aggregate(agg.mean(table1.HT))
    66.75

    Notes
    -----
    Missing values are ignored.

    Parameters
    ----------
    expr : :class:`.NumericExpression`
        Numeric expression.

    Returns
    -------
    :class:`.Expression` of type :py:data:`.tfloat64`
        Mean value of records of `expr`.
    """
    return sum(expr)/count_where(hl.is_defined(expr))
Пример #6
0
def densify(sparse_mt):
    """Convert sparse MatrixTable to a dense one.

    Parameters
    ----------
    sparse_mt : :class:`.MatrixTable`
        Sparse MatrixTable to densify.  The first row key field must
        be named ``locus`` and have type ``locus``.  Must have an
        ``END`` entry field of type ``int32``.

    Returns
    -------
    :class:`.MatrixTable`
        The densified MatrixTable.  The ``END`` entry field is dropped.

    """
    if list(sparse_mt.row_key)[0] != 'locus' or not isinstance(sparse_mt.locus.dtype, hl.tlocus):
        raise ValueError("first row key field must be named 'locus' and have type 'locus'")
    if 'END' not in sparse_mt.entry or sparse_mt.END.dtype != hl.tint32:
        raise ValueError("'densify' requires 'END' entry field of type 'int32'")
    col_key_fields = list(sparse_mt.col_key)

    mt = sparse_mt
    mt = sparse_mt.annotate_entries(__contig = mt.locus.contig)
    t = mt._localize_entries('__entries', '__cols')
    t = t.annotate(
        __entries = hl.rbind(
            hl.scan.array_agg(
                lambda entry: hl.scan._prev_nonnull(hl.or_missing(hl.is_defined(entry.END), entry)),
                t.__entries),
            lambda prev_entries: hl.map(
                lambda i:
                hl.rbind(
                    prev_entries[i], t.__entries[i],
                    lambda prev_entry, entry:
                    hl.cond(
                        (~hl.is_defined(entry) &
                         (prev_entry.END >= t.locus.position) &
                         (prev_entry.__contig == t.locus.contig)),
                        prev_entry,
                        entry)),
                hl.range(0, hl.len(t.__entries)))))
    mt = t._unlocalize_entries('__entries', '__cols', col_key_fields)
    mt = mt.drop('__contig', 'END')
    return mt
Пример #7
0
    def test_undeclared_info(self):
        mt = hl.import_vcf(resource('undeclaredinfo.vcf'))

        rows = mt.rows()
        self.assertTrue(rows.all(hl.is_defined(rows.info)))

        info_type = mt.row.dtype['info']
        self.assertTrue('InbreedingCoeff' in info_type)
        self.assertFalse('undeclared' in info_type)
        self.assertFalse('undeclaredFlag' in info_type)
Пример #8
0
        def with_local_a_index(local_a_index):
            new_pl = hl.or_missing(
                hl.is_defined(old_entry.LPL),
                hl.or_missing(
                    hl.is_defined(local_a_index),
                    hl.range(0, 3).map(lambda i: hl.min(
                        hl.range(0, hl.triangle(hl.len(old_entry.LA)))
                            .filter(lambda j: hl.downcode(hl.unphased_diploid_gt_index_call(j), local_a_index) == hl.unphased_diploid_gt_index_call(i))
                            .map(lambda idx: old_entry.LPL[idx])))))
            fields = set(old_entry.keys())

            def with_pl(pl):
                new_exprs = {}
                dropped_fields = ['LA']
                if 'LGT' in fields:
                    new_exprs['GT'] = hl.downcode(old_entry.LGT, hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LGT')
                if 'LPGT' in fields:
                    new_exprs['PGT'] = hl.downcode(old_entry.LPGT, hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LPGT')
                if 'LAD' in fields:
                    new_exprs['AD'] = hl.or_missing(
                        hl.is_defined(old_entry.LAD),
                        [old_entry.LAD[0], hl.or_else(old_entry.LAD[local_a_index], 0)]) # second entry zeroed for lack of non-ref AD
                    dropped_fields.append('LAD')
                if 'LPL' in fields:
                    new_exprs['PL'] = pl
                    if 'GQ' in fields:
                        new_exprs['GQ'] = hl.or_else(hl.gq_from_pl(pl), old_entry.GQ)

                    dropped_fields.append('LPL')

                return hl.cond(hl.len(ds.alleles) == 1,
                                   old_entry.annotate(**{f[1:]: old_entry[f] for f in ['LGT', 'LPGT', 'LAD', 'LPL'] if f in fields}).drop(*dropped_fields),
                                   old_entry.annotate(**new_exprs).drop(*dropped_fields))

            if 'LPL' in fields:
                return hl.bind(with_pl, new_pl)
            else:
                return with_pl(None)
Пример #9
0
    def test_joins(self):
        kt = hl.utils.range_table(1).key_by().drop('idx')
        kt = kt.annotate(a='foo')

        kt1 = hl.utils.range_table(1).key_by().drop('idx')
        kt1 = kt1.annotate(a='foo', b='bar').key_by('a')

        kt2 = hl.utils.range_table(1).key_by().drop('idx')
        kt2 = kt2.annotate(b='bar', c='baz').key_by('b')

        kt3 = hl.utils.range_table(1).key_by().drop('idx')
        kt3 = kt3.annotate(c='baz', d='qux').key_by('c')

        kt4 = hl.utils.range_table(1).key_by().drop('idx')
        kt4 = kt4.annotate(d='qux', e='quam').key_by('d')

        ktr = kt.annotate(e=kt4[kt3[kt2[kt1[kt.a].b].c].d].e)
        self.assertTrue(ktr.aggregate(agg.collect(ktr.e)) == ['quam'])

        ktr = kt.select(e=kt4[kt3[kt2[kt1[kt.a].b].c].d].e)
        self.assertTrue(ktr.aggregate(agg.collect(ktr.e)) == ['quam'])

        self.assertEqual(kt.filter(kt4[kt3[kt2[kt1[kt.a].b].c].d].e == 'quam').count(), 1)

        m = hl.import_vcf(resource('sample.vcf'))
        vkt = m.rows()
        vkt = vkt.select(vkt.qual)
        vkt = vkt.annotate(qual2=m.index_rows(vkt.key).qual)
        self.assertTrue(vkt.filter(vkt.qual != vkt.qual2).count() == 0)

        m2 = m.annotate_rows(qual2=vkt.index(m.row_key).qual)
        self.assertTrue(m2.filter_rows(m2.qual != m2.qual2).count_rows() == 0)

        m3 = m.annotate_rows(qual2=m.index_rows(m.row_key).qual)
        self.assertTrue(m3.filter_rows(m3.qual != m3.qual2).count_rows() == 0)

        kt5 = hl.utils.range_table(1).annotate(key='C1589').key_by('key')
        m4 = m.annotate_cols(foo=m.s[:5])
        m4 = m4.annotate_cols(idx=kt5[m4.foo].idx)
        n_C1589 = m.filter_cols(m.s[:5] == 'C1589').count_cols()
        self.assertTrue(n_C1589 > 1)
        self.assertEqual(m4.filter_cols(hl.is_defined(m4.idx)).count_cols(), n_C1589)

        kt = hl.utils.range_table(1)
        kt = kt.annotate_globals(foo=5)
        self.assertEqual(hl.eval(kt.foo), 5)

        kt2 = hl.utils.range_table(1)

        kt2 = kt2.annotate_globals(kt_foo=kt.index_globals().foo)
        self.assertEqual(hl.eval(kt2.globals.kt_foo), 5)
Пример #10
0
    def test_entry_join_missingness(self):
        mt1 = hl.utils.range_matrix_table(10, 10, n_partitions=4)
        mt1 = mt1.annotate_entries(x=mt1.row_idx + mt1.col_idx)

        mt2 = mt1.filter_cols(mt1.col_idx % 2 == 0)
        mt2 = mt2.filter_rows(mt2.row_idx % 2 == 0)
        mt_join = mt1.annotate_entries(x2=mt2[mt1.row_idx, mt1.col_idx].x * 10)
        mt_join_entries = mt_join.entries()

        kept = mt_join_entries.filter((mt_join_entries.row_idx % 2 == 0) & (mt_join_entries.col_idx % 2 == 0))
        removed = mt_join_entries.filter(~((mt_join_entries.row_idx % 2 == 0) & (mt_join_entries.col_idx % 2 == 0)))

        self.assertTrue(kept.all(hl.is_defined(kept.x2) & (kept.x2 == kept.x * 10)))
        self.assertTrue(removed.all(hl.is_missing(removed.x2)))
Пример #11
0
def main(args):
    hl.init(master=f'local[{args.n_threads}]',
            log=hl.utils.timestamp_path(os.path.join(tempfile.gettempdir(), 'extract_vcf'), suffix='.log'),
            default_reference=args.reference)

    sys.path.append('/')
    add_args = []
    if args.additional_args is not None:
        add_args = args.additional_args.split(',')
    load_module = importlib.import_module(args.load_module)
    mt = getattr(load_module, args.load_mt_function)(*add_args)

    if args.gene_map_ht_path is None:
        interval = [hl.parse_locus_interval(args.interval)]
    else:
        gene_ht = hl.read_table(args.gene_map_ht_path)
        if args.gene is not None:
            gene_ht = gene_ht.filter(gene_ht.gene_symbol == args.gene)
            interval = gene_ht.aggregate(hl.agg.take(gene_ht.interval, 1), _localize=False)
        else:
            interval = [hl.parse_locus_interval(args.interval)]
            gene_ht = hl.filter_intervals(gene_ht, interval)

        gene_ht = gene_ht.filter(hl.set(args.groups.split(',')).contains(gene_ht.annotation))
        gene_ht.select(group=gene_ht.gene_id + '_' + gene_ht.gene_symbol + '_' + gene_ht.annotation, variant=hl.delimit(gene_ht.variants, '\t')
                       ).key_by().drop('start').export(args.group_output_file, header=False)
        # TODO: possible minor optimization: filter output VCF to only variants in `gene_ht.variants`

    if not args.no_adj:
        mt = mt.filter_entries(mt.adj)

    mt = hl.filter_intervals(mt, interval)

    if not args.input_bgen:
        mt = mt.select_entries('GT')
        mt = mt.filter_rows(hl.agg.count_where(mt.GT.is_non_ref()) > 0)
    mt = mt.annotate_rows(rsid=mt.locus.contig + ':' + hl.str(mt.locus.position) + '_' + mt.alleles[0] + '/' + mt.alleles[1])

    if args.callrate_filter:
        mt = mt.filter_rows(hl.agg.fraction(hl.is_defined(mt.GT)) >= args.callrate_filter)

    if args.export_bgen:
        if not args.input_bgen:
            mt = gt_to_gp(mt)
            mt = impute_missing_gp(mt, mean_impute=args.mean_impute_missing)
        hl.export_bgen(mt, args.output_file, gp=mt.GP, varid=mt.rsid)
    else:
        mt = mt.annotate_entries(GT=hl.or_else(mt.GT, hl.call(0, 0)))
        # Note: no mean-imputation for VCF
        hl.export_vcf(mt, args.output_file)
Пример #12
0
    def _ac_an_parent_child_count(
        proband_gt: hl.expr.CallExpression,
        father_gt: hl.expr.CallExpression,
        mother_gt: hl.expr.CallExpression,
    ) -> Dict[str, hl.expr.Int64Expression]:
        """
        Helper method to get AC and AN for parents and children
        """
        ac_parent_expr = hl.agg.sum(
            father_gt.n_alt_alleles() + mother_gt.n_alt_alleles()
        )
        an_parent_expr = hl.agg.sum(
            (hl.is_defined(father_gt) + hl.is_defined(mother_gt)) * 2
        )
        ac_child_expr = hl.agg.sum(proband_gt.n_alt_alleles())
        an_child_expr = hl.agg.sum(hl.is_defined(proband_gt) * 2)

        return {
            f"ac_parents": ac_parent_expr,
            f"an_parents": an_parent_expr,
            f"ac_children": ac_child_expr,
            f"an_children": an_child_expr,
        }
Пример #13
0
def get_summary_ac_dict(
    ac_expr: hl.expr.Int64Expression,
    lof_expr: hl.expr.StringExpression,
    no_lof_flags_expr: hl.expr.BooleanExpression,
    most_severe_csq_expr: hl.expr.StringExpression,
) -> Dict[str, hl.expr.Int64Expression]:
    """
    Returns dictionary containing containing total allele counts for variant categories.

    Categories are:
        - All variants
        - LoF variants
        - LoF variants that pass LOFTEE
        - LoF variants that pass LOFTEE without any flags
        - LoF variants that are annotate as 'other splice' (OS) by LOFTEE
        - LoF variants that fail LOFTEE
        - Missense variants
        - Synonymous variants

    .. warning:: 
        Assumes `allele_expr` contains only two variants (multi-allelics have been split).

    :param allele_expr: ArrayExpression containing alleles.
    :param lof_expr: StringExpression containing LOFTEE annotation.
    :param no_lof_flags_expr: BooleanExpression indicating whether LoF variant has any flags.
    :return: Dict of variant categories and their total allele counts.
    """
    logger.warning(
        "This function expects that multi-allelic variants have been split!")
    return {
        "total_ac":
        hl.agg.sum(ac_expr),
        "total_ac_LOF":
        hl.agg.filter(hl.is_defined(lof_expr), hl.agg.sum(ac_expr)),
        "total_ac_pass_loftee":
        hl.agg.filter(lof_expr == "HC", hl.agg.sum(ac_expr)),
        "total_ac_pass_loftee_no_flag":
        hl.agg.filter((lof_expr == "HC") & (no_lof_flags_expr),
                      hl.agg.sum(ac_expr)),
        "total_ac_loftee_os":
        hl.agg.filter(lof_expr == "OS", hl.agg.sum(ac_expr)),
        "total_ac_fail_loftee":
        hl.agg.filter(lof_expr == "LC", hl.agg.sum(ac_expr)),
        "total_ac_missense":
        hl.agg.filter(most_severe_csq_expr == "missense_variant",
                      hl.agg.sum(ac_expr)),
        "total_ac_synonymous":
        hl.agg.filter(most_severe_csq_expr == "synonymous_variant",
                      hl.agg.sum(ac_expr)),
    }
Пример #14
0
 def _repartition():
     ## part 2: 5 min with 100 preemptibles
     mt = hl.read_matrix_table(tmp_mt_path)
     mt = mt.repartition(1000)
     withdrawn = hl.read_table(
         'gs://ukb31063/ukb31063.withdrawn_samples.ht')
     mt = mt.anti_join_cols(withdrawn)
     # print(mt.count()) # (1089172, 487409)
     covs = hl.read_table(
         'gs://ukb31063/ukb31063.neale_gwas_covariates.both_sexes.ht')
     mt = mt.annotate_cols(**covs[mt.s])
     mt = mt.filter_cols(hl.is_defined(mt.PC1))
     # print(mt.count())  # (1089172, 361144)
     return mt.checkpoint(mt_path, overwrite=overwrite)
Пример #15
0
def run_gwas(mt,
             phen: str,
             sim_name: str,
             subset_idx: int,
             param_suffix: str,
             wd: str,
             is_logreg=True):
    assert {'GT', 'dosage'}.intersection(
        mt.entry
    ) != {}, "mt does not have an entry field named 'dosage' or 'GT' corresponding to genotype data"

    mt = mt.filter_cols(mt.subset_idx == subset_idx)
    mt = mt.filter_cols(hl.is_defined(mt[phen]))
    print(
        f'\n\ngwas sample count (subset {subset_idx}): {mt.count_cols()}\n\n')

    if 'dosage' in mt.entry:
        mt = mt.annotate_rows(EAF=hl.agg.mean(mt.dosage) / 2)
    elif 'GT' in mt.entry:
        mt = mt.annotate_rows(EAF=hl.agg.mean(mt.GT.n_alt_alleles()) / 2)

    gwas_path = f'{wd}/gwas.{"logreg" if is_logreg else "linreg"}.{sim_name}.subset_{subset_idx}.{param_suffix}.tsv.gz'

    if not hl.hadoop_is_file(gwas_path):
        gt_field = mt.dosage if 'dosage' in mt.entry else mt.GT.n_alt_alleles()

        if is_logreg:
            gwas_ht = hl.logistic_regression_rows(test='wald',
                                                  y=mt[phen],
                                                  x=gt_field,
                                                  covariates=[1],
                                                  pass_through=['EAF'])
        else:
            gwas_ht = hl.linear_regression_rows(y=mt[phen],
                                                x=gt_field,
                                                covariates=[1],
                                                pass_through=['EAF'])
        gwas_ht.select('EAF', 'beta', 'standard_error',
                       'p_value').export(gwas_path)

    else:
        print(f'GWAS already run! ({gwas_path})')
        gwas_ht = hl.import_table(gwas_path, impute=True, force=True)
        gwas_ht = gwas_ht.annotate(locus=hl.parse_locus(gwas_ht.locus),
                                   alleles=gwas_ht.alleles.replace(
                                       '\[\"', '').replace('\"\]',
                                                           '').split('\",\"'))
        gwas_ht = gwas_ht.key_by('locus', 'alleles')

    return gwas_ht
def run_pca_with_relateds(
    qc_mt: hl.MatrixTable,
    related_samples_to_drop: Optional[hl.Table],
    n_pcs: int = 10,
    autosomes_only: bool = True,
) -> Tuple[List[float], hl.Table, hl.Table]:
    """
    First runs PCA excluding the given related samples,
    then projects these samples in the PC space to return scores for all samples.

    The `related_samples_to_drop` Table has to be keyed by the sample ID and all samples present in this
    table will be excluded from the PCA.

    The loadings Table returned also contains a `pca_af` annotation which is the allele frequency
    used for PCA. This is useful to project other samples in the PC space.

    :param qc_mt: Input QC MT
    :param related_samples_to_drop: Optional table of related samples to drop
    :param n_pcs: Number of PCs to compute
    :param autosomes_only: Whether to run the analysis on autosomes only
    :return: eigenvalues, scores and loadings
    """

    unrelated_mt = qc_mt.persist()

    if autosomes_only:
        unrelated_mt = filter_to_autosomes(unrelated_mt)

    if related_samples_to_drop:
        unrelated_mt = qc_mt.filter_cols(
            hl.is_missing(related_samples_to_drop[qc_mt.col_key]))

    pca_evals, pca_scores, pca_loadings = hl.hwe_normalized_pca(
        unrelated_mt.GT, k=n_pcs, compute_loadings=True)
    pca_af_ht = unrelated_mt.annotate_rows(
        pca_af=hl.agg.mean(unrelated_mt.GT.n_alt_alleles()) / 2).rows()
    pca_loadings = pca_loadings.annotate(
        pca_af=pca_af_ht[pca_loadings.key].pca_af
    )  # TODO: Evaluate if needed to write results at this point if relateds or not

    if not related_samples_to_drop:
        return pca_evals, pca_scores, pca_loadings
    else:
        pca_loadings = pca_loadings.persist()
        pca_scores = pca_scores.persist()
        related_mt = qc_mt.filter_cols(
            hl.is_defined(related_samples_to_drop[qc_mt.col_key]))
        related_scores = pc_project(related_mt, pca_loadings)
        pca_scores = pca_scores.union(related_scores)
        return pca_evals, pca_scores, pca_loadings
Пример #17
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("gencode")
    parser.add_argument("canonical_transcripts")
    parser.add_argument("hgnc")
    parser.add_argument("--min-partitions", type=int, default=8)
    parser.add_argument("--output", required=True)
    args = parser.parse_args()

    # Load genes from GTF file
    genes = load_gencode_gene_models(args.gencode, min_partitions=args.min_partitions)
    genes = genes.transmute(gencode_gene_symbol=genes.gene_symbol)

    # Annotate genes with canonical transcript
    canonical_transcripts = load_canonical_transcripts(args.canonical_transcripts, min_partitions=args.min_partitions)
    genes = genes.annotate(canonical_transcript_id=canonical_transcripts[genes.gene_id].transcript_id)

    # Drop transcripts except for canonical
    genes = genes.annotate(
        canonical_transcript=genes.transcripts.filter(
            lambda transcript: transcript.transcript_id == genes.canonical_transcript_id
        ).head()
    )
    genes = genes.drop("transcripts")

    # Annotate genes with information from HGNC
    hgnc = load_hgnc(args.hgnc)
    genes = genes.annotate(**hgnc[genes.gene_id])
    genes = genes.annotate(symbol_source=hl.cond(hl.is_defined(genes.symbol), "hgnc", hl.null(hl.tstr)))
    genes = genes.annotate(
        symbol=hl.or_else(genes.symbol, genes.gencode_gene_symbol),
        symbol_source=hl.or_else(genes.symbol_source, "gencode"),
    )

    # Collect all fields that can be used to search by gene symbol
    genes = genes.annotate(
        symbol_upper_case=genes.symbol.upper(),
        search_terms=hl.set(
            hl.empty_array(hl.tstr)
            .append(genes.symbol)
            .extend(genes.previous_symbols)
            .extend(genes.alias_symbols)
            .append(genes.gencode_gene_symbol)
            .map(lambda s: s.upper())
        ),
    )

    genes.describe()

    genes.write(args.output, overwrite=True)
Пример #18
0
def prepare_clinvar_variants(clinvar_path, reference_genome):
    ds = hl.read_table(clinvar_path)

    ds = ds.filter(hl.is_defined(ds[f"locus_{reference_genome}"]) & hl.is_defined(ds[f"alleles_{reference_genome}"]))

    ds = ds.select(locus=ds[f"locus_{reference_genome}"], alleles=ds[f"alleles_{reference_genome}"], **ds.variant)

    # Remove any variants with alleles other than ACGT
    ds = ds.filter(
        hl.len(hl.set(hl.delimit(ds.alleles, "").split("")).difference(hl.set(["A", "C", "G", "T", ""]))) == 0
    )

    ds = ds.annotate(
        variant_id=variant_id(ds.locus, ds.alleles),
        chrom=normalized_contig(ds.locus.contig),
        pos=ds.locus.position,
        ref=ds.alleles[0],
        alt=ds.alleles[1],
    )

    ds = ds.key_by("locus", "alleles")

    return ds
Пример #19
0
def pc_project(mt: hl.MatrixTable,
               pc_loadings: hl.Table,
               loading_location: str = "loadings",
               af_location: str = "pca_af") -> hl.MatrixTable:
    """
    Projects samples in `mt` on PCs computed in `pc_mt`
    :param MatrixTable mt: MT containing the samples to project
    :param Table pc_loadings: MT containing the PC loadings for the variants
    :param str loading_location: Location of expression for loadings in `pc_loadings`
    :param str af_location: Location of expression for allele frequency in `pc_loadings`
    :return: MT with scores calculated from loadings
    """
    n_variants = mt.count_rows()

    mt = mt.annotate_rows(**pc_loadings[mt.locus, mt.alleles])
    mt = mt.filter_rows(
        hl.is_defined(mt[loading_location]) & hl.is_defined(mt[af_location])
        & (mt[af_location] > 0) & (mt[af_location] < 1))

    gt_norm = (mt.GT.n_alt_alleles() - 2 * mt[af_location]) / hl.sqrt(
        n_variants * 2 * mt[af_location] * (1 - mt[af_location]))
    return mt.annotate_cols(pca_scores=hl.agg.array_sum(mt[loading_location] *
                                                        gt_norm))
Пример #20
0
    def test_entry_join_missingness(self):
        mt1 = hl.utils.range_matrix_table(10, 10, n_partitions=4)
        mt1 = mt1.annotate_entries(x=mt1.row_idx + mt1.col_idx)

        mt2 = mt1.filter_cols(mt1.col_idx % 2 == 0)
        mt2 = mt2.filter_rows(mt2.row_idx % 2 == 0)
        mt_join = mt1.annotate_entries(x2=mt2[mt1.row_idx, mt1.col_idx].x * 10)
        mt_join_entries = mt_join.entries()

        kept = mt_join_entries.filter((mt_join_entries.row_idx % 2 == 0) & (mt_join_entries.col_idx % 2 == 0))
        removed = mt_join_entries.filter(~((mt_join_entries.row_idx % 2 == 0) & (mt_join_entries.col_idx % 2 == 0)))

        self.assertTrue(kept.all(hl.is_defined(kept.x2) & (kept.x2 == kept.x * 10)))
        self.assertTrue(removed.all(hl.is_missing(removed.x2)))
Пример #21
0
def kyle_sex_specific_qc(mt_path):
    mt = hl.read_matrix_table(mt_path)
    mt = mt.annotate_cols(sex=hl.cond(hl.rand_bool(0.5), 'Male', 'Female'))
    (num_males, num_females) = mt.aggregate_cols(
        (hl.agg.count_where(mt.sex == 'Male'),
         hl.agg.count_where(mt.sex == 'Female')))
    mt = mt.annotate_rows(
        male_hets=hl.agg.count_where(mt.GT.is_het() & (mt.sex == 'Male')),
        male_homvars=hl.agg.count_where(mt.GT.is_hom_var()
                                        & (mt.sex == 'Male')),
        male_calls=hl.agg.count_where(
            hl.is_defined(mt.GT) & (mt.sex == 'Male')),
        female_hets=hl.agg.count_where(mt.GT.is_het() & (mt.sex == 'Female')),
        female_homvars=hl.agg.count_where(mt.GT.is_hom_var()
                                          & (mt.sex == 'Female')),
        female_calls=hl.agg.count_where(
            hl.is_defined(mt.GT) & (mt.sex == 'Female')))

    mt = mt.annotate_rows(
        call_rate=(hl.case().when(mt.locus.in_y_nonpar(),
                                  (mt.male_calls / num_males)).when(
                                      mt.locus.in_x_nonpar(),
                                      (mt.male_calls + 2 * mt.female_calls) /
                                      (num_males + 2 * num_females)).default(
                                          (mt.male_calls + mt.female_calls) /
                                          (num_males + num_females))),
        AC=(hl.case().when(mt.locus.in_y_nonpar(), mt.male_homvars).when(
            mt.locus.in_x_nonpar(), mt.male_homvars + mt.female_hets +
            2 * mt.female_homvars).default(mt.male_hets + 2 * mt.male_homvars +
                                           mt.female_hets +
                                           2 * mt.female_homvars)),
        AN=(hl.case().when(mt.locus.in_y_nonpar(), mt.male_calls).when(
            mt.locus.in_x_nonpar(),
            mt.male_calls + 2 * mt.female_calls).default(2 * mt.male_calls +
                                                         2 * mt.female_calls)))

    mt.rows()._force_count()
Пример #22
0
def filter_low_conf_regions(
        mt: hl.MatrixTable,
        filter_lcr: bool = True,
        filter_decoy: bool = True,
        filter_segdup: bool = True,
        high_conf_regions: Optional[List[str]] = None) -> hl.MatrixTable:
    """
    Filters low-confidence regions

    :param MatrixTable mt: MT to filter
    :param bool filter_lcr: Whether to filter LCR regions
    :param bool filter_decoy: Whether to filter decoy regions
    :param bool filter_segdup: Whether to filter Segdup regions
    :param list of str high_conf_regions: Paths to set of high confidence regions to restrict to (union of regions)
    :return: MT with low confidence regions removed
    :rtype: MatrixTable
    """
    from gnomad_hail.resources import lcr_intervals_path, decoy_intervals_path, segdup_intervals_path

    if filter_lcr:
        lcr = hl.import_locus_intervals(lcr_intervals_path)
        mt = mt.filter_rows(hl.is_defined(lcr[mt.locus]), keep=False)

    if filter_decoy:
        decoy = hl.import_bed(decoy_intervals_path)
        mt = mt.filter_rows(hl.is_defined(decoy[mt.locus]), keep=False)

    if filter_segdup:
        segdup = hl.import_bed(segdup_intervals_path)
        mt = mt.filter_rows(hl.is_defined(segdup[mt.locus]), keep=False)

    if high_conf_regions is not None:
        for region in high_conf_regions:
            region = hl.import_locus_intervals(region)
            mt = mt.filter_rows(hl.is_defined(region[mt.locus]), keep=True)

    return mt
Пример #23
0
def all(condition) -> BooleanExpression:
    """Returns ``True`` if `condition` is ``True`` for every record.

    Examples
    --------

    >>> (table1.group_by(table1.SEX)
    ... .aggregate(all_under_70 = agg.all(table1.HT < 70))
    ... .show())
    +-----+--------------+
    | SEX | all_under_70 |
    +-----+--------------+
    | str | bool         |
    +-----+--------------+
    | M   | false        |
    | F   | false        |
    +-----+--------------+

    Notes
    -----
    If there are no records to aggregate, the result is ``True``.

    Missing records are not considered. If every record is missing,
    the result is also ``True``.

    Parameters
    ----------
    condition : :class:`.BooleanExpression`
        Condition to test.

    Returns
    -------
    :class:`.BooleanExpression`
    """
    n_defined = count(filter(lambda x: hl.is_defined(x), condition))
    n_true = count(filter(lambda x: hl.is_defined(x) & x, condition))
    return n_defined == n_true
Пример #24
0
def get_filtered_mt(chrom: str = 'all',
                    pop: str = 'all',
                    imputed: bool = True,
                    min_mac: int = 20,
                    entry_fields=('GP', ),
                    filter_mac_instead_of_ac: bool = False):

    # get ac or mac based on filter_mac_instead_of_ac
    def get_ac(af, an):
        if filter_mac_instead_of_ac:
            # Note that the underlying file behind get_ukb_af_ht_path() accidentally double af and halve an
            return (1.0 - hl.abs(1.0 - af)) * an
        else:
            return af * an

    if imputed:
        ht = hl.read_table(get_ukb_af_ht_path())
        if pop == 'all':
            ht = ht.filter(
                hl.any(lambda x: get_ac(ht.af[x], ht.an[x]) >= min_mac,
                       hl.literal(POPS)))
        else:
            ht = ht.filter(get_ac(ht.af[pop], ht.an[pop]) >= min_mac)
        mt = get_ukb_imputed_data(chrom,
                                  variant_list=ht,
                                  entry_fields=entry_fields)
    else:
        mt = hl.read_matrix_table('gs://ukb31063/ukb31063.genotype.mt')

    covariates_ht = get_covariates()
    hq_samples_ht = get_hq_samples()
    mt = mt.annotate_cols(**covariates_ht[mt.s])
    mt = mt.filter_cols(
        hl.is_defined(mt.pop) & hl.is_defined(hq_samples_ht[mt.s]))

    if pop != 'all': mt = mt.filter_cols(mt.pop == pop)
    return mt
Пример #25
0
def get_cnt_matrix(mnv_table,
                   region="ALL",
                   dist=1,
                   minimum_cnt=0,
                   PASS=True,
                   part_size=1000,
                   hom=False):
    # mnv_table = hail table of mnvs
    # region = bed file, defining the regions of interest (e.g. enhancer region)
    # dist = distance between two SNPs
    # PASS=True: restrict to both pass variants
    # we don't care indels anymore
    # filter by region, if you give a bed file path as region
    if region != "ALL":
        bed = hl.import_bed(region)
        mnv_table = mnv_table.filter(hl.is_defined(bed[mnv_table.locus]))
    if PASS == "NO":  #exclusively getting at least one non-pass ones
        mnv_table = mnv_table.filter((mnv_table.filters.length() > 0)
                                     | (mnv_table.prev_filters.length() > 0))
    elif PASS == True:
        mnv_table = mnv_table.filter((mnv_table.filters.length() == 0)
                                     & (mnv_table.prev_filters.length() == 0))
    if hom:
        mnv_table = mnv_table.filter(mnv_table.n_homhom > 0)
    # count MNV occurance -- restricting to SNPs
    mnv = mnv_table.filter(
        (mnv_table.alleles[0].length() == 1) &
        (mnv_table.alleles[1].length() == 1) &
        (mnv_table.prev_alleles[0].length() == 1) &
        (mnv_table.prev_alleles[1].length() == 1) &
        ((mnv_table.locus.position - mnv_table.prev_locus.position)
         == dist))  # filter to that specific distance

    #repartition to proper size
    mnv = mnv.repartition(part_size)

    mnv_cnt = mnv.group_by("alleles", "prev_alleles").aggregate(
        cnt=agg.count())  # count occurance
    mnv_cnt = mnv_cnt.annotate(refs=mnv_cnt.prev_alleles[0] + "N" *
                               (dist - 1) +
                               mnv_cnt.alleles[0])  # annotate combined refs
    mnv_cnt = mnv_cnt.annotate(alts=mnv_cnt.prev_alleles[1] + "N" *
                               (dist - 1) +
                               mnv_cnt.alleles[1])  # annotate combined alts

    if minimum_cnt > 0:
        mnv_cnt = mnv_cnt.filter(
            (mnv_cnt.cnt > minimum_cnt))  # remove trivial ones
    return (mnv_cnt.select("refs", "alts", "cnt"))
Пример #26
0
def generate_fam_stats(
        mt: hl.MatrixTable,
        fam_file: str
) -> hl.Table:
    """
    Calculate transmission and de novo mutation statistics using trios in the dataset.

    :param mt: Input MatrixTable
    :param fam_file: path to text file containing trio pedigree
    :return: Table containing trio stats
    """
    # Load Pedigree data and filter MT to samples present in any of the trios
    ped = hl.Pedigree.read(fam_file, delimiter="\t")
    fam_ht = hl.import_fam(fam_file, delimiter="\t")
    fam_ht = fam_ht.annotate(
        fam_members=[fam_ht.id, fam_ht.pat_id, fam_ht.mat_id]
    )
    fam_ht = fam_ht.explode('fam_members', name='s')
    fam_ht = fam_ht.key_by('s').select().distinct()

    mt = mt.filter_cols(hl.is_defined(fam_ht[mt.col_key]))
    logger.info(f"Generating family stats using {mt.count_cols()} samples from {len(ped.trios)} trios.")

    mt = filter_to_autosomes(mt)
    mt = annotate_adj(mt)
    mt = mt.select_entries('GT', 'GQ', 'AD', 'END', 'adj')
    mt = hl.experimental.densify(mt)
    mt = mt.filter_rows(hl.len(mt.alleles) == 2)
    mt = hl.trio_matrix(mt, pedigree=ped, complete_trios=True)
    trio_adj = (mt.proband_entry.adj & mt.father_entry.adj & mt.mother_entry.adj)

    ht = mt.select_rows(
        **generate_trio_stats_expr(
            mt,
            transmitted_strata={
                'raw': True,
                'adj': trio_adj
            },
            de_novo_strata={
                'raw': True,
                'adj': trio_adj,
            },
            proband_is_female_expr=mt.is_female
        )
    ).rows()

    return ht.filter(
        ht.n_de_novos_raw + ht.n_transmitted_raw + ht.n_untransmitted_raw > 0
    )
Пример #27
0
def get_mt(remove_withdrawn=True):
    ## matrix table of 1008898 HM3 variants and white British subset of UKB
    mt = hl.read_matrix_table(
        'gs://nbaya/split/ukb31063.hm3_variants.gwas_samples_repart.mt')

    if remove_withdrawn:
        withdrawn = hl.import_table('gs://nbaya/w31063_20200204.csv',
                                    missing='',
                                    no_header=True)
        withdrawn = withdrawn.rename({'f0':
                                      's'})  # rename field with sample IDs
        withdrawn = withdrawn.key_by('s')
        mt = mt.filter_cols(hl.is_defined(withdrawn[mt.s]), keep=False)

    return mt
Пример #28
0
def pc_project(mt: hl.MatrixTable,
               loadings_ht: hl.Table,
               loading_location: str = "loadings",
               af_location: str = "pca_af"):
    """
    Projects samples in `mt` on pre-computed PCs.
    :param MatrixTable mt: MT containing the samples to project
    :param Table loadings_ht: HT containing the PCA loadings and allele frequencies used for the PCA
    :param str loading_location: Location of expression for loadings in `loadings_ht`
    :param str af_location: Location of expression for allele frequency in `loadings_ht`
    :return: Table with scores calculated from loadings in column `scores`
    :rtype: Table
    """
    n_variants = loadings_ht.count()
    mt = mt.annotate_rows(
        pca_loadings=loadings_ht[mt.row_key][loading_location],
        pca_af=loadings_ht[mt.row_key][af_location])
    mt = mt.filter_rows(
        hl.is_defined(mt.pca_loadings) & hl.is_defined(mt.pca_af)
        & (mt.pca_af > 0) & (mt.pca_af < 1))
    gt_norm = (mt.GT.n_alt_alleles() - 2 * mt.pca_af) / hl.sqrt(
        n_variants * 2 * mt.pca_af * (1 - mt.pca_af))
    mt = mt.annotate_cols(scores=hl.agg.array_sum(mt.pca_loadings * gt_norm))
    return mt.cols().select('scores')
Пример #29
0
def pc_project(mt,
               loadings_ht,
               loading_location="loadings",
               af_location="pca_af"):
    """
    Projects samples in `mt` on pre-computed PCs.
    :param MatrixTable mt: MT containing the samples to project into previously calculated PCs
    :param Table loadings_ht: HT containing the PCA loadings and allele frequencies used for the PCA
    :param str loading_location: Location of expression for loadings in `loadings_ht`
    :param str af_location: Location of expression for allele frequency in `loadings_ht`
    :return: Hail Table with scores calculated from loadings in column `scores`
    :rtype: Table

    From Konrad Karczewski
    """
    n_variants = loadings_ht.count()

    # Annotate matrix table with pca loadings and af from other dataset which pcs were calculated from
    mt = mt.annotate_rows(
        pca_loadings=loadings_ht[mt.row_key][loading_location],
        pca_af=loadings_ht[mt.row_key][af_location])

    # Filter to rows where pca_loadings and af are defined, and af > 0 and < 1
    mt = mt.filter_rows(
        hl.is_defined(mt.pca_loadings) & hl.is_defined(mt.pca_af)
        & (mt.pca_af > 0) & (mt.pca_af < 1))

    # Calculate genotype normalization constant
    # Basically, mean centers and normalizes the genotypes under the binomial distribution so that they can be
    # multiplied by the PC loadings to get the projected principal components
    gt_norm = (mt.GT.n_alt_alleles() - 2 * mt.pca_af) / hl.sqrt(
        n_variants * 2 * mt.pca_af * (1 - mt.pca_af))

    mt = mt.annotate_cols(scores=hl.agg.array_sum(mt.pca_loadings * gt_norm))

    return mt.cols().select('scores')
Пример #30
0
def compute_grouped_binned_ht(
    bin_ht: hl.Table,
    checkpoint_path: Optional[str] = None,
) -> hl.GroupedTable:
    """
    Group a Table that has been annotated with bins (`compute_ranked_bin` or `create_binned_ht`).

    The table will be grouped by bin_id (bin, biallelic, etc.), contig, snv, bi_allelic and singleton.

    .. note::

        If performing an aggregation following this grouping (such as `score_bin_agg`) then the aggregation
        function will need to use `ht._parent` to get the origin Table from the GroupedTable for the aggregation

    :param bin_ht: Input Table with a `bin_id` annotation
    :param checkpoint_path: If provided an intermediate checkpoint table is created with all required annotations before shuffling.
    :return: Table grouped by bins(s)
    """
    # Explode the rank table by bin_id
    bin_ht = bin_ht.annotate(
        bin_groups=hl.array(
            [
                hl.Struct(bin_id=bin_name, bin=bin_ht[bin_name])
                for bin_name in bin_ht.bin_group_variant_counts
            ]
        )
    )
    bin_ht = bin_ht.explode(bin_ht.bin_groups)
    bin_ht = bin_ht.transmute(
        bin_id=bin_ht.bin_groups.bin_id, bin=bin_ht.bin_groups.bin
    )
    bin_ht = bin_ht.filter(hl.is_defined(bin_ht.bin))

    if checkpoint_path is not None:
        bin_ht.checkpoint(checkpoint_path, overwrite=True)
    else:
        bin_ht = bin_ht.persist()

    # Group by bin_id, bin and additional stratification desired and compute QC metrics per bin
    return bin_ht.group_by(
        bin_id=bin_ht.bin_id,
        contig=bin_ht.locus.contig,
        snv=hl.is_snp(bin_ht.alleles[0], bin_ht.alleles[1]),
        bi_allelic=~bin_ht.was_split,
        singleton=bin_ht.singleton,
        release_adj=bin_ht.ac > 0,
        bin=bin_ht.bin,
    )._set_buffer_size(20000)
Пример #31
0
def get_all_sample_metadata(
    mt: hl.MatrixTable, build: int, data_type: str, data_source: str, version: int
) -> hl.Table:
    """
    Annotate MatrixTable with all current metadata: sample sequencing metrics, sample ID mapping,
    and callrate for bi-allelic, high-callrate common SNPs.
    :param MatrixTable mt: VCF converted to a MatrixTable
    :param int build: build for write, 37 or 38
    :param str data_type: WGS or WES for write path and flagging metrics
    :param str data_source: internal or external for write path
    :param int version: Int for write path
    :return: Table with seq metrics and mapping
    :rtype: Table
    """
    logger.info("Importing and annotating with sequencing metrics...")
    meta_ht = hl.import_table(
        seq_metrics_path(build, data_type, data_source, version), impute=True
    ).key_by("SAMPLE")

    logger.info("Importing and annotating seqr ID names...")
    remap_ht = hl.import_table(
        remap_path(build, data_type, data_source, version), impute=True
    ).key_by("s")
    meta_ht = meta_ht.annotate(**remap_ht[meta_ht.key])
    meta_ht = meta_ht.annotate(
        seqr_id=hl.if_else(
            hl.is_missing(meta_ht.seqr_id), meta_ht.SAMPLE, meta_ht.seqr_id
        )
    )

    logger.info(
        "Filtering to bi-allelic, high-callrate, common SNPs to calculate callrate..."
    )
    mt = filter_rows_for_qc(
        mt,
        min_af=0.001,
        min_callrate=0.99,
        bi_allelic_only=True,
        snv_only=True,
        apply_hard_filters=False,
        min_inbreeding_coeff_threshold=None,
        min_hardy_weinberg_threshold=None,
    )
    callrate_ht = mt.select_cols(
        filtered_callrate=hl.agg.fraction(hl.is_defined(mt.GT))
    ).cols()
    meta_ht = meta_ht.annotate(**callrate_ht[meta_ht.key])
    return meta_ht
Пример #32
0
def test_model(
    ht: hl.Table,
    rf_model: pyspark.ml.PipelineModel,
    features: List[str],
    label: str,
    prediction_col_name: str = "rf_prediction",
) -> List[hl.tstruct]:
    """
    A wrapper to test a model on a set of examples with known labels.

    1) Runs the model on the data
    2) Prints confusion matrix and accuracy
    3) Returns confusion matrix as a list of struct

    :param ht: Input table
    :param rf_model: RF Model
    :param features: Columns containing features that were used in the model
    :param label: Column containing label to be predicted
    :param prediction_col_name: Where to store the prediction
    :return: A list containing structs with {label, prediction, n}
    """

    ht = apply_rf_model(
        ht.filter(hl.is_defined(ht[label])),
        rf_model,
        features,
        label,
        prediction_col_name=prediction_col_name,
    )

    test_results = (
        ht.group_by(ht[prediction_col_name], ht[label])
        .aggregate(n=hl.agg.count())
        .collect()
    )

    # Print results
    df = pd.DataFrame(test_results)
    df = df.pivot(index=label, columns=prediction_col_name, values="n")
    logger.info("Testing results:\n{}".format(pprint.pformat(df)))
    logger.info(
        "Accuracy: {}".format(
            sum([x.n for x in test_results if x[label] == x[prediction_col_name]])
            / sum([x.n for x in test_results])
        )
    )

    return test_results
Пример #33
0
def test_pca_against_numpy():
    mt = hl.import_vcf(resource('tiny_m.vcf'))
    mt = mt.filter_rows(hl.len(mt.alleles) == 2)
    mt = mt.annotate_rows(AC=hl.agg.sum(mt.GT.n_alt_alleles()),
                          n_called=hl.agg.count_where(hl.is_defined(mt.GT)))
    mt = mt.filter_rows((mt.AC > 0) & (mt.AC < 2 * mt.n_called)).persist()
    n_rows = mt.count_rows()

    def make_expr(mean):
        return hl.if_else(hl.is_defined(mt.GT),
                          (mt.GT.n_alt_alleles() - mean) /
                          hl.sqrt(mean * (2 - mean) * n_rows / 2), 0)

    eigen, scores, loadings = hl.pca(hl.bind(make_expr, mt.AC / mt.n_called),
                                     k=3,
                                     compute_loadings=True)
    hail_scores = scores.explode('scores').scores.collect()
    hail_loadings = loadings.explode('loadings').loadings.collect()

    assert len(eigen) == 3
    assert scores.count() == mt.count_cols()
    assert loadings.count() == n_rows

    assert len(scores.globals) == 0
    assert len(loadings.globals) == 0

    # compute PCA with numpy
    def normalize(a):
        ms = np.mean(a, axis=0, keepdims=True)
        return np.divide(
            np.subtract(a, ms),
            np.sqrt(2.0 * np.multiply(ms / 2.0, 1 - ms / 2.0) * a.shape[1]))

    g = np.pad(np.diag([1.0, 1, 2]), ((0, 1), (0, 0)), mode='constant')
    g[1, 0] = 1.0 / 3
    n = normalize(g)
    U, s, V = np.linalg.svd(n, full_matrices=0)
    np_scores = U.dot(np.diag(s)).flatten()
    np_loadings = V.transpose().flatten()
    np_eigenvalues = np.multiply(s, s).flatten()

    np.testing.assert_allclose(eigen, np_eigenvalues, rtol=1e-5)
    np.testing.assert_allclose(np.abs(hail_scores),
                               np.abs(np_scores),
                               rtol=1e-5)
    np.testing.assert_allclose(np.abs(hail_loadings),
                               np.abs(np_loadings),
                               rtol=1e-5)
def query(output):  # pylint: disable=too-many-locals
    """Query script entry point."""

    hl.init(default_reference='GRCh38')

    hgdp_1kg = hl.read_matrix_table(GNOMAD_HGDP_1KG_MT)
    tob_wgs = hl.read_matrix_table(TOB_WGS).key_rows_by('locus', 'alleles')

    # filter to loci that are contained in both matrix tables after densifying
    tob_wgs = hl.experimental.densify(tob_wgs)

    # Entries and columns must be identical
    tob_wgs_select = tob_wgs.select_entries(
        GT=lgt_to_gt(tob_wgs.LGT, tob_wgs.LA))
    hgdp_1kg_select = hgdp_1kg.select_entries(hgdp_1kg.GT)
    hgdp_1kg_select = hgdp_1kg_select.select_cols()
    # Join datasets
    hgdp1kg_tobwgs_joined = hgdp_1kg_select.union_cols(tob_wgs_select)
    # Add in metadata information
    hgdp_1kg_metadata = hgdp_1kg.cols()
    hgdp1kg_tobwgs_joined = hgdp1kg_tobwgs_joined.annotate_cols(
        hgdp_1kg_metadata=hgdp_1kg_metadata[hgdp1kg_tobwgs_joined.s])

    # choose variants based off of gnomAD v3 parameters
    hgdp1kg_tobwgs_joined = hl.variant_qc(hgdp1kg_tobwgs_joined)
    hgdp1kg_tobwgs_joined = hgdp1kg_tobwgs_joined.annotate_rows(
        IB=hl.agg.inbreeding(hgdp1kg_tobwgs_joined.GT,
                             hgdp1kg_tobwgs_joined.variant_qc.AF[1]))
    hgdp1kg_tobwgs_joined = hgdp1kg_tobwgs_joined.filter_rows(
        (hl.len(hgdp1kg_tobwgs_joined.alleles) == 2)
        & (hgdp1kg_tobwgs_joined.locus.in_autosome())
        & (hgdp1kg_tobwgs_joined.variant_qc.AF[1] > 0.01)
        & (hgdp1kg_tobwgs_joined.variant_qc.call_rate > 0.99)
        & (hgdp1kg_tobwgs_joined.IB.f_stat > -0.25))

    hgdp1kg_tobwgs_joined = hgdp1kg_tobwgs_joined.cache()
    nrows = hgdp1kg_tobwgs_joined.count_rows()
    print(f'hgdp1kg_tobwgs_joined.count_rows() = {nrows}')
    hgdp1kg_tobwgs_joined = hgdp1kg_tobwgs_joined.sample_rows(
        NUM_ROWS_BEFORE_LD_PRUNE / nrows, seed=12345)

    pruned_variant_table = hl.ld_prune(hgdp1kg_tobwgs_joined.GT,
                                       r2=0.1,
                                       bp_window_size=500000)
    hgdp1kg_tobwgs_joined = hgdp1kg_tobwgs_joined.filter_rows(
        hl.is_defined(pruned_variant_table[hgdp1kg_tobwgs_joined.row_key]))
    mt_path = f'{output}/tob_wgs_hgdp_1kg_filtered_variants.mt'
    hgdp1kg_tobwgs_joined.write(mt_path)
Пример #35
0
def get_doubleton_sites(
    vds_path: str = VDS_PATH,
    temp_path: str = TEMP_PATH,
    tranche_data: Tuple[str, int] = TRANCHE_DATA,
    sparse_entries: List[str] = SPARSE_ENTRIES,
) -> hl.Table:
    """
    Filter UKB VDS to bi-allelic, autosomal sites in interval QC pass regions with an adj allele count of two and no homozygotes.

    :param vds_path: Path to UKB 455k VDS. Default is VDS_PATH.
    :param temp_path: Path to bucket to store Table and other temporary data. Default is TEMP_PATH.
    :param tranche_data: UKB tranche data (data source and data freeze number). Default is TRANCHE_DATA.
    :param sparse_entries: List of fields to select from VDS. Default is SPARSE_ENTRIES.
    :return: Table of high quality sites with doubletons.
    """
    logger.info("Reading in VDS and filtering to bi-allelic SNPs...")
    mt = hl.vds.read_vds(vds_path).variant_data
    # Drop unnecessary annotations
    mt = mt.select_rows().select_entries(*sparse_entries)
    mt = mt.filter_rows(
        bi_allelic_expr(mt) & hl.is_snp(mt.alleles[0], mt.alleles[1]))

    logger.info("Filter to autosomes and splitting multiallelics...")
    mt = mt.filter_rows(mt.locus.in_autosome())
    # NOTE: UKB dataset does not have errors with changed loci
    # (`filter_changed_loci = False` will not throw errors here)
    mt = hl.experimental.sparse_split_multi(mt)

    logger.info("Removing AS_lowqual sites...")
    info_ht = hl.read_table(info_ht_path(*tranche_data, split=True))
    mt = mt.filter_rows(~info_ht[mt.row_key].AS_lowqual)

    logger.info("Filtering to interval QC pass regions...")
    interval_ht = hl.read_table(interval_qc_path(*tranche_data, "autosomes"))
    mt = mt.filter_rows(hl.is_defined(interval_ht[mt.locus]))

    logger.info("Filtering to adj and calculating allele count...")
    mt = filter_to_adj(mt)
    mt = mt.annotate_rows(call_stats=hl.agg.call_stats(mt.GT, mt.alleles))
    # Get AC at allele index 1 (call_stats includes a count for each allele, including reference)
    mt = mt.transmute_rows(ac=mt.call_stats.AC[1],
                           n_hom=mt.call_stats.homozygote_count[1])

    logger.info("Filtering to an allele count of two and returning...")
    ht = mt.rows()
    ht = ht.filter((ht.ac == 2) & (ht.n_hom == 0))
    ht = ht.checkpoint(f"{temp_path}/high_quality_sites.ht", overwrite=True)
    return ht
Пример #36
0
def pca_filter_mt(in_mt: hl.MatrixTable,
                  maf: float = 0.05,
                  hwe: float = 1e-3,
                  call_rate: float = 0.98,
                  ld_cor: float = 0.2,
                  ld_window: int = 250000):

    print("\nInitial number of SNPs before filtering: {}".format(
        in_mt.count_rows()))
    mt = hl.variant_qc(in_mt)
    print(f'\nFiltering out variants with MAF < {maf}')
    mt_filt = mt.annotate_rows(maf=hl.min(mt.variant_qc.AF))
    mt_filt = mt_filt.filter_rows(mt_filt.maf > maf)

    print(f'\nFiltering out variants with HWE < {hwe:1e}')
    mt_filt = mt_filt.filter_rows(mt_filt.variant_qc.p_value_hwe > hwe)

    print(f'\nFiltering out variants with Call Rate < {call_rate}')
    mt_filt = mt_filt.filter_rows(mt_filt.variant_qc.call_rate >= call_rate)

    # no strand ambiguity
    print('\nFiltering out strand ambigous variants')
    mt_filt = mt_filt.filter_rows(
        ~hl.is_strand_ambiguous(mt_filt.alleles[0], mt_filt.alleles[1]))

    # MHC chr6:25-35Mb
    # chr8.inversion chr8:7-13Mb
    print(
        '\nFiltering out variants in MHC [chr6:25M-35M] and chromosome 8 inversions [chr8:7M-13M]'
    )
    intervals = ['chr6:25M-35M', 'chr8:7M-13M']
    mt_filt = hl.filter_intervals(mt_filt, [
        hl.parse_locus_interval(x, reference_genome='GRCh38')
        for x in intervals
    ],
                                  keep=False)

    # This step is expensive (on local machine)
    print(
        f'\nLD pruning using correlation threshold of {ld_cor} and window size of {ld_window}'
    )
    mt_ld_prune = hl.ld_prune(mt_filt.GT, r2=ld_cor, bp_window_size=ld_window)
    mt_ld_pruned = mt_filt.filter_rows(
        hl.is_defined(mt_ld_prune[mt_filt.row_key]))
    print("\nNumber of SNPs after filtering: {}".format(
        mt_ld_pruned.count_rows()))

    return mt_ld_pruned
def total_ac_or_af(variant, field):
    return hl.cond(
        variant.type == "MCNV",
        hl.bind(
            lambda cn2_index: hl.bind(
                lambda values_to_sum: values_to_sum.fold(lambda acc, n: acc + n, 0),
                hl.cond(
                    hl.is_defined(cn2_index),
                    field[0:cn2_index].extend(field[cn2_index + 1 :]),
                    field,
                ),
            ),
            hl.zip_with_index(variant.alts).find(lambda t: t[1] == "<CN=2>")[0],
        ),
        field[0],
    )
Пример #38
0
def lift_data(
    t: Union[hl.MatrixTable, hl.Table],
    gnomad: bool,
    data_type: str,
    path: str,
    rg: hl.genetics.ReferenceGenome,
    overwrite: bool,
) -> Union[hl.MatrixTable, hl.Table]:
    """
    Lifts input Table or MatrixTable from one reference build to another

    :param t: Table or MatrixTable
    :param gnomad: Whether data is gnomAD data
    :param data_type: Data type (exomes or genomes for gnomAD; not used otherwise)
    :param path: Path to input Table/MatrixTable (if data is not gnomAD data)
    :param rg: Reference genome
    :param overwrite: Whether to overwrite data
    :return: Table or MatrixTablewith liftover annotations
    """

    logger.info("Annotating input with liftover coordinates")
    liftover_expr = {
        "new_locus": hl.liftover(t.locus, rg, include_strand=True),
        "old_locus": t.locus,
    }
    t = (t.annotate(**liftover_expr)
         if isinstance(t, hl.Table) else t.annotate_rows(**liftover_expr))

    no_target_expr = hl.agg.count_where(hl.is_missing(t.new_locus))
    num_no_target = (t.aggregate(no_target_expr) if isinstance(t, hl.Table)
                     else t.aggregate_rows(no_target_expr))

    logger.info(f"Filtering out {num_no_target} sites that failed to liftover")
    keep_expr = hl.is_defined(t.new_locus)
    t = t.filter(keep_expr) if isinstance(
        t, hl.Table) else t.filter_rows(keep_expr)

    row_key_expr = {"locus": t.new_locus.result, "alleles": t.alleles}
    t = (t.key_by(**row_key_expr)
         if isinstance(t, hl.Table) else t.key_rows_by(**row_key_expr))

    logger.info("Writing out lifted over data")
    t = t.checkpoint(
        get_checkpoint_path(gnomad, data_type, path, isinstance(t, hl.Table)),
        overwrite=overwrite,
    )
    return t
Пример #39
0
def subset_samples_and_variants(
    mt: hl.MatrixTable,
    sample_path: str,
    header: bool = True,
    table_key: str = "s",
    sparse: bool = False,
    gt_expr: str = "GT",
) -> hl.MatrixTable:
    """
    Subset the MatrixTable to the provided list of samples and their variants.

    :param mt: Input MatrixTable
    :param sample_path: Path to a file with list of samples
    :param header: Whether file with samples has a header. Default is True
    :param table_key: Key to sample Table. Default is "s"
    :param sparse: Whether the MatrixTable is sparse. Default is False
    :param gt_expr: Name of field in MatrixTable containing genotype expression. Default is "GT"
    :return: MatrixTable subsetted to specified samples and their variants
    """
    sample_ht = hl.import_table(sample_path,
                                no_header=not header,
                                key=table_key)
    sample_count = sample_ht.count()
    missing_ht = sample_ht.anti_join(mt.cols())
    missing_ht_count = missing_ht.count()
    full_count = mt.count_cols()

    if missing_ht_count != 0:
        missing_samples = missing_ht.s.collect()
        raise DataException(
            f"Only {sample_count - missing_ht_count} out of {sample_count} "
            "subsetting-table IDs matched IDs in the MT.\n"
            f"IDs that aren't in the MT: {missing_samples}\n")

    mt = mt.semi_join_cols(sample_ht)
    if sparse:
        mt = mt.filter_rows(
            hl.agg.any(mt[gt_expr].is_non_ref() | hl.is_defined(mt.END)))
    else:
        mt = mt.filter_rows(hl.agg.any(mt[gt_expr].is_non_ref()))

    logger.info(
        "Finished subsetting samples. Kept %d out of %d samples in MT",
        mt.count_cols(),
        full_count,
    )
    return mt
Пример #40
0
def test_vcf_vds_combiner_equivalence():
    import hail.experimental.vcf_combiner.vcf_combiner as vcf
    import hail.vds.combiner as vds
    _paths = ['gvcfs/HG00096.g.vcf.gz', 'gvcfs/HG00268.g.vcf.gz']
    paths = [resource(p) for p in _paths]
    parts = [
        hl.Interval(
            start=hl.Struct(
                locus=hl.Locus('chr20', 17821257, reference_genome='GRCh38')),
            end=hl.Struct(
                locus=hl.Locus('chr20', 18708366, reference_genome='GRCh38')),
            includes_end=True),
        hl.Interval(
            start=hl.Struct(
                locus=hl.Locus('chr20', 18708367, reference_genome='GRCh38')),
            end=hl.Struct(
                locus=hl.Locus('chr20', 19776611, reference_genome='GRCh38')),
            includes_end=True),
        hl.Interval(
            start=hl.Struct(
                locus=hl.Locus('chr20', 19776612, reference_genome='GRCh38')),
            end=hl.Struct(
                locus=hl.Locus('chr20', 21144633, reference_genome='GRCh38')),
            includes_end=True)
    ]
    vcfs = [
        mt.annotate_rows(
            info=mt.info.annotate(MQ_DP=hl.missing(hl.tint32),
                                  VarDP=hl.missing(hl.tint32),
                                  QUALapprox=hl.missing(hl.tint32)))
        for mt in hl.import_gvcfs(paths,
                                  parts,
                                  reference_genome='GRCh38',
                                  array_elements_required=False)
    ]
    entry_to_keep = defined_entry_fields(
        vcfs[0].filter_rows(hl.is_defined(vcfs[0].info.END)),
        100_000) - {'GT', 'PGT', 'PL'}
    vds = vds.combine_variant_datasets([
        vds.transform_gvcf(mt, reference_entry_fields_to_keep=entry_to_keep)
        for mt in vcfs
    ])
    smt = vcf.combine_gvcfs([vcf.transform_gvcf(mt) for mt in vcfs])
    smt_from_vds = hl.vds.to_merged_sparse_mt(vds).drop('RGQ')
    smt = smt.select_entries(*smt_from_vds.entry)  # harmonize fields and order
    smt = smt.key_rows_by('locus', 'alleles')
    assert smt._same(smt_from_vds)
Пример #41
0
    def phase_diploid_proband(
            locus: hl.expr.LocusExpression,
            alleles: hl.expr.ArrayExpression,
            proband_call: hl.expr.CallExpression,
            father_call: hl.expr.CallExpression,
            mother_call: hl.expr.CallExpression
    ) -> hl.expr.ArrayExpression:
        """
        Returns phased genotype calls in the case of a diploid proband
        (autosomes, PAR regions of sex chromosomes or non-PAR regions of a female proband)

        :param LocusExpression locus: Locus in the trio MatrixTable
        :param ArrayExpression alleles: Alleles in the trio MatrixTable
        :param CallExpression proband_call: Input proband genotype call
        :param CallExpression father_call: Input father genotype call
        :param CallExpression mother_call: Input mother genotype call
        :return: Array containing: phased proband call, phased father call, phased mother call
        :rtype: ArrayExpression
        """

        proband_v = proband_call.one_hot_alleles(alleles)
        father_v = hl.cond(
            locus.in_x_nonpar() | locus.in_y_nonpar(),
            hl.or_missing(father_call.is_haploid(), hl.array([father_call.one_hot_alleles(alleles)])),
            call_to_one_hot_alleles_array(father_call, alleles)
        )
        mother_v = call_to_one_hot_alleles_array(mother_call, alleles)

        combinations = hl.flatmap(
            lambda f:
            hl.zip_with_index(mother_v)
                .filter(lambda m: m[1] + f[1] == proband_v)
                .map(lambda m: hl.struct(m=m[0], f=f[0])),
            hl.zip_with_index(father_v)
        )

        return (
            hl.or_missing(
                hl.is_defined(combinations) & (hl.len(combinations) == 1),
                hl.array([
                    hl.call(father_call[combinations[0].f], mother_call[combinations[0].m], phased=True),
                    hl.cond(father_call.is_haploid(), hl.call(father_call[0], phased=True), phase_parent_call(father_call, combinations[0].f)),
                    phase_parent_call(mother_call, combinations[0].m)
                ])
            )
        )
Пример #42
0
            def with_pl(pl):
                new_exprs = {}
                dropped_fields = ['LA']
                if 'LGT' in fields:
                    new_exprs['GT'] = hl.downcode(old_entry.LGT, hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LGT')
                if 'LPGT' in fields:
                    new_exprs['PGT'] = hl.downcode(old_entry.LPGT, hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LPGT')
                if 'LAD' in fields:
                    new_exprs['AD'] = hl.or_missing(
                        hl.is_defined(old_entry.LAD),
                        [old_entry.LAD[0], hl.or_else(old_entry.LAD[local_a_index], 0)]) # second entry zeroed for lack of non-ref AD
                    dropped_fields.append('LAD')
                if 'LPL' in fields:
                    new_exprs['PL'] = pl
                    if 'GQ' in fields:
                        new_exprs['GQ'] = hl.or_else(hl.gq_from_pl(pl), old_entry.GQ)

                    dropped_fields.append('LPL')

                return hl.cond(hl.len(ds.alleles) == 1,
                                   old_entry.annotate(**{f[1:]: old_entry[f] for f in ['LGT', 'LPGT', 'LAD', 'LPL'] if f in fields}).drop(*dropped_fields),
                                   old_entry.annotate(**new_exprs).drop(*dropped_fields))
Пример #43
0
    def phase_haploid_proband_x_nonpar(
            proband_call: hl.expr.CallExpression,
            father_call: hl.expr.CallExpression,
            mother_call: hl.expr.CallExpression
    ) -> hl.expr.ArrayExpression:
        """
        Returns phased genotype calls in the case of a haploid proband in the non-PAR region of X

        :param CallExpression proband_call: Input proband genotype call
        :param CallExpression father_call: Input father genotype call
        :param CallExpression mother_call: Input mother genotype call
        :return: Array containing: phased proband call, phased father call, phased mother call
        :rtype: ArrayExpression
        """

        transmitted_allele = hl.zip_with_index(hl.array([mother_call[0], mother_call[1]])).find(lambda m: m[1] == proband_call[0])
        return hl.or_missing(
            hl.is_defined(transmitted_allele),
            hl.array([
                hl.call(proband_call[0], phased=True),
                hl.or_missing(father_call.is_haploid(), hl.call(father_call[0], phased=True)),
                phase_parent_call(mother_call, transmitted_allele[0])
            ])
        )
Пример #44
0
def histogram2d(x, y, bins=40, range=None,
                 title=None, width=600, height=600, font_size='7pt',
                 colors=bokeh.palettes.all_palettes['Blues'][7][::-1]):
    """Plot a two-dimensional histogram.

    ``x`` and ``y`` must both be a :class:`NumericExpression` from the same :class:`Table`.

    If ``x_range`` or ``y_range`` are not provided, the function will do a pass through the data to determine
    min and max of each variable.

    Examples
    --------

    >>> ht = hail.utils.range_table(1000).annotate(x=hail.rand_norm(), y=hail.rand_norm())
    >>> p_hist = hail.plot.histogram2d(ht.x, ht.y)

    >>> ht = hail.utils.range_table(1000).annotate(x=hail.rand_norm(), y=hail.rand_norm())
    >>> p_hist = hail.plot.histogram2d(ht.x, ht.y, bins=10, range=((0, 1), None))

    Parameters
    ----------
    x : :class:`.NumericExpression`
        Expression for x-axis (from a Hail table).
    y : :class:`.NumericExpression`
        Expression for y-axis (from the same Hail table as ``x``).
    bins : int or [int, int]
        The bin specification:
        -   If int, the number of bins for the two dimensions (nx = ny = bins).
        -   If [int, int], the number of bins in each dimension (nx, ny = bins).
        The default value is 40.
    range : None or ((float, float), (float, float))
        The leftmost and rightmost edges of the bins along each dimension:
        ((xmin, xmax), (ymin, ymax)). All values outside of this range will be considered outliers
        and not tallied in the histogram. If this value is None, or either of the inner lists is None,
        the range will be computed from the data.
    width : int
        Plot width (default 600px).
    height : int
        Plot height (default 600px).
    title : str
        Title of the plot.
    font_size : str
        String of font size in points (default '7pt').
    colors : List[str]
        List of colors (hex codes, or strings as described
        `here <https://bokeh.pydata.org/en/latest/docs/reference/colors.html>`__). Compatible with one of the many
        built-in palettes available `here <https://bokeh.pydata.org/en/latest/docs/reference/palettes.html>`__.

    Returns
    -------
    :class:`bokeh.plotting.figure.Figure`
    """
    source = x._indices.source
    y_source = y._indices.source

    if source is None or y_source is None:
        raise ValueError("histogram_2d expects two expressions of 'Table', found scalar expression")
    if isinstance(source, hail.MatrixTable):
        raise ValueError("histogram_2d requires source to be Table, not MatrixTable")
    if source != y_source:
        raise ValueError(f"histogram_2d expects two expressions from the same 'Table', found {source} and {y_source}")
    check_row_indexed('histogram_2d', x)
    check_row_indexed('histogram_2d', y)
    if isinstance(bins, int):
        x_bins = y_bins = bins
    else:
        x_bins, y_bins = bins
    if range is None:
        x_range = y_range = None
    else:
        x_range, y_range = range
    if x_range is None or y_range is None:
        warnings.warn('At least one range was not defined in histogram_2d. Doing two passes...')
        ranges = source.aggregate(hail.struct(x_stats=hail.agg.stats(x),
                                              y_stats=hail.agg.stats(y)))
        if x_range is None:
            x_range = (ranges.x_stats.min, ranges.x_stats.max)
        if y_range is None:
            y_range = (ranges.y_stats.min, ranges.y_stats.max)
    else:
        warnings.warn('If x_range or y_range are specified in histogram_2d, and there are points '
                      'outside of these ranges, they will not be plotted')
    x_range = list(map(float, x_range))
    y_range = list(map(float, y_range))
    x_spacing = (x_range[1] - x_range[0]) / x_bins
    y_spacing = (y_range[1] - y_range[0]) / y_bins

    def frange(start, stop, step):
        from itertools import count, takewhile
        return takewhile(lambda x: x <= stop, count(start, step))

    x_levels = hail.literal(list(frange(x_range[0], x_range[1], x_spacing))[::-1])
    y_levels = hail.literal(list(frange(y_range[0], y_range[1], y_spacing))[::-1])

    grouped_ht = source.group_by(
        x=hail.str(x_levels.find(lambda w: x >= w)),
        y=hail.str(y_levels.find(lambda w: y >= w))
    ).aggregate(c=hail.agg.count())
    data = grouped_ht.filter(hail.is_defined(grouped_ht.x) & (grouped_ht.x != str(x_range[1])) &
                             hail.is_defined(grouped_ht.y) & (grouped_ht.y != str(y_range[1]))).to_pandas()

    mapper = LinearColorMapper(palette=colors, low=data.c.min(), high=data.c.max())

    x_axis = sorted(set(data.x), key=lambda z: float(z))
    y_axis = sorted(set(data.y), key=lambda z: float(z))
    p = figure(title=title,
               x_range=x_axis, y_range=y_axis,
               x_axis_location="above", plot_width=width, plot_height=height,
               tools="hover,save,pan,box_zoom,reset,wheel_zoom", toolbar_location='below')

    p.grid.grid_line_color = None
    p.axis.axis_line_color = None
    p.axis.major_tick_line_color = None
    p.axis.major_label_standoff = 0
    p.axis.major_label_text_font_size = font_size
    import math
    p.xaxis.major_label_orientation = math.pi / 3

    p.rect(x='x', y='y', width=1, height=1,
           source=data,
           fill_color={'field': 'c', 'transform': mapper},
           line_color=None)

    color_bar = ColorBar(color_mapper=mapper, major_label_text_font_size=font_size,
                         ticker=BasicTicker(desired_num_ticks=6),
                         label_standoff=6, border_line_color=None, location=(0, 0))
    p.add_layout(color_bar, 'right')

    def set_font_size(p, font_size: str = '12pt'):
        """Set most of the font sizes in a bokeh figure

        Parameters
        ----------
        p : :class:`bokeh.plotting.figure.Figure`
            Input figure.
        font_size : str
            String of font size in points (e.g. '12pt').

        Returns
        -------
        :class:`bokeh.plotting.figure.Figure`
        """
        p.legend.label_text_font_size = font_size
        p.xaxis.axis_label_text_font_size = font_size
        p.yaxis.axis_label_text_font_size = font_size
        p.xaxis.major_label_text_font_size = font_size
        p.yaxis.major_label_text_font_size = font_size
        if hasattr(p.title, 'text_font_size'):
            p.title.text_font_size = font_size
        if hasattr(p.xaxis, 'group_text_font_size'):
            p.xaxis.group_text_font_size = font_size
        return p

    p.select_one(HoverTool).tooltips = [('x', '@x'), ('y', '@y',), ('count', '@c')]
    p = set_font_size(p, font_size)
    return p
Пример #45
0
def full_outer_join_mt(left: hl.MatrixTable, right: hl.MatrixTable) -> hl.MatrixTable:
    """Performs a full outer join on `left` and `right`.

    Replaces row, column, and entry fields with the following:

     - `left_row` / `right_row`: structs of row fields from left and right.
     - `left_col` / `right_col`: structs of column fields from left and right.
     - `left_entry` / `right_entry`: structs of entry fields from left and right.

    Parameters
    ----------
    left : :class:`.MatrixTable`
    right : :class:`.MatrixTable`

    Returns
    -------
    :class:`.MatrixTable`
    """

    if [x.dtype for x in left.row_key.values()] != [x.dtype for x in right.row_key.values()]:
        raise ValueError(f"row key types do not match:\n"
                         f"  left:  {list(left.row_key.values())}\n"
                         f"  right: {list(right.row_key.values())}")

    if [x.dtype for x in left.col_key.values()] != [x.dtype for x in right.col_key.values()]: 
        raise ValueError(f"column key types do not match:\n"
                         f"  left:  {list(left.col_key.values())}\n"
                         f"  right: {list(right.col_key.values())}")

    left = left.select_rows(left_row=left.row)
    left_t = left.localize_entries('left_entries', 'left_cols')
    right = right.select_rows(right_row=right.row)
    right_t = right.localize_entries('right_entries', 'right_cols')

    ht = left_t.join(right_t, how='outer')
    ht = ht.annotate_globals(
        left_keys=hl.group_by(
            lambda t: t[0],
            hl.zip_with_index(
                ht.left_cols.map(lambda x: hl.tuple([x[f] for f in left.col_key])), index_first=False)).map_values(
            lambda elts: elts.map(lambda t: t[1])),
        right_keys=hl.group_by(
            lambda t: t[0],
            hl.zip_with_index(
                ht.right_cols.map(lambda x: hl.tuple([x[f] for f in right.col_key])), index_first=False)).map_values(
            lambda elts: elts.map(lambda t: t[1])))
    ht = ht.annotate_globals(
        key_indices=hl.array(ht.left_keys.key_set().union(ht.right_keys.key_set()))
            .map(lambda k: hl.struct(k=k, left_indices=ht.left_keys.get(k), right_indices=ht.right_keys.get(k)))
            .flatmap(lambda s: hl.case()
                     .when(hl.is_defined(s.left_indices) & hl.is_defined(s.right_indices),
                           hl.range(0, s.left_indices.length()).flatmap(
                               lambda i: hl.range(0, s.right_indices.length()).map(
                                   lambda j: hl.struct(k=s.k, left_index=s.left_indices[i],
                                                       right_index=s.right_indices[j]))))
                     .when(hl.is_defined(s.left_indices),
                           s.left_indices.map(
                               lambda elt: hl.struct(k=s.k, left_index=elt, right_index=hl.null('int32'))))
                     .when(hl.is_defined(s.right_indices),
                           s.right_indices.map(
                               lambda elt: hl.struct(k=s.k, left_index=hl.null('int32'), right_index=elt)))
                     .or_error('assertion error')))
    ht = ht.annotate(__entries=ht.key_indices.map(lambda s: hl.struct(left_entry=ht.left_entries[s.left_index],
                                                                      right_entry=ht.right_entries[s.right_index])))
    ht = ht.annotate_globals(__cols=ht.key_indices.map(
        lambda s: hl.struct(**{f: s.k[i] for i, f in enumerate(left.col_key)},
                            left_col=ht.left_cols[s.left_index],
                            right_col=ht.right_cols[s.right_index])))
    ht = ht.drop('left_entries', 'left_cols', 'left_keys', 'right_entries', 'right_cols', 'right_keys', 'key_indices')
    return ht._unlocalize_entries('__entries', '__cols', list(left.col_key))
Пример #46
0
Файл: qc.py Проект: jigold/hail
def variant_qc(mt, name='variant_qc') -> MatrixTable:
    """Compute common variant statistics (quality control metrics).

    .. include:: ../_templates/req_tvariant.rst

    Examples
    --------

    >>> dataset_result = hl.variant_qc(dataset)

    Notes
    -----
    This method computes variant statistics from the genotype data, returning
    a new struct field `name` with the following metrics based on the fields
    present in the entry schema.

    If `mt` contains an entry field `DP` of type :py:data:`.tint32`, then the
    field `dp_stats` is computed. If `mt` contains an entry field `GQ` of type
    :py:data:`.tint32`, then the field `gq_stats` is computed. Both `dp_stats`
    and `gq_stats` are structs with with four fields:

    - `mean` (``float64``) -- Mean value.
    - `stdev` (``float64``) -- Standard deviation (zero degrees of freedom).
    - `min` (``int32``) -- Minimum value.
    - `max` (``int32``) -- Maximum value.

    If the dataset does not contain an entry field `GT` of type
    :py:data:`.tcall`, then an error is raised. The following fields are always
    computed from `GT`:

    - `AF` (``array<float64>``) -- Calculated allele frequency, one element
      per allele, including the reference. Sums to one. Equivalent to
      `AC` / `AN`.
    - `AC` (``array<int32>``) -- Calculated allele count, one element per
      allele, including the reference. Sums to `AN`.
    - `AN` (``int32``) -- Total number of called alleles.
    - `homozygote_count` (``array<int32>``) -- Number of homozygotes per
      allele. One element per allele, including the reference.
    - `call_rate` (``float64``) -- Fraction of calls neither missing nor filtered.
       Equivalent to `n_called` / :meth:`.count_cols`.
    - `n_called` (``int64``) -- Number of samples with a defined `GT`.
    - `n_not_called` (``int64``) -- Number of samples with a missing `GT`.
    - `n_filtered` (``int64``) -- Number of filtered entries.
    - `n_het` (``int64``) -- Number of heterozygous samples.
    - `n_non_ref` (``int64``) -- Number of samples with at least one called
      non-reference allele.
    - `het_freq_hwe` (``float64``) -- Expected frequency of heterozygous
      samples under Hardy-Weinberg equilibrium. See
      :func:`.functions.hardy_weinberg_test` for details.
    - `p_value_hwe` (``float64``) -- p-value from test of Hardy-Weinberg equilibrium.
      See :func:`.functions.hardy_weinberg_test` for details.

    Warning
    -------
    `het_freq_hwe` and `p_value_hwe` are calculated as in
    :func:`.functions.hardy_weinberg_test`, with non-diploid calls
    (``ploidy != 2``) ignored in the counts. As this test is only
    statistically rigorous in the biallelic setting, :func:`.variant_qc`
    sets both fields to missing for multiallelic variants. Consider using
    :func:`~hail.methods.split_multi` to split multi-allelic variants beforehand.

    Parameters
    ----------
    mt : :class:`.MatrixTable`
        Dataset.
    name : :obj:`str`
        Name for resulting field.

    Returns
    -------
    :class:`.MatrixTable`
    """
    require_row_key_variant(mt, 'variant_qc')

    bound_exprs = {}
    gq_dp_exprs = {}

    def has_field_of_type(name, dtype):
        return name in mt.entry and mt[name].dtype == dtype

    if has_field_of_type('DP', hl.tint32):
        gq_dp_exprs['dp_stats'] = hl.agg.stats(mt.DP).select('mean', 'stdev', 'min', 'max')

    if has_field_of_type('GQ', hl.tint32):
        gq_dp_exprs['gq_stats'] = hl.agg.stats(mt.GQ).select('mean', 'stdev', 'min', 'max')

    if not has_field_of_type('GT',  hl.tcall):
        raise ValueError(f"'variant_qc': expect an entry field 'GT' of type 'call'")

    bound_exprs['n_called'] = hl.agg.count_where(hl.is_defined(mt['GT']))
    bound_exprs['n_not_called'] = hl.agg.count_where(hl.is_missing(mt['GT']))
    bound_exprs['n_filtered'] = mt.count_cols(_localize=False) - hl.agg.count()
    bound_exprs['call_stats'] = hl.agg.call_stats(mt.GT, mt.alleles)

    result = hl.rbind(hl.struct(**bound_exprs),
                      lambda e1: hl.rbind(
                          hl.case().when(hl.len(mt.alleles) == 2,
                                         hl.hardy_weinberg_test(e1.call_stats.homozygote_count[0],
                                                                e1.call_stats.AC[1] - 2 *
                                                                e1.call_stats.homozygote_count[1],
                                                                e1.call_stats.homozygote_count[1])
                                         ).or_missing(),
                          lambda hwe: hl.struct(**{
                              **gq_dp_exprs,
                              **e1.call_stats,
                              'call_rate': hl.float(e1.n_called) / (e1.n_called + e1.n_not_called + e1.n_filtered),
                              'n_called': e1.n_called,
                              'n_not_called': e1.n_not_called,
                              'n_filtered': e1.n_filtered,
                              'n_het': e1.n_called - hl.sum(e1.call_stats.homozygote_count),
                              'n_non_ref': e1.n_called - e1.call_stats.homozygote_count[0],
                              'het_freq_hwe': hwe.het_freq_hwe,
                              'p_value_hwe': hwe.p_value})))

    return mt.annotate_rows(**{name: result})
Пример #47
0
def variant_qc(mt, name='variant_qc') -> MatrixTable:
    """Compute common variant statistics (quality control metrics).

    .. include:: ../_templates/req_tvariant.rst

    Examples
    --------

    >>> dataset_result = hl.variant_qc(dataset)

    Notes
    -----
    This method computes variant statistics from the genotype data, returning
    a new struct field `name` with the following metrics based on the fields
    present in the entry schema.

    If `mt` contains an entry field `DP` of type :py:data:`.tint32`, then the
    field `dp_stats` is computed. If `mt` contains an entry field `GQ` of type
    :py:data:`.tint32`, then the field `gq_stats` is computed. Both `dp_stats`
    and `gq_stats` are structs with with four fields:

    - `mean` (``float64``) -- Mean value.
    - `stdev` (``float64``) -- Standard deviation (zero degrees of freedom).
    - `min` (``int32``) -- Minimum value.
    - `max` (``int32``) -- Maximum value.

    If the dataset does not contain an entry field `GT` of type
    :py:data:`.tcall`, then an error is raised. The following fields are always
    computed from `GT`:

    - `AF` (``array<float64>``) -- Calculated allele frequency, one element
      per allele, including the reference. Sums to one. Equivalent to
      `AC` / `AN`.
    - `AC` (``array<int32>``) -- Calculated allele count, one element per
      allele, including the reference. Sums to `AN`.
    - `AN` (``int32``) -- Total number of called alleles.
    - `homozygote_count` (``array<int32>``) -- Number of homozygotes per
      allele. One element per allele, including the reference.
    - `n_called` (``int64``) -- Number of samples with a defined `GT`.
    - `n_not_called` (``int64``) -- Number of samples with a missing `GT`.
    - `call_rate` (``float32``) -- Fraction of samples with a defined `GT`.
      Equivalent to `n_called` / :meth:`.count_cols`.
    - `n_het` (``int64``) -- Number of heterozygous samples.
    - `n_non_ref` (``int64``) -- Number of samples with at least one called
      non-reference allele.
    - `het_freq_hwe` (``float64``) -- Expected frequency of heterozygous
      samples under Hardy-Weinberg equilibrium. See
      :func:`.functions.hardy_weinberg_test` for details.
    - `p_value_hwe` (``float64``) -- p-value from test of Hardy-Weinberg equilibrium.
      See :func:`.functions.hardy_weinberg_test` for details.

    Warning
    -------
    `het_freq_hwe` and `p_value_hwe` are calculated as in
    :func:`.functions.hardy_weinberg_test`, with non-diploid calls
    (``ploidy != 2``) ignored in the counts. As this test is only
    statistically rigorous in the biallelic setting, :func:`.variant_qc`
    sets both fields to missing for multiallelic variants. Consider using
    :func:`~hail.methods.split_multi` to split multi-allelic variants beforehand.

    Parameters
    ----------
    mt : :class:`.MatrixTable`
        Dataset.
    name : :obj:`str`
        Name for resulting field.

    Returns
    -------
    :class:`.MatrixTable`
    """
    require_row_key_variant(mt, 'variant_qc')

    exprs = {}
    struct_exprs = []

    def has_field_of_type(name, dtype):
        return name in mt.entry and mt[name].dtype == dtype

    n_samples = mt.count_cols()

    if has_field_of_type('DP', hl.tint32):
        exprs['dp_stats'] = hl.agg.stats(mt.DP).select('mean', 'stdev', 'min', 'max')

    if has_field_of_type('GQ', hl.tint32):
        exprs['gq_stats'] = hl.agg.stats(mt.GQ).select('mean', 'stdev', 'min', 'max')

    if not has_field_of_type('GT',  hl.tcall):
        raise ValueError(f"'variant_qc': expect an entry field 'GT' of type 'call'")
    exprs['n_called'] = hl.agg.count_where(hl.is_defined(mt['GT']))
    struct_exprs.append(hl.agg.call_stats(mt.GT, mt.alleles))


    # the structure of this function makes it easy to add new nested computations
    def flatten_struct(*struct_exprs):
        flat = {}
        for struct in struct_exprs:
            for k, v in struct.items():
                flat[k] = v
        return hl.struct(
            **flat,
            **exprs,
        )

    mt = mt.annotate_rows(**{name: hl.bind(flatten_struct, *struct_exprs)})

    hwe = hl.hardy_weinberg_test(mt[name].homozygote_count[0],
                                 mt[name].AC[1] - 2 * mt[name].homozygote_count[1],
                                 mt[name].homozygote_count[1])
    hwe = hwe.select(het_freq_hwe=hwe.het_freq_hwe, p_value_hwe=hwe.p_value)
    mt = mt.annotate_rows(**{name: mt[name].annotate(n_not_called=n_samples - mt[name].n_called,
                                                     call_rate=mt[name].n_called / n_samples,
                                                     n_het=mt[name].n_called - hl.sum(mt[name].homozygote_count),
                                                     n_non_ref=mt[name].n_called - mt[name].homozygote_count[0],
                                                     **hl.cond(hl.len(mt.alleles) == 2,
                                                               hwe,
                                                               hl.null(hwe.dtype)))})
    return mt
Пример #48
0
def sample_qc(mt, name='sample_qc') -> MatrixTable:
    """Compute per-sample metrics useful for quality control.

    .. include:: ../_templates/req_tvariant.rst

    Examples
    --------

    Compute sample QC metrics and remove low-quality samples:

    >>> dataset = hl.sample_qc(dataset, name='sample_qc')
    >>> filtered_dataset = dataset.filter_cols((dataset.sample_qc.dp_stats.mean > 20) & (dataset.sample_qc.r_ti_tv > 1.5))

    Notes
    -----

    This method computes summary statistics per sample from a genetic matrix and stores
    the results as a new column-indexed struct field in the matrix, named based on the
    `name` parameter.

    If `mt` contains an entry field `DP` of type :py:data:`.tint32`, then the
    field `dp_stats` is computed. If `mt` contains an entry field `GQ` of type
    :py:data:`.tint32`, then the field `gq_stats` is computed. Both `dp_stats`
    and `gq_stats` are structs with with four fields:

    - `mean` (``float64``) -- Mean value.
    - `stdev` (``float64``) -- Standard deviation (zero degrees of freedom).
    - `min` (``int32``) -- Minimum value.
    - `max` (``int32``) -- Maximum value.

    If the dataset does not contain an entry field `GT` of type
    :py:data:`.tcall`, then an error is raised. The following fields are always
    computed from `GT`:

    - `call_rate` (``float64``) -- Fraction of calls non-missing.
    - `n_called` (``int64``) -- Number of non-missing calls.
    - `n_not_called` (``int64``) -- Number of missing calls.
    - `n_hom_ref` (``int64``) -- Number of homozygous reference calls.
    - `n_het` (``int64``) -- Number of heterozygous calls.
    - `n_hom_var` (``int64``) -- Number of homozygous alternate calls.
    - `n_non_ref` (``int64``) -- Sum of ``n_het`` and ``n_hom_var``.
    - `n_snp` (``int64``) -- Number of SNP alternate alleles.
    - `n_insertion` (``int64``) -- Number of insertion alternate alleles.
    - `n_deletion` (``int64``) -- Number of deletion alternate alleles.
    - `n_singleton` (``int64``) -- Number of private alleles.
    - `n_transition` (``int64``) -- Number of transition (A-G, C-T) alternate alleles.
    - `n_transversion` (``int64``) -- Number of transversion alternate alleles.
    - `n_star` (``int64``) -- Number of star (upstream deletion) alleles.
    - `r_ti_tv` (``float64``) -- Transition/Transversion ratio.
    - `r_het_hom_var` (``float64``) -- Het/HomVar call ratio.
    - `r_insertion_deletion` (``float64``) -- Insertion/Deletion allele ratio.

    Missing values ``NA`` may result from division by zero.

    Parameters
    ----------
    mt : :class:`.MatrixTable`
        Dataset.
    name : :obj:`str`
        Name for resulting field.

    Returns
    -------
    :class:`.MatrixTable`
        Dataset with a new column-indexed field `name`.
    """

    require_row_key_variant(mt, 'sample_qc')

    from hail.expr.functions import _num_allele_type , _allele_types

    allele_types = _allele_types[:]
    allele_types.extend(['Transition', 'Transversion'])
    allele_enum = {i: v for i, v in enumerate(allele_types)}
    allele_ints = {v: k for k, v in allele_enum.items()}

    def allele_type(ref, alt):
        return hl.bind(lambda at: hl.cond(at == allele_ints['SNP'],
                                          hl.cond(hl.is_transition(ref, alt),
                                                  allele_ints['Transition'],
                                                  allele_ints['Transversion']),
                                          at),
                       _num_allele_type(ref, alt))

    variant_ac = Env.get_uid()
    variant_atypes = Env.get_uid()
    mt = mt.annotate_rows(**{variant_ac: hl.agg.call_stats(mt.GT, mt.alleles).AC,
                             variant_atypes: mt.alleles[1:].map(lambda alt: allele_type(mt.alleles[0], alt))})

    exprs = {}

    def has_field_of_type(name, dtype):
        return name in mt.entry and mt[name].dtype == dtype

    if has_field_of_type('DP', hl.tint32):
        exprs['dp_stats'] = hl.agg.stats(mt.DP).select('mean', 'stdev', 'min', 'max')

    if has_field_of_type('GQ', hl.tint32):
        exprs['gq_stats'] = hl.agg.stats(mt.GQ).select('mean', 'stdev', 'min', 'max')

    if not has_field_of_type('GT',  hl.tcall):
        raise ValueError(f"'sample_qc': expect an entry field 'GT' of type 'call'")

    exprs['n_called'] = hl.agg.count_where(hl.is_defined(mt['GT']))
    exprs['n_not_called'] = hl.agg.count_where(hl.is_missing(mt['GT']))
    exprs['n_hom_ref'] = hl.agg.count_where(mt['GT'].is_hom_ref())
    exprs['n_het'] = hl.agg.count_where(mt['GT'].is_het())
    exprs['n_singleton'] = hl.agg.sum(hl.sum(hl.range(0, mt['GT'].ploidy).map(lambda i: mt[variant_ac][mt['GT'][i]] == 1)))

    def get_allele_type(allele_idx):
        return hl.cond(allele_idx > 0, mt[variant_atypes][allele_idx - 1], hl.null(hl.tint32))

    exprs['allele_type_counts'] = hl.agg.explode(
        lambda elt: hl.agg.counter(elt),
        hl.range(0, mt['GT'].ploidy).map(lambda i: get_allele_type(mt['GT'][i])))

    mt = mt.annotate_cols(**{name: hl.struct(**exprs)})

    zero = hl.int64(0)

    select_exprs = {}
    if 'dp_stats' in exprs:
        select_exprs['dp_stats'] = mt[name].dp_stats
    if 'gq_stats' in exprs:
        select_exprs['gq_stats'] = mt[name].gq_stats

    select_exprs = {
        **select_exprs,
        'call_rate': hl.float64(mt[name].n_called) / (mt[name].n_called + mt[name].n_not_called),
        'n_called': mt[name].n_called,
        'n_not_called': mt[name].n_not_called,
        'n_hom_ref': mt[name].n_hom_ref,
        'n_het': mt[name].n_het,
        'n_hom_var': mt[name].n_called - mt[name].n_hom_ref - mt[name].n_het,
        'n_non_ref': mt[name].n_called - mt[name].n_hom_ref,
        'n_singleton': mt[name].n_singleton,
        'n_snp': mt[name].allele_type_counts.get(allele_ints["Transition"], zero) + \
                 mt[name].allele_type_counts.get(allele_ints["Transversion"], zero),
        'n_insertion': mt[name].allele_type_counts.get(allele_ints["Insertion"], zero),
        'n_deletion': mt[name].allele_type_counts.get(allele_ints["Deletion"], zero),
        'n_transition': mt[name].allele_type_counts.get(allele_ints["Transition"], zero),
        'n_transversion': mt[name].allele_type_counts.get(allele_ints["Transversion"], zero),
        'n_star': mt[name].allele_type_counts.get(allele_ints["Star"], zero)
    }

    mt = mt.annotate_cols(**{name: mt[name].select(**select_exprs)})

    mt = mt.annotate_cols(**{name: mt[name].annotate(
        r_ti_tv=divide_null(hl.float64(mt[name].n_transition), mt[name].n_transversion),
        r_het_hom_var=divide_null(hl.float64(mt[name].n_het), mt[name].n_hom_var),
        r_insertion_deletion=divide_null(hl.float64(mt[name].n_insertion), mt[name].n_deletion)
    )})        

    mt = mt.drop(variant_ac, variant_atypes)

    return mt
Пример #49
0
def ld_score_regression(weight_expr,
                        ld_score_expr,
                        chi_sq_exprs,
                        n_samples_exprs,
                        n_blocks=200,
                        two_step_threshold=30,
                        n_reference_panel_variants=None) -> Table:
    r"""Estimate SNP-heritability and level of confounding biases from
    GWAS summary statistics.

    Given a set or multiple sets of genome-wide association study (GWAS)
    summary statistics, :func:`.ld_score_regression` estimates the heritability
    of a trait or set of traits and the level of confounding biases present in
    the underlying studies by regressing chi-squared statistics on LD scores,
    leveraging the model:

    .. math::

        \mathrm{E}[\chi_j^2] = 1 + Na + \frac{Nh_g^2}{M}l_j

    *  :math:`\mathrm{E}[\chi_j^2]` is the expected chi-squared statistic
       for variant :math:`j` resulting from a test of association between
       variant :math:`j` and a trait.
    *  :math:`l_j = \sum_{k} r_{jk}^2` is the LD score of variant
       :math:`j`, calculated as the sum of squared correlation coefficients
       between variant :math:`j` and nearby variants. See :func:`ld_score`
       for further details.
    *  :math:`a` captures the contribution of confounding biases, such as
       cryptic relatedness and uncontrolled population structure, to the
       association test statistic.
    *  :math:`h_g^2` is the SNP-heritability, or the proportion of variation
       in the trait explained by the effects of variants included in the
       regression model above.
    *  :math:`M` is the number of variants used to estimate :math:`h_g^2`.
    *  :math:`N` is the number of samples in the underlying association study.

    For more details on the method implemented in this function, see:

    * `LD Score regression distinguishes confounding from polygenicity in genome-wide association studies (Bulik-Sullivan et al, 2015) <https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4495769/>`__

    Examples
    --------

    Run the method on a matrix table of summary statistics, where the rows
    are variants and the columns are different phenotypes:

    >>> mt_gwas = hl.read_matrix_table('data/ld_score_regression.sumstats.mt')
    >>> ht_results = hl.experimental.ld_score_regression(
    ...     weight_expr=mt_gwas['ld_score'],
    ...     ld_score_expr=mt_gwas['ld_score'],
    ...     chi_sq_exprs=mt_gwas['chi_squared'],
    ...     n_samples_exprs=mt_gwas['n'])


    Run the method on a table with summary statistics for a single
    phenotype:

    >>> ht_gwas = hl.read_table('data/ld_score_regression.sumstats.ht')
    >>> ht_results = hl.experimental.ld_score_regression(
    ...     weight_expr=ht_gwas['ld_score'],
    ...     ld_score_expr=ht_gwas['ld_score'],
    ...     chi_sq_exprs=ht_gwas['chi_squared_50_irnt'],
    ...     n_samples_exprs=ht_gwas['n_50_irnt'])

    Run the method on a table with summary statistics for multiple
    phenotypes:

    >>> ht_gwas = hl.read_table('data/ld_score_regression.sumstats.ht')
    >>> ht_results = hl.experimental.ld_score_regression(
    ...     weight_expr=ht_gwas['ld_score'],
    ...     ld_score_expr=ht_gwas['ld_score'],
    ...     chi_sq_exprs=[ht_gwas['chi_squared_50_irnt'],
    ...                        ht_gwas['chi_squared_20160']],
    ...     n_samples_exprs=[ht_gwas['n_50_irnt'],
    ...                      ht_gwas['n_20160']])

    Notes
    -----
    The ``exprs`` provided as arguments to :func:`.ld_score_regression`
    must all be from the same object, either a :class:`Table` or a
    :class:`MatrixTable`.

    **If the arguments originate from a table:**

    *  The table must be keyed by fields ``locus`` of type
       :class:`.tlocus` and ``alleles``, a :py:data:`.tarray` of
       :py:data:`.tstr` elements.
    *  ``weight_expr``, ``ld_score_expr``, ``chi_sq_exprs``, and
       ``n_samples_exprs`` are must be row-indexed fields.
    *  The number of expressions passed to ``n_samples_exprs`` must be
       equal to one or the number of expressions passed to
       ``chi_sq_exprs``. If just one expression is passed to
       ``n_samples_exprs``, that sample size expression is assumed to
       apply to all sets of statistics passed to ``chi_sq_exprs``.
       Otherwise, the expressions passed to ``chi_sq_exprs`` and
       ``n_samples_exprs`` are matched by index.
    *  The ``phenotype`` field that keys the table returned by
       :func:`.ld_score_regression` will have generic :obj:`int` values
       ``0``, ``1``, etc. corresponding to the ``0th``, ``1st``, etc.
       expressions passed to the ``chi_sq_exprs`` argument.

    **If the arguments originate from a matrix table:**

    *  The dimensions of the matrix table must be variants
       (rows) by phenotypes (columns).
    *  The rows of the matrix table must be keyed by fields
       ``locus`` of type :class:`.tlocus` and ``alleles``,
       a :py:data:`.tarray` of :py:data:`.tstr` elements.
    *  The columns of the matrix table must be keyed by a field
       of type :py:data:`.tstr` that uniquely identifies phenotypes
       represented in the matrix table. The column key must be a single
       expression; compound keys are not accepted.
    *  ``weight_expr`` and ``ld_score_expr`` must be row-indexed
       fields.
    *  ``chi_sq_exprs`` must be a single entry-indexed field
       (not a list of fields).
    *  ``n_samples_exprs`` must be a single entry-indexed field
       (not a list of fields).
    *  The ``phenotype`` field that keys the table returned by
       :func:`.ld_score_regression` will have values corresponding to the
       column keys of the input matrix table.

    This function returns a :class:`Table` with one row per set of summary
    statistics passed to the ``chi_sq_exprs`` argument. The following
    row-indexed fields are included in the table:

    *  **phenotype** (:py:data:`.tstr`) -- The name of the phenotype. The
       returned table is keyed by this field. See the notes below for
       details on the possible values of this field.
    *  **mean_chi_sq** (:py:data:`.tfloat64`) -- The mean chi-squared
       test statistic for the given phenotype.
    *  **intercept** (`Struct`) -- Contains fields:

       -  **estimate** (:py:data:`.tfloat64`) -- A point estimate of the
          intercept :math:`1 + Na`.
       -  **standard_error**  (:py:data:`.tfloat64`) -- An estimate of
          the standard error of this point estimate.

    *  **snp_heritability** (`Struct`) -- Contains fields:

       -  **estimate** (:py:data:`.tfloat64`) -- A point estimate of the
          SNP-heritability :math:`h_g^2`.
       -  **standard_error** (:py:data:`.tfloat64`) -- An estimate of
          the standard error of this point estimate.

    Warning
    -------
    :func:`.ld_score_regression` considers only the rows for which both row
    fields ``weight_expr`` and ``ld_score_expr`` are defined. Rows with missing
    values in either field are removed prior to fitting the LD score
    regression model.

    Parameters
    ----------
    weight_expr : :class:`.Float64Expression`
                  Row-indexed expression for the LD scores used to derive
                  variant weights in the model.
    ld_score_expr : :class:`.Float64Expression`
                    Row-indexed expression for the LD scores used as covariates
                    in the model.
    chi_sq_exprs : :class:`.Float64Expression` or :obj:`list` of
                        :class:`.Float64Expression`
                        One or more row-indexed (if table) or entry-indexed
                        (if matrix table) expressions for chi-squared
                        statistics resulting from genome-wide association
                        studies.
    n_samples_exprs: :class:`.NumericExpression` or :obj:`list` of
                     :class:`.NumericExpression`
                     One or more row-indexed (if table) or entry-indexed
                     (if matrix table) expressions indicating the number of
                     samples used in the studies that generated the test
                     statistics supplied to ``chi_sq_exprs``.
    n_blocks : :obj:`int`
               The number of blocks used in the jackknife approach to
               estimating standard errors.
    two_step_threshold : :obj:`int`
                         Variants with chi-squared statistics greater than this
                         value are excluded in the first step of the two-step
                         procedure used to fit the model.
    n_reference_panel_variants : :obj:`int`, optional
                                 Number of variants used to estimate the
                                 SNP-heritability :math:`h_g^2`.

    Returns
    -------
    :class:`.Table`
        Table keyed by ``phenotype`` with intercept and heritability estimates
        for each phenotype passed to the function."""

    chi_sq_exprs = wrap_to_list(chi_sq_exprs)
    n_samples_exprs = wrap_to_list(n_samples_exprs)

    assert ((len(chi_sq_exprs) == len(n_samples_exprs)) or
            (len(n_samples_exprs) == 1))
    __k = 2  # number of covariates, including intercept

    ds = chi_sq_exprs[0]._indices.source

    analyze('ld_score_regression/weight_expr',
            weight_expr,
            ds._row_indices)
    analyze('ld_score_regression/ld_score_expr',
            ld_score_expr,
            ds._row_indices)

    # format input dataset
    if isinstance(ds, MatrixTable):
        if len(chi_sq_exprs) != 1:
            raise ValueError("""Only one chi_sq_expr allowed if originating
                from a matrix table.""")
        if len(n_samples_exprs) != 1:
            raise ValueError("""Only one n_samples_expr allowed if
                originating from a matrix table.""")

        col_key = list(ds.col_key)
        if len(col_key) != 1:
            raise ValueError("""Matrix table must be keyed by a single
                phenotype field.""")

        analyze('ld_score_regression/chi_squared_expr',
                chi_sq_exprs[0],
                ds._entry_indices)
        analyze('ld_score_regression/n_samples_expr',
                n_samples_exprs[0],
                ds._entry_indices)

        ds = ds._select_all(row_exprs={'__locus': ds.locus,
                                       '__alleles': ds.alleles,
                                       '__w_initial': weight_expr,
                                       '__w_initial_floor': hl.max(weight_expr,
                                                                   1.0),
                                       '__x': ld_score_expr,
                                       '__x_floor': hl.max(ld_score_expr,
                                                           1.0)},
                            row_key=['__locus', '__alleles'],
                            col_exprs={'__y_name': ds[col_key[0]]},
                            col_key=['__y_name'],
                            entry_exprs={'__y': chi_sq_exprs[0],
                                         '__n': n_samples_exprs[0]})
        ds = ds.annotate_entries(**{'__w': ds.__w_initial})

        ds = ds.filter_rows(hl.is_defined(ds.__locus) &
                            hl.is_defined(ds.__alleles) &
                            hl.is_defined(ds.__w_initial) &
                            hl.is_defined(ds.__x))

    else:
        assert isinstance(ds, Table)
        for y in chi_sq_exprs:
            analyze('ld_score_regression/chi_squared_expr', y, ds._row_indices)
        for n in n_samples_exprs:
            analyze('ld_score_regression/n_samples_expr', n, ds._row_indices)

        ys = ['__y{:}'.format(i) for i, _ in enumerate(chi_sq_exprs)]
        ws = ['__w{:}'.format(i) for i, _ in enumerate(chi_sq_exprs)]
        ns = ['__n{:}'.format(i) for i, _ in enumerate(n_samples_exprs)]

        ds = ds.select(**dict(**{'__locus': ds.locus,
                                 '__alleles': ds.alleles,
                                 '__w_initial': weight_expr,
                                 '__x': ld_score_expr},
                              **{y: chi_sq_exprs[i]
                                 for i, y in enumerate(ys)},
                              **{w: weight_expr for w in ws},
                              **{n: n_samples_exprs[i]
                                 for i, n in enumerate(ns)}))
        ds = ds.key_by(ds.__locus, ds.__alleles)

        table_tmp_file = new_temp_file()
        ds.write(table_tmp_file)
        ds = hl.read_table(table_tmp_file)

        hts = [ds.select(**{'__w_initial': ds.__w_initial,
                            '__w_initial_floor': hl.max(ds.__w_initial,
                                                        1.0),
                            '__x': ds.__x,
                            '__x_floor': hl.max(ds.__x, 1.0),
                            '__y_name': i,
                            '__y': ds[ys[i]],
                            '__w': ds[ws[i]],
                            '__n': hl.int(ds[ns[i]])})
               for i, y in enumerate(ys)]

        mts = [ht.to_matrix_table(row_key=['__locus',
                                           '__alleles'],
                                  col_key=['__y_name'],
                                  row_fields=['__w_initial',
                                              '__w_initial_floor',
                                              '__x',
                                              '__x_floor'])
               for ht in hts]

        ds = mts[0]
        for i in range(1, len(ys)):
            ds = ds.union_cols(mts[i])

        ds = ds.filter_rows(hl.is_defined(ds.__locus) &
                            hl.is_defined(ds.__alleles) &
                            hl.is_defined(ds.__w_initial) &
                            hl.is_defined(ds.__x))

    mt_tmp_file1 = new_temp_file()
    ds.write(mt_tmp_file1)
    mt = hl.read_matrix_table(mt_tmp_file1)

    if not n_reference_panel_variants:
        M = mt.count_rows()
    else:
        M = n_reference_panel_variants

    # block variants for each phenotype
    n_phenotypes = mt.count_cols()

    mt = mt.annotate_entries(__in_step1=(hl.is_defined(mt.__y) &
                                         (mt.__y < two_step_threshold)),
                             __in_step2=hl.is_defined(mt.__y))

    mt = mt.annotate_cols(__col_idx=hl.int(hl.scan.count()),
                          __m_step1=hl.agg.count_where(mt.__in_step1),
                          __m_step2=hl.agg.count_where(mt.__in_step2))

    col_keys = list(mt.col_key)

    ht = mt.localize_entries(entries_array_field_name='__entries',
                             columns_array_field_name='__cols')

    ht = ht.annotate(__entries=hl.rbind(
        hl.scan.array_agg(
            lambda entry: hl.scan.count_where(entry.__in_step1),
            ht.__entries),
        lambda step1_indices: hl.map(
            lambda i: hl.rbind(
                hl.int(hl.or_else(step1_indices[i], 0)),
                ht.__cols[i].__m_step1,
                ht.__entries[i],
                lambda step1_idx, m_step1, entry: hl.rbind(
                    hl.map(
                        lambda j: hl.int(hl.floor(j * (m_step1 / n_blocks))),
                        hl.range(0, n_blocks + 1)),
                    lambda step1_separators: hl.rbind(
                        hl.set(step1_separators).contains(step1_idx),
                        hl.sum(
                            hl.map(
                                lambda s1: step1_idx >= s1,
                                step1_separators)) - 1,
                        lambda is_separator, step1_block: entry.annotate(
                            __step1_block=step1_block,
                            __step2_block=hl.cond(~entry.__in_step1 & is_separator,
                                                  step1_block - 1,
                                                  step1_block))))),
            hl.range(0, hl.len(ht.__entries)))))

    mt = ht._unlocalize_entries('__entries', '__cols', col_keys)

    mt_tmp_file2 = new_temp_file()
    mt.write(mt_tmp_file2)
    mt = hl.read_matrix_table(mt_tmp_file2)
    
    # initial coefficient estimates
    mt = mt.annotate_cols(__initial_betas=[
        1.0, (hl.agg.mean(mt.__y) - 1.0) / hl.agg.mean(mt.__x)])
    mt = mt.annotate_cols(__step1_betas=mt.__initial_betas,
                          __step2_betas=mt.__initial_betas)

    # step 1 iteratively reweighted least squares
    for i in range(3):
        mt = mt.annotate_entries(__w=hl.cond(
            mt.__in_step1,
            1.0/(mt.__w_initial_floor * 2.0 * (mt.__step1_betas[0] +
                                               mt.__step1_betas[1] *
                                               mt.__x_floor)**2),
            0.0))
        mt = mt.annotate_cols(__step1_betas=hl.agg.filter(
            mt.__in_step1,
            hl.agg.linreg(y=mt.__y,
                          x=[1.0, mt.__x],
                          weight=mt.__w).beta))
        mt = mt.annotate_cols(__step1_h2=hl.max(hl.min(
            mt.__step1_betas[1] * M / hl.agg.mean(mt.__n), 1.0), 0.0))
        mt = mt.annotate_cols(__step1_betas=[
            mt.__step1_betas[0],
            mt.__step1_h2 * hl.agg.mean(mt.__n) / M])

    # step 1 block jackknife
    mt = mt.annotate_cols(__step1_block_betas=[
        hl.agg.filter((mt.__step1_block != i) & mt.__in_step1,
                      hl.agg.linreg(y=mt.__y,
                                    x=[1.0, mt.__x],
                                    weight=mt.__w).beta)
        for i in range(n_blocks)])

    mt = mt.annotate_cols(__step1_block_betas_bias_corrected=hl.map(
        lambda x: n_blocks * mt.__step1_betas - (n_blocks - 1) * x,
        mt.__step1_block_betas))

    mt = mt.annotate_cols(
        __step1_jackknife_mean=hl.map(
            lambda i: hl.mean(
                hl.map(lambda x: x[i],
                       mt.__step1_block_betas_bias_corrected)),
            hl.range(0, __k)),
        __step1_jackknife_variance=hl.map(
            lambda i: (hl.sum(
                hl.map(lambda x: x[i]**2,
                       mt.__step1_block_betas_bias_corrected)) -
                       hl.sum(
                hl.map(lambda x: x[i],
                       mt.__step1_block_betas_bias_corrected))**2 /
                       n_blocks) /
            (n_blocks - 1) / n_blocks,
            hl.range(0, __k)))

    # step 2 iteratively reweighted least squares
    for i in range(3):
        mt = mt.annotate_entries(__w=hl.cond(
            mt.__in_step2,
            1.0/(mt.__w_initial_floor *
                 2.0 * (mt.__step2_betas[0] +
                        mt.__step2_betas[1] *
                        mt.__x_floor)**2),
            0.0))
        mt = mt.annotate_cols(__step2_betas=[
            mt.__step1_betas[0],
            hl.agg.filter(mt.__in_step2,
                          hl.agg.linreg(y=mt.__y - mt.__step1_betas[0],
                                        x=[mt.__x],
                                        weight=mt.__w).beta[0])])
        mt = mt.annotate_cols(__step2_h2=hl.max(hl.min(
            mt.__step2_betas[1] * M/hl.agg.mean(mt.__n), 1.0), 0.0))
        mt = mt.annotate_cols(__step2_betas=[
            mt.__step1_betas[0],
            mt.__step2_h2 * hl.agg.mean(mt.__n)/M])

    # step 2 block jackknife
    mt = mt.annotate_cols(__step2_block_betas=[
        hl.agg.filter((mt.__step2_block != i) & mt.__in_step2,
                      hl.agg.linreg(y=mt.__y - mt.__step1_betas[0],
                                    x=[mt.__x],
                                    weight=mt.__w).beta[0])
        for i in range(n_blocks)])

    mt = mt.annotate_cols(__step2_block_betas_bias_corrected=hl.map(
        lambda x: n_blocks * mt.__step2_betas[1] - (n_blocks - 1) * x,
        mt.__step2_block_betas))

    mt = mt.annotate_cols(
        __step2_jackknife_mean=hl.mean(
            mt.__step2_block_betas_bias_corrected),
        __step2_jackknife_variance=(
            hl.sum(mt.__step2_block_betas_bias_corrected**2) -
            hl.sum(mt.__step2_block_betas_bias_corrected)**2 /
            n_blocks) / (n_blocks - 1) / n_blocks)

    # combine step 1 and step 2 block jackknifes
    mt = mt.annotate_entries(
        __step2_initial_w=1.0/(mt.__w_initial_floor *
                               2.0 * (mt.__initial_betas[0] +
                                      mt.__initial_betas[1] *
                                      mt.__x_floor)**2))

    mt = mt.annotate_cols(
        __final_betas=[
            mt.__step1_betas[0],
            mt.__step2_betas[1]],
        __c=(hl.agg.sum(mt.__step2_initial_w * mt.__x) /
             hl.agg.sum(mt.__step2_initial_w * mt.__x**2)))

    mt = mt.annotate_cols(__final_block_betas=hl.map(
        lambda i: (mt.__step2_block_betas[i] - mt.__c *
                   (mt.__step1_block_betas[i][0] - mt.__final_betas[0])),
        hl.range(0, n_blocks)))

    mt = mt.annotate_cols(
        __final_block_betas_bias_corrected=(n_blocks * mt.__final_betas[1] -
                                            (n_blocks - 1) *
                                            mt.__final_block_betas))

    mt = mt.annotate_cols(
        __final_jackknife_mean=[
            mt.__step1_jackknife_mean[0],
            hl.mean(mt.__final_block_betas_bias_corrected)],
        __final_jackknife_variance=[
            mt.__step1_jackknife_variance[0],
            (hl.sum(mt.__final_block_betas_bias_corrected**2) -
             hl.sum(mt.__final_block_betas_bias_corrected)**2 /
             n_blocks) / (n_blocks - 1) / n_blocks])

    # convert coefficient to heritability estimate
    mt = mt.annotate_cols(
        phenotype=mt.__y_name,
        mean_chi_sq=hl.agg.mean(mt.__y),
        intercept=hl.struct(
            estimate=mt.__final_betas[0],
            standard_error=hl.sqrt(mt.__final_jackknife_variance[0])),
        snp_heritability=hl.struct(
            estimate=(M/hl.agg.mean(mt.__n)) * mt.__final_betas[1],
            standard_error=hl.sqrt((M/hl.agg.mean(mt.__n))**2 *
                                   mt.__final_jackknife_variance[1])))

    # format and return results
    ht = mt.cols()
    ht = ht.key_by(ht.phenotype)
    ht = ht.select(ht.mean_chi_sq,
                   ht.intercept,
                   ht.snp_heritability)

    ht_tmp_file = new_temp_file()
    ht.write(ht_tmp_file)
    ht = hl.read_table(ht_tmp_file)
    
    return ht
Пример #50
0
    def test_trio_matrix(self):
        """
        This test depends on certain properties of the trio matrix VCF and
        pedigree structure. This test is NOT a valid test if the pedigree
        includes quads: the trio_matrix method will duplicate the parents
        appropriately, but the genotypes_table and samples_table orthogonal
        paths would require another duplication/explode that we haven't written.
        """
        ped = hl.Pedigree.read(resource('triomatrix.fam'))
        ht = hl.import_fam(resource('triomatrix.fam'))

        mt = hl.import_vcf(resource('triomatrix.vcf'))
        mt = mt.annotate_cols(fam=ht[mt.s].fam_id)

        dads = ht.filter(hl.is_defined(ht.pat_id))
        dads = dads.select(dads.pat_id, is_dad=True).key_by('pat_id')

        moms = ht.filter(hl.is_defined(ht.mat_id))
        moms = moms.select(moms.mat_id, is_mom=True).key_by('mat_id')

        et = (mt.entries()
              .key_by('s')
              .join(dads, how='left')
              .join(moms, how='left'))
        et = et.annotate(is_dad=hl.is_defined(et.is_dad),
                         is_mom=hl.is_defined(et.is_mom))

        et = (et
            .group_by(et.locus, et.alleles, fam=et.fam)
            .aggregate(data=hl.agg.collect(hl.struct(
            role=hl.case().when(et.is_dad, 1).when(et.is_mom, 2).default(0),
            g=hl.struct(GT=et.GT, AD=et.AD, DP=et.DP, GQ=et.GQ, PL=et.PL)))))

        et = et.filter(hl.len(et.data) == 3)
        et = et.select('data').explode('data')

        tt = hl.trio_matrix(mt, ped, complete_trios=True).entries().key_by('locus', 'alleles')
        tt = tt.annotate(fam=tt.proband.fam,
                         data=[hl.struct(role=0, g=tt.proband_entry.select('GT', 'AD', 'DP', 'GQ', 'PL')),
                               hl.struct(role=1, g=tt.father_entry.select('GT', 'AD', 'DP', 'GQ', 'PL')),
                               hl.struct(role=2, g=tt.mother_entry.select('GT', 'AD', 'DP', 'GQ', 'PL'))])
        tt = tt.select('fam', 'data').explode('data')
        tt = tt.filter(hl.is_defined(tt.data.g)).key_by('locus', 'alleles', 'fam')

        self.assertEqual(et.key.dtype, tt.key.dtype)
        self.assertEqual(et.row.dtype, tt.row.dtype)
        self.assertTrue(et._same(tt))

        # test annotations
        e_cols = (mt.cols()
                  .join(dads, how='left')
                  .join(moms, how='left'))
        e_cols = e_cols.annotate(is_dad=hl.is_defined(e_cols.is_dad),
                                 is_mom=hl.is_defined(e_cols.is_mom))
        e_cols = (e_cols.group_by(fam=e_cols.fam)
                  .aggregate(data=hl.agg.collect(hl.struct(role=hl.case()
                                                           .when(e_cols.is_dad, 1).when(e_cols.is_mom, 2).default(0),
                                                           sa=hl.struct(**e_cols.row.select(*mt.col))))))
        e_cols = e_cols.filter(hl.len(e_cols.data) == 3).select('data').explode('data')

        t_cols = hl.trio_matrix(mt, ped, complete_trios=True).cols()
        t_cols = t_cols.annotate(fam=t_cols.proband.fam,
                                 data=[
                                     hl.struct(role=0, sa=t_cols.proband),
                                     hl.struct(role=1, sa=t_cols.father),
                                     hl.struct(role=2, sa=t_cols.mother)]).key_by('fam').select('data').explode('data')
        t_cols = t_cols.filter(hl.is_defined(t_cols.data.sa))

        self.assertEqual(e_cols.key.dtype, t_cols.key.dtype)
        self.assertEqual(e_cols.row.dtype, t_cols.row.dtype)
        self.assertTrue(e_cols._same(t_cols))
Пример #51
0
 def test_join_with_empty(self):
     kt = hl.utils.range_table(10)
     kt2 = kt.head(0)
     kt.annotate(foo=hl.is_defined(kt2[kt.idx]))
Пример #52
0
def de_novo(mt: MatrixTable,
            pedigree: Pedigree,
            pop_frequency_prior,
            *,
            min_gq: int = 20,
            min_p: float = 0.05,
            max_parent_ab: float = 0.05,
            min_child_ab: float = 0.20,
            min_dp_ratio: float = 0.10) -> Table:
    r"""Call putative *de novo* events from trio data.

    .. include:: ../_templates/req_tstring.rst

    .. include:: ../_templates/req_tvariant.rst

    .. include:: ../_templates/req_biallelic.rst

    Examples
    --------

    Call de novo events:

    >>> pedigree = hl.Pedigree.read('data/trios.fam')
    >>> priors = hl.import_table('data/gnomadFreq.tsv', impute=True)
    >>> priors = priors.transmute(**hl.parse_variant(priors.Variant)).key_by('locus', 'alleles')
    >>> de_novo_results = hl.de_novo(dataset, pedigree, pop_frequency_prior=priors[dataset.row_key].AF)

    Notes
    -----
    This method assumes the GATK high-throughput sequencing fields exist:
    `GT`, `AD`, `DP`, `GQ`, `PL`.

    This method replicates the functionality of `Kaitlin Samocha's de novo
    caller <https://github.com/ksamocha/de_novo_scripts>`__. The version
    corresponding to git commit ``bde3e40`` is implemented in Hail with her
    permission and assistance.

    This method produces a :class:`.Table` with the following fields:

     - `locus` (``locus``) -- Variant locus.
     - `alleles` (``array<str>``) -- Variant alleles.
     - `id` (``str``) -- Proband sample ID.
     - `prior` (``float64``) -- Site frequency prior. It is the maximum of:
       the computed dataset alternate allele frequency, the
       `pop_frequency_prior` parameter, and the global prior
       ``1 / 3e7``.
     - `proband` (``struct``) -- Proband column fields from `mt`.
     - `father` (``struct``) -- Father column fields from `mt`.
     - `mother` (``struct``) -- Mother column fields from `mt`.
     - `proband_entry` (``struct``) -- Proband entry fields from `mt`.
     - `father_entry` (``struct``) -- Father entry fields from `mt`.
     - `proband_entry` (``struct``) -- Mother entry fields from `mt`.
     - `is_female` (``bool``) -- ``True`` if proband is female.
     - `p_de_novo` (``float64``) -- Unfiltered posterior probability
       that the event is *de novo* rather than a missed heterozygous
       event in a parent.
     - `confidence` (``str``) Validation confidence. One of: ``'HIGH'``,
       ``'MEDIUM'``, ``'LOW'``.

    The key of the table is ``['locus', 'alleles', 'id']``.

    The model looks for de novo events in which both parents are homozygous
    reference and the proband is a heterozygous. The model makes the simplifying
    assumption that when this configuration ``x = (AA, AA, AB)`` of calls
    occurs, exactly one of the following is true:

     - ``d``: a de novo mutation occurred in the proband and all calls are
       accurate.
     - ``m``: at least one parental allele is actually heterozygous and
       the proband call is accurate.

    We can then estimate the posterior probability of a de novo mutation as:

    .. math::

        \mathrm{P_{\text{de novo}}} = \frac{\mathrm{P}(d\,|\,x)}{\mathrm{P}(d\,|\,x) + \mathrm{P}(m\,|\,x)}

    Applying Bayes rule to the numerator and denominator yields

    .. math::

        \frac{\mathrm{P}(x\,|\,d)\,\mathrm{P}(d)}{\mathrm{P}(x\,|\,d)\,\mathrm{P}(d) +
        \mathrm{P}(x\,|\,m)\,\mathrm{P}(m)}

    The prior on de novo mutation is estimated from the rate in the literature:

    .. math::

        \mathrm{P}(d) = \frac{1 \text{mutation}}{30,000,000\, \text{bases}}

    The prior used for at least one alternate allele between the parents
    depends on the alternate allele frequency:

    .. math::

        \mathrm{P}(m) = 1 - (1 - AF)^4

    The likelihoods :math:`\mathrm{P}(x\,|\,d)` and :math:`\mathrm{P}(x\,|\,m)`
    are computed from the PL (genotype likelihood) fields using these
    factorizations:

    .. math::

        \mathrm{P}(x = (AA, AA, AB) \,|\,d) = \Big(
        &\mathrm{P}(x_{\mathrm{father}} = AA \,|\, \mathrm{father} = AA) \\
        \cdot &\mathrm{P}(x_{\mathrm{mother}} = AA \,|\, \mathrm{mother} =
        AA) \\ \cdot &\mathrm{P}(x_{\mathrm{proband}} = AB \,|\,
        \mathrm{proband} = AB) \Big)

    .. math::

        \mathrm{P}(x = (AA, AA, AB) \,|\,m) = \Big( &
        \mathrm{P}(x_{\mathrm{father}} = AA \,|\, \mathrm{father} = AB)
        \cdot \mathrm{P}(x_{\mathrm{mother}} = AA \,|\, \mathrm{mother} =
        AA) \\ + \, &\mathrm{P}(x_{\mathrm{father}} = AA \,|\,
        \mathrm{father} = AA) \cdot \mathrm{P}(x_{\mathrm{mother}} = AA
        \,|\, \mathrm{mother} = AB) \Big) \\ \cdot \,
        &\mathrm{P}(x_{\mathrm{proband}} = AB \,|\, \mathrm{proband} = AB)

    (Technically, the second factorization assumes there is exactly (rather
    than at least) one alternate allele among the parents, which may be
    justified on the grounds that it is typically the most likely case by far.)

    While this posterior probability is a good metric for grouping putative de
    novo mutations by validation likelihood, there exist error modes in
    high-throughput sequencing data that are not appropriately accounted for by
    the phred-scaled genotype likelihoods. To this end, a number of hard filters
    are applied in order to assign validation likelihood.

    These filters are different for SNPs and insertions/deletions. In the below
    rules, the following variables are used:

     - ``DR`` refers to the ratio of the read depth in the proband to the
       combined read depth in the parents.
     - ``AB`` refers to the read allele balance of the proband (number of
       alternate reads divided by total reads).
     - ``AC`` refers to the count of alternate alleles across all individuals
       in the dataset at the site.
     - ``p`` refers to :math:`\mathrm{P_{\text{de novo}}}`.
     - ``min_p`` refers to the ``min_p`` function parameter.

    HIGH-quality SNV:

    .. code-block:: text

        p > 0.99 && AB > 0.3 && DR > 0.2
            or
        p > 0.99 && AB > 0.3 && AC == 1

    MEDIUM-quality SNV:

    .. code-block:: text

        p > 0.5 && AB > 0.3
            or
        p > 0.5 && AB > 0.2 && AC == 1

    LOW-quality SNV:

    .. code-block:: text

        p > min_p && AB > 0.2

    HIGH-quality indel:

    .. code-block:: text

        p > 0.99 && AB > 0.3 && DR > 0.2
            or
        p > 0.99 && AB > 0.3 && AC == 1

    MEDIUM-quality indel:

    .. code-block:: text

        p > 0.5 && AB > 0.3
            or
        p > 0.5 && AB > 0.2 and AC == 1

    LOW-quality indel:

    .. code-block:: text

        p > min_p && AB > 0.2

    Additionally, de novo candidates are not considered if the proband GQ is
    smaller than the ``min_gq`` parameter, if the proband allele balance is
    lower than the ``min_child_ab`` parameter, if the depth ratio between the
    proband and parents is smaller than the ``min_depth_ratio`` parameter, or if
    the allele balance in a parent is above the ``max_parent_ab`` parameter.

    Parameters
    ----------
    mt : :class:`.MatrixTable`
        High-throughput sequencing dataset.
    pedigree : :class:`.Pedigree`
        Sample pedigree.
    pop_frequency_prior : :class:`.Float64Expression`
        Expression for population alternate allele frequency prior.
    min_gq
        Minimum proband GQ to be considered for *de novo* calling.
    min_p
        Minimum posterior probability to be considered for *de novo* calling.
    max_parent_ab
        Maximum parent allele balance.
    min_child_ab
        Minimum proband allele balance/
    min_dp_ratio
        Minimum ratio between proband read depth and parental read depth.

    Returns
    -------
    :class:`.Table`
    """
    DE_NOVO_PRIOR = 1 / 30000000
    MIN_POP_PRIOR = 100 / 30000000

    required_entry_fields = {'GT', 'AD', 'DP', 'GQ', 'PL'}
    missing_fields = required_entry_fields - set(mt.entry)
    if missing_fields:
        raise ValueError(f"'de_novo': expected 'MatrixTable' to have at least {required_entry_fields}, "
                         f"missing {missing_fields}")

    mt = mt.annotate_rows(__prior=pop_frequency_prior,
                          __alt_alleles=hl.agg.sum(mt.GT.n_alt_alleles()),
                          __total_alleles=2 * hl.agg.sum(hl.is_defined(mt.GT)))
    # subtract 1 from __alt_alleles to correct for the observed genotype
    mt = mt.annotate_rows(__site_freq=hl.max((mt.__alt_alleles - 1) / mt.__total_alleles, mt.__prior, MIN_POP_PRIOR))
    mt = require_biallelic(mt, 'de_novo')

    # FIXME check that __site_freq is between 0 and 1 when possible in expr
    tm = trio_matrix(mt, pedigree, complete_trios=True)

    autosomal = tm.locus.in_autosome_or_par() | (tm.locus.in_x_nonpar() & tm.is_female)
    hemi_x = tm.locus.in_x_nonpar() & ~tm.is_female
    hemi_y = tm.locus.in_y_nonpar() & ~tm.is_female
    hemi_mt = tm.locus.in_mito() & tm.is_female

    is_snp = hl.is_snp(tm.alleles[0], tm.alleles[1])
    n_alt_alleles = tm.__alt_alleles
    prior = tm.__site_freq
    het_hom_hom = tm.proband_entry.GT.is_het() & tm.father_entry.GT.is_hom_ref() & tm.mother_entry.GT.is_hom_ref()
    kid_ad_fail = tm.proband_entry.AD[1] / hl.sum(tm.proband_entry.AD) < min_child_ab

    failure = hl.null(hl.tstruct(p_de_novo=hl.tfloat64, confidence=hl.tstr))

    kid = tm.proband_entry
    dad = tm.father_entry
    mom = tm.mother_entry

    kid_linear_pl = 10 ** (-kid.PL / 10)
    kid_pp = hl.bind(lambda x: x / hl.sum(x), kid_linear_pl)

    dad_linear_pl = 10 ** (-dad.PL / 10)
    dad_pp = hl.bind(lambda x: x / hl.sum(x), dad_linear_pl)

    mom_linear_pl = 10 ** (-mom.PL / 10)
    mom_pp = hl.bind(lambda x: x / hl.sum(x), mom_linear_pl)

    kid_ad_ratio = kid.AD[1] / hl.sum(kid.AD)
    dp_ratio = kid.DP / (dad.DP + mom.DP)

    def call_auto(kid_pp, dad_pp, mom_pp, kid_ad_ratio):
        p_data_given_dn = dad_pp[0] * mom_pp[0] * kid_pp[1] * DE_NOVO_PRIOR
        p_het_in_parent = 1 - (1 - prior) ** 4
        p_data_given_missed_het = (dad_pp[1] * mom_pp[0] + dad_pp[0] * mom_pp[1]) * kid_pp[1] * p_het_in_parent
        p_de_novo = p_data_given_dn / (p_data_given_dn + p_data_given_missed_het)

        def solve(p_de_novo):
            return (
                hl.case()
                    .when(kid.GQ < min_gq, failure)
                    .when((kid.DP / (dad.DP + mom.DP) < min_dp_ratio) |
                          ~(kid_ad_ratio >= min_child_ab), failure)
                    .when((hl.sum(mom.AD) == 0) | (hl.sum(dad.AD) == 0), failure)
                    .when((mom.AD[1] / hl.sum(mom.AD) > max_parent_ab) |
                          (dad.AD[1] / hl.sum(dad.AD) > max_parent_ab), failure)
                    .when(p_de_novo < min_p, failure)
                    .when(~is_snp, hl.case()
                          .when((p_de_novo > 0.99) & (kid_ad_ratio > 0.3) & (n_alt_alleles == 1),
                                hl.struct(p_de_novo=p_de_novo, confidence='HIGH'))
                          .when((p_de_novo > 0.5) & (kid_ad_ratio > 0.3) & (n_alt_alleles <= 5),
                                hl.struct(p_de_novo=p_de_novo, confidence='MEDIUM'))
                          .when((p_de_novo > 0.05) & (kid_ad_ratio > 0.2),
                                hl.struct(p_de_novo=p_de_novo, confidence='LOW'))
                          .or_missing())
                    .default(hl.case()
                             .when(((p_de_novo > 0.99) & (kid_ad_ratio > 0.3) & (dp_ratio > 0.2)) |
                                   ((p_de_novo > 0.99) & (kid_ad_ratio > 0.3) & (n_alt_alleles == 1)) |
                                   ((p_de_novo > 0.5) & (kid_ad_ratio > 0.3) & (n_alt_alleles < 10) & (kid.DP > 10)),
                                   hl.struct(p_de_novo=p_de_novo, confidence='HIGH'))
                             .when((p_de_novo > 0.5) & ((kid_ad_ratio > 0.3) | (n_alt_alleles == 1)),
                                   hl.struct(p_de_novo=p_de_novo, confidence='MEDIUM'))
                             .when((p_de_novo > 0.05) & (kid_ad_ratio > 0.2),
                                   hl.struct(p_de_novo=p_de_novo, confidence='LOW'))
                             .or_missing()
                             )
            )

        return hl.bind(solve, p_de_novo)

    def call_hemi(kid_pp, parent, parent_pp, kid_ad_ratio):
        p_data_given_dn = parent_pp[0] * kid_pp[1] * DE_NOVO_PRIOR
        p_het_in_parent = 1 - (1 - prior) ** 4
        p_data_given_missed_het = (parent_pp[1] + parent_pp[2]) * kid_pp[2] * p_het_in_parent
        p_de_novo = p_data_given_dn / (p_data_given_dn + p_data_given_missed_het)

        def solve(p_de_novo):
            return (
                hl.case()
                    .when(kid.GQ < min_gq, failure)
                    .when((kid.DP / (parent.DP) < min_dp_ratio) |
                          (kid_ad_ratio < min_child_ab), failure)
                    .when((hl.sum(parent.AD) == 0), failure)
                    .when(parent.AD[1] / hl.sum(parent.AD) > max_parent_ab, failure)
                    .when(p_de_novo < min_p, failure)
                    .when(~is_snp, hl.case()
                          .when((p_de_novo > 0.99) & (kid_ad_ratio > 0.3) & (n_alt_alleles == 1),
                                hl.struct(p_de_novo=p_de_novo, confidence='HIGH'))
                          .when((p_de_novo > 0.5) & (kid_ad_ratio > 0.3) & (n_alt_alleles <= 5),
                                hl.struct(p_de_novo=p_de_novo, confidence='MEDIUM'))
                          .when((p_de_novo > 0.05) & (kid_ad_ratio > 0.3),
                                hl.struct(p_de_novo=p_de_novo, confidence='LOW'))
                          .or_missing())
                    .default(hl.case()
                             .when(((p_de_novo > 0.99) & (kid_ad_ratio > 0.3) & (dp_ratio > 0.2)) |
                                   ((p_de_novo > 0.99) & (kid_ad_ratio > 0.3) & (n_alt_alleles == 1)) |
                                   ((p_de_novo > 0.5) & (kid_ad_ratio > 0.3) & (n_alt_alleles < 10) & (kid.DP > 10)),
                                   hl.struct(p_de_novo=p_de_novo, confidence='HIGH'))
                             .when((p_de_novo > 0.5) & ((kid_ad_ratio > 0.3) | (n_alt_alleles == 1)),
                                   hl.struct(p_de_novo=p_de_novo, confidence='MEDIUM'))
                             .when((p_de_novo > 0.05) & (kid_ad_ratio > 0.2),
                                   hl.struct(p_de_novo=p_de_novo, confidence='LOW'))
                             .or_missing()
                             )
            )

        return hl.bind(solve, p_de_novo)

    de_novo_call = (
        hl.case()
            .when(~het_hom_hom | kid_ad_fail, failure)
            .when(autosomal, hl.bind(call_auto, kid_pp, dad_pp, mom_pp, kid_ad_ratio))
            .when(hemi_x | hemi_mt, hl.bind(call_hemi, kid_pp, mom, mom_pp, kid_ad_ratio))
            .when(hemi_y, hl.bind(call_hemi, kid_pp, dad, dad_pp, kid_ad_ratio))
            .or_missing()
    )

    tm = tm.annotate_entries(__call=de_novo_call)
    tm = tm.filter_entries(hl.is_defined(tm.__call))
    entries = tm.entries()
    return (entries.select('__site_freq',
                           'proband',
                           'father',
                           'mother',
                           'proband_entry',
                           'father_entry',
                           'mother_entry',
                           'is_female',
                           **entries.__call)
            .rename({'__site_freq': 'prior'}))
Пример #53
0
def mendel_errors(call, pedigree) -> Tuple[Table, Table, Table, Table]:
    r"""Find Mendel errors; count per variant, individual and nuclear family.

    .. include:: ../_templates/req_tstring.rst

    .. include:: ../_templates/req_tvariant.rst

    .. include:: ../_templates/req_biallelic.rst

    Examples
    --------

    Find all violations of Mendelian inheritance in each (dad, mom, kid) trio in
    a pedigree and return four tables (all errors, errors by family, errors by
    individual, errors by variant):

    >>> ped = hl.Pedigree.read('data/trios.fam')
    >>> all_errors, per_fam, per_sample, per_variant = hl.mendel_errors(dataset['GT'], ped)

    Export all mendel errors to a text file:

    >>> all_errors.export('output/all_mendel_errors.tsv')

    Annotate columns with the number of Mendel errors:

    >>> annotated_samples = dataset.annotate_cols(mendel=per_sample[dataset.s])

    Annotate rows with the number of Mendel errors:

    >>> annotated_variants = dataset.annotate_rows(mendel=per_variant[dataset.locus, dataset.alleles])

    Notes
    -----

    The example above returns four tables, which contain Mendelian violations
    grouped in various ways. These tables are modeled after the `PLINK mendel
    formats <https://www.cog-genomics.org/plink2/formats#mendel>`_, resembling
    the ``.mendel``, ``.fmendel``, ``.imendel``, and ``.lmendel`` formats,
    respectively.

    **First table:** all Mendel errors. This table contains one row per Mendel
    error, keyed by the variant and proband id.

        - `locus` (:class:`.tlocus`) -- Variant locus, key field.
        - `alleles` (:class:`.tarray` of :py:data:`.tstr`) -- Variant alleles, key field.
        - (column key of `dataset`) (:py:data:`.tstr`) -- Proband ID, key field.
        - `fam_id` (:py:data:`.tstr`) -- Family ID.
        - `mendel_code` (:py:data:`.tint32`) -- Mendel error code, see below.

    **Second table:** errors per nuclear family. This table contains one row
    per nuclear family, keyed by the parents.

        - `pat_id` (:py:data:`.tstr`) -- Paternal ID. (key field)
        - `mat_id` (:py:data:`.tstr`) -- Maternal ID. (key field)
        - `fam_id` (:py:data:`.tstr`) -- Family ID.
        - `children` (:py:data:`.tint32`) -- Number of children in this nuclear family.
        - `errors` (:py:data:`.tint64`) -- Number of Mendel errors in this nuclear family.
        - `snp_errors` (:py:data:`.tint64`) -- Number of Mendel errors at SNPs in this
          nuclear family.

    **Third table:** errors per individual. This table contains one row per
    individual. Each error is counted toward the proband, father, and mother
    according to the `Implicated` in the table below.

        - (column key of `dataset`) (:py:data:`.tstr`) -- Sample ID (key field).
        - `fam_id` (:py:data:`.tstr`) -- Family ID.
        - `errors` (:py:data:`.tint64`) -- Number of Mendel errors involving this
          individual.
        - `snp_errors` (:py:data:`.tint64`) -- Number of Mendel errors involving this
          individual at SNPs.

    **Fourth table:** errors per variant.

        - `locus` (:class:`.tlocus`) -- Variant locus, key field.
        - `alleles` (:class:`.tarray` of :py:data:`.tstr`) -- Variant alleles, key field.
        - `errors` (:py:data:`.tint64`) -- Number of Mendel errors in this variant.

    This method only considers complete trios (two parents and proband with
    defined sex). The code of each Mendel error is determined by the table
    below, extending the
    `Plink classification <https://www.cog-genomics.org/plink2/basic_stats#mendel>`__.

    In the table, the copy state of a locus with respect to a trio is defined
    as follows, where PAR is the `pseudoautosomal region
    <https://en.wikipedia.org/wiki/Pseudoautosomal_region>`__ (PAR) of X and Y
    defined by the reference genome and the autosome is defined by
    :meth:`~hail.genetics.Locus.in_autosome`.

    - Auto -- in autosome or in PAR or female child
    - HemiX -- in non-PAR of X and male child
    - HemiY -- in non-PAR of Y and male child

    `Any` refers to the set \{ HomRef, Het, HomVar, NoCall \} and `~`
    denotes complement in this set.

    +------+---------+---------+--------+----------------------------+
    | Code | Dad     | Mom     | Kid    | Copy State | Implicated    |
    +======+=========+=========+========+============+===============+
    |    1 | HomVar  | HomVar  | Het    | Auto       | Dad, Mom, Kid |
    +------+---------+---------+--------+------------+---------------+
    |    2 | HomRef  | HomRef  | Het    | Auto       | Dad, Mom, Kid |
    +------+---------+---------+--------+------------+---------------+
    |    3 | HomRef  | ~HomRef | HomVar | Auto       | Dad, Kid      |
    +------+---------+---------+--------+------------+---------------+
    |    4 | ~HomRef | HomRef  | HomVar | Auto       | Mom, Kid      |
    +------+---------+---------+--------+------------+---------------+
    |    5 | HomRef  | HomRef  | HomVar | Auto       | Kid           |
    +------+---------+---------+--------+------------+---------------+
    |    6 | HomVar  | ~HomVar | HomRef | Auto       | Dad, Kid      |
    +------+---------+---------+--------+------------+---------------+
    |    7 | ~HomVar | HomVar  | HomRef | Auto       | Mom, Kid      |
    +------+---------+---------+--------+------------+---------------+
    |    8 | HomVar  | HomVar  | HomRef | Auto       | Kid           |
    +------+---------+---------+--------+------------+---------------+
    |    9 | Any     | HomVar  | HomRef | HemiX      | Mom, Kid      |
    +------+---------+---------+--------+------------+---------------+
    |   10 | Any     | HomRef  | HomVar | HemiX      | Mom, Kid      |
    +------+---------+---------+--------+------------+---------------+
    |   11 | HomVar  | Any     | HomRef | HemiY      | Dad, Kid      |
    +------+---------+---------+--------+------------+---------------+
    |   12 | HomRef  | Any     | HomVar | HemiY      | Dad, Kid      |
    +------+---------+---------+--------+------------+---------------+

    See Also
    --------
    :func:`.mendel_error_code`

    Parameters
    ----------
    dataset : :class:`.MatrixTable`
    pedigree : :class:`.Pedigree`

    Returns
    -------
    (:class:`.Table`, :class:`.Table`, :class:`.Table`, :class:`.Table`)
    """
    source = call._indices.source
    if not isinstance(source, MatrixTable):
        raise ValueError("'mendel_errors': expected 'call' to be an expression of 'MatrixTable', found {}".format(
            "expression of '{}'".format(source.__class__) if source is not None else 'scalar expression'))

    source = source.select_entries(__GT=call)
    dataset = require_biallelic(source, 'mendel_errors')
    tm = trio_matrix(dataset, pedigree, complete_trios=True)
    tm = tm.select_entries(mendel_code=hl.mendel_error_code(
        tm.locus,
        tm.is_female,
        tm.father_entry['__GT'],
        tm.mother_entry['__GT'],
        tm.proband_entry['__GT']
    ))
    ck_name = next(iter(source.col_key))
    tm = tm.filter_entries(hl.is_defined(tm.mendel_code))
    tm = tm.rename({'id' : ck_name})

    entries = tm.entries()

    table1 = entries.select('fam_id', 'mendel_code')

    fam_counts = (
        entries
            .group_by(pat_id=entries.father[ck_name], mat_id=entries.mother[ck_name])
            .partition_hint(min(entries.n_partitions(), 8))
            .aggregate(children=hl.len(hl.agg.collect_as_set(entries[ck_name])),
                       errors=hl.agg.count_where(hl.is_defined(entries.mendel_code)),
                       snp_errors=hl.agg.count_where(hl.is_snp(entries.alleles[0], entries.alleles[1]) &
                                                     hl.is_defined(entries.mendel_code)))
    )
    table2 = tm.key_cols_by().cols()
    table2 = table2.select(pat_id=table2.father[ck_name],
                           mat_id=table2.mother[ck_name],
                           fam_id=table2.fam_id,
                           **fam_counts[table2.father[ck_name], table2.mother[ck_name]])
    table2 = table2.key_by('pat_id', 'mat_id').distinct()
    table2 = table2.annotate(errors=hl.or_else(table2.errors, hl.int64(0)),
                             snp_errors=hl.or_else(table2.snp_errors, hl.int64(0)))

    # in implicated, idx 0 is dad, idx 1 is mom, idx 2 is child
    implicated = hl.literal([
        [0, 0, 0],  # dummy
        [1, 1, 1],
        [1, 1, 1],
        [1, 0, 1],
        [0, 1, 1],
        [0, 0, 1],
        [1, 0, 1],
        [0, 1, 1],
        [0, 0, 1],
        [0, 1, 1],
        [0, 1, 1],
        [1, 0, 1],
        [1, 0, 1],
    ], dtype=hl.tarray(hl.tarray(hl.tint64)))

    table3 = tm.annotate_cols(all_errors=hl.or_else(hl.agg.array_sum(implicated[tm.mendel_code]), [0, 0, 0]),
                              snp_errors=hl.or_else(
                                  hl.agg.filter(hl.is_snp(tm.alleles[0], tm.alleles[1]),
                                                hl.agg.array_sum(implicated[tm.mendel_code])),
                                  [0, 0, 0])).key_cols_by().cols()

    table3 = table3.select(xs=[
        hl.struct(**{ck_name: table3.father[ck_name],
                     'fam_id': table3.fam_id,
                     'errors': table3.all_errors[0],
                     'snp_errors': table3.snp_errors[0]}),
        hl.struct(**{ck_name: table3.mother[ck_name],
                     'fam_id': table3.fam_id,
                     'errors': table3.all_errors[1],
                     'snp_errors': table3.snp_errors[1]}),
        hl.struct(**{ck_name: table3.proband[ck_name],
                     'fam_id': table3.fam_id,
                     'errors': table3.all_errors[2],
                     'snp_errors': table3.snp_errors[2]}),
    ])
    table3 = table3.explode('xs')
    table3 = table3.select(**table3.xs)
    table3 = (table3.group_by(ck_name, 'fam_id')
              .aggregate(errors=hl.agg.sum(table3.errors),
                         snp_errors=hl.agg.sum(table3.snp_errors))
              .key_by(ck_name))

    table4 = tm.select_rows(errors=hl.agg.count_where(hl.is_defined(tm.mendel_code))).rows()

    return table1, table2, table3, table4
Пример #54
0
def locus_windows(locus_expr, radius, coord_expr=None, _localize=True):
    """Returns start and stop indices for window around each locus.

    Examples
    --------

    Windows with 2bp radius for one contig with positions 1, 2, 3, 4, 5:

    >>> starts, stops = hl.linalg.utils.locus_windows(
    ...     hl.balding_nichols_model(1, 5, 5).locus,
    ...     radius=2)
    >>> starts, stops
    (array([0, 0, 0, 1, 2]), array([3, 4, 5, 5, 5]))

    The following examples involve three contigs.

    >>> loci = [{'locus': hl.Locus('1', 1), 'cm': 1.0},
    ...         {'locus': hl.Locus('1', 2), 'cm': 3.0},
    ...         {'locus': hl.Locus('1', 4), 'cm': 4.0},
    ...         {'locus': hl.Locus('2', 1), 'cm': 2.0},
    ...         {'locus': hl.Locus('2', 1), 'cm': 2.0},
    ...         {'locus': hl.Locus('3', 3), 'cm': 5.0}]

    >>> ht = hl.Table.parallelize(
    ...         loci,
    ...         hl.tstruct(locus=hl.tlocus('GRCh37'), cm=hl.tfloat64),
    ...         key=['locus'])

    Windows with 1bp radius:

    >>> hl.linalg.utils.locus_windows(ht.locus, 1)
    (array([0, 0, 2, 3, 3, 5]), array([2, 2, 3, 5, 5, 6]))

    Windows with 1cm radius:

    >>> hl.linalg.utils.locus_windows(ht.locus, 1.0, coord_expr=ht.cm)
    (array([0, 1, 1, 3, 3, 5]), array([1, 3, 3, 5, 5, 6]))

    Notes
    -----
    This function returns two 1-dimensional ndarrays of integers,
    ``starts`` and ``stops``, each of size equal to the number of rows.

    By default, for all indices ``i``, ``[starts[i], stops[i])`` is the maximal
    range of row indices ``j`` such that ``contig[i] == contig[j]`` and
    ``position[i] - radius <= position[j] <= position[i] + radius``.

    If the :meth:`.global_position` on `locus_expr` is not in ascending order,
    this method will fail. Ascending order should hold for a matrix table keyed
    by locus or variant (and the associated row table), or for a table that has
    been ordered by `locus_expr`.

    Set `coord_expr` to use a value other than position to define the windows.
    This row-indexed numeric expression must be non-missing, non-``nan``, on the
    same source as `locus_expr`, and ascending with respect to locus
    position for each contig; otherwise the function will fail.

    The last example above uses centimorgan coordinates, so
    ``[starts[i], stops[i])`` is the maximal range of row indices ``j`` such
    that ``contig[i] == contig[j]`` and
    ``cm[i] - radius <= cm[j] <= cm[i] + radius``.

    Index ranges are start-inclusive and stop-exclusive. This function is
    especially useful in conjunction with
    :meth:`.BlockMatrix.sparsify_row_intervals`.

    Parameters
    ----------
    locus_expr : :class:`.LocusExpression`
        Row-indexed locus expression on a table or matrix table.
    radius: :obj:`int`
        Radius of window for row values.
    coord_expr: :class:`.Float64Expression`, optional
        Row-indexed numeric expression for the row value.
        Must be on the same table or matrix table as `locus_expr`.
        By default, the row value is given by the locus position.

    Returns
    -------
    (:class:`ndarray` of :obj:`int64`, :class:`ndarray` of :obj:`int64`)
        Tuple of start indices array and stop indices array.
    """
    if radius < 0:
        raise ValueError(f"locus_windows: 'radius' must be non-negative, found {radius}")
    check_row_indexed('locus_windows', locus_expr)
    if coord_expr is not None:
        check_row_indexed('locus_windows', coord_expr)

    src = locus_expr._indices.source
    if locus_expr not in src._fields_inverse:
        locus = Env.get_uid()
        annotate_fields = {locus: locus_expr}

        if coord_expr is not None:
            if coord_expr not in src._fields_inverse:
                coords = Env.get_uid()
                annotate_fields[coords] = coord_expr
            else:
                coords = src._fields_inverse[coord_expr]

        if isinstance(src, hl.MatrixTable):
            new_src = src.annotate_rows(**annotate_fields)
        else:
            new_src = src.annotate(**annotate_fields)

        locus_expr = new_src[locus]
        if coord_expr is not None:
            coord_expr = new_src[coords]

    if coord_expr is None:
        coord_expr = locus_expr.position

    rg = locus_expr.dtype.reference_genome
    contig_group_expr = hl.agg.group_by(hl.locus(locus_expr.contig, 1, reference_genome=rg), hl.agg.collect(coord_expr))

    # check loci are in sorted order
    last_pos = hl.fold(lambda a, elt: (hl.case()
                                         .when(a <= elt, elt)
                                         .or_error("locus_windows: 'locus_expr' global position must be in ascending order.")),
                       -1,
                       hl.agg.collect(hl.case()
                                        .when(hl.is_defined(locus_expr), locus_expr.global_position())
                                        .or_error("locus_windows: missing value for 'locus_expr'.")))
    checked_contig_groups = (hl.case()
                               .when(last_pos >= 0, contig_group_expr)
                               .or_error("locus_windows: 'locus_expr' has length 0"))

    contig_groups = locus_expr._aggregation_method()(checked_contig_groups, _localize=False)

    coords = hl.sorted(hl.array(contig_groups)).map(lambda t: t[1])
    starts_and_stops = hl._locus_windows_per_contig(coords, radius)

    if not _localize:
        return starts_and_stops

    starts, stops = hl.eval(starts_and_stops)
    return np.array(starts), np.array(stops)