Esempio n. 1
0
def rename_duplicates(dataset, name='unique_id') -> MatrixTable:
    """Rename duplicate column keys.

    .. include:: ../_templates/req_tstring.rst

    Examples
    --------

    >>> renamed = hl.rename_duplicates(dataset).cols()
    >>> duplicate_samples = (renamed.filter(renamed.s != renamed.unique_id)
    ...                             .select()
    ...                             .collect())

    Notes
    -----

    This method produces a new column field from the string column key by
    appending a unique suffix ``_N`` as necessary. For example, if the column
    key "NA12878" appears three times in the dataset, the first will produce
    "NA12878", the second will produce "NA12878_1", and the third will produce
    "NA12878_2". The name of this new field is parameterized by `name`.

    Parameters
    ----------
    dataset : :class:`.MatrixTable`
        Dataset.
    name : :obj:`str`
        Name of new field.

    Returns
    -------
    :class:`.MatrixTable`
    """

    require_col_key_str(dataset, 'rename_duplicates')
    ids = dataset.col_key[0].collect()
    uniques = set()
    mapping = []
    new_ids = []

    fmt = lambda s, i: '{}_{}'.format(s, i)
    for s in ids:
        s_ = s
        i = 0
        while s_ in uniques:
            i += 1
            s_ = fmt(s, i)

        if s_ != s:
            mapping.append((s, s_))
        uniques.add(s_)
        new_ids.append(s_)

    if mapping:
        info(f'Renamed {len(mapping)} duplicate {plural("sample ID", len(mapping))}. Mangled IDs as follows:' +
             ''.join(f'\n  "{pre}" => "{post}"' for pre, post in mapping))
    else:
        info('No duplicate sample IDs found.')
    uid = Env.get_uid()
    return dataset.annotate_cols(**{name: hl.literal(new_ids)[hl.int(hl.scan.count())]})
Esempio n. 2
0
    def test_collect_cols_by_key(self):
        mt = hl.utils.range_matrix_table(3, 3)
        col_dict = hl.literal({0: [1], 1: [2, 3], 2: [4, 5, 6]})
        mt = mt.annotate_cols(foo=col_dict.get(mt.col_idx)) \
            .explode_cols('foo')
        mt = mt.annotate_entries(bar=mt.row_idx * mt.foo)

        grouped = mt.collect_cols_by_key()

        self.assertListEqual(grouped.cols().order_by('col_idx').collect(),
                             [hl.Struct(col_idx=0, foo=[1]),
                              hl.Struct(col_idx=1, foo=[2, 3]),
                              hl.Struct(col_idx=2, foo=[4, 5, 6])])
        self.assertListEqual(
            grouped.entries().select('bar')
                .order_by('row_idx', 'col_idx').collect(),
            [hl.Struct(row_idx=0, col_idx=0, bar=[0]),
             hl.Struct(row_idx=0, col_idx=1, bar=[0, 0]),
             hl.Struct(row_idx=0, col_idx=2, bar=[0, 0, 0]),
             hl.Struct(row_idx=1, col_idx=0, bar=[1]),
             hl.Struct(row_idx=1, col_idx=1, bar=[2, 3]),
             hl.Struct(row_idx=1, col_idx=2, bar=[4, 5, 6]),
             hl.Struct(row_idx=2, col_idx=0, bar=[2]),
             hl.Struct(row_idx=2, col_idx=1, bar=[4, 6]),
             hl.Struct(row_idx=2, col_idx=2, bar=[8, 10, 12])])
Esempio n. 3
0
    def test_annotate_globals(self):
        mt = hl.utils.range_matrix_table(1, 1)
        ht = hl.utils.range_table(1, 1)
        data = [
            (5, hl.tint, operator.eq),
            (float('nan'), hl.tfloat32, lambda x, y: str(x) == str(y)),
            (float('inf'), hl.tfloat64, lambda x, y: str(x) == str(y)),
            (float('-inf'), hl.tfloat64, lambda x, y: str(x) == str(y)),
            (1.111, hl.tfloat64, operator.eq),
            ([hl.Struct(**{'a': None, 'b': 5}),
              hl.Struct(**{'a': 'hello', 'b': 10})], hl.tarray(hl.tstruct(a=hl.tstr, b=hl.tint)), operator.eq)
        ]

        for x, t, f in data:
            self.assertTrue(f(hl.eval(mt.annotate_globals(foo=hl.literal(x, t)).foo), x), f"{x}, {t}")
            self.assertTrue(f(hl.eval(ht.annotate_globals(foo=hl.literal(x, t)).foo), x), f"{x}, {t}")
Esempio n. 4
0
 def test_summarize_variants(self):
     mt = hl.utils.range_matrix_table(3, 3)
     variants = hl.literal({0: hl.Struct(locus=hl.Locus('1', 1), alleles=['A', 'T', 'C']),
                            1: hl.Struct(locus=hl.Locus('2', 1), alleles=['A', 'AT', '@']),
                            2: hl.Struct(locus=hl.Locus('2', 1), alleles=['AC', 'GT'])})
     mt = mt.annotate_rows(**variants[mt.row_idx]).key_rows_by('locus', 'alleles')
     r = hl.summarize_variants(mt, show=False)
     self.assertEqual(r.n_variants, 3)
     self.assertEqual(r.contigs, {'1': 1, '2': 2})
     self.assertEqual(r.allele_types, {'SNP': 2, 'MNP': 1, 'Unknown': 1, 'Insertion': 1})
     self.assertEqual(r.allele_counts, {2: 1, 3: 2})
Esempio n. 5
0
 def test_refs_with_process_joins(self):
     mt = hl.utils.range_matrix_table(10, 10)
     mt = mt.annotate_entries(
         a_literal=hl.literal(['a']),
         a_col_join=hl.is_defined(mt.cols()[mt.col_key]),
         a_row_join=hl.is_defined(mt.rows()[mt.row_key]),
         an_entry_join=hl.is_defined(mt[mt.row_key, mt.col_key]),
         the_global_failure=hl.cond(True, mt.globals, hl.null(mt.globals.dtype)),
         the_row_failure=hl.cond(True, mt.row, hl.null(mt.row.dtype)),
         the_col_failure=hl.cond(True, mt.col, hl.null(mt.col.dtype)),
         the_entry_failure=hl.cond(True, mt.entry, hl.null(mt.entry.dtype)),
     )
     mt.count()
Esempio n. 6
0
    def test_maximal_independent_set(self):
        # prefer to remove nodes with higher index
        t = hl.utils.range_table(10)
        graph = t.select(i=hl.int64(t.idx), j=hl.int64(t.idx + 10), bad_type=hl.float32(t.idx))

        mis_table = hl.maximal_independent_set(graph.i, graph.j, True, lambda l, r: l - r)
        mis = [row['node'] for row in mis_table.collect()]
        self.assertEqual(sorted(mis), list(range(0, 10)))
        self.assertEqual(mis_table.row.dtype, hl.tstruct(node=hl.tint64))
        self.assertEqual(mis_table.key.dtype, hl.tstruct(node=hl.tint64))

        self.assertRaises(ValueError, lambda: hl.maximal_independent_set(graph.i, graph.bad_type, True))
        self.assertRaises(ValueError, lambda: hl.maximal_independent_set(graph.i, hl.utils.range_table(10).idx, True))
        self.assertRaises(ValueError, lambda: hl.maximal_independent_set(hl.literal(1), hl.literal(2), True))
Esempio n. 7
0
    def overlaps(self, interval):
        """True if the the supplied interval contains any value in common with this one.

        Parameters
        ----------
        interval : :class:`.Interval`
            Interval object with the same point type.

        Returns
        -------
        :obj:`bool`
        """

        return hl.eval(hl.literal(self, hl.tinterval(self._point_type)).overlaps(interval))
Esempio n. 8
0
    def test_rename_duplicates(self):
        mt = hl.utils.range_matrix_table(5, 5)

        assert hl.rename_duplicates(
            mt.key_cols_by(s=hl.str(mt.col_idx))
        ).unique_id.collect() == ['0', '1', '2', '3', '4']

        assert hl.rename_duplicates(
            mt.key_cols_by(s='0')
        ).unique_id.collect() == ['0', '0_1', '0_2', '0_3', '0_4']

        assert hl.rename_duplicates(
            mt.key_cols_by(s=hl.literal(['0', '0_1', '0', '0_2', '0'])[mt.col_idx])
        ).unique_id.collect() == ['0', '0_1', '0_2', '0_2_1', '0_3']

        assert hl.rename_duplicates(
            mt.key_cols_by(s=hl.str(mt.col_idx)),
            'foo'
        )['foo'].dtype == hl.tstr
Esempio n. 9
0
    def contains(self, value):
        """True if `value` is contained within the interval.

        Examples
        --------

        >>> interval2.contains(5)
        True

        >>> interval2.contains(6)
        False

        Parameters
        ----------
        value :
            Object with type :meth:`.point_type`.

        Returns
        -------
        :obj:`bool`
        """

        return hl.eval(hl.literal(self, hl.tinterval(self._point_type)).contains(value))
Esempio n. 10
0
def import_gtf(path,
               reference_genome=None,
               skip_invalid_contigs=False,
               min_partitions=None) -> hl.Table:
    """Import a GTF file.

       The GTF file format is identical to the GFF version 2 file format,
       and so this function can be used to import GFF version 2 files as
       well.

       See https://www.ensembl.org/info/website/upload/gff.html for more
       details on the GTF/GFF2 file format.

       The :class:`.Table` returned by this function will be keyed by the
       ``interval`` row field and will include the following row fields:

       .. code-block:: text

           'source': str
           'feature': str
           'score': float64
           'strand': str
           'frame': int32
           'interval': interval<>

       There will also be corresponding fields for every tag found in the
       attribute field of the GTF file.

       Note
       ----

       This function will return an ``interval`` field of type :class:`.tinterval`
       constructed from the ``seqname``, ``start``, and ``end`` fields in the
       GTF file. This interval is inclusive of both the start and end positions
       in the GTF file. 

       If the ``reference_genome`` parameter is specified, the start and end
       points of the ``interval`` field will be of type :class:`.tlocus`.
       Otherwise, the start and end points of the ``interval`` field will be of
       type :class:`.tstruct` with fields ``seqname`` (type :class:`str`) and
       ``position`` (type :class:`.tint32`).

       Furthermore, if the ``reference_genome`` parameter is specified and
       ``skip_invalid_contigs`` is ``True``, this import function will skip
       lines in the GTF where ``seqname`` is not consistent with the reference
       genome specified.

       Example
       -------

       >>> ht = hl.experimental.import_gtf('data/test.gtf', 
       ...                                 reference_genome='GRCh37',
       ...                                 skip_invalid_contigs=True)

       >>> ht.describe()  # doctest: +NOTEST
       ----------------------------------------
       Global fields:
       None
       ----------------------------------------
       Row fields:
           'source': str
           'feature': str
           'score': float64
           'strand': str
           'frame': int32
           'gene_type': str
           'exon_id': str
           'havana_transcript': str
           'level': str
           'transcript_name': str
           'gene_status': str
           'gene_id': str
           'transcript_type': str
           'tag': str
           'transcript_status': str
           'gene_name': str
           'transcript_id': str
           'exon_number': str
           'havana_gene': str
           'interval': interval<locus<GRCh37>>
       ----------------------------------------
       Key: ['interval']
       ----------------------------------------

       Parameters
       ----------

       path : :obj:`str`
           File to import.
       reference_genome : :obj:`str` or :class:`.ReferenceGenome`, optional
           Reference genome to use.
       skip_invalid_contigs : :obj:`bool`
           If ``True`` and `reference_genome` is not ``None``, skip lines where
           ``seqname`` is not consistent with the reference genome.
       min_partitions : :obj:`int` or :obj:`None`
           Minimum number of partitions (passed to import_table).

       Returns
       -------
       :class:`.Table`
       """

    ht = hl.import_table(path,
                         min_partitions=min_partitions,
                         comment='#',
                         no_header=True,
                         types={
                             'f3': hl.tint,
                             'f4': hl.tint,
                             'f5': hl.tfloat,
                             'f7': hl.tint
                         },
                         missing='.',
                         delimiter='\t')

    ht = ht.rename({
        'f0': 'seqname',
        'f1': 'source',
        'f2': 'feature',
        'f3': 'start',
        'f4': 'end',
        'f5': 'score',
        'f6': 'strand',
        'f7': 'frame',
        'f8': 'attribute'
    })

    ht = ht.annotate(attribute=hl.dict(
        hl.map(
            lambda x: (x.split(' ')[0], x.split(' ')[1].replace('"', '').
                       replace(';$', '')), ht['attribute'].split('; '))))

    attributes = ht.aggregate(
        hl.agg.explode(lambda x: hl.agg.collect_as_set(x),
                       ht['attribute'].keys()))

    ht = ht.transmute(
        **{
            x: hl.or_missing(ht['attribute'].contains(x), ht['attribute'][x])
            for x in attributes if x
        })

    if reference_genome:
        if reference_genome == 'GRCh37':
            ht = ht.annotate(seqname=ht['seqname'].replace('^chr', ''))
        else:
            ht = ht.annotate(seqname=hl.case().when(
                ht['seqname'].startswith('HLA'), ht['seqname']).when(
                    ht['seqname'].startswith('chrHLA'), ht['seqname'].replace(
                        '^chr', '')).when(ht['seqname'].startswith(
                            'chr'), ht['seqname']).default('chr' +
                                                           ht['seqname']))
        if skip_invalid_contigs:
            valid_contigs = hl.literal(
                set(hl.get_reference(reference_genome).contigs))
            ht = ht.filter(valid_contigs.contains(ht['seqname']))
        ht = ht.transmute(
            interval=hl.locus_interval(ht['seqname'],
                                       ht['start'],
                                       ht['end'],
                                       includes_start=True,
                                       includes_end=True,
                                       reference_genome=reference_genome))
    else:
        ht = ht.transmute(interval=hl.interval(
            hl.struct(seqname=ht['seqname'], position=ht['start']),
            hl.struct(seqname=ht['seqname'], position=ht['end']),
            includes_start=True,
            includes_end=True))

    ht = ht.key_by('interval')

    return ht
Esempio n. 11
0
def main(args):
    hl.init(log="/variant_qc_evaluation.log")

    if args.create_bin_ht:
        create_bin_ht(
            args.model_id,
            args.n_bins,
        ).write(
            get_score_bins(args.model_id, aggregated=False).path,
            overwrite=args.overwrite,
        )

    if args.run_sanity_checks:
        ht = get_score_bins(args.model_id, aggregated=False).ht()
        logger.info("Running sanity checks...")
        print(
            ht.aggregate(
                hl.struct(
                    was_biallelic=hl.agg.counter(~ht.was_split),
                    has_biallelic_rank=hl.agg.counter(
                        hl.is_defined(ht.biallelic_bin)),
                    was_singleton=hl.agg.counter(ht.singleton),
                    has_singleton_rank=hl.agg.counter(
                        hl.is_defined(ht.singleton_bin)),
                    was_biallelic_singleton=hl.agg.counter(ht.singleton
                                                           & ~ht.was_split),
                    has_biallelic_singleton_rank=hl.agg.counter(
                        hl.is_defined(ht.biallelic_singleton_bin)),
                )))

    if args.create_aggregated_bin_ht:
        logger.warning(
            "Use only workers, it typically crashes with preemptibles")
        create_aggregated_bin_ht(args.model_id).write(
            get_score_bins(args.model_id, aggregated=True).path,
            overwrite=args.overwrite,
        )

    if args.extract_truth_samples:
        logger.info(f"Extracting truth samples from MT...")
        mt = get_gnomad_v3_mt(key_by_locus_and_alleles=True,
                              remove_hard_filtered_samples=False)

        mt = mt.filter_cols(
            hl.literal([v["s"]
                        for k, v in TRUTH_SAMPLES.items()]).contains(mt.s))
        mt = hl.experimental.sparse_split_multi(mt, filter_changed_loci=True)

        # Checkpoint to prevent needing to go through the large table a second time
        mt = mt.checkpoint(
            get_checkpoint_path("truth_samples", mt=True),
            overwrite=args.overwrite,
        )

        for truth_sample in TRUTH_SAMPLES:
            truth_sample_mt = mt.filter_cols(
                mt.s == TRUTH_SAMPLES[truth_sample]["s"])
            # Filter to variants in truth data
            truth_sample_mt = truth_sample_mt.filter_rows(
                hl.agg.any(truth_sample_mt.GT.is_non_ref()))
            truth_sample_mt.naive_coalesce(args.n_partitions).write(
                get_callset_truth_data(truth_sample).path,
                overwrite=args.overwrite,
            )

    if args.merge_with_truth_data:
        for truth_sample in TRUTH_SAMPLES:
            logger.info(
                f"Creating a merged table with callset truth sample and truth data for {truth_sample}..."
            )

            # Load truth data
            mt = get_callset_truth_data(truth_sample).mt()
            truth_hc_intervals = TRUTH_SAMPLES[truth_sample]["hc_intervals"]
            truth_mt = TRUTH_SAMPLES[truth_sample]["truth_mt"]
            truth_mt = truth_mt.key_cols_by(
                s=hl.str(TRUTH_SAMPLES[truth_sample]["s"]))

            # Remove low quality sites
            info_ht = get_info(split=True).ht()
            mt = mt.filter_rows(~info_ht[mt.row_key].AS_lowqual)

            ht = create_truth_sample_ht(mt, truth_mt, truth_hc_intervals)
            ht.write(
                get_callset_truth_data(truth_sample, mt=False).path,
                overwrite=args.overwrite,
            )

    if args.bin_truth_sample_concordance:
        for truth_sample in TRUTH_SAMPLES:
            logger.info(
                f"Creating binned concordance table for {truth_sample} for model {args.model_id}"
            )
            ht = get_callset_truth_data(truth_sample, mt=False).ht()

            info_ht = get_info(split=True).ht()
            ht = ht.filter(
                ~info_ht[ht.key].AS_lowqual
                & ~hl.is_defined(telomeres_and_centromeres.ht()[ht.locus]))

            logger.info("Filtering out low confidence regions and segdups...")
            ht = filter_low_conf_regions(
                ht,
                filter_lcr=True,
                # TODO: Uncomment when we have decoy path
                filter_decoy=False,  # True,
                filter_segdup=True,
            )

            logger.info(
                "Loading HT containing RF or VQSR scores annotated with a bin based on the rank of score..."
            )
            metric_ht = get_score_bins(args.model_id, aggregated=False).ht()
            ht = ht.filter(hl.is_defined(metric_ht[ht.key]))

            ht = ht.annotate(score=metric_ht[ht.key].score)

            ht = compute_binned_truth_sample_concordance(
                ht, metric_ht, args.n_bins)
            ht.write(
                get_binned_concordance(args.model_id, truth_sample).path,
                overwrite=args.overwrite,
            )
Esempio n. 12
0
def get_het_hom_summary_dict(
    csq_set: Set[str],
    most_severe_csq_expr: hl.expr.StringExpression,
    defined_sites_expr: hl.expr.Int64Expression,
    num_homs_expr: hl.expr.Int64Expression,
    num_hets_expr: hl.expr.Int64Expression,
    pop_expr: hl.expr.StringExpression,
) -> Dict[str, hl.expr.Int64Expression]:
    """
    Generate dictionary containing summary counts.

    Summary counts are:
        - Number of sites with defined genotype calls
        - Number of samples with heterozygous calls
        - Number of samples with homozygous calls

    Function has option to generate counts by population.

    :param csq_set: Set containing transcript consequence string(s).
    :param most_severe_csq_expr: StringExpression containing most severe consequence.
    :param defined_sites_expr: Int64Expression containing number of sites with defined genotype calls.
    :param num_homs_expr: Int64Expression containing number of samples with homozygous genotype calls.
    :param num_hets_expr: Int64Expression containing number of samples with heterozygous genotype calls.
    :param pop_expr: StringExpression containing sample population labels.
    :return: Dictionary of summary annotation names and their values.
    """
    csq_filter_expr = hl.literal(csq_set).contains(most_severe_csq_expr)
    return {
        "no_alt_calls":
        hl.agg.count_where((csq_filter_expr)
                           & (defined_sites_expr > 0)
                           & (num_homs_expr + num_hets_expr == 0)),
        "obs_het":
        hl.agg.count_where((csq_filter_expr) & (num_homs_expr == 0)
                           & (num_hets_expr > 0)),
        "obs_hom":
        hl.agg.count_where((csq_filter_expr) & (num_homs_expr > 0)),
        "defined":
        hl.agg.count_where((csq_filter_expr) & (defined_sites_expr > 0)),
        "pop_no_alt_calls":
        hl.agg.group_by(
            pop_expr,
            hl.agg.count_where((csq_filter_expr)
                               & (defined_sites_expr > 0)
                               & (num_homs_expr + num_hets_expr == 0)),
        ),
        "pop_obs_het":
        hl.agg.group_by(
            pop_expr,
            hl.agg.count_where((csq_filter_expr) & (num_homs_expr == 0)
                               & (num_hets_expr > 0)),
        ),
        "pop_obs_hom":
        hl.agg.group_by(
            pop_expr,
            hl.agg.count_where((csq_filter_expr) & (num_homs_expr > 0)),
        ),
        "pop_defined":
        hl.agg.group_by(
            pop_expr,
            hl.agg.count_where((csq_filter_expr) & (defined_sites_expr > 0)),
        ),
    }
Esempio n. 13
0
File: qc.py Progetto: troels/hail
 def n_discordant(counter):
     return hl.sum(
         hl.array(counter).filter(lambda tup: ~hl.literal(
             {i**2
              for i in range(5)}).contains(tup[0])).map(lambda tup: tup[1]))
Esempio n. 14
0
def transmission_disequilibrium_test(dataset, pedigree) -> Table:
    r"""Performs the transmission disequilibrium test on trios.

    .. include:: ../_templates/req_tstring.rst

    .. include:: ../_templates/req_tvariant.rst

    .. include:: ../_templates/req_biallelic.rst

    Examples
    --------
    Compute TDT association statistics and show the first two results:
    
    >>> pedigree = hl.Pedigree.read('data/tdt_trios.fam')
    >>> tdt_table = hl.transmission_disequilibrium_test(tdt_dataset, pedigree)
    >>> tdt_table.show(2)  # doctest: +NOTEST
    +---------------+------------+-------+-------+----------+----------+
    | locus         | alleles    |     t |     u |   chi_sq |  p_value |
    +---------------+------------+-------+-------+----------+----------+
    | locus<GRCh37> | array<str> | int64 | int64 |  float64 |  float64 |
    +---------------+------------+-------+-------+----------+----------+
    | 1:246714629   | ["C","A"]  |     0 |     4 | 4.00e+00 | 4.55e-02 |
    | 2:167262169   | ["T","C"]  |    NA |    NA |       NA |       NA |
    +---------------+------------+-------+-------+----------+----------+

    Export variants with p-values below 0.001:

    >>> tdt_table = tdt_table.filter(tdt_table.p_value < 0.001)
    >>> tdt_table.export("output/tdt_results.tsv")

    Notes
    -----
    The
    `transmission disequilibrium test <https://en.wikipedia.org/wiki/Transmission_disequilibrium_test#The_case_of_trios:_one_affected_child_per_family>`__
    compares the number of times the alternate allele is transmitted (t) versus
    not transmitted (u) from a heterozgyous parent to an affected child. The null
    hypothesis holds that each case is equally likely. The TDT statistic is given by

    .. math::

        (t - u)^2 \over (t + u)

    and asymptotically follows a chi-squared distribution with one degree of
    freedom under the null hypothesis.

    :func:`transmission_disequilibrium_test` only considers complete trios (two
    parents and a proband with defined sex) and only returns results for the
    autosome, as defined by :meth:`~hail.genetics.Locus.in_autosome`, and
    chromosome X. Transmissions and non-transmissions are counted only for the
    configurations of genotypes and copy state in the table below, in order to
    filter out Mendel errors and configurations where transmission is
    guaranteed. The copy state of a locus with respect to a trio is defined as
    follows:

    - Auto -- in autosome or in PAR of X or female child
    - HemiX -- in non-PAR of X and male child

    Here PAR is the `pseudoautosomal region
    <https://en.wikipedia.org/wiki/Pseudoautosomal_region>`__
    of X and Y defined by :class:`.ReferenceGenome`, which many variant callers
    map to chromosome X.

    +--------+--------+--------+------------+---+---+
    |  Kid   | Dad    | Mom    | Copy State | t | u |
    +========+========+========+============+===+===+
    | HomRef | Het    | Het    | Auto       | 0 | 2 |
    +--------+--------+--------+------------+---+---+
    | HomRef | HomRef | Het    | Auto       | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | HomRef | Het    | HomRef | Auto       | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | Het    | Het    | Het    | Auto       | 1 | 1 |
    +--------+--------+--------+------------+---+---+
    | Het    | HomRef | Het    | Auto       | 1 | 0 |
    +--------+--------+--------+------------+---+---+
    | Het    | Het    | HomRef | Auto       | 1 | 0 |
    +--------+--------+--------+------------+---+---+
    | Het    | HomVar | Het    | Auto       | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | Het    | Het    | HomVar | Auto       | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | HomVar | Het    | Het    | Auto       | 2 | 0 |
    +--------+--------+--------+------------+---+---+
    | HomVar | Het    | HomVar | Auto       | 1 | 0 |
    +--------+--------+--------+------------+---+---+
    | HomVar | HomVar | Het    | Auto       | 1 | 0 |
    +--------+--------+--------+------------+---+---+
    | HomRef | HomRef | Het    | HemiX      | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | HomRef | HomVar | Het    | HemiX      | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | HomVar | HomRef | Het    | HemiX      | 1 | 0 |
    +--------+--------+--------+------------+---+---+
    | HomVar | HomVar | Het    | HemiX      | 1 | 0 |
    +--------+--------+--------+------------+---+---+

    :func:`tdt` produces a table with the following columns:

     - `locus` (:class:`.tlocus`) -- Locus.
     - `alleles` (:class:`.tarray` of :py:data:`.tstr`) -- Alleles.
     - `t` (:py:data:`.tint32`) -- Number of transmitted alternate alleles.
     - `u` (:py:data:`.tint32`) -- Number of untransmitted alternate alleles.
     - `chi_sq` (:py:data:`.tfloat64`) -- TDT statistic.
     - `p_value` (:py:data:`.tfloat64`) -- p-value.

    Parameters
    ----------
    dataset : :class:`.MatrixTable`
        Dataset.
    pedigree : :class:`~hail.genetics.Pedigree`
        Sample pedigree.

    Returns
    -------
    :class:`.Table`
        Table of TDT results.
    """

    dataset = require_biallelic(dataset, 'transmission_disequilibrium_test')
    dataset = dataset.annotate_rows(auto_or_x_par = dataset.locus.in_autosome() | dataset.locus.in_x_par())
    dataset = dataset.filter_rows(dataset.auto_or_x_par | dataset.locus.in_x_nonpar())

    hom_ref = 0
    het = 1
    hom_var = 2

    auto = 2
    hemi_x = 1

    #                     kid,     dad,     mom,   copy, t, u
    config_counts = [(hom_ref,     het,     het,   auto, 0, 2),
                     (hom_ref, hom_ref,     het,   auto, 0, 1),
                     (hom_ref,     het, hom_ref,   auto, 0, 1),
                     (    het,     het,     het,   auto, 1, 1),
                     (    het, hom_ref,     het,   auto, 1, 0),
                     (    het,     het, hom_ref,   auto, 1, 0),
                     (    het, hom_var,     het,   auto, 0, 1),
                     (    het,     het, hom_var,   auto, 0, 1),
                     (hom_var,     het,     het,   auto, 2, 0),
                     (hom_var,     het, hom_var,   auto, 1, 0),
                     (hom_var, hom_var,     het,   auto, 1, 0),
                     (hom_ref, hom_ref,     het, hemi_x, 0, 1),
                     (hom_ref, hom_var,     het, hemi_x, 0, 1),
                     (hom_var, hom_ref,     het, hemi_x, 1, 0),
                     (hom_var, hom_var,     het, hemi_x, 1, 0)]

    count_map = hl.literal({(c[0], c[1], c[2], c[3]): [c[4], c[5]] for c in config_counts})

    tri = trio_matrix(dataset, pedigree, complete_trios=True)

    # this filter removes mendel error of het father in x_nonpar. It also avoids
    #   building and looking up config in common case that neither parent is het
    father_is_het = tri.father_entry.GT.is_het()
    parent_is_valid_het = ((father_is_het & tri.auto_or_x_par) |
                           (tri.mother_entry.GT.is_het() & ~father_is_het))

    copy_state = hl.cond(tri.auto_or_x_par | tri.is_female, 2, 1)

    config = (tri.proband_entry.GT.n_alt_alleles(),
              tri.father_entry.GT.n_alt_alleles(),
              tri.mother_entry.GT.n_alt_alleles(),
              copy_state)

    tri = tri.annotate_rows(counts = agg.filter(parent_is_valid_het, agg.array_sum(count_map.get(config))))

    tab = tri.rows().select('counts')
    tab = tab.transmute(t = tab.counts[0], u = tab.counts[1])
    tab = tab.annotate(chi_sq = ((tab.t - tab.u) ** 2) / (tab.t + tab.u))
    tab = tab.annotate(p_value = hl.pchisqtail(tab.chi_sq, 1.0))

    return tab.cache()
Esempio n. 15
0
def init(doctest_namespace):
    # This gets run once per process -- must avoid race conditions
    print("setting up doctest...")

    olddir = os.getcwd()
    os.chdir("docs/")

    doctest_namespace['hl'] = hl
    doctest_namespace['agg'] = agg

    if not os.path.isdir("output/"):
        try:
            os.mkdir("output/")
        except OSError:
            pass

    files = ["sample.vds", "sample.qc.vds", "sample.filtered.vds"]
    for f in files:
        if os.path.isdir(f):
            shutil.rmtree(f)

    ds = hl.read_matrix_table('data/example.vds')
    doctest_namespace['ds'] = ds
    doctest_namespace['dataset'] = ds
    doctest_namespace['dataset2'] = ds.annotate_globals(global_field=5)
    doctest_namespace['dataset_to_union_1'] = ds
    doctest_namespace['dataset_to_union_2'] = ds

    v_metadata = ds.rows().annotate_globals(global_field=5).annotate(consequence='SYN')
    doctest_namespace['v_metadata'] = v_metadata

    s_metadata = ds.cols().annotate(pop='AMR', is_case=False, sex='F')
    doctest_namespace['s_metadata'] = s_metadata

    # Table
    table1 = hl.import_table('data/kt_example1.tsv', impute=True, key='ID')
    table1 = table1.annotate_globals(global_field_1=5, global_field_2=10)
    doctest_namespace['table1'] = table1
    doctest_namespace['other_table'] = table1

    table2 = hl.import_table('data/kt_example2.tsv', impute=True, key='ID')
    doctest_namespace['table2'] = table2

    table4 = hl.import_table('data/kt_example4.tsv', impute=True,
                             types={'B': hl.tstruct(B0=hl.tbool, B1=hl.tstr),
                                    'D': hl.tstruct(cat=hl.tint32, dog=hl.tint32),
                                    'E': hl.tstruct(A=hl.tint32, B=hl.tint32)})
    doctest_namespace['table4'] = table4

    people_table = hl.import_table('data/explode_example.tsv', delimiter='\\s+',
                                   types={'Age': hl.tint32, 'Children': hl.tarray(hl.tstr)},
                                   key='Name')
    doctest_namespace['people_table'] = people_table

    # TDT
    doctest_namespace['tdt_dataset'] = hl.import_vcf('data/tdt_tiny.vcf')

    ds2 = hl.variant_qc(ds)
    doctest_namespace['ds2'] = ds2.select_rows(AF=ds2.variant_qc.AF)

    # Expressions
    doctest_namespace['names'] = hl.literal(['Alice', 'Bob', 'Charlie'])
    doctest_namespace['a1'] = hl.literal([0, 1, 2, 3, 4, 5])
    doctest_namespace['a2'] = hl.literal([1, -1, 1, -1, 1, -1])
    doctest_namespace['t'] = hl.literal(True)
    doctest_namespace['f'] = hl.literal(False)
    doctest_namespace['na'] = hl.null(hl.tbool)
    doctest_namespace['call'] = hl.call(0, 1, phased=False)
    doctest_namespace['a'] = hl.literal([1, 2, 3, 4, 5])
    doctest_namespace['d'] = hl.literal({'Alice': 43, 'Bob': 33, 'Charles': 44})
    doctest_namespace['interval'] = hl.interval(3, 11)
    doctest_namespace['locus_interval'] = hl.parse_locus_interval("1:53242-90543")
    doctest_namespace['locus'] = hl.locus('1', 1034245)
    doctest_namespace['x'] = hl.literal(3)
    doctest_namespace['y'] = hl.literal(4.5)
    doctest_namespace['s1'] = hl.literal({1, 2, 3})
    doctest_namespace['s2'] = hl.literal({1, 3, 5})
    doctest_namespace['s3'] = hl.literal({'Alice', 'Bob', 'Charlie'})
    doctest_namespace['struct'] = hl.struct(a=5, b='Foo')
    doctest_namespace['tup'] = hl.literal(("a", 1, [1, 2, 3]))
    doctest_namespace['s'] = hl.literal('The quick brown fox')
    doctest_namespace['interval2'] = hl.Interval(3, 6)

    # Overview
    doctest_namespace['ht'] = hl.import_table("data/kt_example1.tsv", impute=True)
    doctest_namespace['mt'] = ds

    gnomad_data = ds.rows()
    doctest_namespace['gnomad_data'] = gnomad_data.select(gnomad_data.info.AF)

    # BGEN
    bgen = hl.import_bgen('data/example.8bits.bgen',
                          entry_fields=['GT', 'GP', 'dosage'])
    doctest_namespace['variants_table'] = bgen.rows()

    print("finished setting up doctest...")
    yield
    os.chdir(olddir)
Esempio n. 16
0
def make_betas(mt, h2, pi=None, annot=None, rg=None):
    r"""Generates betas under different models.

    Simulates betas (SNP effects) under the infinitesimal, spike & slab, or
    annotation-informed models, depending on parameters passed.

    Parameters
    ----------
    mt : :class:`.MatrixTable`
        MatrixTable containing genotypes to be used. Also should contain
        variant annotations as row fields if running the annotation-informed
        model or covariates as column fields if adding population stratification.
    h2 : :obj:`float` or :obj:`int` or :obj:`list` or :class:`numpy.ndarray`
        SNP-based heritability of simulated trait(s).
    pi : :obj:`float` or :obj:`int` or :obj:`list` or :class:`numpy.ndarray`, optional
        Probability of SNP being causal when simulating under the spike & slab
        model. If doing two-trait spike & slab `pi` is a list of probabilities for
        overlapping causal SNPs (see docstring of :func:`.multitrait_ss`)
    annot : :class:`.Expression`, optional
        Row field of aggregated annotations for annotation-informed model.
    rg : :obj:`float` or :obj:`int` or :obj:`list` or :class:`numpy.ndarray`, optional
        Genetic correlation between traits.

    Returns
    -------
    mt : :class:`.MatrixTable`
        :class:`.MatrixTable` with betas as a row field, simulated according to specified model.
    pi : :obj:`list`
        Probability of a SNP being causal for different traits, possibly altered
        from input `pi` if covariance matrix for multitrait simulation was not
        positive semi-definite.
    rg : :obj:`list`
        Genetic correlation between traits, possibly altered from input `rg` if
        covariance matrix for multitrait simulation was not positive semi-definite.

    """
    h2 = h2.tolist() if type(h2) is np.ndarray else (
        [h2] if type(h2) is not list else h2)
    pi = pi.tolist() if type(pi) is np.ndarray else (
        [pi] if type(pi) is not list else pi)
    rg = rg.tolist() if type(rg) is np.ndarray else (
        [rg] if type(rg) is not list else rg)
    assert (all(x >= 0 and x <= 1
                for x in h2)), 'h2 values must be between 0 and 1'
    assert (pi is not [None]) or all(
        x >= 0 and x <= 1
        for x in pi), 'pi values for spike & slab must be between 0 and 1'
    assert (rg == [None]
            or all(x >= -1 and x <= 1
                   for x in rg)), 'rg values must be between -1 and 1 or None'
    if annot is not None:  # multi-trait annotation-informed
        assert rg == [
            None
        ], 'Correlated traits not supported for annotation-informed model'
        h2 = h2 if type(h2) is list else [h2]
        annot_sum = mt.aggregate_rows(hl.agg.sum(annot))
        mt = mt.annotate_rows(beta=hl.literal(h2).map(
            lambda x: hl.rand_norm(0, hl.sqrt(annot * x / (annot_sum * M)))))
    elif len(h2) > 1 and (pi == [None] or pi
                          == [1]):  # multi-trait correlated infinitesimal
        mt, rg = multitrait_inf(mt=mt, h2=h2, rg=rg)
    elif len(h2) == 2 and len(pi) > 1 and len(
            rg) == 1:  # two trait correlated spike & slab
        print('multitrait ss')
        mt, pi, rg = multitrait_ss(mt=mt,
                                   h2=h2,
                                   rg=0 if rg is [None] else rg[0],
                                   pi=pi)
    elif len(h2) == 1 and len(
            pi) == 1:  # single trait infinitesimal/spike & slab
        M = mt.count_rows()
        pi_temp = 1 if pi == [None] else pi[0]
        mt = mt.annotate_rows(beta=hl.rand_bool(pi_temp) *
                              hl.rand_norm(0, hl.sqrt(h2[0] / (M * pi_temp))))
    else:
        raise ValueError('Parameters passed do not match any models.')
    return mt, pi, rg
Esempio n. 17
0
def test_round_trip_basics():
    assert_round_trip(hl.literal(1))
Esempio n. 18
0
    def test_locus_windows(self):
        def assert_eq(a, b):
            self.assertTrue(np.array_equal(a, np.array(b)))

        centimorgans = hl.literal([0.1, 1.0, 1.0, 1.5, 1.9])

        mt = hl.balding_nichols_model(1, 5, 5).add_row_index()
        mt = mt.annotate_rows(cm=centimorgans[hl.int32(mt.row_idx)]).cache()

        starts, stops = hl.linalg.utils.locus_windows(mt.locus, 2)
        assert_eq(starts, [0, 0, 0, 1, 2])
        assert_eq(stops, [3, 4, 5, 5, 5])

        starts, stops = hl.linalg.utils.locus_windows(mt.locus, 0.5, coord_expr=mt.cm)
        assert_eq(starts, [0, 1, 1, 1, 3])
        assert_eq(stops, [1, 4, 4, 5, 5])

        starts, stops = hl.linalg.utils.locus_windows(mt.locus, 1.0, coord_expr=2 * centimorgans[hl.int32(mt.row_idx)])
        assert_eq(starts, [0, 1, 1, 1, 3])
        assert_eq(stops, [1, 4, 4, 5, 5])

        rows = [{'locus': hl.Locus('1', 1), 'cm': 1.0},
                {'locus': hl.Locus('1', 2), 'cm': 3.0},
                {'locus': hl.Locus('1', 4), 'cm': 4.0},
                {'locus': hl.Locus('2', 1), 'cm': 2.0},
                {'locus': hl.Locus('2', 1), 'cm': 2.0},
                {'locus': hl.Locus('3', 3), 'cm': 5.0}]

        ht = hl.Table.parallelize(rows,
                                  hl.tstruct(locus=hl.tlocus('GRCh37'), cm=hl.tfloat64),
                                  key=['locus'])

        starts, stops = hl.linalg.utils.locus_windows(ht.locus, 1)
        assert_eq(starts, [0, 0, 2, 3, 3, 5])
        assert_eq(stops, [2, 2, 3, 5, 5, 6])

        starts, stops = hl.linalg.utils.locus_windows(ht.locus, 1.0, coord_expr=ht.cm)
        assert_eq(starts, [0, 1, 1, 3, 3, 5])
        assert_eq(stops, [1, 3, 3, 5, 5, 6])

        with self.assertRaises(ValueError) as cm:
            hl.linalg.utils.locus_windows(ht.order_by(ht.cm).locus, 1.0)
        self.assertTrue('ascending order' in str(cm.exception))

        with self.assertRaises(ExpressionException) as cm:
            hl.linalg.utils.locus_windows(ht.locus, 1.0, coord_expr=hl.utils.range_table(1).idx)
        self.assertTrue('different source' in str(cm.exception))

        with self.assertRaises(ExpressionException) as cm:
            hl.linalg.utils.locus_windows(hl.locus('1', 1), 1.0)
        self.assertTrue("no source" in str(cm.exception))

        with self.assertRaises(ExpressionException) as cm:
            hl.linalg.utils.locus_windows(ht.locus, 1.0, coord_expr=0.0)
        self.assertTrue("no source" in str(cm.exception))

        ht = ht.annotate_globals(x = hl.locus('1', 1), y = 1.0)
        with self.assertRaises(ExpressionException) as cm:
            hl.linalg.utils.locus_windows(ht.x, 1.0)
        self.assertTrue("row-indexed" in str(cm.exception))
        with self.assertRaises(ExpressionException) as cm:
            hl.linalg.utils.locus_windows(ht.locus, 1.0, ht.y)
        self.assertTrue("row-indexed" in str(cm.exception))

        ht = hl.Table.parallelize([{'locus': hl.null(hl.tlocus()), 'cm': 1.0}],
                                  hl.tstruct(locus=hl.tlocus('GRCh37'), cm=hl.tfloat64), key=['locus'])
        with self.assertRaises(ValueError) as cm:
            hl.linalg.utils.locus_windows(ht.locus, 1.0)
        self.assertTrue("missing value for 'locus_expr'" in str(cm.exception))
        with self.assertRaises(ValueError) as cm:
            hl.linalg.utils.locus_windows(ht.locus, 1.0, coord_expr=ht.cm)
        self.assertTrue("missing value for 'locus_expr'" in str(cm.exception))

        ht = hl.Table.parallelize([{'locus': hl.Locus('1', 1), 'cm': hl.null(hl.tfloat64)}],
                                  hl.tstruct(locus=hl.tlocus('GRCh37'), cm=hl.tfloat64), key=['locus'])
        with self.assertRaises(ValueError) as cm:
            hl.linalg.utils.locus_windows(ht.locus, 1.0, coord_expr=ht.cm)
        self.assertTrue("missing value for 'coord_expr'" in str(cm.exception))
Esempio n. 19
0
def mendel_errors(call, pedigree) -> Tuple[Table, Table, Table, Table]:
    r"""Find Mendel errors; count per variant, individual and nuclear family.

    .. include:: ../_templates/req_tstring.rst

    .. include:: ../_templates/req_tvariant.rst

    .. include:: ../_templates/req_biallelic.rst

    Examples
    --------

    Find all violations of Mendelian inheritance in each (dad, mom, kid) trio in
    a pedigree and return four tables (all errors, errors by family, errors by
    individual, errors by variant):

    >>> ped = hl.Pedigree.read('data/trios.fam')
    >>> all_errors, per_fam, per_sample, per_variant = hl.mendel_errors(dataset['GT'], ped)

    Export all mendel errors to a text file:

    >>> all_errors.export('output/all_mendel_errors.tsv')

    Annotate columns with the number of Mendel errors:

    >>> annotated_samples = dataset.annotate_cols(mendel=per_sample[dataset.s])

    Annotate rows with the number of Mendel errors:

    >>> annotated_variants = dataset.annotate_rows(mendel=per_variant[dataset.locus, dataset.alleles])

    Notes
    -----

    The example above returns four tables, which contain Mendelian violations
    grouped in various ways. These tables are modeled after the `PLINK mendel
    formats <https://www.cog-genomics.org/plink2/formats#mendel>`_, resembling
    the ``.mendel``, ``.fmendel``, ``.imendel``, and ``.lmendel`` formats,
    respectively.

    **First table:** all Mendel errors. This table contains one row per Mendel
    error, keyed by the variant and proband id.

        - `locus` (:class:`.tlocus`) -- Variant locus, key field.
        - `alleles` (:class:`.tarray` of :py:data:`.tstr`) -- Variant alleles, key field.
        - (column key of `dataset`) (:py:data:`.tstr`) -- Proband ID, key field.
        - `fam_id` (:py:data:`.tstr`) -- Family ID.
        - `mendel_code` (:py:data:`.tint32`) -- Mendel error code, see below.

    **Second table:** errors per nuclear family. This table contains one row
    per nuclear family, keyed by the parents.

        - `pat_id` (:py:data:`.tstr`) -- Paternal ID. (key field)
        - `mat_id` (:py:data:`.tstr`) -- Maternal ID. (key field)
        - `fam_id` (:py:data:`.tstr`) -- Family ID.
        - `children` (:py:data:`.tint32`) -- Number of children in this nuclear family.
        - `errors` (:py:data:`.tint64`) -- Number of Mendel errors in this nuclear family.
        - `snp_errors` (:py:data:`.tint64`) -- Number of Mendel errors at SNPs in this
          nuclear family.

    **Third table:** errors per individual. This table contains one row per
    individual. Each error is counted toward the proband, father, and mother
    according to the `Implicated` in the table below.

        - (column key of `dataset`) (:py:data:`.tstr`) -- Sample ID (key field).
        - `fam_id` (:py:data:`.tstr`) -- Family ID.
        - `errors` (:py:data:`.tint64`) -- Number of Mendel errors involving this
          individual.
        - `snp_errors` (:py:data:`.tint64`) -- Number of Mendel errors involving this
          individual at SNPs.

    **Fourth table:** errors per variant.

        - `locus` (:class:`.tlocus`) -- Variant locus, key field.
        - `alleles` (:class:`.tarray` of :py:data:`.tstr`) -- Variant alleles, key field.
        - `errors` (:py:data:`.tint64`) -- Number of Mendel errors in this variant.

    This method only considers complete trios (two parents and proband with
    defined sex). The code of each Mendel error is determined by the table
    below, extending the
    `Plink classification <https://www.cog-genomics.org/plink2/basic_stats#mendel>`__.

    In the table, the copy state of a locus with respect to a trio is defined
    as follows, where PAR is the `pseudoautosomal region
    <https://en.wikipedia.org/wiki/Pseudoautosomal_region>`__ (PAR) of X and Y
    defined by the reference genome and the autosome is defined by
    :meth:`~hail.genetics.Locus.in_autosome`.

    - Auto -- in autosome or in PAR or female child
    - HemiX -- in non-PAR of X and male child
    - HemiY -- in non-PAR of Y and male child

    `Any` refers to the set \{ HomRef, Het, HomVar, NoCall \} and `~`
    denotes complement in this set.

    +------+---------+---------+--------+----------------------------+
    | Code | Dad     | Mom     | Kid    | Copy State | Implicated    |
    +======+=========+=========+========+============+===============+
    |    1 | HomVar  | HomVar  | Het    | Auto       | Dad, Mom, Kid |
    +------+---------+---------+--------+------------+---------------+
    |    2 | HomRef  | HomRef  | Het    | Auto       | Dad, Mom, Kid |
    +------+---------+---------+--------+------------+---------------+
    |    3 | HomRef  | ~HomRef | HomVar | Auto       | Dad, Kid      |
    +------+---------+---------+--------+------------+---------------+
    |    4 | ~HomRef | HomRef  | HomVar | Auto       | Mom, Kid      |
    +------+---------+---------+--------+------------+---------------+
    |    5 | HomRef  | HomRef  | HomVar | Auto       | Kid           |
    +------+---------+---------+--------+------------+---------------+
    |    6 | HomVar  | ~HomVar | HomRef | Auto       | Dad, Kid      |
    +------+---------+---------+--------+------------+---------------+
    |    7 | ~HomVar | HomVar  | HomRef | Auto       | Mom, Kid      |
    +------+---------+---------+--------+------------+---------------+
    |    8 | HomVar  | HomVar  | HomRef | Auto       | Kid           |
    +------+---------+---------+--------+------------+---------------+
    |    9 | Any     | HomVar  | HomRef | HemiX      | Mom, Kid      |
    +------+---------+---------+--------+------------+---------------+
    |   10 | Any     | HomRef  | HomVar | HemiX      | Mom, Kid      |
    +------+---------+---------+--------+------------+---------------+
    |   11 | HomVar  | Any     | HomRef | HemiY      | Dad, Kid      |
    +------+---------+---------+--------+------------+---------------+
    |   12 | HomRef  | Any     | HomVar | HemiY      | Dad, Kid      |
    +------+---------+---------+--------+------------+---------------+

    See Also
    --------
    :func:`.mendel_error_code`

    Parameters
    ----------
    dataset : :class:`.MatrixTable`
    pedigree : :class:`.Pedigree`

    Returns
    -------
    (:class:`.Table`, :class:`.Table`, :class:`.Table`, :class:`.Table`)
    """
    source = call._indices.source
    if not isinstance(source, MatrixTable):
        raise ValueError("'mendel_errors': expected 'call' to be an expression of 'MatrixTable', found {}".format(
            "expression of '{}'".format(source.__class__) if source is not None else 'scalar expression'))

    source = source.select_entries(__GT=call)
    dataset = require_biallelic(source, 'mendel_errors')
    tm = trio_matrix(dataset, pedigree, complete_trios=True)
    tm = tm.select_entries(mendel_code=hl.mendel_error_code(
        tm.locus,
        tm.is_female,
        tm.father_entry['__GT'],
        tm.mother_entry['__GT'],
        tm.proband_entry['__GT']
    ))
    ck_name = next(iter(source.col_key))
    tm = tm.filter_entries(hl.is_defined(tm.mendel_code))
    tm = tm.rename({'id' : ck_name})

    entries = tm.entries()

    table1 = entries.select('fam_id', 'mendel_code')

    fam_counts = (
        entries
            .group_by(pat_id=entries.father[ck_name], mat_id=entries.mother[ck_name])
            .partition_hint(min(entries.n_partitions(), 8))
            .aggregate(children=hl.len(hl.agg.collect_as_set(entries[ck_name])),
                       errors=hl.agg.count_where(hl.is_defined(entries.mendel_code)),
                       snp_errors=hl.agg.count_where(hl.is_snp(entries.alleles[0], entries.alleles[1]) &
                                                     hl.is_defined(entries.mendel_code)))
    )
    table2 = tm.key_cols_by().cols()
    table2 = table2.select(pat_id=table2.father[ck_name],
                           mat_id=table2.mother[ck_name],
                           fam_id=table2.fam_id,
                           **fam_counts[table2.father[ck_name], table2.mother[ck_name]])
    table2 = table2.key_by('pat_id', 'mat_id').distinct()
    table2 = table2.annotate(errors=hl.or_else(table2.errors, hl.int64(0)),
                             snp_errors=hl.or_else(table2.snp_errors, hl.int64(0)))

    # in implicated, idx 0 is dad, idx 1 is mom, idx 2 is child
    implicated = hl.literal([
        [0, 0, 0],  # dummy
        [1, 1, 1],
        [1, 1, 1],
        [1, 0, 1],
        [0, 1, 1],
        [0, 0, 1],
        [1, 0, 1],
        [0, 1, 1],
        [0, 0, 1],
        [0, 1, 1],
        [0, 1, 1],
        [1, 0, 1],
        [1, 0, 1],
    ], dtype=hl.tarray(hl.tarray(hl.tint64)))

    table3 = tm.annotate_cols(all_errors=hl.or_else(hl.agg.array_sum(implicated[tm.mendel_code]), [0, 0, 0]),
                              snp_errors=hl.or_else(
                                  hl.agg.filter(hl.is_snp(tm.alleles[0], tm.alleles[1]),
                                                hl.agg.array_sum(implicated[tm.mendel_code])),
                                  [0, 0, 0])).key_cols_by().cols()

    table3 = table3.select(xs=[
        hl.struct(**{ck_name: table3.father[ck_name],
                     'fam_id': table3.fam_id,
                     'errors': table3.all_errors[0],
                     'snp_errors': table3.snp_errors[0]}),
        hl.struct(**{ck_name: table3.mother[ck_name],
                     'fam_id': table3.fam_id,
                     'errors': table3.all_errors[1],
                     'snp_errors': table3.snp_errors[1]}),
        hl.struct(**{ck_name: table3.proband[ck_name],
                     'fam_id': table3.fam_id,
                     'errors': table3.all_errors[2],
                     'snp_errors': table3.snp_errors[2]}),
    ])
    table3 = table3.explode('xs')
    table3 = table3.select(**table3.xs)
    table3 = (table3.group_by(ck_name, 'fam_id')
              .aggregate(errors=hl.agg.sum(table3.errors),
                         snp_errors=hl.agg.sum(table3.snp_errors))
              .key_by(ck_name))

    table4 = tm.select_rows(errors=hl.agg.count_where(hl.is_defined(tm.mendel_code))).rows()

    return table1, table2, table3, table4
Esempio n. 20
0
def mendel_errors(call, pedigree) -> Tuple[Table, Table, Table, Table]:
    r"""Find Mendel errors; count per variant, individual and nuclear family.

    .. include:: ../_templates/req_tstring.rst

    .. include:: ../_templates/req_tvariant.rst

    .. include:: ../_templates/req_biallelic.rst

    Examples
    --------

    Find all violations of Mendelian inheritance in each (dad, mom, kid) trio in
    a pedigree and return four tables (all errors, errors by family, errors by
    individual, errors by variant):

    >>> ped = hl.Pedigree.read('data/trios.fam')
    >>> all_errors, per_fam, per_sample, per_variant = hl.mendel_errors(dataset['GT'], ped)

    Export all mendel errors to a text file:

    >>> all_errors.export('output/all_mendel_errors.tsv')

    Annotate columns with the number of Mendel errors:

    >>> annotated_samples = dataset.annotate_cols(mendel=per_sample[dataset.s])

    Annotate rows with the number of Mendel errors:

    >>> annotated_variants = dataset.annotate_rows(mendel=per_variant[dataset.locus, dataset.alleles])

    Notes
    -----

    The example above returns four tables, which contain Mendelian violations
    grouped in various ways. These tables are modeled after the `PLINK mendel
    formats <https://www.cog-genomics.org/plink2/formats#mendel>`_, resembling
    the ``.mendel``, ``.fmendel``, ``.imendel``, and ``.lmendel`` formats,
    respectively.

    **First table:** all Mendel errors. This table contains one row per Mendel
    error, keyed by the variant and proband id.

        - `locus` (:class:`.tlocus`) -- Variant locus, key field.
        - `alleles` (:class:`.tarray` of :py:data:`.tstr`) -- Variant alleles, key field.
        - (column key of `dataset`) (:py:data:`.tstr`) -- Proband ID, key field.
        - `fam_id` (:py:data:`.tstr`) -- Family ID.
        - `mendel_code` (:py:data:`.tint32`) -- Mendel error code, see below.

    **Second table:** errors per nuclear family. This table contains one row
    per nuclear family, keyed by the parents.

        - `pat_id` (:py:data:`.tstr`) -- Paternal ID. (key field)
        - `mat_id` (:py:data:`.tstr`) -- Maternal ID. (key field)
        - `fam_id` (:py:data:`.tstr`) -- Family ID.
        - `children` (:py:data:`.tint32`) -- Number of children in this nuclear family.
        - `errors` (:py:data:`.tint64`) -- Number of Mendel errors in this nuclear family.
        - `snp_errors` (:py:data:`.tint64`) -- Number of Mendel errors at SNPs in this
          nuclear family.

    **Third table:** errors per individual. This table contains one row per
    individual. Each error is counted toward the proband, father, and mother
    according to the `Implicated` in the table below.

        - (column key of `dataset`) (:py:data:`.tstr`) -- Sample ID (key field).
        - `fam_id` (:py:data:`.tstr`) -- Family ID.
        - `errors` (:py:data:`.tint64`) -- Number of Mendel errors involving this
          individual.
        - `snp_errors` (:py:data:`.tint64`) -- Number of Mendel errors involving this
          individual at SNPs.

    **Fourth table:** errors per variant.

        - `locus` (:class:`.tlocus`) -- Variant locus, key field.
        - `alleles` (:class:`.tarray` of :py:data:`.tstr`) -- Variant alleles, key field.
        - `errors` (:py:data:`.tint64`) -- Number of Mendel errors in this variant.

    This method only considers complete trios (two parents and proband with
    defined sex). The code of each Mendel error is determined by the table
    below, extending the
    `Plink classification <https://www.cog-genomics.org/plink2/basic_stats#mendel>`__.

    In the table, the copy state of a locus with respect to a trio is defined
    as follows, where PAR is the `pseudoautosomal region
    <https://en.wikipedia.org/wiki/Pseudoautosomal_region>`__ (PAR) of X and Y
    defined by the reference genome and the autosome is defined by
    :meth:`~.LocusExpression.in_autosome`.

    - Auto -- in autosome or in PAR or female child
    - HemiX -- in non-PAR of X and male child
    - HemiY -- in non-PAR of Y and male child

    `Any` refers to the set \{ HomRef, Het, HomVar, NoCall \} and `~`
    denotes complement in this set.

    +------+---------+---------+--------+----------------------------+
    | Code | Dad     | Mom     | Kid    | Copy State | Implicated    |
    +======+=========+=========+========+============+===============+
    |    1 | HomVar  | HomVar  | Het    | Auto       | Dad, Mom, Kid |
    +------+---------+---------+--------+------------+---------------+
    |    2 | HomRef  | HomRef  | Het    | Auto       | Dad, Mom, Kid |
    +------+---------+---------+--------+------------+---------------+
    |    3 | HomRef  | ~HomRef | HomVar | Auto       | Dad, Kid      |
    +------+---------+---------+--------+------------+---------------+
    |    4 | ~HomRef | HomRef  | HomVar | Auto       | Mom, Kid      |
    +------+---------+---------+--------+------------+---------------+
    |    5 | HomRef  | HomRef  | HomVar | Auto       | Kid           |
    +------+---------+---------+--------+------------+---------------+
    |    6 | HomVar  | ~HomVar | HomRef | Auto       | Dad, Kid      |
    +------+---------+---------+--------+------------+---------------+
    |    7 | ~HomVar | HomVar  | HomRef | Auto       | Mom, Kid      |
    +------+---------+---------+--------+------------+---------------+
    |    8 | HomVar  | HomVar  | HomRef | Auto       | Kid           |
    +------+---------+---------+--------+------------+---------------+
    |    9 | Any     | HomVar  | HomRef | HemiX      | Mom, Kid      |
    +------+---------+---------+--------+------------+---------------+
    |   10 | Any     | HomRef  | HomVar | HemiX      | Mom, Kid      |
    +------+---------+---------+--------+------------+---------------+
    |   11 | HomVar  | Any     | HomRef | HemiY      | Dad, Kid      |
    +------+---------+---------+--------+------------+---------------+
    |   12 | HomRef  | Any     | HomVar | HemiY      | Dad, Kid      |
    +------+---------+---------+--------+------------+---------------+

    See Also
    --------
    :func:`.mendel_error_code`

    Parameters
    ----------
    dataset : :class:`.MatrixTable`
    pedigree : :class:`.Pedigree`

    Returns
    -------
    (:class:`.Table`, :class:`.Table`, :class:`.Table`, :class:`.Table`)
    """
    source = call._indices.source
    if not isinstance(source, MatrixTable):
        raise ValueError(
            "'mendel_errors': expected 'call' to be an expression of 'MatrixTable', found {}"
            .format("expression of '{}'".format(source.__class__)
                    if source is not None else 'scalar expression'))

    source = source.select_entries(__GT=call)
    dataset = require_biallelic(source, 'mendel_errors')
    tm = trio_matrix(dataset, pedigree, complete_trios=True)
    tm = tm.select_entries(mendel_code=hl.mendel_error_code(
        tm.locus, tm.is_female, tm.father_entry['__GT'],
        tm.mother_entry['__GT'], tm.proband_entry['__GT']))
    ck_name = next(iter(source.col_key))
    tm = tm.filter_entries(hl.is_defined(tm.mendel_code))
    tm = tm.rename({'id': ck_name})

    entries = tm.entries()

    table1 = entries.select('fam_id', 'mendel_code')

    t2 = tm.annotate_cols(errors=hl.agg.count(),
                          snp_errors=hl.agg.count_where(
                              hl.is_snp(tm.alleles[0], tm.alleles[1])))
    table2 = t2.key_cols_by().cols()
    table2 = table2.select(pat_id=table2.father[ck_name],
                           mat_id=table2.mother[ck_name],
                           fam_id=table2.fam_id,
                           errors=table2.errors,
                           snp_errors=table2.snp_errors)
    table2 = table2.group_by('pat_id', 'mat_id').aggregate(
        fam_id=hl.agg.take(table2.fam_id, 1)[0],
        children=hl.int32(hl.agg.count()),
        errors=hl.agg.sum(table2.errors),
        snp_errors=hl.agg.sum(table2.snp_errors))
    table2 = table2.annotate(errors=hl.or_else(table2.errors, hl.int64(0)),
                             snp_errors=hl.or_else(table2.snp_errors,
                                                   hl.int64(0)))

    # in implicated, idx 0 is dad, idx 1 is mom, idx 2 is child
    implicated = hl.literal(
        [
            [0, 0, 0],  # dummy
            [1, 1, 1],
            [1, 1, 1],
            [1, 0, 1],
            [0, 1, 1],
            [0, 0, 1],
            [1, 0, 1],
            [0, 1, 1],
            [0, 0, 1],
            [0, 1, 1],
            [0, 1, 1],
            [1, 0, 1],
            [1, 0, 1],
        ],
        dtype=hl.tarray(hl.tarray(hl.tint64)))

    table3 = tm.annotate_cols(
        all_errors=hl.or_else(hl.agg.array_sum(implicated[tm.mendel_code]),
                              [0, 0, 0]),
        snp_errors=hl.or_else(
            hl.agg.filter(hl.is_snp(tm.alleles[0], tm.alleles[1]),
                          hl.agg.array_sum(implicated[tm.mendel_code])),
            [0, 0, 0])).key_cols_by().cols()

    table3 = table3.select(xs=[
        hl.struct(
            **{
                ck_name: table3.father[ck_name],
                'fam_id': table3.fam_id,
                'errors': table3.all_errors[0],
                'snp_errors': table3.snp_errors[0]
            }),
        hl.struct(
            **{
                ck_name: table3.mother[ck_name],
                'fam_id': table3.fam_id,
                'errors': table3.all_errors[1],
                'snp_errors': table3.snp_errors[1]
            }),
        hl.struct(
            **{
                ck_name: table3.proband[ck_name],
                'fam_id': table3.fam_id,
                'errors': table3.all_errors[2],
                'snp_errors': table3.snp_errors[2]
            }),
    ])
    table3 = table3.explode('xs')
    table3 = table3.select(**table3.xs)
    table3 = (table3.group_by(ck_name, 'fam_id').aggregate(
        errors=hl.agg.sum(table3.errors),
        snp_errors=hl.agg.sum(table3.snp_errors)).key_by(ck_name))

    table4 = tm.select_rows(
        errors=hl.agg.count_where(hl.is_defined(tm.mendel_code))).rows()

    return table1, table2, table3, table4
Esempio n. 21
0
def transmission_disequilibrium_test(dataset, pedigree) -> Table:
    r"""Performs the transmission disequilibrium test on trios.

    .. include:: ../_templates/req_tstring.rst

    .. include:: ../_templates/req_tvariant.rst

    .. include:: ../_templates/req_biallelic.rst

    Examples
    --------
    Compute TDT association statistics and show the first two results:

    >>> pedigree = hl.Pedigree.read('data/tdt_trios.fam')
    >>> tdt_table = hl.transmission_disequilibrium_test(tdt_dataset, pedigree)
    >>> tdt_table.show(2)  # doctest: +SKIP_OUTPUT_CHECK
    +---------------+------------+-------+-------+----------+----------+
    | locus         | alleles    |     t |     u |   chi_sq |  p_value |
    +---------------+------------+-------+-------+----------+----------+
    | locus<GRCh37> | array<str> | int64 | int64 |  float64 |  float64 |
    +---------------+------------+-------+-------+----------+----------+
    | 1:246714629   | ["C","A"]  |     0 |     4 | 4.00e+00 | 4.55e-02 |
    | 2:167262169   | ["T","C"]  |    NA |    NA |       NA |       NA |
    +---------------+------------+-------+-------+----------+----------+

    Export variants with p-values below 0.001:

    >>> tdt_table = tdt_table.filter(tdt_table.p_value < 0.001)
    >>> tdt_table.export(f"output/tdt_results.tsv")

    Notes
    -----
    The
    `transmission disequilibrium test <https://en.wikipedia.org/wiki/Transmission_disequilibrium_test#The_case_of_trios:_one_affected_child_per_family>`__
    compares the number of times the alternate allele is transmitted (t) versus
    not transmitted (u) from a heterozgyous parent to an affected child. The null
    hypothesis holds that each case is equally likely. The TDT statistic is given by

    .. math::

        (t - u)^2 \over (t + u)

    and asymptotically follows a chi-squared distribution with one degree of
    freedom under the null hypothesis.

    :func:`transmission_disequilibrium_test` only considers complete trios (two
    parents and a proband with defined sex) and only returns results for the
    autosome, as defined by :meth:`~.LocusExpression.in_autosome`, and
    chromosome X. Transmissions and non-transmissions are counted only for the
    configurations of genotypes and copy state in the table below, in order to
    filter out Mendel errors and configurations where transmission is
    guaranteed. The copy state of a locus with respect to a trio is defined as
    follows:

    - Auto -- in autosome or in PAR of X or female child
    - HemiX -- in non-PAR of X and male child

    Here PAR is the `pseudoautosomal region
    <https://en.wikipedia.org/wiki/Pseudoautosomal_region>`__
    of X and Y defined by :class:`.ReferenceGenome`, which many variant callers
    map to chromosome X.

    +--------+--------+--------+------------+---+---+
    |  Kid   | Dad    | Mom    | Copy State | t | u |
    +========+========+========+============+===+===+
    | HomRef | Het    | Het    | Auto       | 0 | 2 |
    +--------+--------+--------+------------+---+---+
    | HomRef | HomRef | Het    | Auto       | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | HomRef | Het    | HomRef | Auto       | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | Het    | Het    | Het    | Auto       | 1 | 1 |
    +--------+--------+--------+------------+---+---+
    | Het    | HomRef | Het    | Auto       | 1 | 0 |
    +--------+--------+--------+------------+---+---+
    | Het    | Het    | HomRef | Auto       | 1 | 0 |
    +--------+--------+--------+------------+---+---+
    | Het    | HomVar | Het    | Auto       | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | Het    | Het    | HomVar | Auto       | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | HomVar | Het    | Het    | Auto       | 2 | 0 |
    +--------+--------+--------+------------+---+---+
    | HomVar | Het    | HomVar | Auto       | 1 | 0 |
    +--------+--------+--------+------------+---+---+
    | HomVar | HomVar | Het    | Auto       | 1 | 0 |
    +--------+--------+--------+------------+---+---+
    | HomRef | HomRef | Het    | HemiX      | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | HomRef | HomVar | Het    | HemiX      | 0 | 1 |
    +--------+--------+--------+------------+---+---+
    | HomVar | HomRef | Het    | HemiX      | 1 | 0 |
    +--------+--------+--------+------------+---+---+
    | HomVar | HomVar | Het    | HemiX      | 1 | 0 |
    +--------+--------+--------+------------+---+---+

    :func:`.transmission_disequilibrium_test` produces a table with the following columns:

     - `locus` (:class:`.tlocus`) -- Locus.
     - `alleles` (:class:`.tarray` of :py:data:`.tstr`) -- Alleles.
     - `t` (:py:data:`.tint32`) -- Number of transmitted alternate alleles.
     - `u` (:py:data:`.tint32`) -- Number of untransmitted alternate alleles.
     - `chi_sq` (:py:data:`.tfloat64`) -- TDT statistic.
     - `p_value` (:py:data:`.tfloat64`) -- p-value.

    Parameters
    ----------
    dataset : :class:`.MatrixTable`
        Dataset.
    pedigree : :class:`~hail.genetics.Pedigree`
        Sample pedigree.

    Returns
    -------
    :class:`.Table`
        Table of TDT results.
    """

    dataset = require_biallelic(dataset, 'transmission_disequilibrium_test')
    dataset = dataset.annotate_rows(auto_or_x_par=dataset.locus.in_autosome()
                                    | dataset.locus.in_x_par())
    dataset = dataset.filter_rows(dataset.auto_or_x_par
                                  | dataset.locus.in_x_nonpar())

    hom_ref = 0
    het = 1
    hom_var = 2

    auto = 2
    hemi_x = 1

    #                     kid,     dad,     mom,   copy, t, u
    config_counts = [
        (hom_ref, het, het, auto, 0, 2),  # noqa: E241
        (hom_ref, hom_ref, het, auto, 0, 1),  # noqa: E241
        (hom_ref, het, hom_ref, auto, 0, 1),  # noqa: E241
        (het, het, het, auto, 1, 1),  # noqa: E241, E201
        (het, hom_ref, het, auto, 1, 0),  # noqa: E241, E201
        (het, het, hom_ref, auto, 1, 0),  # noqa: E241, E201
        (het, hom_var, het, auto, 0, 1),  # noqa: E241, E201
        (het, het, hom_var, auto, 0, 1),  # noqa: E241, E201
        (hom_var, het, het, auto, 2, 0),  # noqa: E241
        (hom_var, het, hom_var, auto, 1, 0),  # noqa: E241
        (hom_var, hom_var, het, auto, 1, 0),  # noqa: E241
        (hom_ref, hom_ref, het, hemi_x, 0, 1),  # noqa: E241
        (hom_ref, hom_var, het, hemi_x, 0, 1),  # noqa: E241
        (hom_var, hom_ref, het, hemi_x, 1, 0),  # noqa: E241
        (hom_var, hom_var, het, hemi_x, 1, 0)
    ]  # noqa: E241

    count_map = hl.literal({(c[0], c[1], c[2], c[3]): [c[4], c[5]]
                            for c in config_counts})

    tri = trio_matrix(dataset, pedigree, complete_trios=True)

    # this filter removes mendel error of het father in x_nonpar. It also avoids
    #   building and looking up config in common case that neither parent is het
    father_is_het = tri.father_entry.GT.is_het()
    parent_is_valid_het = ((father_is_het & tri.auto_or_x_par)
                           | (tri.mother_entry.GT.is_het() & ~father_is_het))

    copy_state = hl.if_else(tri.auto_or_x_par | tri.is_female, 2, 1)

    config = (tri.proband_entry.GT.n_alt_alleles(),
              tri.father_entry.GT.n_alt_alleles(),
              tri.mother_entry.GT.n_alt_alleles(), copy_state)

    tri = tri.annotate_rows(counts=agg.filter(
        parent_is_valid_het, agg.array_sum(count_map.get(config))))

    tab = tri.rows().select('counts')
    tab = tab.transmute(t=tab.counts[0], u=tab.counts[1])
    tab = tab.annotate(chi_sq=((tab.t - tab.u)**2) / (tab.t + tab.u))
    tab = tab.annotate(p_value=hl.pchisqtail(tab.chi_sq, 1.0))

    return tab.cache()
Esempio n. 22
0
 def test_filter_cols_with_global_references(self):
     mt = hl.utils.range_matrix_table(10, 10)
     s = hl.literal({1, 3, 5, 7})
     self.assertEqual(mt.filter_cols(s.contains(mt.col_idx)).count_cols(), 4)
Esempio n. 23
0
def multitrait_ss(mt, h2, pi, rg=0, seed=None):
    r"""Generates spike & slab betas for simulation of two correlated phenotypes.

    Parameters
    ----------
    mt : :class:`.MatrixTable`
        :class:`.MatrixTable` for simulated phenotype.
    h2 : :obj:`list` or :class:`numpy.ndarray`
        Desired SNP-based heritability of simulated traits.
    pi : :obj:`list` or :class:`numpy.ndarray`
        List of proportion of SNPs: :math:`p_{TT}`, :math:`p_{TF}`, :math:`p_{FT}`
        :math:`p_{TT}` is the proportion of SNPs that are causal for both traits,
        :math:`p_{TF}` is the proportion of SNPs that are causal for trait 1 but not trait 2,
        :math:`p_{FT}` is the proportion of SNPs that are causal for trait 2 but not trait 1.
    rg : :obj:`float` or :obj:`int`
        Genetic correlation between traits.
    seed : :obj:`int`, optional
        Seed for random number generator. If `seed` is ``None``, `seed` is set randomly.

    Warning
    -------
    May give inaccurate results if chosen parameters make the covariance matrix
    not positive semi-definite. Covariance matrix is likely to not be positive
    semi-definite when :math:`p_{TT}` is small and rg is large.

    Returns
    -------
    mt : :class:`.MatrixTable`
        :class:`.MatrixTable` with simulated SNP effects as a row field of arrays.
    pi : :obj:`list` or :class:`numpy.ndarray`
        List of proportion of SNPs: :math:`p_{TT}`, :math:`p_{TF}`, :math:`p_{FT}`.
        Possibly altered if covariance matrix of traits was not positive semi-definite.
    rg : :obj:`list`
        Genetic correlation between traits, possibly altered from input `rg` if
        covariance matrix was not positive semi-definite.
    """
    assert sum(
        pi) <= 1, "probabilities of being causal must sum to be less than 1"
    seed = seed if seed is not None else int(str(Env.next_seed())[:8])
    ptt, ptf, pft, pff = pi[0], pi[1], pi[2], 1 - sum(pi)
    cov_matrix = np.asarray([[1 / (ptt + ptf), rg / ptt],
                             [rg / ptt, 1 / (ptt + pft)]])
    M = mt.count_rows()
    # seed random state for replicability
    randstate = np.random.RandomState(int(seed))
    if np.any(np.linalg.eigvals(cov_matrix) < 0):
        print(
            'adjusting parameters to make covariance matrix positive semidefinite'
        )
        rg0, ptt0 = rg, ptt
        while np.any(np.linalg.eigvals(cov_matrix) < 0
                     ):  # check positive semidefinite
            rg = round(0.99 * rg, 6)
            ptt = round(ptt + (pff) * 0.001, 6)
            cov_matrix = np.asarray([[1 / (ptt + ptf), rg / ptt],
                                     [rg / ptt, 1 / (ptt + pft)]])
        pff0, pff = pff, 1 - sum([ptt, ptf, pft])
        print(f'rg: {rg0} -> {rg}\nptt: {ptt0} -> {ptt}\npff: {pff0} -> {pff}')
        pi = [ptt, ptf, pft, pff]
    beta = randstate.multivariate_normal(mean=np.zeros(2),
                                         cov=cov_matrix,
                                         size=[
                                             int(M),
                                         ])
    zeros = np.zeros(shape=int(M)).T
    beta_matrix = np.stack(
        (beta, np.asarray([beta[:, 0], zeros]).T, np.asarray(
            [zeros, zeros]).T, np.asarray([zeros, beta[:, 1]]).T),
        axis=1)
    idx = np.random.choice(a=[0, 1, 2, 3], size=int(M), p=[ptt, ptf, pft, pff])
    betas = beta_matrix[range(int(M)), idx, :]
    betas[:, 0] *= (h2[0] / M)**(1 / 2)
    betas[:, 1] *= (h2[1] / M)**(1 / 2)
    df = pd.DataFrame([0] * M, columns=['beta'])
    tb = hl.Table.from_pandas(df)
    tb = tb.add_index().key_by('idx')
    tb = tb.annotate(beta=hl.literal(betas.tolist())[hl.int32(tb.idx)])
    mt = mt.add_row_index()
    mt = mt.annotate_rows(beta=tb[mt.row_idx]['beta'])
    return mt, pi, [rg]
def compute_quantile_bin(
    ht: hl.Table,
    score_expr: hl.expr.NumericExpression,
    bin_expr: Dict[str, hl.expr.BooleanExpression] = {"bin": True},
    compute_snv_indel_separately: bool = True,
    n_bins: int = 100,
    k: int = 1000,
    desc: bool = True,
) -> hl.Table:
    """
    Returns a table with a bin for each row based on quantiles of `score_expr`.
    The bin is computed by dividing the `score_expr` into `n_bins` bins containing an equal number of elements.
    This is done based on quantiles computed with hl.agg.approx_quantiles. If a single value in `score_expr` spans more
    than one bin, the rows with this value are distributed randomly across the bins it spans.
    If `compute_snv_indel_separately` is True all items in `bin_expr` will be stratified by snv / indels for the bin
    calculation. Because SNV and indel rows are mutually exclusive, they are re-combined into a single annotation. For
    example if we have the following four variants and scores and `n_bins` of 2:
    ========   =======   ======   =================   =================
    Variant    Type      Score    bin - `compute_snv_indel_separately`:
    --------   -------   ------   -------------------------------------
    \          \         \        False               True
    ========   =======   ======   =================   =================
    Var1       SNV       0.1      1                   1
    Var2       SNV       0.2      1                   2
    Var3       Indel     0.3      2                   1
    Var4       Indel     0.4      2                   2
    ========   =======   ======   =================   =================
    .. note::
        The `bin_expr` defines which data the bin(s) should be computed on. E.g., to get a biallelic quantile bin and an
        singleton quantile bin, the following could be used:
        .. code-block:: python
            bin_expr={
                'biallelic_bin': ~ht.was_split,
                'singleton_bin': ht.singleton
            }
    :param ht: Input Table
    :param score_expr: Expression containing the score
    :param bin_expr: Quantile bin(s) to be computed (see notes)
    :param compute_snv_indel_separately: Should all `bin_expr` items be stratified by snv / indels
    :param n_bins: Number of bins to bin the data into
    :param k: The `k` parameter of approx_quantiles
    :param desc: Whether to bin the score in descending order
    :return: Table with the quantile bins
    """
    import math

    def quantiles_to_bin_boundaries(quantiles: List[int]) -> Dict:
        """
        Merges bins with the same boundaries into a unique bin while keeping track of
        which bins have been merged and the global index of all bins.
        :param quantiles: Original bins boundaries
        :return: (dict of the indices of bins for which multiple bins were collapsed -> number of bins collapsed,
                  Global indices of merged bins,
                  Merged bins boundaries)
        """

        # Pad the quantiles to create boundaries for the first and last bins
        bin_boundaries = [-math.inf] + quantiles + [math.inf]
        merged_bins = defaultdict(int)

        # If every quantile has a unique value, then bin boudaries are unique
        # and can be passed to binary_search as-is
        if len(quantiles) == len(set(quantiles)):
            return dict(
                merged_bins=merged_bins,
                global_bin_indices=list(range(len(bin_boundaries))),
                bin_boundaries=bin_boundaries,
            )

        indexed_bins = list(enumerate(bin_boundaries))
        i = 1
        while i < len(indexed_bins):
            if indexed_bins[i - 1][1] == indexed_bins[i][1]:
                merged_bins[i - 1] += 1
                indexed_bins.pop(i)
            else:
                i += 1

        return dict(
            merged_bins=merged_bins,
            global_bin_indices=[x[0] for x in indexed_bins],
            bin_boundaries=[x[1] for x in indexed_bins],
        )

    if compute_snv_indel_separately:
        # For each bin, add a SNV / indel stratification
        bin_expr = {
            f"{bin_id}_{snv}": (bin_expr & snv_expr)
            for bin_id, bin_expr in bin_expr.items() for snv, snv_expr in [
                ("snv", hl.is_snp(ht.alleles[0], ht.alleles[1])),
                ("indel", ~hl.is_snp(ht.alleles[0], ht.alleles[1])),
            ]
        }
        print("ADSADSADASDAS")
        print(bin_expr)

    bin_ht = ht.annotate(
        **{
            f"_filter_{bin_id}": bin_expr
            for bin_id, bin_expr in bin_expr.items()
        },
        _score=score_expr,
        snv=hl.is_snp(ht.alleles[0], ht.alleles[1]),
    )
    print(bin_ht.show())
    logger.info(
        f"Adding quantile bins using approximate_quantiles binned into {n_bins}, using k={k}"
    )
    bin_stats = bin_ht.aggregate(
        hl.struct(
            **{
                bin_id: hl.agg.filter(
                    bin_ht[f"_filter_{bin_id}"],
                    hl.struct(
                        n=hl.agg.count(),
                        quantiles=hl.agg.approx_quantiles(
                            bin_ht._score,
                            [x / (n_bins) for x in range(1, n_bins)],
                            k=k),
                    ),
                )
                for bin_id in bin_expr
            }))

    # Take care of bins with duplicated boundaries
    bin_stats = bin_stats.annotate(
        **{
            rname: bin_stats[rname].annotate(
                **quantiles_to_bin_boundaries(bin_stats[rname].quantiles))
            for rname in bin_stats
        })

    bin_ht = bin_ht.annotate_globals(bin_stats=hl.literal(
        bin_stats,
        dtype=hl.tstruct(
            **{
                bin_id: hl.tstruct(
                    n=hl.tint64,
                    quantiles=hl.tarray(hl.tfloat64),
                    bin_boundaries=hl.tarray(hl.tfloat64),
                    global_bin_indices=hl.tarray(hl.tint32),
                    merged_bins=hl.tdict(hl.tint32, hl.tint32),
                )
                for bin_id in bin_expr
            }),
    ))

    # Annotate the bin as the index in the unique boundaries array
    bin_ht = bin_ht.annotate(
        **{
            bin_id: hl.or_missing(
                bin_ht[f"_filter_{bin_id}"],
                hl.binary_search(bin_ht.bin_stats[bin_id].bin_boundaries,
                                 bin_ht._score),
            )
            for bin_id in bin_expr
        })

    # Convert the bin to global bin by expanding merged bins, that is:
    # If a value falls in a bin that needs expansion, assign it randomly to one of the expanded bins
    # Otherwise, simply modify the bin to its global index (with expanded bins that is)
    bin_ht = bin_ht.select(
        "snv",
        **{
            bin_id: hl.if_else(
                bin_ht.bin_stats[bin_id].merged_bins.contains(bin_ht[bin_id]),
                bin_ht.bin_stats[bin_id].global_bin_indices[bin_ht[bin_id]] +
                hl.int(
                    hl.rand_unif(
                        0, bin_ht.bin_stats[bin_id].merged_bins[bin_ht[bin_id]]
                        + 1)),
                bin_ht.bin_stats[bin_id].global_bin_indices[bin_ht[bin_id]],
            )
            for bin_id in bin_expr
        },
    )

    if desc:
        bin_ht = bin_ht.annotate(
            **{bin_id: n_bins - bin_ht[bin_id]
               for bin_id in bin_expr})

    # Because SNV and indel rows are mutually exclusive, re-combine them into a single bin.
    # Update the global bin_stats struct to reflect the change in bin names in the table
    if compute_snv_indel_separately:
        bin_expr_no_snv = {
            bin_id.rsplit("_", 1)[0]
            for bin_id in bin_ht.bin_stats
        }
        bin_ht = bin_ht.annotate_globals(bin_stats=hl.struct(
            **{
                bin_id: hl.struct(
                    **{
                        snv: bin_ht.bin_stats[f"{bin_id}_{snv}"]
                        for snv in ["snv", "indel"]
                    })
                for bin_id in bin_expr_no_snv
            }))

        bin_ht = bin_ht.transmute(
            **{
                bin_id: hl.if_else(
                    bin_ht.snv,
                    bin_ht[f"{bin_id}_snv"],
                    bin_ht[f"{bin_id}_indel"],
                )
                for bin_id in bin_expr_no_snv
            })

    return bin_ht
Esempio n. 25
0
def _to_expr(e, dtype):
    if e is None:
        return None
    elif isinstance(e, Expression):
        if e.dtype != dtype:
            assert is_numeric(dtype), 'expected {}, got {}'.format(
                dtype, e.dtype)
            if dtype == tfloat64:
                return hl.float64(e)
            elif dtype == tfloat32:
                return hl.float32(e)
            elif dtype == tint64:
                return hl.int64(e)
            else:
                assert dtype == tint32
                return hl.int32(e)
        return e
    elif not is_compound(dtype):
        # these are not container types and cannot contain expressions if we got here
        return e
    elif isinstance(dtype, tstruct):
        new_fields = []
        found_expr = False
        for f, t in dtype.items():
            value = _to_expr(e[f], t)
            found_expr = found_expr or isinstance(value, Expression)
            new_fields.append(value)

        if not found_expr:
            return e
        else:
            exprs = [
                new_fields[i] if isinstance(new_fields[i], Expression) else
                hl.literal(new_fields[i], dtype[i])
                for i in range(len(new_fields))
            ]
            fields = {name: expr for name, expr in zip(dtype.keys(), exprs)}
            from .typed_expressions import StructExpression
            return StructExpression._from_fields(fields)

    elif isinstance(dtype, tarray):
        elements = []
        found_expr = False
        for element in e:
            value = _to_expr(element, dtype.element_type)
            found_expr = found_expr or isinstance(value, Expression)
            elements.append(value)
        if not found_expr:
            return e
        else:
            assert (len(elements) > 0)
            exprs = [
                element if isinstance(element, Expression) else hl.literal(
                    element, dtype.element_type) for element in elements
            ]
            indices, aggregations = unify_all(*exprs)
        x = ir.MakeArray([e._ir for e in exprs], None)
        return expressions.construct_expr(x, dtype, indices, aggregations)
    elif isinstance(dtype, tset):
        elements = []
        found_expr = False
        for element in e:
            value = _to_expr(element, dtype.element_type)
            found_expr = found_expr or isinstance(value, Expression)
            elements.append(value)
        if not found_expr:
            return e
        else:
            assert (len(elements) > 0)
            exprs = [
                element if isinstance(element, Expression) else hl.literal(
                    element, dtype.element_type) for element in elements
            ]
            indices, aggregations = unify_all(*exprs)
            x = ir.ToSet(
                ir.ToStream(ir.MakeArray([e._ir for e in exprs], None)))
            return expressions.construct_expr(x, dtype, indices, aggregations)
    elif isinstance(dtype, ttuple):
        elements = []
        found_expr = False
        assert len(e) == len(dtype.types)
        for i in range(len(e)):
            value = _to_expr(e[i], dtype.types[i])
            found_expr = found_expr or isinstance(value, Expression)
            elements.append(value)
        if not found_expr:
            return e
        else:
            exprs = [
                elements[i] if isinstance(elements[i], Expression) else
                hl.literal(elements[i], dtype.types[i])
                for i in range(len(elements))
            ]
            indices, aggregations = unify_all(*exprs)
            x = ir.MakeTuple([expr._ir for expr in exprs])
            return expressions.construct_expr(x, dtype, indices, aggregations)
    elif isinstance(dtype, tdict):
        keys = []
        values = []
        found_expr = False
        for k, v in e.items():
            k_ = _to_expr(k, dtype.key_type)
            v_ = _to_expr(v, dtype.value_type)
            found_expr = found_expr or isinstance(k_, Expression)
            found_expr = found_expr or isinstance(v_, Expression)
            keys.append(k_)
            values.append(v_)
        if not found_expr:
            return e
        else:
            assert len(keys) > 0
            # Here I use `to_expr` to call `lit` the keys and values separately.
            # I anticipate a common mode is statically-known keys and Expression
            # values.
            key_array = to_expr(keys, tarray(dtype.key_type))
            value_array = to_expr(values, tarray(dtype.value_type))
            return hl.dict(hl.zip(key_array, value_array))
    elif isinstance(dtype, hl.tndarray):
        return hl.nd.array(e)
    else:
        raise NotImplementedError(dtype)
Esempio n. 26
0
def import_gtf(path, reference_genome=None, skip_invalid_contigs=False, min_partitions=None) -> hl.Table:
    """Import a GTF file.

       The GTF file format is identical to the GFF version 2 file format,
       and so this function can be used to import GFF version 2 files as
       well.

       See https://www.ensembl.org/info/website/upload/gff.html for more
       details on the GTF/GFF2 file format.

       The :class:`.Table` returned by this function will be keyed by the
       ``interval`` row field and will include the following row fields:

       .. code-block:: text

           'source': str
           'feature': str
           'score': float64
           'strand': str
           'frame': int32
           'interval': interval<>

       There will also be corresponding fields for every tag found in the
       attribute field of the GTF file.

       Note
       ----

       This function will return an ``interval`` field of type :class:`.tinterval`
       constructed from the ``seqname``, ``start``, and ``end`` fields in the
       GTF file. This interval is inclusive of both the start and end positions
       in the GTF file. 

       If the ``reference_genome`` parameter is specified, the start and end
       points of the ``interval`` field will be of type :class:`.tlocus`.
       Otherwise, the start and end points of the ``interval`` field will be of
       type :class:`.tstruct` with fields ``seqname`` (type :class:`str`) and
       ``position`` (type :class:`.tint32`).

       Furthermore, if the ``reference_genome`` parameter is specified and
       ``skip_invalid_contigs`` is ``True``, this import function will skip
       lines in the GTF where ``seqname`` is not consistent with the reference
       genome specified.

       Example
       -------

       >>> ht = hl.experimental.import_gtf('data/test.gtf', 
       ...                                 reference_genome='GRCh37',
       ...                                 skip_invalid_contigs=True)

       >>> ht.describe()  # doctest: +NOTEST
       ----------------------------------------
       Global fields:
       None
       ----------------------------------------
       Row fields:
           'source': str
           'feature': str
           'score': float64
           'strand': str
           'frame': int32
           'gene_type': str
           'exon_id': str
           'havana_transcript': str
           'level': str
           'transcript_name': str
           'gene_status': str
           'gene_id': str
           'transcript_type': str
           'tag': str
           'transcript_status': str
           'gene_name': str
           'transcript_id': str
           'exon_number': str
           'havana_gene': str
           'interval': interval<locus<GRCh37>>
       ----------------------------------------
       Key: ['interval']
       ----------------------------------------

       Parameters
       ----------

       path : :obj:`str`
           File to import.
       reference_genome : :obj:`str` or :class:`.ReferenceGenome`, optional
           Reference genome to use.
       skip_invalid_contigs : :obj:`bool`
           If ``True`` and `reference_genome` is not ``None``, skip lines where
           ``seqname`` is not consistent with the reference genome.
       min_partitions : :obj:`int` or :obj:`None`
           Minimum number of partitions (passed to import_table).

       Returns
       -------
       :class:`.Table`
       """

    ht = hl.import_table(path,
                         min_partitions=min_partitions,
                         comment='#',
                         no_header=True,
                         types={'f3': hl.tint,
                                'f4': hl.tint,
                                'f5': hl.tfloat,
                                'f7': hl.tint},
                         missing='.',
                         delimiter='\t')

    ht = ht.rename({'f0': 'seqname',
                    'f1': 'source',
                    'f2': 'feature',
                    'f3': 'start',
                    'f4': 'end',
                    'f5': 'score',
                    'f6': 'strand',
                    'f7': 'frame',
                    'f8': 'attribute'})

    ht = ht.annotate(attribute=hl.dict(
        hl.map(lambda x: (x.split(' ')[0],
                          x.split(' ')[1].replace('"', '').replace(';$', '')),
               ht['attribute'].split('; '))))

    attributes = ht.aggregate(hl.agg.explode(lambda x: hl.agg.collect_as_set(x), ht['attribute'].keys()))

    ht = ht.transmute(**{x: hl.or_missing(ht['attribute'].contains(x),
                                          ht['attribute'][x])
                         for x in attributes if x})

    if reference_genome:
        if reference_genome == 'GRCh37':
            ht = ht.annotate(seqname=ht['seqname'].replace('^chr', ''))
        else:
            ht = ht.annotate(seqname=hl.case()
                                       .when(ht['seqname'].startswith('HLA'), ht['seqname'])
                                       .when(ht['seqname'].startswith('chrHLA'), ht['seqname'].replace('^chr', ''))
                                       .when(ht['seqname'].startswith('chr'), ht['seqname'])
                                       .default('chr' + ht['seqname']))
        if skip_invalid_contigs:
            valid_contigs = hl.literal(set(hl.get_reference(reference_genome).contigs))
            ht = ht.filter(valid_contigs.contains(ht['seqname']))
        ht = ht.transmute(interval=hl.locus_interval(ht['seqname'],
                                                     ht['start'],
                                                     ht['end'],
                                                     includes_start=True,
                                                     includes_end=True,
                                                     reference_genome=reference_genome))
    else:
        ht = ht.transmute(interval=hl.interval(hl.struct(seqname=ht['seqname'], position=ht['start']),
                                               hl.struct(seqname=ht['seqname'], position=ht['end']),
                                               includes_start=True,
                                               includes_end=True))

    ht = ht.key_by('interval')

    return ht
Esempio n. 27
0
 def test_literals_rebuild(self):
     mt = hl.utils.range_matrix_table(1, 1)
     mt = mt.annotate_rows(x = hl.cond(hl.len(hl.literal([1,2,3])) < hl.rand_unif(10, 11), mt.globals, hl.struct()))
     mt._force_count_rows()
Esempio n. 28
0
def rename_duplicates(dataset, name='unique_id') -> MatrixTable:
    """Rename duplicate column keys.

    .. include:: ../_templates/req_tstring.rst

    Examples
    --------

    >>> renamed = hl.rename_duplicates(dataset).cols()
    >>> duplicate_samples = (renamed.filter(renamed.s != renamed.unique_id)
    ...                             .select()
    ...                             .collect())

    Notes
    -----

    This method produces a new column field from the string column key by
    appending a unique suffix ``_N`` as necessary. For example, if the column
    key "NA12878" appears three times in the dataset, the first will produce
    "NA12878", the second will produce "NA12878_1", and the third will produce
    "NA12878_2". The name of this new field is parameterized by `name`.

    Parameters
    ----------
    dataset : :class:`.MatrixTable`
        Dataset.
    name : :obj:`str`
        Name of new field.

    Returns
    -------
    :class:`.MatrixTable`
    """

    require_col_key_str(dataset, 'rename_duplicates')
    ids = dataset.col_key[0].collect()
    uniques = set()
    mapping = []
    new_ids = []

    fmt = lambda s, i: '{}_{}'.format(s, i)
    for s in ids:
        s_ = s
        i = 0
        while s_ in uniques:
            i += 1
            s_ = fmt(s, i)

        if s_ != s:
            mapping.append((s, s_))
        uniques.add(s_)
        new_ids.append(s_)

    if mapping:
        info(
            f'Renamed {len(mapping)} duplicate {plural("sample ID", len(mapping))}. Mangled IDs as follows:'
            + ''.join(f'\n  "{pre}" => "{post}"' for pre, post in mapping))
    else:
        info('No duplicate sample IDs found.')
    uid = Env.get_uid()
    return dataset.annotate_cols(
        **{name: hl.literal(new_ids)[hl.int(hl.scan.count())]})
Esempio n. 29
0
def explode_by_p_threshold(mt):
    mt = mt.annotate_cols(p_threshold=hl.literal(list(
        P_THRESHOLDS.items()))).explode_cols('p_threshold')
    mt = mt.transmute_cols(p_threshold_name=mt.p_threshold[0],
                           p_threshold=mt.p_threshold[1])
    return mt
def import_mnv_file(path, **kwargs):
    column_types = {
        "AC_mnv_ex": hl.tint,
        "AC_mnv_gen": hl.tint,
        "AC_mnv": hl.tint,
        "AC_snp1_ex": hl.tint,
        "AC_snp1_gen": hl.tint,
        "AC_snp1": hl.tint,
        "AC_snp2_ex": hl.tint,
        "AC_snp2_gen": hl.tint,
        "AC_snp2": hl.tint,
        "AN_snp1_ex": hl.tfloat,
        "AN_snp1_gen": hl.tfloat,
        "AN_snp2_ex": hl.tfloat,
        "AN_snp2_gen": hl.tfloat,
        "categ": hl.tstr,
        "filter_snp1_ex": hl.tarray(hl.tstr),
        "filter_snp1_gen": hl.tarray(hl.tstr),
        "filter_snp2_ex": hl.tarray(hl.tstr),
        "filter_snp2_gen": hl.tarray(hl.tstr),
        "gene_id": hl.tstr,
        "gene_name": hl.tstr,
        "locus.contig": hl.tstr,
        "locus.position": hl.tint,
        "mnv_amino_acids": hl.tstr,
        "mnv_codons": hl.tstr,
        "mnv_consequence": hl.tstr,
        "mnv_lof": hl.tstr,
        "mnv": hl.tstr,
        "n_homhom_ex": hl.tint,
        "n_homhom_gen": hl.tint,
        "n_homhom": hl.tint,
        "n_indv_ex": hl.tint,
        "n_indv_gen": hl.tint,
        "n_indv": hl.tint,
        "snp1_amino_acids": hl.tstr,
        "snp1_codons": hl.tstr,
        "snp1_consequence": hl.tstr,
        "snp1_lof": hl.tstr,
        "snp1": hl.tstr,
        "snp2_amino_acids": hl.tstr,
        "snp2_codons": hl.tstr,
        "snp2_consequence": hl.tstr,
        "snp2_lof": hl.tstr,
        "snp2": hl.tstr,
        "transcript_id": hl.tstr,
    }

    ds = hl.import_table(path,
                         key="mnv",
                         missing="",
                         types=column_types,
                         **kwargs)

    ds = ds.transmute(locus=hl.locus(ds["locus.contig"], ds["locus.position"]))

    ds = ds.transmute(
        contig=normalized_contig(ds.locus),
        pos=ds.locus.position,
        xpos=x_position(ds.locus),
    )

    ds = ds.annotate(ref=ds.mnv.split("-")[2],
                     alt=ds.mnv.split("-")[3],
                     variant_id=ds.mnv)

    ds = ds.annotate(snp1_copy=ds.snp1, snp2_copy=ds.snp2)
    ds = ds.transmute(constituent_snvs=[
        hl.bind(
            lambda variant_id_parts: hl.struct(
                variant_id=ds[f"{snp}_copy"],
                chrom=variant_id_parts[0],
                pos=hl.int(variant_id_parts[1]),
                ref=variant_id_parts[2],
                alt=variant_id_parts[3],
                exome=hl.or_missing(
                    hl.is_defined(ds[f"AN_{snp}_ex"]),
                    hl.struct(
                        filters=ds[f"filter_{snp}_ex"],
                        ac=ds[f"AC_{snp}_ex"],
                        an=hl.int(ds[f"AN_{snp}_ex"]),
                    ),
                ),
                genome=hl.or_missing(
                    hl.is_defined(ds[f"AN_{snp}_gen"]),
                    hl.struct(
                        filters=ds[f"filter_{snp}_gen"],
                        ac=ds[f"AC_{snp}_gen"],
                        an=hl.int(ds[f"AN_{snp}_gen"]),
                    ),
                ),
            ),
            ds[f"{snp}_copy"].split("-"),
        ) for snp in ["snp1", "snp2"]
    ])

    ds = ds.annotate(constituent_snv_ids=[ds.snp1, ds.snp2])

    ds = ds.annotate(
        mnv_in_exome=ds.constituent_snvs.all(lambda s: hl.is_defined(s.exome)),
        mnv_in_genome=ds.constituent_snvs.all(
            lambda s: hl.is_defined(s.genome)),
    )

    ds = ds.transmute(
        n_individuals=ds.n_indv,
        ac=ds.AC_mnv,
        ac_hom=ds.n_homhom,
        exome=hl.or_missing(
            ds.mnv_in_exome,
            hl.struct(n_individuals=ds.n_indv_ex,
                      ac=ds.AC_mnv_ex,
                      ac_hom=ds.n_homhom_ex),
        ),
        genome=hl.or_missing(
            ds.mnv_in_genome,
            hl.struct(n_individuals=ds.n_indv_gen,
                      ac=ds.AC_mnv_gen,
                      ac_hom=ds.n_homhom_gen),
        ),
    )

    ds = ds.drop("AC_snp1", "AC_snp2")

    ds = ds.transmute(consequence=hl.struct(
        category=ds.categ,
        gene_id=ds.gene_id,
        gene_name=ds.gene_name,
        transcript_id=ds.transcript_id,
        consequence=ds.mnv_consequence,
        codons=ds.mnv_codons,
        amino_acids=ds.mnv_amino_acids,
        lof=ds.mnv_lof,
        snv_consequences=[
            hl.struct(
                variant_id=ds[f"{snp}"],
                amino_acids=ds[f"{snp}_amino_acids"],
                codons=ds[f"{snp}_codons"],
                consequence=ds[f"{snp}_consequence"],
                lof=ds[f"{snp}_lof"],
            ) for snp in ["snp1", "snp2"]
        ],
    ))

    # Collapse table to one row per MNV, with all consequences for the MNV collected into an array
    consequences = ds.group_by(
        ds.mnv).aggregate(consequences=hl.agg.collect(ds.consequence))
    ds = ds.drop("consequence")
    ds = ds.distinct()
    ds = ds.join(consequences)

    # Sort consequences by severity
    ds = ds.annotate(consequences=hl.sorted(
        ds.consequences,
        key=lambda c: consequence_term_rank(c.consequence),
    ))

    ds = ds.annotate(changes_amino_acids_for_snvs=hl.literal([0, 1]).filter(
        lambda idx: ds.consequences.any(lambda csq: csq.snv_consequences[
            idx].amino_acids.lower() != csq.amino_acids.lower())).map(
                lambda idx: ds.constituent_snv_ids[idx]))

    return ds
Esempio n. 31
0
File: qc.py Progetto: troels/hail
def concordance(
        left,
        right,
        *,
        _localize_global_statistics=True
) -> Tuple[List[List[int]], Table, Table]:
    """Calculate call concordance with another dataset.

    .. include:: ../_templates/req_tstring.rst

    .. include:: ../_templates/req_tvariant.rst

    .. include:: ../_templates/req_biallelic.rst

    .. include:: ../_templates/req_unphased_diploid_gt.rst

    Examples
    --------

    Compute concordance between two datasets and output the global concordance
    statistics and two tables with concordance computed per column key and per
    row key:

    >>> global_conc, cols_conc, rows_conc = hl.concordance(dataset, dataset2)

    Notes
    -----

    This method computes the genotype call concordance (from the entry
    field **GT**) between two biallelic variant datasets.  It requires
    unique sample IDs and performs an inner join on samples (only
    samples in both datasets will be considered). In addition, all genotype
    calls must be **diploid** and **unphased**.

    It performs an ordered zip join of the variants.  That means the
    variants of each dataset are sorted, with duplicate variants
    appearing in some random relative order, and then zipped together.
    When a variant appears a different number of times between the two
    datasets, the dataset with the fewer number of instances is padded
    with "no data".  For example, if a variant is only in one dataset,
    then each genotype is treated as "no data" in the other.

    This method returns a tuple of three objects: a nested list of
    list of int with global concordance summary statistics, a table
    with concordance statistics per column key, and a table with
    concordance statistics per row key.

    **Using the global summary result**

    The global summary is a list of list of int (conceptually a 5 by 5 matrix),
    where the indices have special meaning:

    0. No Data (missing variant)
    1. No Call (missing genotype call)
    2. Hom Ref
    3. Heterozygous
    4. Hom Var

    The first index is the state in the left dataset and the second index is
    the state in the right dataset. Typical uses of the summary list are shown
    below.

    >>> summary, samples, variants = hl.concordance(dataset, dataset2)
    >>> left_homref_right_homvar = summary[2][4]
    >>> left_het_right_missing = summary[3][1]
    >>> left_het_right_something_else = sum(summary[3][:]) - summary[3][3]
    >>> total_concordant = summary[2][2] + summary[3][3] + summary[4][4]
    >>> total_discordant = sum([sum(s[2:]) for s in summary[2:]]) - total_concordant

    **Using the table results**

    Table 1: Concordance statistics by column

    This table contains the column key field of `left`, and the following fields:

        - `n_discordant` (:py:data:`.tint64`) -- Count of discordant calls (see below for
          full definition).
        - `concordance` (:class:`.tarray` of :class:`.tarray` of :py:data:`.tint64`) --
          Array of concordance per state on left and right, matching the structure of
          the global summary defined above.

    Table 2: Concordance statistics by row

    This table contains the row key fields of `left`, and the following fields:

        - `n_discordant` (:py:data:`.tfloat64`) -- Count of discordant calls (see below for
          full definition).
        - `concordance` (:class:`.tarray` of :class:`.tarray` of :py:data:`.tint64`) --
          Array of concordance per state on left and right, matching the structure of the
          global summary defined above.

    In these tables, the column **n_discordant** is provided as a convenience,
    because this is often one of the most useful concordance statistics. This
    value is the number of genotypes which were called (homozygous reference,
    heterozygous, or homozygous variant) in both datasets, but where the call
    did not match between the two.

    The column `concordance` matches the structure of the global summmary,
    which is detailed above. Once again, the first index into this array is the
    state on the left, and the second index is the state on the right. For
    example, ``concordance[1][4]`` is the number of "no call" genotypes on the
    left that were called homozygous variant on the right.

    Parameters
    ----------
    left : :class:`.MatrixTable`
        First dataset to compare.
    right : :class:`.MatrixTable`
        Second dataset to compare.

    Returns
    -------
    (list of list of int, :class:`.Table`, :class:`.Table`)
        The global concordance statistics, a table with concordance statistics
        per column key, and a table with concordance statistics per row key.

    """

    require_col_key_str(left, 'concordance, left')
    require_col_key_str(right, 'concordance, right')

    left_sample_counter = left.aggregate_cols(hl.agg.counter(left.col_key[0]))
    right_sample_counter = right.aggregate_cols(
        hl.agg.counter(right.col_key[0]))

    left_bad = [f'{k!r}: {v}' for k, v in left_sample_counter.items() if v > 1]
    right_bad = [
        f'{k!r}: {v}' for k, v in right_sample_counter.items() if v > 1
    ]
    if left_bad or right_bad:
        raise ValueError(f"Found duplicate sample IDs:\n"
                         f"  left:  {', '.join(left_bad)}\n"
                         f"  right: {', '.join(right_bad)}")

    included = set(left_sample_counter.keys()).intersection(
        set(right_sample_counter.keys()))

    info(
        f"concordance: including {len(included)} shared samples "
        f"({len(left_sample_counter)} total on left, {len(right_sample_counter)} total on right)"
    )

    left = require_biallelic(left, 'concordance, left')
    right = require_biallelic(right, 'concordance, right')

    lit = hl.literal(included, dtype=hl.tset(hl.tstr))
    left = left.filter_cols(lit.contains(left.col_key[0]))
    right = right.filter_cols(lit.contains(right.col_key[0]))

    left = left.select_entries('GT').select_rows().select_cols()
    right = right.select_entries('GT').select_rows().select_cols()

    joined = hl.experimental.full_outer_join_mt(left, right)

    def get_idx(struct):
        return hl.cond(hl.is_missing(struct), 0,
                       hl.coalesce(2 + struct.GT.n_alt_alleles(), 1))

    aggr = hl.agg.counter(
        get_idx(joined.left_entry) + 5 * get_idx(joined.right_entry))

    def concordance_array(counter):
        return hl.range(0, 5).map(
            lambda i: hl.range(0, 5).map(lambda j: counter.get(i + 5 * j, 0)))

    def n_discordant(counter):
        return hl.sum(
            hl.array(counter).filter(lambda tup: ~hl.literal(
                {i**2
                 for i in range(5)}).contains(tup[0])).map(lambda tup: tup[1]))

    glob = joined.aggregate_entries(concordance_array(aggr),
                                    _localize=_localize_global_statistics)
    if _localize_global_statistics:
        total_conc = [x[1:] for x in glob[1:]]
        on_diag = sum(total_conc[i][i] for i in range(len(total_conc)))
        total_obs = sum(sum(x) for x in total_conc)
        info(f"concordance: total concordance {on_diag/total_obs * 100:.2f}%")

    per_variant = joined.annotate_rows(concordance=aggr)
    per_variant = per_variant.annotate_rows(
        concordance=concordance_array(per_variant.concordance),
        n_discordant=n_discordant(per_variant.concordance))
    per_sample = joined.annotate_cols(concordance=aggr)
    per_sample = per_sample.annotate_cols(
        concordance=concordance_array(per_sample.concordance),
        n_discordant=n_discordant(per_sample.concordance))

    return glob, per_sample.cols(), per_variant.rows()
Esempio n. 32
0
									y=final.pheno.All_Pneumonia,\
									x=final.GT.n_alt_alleles(),\
									covariates=[1, final.pheno.age,final.pheno.age2, final.pheno.Sex_numeric, final.pheno.ever_smoked, final.pheno.PC1,final.pheno.PC2,final.pheno.PC3,final.pheno.PC4,final.pheno.PC5,final.pheno.PC6,final.pheno.PC7,final.pheno.PC8,final.pheno.PC9,final.pheno.PC10,final.array],
									pass_through=['rsid','Gene','Consequence','clin_sig', 'metasvm','LOF_LOFTEE','PolyPhen','SIFT','hgvsp','AF', 'AC', 'AN','info'])
​
### Writting out the annotated GWAS results:
gwas.flatten().export('gs://ukbb_v2/projects/mzekavat/Pneumonia_GWAS/logreg_wald_All_Pneumonia.tsv.bgz')
gwas.write('gs://ukbb_v2/projects/mzekavat/Pneumonia_GWAS/logreg_wald_All_Pneumonia.ht')
gwas = hl.read_table('gs://ukbb_v2/projects/mzekavat/Pneumonia_GWAS/logreg_wald_All_Pneumonia.ht')
gwas_v2 = gwas.filter(gwas.p_value<0.0001, keep=True)


### Filtering the pneumonia GWAS to just the SBP and DBP SNPs:
SBPsnps = hl.import_table('gs://ukbb_v2/projects/mzekavat/Pneumonia_GWAS/SBP_75SNP_instrument_hg37.txt', impute = True) 
DBPsnps = hl.import_table('gs://ukbb_v2/projects/mzekavat/Pneumonia_GWAS/DBP_75SNP_instrument_hg37.txt', impute = True) 


SBP_SNPs =  [row['SNP'] for row in SBPsnps.select(SBPsnps.SNP).collect()]

gwas_SBPvar =gwas.filter(hl.literal(SBP_SNPs).contains(gwas.rsid), keep=True)

DBP_SNPs =  [row['SNP'] for row in DBPsnps.select(DBPsnps.SNP).collect()]

gwas_DBPvar =gwas.filter(hl.literal(DBP_SNPs).contains(gwas.rsid), keep=True)


gwas_SBPvar.export('gs://ukbb_v2/projects/mzekavat/Pneumonia_GWAS/gwas_SBPvar.logreg_wald_All_Pneumonia.tsv.bgz')
gwas_DBPvar.export('gs://ukbb_v2/projects/mzekavat/Pneumonia_GWAS/gwas_DBPvar.logreg_wald_All_Pneumonia.tsv.bgz')


Esempio n. 33
0
def default_generate_gene_lof_matrix(
    mt: hl.MatrixTable,
    tx_ht: Optional[hl.Table],
    high_expression_cutoff: float = 0.9,
    low_expression_cutoff: float = 0.1,
    filter_field: str = "filters",
    freq_field: str = "freq",
    freq_index: int = 0,
    additional_csq_set: Set[str] = {"missense_variant", "synonymous_variant"},
    all_transcripts: bool = False,
    filter_an: bool = False,
    filter_to_rare: bool = False,
    pre_loftee: bool = False,
    lof_csq_set: Set[str] = LOF_CSQ_SET,
    remove_ultra_common: bool = False,
) -> hl.MatrixTable:
    """
    Generate loss-of-function gene matrix.

    Used to generate summary metrics on LoF variants.

    :param mt: Input MatrixTable.
    :param tx_ht: Optional Table containing expression levels per transcript.
    :param high_expression_cutoff: Minimum mean proportion expressed cutoff for a transcript to be considered highly expressed. Default is 0.9.
    :param low_expression_cutoff: Upper mean proportion expressed cutoff for a transcript to lowly expressed. Default is 0.1.
    :param filter_field: Name of field in MT that contains variant filters. Default is 'filters'.
    :param freq_field: Name of field in MT that contains frequency information. Default is 'freq'.
    :param freq_index: Which index of frequency struct to use. Default is 0.
    :param additional_csq_set: Set of additional consequences to keep. Default is {'missense_variant', 'synonymous_variant'}.
    :param all_transcripts: Whether to use all transcripts instead of just the transcript with most severe consequence. Default is False.
    :param filter_an: Whether to filter using allele number as proxy for call rate. Default is False.
    :param filter_to_rare: Whether to filter to rare (AF < 5%) variants. Default is False.
    :param pre_loftee: Whether LoF consequences have been annotated with LOFTEE. Default is False.
    :param lof_csq_set: Set of LoF consequence strings. Default is {"splice_acceptor_variant", "splice_donor_variant", "stop_gained", "frameshift_variant"}.
    :param remove_ultra_common: Whether to remove ultra common (AF > 95%) variants. Default is False.
    """
    logger.info("Filtering to PASS variants...")
    filt_criteria = hl.len(mt[filter_field]) == 0
    if filter_an:
        logger.info(
            "Using AN (as a call rate proxy) to filter to variants that meet a minimum call rate..."
        )
        mt = mt.filter_rows(get_an_criteria(mt))
    if remove_ultra_common:
        logger.info("Removing ultra common (AF > 95%) variants...")
        filt_criteria &= mt[freq_field][freq_index].AF < 0.95
    if filter_to_rare:
        logger.info("Filtering to rare (AF < 5%) variants...")
        filt_criteria &= mt[freq_field][freq_index].AF < 0.05
    mt = mt.filter_rows(filt_criteria)

    if all_transcripts:
        logger.info("Exploding transcript_consequences field...")
        explode_field = "transcript_consequences"
    else:
        logger.info(
            "Adding most severe (worst) consequence and expoding worst_csq_by_gene field..."
        )
        mt = process_consequences(mt)
        explode_field = "worst_csq_by_gene"

    if additional_csq_set:
        logger.info("Including these consequences: %s", additional_csq_set)
        additional_cats = hl.literal(additional_csq_set)

    if pre_loftee:
        logger.info("Filtering to LoF consequences: %s", lof_csq_set)
        lof_cats = hl.literal(lof_csq_set)
        criteria = lambda x: lof_cats.contains(
            add_most_severe_consequence_to_consequence(x).
            most_severe_consequence)
        if additional_csq_set:
            criteria = lambda x: lof_cats.contains(
                add_most_severe_consequence_to_consequence(x).
                most_severe_consequence) | additional_cats.contains(
                    add_most_severe_consequence_to_consequence(x).
                    most_severe_consequence)

    else:
        logger.info(
            "Filtering to LoF variants that pass LOFTEE with no LoF flags...")
        criteria = lambda x: (x.lof == "HC") & hl.is_missing(x.lof_flags)
        if additional_csq_set:
            criteria = lambda x: (x.lof == "HC") & hl.is_missing(
                x.lof_flags) | additional_cats.contains(
                    add_most_severe_consequence_to_consequence(x).
                    most_severe_consequence)

    csqs = mt.vep[explode_field].filter(criteria)
    mt = mt.select_rows(mt[freq_field], csqs=csqs)
    mt = mt.explode_rows(mt.csqs)
    annotation_expr = {
        "gene_id": mt.csqs.gene_id,
        "gene": mt.csqs.gene_symbol,
        "indel": hl.is_indel(mt.alleles[0], mt.alleles[1]),
        "most_severe_consequence": mt.csqs.most_severe_consequence,
    }

    if tx_ht:
        logger.info("Adding transcript expression annotation...")
        tx_annotation = get_tx_expression_expr(
            mt.row_key,
            tx_ht,
            mt.csqs,
        ).mean_proportion
        annotation_expr["expressed"] = (hl.case().when(
            tx_annotation >= high_expression_cutoff,
            "high").when(tx_annotation > low_expression_cutoff,
                         "medium").when(hl.is_defined(tx_annotation),
                                        "low").default("missing"))
    else:
        annotation_expr["transcript_id"] = mt.csqs.transcript_id
        annotation_expr["canonical"] = hl.is_defined(mt.csqs.canonical)

    mt = mt.annotate_rows(**annotation_expr)
    return (mt.group_rows_by(*list(annotation_expr.keys())).aggregate_rows(
        n_sites=hl.agg.count(),
        n_sites_array=hl.agg.array_sum(
            mt.freq.map(lambda x: hl.int(x.AC > 0))),
        classic_caf=hl.agg.sum(mt[freq_field][freq_index].AF),
        max_af=hl.agg.max(mt[freq_field][freq_index].AF),
        classic_caf_array=hl.agg.array_sum(mt[freq_field].map(lambda x: x.AF)),
    ).aggregate_entries(
        num_homs=hl.agg.count_where(mt.GT.is_hom_var()),
        num_hets=hl.agg.count_where(mt.GT.is_het()),
        defined_sites=hl.agg.count_where(hl.is_defined(mt.GT)),
    ).result())
Esempio n. 34
0
 def test_literals_rebuild(self):
     mt = hl.utils.range_matrix_table(1, 1)
     mt = mt.annotate_rows(x=hl.cond(
         hl.len(hl.literal([1, 2, 3])) < hl.rand_unif(10, 11), mt.globals,
         hl.struct()))
     mt._force_count_rows()
Esempio n. 35
0
def import_gtf(path,
               reference_genome=None,
               skip_invalid_contigs=False,
               min_partitions=None,
               force_bgz=False,
               force=False) -> hl.Table:
    """Import a GTF file.

       The GTF file format is identical to the GFF version 2 file format,
       and so this function can be used to import GFF version 2 files as
       well.

       See https://www.ensembl.org/info/website/upload/gff.html for more
       details on the GTF/GFF2 file format.

       The :class:`.Table` returned by this function will be keyed by the
       ``interval`` row field and will include the following row fields:

       .. code-block:: text

           'source': str
           'feature': str
           'score': float64
           'strand': str
           'frame': int32
           'interval': interval<>

       There will also be corresponding fields for every tag found in the
       attribute field of the GTF file.

       Note
       ----

       This function will return an ``interval`` field of type :class:`.tinterval`
       constructed from the ``seqname``, ``start``, and ``end`` fields in the
       GTF file. This interval is inclusive of both the start and end positions
       in the GTF file.

       If the ``reference_genome`` parameter is specified, the start and end
       points of the ``interval`` field will be of type :class:`.tlocus`.
       Otherwise, the start and end points of the ``interval`` field will be of
       type :class:`.tstruct` with fields ``seqname`` (type :class:`str`) and
       ``position`` (type :obj:`.tint32`).

       Furthermore, if the ``reference_genome`` parameter is specified and
       ``skip_invalid_contigs`` is ``True``, this import function will skip
       lines in the GTF where ``seqname`` is not consistent with the reference
       genome specified.

       Example
       -------

       >>> ht = hl.experimental.import_gtf('data/test.gtf',
       ...                                 reference_genome='GRCh37',
       ...                                 skip_invalid_contigs=True)

       >>> ht.describe()  # doctest: +SKIP_OUTPUT_CHECK
       ----------------------------------------
       Global fields:
       None
       ----------------------------------------
       Row fields:
           'source': str
           'feature': str
           'score': float64
           'strand': str
           'frame': int32
           'gene_type': str
           'exon_id': str
           'havana_transcript': str
           'level': str
           'transcript_name': str
           'gene_status': str
           'gene_id': str
           'transcript_type': str
           'tag': str
           'transcript_status': str
           'gene_name': str
           'transcript_id': str
           'exon_number': str
           'havana_gene': str
           'interval': interval<locus<GRCh37>>
       ----------------------------------------
       Key: ['interval']
       ----------------------------------------

       Parameters
       ----------

       path : :class:`str`
           File to import.
       reference_genome : :class:`str` or :class:`.ReferenceGenome`, optional
           Reference genome to use.
       skip_invalid_contigs : :obj:`bool`
           If ``True`` and `reference_genome` is not ``None``, skip lines where
           ``seqname`` is not consistent with the reference genome.
       min_partitions : :obj:`int` or :obj:`None`
           Minimum number of partitions (passed to import_table).
       force_bgz : :obj:`bool`
           If ``True``, load files as blocked gzip files, assuming
           that they were actually compressed using the BGZ codec. This option is
           useful when the file extension is not ``'.bgz'``, but the file is
           blocked gzip, so that the file can be read in parallel and not on a
           single node.
       force : :obj:`bool`
           If ``True``, load gzipped files serially on one core. This should
           be used only when absolutely necessary, as processing time will be
           increased due to lack of parallelism.

       Returns
       -------
       :class:`.Table`
       """

    ht = hl.import_table(path,
                         min_partitions=min_partitions,
                         comment='#',
                         no_header=True,
                         types={
                             'f3': hl.tint,
                             'f4': hl.tint,
                             'f5': hl.tfloat,
                             'f7': hl.tint
                         },
                         missing='.',
                         delimiter='\t',
                         force_bgz=force_bgz,
                         force=force)

    ht = ht.rename({
        'f0': 'seqname',
        'f1': 'source',
        'f2': 'feature',
        'f3': 'start',
        'f4': 'end',
        'f5': 'score',
        'f6': 'strand',
        'f7': 'frame',
        'f8': 'attribute'
    })

    def parse_attributes(unparsed_attributes):
        def parse_attribute(attribute):
            key_and_value = attribute.split(' ')
            key = key_and_value[0]
            value = key_and_value[1]
            return (key, value.replace('"|;\\$', ''))

        return hl.dict(unparsed_attributes.split('; ').map(parse_attribute))

    ht = ht.annotate(attribute=parse_attributes(ht['attribute']))

    ht = ht.checkpoint(new_temp_file())

    attributes = ht.aggregate(
        hl.agg.explode(lambda x: hl.agg.collect_as_set(x),
                       ht['attribute'].keys()))

    ht = ht.transmute(
        **{
            x: hl.or_missing(ht['attribute'].contains(x), ht['attribute'][x])
            for x in attributes if x
        })

    if reference_genome:
        if reference_genome.name == 'GRCh37':
            ht = ht.annotate(
                seqname=hl.case().when((ht['seqname'] == 'M')
                                       | (ht['seqname'] == 'chrM'), 'MT').
                when(ht['seqname'].startswith('chr'), ht['seqname'].replace(
                    '^chr', '')).default(ht['seqname']))
        else:
            ht = ht.annotate(seqname=hl.case().when(
                ht['seqname'].startswith('HLA'), ht['seqname']).when(
                    ht['seqname'].startswith('chrHLA'), ht['seqname'].replace(
                        '^chr', '')).when(ht['seqname'].startswith(
                            'chr'), ht['seqname']).default('chr' +
                                                           ht['seqname']))
        if skip_invalid_contigs:
            valid_contigs = hl.literal(set(reference_genome.contigs))
            ht = ht.filter(valid_contigs.contains(ht['seqname']))
        ht = ht.transmute(
            interval=hl.locus_interval(ht['seqname'],
                                       ht['start'],
                                       ht['end'],
                                       includes_start=True,
                                       includes_end=True,
                                       reference_genome=reference_genome))
    else:
        ht = ht.transmute(interval=hl.interval(
            hl.struct(seqname=ht['seqname'], position=ht['start']),
            hl.struct(seqname=ht['seqname'], position=ht['end']),
            includes_start=True,
            includes_end=True))

    ht = ht.key_by('interval')

    return ht
Esempio n. 36
0
 def test_filter_cols_with_global_references(self):
     mt = hl.utils.range_matrix_table(10, 10)
     s = hl.literal({1, 3, 5, 7})
     self.assertEqual(
         mt.filter_cols(s.contains(mt.col_idx)).count_cols(), 4)
Esempio n. 37
0
def faf_expr(
    freq: hl.expr.ArrayExpression,
    freq_meta: hl.expr.ArrayExpression,
    locus: hl.expr.LocusExpression,
    pops_to_exclude: Optional[Set[str]] = None,
    faf_thresholds: List[float] = [0.95, 0.99],
) -> Tuple[hl.expr.ArrayExpression, List[Dict[str, str]]]:
    """
    Calculates the filtering allele frequency (FAF) for each threshold specified in `faf_thresholds`.
    See http://cardiodb.org/allelefrequencyapp/ for more information.

    The FAF is computed for each of the following population stratification if found in `freq_meta`:

        - All samples, with adj criteria
        - For each population, with adj criteria
        - For all sex/population on the non-PAR regions of sex chromosomes (will be missing on autosomes and PAR regions of sex chromosomes)

    Each of the FAF entry is a struct with one entry per threshold specified in `faf_thresholds` of type float64.

    This returns a tuple with two expressions:

        1. An array of FAF expressions as described above
        2. An array of dict containing the metadata for each of the array elements, in the same format as that produced by `annotate_freq`.

    :param freq: ArrayExpression of call stats structs (typically generated by hl.agg.call_stats)
    :param freq_meta: ArrayExpression of meta dictionaries corresponding to freq (typically generated using annotate_freq)
    :param locus: locus
    :param pops_to_exclude: Set of populations to exclude from faf calculation (typically bottlenecked or consanguineous populations)
    :param faf_thresholds: List of FAF thresholds to compute
    :return: (FAF expression, FAF metadata)
    """
    _pops_to_exclude = (
        hl.literal(pops_to_exclude) if pops_to_exclude is not None else {}
    )

    # pylint: disable=invalid-unary-operand-type
    faf_freq_indices = hl.range(0, hl.len(freq_meta)).filter(
        lambda i: (freq_meta[i].get("group") == "adj")
        & (
            (freq_meta[i].size() == 1)
            | (
                (hl.set(freq_meta[i].keys()) == {"pop", "group"})
                & (~_pops_to_exclude.contains(freq_meta[i]["pop"]))
            )
        )
    )
    sex_faf_freq_indices = hl.range(0, hl.len(freq_meta)).filter(
        lambda i: (freq_meta[i].get("group") == "adj")
        & (freq_meta[i].contains("sex"))
        & (
            (freq_meta[i].size() == 2)
            | (
                (hl.set(freq_meta[i].keys()) == {"pop", "group", "sex"})
                & (~_pops_to_exclude.contains(freq_meta[i]["pop"]))
            )
        )
    )

    faf_expr = faf_freq_indices.map(
        lambda i: hl.struct(
            **{
                f"faf{str(threshold)[2:]}": hl.experimental.filtering_allele_frequency(
                    freq[i].AC, freq[i].AN, threshold
                )
                for threshold in faf_thresholds
            }
        )
    )

    faf_expr = faf_expr.extend(
        sex_faf_freq_indices.map(
            lambda i: hl.or_missing(
                ~locus.in_autosome_or_par(),
                hl.struct(
                    **{
                        f"faf{str(threshold)[2:]}": hl.experimental.filtering_allele_frequency(
                            freq[i].AC, freq[i].AN, threshold
                        )
                        for threshold in faf_thresholds
                    }
                ),
            )
        )
    )

    faf_meta = faf_freq_indices.extend(sex_faf_freq_indices).map(lambda i: freq_meta[i])
    return faf_expr, hl.eval(faf_meta)
Esempio n. 38
0
def export_binary_eur(cluster_idx, num_clusters=10, batch_size = 256):
    r'''
    Export summary statistics for binary traits defined only for EUR. 
    Given the large number of such traits (4184), it makes sense to batch this 
    across `num_clusters` clusters for reduced wall time and robustness to mid-export errors.
    NOTE: `cluster_idx` is 1-indexed.
    '''
    mt0 = get_final_sumstats_mt_for_export()
    meta_mt0 = hl.read_matrix_table(get_meta_analysis_results_path())
    
    mt0 = mt0.annotate_cols(pheno_id = get_pheno_id(tb=mt0))
    mt0 = mt0.annotate_rows(chr = mt0.locus.contig,
                            pos = mt0.locus.position,
                            ref = mt0.alleles[0],
                            alt = mt0.alleles[1])
    
    trait_types_to_run = ['categorical','phecode', 'icd10', 'prescriptions'] # list of which trait_type to run
        
    # fields specific to each category of trait    
    meta_fields = ['AF_Cases','AF_Controls']
    fields = ['AF.Cases','AF.Controls']
    
    # dictionaries for renaming fields
    meta_field_rename_dict = {'BETA':'beta_meta',
                                     'SE':'se_meta',
                                     'Pvalue':'pval_meta',
                                     'AF_Cases':'af_cases_meta',
                                     'AF_Controls':'af_controls_meta',
                                     'Pvalue_het':'pval_heterogeneity'}
    field_rename_dict = {'AF.Cases':'af_cases',
                                'AF.Controls':'af_controls',
                                'BETA':'beta',
                                'SE':'se',
                                'Pvalue':'pval',
                                'low_confidence':'low_confidence'} # decided on this implementation to make later code cleaner
    
    all_binary_trait_types = {'categorical','phecode', 'icd10', 'prescriptions'}
    
    meta_fields += ['BETA','SE','Pvalue','Pvalue_het']
    fields += ['BETA','SE','Pvalue','low_confidence']
    
    trait_category = 'binary'        
    trait_types = all_binary_trait_types.intersection(trait_types_to_run) # get list of binary trait types to run
    pop_set = {'EUR'}
    start = time()
    
    mt1 = mt0.filter_cols((hl.literal(trait_types).contains(mt0.trait_type))&
                          (hl.set(mt0.pheno_data.pop)==hl.literal(pop_set)))
    
    pheno_id_list = mt1.pheno_id.collect()
    
    num_traits = len(pheno_id_list) # total number of traits to run
    
    traits_per_cluster = ceil(num_traits/num_clusters) # maximum traits to run per cluster
    
    cluster_pheno_id_list = pheno_id_list[(cluster_idx-1)*traits_per_cluster:cluster_idx*traits_per_cluster] # list of traits to run in current cluster
    
    print(len(cluster_pheno_id_list))
    
    mt1 = mt1.filter_cols(hl.literal(cluster_pheno_id_list).contains(mt1.pheno_id))
    
    pop_list = sorted(pop_set)
    
    annotate_dict = {}
    
    keyed_mt = meta_mt0[mt1.row_key,mt1.col_key]
    if len(pop_set)>1:
        for field in meta_fields: # NOTE: Meta-analysis columns go before per-population columns
            field_expr = keyed_mt.meta_analysis[field][0]
            annotate_dict.update({f'{meta_field_rename_dict[field]}': hl.if_else(hl.is_nan(field_expr),
                                                                      hl.str(field_expr),
                                                                      hl.format('%.3e', field_expr))})

    for field in fields:
        for pop_idx, pop in enumerate(pop_list):
            field_expr = mt1.summary_stats[field][pop_idx]
            annotate_dict.update({f'{field_rename_dict[field]}_{pop}': hl.if_else(hl.is_nan(field_expr),
                                                                       hl.str(field_expr),
                                                                       hl.str(field_expr) if field=='low_confidence' else hl.format('%.3e', field_expr))})
    
    mt2 = mt1.annotate_entries(**annotate_dict)
    
    mt2 = mt2.filter_cols(mt2.coding != 'zekavat_20200409')
    mt2 = mt2.key_cols_by('pheno_id')
    mt2 = mt2.key_rows_by().drop('locus','alleles','summary_stats') # row fields that are no longer included: 'gene','annotation'
    print(mt2.describe())
    
    batch_idx = 1
    get_export_path = lambda batch_idx: f'{ldprune_dir}/release/{trait_category}/{"-".join(pop_list)}_batch{batch_idx}/subbatch{cluster_idx}'

    while hl.hadoop_is_dir(get_export_path(batch_idx)):
        batch_idx += 1
    print(f'\nExporting {len(cluster_pheno_id_list)} phenos to: {get_export_path(batch_idx)}\n')
    hl.experimental.export_entries_by_col(mt = mt2,
                                          path = get_export_path(batch_idx),
                                          bgzip = True,
                                          batch_size = batch_size,
                                          use_string_key_as_file_name = True,
                                          header_json_in_file = False)
    end = time()
    print(f'\nExport complete for:\n{trait_types}\n{pop_list}\ntime: {round((end-start)/3600,2)} hrs')
for k in K:
    for o in OR:

        # Participation bias
        df['z'] = df['y0'] * log(o) + df['y1'] * log(o) + np.random.normal(
            0, 0.1, len(df.index))

        # Sex-differential effect Zm = Z*K
        df.loc[(df['sex'] == 0), 'z'] = df.loc[(df['sex'] == 0), 'z'] * k
        df['prob'] = [1 / (1 + exp(-z)) for z in df['z']]
        df['sel'] = np.random.binomial(n=1, p=df['prob'], size=len(df.index))

        # Filter MatrixTable and get sample
        samples_to_keep = set(df.loc[(df['sel'] == 1), 's'])
        set_to_keep = hl.literal(samples_to_keep)
        mt_sampled = mt.filter_cols(set_to_keep.contains(mt['s']), keep=True)

        i = '_' + str(k) + '_' + str(o)

        # Export phenotypes
        mt_sampled.cols().select(
            's', 'sex', 'y0',
            'y1').key_by().export(out_bucket + 'phenotypes/pheno' + i + '.tsv')

        # GWAS of sex
        gwas_s = gwas(mt_sampled.sex, mt_sampled.GT.n_alt_alleles(), [1.0])
        fn = out_bucket + 'gwas/gwas_sex' + i + '.tsv'
        export_gwas(gwas_s, fn)

        # GWAS of y0
Esempio n. 40
0
def export_results(num_pops, trait_types='all', batch_size=256, mt=None, 
                   export_path_str=None, skip_binary_eur=True):
    r'''
    `num_pops`: exact number of populations for which phenotype is defined
    `trait_types`: trait category (options: all, binary, quant)
    `batch_size`: batch size argument for export entries by col
    '''
    assert trait_types in {'all','quant','binary'}, "trait_types must be one of the following: {'all','quant','binary'}"
    print(f'\n\nExporting {trait_types} trait types for {num_pops} pops\n\n')
    if mt == None:
        mt0 = get_final_sumstats_mt_for_export()
    else:
        mt0 = mt
        
    meta_mt0 = hl.read_matrix_table(get_meta_analysis_results_path())
    
    mt0 = mt0.annotate_cols(pheno_id = get_pheno_id(tb=mt0))
    mt0 = mt0.annotate_rows(chr = mt0.locus.contig,
                            pos = mt0.locus.position,
                            ref = mt0.alleles[0],
                            alt = mt0.alleles[1])
    
    all_pops = ['AFR', 'AMR', 'CSA', 'EAS', 'EUR', 'MID']
    
    if trait_types == 'all':
        trait_types_to_run = ['continuous','biomarkers','categorical','phecode', 'icd10', 'prescriptions'] # list of which trait_type to run
    elif trait_types == 'quant':
        trait_types_to_run = ['continuous','biomarkers']
    elif trait_types == 'binary':
        trait_types_to_run = ['categorical','phecode', 'icd10', 'prescriptions']

    pop_sets = [set(i) for i in list(combinations(all_pops, num_pops))] # list of exact set of pops for which phenotype is defined
    
    
    # fields specific to each category of trait
    quant_meta_fields = ['AF_Allele2']
    quant_fields = ['AF_Allele2']
    
    binary_meta_fields = ['AF_Cases','AF_Controls']
    binary_fields = ['AF.Cases','AF.Controls']
    
    # dictionaries for renaming fields
    quant_meta_field_rename_dict = {'AF_Allele2':'af_meta',
                              'BETA':'beta_meta',
                              'SE':'se_meta',
                              'Pvalue':'pval_meta',
                              'Pvalue_het':'pval_heterogeneity'}
    quant_field_rename_dict = {'AF_Allele2':'af',
                         'BETA':'beta',
                         'SE':'se',
                         'Pvalue':'pval',
                         'low_confidence':'low_confidence'} # decided on this implementation to make later code cleaner
    
    binary_meta_field_rename_dict = {'BETA':'beta_meta',
                                     'SE':'se_meta',
                                     'Pvalue':'pval_meta',
                                     'AF_Cases':'af_cases_meta',
                                     'AF_Controls':'af_controls_meta',
                                     'Pvalue_het':'pval_heterogeneity'}
    binary_field_rename_dict = {'AF.Cases':'af_cases',
                                'AF.Controls':'af_controls',
                                'BETA':'beta',
                                'SE':'se',
                                'Pvalue':'pval',
                                'low_confidence':'low_confidence'} # decided on this implementation to make later code cleaner
    
    all_quant_trait_types = {'continuous','biomarkers'}
    all_binary_trait_types = {'categorical','phecode', 'icd10', 'prescriptions'}
    
    quant_trait_types = all_quant_trait_types.intersection(trait_types_to_run) # get list of quant trait types to run
    binary_trait_types = all_binary_trait_types.intersection(trait_types_to_run) # get list of binary trait types to run
    error_trait_types = set(trait_types_to_run).difference(quant_trait_types.union(binary_trait_types))
    assert len(error_trait_types)==0, f'ERROR: The following trait_types are invalid: {error_trait_types}'
        
    for trait_category, trait_types in [('binary', binary_trait_types), ('quant', quant_trait_types)]:
        if len(trait_types)==0: #if no traits in trait_types list
            continue

        print(f'{trait_category} trait types to run: {trait_types}')
        
        if trait_category == 'quant':
            meta_fields = quant_meta_fields
            fields = quant_fields
            meta_field_rename_dict = quant_meta_field_rename_dict
            field_rename_dict = quant_field_rename_dict
        elif trait_category == 'binary':
            meta_fields = binary_meta_fields
            fields = binary_fields
            meta_field_rename_dict = binary_meta_field_rename_dict
            field_rename_dict = binary_field_rename_dict
    
        meta_fields += ['BETA','SE','Pvalue','Pvalue_het']
        fields += ['BETA','SE','Pvalue','low_confidence']
            
        for pop_set in pop_sets:    
            start = time()
            
            if (pop_set == {'EUR'} and trait_category == 'binary') and skip_binary_eur: # run EUR-only binary traits separately
                print('\nSkipping EUR-only binary traits\n')
                continue
            
            mt1 = mt0.filter_cols((hl.literal(trait_types).contains(mt0.trait_type))&
                                  (hl.set(mt0.pheno_data.pop)==hl.literal(pop_set)))
            
            col_ct = mt1.count_cols()
            if col_ct==0:
                print(f'\nSkipping {trait_types},{sorted(pop_set)}, no phenotypes found\n')
                continue
            
            pop_list = sorted(pop_set)
            
            annotate_dict = {}
            keyed_mt = meta_mt0[mt1.row_key,mt1.col_key]
            if len(pop_set)>1:
                for field in meta_fields: # NOTE: Meta-analysis columns go before per-population columns
                    field_expr = keyed_mt.meta_analysis[field][0]
                    annotate_dict.update({f'{meta_field_rename_dict[field]}': hl.if_else(hl.is_nan(field_expr),
                                                                              hl.str(field_expr),
                                                                              hl.format('%.3e', field_expr))})
        
            for field in fields:
                for pop_idx, pop in enumerate(pop_list):
                    field_expr = mt1.summary_stats[field][pop_idx]
                    annotate_dict.update({f'{field_rename_dict[field]}_{pop}': hl.if_else(hl.is_nan(field_expr),
                                                                               hl.str(field_expr),
                                                                               hl.str(field_expr) if field=='low_confidence' else hl.format('%.3e', field_expr))})
            
            mt2 = mt1.annotate_entries(**annotate_dict)
            
            mt2 = mt2.filter_cols(mt2.coding != 'zekavat_20200409')
            mt2 = mt2.key_cols_by('pheno_id')
            mt2 = mt2.key_rows_by().drop('locus','alleles','summary_stats') # row fields that are no longer included: 'gene','annotation'
            
            batch_idx = 1        
            get_export_path = lambda batch_idx: f'{ldprune_dir}/export_results/{"" if export_path_str is None else f"{export_path_str}/"}{trait_category}/{"-".join(pop_list)}_batch{batch_idx}'
            print(mt2.describe())
            while hl.hadoop_is_dir(get_export_path(batch_idx)):
                batch_idx += 1
            print(f'\nExporting {col_ct} phenos to: {get_export_path(batch_idx)}\n')
            hl.experimental.export_entries_by_col(mt = mt2,
                                                  path = get_export_path(batch_idx),
                                                  bgzip = True,
                                                  batch_size = batch_size,
                                                  use_string_key_as_file_name = True,
                                                  header_json_in_file = False)
            end = time()
            print(f'\nExport complete for:\n{trait_types}\n{pop_list}\ntime: {round((end-start)/3600,2)} hrs')
Esempio n. 41
0
def trio_matrix(dataset, pedigree, complete_trios=False) -> MatrixTable:
    """Builds and returns a matrix where columns correspond to trios and entries contain genotypes for the trio.

    .. include:: ../_templates/req_tstring.rst

    Examples
    --------

    Create a trio matrix:

    >>> pedigree = hl.Pedigree.read('data/case_control_study.fam')
    >>> trio_dataset = hl.trio_matrix(dataset, pedigree, complete_trios=True)

    Notes
    -----

    This method builds a new matrix table with one column per trio. If
    `complete_trios` is ``True``, then only trios that satisfy
    :meth:`.Trio.is_complete` are included. In this new dataset, the column
    identifiers are the sample IDs of the trio probands. The column fields and
    entries of the matrix are changed in the following ways:

    The new column fields consist of three structs (`proband`, `father`,
    `mother`), a Boolean field, and a string field:

    - **proband** (:class:`.tstruct`) - Column fields on the proband.
    - **father** (:class:`.tstruct`) - Column fields on the father.
    - **mother** (:class:`.tstruct`) - Column fields on the mother.
    - **id** (:py:data:`.tstr`) - Column key for the proband.
    - **is_female** (:py:data:`.tbool`) - Proband is female.
      ``True`` for female, ``False`` for male, missing if unknown.
    - **fam_id** (:py:data:`.tstr`) - Family ID.

    The new entry fields are:

    - **proband_entry** (:class:`.tstruct`) - Proband entry fields.
    - **father_entry** (:class:`.tstruct`) - Father entry fields.
    - **mother_entry** (:class:`.tstruct`) - Mother entry fields.

    Parameters
    ----------
    pedigree : :class:`.Pedigree`

    Returns
    -------
    :class:`.MatrixTable`
    """
    mt = dataset
    require_col_key_str(mt, "trio_matrix")

    k = mt.col_key.dtype.fields[0]
    samples = mt[k].collect()

    pedigree = pedigree.filter_to(samples)
    trios = pedigree.complete_trios() if complete_trios else pedigree.trios
    n_trios = len(trios)

    sample_idx = {}
    for i, s in enumerate(samples):
        sample_idx[s] = i

    trios = [
        hl.Struct(id=sample_idx[t.s],
                  pat_id=None if t.pat_id is None else sample_idx[t.pat_id],
                  mat_id=None if t.mat_id is None else sample_idx[t.mat_id],
                  is_female=t.is_female,
                  fam_id=t.fam_id) for t in trios
    ]
    trios_type = hl.dtype(
        'array<struct{id:int,pat_id:int,mat_id:int,is_female:bool,fam_id:str}>'
    )

    trios_sym = Env.get_uid()
    entries_sym = Env.get_uid()
    cols_sym = Env.get_uid()

    mt = mt.annotate_globals(**{trios_sym: hl.literal(trios, trios_type)})
    mt = mt._localize_entries(entries_sym, cols_sym)
    mt = mt.annotate_globals(
        **{
            cols_sym:
            hl.map(
                lambda i: hl.bind(
                    lambda t: hl.struct(id=mt[cols_sym][t.id][k],
                                        proband=mt[cols_sym][t.id],
                                        father=mt[cols_sym][t.pat_id],
                                        mother=mt[cols_sym][t.mat_id],
                                        is_female=t.is_female,
                                        fam_id=t.fam_id), mt[trios_sym][i]),
                hl.range(0, n_trios))
        })
    mt = mt.annotate(
        **{
            entries_sym:
            hl.map(
                lambda i: hl.bind(
                    lambda t: hl.struct(proband_entry=mt[entries_sym][t.id],
                                        father_entry=mt[entries_sym][t.pat_id],
                                        mother_entry=mt[entries_sym][t.mat_id]
                                        ), mt[trios_sym][i]),
                hl.range(0, n_trios))
        })
    mt = mt.drop(trios_sym)

    return mt._unlocalize_entries(entries_sym, cols_sym, ['id'])
Esempio n. 42
0
def ld_score(entry_expr,
             locus_expr,
             radius,
             coord_expr=None,
             annotation_exprs=None,
             block_size=None) -> Table:
    """Calculate LD scores.

    Example
    -------

    >>> # Load genetic data into MatrixTable
    >>> mt = hl.import_plink(bed='data/ldsc.bed',
    ...                      bim='data/ldsc.bim',
    ...                      fam='data/ldsc.fam')

    >>> # Create locus-keyed Table with numeric variant annotations
    >>> ht = hl.import_table('data/ldsc.annot',
    ...                      types={'BP': hl.tint,
    ...                             'binary': hl.tfloat,
    ...                             'continuous': hl.tfloat})
    >>> ht = ht.annotate(locus=hl.locus(ht.CHR, ht.BP))
    >>> ht = ht.key_by('locus')

    >>> # Annotate MatrixTable with external annotations
    >>> mt = mt.annotate_rows(binary_annotation=ht[mt.locus].binary,
    ...                       continuous_annotation=ht[mt.locus].continuous)

    >>> # Calculate LD scores using centimorgan coordinates
    >>> ht_scores = hl.experimental.ld_score(entry_expr=mt.GT.n_alt_alleles(),
    ...                                      locus_expr=mt.locus,
    ...                                      radius=1.0,
    ...                                      coord_expr=mt.cm_position,
    ...                                      annotation_exprs=[mt.binary_annotation,
    ...                                                        mt.continuous_annotation])

    >>> # Show results
    >>> ht_scores.show(3)

    .. code-block:: text

        +---------------+-------------------+-----------------------+-------------+
        | locus         | binary_annotation | continuous_annotation |  univariate |
        +---------------+-------------------+-----------------------+-------------+
        | locus<GRCh37> |           float64 |               float64 |     float64 |
        +---------------+-------------------+-----------------------+-------------+
        | 20:82079      |       1.15183e+00 |           7.30145e+01 | 1.60117e+00 |
        | 20:103517     |       2.04604e+00 |           2.75392e+02 | 4.69239e+00 |
        | 20:108286     |       2.06585e+00 |           2.86453e+02 | 5.00124e+00 |
        +---------------+-------------------+-----------------------+-------------+


    Warning
    -------
        :func:`.ld_score` will fail if ``entry_expr`` results in any missing
        values. The special float value ``nan`` is not considered a
        missing value.

    **Further reading**

    For more in-depth discussion of LD scores, see:

    - `LD Score regression distinguishes confounding from polygenicity in genome-wide association studies (Bulik-Sullivan et al, 2015) <https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4495769/>`__
    - `Partitioning heritability by functional annotation using genome-wide association summary statistics (Finucane et al, 2015) <https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4626285/>`__

    Notes
    -----

    `entry_expr`, `locus_expr`, `coord_expr` (if specified), and
    `annotation_exprs` (if specified) must come from the same
    MatrixTable.


    Parameters
    ----------
    entry_expr : :class:`.NumericExpression`
        Expression for entries of genotype matrix
        (e.g. ``mt.GT.n_alt_alleles()``).
    locus_expr : :class:`.LocusExpression`
        Row-indexed locus expression.
    radius : :obj:`int` or :obj:`float`
        Radius of window for row values (in units of `coord_expr` if set,
        otherwise in units of basepairs).
    coord_expr: :class:`.Float64Expression`, optional
        Row-indexed numeric expression for the row value used to window
        variants. By default, the row value is given by the locus
        position.
    annotation_exprs : :class:`.NumericExpression` or
                       :obj:`list` of :class:`.NumericExpression`, optional
        Annotation expression(s) to partition LD scores. Univariate
        annotation will always be included and does not need to be
        specified.
    block_size : :obj:`int`, optional
        Block size. Default given by :meth:`.BlockMatrix.default_block_size`.

    Returns
    -------
    :class:`.Table`
        Table keyed by `locus_expr` with LD scores for each variant and
        `annotation_expr`. The function will always return LD scores for
        the univariate (all SNPs) annotation."""

    mt = entry_expr._indices.source
    mt_locus_expr = locus_expr._indices.source

    if coord_expr is None:
        mt_coord_expr = mt_locus_expr
    else:
        mt_coord_expr = coord_expr._indices.source

    if not annotation_exprs:
        check_mts = all([mt == mt_locus_expr,
                         mt == mt_coord_expr])
    else:
        check_mts = all([mt == mt_locus_expr,
                         mt == mt_coord_expr] +
                        [mt == x._indices.source
                         for x in wrap_to_list(annotation_exprs)])

    if not check_mts:
        raise ValueError("""ld_score: entry_expr, locus_expr, coord_expr
                            (if specified), and annotation_exprs (if
                            specified) must come from same MatrixTable.""")

    n = mt.count_cols()
    r2 = hl.row_correlation(entry_expr, block_size) ** 2
    r2_adj = ((n-1.0) / (n-2.0)) * r2 - (1.0 / (n-2.0))

    starts, stops = hl.linalg.utils.locus_windows(locus_expr,
                                                  radius,
                                                  coord_expr)
    r2_adj_sparse = r2_adj.sparsify_row_intervals(starts, stops)

    r2_adj_sparse_tmp = new_temp_file()
    r2_adj_sparse.write(r2_adj_sparse_tmp)
    r2_adj_sparse = BlockMatrix.read(r2_adj_sparse_tmp)

    if not annotation_exprs:
        cols = ['univariate']
        col_idxs = {0: 'univariate'}
        l2 = r2_adj_sparse.sum(axis=1)
    else:
        ht = mt.select_rows(*wrap_to_list(annotation_exprs)).rows()
        ht = ht.annotate(univariate=hl.literal(1.0))
        names = [name for name in ht.row if name not in ht.key]

        ht_union = hl.Table.union(
            *[(ht.annotate(name=hl.str(x),
                           value=hl.float(ht[x]))
                 .select('name', 'value')) for x in names])
        mt_annotations = ht_union.to_matrix_table(
            row_key=list(ht_union.key),
            col_key=['name'])

        cols = mt_annotations.key_cols_by()['name'].collect()
        col_idxs = {i: cols[i] for i in range(len(cols))}

        a_tmp = new_temp_file()
        BlockMatrix.write_from_entry_expr(mt_annotations.value, a_tmp)

        a = BlockMatrix.read(a_tmp)
        l2 = r2_adj_sparse @ a

    l2_bm_tmp = new_temp_file()
    l2_tsv_tmp = new_temp_file()
    l2.write(l2_bm_tmp, force_row_major=True)
    BlockMatrix.export(l2_bm_tmp, l2_tsv_tmp)

    ht_scores = hl.import_table(l2_tsv_tmp, no_header=True, impute=True)
    ht_scores = ht_scores.add_index()
    ht_scores = ht_scores.key_by('idx')
    ht_scores = ht_scores.rename({'f{:}'.format(i): col_idxs[i]
                                  for i in range(len(cols))})

    ht = mt.select_rows(__locus=locus_expr).rows()
    ht = ht.add_index()
    ht = ht.annotate(**ht_scores[ht.idx])
    ht = ht.key_by('__locus')
    ht = ht.select(*[x for x in ht_scores.row if x not in ht_scores.key])
    ht = ht.rename({'__locus': 'locus'})

    return ht
Esempio n. 43
0
def test_ndarray_reshape():
    np_single = np.array([8])
    single = hl.nd.array([8])

    np_zero_dim = np.array(4)
    zero_dim = hl.nd.array(4)

    np_a = np.array([1, 2, 3, 4, 5, 6])
    a = hl.nd.array(np_a)

    np_cube = np.array([0, 1, 2, 3, 4, 5, 6, 7]).reshape((2, 2, 2))
    cube = hl.nd.array([0, 1, 2, 3, 4, 5, 6, 7]).reshape((2, 2, 2))
    cube_to_rect = cube.reshape((2, 4))
    np_cube_to_rect = np_cube.reshape((2, 4))
    cube_t_to_rect = cube.transpose((1, 0, 2)).reshape((2, 4))
    np_cube_t_to_rect = np_cube.transpose((1, 0, 2)).reshape((2, 4))

    np_hypercube = np.arange(3 * 5 * 7 * 9).reshape((3, 5, 7, 9))
    hypercube = hl.nd.array(np_hypercube)

    np_shape_zero = np.array([])
    shape_zero = hl.nd.array(np_shape_zero)

    assert_ndarrays_eq((single.reshape(()), np_single.reshape(
        ())), (zero_dim.reshape(()), np_zero_dim.reshape(
            ())), (zero_dim.reshape((1, )), np_zero_dim.reshape(
                (1, ))), (a.reshape((6, )), np_a.reshape((6, ))), (a.reshape(
                    (2, 3)), np_a.reshape((2, 3))), (a.reshape(
                        (3, 2)), np_a.reshape((3, 2))), (a.reshape(
                            (3, -1)), np_a.reshape((3, -1))), (a.reshape(
                                (-1, 2)), np_a.reshape(
                                    (-1, 2))), (cube_to_rect, np_cube_to_rect),
                       (cube_t_to_rect, np_cube_t_to_rect), (hypercube.reshape(
                           (5, 7, 9, 3)).reshape(
                               (7, 9, 3, 5)), np_hypercube.reshape(
                                   (7, 9, 3, 5))),
                       (hypercube.reshape(hl.tuple(
                           [5, 7, 9, 3])), np_hypercube.reshape(
                               (5, 7, 9, 3))), (shape_zero.reshape(
                                   (0, 5)), np_shape_zero.reshape((0, 5))),
                       (shape_zero.reshape(
                           (-1, 5)), np_shape_zero.reshape((-1, 5))))

    assert hl.eval(hl.null(hl.tndarray(hl.tfloat, 2)).reshape((4, 5))) is None
    assert hl.eval(
        hl.nd.array(hl.range(20)).reshape(
            hl.null(hl.ttuple(hl.tint64, hl.tint64)))) is None

    with pytest.raises(FatalError) as exc:
        hl.eval(hl.literal(np_cube).reshape((-1, -1)))
    assert "more than one -1" in str(exc)

    with pytest.raises(FatalError) as exc:
        hl.eval(hl.literal(np_cube).reshape((20, )))
    assert "requested shape is incompatible with number of elements" in str(
        exc)

    with pytest.raises(FatalError) as exc:
        hl.eval(a.reshape((3, )))
    assert "requested shape is incompatible with number of elements" in str(
        exc)

    with pytest.raises(FatalError) as exc:
        hl.eval(a.reshape(()))
    assert "requested shape is incompatible with number of elements" in str(
        exc)

    with pytest.raises(FatalError) as exc:
        hl.eval(hl.literal(np_cube).reshape((0, 2, 2)))
    assert "requested shape is incompatible with number of elements" in str(
        exc)

    with pytest.raises(FatalError) as exc:
        hl.eval(hl.literal(np_cube).reshape((2, 2, -2)))
    assert "must contain only nonnegative numbers or -1" in str(exc)

    with pytest.raises(FatalError) as exc:
        hl.eval(shape_zero.reshape((0, -1)))
    assert "Can't reshape" in str(exc)
Esempio n. 44
0
def run_meta_split(i):
    print('####################')
    print('Starting split ' + str(i))
    print('####################')
    starttime = datetime.datetime.now()
    pi = ['A'] * int(n_chunks / 2) + ['B'] * int(n_chunks / 2)
    seed_id = int(batch +
                  str(i).zfill(4))  #create a seed_id unique to every split
    randstate = np.random.RandomState(seed_id)  #seed with seed_id

    randstate.shuffle(pi)
    gmt_shuf = gmt.annotate_cols(label=hl.literal(pi)[hl.int32(gmt.col_idx)])

    mt = gmt_shuf.group_cols_by(gmt_shuf.label).aggregate(
        unnorm_meta_beta=hl.agg.sum(gmt_shuf.beta /
                                    gmt_shuf.standard_error**2),
        inv_se2=hl.agg.sum(1 / gmt_shuf.standard_error**2)).key_rows_by('SNP')

    ht = mt.make_table()

    ht = ht.annotate(A_Z=ht['A.unnorm_meta_beta'] / hl.sqrt(ht['A.inv_se2']),
                     B_Z=ht['B.unnorm_meta_beta'] / hl.sqrt(ht['B.inv_se2']))

    ht = ht.drop('A.unnorm_meta_beta', 'B.unnorm_meta_beta', 'A.inv_se2',
                 'B.inv_se2')

    variants = hl.import_table('gs://nbaya/rg_sex/50_snps_alleles_N.tsv.gz',
                               types={'N': hl.tint64})
    variants = variants.key_by('SNP')
    #    mt_all = hl.read_matrix_table('gs://nbaya/split/ukb31063.hm3_variants.gwas_samples_'+phen+'_grouped'+str(n_chunks)+'_batch_'+batch+'.mt') #matrix table containing individual samples. OUTDATED
    ht_all = hl.read_table(
        'gs://nbaya/split/ukb31063.hm3_variants.gwas_samples_' + phen +
        '_grouped' + str(n_chunks) + '_batch_' + batch +
        '.ht')  #hail table containing individual samples
    variants = variants.annotate(N=hl.int32(ht_all.count() / 2))
    variants.show()

    metaA = variants.annotate(Z=ht[variants.SNP].A_Z)
    metaB = variants.annotate(Z=ht[variants.SNP].B_Z)

    #    metaA_path = 'gs://nbaya/rg_sex/'+phen+'_meta_A_n'+str(n_chunks)+'_batch_'+batch+'_s'+str(i)+'.tsv.bgz'
    #    metaB_path = 'gs://nbaya/rg_sex/'+phen+'_meta_B_n'+str(n_chunks)+'_batch_'+batch+'_s'+str(i)+'.tsv.bgz'
    metaA_path = 'gs://nbaya/rg_sex/' + variant_set + '_' + phen + '_meta_A_n' + str(
        n_chunks) + '_batch_' + batch + '_s' + str(
            i
        ) + '.tsv.bgz'  #only used by qc_pos variant set and later hm3 phens
    metaB_path = 'gs://nbaya/rg_sex/' + variant_set + '_' + phen + '_meta_B_n' + str(
        n_chunks) + '_batch_' + batch + '_s' + str(
            i
        ) + '.tsv.bgz'  #only used by qc_pos variant set and later hm3 phens
    metaA.export(metaA_path)
    metaB.export(metaB_path)

    endtime = datetime.datetime.now()
    elapsed = endtime - starttime
    print('####################')
    print('Completed iteration ' + str(i))
    print('Files written to:')
    print(metaA_path + '\t' + metaB_path)
    print('Time: {:%H:%M:%S (%Y-%b-%d)}'.format(datetime.datetime.now()))
    print('Iteration time: ' + str(round(elapsed.seconds / 60, 2)) +
          ' minutes')
    print('####################')
Esempio n. 45
0
def multitrait_inf(mt, h2=None, rg=None, cov_matrix=None, seed=None):
    r"""Generates correlated betas for multi-trait infinitesimal simulations for
    any number of phenotypes.

    Parameters
    ----------
    mt : :class:`.MatrixTable`
        MatrixTable for simulated phenotype.
    h2 : :obj:`float` or :obj:`int` or :obj:`list` or :class:`numpy.ndarray`, optional
        Desired SNP-based heritability (:math:`h^2`) of simulated traits.
        If `h2` is ``None``, :math:`h^2` is based on diagonal of `cov_matrix`.
    rg : :obj:`float` or :obj:`int` or :obj:`list` or :class:`numpy.ndarray`, optional
        Desired genetic correlation (:math:`r_g`) between simulated traits.
        If simulating more than two correlated traits, `rg` should be a list
        of :math:`rg` values corresponding to the upper right triangle of the
        covariance matrix. If `rg` is ``None`` and `cov_matrix` is ``None``, :math:`r_g`
        is assumed to be 0 between traits. If `rg` and `cov_matrix` are both
        not None, :math:`r_g` values from `cov_matrix` take precedence.
    cov_matrix : :class:`numpy.ndarray`, optional
        Covariance matrix for traits, **unscaled by :math:`M`**, the number of SNPs.
        Overrides `h2` and `rg` even when `h2` or `rg` are not ``None``.
    seed : :obj:`int`, optional
        Seed for random number generator. If `seed` is ``None``, `seed` is set randomly.

    Returns
    -------
    mt : :class:`.MatrixTable`
        :class:`.MatrixTable` with simulated SNP effects as a row field of arrays.
    rg : :obj:`list`
        Genetic correlation between traits, possibly altered from input `rg` if
        covariance matrix was not positive semi-definite.
    """
    uid = Env.get_uid(base=100)
    h2 = (h2.tolist() if type(h2) is np.ndarray else
          ([h2] if type(h2) is not list else h2))
    rg = rg.tolist() if type(rg) is np.ndarray else (
        [rg] if type(rg) is not list else rg)
    assert (all(x >= 0 and x <= 1
                for x in h2)), 'h2 values must be between 0 and 1'
    assert h2 is not [
        None
    ] or cov_matrix is not None, 'h2 and cov_matrix cannot both be None'
    seed = seed if seed is not None else int(str(Env.next_seed())[:8])
    M = mt.count_rows()
    if cov_matrix is not None:
        n_phens = cov_matrix.shape[0]
    else:
        n_phens = len(h2)
        if rg == [None]:
            print(f'Assuming rg=0 for all {n_phens} traits')
            rg = [0] * int((n_phens**2 - n_phens) / 2)
        assert (all(x >= -1 and x <= 1
                    for x in rg)), 'rg values must be between 0 and 1'
        cov, rg = get_cov_matrix(h2, rg)
    cov = (1 / M) * cov
    # seed random state for replicability
    randstate = np.random.RandomState(int(seed))
    betas = randstate.multivariate_normal(mean=np.zeros(n_phens),
                                          cov=cov,
                                          size=[
                                              M,
                                          ])
    df = pd.DataFrame([0] * M, columns=['beta'])
    tb = hl.Table.from_pandas(df)
    tb = tb.add_index().key_by('idx')
    tb = tb.annotate(beta=hl.literal(betas.tolist())[hl.int32(tb.idx)])
    mt = mt.add_row_index(name='row_idx' + uid)
    mt = mt.annotate_rows(beta=tb[mt['row_idx' + uid]]['beta'])
    mt = _clean_fields(mt, uid)
    return mt, rg
Esempio n. 46
0
def generate_datasets(doctest_namespace):
    doctest_namespace['hl'] = hl
    doctest_namespace['np'] = np

    ds = hl.import_vcf('data/sample.vcf.bgz')
    ds = ds.sample_rows(0.03)
    ds = ds.annotate_rows(use_as_marker=hl.rand_bool(0.5),
                          panel_maf=0.1,
                          anno1=5,
                          anno2=0,
                          consequence="LOF",
                          gene="A",
                          score=5.0)
    ds = ds.annotate_rows(a_index=1)
    ds = hl.sample_qc(hl.variant_qc(ds))
    ds = ds.annotate_cols(is_case=True,
                          pheno=hl.struct(is_case=hl.rand_bool(0.5),
                                          is_female=hl.rand_bool(0.5),
                                          age=hl.rand_norm(65, 10),
                                          height=hl.rand_norm(70, 10),
                                          blood_pressure=hl.rand_norm(120, 20),
                                          cohort_name="cohort1"),
                          cov=hl.struct(PC1=hl.rand_norm(0, 1)),
                          cov1=hl.rand_norm(0, 1),
                          cov2=hl.rand_norm(0, 1),
                          cohort="SIGMA")
    ds = ds.annotate_globals(
        global_field_1=5,
        global_field_2=10,
        pli={
            'SCN1A': 0.999,
            'SONIC': 0.014
        },
        populations=['AFR', 'EAS', 'EUR', 'SAS', 'AMR', 'HIS'])
    ds = ds.annotate_rows(gene=['TTN'])
    ds = ds.annotate_cols(cohorts=['1kg'], pop='EAS')
    ds = ds.checkpoint(f'output/example.mt', overwrite=True)

    doctest_namespace['ds'] = ds
    doctest_namespace['dataset'] = ds
    doctest_namespace['dataset2'] = ds.annotate_globals(global_field=5)
    doctest_namespace['dataset_to_union_1'] = ds
    doctest_namespace['dataset_to_union_2'] = ds

    v_metadata = ds.rows().annotate_globals(global_field=5).annotate(
        consequence='SYN')
    doctest_namespace['v_metadata'] = v_metadata

    s_metadata = ds.cols().annotate(pop='AMR', is_case=False, sex='F')
    doctest_namespace['s_metadata'] = s_metadata
    doctest_namespace['cols_to_keep'] = s_metadata
    doctest_namespace['cols_to_remove'] = s_metadata
    doctest_namespace['rows_to_keep'] = v_metadata
    doctest_namespace['rows_to_remove'] = v_metadata

    small_mt = hl.balding_nichols_model(3, 4, 4)
    doctest_namespace['small_mt'] = small_mt.checkpoint('output/small.mt',
                                                        overwrite=True)

    # Table
    table1 = hl.import_table('data/kt_example1.tsv', impute=True, key='ID')
    table1 = table1.annotate_globals(global_field_1=5, global_field_2=10)
    doctest_namespace['table1'] = table1
    doctest_namespace['other_table'] = table1

    table2 = hl.import_table('data/kt_example2.tsv', impute=True, key='ID')
    doctest_namespace['table2'] = table2

    table4 = hl.import_table('data/kt_example4.tsv',
                             impute=True,
                             types={
                                 'B': hl.tstruct(B0=hl.tbool, B1=hl.tstr),
                                 'D': hl.tstruct(cat=hl.tint32, dog=hl.tint32),
                                 'E': hl.tstruct(A=hl.tint32, B=hl.tint32)
                             })
    doctest_namespace['table4'] = table4

    people_table = hl.import_table('data/explode_example.tsv',
                                   delimiter='\\s+',
                                   types={
                                       'Age': hl.tint32,
                                       'Children': hl.tarray(hl.tstr)
                                   },
                                   key='Name')
    doctest_namespace['people_table'] = people_table

    # TDT
    doctest_namespace['tdt_dataset'] = hl.import_vcf('data/tdt_tiny.vcf')

    ds2 = hl.variant_qc(ds)
    doctest_namespace['ds2'] = ds2.select_rows(AF=ds2.variant_qc.AF)

    # Expressions
    doctest_namespace['names'] = hl.literal(['Alice', 'Bob', 'Charlie'])
    doctest_namespace['a1'] = hl.literal([0, 1, 2, 3, 4, 5])
    doctest_namespace['a2'] = hl.literal([1, -1, 1, -1, 1, -1])
    doctest_namespace['t'] = hl.literal(True)
    doctest_namespace['f'] = hl.literal(False)
    doctest_namespace['na'] = hl.null(hl.tbool)
    doctest_namespace['call'] = hl.call(0, 1, phased=False)
    doctest_namespace['a'] = hl.literal([1, 2, 3, 4, 5])
    doctest_namespace['d'] = hl.literal({
        'Alice': 43,
        'Bob': 33,
        'Charles': 44
    })
    doctest_namespace['interval'] = hl.interval(3, 11)
    doctest_namespace['locus_interval'] = hl.parse_locus_interval(
        "1:53242-90543")
    doctest_namespace['locus'] = hl.locus('1', 1034245)
    doctest_namespace['x'] = hl.literal(3)
    doctest_namespace['y'] = hl.literal(4.5)
    doctest_namespace['s1'] = hl.literal({1, 2, 3})
    doctest_namespace['s2'] = hl.literal({1, 3, 5})
    doctest_namespace['s3'] = hl.literal({'Alice', 'Bob', 'Charlie'})
    doctest_namespace['struct'] = hl.struct(a=5, b='Foo')
    doctest_namespace['tup'] = hl.literal(("a", 1, [1, 2, 3]))
    doctest_namespace['s'] = hl.literal('The quick brown fox')
    doctest_namespace['interval2'] = hl.Interval(3, 6)
    doctest_namespace['nd'] = hl._nd.array([[1, 2], [3, 4]])

    # Overview
    doctest_namespace['ht'] = hl.import_table("data/kt_example1.tsv",
                                              impute=True)
    doctest_namespace['mt'] = ds

    gnomad_data = ds.rows()
    doctest_namespace['gnomad_data'] = gnomad_data.select(gnomad_data.info.AF)

    # BGEN
    bgen = hl.import_bgen('data/example.8bits.bgen',
                          entry_fields=['GT', 'GP', 'dosage'])
    doctest_namespace['variants_table'] = bgen.rows()

    burden_ds = hl.import_vcf('data/example_burden.vcf')
    burden_kt = hl.import_table('data/example_burden.tsv',
                                key='Sample',
                                impute=True)
    burden_ds = burden_ds.annotate_cols(burden=burden_kt[burden_ds.s])
    burden_ds = burden_ds.annotate_rows(
        weight=hl.float64(burden_ds.locus.position))
    burden_ds = hl.variant_qc(burden_ds)
    genekt = hl.import_locus_intervals('data/gene.interval_list')
    burden_ds = burden_ds.annotate_rows(gene=genekt[burden_ds.locus])
    burden_ds = burden_ds.checkpoint(f'output/example_burden.vds',
                                     overwrite=True)
    doctest_namespace['burden_ds'] = burden_ds

    ld_score_one_pheno_sumstats = hl.import_table(
        'data/ld_score_regression.one_pheno.sumstats.tsv',
        types={
            'locus': hl.tlocus('GRCh37'),
            'alleles': hl.tarray(hl.tstr),
            'chi_squared': hl.tfloat64,
            'n': hl.tint32,
            'ld_score': hl.tfloat64,
            'phenotype': hl.tstr,
            'chi_squared_50_irnt': hl.tfloat64,
            'n_50_irnt': hl.tint32,
            'chi_squared_20160': hl.tfloat64,
            'n_20160': hl.tint32
        },
        key=['locus', 'alleles'])
    doctest_namespace[
        'ld_score_one_pheno_sumstats'] = ld_score_one_pheno_sumstats

    mt = hl.import_matrix_table(
        'data/ld_score_regression.all_phenos.sumstats.tsv',
        row_fields={
            'locus': hl.tstr,
            'alleles': hl.tstr,
            'ld_score': hl.tfloat64
        },
        entry_type=hl.tstr)
    mt = mt.key_cols_by(phenotype=mt.col_id)
    mt = mt.key_rows_by(locus=hl.parse_locus(mt.locus),
                        alleles=mt.alleles.split(','))
    mt = mt.drop('row_id', 'col_id')
    mt = mt.annotate_entries(x=mt.x.split(","))
    mt = mt.transmute_entries(chi_squared=hl.float64(mt.x[0]),
                              n=hl.int32(mt.x[1]))
    mt = mt.annotate_rows(ld_score=hl.float64(mt.ld_score))
    doctest_namespace['ld_score_all_phenos_sumstats'] = mt

    print("finished setting up doctest...")
Esempio n. 47
0
def calculate_phenotypes(mt,
                         genotype,
                         beta,
                         h2,
                         popstrat=None,
                         popstrat_var=None,
                         exact_h2=False):
    r"""Calculates phenotypes by multiplying genotypes and betas.

    Parameters
    ----------
    mt : :class:`.MatrixTable`
        :class:`.MatrixTable` with all relevant fields passed as parameters.
    genotype : :class:`.Expression` or :class:`.CallExpression`
        Entry field of genotypes.
    beta : :class:`.Expression`
        Row field of SNP effects.
    h2 : :obj:`float` or :obj:`int` or :obj:`list` or :class:`numpy.ndarray`
        SNP-based heritability (:math:`h^2`) of simulated trait. Can only be
        ``None`` if running annotation-informed model.
    popstrat : :class:`.Expression`, optional
        Column field containing population stratification term.
    popstrat_var : :obj:`float` or :obj:`int`
        Variance of population stratification term.
    exact_h2: :obj:`bool`
        Whether to exactly simulate ratio of variance of genetic component of
        phenotype to variance of phenotype to be h2. If `False`, ratio will be
        h2 in expectation. Observed h2 in the simulation will be close to
        expected h2 for large-scale simulations.

    Returns
    -------
    :class:`.MatrixTable`
        :class:`.MatrixTable` with simulated phenotype as column field.
    """
    print('calculating phenotype')
    h2 = h2.tolist() if type(h2) is np.ndarray else (
        [h2] if type(h2) is not list else h2)
    assert popstrat_var is None or (popstrat_var >=
                                    0), 'popstrat_var must be non-negative'
    uid = Env.get_uid(base=100)
    mt = annotate_all(
        mt=mt,
        row_exprs={'beta_' + uid: beta},
        col_exprs={} if popstrat is None else {'popstrat_' + uid: popstrat},
        entry_exprs={
            'gt_' + uid:
            genotype.n_alt_alleles()
            if genotype.dtype is hl.dtype('call') else genotype
        })
    mt = mt.filter_rows(hl.agg.stats(mt['gt_' + uid]).stdev > 0)
    mt = normalize_genotypes(mt['gt_' + uid])
    if mt['beta_' + uid].dtype == hl.dtype('array<float64>'):  # if >1 traits
        if exact_h2:
            raise ValueError(
                'exact_h2=True not supported for multitrait simulations')
        else:
            mt = mt.annotate_cols(y_no_noise=hl.agg.array_agg(
                lambda beta: hl.agg.sum(beta * mt['norm_gt']), mt['beta_' +
                                                                  uid]))
            mt = mt.annotate_cols(
                y=mt.y_no_noise +
                hl.literal(h2).map(lambda x: hl.rand_norm(0, hl.sqrt(1 - x))))
    else:
        if exact_h2 and min([h2[0], 1 - h2[0]]) != 0:
            print('exact h2')
            mt = mt.annotate_cols(**{
                'y_no_noise_' + uid:
                hl.agg.sum(mt['beta_' + uid] * mt['norm_gt'])
            })
            y_no_noise_stdev = mt.aggregate_cols(
                hl.agg.stats(mt['y_no_noise_' + uid]).stdev)
            mt = mt.annotate_cols(
                y_no_noise=hl.sqrt(h2[0]) * mt['y_no_noise_' + uid] /
                y_no_noise_stdev
            )  # normalize genetic component of phenotype to have variance of exactly h2
            mt = mt.annotate_cols(
                **{'noise_' + uid: hl.rand_norm(0, hl.sqrt(1 - h2[0]))})
            noise_stdev = mt.aggregate_cols(
                hl.agg.stats(mt['noise_' + uid]).stdev)
            mt = mt.annotate_cols(noise=hl.sqrt(1 - h2[0]) *
                                  mt['noise_' + uid] / noise_stdev)
            mt = mt.annotate_cols(
                y=mt.y_no_noise +
                hl.sqrt(1 - h2[0]) * mt['noise_' + uid] / noise_stdev)
        else:
            mt = mt.annotate_cols(y_no_noise=hl.agg.sum(mt['beta_' + uid] *
                                                        mt['norm_gt']))
            mt = mt.annotate_cols(y=mt.y_no_noise +
                                  hl.rand_norm(0, hl.sqrt(1 - h2[0])))
    if popstrat is not None:
        var_factor = 1 if popstrat_var is None else (popstrat_var**(
            1 / 2)) / mt.aggregate_cols(hl.agg.stats(mt['popstrat_' +
                                                        uid])).stdev
        mt = mt.rename({'y': 'y_no_popstrat'})
        mt = mt.annotate_cols(y=mt.y_no_popstrat +
                              mt['popstrat_' + uid] * var_factor)
    mt = _clean_fields(mt, uid)
    return mt
Esempio n. 48
0
File: plots.py Progetto: jigold/hail
def histogram2d(x, y, bins=40, range=None,
                 title=None, width=600, height=600, font_size='7pt',
                 colors=bokeh.palettes.all_palettes['Blues'][7][::-1]):
    """Plot a two-dimensional histogram.

    ``x`` and ``y`` must both be a :class:`NumericExpression` from the same :class:`Table`.

    If ``x_range`` or ``y_range`` are not provided, the function will do a pass through the data to determine
    min and max of each variable.

    Examples
    --------

    >>> ht = hail.utils.range_table(1000).annotate(x=hail.rand_norm(), y=hail.rand_norm())
    >>> p_hist = hail.plot.histogram2d(ht.x, ht.y)

    >>> ht = hail.utils.range_table(1000).annotate(x=hail.rand_norm(), y=hail.rand_norm())
    >>> p_hist = hail.plot.histogram2d(ht.x, ht.y, bins=10, range=((0, 1), None))

    Parameters
    ----------
    x : :class:`.NumericExpression`
        Expression for x-axis (from a Hail table).
    y : :class:`.NumericExpression`
        Expression for y-axis (from the same Hail table as ``x``).
    bins : int or [int, int]
        The bin specification:
        -   If int, the number of bins for the two dimensions (nx = ny = bins).
        -   If [int, int], the number of bins in each dimension (nx, ny = bins).
        The default value is 40.
    range : None or ((float, float), (float, float))
        The leftmost and rightmost edges of the bins along each dimension:
        ((xmin, xmax), (ymin, ymax)). All values outside of this range will be considered outliers
        and not tallied in the histogram. If this value is None, or either of the inner lists is None,
        the range will be computed from the data.
    width : int
        Plot width (default 600px).
    height : int
        Plot height (default 600px).
    title : str
        Title of the plot.
    font_size : str
        String of font size in points (default '7pt').
    colors : List[str]
        List of colors (hex codes, or strings as described
        `here <https://bokeh.pydata.org/en/latest/docs/reference/colors.html>`__). Compatible with one of the many
        built-in palettes available `here <https://bokeh.pydata.org/en/latest/docs/reference/palettes.html>`__.

    Returns
    -------
    :class:`bokeh.plotting.figure.Figure`
    """
    source = x._indices.source
    y_source = y._indices.source

    if source is None or y_source is None:
        raise ValueError("histogram_2d expects two expressions of 'Table', found scalar expression")
    if isinstance(source, hail.MatrixTable):
        raise ValueError("histogram_2d requires source to be Table, not MatrixTable")
    if source != y_source:
        raise ValueError(f"histogram_2d expects two expressions from the same 'Table', found {source} and {y_source}")
    check_row_indexed('histogram_2d', x)
    check_row_indexed('histogram_2d', y)
    if isinstance(bins, int):
        x_bins = y_bins = bins
    else:
        x_bins, y_bins = bins
    if range is None:
        x_range = y_range = None
    else:
        x_range, y_range = range
    if x_range is None or y_range is None:
        warnings.warn('At least one range was not defined in histogram_2d. Doing two passes...')
        ranges = source.aggregate(hail.struct(x_stats=hail.agg.stats(x),
                                              y_stats=hail.agg.stats(y)))
        if x_range is None:
            x_range = (ranges.x_stats.min, ranges.x_stats.max)
        if y_range is None:
            y_range = (ranges.y_stats.min, ranges.y_stats.max)
    else:
        warnings.warn('If x_range or y_range are specified in histogram_2d, and there are points '
                      'outside of these ranges, they will not be plotted')
    x_range = list(map(float, x_range))
    y_range = list(map(float, y_range))
    x_spacing = (x_range[1] - x_range[0]) / x_bins
    y_spacing = (y_range[1] - y_range[0]) / y_bins

    def frange(start, stop, step):
        from itertools import count, takewhile
        return takewhile(lambda x: x <= stop, count(start, step))

    x_levels = hail.literal(list(frange(x_range[0], x_range[1], x_spacing))[::-1])
    y_levels = hail.literal(list(frange(y_range[0], y_range[1], y_spacing))[::-1])

    grouped_ht = source.group_by(
        x=hail.str(x_levels.find(lambda w: x >= w)),
        y=hail.str(y_levels.find(lambda w: y >= w))
    ).aggregate(c=hail.agg.count())
    data = grouped_ht.filter(hail.is_defined(grouped_ht.x) & (grouped_ht.x != str(x_range[1])) &
                             hail.is_defined(grouped_ht.y) & (grouped_ht.y != str(y_range[1]))).to_pandas()

    mapper = LinearColorMapper(palette=colors, low=data.c.min(), high=data.c.max())

    x_axis = sorted(set(data.x), key=lambda z: float(z))
    y_axis = sorted(set(data.y), key=lambda z: float(z))
    p = figure(title=title,
               x_range=x_axis, y_range=y_axis,
               x_axis_location="above", plot_width=width, plot_height=height,
               tools="hover,save,pan,box_zoom,reset,wheel_zoom", toolbar_location='below')

    p.grid.grid_line_color = None
    p.axis.axis_line_color = None
    p.axis.major_tick_line_color = None
    p.axis.major_label_standoff = 0
    p.axis.major_label_text_font_size = font_size
    import math
    p.xaxis.major_label_orientation = math.pi / 3

    p.rect(x='x', y='y', width=1, height=1,
           source=data,
           fill_color={'field': 'c', 'transform': mapper},
           line_color=None)

    color_bar = ColorBar(color_mapper=mapper, major_label_text_font_size=font_size,
                         ticker=BasicTicker(desired_num_ticks=6),
                         label_standoff=6, border_line_color=None, location=(0, 0))
    p.add_layout(color_bar, 'right')

    def set_font_size(p, font_size: str = '12pt'):
        """Set most of the font sizes in a bokeh figure

        Parameters
        ----------
        p : :class:`bokeh.plotting.figure.Figure`
            Input figure.
        font_size : str
            String of font size in points (e.g. '12pt').

        Returns
        -------
        :class:`bokeh.plotting.figure.Figure`
        """
        p.legend.label_text_font_size = font_size
        p.xaxis.axis_label_text_font_size = font_size
        p.yaxis.axis_label_text_font_size = font_size
        p.xaxis.major_label_text_font_size = font_size
        p.yaxis.major_label_text_font_size = font_size
        if hasattr(p.title, 'text_font_size'):
            p.title.text_font_size = font_size
        if hasattr(p.xaxis, 'group_text_font_size'):
            p.xaxis.group_text_font_size = font_size
        return p

    p.select_one(HoverTool).tooltips = [('x', '@x'), ('y', '@y',), ('count', '@c')]
    p = set_font_size(p, font_size)
    return p