Beispiel #1
0
def download_data():
    global _data_dir, _mt
    _data_dir = os.environ.get('HAIL_BENCHMARK_DIR',
                               '/tmp/hail_benchmark_data')
    print(f'using benchmark data directory {_data_dir}')
    os.makedirs(_data_dir, exist_ok=True)

    files = map(lambda f: os.path.join(_data_dir, f), [
        'profile.vcf.bgz', 'profile.mt', 'table_10M_par_1000.ht',
        'table_10M_par_100.ht', 'table_10M_par_10.ht',
        'gnomad_dp_simulation.mt', 'many_strings_table.ht'
    ])
    if not all(os.path.exists(file) for file in files):
        hl.init()  # use all cores

        vcf = os.path.join(_data_dir, 'profile.vcf.bgz')
        print('files not found - downloading...', end='', flush=True)
        urlretrieve(
            'https://storage.googleapis.com/hail-common/benchmark/profile.vcf.bgz',
            vcf)
        print('done', flush=True)
        print('importing...', end='', flush=True)
        hl.import_vcf(vcf, min_partitions=16).write(os.path.join(
            _data_dir, 'profile.mt'),
                                                    overwrite=True)

        ht = hl.utils.range_table(
            10_000_000,
            1000).annotate(**{f'f_{i}': hl.rand_unif(0, 1)
                              for i in range(5)})
        ht = ht.checkpoint(os.path.join(_data_dir, 'table_10M_par_1000.ht'),
                           overwrite=True)
        ht = ht.naive_coalesce(100).checkpoint(os.path.join(
            _data_dir, 'table_10M_par_100.ht'),
                                               overwrite=True)
        ht.naive_coalesce(10).write(os.path.join(_data_dir,
                                                 'table_10M_par_10.ht'),
                                    overwrite=True)

        mt = hl.utils.range_matrix_table(n_rows=250_000,
                                         n_cols=1_000,
                                         n_partitions=32)
        mt = mt.annotate_entries(x=hl.int(hl.rand_unif(0, 4.5)**3))
        mt.write(os.path.join(_data_dir, 'gnomad_dp_simulation.mt'),
                 overwrite=True)

        print('downloading many strings table...')
        mst_tsv = os.path.join(_data_dir, 'many_strings_table.tsv.bgz')
        mst_ht = os.path.join(_data_dir, 'many_strings_table.ht')
        urlretrieve(
            'https://storage.googleapis.com/hail-common/benchmark/many_strings_table.tsv.bgz',
            mst_tsv)
        print('importing...')
        hl.import_table(mst_tsv).write(mst_ht, overwrite=True)
        hl.stop()
    else:
        print('all files found.', flush=True)
Beispiel #2
0
def linear_regression_rows_nd(mt_path):
    mt = hl.read_matrix_table(mt_path)
    num_phenos = 100
    num_covs = 20
    pheno_dict = {f"pheno_{i}": hl.rand_unif(0, 1) for i in range(num_phenos)}
    cov_dict = {f"cov_{i}": hl.rand_unif(0, 1) for i in range(num_covs)}
    mt = mt.annotate_cols(**pheno_dict)
    mt = mt.annotate_cols(**cov_dict)
    res = hl._linear_regression_rows_nd(
        y=[mt[key] for key in pheno_dict.keys()],
        x=mt.x,
        covariates=[mt[key] for key in cov_dict.keys()])
    res._force_count()
Beispiel #3
0
def generate_random_gen():
    mt = hl.utils.range_matrix_table(30, 10)
    mt = (mt.annotate_rows(locus=hl.locus('20', mt.row_idx + 1),
                           alleles=['A', 'G']).key_rows_by('locus', 'alleles'))
    mt = (mt.annotate_cols(s=hl.str(mt.col_idx)).key_cols_by('s'))
    # using totally random values leads rounding differences where
    # identical GEN values get rounded differently, leading to
    # differences in the GT call between import_{gen, bgen}
    mt = mt.annotate_entries(a=hl.int32(hl.rand_unif(0.0, 255.0)))
    mt = mt.annotate_entries(b=hl.int32(hl.rand_unif(0.0, 255.0 - mt.a)))
    mt = mt.transmute_entries(GP=hl.array([mt.a, mt.b, 255.0 - mt.a - mt.b]) /
                              255.0)
    # 20% missing
    mt = mt.filter_entries(hl.rand_bool(0.8))
    hl.export_gen(mt, 'random', precision=4)
Beispiel #4
0
def generate_random_gen():
    mt = hl.utils.range_matrix_table(30, 10)
    mt = (mt.annotate_rows(locus = hl.locus('20', mt.row_idx + 1),
                           alleles = ['A', 'G'])
          .key_rows_by('locus', 'alleles'))
    mt = (mt.annotate_cols(s = hl.str(mt.col_idx))
          .key_cols_by('s'))
    # using totally random values leads rounding differences where
    # identical GEN values get rounded differently, leading to
    # differences in the GT call between import_{gen, bgen}
    mt = mt.annotate_entries(a = hl.int32(hl.rand_unif(0.0, 255.0)))
    mt = mt.annotate_entries(b = hl.int32(hl.rand_unif(0.0, 255.0 - mt.a)))
    mt = mt.transmute_entries(GP = hl.array([mt.a, mt.b, 255.0 - mt.a - mt.b]) / 255.0)
    # 20% missing
    mt = mt.filter_entries(hl.rand_bool(0.8))
    hl.export_gen(mt, 'random', precision=4)
Beispiel #5
0
def pc_relate_5k_5k(mt_path):
    mt = hl.read_matrix_table(mt_path)
    mt = mt.annotate_cols(scores=hl.range(2).map(lambda x: hl.rand_unif(0, 1)))
    rel = hl.pc_relate(mt.GT,
                       0.05,
                       scores_expr=mt.scores,
                       statistics='kin',
                       min_kinship=0.05)
    rel._force_count()
Beispiel #6
0
    def sample_ordering_expr(mt):
        """It can be problematic for downstream steps when several samples have many times more variants selected
        than in other samples. To avoid this, and distribute variants more evenly across samples,
        add a random number as the secondary sort order. This way, when many samples have an identically high GQ
        (as often happens for common variants), the same few samples don't get selected repeatedly for all common
        variants.
        """

        return -mt.GQ, hl.rand_unif(0, 1, seed=1)
Beispiel #7
0
 def _create(self, resource_dir):
     logging.info('creating gnomad_dp_simulation matrix table...')
     mt = hl.utils.range_matrix_table(n_rows=250_000,
                                      n_cols=1_000,
                                      n_partitions=32)
     mt = mt.annotate_entries(x=hl.int(hl.rand_unif(0, 4.5)**3))
     mt.write(os.path.join(resource_dir, 'gnomad_dp_simulation.mt'),
              overwrite=True)
     logging.info('done creating gnomad_dp_simulation matrix table.')
Beispiel #8
0
def pc_relate_big():
    mt = hl.balding_nichols_model(3, 2 * 4096, 2 * 4096).checkpoint(
        hl.utils.new_temp_file(extension='mt'))
    mt = mt.annotate_cols(scores=hl.range(2).map(lambda x: hl.rand_unif(0, 1)))
    rel = hl.pc_relate(mt.GT,
                       0.05,
                       scores_expr=mt.scores,
                       statistics='kin',
                       min_kinship=0.05)
    rel._force_count()
Beispiel #9
0
    def test_mt_full_outer_join(self):
        mt1 = hl.utils.range_matrix_table(10, 10)
        mt1 = mt1.annotate_cols(c1=hl.rand_unif(0, 1))
        mt1 = mt1.annotate_rows(r1=hl.rand_unif(0, 1))
        mt1 = mt1.annotate_entries(e1=hl.rand_unif(0, 1))

        mt2 = hl.utils.range_matrix_table(10, 10)
        mt2 = mt2.annotate_cols(c1=hl.rand_unif(0, 1))
        mt2 = mt2.annotate_rows(r1=hl.rand_unif(0, 1))
        mt2 = mt2.annotate_entries(e1=hl.rand_unif(0, 1))

        mtj = hl.experimental.full_outer_join_mt(mt1, mt2)
        assert (mtj.aggregate_entries(
            hl.agg.all(
                mtj.left_entry == mt1.index_entries(mtj.row_key, mtj.col_key)))
                )
        assert (mtj.aggregate_entries(
            hl.agg.all(mtj.right_entry == mt2.index_entries(
                mtj.row_key, mtj.col_key))))

        mt2 = mt2.key_cols_by(new_col_key=5 -
                              (mt2.col_idx // 2))  # duplicate col keys
        mt1 = mt1.key_rows_by(new_row_key=5 -
                              (mt1.row_idx // 2))  # duplicate row keys
        mtj = hl.experimental.full_outer_join_mt(mt1, mt2)

        assert (mtj.count() == (15, 15))
Beispiel #10
0
    def test(self):
        schema = hl.tstruct(a=hl.tint32, b=hl.tint32, c=hl.tint32, d=hl.tint32, e=hl.tstr,
                            f=hl.tarray(hl.tint32),
                            g=hl.tarray(
                                hl.tstruct(x=hl.tint32, y=hl.tint32, z=hl.tstr)),
                            h=hl.tstruct(a=hl.tint32, b=hl.tint32, c=hl.tstr),
                            i=hl.tbool,
                            j=hl.tstruct(x=hl.tint32, y=hl.tint32, z=hl.tstr))

        rows = [{'a': 4, 'b': 1, 'c': 3, 'd': 5,
                 'e': "hello", 'f': [1, 2, 3],
                 'g': [hl.Struct(x=1, y=5, z='banana')],
                 'h': hl.Struct(a=5, b=3, c='winter'),
                 'i': True,
                 'j': hl.Struct(x=3, y=2, z='summer')}]

        kt = hl.Table.parallelize(rows, schema)

        result = convert_struct_to_dict(kt.annotate(
            chisq=hl.chisq(kt.a, kt.b, kt.c, kt.d),
            ctt=hl.ctt(kt.a, kt.b, kt.c, kt.d, 5),
            dict=hl.dict(hl.zip([kt.a, kt.b], [kt.c, kt.d])),
            dpois=hl.dpois(4, kt.a),
            drop=kt.h.drop('b', 'c'),
            exp=hl.exp(kt.c),
            fet=hl.fisher_exact_test(kt.a, kt.b, kt.c, kt.d),
            hwe=hl.hardy_weinberg_p(1, 2, 1),
            index=hl.index(kt.g, 'z'),
            is_defined=hl.is_defined(kt.i),
            is_missing=hl.is_missing(kt.i),
            is_nan=hl.is_nan(hl.float64(kt.a)),
            json=hl.json(kt.g),
            log=hl.log(kt.a, kt.b),
            log10=hl.log10(kt.c),
            or_else=hl.or_else(kt.a, 5),
            or_missing=hl.or_missing(kt.i, kt.j),
            pchisqtail=hl.pchisqtail(kt.a, kt.b),
            pcoin=hl.rand_bool(0.5),
            pnorm=hl.pnorm(0.2),
            pow=2.0 ** kt.b,
            ppois=hl.ppois(kt.a, kt.b),
            qchisqtail=hl.qchisqtail(kt.a, kt.b),
            range=hl.range(0, 5, kt.b),
            rnorm=hl.rand_norm(0.0, kt.b),
            rpois=hl.rand_pois(kt.a),
            runif=hl.rand_unif(kt.b, kt.a),
            select=kt.h.select('c', 'b'),
            sqrt=hl.sqrt(kt.a),
            to_str=[hl.str(5), hl.str(kt.a), hl.str(kt.g)],
            where=hl.cond(kt.i, 5, 10)
        ).take(1)[0])
Beispiel #11
0
    def _create(self, resource_dir):

        def compatible_checkpoint(obj, path):
            obj.write(path, overwrite=True)
            return hl.read_table(path)

        ht = hl.utils.range_table(10_000_000, 1000).annotate(**{f'f_{i}': hl.rand_unif(0, 1) for i in range(5)})
        logging.info('Writing 1000-partition table...')
        ht = compatible_checkpoint(ht, os.path.join(resource_dir, 'table_10M_par_1000.ht'))
        logging.info('Writing 100-partition table...')
        ht = compatible_checkpoint(ht.repartition(100, shuffle=False), os.path.join(resource_dir, 'table_10M_par_100.ht'))
        logging.info('Writing 10-partition table...')
        ht.repartition(10, shuffle=False).write(os.path.join(resource_dir, 'table_10M_par_10.ht'), overwrite=True)
        logging.info('done writing many-partitions tables.')
Beispiel #12
0
 def _create(self, resource_dir):
     ht = hl.utils.range_table(
         10_000_000,
         1000).annotate(**{f'f_{i}': hl.rand_unif(0, 1)
                           for i in range(5)})
     logging.info('Writing 1000-partition table...')
     ht = ht.checkpoint(os.path.join(resource_dir, 'table_10M_par_1000.ht'),
                        overwrite=True)
     logging.info('Writing 100-partition table...')
     ht = ht.naive_coalesce(100).checkpoint(os.path.join(
         resource_dir, 'table_10M_par_100.ht'),
                                            overwrite=True)
     logging.info('Writing 10-partition table...')
     ht.naive_coalesce(10).write(os.path.join(resource_dir,
                                              'table_10M_par_10.ht'),
                                 overwrite=True)
     logging.info('done writing many-partitions tables.')
Beispiel #13
0
def filter_ht_for_plink(ht: hl.Table,
                        n_samples: int,
                        min_call_rate: float = 0.95,
                        variants_per_mac_category: int = 2000,
                        variants_per_maf_category: int = 10000):
    from gnomad.utils.filtering import filter_to_autosomes
    ht = filter_to_autosomes(ht)
    ht = ht.filter((ht.call_stats.AN >= n_samples * 2 * min_call_rate)
                   & (ht.call_stats.AC > 0))
    ht = ht.annotate(mac_category=mac_category_case_builder(ht.call_stats))
    category_counter = ht.aggregate(hl.agg.counter(ht.mac_category))
    print(category_counter)
    ht = ht.annotate_globals(category_counter=category_counter)
    return ht.filter(
        hl.rand_unif(
            0, 1) < hl.cond(ht.mac_category >= 1, variants_per_mac_category,
                            variants_per_maf_category) /
        ht.category_counter[ht.mac_category])
Beispiel #14
0
def prepare_mt_for_plink(mt: hl.MatrixTable,
                         n_samples: int,
                         min_call_rate: float = 0.95,
                         variants_per_mac_category: int = 2000,
                         variants_per_maf_category: int = 10000):
    from gnomad.utils.filtering import filter_to_autosomes
    mt = filter_to_autosomes(mt)
    mt = mt.filter_rows((mt.call_stats.AN >= n_samples * 2 * min_call_rate)
                        & (mt.call_stats.AC[1] > 0))
    mt = mt.annotate_rows(
        mac_category=mac_category_case_builder(mt.call_stats))
    category_counter = mt.aggregate_rows(hl.agg.counter(mt.mac_category))
    print(category_counter)
    mt = mt.annotate_globals(category_counter=category_counter)
    return mt.filter_rows(
        hl.rand_unif(
            0, 1) < hl.cond(mt.mac_category >= 1, variants_per_mac_category,
                            variants_per_maf_category) /
        mt.category_counter[mt.mac_category])
Beispiel #15
0
def logistic_regression_rows_wald_nd(mt_path):
    mt = hl.read_matrix_table(mt_path)
    mt = mt.head(2000)
    num_phenos = 5
    num_covs = 2
    pheno_dict = {
        f"pheno_{i}": hl.rand_bool(.5, seed=i)
        for i in range(num_phenos)
    }
    cov_dict = {
        f"cov_{i}": hl.rand_unif(0, 1, seed=i)
        for i in range(num_covs)
    }
    mt = mt.annotate_cols(**pheno_dict)
    mt = mt.annotate_cols(**cov_dict)
    res = hl._logistic_regression_rows_nd(
        test='wald',
        y=[mt[key] for key in pheno_dict.keys()],
        x=mt.x,
        covariates=[mt[key] for key in cov_dict.keys()])
    res._force_count()
Beispiel #16
0
    def test_mt_full_outer_join(self):
        mt1 = hl.utils.range_matrix_table(10, 10)
        mt1 = mt1.annotate_cols(c1=hl.rand_unif(0, 1))
        mt1 = mt1.annotate_rows(r1=hl.rand_unif(0, 1))
        mt1 = mt1.annotate_entries(e1=hl.rand_unif(0, 1))

        mt2 = hl.utils.range_matrix_table(10, 10)
        mt2 = mt2.annotate_cols(c1=hl.rand_unif(0, 1))
        mt2 = mt2.annotate_rows(r1=hl.rand_unif(0, 1))
        mt2 = mt2.annotate_entries(e1=hl.rand_unif(0, 1))

        mtj = hl.experimental.full_outer_join_mt(mt1, mt2)
        assert(mtj.aggregate_entries(hl.agg.all(mtj.left_entry == mt1.index_entries(mtj.row_key, mtj.col_key))))
        assert(mtj.aggregate_entries(hl.agg.all(mtj.right_entry == mt2.index_entries(mtj.row_key, mtj.col_key))))

        mt2 = mt2.key_cols_by(new_col_key = 5 - (mt2.col_idx // 2)) # duplicate col keys
        mt1 = mt1.key_rows_by(new_row_key = 5 - (mt1.row_idx // 2)) # duplicate row keys
        mtj = hl.experimental.full_outer_join_mt(mt1, mt2)

        assert(mtj.count() == (15, 15))
Beispiel #17
0
def download_data(data_dir):
    global _data_dir, _mt
    _data_dir = data_dir or os.environ.get(
        'HAIL_BENCHMARK_DIR') or '/tmp/hail_benchmark_data'
    logging.info(f'using benchmark data directory {_data_dir}')
    os.makedirs(_data_dir, exist_ok=True)

    files = map(lambda f: os.path.join(_data_dir, f), [
        'profile.vcf.bgz', 'profile.mt', 'table_10M_par_1000.ht',
        'table_10M_par_100.ht', 'table_10M_par_10.ht',
        'gnomad_dp_simulation.mt', 'many_strings_table.ht',
        'many_ints_table.ht', 'sim_ukb.bgen'
    ])
    if not all(os.path.exists(file) for file in files):
        hl.init()  # use all cores

        vcf = os.path.join(_data_dir, 'profile.vcf.bgz')
        logging.info('downloading profile.vcf.bgz...')
        urlretrieve(
            'https://storage.googleapis.com/hail-common/benchmark/profile.vcf.bgz',
            vcf)
        logging.info('done downloading profile.vcf.bgz.')
        logging.info('importing profile.vcf.bgz...')
        hl.import_vcf(vcf, min_partitions=16).write(os.path.join(
            _data_dir, 'profile.mt'),
                                                    overwrite=True)
        logging.info('done importing profile.vcf.bgz.')

        logging.info('writing 10M row partitioned tables...')

        ht = hl.utils.range_table(
            10_000_000,
            1000).annotate(**{f'f_{i}': hl.rand_unif(0, 1)
                              for i in range(5)})
        ht = ht.checkpoint(os.path.join(_data_dir, 'table_10M_par_1000.ht'),
                           overwrite=True)
        ht = ht.naive_coalesce(100).checkpoint(os.path.join(
            _data_dir, 'table_10M_par_100.ht'),
                                               overwrite=True)
        ht.naive_coalesce(10).write(os.path.join(_data_dir,
                                                 'table_10M_par_10.ht'),
                                    overwrite=True)
        logging.info('done writing 10M row partitioned tables.')

        logging.info('creating gnomad_dp_simulation matrix table...')
        mt = hl.utils.range_matrix_table(n_rows=250_000,
                                         n_cols=1_000,
                                         n_partitions=32)
        mt = mt.annotate_entries(x=hl.int(hl.rand_unif(0, 4.5)**3))
        mt.write(os.path.join(_data_dir, 'gnomad_dp_simulation.mt'),
                 overwrite=True)
        logging.info('done creating gnomad_dp_simulation matrix table.')

        logging.info('downloading many_strings_table.tsv.bgz...')
        mst_tsv = os.path.join(_data_dir, 'many_strings_table.tsv.bgz')
        mst_ht = os.path.join(_data_dir, 'many_strings_table.ht')
        urlretrieve(
            'https://storage.googleapis.com/hail-common/benchmark/many_strings_table.tsv.bgz',
            mst_tsv)
        logging.info('done downloading many_strings_table.tsv.bgz.')
        logging.info('importing many_strings_table.tsv.bgz...')
        hl.import_table(mst_tsv).write(mst_ht, overwrite=True)
        logging.info('done importing many_strings_table.tsv.bgz.')

        logging.info('downloading many_ints_table.tsv.bgz...')
        mit_tsv = os.path.join(_data_dir, 'many_ints_table.tsv.bgz')
        mit_ht = os.path.join(_data_dir, 'many_ints_table.ht')
        urlretrieve(
            'https://storage.googleapis.com/hail-common/benchmark/many_ints_table.tsv.bgz',
            mit_tsv)
        logging.info('done downloading many_ints_table.tsv.bgz.')
        logging.info('importing many_ints_table.tsv.bgz...')
        hl.import_table(mit_tsv,
                        types={
                            'idx': 'int',
                            **{f'i{i}': 'int'
                               for i in range(5)},
                            **{f'array{i}': 'array<int>'
                               for i in range(2)}
                        }).write(mit_ht, overwrite=True)
        logging.info('done importing many_ints_table.tsv.bgz.')

        bgen = 'sim_ukb.bgen'
        sample = 'sim_ukb.sample'
        logging.info(f'downloading {bgen}...')
        local_bgen = os.path.join(_data_dir, bgen)
        local_sample = os.path.join(_data_dir, sample)
        urlretrieve(
            f'https://storage.googleapis.com/hail-common/benchmark/{bgen}',
            local_bgen)
        urlretrieve(
            f'https://storage.googleapis.com/hail-common/benchmark/{sample}',
            local_sample)
        logging.info(f'done downloading {bgen}...')
        logging.info(f'indexing {bgen}...')
        hl.index_bgen(local_bgen)
        logging.info(f'done indexing {bgen}.')

        hl.stop()
    else:
        logging.info('all files found.')
Beispiel #18
0
    def _test_linear_mixed_model_low_rank(self):
        seed = 0
        n_populations = 8
        fst = n_populations * [.9]
        n_samples = 500
        n_variants = 200
        n_orig_markers = 100
        n_culprits = 10
        n_covariates = 3
        sigma_sq = 1
        tau_sq = 1

        from numpy.random import RandomState
        prng = RandomState(seed)

        x = np.hstack((np.ones(shape=(n_samples, 1)),
                       prng.normal(size=(n_samples, n_covariates - 1))))

        mt = hl.balding_nichols_model(n_populations=n_populations,
                                      n_samples=n_samples,
                                      n_variants=n_variants,
                                      fst=fst,
                                      af_dist=hl.rand_unif(0.1, 0.9, seed=seed),
                                      seed=seed)

        pa_t_path = utils.new_temp_file(suffix='bm')
        a_t_path = utils.new_temp_file(suffix='bm')

        BlockMatrix.write_from_entry_expr(mt.GT.n_alt_alleles(), a_t_path)

        a = BlockMatrix.read(a_t_path).T.to_numpy()
        g = a[:, -n_orig_markers:]
        g_std = self._filter_and_standardize_cols(g)

        n_markers = g_std.shape[1]

        k = (g_std @ g_std.T) * n_samples / n_markers

        beta = np.arange(n_covariates)
        beta_stars = np.array([1] * n_culprits)

        y = prng.multivariate_normal(
            np.hstack((a[:, 0:n_culprits], x)) @ np.hstack((beta_stars, beta)),
            sigma_sq * k + tau_sq * np.eye(n_samples))

        # low rank computation of S, P
        l = g_std.T @ g_std
        sl, v = np.linalg.eigh(l)
        n_eigenvectors = int(np.sum(sl > 1e-10))
        sl = sl[-n_eigenvectors:]
        v = v[:, -n_eigenvectors:]
        s = sl * (n_samples / n_markers)
        p = (g_std @ (v / np.sqrt(sl))).T

        # compare with full rank S, P
        sk0, uk = np.linalg.eigh(k)
        sk = sk0[-n_eigenvectors:]
        pk = uk[:, -n_eigenvectors:].T
        assert np.allclose(sk, s)
        assert np.allclose(np.abs(pk), np.abs(p))

        # build and fit model
        py = p @ y
        px = p @ x
        pa = p @ a

        model = LinearMixedModel(py, px, s, y, x)
        assert model.n == n_samples
        assert model.f == n_covariates
        assert model.r == n_eigenvectors
        assert model.low_rank

        model.fit()

        # check effect sizes tend to be near 1 for first n_marker alternative models
        BlockMatrix.from_numpy(pa).T.write(pa_t_path, force_row_major=True)
        df_lmm = model.fit_alternatives(pa_t_path, a_t_path).to_pandas()

        assert 0.9 < np.mean(df_lmm['beta'][:n_culprits]) < 1.1

        # compare NumPy and Hail LMM per alternative
        df_numpy = model.fit_alternatives_numpy(pa, a).to_pandas()
        assert np.min(df_numpy['chi_sq']) > 0

        na_numpy = df_numpy.isna().any(axis=1)
        na_lmm = df_lmm.isna().any(axis=1)

        assert na_numpy.sum() <= 10
        assert na_lmm.sum() <= 10
        assert np.logical_xor(na_numpy, na_lmm).sum() <= 5

        mask = ~(na_numpy | na_lmm)

        lmm_vs_numpy_p_value = np.sort(np.abs(df_lmm['p_value'][mask] - df_numpy['p_value'][mask]))

        assert lmm_vs_numpy_p_value[10] < 1e-12  # 10 least p-values differences
        assert lmm_vs_numpy_p_value[-1] < 1e-8   # all p-values
def annotate_freq(
    mt: hl.MatrixTable,
    sex_expr: Optional[hl.expr.StringExpression] = None,
    pop_expr: Optional[hl.expr.StringExpression] = None,
    subpop_expr: Optional[hl.expr.StringExpression] = None,
    additional_strata_expr: Optional[Dict[str,
                                          hl.expr.StringExpression]] = None,
    downsamplings: Optional[List[int]] = None,
) -> hl.MatrixTable:
    """
    Adds a row annotation `freq` to the input `mt` with stratified allele frequencies,
    and a global annotation `freq_meta` with metadata.

    .. note::

        Currently this only supports bi-allelic sites.
        The input `mt` needs to have the following entry fields:
        - GT: a CallExpression containing the genotype
        - adj: a BooleanExpression containing whether the genotype is of high quality or not.
        All expressions arguments need to be expression on the input `mt`.

    .. rubric:: `freq` row annotation

    The `freq` row annotation is an Array of Struct, with each Struct containing the following fields:

        - AC: int32
        - AF: float64
        - AN: int32
        - homozygote_count: int32

    Each element of the array corresponds to a stratification of the data, and the metadata about these annotations is
    stored in the globals.

    .. rubric:: Global `freq_meta` metadata annotation

    The global annotation `freq_meta` is added to the input `mt`. It is a list of dict.
    Each element of the list contains metadata on a frequency stratification and the index in the list corresponds
    to the index of that frequency stratification in the `freq` row annotation.

    .. rubric:: The `downsamplings` parameter

    If the `downsamplings` parameter is used, frequencies will be computed for all samples and by population
    (if `pop_expr` is specified) by downsampling the number of samples without replacement to each of the numbers specified in the
    `downsamplings` array, provided that there are enough samples in the dataset.
    In addition, if `pop_expr` is specified, a downsampling to each of the exact number of samples present in each population is added.
    Note that samples are randomly sampled only once, meaning that the lower downsamplings are subsets of the higher ones.

    :param mt: Input MatrixTable
    :param sex_expr: When specified, frequencies are stratified by sex. If `pop_expr` is also specified, then a pop/sex stratifiction is added.
    :param pop_expr: When specified, frequencies are stratified by population. If `sex_expr` is also specified, then a pop/sex stratifiction is added.
    :param subpop_expr: When specified, frequencies are stratified by sub-continental population. Note that `pop_expr` is required as well when using this option.
    :param additional_strata_expr: When specified, frequencies are stratified by the given additional strata found in the dict. This can e.g. be used to stratify by platform.
    :param downsamplings: When specified, frequencies are computed by downsampling the data to the number of samples given in the list. Note that if `pop_expr` is specified, downsamplings by population is also computed.
    :return: MatrixTable with `freq` annotation
    """

    if subpop_expr is not None and pop_expr is None:
        raise NotImplementedError(
            "annotate_freq requires pop_expr when using subpop_expr")

    if additional_strata_expr is None:
        additional_strata_expr = {}

    _freq_meta_expr = hl.struct(**additional_strata_expr)
    if sex_expr is not None:
        _freq_meta_expr = _freq_meta_expr.annotate(sex=sex_expr)
    if pop_expr is not None:
        _freq_meta_expr = _freq_meta_expr.annotate(pop=pop_expr)
    if subpop_expr is not None:
        _freq_meta_expr = _freq_meta_expr.annotate(subpop=subpop_expr)

    # Annotate cols with provided cuts
    mt = mt.annotate_cols(_freq_meta=_freq_meta_expr)

    # Get counters for sex, pop and subpop if set
    cut_dict = {
        cut: hl.agg.filter(hl.is_defined(mt._freq_meta[cut]),
                           hl.agg.counter(mt._freq_meta[cut]))
        for cut in mt._freq_meta if cut != "subpop"
    }
    if "subpop" in mt._freq_meta:
        cut_dict["subpop"] = hl.agg.filter(
            hl.is_defined(mt._freq_meta.pop)
            & hl.is_defined(mt._freq_meta.subpop),
            hl.agg.counter(
                hl.struct(subpop=mt._freq_meta.subpop, pop=mt._freq_meta.pop)),
        )

    cut_data = mt.aggregate_cols(hl.struct(**cut_dict))
    sample_group_filters = []

    # Create downsamplings if needed
    if downsamplings is not None:
        # Add exact pop size downsampling if pops were provided
        if cut_data.get("pop"):
            downsamplings = list(
                set(downsamplings + list(cut_data.get("pop").values()))
            )  # Add the pops values if not in yet
            downsamplings = sorted([
                x for x in downsamplings
                if x <= sum(cut_data.get("pop").values())
            ])
        logger.info(
            f"Found {len(downsamplings)} downsamplings: {downsamplings}")

        # Shuffle the samples, then create a global index for downsampling
        # And a pop-index if pops were provided
        downsampling_ht = mt.cols()
        downsampling_ht = downsampling_ht.annotate(r=hl.rand_unif(0, 1))
        downsampling_ht = downsampling_ht.order_by(downsampling_ht.r)
        scan_expr = {"global_idx": hl.scan.count()}
        if cut_data.get("pop"):
            scan_expr["pop_idx"] = hl.scan.counter(
                downsampling_ht._freq_meta.pop).get(
                    downsampling_ht._freq_meta.pop, 0)
        downsampling_ht = downsampling_ht.annotate(**scan_expr)
        downsampling_ht = downsampling_ht.key_by("s").select(*scan_expr)
        mt = mt.annotate_cols(downsampling=downsampling_ht[mt.s])
        mt = mt.annotate_globals(downsamplings=downsamplings)

        # Create downsampled sample groups
        sample_group_filters.extend([(
            {
                "downsampling": str(ds),
                "pop": "global"
            },
            mt.downsampling.global_idx < ds,
        ) for ds in downsamplings])
        if cut_data.get("pop"):
            sample_group_filters.extend([
                (
                    {
                        "downsampling": str(ds),
                        "pop": pop
                    },
                    (mt.downsampling.pop_idx < ds) &
                    (mt._freq_meta.pop == pop),
                ) for ds in downsamplings
                for pop, pop_count in cut_data.get("pop", {}).items()
                if ds <= pop_count
            ])

    # Add all desired strata, starting with the full set and ending with downsamplings (if any)
    sample_group_filters = ([({}, True)] + [({
        "pop": pop
    }, mt._freq_meta.pop == pop) for pop in cut_data.get("pop", {})] +
                            [({
                                "sex": sex
                            }, mt._freq_meta.sex == sex)
                             for sex in cut_data.get("sex", {})] +
                            [(
                                {
                                    "pop": pop,
                                    "sex": sex
                                },
                                (mt._freq_meta.sex == sex) &
                                (mt._freq_meta.pop == pop),
                            ) for sex in cut_data.get("sex", {})
                             for pop in cut_data.get("pop", {})] + [(
                                 {
                                     "subpop": subpop.subpop,
                                     "pop": subpop.pop
                                 },
                                 (mt._freq_meta.pop == subpop.pop)
                                 & (mt._freq_meta.subpop == subpop.subpop),
                             ) for subpop in cut_data.get("subpop", {})] +
                            [({
                                strata: str(s_value)
                            }, mt._freq_meta[strata] == s_value)
                             for strata in additional_strata_expr
                             for s_value in cut_data.get(strata, {})] +
                            sample_group_filters)

    # Annotate columns with group_membership
    mt = mt.annotate_cols(
        group_membership=[x[1] for x in sample_group_filters])

    # Create and annotate global expression with meta information
    freq_meta_expr = [
        dict(**sample_group[0], group="adj")
        for sample_group in sample_group_filters
    ]
    freq_meta_expr.insert(1, {"group": "raw"})
    mt = mt.annotate_globals(freq_meta=freq_meta_expr)

    # Create frequency expression array from the sample groups
    freq_expr = hl.agg.array_agg(
        lambda i: hl.agg.filter(mt.group_membership[i] & mt.adj,
                                hl.agg.call_stats(mt.GT, mt.alleles)),
        hl.range(len(sample_group_filters)),
    )

    # Insert raw as the second element of the array
    freq_expr = (freq_expr[:1].extend([hl.agg.call_stats(mt.GT, mt.alleles)
                                       ]).extend(freq_expr[1:]))

    # Select non-ref allele (assumes bi-allelic)
    freq_expr = freq_expr.map(lambda cs: cs.annotate(
        AC=cs.AC[1],
        AF=cs.AF[
            1
        ],  # TODO This is NA in case AC and AN are 0 -- should we set it to 0?
        homozygote_count=cs.homozygote_count[1],
    ))

    # Return MT with freq row annotation
    return mt.annotate_rows(freq=freq_expr).drop("_freq_meta")
Beispiel #20
0
def main(args):
    betas = ['beta_01', 'beta_1', 'beta_10', 'beta_100']
    spike_slab = hl.import_table(
        'gs://armartin/mama/spike_slab/BBJ_UKB_hm3.chr22.cm.beta.txt',
        impute=True)
    spike_slab = spike_slab.key_by(**hl.parse_variant(spike_slab.v))
    if args.compute_true_phenotypes:
        # get the white british subset
        eur = hl.import_table(
            'gs://phenotype_31063/ukb31063.gwas_covariates.both_sexes.tsv'
        ).key_by('s')

        # read in imputed data, subset to chr22
        mt = hl.read_matrix_table(
            'gs://phenotype_31063/hail/imputed/ukb31063.dosage.autosomes.mt')
        mt = hl.filter_intervals(mt, [hl.parse_locus_interval('22')])

        # annotate and filter imputed data to all sites with causal effects
        mt = mt.annotate_rows(ss=spike_slab[mt.row_key])
        mt = mt.filter_rows(hl.is_defined(mt.ss))

        # compute true PRS (i.e. phenotypes)
        annot_expr = {i: hl.agg.sum(mt.ss[i] * mt.dosage) for i in betas}

        # write out phenos for white British unrelated subset
        mt = mt.annotate_cols(**annot_expr)
        mt = mt.filter_cols(hl.is_defined(eur[mt.s]))
        mt.cols().write(
            'gs://armartin/mama/spike_slab/BBJ_UKB_hm3.chr22.cm.beta.true_PRS.ht',
            stage_locally=True,
            overwrite=True)

    if args.run_gwas:
        # read back in PRS (now true phenotypes)
        phenos = hl.read_table(
            'gs://armartin/mama/spike_slab/BBJ_UKB_hm3.chr22.cm.beta.true_PRS.ht'
        ).key_by('s')
        phenos.show()
        covariates = hl.import_table(
            'gs://phenotype_31063/ukb31063.gwas_covariates.both_sexes.tsv',
            impute=True,
            types={
                's': hl.tstr
            }).key_by('s')
        full_mt = hl.read_matrix_table(
            'gs://phenotype_31063/hail/imputed/ukb31063.dosage.autosomes.mt')
        full_mt = full_mt.annotate_cols(**covariates[full_mt.s])
        full_mt = hl.filter_intervals(full_mt, [hl.parse_locus_interval('22')])

        # annotate and filter imputed data to all sites with causal effects
        full_mt = full_mt.annotate_rows(ss=spike_slab[full_mt.row_key])
        full_mt = full_mt.filter_rows(hl.is_defined(full_mt.ss))

        # subset to white British subset, get 10 sets of 10k and run a gwas for each of these w/ PCs as covs
        for i in range(10):
            subset_pheno = phenos.annotate(r=hl.rand_unif(0, 1))
            subset_pheno = subset_pheno.order_by(
                subset_pheno.r).add_index('global_idx').key_by('s')
            subset_pheno = subset_pheno.filter(subset_pheno.global_idx < 10000)
            mt = full_mt.annotate_cols(**subset_pheno[full_mt.s])
            mt = mt.annotate_rows(maf=hl.agg.mean(mt.dosage) / 2)
            result_ht = hl.linear_regression_rows(
                y=[mt[i] for i in betas],
                x=mt.dosage,
                covariates=[1] + [mt['PC' + str(i)] for i in range(1, 21)],
                pass_through=['rsid', 'maf'])

            subset_pheno.export(
                'gs://armartin/mama/spike_slab/UKB_hm3.chr22.cm.beta.true_PRS.gwas_inds_'
                + str(i) + '.tsv.gz')
            result_ht.write(
                'gs://armartin/mama/spike_slab/UKB_hm3.chr22.cm.beta.true_PRS.gwas_sumstat_'
                + str(i) + '.ht',
                overwrite=True)

    if args.write_gwas:
        for i in range(10):
            result_ht = hl.read_table(
                'gs://armartin/mama/spike_slab/UKB_hm3.chr22.cm.beta.true_PRS.gwas_sumstat_'
                + str(i) + '.ht')
            result_ht = result_ht.key_by()
            get_expr = {
                field + '_' + x: result_ht[field][i]
                for i, x in enumerate(betas)
                for field in ['beta', 'standard_error', 'p_value']
            }
            result_ht.select(chr=result_ht.locus.contig, pos=result_ht.locus.position, rsid=result_ht.rsid, ref=result_ht.alleles[0],
                             alt=result_ht.alleles[1], maf=result_ht.maf, n=result_ht.n, **get_expr)\
                .export('gs://armartin/mama/spike_slab/UKB_hm3.chr22.cm.beta.true_PRS.gwas_sumstat_' + str(i) + '.tsv.gz')
def compute_quantile_bin(
    ht: hl.Table,
    score_expr: hl.expr.NumericExpression,
    bin_expr: Dict[str, hl.expr.BooleanExpression] = {"bin": True},
    compute_snv_indel_separately: bool = True,
    n_bins: int = 100,
    k: int = 1000,
    desc: bool = True,
) -> hl.Table:
    """
    Returns a table with a bin for each row based on quantiles of `score_expr`.
    The bin is computed by dividing the `score_expr` into `n_bins` bins containing an equal number of elements.
    This is done based on quantiles computed with hl.agg.approx_quantiles. If a single value in `score_expr` spans more
    than one bin, the rows with this value are distributed randomly across the bins it spans.
    If `compute_snv_indel_separately` is True all items in `bin_expr` will be stratified by snv / indels for the bin
    calculation. Because SNV and indel rows are mutually exclusive, they are re-combined into a single annotation. For
    example if we have the following four variants and scores and `n_bins` of 2:
    ========   =======   ======   =================   =================
    Variant    Type      Score    bin - `compute_snv_indel_separately`:
    --------   -------   ------   -------------------------------------
    \          \         \        False               True
    ========   =======   ======   =================   =================
    Var1       SNV       0.1      1                   1
    Var2       SNV       0.2      1                   2
    Var3       Indel     0.3      2                   1
    Var4       Indel     0.4      2                   2
    ========   =======   ======   =================   =================
    .. note::
        The `bin_expr` defines which data the bin(s) should be computed on. E.g., to get a biallelic quantile bin and an
        singleton quantile bin, the following could be used:
        .. code-block:: python
            bin_expr={
                'biallelic_bin': ~ht.was_split,
                'singleton_bin': ht.singleton
            }
    :param ht: Input Table
    :param score_expr: Expression containing the score
    :param bin_expr: Quantile bin(s) to be computed (see notes)
    :param compute_snv_indel_separately: Should all `bin_expr` items be stratified by snv / indels
    :param n_bins: Number of bins to bin the data into
    :param k: The `k` parameter of approx_quantiles
    :param desc: Whether to bin the score in descending order
    :return: Table with the quantile bins
    """
    import math

    def quantiles_to_bin_boundaries(quantiles: List[int]) -> Dict:
        """
        Merges bins with the same boundaries into a unique bin while keeping track of
        which bins have been merged and the global index of all bins.
        :param quantiles: Original bins boundaries
        :return: (dict of the indices of bins for which multiple bins were collapsed -> number of bins collapsed,
                  Global indices of merged bins,
                  Merged bins boundaries)
        """

        # Pad the quantiles to create boundaries for the first and last bins
        bin_boundaries = [-math.inf] + quantiles + [math.inf]
        merged_bins = defaultdict(int)

        # If every quantile has a unique value, then bin boudaries are unique
        # and can be passed to binary_search as-is
        if len(quantiles) == len(set(quantiles)):
            return dict(
                merged_bins=merged_bins,
                global_bin_indices=list(range(len(bin_boundaries))),
                bin_boundaries=bin_boundaries,
            )

        indexed_bins = list(enumerate(bin_boundaries))
        i = 1
        while i < len(indexed_bins):
            if indexed_bins[i - 1][1] == indexed_bins[i][1]:
                merged_bins[i - 1] += 1
                indexed_bins.pop(i)
            else:
                i += 1

        return dict(
            merged_bins=merged_bins,
            global_bin_indices=[x[0] for x in indexed_bins],
            bin_boundaries=[x[1] for x in indexed_bins],
        )

    if compute_snv_indel_separately:
        # For each bin, add a SNV / indel stratification
        bin_expr = {
            f"{bin_id}_{snv}": (bin_expr & snv_expr)
            for bin_id, bin_expr in bin_expr.items() for snv, snv_expr in [
                ("snv", hl.is_snp(ht.alleles[0], ht.alleles[1])),
                ("indel", ~hl.is_snp(ht.alleles[0], ht.alleles[1])),
            ]
        }
        print("ADSADSADASDAS")
        print(bin_expr)

    bin_ht = ht.annotate(
        **{
            f"_filter_{bin_id}": bin_expr
            for bin_id, bin_expr in bin_expr.items()
        },
        _score=score_expr,
        snv=hl.is_snp(ht.alleles[0], ht.alleles[1]),
    )
    print(bin_ht.show())
    logger.info(
        f"Adding quantile bins using approximate_quantiles binned into {n_bins}, using k={k}"
    )
    bin_stats = bin_ht.aggregate(
        hl.struct(
            **{
                bin_id: hl.agg.filter(
                    bin_ht[f"_filter_{bin_id}"],
                    hl.struct(
                        n=hl.agg.count(),
                        quantiles=hl.agg.approx_quantiles(
                            bin_ht._score,
                            [x / (n_bins) for x in range(1, n_bins)],
                            k=k),
                    ),
                )
                for bin_id in bin_expr
            }))

    # Take care of bins with duplicated boundaries
    bin_stats = bin_stats.annotate(
        **{
            rname: bin_stats[rname].annotate(
                **quantiles_to_bin_boundaries(bin_stats[rname].quantiles))
            for rname in bin_stats
        })

    bin_ht = bin_ht.annotate_globals(bin_stats=hl.literal(
        bin_stats,
        dtype=hl.tstruct(
            **{
                bin_id: hl.tstruct(
                    n=hl.tint64,
                    quantiles=hl.tarray(hl.tfloat64),
                    bin_boundaries=hl.tarray(hl.tfloat64),
                    global_bin_indices=hl.tarray(hl.tint32),
                    merged_bins=hl.tdict(hl.tint32, hl.tint32),
                )
                for bin_id in bin_expr
            }),
    ))

    # Annotate the bin as the index in the unique boundaries array
    bin_ht = bin_ht.annotate(
        **{
            bin_id: hl.or_missing(
                bin_ht[f"_filter_{bin_id}"],
                hl.binary_search(bin_ht.bin_stats[bin_id].bin_boundaries,
                                 bin_ht._score),
            )
            for bin_id in bin_expr
        })

    # Convert the bin to global bin by expanding merged bins, that is:
    # If a value falls in a bin that needs expansion, assign it randomly to one of the expanded bins
    # Otherwise, simply modify the bin to its global index (with expanded bins that is)
    bin_ht = bin_ht.select(
        "snv",
        **{
            bin_id: hl.if_else(
                bin_ht.bin_stats[bin_id].merged_bins.contains(bin_ht[bin_id]),
                bin_ht.bin_stats[bin_id].global_bin_indices[bin_ht[bin_id]] +
                hl.int(
                    hl.rand_unif(
                        0, bin_ht.bin_stats[bin_id].merged_bins[bin_ht[bin_id]]
                        + 1)),
                bin_ht.bin_stats[bin_id].global_bin_indices[bin_ht[bin_id]],
            )
            for bin_id in bin_expr
        },
    )

    if desc:
        bin_ht = bin_ht.annotate(
            **{bin_id: n_bins - bin_ht[bin_id]
               for bin_id in bin_expr})

    # Because SNV and indel rows are mutually exclusive, re-combine them into a single bin.
    # Update the global bin_stats struct to reflect the change in bin names in the table
    if compute_snv_indel_separately:
        bin_expr_no_snv = {
            bin_id.rsplit("_", 1)[0]
            for bin_id in bin_ht.bin_stats
        }
        bin_ht = bin_ht.annotate_globals(bin_stats=hl.struct(
            **{
                bin_id: hl.struct(
                    **{
                        snv: bin_ht.bin_stats[f"{bin_id}_{snv}"]
                        for snv in ["snv", "indel"]
                    })
                for bin_id in bin_expr_no_snv
            }))

        bin_ht = bin_ht.transmute(
            **{
                bin_id: hl.if_else(
                    bin_ht.snv,
                    bin_ht[f"{bin_id}_snv"],
                    bin_ht[f"{bin_id}_indel"],
                )
                for bin_id in bin_expr_no_snv
            })

    return bin_ht
Beispiel #22
0
 def test_literals_rebuild(self):
     mt = hl.utils.range_matrix_table(1, 1)
     mt = mt.annotate_rows(x = hl.cond(hl.len(hl.literal([1,2,3])) < hl.rand_unif(10, 11), mt.globals, hl.struct()))
     mt._force_count_rows()
Beispiel #23
0
 def test_literals_rebuild(self):
     mt = hl.utils.range_matrix_table(1, 1)
     mt = mt.annotate_rows(x=hl.cond(
         hl.len(hl.literal([1, 2, 3])) < hl.rand_unif(10, 11), mt.globals,
         hl.struct()))
     mt._force_count_rows()
Beispiel #24
0
def samples_qc(mt, mt_to_annotate, args):
    """
    Performs samples QC on a matrix table, removing samples on chimera and contamination %, as well as being +/- 4
    standard deviations from mean on TiTv, het/homvar, insertion/deletion ratios and n_singletons for a specific
    batch or cohort

    :param mt: matrix table, low-pass failing variants and genotypes filtered out
    :param mt_to_annotate: matrix table to annotate with failing samples information after calculating on filtered mt
    :param args:
    :return: returns annotated, unfiltered matrix table
    """
    datestr = time.strftime("%Y.%m.%d")

    # Run variant QC to get up to date variant QC metrics for samples QC
    mt = hl.sample_qc(mt)

    # Pull data to cols and checkpoint
    mt_cols = mt.cols()
    mt_cols = mt_cols.checkpoint("samples_qc_cols_tmp.ht", overwrite=True)

    # Instantiate empty array for failing samples QC tags
    mt_cols = mt_cols.annotate(failing_samples_qc=hl.empty_array(hl.tstr))

    ############################################################
    # Find samples failing on chimeras or contamination values #
    ############################################################
    mt_cols = mt_cols.annotate(failing_samples_qc=hl.cond(
        (mt_cols[args.chimeras_col] > args.chimeras_max)
        & hl.is_defined(mt_cols[args.chimeras_col]),
        mt_cols.failing_samples_qc.append(
            "failing_chimeras"), mt_cols.failing_samples_qc))

    mt_cols = mt_cols.annotate(failing_samples_qc=hl.cond(
        (mt_cols[args.contamination_col] > args.contamination_max)
        & hl.is_defined(mt_cols[args.contamination_col]),
        mt_cols.failing_samples_qc.append(
            "failing_contamination"), mt_cols.failing_samples_qc))

    failing_chim = mt_cols.aggregate(
        hl.agg.count_where(
            mt_cols.failing_samples_qc.contains("failing_chimeras")))
    miss_chim = mt_cols.aggregate(
        hl.agg.count_where(~(hl.is_defined(mt_cols[args.chimeras_col]))))
    failing_contam = mt_cols.aggregate(
        hl.agg.count_where(
            mt_cols.failing_samples_qc.contains("failing_contamination")))
    miss_contam = mt_cols.aggregate(
        hl.agg.count_where(~(hl.is_defined(mt_cols[args.contamination_col]))))

    logging.info(
        f"Number of samples failing on chimeras % > {args.chimeras_max}: {failing_chim}"
    )
    logging.info(f"Number of samples missing chimeras %: {miss_chim}")
    logging.info(
        f"Number of samples failing on contamination % > {args.contamination_max}: {failing_contam}"
    )
    logging.info(f"Number of samples missing contamination %: {miss_contam}")

    chim_stats = mt_cols.aggregate(hl.agg.stats(mt_cols[args.chimeras_col]))
    cont_stats = mt_cols.aggregate(
        hl.agg.stats(mt_cols[args.contamination_col]))
    logging.info(f"Chimeras statistics: {chim_stats}")
    logging.info(f"Contamination statistics: {cont_stats}")

    ###############################################
    # Find samples failing on sex-aware call rate #
    ###############################################
    if args.sample_call_rate is not None:
        mt_cols = mt_cols.annotate(failing_samples_qc=hl.cond(
            (mt_cols.sexaware_sample_call_rate < args.sample_call_rate)
            & hl.is_defined(mt_cols.sexaware_sample_call_rate),
            mt_cols.failing_samples_qc.append(
                "failing_sexaware_sample_call_rate"),
            mt_cols.failing_samples_qc))

        mt_cols = mt_cols.annotate(failing_samples_qc=hl.cond(
            ~(hl.is_defined(mt_cols.sexaware_sample_call_rate)),
            mt_cols.failing_samples_qc.append(
                "missing_sexaware_sample_call_rate"),
            mt_cols.failing_samples_qc))

        failing_cr = mt_cols.aggregate(
            hl.agg.count_where(
                mt_cols.failing_samples_qc.contains(
                    "failing_sexaware_sample_call_rate")))
        missing_cr = mt_cols.aggregate(
            hl.agg.count_where(
                mt_cols.failing_samples_qc.contains(
                    "missing_sexaware_sample_call_rate")))

        logging.info(
            f"Number of samples failing on sex-aware call rate > {args.sample_call_rate}: {failing_cr}"
        )
        logging.info(
            f"Number of samples missing sex-aware call rate : {missing_cr}")

        cr_stats = mt_cols.aggregate(
            hl.agg.stats(mt_cols.sexaware_sample_call_rate))

        logging.info(f"Sex-aware call rate statistics: {cr_stats}")

    ######################################################################################
    # Find samples failing per-cohort on titv, het_homvar ratio, indel, and # singletons #
    ######################################################################################
    if args.batch_col_name is not None:
        batch_none = mt_cols.aggregate(
            hl.agg.count_where(~(hl.is_defined(mt_cols[args.batch_col_name]))))
        mt_cols = mt_cols.annotate(
            **{
                args.batch_col_name:
                hl.or_else(mt_cols[args.batch_col_name], "no_batch_info")
            })

        if batch_none > 0:
            logging.info(
                f"Warning- {batch_none} samples have batch undefined. These samples will be grouped in one"
                f"batch for sample QC (named no_batch_info).")
            mt_cols.filter_cols(mt_cols[args.batch_col_name] ==
                                "no_batch_info").s.show(batch_none + 1)

        batch_set = mt_cols.aggregate(
            hl.agg.collect_as_set(mt_cols[args.batch_col_name]))
    else:
        args.batch_col_name = "mock_batch_col"
        mt_cols = mt_cols.annotate(mock_batch_col="all")
        batch_set = ["all"]

    # Convert batch strings to numeric values, create label for plotting
    batch_set_numeric = list(range(len(batch_set)))
    batch_key = list(zip(batch_set, batch_set_numeric))

    mt_cols = mt_cols.annotate(plot_batch=0)
    for batch in batch_key:
        mt_cols = mt_cols.annotate(
            plot_batch=hl.cond(mt_cols[args.batch_col_name] == batch[0],
                               batch[1], mt_cols.plot_batch))
        mt_cols = mt_cols.annotate(plot_batch_jitter=mt_cols.plot_batch +
                                   hl.rand_unif(-0.3, 0.3))

    batch_thresholds = {}
    batch_statistics = {}
    for measure in [
            'r_ti_tv', 'r_het_hom_var', 'r_insertion_deletion', 'n_singleton'
    ]:
        logging.info(f"Performing sample QC for measure {measure}")

        # Instantiate/reset box plot label
        mt_cols = mt_cols.annotate(boxplot_label=mt_cols[args.batch_col_name])

        batch_thresholds[measure] = {}
        batch_statistics[measure] = {}

        mt_cols = mt_cols.annotate(failing_samples_qc=hl.cond(
            ~(hl.is_defined(mt_cols.sample_qc[measure])),
            mt_cols.failing_samples_qc.append(f"missing_{measure}"),
            mt_cols.failing_samples_qc))

        for batch in batch_set:
            # See if values exist at all for all values
            defined_values = mt_cols.aggregate(
                hl.agg.count_where(hl.is_defined(mt_cols.sample_qc[measure])))

            if defined_values > 0:
                # Get mean and standard deviation for each measure, for each batch's samples
                stats = mt_cols.aggregate(
                    hl.agg.filter(mt_cols[args.batch_col_name] == batch,
                                  hl.agg.stats(mt_cols.sample_qc[measure])))

                # Get cutoffs for each measure
                cutoff_upper = stats.mean + (args.sampleqc_sd_threshold *
                                             stats.stdev)
                cutoff_lower = stats.mean - (args.sampleqc_sd_threshold *
                                             stats.stdev)

                if measure == "n_singleton":
                    logging.info(
                        f"Max number of singletons for batch {batch}: {stats.max}"
                    )

                mt_cols = mt_cols.annotate(failing_samples_qc=hl.cond(
                    ((mt_cols.sample_qc[measure] > cutoff_upper)
                     | (mt_cols.sample_qc[measure] < cutoff_lower))
                    & hl.is_defined(mt_cols.sample_qc[measure])
                    & (mt_cols[args.batch_col_name] == batch),
                    mt_cols.failing_samples_qc.append(
                        f"failing_{measure}"), mt_cols.failing_samples_qc))

                mt_cols = mt_cols.annotate(boxplot_label=hl.cond(
                    ((mt_cols.sample_qc[measure] > cutoff_upper)
                     | (mt_cols.sample_qc[measure] < cutoff_lower))
                    & hl.is_defined(mt_cols.sample_qc[measure])
                    & (mt_cols[args.batch_col_name] == batch), "outlier",
                    mt_cols.boxplot_label))

                # Collect thresholds and statistics for each batch
                batch_thresholds[measure][batch] = {
                    'min_thresh': cutoff_lower,
                    'max_thresh': cutoff_upper
                }
                batch_statistics[measure][batch] = stats

            else:
                logging.error(
                    f"Error- no defined values for measure {measure}. NAs can be introduced by division by "
                    f"zero. Samples not filtered on {measure}!")

        # Create plot for measure for each batch
        output_file(f"{datestr}_samples_qc_plots_{measure}.html")
        p = hl.plot.scatter(mt_cols.plot_batch_jitter,
                            mt_cols.sample_qc[measure],
                            label=mt_cols.boxplot_label,
                            title=f"{measure} values split by batch.")
        save(p)

    ##########################
    # Report failing samples #
    ##########################
    for measure in [
            'r_ti_tv', 'r_het_hom_var', 'r_insertion_deletion', 'n_singleton'
    ]:
        failing_count = mt_cols.aggregate(
            hl.agg.count_where(
                mt_cols.failing_samples_qc.contains(f"failing_{measure}")))
        missing_count = mt_cols.aggregate(
            hl.agg.count_where(
                mt_cols.failing_samples_qc.contains(f"missing_{measure}")))
        logging.info(
            f"Number of samples failing on {measure}: {failing_count}")
        logging.info(f"Number of samples missing {measure}: {missing_count}")

    failing_any = mt_cols.aggregate(
        hl.agg.count_where(hl.len(mt_cols.failing_samples_qc) != 0))
    logging.info(
        f"Number of samples failing samples QC on any measure: {failing_any}")

    if args.pheno_col is not None:
        cases_failing = mt_cols.aggregate(
            hl.agg.filter(
                mt_cols[args.pheno_col] == True,
                hl.agg.count_where(hl.len(mt_cols.failing_samples_qc) != 0)))
        controls_failing = mt_cols.aggregate(
            hl.agg.filter(
                mt_cols[args.pheno_col] == False,
                hl.agg.count_where(hl.len(mt_cols.failing_samples_qc) != 0)))
        logging.info(f"Cases failing QC: {cases_failing}")
        logging.info(f"Controls failing QC: {controls_failing}")

    #######################################################################################################
    # Annotate original (unfiltered) matrix table with failing samples QC information + sample QC measure #
    #######################################################################################################
    mt_to_annotate = mt_to_annotate.annotate_cols(
        sample_qc=mt_cols[mt_to_annotate.s].sample_qc)
    mt_to_annotate = mt_to_annotate.annotate_cols(
        failing_samples_qc=mt_cols[mt_to_annotate.s].failing_samples_qc)

    mt_to_annotate = mt_to_annotate.annotate_globals(
        samples_qc_stats_batches=batch_statistics)
    mt_to_annotate = mt_to_annotate.annotate_globals(
        samples_qc_stats_chim_cont={
            'chimeras': chim_stats,
            'contamination': cont_stats
        })
    mt_to_annotate = mt_to_annotate.annotate_globals(
        samples_qc_thresholds={
            'chimeras_max': str(args.chimeras_max),
            'contamination_max': str(args.contamination_max),
            'deviation_multiplier_threshold': str(args.sampleqc_sd_threshold),
            'batches': str(batch_set),
            'batch_cohort_name': str(args.batch_col_name)
        })

    mt_to_annotate = mt_to_annotate.annotate_globals(
        samples_qc_batch_thresholds=batch_thresholds)

    return mt_to_annotate
Beispiel #25
0
def compute_ranked_bin(
    ht: hl.Table,
    score_expr: hl.expr.NumericExpression,
    bin_expr: Dict[str, hl.expr.BooleanExpression] = {"bin": True},
    compute_snv_indel_separately: bool = True,
    n_bins: int = 100,
    desc: bool = True,
) -> hl.Table:
    r"""
    Return a table with a bin for each row based on the ranking of `score_expr`.

    The bin is computed by dividing the `score_expr` into `n_bins` bins containing approximately equal numbers of elements.
    This is done by ranking the rows by `score_expr` (and a random number in cases where multiple variants have the same score)
    and then assigning the variant to a bin based on its ranking.

    If `compute_snv_indel_separately` is True all items in `bin_expr` will be stratified by snv / indels for the ranking and
    bin calculation. Because SNV and indel rows are mutually exclusive, they are re-combined into a single annotation. For
    example if we have the following four variants and scores and `n_bins` of 2:

    ========   =======   ======   =================   =================
    Variant    Type      Score    bin - `compute_snv_indel_separately`:
    --------   -------   ------   -------------------------------------
    \          \         \        False               True
    ========   =======   ======   =================   =================
    Var1       SNV       0.1      1                   1
    Var2       SNV       0.2      1                   2
    Var3       Indel     0.3      2                   1
    Var4       Indel     0.4      2                   2
    ========   =======   ======   =================   =================

    .. note::

        The `bin_expr` defines which data the bin(s) should be computed on. E.g., to get biallelic specific binning
        and singleton specific binning, the following could be used:

        .. code-block:: python

            bin_expr={
                'biallelic_bin': ~ht.was_split,
                'singleton_bin': ht.singleton
            }

    :param ht: Input Table
    :param score_expr: Expression containing the score
    :param bin_expr: Specific row grouping(s) to perform ranking and binning on (see note)
    :param compute_snv_indel_separately: Should all `bin_expr` items be stratified by SNVs / indels
    :param n_bins: Number of bins to bin the data into
    :param desc: Whether to bin the score in descending order
    :return: Table with the requested bin annotations
    """
    if compute_snv_indel_separately:
        # For each bin, add a SNV / indel stratification
        bin_expr = {
            f"{bin_id}_{snv}": (bin_expr & snv_expr)
            for bin_id, bin_expr in bin_expr.items() for snv, snv_expr in [
                ("snv", hl.is_snp(ht.alleles[0], ht.alleles[1])),
                ("indel", ~hl.is_snp(ht.alleles[0], ht.alleles[1])),
            ]
        }

    bin_ht = ht.select(
        **{
            f"_filter_{bin_id}": bin_expr
            for bin_id, bin_expr in bin_expr.items()
        },
        _score=score_expr,
        snv=hl.is_snp(ht.alleles[0], ht.alleles[1]),
        _rand=hl.rand_unif(0, 1),
    )

    logger.info(
        "Sorting the HT by score_expr followed by a random float between 0 and 1. "
        "Then adding a row index per grouping defined by bin_expr...")
    bin_ht = bin_ht.order_by("_score", "_rand")
    bin_ht = bin_ht.annotate(
        **{
            f"{bin_id}_rank": hl.or_missing(
                bin_ht[f"_filter_{bin_id}"],
                hl.scan.count_where(bin_ht[f"_filter_{bin_id}"]),
            )
            for bin_id in bin_expr
        })
    bin_ht = bin_ht.key_by("locus", "alleles")

    # Annotate globals with variant counts per group defined by bin_expr. This is used to determine bin assignment
    bin_ht = bin_ht.annotate_globals(bin_group_variant_counts=bin_ht.aggregate(
        hl.Struct(
            **{
                bin_id: hl.agg.filter(
                    bin_ht[f"_filter_{bin_id}"],
                    hl.agg.count(),
                )
                for bin_id in bin_expr
            })))

    logger.info("Binning ranked rows into %d bins...", n_bins)
    bin_ht = bin_ht.select(
        "snv",
        **{
            bin_id: hl.int(
                hl.floor(
                    (n_bins *
                     (bin_ht[f"{bin_id}_rank"] /
                      hl.float64(bin_ht.bin_group_variant_counts[bin_id]))) +
                    1))
            for bin_id in bin_expr
        },
    )

    if desc:
        bin_ht = bin_ht.annotate(
            **{bin_id: n_bins - bin_ht[bin_id] + 1
               for bin_id in bin_expr})

    # Because SNV and indel rows are mutually exclusive, re-combine them into a single bin.
    # Update the global bin_group_variant_counts struct to reflect the change in bin names in the table
    if compute_snv_indel_separately:
        bin_expr_no_snv = {
            bin_id.rsplit("_", 1)[0]
            for bin_id in bin_ht.bin_group_variant_counts
        }
        bin_ht = bin_ht.annotate_globals(bin_group_variant_counts=hl.struct(
            **{
                bin_id: hl.struct(
                    **{
                        snv: bin_ht.bin_group_variant_counts[f"{bin_id}_{snv}"]
                        for snv in ["snv", "indel"]
                    })
                for bin_id in bin_expr_no_snv
            }))

        bin_ht = bin_ht.transmute(
            **{
                bin_id: hl.if_else(
                    bin_ht.snv,
                    bin_ht[f"{bin_id}_snv"],
                    bin_ht[f"{bin_id}_indel"],
                )
                for bin_id in bin_expr_no_snv
            })

    return bin_ht
Beispiel #26
0
 def test_partitioning_rewrite(self):
     ht = hl.utils.range_table(10, 3)
     ht1 = ht.annotate(x=hl.rand_unif(0, 1))
     self.assertEqual(ht1.x.collect()[:5], ht1.head(5).x.collect())
Beispiel #27
0
 def test_partitioning_rewrite(self):
     ht = hl.utils.range_table(10, 3)
     ht1 = ht.annotate(x=hl.rand_unif(0, 1))
     self.assertEqual(ht1.x.collect()[:5], ht1.head(5).x.collect())
Beispiel #28
0
 def make_random_function(self, mt, h2, pi):  #pi is slab prob
     M = mt.count_rows()  # number of variants
     if (hl.rand_unif(0, 1) < pi):
         return hl.rand_norm(0, h2 / (M * pi))
     else:
         return 0