Example #1
0
def copmute_ldscore(ht, bm_ld, n, radius, out_name, overwrite):
    r2 = bm_ld**2
    r2_adj = ((n - 1.0) / (n - 2.0)) * r2 - (1.0 / (n - 2.0))

    # This is required, as the squaring/multiplication densifies, so this re-sparsifies.
    starts_and_stops = hl.linalg.utils.locus_windows(ht.locus,
                                                     radius,
                                                     _localize=False)
    r2_adj = r2_adj._sparsify_row_intervals_expr(starts_and_stops,
                                                 blocks_only=False)
    r2_adj = r2_adj.sparsify_triangle()
    r2_adj = checkpoint_tmp(r2_adj)

    # Note that the original ld matrix is triangular
    l2row = checkpoint_tmp(r2_adj.sum(axis=0)).T
    l2col = checkpoint_tmp(r2_adj.sum(axis=1))
    r2_diag = checkpoint_tmp(r2_adj.diagonal()).T
    l2 = l2row + l2col - r2_diag
    l2_bm_tmp = new_temp_file()
    l2_tsv_tmp = new_gs_temp_path()

    l2.write(l2_bm_tmp, force_row_major=True)
    BlockMatrix.export(l2_bm_tmp, l2_tsv_tmp)

    ht_scores = hl.import_table(l2_tsv_tmp, no_header=True, impute=True)
    ht_scores = ht_scores.add_index().rename({'f0': 'ld_score'})
    ht_scores = ht_scores.key_by('idx')
    ht = ht.add_index()
    ht = ht.annotate(**ht_scores[ht.idx]).drop('idx')
    ht = ht.checkpoint(out_name, overwrite)
    return ht
Example #2
0
def compute_and_annotate_ld_score(ht, r2_adj, radius, out_name, overwrite):
    starts_and_stops = hl.linalg.utils.locus_windows(ht.locus,
                                                     radius,
                                                     _localize=False)

    # Lifted directly from https://github.com/hail-is/hail/blob/555e02d6c792263db2c3ed97db8002b489e2dacb/hail/python/hail/methods/statgen.py#L2595
    # for the time being, until efficient BlockMatrix filtering gets an easier interface
    # This is required, as the squaring/multiplication densifies, so this re-sparsifies.
    r2_adj = BlockMatrix._from_java(
        r2_adj._jbm.filterRowIntervalsIR(
            Env.backend()._to_java_ir(starts_and_stops._ir), False))

    l2row = r2_adj.sum(axis=0).T
    l2col = r2_adj.sum(axis=1)
    l2 = l2row + l2col + 1
    l2_bm_tmp = new_temp_file()
    l2_tsv_tmp = new_temp_file()

    l2.write(l2_bm_tmp, force_row_major=True)
    BlockMatrix.export(l2_bm_tmp, l2_tsv_tmp)

    ht_scores = hl.import_table(l2_tsv_tmp, no_header=True, impute=True)
    ht_scores = ht_scores.add_index().rename({'f0': 'ld_score'})
    ht_scores = ht_scores.key_by('idx')
    ht = ht.annotate(**ht_scores[ht.new_idx]).select_globals()
    ht.filter(hl.is_defined(ht.ld_score)).write(out_name, overwrite)
Example #3
0
def generate_ld_scores_from_ld_matrix(pop_data,
                                      data_type,
                                      min_frequency=0.01,
                                      call_rate_cutoff=0.8,
                                      adj: bool = False,
                                      radius: int = 1000000,
                                      overwrite=False):
    # This function required a decent number of high-mem machines (with an SSD for good measure) to complete the AFR
    # For the rest, on 20 n1-standard-8's, 1h15m to export block matrix, 15 mins to compute LD scores per population (~$150 total)
    for label, pops in dict(pop_data).items():
        for pop, n in pops.items():
            if pop in ('nfe', 'fin', 'asj'): continue
            ht = hl.read_table(ld_index_path(data_type, pop, adj=adj))
            ht = ht.filter((ht.pop_freq.AF >= min_frequency)
                           & (ht.pop_freq.AF <= 1 - min_frequency)
                           & (ht.pop_freq.AN / n >= 2 *
                              call_rate_cutoff)).add_index(name='new_idx')

            indices = ht.idx.collect()

            r2 = BlockMatrix.read(
                ld_matrix_path(data_type,
                               pop,
                               min_frequency >= COMMON_FREQ,
                               adj=adj))
            r2 = r2.filter(indices, indices)**2
            r2_adj = ((n - 1.0) / (n - 2.0)) * r2 - (1.0 / (n - 2.0))

            starts_and_stops = hl.linalg.utils.locus_windows(ht.locus,
                                                             radius,
                                                             _localize=False)

            # Lifted directly from https://github.com/hail-is/hail/blob/555e02d6c792263db2c3ed97db8002b489e2dacb/hail/python/hail/methods/statgen.py#L2595
            # for the time being, until efficient BlockMatrix filtering gets an easier interface
            r2_adj = BlockMatrix._from_java(
                r2_adj._jbm.filterRowIntervalsIR(
                    Env.backend()._to_java_ir(starts_and_stops._ir), False))

            l2row = r2_adj.sum(axis=0).T
            l2col = r2_adj.sum(axis=1)
            l2 = l2row + l2col + 1

            l2_bm_tmp = new_temp_file()
            l2_tsv_tmp = new_temp_file()
            l2.write(l2_bm_tmp, force_row_major=True)
            BlockMatrix.export(l2_bm_tmp, l2_tsv_tmp)

            ht_scores = hl.import_table(l2_tsv_tmp,
                                        no_header=True,
                                        impute=True)
            ht_scores = ht_scores.add_index().rename({'f0': 'ld_score'})
            ht_scores = ht_scores.key_by('idx')

            ht = ht.annotate(**ht_scores[ht.new_idx]).select_globals()
            ht.filter(hl.is_defined(ht.ld_score)).write(
                ld_scores_path(data_type, pop, adj), overwrite)
Example #4
0
def ld_score(entry_expr,
             locus_expr,
             radius,
             coord_expr=None,
             annotation_exprs=None,
             block_size=None) -> Table:
    """Calculate LD scores.

    Example
    -------

    >>> # Load genetic data into MatrixTable
    >>> mt = hl.import_plink(bed='data/ldsc.bed',
    ...                      bim='data/ldsc.bim',
    ...                      fam='data/ldsc.fam')

    >>> # Create locus-keyed Table with numeric variant annotations
    >>> ht = hl.import_table('data/ldsc.annot',
    ...                      types={'BP': hl.tint,
    ...                             'binary': hl.tfloat,
    ...                             'continuous': hl.tfloat})
    >>> ht = ht.annotate(locus=hl.locus(ht.CHR, ht.BP))
    >>> ht = ht.key_by('locus')

    >>> # Annotate MatrixTable with external annotations
    >>> mt = mt.annotate_rows(binary_annotation=ht[mt.locus].binary,
    ...                       continuous_annotation=ht[mt.locus].continuous)

    >>> # Calculate LD scores using centimorgan coordinates
    >>> ht_scores = hl.experimental.ld_score(entry_expr=mt.GT.n_alt_alleles(),
    ...                                      locus_expr=mt.locus,
    ...                                      radius=1.0,
    ...                                      coord_expr=mt.cm_position,
    ...                                      annotation_exprs=[mt.binary_annotation,
    ...                                                        mt.continuous_annotation])

    >>> # Show results
    >>> ht_scores.show(3)

    .. code-block:: text

        +---------------+-------------------+-----------------------+-------------+
        | locus         | binary_annotation | continuous_annotation |  univariate |
        +---------------+-------------------+-----------------------+-------------+
        | locus<GRCh37> |           float64 |               float64 |     float64 |
        +---------------+-------------------+-----------------------+-------------+
        | 20:82079      |       1.15183e+00 |           7.30145e+01 | 1.60117e+00 |
        | 20:103517     |       2.04604e+00 |           2.75392e+02 | 4.69239e+00 |
        | 20:108286     |       2.06585e+00 |           2.86453e+02 | 5.00124e+00 |
        +---------------+-------------------+-----------------------+-------------+


    Warning
    -------
        :func:`.ld_score` will fail if ``entry_expr`` results in any missing
        values. The special float value ``nan`` is not considered a
        missing value.

    **Further reading**

    For more in-depth discussion of LD scores, see:

    - `LD Score regression distinguishes confounding from polygenicity in genome-wide association studies (Bulik-Sullivan et al, 2015) <https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4495769/>`__
    - `Partitioning heritability by functional annotation using genome-wide association summary statistics (Finucane et al, 2015) <https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4626285/>`__

    Notes
    -----

    `entry_expr`, `locus_expr`, `coord_expr` (if specified), and
    `annotation_exprs` (if specified) must come from the same
    MatrixTable.


    Parameters
    ----------
    entry_expr : :class:`.NumericExpression`
        Expression for entries of genotype matrix
        (e.g. ``mt.GT.n_alt_alleles()``).
    locus_expr : :class:`.LocusExpression`
        Row-indexed locus expression.
    radius : :obj:`int` or :obj:`float`
        Radius of window for row values (in units of `coord_expr` if set,
        otherwise in units of basepairs).
    coord_expr: :class:`.Float64Expression`, optional
        Row-indexed numeric expression for the row value used to window
        variants. By default, the row value is given by the locus
        position.
    annotation_exprs : :class:`.NumericExpression` or
                       :obj:`list` of :class:`.NumericExpression`, optional
        Annotation expression(s) to partition LD scores. Univariate
        annotation will always be included and does not need to be
        specified.
    block_size : :obj:`int`, optional
        Block size. Default given by :meth:`.BlockMatrix.default_block_size`.

    Returns
    -------
    :class:`.Table`
        Table keyed by `locus_expr` with LD scores for each variant and
        `annotation_expr`. The function will always return LD scores for
        the univariate (all SNPs) annotation."""

    mt = entry_expr._indices.source
    mt_locus_expr = locus_expr._indices.source

    if coord_expr is None:
        mt_coord_expr = mt_locus_expr
    else:
        mt_coord_expr = coord_expr._indices.source

    if not annotation_exprs:
        check_mts = all([mt == mt_locus_expr,
                         mt == mt_coord_expr])
    else:
        check_mts = all([mt == mt_locus_expr,
                         mt == mt_coord_expr]
                        + [mt == x._indices.source
                           for x in wrap_to_list(annotation_exprs)])

    if not check_mts:
        raise ValueError("""ld_score: entry_expr, locus_expr, coord_expr
                            (if specified), and annotation_exprs (if
                            specified) must come from same MatrixTable.""")

    n = mt.count_cols()
    r2 = hl.row_correlation(entry_expr, block_size) ** 2
    r2_adj = ((n - 1.0) / (n - 2.0)) * r2 - (1.0 / (n - 2.0))

    starts, stops = hl.linalg.utils.locus_windows(locus_expr,
                                                  radius,
                                                  coord_expr)
    r2_adj_sparse = r2_adj.sparsify_row_intervals(starts, stops)

    r2_adj_sparse_tmp = new_temp_file()
    r2_adj_sparse.write(r2_adj_sparse_tmp)
    r2_adj_sparse = BlockMatrix.read(r2_adj_sparse_tmp)

    if not annotation_exprs:
        cols = ['univariate']
        col_idxs = {0: 'univariate'}
        l2 = r2_adj_sparse.sum(axis=1)
    else:
        ht = mt.select_rows(*wrap_to_list(annotation_exprs)).rows()
        ht = ht.annotate(univariate=hl.literal(1.0))
        names = [name for name in ht.row if name not in ht.key]

        ht_union = hl.Table.union(
            *[(ht.annotate(name=hl.str(x),
                           value=hl.float(ht[x]))
               .select('name', 'value')) for x in names])
        mt_annotations = ht_union.to_matrix_table(
            row_key=list(ht_union.key),
            col_key=['name'])

        cols = mt_annotations.key_cols_by()['name'].collect()
        col_idxs = {i: cols[i] for i in range(len(cols))}

        a_tmp = new_temp_file()
        BlockMatrix.write_from_entry_expr(mt_annotations.value, a_tmp)

        a = BlockMatrix.read(a_tmp)
        l2 = r2_adj_sparse @ a

    l2_bm_tmp = new_temp_file()
    l2_tsv_tmp = new_temp_file()
    l2.write(l2_bm_tmp, force_row_major=True)
    BlockMatrix.export(l2_bm_tmp, l2_tsv_tmp)

    ht_scores = hl.import_table(l2_tsv_tmp, no_header=True, impute=True)
    ht_scores = ht_scores.add_index()
    ht_scores = ht_scores.key_by('idx')
    ht_scores = ht_scores.rename({'f{:}'.format(i): col_idxs[i]
                                  for i in range(len(cols))})

    ht = mt.select_rows(__locus=locus_expr).rows()
    ht = ht.add_index()
    ht = ht.annotate(**ht_scores[ht.idx])
    ht = ht.key_by('__locus')
    ht = ht.select(*[x for x in ht_scores.row if x not in ht_scores.key])
    ht = ht.rename({'__locus': 'locus'})

    return ht
Example #5
0
import hail as hl
from hail.linalg import BlockMatrix
from os import path
import sys
import pandas as pd

chr_id, group_id = sys.argv[1], sys.argv[2]
print(chr_id + ' ' + group_id)
idx_comb = pd.read_csv("mapLDref.tsv.gz", sep="\t")
idx_comb['chr'] = idx_comb['chr'].astype(str)
idx_comb['group'] = idx_comb['group'].astype(str)

idx_comb_chr = idx_comb[idx_comb.chr == chr_id]
chridx = idx_comb_chr[idx_comb_chr.group == group_id].idx.tolist()

ext = 'chr' + chr_id + '.' + group_id
## Load data
bm = BlockMatrix.read('s3a://pan-ukb-us-east-1/ld_release/UKBB.EUR.ldadj.bm')

bmchr = bm.filter(chridx, chridx)
bmchr.write(ext + '.bm', force_row_major=True)
BlockMatrix.export(ext + '.bm', ext + '.csv.bgz', delimiter='\t')
Example #6
0
def ld_score(entry_expr,
             locus_expr,
             radius,
             coord_expr=None,
             annotation_exprs=None,
             block_size=None) -> Table:
    """Calculate LD scores.

    Example
    -------

    >>> # Load genetic data into MatrixTable
    >>> mt = hl.import_plink(bed='data/ldsc.bed',
    ...                      bim='data/ldsc.bim',
    ...                      fam='data/ldsc.fam')

    >>> # Create locus-keyed Table with numeric variant annotations
    >>> ht = hl.import_table('data/ldsc.annot',
    ...                      types={'BP': hl.tint,
    ...                             'binary': hl.tfloat,
    ...                             'continuous': hl.tfloat})
    >>> ht = ht.annotate(locus=hl.locus(ht.CHR, ht.BP))
    >>> ht = ht.key_by('locus')

    >>> # Annotate MatrixTable with external annotations
    >>> mt = mt.annotate_rows(binary_annotation=ht[mt.locus].binary,
    ...                       continuous_annotation=ht[mt.locus].continuous)

    >>> # Calculate LD scores using centimorgan coordinates
    >>> ht_scores = hl.experimental.ld_score(entry_expr=mt.GT.n_alt_alleles(),
    ...                                      locus_expr=mt.locus,
    ...                                      radius=1.0,
    ...                                      coord_expr=mt.cm_position,
    ...                                      annotation_exprs=[mt.binary_annotation,
    ...                                                        mt.continuous_annotation])

    >>> # Show results
    >>> ht_scores.show(3)

    .. code-block:: text

        +---------------+-------------------+-----------------------+-------------+
        | locus         | binary_annotation | continuous_annotation |  univariate |
        +---------------+-------------------+-----------------------+-------------+
        | locus<GRCh37> |           float64 |               float64 |     float64 |
        +---------------+-------------------+-----------------------+-------------+
        | 20:82079      |       1.15183e+00 |           7.30145e+01 | 1.60117e+00 |
        | 20:103517     |       2.04604e+00 |           2.75392e+02 | 4.69239e+00 |
        | 20:108286     |       2.06585e+00 |           2.86453e+02 | 5.00124e+00 |
        +---------------+-------------------+-----------------------+-------------+


    Warning
    -------
        :func:`.ld_score` will fail if ``entry_expr`` results in any missing
        values. The special float value ``nan`` is not considered a
        missing value.

    **Further reading**

    For more in-depth discussion of LD scores, see:

    - `LD Score regression distinguishes confounding from polygenicity in genome-wide association studies (Bulik-Sullivan et al, 2015) <https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4495769/>`__
    - `Partitioning heritability by functional annotation using genome-wide association summary statistics (Finucane et al, 2015) <https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4626285/>`__

    Notes
    -----

    `entry_expr`, `locus_expr`, `coord_expr` (if specified), and
    `annotation_exprs` (if specified) must come from the same
    MatrixTable.


    Parameters
    ----------
    entry_expr : :class:`.NumericExpression`
        Expression for entries of genotype matrix
        (e.g. ``mt.GT.n_alt_alleles()``).
    locus_expr : :class:`.LocusExpression`
        Row-indexed locus expression.
    radius : :obj:`int` or :obj:`float`
        Radius of window for row values (in units of `coord_expr` if set,
        otherwise in units of basepairs).
    coord_expr: :class:`.Float64Expression`, optional
        Row-indexed numeric expression for the row value used to window
        variants. By default, the row value is given by the locus
        position.
    annotation_exprs : :class:`.NumericExpression` or
                       :obj:`list` of :class:`.NumericExpression`, optional
        Annotation expression(s) to partition LD scores. Univariate
        annotation will always be included and does not need to be
        specified.
    block_size : :obj:`int`, optional
        Block size. Default given by :meth:`.BlockMatrix.default_block_size`.

    Returns
    -------
    :class:`.Table`
        Table keyed by `locus_expr` with LD scores for each variant and
        `annotation_expr`. The function will always return LD scores for
        the univariate (all SNPs) annotation."""

    mt = entry_expr._indices.source
    mt_locus_expr = locus_expr._indices.source

    if coord_expr is None:
        mt_coord_expr = mt_locus_expr
    else:
        mt_coord_expr = coord_expr._indices.source

    if not annotation_exprs:
        check_mts = all([mt == mt_locus_expr,
                         mt == mt_coord_expr])
    else:
        check_mts = all([mt == mt_locus_expr,
                         mt == mt_coord_expr] +
                        [mt == x._indices.source
                         for x in wrap_to_list(annotation_exprs)])

    if not check_mts:
        raise ValueError("""ld_score: entry_expr, locus_expr, coord_expr
                            (if specified), and annotation_exprs (if
                            specified) must come from same MatrixTable.""")

    n = mt.count_cols()
    r2 = hl.row_correlation(entry_expr, block_size) ** 2
    r2_adj = ((n-1.0) / (n-2.0)) * r2 - (1.0 / (n-2.0))

    starts, stops = hl.linalg.utils.locus_windows(locus_expr,
                                                  radius,
                                                  coord_expr)
    r2_adj_sparse = r2_adj.sparsify_row_intervals(starts, stops)

    r2_adj_sparse_tmp = new_temp_file()
    r2_adj_sparse.write(r2_adj_sparse_tmp)
    r2_adj_sparse = BlockMatrix.read(r2_adj_sparse_tmp)

    if not annotation_exprs:
        cols = ['univariate']
        col_idxs = {0: 'univariate'}
        l2 = r2_adj_sparse.sum(axis=1)
    else:
        ht = mt.select_rows(*wrap_to_list(annotation_exprs)).rows()
        ht = ht.annotate(univariate=hl.literal(1.0))
        names = [name for name in ht.row if name not in ht.key]

        ht_union = hl.Table.union(
            *[(ht.annotate(name=hl.str(x),
                           value=hl.float(ht[x]))
                 .select('name', 'value')) for x in names])
        mt_annotations = ht_union.to_matrix_table(
            row_key=list(ht_union.key),
            col_key=['name'])

        cols = mt_annotations.key_cols_by()['name'].collect()
        col_idxs = {i: cols[i] for i in range(len(cols))}

        a_tmp = new_temp_file()
        BlockMatrix.write_from_entry_expr(mt_annotations.value, a_tmp)

        a = BlockMatrix.read(a_tmp)
        l2 = r2_adj_sparse @ a

    l2_bm_tmp = new_temp_file()
    l2_tsv_tmp = new_temp_file()
    l2.write(l2_bm_tmp, force_row_major=True)
    BlockMatrix.export(l2_bm_tmp, l2_tsv_tmp)

    ht_scores = hl.import_table(l2_tsv_tmp, no_header=True, impute=True)
    ht_scores = ht_scores.add_index()
    ht_scores = ht_scores.key_by('idx')
    ht_scores = ht_scores.rename({'f{:}'.format(i): col_idxs[i]
                                  for i in range(len(cols))})

    ht = mt.select_rows(__locus=locus_expr).rows()
    ht = ht.add_index()
    ht = ht.annotate(**ht_scores[ht.idx])
    ht = ht.key_by('__locus')
    ht = ht.select(*[x for x in ht_scores.row if x not in ht_scores.key])
    ht = ht.rename({'__locus': 'locus'})

    return ht
Example #7
0
def ld_score(entry_expr, annotation_exprs, position_expr,
             window_size) -> Table:
    """Calculate LD scores.

    Example
    -------

    >>> # Load genetic data into MatrixTable
    >>> mt = hl.import_plink(bed='data/ldsc.bed',
    ...                      bim='data/ldsc.bim',
    ...                      fam='data/ldsc.fam')

    >>> # Create locus-keyed Table with numeric variant annotations
    >>> ht = hl.import_table('data/ldsc.annot',
    ...                      types={'BP': hl.tint,
    ...                             'binary': hl.tfloat,
    ...                             'continuous': hl.tfloat})
    >>> ht = ht.annotate(locus=hl.locus(ht.CHR, ht.BP))
    >>> ht = ht.key_by('locus')

    >>> # Annotate MatrixTable with external annotations
    >>> mt = mt.annotate_rows(univariate_annotation=1,
    ...                       binary_annotation=ht[mt.locus].binary,
    ...                       continuous_annotation=ht[mt.locus].continuous)

    >>> # Annotate MatrixTable with alt allele count stats
    >>> mt = mt.annotate_rows(stats=hl.agg.stats(mt.GT.n_alt_alleles()))

    >>> # Create standardized genotype entry
    >>> mt = mt.annotate_entries(GT_std=hl.or_else(
    ...     (mt.GT.n_alt_alleles() - mt.stats.mean)/mt.stats.stdev, 0.0))

    >>> # Calculate LD scores using standardized genotypes
    >>> ht_scores = hl.experimental.ld_score(entry_expr=mt.GT_std,
    ...                                      annotation_exprs=[
    ...                                         mt.univariate_annotation,
    ...                                         mt.binary_annotation,
    ...                                         mt.continuous_annotation],
    ...                                      position_expr=mt.cm_position,
    ...                                      window_size=1)

    Warning
    -------
        :func:`.ld_score` will fail if ``entry_expr`` results in any missing
        values. The special float value ``nan`` is not considered a
        missing value.

    **Further reading**

    For more in-depth discussion of LD scores, see:

    - `LD Score regression distinguishes confounding from polygenicity in genome-wide association studies (Bulik-Sullivan et al, 2015) <https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4495769/>`__
    - `Partitioning heritability by functional annotation using genome-wide association summary statistics (Finucane et al, 2015) <https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4626285/>`__

    Parameters
    ----------
    entry_expr : :class:`.NumericExpression`
        Expression for entries of genotype matrix
        (e.g. ``mt.GT.n_alt_alleles()``).
    annotation_exprs : :class:`.NumericExpression` or
                       :obj:`list` of :class:`.NumericExpression`
        Annotation expression(s) to partition LD scores.
    position_expr : :class:`.NumericExpression`
        Expression for position of variant
        (e.g. ``mt.cm_position`` or ``mt.locus.position``).
    window_size : :obj:`int` or :obj:`float`
        Size of variant window used to calculate LD scores,
        in units of ``position``.

    Returns
    -------
    :class:`.Table`
        Locus-keyed table with LD scores for each variant and annotation."""

    assert window_size >= 0

    mt = entry_expr._indices.source
    annotations = wrap_to_list(annotation_exprs)
    variant_key = [x for x in mt.row_key]

    ht_annotations = mt.select_rows(*annotations).rows()
    annotation_names = [x for x in ht_annotations.row if x not in variant_key]

    ht_annotations = hl.Table.union(*[(ht_annotations.annotate(
        annotation=hl.str(x), value=hl.float(ht_annotations[x])).select(
            'annotation', 'value')) for x in annotation_names])
    mt_annotations = ht_annotations.to_matrix_table(row_key=variant_key,
                                                    col_key=['annotation'])

    cols = mt_annotations['annotation'].collect()
    col_idxs = {i: cols[i] for i in range(len(cols))}

    G = BlockMatrix.from_entry_expr(entry_expr)
    A = BlockMatrix.from_entry_expr(mt_annotations.value)

    n = G.n_cols

    R2 = ((G @ G.T) / n)**2
    R2_adj = R2 - (1.0 - R2) / (n - 2.0)

    positions = [(x[0], float(x[1]))
                 for x in hl.array([mt.locus.contig,
                                    hl.str(position_expr)]).collect()]
    n_positions = len(positions)

    starts = np.zeros(n_positions, dtype='int')
    stops = np.zeros(n_positions, dtype='int')

    contig = '0'
    for i, (c, p) in enumerate(positions):
        if c != contig:
            j = i
            k = i
            contig = c

        min_val = p - window_size
        max_val = p + window_size

        while j < n_positions and positions[j][1] < min_val:
            j += 1

        starts[i] = j

        if k == n_positions:
            stops[i] = k
            continue

        while positions[k][0] == contig and positions[k][1] <= max_val:
            k += 1
            if k == n_positions:
                break

        stops[i] = k

    R2_adj_sparse = R2_adj.sparsify_row_intervals([int(x) for x in starts],
                                                  [int(x) for x in stops])
    L2 = R2_adj_sparse @ A

    tmp_bm_path = new_temp_file()
    tmp_tsv_path = new_temp_file()

    L2.write(tmp_bm_path, force_row_major=True)
    BlockMatrix.export(tmp_bm_path, tmp_tsv_path)

    ht_scores = hl.import_table(tmp_tsv_path, no_header=True, impute=True)
    ht_scores = ht_scores.add_index()
    ht_scores = ht_scores.key_by('idx')
    ht_scores = ht_scores.rename(
        {'f{:}'.format(i): col_idxs[i]
         for i in range(len(cols))})

    ht_variants = mt.rows()
    ht_variants = ht_variants.drop(
        *[x for x in ht_variants.row if x not in variant_key])
    ht_variants = ht_variants.add_index()
    ht_variants = ht_variants.key_by('idx')

    ht_scores = ht_variants.join(ht_scores, how='inner')
    ht_scores = ht_scores.key_by('locus')
    ht_scores = ht_scores.drop('alleles', 'idx')

    return ht_scores