Exemple #1
0
def get_duplicated_samples_ht(duplicated_samples: List[Set[str]],
                              samples_rankings_ht: hl.Table,
                              rank_ann: str = 'rank'):
    """
    Creates a HT with duplicated samples sets.
    Each row is indexed by the sample that is kept and also contains the set of duplicate samples that should be filtered.

    `samples_rankings_ht` is a HT containing a global rank for each of the samples (smaller is better).

    :param duplicated_samples: List of sets of duplicated samples
    :param samples_rankings_ht: HT with global rank for each sample
    :param rank_ann: Annotation in `samples_ranking_ht` containing each sample global rank (smaller is better).
    :return: HT with duplicate sample sets, including which to keep/filter
    """
    dups_ht = hl.Table.parallelize([
        hl.struct(dup_set=i, dups=duplicated_samples[i])
        for i in range(0, len(duplicated_samples))
    ])
    dups_ht = dups_ht.explode(dups_ht.dups, name='_dup')
    dups_ht = dups_ht.key_by('_dup')
    dups_ht = dups_ht.annotate(rank=samples_rankings_ht[dups_ht.key][rank_ann])
    dups_cols = hl.bind(
        lambda x: hl.struct(kept=x[0], filtered=x[1:]),
        hl.sorted(hl.agg.collect(hl.tuple([dups_ht._dup, dups_ht.rank])),
                  key=lambda x: x[1]).map(lambda x: x[0]))
    dups_ht = dups_ht.group_by(dups_ht.dup_set).aggregate(**dups_cols)

    if isinstance(dups_ht.kept, hl.expr.StructExpression):
        dups_ht = dups_ht.key_by(**dups_ht.kept).drop('kept')
    else:
        dups_ht = dups_ht.key_by(
            s=dups_ht.kept
        )  # Since there is no defined name in the case of a non-struct type, use `s`
    return dups_ht
Exemple #2
0
    def find_worst_transcript_consequence(
            tcl: hl.expr.ArrayExpression) -> hl.expr.StructExpression:
        """
        Gets worst transcript_consequence from an array of em
        """
        flag_score = 500
        no_flag_score = flag_score * (1 + penalize_flags)

        def csq_score(tc):
            return csq_dict[csqs.find(
                lambda x: x == tc.most_severe_consequence)]

        tcl = tcl.map(lambda tc: tc.annotate(
            csq_score=hl.case(missing_false=True).
            when((tc.lof == 'HC') & (tc.lof_flags == ''),
                 csq_score(tc) - no_flag_score).when(
                     (tc.lof == 'HC') & (tc.lof_flags != ''),
                     csq_score(tc) - flag_score).when(tc.lof == 'LC',
                                                      csq_score(tc) - 10).
            when(tc.polyphen_prediction == 'probably_damaging',
                 csq_score(tc) - 0.5).when(
                     tc.polyphen_prediction == 'possibly_damaging',
                     csq_score(tc) - 0.25).when(
                         tc.polyphen_prediction == 'benign',
                         csq_score(tc) - 0.1).default(csq_score(tc))))
        return hl.or_missing(
            hl.len(tcl) > 0,
            hl.sorted(tcl, lambda x: x.csq_score)[0])
def add_popmax_expr(freq: hl.expr.ArrayExpression,
                    freq_meta: hl.expr.ArrayExpression,
                    populations: Set[str]) -> hl.expr.ArrayExpression:
    """
    Calculates popmax (add an additional entry into freq with popmax: pop)

    :param ArrayExpression freq: ArrayExpression of Structs with ['ac', 'an', 'hom']
    :param ArrayExpression freq_meta: ArrayExpression of meta dictionaries corresponding to freq
    :param set of str populations: Set of populations over which to calculate popmax
    :return: Frequency data with annotated popmax
    :rtype: ArrayExpression
    """
    pops_to_use = hl.literal(populations)
    freq = hl.map(lambda x: x[0].annotate(meta=x[1]), hl.zip(freq, freq_meta))
    freq_filtered = hl.filter(
        lambda f: (f.meta.size() == 2) & (f.meta.get('group') == 'adj') &
        pops_to_use.contains(f.meta.get('pop')) & (f.AC > 0), freq)
    sorted_freqs = hl.sorted(freq_filtered, key=lambda x: x.AF, reverse=True)
    return hl.or_missing(
        hl.len(sorted_freqs) > 0,
        hl.struct(AC=sorted_freqs[0].AC,
                  AF=sorted_freqs[0].AF,
                  AN=sorted_freqs[0].AN,
                  homozygote_count=sorted_freqs[0].homozygote_count,
                  pop=sorted_freqs[0].meta['pop']))
def load_cmg(cmg_csv: str) -> hl.Table:
    cmg_ht = hl.import_table(cmg_csv, impute=True, delimiter=",", quote='"')

    cmg_ht = cmg_ht.transmute(
        locus1_b38=hl.locus("chr" + hl.str(cmg_ht.chrom_1), cmg_ht.pos_1, reference_genome='GRCh38'),
        alleles1_b38=[cmg_ht.ref_1, cmg_ht.alt_1],
        locus2_b38=hl.locus("chr" + hl.str(cmg_ht.chrom_2), cmg_ht.pos_2, reference_genome='GRCh38'),
        alleles2_b38=[cmg_ht.ref_2, cmg_ht.alt_2]
    )

    liftover_references = get_liftover_genome(cmg_ht.rename({'locus1_b38': 'locus'}))
    lifted_over_variants = hl.sorted(
        hl.array([
            liftover_expr(cmg_ht.locus1_b38, cmg_ht.alleles1_b38, liftover_references[1]),
            liftover_expr(cmg_ht.locus2_b38, cmg_ht.alleles2_b38, liftover_references[1])
        ]),
        lambda x: x.locus
    )

    cmg_ht = cmg_ht.key_by(
        locus1=lifted_over_variants[0].locus,
        alleles1=lifted_over_variants[0].alleles,
        locus2=lifted_over_variants[1].locus,
        alleles2=lifted_over_variants[1].alleles
    )

    return cmg_ht.annotate(
        bad_liftover=(
                hl.is_missing(cmg_ht.locus1) |
                hl.is_missing(cmg_ht.locus2) |
                (cmg_ht.locus1.sequence_context() != cmg_ht.alleles1[0][0]) |
                (cmg_ht.locus2.sequence_context() != cmg_ht.alleles2[0][0])
        )
    )
def pop_max_expr(
    freq: hl.expr.ArrayExpression,
    freq_meta: hl.expr.ArrayExpression,
    pops_to_exclude: Optional[Set[str]] = None,
) -> hl.expr.StructExpression:
    """
    Creates an expression containing popmax: the frequency information about the population
    that has the highest AF from the populations provided in `freq_meta`,
    excluding those specified in `pops_to_exclude`.
    Only frequencies from adj populations are considered.

    This resulting struct contains the following fields:

        - AC: int32
        - AF: float64
        - AN: int32
        - homozygote_count: int32
        - pop: str

    :param freq: ArrayExpression of Structs with fields ['AC', 'AF', 'AN', 'homozygote_count']
    :param freq_meta: ArrayExpression of meta dictionaries corresponding to freq (as returned by annotate_freq)
    :param pops_to_exclude: Set of populations to skip for popmax calcluation

    :return: Popmax struct
    """
    _pops_to_exclude = hl.literal(pops_to_exclude)
    popmax_freq_indices = hl.range(0, hl.len(freq_meta)).filter(
        lambda i: (hl.set(freq_meta[i].keys()) == {"group", "pop"})
        & (freq_meta[i]["group"] == "adj")
        & (~_pops_to_exclude.contains(freq_meta[i]["pop"])))
    freq_filtered = popmax_freq_indices.map(lambda i: freq[i].annotate(
        pop=freq_meta[i]["pop"])).filter(lambda f: f.AC > 0)

    sorted_freqs = hl.sorted(freq_filtered, key=lambda x: x.AF, reverse=True)
    return hl.or_missing(hl.len(sorted_freqs) > 0, sorted_freqs[0])
Exemple #6
0
def pull_out_worst_from_tx_annotate(mt):
    csq_order = []
    for loftee_filter in ["HC", "LC"]:
        for no_flag in [True, False]:
            for consequence in CSQ_CODING_HIGH_IMPACT:
                csq_order.append((loftee_filter, no_flag, consequence))

    # prioritization of mis and syn variant on protein coding transcripts
    csq_order.extend([(hl.null(hl.tstr), True, x)
                      for x in CSQ_CODING_MEDIUM_IMPACT + CSQ_CODING_LOW_IMPACT
                      ])

    # Any variant on a non protein coding transcript (ie. where LOF = None)
    csq_order.extend([(hl.null(hl.tstr), True, x)
                      for x in CSQ_CODING_HIGH_IMPACT +
                      CSQ_CODING_MEDIUM_IMPACT + CSQ_CODING_LOW_IMPACT])

    csq_order = hl.literal({(x): i for i, x in enumerate(csq_order)})

    mt = mt.annotate_rows(**hl.sorted(
        mt.tx_annotation,
        key=lambda x: csq_order[
            (x.lof, hl.or_else(hl.is_missing(x.lof_flag), False), x.csq)])[0])

    return mt
Exemple #7
0
def get_worst_gene_csq_code_expr(vep_expr: hl.expr.StructExpression) -> hl.expr.DictExpression:
    worst_gene_csq_expr = vep_expr.transcript_consequences.filter(
        lambda tc: tc.biotype == 'protein_coding'
    ).map(
        lambda ts: ts.select(
            'gene_id',
            'gene_symbol',
            csq=(
                hl.case(missing_false=True)
                    .when(ts.lof == 'HC', CSQ_CODES.index('lof'))
                    .when(ts.polyphen_prediction == 'probably_damaging', CSQ_CODES.index('damaging_missense'))
                    .when(ts.consequence_terms.any(lambda x: x == 'missense_variant'), CSQ_CODES.index('missense_variant'))
                    .when(ts.consequence_terms.all(lambda x: x == 'synonymous_variant'), CSQ_CODES.index('synonymous_variant'))
                    .or_missing()
            )
        )
    )

    worst_gene_csq_expr = worst_gene_csq_expr.filter(lambda x: hl.is_defined(x.csq))
    worst_gene_csq_expr = worst_gene_csq_expr.group_by(lambda x: x.gene_id)
    worst_gene_csq_expr = worst_gene_csq_expr.map_values(
        lambda x: hl.sorted(x, key=lambda y: y.csq)[0]
    )

    return worst_gene_csq_expr
Exemple #8
0
def get_sorted_variants_expr(
        locus1: hl.expr.LocusExpression, alleles1: hl.expr.ArrayExpression,
        locus2: hl.expr.LocusExpression,
        alleles2: hl.expr.ArrayExpression) -> hl.expr.StructExpression:
    if locus1.dtype.reference_genome.name == 'GRCh37':
        variants = [
            hl.struct(locus=locus1, alleles=alleles1),
            hl.struct(locus=locus2, alleles=alleles2)
        ]
    else:
        logger.warning("Variants are not on GRCh37; they will be lifted over.")
        _, destination_ref = get_liftover_genome(hl.struct(locus=locus1))
        variants = [
            liftover_expr(locus=locus1,
                          alleles=alleles1,
                          destination_ref=destination_ref),
            liftover_expr(locus=locus2,
                          alleles=alleles2,
                          destination_ref=destination_ref)
        ]

    sorted_variants = hl.sorted(variants, key=lambda x: x.locus)
    return hl.struct(locus1=sorted_variants[0].locus,
                     alleles1=sorted_variants[0].alleles,
                     locus2=sorted_variants[1].locus,
                     alleles2=sorted_variants[1].alleles)
Exemple #9
0
def filter_kin_ht(
    ht: hl.Table,
    out_summary: io.TextIOWrapper,
    first_degree_pi_hat: float = 0.40,
    grandparent_pi_hat: float = 0.20,
    grandparent_ibd1: float = 0.25,
    grandparent_ibd2: float = 0.15,
) -> hl.Table:
    """
    Filter the kinship table to relationships of grandparents and above.

    :param ht: hl.Table
    :param out_summary: Summary file with a summary statistics and notes
    :param first_degree_pi_hat: Minimum pi_hat threshold to use to filter the kinship table to first degree relatives
    :param grandparent_pi_hat: Minimum pi_hat threshold to use to filter the kinship table to grandparents
    :param grandparent_ibd1: Minimum IBD1 threshold to use to filter the kinship table to grandparents
    :param grandparent_ibd2: Maximum IBD2 threshold to use to filter the kinship table to grandparents
    :return: Table containing only relationships of grandparents and above
    """
    # Filter to anything above the relationship of a grandparent
    ht = ht.filter((ht.pi_hat > first_degree_pi_hat)
                   | ((ht.pi_hat > grandparent_pi_hat)
                      & (ht.ibd1 > grandparent_ibd1)
                      & (ht.ibd2 < grandparent_ibd2)))
    ht = ht.annotate(pair=hl.sorted([ht.i, ht.j]))

    out_summary.write(
        f"NOTE: kinship table was filtered to:\n(kin > {first_degree_pi_hat}) or kin > {grandparent_pi_hat} and IBD1 > {grandparent_ibd1} and IBD2 > {grandparent_ibd2})\n"
    )
    out_summary.write(
        f"relationships not meeting this critera were not evaluated\n\n")

    return ht
Exemple #10
0
def load_prescription_data(prescription_data_tsv_path: str, prescription_mapping_tsv_path):
    ht = hl.import_table(prescription_data_tsv_path, types={'eid': hl.tint, 'data_provider': hl.tint}, key='eid')
    mapping_ht = hl.import_table(prescription_mapping_tsv_path, impute=True, key='Original_Prescription')
    ht = ht.annotate(issue_date=hl.cond(hl.len(ht.issue_date) == 0, hl.null(hl.tint64),
                                        hl.experimental.strptime(ht.issue_date + ' 00:00:00', '%d/%m/%Y %H:%M:%S', 'GMT')),
                     **mapping_ht[ht.drug_name])
    ht = ht.filter(ht.Generic_Name != '').key_by('eid', 'Generic_Name', 'Drug_Category_and_Indication').collect_by_key()
    ht = ht.annotate(values=hl.sorted(ht.values, key=lambda x: x.issue_date))
    return ht.to_matrix_table(row_key=['eid'], col_key=['Generic_Name'], col_fields=['Drug_Category_and_Indication'])
Exemple #11
0
def project_max_expr(
    project_expr: hl.expr.StringExpression,
    gt_expr: hl.expr.CallExpression,
    alleles_expr: hl.expr.ArrayExpression,
    n_projects: int = 5,
) -> hl.expr.ArrayExpression:
    """
    Create an expression that computes allele frequency information by project for the `n_projects` with the largest AF at this row.

    Will return an array with one element per non-reference allele.

    Each of these elements is itself an array of structs with the following fields:

        - AC: int32
        - AF: float64
        - AN: int32
        - homozygote_count: int32
        - project: str

    .. note::

        Only projects with AF > 0 are returned.
        In case of ties, the project ordering is not guaranteed, and at most `n_projects` are returned.

    :param project_expr: column expression containing the project
    :param gt_expr: entry expression containing the genotype
    :param alleles_expr: row expression containing the alleles
    :param n_projects: Maximum number of projects to return for each row
    :return: projectmax expression
    """
    n_alleles = hl.len(alleles_expr)

    # compute call stats by  project
    project_cs = hl.array(
        hl.agg.group_by(project_expr, hl.agg.call_stats(gt_expr,
                                                        alleles_expr)))

    return hl.or_missing(
        n_alleles > 1,  # Exclude monomorphic sites
        hl.range(1, n_alleles).map(lambda ai: hl.sorted(
            project_cs.filter(
                # filter to projects with AF > 0
                lambda x: x[1].AF[ai] > 0),
            # order the callstats computed by AF in decreasing order
            lambda x: -x[1].AF[ai]
            # take the n_projects projects with largest AF
        )[:n_projects].map(
            # add the project in the callstats struct
            lambda x: x[1].annotate(
                AC=x[1].AC[ai],
                AF=x[1].AF[ai],
                AN=x[1].AN,
                homozygote_count=x[1].homozygote_count[ai],
                project=x[0],
            ))),
    )
Exemple #12
0
def post_process_gene_map_ht(gene_ht):
    groups = [
        'pLoF', 'missense|LC', 'pLoF|missense|LC', 'synonymous', 'missense'
    ]
    variant_groups = hl.map(
        lambda group: group.split('\\|').flatmap(lambda csq: gene_ht.variants.
                                                 get(csq)), groups)
    gene_ht = gene_ht.transmute(variant_groups=hl.zip(
        groups, variant_groups)).explode('variant_groups')
    gene_ht = gene_ht.transmute(annotation=gene_ht.variant_groups[0],
                                variants=hl.sorted(gene_ht.variant_groups[1]))
    gene_ht = gene_ht.key_by(start=gene_ht.interval.start)
    return gene_ht.filter(hl.len(gene_ht.variants) > 0)
def format_regional_missense_constraint(ds):
    ds = ds.annotate(obs_mis=hl.int(ds.obs_mis))

    ds = ds.annotate(start=hl.min(ds.genomic_start, ds.genomic_end), stop=hl.max(ds.genomic_start, ds.genomic_end))

    ds = ds.drop("amino_acids", "chr", "gene", "genomic_start", "genomic_end", "region_name")

    ds = ds.transmute(transcript_id=ds.transcript.split("\\.")[0])

    ds = ds.group_by("transcript_id").aggregate(regions=hl.agg.collect(ds.row_value))

    ds = ds.annotate(regions=hl.sorted(ds.regions, lambda region: region.start))

    return ds
Exemple #14
0
def merge_overlapping_regions(regions):
    return hl.cond(
        hl.len(regions) > 1,
        hl.rbind(
            hl.sorted(regions, lambda region: region.start),
            lambda sorted_regions: sorted_regions[1:].fold(
                lambda acc, region: hl.cond(
                    region.start <= acc[-1].stop + 1,
                    acc[:-1].append(acc[-1].annotate(stop=hl.max(
                        region.stop, acc[-1].stop))),
                    acc.append(region),
                ),
                [sorted_regions[0]],
            ),
        ),
        regions,
    )
def prepare_variant_results():
    results_path = pipeline_config.get("SCHEMA", "variant_results_path")
    annotations_path = pipeline_config.get("SCHEMA",
                                           "variant_annotations_path")

    results = hl.read_table(results_path)

    results = results.drop("v", "af_case", "af_ctrl")

    # Add n_denovos to AC_case
    results = results.annotate(ac_case=hl.or_else(results.ac_case, 0) +
                               hl.or_else(results.n_denovos, 0))

    results = results.annotate(
        source=hl.delimit(hl.sorted(hl.array(results.source)), ", "))

    results = results.group_by(
        "locus",
        "alleles").aggregate(group_results=hl.agg.collect(results.row_value))
    results = results.annotate(group_results=hl.dict(
        results.group_results.map(lambda group_result:
                                  (group_result.analysis_group,
                                   group_result.drop("analysis_group")))))

    variants = hl.read_table(annotations_path)
    variants = variants.select(
        gene_id=variants.gene_id,
        consequence=hl.case().when(
            (variants.canonical_term == "missense_variant") &
            (variants.mpc >= 3), "missense_variant_mpc_>=3").when(
                (variants.canonical_term == "missense_variant") &
                (variants.mpc >= 2), "missense_variant_mpc_2-3").when(
                    variants.canonical_term == "missense_variant",
                    "missense_variant_mpc_<2").default(
                        variants.canonical_term),
        hgvsc=variants.hgvsc_canonical.split(":")[-1],
        hgvsp=variants.hgvsp_canonical.split(":")[-1],
        info=hl.struct(cadd=variants.cadd,
                       mpc=variants.mpc,
                       polyphen=variants.polyphen),
    )

    variants = variants.annotate(**results[variants.key])
    variants = variants.filter(hl.is_defined(variants.group_results))

    return variants
Exemple #16
0
def compute_kinship_ht(mt, genome_version="GRCh38"):

    mt = filter_to_biallelics(mt)
    mt = filter_to_autosomes(mt)
    mt = mt.filter_rows(hl.is_snp(mt.alleles[0], mt.alleles[1]))

    mt = hl.variant_qc(mt)
    mt = mt.filter_rows(mt.variant_qc.call_rate > 0.99)
    #mt = mt.filter_rows(mt.info.AF > 0.001) # leaves 100% of variants

    mt = ld_prune(mt, genome_version=genome_version)

    ibd_results_ht = hl.identity_by_descent(mt,
                                            maf=mt.info.AF,
                                            min=0.10,
                                            max=1.0)
    ibd_results_ht = ibd_results_ht.annotate(
        ibd0=ibd_results_ht.ibd.Z0,
        ibd1=ibd_results_ht.ibd.Z1,
        ibd2=ibd_results_ht.ibd.Z2,
        pi_hat=ibd_results_ht.ibd.PI_HAT).drop("ibs0", "ibs1", "ibs2", "ibd")

    kin_ht = ibd_results_ht

    # filter to anything above the relationship of a grandparent
    first_degree_pi_hat = .40
    grandparent_pi_hat = .20
    grandparent_ibd1 = 0.25
    grandparent_ibd2 = 0.15

    kin_ht = kin_ht.key_by("i", "j")
    kin_ht = kin_ht.filter((kin_ht.pi_hat > first_degree_pi_hat) | (
        (kin_ht.pi_hat > grandparent_pi_hat) & (kin_ht.ibd1 > grandparent_ibd1)
        & (kin_ht.ibd2 < grandparent_ibd2)))

    kin_ht = kin_ht.annotate(relation=hl.sorted([kin_ht.i, kin_ht.j
                                                 ]))  #better variable name

    return kin_ht
def prepare_exac_regional_missense_constraint(path):
    ds = hl.import_table(
        path,
        missing="",
        types={
            "transcript": hl.tstr,
            "gene": hl.tstr,
            "chr": hl.tstr,
            "amino_acids": hl.tstr,
            "genomic_start": hl.tint,
            "genomic_end": hl.tint,
            "obs_mis": hl.tfloat,
            "exp_mis": hl.tfloat,
            "obs_exp": hl.tfloat,
            "chisq_diff_null": hl.tfloat,
            "region_name": hl.tstr,
        },
    )

    ds = ds.annotate(obs_mis=hl.int(ds.obs_mis))

    ds = ds.annotate(start=hl.min(ds.genomic_start, ds.genomic_end),
                     stop=hl.max(ds.genomic_start, ds.genomic_end))

    ds = ds.drop("amino_acids", "chr", "gene", "genomic_start", "genomic_end",
                 "region_name")

    ds = ds.transmute(transcript_id=ds.transcript.split("\\.")[0])

    ds = ds.group_by("transcript_id").aggregate(
        regions=hl.agg.collect(ds.row_value))

    ds = ds.annotate(
        regions=hl.sorted(ds.regions, lambda region: region.start))

    ds = ds.select(exac_regional_missense_constraint_regions=ds.regions)

    return ds
Exemple #18
0
def add_popmax_expr(freq: hl.expr.ArrayExpression) -> hl.expr.ArrayExpression:
    """
    Calculates popmax (add an additional entry into freq with popmax: pop)

    :param ArrayExpression freq: ArrayExpression of Structs with ['ac', 'an', 'hom', 'meta']
    :return: Frequency data with annotated popmax
    :rtype: ArrayExpression
    """
    freq_filtered = hl.filter(
        lambda x:
        (x.meta.keys() == ['population']) & (x.meta['population'] != 'oth'),
        freq)
    sorted_freqs = hl.sorted(freq_filtered,
                             key=lambda x: x.ac / x.an,
                             reverse=True)
    return hl.cond(
        hl.len(sorted_freqs) > 0,
        freq.append(
            hl.struct(ac=sorted_freqs[0].ac,
                      an=sorted_freqs[0].an,
                      hom=sorted_freqs[0].hom,
                      meta={'popmax': sorted_freqs[0].meta['population']})),
        freq)
Exemple #19
0
def process_consequences(
    mt: Union[hl.MatrixTable, hl.Table],
    vep_root: str = "vep",
    penalize_flags: bool = True,
) -> Union[hl.MatrixTable, hl.Table]:
    """
    Adds most_severe_consequence (worst consequence for a transcript) into [vep_root].transcript_consequences,
    and worst_csq_by_gene, any_lof into [vep_root]

    :param mt: Input MT
    :param vep_root: Root for vep annotation (probably vep)
    :param penalize_flags: Whether to penalize LOFTEE flagged variants, or treat them as equal to HC
    :return: MT with better formatted consequences
    """
    csqs = hl.literal(CSQ_ORDER)
    csq_dict = hl.literal(dict(zip(CSQ_ORDER, range(len(CSQ_ORDER)))))

    def find_worst_transcript_consequence(
        tcl: hl.expr.ArrayExpression, ) -> hl.expr.StructExpression:
        """
        Gets worst transcript_consequence from an array of em
        """
        flag_score = 500
        no_flag_score = flag_score * (1 + penalize_flags)

        def csq_score(tc):
            return csq_dict[csqs.find(
                lambda x: x == tc.most_severe_consequence)]

        tcl = tcl.map(
            lambda tc: tc.annotate(csq_score=hl.case(missing_false=True).when(
                (tc.lof == "HC") & (tc.lof_flags == ""),
                csq_score(tc) - no_flag_score,
            ).when(
                (tc.lof == "HC") & (tc.lof_flags != ""),
                csq_score(tc) - flag_score
            ).when(tc.lof == "OS",
                   csq_score(tc) - 20).when(
                       tc.lof == "LC",
                       csq_score(tc) - 10
                   ).when(tc.polyphen_prediction == "probably_damaging",
                          csq_score(tc) - 0.5).when(
                              tc.polyphen_prediction == "possibly_damaging",
                              csq_score(tc) - 0.25).when(
                                  tc.polyphen_prediction == "benign",
                                  csq_score(tc) - 0.1).default(csq_score(tc))))
        return hl.or_missing(
            hl.len(tcl) > 0,
            hl.sorted(tcl, lambda x: x.csq_score)[0])

    transcript_csqs = mt[vep_root].transcript_consequences.map(
        add_most_severe_consequence_to_consequence)

    gene_dict = transcript_csqs.group_by(lambda tc: tc.gene_symbol)
    worst_csq_gene = gene_dict.map_values(
        find_worst_transcript_consequence).values()
    sorted_scores = hl.sorted(worst_csq_gene, key=lambda tc: tc.csq_score)

    canonical = transcript_csqs.filter(lambda csq: csq.canonical == 1)
    gene_canonical_dict = canonical.group_by(lambda tc: tc.gene_symbol)
    worst_csq_gene_canonical = gene_canonical_dict.map_values(
        find_worst_transcript_consequence).values()
    sorted_canonical_scores = hl.sorted(worst_csq_gene_canonical,
                                        key=lambda tc: tc.csq_score)

    vep_data = mt[vep_root].annotate(
        transcript_consequences=transcript_csqs,
        worst_consequence_term=csqs.find(lambda c: transcript_csqs.map(
            lambda csq: csq.most_severe_consequence).contains(c)),
        worst_csq_by_gene=sorted_scores,
        worst_csq_for_variant=hl.or_missing(
            hl.len(sorted_scores) > 0, sorted_scores[0]),
        worst_csq_by_gene_canonical=sorted_canonical_scores,
        worst_csq_for_variant_canonical=hl.or_missing(
            hl.len(sorted_canonical_scores) > 0, sorted_canonical_scores[0]),
    )

    return (mt.annotate_rows(**{vep_root: vep_data}) if isinstance(
        mt, hl.MatrixTable) else mt.annotate(**{vep_root: vep_data}))
Exemple #20
0
def prepare_mitochondrial_variants(path, mnvs_path=None):
    ds = hl.read_table(path)

    haplogroups = hl.eval(ds.globals.hap_order)

    ds = ds.annotate(hl_hist=ds.hl_hist.annotate(
        bin_edges=ds.hl_hist.bin_edges.map(
            lambda n: hl.float(hl.format("%.2f", n)))))

    filter_names = hl.dict({
        "artifact_prone_site": "Artifact-prone site",
        "indel_stack": "Indel stack",
        "npg": "No passing genotype"
    })

    ds = ds.select(
        # ID
        variant_id=variant_id(ds.locus, ds.alleles),
        reference_genome=ds.locus.dtype.reference_genome.name,
        chrom=normalized_contig(ds.locus.contig),
        pos=ds.locus.position,
        ref=ds.alleles[0],
        alt=ds.alleles[1],
        rsid=ds.rsid,
        # Quality
        filters=ds.filters.map(lambda f: filter_names.get(f, f)),
        qual=ds.qual,
        genotype_quality_metrics=[
            hl.struct(name="Depth", alt=ds.dp_hist_alt, all=ds.dp_hist_all)
        ],
        genotype_quality_filters=[
            hl.struct(
                name="Base Quality",
                filtered=hl.struct(bin_edges=ds.hl_hist.bin_edges,
                                   bin_freq=ds.base_qual_hist),
            ),
            hl.struct(
                name="Contamination",
                filtered=hl.struct(bin_edges=ds.hl_hist.bin_edges,
                                   bin_freq=ds.contamination_hist),
            ),
            hl.struct(
                name="Heteroplasmy below 10%",
                filtered=hl.struct(
                    bin_edges=ds.hl_hist.bin_edges,
                    bin_freq=ds.heteroplasmy_below_10_percent_hist),
            ),
            hl.struct(name="Position",
                      filtered=hl.struct(bin_edges=ds.hl_hist.bin_edges,
                                         bin_freq=ds.position_hist)),
            hl.struct(
                name="Strand Bias",
                filtered=hl.struct(bin_edges=ds.hl_hist.bin_edges,
                                   bin_freq=ds.strand_bias_hist),
            ),
            hl.struct(
                name="Weak Evidence",
                filtered=hl.struct(bin_edges=ds.hl_hist.bin_edges,
                                   bin_freq=ds.weak_evidence_hist),
            ),
        ],
        site_quality_metrics=[
            hl.struct(name="Mean Depth", value=nullify_nan(ds.dp_mean)),
            hl.struct(name="Mean MQ", value=nullify_nan(ds.mq_mean)),
            hl.struct(name="Mean TLOD", value=nullify_nan(ds.tlod_mean)),
        ],
        # Frequency
        an=ds.AN,
        ac_hom=ds.AC_hom,
        ac_het=ds.AC_het,
        excluded_ac=ds.excluded_AC,
        # Heteroplasmy
        common_low_heteroplasmy=ds.common_low_heteroplasmy,
        heteroplasmy_distribution=ds.hl_hist,
        max_heteroplasmy=ds.max_hl,
        # Populations
        populations=hl.sorted(
            hl.range(hl.len(
                ds.globals.pop_order)).map(lambda pop_index: hl.struct(
                    id=ds.globals.pop_order[pop_index],
                    an=ds.pop_AN[pop_index],
                    ac_het=ds.pop_AC_het[pop_index],
                    ac_hom=ds.pop_AC_hom[pop_index],
                    heteroplasmy_distribution=hl.struct(
                        bin_edges=ds.hl_hist.bin_edges,
                        bin_freq=ds.pop_hl_hist[pop_index],
                        n_smaller=0,
                        n_larger=0,
                    ),
                )),
            key=lambda pop: pop.id,
        ),
        # Haplogroups
        hapmax_af_hom=ds.hapmax_AF_hom,
        hapmax_af_het=ds.hapmax_AF_het,
        faf_hapmax_hom=ds.faf_hapmax_hom,
        haplogroup_defining=ds.hap_defining_variant,
        haplogroups=[
            hl.struct(
                id=haplogroup,
                an=ds.hap_AN[i],
                ac_het=ds.hap_AC_het[i],
                ac_hom=ds.hap_AC_hom[i],
                faf_hom=ds.hap_faf_hom[i],
                heteroplasmy_distribution=ds.hap_hl_hist[i],
            ) for i, haplogroup in enumerate(haplogroups)
        ],
        # Other
        age_distribution=hl.struct(het=ds.age_hist_het, hom=ds.age_hist_hom),
        flags=hl.set([
            hl.or_missing(ds.common_low_heteroplasmy,
                          "common_low_heteroplasmy")
        ]).filter(hl.is_defined),
        mitotip_score=ds.mitotip_score,
        mitotip_trna_prediction=ds.mitotip_trna_prediction,
        pon_ml_probability_of_pathogenicity=ds.
        pon_ml_probability_of_pathogenicity,
        pon_mt_trna_prediction=ds.pon_mt_trna_prediction,
        variant_collapsed=ds.variant_collapsed,
        vep=ds.vep,
    )

    if mnvs_path:
        mnvs = hl.import_table(mnvs_path,
                               types={
                                   "pos": hl.tint,
                                   "ref": hl.tstr,
                                   "alt": hl.tstr,
                                   "AC_hom_MNV": hl.tint
                               })
        mnvs = mnvs.key_by(
            locus=hl.locus("chrM",
                           mnvs.pos,
                           reference_genome=ds.locus.dtype.reference_genome),
            alleles=[mnvs.ref, mnvs.alt],
        )
        ds = ds.annotate(ac_hom_mnv=hl.or_else(mnvs[ds.key].AC_hom_MNV, 0))
        ds = ds.annotate(
            flags=hl.if_else(ds.ac_hom_mnv > 0, ds.flags.add("mnv"), ds.flags))

    return ds
Exemple #21
0
def sorted_transcript_consequences_v2(vep_root):
    """Sort transcripts by 3 properties:

        1. coding > non-coding
        2. transcript consequence severity
        3. canonical > non-canonical

    so that the 1st array entry will be for the coding, most-severe, canonical transcript (assuming
    one exists).

    Also, for each transcript in the array, computes these additional fields:
        domains: converts structs with db/name fields to string db:name
        hgvs: hgvsp (formatted for synonymous variants) if it exists, otherwise hgvsc
        major_consequence: set to most severe consequence for that transcript (
            VEP sometimes provides multiple consequences for a single transcript)
        major_consequence_rank: major_consequence rank based on VEP SO ontology (most severe = 1)
            (see http://www.ensembl.org/info/genome/variation/predicted_data.html)
        category: set to one of: "lof", "missense", "synonymous", "other" based on the value of major_consequence.

    Args:
        vep_root (StructExpression): root path of the VEP struct in the MT
    """

    consequences = (vep_root.transcript_consequences.map(
        lambda c: c.annotate(consequence_terms=c.consequence_terms.filter(
            lambda t: ~OMIT_CONSEQUENCE_TERMS.contains(t)))
    ).filter(lambda c: c.consequence_terms.size() > 0).map(
        lambda c: c.annotate(major_consequence=hl.sorted(
            c.consequence_terms, key=consequence_term_rank)[0])
    ).map(lambda c: c.annotate(
        category=(hl.case().when(
            consequence_term_rank(c.major_consequence) <=
            consequence_term_rank("frameshift_variant"), "lof").when(
                consequence_term_rank(c.major_consequence) <=
                consequence_term_rank("missense_variant"),
                "missense",
            ).when(
                consequence_term_rank(c.major_consequence) <=
                consequence_term_rank("synonymous_variant"),
                "synonymous",
            ).default("other")),
        domains=c.domains.map(lambda domain: domain.db + ":" + domain.name),
        hgvs=hl.cond(
            hl.is_missing(c.hgvsp) | SPLICE_CONSEQUENCES.contains(
                c.major_consequence),
            c.hgvsc.split(":")[-1],
            hgvsp_from_consequence_amino_acids(c),
        ),
        major_consequence_rank=consequence_term_rank(c.major_consequence),
    )))

    consequences = hl.sorted(
        consequences,
        lambda c: (hl.bind(
            lambda is_coding, is_most_severe, is_canonical: (hl.cond(
                is_coding,
                hl.cond(is_most_severe, hl.cond(is_canonical, 1, 2),
                        hl.cond(is_canonical, 3, 4)),
                hl.cond(is_most_severe, hl.cond(is_canonical, 5, 6),
                        hl.cond(is_canonical, 7, 8)),
            )),
            hl.or_else(c.biotype, "") == "protein_coding",
            hl.set(c.consequence_terms).contains(vep_root.
                                                 most_severe_consequence),
            hl.or_else(c.canonical, 0) == 1,
        )),
    )

    consequences = hl.zip_with_index(consequences).map(
        lambda csq_with_index: csq_with_index[1].annotate(transcript_rank=
                                                          csq_with_index[0]))

    # TODO: Discard most of lof_info field
    # Keep whether lof_info contains DONOR_DISRUPTION, ACCEPTOR_DISRUPTION, or DE_NOVO_DONOR
    consequences = consequences.map(lambda c: c.select(
        "amino_acids",
        "biotype",
        "canonical",
        "category",
        "cdna_end",
        "cdna_start",
        "codons",
        "consequence_terms",
        "domains",
        "gene_id",
        "gene_symbol",
        "hgvs",
        "hgvsc",
        "hgvsp",
        "lof_filter",
        "lof_flags",
        "lof_info",
        "lof",
        "major_consequence",
        "major_consequence_rank",
        "polyphen_prediction",
        "protein_id",
        "protein_start",
        "sift_prediction",
        "transcript_id",
        "transcript_rank",
    ))

    return consequences
def import_mnv_file(path, **kwargs):
    column_types = {
        "AC_mnv_ex": hl.tint,
        "AC_mnv_gen": hl.tint,
        "AC_mnv": hl.tint,
        "AC_snp1_ex": hl.tint,
        "AC_snp1_gen": hl.tint,
        "AC_snp1": hl.tint,
        "AC_snp2_ex": hl.tint,
        "AC_snp2_gen": hl.tint,
        "AC_snp2": hl.tint,
        "AN_snp1_ex": hl.tfloat,
        "AN_snp1_gen": hl.tfloat,
        "AN_snp2_ex": hl.tfloat,
        "AN_snp2_gen": hl.tfloat,
        "categ": hl.tstr,
        "filter_snp1_ex": hl.tarray(hl.tstr),
        "filter_snp1_gen": hl.tarray(hl.tstr),
        "filter_snp2_ex": hl.tarray(hl.tstr),
        "filter_snp2_gen": hl.tarray(hl.tstr),
        "gene_id": hl.tstr,
        "gene_name": hl.tstr,
        "locus.contig": hl.tstr,
        "locus.position": hl.tint,
        "mnv_amino_acids": hl.tstr,
        "mnv_codons": hl.tstr,
        "mnv_consequence": hl.tstr,
        "mnv_lof": hl.tstr,
        "mnv": hl.tstr,
        "n_homhom_ex": hl.tint,
        "n_homhom_gen": hl.tint,
        "n_homhom": hl.tint,
        "n_indv_ex": hl.tint,
        "n_indv_gen": hl.tint,
        "n_indv": hl.tint,
        "snp1_amino_acids": hl.tstr,
        "snp1_codons": hl.tstr,
        "snp1_consequence": hl.tstr,
        "snp1_lof": hl.tstr,
        "snp1": hl.tstr,
        "snp2_amino_acids": hl.tstr,
        "snp2_codons": hl.tstr,
        "snp2_consequence": hl.tstr,
        "snp2_lof": hl.tstr,
        "snp2": hl.tstr,
        "transcript_id": hl.tstr,
    }

    ds = hl.import_table(path,
                         key="mnv",
                         missing="",
                         types=column_types,
                         **kwargs)

    ds = ds.transmute(locus=hl.locus(ds["locus.contig"], ds["locus.position"]))

    ds = ds.transmute(
        contig=normalized_contig(ds.locus),
        pos=ds.locus.position,
        xpos=x_position(ds.locus),
    )

    ds = ds.annotate(ref=ds.mnv.split("-")[2],
                     alt=ds.mnv.split("-")[3],
                     variant_id=ds.mnv)

    ds = ds.annotate(snp1_copy=ds.snp1, snp2_copy=ds.snp2)
    ds = ds.transmute(constituent_snvs=[
        hl.bind(
            lambda variant_id_parts: hl.struct(
                variant_id=ds[f"{snp}_copy"],
                chrom=variant_id_parts[0],
                pos=hl.int(variant_id_parts[1]),
                ref=variant_id_parts[2],
                alt=variant_id_parts[3],
                exome=hl.or_missing(
                    hl.is_defined(ds[f"AN_{snp}_ex"]),
                    hl.struct(
                        filters=ds[f"filter_{snp}_ex"],
                        ac=ds[f"AC_{snp}_ex"],
                        an=hl.int(ds[f"AN_{snp}_ex"]),
                    ),
                ),
                genome=hl.or_missing(
                    hl.is_defined(ds[f"AN_{snp}_gen"]),
                    hl.struct(
                        filters=ds[f"filter_{snp}_gen"],
                        ac=ds[f"AC_{snp}_gen"],
                        an=hl.int(ds[f"AN_{snp}_gen"]),
                    ),
                ),
            ),
            ds[f"{snp}_copy"].split("-"),
        ) for snp in ["snp1", "snp2"]
    ])

    ds = ds.annotate(constituent_snv_ids=[ds.snp1, ds.snp2])

    ds = ds.annotate(
        mnv_in_exome=ds.constituent_snvs.all(lambda s: hl.is_defined(s.exome)),
        mnv_in_genome=ds.constituent_snvs.all(
            lambda s: hl.is_defined(s.genome)),
    )

    ds = ds.transmute(
        n_individuals=ds.n_indv,
        ac=ds.AC_mnv,
        ac_hom=ds.n_homhom,
        exome=hl.or_missing(
            ds.mnv_in_exome,
            hl.struct(n_individuals=ds.n_indv_ex,
                      ac=ds.AC_mnv_ex,
                      ac_hom=ds.n_homhom_ex),
        ),
        genome=hl.or_missing(
            ds.mnv_in_genome,
            hl.struct(n_individuals=ds.n_indv_gen,
                      ac=ds.AC_mnv_gen,
                      ac_hom=ds.n_homhom_gen),
        ),
    )

    ds = ds.drop("AC_snp1", "AC_snp2")

    ds = ds.transmute(consequence=hl.struct(
        category=ds.categ,
        gene_id=ds.gene_id,
        gene_name=ds.gene_name,
        transcript_id=ds.transcript_id,
        consequence=ds.mnv_consequence,
        codons=ds.mnv_codons,
        amino_acids=ds.mnv_amino_acids,
        lof=ds.mnv_lof,
        snv_consequences=[
            hl.struct(
                variant_id=ds[f"{snp}"],
                amino_acids=ds[f"{snp}_amino_acids"],
                codons=ds[f"{snp}_codons"],
                consequence=ds[f"{snp}_consequence"],
                lof=ds[f"{snp}_lof"],
            ) for snp in ["snp1", "snp2"]
        ],
    ))

    # Collapse table to one row per MNV, with all consequences for the MNV collected into an array
    consequences = ds.group_by(
        ds.mnv).aggregate(consequences=hl.agg.collect(ds.consequence))
    ds = ds.drop("consequence")
    ds = ds.distinct()
    ds = ds.join(consequences)

    # Sort consequences by severity
    ds = ds.annotate(consequences=hl.sorted(
        ds.consequences,
        key=lambda c: consequence_term_rank(c.consequence),
    ))

    ds = ds.annotate(changes_amino_acids_for_snvs=hl.literal([0, 1]).filter(
        lambda idx: ds.consequences.any(lambda csq: csq.snv_consequences[
            idx].amino_acids.lower() != csq.amino_acids.lower())).map(
                lambda idx: ds.constituent_snv_ids[idx]))

    return ds
def prepare_gnomad_v2_variants_helper(path, exome_or_genome):
    ds = hl.read_table(path)

    ###############
    # Frequencies #
    ###############

    g = hl.eval(ds.globals)

    subsets = ["gnomad", "controls", "non_neuro", "non_topmed"] + (["non_cancer"] if exome_or_genome == "exome" else [])

    ds = ds.select_globals()

    ds = ds.annotate(
        freq=hl.struct(
            **{
                subset: hl.struct(
                    ac=ds.freq[g.freq_index_dict[subset]].AC,
                    ac_raw=ds.freq[g.freq_index_dict[f"{subset}_raw"]].AC,
                    an=ds.freq[g.freq_index_dict[subset]].AN,
                    hemizygote_count=hl.if_else(ds.nonpar, ds.freq[g.freq_index_dict[f"{subset}_male"]].AC, 0),
                    homozygote_count=ds.freq[g.freq_index_dict[subset]].homozygote_count,
                    populations=population_frequencies_expression(ds, g.freq_index_dict, subset),
                )
                for subset in subsets
            }
        )
    )

    # If a variant is not present in a subset, do not store population frequencies for that subset
    ds = ds.annotate(
        freq=ds.freq.annotate(
            **{
                subset: ds.freq[subset].annotate(
                    populations=hl.if_else(
                        ds.freq[subset].ac_raw == 0,
                        hl.empty_array(ds.freq[subset].populations.dtype.element_type),
                        ds.freq[subset].populations,
                    )
                )
                for subset in subsets
            }
        )
    )

    ###########################################
    # Subsets in which the variant is present #
    ###########################################

    ds = ds.annotate(
        subsets=hl.set(
            hl.array([(subset, ds.freq[subset].ac_raw > 0) for subset in subsets])
            .filter(lambda t: t[1])
            .map(lambda t: t[0])
        )
    )

    if exome_or_genome == "genome":
        ds = ds.annotate(subsets=ds.subsets.add("non_cancer"))

    ##############################
    # Filtering allele frequency #
    ##############################

    ds = ds.annotate(
        freq=ds.freq.annotate(
            **{
                subset: ds.freq[subset].annotate(
                    faf95=hl.rbind(
                        hl.sorted(
                            hl.array(
                                [
                                    hl.struct(
                                        faf=ds.faf[g.faf_index_dict[f"{subset}_{pop_id}"]].faf95, population=pop_id,
                                    )
                                    for pop_id in (
                                        ["afr", "amr", "eas", "nfe"] + (["sas"] if exome_or_genome == "exome" else [])
                                    )
                                ]
                            ).filter(lambda f: f.faf > 0),
                            key=lambda f: (-f.faf, f.population),
                        ),
                        lambda fafs: hl.if_else(
                            hl.len(fafs) > 0,
                            hl.struct(popmax=fafs[0].faf, popmax_population=fafs[0].population,),
                            hl.struct(popmax=hl.null(hl.tfloat), popmax_population=hl.null(hl.tstr),),
                        ),
                    ),
                    faf99=hl.rbind(
                        hl.sorted(
                            hl.array(
                                [
                                    hl.struct(
                                        faf=ds.faf[g.faf_index_dict[f"{subset}_{pop_id}"]].faf99, population=pop_id,
                                    )
                                    for pop_id in (
                                        ["afr", "amr", "eas", "nfe"] + (["sas"] if exome_or_genome == "exome" else [])
                                    )
                                ]
                            ).filter(lambda f: f.faf > 0),
                            key=lambda f: (-f.faf, f.population),
                        ),
                        lambda fafs: hl.if_else(
                            hl.len(fafs) > 0,
                            hl.struct(popmax=fafs[0].faf, popmax_population=fafs[0].population,),
                            hl.struct(popmax=hl.null(hl.tfloat), popmax_population=hl.null(hl.tstr),),
                        ),
                    ),
                )
                for subset in (
                    ["gnomad", "controls", "non_neuro", "non_topmed"]
                    + (["non_cancer"] if exome_or_genome == "exome" else [])
                )
            }
        ),
    )

    ds = ds.drop("faf")

    ####################
    # Age distribution #
    ####################

    # Format age distributions
    ds = ds.transmute(
        age_distribution=hl.struct(
            **{
                subset: hl.struct(het=ds.age_hist_het[index], hom=ds.age_hist_hom[index],)
                for subset, index in g.age_index_dict.items()
            },
        )
    )

    ###################
    # Quality metrics #
    ###################

    ds = ds.transmute(
        quality_metrics=hl.struct(
            allele_balance=hl.struct(
                alt_raw=ds.ab_hist_alt.annotate(
                    bin_edges=ds.ab_hist_alt.bin_edges.map(lambda n: hl.float(hl.format("%.3f", n)))
                )
            ),
            genotype_depth=hl.struct(all_raw=ds.dp_hist_all, alt_raw=ds.dp_hist_alt),
            genotype_quality=hl.struct(all_raw=ds.gq_hist_all, alt_raw=ds.gq_hist_alt),
            # Use the same fields as the VCFs
            # Based https://github.com/macarthur-lab/gnomad_qc/blob/25a81bc2166fbe4ccbb2f7a87d36aba661150413/variant_qc/prepare_data_release.py#L128-L159
            site_quality_metrics=[
                hl.struct(metric="BaseQRankSum", value=ds.allele_info.BaseQRankSum),
                hl.struct(metric="ClippingRankSum", value=ds.allele_info.ClippingRankSum),
                hl.struct(metric="DP", value=hl.float(ds.allele_info.DP)),
                hl.struct(metric="FS", value=ds.info_FS),
                hl.struct(metric="InbreedingCoeff", value=ds.info_InbreedingCoeff),
                hl.struct(metric="MQ", value=ds.info_MQ),
                hl.struct(metric="MQRankSum", value=ds.info_MQRankSum),
                hl.struct(metric="pab_max", value=ds.pab_max),
                hl.struct(metric="QD", value=ds.info_QD),
                hl.struct(metric="ReadPosRankSum", value=ds.info_ReadPosRankSum),
                hl.struct(metric="RF", value=ds.rf_probability),
                hl.struct(metric="SiteQuality", value=ds.qual),
                hl.struct(metric="SOR", value=ds.info_SOR),
                hl.struct(metric="VQSLOD", value=ds.allele_info.VQSLOD),
                hl.struct(metric="VQSR_NEGATIVE_TRAIN_SITE", value=hl.float(ds.info_NEGATIVE_TRAIN_SITE)),
                hl.struct(metric="VQSR_POSITIVE_TRAIN_SITE", value=hl.float(ds.info_POSITIVE_TRAIN_SITE)),
            ],
        )
    )

    #################
    # Unused fields #
    #################

    ds = ds.drop(
        "adj_biallelic_rank",
        "adj_biallelic_singleton_rank",
        "adj_rank",
        "adj_singleton_rank",
        "allele_type",
        "biallelic_rank",
        "biallelic_singleton_rank",
        "has_star",
        "info_DP",
        "mills",
        "n_alt_alleles",
        "n_nonref",
        "omni",
        "popmax",
        "qd",
        "rank",
        "score",
        "singleton_rank",
        "singleton",
        "transmitted_singleton",
        "variant_type",
        "was_mixed",
        "was_split",
    )

    # These two fields appear only in the genomes table
    if "_score" in ds.row_value.dtype.fields:
        ds = ds.drop("_score", "_singleton")

    ds = ds.select(**{exome_or_genome: ds.row_value})

    return ds
def MAF(AF_list):
    AF_list = AF_list.append(1 - hl.sum(AF_list))
    return (hl.sorted(AF_list)[-2])
def main(args):
    data_type = 'exomes' if args.exomes else 'genomes'

    if args.pbt_tm:
        mt = get_gnomad_data(data_type, split=False)
        meta = mt.cols()
        hq_samples = meta.aggregate(
            hl.agg.filter(meta.meta.high_quality, hl.agg.collect(meta.s)))
        ped = hl.Pedigree.read(fam_path(data_type),
                               delimiter='\\t').filter_to(hq_samples)
        ped_samples = hl.literal(
            set([
                s for trio in ped.complete_trios()
                for s in [trio.s, trio.pat_id, trio.mat_id]
            ]))

        mt = mt.filter_cols(ped_samples.contains(mt.s))
        mt = mt.select_cols().select_rows()
        mt = mt.filter_rows(hl.agg.any(mt.GT.is_non_ref()))

        tm = hl.trio_matrix(mt, ped, complete_trios=True)
        tm = hl.experimental.phase_trio_matrix_by_transmission(tm)
        tm.write(pbt_phased_trios_mt_path(data_type,
                                          split=False,
                                          trio_matrix=True),
                 overwrite=args.overwrite)

    if args.pbt_explode:
        tm = hl.read_matrix_table(
            pbt_phased_trios_mt_path(data_type, split=False, trio_matrix=True))

        tm = tm.annotate_entries(trio_adj=tm.proband_entry.adj
                                 & tm.father_entry.adj & tm.mother_entry.adj)
        pmt = explode_trio_matrix(tm, keep_trio_entries=True)
        pmt = pmt.transmute_entries(trio_adj=pmt.source_trio_entry.trio_adj)
        pmt.write(pbt_phased_trios_mt_path(data_type, split=False),
                  overwrite=args.overwrite)

        pmt = hl.read_matrix_table(
            pbt_phased_trios_mt_path(data_type, split=False))
        pmt = pmt.rename({'PBT_GT':
                          'PGT'})  # ugly but supported by hl.split_multi_hts
        pmt = hl.split_multi_hts(pmt)
        pmt = pmt.rename({'PGT': 'PBT_GT'})
        pmt.write(pbt_phased_trios_mt_path(data_type),
                  overwrite=args.overwrite)

    if args.phase_multi_families:
        pbt = hl.read_matrix_table(pbt_phased_trios_mt_path(data_type))
        # Keep samples that:
        # 1. There are more than one entry in the Matrix (i.e. they are part of multiple trios)
        # 2. In all their entries, the parents are the same (there are only two exceptions to this, so best to ignore these and focus on parents/multi-offspring families)
        nt_samples = pbt.cols()
        nt_samples = nt_samples.group_by('s').aggregate(
            trios=hl.agg.collect(nt_samples.source_trio))
        nt_samples = nt_samples.filter(
            (hl.len(nt_samples.trios) > 1) &
            nt_samples.trios[1:].any(lambda x: (x.mother.s != nt_samples.trios[
                0].mother.s) | (x.father.s != nt_samples.trios[0].father.s)),
            keep=False)
        pbt = pbt.filter_cols(hl.is_defined(nt_samples[pbt.col_key]))

        # Group cols for these samples, keeping all GTs in an array
        # Compute the consensus GT (incl. phase) + QC metrics based on (a) phased genotypes have priority, (b) genotypes with most votes
        pbt = pbt.group_cols_by('s').aggregate(PBT_GTs=hl.agg.filter(
            hl.is_defined(pbt.GT), hl.agg.collect(pbt.GT)))
        gt_counter = hl.sorted(hl.array(
            pbt.PBT_GTs.group_by(lambda x: x).map_values(lambda x: hl.len(x))),
                               key=lambda x: x[0].phased * 100 + x[1],
                               reverse=True)
        phased_gt_counts = gt_counter.filter(lambda x: x[0].phased).map(
            lambda x: x[1])
        pbt = pbt.annotate_entries(
            consensus_gt=gt_counter.map(lambda x: x[0]).find(lambda x: True),
            phase_concordance=phased_gt_counts.find(lambda x: True) /
            hl.sum(phased_gt_counts),
            discordant_gts=hl.len(
                hl.set(
                    pbt.PBT_GTs.map(lambda x: hl.cond(
                        x.phased, hl.call(x[0], x[1]), x)))) > 1)
        pbt.write('gs://gnomad/projects/compound_hets/pbt_multi_families.mt')
Exemple #26
0
def locus_windows(locus_expr, radius, coord_expr=None, _localize=True):
    """Returns start and stop indices for window around each locus.

    Examples
    --------

    Windows with 2bp radius for one contig with positions 1, 2, 3, 4, 5:

    >>> starts, stops = hl.linalg.utils.locus_windows(
    ...     hl.balding_nichols_model(1, 5, 5).locus,
    ...     radius=2)
    >>> starts, stops
    (array([0, 0, 0, 1, 2]), array([3, 4, 5, 5, 5]))

    The following examples involve three contigs.

    >>> loci = [{'locus': hl.Locus('1', 1), 'cm': 1.0},
    ...         {'locus': hl.Locus('1', 2), 'cm': 3.0},
    ...         {'locus': hl.Locus('1', 4), 'cm': 4.0},
    ...         {'locus': hl.Locus('2', 1), 'cm': 2.0},
    ...         {'locus': hl.Locus('2', 1), 'cm': 2.0},
    ...         {'locus': hl.Locus('3', 3), 'cm': 5.0}]

    >>> ht = hl.Table.parallelize(
    ...         loci,
    ...         hl.tstruct(locus=hl.tlocus('GRCh37'), cm=hl.tfloat64),
    ...         key=['locus'])

    Windows with 1bp radius:

    >>> hl.linalg.utils.locus_windows(ht.locus, 1)
    (array([0, 0, 2, 3, 3, 5]), array([2, 2, 3, 5, 5, 6]))

    Windows with 1cm radius:

    >>> hl.linalg.utils.locus_windows(ht.locus, 1.0, coord_expr=ht.cm)
    (array([0, 1, 1, 3, 3, 5]), array([1, 3, 3, 5, 5, 6]))

    Notes
    -----
    This function returns two 1-dimensional ndarrays of integers,
    ``starts`` and ``stops``, each of size equal to the number of rows.

    By default, for all indices ``i``, ``[starts[i], stops[i])`` is the maximal
    range of row indices ``j`` such that ``contig[i] == contig[j]`` and
    ``position[i] - radius <= position[j] <= position[i] + radius``.

    If the :meth:`.global_position` on `locus_expr` is not in ascending order,
    this method will fail. Ascending order should hold for a matrix table keyed
    by locus or variant (and the associated row table), or for a table that has
    been ordered by `locus_expr`.

    Set `coord_expr` to use a value other than position to define the windows.
    This row-indexed numeric expression must be non-missing, non-``nan``, on the
    same source as `locus_expr`, and ascending with respect to locus
    position for each contig; otherwise the function will fail.

    The last example above uses centimorgan coordinates, so
    ``[starts[i], stops[i])`` is the maximal range of row indices ``j`` such
    that ``contig[i] == contig[j]`` and
    ``cm[i] - radius <= cm[j] <= cm[i] + radius``.

    Index ranges are start-inclusive and stop-exclusive. This function is
    especially useful in conjunction with
    :meth:`.BlockMatrix.sparsify_row_intervals`.

    Parameters
    ----------
    locus_expr : :class:`.LocusExpression`
        Row-indexed locus expression on a table or matrix table.
    radius: :obj:`int`
        Radius of window for row values.
    coord_expr: :class:`.Float64Expression`, optional
        Row-indexed numeric expression for the row value.
        Must be on the same table or matrix table as `locus_expr`.
        By default, the row value is given by the locus position.

    Returns
    -------
    (:class:`ndarray` of :obj:`int64`, :class:`ndarray` of :obj:`int64`)
        Tuple of start indices array and stop indices array.
    """
    if radius < 0:
        raise ValueError(f"locus_windows: 'radius' must be non-negative, found {radius}")
    check_row_indexed('locus_windows', locus_expr)
    if coord_expr is not None:
        check_row_indexed('locus_windows', coord_expr)

    src = locus_expr._indices.source
    if locus_expr not in src._fields_inverse:
        locus = Env.get_uid()
        annotate_fields = {locus: locus_expr}

        if coord_expr is not None:
            if coord_expr not in src._fields_inverse:
                coords = Env.get_uid()
                annotate_fields[coords] = coord_expr
            else:
                coords = src._fields_inverse[coord_expr]

        if isinstance(src, hl.MatrixTable):
            new_src = src.annotate_rows(**annotate_fields)
        else:
            new_src = src.annotate(**annotate_fields)

        locus_expr = new_src[locus]
        if coord_expr is not None:
            coord_expr = new_src[coords]

    if coord_expr is None:
        coord_expr = locus_expr.position

    rg = locus_expr.dtype.reference_genome
    contig_group_expr = hl.agg.group_by(hl.locus(locus_expr.contig, 1, reference_genome=rg), hl.agg.collect(coord_expr))

    # check loci are in sorted order
    last_pos = hl.fold(lambda a, elt: (hl.case()
                                         .when(a <= elt, elt)
                                         .or_error("locus_windows: 'locus_expr' global position must be in ascending order.")),
                       -1,
                       hl.agg.collect(hl.case()
                                        .when(hl.is_defined(locus_expr), locus_expr.global_position())
                                        .or_error("locus_windows: missing value for 'locus_expr'.")))
    checked_contig_groups = (hl.case()
                               .when(last_pos >= 0, contig_group_expr)
                               .or_error("locus_windows: 'locus_expr' has length 0"))

    contig_groups = locus_expr._aggregation_method()(checked_contig_groups, _localize=False)

    coords = hl.sorted(hl.array(contig_groups)).map(lambda t: t[1])
    starts_and_stops = hl._locus_windows_per_contig(coords, radius)

    if not _localize:
        return starts_and_stops

    starts, stops = hl.eval(starts_and_stops)
    return np.array(starts), np.array(stops)
def get_gold_stars(review_status):
    review_status_str = hl.delimit(hl.sorted(review_status, key=lambda s: s.replace("^_", "z")))
    return CLINVAR_GOLD_STARS[review_status_str]
Exemple #28
0
print("\n=== Processing ===")
mt = mt.annotate_rows(sortedTranscriptConsequences=
                      get_expr_for_vep_sorted_transcript_consequences_array(
                          vep_root=mt.vep))

mt = mt.annotate_rows(
    main_transcript=
    get_expr_for_worst_transcript_consequence_annotations_struct(
        vep_sorted_transcript_consequences_root=mt.sortedTranscriptConsequences
    ))

mt = mt.annotate_rows(gene_ids=get_expr_for_vep_gene_ids_set(
    vep_transcript_consequences_root=mt.sortedTranscriptConsequences), )

review_status_str = hl.delimit(
    hl.sorted(hl.array(hl.set(mt.info.CLNREVSTAT)),
              key=lambda s: s.replace("^_", "z")))

mt = mt.select_rows(
    allele_id=mt.info.ALLELEID,
    alt=get_expr_for_alt_allele(mt),
    chrom=get_expr_for_contig(mt.locus),
    clinical_significance=hl.delimit(
        hl.sorted(hl.array(hl.set(mt.info.CLNSIG)),
                  key=lambda s: s.replace("^_", "z"))),
    domains=get_expr_for_vep_protein_domains_set(
        vep_transcript_consequences_root=mt.vep.transcript_consequences),
    gene_ids=mt.gene_ids,
    gene_id_to_consequence_json=get_expr_for_vep_gene_id_to_consequence_map(
        vep_sorted_transcript_consequences_root=mt.
        sortedTranscriptConsequences,
        gene_ids=mt.gene_ids),
Exemple #29
0
def populate_clinvar():

    clinvar_release_date = _parse_clinvar_release_date('clinvar.vcf.gz')
    mt = import_vcf('clinvar.vcf.gz',
                    "38",
                    drop_samples=True,
                    min_partitions=2000,
                    skip_invalid_loci=True)
    mt = mt.annotate_globals(version=clinvar_release_date)

    print("\n=== Running VEP ===")
    mt = hl.vep(mt, 'vep85-loftee-ruddle-b38.json', name="vep")

    print("\n=== Processing ===")
    mt = mt.annotate_rows(
        sortedTranscriptConsequences=
        get_expr_for_vep_sorted_transcript_consequences_array(vep_root=mt.vep))

    mt = mt.annotate_rows(
        main_transcript=
        get_expr_for_worst_transcript_consequence_annotations_struct(
            vep_sorted_transcript_consequences_root=mt.
            sortedTranscriptConsequences))

    mt = mt.annotate_rows(gene_ids=get_expr_for_vep_gene_ids_set(
        vep_transcript_consequences_root=mt.sortedTranscriptConsequences), )

    review_status_str = hl.delimit(
        hl.sorted(hl.array(hl.set(mt.info.CLNREVSTAT)),
                  key=lambda s: s.replace("^_", "z")))

    mt = mt.select_rows(
        allele_id=mt.info.ALLELEID,
        alt=get_expr_for_alt_allele(mt),
        chrom=get_expr_for_contig(mt.locus),
        clinical_significance=hl.delimit(
            hl.sorted(hl.array(hl.set(mt.info.CLNSIG)),
                      key=lambda s: s.replace("^_", "z"))),
        domains=get_expr_for_vep_protein_domains_set(
            vep_transcript_consequences_root=mt.vep.transcript_consequences),
        gene_ids=mt.gene_ids,
        gene_id_to_consequence_json=get_expr_for_vep_gene_id_to_consequence_map(
            vep_sorted_transcript_consequences_root=mt.
            sortedTranscriptConsequences,
            gene_ids=mt.gene_ids),
        gold_stars=CLINVAR_GOLD_STARS_LOOKUP[review_status_str],
        **{
            f"main_transcript_{field}": mt.main_transcript[field]
            for field in mt.main_transcript.dtype.fields
        },
        pos=get_expr_for_start_pos(mt),
        ref=get_expr_for_ref_allele(mt),
        review_status=review_status_str,
        transcript_consequence_terms=get_expr_for_vep_consequence_terms_set(
            vep_transcript_consequences_root=mt.sortedTranscriptConsequences),
        transcript_ids=get_expr_for_vep_transcript_ids_set(
            vep_transcript_consequences_root=mt.sortedTranscriptConsequences),
        transcript_id_to_consequence_json=
        get_expr_for_vep_transcript_id_to_consequence_map(
            vep_transcript_consequences_root=mt.sortedTranscriptConsequences),
        variant_id=get_expr_for_variant_id(mt),
        xpos=get_expr_for_xpos(mt.locus),
    )

    print("\n=== Summary ===")
    hl.summarize_variants(mt)

    # Drop key columns for export
    rows = mt.rows()
    rows = rows.order_by(rows.variant_id).drop("locus", "alleles")
    rows.write('clinvar.ht', overwrite=True)
    '''
Exemple #30
0
def locus_windows(locus_expr, radius, coord_expr=None, _localize=True):
    """Returns start and stop indices for window around each locus.

    Examples
    --------

    Windows with 2bp radius for one contig with positions 1, 2, 3, 4, 5:

    >>> starts, stops = hl.linalg.utils.locus_windows(
    ...     hl.balding_nichols_model(1, 5, 5).locus,
    ...     radius=2)
    >>> starts, stops
    (array([0, 0, 0, 1, 2]), array([3, 4, 5, 5, 5]))

    The following examples involve three contigs.

    >>> loci = [{'locus': hl.Locus('1', 1), 'cm': 1.0},
    ...         {'locus': hl.Locus('1', 2), 'cm': 3.0},
    ...         {'locus': hl.Locus('1', 4), 'cm': 4.0},
    ...         {'locus': hl.Locus('2', 1), 'cm': 2.0},
    ...         {'locus': hl.Locus('2', 1), 'cm': 2.0},
    ...         {'locus': hl.Locus('3', 3), 'cm': 5.0}]

    >>> ht = hl.Table.parallelize(
    ...         loci,
    ...         hl.tstruct(locus=hl.tlocus('GRCh37'), cm=hl.tfloat64),
    ...         key=['locus'])

    Windows with 1bp radius:

    >>> hl.linalg.utils.locus_windows(ht.locus, 1)
    (array([0, 0, 2, 3, 3, 5]), array([2, 2, 3, 5, 5, 6]))

    Windows with 1cm radius:

    >>> hl.linalg.utils.locus_windows(ht.locus, 1.0, coord_expr=ht.cm)
    (array([0, 1, 1, 3, 3, 5]), array([1, 3, 3, 5, 5, 6]))

    Notes
    -----
    This function returns two 1-dimensional ndarrays of integers,
    ``starts`` and ``stops``, each of size equal to the number of rows.

    By default, for all indices ``i``, ``[starts[i], stops[i])`` is the maximal
    range of row indices ``j`` such that ``contig[i] == contig[j]`` and
    ``position[i] - radius <= position[j] <= position[i] + radius``.

    If the :meth:`.global_position` on `locus_expr` is not in ascending order,
    this method will fail. Ascending order should hold for a matrix table keyed
    by locus or variant (and the associated row table), or for a table that has
    been ordered by `locus_expr`.

    Set `coord_expr` to use a value other than position to define the windows.
    This row-indexed numeric expression must be non-missing, non-``nan``, on the
    same source as `locus_expr`, and ascending with respect to locus
    position for each contig; otherwise the function will fail.

    The last example above uses centimorgan coordinates, so
    ``[starts[i], stops[i])`` is the maximal range of row indices ``j`` such
    that ``contig[i] == contig[j]`` and
    ``cm[i] - radius <= cm[j] <= cm[i] + radius``.

    Index ranges are start-inclusive and stop-exclusive. This function is
    especially useful in conjunction with
    :meth:`.BlockMatrix.sparsify_row_intervals`.

    Parameters
    ----------
    locus_expr : :class:`.LocusExpression`
        Row-indexed locus expression on a table or matrix table.
    radius: :obj:`int`
        Radius of window for row values.
    coord_expr: :class:`.Float64Expression`, optional
        Row-indexed numeric expression for the row value.
        Must be on the same table or matrix table as `locus_expr`.
        By default, the row value is given by the locus position.

    Returns
    -------
    (:class:`ndarray` of :obj:`int64`, :class:`ndarray` of :obj:`int64`)
        Tuple of start indices array and stop indices array.
    """
    if radius < 0:
        raise ValueError(f"locus_windows: 'radius' must be non-negative, found {radius}")
    check_row_indexed('locus_windows', locus_expr)
    if coord_expr is not None:
        check_row_indexed('locus_windows', coord_expr)

    src = locus_expr._indices.source
    if locus_expr not in src._fields_inverse:
        locus = Env.get_uid()
        annotate_fields = {locus: locus_expr}

        if coord_expr is not None:
            if coord_expr not in src._fields_inverse:
                coords = Env.get_uid()
                annotate_fields[coords] = coord_expr
            else:
                coords = src._fields_inverse[coord_expr]

        if isinstance(src, hl.MatrixTable):
            new_src = src.annotate_rows(**annotate_fields)
        else:
            new_src = src.annotate(**annotate_fields)

        locus_expr = new_src[locus]
        if coord_expr is not None:
            coord_expr = new_src[coords]

    if coord_expr is None:
        coord_expr = locus_expr.position

    rg = locus_expr.dtype.reference_genome
    contig_group_expr = hl.agg.group_by(hl.locus(locus_expr.contig, 1, reference_genome=rg), hl.agg.collect(coord_expr))

    # check loci are in sorted order
    last_pos = hl.fold(lambda a, elt: (hl.case()
                                         .when(a <= elt, elt)
                                         .or_error("locus_windows: 'locus_expr' global position must be in ascending order.")),
                       -1,
                       hl.agg.collect(hl.case()
                                        .when(hl.is_defined(locus_expr), locus_expr.global_position())
                                        .or_error("locus_windows: missing value for 'locus_expr'.")))
    checked_contig_groups = (hl.case()
                               .when(last_pos >= 0, contig_group_expr)
                               .or_error("locus_windows: 'locus_expr' has length 0"))

    contig_groups = locus_expr._aggregation_method()(checked_contig_groups, _localize=False)

    coords = hl.sorted(hl.array(contig_groups)).map(lambda t: t[1])
    starts_and_stops = hl._locus_windows_per_contig(coords, radius)

    if not _localize:
        return starts_and_stops

    starts, stops = hl.eval(starts_and_stops)
    return np.array(starts), np.array(stops)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--input-url",
        help="URL of ExAC sites VCF",
        default=
        "gs://gnomad-public/legacy/exac_browser/ExAC.r1.sites.vep.vcf.gz")
    parser.add_argument("--output-url",
                        help="URL to write Hail table to",
                        required=True)
    parser.add_argument("--subset",
                        help="Filter variants to this chrom:start-end range")
    args = parser.parse_args()

    hl.init(log="/tmp/hail.log")

    print("\n=== Importing VCF ===")

    ds = hl.import_vcf(args.input_url,
                       force_bgz=True,
                       min_partitions=2000,
                       skip_invalid_loci=True).rows()

    if args.subset:
        print(f"\n=== Filtering to interval {args.subset} ===")
        subset_interval = hl.parse_locus_interval(args.subset)
        ds = ds.filter(subset_interval.contains(ds.locus))

    print("\n=== Splitting multiallelic variants ===")

    ds = hl.split_multi(ds)

    ds = ds.repartition(2000, shuffle=True)

    # Get value corresponding to the split variant
    ds = ds.annotate(info=ds.info.annotate(
        **{
            field: hl.or_missing(hl.is_defined(ds.info[field]), ds.info[field][
                ds.a_index - 1])
            for field in PER_ALLELE_FIELDS
        }))

    # For DP_HIST and GQ_HIST, the first value in the array contains the histogram for all individuals,
    # which is the same in each alt allele's variant.
    ds = ds.annotate(info=ds.info.annotate(
        DP_HIST=hl.struct(all=ds.info.DP_HIST[0],
                          alt=ds.info.DP_HIST[ds.a_index]),
        GQ_HIST=hl.struct(all=ds.info.GQ_HIST[0],
                          alt=ds.info.GQ_HIST[ds.a_index]),
    ))

    ds = ds.cache()

    print("\n=== Munging data ===")

    # Convert "NA" and empty strings into null values
    # Convert fields in chunks to avoid "Method code too large" errors
    for i in range(0, len(SELECT_INFO_FIELDS), 10):
        ds = ds.annotate(info=ds.info.annotate(
            **{
                field: hl.or_missing(
                    hl.is_defined(ds.info[field]),
                    hl.bind(
                        lambda value: hl.cond(
                            (value == "") | (value == "NA"),
                            hl.null(ds.info[field].dtype), ds.info[field]),
                        hl.str(ds.info[field]),
                    ),
                )
                for field in SELECT_INFO_FIELDS[i:i + 10]
            }))

    # Convert field types
    ds = ds.annotate(info=ds.info.annotate(
        **{
            field: hl.cond(ds.info[field] == "", hl.null(hl.tint),
                           hl.int(ds.info[field]))
            for field in CONVERT_TO_INT_FIELDS
        }))
    ds = ds.annotate(info=ds.info.annotate(
        **{
            field: hl.cond(ds.info[field] == "", hl.null(hl.tfloat),
                           hl.float(ds.info[field]))
            for field in CONVERT_TO_FLOAT_FIELDS
        }))

    # Format VEP annotations to mimic the output of hail.vep
    ds = ds.annotate(info=ds.info.annotate(CSQ=ds.info.CSQ.map(
        lambda s: s.replace("%3A", ":").replace("%3B", ";").replace(
            "%3D", "=").replace("%25", "%").replace("%2C", ","))))
    ds = ds.annotate(vep=hl.struct(
        transcript_consequences=ds.info.CSQ.map(lambda csq_str: hl.bind(
            lambda csq_values: hl.struct(
                **{
                    field: hl.cond(csq_values[index] == "", hl.null(hl.tstr),
                                   csq_values[index])
                    for index, field in enumerate(VEP_FIELDS)
                }),
            csq_str.split("\\|"),
        )).filter(lambda annotation: annotation.Feature.startswith("ENST")).
        filter(lambda annotation: hl.int(annotation.ALLELE_NUM) == ds.a_index).
        map(lambda annotation: annotation.select(
            amino_acids=annotation.Amino_acids,
            biotype=annotation.BIOTYPE,
            canonical=annotation.CANONICAL == "YES",
            # cDNA_position may contain either "start-end" or, when start == end, "start"
            cdna_start=split_position_start(annotation.cDNA_position),
            cdna_end=split_position_end(annotation.cDNA_position),
            codons=annotation.Codons,
            consequence_terms=annotation.Consequence.split("&"),
            distance=hl.int(annotation.DISTANCE),
            domains=hl.or_missing(
                hl.is_defined(annotation.DOMAINS),
                annotation.DOMAINS.split("&").map(lambda d: hl.struct(
                    db=d.split(":")[0], name=d.split(":")[1])),
            ),
            exon=annotation.EXON,
            gene_id=annotation.Gene,
            gene_symbol=annotation.SYMBOL,
            gene_symbol_source=annotation.SYMBOL_SOURCE,
            hgnc_id=annotation.HGNC_ID,
            hgvsc=annotation.HGVSc,
            hgvsp=annotation.HGVSp,
            lof=annotation.LoF,
            lof_filter=annotation.LoF_filter,
            lof_flags=annotation.LoF_flags,
            lof_info=annotation.LoF_info,
            # PolyPhen field contains "polyphen_prediction(polyphen_score)"
            polyphen_prediction=hl.or_missing(
                hl.is_defined(annotation.PolyPhen),
                annotation.PolyPhen.split("\\(")[0]),
            protein_id=annotation.ENSP,
            # Protein_position may contain either "start-end" or, when start == end, "start"
            protein_start=split_position_start(annotation.Protein_position),
            protein_end=split_position_end(annotation.Protein_position),
            # SIFT field contains "sift_prediction(sift_score)"
            sift_prediction=hl.or_missing(hl.is_defined(annotation.SIFT),
                                          annotation.SIFT.split("\\(")[0]),
            transcript_id=annotation.Feature,
        ))))

    ds = ds.annotate(vep=ds.vep.annotate(most_severe_consequence=hl.bind(
        lambda all_consequence_terms: hl.or_missing(
            all_consequence_terms.size() != 0,
            hl.sorted(all_consequence_terms, key=consequence_term_rank)[0]),
        ds.vep.transcript_consequences.flatmap(lambda c: c.consequence_terms),
    )))

    ds = ds.cache()

    print("\n=== Adding derived fields ===")

    ds = ds.annotate(
        sorted_transcript_consequences=sorted_transcript_consequences_v3(
            ds.vep))

    ds = ds.select(
        "filters",
        "qual",
        "rsid",
        "sorted_transcript_consequences",
        AC=ds.info.AC,
        AC_Adj=ds.info.AC_Adj,
        AC_Hemi=ds.info.AC_Hemi,
        AC_Hom=ds.info.AC_Hom,
        AF=ds.info.AF,
        AN=ds.info.AN,
        AN_Adj=ds.info.AN_Adj,
        BaseQRankSum=ds.info.BaseQRankSum,
        CCC=ds.info.CCC,
        ClippingRankSum=ds.info.ClippingRankSum,
        DB=ds.info.DB,
        DP=ds.info.DP,
        DS=ds.info.DS,
        END=ds.info.END,
        FS=ds.info.FS,
        GQ_MEAN=ds.info.GQ_MEAN,
        GQ_STDDEV=ds.info.GQ_STDDEV,
        HWP=ds.info.HWP,
        HaplotypeScore=ds.info.HaplotypeScore,
        InbreedingCoeff=ds.info.InbreedingCoeff,
        MLEAC=ds.info.MLEAC,
        MLEAF=ds.info.MLEAF,
        MQ=ds.info.MQ,
        MQ0=ds.info.MQ0,
        MQRankSum=ds.info.MQRankSum,
        NCC=ds.info.NCC,
        NEGATIVE_TRAIN_SITE=ds.info.NEGATIVE_TRAIN_SITE,
        POSITIVE_TRAIN_SITE=ds.info.POSITIVE_TRAIN_SITE,
        QD=ds.info.QD,
        ReadPosRankSum=ds.info.ReadPosRankSum,
        VQSLOD=ds.info.VQSLOD,
        culprit=ds.info.culprit,
        DP_HIST=ds.info.DP_HIST,
        GQ_HIST=ds.info.GQ_HIST,
        DOUBLETON_DIST=ds.info.DOUBLETON_DIST,
        AC_CONSANGUINEOUS=ds.info.AC_CONSANGUINEOUS,
        AN_CONSANGUINEOUS=ds.info.AN_CONSANGUINEOUS,
        Hom_CONSANGUINEOUS=ds.info.Hom_CONSANGUINEOUS,
        AGE_HISTOGRAM_HET=ds.info.AGE_HISTOGRAM_HET,
        AGE_HISTOGRAM_HOM=ds.info.AGE_HISTOGRAM_HOM,
        AC_POPMAX=ds.info.AC_POPMAX,
        AN_POPMAX=ds.info.AN_POPMAX,
        POPMAX=ds.info.POPMAX,
        K1_RUN=ds.info.K1_RUN,
        K2_RUN=ds.info.K2_RUN,
        K3_RUN=ds.info.K3_RUN,
        ESP_AF_POPMAX=ds.info.ESP_AF_POPMAX,
        ESP_AF_GLOBAL=ds.info.ESP_AF_GLOBAL,
        ESP_AC=ds.info.ESP_AC,
        KG_AF_POPMAX=ds.info.KG_AF_POPMAX,
        KG_AF_GLOBAL=ds.info.KG_AF_GLOBAL,
        KG_AC=ds.info.KG_AC,
        AC_FEMALE=ds.info.AC_FEMALE,
        AN_FEMALE=ds.info.AN_FEMALE,
        AC_MALE=ds.info.AC_MALE,
        AN_MALE=ds.info.AN_MALE,
        populations=hl.struct(
            **{
                pop_id: hl.struct(
                    AC=ds.info[f"AC_{pop_id}"],
                    AN=ds.info[f"AN_{pop_id}"],
                    hemi=ds.info[f"Hemi_{pop_id}"],
                    hom=ds.info[f"Hom_{pop_id}"],
                )
                for pop_id in
                ["AFR", "AMR", "EAS", "FIN", "NFE", "OTH", "SAS"]
            }),
        colocated_variants=hl.bind(
            lambda this_variant_id: variant_ids(ds.old_locus, ds.old_alleles).
            filter(lambda v_id: v_id != this_variant_id),
            variant_id(ds.locus, ds.alleles),
        ),
        variant_id=variant_id(ds.locus, ds.alleles),
        xpos=x_position(ds.locus),
    )

    print("\n=== Writing table ===")

    ds.write(args.output_url)