Esempio n. 1
0
def flip_text(base):
    """
    :param StringExpression base: Expression of a single base
    :return: StringExpression of flipped base
    :rtype: StringExpression
    """
    return (hl.switch(base).when('A', 'T').when('T', 'A').when('C', 'G').when(
        'G', 'C').default(base))
Esempio n. 2
0
 def fix_alleles(alleles):
     ref = alleles.map(lambda d: d.ref).fold(
         lambda s, t: hl.cond(hl.len(s) > hl.len(t), s, t), '')
     alts = alleles.map(lambda a: hl.switch(hl.allele_type(
         a.ref, a.alt)).when('SNP', a.alt + ref[hl.len(a.alt):]).when(
             'Insertion', a.alt + ref[hl.len(a.ref):]).when(
                 'Deletion', a.alt + ref[hl.len(a.ref):]).default(a.alt))
     return hl.array([ref]).extend(alts)
Esempio n. 3
0
def flip_base(base: str) -> str:
    """
    Returns the complement of a base
    :param str base: Base to be flipped
    :return: Complement of input base
    :rtype: str
    """
    return (hl.switch(base).when('A', 'T').when('T', 'A').when('G', 'C').when(
        'C', 'G').default(base))
Esempio n. 4
0
def _encode_allele(allele: hl.expr.StringExpression) -> hl.expr.StringExpression:
    return hl.delimit(
        _grouped(
            # Convert string to array
            allele.split("")[:-1]
            # Convert letters to numbers
            .map(lambda letter: hl.switch(letter).when("A", 0).when("C", 1).when("G", 2).when("T", 3).or_missing()),
            3,  # Group into sets of 3
        )
        # Ensure each group has 3 elements
        .map(lambda g: g.extend(hl.range(3 - hl.len(g)).map(lambda _: 0)))
        # Bit shift and add group elements
        .map(lambda g: g[0] * 16 + g[1] * 4 + g[2])
        # Convert to letters
        .map(lambda n: _ENCODED_ALLELE_CHARACTERS[n]),
        "",
    )
Esempio n. 5
0
def main(args):
    hl.init(log='/assign_phecodes.log')

    # Read in the Phecode (v1.2b1) <-> ICD 9/10 codes mapping
    with hadoop_open(
            'gs://ukb-diverse-pops/phecode/UKB_Phecode_v1.2b1_ICD_Mapping.txt',
            'r') as f:
        df = pd.read_csv(f, delimiter='\t', dtype=str)
    list_of_icd_codes_to_include = [
        row.icd_codes.split(',') for _, row in df.iterrows()
    ]
    list_of_phecodes_to_exclude = [
        row.exclude_phecodes.split(',') for _, row in df.iterrows()
    ]
    df['icd_codes'] = list_of_icd_codes_to_include
    df['exclude_phecodes'] = list_of_phecodes_to_exclude

    # Convert it to HailTable
    phecode_ht = hl.Table.from_pandas(df)
    phecode_ht = phecode_ht.key_by('icd_codes')
    phecode_ht = phecode_ht.checkpoint(
        'gs://ukb-diverse-pops/phecode/UKB_Phecode_v1.2b1_ICD_Mapping.ht',
        overwrite=args.overwrite)

    # Retreive UKB ICD MatrixTable and combine codes based on Phecode definitions
    icd_all = hl.read_matrix_table(get_ukb_pheno_mt_path('icd_all'))
    mt = combine_phenotypes(icd_all,
                            icd_all.icd_code,
                            icd_all.any_codes,
                            list_of_icd_codes_to_include,
                            new_col_name='icd_codes',
                            new_entry_name='include_to_cases')
    mt = mt.annotate_cols(
        phecode=phecode_ht[mt.icd_codes].phecode,
        phecode_sex=phecode_ht[mt.icd_codes].sex,
        phecode_description=phecode_ht[mt.icd_codes].description,
        phecode_group=phecode_ht[mt.icd_codes].group,
        exclude_phecodes=phecode_ht[mt.icd_codes].exclude_phecodes)

    # Annotate sex for sex-specific phenotypes
    ukb_pheno_ht = hl.read_table(get_ukb_pheno_ht_path())
    mt = mt.annotate_rows(isFemale=ukb_pheno_ht[mt.userId].sex == 0)
    mt = checkpoint_tmp(mt)

    # Compute phecode excluded from controls
    mt = mt.key_cols_by()
    exclude_mt = combine_phenotypes(mt,
                                    mt.phecode,
                                    mt.include_to_cases,
                                    list_of_phecodes_to_exclude,
                                    new_entry_name='exclude_from_controls')
    exclude_mt = checkpoint_tmp(exclude_mt)

    # Annotate exclusion
    mt = mt.key_cols_by('exclude_phecodes')
    mt = mt.annotate_entries(
        exclude_sex=(hl.switch(mt.phecode_sex).when("males", mt.isFemale).when(
            "females", ~mt.isFemale).default(False)),
        exclude_from_controls=hl.coalesce(
            exclude_mt[mt.userId, mt.exclude_phecodes].exclude_from_controls,
            False))

    # Compute final case/control status
    # `case_control` becomes missing (NA) if a sample 1) is excluded because of sex, 2) is not cases and excluded from controls.
    mt = mt.annotate_entries(case_control=hl.if_else(
        mt.exclude_sex | (~mt.include_to_cases & mt.exclude_from_controls),
        hl.null(hl.tbool), mt.include_to_cases))

    mt = mt.key_cols_by('phecode')
    mt.describe()

    mt.write(get_ukb_pheno_mt_path('phecode'), overwrite=args.overwrite)
Esempio n. 6
0
def main():
    args = parse_args()

    tables = []
    for i, path in enumerate(args.paths):

        ht = import_SJ_out_tab(path)
        ht = ht.key_by("chrom", "start_1based", "end_1based")

        if args.normalize_read_counts:
            ht = ht.annotate_globals(
                unique_reads_in_sample=ht.aggregate(hl.agg.sum(
                    ht.unique_reads)),
                multi_mapped_reads_in_sample=ht.aggregate(
                    hl.agg.sum(ht.multi_mapped_reads)),
            )

        # add 'interval' column
        #ht = ht.annotate(interval=hl.interval(
        #    hl.locus(ht.chrom, ht.start_1based, reference_genome=reference_genome),
        #    hl.locus(ht.chrom, ht.end_1based, reference_genome=reference_genome),))

        tables.append(ht)

    # compute mean
    if args.normalize_read_counts:
        mean_unique_reads_in_sample = sum(
            [hl.eval(ht.unique_reads_in_sample)
             for ht in tables]) / float(len(tables))
        mean_multi_mapped_reads_in_sample = sum(
            [hl.eval(ht.multi_mapped_reads_in_sample)
             for ht in tables]) / float(len(tables))
        print(
            f"mean_unique_reads_in_sample: {mean_unique_reads_in_sample:01f}, mean_multi_mapped_reads_in_sample: {mean_multi_mapped_reads_in_sample:01f}"
        )

    combined_ht = None
    for i, ht in enumerate(tables):
        print(f"Processing table #{i} out of {len(tables)}")

        if args.normalize_read_counts:
            unique_reads_multiplier = mean_unique_reads_in_sample / float(
                hl.eval(ht.unique_reads_in_sample))
            multi_mapped_reads_multiplier = mean_multi_mapped_reads_in_sample / float(
                hl.eval(ht.multi_mapped_reads_in_sample))
            print(
                f"unique_reads_multiplier: {unique_reads_multiplier:01f}, multi_mapped_reads_multiplier: {multi_mapped_reads_multiplier:01f}"
            )

        ht = ht.annotate(
            strand_counter=hl.or_else(
                hl.switch(ht.strand).when(1, 1).when(2, -1).or_missing(), 0),
            num_samples_with_this_junction=1,
        )

        if args.normalize_read_counts:
            ht = ht.annotate(
                unique_reads=hl.int32(ht.unique_reads *
                                      unique_reads_multiplier),
                multi_mapped_reads=hl.int32(ht.multi_mapped_reads *
                                            multi_mapped_reads_multiplier),
            )

        if combined_ht is None:
            combined_ht = ht
            continue

        print("----")
        print_stats(path, ht)

        combined_ht = combined_ht.join(ht, how="outer")
        combined_ht = combined_ht.transmute(
            strand=hl.or_else(
                combined_ht.strand, combined_ht.strand_1
            ),  ## in rare cases, the strand for the same junction may differ across samples, so use a 2-step process that assigns strand based on majority of samples
            strand_counter=hl.sum([
                combined_ht.strand_counter, combined_ht.strand_counter_1
            ]),  # samples vote on whether strand = 1 (eg. '+') or 2 (eg. '-')
            intron_motif=hl.or_else(combined_ht.intron_motif,
                                    combined_ht.intron_motif_1
                                    ),  ## double-check that left == right?
            known_splice_junction=hl.or_else(
                hl.cond((combined_ht.known_splice_junction == 1) |
                        (combined_ht.known_splice_junction_1 == 1), 1, 0),
                0),  ## double-check that left == right?
            unique_reads=hl.sum(
                [combined_ht.unique_reads, combined_ht.unique_reads_1]),
            multi_mapped_reads=hl.sum([
                combined_ht.multi_mapped_reads,
                combined_ht.multi_mapped_reads_1
            ]),
            maximum_overhang=hl.max(
                [combined_ht.maximum_overhang,
                 combined_ht.maximum_overhang_1]),
            num_samples_with_this_junction=hl.sum([
                combined_ht.num_samples_with_this_junction,
                combined_ht.num_samples_with_this_junction_1
            ]),
        )

        combined_ht = combined_ht.checkpoint(
            f"checkpoint{i % 2}.ht", overwrite=True)  #, _read_if_exists=True)

    total_junctions_count = combined_ht.count()
    strand_conflicts_count = combined_ht.filter(
        hl.abs(combined_ht.strand_counter) /
        hl.float(combined_ht.num_samples_with_this_junction) < 0.1,
        keep=True).count()

    # set final strand value to 1 (eg. '+') or 2 (eg. '-') or 0 (eg. uknown) based on the setting in the majority of samples
    combined_ht = combined_ht.annotate(
        strand=hl.case().when(combined_ht.strand_counter > 0, 1).when(
            combined_ht.strand_counter < 0, 2).default(0))

    combined_ht = combined_ht.annotate_globals(combined_tables=args.paths,
                                               n_combined_tables=len(
                                                   args.paths))

    if strand_conflicts_count:
        print(
            f"WARNING: Found {strand_conflicts_count} strand_conflicts out of {total_junctions_count} total_junctions"
        )

    # write as HT
    combined_ht = combined_ht.checkpoint(
        f"combined.SJ.out.ht", overwrite=True)  #, _read_if_exists=True)

    ## write as tsv
    output_prefix = f"combined.{len(tables)}_samples{'.normalized_counts' if args.normalize_read_counts else ''}"
    combined_ht = combined_ht.key_by()
    combined_ht.export(f"{output_prefix}.with_header.combined.SJ.out.tab",
                       header=True)
    combined_ht = combined_ht.select(
        "chrom",
        "start_1based",
        "end_1based",
        "strand",
        "intron_motif",
        "known_splice_junction",
        "unique_reads",
        "multi_mapped_reads",
        "maximum_overhang",
    )
    combined_ht.export(f"{output_prefix}.SJ.out.tab", header=False)

    print(
        f"unique_reads_in combined table: {combined_ht.aggregate(hl.agg.sum(combined_ht.unique_reads))}"
    )