def prepare_exac_constraint(exac_constraint_path):
    ds = hl.import_table(exac_constraint_path, force=True)

    ds = ds.select(
        transcript_id=ds.transcript.split("\\.")[0],
        # Expected
        exp_syn=hl.float(ds.exp_syn),
        exp_mis=hl.float(ds.exp_mis),
        exp_lof=hl.float(ds.exp_lof),
        # Actual
        obs_syn=hl.int(ds.n_syn),
        obs_mis=hl.int(ds.n_mis),
        obs_lof=hl.int(ds.n_lof),
        # mu
        mu_syn=hl.float(ds.mu_syn),
        mu_mis=hl.float(ds.mu_mis),
        mu_lof=hl.float(ds.mu_lof),
        # Z
        syn_z=hl.float(ds.syn_z),
        mis_z=hl.float(ds.mis_z),
        lof_z=hl.float(ds.lof_z),
        # Other
        pLI=hl.float(ds.pLI),
    )

    ds = ds.key_by("transcript_id")

    return ds
示例#2
0
def prepare_exac_constraint(path):
    ds = hl.import_table(path, force=True)
    ds = ds.repartition(32, shuffle=True)

    # Select relevant fields
    ds = ds.select(
        # Remove version number from transcript ID
        transcript_id=ds.transcript.split("\\.")[0],
        # Expected
        exp_syn=hl.float(ds.exp_syn),
        exp_mis=hl.float(ds.exp_mis),
        exp_lof=hl.float(ds.exp_lof),
        # Actual
        obs_syn=hl.int(ds.n_syn),
        obs_mis=hl.int(ds.n_mis),
        obs_lof=hl.int(ds.n_lof),
        # mu
        mu_syn=hl.float(ds.mu_syn),
        mu_mis=hl.float(ds.mu_mis),
        mu_lof=hl.float(ds.mu_lof),
        # Z
        syn_z=hl.float(ds.syn_z),
        mis_z=hl.float(ds.mis_z),
        lof_z=hl.float(ds.lof_z),
        # Other
        pli=hl.float(ds.pLI),
    )

    ds = ds.key_by("transcript_id")

    return ds
def format_exac_constraint(ds):
    # Select relevant fields
    ds = ds.select(
        transcript_id=ds.transcript.split("\\.")[0],
        # Expected
        exp_syn=hl.float(ds.exp_syn),
        exp_mis=hl.float(ds.exp_mis),
        exp_lof=hl.float(ds.exp_lof),
        # Actual
        obs_syn=hl.int(ds.n_syn),
        obs_mis=hl.int(ds.n_mis),
        obs_lof=hl.int(ds.n_lof),
        # mu
        mu_syn=hl.float(ds.mu_syn),
        mu_mis=hl.float(ds.mu_mis),
        mu_lof=hl.float(ds.mu_lof),
        # Z
        syn_z=hl.float(ds.syn_z),
        mis_z=hl.float(ds.mis_z),
        lof_z=hl.float(ds.lof_z),
        # Other
        pLI=hl.float(ds.pLI),
    )

    ds = ds.key_by("transcript_id")

    return ds
def intersect_target_ref(ref_mt_filt,
                         snp_list,
                         grch37_or_grch38,
                         intersect_out,
                         overwrite: bool = False):
    mt = hl.read_matrix_table(ref_mt_filt)
    if grch37_or_grch38.lower() == 'grch38':
        snp_list = snp_list.key_by(locus=hl.locus(hl.str(snp_list.chr),
                                                  hl.int(snp_list.pos),
                                                  reference_genome='GRCh38'),
                                   alleles=[snp_list.ref, snp_list.alt])
        mt = mt.filter_rows(hl.is_defined(snp_list[mt.row_key]))

    elif grch37_or_grch38.lower() == 'grch37':
        snp_list = snp_list.key_by(locus=hl.locus(hl.str(snp_list.chr),
                                                  hl.int(snp_list.pos),
                                                  reference_genome='GRCh37'),
                                   alleles=[snp_list.ref, snp_list.alt])
        # liftover snp list to GRCh38, filter to SNPs in mt
        rg37, rg38 = load_liftover()

        snp_liftover = snp_list.annotate(
            new_locus=hl.liftover(snp_list.locus, 'GRCh38'))
        snp_liftover = snp_liftover.filter(
            hl.is_defined(snp_liftover.new_locus))
        snp_liftover = snp_liftover.key_by(locus=snp_liftover.new_locus,
                                           alleles=snp_liftover.alleles)
        mt = mt.filter_rows(hl.is_defined(snp_liftover[mt.row_key]))

    mt = mt.repartition(5000)
    mt = mt.checkpoint(intersect_out,
                       overwrite=overwrite,
                       _read_if_exists=not overwrite)
示例#5
0
def import_vqsr(
    vqsr_path: str,
    vqsr_type: str = "alleleSpecificTrans",
    num_partitions: int = 5000,
    overwrite: bool = False,
    import_header_path: Optional[str] = None,
) -> None:
    """
    Imports vqsr site vcf into a HT
    :param vqsr_path: Path to input vqsr site vcf. This can be specified as Hadoop glob patterns
    :param vqsr_type: One of `classic`, `alleleSpecific` (allele specific) or `alleleSpecificTrans`
        (allele specific with transmitted singletons)
    :param num_partitions: Number of partitions to use for the VQSR HT
    :param overwrite: Whether to overwrite imported VQSR HT
    :param import_header_path: Optional path to a header file to use for import
    :return: None
    """

    logger.info(f"Importing VQSR annotations for {vqsr_type} VQSR...")
    mt = hl.import_vcf(
        vqsr_path,
        force_bgz=True,
        reference_genome="GRCh38",
        header_file=import_header_path,
    ).repartition(num_partitions)

    ht = mt.rows()

    ht = ht.annotate(info=ht.info.annotate(
        AS_VQSLOD=ht.info.AS_VQSLOD.map(lambda x: hl.float(x)),
        AS_QUALapprox=ht.info.AS_QUALapprox.split("\|")[1:].map(
            lambda x: hl.int(x)),
        AS_VarDP=ht.info.AS_VarDP.split("\|")[1:].map(lambda x: hl.int(x)),
        AS_SB_TABLE=ht.info.AS_SB_TABLE.split("\|").map(
            lambda x: x.split(",").map(lambda y: hl.int(y))),
    ))

    ht = ht.checkpoint(
        get_vqsr_filters(f"vqsr_{vqsr_type}", split=False,
                         finalized=False).path,
        overwrite=overwrite,
    )

    unsplit_count = ht.count()
    ht = hl.split_multi_hts(ht)

    ht = ht.annotate(
        info=ht.info.annotate(**split_info_annotation(ht.info, ht.a_index)), )

    ht = ht.checkpoint(
        get_vqsr_filters(f"vqsr_{vqsr_type}", split=True,
                         finalized=False).path,
        overwrite=overwrite,
    )
    split_count = ht.count()
    logger.info(
        f"Found {unsplit_count} unsplit and {split_count} split variants with VQSR annotations"
    )
def prepare_gtex_expression_data(transcript_tpms_path, sample_annotations_path,
                                 tmp_path):
    # Recompress tpms file with block gzip so that import_matrix_table will read the file
    ds = hl.import_table(transcript_tpms_path, force=True)
    tmp_transcript_tpms_path = tmp_path + "/" + transcript_tpms_path.split(
        "/")[-1].replace(".gz", ".bgz")
    ds.export(tmp_transcript_tpms_path)

    # Import data
    ds = hl.import_matrix_table(
        tmp_transcript_tpms_path,
        row_fields={
            "transcript_id": hl.tstr,
            "gene_id": hl.tstr
        },
        entry_type=hl.tfloat,
    )
    ds = ds.rename({"col_id": "sample_id"})
    ds = ds.repartition(1000, shuffle=True)

    samples = hl.import_table(sample_annotations_path, key="SAMPID")

    # Separate version numbers from transcript and gene IDs
    ds = ds.annotate_rows(
        transcript_id=ds.transcript_id.split(r"\.")[0],
        transcript_version=hl.int(ds.transcript_id.split(r"\.")[1]),
        gene_id=ds.gene_id.split(r"\.")[0],
        gene_version=hl.int(ds.gene_id.split(r"\.")[1]),
    )

    # Annotate columns with the tissue the sample came from
    ds = ds.annotate_cols(tissue=samples[ds.sample_id].SMTSD)

    # Collect expression into median across all samples in each tissue
    ds = ds.group_cols_by(ds.tissue).aggregate(**{
        "": hl.agg.approx_median(ds.x)
    }).make_table()

    # Format tissue names
    other_fields = {
        "transcript_id", "transcript_version", "gene_id", "gene_version"
    }
    tissues = [f for f in ds.row_value.dtype.fields if f not in other_fields]
    ds = ds.transmute(tissues=hl.struct(
        **{format_tissue_name(tissue): ds[tissue]
           for tissue in tissues}))

    ds = ds.key_by("transcript_id").drop("row_id")

    return ds
示例#7
0
def specific_clumps(filename):
    clump = hl.import_table(filename,
                            delimiter='\s+',
                            min_partitions=10,
                            types={'P': hl.tfloat})
    clump = clump.key_by(locus=hl.locus(hl.str(clump.CHR), hl.int(clump.BP)))
    return clump
def make_sample_rank_table(phe_ht: hl.Table) -> hl.Table:
    """
    Make table with rank of sample sorted by retention priority
    (lower rank has higher priority).
    It mainly uses two bits of information:
      - cases are prioritised over controls
      - samples are preferred based on the cohort info as follow: chd > ddd > ukbb
    :param phe_ht: Table with sample meta-data annotations (e.g. phenotype, cohort info...)
    :return: Hail Table
    """

    phe_ht = (
        phe_ht.annotate(
            case_control_rank=hl.int(
                phe_ht['phe.is_case']),  # 0: control, 1: cases
            cohort_rank=hl.case().when(phe_ht.is_ukbb, 10).when(
                phe_ht.is_ddd, 100).when(phe_ht.is_chd,
                                         1000).or_missing()).key_by())

    phe_ht = (phe_ht.select('ega_id', 'case_control_rank', 'cohort_rank'))

    # sort table (descending)
    tb_rank = (phe_ht.order_by(hl.desc(phe_ht.case_control_rank),
                               hl.desc(phe_ht.cohort_rank)))

    tb_rank = (tb_rank.add_index(name='rank').key_by('ega_id'))

    tb_rank = tb_rank.annotate(rank=tb_rank.rank + 1)

    return tb_rank
示例#9
0
def get_codings():
    """
    Read codings data from Duncan's repo and load into hail Table

    :return: Hail table with codings
    :rtype: Table
    """
    root = f'{tempfile.gettempdir()}/PHESANT'
    if subprocess.check_call(
        ['git', 'clone', 'https://github.com/astheeggeggs/PHESANT.git', root]):
        raise Exception('Could not clone repo')
    hts = []
    coding_dir = f'{root}/WAS/codings'
    for coding_file in os.listdir(f'{coding_dir}'):
        hl.hadoop_copy(f'file://{coding_dir}/{coding_file}',
                       f'{coding_dir}/{coding_file}')
        ht = hl.import_table(f'{coding_dir}/{coding_file}')
        if 'node_id' not in ht.row:
            ht = ht.annotate(node_id=hl.null(hl.tstr),
                             parent_id=hl.null(hl.tstr),
                             selectable=hl.null(hl.tstr))
        ht = ht.annotate(
            coding_id=hl.int(coding_file.split('.')[0].replace('coding', '')))
        hts.append(ht)
    full_ht = hts[0].union(*hts[1:]).key_by('coding_id', 'coding')
    return full_ht.repartition(10)
示例#10
0
def get_all_codings():
    """
    Download all coding data files from UKB website
    """
    import requests
    coding_prefix = '/tmp/coding'
    all_codings = requests.post(url='http://biobank.ndph.ox.ac.uk/showcase/scdown.cgi', data={'fmt': 'txt', 'id': 2})
    all_codings = all_codings.text.strip().split('\n')[1:]
    hts = []
    for coding_list in all_codings:
        coding = coding_list.split('\t')[0]
        r = requests.post(url='http://biobank.ndph.ox.ac.uk/showcase/codown.cgi', data={'id': coding})
        req_data = r.text
        if r.status_code != 200 or not req_data or req_data.startswith('<!DOCTYPE HTML>'):
            print(f'Issue with {coding}: {r.text}')
            continue
        with open(f'{coding_prefix}{coding}.tsv', 'w') as f:
            f.write(req_data)
        hl.hadoop_copy(f'file://{coding_prefix}{coding}.tsv', f'{coding_prefix}{coding}.tsv')
        ht = hl.import_table(f'{coding_prefix}{coding}.tsv')
        if 'node_id' not in ht.row:
            ht = ht.annotate(node_id=hl.null(hl.tstr), parent_id=hl.null(hl.tstr), selectable=hl.null(hl.tstr))
        ht = ht.annotate(coding_id=hl.int(coding))
        hts.append(ht)
    full_ht = hts[0].union(*hts[1:]).key_by('coding_id', 'coding')
    return full_ht.repartition(10)
示例#11
0
def maf_filter(mt, maf, filter_ac0_after_pruning=False):
    """
    Takes matrix table, filters out failing genotypes, variants, and samples, and MAF prunes the
    table, and returns the matrix table

    :param mt: matrix table to prune (should be LD pruned and have x chrom removed).
    :param filter_ac0_after_pruning: filter variants no longer in the data, e.g. sum(AC) = 0?
    :return: returns maf filtered matrix table.
    """

    # Run hl.variant_qc() to get AFs
    mt = hl.variant_qc(mt)

    # Filter MAF
    logging.info(f'Filtering out variants with minor allele frequency < {maf}')
    mt = mt.filter_rows(mt.row.variant_qc.AF[1] > maf, keep=True)
    mt = mt.annotate_globals(maf_threshold_LDpruning=maf)

    if filter_ac0_after_pruning:
        logging.info(
            'Removing variants with alt allele count = 0 (monomorphic variants).'
        )
        mt = hl.variant_qc(mt)
        mt = mt.filter_rows(hl.sum(mt.row.variant_qc.AC) == hl.int(0),
                            keep=False)
        count = mt.count()
        logging.info(
            f"MT count after removing monomorphic variants and MAF filtering: {count}"
        )
    else:
        logging.info("MAF pruned mt count:" + str(mt.count()))

    return mt
示例#12
0
文件: misc.py 项目: tpoterba/hail
def rename_duplicates(dataset, name='unique_id') -> MatrixTable:
    """Rename duplicate column keys.

    .. include:: ../_templates/req_tstring.rst

    Examples
    --------

    >>> renamed = hl.rename_duplicates(dataset).cols()
    >>> duplicate_samples = (renamed.filter(renamed.s != renamed.unique_id)
    ...                             .select()
    ...                             .collect())

    Notes
    -----

    This method produces a new column field from the string column key by
    appending a unique suffix ``_N`` as necessary. For example, if the column
    key "NA12878" appears three times in the dataset, the first will produce
    "NA12878", the second will produce "NA12878_1", and the third will produce
    "NA12878_2". The name of this new field is parameterized by `name`.

    Parameters
    ----------
    dataset : :class:`.MatrixTable`
        Dataset.
    name : :obj:`str`
        Name of new field.

    Returns
    -------
    :class:`.MatrixTable`
    """

    require_col_key_str(dataset, 'rename_duplicates')
    ids = dataset.col_key[0].collect()
    uniques = set()
    mapping = []
    new_ids = []

    fmt = lambda s, i: '{}_{}'.format(s, i)
    for s in ids:
        s_ = s
        i = 0
        while s_ in uniques:
            i += 1
            s_ = fmt(s, i)

        if s_ != s:
            mapping.append((s, s_))
        uniques.add(s_)
        new_ids.append(s_)

    if mapping:
        info(f'Renamed {len(mapping)} duplicate {plural("sample ID", len(mapping))}. Mangled IDs as follows:' +
             ''.join(f'\n  "{pre}" => "{post}"' for pre, post in mapping))
    else:
        info('No duplicate sample IDs found.')
    uid = Env.get_uid()
    return dataset.annotate_cols(**{name: hl.literal(new_ids)[hl.int(hl.scan.count())]})
示例#13
0
def main(args):
    full_vcf = hl.read_matrix_table(args.allreads_prefix + '.mt')

    # liftover chains
    rg37 = hl.get_reference('GRCh37')
    rg38 = hl.get_reference('GRCh38')
    rg37.add_liftover(
        'gs://hail-common/references/grch37_to_grch38.over.chain.gz', rg38)

    chips = hl.hadoop_open(args.chip_loci)
    chip_dict = {}
    for chip in chips:
        chip = chip.strip().split()
        chip_pos = hl.import_table(chip[1],
                                   filter='\[Controls\]',
                                   skip_blank_lines=True)
        chip_pos = chip_pos.filter(
            hl.array(list(map(str, range(1, 23))) + ['X', 'Y']).contains(
                chip_pos.chr))
        chip_pos = chip_pos.key_by(
            locus=hl.locus(chip_pos.chr, hl.int(chip_pos.pos)))

        #  liftover chip position info
        chip_pos = chip_pos.annotate(
            new_locus=hl.liftover(chip_pos.locus, 'GRCh38'))
        chip_pos = chip_pos.filter(hl.is_defined(chip_pos.new_locus))
        chip_pos = chip_pos.key_by(locus=chip_pos.new_locus)

        # filter full vcf to sites in genotype data
        geno_vcf = full_vcf.filter_rows(hl.is_defined(
            chip_pos[full_vcf.locus]))
        hl.export_vcf(
            geno_vcf,
            'gs://neurogap/high_coverage/NeuroGap_30x_' + chip[0] + '.vcf.bgz')
示例#14
0
def specific_clumps(filename):
    clump = hl.import_table(filename, delimiter='\s+', min_partitions=10, types={'P': hl.tfloat})
    clump_dict = clump.aggregate(hl.dict(hl.agg.collect(
        (hl.locus(hl.str(clump.CHR), hl.int(clump.BP)),
        True)
    )), _localize=False)
    return clump_dict
示例#15
0
def run_platform_pca(
    callrate_mt: hl.MatrixTable,
    binarization_threshold: Optional[float] = 0.25
) -> Tuple[List[float], hl.Table, hl.Table]:
    """
    Runs a PCA on a sample/interval MT with each entry containing the call rate.
    When `binzarization_threshold` is set, the callrate is transformed to a 0/1 value based on the threshold.
    E.g. with the default threshold of 0.25, all entries with a callrate < 0.25 are considered as 0s, others as 1s.

    :param callrate_mt: Input callrate MT
    :param binarization_threshold: binzarization_threshold. None is no threshold desired
    :return: eigenvalues, scores_ht, loadings_ht
    """
    logger.info("Running platform PCA")

    if binarization_threshold is not None:
        callrate_mt = callrate_mt.annotate_entries(callrate=hl.int(
            callrate_mt.callrate > binarization_threshold))
    # Center until Hail's PCA does it for you
    callrate_mt = callrate_mt.annotate_rows(
        mean_callrate=hl.agg.mean(callrate_mt.callrate))
    callrate_mt = callrate_mt.annotate_entries(callrate=callrate_mt.callrate -
                                               callrate_mt.mean_callrate)
    eigenvalues, scores, loadings = hl.pca(
        callrate_mt.callrate, compute_loadings=True
    )  # TODO:  Evaluate whether computing loadings is a good / worthy thing
    logger.info("Platform PCA eigenvalues: {}".format(eigenvalues))

    return eigenvalues, scores, loadings
示例#16
0
def contig_number(contig: hl.expr.StringExpression) -> hl.expr.Int32Expression:
    return hl.bind(
        lambda contig:
        (hl.case().when(contig == "X", 23).when(contig == "Y", 24).when(
            contig == "M", 25).default(hl.int(contig))),
        normalized_contig(contig),
    )
示例#17
0
 def compute_element(absolute):
     return hl.rbind(
         absolute % n_rows,
         absolute // n_rows,
         lambda row, col: hl.range(hl.int(n_inner)).map(
             lambda inner: multiply(left[row, inner], right[inner, col])
         ).fold(add, zero))
示例#18
0
def run_pipeline(args):
    hl.init(log='./hail_annotation_pipeline.log')
    '''
	rg = hl.get_reference('GRCh37')
	grch37_contigs = [x for x in rg.contigs if not x.startswith('GL') and not x.startswith('M')]
	contig_dict = dict(zip(grch37_contigs, ['chr'+x for x in grch37_contigs]))
	'''

    exome_intervals = hl.import_locus_intervals(
        '/gpfs/ycga/project/lek/shared/resources/hg38/exome_evaluation_regions.v1.interval_list',
        reference_genome='GRCh38')

    #mt = hl.import_vcf(args.vcf,reference_genome='GRCh38',contig_recoding=contig_dict,array_elements_required=False,force_bgz=True,filter='MONOALLELIC')
    mt = hl.import_vcf(args.vcf,
                       reference_genome='GRCh38',
                       array_elements_required=False,
                       force_bgz=True,
                       filter='MONOALLELIC')

    mt = mt.filter_rows(hl.is_defined(exome_intervals[mt.locus]))

    pprint.pprint(mt.describe())
    pprint.pprint(mt.show())

    mt = mt.repartition(hl.eval(hl.int(mt.n_partitions() / 10)))

    mt.write(args.out, overwrite=True)
def import_key(ss_filename, ss_keys, clump_name):
    keys = ss_keys.split(',')
    ss = hl.import_table(ss_filename,
                         impute=True,
                         delimiter='\s+',
                         types={
                             keys[1]: hl.tfloat,
                             keys[0]: hl.tstr
                         },
                         min_partitions=100)
    clump = hl.import_table(clump_name,
                            delimiter='\s+',
                            min_partitions=10,
                            types={
                                'P': hl.tfloat,
                                'CHR': hl.tstr,
                                'BP': hl.tint
                            })
    clump = clump.key_by(locus=hl.locus(clump.CHR, clump.BP))
    clump = clump.filter(clump.P < 5e-8)
    ss = ss.annotate(**{keys[1]: hl.int(ss[keys[1]])})
    chroms = set(map(str, range(1, 23)))
    ss = ss.filter(hl.literal(chroms).contains(ss[keys[0]]))
    ss = ss.annotate(locus=hl.locus(hl.str(ss[keys[0]]), ss[keys[1]]),
                     alleles=[ss[keys[2]], ss[keys[3]]])
    ss = ss.key_by(ss.locus)
    ss = ss.annotate(clump=hl.is_defined(clump[ss.key]))
    ss = ss.key_by(ss.locus, ss.alleles)
    p = keys[-1]
    return ss, p
示例#20
0
def make_sumstats_bm(sumstats_bm_path, high_quality):
    meta_mt = hl.read_matrix_table(get_meta_analysis_results_path())
    clump_mt = hl.read_matrix_table(
        get_clumping_results_path(high_quality_only=high_quality)).rename(
            {'pop': 'clump_pops'})
    mt = all_axis_join(meta_mt, clump_mt)
    mt = separate_results_mt_by_pop(mt,
                                    'clump_pops',
                                    'plink_clump',
                                    skip_drop=True)
    mt = separate_results_mt_by_pop(mt,
                                    'meta_analysis_data',
                                    'meta_analysis',
                                    skip_drop=True)
    mt = mt.filter_cols(mt.meta_analysis_data.pop == mt.clump_pops)
    mt = explode_by_p_threshold(mt).unfilter_entries()

    mt = mt.filter_cols((mt.description == 'Type 2 diabetes')
                        & (mt.p_threshold == 1))

    BlockMatrix.write_from_entry_expr(hl.or_else(
        mt.meta_analysis.BETA * hl.is_defined(mt.plink_clump.TOTAL) *
        hl.int(mt.meta_analysis.Pvalue < mt.p_threshold), 0.0),
                                      sumstats_bm_path,
                                      overwrite=True)
示例#21
0
文件: misc.py 项目: zhouhufeng/hail
def rename_duplicates(dataset, name='unique_id') -> MatrixTable:
    """Rename duplicate column keys.

    .. include:: ../_templates/req_tstring.rst

    Examples
    --------

    >>> renamed = hl.rename_duplicates(dataset).cols()
    >>> duplicate_samples = (renamed.filter(renamed.s != renamed.unique_id)
    ...                             .select()
    ...                             .collect())

    Notes
    -----

    This method produces a new column field from the string column key by
    appending a unique suffix ``_N`` as necessary. For example, if the column
    key "NA12878" appears three times in the dataset, the first will produce
    "NA12878", the second will produce "NA12878_1", and the third will produce
    "NA12878_2". The name of this new field is parameterized by `name`.

    Parameters
    ----------
    dataset : :class:`.MatrixTable`
        Dataset.
    name : :obj:`str`
        Name of new field.

    Returns
    -------
    :class:`.MatrixTable`
    """

    require_col_key_str(dataset, 'rename_duplicates')
    ids = dataset.col_key[0].collect()
    uniques = set()
    mapping = []
    new_ids = []

    fmt = lambda s, i: '{}_{}'.format(s, i)
    for s in ids:
        s_ = s
        i = 0
        while s_ in uniques:
            i += 1
            s_ = fmt(s, i)

        if s_ != s:
            mapping.append((s, s_))
        uniques.add(s_)
        new_ids.append(s_)

    if mapping:
        info(f'Renamed {len(mapping)} duplicate {plural("sample ID", len(mapping))}. Mangled IDs as follows:' +
             ''.join(f'\n  "{pre}" => "{post}"' for pre, post in mapping))
    else:
        info('No duplicate sample IDs found.')
    uid = Env.get_uid()
    return dataset.annotate_cols(**{name: hl.literal(new_ids)[hl.int(hl.scan.count())]})
示例#22
0
def split_position_end(position):
    return hl.or_missing(
        hl.is_defined(position),
        hl.bind(
            lambda start: hl.cond(start == "?", hl.null(hl.tint), hl.int(start)
                                  ),
            position.split("-")[-1]),
    )
示例#23
0
文件: array.py 项目: tuyanglin/hail
 def block_product(left, right):
     product = left @ right
     n_rows, n_cols = product.shape
     return hl.struct(
         shape=product.shape,
         block=hl.range(hl.int(
             n_rows * n_cols)).map(lambda absolute: product[
                 absolute % n_rows, absolute // n_rows]))
示例#24
0
 def get_lgt(e, n_alleles, has_non_ref, row):
     index = e.GT.unphased_diploid_gt_index()
     n_no_nonref = n_alleles - hl.int(has_non_ref)
     triangle_without_nonref = hl.triangle(n_no_nonref)
     return (hl.case().when(index < triangle_without_nonref, e.GT).when(
         index < hl.triangle(n_alleles),
         hl.null('call')).or_error('invalid GT ' + hl.str(e.GT) +
                                   ' at site ' + hl.str(row.locus)))
示例#25
0
def get_expr_for_xpos(contig, position):
    return hl.bind(
        lambda contig_number: hl.int64(contig_number) * 1_000_000_000 +
        position,
        hl.case().when(contig == "X",
                       23).when(contig == "Y",
                                24).when(contig[0] == "M",
                                         25).default(hl.int(contig)),
    )
def xpos(chrom, position):
    contig_number = (
        hl.case()
        .when(chrom == "X", 23)
        .when(chrom == "Y", 24)
        .when(chrom[0] == "M", 25)
        .default(hl.int(chrom))
    )
    return hl.int64(contig_number) * 1_000_000_000 + position
示例#27
0
def get_expr_for_contig_number(
        locus: hl.expr.LocusExpression) -> hl.expr.Int32Expression:
    """Convert contig name to contig number"""
    return hl.bind(
        lambda contig:
        (hl.case().when(contig == "X", 23).when(contig == "Y", 24).when(
            contig[0] == "M", 25).default(hl.int(contig))),
        get_expr_for_contig(locus),
    )
示例#28
0
def download_data():
    global _data_dir, _mt
    _data_dir = os.environ.get('HAIL_BENCHMARK_DIR',
                               '/tmp/hail_benchmark_data')
    print(f'using benchmark data directory {_data_dir}')
    os.makedirs(_data_dir, exist_ok=True)

    files = map(lambda f: os.path.join(_data_dir, f), [
        'profile.vcf.bgz', 'profile.mt', 'table_10M_par_1000.ht',
        'table_10M_par_100.ht', 'table_10M_par_10.ht',
        'gnomad_dp_simulation.mt', 'many_strings_table.ht'
    ])
    if not all(os.path.exists(file) for file in files):
        hl.init()  # use all cores

        vcf = os.path.join(_data_dir, 'profile.vcf.bgz')
        print('files not found - downloading...', end='', flush=True)
        urlretrieve(
            'https://storage.googleapis.com/hail-common/benchmark/profile.vcf.bgz',
            vcf)
        print('done', flush=True)
        print('importing...', end='', flush=True)
        hl.import_vcf(vcf, min_partitions=16).write(os.path.join(
            _data_dir, 'profile.mt'),
                                                    overwrite=True)

        ht = hl.utils.range_table(
            10_000_000,
            1000).annotate(**{f'f_{i}': hl.rand_unif(0, 1)
                              for i in range(5)})
        ht = ht.checkpoint(os.path.join(_data_dir, 'table_10M_par_1000.ht'),
                           overwrite=True)
        ht = ht.naive_coalesce(100).checkpoint(os.path.join(
            _data_dir, 'table_10M_par_100.ht'),
                                               overwrite=True)
        ht.naive_coalesce(10).write(os.path.join(_data_dir,
                                                 'table_10M_par_10.ht'),
                                    overwrite=True)

        mt = hl.utils.range_matrix_table(n_rows=250_000,
                                         n_cols=1_000,
                                         n_partitions=32)
        mt = mt.annotate_entries(x=hl.int(hl.rand_unif(0, 4.5)**3))
        mt.write(os.path.join(_data_dir, 'gnomad_dp_simulation.mt'),
                 overwrite=True)

        print('downloading many strings table...')
        mst_tsv = os.path.join(_data_dir, 'many_strings_table.tsv.bgz')
        mst_ht = os.path.join(_data_dir, 'many_strings_table.ht')
        urlretrieve(
            'https://storage.googleapis.com/hail-common/benchmark/many_strings_table.tsv.bgz',
            mst_tsv)
        print('importing...')
        hl.import_table(mst_tsv).write(mst_ht, overwrite=True)
        hl.stop()
    else:
        print('all files found.', flush=True)
示例#29
0
 def _create(self, resource_dir):
     logging.info('creating gnomad_dp_simulation matrix table...')
     mt = hl.utils.range_matrix_table(n_rows=250_000,
                                      n_cols=1_000,
                                      n_partitions=32)
     mt = mt.annotate_entries(x=hl.int(hl.rand_unif(0, 4.5)**3))
     mt.write(os.path.join(resource_dir, 'gnomad_dp_simulation.mt'),
              overwrite=True)
     logging.info('done creating gnomad_dp_simulation matrix table.')
def geno_stats(ht_dict, geno_prefix):
    """
    computes numbers of hom ref, het, and hom var variants present in data
    :param ht_dict:
    :param geno_prefix:
    :return:
    """
    ht_samples = list(ht_dict.values())
    ht_samples = [
        ht.annotate(sample_qc=ht.sample_qc.drop('gq_stats', 'dp_stats'))
        if 'gq_stats' in list(ht.sample_qc) else ht for ht in ht_samples
    ]
    ht_joined = ht_samples[0].union(*ht_samples[1:], unify=True)
    geno_stats = ht_joined.group_by(ht_joined.cov).aggregate(
        n_hom_ref_stats=hl.int(hl.agg.mean(ht_joined.sample_qc.n_hom_ref)),
        n_het_stats=hl.int(hl.agg.mean(ht_joined.sample_qc.n_het)),
        n_hom_var_stats=hl.int(hl.agg.mean(ht_joined.sample_qc.n_hom_var)))

    geno_stats.show(40)
    geno_stats.export(geno_prefix + 'geno_stats.tsv')
示例#31
0
def test_blanczos_against_hail():
    k = 10

    def concatToNumpy(field, horizontal=True):
        blocks = field.collect()
        if horizontal:
            return np.concatenate(blocks, axis=0)
        else:
            return np.concatenate(blocks, axis=1)

    hl.utils.get_1kg('data/')
    hl.import_vcf('data/1kg.vcf.bgz').write('data/1kg.mt', overwrite=True)
    dataset = hl.read_matrix_table('data/1kg.mt')

    b_eigens, b_scores, b_loadings = hl._blanczos_pca(hl.int(
        hl.is_defined(dataset.GT)),
                                                      k=k,
                                                      q_iterations=3,
                                                      compute_loadings=True)
    b_scores = concatToNumpy(b_scores.scores)
    b_loadings = concatToNumpy(b_loadings.loadings)
    b_scores = np.reshape(b_scores, (len(b_scores) // k, k))
    b_loadings = np.reshape(b_loadings, (len(b_loadings) // k, k))

    h_eigens, h_scores, h_loadings = hl.pca(hl.int(hl.is_defined(dataset.GT)),
                                            k=k,
                                            compute_loadings=True)
    h_scores = np.reshape(concatToNumpy(h_scores.scores), b_scores.shape)
    h_loadings = np.reshape(concatToNumpy(h_loadings.loadings),
                            b_loadings.shape)

    # equation 12 from https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4827102/pdf/main.pdf
    def bound(vs, us):
        return 1 / k * sum([np.linalg.norm(us.T @ vs[:, i]) for i in range(k)])

    MEV = bound(h_loadings, b_loadings)

    np.testing.assert_allclose(b_eigens, h_eigens, rtol=0.05)
    assert MEV > 0.9
示例#32
0
    def phase_parent_call(call: hl.expr.CallExpression, transmitted_allele_index: int):
        """
        Given a genotype and which allele was transmitted to the offspring, returns the parent phased genotype.

        :param CallExpression call: Parent genotype
        :param int transmitted_allele_index: index of transmitted allele (0 or 1)
        :return: Phased parent genotype
        :rtype: CallExpression
        """
        return hl.call(
            call[transmitted_allele_index],
            call[hl.int(transmitted_allele_index == 0)],
            phased=True
        )
示例#33
0
def ld_score_regression(weight_expr,
                        ld_score_expr,
                        chi_sq_exprs,
                        n_samples_exprs,
                        n_blocks=200,
                        two_step_threshold=30,
                        n_reference_panel_variants=None) -> Table:
    r"""Estimate SNP-heritability and level of confounding biases from
    GWAS summary statistics.

    Given a set or multiple sets of genome-wide association study (GWAS)
    summary statistics, :func:`.ld_score_regression` estimates the heritability
    of a trait or set of traits and the level of confounding biases present in
    the underlying studies by regressing chi-squared statistics on LD scores,
    leveraging the model:

    .. math::

        \mathrm{E}[\chi_j^2] = 1 + Na + \frac{Nh_g^2}{M}l_j

    *  :math:`\mathrm{E}[\chi_j^2]` is the expected chi-squared statistic
       for variant :math:`j` resulting from a test of association between
       variant :math:`j` and a trait.
    *  :math:`l_j = \sum_{k} r_{jk}^2` is the LD score of variant
       :math:`j`, calculated as the sum of squared correlation coefficients
       between variant :math:`j` and nearby variants. See :func:`ld_score`
       for further details.
    *  :math:`a` captures the contribution of confounding biases, such as
       cryptic relatedness and uncontrolled population structure, to the
       association test statistic.
    *  :math:`h_g^2` is the SNP-heritability, or the proportion of variation
       in the trait explained by the effects of variants included in the
       regression model above.
    *  :math:`M` is the number of variants used to estimate :math:`h_g^2`.
    *  :math:`N` is the number of samples in the underlying association study.

    For more details on the method implemented in this function, see:

    * `LD Score regression distinguishes confounding from polygenicity in genome-wide association studies (Bulik-Sullivan et al, 2015) <https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4495769/>`__

    Examples
    --------

    Run the method on a matrix table of summary statistics, where the rows
    are variants and the columns are different phenotypes:

    >>> mt_gwas = hl.read_matrix_table('data/ld_score_regression.sumstats.mt')
    >>> ht_results = hl.experimental.ld_score_regression(
    ...     weight_expr=mt_gwas['ld_score'],
    ...     ld_score_expr=mt_gwas['ld_score'],
    ...     chi_sq_exprs=mt_gwas['chi_squared'],
    ...     n_samples_exprs=mt_gwas['n'])


    Run the method on a table with summary statistics for a single
    phenotype:

    >>> ht_gwas = hl.read_table('data/ld_score_regression.sumstats.ht')
    >>> ht_results = hl.experimental.ld_score_regression(
    ...     weight_expr=ht_gwas['ld_score'],
    ...     ld_score_expr=ht_gwas['ld_score'],
    ...     chi_sq_exprs=ht_gwas['chi_squared_50_irnt'],
    ...     n_samples_exprs=ht_gwas['n_50_irnt'])

    Run the method on a table with summary statistics for multiple
    phenotypes:

    >>> ht_gwas = hl.read_table('data/ld_score_regression.sumstats.ht')
    >>> ht_results = hl.experimental.ld_score_regression(
    ...     weight_expr=ht_gwas['ld_score'],
    ...     ld_score_expr=ht_gwas['ld_score'],
    ...     chi_sq_exprs=[ht_gwas['chi_squared_50_irnt'],
    ...                        ht_gwas['chi_squared_20160']],
    ...     n_samples_exprs=[ht_gwas['n_50_irnt'],
    ...                      ht_gwas['n_20160']])

    Notes
    -----
    The ``exprs`` provided as arguments to :func:`.ld_score_regression`
    must all be from the same object, either a :class:`Table` or a
    :class:`MatrixTable`.

    **If the arguments originate from a table:**

    *  The table must be keyed by fields ``locus`` of type
       :class:`.tlocus` and ``alleles``, a :py:data:`.tarray` of
       :py:data:`.tstr` elements.
    *  ``weight_expr``, ``ld_score_expr``, ``chi_sq_exprs``, and
       ``n_samples_exprs`` are must be row-indexed fields.
    *  The number of expressions passed to ``n_samples_exprs`` must be
       equal to one or the number of expressions passed to
       ``chi_sq_exprs``. If just one expression is passed to
       ``n_samples_exprs``, that sample size expression is assumed to
       apply to all sets of statistics passed to ``chi_sq_exprs``.
       Otherwise, the expressions passed to ``chi_sq_exprs`` and
       ``n_samples_exprs`` are matched by index.
    *  The ``phenotype`` field that keys the table returned by
       :func:`.ld_score_regression` will have generic :obj:`int` values
       ``0``, ``1``, etc. corresponding to the ``0th``, ``1st``, etc.
       expressions passed to the ``chi_sq_exprs`` argument.

    **If the arguments originate from a matrix table:**

    *  The dimensions of the matrix table must be variants
       (rows) by phenotypes (columns).
    *  The rows of the matrix table must be keyed by fields
       ``locus`` of type :class:`.tlocus` and ``alleles``,
       a :py:data:`.tarray` of :py:data:`.tstr` elements.
    *  The columns of the matrix table must be keyed by a field
       of type :py:data:`.tstr` that uniquely identifies phenotypes
       represented in the matrix table. The column key must be a single
       expression; compound keys are not accepted.
    *  ``weight_expr`` and ``ld_score_expr`` must be row-indexed
       fields.
    *  ``chi_sq_exprs`` must be a single entry-indexed field
       (not a list of fields).
    *  ``n_samples_exprs`` must be a single entry-indexed field
       (not a list of fields).
    *  The ``phenotype`` field that keys the table returned by
       :func:`.ld_score_regression` will have values corresponding to the
       column keys of the input matrix table.

    This function returns a :class:`Table` with one row per set of summary
    statistics passed to the ``chi_sq_exprs`` argument. The following
    row-indexed fields are included in the table:

    *  **phenotype** (:py:data:`.tstr`) -- The name of the phenotype. The
       returned table is keyed by this field. See the notes below for
       details on the possible values of this field.
    *  **mean_chi_sq** (:py:data:`.tfloat64`) -- The mean chi-squared
       test statistic for the given phenotype.
    *  **intercept** (`Struct`) -- Contains fields:

       -  **estimate** (:py:data:`.tfloat64`) -- A point estimate of the
          intercept :math:`1 + Na`.
       -  **standard_error**  (:py:data:`.tfloat64`) -- An estimate of
          the standard error of this point estimate.

    *  **snp_heritability** (`Struct`) -- Contains fields:

       -  **estimate** (:py:data:`.tfloat64`) -- A point estimate of the
          SNP-heritability :math:`h_g^2`.
       -  **standard_error** (:py:data:`.tfloat64`) -- An estimate of
          the standard error of this point estimate.

    Warning
    -------
    :func:`.ld_score_regression` considers only the rows for which both row
    fields ``weight_expr`` and ``ld_score_expr`` are defined. Rows with missing
    values in either field are removed prior to fitting the LD score
    regression model.

    Parameters
    ----------
    weight_expr : :class:`.Float64Expression`
                  Row-indexed expression for the LD scores used to derive
                  variant weights in the model.
    ld_score_expr : :class:`.Float64Expression`
                    Row-indexed expression for the LD scores used as covariates
                    in the model.
    chi_sq_exprs : :class:`.Float64Expression` or :obj:`list` of
                        :class:`.Float64Expression`
                        One or more row-indexed (if table) or entry-indexed
                        (if matrix table) expressions for chi-squared
                        statistics resulting from genome-wide association
                        studies.
    n_samples_exprs: :class:`.NumericExpression` or :obj:`list` of
                     :class:`.NumericExpression`
                     One or more row-indexed (if table) or entry-indexed
                     (if matrix table) expressions indicating the number of
                     samples used in the studies that generated the test
                     statistics supplied to ``chi_sq_exprs``.
    n_blocks : :obj:`int`
               The number of blocks used in the jackknife approach to
               estimating standard errors.
    two_step_threshold : :obj:`int`
                         Variants with chi-squared statistics greater than this
                         value are excluded in the first step of the two-step
                         procedure used to fit the model.
    n_reference_panel_variants : :obj:`int`, optional
                                 Number of variants used to estimate the
                                 SNP-heritability :math:`h_g^2`.

    Returns
    -------
    :class:`.Table`
        Table keyed by ``phenotype`` with intercept and heritability estimates
        for each phenotype passed to the function."""

    chi_sq_exprs = wrap_to_list(chi_sq_exprs)
    n_samples_exprs = wrap_to_list(n_samples_exprs)

    assert ((len(chi_sq_exprs) == len(n_samples_exprs)) or
            (len(n_samples_exprs) == 1))
    __k = 2  # number of covariates, including intercept

    ds = chi_sq_exprs[0]._indices.source

    analyze('ld_score_regression/weight_expr',
            weight_expr,
            ds._row_indices)
    analyze('ld_score_regression/ld_score_expr',
            ld_score_expr,
            ds._row_indices)

    # format input dataset
    if isinstance(ds, MatrixTable):
        if len(chi_sq_exprs) != 1:
            raise ValueError("""Only one chi_sq_expr allowed if originating
                from a matrix table.""")
        if len(n_samples_exprs) != 1:
            raise ValueError("""Only one n_samples_expr allowed if
                originating from a matrix table.""")

        col_key = list(ds.col_key)
        if len(col_key) != 1:
            raise ValueError("""Matrix table must be keyed by a single
                phenotype field.""")

        analyze('ld_score_regression/chi_squared_expr',
                chi_sq_exprs[0],
                ds._entry_indices)
        analyze('ld_score_regression/n_samples_expr',
                n_samples_exprs[0],
                ds._entry_indices)

        ds = ds._select_all(row_exprs={'__locus': ds.locus,
                                       '__alleles': ds.alleles,
                                       '__w_initial': weight_expr,
                                       '__w_initial_floor': hl.max(weight_expr,
                                                                   1.0),
                                       '__x': ld_score_expr,
                                       '__x_floor': hl.max(ld_score_expr,
                                                           1.0)},
                            row_key=['__locus', '__alleles'],
                            col_exprs={'__y_name': ds[col_key[0]]},
                            col_key=['__y_name'],
                            entry_exprs={'__y': chi_sq_exprs[0],
                                         '__n': n_samples_exprs[0]})
        ds = ds.annotate_entries(**{'__w': ds.__w_initial})

        ds = ds.filter_rows(hl.is_defined(ds.__locus) &
                            hl.is_defined(ds.__alleles) &
                            hl.is_defined(ds.__w_initial) &
                            hl.is_defined(ds.__x))

    else:
        assert isinstance(ds, Table)
        for y in chi_sq_exprs:
            analyze('ld_score_regression/chi_squared_expr', y, ds._row_indices)
        for n in n_samples_exprs:
            analyze('ld_score_regression/n_samples_expr', n, ds._row_indices)

        ys = ['__y{:}'.format(i) for i, _ in enumerate(chi_sq_exprs)]
        ws = ['__w{:}'.format(i) for i, _ in enumerate(chi_sq_exprs)]
        ns = ['__n{:}'.format(i) for i, _ in enumerate(n_samples_exprs)]

        ds = ds.select(**dict(**{'__locus': ds.locus,
                                 '__alleles': ds.alleles,
                                 '__w_initial': weight_expr,
                                 '__x': ld_score_expr},
                              **{y: chi_sq_exprs[i]
                                 for i, y in enumerate(ys)},
                              **{w: weight_expr for w in ws},
                              **{n: n_samples_exprs[i]
                                 for i, n in enumerate(ns)}))
        ds = ds.key_by(ds.__locus, ds.__alleles)

        table_tmp_file = new_temp_file()
        ds.write(table_tmp_file)
        ds = hl.read_table(table_tmp_file)

        hts = [ds.select(**{'__w_initial': ds.__w_initial,
                            '__w_initial_floor': hl.max(ds.__w_initial,
                                                        1.0),
                            '__x': ds.__x,
                            '__x_floor': hl.max(ds.__x, 1.0),
                            '__y_name': i,
                            '__y': ds[ys[i]],
                            '__w': ds[ws[i]],
                            '__n': hl.int(ds[ns[i]])})
               for i, y in enumerate(ys)]

        mts = [ht.to_matrix_table(row_key=['__locus',
                                           '__alleles'],
                                  col_key=['__y_name'],
                                  row_fields=['__w_initial',
                                              '__w_initial_floor',
                                              '__x',
                                              '__x_floor'])
               for ht in hts]

        ds = mts[0]
        for i in range(1, len(ys)):
            ds = ds.union_cols(mts[i])

        ds = ds.filter_rows(hl.is_defined(ds.__locus) &
                            hl.is_defined(ds.__alleles) &
                            hl.is_defined(ds.__w_initial) &
                            hl.is_defined(ds.__x))

    mt_tmp_file1 = new_temp_file()
    ds.write(mt_tmp_file1)
    mt = hl.read_matrix_table(mt_tmp_file1)

    if not n_reference_panel_variants:
        M = mt.count_rows()
    else:
        M = n_reference_panel_variants

    # block variants for each phenotype
    n_phenotypes = mt.count_cols()

    mt = mt.annotate_entries(__in_step1=(hl.is_defined(mt.__y) &
                                         (mt.__y < two_step_threshold)),
                             __in_step2=hl.is_defined(mt.__y))

    mt = mt.annotate_cols(__col_idx=hl.int(hl.scan.count()),
                          __m_step1=hl.agg.count_where(mt.__in_step1),
                          __m_step2=hl.agg.count_where(mt.__in_step2))

    col_keys = list(mt.col_key)

    ht = mt.localize_entries(entries_array_field_name='__entries',
                             columns_array_field_name='__cols')

    ht = ht.annotate(__entries=hl.rbind(
        hl.scan.array_agg(
            lambda entry: hl.scan.count_where(entry.__in_step1),
            ht.__entries),
        lambda step1_indices: hl.map(
            lambda i: hl.rbind(
                hl.int(hl.or_else(step1_indices[i], 0)),
                ht.__cols[i].__m_step1,
                ht.__entries[i],
                lambda step1_idx, m_step1, entry: hl.rbind(
                    hl.map(
                        lambda j: hl.int(hl.floor(j * (m_step1 / n_blocks))),
                        hl.range(0, n_blocks + 1)),
                    lambda step1_separators: hl.rbind(
                        hl.set(step1_separators).contains(step1_idx),
                        hl.sum(
                            hl.map(
                                lambda s1: step1_idx >= s1,
                                step1_separators)) - 1,
                        lambda is_separator, step1_block: entry.annotate(
                            __step1_block=step1_block,
                            __step2_block=hl.cond(~entry.__in_step1 & is_separator,
                                                  step1_block - 1,
                                                  step1_block))))),
            hl.range(0, hl.len(ht.__entries)))))

    mt = ht._unlocalize_entries('__entries', '__cols', col_keys)

    mt_tmp_file2 = new_temp_file()
    mt.write(mt_tmp_file2)
    mt = hl.read_matrix_table(mt_tmp_file2)
    
    # initial coefficient estimates
    mt = mt.annotate_cols(__initial_betas=[
        1.0, (hl.agg.mean(mt.__y) - 1.0) / hl.agg.mean(mt.__x)])
    mt = mt.annotate_cols(__step1_betas=mt.__initial_betas,
                          __step2_betas=mt.__initial_betas)

    # step 1 iteratively reweighted least squares
    for i in range(3):
        mt = mt.annotate_entries(__w=hl.cond(
            mt.__in_step1,
            1.0/(mt.__w_initial_floor * 2.0 * (mt.__step1_betas[0] +
                                               mt.__step1_betas[1] *
                                               mt.__x_floor)**2),
            0.0))
        mt = mt.annotate_cols(__step1_betas=hl.agg.filter(
            mt.__in_step1,
            hl.agg.linreg(y=mt.__y,
                          x=[1.0, mt.__x],
                          weight=mt.__w).beta))
        mt = mt.annotate_cols(__step1_h2=hl.max(hl.min(
            mt.__step1_betas[1] * M / hl.agg.mean(mt.__n), 1.0), 0.0))
        mt = mt.annotate_cols(__step1_betas=[
            mt.__step1_betas[0],
            mt.__step1_h2 * hl.agg.mean(mt.__n) / M])

    # step 1 block jackknife
    mt = mt.annotate_cols(__step1_block_betas=[
        hl.agg.filter((mt.__step1_block != i) & mt.__in_step1,
                      hl.agg.linreg(y=mt.__y,
                                    x=[1.0, mt.__x],
                                    weight=mt.__w).beta)
        for i in range(n_blocks)])

    mt = mt.annotate_cols(__step1_block_betas_bias_corrected=hl.map(
        lambda x: n_blocks * mt.__step1_betas - (n_blocks - 1) * x,
        mt.__step1_block_betas))

    mt = mt.annotate_cols(
        __step1_jackknife_mean=hl.map(
            lambda i: hl.mean(
                hl.map(lambda x: x[i],
                       mt.__step1_block_betas_bias_corrected)),
            hl.range(0, __k)),
        __step1_jackknife_variance=hl.map(
            lambda i: (hl.sum(
                hl.map(lambda x: x[i]**2,
                       mt.__step1_block_betas_bias_corrected)) -
                       hl.sum(
                hl.map(lambda x: x[i],
                       mt.__step1_block_betas_bias_corrected))**2 /
                       n_blocks) /
            (n_blocks - 1) / n_blocks,
            hl.range(0, __k)))

    # step 2 iteratively reweighted least squares
    for i in range(3):
        mt = mt.annotate_entries(__w=hl.cond(
            mt.__in_step2,
            1.0/(mt.__w_initial_floor *
                 2.0 * (mt.__step2_betas[0] +
                        mt.__step2_betas[1] *
                        mt.__x_floor)**2),
            0.0))
        mt = mt.annotate_cols(__step2_betas=[
            mt.__step1_betas[0],
            hl.agg.filter(mt.__in_step2,
                          hl.agg.linreg(y=mt.__y - mt.__step1_betas[0],
                                        x=[mt.__x],
                                        weight=mt.__w).beta[0])])
        mt = mt.annotate_cols(__step2_h2=hl.max(hl.min(
            mt.__step2_betas[1] * M/hl.agg.mean(mt.__n), 1.0), 0.0))
        mt = mt.annotate_cols(__step2_betas=[
            mt.__step1_betas[0],
            mt.__step2_h2 * hl.agg.mean(mt.__n)/M])

    # step 2 block jackknife
    mt = mt.annotate_cols(__step2_block_betas=[
        hl.agg.filter((mt.__step2_block != i) & mt.__in_step2,
                      hl.agg.linreg(y=mt.__y - mt.__step1_betas[0],
                                    x=[mt.__x],
                                    weight=mt.__w).beta[0])
        for i in range(n_blocks)])

    mt = mt.annotate_cols(__step2_block_betas_bias_corrected=hl.map(
        lambda x: n_blocks * mt.__step2_betas[1] - (n_blocks - 1) * x,
        mt.__step2_block_betas))

    mt = mt.annotate_cols(
        __step2_jackknife_mean=hl.mean(
            mt.__step2_block_betas_bias_corrected),
        __step2_jackknife_variance=(
            hl.sum(mt.__step2_block_betas_bias_corrected**2) -
            hl.sum(mt.__step2_block_betas_bias_corrected)**2 /
            n_blocks) / (n_blocks - 1) / n_blocks)

    # combine step 1 and step 2 block jackknifes
    mt = mt.annotate_entries(
        __step2_initial_w=1.0/(mt.__w_initial_floor *
                               2.0 * (mt.__initial_betas[0] +
                                      mt.__initial_betas[1] *
                                      mt.__x_floor)**2))

    mt = mt.annotate_cols(
        __final_betas=[
            mt.__step1_betas[0],
            mt.__step2_betas[1]],
        __c=(hl.agg.sum(mt.__step2_initial_w * mt.__x) /
             hl.agg.sum(mt.__step2_initial_w * mt.__x**2)))

    mt = mt.annotate_cols(__final_block_betas=hl.map(
        lambda i: (mt.__step2_block_betas[i] - mt.__c *
                   (mt.__step1_block_betas[i][0] - mt.__final_betas[0])),
        hl.range(0, n_blocks)))

    mt = mt.annotate_cols(
        __final_block_betas_bias_corrected=(n_blocks * mt.__final_betas[1] -
                                            (n_blocks - 1) *
                                            mt.__final_block_betas))

    mt = mt.annotate_cols(
        __final_jackknife_mean=[
            mt.__step1_jackknife_mean[0],
            hl.mean(mt.__final_block_betas_bias_corrected)],
        __final_jackknife_variance=[
            mt.__step1_jackknife_variance[0],
            (hl.sum(mt.__final_block_betas_bias_corrected**2) -
             hl.sum(mt.__final_block_betas_bias_corrected)**2 /
             n_blocks) / (n_blocks - 1) / n_blocks])

    # convert coefficient to heritability estimate
    mt = mt.annotate_cols(
        phenotype=mt.__y_name,
        mean_chi_sq=hl.agg.mean(mt.__y),
        intercept=hl.struct(
            estimate=mt.__final_betas[0],
            standard_error=hl.sqrt(mt.__final_jackknife_variance[0])),
        snp_heritability=hl.struct(
            estimate=(M/hl.agg.mean(mt.__n)) * mt.__final_betas[1],
            standard_error=hl.sqrt((M/hl.agg.mean(mt.__n))**2 *
                                   mt.__final_jackknife_variance[1])))

    # format and return results
    ht = mt.cols()
    ht = ht.key_by(ht.phenotype)
    ht = ht.select(ht.mean_chi_sq,
                   ht.intercept,
                   ht.snp_heritability)

    ht_tmp_file = new_temp_file()
    ht.write(ht_tmp_file)
    ht = hl.read_table(ht_tmp_file)
    
    return ht