Ejemplo n.º 1
0
def compute_fisher_exact(tb: hl.Table,
                         n_cases_col: str,
                         n_control_col: str,
                         total_cases_col: str,
                         total_controls_col: str,
                         correct_total_counts: bool,
                         root_col_name: str,
                         extra_fields: dict) -> hl.Table:
    """
    Perform two-sided Fisher Exact test. Add extra annotations (if any)

    :param tb: Hail Table
    :param n_cases_col: field name with number of (affected) cases
    :param n_control_col: field name with number of (affected) control
    :param total_cases_col: field name with total number of cases
    :param total_controls_col: field name with total number of controls
    :param correct_total_counts: should the total numbers (case/control) be corrected to avoid duplicated counting?
    :param root_col_name: field to be annotated with test results
    :param extra_fields: Extra filed (must be a dict) to be annotated
    :return: Hail Table with Fisher Exact test results.
    """
    # compute fisher exact
    if correct_total_counts:
        fet = hl.fisher_exact_test(c1=hl.int32(tb[n_cases_col]),
                                   c2=hl.int32(tb[n_control_col]),
                                   c3=hl.int32(tb[total_cases_col]) - hl.int32(tb[n_cases_col]),
                                   c4=hl.int32(tb[total_controls_col]) - hl.int32(tb[n_control_col]))
    else:
        fet = hl.fisher_exact_test(c1=hl.int32(tb[n_cases_col]),
                                   c2=hl.int32(tb[n_control_col]),
                                   c3=hl.int32(tb[total_cases_col]),
                                   c4=hl.int32(tb[total_controls_col]))

    tb = (tb
          .annotate(**{root_col_name: fet})
          .flatten()
          )

    if len(extra_fields) == 0:
        return tb
    else:
        return tb.annotate(**extra_fields)
Ejemplo n.º 2
0
    def test(self):
        schema = hl.tstruct(a=hl.tint32, b=hl.tint32, c=hl.tint32, d=hl.tint32, e=hl.tstr,
                            f=hl.tarray(hl.tint32),
                            g=hl.tarray(
                                hl.tstruct(x=hl.tint32, y=hl.tint32, z=hl.tstr)),
                            h=hl.tstruct(a=hl.tint32, b=hl.tint32, c=hl.tstr),
                            i=hl.tbool,
                            j=hl.tstruct(x=hl.tint32, y=hl.tint32, z=hl.tstr))

        rows = [{'a': 4, 'b': 1, 'c': 3, 'd': 5,
                 'e': "hello", 'f': [1, 2, 3],
                 'g': [hl.Struct(x=1, y=5, z='banana')],
                 'h': hl.Struct(a=5, b=3, c='winter'),
                 'i': True,
                 'j': hl.Struct(x=3, y=2, z='summer')}]

        kt = hl.Table.parallelize(rows, schema)

        result = convert_struct_to_dict(kt.annotate(
            chisq=hl.chisq(kt.a, kt.b, kt.c, kt.d),
            ctt=hl.ctt(kt.a, kt.b, kt.c, kt.d, 5),
            dict=hl.dict(hl.zip([kt.a, kt.b], [kt.c, kt.d])),
            dpois=hl.dpois(4, kt.a),
            drop=kt.h.drop('b', 'c'),
            exp=hl.exp(kt.c),
            fet=hl.fisher_exact_test(kt.a, kt.b, kt.c, kt.d),
            hwe=hl.hardy_weinberg_p(1, 2, 1),
            index=hl.index(kt.g, 'z'),
            is_defined=hl.is_defined(kt.i),
            is_missing=hl.is_missing(kt.i),
            is_nan=hl.is_nan(hl.float64(kt.a)),
            json=hl.json(kt.g),
            log=hl.log(kt.a, kt.b),
            log10=hl.log10(kt.c),
            or_else=hl.or_else(kt.a, 5),
            or_missing=hl.or_missing(kt.i, kt.j),
            pchisqtail=hl.pchisqtail(kt.a, kt.b),
            pcoin=hl.rand_bool(0.5),
            pnorm=hl.pnorm(0.2),
            pow=2.0 ** kt.b,
            ppois=hl.ppois(kt.a, kt.b),
            qchisqtail=hl.qchisqtail(kt.a, kt.b),
            range=hl.range(0, 5, kt.b),
            rnorm=hl.rand_norm(0.0, kt.b),
            rpois=hl.rand_pois(kt.a),
            runif=hl.rand_unif(kt.b, kt.a),
            select=kt.h.select('c', 'b'),
            sqrt=hl.sqrt(kt.a),
            to_str=[hl.str(5), hl.str(kt.a), hl.str(kt.g)],
            where=hl.cond(kt.i, 5, 10)
        ).take(1)[0])
Ejemplo n.º 3
0
 def fet_expr(het_count_exp: hl.expr.Int64Expression,
              hom_count_expr: hl.expr.Int64Expression):
     return hl.bind(
         lambda x: hl.struct(
             counts=x,
             dominant=hl.fisher_exact_test(x[0][0], x[0][1] + x[0][2],
                                           x[1][0], x[1][1] + x[1][2]),
             recessive=hl.fisher_exact_test(x[0][0] + x[0][1], x[0][
                 2], x[1][0] + x[1][1], x[1][2])),
         hl.bind(
             lambda x: [
                 [
                     hl.int32(
                         hl.cond(x.contains(False), x[False].get(0, 0),
                                 0)),
                     hl.int32(
                         hl.cond(x.contains(False), x[False].get(1, 0),
                                 0)),
                     hl.int32(
                         hl.cond(x.contains(False), x[False].get(2, 0),
                                 0))
                 ],
                 [
                     hl.int32(
                         hl.cond(x.contains(True), x[True].get(0, 0), 0)
                     ),
                     hl.int32(
                         hl.cond(x.contains(True), x[True].get(1, 0), 0)
                     ),
                     hl.int32(
                         hl.cond(x.contains(True), x[True].get(2, 0), 0)
                     )
                 ],
             ],
             hl.agg.group_by(
                 mt.is_case,
                 hl.agg.counter(
                     hl.min(2, het_count_exp + 2 * hom_count_expr)))))
Ejemplo n.º 4
0
def fs_from_sb(
    sb: Union[hl.expr.ArrayNumericExpression, hl.expr.ArrayExpression],
    normalize: bool = True,
    min_cell_count: int = 200,
    min_count: int = 4,
    min_p_value: float = 1e-320,
) -> hl.expr.Int64Expression:
    """
    Computes `FS` (Fisher strand balance) annotation from  the `SB` (strand balance table) field.
    `FS` is the phred-scaled value of the double-sided Fisher exact test on strand balance.

    Using default values will have the same behavior as the GATK implementation, that is:
    - If sum(counts) > 2*`min_cell_count` (default to GATK value of 200), they are normalized
    - If sum(counts) < `min_count` (default to GATK value of 4), returns missing
    - Any p-value < `min_p_value` (default to GATK value of 1e-320) is truncated to that value

    In addition to the default GATK behavior, setting `normalize` to `False` will perform a chi-squared test
    for large counts (> `min_cell_count`) instead of normalizing the cell values.

    .. note::

        This function can either take
        - an array of length four containing the forward and reverse strands' counts of ref and alt alleles: [ref fwd, ref rev, alt fwd, alt rev]
        - a two dimensional array with arrays of length two, containing the counts: [[ref fwd, ref rev], [alt fwd, alt rev]]

    GATK code here: https://github.com/broadinstitute/gatk/blob/master/src/main/java/org/broadinstitute/hellbender/tools/walkers/annotator/FisherStrand.java

    :param sb: Count of ref/alt reads on each strand
    :param normalize: Whether to normalize counts is sum(counts) > min_cell_count (normalize=True), or use a chi sq instead of FET (normalize=False)
    :param min_cell_count: Maximum count for performing a FET
    :param min_count: Minimum total count to output FS (otherwise null it output)
    :return: FS value
    """
    if not isinstance(sb, hl.expr.ArrayNumericExpression):
        sb = hl.bind(lambda x: hl.flatten(x), sb)

    sb_sum = hl.bind(lambda x: hl.sum(x), sb)

    # Normalize table if counts get too large
    if normalize:
        fs_expr = hl.bind(
            lambda sb, sb_sum: hl.cond(
                sb_sum <= 2 * min_cell_count,
                sb,
                sb.map(lambda x: hl.int(x / (sb_sum / min_cell_count))),
            ),
            sb,
            sb_sum,
        )

        # FET
        fs_expr = to_phred(
            hl.max(
                hl.fisher_exact_test(
                    fs_expr[0], fs_expr[1], fs_expr[2], fs_expr[3]
                ).p_value,
                min_p_value,
            )
        )
    else:
        fs_expr = to_phred(
            hl.max(
                hl.contingency_table_test(
                    sb[0], sb[1], sb[2], sb[3], min_cell_count=min_cell_count
                ).p_value,
                min_p_value,
            )
        )

    # Return null if counts <= `min_count`
    return hl.or_missing(
        sb_sum > min_count, hl.max(0, fs_expr)  # Needed to avoid -0.0 values
    )
Ejemplo n.º 5
0
    e_cases = filtered_e.aggregate(hl.agg.sum(filtered_e.nontopmed_cases))
    e_homs = filtered_e.aggregate(hl.agg.max(filtered_e.nontopmed_hom))
    e_controls = e_max_AN - e_cases - (e_homs * 2)

    g_max_AN = filtered_g.aggregate(hl.agg.max(filtered_g.nontopmed_AN))
    g_cases = filtered_g.aggregate(hl.agg.sum(filtered_g.nontopmed_cases))
    g_homs = filtered_g.aggregate(hl.agg.max(filtered_g.nontopmed_hom))
    g_controls = g_max_AN - g_cases - (g_homs * 2)

    print("Exomes max AN: {}".format(e_max_AN), file=f)
    print("Exomes cases: {}".format(e_cases), file=f)
    print("Exome Homs: {}".format(e_homs), file=f)
    print("Exome Controls: {}".format(e_controls), file=f)

    print("Genomes max AN: {}".format(g_max_AN), file=f)
    print("Genomes cases: {}".format(g_cases), file=f)
    print("Genomes Homs: {}".format(g_homs), file=f)
    print("Genome controls: {}".format(g_controls), file=f)

    print("Running Fisher Exact Test")
    tot_cases = int(e_cases + g_cases)
    tot_controls = int(e_controls + g_controls)
    print("Combined controls: {}".format(tot_controls), file=f)
    print("Combined cases: {}".format(tot_cases), file=f)

result = hl.fisher_exact_test(3, 45, tot_cases, tot_controls)
print("Exporting Fisher Exact Test Results...")
result.export(
    '/gpfs/ycga/project/kahle/sp2349/moyamoya/case_control/diaph1_cc.tsv')
print("Done!")
def gnomad_het_case_control(exp_cases,exp_controls,chrm='0',start=0,end=0,out_dir=os.getcwd(),name=""):
	"""Performs case Control Test for specified coordinates using the Gnomad combined exome and genome dataset."""
	#Convert chromosome integer to string
	if type(chrm) == 'int':
		chrm = str(chrm)
	chrm = str(chrm)	
	
	#Check if name provided
	if name == "":
		name = "results_{}_{}-{}".format(chrm,start,end)
	
	#Create folder structure for ouput
	parent_dir = os.path.join(out_dir,"gnomad_case_control",name)
	if os.path.exists(os.path.join(parent_dir)) == False:
		os.makedirs(parent_dir)
	
	out_f = os.path.join(parent_dir,"results_{}_{}-{}.out".format(chrm,start,end))	
	with open(out_f, "w") as f:
		#Import gnomad combined exome and genome table
		gnomad_e = hl.read_table('/gpfs/ycga/scratch60/kahle/sp2349/combined_weilai_mts/combined.filtered.final.gnomad.r2.1.1.sites.{}.mt'.format(chrm))
		
		#Filter on Bravo coordinates
		#filtered_e = gnomad_e.filter((gnomad_e.info.bravo <= 0.0005) | (gnomad_e.info.bravo.is_defined() == False))
		#filtered_g = gnomad_g.filter((gnomad_g.info.bravo <= 0.0005) | (gnomad_g.info.bravo.is_defined() == False))

		#Filter on Bravo coordinates
		filtered_e = gnomad_e.filter(gnomad_e.bravo <= 0.0005)

		#Filter on DIAPH1 coordinates
		filtered_e = filtered_e.filter((filtered_e.locus >= hl.locus(chrm,start)) & (filtered_e.locus <= hl.locus(chrm,end)))

		print("Exomes count filtered on gene: {}".format(filtered_e.count()),file=f)
	   
		#Filter for variants with AC == 0
		filtered_e = filtered_e.filter(filtered_e.combined_nontopmed_AC	> 0)

		filtered_e 
		#Filter for LoF variants.If missense, CADD1.6 score >= 20 or MetaSVM == D
		filtered_e = filtered_e.filter(((filtered_e["Exonic_refGene"] == 'nonsynonymous_SNV') & (filtered_e.CADD16snv_PHRED >=20)) | 
									   ((filtered_e["Exonic_refGene"] == 'nonsynonymous_SNV') & (filtered_e.MetaSVM_pred == "D")) | 
									   (filtered_e["Func_refGene"] == 'splicing') |
									   (filtered_e["Exonic_refGene"] == 'frameshift_deletion') |
									   (filtered_e["Exonic_refGene"] == 'frameshift_insertion') |
									   (filtered_e["Exonic_refGene"] == 'stopgain') |
									   (filtered_e["Exonic_refGene"] == 'stoploss') |
									   (filtered_e["Func_refGene"] == 'exonic\\x3bsplicing'))
		
		print("Lof var rows kept: {}".format(filtered_e.count()),file=f)
		
		#Gather non_topmed indicies. Collect() results in a list containing a dictionary. Too see all groups use freq_index_dict.collect()[0].
		#See Macarthur lab website for more information on the gnomad groups.
		#group_e = filtered_e.freq_index_dict.collect()[0][group_sel]
		#group_g = filtered_g.freq_index_dict.collect()[0][group_sel]

		#print("Exome non_topmed index: {}".format(group_e),file=f)
		#print("Genome non_topmed index: {}".format(group_g),file=f)

		#Calculate Cases and Controls. Cases = AC - (2*homozygote_count). Controls = max(AN - cases)
		filtered_e = filtered_e.annotate(nontopmed_cases=filtered_e.combined_nontopmed_AC-(filtered_e.combined_nontopmed_nhomalt *2))
		filtered_e = filtered_e.annotate(nontopmed_AN=filtered_e.combined_nontopmed_AN)
		filtered_e = filtered_e.annotate(nontopmed_hom=filtered_e.combined_nontopmed_nhomalt)
		
		e_max_AN = filtered_e.aggregate(hl.agg.max(filtered_e.nontopmed_AN))
		e_cases = filtered_e.aggregate(hl.agg.sum(filtered_e.nontopmed_cases))
		e_homs = filtered_e.aggregate(hl.agg.sum(filtered_e.nontopmed_hom))
		e_controls = e_max_AN-e_cases-(e_homs*2)

		print("Combined max AN: {}".format(e_max_AN),file=f)
		print("Combined cases: {}".format(e_cases),file=f)
		print("Combined Homs: {}".format(e_homs),file=f)
		print("Combined Controls: {}".format(e_controls),file=f)

		print("Running Fisher Exact Test")
		
		filtered_e.write(os.path.join(parent_dir,"results_{}_{}-{}.ht".format(chrm,start,end)),overwrite=True)
		df = filtered_e.to_pandas()
		df.to_csv(os.path.join(parent_dir,"results_{}_{}-{}.csv".format(chrm,start,end)))

		result = hl.fisher_exact_test(exp_cases,exp_controls,e_cases,e_controls)
		print("Exporting Fisher Exact Test Results...")
		result.export(os.path.join(parent_dir,"results_{}_{}-{}.tsv".format(chrm,start,end)))
		print("Done!")
        (filtered.combined_nontopmed_nhomalt * 2))
    #filtered = filtered.annotate(nontopmed_AN=filtered.combined_nontopmed_AN)
    #filtered = filtered.annotate(nontopmed_hom=filtered.combined_nontopmed_nhomalt)

    max_AN = filtered.aggregate(hl.agg.max(filtered.combined_nontopmed_AN))
    cases = filtered.aggregate(hl.agg.sum(filtered.combined_nontopmed_cases))
    homs = filtered.aggregate(hl.agg.sum(filtered.combined_nontopmed_nhomalt))
    controls = max_AN - cases - (homs * 2)

    print("Combined max AN: {}".format(max_AN), file=f)
    print("Combined cases: {}".format(cases), file=f)
    print("Combined Homs: {}".format(homs), file=f)
    print("Combined Controls: {}".format(controls), file=f)

    print("Running Fisher Exact Test")

    filtered.write(
        "/gpfs/ycga/project/kahle/sp2349/moyamoya/case_control/cc_output/diaph1_gnomad-combined_sam.ht",
        overwrite=True)
    df = filtered.to_pandas()
    df.to_csv(
        "/gpfs/ycga/project/kahle/sp2349/moyamoya/case_control/cc_output/diaph1_gnomad-combined_sam.csv"
    )

result = hl.fisher_exact_test(3, 45, cases, controls)
print("Exporting Fisher Exact Test Results...")
result.export(
    '/gpfs/ycga/project/kahle/sp2349/moyamoya/case_control/cc_output/diaph1_gnomad-combined_sam_cc.tsv'
)
print("Done!")