示例#1
0
def combine(ts):
    # pylint: disable=protected-access
    tmp = ts.annotate(
        alleles=merge_alleles(ts.data.map(lambda d: d.alleles)),
        rsid=hl.find(hl.is_defined, ts.data.map(lambda d: d.rsid)),
        filters=hl.set(hl.flatten(ts.data.map(lambda d: hl.array(d.filters)))),
        info=hl.struct(
            DP=hl.sum(ts.data.map(lambda d: d.info.DP)),
            MQ_DP=hl.sum(ts.data.map(lambda d: d.info.MQ_DP)),
            QUALapprox=hl.sum(ts.data.map(lambda d: d.info.QUALapprox)),
            RAW_MQ=hl.sum(ts.data.map(lambda d: d.info.RAW_MQ)),
            VarDP=hl.sum(ts.data.map(lambda d: d.info.VarDP)),
            SB=hl.array([
                hl.sum(ts.data.map(lambda d: d.info.SB[0])),
                hl.sum(ts.data.map(lambda d: d.info.SB[1])),
                hl.sum(ts.data.map(lambda d: d.info.SB[2])),
                hl.sum(ts.data.map(lambda d: d.info.SB[3]))
            ])))
    tmp = tmp.annotate(
        __entries=hl.bind(
            lambda combined_allele_index:
            hl.range(0, hl.len(tmp.data)).flatmap(
                lambda i:
                hl.cond(hl.is_missing(tmp.data[i].__entries),
                        hl.range(0, hl.len(tmp.g[i].__cols))
                          .map(lambda _: hl.null(tmp.data[i].__entries.dtype.element_type)),
                        hl.bind(
                            lambda old_to_new: tmp.data[i].__entries.map(lambda e: renumber_entry(e, old_to_new)),
                            hl.range(0, hl.len(tmp.data[i].alleles)).map(
                                lambda j: combined_allele_index[tmp.data[i].alleles[j]])))),
            hl.dict(hl.range(0, hl.len(tmp.alleles)).map(
                lambda j: hl.tuple([tmp.alleles[j], j])))))
    tmp = tmp.annotate_globals(__cols=hl.flatten(tmp.g.map(lambda g: g.__cols)))

    return tmp.drop('data', 'g')
示例#2
0
文件: helpers.py 项目: jigold/hail
def create_all_values():
    return hl.struct(
        f32=hl.float32(3.14),
        i64=hl.int64(-9),
        m=hl.null(hl.tfloat64),
        astruct=hl.struct(a=hl.null(hl.tint32), b=5.5),
        mstruct=hl.null(hl.tstruct(x=hl.tint32, y=hl.tstr)),
        aset=hl.set(['foo', 'bar', 'baz']),
        mset=hl.null(hl.tset(hl.tfloat64)),
        d=hl.dict({hl.array(['a', 'b']): 0.5, hl.array(['x', hl.null(hl.tstr), 'z']): 0.3}),
        md=hl.null(hl.tdict(hl.tint32, hl.tstr)),
        h38=hl.locus('chr22', 33878978, 'GRCh38'),
        ml=hl.null(hl.tlocus('GRCh37')),
        i=hl.interval(
            hl.locus('1', 999),
            hl.locus('1', 1001)),
        c=hl.call(0, 1),
        mc=hl.null(hl.tcall),
        t=hl.tuple([hl.call(1, 2, phased=True), 'foo', hl.null(hl.tstr)]),
        mt=hl.null(hl.ttuple(hl.tlocus('GRCh37'), hl.tbool))
    )
示例#3
0
def create_all_values_datasets():
    all_values = hl.struct(
        f32=hl.float32(3.14),
        i64=hl.int64(-9),
        m=hl.null(hl.tfloat64),
        astruct=hl.struct(a=hl.null(hl.tint32), b=5.5),
        mstruct=hl.null(hl.tstruct(x=hl.tint32, y=hl.tstr)),
        aset=hl.set(['foo', 'bar', 'baz']),
        mset=hl.null(hl.tset(hl.tfloat64)),
        d=hl.dict({hl.array(['a', 'b']): 0.5, hl.array(['x', hl.null(hl.tstr), 'z']): 0.3}),
        md=hl.null(hl.tdict(hl.tint32, hl.tstr)),
        h38=hl.locus('chr22', 33878978, 'GRCh38'),
        ml=hl.null(hl.tlocus('GRCh37')),
        i=hl.interval(
            hl.locus('1', 999),
            hl.locus('1', 1001)),
        c=hl.call(0, 1),
        mc=hl.null(hl.tcall),
        t=hl.tuple([hl.call(1, 2, phased=True), 'foo', hl.null(hl.tstr)]),
        mt=hl.null(hl.ttuple(hl.tlocus('GRCh37'), hl.tbool))
    )

    def prefix(s, p):
        return hl.struct(**{p + k: s[k] for k in s})

    all_values_table = (hl.utils.range_table(5, n_partitions=3)
                        .annotate_globals(**prefix(all_values, 'global_'))
                        .annotate(**all_values)
                        .cache())

    all_values_matrix_table = (hl.utils.range_matrix_table(3, 2, n_partitions=2)
                               .annotate_globals(**prefix(all_values, 'global_'))
                               .annotate_rows(**prefix(all_values, 'row_'))
                               .annotate_cols(**prefix(all_values, 'col_'))
                               .annotate_entries(**prefix(all_values, 'entry_'))
                               .cache())

    return all_values_table, all_values_matrix_table
示例#4
0
def full_outer_join_mt(left: hl.MatrixTable, right: hl.MatrixTable) -> hl.MatrixTable:
    """Performs a full outer join on `left` and `right`.

    Replaces row, column, and entry fields with the following:

     - `left_row` / `right_row`: structs of row fields from left and right.
     - `left_col` / `right_col`: structs of column fields from left and right.
     - `left_entry` / `right_entry`: structs of entry fields from left and right.

    Examples
    --------

    The following creates and joins two random datasets with disjoint sample ids
    but non-disjoint variant sets. We use :func:`.or_else` to attempt to find a
    non-missing genotype. If neither genotype is non-missing, then the genotype
    is set to missing. In particular, note that Samples `2` and `3` have missing
    genotypes for loci 1:1 and 1:2 because those loci are not present in `mt2`
    and these samples are not present in `mt1`

    >>> hl.set_global_seed(0)
    >>> mt1 = hl.balding_nichols_model(1, 2, 3)
    >>> mt2 = hl.balding_nichols_model(1, 2, 3)
    >>> mt2 = mt2.key_rows_by(locus=hl.locus(mt2.locus.contig,
    ...                                      mt2.locus.position+2),
    ...                       alleles=mt2.alleles)
    >>> mt2 = mt2.key_cols_by(sample_idx=mt2.sample_idx+2)
    >>> mt1.show()
    +---------------+------------+------+------+
    | locus         | alleles    | 0.GT | 1.GT |
    +---------------+------------+------+------+
    | locus<GRCh37> | array<str> | call | call |
    +---------------+------------+------+------+
    | 1:1           | ["A","C"]  | 0/1  | 0/1  |
    | 1:2           | ["A","C"]  | 1/1  | 1/1  |
    | 1:3           | ["A","C"]  | 0/0  | 0/0  |
    +---------------+------------+------+------+
    <BLANKLINE>
    >>> mt2.show()  # doctest: +SKIP_OUTPUT_CHECK
    +---------------+------------+------+------+
    | locus         | alleles    | 0.GT | 1.GT |
    +---------------+------------+------+------+
    | locus<GRCh37> | array<str> | call | call |
    +---------------+------------+------+------+
    | 1:3           | ["A","C"]  | 0/1  | 1/1  |
    | 1:4           | ["A","C"]  | 0/1  | 0/1  |
    | 1:5           | ["A","C"]  | 1/1  | 0/0  |
    +---------------+------------+------+------+
    <BLANKLINE>
    >>> mt3 = hl.experimental.full_outer_join_mt(mt1, mt2)
    >>> mt3 = mt3.select_entries(GT=hl.or_else(mt3.left_entry.GT, mt3.right_entry.GT))
    >>> mt3.show()
    +---------------+------------+------+------+------+------+
    | locus         | alleles    | 0.GT | 1.GT | 2.GT | 3.GT |
    +---------------+------------+------+------+------+------+
    | locus<GRCh37> | array<str> | call | call | call | call |
    +---------------+------------+------+------+------+------+
    | 1:1           | ["A","C"]  | 0/1  | 0/1  | NA   | NA   |
    | 1:2           | ["A","C"]  | 1/1  | 1/1  | NA   | NA   |
    | 1:3           | ["A","C"]  | 0/0  | 0/0  | 0/1  | 1/1  |
    | 1:4           | ["A","C"]  | NA   | NA   | 0/1  | 0/1  |
    | 1:5           | ["A","C"]  | NA   | NA   | 1/1  | 0/0  |
    +---------------+------------+------+------+------+------+
    <BLANKLINE>

    Parameters
    ----------
    left : :class:`.MatrixTable`
    right : :class:`.MatrixTable`

    Returns
    -------
    :class:`.MatrixTable`
    """

    if [x.dtype for x in left.row_key.values()] != [x.dtype for x in right.row_key.values()]:
        raise ValueError(f"row key types do not match:\n"
                         f"  left:  {list(left.row_key.values())}\n"
                         f"  right: {list(right.row_key.values())}")

    if [x.dtype for x in left.col_key.values()] != [x.dtype for x in right.col_key.values()]:
        raise ValueError(f"column key types do not match:\n"
                         f"  left:  {list(left.col_key.values())}\n"
                         f"  right: {list(right.col_key.values())}")

    left = left.select_rows(left_row=left.row)
    left_t = left.localize_entries('left_entries', 'left_cols')
    right = right.select_rows(right_row=right.row)
    right_t = right.localize_entries('right_entries', 'right_cols')

    ht = left_t.join(right_t, how='outer')
    ht = ht.annotate_globals(
        left_keys=hl.group_by(
            lambda t: t[0],
            hl.enumerate(
                ht.left_cols.map(lambda x: hl.tuple([x[f] for f in left.col_key])), index_first=False)).map_values(
            lambda elts: elts.map(lambda t: t[1])),
        right_keys=hl.group_by(
            lambda t: t[0],
            hl.enumerate(
                ht.right_cols.map(lambda x: hl.tuple([x[f] for f in right.col_key])), index_first=False)).map_values(
            lambda elts: elts.map(lambda t: t[1])))
    ht = ht.annotate_globals(
        key_indices=hl.array(ht.left_keys.key_set().union(ht.right_keys.key_set()))
        .map(lambda k: hl.struct(k=k, left_indices=ht.left_keys.get(k), right_indices=ht.right_keys.get(k)))
        .flatmap(lambda s: hl.case()
                 .when(hl.is_defined(s.left_indices) & hl.is_defined(s.right_indices),
                       hl.range(0, s.left_indices.length()).flatmap(
                           lambda i: hl.range(0, s.right_indices.length()).map(
                               lambda j: hl.struct(k=s.k, left_index=s.left_indices[i],
                                                   right_index=s.right_indices[j]))))
                 .when(hl.is_defined(s.left_indices),
                       s.left_indices.map(
                           lambda elt: hl.struct(k=s.k, left_index=elt, right_index=hl.missing('int32'))))
                 .when(hl.is_defined(s.right_indices),
                       s.right_indices.map(
                           lambda elt: hl.struct(k=s.k, left_index=hl.missing('int32'), right_index=elt)))
                 .or_error('assertion error')))
    ht = ht.annotate(__entries=ht.key_indices.map(lambda s: hl.struct(left_entry=ht.left_entries[s.left_index],
                                                                      right_entry=ht.right_entries[s.right_index])))
    ht = ht.annotate_globals(__cols=ht.key_indices.map(
        lambda s: hl.struct(**{f: s.k[i] for i, f in enumerate(left.col_key)},
                            left_col=ht.left_cols[s.left_index],
                            right_col=ht.right_cols[s.right_index])))
    ht = ht.drop('left_entries', 'left_cols', 'left_keys', 'right_entries', 'right_cols', 'right_keys', 'key_indices')
    return ht._unlocalize_entries('__entries', '__cols', list(left.col_key))
示例#5
0
def compute_binned_truth_sample_concordance(ht: hl.Table,
                                            binned_score_ht: hl.Table,
                                            n_bins: int = 100) -> hl.Table:
    """
    Determines the concordance (TP, FP, FN) between a truth sample within the callset and the samples truth data
    grouped by bins computed using `compute_quantile_bin`.

    .. note::

        The input 'ht` should contain three row fields:
            - score: value to use for quantile binning
            - GT: a CallExpression containing the genotype of the evaluation data for the sample
            - truth_GT: a CallExpression containing the genotype of the truth sample

        The input `binned_score_ht` should contain:
             - score: value used to bin the full callset
             - bin: the full callset quantile bin


    The table is grouped by global/truth sample bin and variant type and contains TP, FP and FN.

    :param ht: Input HT
    :param binned_score_ht: Table with the an annotation for quantile bin for each variant
    :param n_bins: Number of bins to bin the data into
    :return: Binned truth sample concordance HT
    """
    # Annotate score and global bin
    indexed_binned_score_ht = binned_score_ht[ht.key]
    ht = ht.annotate(score=indexed_binned_score_ht.score,
                     global_bin=indexed_binned_score_ht.bin)

    # Annotate the truth sample quantile bin
    bin_ht = compute_quantile_bin(
        ht,
        score_expr=ht.score,
        bin_expr={"truth_sample_bin": hl.expr.bool(True)},
        n_bins=n_bins,
    )
    ht = ht.join(bin_ht, how="left")

    # Explode the global and truth sample bins
    ht = ht.annotate(bin=[
        hl.tuple(["global_bin", ht.global_bin]),
        hl.tuple(["truth_sample_bin", ht.truth_sample_bin]),
    ])

    ht = ht.explode(ht.bin)
    ht = ht.annotate(bin_id=ht.bin[0], bin=hl.int(ht.bin[1]))

    # Compute TP, FP and FN by bin_id, variant type and bin
    return (ht.group_by("bin_id", "snv", "bin").aggregate(
        # TP => allele is found in both data sets
        tp=hl.agg.count_where(ht.GT.is_non_ref() & ht.truth_GT.is_non_ref()),
        # FP => allele is found only in test data set
        fp=hl.agg.count_where(ht.GT.is_non_ref()
                              & hl.or_else(ht.truth_GT.is_hom_ref(), True)),
        # FN => allele is found in truth data only
        fn=hl.agg.count_where(ht.GT.is_hom_ref()
                              & hl.or_else(ht.truth_GT.is_non_ref(), True)),
        min_score=hl.agg.min(ht.score),
        max_score=hl.agg.max(ht.score),
        n_alleles=hl.agg.count(),
    ).repartition(5))
示例#6
0
def table_expr_take():
    ht = hl.read_table(resource('many_strings_table.ht'))
    hl.tuple([ht.f1, ht.f2]).take(100)
示例#7
0
#print(mt.aggregate_cols(agg.counter(mt.Gender_Classification))) '0.0': 1291, '1.0': 1244}

pca_eigenvalues, pca_scores, _ = hl.hwe_normalized_pca(mt.GT, k=2)
mt = mt.annotate_cols(pca=pca_scores[mt.s])

x = pca_scores.scores[0]
y = pca_scores.scores[1]
label = mt.cols()[pca_scores.s].Super_Population
collect_all = nullable(bool)

if isinstance(x, Expression) and isinstance(y, Expression):
    agg_f = x._aggregation_method()
    if isinstance(label, Expression):
        if collect_all:
            res = hail.tuple([x, y, label]).collect()
            label = [point[2] for point in res]
        else:
            res = agg_f(
                aggregators.downsample(x,
                                       y,
                                       label=label,
                                       n_divisions=n_divisions))
            label = [point[2][0] for point in res]

        x = [point[0] for point in res]
        y = [point[1] for point in res]
    else:
        if collect_all:
            res = hail.tuple([x, y]).collect()
        else:
示例#8
0
def compute_binned_truth_sample_concordance(
    ht: hl.Table,
    binned_score_ht: hl.Table,
    n_bins: int = 100,
    add_bins: Dict[str, hl.expr.BooleanExpression] = {},
) -> hl.Table:
    """
    Determine the concordance (TP, FP, FN) between a truth sample within the callset and the samples truth data grouped by bins computed using `compute_ranked_bin`.

    .. note::
        The input 'ht` should contain three row fields:
            - score: value to use for binning
            - GT: a CallExpression containing the genotype of the evaluation data for the sample
            - truth_GT: a CallExpression containing the genotype of the truth sample
        The input `binned_score_ht` should contain:
             - score: value used to bin the full callset
             - bin: the full callset bin

    'add_bins` can be used to add additional global and truth sample binning to the final binned truth sample
    concordance HT. The keys in `add_bins` must be present in `binned_score_ht` and the values in `add_bins`
    should be expressions on `ht` that define a subset of variants to bin in the truth sample. An example is if we want
    to look at the global and truth sample binning on only bi-allelic variants. `add_bins` could be set to
    {'biallelic_bin': ht.biallelic}.

    The table is grouped by global/truth sample bin and variant type and contains TP, FP and FN.

    :param ht: Input HT
    :param binned_score_ht: Table with the bin annotation for each variant
    :param n_bins: Number of bins to bin the data into
    :param add_bins: Dictionary of additional global bin columns (key) and the expr to use for binning the truth sample (value)
    :return: Binned truth sample concordance HT
    """
    # Annotate score and global bin
    indexed_binned_score_ht = binned_score_ht[ht.key]
    ht = ht.annotate(
        **{
            f"global_{bin_id}": indexed_binned_score_ht[bin_id]
            for bin_id in add_bins
        },
        **{f"_{bin_id}": bin_expr
           for bin_id, bin_expr in add_bins.items()},
        score=indexed_binned_score_ht.score,
        global_bin=indexed_binned_score_ht.bin,
    )

    # Annotate the truth sample bin
    bin_ht = compute_ranked_bin(
        ht,
        score_expr=ht.score,
        bin_expr={
            "truth_sample_bin": hl.expr.bool(True),
            **{
                f"truth_sample_{bin_id}": ht[f"_{bin_id}"]
                for bin_id in add_bins
            },
        },
        n_bins=n_bins,
    )
    ht = ht.join(bin_ht, how="left")

    bin_list = [
        hl.tuple(["global_bin", ht.global_bin]),
        hl.tuple(["truth_sample_bin", ht.truth_sample_bin]),
    ]
    bin_list.extend([
        hl.tuple([f"global_{bin_id}", ht[f"global_{bin_id}"]])
        for bin_id in add_bins
    ])
    bin_list.extend([
        hl.tuple([f"truth_sample_{bin_id}", ht[f"truth_sample_{bin_id}"]])
        for bin_id in add_bins
    ])

    # Explode the global and truth sample bins
    ht = ht.annotate(bin=bin_list)

    ht = ht.explode(ht.bin)
    ht = ht.annotate(bin_id=ht.bin[0], bin=hl.int(ht.bin[1]))

    # Compute TP, FP and FN by bin_id, variant type and bin
    return (ht.group_by("bin_id", "snv", "bin").aggregate(
        # TP => allele is found in both data sets
        tp=hl.agg.count_where(ht.GT.is_non_ref() & ht.truth_GT.is_non_ref()),
        # FP => allele is found only in test data set
        fp=hl.agg.count_where(ht.GT.is_non_ref()
                              & hl.or_else(ht.truth_GT.is_hom_ref(), True)),
        # FN => allele is found in truth data only
        fn=hl.agg.count_where(
            hl.or_else(ht.GT.is_hom_ref(), True) & ht.truth_GT.is_non_ref()),
        min_score=hl.agg.min(ht.score),
        max_score=hl.agg.max(ht.score),
        n_alleles=hl.agg.count(),
    ).repartition(5))
示例#9
0
def table_aggregate_counter(ht_path):
    ht = hl.read_table(ht_path)
    ht.aggregate(hl.tuple([hl.agg.counter(ht[f'f{i}']) for i in range(8)]))
示例#10
0
def table_expr_take(ht_path):
    ht = hl.read_table(ht_path)
    hl.tuple([ht.f1, ht.f2]).take(100)
示例#11
0
def combine_datasets(dataset_ids):
    gene_models_path = f"{pipeline_config.get('output', 'staging_path')}/gene_models.ht"
    ds = hl.read_table(gene_models_path)

    ds = ds.annotate(gene_results=hl.struct(), variants=hl.struct())
    ds = ds.annotate_globals(
        meta=hl.struct(variant_fields=VARIANT_FIELDS, datasets=hl.struct()))

    for dataset_id in dataset_ids:
        dataset_path = os.path.join(
            pipeline_config.get("output", "staging_path"), dataset_id.lower())
        gene_results = hl.read_table(
            os.path.join(dataset_path, "gene_results.ht"))

        gene_group_result_field_names = gene_results.group_results.dtype.value_type.fields
        gene_group_result_field_types = [
            str(typ).rstrip("3264")
            for typ in gene_results.group_results.dtype.value_type.types
        ]
        gene_result_analysis_groups = list(
            gene_results.aggregate(
                hl.agg.explode(hl.agg.collect_as_set,
                               gene_results.group_results.keys())))

        gene_results = gene_results.annotate(group_results=hl.array([
            hl.tuple([
                gene_results.group_results.get(group)[field]
                for field in gene_group_result_field_names
            ]) for group in gene_result_analysis_groups
        ]))

        ds = ds.annotate(gene_results=ds.gene_results.annotate(
            **{dataset_id: gene_results[ds.gene_id]}))

        variant_results = hl.read_table(
            os.path.join(dataset_path, "variant_results.ht"))

        reference_genome = variant_results.locus.dtype.reference_genome.name
        variant_info_field_names = variant_results.info.dtype.fields
        variant_info_field_types = [
            str(typ).rstrip("3264") for typ in variant_results.info.dtype.types
        ]
        variant_group_result_field_names = variant_results.group_results.dtype.value_type.fields
        variant_group_result_field_types = [
            str(typ).rstrip("3264")
            for typ in variant_results.group_results.dtype.value_type.types
        ]
        variant_result_analysis_groups = list(
            variant_results.aggregate(
                hl.agg.explode(hl.agg.collect_as_set,
                               variant_results.group_results.keys())))

        variant_results = variant_results.annotate(
            info=hl.tuple([
                variant_results.info[field]
                for field in variant_info_field_names
            ]),
            group_results=hl.array([
                hl.rbind(
                    variant_results.group_results.get(group),
                    lambda group_result: hl.or_missing(
                        hl.is_defined(group_result),
                        hl.tuple([
                            group_result[field]
                            for field in variant_group_result_field_names
                        ]),
                    ),
                ) for group in variant_result_analysis_groups
            ]),
        )

        variant_results = variant_results.annotate(
            variant_id=variant_results.locus.contig.replace("^chr", "") + "-" +
            hl.str(variant_results.locus.position) + "-" +
            variant_results.alleles[0] + "-" + variant_results.alleles[1],
            pos=variant_results.locus.position,
        )

        variant_results = variant_results.annotate(variant=hl.tuple(
            [variant_results[field] for field in VARIANT_FIELDS]))
        variant_results = variant_results.group_by("gene_id").aggregate(
            variants=hl.agg.collect(variant_results.variant))
        ds = ds.annotate(variants=ds.variants.annotate(
            **{
                dataset_id:
                hl.or_else(
                    variant_results[ds.gene_id].variants,
                    hl.empty_array(
                        variant_results.variants.dtype.element_type),
                )
            }))

        ds = ds.annotate_globals(meta=ds.globals.meta.annotate(
            datasets=ds.globals.meta.datasets.annotate(
                **{
                    dataset_id:
                    hl.struct(
                        reference_genome=reference_genome,
                        gene_result_analysis_groups=gene_result_analysis_groups
                        or hl.empty_array(hl.tstr),
                        gene_group_result_field_names=
                        gene_group_result_field_names
                        or hl.empty_array(hl.tstr),
                        gene_group_result_field_types=
                        gene_group_result_field_types
                        or hl.empty_array(hl.tstr),
                        variant_info_field_names=variant_info_field_names
                        or hl.empty_array(hl.tstr),
                        variant_info_field_types=variant_info_field_types
                        or hl.empty_array(hl.tstr),
                        variant_result_analysis_groups=
                        variant_result_analysis_groups
                        or hl.empty_array(hl.tstr),
                        variant_group_result_field_names=
                        variant_group_result_field_names
                        or hl.empty_array(hl.tstr),
                        variant_group_result_field_types=
                        variant_group_result_field_types
                        or hl.empty_array(hl.tstr),
                    ),
                })))

    return ds
示例#12
0
def median_impute_features(
        ht: hl.Table,
        strata: Optional[Dict[str, hl.expr.Expression]] = None) -> hl.Table:
    """
    Numerical features in the Table are median-imputed by Hail's `approx_median`.

    If a `strata` dict is given, imputation is done based on the median of of each stratification.

    The annotations that are added to the Table are
        - feature_imputed - A row annotation indicating if each numerical feature was imputed or not.
        - features_median - A global annotation containing the median of the numerical features. If `strata` is given,
          this struct will also be broken down by the given strata.
        - variants_by_strata - An additional global annotation with the variant counts by strata that will only be
          added if imputing by a given `strata`.

    :param ht: Table containing all samples and features for median imputation.
    :param strata: Whether to impute features median by specific strata (default False).
    :return: Feature Table imputed using approximate median values.
    """

    logger.info(
        "Computing feature medians for imputation of missing numeric values")
    numerical_features = [
        k for k, v in ht.row.dtype.items() if v == hl.tint or v == hl.tfloat
    ]

    median_agg_expr = hl.struct(
        **{
            feature: hl.agg.approx_median(ht[feature])
            for feature in numerical_features
        })

    if strata:
        ht = ht.annotate_globals(
            feature_medians=ht.aggregate(
                hl.agg.group_by(hl.tuple([ht[x] for x in strata]),
                                median_agg_expr),
                _localize=False,
            ),
            variants_by_strata=ht.aggregate(hl.agg.counter(
                hl.tuple([ht[x] for x in strata])),
                                            _localize=False),
        )
        feature_median_expr = ht.feature_medians[hl.tuple(
            [ht[x] for x in strata])]
        logger.info("Variant count by strata:\n{}".format("\n".join([
            "{}: {}".format(k, v)
            for k, v in hl.eval(ht.variants_by_strata).items()
        ])))

    else:
        ht = ht.annotate_globals(
            feature_medians=ht.aggregate(median_agg_expr, _localize=False))
        feature_median_expr = ht.feature_medians

    ht = ht.annotate(
        **{
            f: hl.or_else(ht[f], feature_median_expr[f])
            for f in numerical_features
        },
        feature_imputed=hl.struct(
            **{f: hl.is_missing(ht[f])
               for f in numerical_features}),
    )

    return ht
示例#13
0
 def result(self):
     return hl.tuple(self.l)
示例#14
0
def manhattan(pvals,
              locus=None,
              title=None,
              size=4,
              hover_fields=None,
              collect_all=False,
              n_divisions=500,
              significance_line=5e-8):
    """Create a Manhattan plot. (https://en.wikipedia.org/wiki/Manhattan_plot)

    Parameters
    ----------
    pvals : :class:`.Float64Expression`
        P-values to be plotted.
    locus : :class:`.LocusExpression`
        Locus values to be plotted.
    title : str
        Title of the plot.
    size : int
        Size of markers in screen space units.
    hover_fields : Dict[str, :class:`.Expression`]
        Dictionary of field names and values to be shown in the HoverTool of the plot.
    collect_all : bool
        Whether to collect all values or downsample before plotting.
    n_divisions : int
        Factor by which to downsample (default value = 500). A lower input results in fewer output datapoints.
    significance_line : float, optional
        p-value at which to add a horizontal, dotted red line indicating
        genome-wide significance.  If ``None``, no line is added.

    Returns
    -------
    :class:`bokeh.plotting.figure.Figure`
    """
    def get_contig_index(x, starts):
        left = 0
        right = len(starts) - 1
        while left <= right:
            mid = (left + right) // 2
            if x < starts[mid]:
                if x >= starts[mid - 1]:
                    return mid - 1
                right = mid
            elif x >= starts[mid + 1]:
                left = mid + 1
            else:
                return mid

    if locus is None:
        locus = pvals._indices.source.locus

    if hover_fields is None:
        hover_fields = {}

    hover_fields['locus'] = hail.str(locus)

    pvals = -hail.log10(pvals)

    if collect_all:
        res = hail.tuple(
            [locus.global_position(), pvals,
             hail.struct(**hover_fields)]).collect()
        hf_struct = [point[2] for point in res]
        for key in hover_fields:
            hover_fields[key] = [item[key] for item in hf_struct]
    else:
        agg_f = pvals._aggregation_method()
        res = agg_f(
            aggregators.downsample(
                locus.global_position(),
                pvals,
                label=hail.array([hail.str(x) for x in hover_fields.values()]),
                n_divisions=n_divisions))
        fields = [point[2] for point in res]
        for idx, key in enumerate(list(hover_fields.keys())):
            hover_fields[key] = [field[idx] for field in fields]

    x = [point[0] for point in res]
    y = [point[1] for point in res]
    y_linear = [10**(-p) for p in y]
    hover_fields['p_value'] = y_linear

    ref = locus.dtype.reference_genome

    total_pos = 0
    start_points = []
    for i in range(0, len(ref.contigs)):
        start_points.append(total_pos)
        total_pos += ref.lengths.get(ref.contigs[i])
    start_points.append(total_pos)  # end point of all contigs

    observed_contigs = set()
    label = []
    for element in x:
        contig_index = get_contig_index(element, start_points)
        label.append(str(contig_index % 2))
        observed_contigs.add(ref.contigs[contig_index])

    labels = ref.contigs.copy()
    num_deleted = 0
    mid_points = []
    for i in range(0, len(ref.contigs)):
        if ref.contigs[i] in observed_contigs:
            length = ref.lengths.get(ref.contigs[i])
            mid = start_points[i] + length / 2
            if mid % 1 == 0:
                mid += 0.5
            mid_points.append(mid)
        else:
            del labels[i - num_deleted]
            num_deleted += 1

    p = scatter(x,
                y,
                label=label,
                title=title,
                xlabel='Chromosome',
                ylabel='P-value (-log10 scale)',
                size=size,
                legend=False,
                source_fields=hover_fields)

    p.xaxis.ticker = mid_points
    p.xaxis.major_label_overrides = dict(zip(mid_points, labels))
    p.width = 1000

    tooltips = [(key, "@{}".format(key)) for key in hover_fields]
    p.add_tools(HoverTool(tooltips=tooltips))

    if significance_line is not None:
        p.renderers.append(
            Span(location=-log10(significance_line),
                 dimension='width',
                 line_color='red',
                 line_dash='dashed',
                 line_width=1.5))

    return p
示例#15
0
def full_outer_join_mt(left: hl.MatrixTable,
                       right: hl.MatrixTable) -> hl.MatrixTable:
    """Performs a full outer join on `left` and `right`.

    Replaces row, column, and entry fields with the following:

     - `left_row` / `right_row`: structs of row fields from left and right.
     - `left_col` / `right_col`: structs of column fields from left and right.
     - `left_entry` / `right_entry`: structs of entry fields from left and right.

    Parameters
    ----------
    left : :class:`.MatrixTable`
    right : :class:`.MatrixTable`

    Returns
    -------
    :class:`.MatrixTable`
    """

    if [x.dtype for x in left.row_key.values()
        ] != [x.dtype for x in right.row_key.values()]:
        raise ValueError(f"row key types do not match:\n"
                         f"  left:  {list(left.row_key.values())}\n"
                         f"  right: {list(right.row_key.values())}")

    if [x.dtype for x in left.col_key.values()
        ] != [x.dtype for x in right.col_key.values()]:
        raise ValueError(f"column key types do not match:\n"
                         f"  left:  {list(left.col_key.values())}\n"
                         f"  right: {list(right.col_key.values())}")

    left = left.select_rows(left_row=left.row)
    left_t = left.localize_entries('left_entries', 'left_cols')
    right = right.select_rows(right_row=right.row)
    right_t = right.localize_entries('right_entries', 'right_cols')

    ht = left_t.join(right_t, how='outer')
    ht = ht.annotate_globals(left_keys=hl.group_by(
        lambda t: t[0],
        hl.zip_with_index(
            ht.left_cols.map(lambda x: hl.tuple([x[f] for f in left.col_key])),
            index_first=False)).map_values(
                lambda elts: elts.map(lambda t: t[1])),
                             right_keys=hl.group_by(
                                 lambda t: t[0],
                                 hl.zip_with_index(
                                     ht.right_cols.map(lambda x: hl.tuple(
                                         [x[f] for f in right.col_key])),
                                     index_first=False)).
                             map_values(lambda elts: elts.map(lambda t: t[1])))
    ht = ht.annotate_globals(key_indices=hl.array(ht.left_keys.key_set(
    ).union(ht.right_keys.key_set())).map(lambda k: hl.struct(
        k=k,
        left_indices=ht.left_keys.get(k),
        right_indices=ht.right_keys.get(k))).flatmap(lambda s: hl.case().when(
            hl.is_defined(s.left_indices) & hl.is_defined(s.right_indices),
            hl.range(0, s.left_indices.length()).flatmap(lambda i: hl.range(
                0, s.right_indices.length()).map(lambda j: hl.struct(
                    k=s.k,
                    left_index=s.left_indices[i],
                    right_index=s.right_indices[j])))
        ).when(
            hl.is_defined(s.left_indices),
            s.left_indices.map(lambda elt: hl.struct(
                k=s.k, left_index=elt, right_index=hl.null('int32')))).when(
                    hl.is_defined(s.right_indices),
                    s.right_indices.map(lambda elt: hl.struct(
                        k=s.k, left_index=hl.null('int32'), right_index=elt))
                ).or_error('assertion error')))
    ht = ht.annotate(__entries=ht.key_indices.map(
        lambda s: hl.struct(left_entry=ht.left_entries[s.left_index],
                            right_entry=ht.right_entries[s.right_index])))
    ht = ht.annotate_globals(__cols=ht.key_indices.map(
        lambda s: hl.struct(**{f: s.k[i]
                               for i, f in enumerate(left.col_key)},
                            left_col=ht.left_cols[s.left_index],
                            right_col=ht.right_cols[s.right_index])))
    ht = ht.drop('left_entries', 'left_cols', 'left_keys', 'right_entries',
                 'right_cols', 'right_keys', 'key_indices')
    return ht._unlocalize_entries('__entries', '__cols', list(left.col_key))
示例#16
0
文件: plots.py 项目: lfrancioli/hail
def manhattan(pvals, locus=None, title=None, size=4, hover_fields=None, collect_all=False, n_divisions=500):
    """Create a Manhattan plot. (https://en.wikipedia.org/wiki/Manhattan_plot)

    Parameters
    ----------
    pvals : :class:`.Float64Expression`
        P-values to be plotted.
    locus : :class:`.LocusExpression`
        Locus values to be plotted.
    title : str
        Title of the plot.
    size : int
        Size of markers in screen space units.
    hover_fields : Dict[str, :class:`.Expression`]
        Dictionary of field names and values to be shown in the HoverTool of the plot.
    collect_all : bool
        Whether to collect all values or downsample before plotting.
    n_divisions : int
        Factor by which to downsample (default value = 500). A lower input results in fewer output datapoints.

    Returns
    -------
    :class:`bokeh.plotting.figure.Figure`
    """
    def get_contig_index(x, starts):
        left = 0
        right = len(starts) - 1
        while left <= right:
            mid = (left + right) // 2
            if x < starts[mid]:
                if x >= starts[mid - 1]:
                    return mid - 1
                right = mid
            elif x >= starts[mid+1]:
                left = mid + 1
            else:
                return mid

    if locus is None:
        locus = pvals._indices.source.locus

    if hover_fields is None:
        hover_fields = {}

    hover_fields['locus'] = hail.str(locus)

    pvals = -hail.log10(pvals)

    if collect_all:
        res = hail.tuple([locus.global_position(), pvals, hail.struct(**hover_fields)]).collect()
        hf_struct = [point[2] for point in res]
        for key in hover_fields:
            hover_fields[key] = [item[key] for item in hf_struct]
    else:
        agg_f = pvals._aggregation_method()
        res = agg_f(aggregators.downsample(locus.global_position(), pvals,
                                           label=hail.array([hail.str(x) for x in hover_fields.values()]),
                                           n_divisions=n_divisions))
        fields = [point[2] for point in res]
        for idx, key in enumerate(list(hover_fields.keys())):
            hover_fields[key] = [field[idx] for field in fields]

    x = [point[0] for point in res]
    y = [point[1] for point in res]
    y_linear = [10 ** (-p) for p in y]
    hover_fields['p_value'] = y_linear

    ref = locus.dtype.reference_genome

    total_pos = 0
    start_points = []
    for i in range(0, len(ref.contigs)):
        start_points.append(total_pos)
        total_pos += ref.lengths.get(ref.contigs[i])
    start_points.append(total_pos)  # end point of all contigs

    observed_contigs = set()
    label = []
    for element in x:
        contig_index = get_contig_index(element, start_points)
        label.append(str(contig_index % 2))
        observed_contigs.add(ref.contigs[contig_index])

    labels = ref.contigs.copy()
    num_deleted = 0
    mid_points = []
    for i in range(0, len(ref.contigs)):
        if ref.contigs[i] in observed_contigs:
            length = ref.lengths.get(ref.contigs[i])
            mid = start_points[i] + length / 2
            if mid % 1 == 0:
                mid += 0.5
            mid_points.append(mid)
        else:
            del labels[i - num_deleted]
            num_deleted += 1

    p = scatter(x, y, label=label, title=title, xlabel='Chromosome', ylabel='P-value (-log10 scale)',
                size=size, legend=False, source_fields=hover_fields)

    p.xaxis.ticker = mid_points
    p.xaxis.major_label_overrides = dict(zip(mid_points, labels))
    p.width = 1000

    tooltips = [(key, "@{}".format(key)) for key in hover_fields]
    p.add_tools(HoverTool(
        tooltips=tooltips
    ))

    return p
示例#17
0
文件: plots.py 项目: lfrancioli/hail
def scatter(x, y, label=None, title=None, xlabel=None, ylabel=None, size=4, legend=True,
            collect_all=False, n_divisions=500, source_fields=None):
    """Create a scatterplot.

    Parameters
    ----------
    x : List[float] or :class:`.Float64Expression`
        List of x-values to be plotted.
    y : List[float] or :class:`.Float64Expression`
        List of y-values to be plotted.
    label : List[str] or :class:`.StringExpression`
        List of labels for x and y values, used to assign each point a label (e.g. population)
    title : str
        Title of the scatterplot.
    xlabel : str
        X-axis label.
    ylabel : str
        Y-axis label.
    size : int
        Size of markers in screen space units.
    legend : bool
        Whether or not to show the legend in the resulting figure.
    collect_all : bool
        Whether to collect all values or downsample before plotting.
        This parameter will be ignored if x and y are Python objects.
    n_divisions : int
        Factor by which to downsample (default value = 500). A lower input results in fewer output datapoints.
    source_fields : Dict[str, List[Any]]
        Extra fields for the ColumnDataSource of the plot.

    Returns
    -------
    :class:`bokeh.plotting.figure.Figure`
    """
    if isinstance(x, Expression) and isinstance(y, Expression):
        agg_f = x._aggregation_method()
        if isinstance(label, Expression):
            if collect_all:
                res = hail.tuple([x, y, label]).collect()
                label = [point[2] for point in res]
            else:
                res = agg_f(aggregators.downsample(x, y, label=label, n_divisions=n_divisions))
                label = [point[2][0] for point in res]

            x = [point[0] for point in res]
            y = [point[1] for point in res]
        else:
            if collect_all:
                res = hail.tuple([x, y]).collect()
            else:
                res = agg_f(aggregators.downsample(x, y, n_divisions=n_divisions))

            x = [point[0] for point in res]
            y = [point[1] for point in res]
    elif isinstance(x, Expression) or isinstance(y, Expression):
        raise TypeError('Invalid input: x and y must both be either Expressions or Python Lists.')
    else:
        if isinstance(label, Expression):
            label = label.collect()

    p = figure(title=title, x_axis_label=xlabel, y_axis_label=ylabel, background_fill_color='#EEEEEE')
    if label is not None:
        fields = dict(x=x, y=y, label=label)
        if source_fields is not None:
            for key, values in source_fields.items():
                fields[key] = values

        source = ColumnDataSource(fields)

        if legend:
            leg = 'label'
        else:
            leg = None

        factors = list(set(label))
        if len(factors) > len(palette):
            color_gen = cycle(palette)
            colors = []
            for i in range(0, len(factors)):
                colors.append(next(color_gen))
        else:
            colors = palette[0:len(factors)]

        color_mapper = CategoricalColorMapper(factors=factors, palette=colors)
        p.circle('x', 'y', alpha=0.5, source=source, size=size,
                 color={'field': 'label', 'transform': color_mapper}, legend=leg)
    else:
        p.circle(x, y, alpha=0.5, size=size)
    return p
示例#18
0
def compute_from_vp_mt(chr20: bool, overwrite: bool):
    meta = get_gnomad_meta('exomes')
    vp_mt = hl.read_matrix_table(full_mt_path('exomes'))
    vp_mt = vp_mt.filter_cols(meta[vp_mt.col_key].release)
    ann_ht = hl.read_table(vp_ann_ht_path('exomes'))
    phase_ht = hl.read_table(phased_vp_count_ht_path('exomes'))

    if chr20:
        vp_mt, ann_ht, phase_ht = filter_to_chr20([vp_mt, ann_ht, phase_ht])

    vep1_expr = get_worst_gene_csq_code_expr(ann_ht.vep1)
    vep2_expr = get_worst_gene_csq_code_expr(ann_ht.vep2)
    ann_ht = ann_ht.select(
        'snv1',
        'snv2',
        is_singleton_vp=(ann_ht.freq1['all'].AC < 2) & (ann_ht.freq2['all'].AC < 2),
        pop_af=hl.dict(
            ann_ht.freq1.key_set().intersection(ann_ht.freq2.key_set())
                .map(
                lambda pop: hl.tuple([pop, hl.max(ann_ht.freq1[pop].AF, ann_ht.freq2[pop].AF)])
            )
        ),
        popmax_af=hl.max(ann_ht.popmax1.AF, ann_ht.popmax2.AF, filter_missing=False),
        filtered=(hl.len(ann_ht.filters1) > 0) | (hl.len(ann_ht.filters2) > 0),
        vep=vep1_expr.keys().filter(
            lambda k: vep2_expr.contains(k)
        ).map(
            lambda k: vep1_expr[k].annotate(
                csq=hl.max(vep1_expr[k].csq, vep2_expr[k].csq)
            )
        )
    )

    vp_mt = vp_mt.annotate_cols(
        pop=meta[vp_mt.col_key].pop
    )
    vp_mt = vp_mt.annotate_rows(
        **ann_ht[vp_mt.row_key],
        phase_info=phase_ht[vp_mt.row_key].phase_info
    )

    vp_mt = vp_mt.filter_rows(
        ~vp_mt.filtered
    )

    vp_mt = vp_mt.filter_entries(
        vp_mt.GT1.is_het() & vp_mt.GT2.is_het() & vp_mt.adj1 & vp_mt.adj2
    )

    vp_mt = vp_mt.select_entries(
        x=True
    )

    vp_mt = vp_mt.annotate_cols(
        pop=['all', vp_mt.pop]
    )
    vp_mt = vp_mt.explode_cols('pop')

    vp_mt = vp_mt.explode_rows('vep')
    vp_mt = vp_mt.transmute_rows(
        **vp_mt.vep
    )

    def get_grouped_phase_agg():
        return hl.agg.group_by(
            hl.case()
                .when(~vp_mt.is_singleton_vp & (vp_mt.phase_info[vp_mt.pop].em.adj.p_chet > CHET_THRESHOLD), 1)
                .when(~vp_mt.is_singleton_vp & (vp_mt.phase_info[vp_mt.pop].em.adj.p_chet < SAME_HAP_THRESHOLD), 2)
                .default(3)
            ,
            hl.agg.min(vp_mt.csq)
        )

    vp_mt = vp_mt.group_rows_by(
        'gene_id',
        'gene_symbol'
    ).aggregate(
        all=hl.agg.filter(
            vp_mt.x &
            hl.if_else(
                vp_mt.pop == 'all',
                hl.is_defined(vp_mt.popmax_af) &
                (vp_mt.popmax_af <= MAX_FREQ),
                vp_mt.pop_af[vp_mt.pop] <= MAX_FREQ
            ),
            get_grouped_phase_agg()
        ),
        af_le_0_001=hl.agg.filter(
            hl.if_else(
                vp_mt.pop == 'all',
                hl.is_defined(vp_mt.popmax_af) &
                (vp_mt.popmax_af <= 0.001),
                vp_mt.pop_af[vp_mt.pop] <= 0.001
            )
            & vp_mt.x,
            get_grouped_phase_agg()
        )
    )

    vp_mt = vp_mt.checkpoint('gs://gnomad-tmp/compound_hets/chet_per_gene{}.2.mt'.format(
        '.chr20' if chr20 else ''
    ), overwrite=True)

    gene_ht = vp_mt.annotate_rows(
        row_counts=hl.flatten([
            hl.array(
                hl.agg.group_by(
                    vp_mt.pop,
                    hl.struct(
                        csq=csq,
                        af=af,
                        # TODO: Review this
                        # These will only kept the worst csq -- now maybe it'd be better to keep either
                        # - the worst csq for chet or
                        # - the worst csq for both chet and same_hap
                        n_worst_chet=hl.agg.count_where(vp_mt[af].get(1) == csq_i),
                        n_chet=hl.agg.count_where((vp_mt[af].get(1) == csq_i) & (vp_mt[af].get(2, 9) >= csq_i) & (vp_mt[af].get(3, 9) >= csq_i)),
                        n_same_hap=hl.agg.count_where((vp_mt[af].get(2) == csq_i) & (vp_mt[af].get(1, 9) > csq_i) & (vp_mt[af].get(3, 9) >= csq_i)),
                        n_unphased=hl.agg.count_where((vp_mt[af].get(3) == csq_i) & (vp_mt[af].get(1, 9) > csq_i) & (vp_mt[af].get(2, 9) > csq_i))
                    )
                )
            ).filter(
                lambda x: (x[1].n_chet > 0) | (x[1].n_same_hap > 0) | (x[1].n_unphased > 0)
            ).map(
                lambda x: x[1].annotate(
                    pop=x[0]
                )
            )
            for csq_i, csq in enumerate(CSQ_CODES)
            for af in ['all', 'af_le_0_001']
        ])
    ).rows()

    gene_ht = gene_ht.explode('row_counts')
    gene_ht = gene_ht.select(
        **gene_ht.row_counts
    )

    gene_ht.describe()
    gene_ht = gene_ht.checkpoint(
        'gs://gnomad-lfran/compound_hets/chet_per_gene{}.ht'.format(
            '.chr20' if chr20 else ''
        ),
        overwrite=overwrite
    )

    gene_ht.flatten().export(
        'gs://gnomad-lfran/compound_hets/chet_per_gene{}.tsv.gz'.format(
            '.chr20' if chr20 else ''
        )
    )
示例#19
0
def combine(ts):
    def merge_alleles(alleles):
        from hail.expr.functions import _num_allele_type, _allele_ints
        return hl.rbind(
            alleles.map(lambda a: hl.or_else(a[0], '')).fold(
                lambda s, t: hl.cond(hl.len(s) > hl.len(t), s, t), ''),
            lambda ref: hl.rbind(
                alleles.map(lambda al: hl.rbind(
                    al[0], lambda r: hl.array([ref]).
                    extend(al[1:].map(lambda a: hl.rbind(
                        _num_allele_type(r, a), lambda at: hl.cond(
                            (_allele_ints['SNP'] == at) |
                            (_allele_ints['Insertion'] == at) |
                            (_allele_ints['Deletion'] == at) |
                            (_allele_ints['MNP'] == at) | (_allele_ints[
                                'Complex'] == at), a + ref[hl.len(r):], a)
                    ))))), lambda lal: hl.struct(globl=hl.array([ref]).extend(
                        hl.array(hl.set(hl.flatten(lal)).remove(ref))),
                                                 local=lal)))

    def renumber_entry(entry, old_to_new) -> StructExpression:
        # global index of alternate (non-ref) alleles
        return entry.annotate(LA=entry.LA.map(lambda lak: old_to_new[lak]))

    if (ts.row.dtype, ts.globals.dtype) not in _merge_function_map:
        f = hl.experimental.define_function(
            lambda row, gbl: hl.rbind(
                merge_alleles(row.data.map(lambda d: d.alleles)), lambda
                alleles: hl.struct(
                    locus=row.locus,
                    alleles=alleles.globl,
                    rsid=hl.find(hl.is_defined, row.data.map(lambda d: d.rsid)
                                 ),
                    info=hl.struct(
                        MQ_DP=hl.sum(row.data.map(lambda d: d.info.MQ_DP)),
                        QUALapprox=hl.sum(
                            row.data.map(lambda d: d.info.QUALapprox)),
                        RAW_MQ=hl.sum(row.data.map(lambda d: d.info.RAW_MQ)),
                        VarDP=hl.sum(row.data.map(lambda d: d.info.VarDP)),
                        SB_TABLE=hl.array([
                            hl.sum(row.data.map(lambda d: d.info.SB_TABLE[0])),
                            hl.sum(row.data.map(lambda d: d.info.SB_TABLE[1])),
                            hl.sum(row.data.map(lambda d: d.info.SB_TABLE[2])),
                            hl.sum(row.data.map(lambda d: d.info.SB_TABLE[3]))
                        ])),
                    __entries=hl.bind(
                        lambda combined_allele_index: hl.
                        range(0, hl.len(row.data)).flatmap(lambda i: hl.cond(
                            hl.is_missing(row.data[i].__entries),
                            hl.range(0, hl.len(gbl.g[i].__cols)).map(
                                lambda _: hl.null(row.data[i].__entries.dtype.
                                                  element_type)),
                            hl.bind(
                                lambda old_to_new: row.data[i].__entries.map(
                                    lambda e: renumber_entry(e, old_to_new)),
                                hl.range(0, hl.len(alleles.local[i])).map(
                                    lambda j: combined_allele_index[
                                        alleles.local[i][j]])))),
                        hl.dict(
                            hl.range(0, hl.len(alleles.globl)).map(
                                lambda j: hl.tuple([alleles.globl[j], j])))))),
            ts.row.dtype, ts.globals.dtype)
        _merge_function_map[(ts.row.dtype, ts.globals.dtype)] = f
    merge_function = _merge_function_map[(ts.row.dtype, ts.globals.dtype)]
    ts = Table(
        TableMapRows(
            ts._tir,
            Apply(merge_function._name, TopLevelReference('row'),
                  TopLevelReference('global'))))
    return ts.transmute_globals(
        __cols=hl.flatten(ts.g.map(lambda g: g.__cols)))
示例#20
0
 def result(self):
     return hl.tuple(self.l)
示例#21
0
文件: misc.py 项目: henrydavidge/hail
def locus_windows(locus_expr, radius, coord_expr=None):
    """Returns start and stop indices for window around each locus.

    Examples
    --------

    Windows with 2bp radius for one contig with positions 1, 2, 3, 4, 5:

    >>> starts, stops = hl.linalg.utils.locus_windows(
    ...     hl.balding_nichols_model(1, 5, 5).locus,
    ...     radius=2)
    >>> starts, stops
    (array([0, 0, 0, 1, 2]), array([3, 4, 5, 5, 5]))

    The following examples involve three contigs.

    >>> loci = [{'locus': hl.Locus('1', 1), 'cm': 1.0},
    ...         {'locus': hl.Locus('1', 2), 'cm': 3.0},
    ...         {'locus': hl.Locus('1', 4), 'cm': 4.0},
    ...         {'locus': hl.Locus('2', 1), 'cm': 2.0},
    ...         {'locus': hl.Locus('2', 1), 'cm': 2.0},
    ...         {'locus': hl.Locus('3', 3), 'cm': 5.0}]

    >>> ht = hl.Table.parallelize(
    ...         loci,
    ...         hl.tstruct(locus=hl.tlocus('GRCh37'), cm=hl.tfloat64),
    ...         key=['locus'])

    Windows with 1bp radius:

    >>> hl.linalg.utils.locus_windows(ht.locus, 1)
    (array([0, 0, 2, 3, 3, 5]), array([2, 2, 3, 5, 5, 6]))

    Windows with 1cm radius:

    >>> hl.linalg.utils.locus_windows(ht.locus, 1.0, coord_expr=ht.cm)
    (array([0, 1, 1, 3, 3, 5]), array([1, 3, 3, 5, 5, 6]))

    Notes
    -----
    This function returns two 1-dimensional ndarrays of integers,
    ``starts`` and ``stops``, each of size equal to the number of rows.

    By default, for all indices ``i``, ``[starts[i], stops[i])`` is the maximal
    range of row indices ``j`` such that ``contig[i] == contig[j]`` and
    ``position[i] - radius <= position[j] <= position[i] + radius``.

    If the :meth:`.global_position` on `locus_expr` is not in ascending order,
    this method will fail. Ascending order should hold for a matrix table keyed
    by locus or variant (and the associated row table), or for a table that has
    been ordered by `locus_expr`.

    Set `coord_expr` to use a value other than position to define the windows.
    This row-indexed numeric expression must be non-missing, non-``nan``, on the
    same source as `locus_expr`, and ascending with respect to locus
    position for each contig; otherwise the function will fail.

    The last example above uses centimorgan coordinates, so
    ``[starts[i], stops[i])`` is the maximal range of row indices ``j`` such
    that ``contig[i] == contig[j]`` and
    ``cm[i] - radius <= cm[j] <= cm[i] + radius``.

    Index ranges are start-inclusive and stop-exclusive. This function is
    especially useful in conjunction with
    :meth:`.BlockMatrix.sparsify_row_intervals`.

    Parameters
    ----------
    locus_expr : :class:`.LocusExpression`
        Row-indexed locus expression on a table or matrix table.
    radius: :obj:`int`
        Radius of window for row values.
    coord_expr: :class:`.Float64Expression`, optional
        Row-indexed numeric expression for the row value.
        Must be on the same table or matrix table as `locus_expr`.
        By default, the row value is given by the locus position.

    Returns
    -------
    (:class:`ndarray` of :obj:`int64`, :class:`ndarray` of :obj:`int64`)
        Tuple of start indices array and stop indices array.
    """
    if radius < 0:
        raise ValueError(
            f"locus_windows: 'radius' must be non-negative, found {radius}")
    check_row_indexed('locus_windows', locus_expr)
    if coord_expr is None:
        global_pos_list = locus_expr.global_position().collect()
        n_loci = len(global_pos_list)
        global_pos = np.zeros(n_loci, dtype=np.int64)
        for i, p in enumerate(global_pos_list):
            if p is None:
                raise ValueError(
                    f"locus_windows: missing value for 'locus_expr' global position at row {i}"
                )
            global_pos[i] = p
        coord = global_pos
        del global_pos_list
    else:
        check_row_indexed('locus_windows', coord_expr)
        global_pos_and_coord =\
            hl.tuple([locus_expr.global_position(), coord_expr]).collect()  # raises exception if sources differ
        n_loci = len(global_pos_and_coord)

        global_pos = np.zeros(n_loci, dtype=np.int64)
        coord = np.zeros(n_loci, dtype=np.float64)
        for i, x in enumerate(global_pos_and_coord):
            if x[0] is None:
                raise ValueError(
                    f"locus_windows: missing value for 'locus_expr' global position at row {i}"
                )
            global_pos[i] = x[0]
            if x[1] is None:
                raise ValueError(
                    f"locus_windows: missing value for 'coord_expr' at row {i}"
                )
            coord[i] = x[1]
        del global_pos_and_coord

    if n_loci == 0:
        return np.zeros(shape=0, dtype=np.int64), np.zeros(shape=0,
                                                           dtype=np.int64)

    contig_name = locus_expr.dtype.reference_genome.contigs
    contig_len = locus_expr.dtype.reference_genome.lengths
    contig_cum_len = np.cumsum([contig_len[name] for name in contig_name])

    assert (global_pos[-1] < contig_cum_len[-1])

    contig_start_idx = _compute_contig_start_idx(global_pos, contig_cum_len)
    n_contigs = len(contig_start_idx)
    contig_start_idx.append(n_loci)
    contig_bounds = [
        array_windows(coord[contig_start_idx[c]:contig_start_idx[c + 1]],
                      radius) for c in range(n_contigs)
    ]
    starts = np.concatenate(
        [contig_start_idx[c] + contig_bounds[c][0] for c in range(n_contigs)])
    stops = np.concatenate(
        [contig_start_idx[c] + contig_bounds[c][1] for c in range(n_contigs)])

    return starts, stops
示例#22
0
def table_aggregate_take_by_strings(ht_path):
    ht = hl.read_table(ht_path)
    ht.aggregate(
        hl.tuple([
            hl.agg.take(ht['f18'], 25, ordering=ht[f'f{i}']) for i in range(18)
        ]))
示例#23
0
def assert_all_eval_to(*expr_and_expected):
    exprs, expecteds = zip(*expr_and_expected)
    assert_evals_to(hl.tuple(exprs), expecteds)
示例#24
0
def infer_families(
    relationship_ht: hl.Table,
    sex: Union[hl.Table, Dict[str, bool]],
    duplicate_samples_ht: hl.Table,
    i_col: str = "i",
    j_col: str = "j",
    relationship_col: str = "relationship",
) -> hl.Pedigree:
    """
    This function takes a hail Table with a row for each pair of individuals i,j in the data that are related (it's OK to have unrelated samples too).
    The `relationship_col` should be a column specifying the relationship between each two samples as defined in this module's constants.

    This function returns a pedigree containing trios inferred from the data. Family ID can be the same for multiple
    trios if one or more members of the trios are related (e.g. sibs, multi-generational family). Trios are ordered by family ID.

    .. note::

        This function only returns complete trios defined as: one child, one father and one mother (sex is required for both parents).

    :param relationship_ht: Input relationship table
    :param sex: A Table or dict giving the sex for each sample (`TRUE`=female, `FALSE`=male). If a Table is given, it should have a field `is_female`.
    :param duplicated_samples: All duplicated samples TO REMOVE (If not provided, this function won't work as it assumes that each child has exactly two parents)
    :param i_col: Column containing the 1st sample of the pair in the relationship table
    :param j_col: Column containing the 2nd sample of the pair in the relationship table
    :param relationship_col: Column contatining the relationship for the sample pair as defined in this module constants.
    :return: Pedigree of complete trios
    """
    def group_parent_child_pairs_by_fam(
        parent_child_pairs: Iterable[Tuple[str, str]]
    ) -> List[List[Tuple[str, str]]]:
        """
        Takes all parent-children pairs and groups them by family.
        A family here is defined as a list of sample-pairs which all share at least one sample with at least one other sample-pair in the list.

        :param parent_child_pairs: All the parent-children pairs
        :return: A list of families, where each element of the list is a list of the parent-children pairs
        """
        fam_id = 1  # stores the current family id
        s_fam = dict()  # stores the family id for each sample
        fams = defaultdict(list)  # stores fam_id -> sample-pairs
        for pair in parent_child_pairs:
            if pair[0] in s_fam:
                if pair[1] in s_fam:
                    if (
                            s_fam[pair[0]] != s_fam[pair[1]]
                    ):  # If both samples are in different families, merge the families
                        new_fam_id = s_fam[pair[0]]
                        fam_id_to_merge = s_fam[pair[1]]
                        for s in s_fam:
                            if s_fam[s] == fam_id_to_merge:
                                s_fam[s] = new_fam_id
                        fams[new_fam_id].extend(fams.pop(fam_id_to_merge))
                else:  # If only the 1st sample in the pair is already in a family, assign the 2nd sample in the pair to the same family
                    s_fam[pair[1]] = s_fam[pair[0]]
                fams[s_fam[pair[0]]].append(pair)
            elif (
                    pair[1] in s_fam
            ):  # If only the 2nd sample in the pair is already in a family, assign the 1st sample in the pair to the same family
                s_fam[pair[0]] = s_fam[pair[1]]
                fams[s_fam[pair[1]]].append(pair)
            else:  # If none of the samples in the pair is already in a family, create a new family
                s_fam[pair[0]] = fam_id
                s_fam[pair[1]] = fam_id
                fams[fam_id].append(pair)
                fam_id += 1

        return list(fams.values())

    def get_trios(
        fam_id: str,
        parent_child_pairs: List[Tuple[str, str]],
        related_pairs: Dict[Tuple[str, str], str],
    ) -> List[hl.Trio]:
        """
        Generates trios based from the list of parent-child pairs in the family and
        all related pairs in the data. Only complete parent/offspring trios are included in the results.

        The trios are assembled as follows:
        1. All pairs of unrelated samples with different sexes within the family are extracted as possible parent pairs
        2. For each possible parent pair, a list of all children is constructed (each child in the list has a parent-offspring pair with each parent)
        3. If there are multiple children for a given parent pair, all children should be siblings with each other
        4. Check that each child was only assigned a single pair of parents. If a child is found to have multiple parent pairs, they are ALL discarded.

        :param fam_id: The family ID
        :param parent_child_pairs: The parent-child pairs for this family
        :param related_pairs: All related sample pairs in the data
        :return: List of trios in the family
        """
        def get_possible_parents(samples: List[str]) -> List[Tuple[str, str]]:
            """
            1. All pairs of unrelated samples with different sexes within the family are extracted as possible parent pairs

            :param samples: All samples in the family
            :return: Possible parent pairs
            """
            possible_parents = []
            for i in range(len(samples)):
                for j in range(i + 1, len(samples)):
                    if (related_pairs.get(
                            tuple(sorted([samples[i], samples[j]]))) is None):
                        if sex.get(samples[i]) is False and sex.get(
                                samples[j]) is True:
                            possible_parents.append((samples[i], samples[j]))
                        elif (sex.get(samples[i]) is True
                              and sex.get(samples[j]) is False):
                            possible_parents.append((samples[j], samples[i]))
            return possible_parents

        def get_children(possible_parents: Tuple[str, str]) -> List[str]:
            """
            2. For a given possible parent pair, a list of all children is constructed (each child in the list has a parent-offspring pair with each parent)

            :param possible_parents: A pair of possible parents
            :return: The list of all children (if any) corresponding to the possible parents
            """
            possible_offsprings = defaultdict(
                set
            )  # stores sample -> set of parents in the possible_parents where (sample, parent) is found in possible_child_pairs
            for pair in parent_child_pairs:
                if possible_parents[0] == pair[0]:
                    possible_offsprings[pair[1]].add(possible_parents[0])
                elif possible_parents[0] == pair[1]:
                    possible_offsprings[pair[0]].add(possible_parents[0])
                elif possible_parents[1] == pair[0]:
                    possible_offsprings[pair[1]].add(possible_parents[1])
                elif possible_parents[1] == pair[1]:
                    possible_offsprings[pair[0]].add(possible_parents[1])

            return [
                s for s, parents in possible_offsprings.items()
                if len(parents) == 2
            ]

        def check_sibs(children: List[str]) -> bool:
            """
            3. If there are multiple children for a given parent pair, all children should be siblings with each other

            :param children: List of all children for a given parent pair
            :return: Whether all children in the list are siblings
            """
            for i in range(len(children)):
                for j in range(i + 1, len(children)):
                    if (related_pairs[tuple(sorted([children[i], children[j]
                                                    ]))] != SIBLINGS):
                        return False
            return True

        def discard_multi_parents_children(trios: List[hl.Trio]):
            """
            4. Check that each child was only assigned a single pair of parents. If a child is found to have multiple parent pairs, they are ALL discarded.

            :param trios: All trios formed for this family
            :return: The list of trios for which each child has a single parents pair.
            """
            children_trios = defaultdict(list)
            for trio in trios:
                children_trios[trio.s].append(trio)

            for s, s_trios in children_trios.items():
                if len(s_trios) > 1:
                    logger.warning(
                        "Discarded duplicated child {0} found multiple in trios: {1}"
                        .format(s, ", ".join([str(trio) for trio in s_trios])))

            return [
                trios[0] for trios in children_trios.values()
                if len(trios) == 1
            ]

        # Get all possible pairs of parents in (father, mother) order
        all_possible_parents = get_possible_parents(
            list({s
                  for pair in parent_child_pairs for s in pair}))

        trios = []
        for possible_parents in all_possible_parents:
            children = get_children(possible_parents)
            if check_sibs(children):
                trios.extend([
                    hl.Trio(
                        s=s,
                        fam_id=fam_id,
                        pat_id=possible_parents[0],
                        mat_id=possible_parents[1],
                        is_female=sex.get(s),
                    ) for s in children
                ])
            else:
                logger.warning(
                    "Discarded family with same parents, and multiple offspring that weren't siblings:"
                    "\nMother: {}\nFather:{}\nChildren:{}".format(
                        possible_parents[0], possible_parents[1],
                        ", ".join(children)))

        return discard_multi_parents_children(trios)

    # Get all the relations we care about:
    # => Remove unrelateds and duplicates
    dups = duplicate_samples_ht.aggregate(
        hl.agg.explode(lambda dup: hl.agg.collect_as_set(dup),
                       duplicate_samples_ht.filtered),
        _localize=False,
    )
    relationship_ht = relationship_ht.filter(
        ~dups.contains(relationship_ht[i_col])
        & ~dups.contains(relationship_ht[j_col])
        & (relationship_ht[relationship_col] != UNRELATED))

    # Check relatedness table format
    if not relationship_ht[i_col].dtype == relationship_ht[j_col].dtype:
        logger.error(
            "i_col and j_col of the relatedness table need to be of the same type."
        )

    # If i_col and j_col aren't str, then convert them
    if not isinstance(relationship_ht[i_col], hl.expr.StringExpression):
        logger.warning(
            f"Pedigrees can only be constructed from string IDs, but your relatedness_ht ID column is of type: {relationship_ht[i_col].dtype}. Expression will be converted to string in Pedigrees."
        )
        if isinstance(relationship_ht[i_col], hl.expr.StructExpression):
            logger.warning(
                f"Struct fields {list(relationship_ht[i_col])} will be joined by underscores to use as sample names in Pedigree."
            )
            relationship_ht = relationship_ht.key_by(
                **{
                    i_col:
                    hl.delimit(
                        hl.array([
                            hl.str(relationship_ht[i_col][x])
                            for x in relationship_ht[i_col]
                        ]),
                        "_",
                    ),
                    j_col:
                    hl.delimit(
                        hl.array([
                            hl.str(relationship_ht[j_col][x])
                            for x in relationship_ht[j_col]
                        ]),
                        "_",
                    ),
                })
        else:
            raise NotImplementedError(
                "The `i_col` and `j_col` columns of the `relationship_ht` argument passed to infer_families are not of type StringExpression or Struct."
            )

    # If sex is a Table, extract sex information as a Dict
    if isinstance(sex, hl.Table):
        sex = dict(hl.tuple([sex.s, sex.is_female]).collect())

    # Collect all related sample pairs and
    # create a dictionnary with pairs as keys and relationships as values
    # Sample-pairs are tuples ordered by sample name
    related_pairs = {
        tuple(sorted([i, j])): rel
        for i, j, rel in hl.tuple([
            relationship_ht.i, relationship_ht.j, relationship_ht.relationship
        ]).collect()
    }

    parent_child_pairs_by_fam = group_parent_child_pairs_by_fam(
        [pair for pair, rel in related_pairs.items() if rel == PARENT_CHILD])
    return hl.Pedigree([
        trio for fam_index, parent_child_pairs in enumerate(
            parent_child_pairs_by_fam) for trio in get_trios(
                str(fam_index), parent_child_pairs, related_pairs)
    ])
示例#25
0
import hail as hl
from gnomad_qc.v3.resources import get_full_mt

last_END_position_path = 'gs://gnomad/annotations/hail-0.2/ht/genomes_v3/gnomad_genomes_v3_last_END_positions.ht'

# END RESOURCES

mt = get_full_mt(False)
mt = mt.select_entries('END')
t = mt._localize_entries('__entries', '__cols')
t = t.select(last_END_position=hl.or_else(
    hl.min(
        hl.scan.array_agg(
            lambda entry: hl.scan._prev_nonnull(
                hl.or_missing(hl.is_defined(entry.END),
                              hl.tuple([t.locus, entry.END]))), t.__entries).
        map(lambda x: hl.or_missing((x[1] >= t.locus.position) & (x[
            0].contig == t.locus.contig), x[0].position))), t.locus.position))
t.write(last_END_position_path, overwrite=True)
示例#26
0
 def _all_summary_aggs(self):
     return hl.tuple((hl.agg.filter(hl.is_missing(self), hl.agg.count()),
                      hl.agg.filter(hl.is_defined(self),
                                    hl.agg.count()), self._summary_aggs()))
示例#27
0
def table_aggregate_counter():
    ht = hl.read_table(resource('many_strings_table.ht'))
    ht.aggregate(hl.tuple([hl.agg.counter(ht[f'f{i}']) for i in range(8)]))
示例#28
0
def compute_stratified_metrics_filter(
    ht: hl.Table,
    qc_metrics: Dict[str, hl.expr.NumericExpression],
    strata: Optional[Dict[str, hl.expr.Expression]] = None,
    lower_threshold: float = 4.0,
    upper_threshold: float = 4.0,
    metric_threshold: Optional[Dict[str, Tuple[float, float]]] = None,
    filter_name: str = "qc_metrics_filters",
) -> hl.Table:
    """
    Compute median, MAD, and upper and lower thresholds for each metric used in outlier filtering.

    :param ht: HT containing relevant sample QC metric annotations
    :param qc_metrics: list of metrics (name and expr) for which to compute the critical values for filtering outliers
    :param strata: List of annotations used for stratification. These metrics should be discrete types!
    :param lower_threshold: Lower MAD threshold
    :param upper_threshold: Upper MAD threshold
    :param metric_threshold: Can be used to specify different (lower, upper) thresholds for one or more metrics
    :param filter_name: Name of resulting filters annotation
    :return: Table grouped by strata, with upper and lower threshold values computed for each sample QC metric
    """
    _metric_threshold = {
        metric: (lower_threshold, upper_threshold)
        for metric in qc_metrics
    }
    if metric_threshold is not None:
        _metric_threshold.update(metric_threshold)

    def make_filters_expr(ht: hl.Table,
                          qc_metrics: Iterable[str]) -> hl.expr.SetExpression:
        return hl.set(
            hl.filter(
                lambda x: hl.is_defined(x),
                [
                    hl.or_missing(ht[f"fail_{metric}"], metric)
                    for metric in qc_metrics
                ],
            ))

    if strata is None:
        strata = {}

    ht = ht.select(**qc_metrics, **strata).key_by("s").persist()

    agg_expr = hl.struct(
        **{
            metric: hl.bind(
                lambda x: x.annotate(
                    lower=x.median - _metric_threshold[metric][0] * x.mad,
                    upper=x.median + _metric_threshold[metric][1] * x.mad,
                ),
                get_median_and_mad_expr(ht[metric]),
            )
            for metric in qc_metrics
        })

    if strata:
        ht = ht.annotate_globals(qc_metrics_stats=ht.aggregate(
            hl.agg.group_by(hl.tuple([ht[x] for x in strata]), agg_expr),
            _localize=False,
        ))
        metrics_stats_expr = ht.qc_metrics_stats[hl.tuple(
            [ht[x] for x in strata])]
    else:
        ht = ht.annotate_globals(
            qc_metrics_stats=ht.aggregate(agg_expr, _localize=False))
        metrics_stats_expr = ht.qc_metrics_stats

    fail_exprs = {
        f"fail_{metric}": (ht[metric] <= metrics_stats_expr[metric].lower)
        | (ht[metric] >= metrics_stats_expr[metric].upper)
        for metric in qc_metrics
    }
    ht = ht.transmute(**fail_exprs)
    stratified_filters = make_filters_expr(ht, qc_metrics)
    return ht.annotate(**{filter_name: stratified_filters})
示例#29
0
def sample_qc(vds: 'VariantDataset', *, name='sample_qc', gq_bins: 'Sequence[int]' = (0, 20, 60)) -> 'Table':
    """Run sample_qc on dataset in the split :class:`.VariantDataset` representation.

    Parameters
    ----------
    vds : :class:`.VariantDataset`
        Dataset in VariantDataset representation.
    name : :obj:`str`
        Name for resulting field.
    gq_bins : :class:`tuple` of :obj:`int`
        Tuple containing cutoffs for genotype quality (GQ) scores.

    Returns
    -------
    :class:`.Table`
        Hail Table of results, keyed by sample.
    """

    require_first_key_field_locus(vds.reference_data, 'sample_qc')
    require_first_key_field_locus(vds.variant_data, 'sample_qc')

    from hail.expr.functions import _num_allele_type, _allele_types

    allele_types = _allele_types[:]
    allele_types.extend(['Transition', 'Transversion'])
    allele_enum = {i: v for i, v in enumerate(allele_types)}
    allele_ints = {v: k for k, v in allele_enum.items()}

    def allele_type(ref, alt):
        return hl.bind(
            lambda at: hl.if_else(at == allele_ints['SNP'],
                                  hl.if_else(hl.is_transition(ref, alt),
                                             allele_ints['Transition'],
                                             allele_ints['Transversion']),
                                  at),
            _num_allele_type(ref, alt)
        )

    variant_ac = Env.get_uid()
    variant_atypes = Env.get_uid()

    vmt = vds.variant_data
    if 'GT' not in vmt.entry:
        vmt = vmt.annotate_entries(GT=hl.experimental.lgt_to_gt(vmt.LGT, vmt.LA))

    vmt = vmt.annotate_rows(**{
        variant_ac: hl.agg.call_stats(vmt.GT, vmt.alleles).AC,
        variant_atypes: vmt.alleles[1:].map(lambda alt: allele_type(vmt.alleles[0], alt))
    })

    bound_exprs = {}

    bound_exprs['n_het'] = hl.agg.count_where(vmt['GT'].is_het())
    bound_exprs['n_hom_var'] = hl.agg.count_where(vmt['GT'].is_hom_var())
    bound_exprs['n_singleton'] = hl.agg.sum(
        hl.sum(hl.range(0, vmt['GT'].ploidy).map(lambda i: vmt[variant_ac][vmt['GT'][i]] == 1))
    )

    bound_exprs['allele_type_counts'] = hl.agg.explode(
        lambda allele_type: hl.tuple(
            hl.agg.count_where(allele_type == i) for i in range(len(allele_ints))
        ),
        (hl.range(0, vmt['GT'].ploidy)
         .map(lambda i: vmt['GT'][i])
         .filter(lambda allele_idx: allele_idx > 0)
         .map(lambda allele_idx: vmt[variant_atypes][allele_idx - 1]))
    )

    gq_exprs = hl.agg.filter(
        hl.is_defined(vmt.GT),
        hl.struct(**{f'gq_over_{x}': hl.agg.count_where(vmt.GQ > x) for x in gq_bins})
    )

    result_struct = hl.rbind(
        hl.struct(**bound_exprs),
        lambda x: hl.rbind(
            hl.struct(**{
                'gq_exprs': gq_exprs,
                'n_het': x.n_het,
                'n_hom_var': x.n_hom_var,
                'n_non_ref': x.n_het + x.n_hom_var,
                'n_singleton': x.n_singleton,
                'n_snp': (x.allele_type_counts[allele_ints['Transition']]
                          + x.allele_type_counts[allele_ints['Transversion']]),
                'n_insertion': x.allele_type_counts[allele_ints['Insertion']],
                'n_deletion': x.allele_type_counts[allele_ints['Deletion']],
                'n_transition': x.allele_type_counts[allele_ints['Transition']],
                'n_transversion': x.allele_type_counts[allele_ints['Transversion']],
                'n_star': x.allele_type_counts[allele_ints['Star']]
            }),
            lambda s: s.annotate(
                r_ti_tv=divide_null(hl.float64(s.n_transition), s.n_transversion),
                r_het_hom_var=divide_null(hl.float64(s.n_het), s.n_hom_var),
                r_insertion_deletion=divide_null(hl.float64(s.n_insertion), s.n_deletion)
            )
        )
    )
    variant_results = vmt.select_cols(**result_struct).cols()

    rmt = vds.reference_data
    ref_results = rmt.select_cols(
        gq_exprs=hl.struct(**{
            f'gq_over_{x}': hl.agg.filter(rmt.GQ > x, hl.agg.sum(1 + rmt.END - rmt.locus.position)) for x in gq_bins
        })
    ).cols()

    joined = ref_results[variant_results.key].gq_exprs
    joined_results = variant_results.transmute(**{
        f'gq_over_{x}': variant_results.gq_exprs[f'gq_over_{x}'] + joined[f'gq_over_{x}'] for x in gq_bins
    })
    return joined_results
示例#30
0
def create_binned_concordance(data_type: str, truth_sample: str, metric: str,
                              nbins: int, overwrite: bool) -> None:
    """
    Creates and writes a concordance table binned by rank (both absolute and relative) for a given data type, truth sample and metric.

    :param str data_type: One 'exomes' or 'genomes'
    :param str truth_sample: Which truth sample concordance to load
    :param str metric: One of the evaluation metrics (or a RF hash)
    :param int nbins: Number of bins for the rank
    :param bool overwrite: Whether to overwrite existing table
    :return: Nothing -- just writes the table
    :rtype: None
    """

    if hl.hadoop_exists(
            binned_concordance_path(data_type, truth_sample, metric) +
            '/_SUCCESS') and not overwrite:
        logger.warn(
            f"Skipping binned concordance creation as {binned_concordance_path(data_type, truth_sample, metric)} exists and overwrite=False"
        )
    else:
        ht = hl.read_table(
            annotations_ht_path(data_type, f'{truth_sample}_concordance'))
        # Remove 1bp indels for syndip as cannot be trusted
        if truth_sample == 'syndip':
            ht = ht.filter(
                hl.is_indel(ht.alleles[0], ht.alleles[1]) &
                (hl.abs(hl.len(ht.alleles[0]) - hl.len(ht.alleles[1])) == 1),
                keep=False)
            high_conf_intervals = hl.import_locus_intervals(
                syndip_high_conf_regions_bed_path)
        else:
            high_conf_intervals = hl.import_locus_intervals(
                NA12878_high_conf_regions_bed_path)

        lcr = hl.import_locus_intervals(lcr_intervals_path)
        segdup = hl.import_locus_intervals(segdup_intervals_path)
        ht = ht.filter(
            hl.is_defined(high_conf_intervals[ht.locus])
            & hl.is_missing(lcr[ht.locus]) & hl.is_missing(segdup[ht.locus]))

        if metric in ['vqsr', 'rf_2.0.2', 'rf_2.0.2_beta', 'cnn']:
            metric_ht = hl.read_table(score_ranking_path(data_type, metric))
        else:
            metric_ht = hl.read_table(
                rf_path(data_type, 'rf_result', run_hash=metric))

        metric_snvs, metrics_indels = metric_ht.aggregate([
            hl.agg.count_where(
                hl.is_snp(metric_ht.alleles[0], metric_ht.alleles[1])),
            hl.agg.count_where(
                ~hl.is_snp(metric_ht.alleles[0], metric_ht.alleles[1]))
        ])

        snvs, indels = ht.aggregate([
            hl.agg.count_where(hl.is_snp(ht.alleles[0], ht.alleles[1])),
            hl.agg.count_where(~hl.is_snp(ht.alleles[0], ht.alleles[1]))
        ])

        ht = ht.annotate_globals(global_counts=hl.struct(
            snvs=metric_snvs, indels=metrics_indels),
                                 counts=hl.struct(snvs=snvs, indels=indels))

        ht = ht.annotate(
            snv=hl.is_snp(ht.alleles[0], ht.alleles[1]),
            score=metric_ht[ht.key].score,
            global_rank=metric_ht[ht.key].rank,
            # TP => allele is found in both data sets
            n_tp=ht.concordance[3][3] + ht.concordance[3][4] +
            ht.concordance[4][3] + ht.concordance[4][4],
            # FP => allele is found only in test data set
            n_fp=hl.sum(ht.concordance[3][:2]) + hl.sum(ht.concordance[4][:2]),
            # FN => allele is found only in truth data set
            n_fn=hl.sum(ht.concordance[:2].map(lambda x: x[3] + x[4])))

        ht = add_rank(ht, -1.0 * ht.score)

        ht = ht.annotate(rank=[
            hl.tuple([
                'global_rank', (ht.global_rank + 1) /
                hl.cond(ht.snv, ht.globals.global_counts.snvs,
                        ht.globals.global_counts.indels)
            ]),
            hl.tuple([
                'truth_sample_rank', (ht.rank + 1) / hl.cond(
                    ht.snv, ht.globals.counts.snvs, ht.globals.counts.indels)
            ])
        ])

        ht = ht.explode(ht.rank)
        ht = ht.annotate(rank_name=ht.rank[0], bin=hl.int(ht.rank[1] * nbins))

        ht = ht.group_by('rank_name', 'snv', 'bin').aggregate(
            # Look at site-level metrics -> tp > fp > fn -- only important for multi-sample comparisons
            tp=hl.agg.count_where(ht.n_tp > 0),
            fp=hl.agg.count_where((ht.n_tp == 0) & (ht.n_fp > 0)),
            fn=hl.agg.count_where((ht.n_tp == 0) & (ht.n_fp == 0)
                                  & (ht.n_fn > 0)),
            min_score=hl.agg.min(ht.score),
            max_score=hl.agg.max(ht.score),
            n_alleles=hl.agg.count()).repartition(5)

        ht.write(binned_concordance_path(data_type, truth_sample, metric),
                 overwrite=overwrite)
示例#31
0
def combine(ts):
    def merge_alleles(alleles):
        from hail.expr.functions import _num_allele_type, _allele_ints
        return hl.rbind(
            alleles.map(lambda a: hl.or_else(a[0], ''))
                   .fold(lambda s, t: hl.cond(hl.len(s) > hl.len(t), s, t), ''),
            lambda ref:
            hl.rbind(
                alleles.map(
                    lambda al: hl.rbind(
                        al[0],
                        lambda r:
                        hl.array([ref]).extend(
                            al[1:].map(
                                lambda a:
                                hl.rbind(
                                    _num_allele_type(r, a),
                                    lambda at:
                                    hl.cond(
                                        (_allele_ints['SNP'] == at) |
                                        (_allele_ints['Insertion'] == at) |
                                        (_allele_ints['Deletion'] == at) |
                                        (_allele_ints['MNP'] == at) |
                                        (_allele_ints['Complex'] == at),
                                        a + ref[hl.len(r):],
                                        a)))))),
                lambda lal:
                hl.struct(
                    globl=hl.array([ref]).extend(hl.array(hl.set(hl.flatten(lal)).remove(ref))),
                    local=lal)))

    def renumber_entry(entry, old_to_new) -> StructExpression:
        # global index of alternate (non-ref) alleles
        return entry.annotate(LA=entry.LA.map(lambda lak: old_to_new[lak]))

    if (ts.row.dtype, ts.globals.dtype) not in _merge_function_map:
        f = hl.experimental.define_function(
            lambda row, gbl:
            hl.rbind(
                merge_alleles(row.data.map(lambda d: d.alleles)),
                lambda alleles:
                hl.struct(
                    locus=row.locus,
                    alleles=alleles.globl,
                    rsid=hl.find(hl.is_defined, row.data.map(lambda d: d.rsid)),
                    __entries=hl.bind(
                        lambda combined_allele_index:
                        hl.range(0, hl.len(row.data)).flatmap(
                            lambda i:
                            hl.cond(hl.is_missing(row.data[i].__entries),
                                    hl.range(0, hl.len(gbl.g[i].__cols))
                                      .map(lambda _: hl.null(row.data[i].__entries.dtype.element_type)),
                                    hl.bind(
                                        lambda old_to_new: row.data[i].__entries.map(
                                            lambda e: renumber_entry(e, old_to_new)),
                                        hl.range(0, hl.len(alleles.local[i])).map(
                                            lambda j: combined_allele_index[alleles.local[i][j]])))),
                        hl.dict(hl.range(0, hl.len(alleles.globl)).map(
                            lambda j: hl.tuple([alleles.globl[j], j])))))),
            ts.row.dtype, ts.globals.dtype)
        _merge_function_map[(ts.row.dtype, ts.globals.dtype)] = f
    merge_function = _merge_function_map[(ts.row.dtype, ts.globals.dtype)]
    ts = Table(TableMapRows(ts._tir, Apply(merge_function._name,
                                           TopLevelReference('row'),
                                           TopLevelReference('global'))))
    return ts.transmute_globals(__cols=hl.flatten(ts.g.map(lambda g: g.__cols)))
示例#32
0
 def _coerce(self, x: Expression):
     assert isinstance(x, hl.expr.TupleExpression)
     return hl.tuple(c.coerce(e) for c, e in zip(self.elements, x))
示例#33
0
 def _coerce(self, x: Expression):
     assert isinstance(x, hl.expr.TupleExpression)
     return hl.tuple(c.coerce(e) for c, e in zip(self.elements, x))
示例#34
0
def table_aggregate_take_by_strings():
    ht = hl.read_table(resource('many_strings_table.ht'))
    ht.aggregate(
        hl.tuple([
            hl.agg.take(ht['f18'], 25, ordering=ht[f'f{i}']) for i in range(18)
        ]))
示例#35
0
def full_outer_join_mt(left: hl.MatrixTable, right: hl.MatrixTable) -> hl.MatrixTable:
    """Performs a full outer join on `left` and `right`.

    Replaces row, column, and entry fields with the following:

     - `left_row` / `right_row`: structs of row fields from left and right.
     - `left_col` / `right_col`: structs of column fields from left and right.
     - `left_entry` / `right_entry`: structs of entry fields from left and right.

    Parameters
    ----------
    left : :class:`.MatrixTable`
    right : :class:`.MatrixTable`

    Returns
    -------
    :class:`.MatrixTable`
    """

    if [x.dtype for x in left.row_key.values()] != [x.dtype for x in right.row_key.values()]:
        raise ValueError(f"row key types do not match:\n"
                         f"  left:  {list(left.row_key.values())}\n"
                         f"  right: {list(right.row_key.values())}")

    if [x.dtype for x in left.col_key.values()] != [x.dtype for x in right.col_key.values()]: 
        raise ValueError(f"column key types do not match:\n"
                         f"  left:  {list(left.col_key.values())}\n"
                         f"  right: {list(right.col_key.values())}")

    left = left.select_rows(left_row=left.row)
    left_t = left.localize_entries('left_entries', 'left_cols')
    right = right.select_rows(right_row=right.row)
    right_t = right.localize_entries('right_entries', 'right_cols')

    ht = left_t.join(right_t, how='outer')
    ht = ht.annotate_globals(
        left_keys=hl.group_by(
            lambda t: t[0],
            hl.zip_with_index(
                ht.left_cols.map(lambda x: hl.tuple([x[f] for f in left.col_key])), index_first=False)).map_values(
            lambda elts: elts.map(lambda t: t[1])),
        right_keys=hl.group_by(
            lambda t: t[0],
            hl.zip_with_index(
                ht.right_cols.map(lambda x: hl.tuple([x[f] for f in right.col_key])), index_first=False)).map_values(
            lambda elts: elts.map(lambda t: t[1])))
    ht = ht.annotate_globals(
        key_indices=hl.array(ht.left_keys.key_set().union(ht.right_keys.key_set()))
            .map(lambda k: hl.struct(k=k, left_indices=ht.left_keys.get(k), right_indices=ht.right_keys.get(k)))
            .flatmap(lambda s: hl.case()
                     .when(hl.is_defined(s.left_indices) & hl.is_defined(s.right_indices),
                           hl.range(0, s.left_indices.length()).flatmap(
                               lambda i: hl.range(0, s.right_indices.length()).map(
                                   lambda j: hl.struct(k=s.k, left_index=s.left_indices[i],
                                                       right_index=s.right_indices[j]))))
                     .when(hl.is_defined(s.left_indices),
                           s.left_indices.map(
                               lambda elt: hl.struct(k=s.k, left_index=elt, right_index=hl.null('int32'))))
                     .when(hl.is_defined(s.right_indices),
                           s.right_indices.map(
                               lambda elt: hl.struct(k=s.k, left_index=hl.null('int32'), right_index=elt)))
                     .or_error('assertion error')))
    ht = ht.annotate(__entries=ht.key_indices.map(lambda s: hl.struct(left_entry=ht.left_entries[s.left_index],
                                                                      right_entry=ht.right_entries[s.right_index])))
    ht = ht.annotate_globals(__cols=ht.key_indices.map(
        lambda s: hl.struct(**{f: s.k[i] for i, f in enumerate(left.col_key)},
                            left_col=ht.left_cols[s.left_index],
                            right_col=ht.right_cols[s.right_index])))
    ht = ht.drop('left_entries', 'left_cols', 'left_keys', 'right_entries', 'right_cols', 'right_keys', 'key_indices')
    return ht._unlocalize_entries('__entries', '__cols', list(left.col_key))
示例#36
0
def test_ndarray_reshape():
    np_single = np.array([8])
    single = hl.nd.array([8])

    np_zero_dim = np.array(4)
    zero_dim = hl.nd.array(4)

    np_a = np.array([1, 2, 3, 4, 5, 6])
    a = hl.nd.array(np_a)

    np_cube = np.array([0, 1, 2, 3, 4, 5, 6, 7]).reshape((2, 2, 2))
    cube = hl.nd.array([0, 1, 2, 3, 4, 5, 6, 7]).reshape((2, 2, 2))
    cube_to_rect = cube.reshape((2, 4))
    np_cube_to_rect = np_cube.reshape((2, 4))
    cube_t_to_rect = cube.transpose((1, 0, 2)).reshape((2, 4))
    np_cube_t_to_rect = np_cube.transpose((1, 0, 2)).reshape((2, 4))

    np_hypercube = np.arange(3 * 5 * 7 * 9).reshape((3, 5, 7, 9))
    hypercube = hl.nd.array(np_hypercube)

    np_shape_zero = np.array([])
    shape_zero = hl.nd.array(np_shape_zero)

    assert_ndarrays_eq((single.reshape(()), np_single.reshape(
        ())), (zero_dim.reshape(()), np_zero_dim.reshape(
            ())), (zero_dim.reshape((1, )), np_zero_dim.reshape(
                (1, ))), (a.reshape((6, )), np_a.reshape((6, ))), (a.reshape(
                    (2, 3)), np_a.reshape((2, 3))), (a.reshape(
                        (3, 2)), np_a.reshape((3, 2))), (a.reshape(
                            (3, -1)), np_a.reshape((3, -1))), (a.reshape(
                                (-1, 2)), np_a.reshape(
                                    (-1, 2))), (cube_to_rect, np_cube_to_rect),
                       (cube_t_to_rect, np_cube_t_to_rect), (hypercube.reshape(
                           (5, 7, 9, 3)).reshape(
                               (7, 9, 3, 5)), np_hypercube.reshape(
                                   (7, 9, 3, 5))),
                       (hypercube.reshape(hl.tuple(
                           [5, 7, 9, 3])), np_hypercube.reshape(
                               (5, 7, 9, 3))), (shape_zero.reshape(
                                   (0, 5)), np_shape_zero.reshape((0, 5))),
                       (shape_zero.reshape(
                           (-1, 5)), np_shape_zero.reshape((-1, 5))))

    assert hl.eval(hl.null(hl.tndarray(hl.tfloat, 2)).reshape((4, 5))) is None
    assert hl.eval(
        hl.nd.array(hl.range(20)).reshape(
            hl.null(hl.ttuple(hl.tint64, hl.tint64)))) is None

    with pytest.raises(FatalError) as exc:
        hl.eval(hl.literal(np_cube).reshape((-1, -1)))
    assert "more than one -1" in str(exc)

    with pytest.raises(FatalError) as exc:
        hl.eval(hl.literal(np_cube).reshape((20, )))
    assert "requested shape is incompatible with number of elements" in str(
        exc)

    with pytest.raises(FatalError) as exc:
        hl.eval(a.reshape((3, )))
    assert "requested shape is incompatible with number of elements" in str(
        exc)

    with pytest.raises(FatalError) as exc:
        hl.eval(a.reshape(()))
    assert "requested shape is incompatible with number of elements" in str(
        exc)

    with pytest.raises(FatalError) as exc:
        hl.eval(hl.literal(np_cube).reshape((0, 2, 2)))
    assert "requested shape is incompatible with number of elements" in str(
        exc)

    with pytest.raises(FatalError) as exc:
        hl.eval(hl.literal(np_cube).reshape((2, 2, -2)))
    assert "must contain only nonnegative numbers or -1" in str(exc)

    with pytest.raises(FatalError) as exc:
        hl.eval(shape_zero.reshape((0, -1)))
    assert "Can't reshape" in str(exc)
def write_data_files(table_path, output_directory, genes=None):
    if output_directory.startswith("gs://"):
        raise Exception("Cannot write output to Google Storage")

    ds = hl.read_table(table_path)

    os.makedirs(output_directory, exist_ok=True)

    with open(f"{output_directory}/metadata.json", "w") as output_file:
        output_file.write(hl.eval(hl.json(ds.globals.meta)))

    gene_search_terms = ds.select(data=hl.json(hl.tuple([ds.gene_id, ds.search_terms])))
    gene_search_terms.key_by().select("data").export(f"{output_directory}/gene_search_terms.json.txt", header=False)
    os.remove(f"{output_directory}/.gene_search_terms.json.txt.crc")

    ds = ds.drop("previous_symbols", "alias_symbols", "search_terms")

    os.makedirs(f"{output_directory}/results", exist_ok=True)
    for dataset in ds.globals.meta.datasets.dtype.fields:
        reference_genome = "GRCh38" if dataset == "bipex" else "GRCh37"
        gene_results = ds.filter(hl.is_defined(ds.gene_results[dataset]))
        gene_results = gene_results.select(
            result=hl.tuple(
                [
                    gene_results.gene_id,
                    gene_results.symbol,
                    gene_results.name,
                    gene_results[reference_genome].chrom,
                    (gene_results[reference_genome].start + gene_results[reference_genome].stop) // 2,
                    gene_results.gene_results[dataset].group_results,
                ]
            )
        )
        gene_results = gene_results.collect()

        gene_results = [r.result for r in gene_results]

        with open(f"{output_directory}/results/{dataset.lower()}.json", "w") as output_file:
            output_file.write(json.dumps({"results": gene_results}, cls=ResultEncoder))

    if genes:
        ds = ds.filter(hl.set(genes).contains(ds.gene_id))

    temp_file_name = "temp.tsv"
    n_rows = ds.count()
    ds.select(data=hl.json(ds.row)).export(f"{output_directory}/{temp_file_name}", header=False)

    csv.field_size_limit(sys.maxsize)
    os.makedirs(f"{output_directory}/genes", exist_ok=True)

    with multiprocessing.get_context("spawn").Pool() as pool:
        with open(f"{output_directory}/{temp_file_name}") as data_file:

            reader = csv.reader(data_file, delimiter="\t")
            for gene_id, gene_grch37, gene_grch38, all_variants in tqdm(pool.imap(split_data, reader), total=n_rows):
                num = int(gene_id.lstrip("ENSGR"))
                gene_dir = f"{output_directory}/genes/{str(num % 1000).zfill(3)}"
                os.makedirs(gene_dir, exist_ok=True)

                if gene_grch37:
                    with open(f"{gene_dir}/{gene_id}_GRCh37.json", "w") as out_file:
                        out_file.write(gene_grch37)

                if gene_grch38:
                    with open(f"{gene_dir}/{gene_id}_GRCh38.json", "w") as out_file:
                        out_file.write(gene_grch38)

                for dataset, dataset_variants in all_variants.items():
                    if dataset_variants:
                        with open(f"{gene_dir}/{gene_id}_{dataset.lower()}_variants.json", "w") as out_file:
                            out_file.write(dataset_variants)

    os.remove(f"{output_directory}/{temp_file_name}")
    os.remove(f"{output_directory}/.{temp_file_name}.crc")
示例#38
0
def scatter(x,
            y,
            label=None,
            title=None,
            xlabel=None,
            ylabel=None,
            size=4,
            legend=True,
            source_fields=None):
    """Create a scatterplot.

    Parameters
    ----------
    x : List[float] or :class:`.Float64Expression`
        List of x-values to be plotted.
    y : List[float] or :class:`.Float64Expression`
        List of y-values to be plotted.
    label : List[str] or :class:`.StringExpression`
        List of labels for x and y values, used to assign each point a label (e.g. population)
    title : str
        Title of the scatterplot.
    xlabel : str
        X-axis label.
    ylabel : str
        Y-axis label.
    size : int
        Size of markers in screen space units.
    legend : bool
        Whether or not to show the legend in the resulting figure.
    source_fields : Dict[str, List[Any]]
        Extra fields for the ColumnDataSource of the plot.

    Returns
    -------
    :class:`bokeh.plotting.figure.Figure`
    """
    if isinstance(x, Expression) and isinstance(y, Expression):
        if isinstance(label, Expression):
            res = hail.tuple([x, y, label]).collect()
            x = [point[0] for point in res]
            y = [point[1] for point in res]
            label = [point[2] for point in res]
        else:
            res = hail.tuple([x, y]).collect()
            x = [point[0] for point in res]
            y = [point[1] for point in res]
    elif isinstance(x, Expression) or isinstance(y, Expression):
        raise TypeError(
            'Invalid input: x and y must both be either Expressions or Python Lists.'
        )
    else:
        if isinstance(label, Expression):
            label = label.collect()

    p = figure(title=title,
               x_axis_label=xlabel,
               y_axis_label=ylabel,
               background_fill_color='#EEEEEE')
    if label is not None:
        fields = dict(x=x, y=y, label=label)
        if source_fields is not None:
            for key, values in source_fields.items():
                fields[key] = values

        source = ColumnDataSource(fields)

        if legend:
            leg = 'label'
        else:
            leg = None

        factors = list(set(label))
        if len(factors) > len(palette):
            color_gen = cycle(palette)
            colors = []
            for i in range(0, len(factors)):
                colors.append(next(color_gen))
        else:
            colors = palette[0:len(factors)]

        color_mapper = CategoricalColorMapper(factors=factors, palette=colors)
        p.circle('x',
                 'y',
                 alpha=0.5,
                 source=source,
                 size=size,
                 color={
                     'field': 'label',
                     'transform': color_mapper
                 },
                 legend=leg)
    else:
        p.circle(x, y, alpha=0.5, size=size)
    return p
示例#39
0
def annotate_unphased_pairs(unphased_ht: hl.Table, n_variant_pairs: int,
                            least_consequence: str, max_af: float):
    # unphased_ht = vp_ht.filter(hl.is_missing(vp_ht.all_phase))
    # unphased_ht = unphased_ht.key_by()

    # Explode variant pairs
    unphased_ht = unphased_ht.annotate(las=[
        hl.tuple([unphased_ht.locus1, unphased_ht.alleles1]),
        hl.tuple([unphased_ht.locus2, unphased_ht.alleles2])
    ]).explode('las', name='la')

    unphased_ht = unphased_ht.key_by(
        locus=unphased_ht.la[0], alleles=unphased_ht.la[1]).persist(
        )  # .checkpoint('gs://gnomad-tmp/vp_ht_unphased.ht')

    # Annotate single variants with gnomAD freq
    gnomad_ht = gnomad.public_release('exomes').ht()
    gnomad_ht = gnomad_ht.semi_join(unphased_ht).repartition(
        ceil(n_variant_pairs / 10000), shuffle=True).persist()

    missing_freq = hl.struct(
        AC=0,
        AF=0,
        AN=125748 * 2,  # set to no missing for now
        homozygote_count=0)

    logger.info(
        f"{gnomad_ht.count()}/{unphased_ht.count()} single variants from the unphased pairs found in gnomAD."
    )

    gnomad_indexed = gnomad_ht[unphased_ht.key]
    gnomad_freq = gnomad_indexed.freq
    unphased_ht = unphased_ht.annotate(
        adj_freq=hl.or_else(gnomad_freq[0], missing_freq),
        raw_freq=hl.or_else(gnomad_freq[1], missing_freq),
        vep_genes=vep_genes_expr(gnomad_indexed.vep, least_consequence),
        max_af_filter=gnomad_indexed.freq[0].AF <= max_af
        # pop_max_freq=hl.or_else(
        #     gnomad_exomes.popmax[0],
        #     missing_freq.annotate(
        #         pop=hl.null(hl.tstr)
        #     )
        # )
    )
    unphased_ht = unphased_ht.persist()
    # unphased_ht = unphased_ht.checkpoint('gs://gnomad-tmp/unphased_ann.ht', overwrite=True)

    loci_expr = hl.sorted(
        hl.agg.collect(
            hl.tuple([
                unphased_ht.locus,
                hl.struct(
                    adj_freq=unphased_ht.adj_freq,
                    raw_freq=unphased_ht.raw_freq,
                    # pop_max_freq=unphased_ht.pop_max_freq
                )
            ])),
        lambda x: x[0]  # sort by locus
    ).map(lambda x: x[1]  # get rid of locus
          )

    vp_freq_expr = hl.struct(v1=loci_expr[0], v2=loci_expr[1])

    # [AABB, AABb, AAbb, AaBB, AaBb, Aabb, aaBB, aaBb, aabb]
    def get_gt_counts(freq: str):
        return hl.array([
            hl.min(vp_freq_expr.v1[freq].AN, vp_freq_expr.v2[freq].AN),  # AABB
            vp_freq_expr.v2[freq].AC -
            (2 * vp_freq_expr.v2[freq].homozygote_count),  # AABb
            vp_freq_expr.v2[freq].homozygote_count,  # AAbb
            vp_freq_expr.v1[freq].AC -
            (2 * vp_freq_expr.v1[freq].homozygote_count),  # AaBB
            0,  # AaBb
            0,  # Aabb
            vp_freq_expr.v1[freq].homozygote_count,  # aaBB
            0,  # aaBb
            0  # aabb
        ])

    gt_counts_raw_expr = get_gt_counts('raw_freq')
    gt_counts_adj_expr = get_gt_counts('adj_freq')

    # gt_counts_pop_max_expr = get_gt_counts('pop_max_freq')
    unphased_ht = unphased_ht.group_by(
        unphased_ht.locus1, unphased_ht.alleles1, unphased_ht.locus2,
        unphased_ht.alleles2
    ).aggregate(
        pop='all',  # TODO Add option for multiple pops?
        phase_info=hl.struct(gt_counts=hl.struct(raw=gt_counts_raw_expr,
                                                 adj=gt_counts_adj_expr),
                             em=hl.struct(
                                 raw=get_em_expr(gt_counts_raw_expr),
                                 adj=get_em_expr(gt_counts_raw_expr))),
        vep_genes=hl.agg.collect(
            unphased_ht.vep_genes).filter(lambda x: hl.len(x) > 0),
        max_af_filter=hl.agg.all(unphased_ht.max_af_filter)

        # pop_max_gt_counts_adj=gt_counts_raw_expr,
        # pop_max_em_p_chet_adj=get_em_expr(gt_counts_raw_expr).p_chet,
    )  # .key_by()

    unphased_ht = unphased_ht.transmute(
        vep_filter=(hl.len(unphased_ht.vep_genes) > 1)
        & (hl.len(unphased_ht.vep_genes[0].intersection(
            unphased_ht.vep_genes[1])) > 0))

    max_af_filtered, vep_filtered = unphased_ht.aggregate([
        hl.agg.count_where(~unphased_ht.max_af_filter),
        hl.agg.count_where(~unphased_ht.vep_filter)
    ])
    if max_af_filtered > 0:
        logger.info(
            f"{max_af_filtered} variant-pairs excluded because the AF of at least one variant was > {max_af}"
        )
    if vep_filtered > 0:
        logger.info(
            f"{vep_filtered} variant-pairs excluded because the variants were not found within the same gene with a csq of at least {least_consequence}"
        )

    unphased_ht = unphased_ht.filter(unphased_ht.max_af_filter
                                     & unphased_ht.vep_filter)

    return unphased_ht.drop('max_af_filter', 'vep_filter')