Example #1
0
def compute_and_annotate_ld_score(ht, r2_adj, radius, out_name, overwrite):
    starts_and_stops = hl.linalg.utils.locus_windows(ht.locus,
                                                     radius,
                                                     _localize=False)

    # Lifted directly from https://github.com/hail-is/hail/blob/555e02d6c792263db2c3ed97db8002b489e2dacb/hail/python/hail/methods/statgen.py#L2595
    # for the time being, until efficient BlockMatrix filtering gets an easier interface
    # This is required, as the squaring/multiplication densifies, so this re-sparsifies.
    r2_adj = BlockMatrix._from_java(
        r2_adj._jbm.filterRowIntervalsIR(
            Env.backend()._to_java_ir(starts_and_stops._ir), False))

    l2row = r2_adj.sum(axis=0).T
    l2col = r2_adj.sum(axis=1)
    l2 = l2row + l2col + 1
    l2_bm_tmp = new_temp_file()
    l2_tsv_tmp = new_temp_file()

    l2.write(l2_bm_tmp, force_row_major=True)
    BlockMatrix.export(l2_bm_tmp, l2_tsv_tmp)

    ht_scores = hl.import_table(l2_tsv_tmp, no_header=True, impute=True)
    ht_scores = ht_scores.add_index().rename({'f0': 'ld_score'})
    ht_scores = ht_scores.key_by('idx')
    ht = ht.annotate(**ht_scores[ht.new_idx]).select_globals()
    ht.filter(hl.is_defined(ht.ld_score)).write(out_name, overwrite)
Example #2
0
def generate_ld_scores_from_ld_matrix(pop_data,
                                      data_type,
                                      min_frequency=0.01,
                                      call_rate_cutoff=0.8,
                                      adj: bool = False,
                                      radius: int = 1000000,
                                      overwrite=False):
    # This function required a decent number of high-mem machines (with an SSD for good measure) to complete the AFR
    # For the rest, on 20 n1-standard-8's, 1h15m to export block matrix, 15 mins to compute LD scores per population (~$150 total)
    for label, pops in dict(pop_data).items():
        for pop, n in pops.items():
            if pop in ('nfe', 'fin', 'asj'): continue
            ht = hl.read_table(ld_index_path(data_type, pop, adj=adj))
            ht = ht.filter((ht.pop_freq.AF >= min_frequency)
                           & (ht.pop_freq.AF <= 1 - min_frequency)
                           & (ht.pop_freq.AN / n >= 2 *
                              call_rate_cutoff)).add_index(name='new_idx')

            indices = ht.idx.collect()

            r2 = BlockMatrix.read(
                ld_matrix_path(data_type,
                               pop,
                               min_frequency >= COMMON_FREQ,
                               adj=adj))
            r2 = r2.filter(indices, indices)**2
            r2_adj = ((n - 1.0) / (n - 2.0)) * r2 - (1.0 / (n - 2.0))

            starts_and_stops = hl.linalg.utils.locus_windows(ht.locus,
                                                             radius,
                                                             _localize=False)

            # Lifted directly from https://github.com/hail-is/hail/blob/555e02d6c792263db2c3ed97db8002b489e2dacb/hail/python/hail/methods/statgen.py#L2595
            # for the time being, until efficient BlockMatrix filtering gets an easier interface
            r2_adj = BlockMatrix._from_java(
                r2_adj._jbm.filterRowIntervalsIR(
                    Env.backend()._to_java_ir(starts_and_stops._ir), False))

            l2row = r2_adj.sum(axis=0).T
            l2col = r2_adj.sum(axis=1)
            l2 = l2row + l2col + 1

            l2_bm_tmp = new_temp_file()
            l2_tsv_tmp = new_temp_file()
            l2.write(l2_bm_tmp, force_row_major=True)
            BlockMatrix.export(l2_bm_tmp, l2_tsv_tmp)

            ht_scores = hl.import_table(l2_tsv_tmp,
                                        no_header=True,
                                        impute=True)
            ht_scores = ht_scores.add_index().rename({'f0': 'ld_score'})
            ht_scores = ht_scores.key_by('idx')

            ht = ht.annotate(**ht_scores[ht.new_idx]).select_globals()
            ht.filter(hl.is_defined(ht.ld_score)).write(
                ld_scores_path(data_type, pop, adj), overwrite)