Example #1
0
def array_qc(path_to_mt: str, output_root_directory: str, mt_1kg_eur_path: str,
             r2: float, pihat: float, sample_call_rate_thresh: float,
             sample_call_rate_thresh_chr: float,
             variant_call_rate_thresh: float, maf_thresh: float,
             hwe_thresh: float, ass_thresh: float):
    """
        This function performs array-level QC including the following QC filters:
        - sample call rate
        - sample call rate for each chromosome 
        - variant call rate 
        - minor allele frequencies 
        - hardy weinberg equilibrium
        - pseudo case-control GWAS (tagging samples from one cohort as cases, from the other cohorts as controls)
        - pseudo case-control GWAS against 1KG 
        ----
        :param path_to_str: path to martrixtable
        :param str output_root_directory
        :param str mt_1kg_eur_path: path to 1KG EUR population
        :param float r2: LD pruning threshold
        :param float pihat: pihat threshold for filtering related samples
        :param float sample_call_rate_thresh: sample call rate threshold
        :param float sample_call_rate_thresh_chr: per chromosome sample call rate
        :param float variant_call_rate_thresh: variant call rate threshold
        :param float maf_thresh: minor allele frequencies threshold
        :param float hwe_thresh: hardy weinberg equilibrium test pvalues threshold
        :param float ass_thresh: p-value threshold for pseudo GWAS
        """

    mt = hl.read_matrix_table(path_to_mt)
    mt_1kg_eur = hl.read_matrix_table(mt_1kg_eur_path)
    mt_1kg_eur = mt_1kg_eur.filter_cols(
        mt_1kg_eur.population == 'eur(mainland)')

    # Build Directory structure
    directory_structure = build_array_qc_directory_structure(
        output_root_directory, population=population, array=array)

    # Count number of samples and variants before QC
    n_variants_before_qc, n_samples_before_qc = mt.count()

    # Perform sample QC
    mt = hl.sample_qc(mt)
    mt = hl.variant_qc(mt)
    mt = calculate_per_chr_sample_call_rate(mt=mt,
                                            sample_call_rate_col='scr_chr')
    mt = mt.cache()

    # Calculate number of samples & variants that fail each filter
    n_sample_call_rate_thresh = mt.aggregate_cols(
        hl.agg.count_where(mt.sample_qc.call_rate < sample_call_rate_thresh))
    n_sample_call_rate_thresh_chr = mt.aggregate_cols(
        hl.agg.count_where(mt.scr_chr < sample_call_rate_thresh_chr))
    n_variant_call_rate_thresh = mt.aggregate_rows(
        hl.agg.count_where(mt.variant_qc.call_rate < variant_call_rate_thresh))
    n_mafsh = mt.aggregate_rows(
        hl.agg.count_where((mt.variant_qc.AF[0] < maf_thresh)
                           | (mt.variant_qc.AF[1] < maf_thresh)))
    n_hwesh = mt.aggregate_rows(
        hl.agg.count_where((mt.het_freq_hwe < hwe_thresh)
                           | (mt.p_value < hwe_thresh)))

    # Filter samples & variants
    mt = mt.filter_cols((mt.sample_qc.call_rate > sample_call_rate_thresh) &
                        (mt.scr_chr > sample_call_rate_thresh_chr),
                        keep=True)
    mt = mt.filter_rows(
        (mt.variant_qc.call_rate > variant_call_rate_thresh) &
        (mt.variant_qc.AF[0] > maf_thresh) & (mt.variant_qc.AF[1] > maf_thresh)
        & (mt.het_freq_hwe > hwe_thresh) & (mt.p_value > hwe_thresh),
        keep=True)

    # LD pruning
    pruned_variants_list = ld_prune(mt, r2=r2, pruned_variants_list=True)
    pruned_mt = mt.filter_rows(hl.is_defined(pruned_variants_list[mt.row_key]),
                               keep=True)
    samples_to_remove = identify_related_samples(pruned_mt,
                                                 pihat_threshold=pihat)

    # Filter related samples
    n_related = samples_to_remove.count()
    mt = mt.filter_cols(hl.is_defined(samples_to_remove[mt.col_key]),
                        keep=False)

    # PCA
    scores_ht, _ = pca(mt=pruned_mt, n_evecs=6, remove_outliers=False)
    scores_ht.write(directory_structure["pca"] + "/scores.ht", overwrite=True)

    # Pseudo Case-Control Analysis
    n_pseudo_case_ontrol = 0
    pruned_1kg_eur_mt = mt_1kg_eur.filter_rows(hl.is_defined(
        pruned_variants_list[mt_1kg_eur.row_key]),
                                               keep=True)
    scores_1kg_eur_ht, loadings_1kg_eur_ht = pca(mt=pruned_1kg_eur_mt,
                                                 n_evecs=6,
                                                 remove_outliers=False)
    scores_ht = pca_project(mt=mt,
                            loadings_ht=loadings_1kg_eur_ht,
                            correct_shrinkage=True)
    mt = mt.annotate_cols(scores=scores_ht[mt.col_key].scores)
    if len(cohort_labels) > 1:
        for cohort in cohort_labels:
            mt = mt.annotate_cols(is_case=hl.cond(mt.cohort == cohort, 1, 0))
            result_ht = hl.linear_regression_rows(
                y=mt.is_case,
                x=mt.GT.n_alt_alleles(),
                covariates=[
                    1, mt.scores[0], mt.scores[1], mt.scores[2], mt.scores[3],
                    mt.scores[4], mt.scores[5]
                ])
            mt = mt.annotate_rows(
                **{'pvalue_' + cohort: result_ht[mt.row_key].p_value})
            mt = mt.drop('is_case')

        # Calculate the minimum pvalues cross each each cross cohort comparison
        mt = mt.annotate_rows(pvalues_assoc=hl.min(
            [mt['pvalue_' + str(cohort)] for cohort in cohort_labels]))

        # Count number of variants failing cross-cohort comparison
        n_pseudo_case_ontrol = mt.aggregate_rows(
            hl.agg.count_where(mt.pvalues_assoc < ass_thresh))

        # Filter variants
        mt = mt.filter_rows(mt.pvalues_assoc > ass_thresh, keep=True)

    # Pseudo Case-Control Analysis against 1KG EUR
    n_pseudo_1kg_case_control = 0
    if population == "eur_mainland":
        mt_merged = mt.select_cols().select_rows().union_cols(
            mt_1kg_eur.select_cols().select_rows())
        scores_ht = scores_ht.union(scores_1kg_eur_ht)

        mt_merged = mt_merged.annotate_cols(
            scores=scores_ht[mt_merged.col_key].scores,
            is_case=hl.cond(hl.is_defined(mt.index_cols(mt_merged.col_key)), 1,
                            0))
        result_ht = hl.linear_regression_rows(
            y=mt_merged.is_case,
            x=mt_merged.GT.n_alt_alleles(),
            covariates=[
                1, mt_merged.scores[0], mt_merged.scores[1],
                mt_merged.scores[2], mt_merged.scores[3], mt_merged.scores[4],
                mt_merged.scores[5]
            ])

        # Annotate p-values from comparison against 1KG EUR
        mt = mt.annotate_rows(pvalues_1kg_assoc=result_ht[mt.row_key].p_value)

        # Count number of variants failing comparison against 1KG EUR
        n_pseudo_1kg_case_control = mt.aggregate_rows(
            hl.agg.count_where(mt.pvalues_1kg_assoc < ass_thresh))

        # Filter variants
        mt = mt.filter_rows((mt.pvalues_1kg_assoc > ass_thresh) |
                            (hl.is_missing(mt.pvalues_1kg_assoc)),
                            keep=True)

    # Count number of samples and variants after QC
    n_variants_after_qc, n_samples_after_qc = mt.count()

    # Write QC'ed data to google bucket
    mt.write(directory_structure["share"] + "/final.mt", overwrite=True)

    # Write QC meta information
    meta = {
        "Number of Samples before QC": n_samples_before_qc,
        "Number of Samples after QC:": n_samples_after_qc,
        "Number of Variants before QC": n_variants_before_qc,
        "Number of Variants after QC": n_variants_after_qc,
        "Genotyping Array": array,
        "Population": population,
        "Sample QC Summary": {
            "IDs: call rate < %s" % sample_call_rate_thresh:
            n_sample_call_rate_thresh,
            "IDs: minimum per-chromosome call rate < %s" % sample_call_rate_thresh_chr:
            n_sample_call_rate_thresh_chr,
            "IDs: pi-hat > %s" % pihat:
            n_related
        },
        "Variant QC Summary": {
            "SNPs: call rate < %s" % variant_call_rate_thresh:
            n_variant_call_rate_thresh,
            "SNPs: minor allele frequency < %s" % maf_thresh:
            n_mafsh,
            "SNPs: HWE p-values < %s" % hwe_thresh:
            n_hwesh,
            "SNPs: Cross cohorts pseudo case-control across p-value < %s" % ass_thresh:
            n_pseudo_case_ontrol + n_pseudo_1kg_case_control
        }
    }
    with hl.hadoop_open(directory_structure["summary"] + "/meta.json",
                        'w') as outfile:
        json.dump(meta, outfile)
def cross_array_comparison(path_to_mt: str, output_root_directory: str,
                           mt_1kg_eur_path: str, r2: float, pihat: float,
                           ass_thresh: float):
    """
        This function performs post-imputaion QC including the following QC filters:
        - pseudo case-control GWAS (tagging samples from one cohort as cases, from the other cohorts as controls)
        - pseudo case-control GWAS against 1KG 
        ----
        :param path_to_str: path to martrixtable
        :param str output_root_directory
        :param str mt_1kg_eur_path: path to 1KG EUR population
        :param float r2: LD pruning threshold
        :param float pihat: pihat threshold for filtering related samples
        :param float ass_thresh: p-value threshold for pseudo GWAS
        """

    mt = hl.read_matrix_table(path_to_mt)
    mt_1kg_eur = hl.read_matrix_table(mt_1kg_eur_path)
    mt_1kg_eur = mt_1kg_eur.filter_cols(
        mt_1kg_eur.population == 'eur(mainland)')

    # Build Directory structure
    directory_structure = build_array_qc_directory_structure(
        output_root_directory, population=population, array=array)

    # Count number of samples and variants before QC
    n_variants_before_qc, n_samples_before_qc = mt.count()

    # LD pruning
    pruned_variants_list = ld_prune(mt, r2=r2, pruned_variants_list=True)
    pruned_mt = mt.filter_rows(hl.is_defined(pruned_variants_list[mt.row_key]),
                               keep=True)
    samples_to_remove = identify_related_samples(pruned_mt,
                                                 pihat_threshold=pihat)

    # Filter related samples
    n_related = samples_to_remove.count()
    mt = mt.filter_cols(hl.is_defined(samples_to_remove[mt.col_key]),
                        keep=False)

    # PCA
    scores_ht, _ = pca(mt=pruned_mt, n_evecs=6, remove_outliers=False)
    scores_ht.write(directory_structure["pca"] + "/scores.ht", overwrite=True)

    # Pseudo Case-Control Analysis
    n_pseudo_case_ontrol = 0
    pruned_1kg_eur_mt = mt_1kg_eur.filter_rows(hl.is_defined(
        pruned_variants_list[mt_1kg_eur.row_key]),
                                               keep=True)
    scores_1kg_eur_ht, loadings_1kg_eur_ht = pca(mt=pruned_1kg_eur_mt,
                                                 n_evecs=6,
                                                 remove_outliers=False)
    scores_ht = pca_project(mt=mt,
                            loadings_ht=loadings_1kg_eur_ht,
                            correct_shrinkage=True)
    mt = mt.annotate_cols(scores=scores_ht[mt.col_key].scores)
    if len(cohort_labels) > 1:
        for cohort in cohort_labels:
            mt = mt.annotate_cols(is_case=hl.cond(mt.cohort == cohort, 1, 0))
            result_ht = hl.linear_regression_rows(
                y=mt.is_case,
                x=mt.GT.n_alt_alleles(),
                covariates=[
                    1, mt.scores[0], mt.scores[1], mt.scores[2], mt.scores[3],
                    mt.scores[4], mt.scores[5]
                ])
            mt = mt.annotate_rows(
                **{'pvalue_' + cohort: result_ht[mt.row_key].p_value})
            mt = mt.drop('is_case')

        # Calculate the minimum pvalues cross each each cross cohort comparison
        mt = mt.annotate_rows(pvalues_assoc=hl.min(
            [mt['pvalue_' + str(cohort)] for cohort in cohort_labels]))

        # Count number of variants failing cross-cohort comparison
        n_pseudo_case_ontrol = mt.aggregate_rows(
            hl.agg.count_where(mt.pvalues_assoc < ass_thresh))

        # Filter variants
        mt = mt.filter_rows(mt.pvalues_assoc > ass_thresh, keep=True)

    # Pseudo Case-Control Analysis against 1KG EUR
    n_pseudo_1kg_case_control = 0
    if population == "eur_mainland":
        mt_merged = mt.select_cols().select_rows().union_cols(
            mt_1kg_eur.select_cols().select_rows())
        scores_ht = scores_ht.union(scores_1kg_eur_ht)

        mt_merged = mt_merged.annotate_cols(
            scores=scores_ht[mt_merged.col_key].scores,
            is_case=hl.cond(hl.is_defined(mt.index_cols(mt_merged.col_key)), 1,
                            0))
        result_ht = hl.linear_regression_rows(
            y=mt_merged.is_case,
            x=mt_merged.GT.n_alt_alleles(),
            covariates=[
                1, mt_merged.scores[0], mt_merged.scores[1],
                mt_merged.scores[2], mt_merged.scores[3], mt_merged.scores[4],
                mt_merged.scores[5]
            ])

        # Annotate p-values from comparison against 1KG EUR
        mt = mt.annotate_rows(pvalues_1kg_assoc=result_ht[mt.row_key].p_value)

        # Count number of variants failing comparison against 1KG EUR
        n_pseudo_1kg_case_control = mt.aggregate_rows(
            hl.agg.count_where(mt.pvalues_1kg_assoc < ass_thresh))

        # Filter variants
        mt = mt.filter_rows((mt.pvalues_1kg_assoc > ass_thresh) |
                            (hl.is_missing(mt.pvalues_1kg_assoc)),
                            keep=True)

    # Count number of samples and variants after QC
    n_variants_after_qc, n_samples_after_qc = mt.count()

    # Write QC'ed data to google bucket
    mt.write(directory_structure["share"] + "/final.mt", overwrite=True)

    # Write QC meta information
    meta = {
        "Number of Samples before QC": n_samples_before_qc,
        "Number of Samples after QC:": n_samples_after_qc,
        "Number of Variants before QC": n_variants_before_qc,
        "Number of Variants after QC": n_variants_after_qc,
        "Variant QC Summary": {
            "SNPs: Cross cohorts pseudo case-control across p-value < %s" % ass_thresh:
            n_pseudo_case_ontrol + n_pseudo_1kg_case_control
        }
    }
    with hl.hadoop_open(directory_structure["summary"] + "/cross_array.json",
                        'w') as outfile:
        json.dump(meta, outfile)
Example #3
0
def match_ancestry(path_to_mt: str, mt_1kg_path: str, mt_1kg_eur_path: str, output_root_directory: str,
             r2: float, pihat: float, hwe_thresh: float, n_partitions: int = 200):
    """  
        This function uses a random forest classifier trained from 1000 genomes to predict the continental ancestry for each sample.
        Within European population, uses another random forest classifier trained from 1000 genomes to predict whether they're from mainland European,
        Finland or Ashkenazi Jewish samples
        Note
        ----
        :param str path_to_mt: path to matrixtable
        :param str mt_1kg_path: path to 1000 genomes reference
        :param str mt_1kg_eur_path: path to 1000 genomes European reference
        :param str output_root_directory
        :param float r2: LD pruning threshold
        :param float hwe_thresh: hardy weinberg equilibrium test pvalues threshold
        :param n_partitions: number of partitions
        """

    mt = hl.read_matrix_table(path_to_mt)
    mt_1kg = hl.read_matrix_table(mt_1kg_path)
    mt_1kg_eur = hl.read_matrix_table(mt_1kg_eur_path)
    directory_structure = build_cohort_qc_directory_structure(output_root_directory)
    n_variants_before_qc, n_samples_before_qc = mt.count()

    # LD pruning & Identify related samples
    pruned_variants_list = ld_prune(mt, r2=r2, pruned_variants_list=True)
    pruned_mt = mt.filter_rows(hl.is_defined(pruned_variants_list[mt.row_key]), keep=True)
    samples_to_remove = identify_related_samples(pruned_mt, pihat_threshold=pihat)

    # Filter related samples
    n_related = samples_to_remove.count()
    mt = mt.filter_cols(hl.is_defined(samples_to_remove[mt.col_key]), keep=False)

    # Project onto 1000 genomes
    pruned_1kg_mt = mt_1kg.filter_rows(hl.is_defined(pruned_variants_list[mt_1kg.row_key]), keep=True)
    scores_1kg_ht, loadings_1kg_ht = pca(mt=pruned_1kg_mt, n_evecs=6, remove_outliers=False)
    scores_ht = pca_project(mt=mt, loadings_ht=loadings_1kg_ht, correct_shrinkage=True)
    scores_ht = scores_1kg_ht.union(scores_ht)
    scores_ht = scores_ht.annotate(super_population=mt_1kg.index_cols(scores_ht.key).super_population)

    # Use Random forest classifier to assign ancestry
    pops_ht, pop_clf = assign_population_pcs(
        pop_pca_scores = scores_ht, pc_cols = scores_ht.scores,
        known_col = 'super_population', min_prob = 0.9, prop_train = 0.8)

    # Plot ancestry assignment result
    pops_ht = pops_ht.annotate(PC1=pops_ht.pca_scores[0],
                               PC2=pops_ht.pca_scores[1],
                               PC3=pops_ht.pca_scores[2],
                               PC4=pops_ht.pca_scores[3],
                               PC5=pops_ht.pca_scores[4],
                               PC6=pops_ht.pca_scores[5])
    mt = mt.annotate_cols(pop=pops_ht[mt.col_key].pop)

    # For samples assigning to EUR population, trying to project them onto 1KG EUR,
    # and classify into eur(mainland), FIN and AJ
    if mt.aggregate_cols(hl.agg.count_where(mt.pop == 'eur')) > 0:
        mt_eur = mt.filter_cols(mt.pop == 'eur', keep=True)
        mt_1kg_eur_mt = mt_1kg_eur.filter_rows(hl.is_defined(pruned_variants_list[mt_1kg_eur.row_key]), keep=True)
        scores_1kg_eur_ht, loadings_1kg_eur_ht = pca(mt=mt_1kg_eur_mt, n_evecs=6, remove_outliers=False)
        scores_ht = pca_project(mt=mt_eur, loadings_ht=loadings_1kg_eur_ht, correct_shrinkage=True)
        scores_ht = scores_1kg_eur_ht.union(scores_ht)
        scores_ht = scores_ht.annotate(population=mt_1kg_eur.index_cols(scores_ht.key).population)

        # Use Random forest classifier to assign ancestry
        pops_ht, pop_clf = assign_population_pcs(
            pop_pca_scores=scores_ht, pc_cols = scores_ht.scores,
            known_col='population', min_prob = 0.9, prop_train = 0.8)

        pops_ht.write(directory_structure['pca']+'/scores_1kg_eur.kt', overwrite=True)
        mt = mt.transmute_cols(pop=hl.cond(mt.pop == 'eur', pops_ht[mt.col_key].pop, mt.pop))

    # Filter out samples that are not assigned to any populations
    mt = mt.filter_cols(mt.pop == 'oth', keep=False)

    # Count number of samples in each population
    pops = mt.aggregate_cols(hl.agg.counter(mt.pop))

    # Calculate p-values of Hardy Weinberg Equilibrium within each population
    for pop, count in pops.items():
        mt = mt.annotate_rows(**{"pop_"+pop:hl.agg.filter(mt.pop == pop, hl.agg.hardy_weinberg_test(mt.GT))})

    # Calculate minimum hwe p-values across populations
    mt = mt.annotate_rows(het_freq_hwe=hl.min([mt['pop_'+pop].het_freq_hwe for pop, count in pops.items()]),
                          p_value=hl.min([mt['pop_'+pop].p_value for pop, count in pops.items()]))

    # Count number of variants failing HWE filter
    n_hwe_variants = mt.aggregate_rows(hl.agg.count_where((mt.het_freq_hwe < hwe_thresh) | (mt.p_value < hwe_thresh)))

    # Filter out variants that fails HWE
    mt = mt.filter_rows((mt.het_freq_hwe > hwe_thresh) & (mt.p_value > hwe_thresh), keep = True)

    # Count samples & variants after QC
    n_variants_after_qc, n_samples_after_qc = mt.count()

    # create a dictionary storing the meta information
    meta = {"Number of Samples before QC": n_samples_before_qc,
            "Number of Samples after QC:": n_samples_after_qc,
            "Number of Variants before QC": n_variants_before_qc,
            "Number of Variants after QC": n_variants_after_qc,
            "SNPs: HWE p-values < %s" % hwe_thresh: n_hwe_variants,
            "Population Assignment": {
                "EUR (mainland)": pops.get("eur(mainland)", 0),
                "Ashkenazi Jew": pops.get("aj", 0),
                "FIN": pops.get("fin", 0),
                "EAS": pops.get("asn", 0),
                "AFR": pops.get("afr", 0),
                "AMR": pops.get("amr", 0),
                "SAS": pops.get("sas", 0)
            }}
    mt.write(directory_structure["share"]+"/all.mt", overwrite = True)

    # Export meta
    with hl.hadoop_open(directory_structure["summary"] + "/match_ancestry_summary.json", 'w') as outfile:
        json.dump(meta, outfile)

    # Write hail MatrixTable
    for pop, count in pops.items():
        pop_mt = mt.filter_cols(mt.pop == pop)
        pop_mt.write(directory_structure["share"]+"/{pop}.mt".format(pop=pop), overwrite=True)
Example #4
0
                        (mt.variant_qc.AF[0] > mafrsh) &
                        (mt.variant_qc.AF[1] > mafrsh) &
                        (mt.het_freq_hwe > hwersh) &
                        (mt.p_value > hwersh), keep=True)

    # LD pruning
    pruned_variants_list = ld_prune(mt, r2=r2, pruned_variants_list=True)
    pruned_mt = mt.filter_rows(hl.is_defined(pruned_variants_list[mt.row_key]), keep=True)
    samples_to_remove = identify_related_samples(pruned_mt, pihat_threshold=pirsh)

    # Filter related samples
    n_related = samples_to_remove.count()
    mt = mt.filter_cols(hl.is_defined(samples_to_remove[mt.col_key]), keep=False)

    # PCA
    scores_ht, _ = pca(mt=pruned_mt, n_evecs=6, remove_outliers=False)
    scores_ht.write(directory_structure["pca"] + "/scores.ht", overwrite=True)

    # Plot PCs with cohort label
    scores_ht = scores_ht.annotate(cohort=mt.index_cols(scores_ht.key).cohort)
    scores_ht = scores_ht.annotate(PC1=scores_ht.scores[0],
                                   PC2=scores_ht.scores[1],
                                   PC3=scores_ht.scores[2],
                                   PC4=scores_ht.scores[3],
                                   PC5=scores_ht.scores[4],
                                   PC6=scores_ht.scores[5])
    scatter(ht=scores_ht, x_location='PC1', y_location='PC2', color_location='cohort',
            plot_path=directory_structure['pca'] + '/SCATTER_PC1_PC2_BEFORE_QC.png')
    scatter(ht=scores_ht, x_location='PC1', y_location='PC3', color_location='cohort',
            plot_path=directory_structure['pca'] + '/SCATTER_PC1_PC3_BEFORE_QC.png')
    scatter(ht=scores_ht, x_location='PC2', y_location='PC3', color_location='cohort',
mt_AJ = mt_AJ.annotate_cols(population='aj')

# merge
mt_merged = mt_eur.select_cols('population').union_cols(
    mt_AJ.select_cols('population'))
pruned_variants_list = ld_prune(mt_merged, r2=0.2, pruned_variants_list=True)
pruned_variants_list.write(
    "gs://unicorn-resources/Ashkenazi_Jewish_PCA/ld_pruned_variants.kt",
    overwrite=True)
pruned_variants_list = hl.read_table(
    "gs://unicorn-resources/Ashkenazi_Jewish_PCA/ld_pruned_variants.kt")
mt_pruned = mt_merged.filter_rows(hl.is_defined(
    pruned_variants_list[mt_merged.row_key]),
                                  keep=True)

scores_ht, _ = pca(mt=mt_pruned, n_evecs=6, remove_outliers=False)
# Show ancestry assignment result
scores_ht = scores_ht.annotate(PC1=scores_ht.scores[0],
                               PC2=scores_ht.scores[1],
                               PC3=scores_ht.scores[2],
                               PC4=scores_ht.scores[3],
                               PC5=scores_ht.scores[4],
                               PC6=scores_ht.scores[5],
                               pop=mt_merged.index_cols(
                                   scores_ht.key).population)
scatter(
    ht=scores_ht,
    x_location='PC1',
    y_location='PC2',
    color_location='pop',
    plot_path=
def __main__(bfile: str,
             mt_1kg_path: str,
             mt_1kg_eur_path: str,
             output_root_directory: str,
             r2: float,
             pirsh: float,
             scrsh: float,
             scrsh_chr: float,
             vcrsh: float,
             mafrsh: float,
             hwersh: float,
             srh: int = 50,
             n_partitions: int = 200):
    """
        This function performs cohort-level QC including the following QC filters:
        - sample call rate > 0.98
        - sample call rate for each chromosome > 0.50
        - inbreeding coefficient should be less than 3 std from its mean
        - remove related samples (pi-hcat threshold = 0.0625)
        - remove sex-error samples
        - variant call rate > 0.98
        - minor allele frequencies > 0.01
        - within each population, p-hwe > 1e-4
        - autosome only
        This function uses a random forest classifier trained from 1000 genomes (AMR, AFR, EUR, EAS) to predict the continental ancestry for each sample.
        The detailed procedure includes:
            - filter to variants that are common in both dataset & LD prune
            - PCA with 1000 genomes (the input 1000 genomes is Stephan's cleaned version of 1KG
            - train random forest classifier with first 6 PCs
            - project genotyping data onto 1KG
            - predict the continental ancestry label using the projected PCs
        Within European population, uses another random forest classifier trained from 1000 genomes to predict whether they're from mainland European,
        Finland or Ashkenazi Jewish samples
        Note
        ----
        :param str bfile: path to input PLINK bfile
        :param str mt_1kg_path: path to 1000 genomes reference
        :param str mt_1kg_eur_path: path to 1000 genomes European reference
        :param str output_root_directory
        :param float r2: LD pruning threshold
        :param float pirsh: pihat threshold for filtering related samples
        :param float scrsh: sample call rate threshold
        :param float scrsh_chr: per chromosome sample call rate
        :param float vcrsh: variant call rate threshold
        :param float mafrsh: minor allele frequencies threshold
        :param float hwersh: hardy weinberg equilibrium test pvalues threshold
        :param int srh: cut-off for minimum sample size
        """

    mt = hl.import_plink(bed=bfile + '.bed',
                         bim=bfile + '.bim',
                         fam=bfile + '.fam',
                         min_partitions=n_partitions)
    mt_1kg = hl.read_matrix_table(mt_1kg_path)
    mt_1kg_eur = hl.read_matrix_table(mt_1kg_eur_path)
    directory_structure = build_cohort_qc_directory_structure(
        output_root_directory)
    n_variants_before_qc, n_samples_before_qc = mt.count()

    # If sample size is less than a threshold, skip
    if n_samples_before_qc < srh: return

    # Perform sample QC
    mt = hl.sample_qc(mt)
    mt = hl.variant_qc(mt)
    mt = calculate_per_chr_scrt(
        mt=mt, scrt_col='scr_chr'
    )  # calculate per-chromosome call rate for each sample
    mt = calculate_inbreeding_coefficients(
        mt=mt, ib_col='ib',
        ib_global='ib_stats')  # calculate inbreeding coefficient
    mt = mt.cache()

    # Calculate number of samples & variants that fail each filter
    n_scrsh = mt.aggregate_cols(
        hl.agg.count_where(mt.sample_qc.call_rate < scrsh))
    n_scrsh_chr = mt.aggregate_cols(hl.agg.count_where(mt.scr_chr < scrsh_chr))
    n_ib = mt.aggregate_cols(
        hl.agg.count_where(
            hl.abs(mt.ib.f_stat - mt.ib_stats.mean) > 3 * mt.ib_stats.stdev))
    n_vcrsh = mt.aggregate_rows(
        hl.agg.count_where(mt.variant_qc.call_rate < vcrsh))
    n_mafsh = mt.aggregate_rows(
        hl.agg.count_where((mt.variant_qc.AF[0] < mafrsh)
                           | (mt.variant_qc.AF[1] < mafrsh)))

    # Histogram of QC metrics before QC
    histogram(mt.cols(),
              location='sample_qc.call_rate',
              plot_path=directory_structure["summary"] +
              "/hist_before_QC_sample_call_rate.png")
    histogram(mt.cols(),
              location='ib.f_stat',
              plot_path=directory_structure["summary"] +
              "/hist_before_QC_inbreeding_coefficient.png")
    histogram(mt.rows(),
              location='variant_qc.call_rate',
              plot_path=directory_structure["summary"] +
              "/hist_before_QC_variant_call_rate.png")

    # Filter samples & variants
    mt = mt.filter_cols(
        (mt.sample_qc.call_rate > scrsh) &
        (hl.abs(mt.ib.f_stat - mt.ib_stats.mean) <= 3 * mt.ib_stats.stdev) &
        (mt.scr_chr > scrsh_chr),
        keep=True)
    mt = mt.filter_rows(
        (mt.variant_qc.call_rate > vcrsh) & (mt.variant_qc.AF[0] > mafrsh) &
        (mt.variant_qc.AF[1] > mafrsh) & (mt.locus.in_autosome()),
        keep=True)

    # Histogram of QC metrics after QC
    histogram(mt.cols(),
              location='sample_qc.call_rate',
              plot_path=directory_structure["summary"] +
              "/hist_after_QC_sample_call_rate.png")
    histogram(mt.cols(),
              location='ib.f_stat',
              plot_path=directory_structure["summary"] +
              "/hist_after_QC_inbreeding_coefficient.png")
    histogram(mt.rows(),
              location='variant_qc.call_rate',
              plot_path=directory_structure["summary"] +
              "/hist_after_QC_variant_call_rate.png")

    # LD pruning & Identify related samples
    pruned_variants_list = ld_prune(mt, r2=r2, pruned_variants_list=True)
    pruned_variants_list.write(directory_structure["pca"] +
                               "/ld_pruned_variants.kt",
                               overwrite=True)
    pruned_variants_list = hl.read_table(directory_structure["pca"] +
                                         "/ld_pruned_variants.kt")
    pruned_mt = mt.filter_rows(hl.is_defined(pruned_variants_list[mt.row_key]),
                               keep=True)
    samples_to_remove = identify_related_samples(pruned_mt,
                                                 pihat_threshold=pirsh)

    # Filter related samples
    n_related = samples_to_remove.count()
    mt = mt.filter_cols(hl.is_defined(samples_to_remove[mt.col_key]),
                        keep=False)

    # Project onto 1000 genomes
    pruned_1kg_mt = mt_1kg.filter_rows(hl.is_defined(
        pruned_variants_list[mt_1kg.row_key]),
                                       keep=True)
    scores_1kg_ht, loadings_1kg_ht = pca(mt=pruned_1kg_mt,
                                         n_evecs=6,
                                         remove_outliers=False)
    scores_ht = pca_project(mt=mt,
                            loadings_ht=loadings_1kg_ht,
                            correct_shrinkage=True)
    scores_ht = scores_1kg_ht.union(scores_ht)
    scores_ht = scores_ht.annotate(
        super_population=mt_1kg.index_cols(scores_ht.key).super_population)

    # Use Random forest classifier to assign ancestry
    pops_ht, pop_clf = assign_population_pcs(pop_pca_scores=scores_ht,
                                             pc_cols=scores_ht.scores,
                                             known_col='super_population',
                                             min_prob=0.9,
                                             prop_train=0.8)

    # Plot ancestry assignment result
    pops_ht = pops_ht.annotate(PC1=pops_ht.pca_scores[0],
                               PC2=pops_ht.pca_scores[1],
                               PC3=pops_ht.pca_scores[2],
                               PC4=pops_ht.pca_scores[3],
                               PC5=pops_ht.pca_scores[4],
                               PC6=pops_ht.pca_scores[5])
    scatter(ht=pops_ht,
            x_location='PC1',
            y_location='PC2',
            color_location='pop',
            plot_path=directory_structure['pca'] + '/SCATTER_1KG_PC1_PC2.png')
    scatter(ht=pops_ht,
            x_location='PC1',
            y_location='PC3',
            color_location='pop',
            plot_path=directory_structure['pca'] + '/SCATTER_1KG_PC1_PC3.png')
    scatter(ht=pops_ht,
            x_location='PC2',
            y_location='PC3',
            color_location='pop',
            plot_path=directory_structure['pca'] + '/SCATTER_1KG_PC2_PC3.png')
    scatter(ht=pops_ht,
            x_location='PC3',
            y_location='PC4',
            color_location='pop',
            plot_path=directory_structure['pca'] + '/SCATTER_1KG_PC3_PC4.png')
    scatter(ht=pops_ht,
            x_location='PC4',
            y_location='PC5',
            color_location='pop',
            plot_path=directory_structure['pca'] + '/SCATTER_1KG_PC4_PC5.png')
    scatter(ht=pops_ht,
            x_location='PC5',
            y_location='PC6',
            color_location='pop',
            plot_path=directory_structure['pca'] + '/SCATTER_1KG_PC5_PC6.png')
    pops_ht.write(directory_structure['pca'] + '/scores_1kg.kt',
                  overwrite=True)
    mt = mt.annotate_cols(pop=pops_ht[mt.col_key].pop)

    # For samples assigning to EUR population, trying to project them onto 1KG EUR,
    # and classify into eur(mainland), FIN and AJ
    if mt.aggregate_cols(hl.agg.count_where(mt.pop == 'eur')) > 0:
        mt_eur = mt.filter_cols(mt.pop == 'eur', keep=True)
        mt_1kg_eur_mt = mt_1kg_eur.filter_rows(hl.is_defined(
            pruned_variants_list[mt_1kg_eur.row_key]),
                                               keep=True)
        scores_1kg_eur_ht, loadings_1kg_eur_ht = pca(mt=mt_1kg_eur_mt,
                                                     n_evecs=6,
                                                     remove_outliers=False)
        scores_ht = pca_project(mt=mt_eur,
                                loadings_ht=loadings_1kg_eur_ht,
                                correct_shrinkage=True)
        scores_ht = scores_1kg_eur_ht.union(scores_ht)
        scores_ht = scores_ht.annotate(
            population=mt_1kg_eur.index_cols(scores_ht.key).population)

        # Use Random forest classifier to assign ancestry
        pops_ht, pop_clf = assign_population_pcs(pop_pca_scores=scores_ht,
                                                 pc_cols=scores_ht.scores,
                                                 known_col='population',
                                                 min_prob=0.9,
                                                 prop_train=0.8)

        # Plot ancestry assignment result
        pops_ht = pops_ht.annotate(PC1=pops_ht.pca_scores[0],
                                   PC2=pops_ht.pca_scores[1],
                                   PC3=pops_ht.pca_scores[2],
                                   PC4=pops_ht.pca_scores[3],
                                   PC5=pops_ht.pca_scores[4],
                                   PC6=pops_ht.pca_scores[5])
        scatter(ht=pops_ht,
                x_location='PC1',
                y_location='PC2',
                color_location='pop',
                plot_path=directory_structure['pca'] +
                '/SCATTER_1KG_EUR_PC1_PC2.png')
        scatter(ht=pops_ht,
                x_location='PC1',
                y_location='PC3',
                color_location='pop',
                plot_path=directory_structure['pca'] +
                '/SCATTER_1KG_EUR_PC1_PC3.png')
        scatter(ht=pops_ht,
                x_location='PC2',
                y_location='PC3',
                color_location='pop',
                plot_path=directory_structure['pca'] +
                '/SCATTER_1KG_EUR_PC2_PC3.png')
        scatter(ht=pops_ht,
                x_location='PC3',
                y_location='PC4',
                color_location='pop',
                plot_path=directory_structure['pca'] +
                '/SCATTER_1KG_EUR_PC3_PC4.png')
        scatter(ht=pops_ht,
                x_location='PC4',
                y_location='PC5',
                color_location='pop',
                plot_path=directory_structure['pca'] +
                '/SCATTER_1KG_EUR_PC4_PC5.png')
        scatter(ht=pops_ht,
                x_location='PC5',
                y_location='PC6',
                color_location='pop',
                plot_path=directory_structure['pca'] +
                '/SCATTER_1KG_EUR_PC5_PC6.png')
        pops_ht.write(directory_structure['pca'] + '/scores_1kg_eur.kt',
                      overwrite=True)
        mt = mt.transmute_cols(
            pop=hl.cond(mt.pop == 'eur', pops_ht[mt.col_key].pop, mt.pop))

    # Filter out samples that are not assigned to any populations
    mt = mt.filter_cols(mt.pop == 'oth', keep=False)

    # Count number of samples in each population
    pops = mt.aggregate_cols(hl.agg.counter(mt.pop))

    # Calculate p-values of Hardy Weinberg Equilibrium within each population
    for pop, count in pops.items():
        if count < srh:  # if number of samples in this population is lower than certain threshold, skip HWE test
            continue
        mt = mt.annotate_rows(
            **{
                "pop_" + pop:
                hl.agg.filter(mt.pop == pop, hl.agg.hardy_weinberg_test(mt.GT))
            })

    # Calculate minimum hwe p-values across populations
    mt = mt.annotate_rows(het_freq_hwe=hl.min([
        mt['pop_' + pop].het_freq_hwe for pop, count in pops.items()
        if count > srh
    ]),
                          p_value=hl.min([
                              mt['pop_' + pop].p_value
                              for pop, count in pops.items() if count > srh
                          ]))

    # Count number of variants failing HWE filter
    n_hrsh = mt.aggregate_rows(
        hl.agg.count_where((mt.het_freq_hwe < hwersh) | (mt.p_value < hwersh)))

    # Filter out variants that fails HWE
    mt = mt.filter_rows((mt.het_freq_hwe > hwersh) & (mt.p_value > hwersh),
                        keep=True)

    # Count samples & variants after QC
    n_variants_after_qc, n_samples_after_qc = mt.count()

    # create a dictionary storing the meta information
    meta = {
        "Number of Samples before QC": n_samples_before_qc,
        "Number of Samples after QC:": n_samples_after_qc,
        "Number of Variants before QC": n_variants_before_qc,
        "Number of Variants after QC": n_variants_after_qc,
        "Sample QC Summary": {
            "IDs: call rate < %s" % scrsh: n_scrsh,
            "IDs: minimum per-chromosome call rate < %s" % scrsh_chr:
            n_scrsh_chr,
            "IDs: pi-hat > %s" % pirsh: n_related,
            "IDs: inbreeding coefficient is 3 stdev from its mean": n_ib
        },
        "Variant QC Summary": {
            "SNPs: call rate < %s" % vcrsh: n_vcrsh,
            "SNPs: minor allele frequency < %s" % mafrsh: n_mafsh,
            "SNPs: HWE p-values < %s" % hwersh: n_hrsh
        },
        "Population Assignment": {
            "EUR (mainland)": pops.get("eur(mainland)", 0),
            "Ashkenazi Jew": pops.get("aj", 0),
            "FIN": pops.get("fin", 0),
            "EAS": pops.get("asn", 0),
            "AFR": pops.get("afr", 0),
            "AMR": pops.get("amr", 0)
        }
    }
    mt.write(directory_structure["share"] + "/all.mt", overwrite=True)

    # Export meta
    with hl.hadoop_open(directory_structure["summary"] + "/meta.json",
                        'w') as outfile:
        json.dump(meta, outfile)

    # Write hail MatrixTable
    for pop, count in pops.items():
        pop_mt = mt.filter_cols(mt.pop == pop)
        pop_mt.write(directory_structure["share"] +
                     "/{pop}.mt".format(pop=pop),
                     overwrite=True)
Example #7
0
def __main__(path_to_mt: str, output_root_directory: str, population: str,
             cohort_label: str, mt_1kg_eur_path: str, r2: float, pirsh: float,
             scrsh: float, scrsh_chr: float, vcrsh: float, mafrsh: float,
             hwersh: float, assrsh: float):
    """
        This function performs cohort-level QC including the following QC filters:
        - sample call rate > 0.98
        - sample call rate for each chromosome > 0.50
        - remove related samples (pi-hat threshold = 0.0625)
        - variant call rate > 0.98
        - minor allele frequencies > 0.01
        - within each population, p-hwe > 1e-4
        - pseudo case-control GWAS (tagging samples from one cohort as cases, from the other cohorts as controls) p-values > 1e-4
        - pseudo case-control GWAS against 1KG p-values > 1e-4 (currently it is EUR only)
        Note
        In order to convert VCF to zipped VCF, run the following commands:
            sudo apt-get update
            sudo apt-get install vcftools -y
            sudo apt-get install tabix -y
        ----
        :param List[str] path_to_mt: a list of path to MatrixTable that are from same population & same array
        :param List[str] cohort_labels: a list of cohort labels
        :param str output_root_directory
        :param str population
        :param str array
        :param str mt_1kg_eur_path: path to 1KG EUR population
        :param float r2: LD pruning threshold
        :param float pirsh: pihat threshold for filtering related samples
        :param float scrsh: sample call rate threshold
        :param float scrsh_chr: per chromosome sample call rate
        :param float vcrsh: variant call rate threshold
        :param float mafrsh: minor allele frequencies threshold
        :param float hwersh: hardy weinberg equilibrium test pvalues threshold
        :param float assrsh: p-value threshold for pseudo GWAS
        """

    # Merge MatrixTable from same population and same genotyping array
    mt = hl.read_matrix_table(path_to_mt)

    mt_1kg_eur = hl.read_matrix_table(mt_1kg_eur_path)
    mt_1kg_eur = mt_1kg_eur.filter_cols(
        mt_1kg_eur.population == 'eur(mainland)')

    # Build Directory structure
    directory_structure = build_array_qc_directory_structure(
        output_root_directory, population=population, array=cohort_label)

    # Count number of samples and variants before QC
    n_variants_before_qc, n_samples_before_qc = mt.count()

    # Perform sample QC
    mt = hl.sample_qc(mt)
    mt = hl.variant_qc(mt)
    mt = calculate_per_chr_scrt(mt=mt, scrt_col='scr_chr')
    mt = mt.cache()

    # Calculate number of samples & variants that fail each filter
    n_scrsh = mt.aggregate_cols(
        hl.agg.count_where(mt.sample_qc.call_rate < scrsh))
    n_scrsh_chr = mt.aggregate_cols(hl.agg.count_where(mt.scr_chr < scrsh_chr))
    n_vcrsh = mt.aggregate_rows(
        hl.agg.count_where(mt.variant_qc.call_rate < vcrsh))
    n_mafsh = mt.aggregate_rows(
        hl.agg.count_where((mt.variant_qc.AF[0] < mafrsh)
                           | (mt.variant_qc.AF[1] < mafrsh)))
    n_hwesh = mt.aggregate_rows(
        hl.agg.count_where((mt.het_freq_hwe < hwersh) | (mt.p_value < hwersh)))

    # Filter samples & variants
    mt = mt.filter_cols(
        (mt.sample_qc.call_rate > scrsh) & (mt.scr_chr > scrsh_chr), keep=True)
    mt = mt.filter_rows(
        (mt.variant_qc.call_rate > vcrsh) & (mt.variant_qc.AF[0] > mafrsh) &
        (mt.variant_qc.AF[1] > mafrsh) & (mt.het_freq_hwe > hwersh) &
        (mt.p_value > hwersh),
        keep=True)

    # LD pruning
    pruned_variants_list = ld_prune(mt, r2=r2, pruned_variants_list=True)
    pruned_mt = mt.filter_rows(hl.is_defined(pruned_variants_list[mt.row_key]),
                               keep=True)
    samples_to_remove = identify_related_samples(pruned_mt,
                                                 pihat_threshold=pirsh)

    # Filter related samples
    n_related = samples_to_remove.count()
    mt = mt.filter_cols(hl.is_defined(samples_to_remove[mt.col_key]),
                        keep=False)

    # Pseudo Case-Control Analysis
    n_cscohort = 0
    pruned_1kg_eur_mt = mt_1kg_eur.filter_rows(hl.is_defined(
        pruned_variants_list[mt_1kg_eur.row_key]),
                                               keep=True)
    scores_1kg_eur_ht, loadings_1kg_eur_ht = pca(mt=pruned_1kg_eur_mt,
                                                 n_evecs=6,
                                                 remove_outliers=False)
    scores_ht = pca_project(mt=mt,
                            loadings_ht=loadings_1kg_eur_ht,
                            correct_shrinkage=True)
    mt = mt.annotate_cols(scores=scores_ht[mt.col_key].scores)

    # Pseudo Case-Control Analysis against 1KG EUR
    n_cs1kg = 0
    if population == "eur_mainland":
        mt_merged = mt.select_cols().select_rows().union_cols(
            mt_1kg_eur.select_cols().select_rows())
        scores_ht = scores_ht.union(scores_1kg_eur_ht)

        mt_merged = mt_merged.annotate_cols(
            scores=scores_ht[mt_merged.col_key].scores,
            is_case=hl.cond(hl.is_defined(mt.index_cols(mt_merged.col_key)), 1,
                            0))
        result_ht = hl.linear_regression_rows(
            y=mt_merged.is_case,
            x=mt_merged.GT.n_alt_alleles(),
            covariates=[
                1, mt_merged.scores[0], mt_merged.scores[1],
                mt_merged.scores[2], mt_merged.scores[3], mt_merged.scores[4],
                mt_merged.scores[5]
            ])

        # Annotate p-values from comparison against 1KG EUR
        mt = mt.annotate_rows(pvalues_1kg_assoc=result_ht[mt.row_key].p_value)

        # Count number of variants failing comparison against 1KG EUR
        n_cs1kg = mt.aggregate_rows(
            hl.agg.count_where(mt.pvalues_1kg_assoc < assrsh))

        # Filter variants
        mt = mt.filter_rows((mt.pvalues_1kg_assoc > assrsh) |
                            (hl.is_missing(mt.pvalues_1kg_assoc)),
                            keep=True)

        # PLOT projecting onto 1KG EUR
        scores_ht = scores_ht.annotate(PC1=scores_ht.scores[0],
                                       PC2=scores_ht.scores[1],
                                       PC3=scores_ht.scores[2],
                                       PC4=scores_ht.scores[3],
                                       PC5=scores_ht.scores[4],
                                       PC6=scores_ht.scores[5])
        scatter(ht=scores_ht,
                x_location='PC1',
                y_location='PC2',
                plot_path=directory_structure['pca'] +
                '/SCATTER_1KG_EUR_PC1_PC2.png')
        scatter(ht=scores_ht,
                x_location='PC1',
                y_location='PC3',
                plot_path=directory_structure['pca'] +
                '/SCATTER_1KG_EUR_PC1_PC3.png')
        scatter(ht=scores_ht,
                x_location='PC2',
                y_location='PC3',
                plot_path=directory_structure['pca'] +
                '/SCATTER_1KG_EUR_PC2_PC3.png')
        scatter(ht=scores_ht,
                x_location='PC3',
                y_location='PC4',
                plot_path=directory_structure['pca'] +
                '/SCATTER_1KG_EUR_PC3_PC4.png')
        scatter(ht=scores_ht,
                x_location='PC4',
                y_location='PC5',
                plot_path=directory_structure['pca'] +
                '/SCATTER_1KG_EUR_PC4_PC5.png')
        scatter(ht=scores_ht,
                x_location='PC5',
                y_location='PC6',
                plot_path=directory_structure['pca'] +
                '/SCATTER_1KG_EUR_PC5_PC6.png')
        scores_ht.write(directory_structure['pca'] + '/scores_1kg.kt',
                        overwrite=True)

    # LD pruning
    pruned_mt = mt.filter_rows(hl.is_defined(pruned_variants_list[mt.row_key]),
                               keep=True)

    # PCA
    scores_ht, _ = pca(mt=pruned_mt,
                       n_evecs=10,
                       remove_outliers=True,
                       sigma_thresh=5,
                       n_outlieriters=3)
    scores_ht.write(directory_structure["pca"] + "/scores.ht", overwrite=True)

    # Plot PCs with cohort label
    scores_ht = scores_ht.annotate(PC1=scores_ht.scores[0],
                                   PC2=scores_ht.scores[1],
                                   PC3=scores_ht.scores[2],
                                   PC4=scores_ht.scores[3],
                                   PC5=scores_ht.scores[4],
                                   PC6=scores_ht.scores[5])
    scatter(ht=scores_ht,
            x_location='PC1',
            y_location='PC2',
            plot_path=directory_structure['pca'] +
            '/SCATTER_PC1_PC2_AFTER_QC.png')
    scatter(ht=scores_ht,
            x_location='PC1',
            y_location='PC3',
            plot_path=directory_structure['pca'] +
            '/SCATTER_PC1_PC3_AFTER_QC.png')
    scatter(ht=scores_ht,
            x_location='PC2',
            y_location='PC3',
            plot_path=directory_structure['pca'] +
            '/SCATTER_PC2_PC3_AFTER_QC.png')
    scatter(ht=scores_ht,
            x_location='PC3',
            y_location='PC4',
            plot_path=directory_structure['pca'] +
            '/SCATTER_PC3_PC4_AFTER_QC.png')
    scatter(ht=scores_ht,
            x_location='PC4',
            y_location='PC5',
            plot_path=directory_structure['pca'] +
            '/SCATTER_PC4_PC5_AFTER_QC.png')
    scatter(ht=scores_ht,
            x_location='PC5',
            y_location='PC6',
            plot_path=directory_structure['pca'] +
            '/SCATTER_PC5_PC6_AFTER_QC.png')

    # Count PCA outliers
    n_pca_outliers = mt.aggregate_cols(
        hl.agg.count_where(~hl.is_defined(scores_ht[mt.col_key])))

    # Filter PCA outliers
    mt = mt.filter_cols(hl.is_defined(scores_ht[mt.col_key]), keep=True)

    # Count number of samples and variants after QC
    n_variants_after_qc, n_samples_after_qc = mt.count()

    # Write QC'ed data to google bucket
    mt.write(directory_structure["share"] + "/final.mt", overwrite=True)
    #mt = hl.read_matrix_table(directory_structure["share"]+"/final.mt")

    # Write QC meta information
    meta = {
        "Number of Samples before QC": n_samples_before_qc,
        "Number of Samples after QC:": n_samples_after_qc,
        "Number of Variants before QC": n_variants_before_qc,
        "Number of Variants after QC": n_variants_after_qc,
        "Genotyping Array": cohort_label,
        "Population": population,
        "Sample QC Summary": {
            "IDs: call rate < %s" % scrsh: n_scrsh,
            "IDs: minimum per-chromosome call rate < %s" % scrsh_chr:
            n_scrsh_chr,
            "IDs: pi-hat > %s" % pirsh: n_related,
            "IDs: PCA outliers": n_pca_outliers
        },
        "Variant QC Summary": {
            "SNPs: call rate < %s" % vcrsh:
            n_vcrsh,
            "SNPs: minor allele frequency < %s" % mafrsh:
            n_mafsh,
            "SNPs: HWE p-values < %s" % hwersh:
            n_hwesh,
            "SNPs: Cross cohorts pseudo case-control across p-value < %s" % assrsh:
            n_cscohort + n_cs1kg
        }
    }
    with hl.hadoop_open(directory_structure["summary"] + "/meta.json",
                        'w') as outfile:
        json.dump(meta, outfile)

    # Convert to VCF files
    convert_to_vcf(mt=mt,
                   output_root_for_unzip_vcf=directory_structure['vcf'] +
                   "/unzip",
                   output_root_for_zip_vcf=directory_structure['vcf'] + "/zip",
                   basename='final')