Exemplo n.º 1
0
def main():
    parser = argparse.ArgumentParser(
        "Set up the paired-gene subtype expression effect isolation "
        "experiment by enumerating the subtypes to be tested.")

    # create positional command line arguments
    parser.add_argument('cohort', type=str, help="which TCGA cohort to use")
    parser.add_argument('mut_levels',
                        type=str,
                        help="the mutation property levels to consider")
    parser.add_argument('genes',
                        type=str,
                        nargs='+',
                        help="a list of mutated genes")

    # create optional command line arguments
    parser.add_argument('--samp_cutoff',
                        type=int,
                        default=20,
                        help='subtype sample frequency threshold')
    parser.add_argument('--verbose',
                        '-v',
                        action='store_true',
                        help='turns on diagnostic messages')

    # parse command line arguments, create directory where found subtypes
    # will be stored
    args = parser.parse_args()
    use_lvls = args.mut_levels.split('__')
    out_path = os.path.join(base_dir, 'setup', args.cohort,
                            '_'.join(args.genes))
    os.makedirs(out_path, exist_ok=True)

    # log into Synapse using locally stored credentials
    syn = synapseclient.Synapse()
    syn.cache.cache_root_dir = syn_root
    syn.login()

    cdata = MutationCohort(cohort=args.cohort,
                           mut_genes=args.genes,
                           mut_levels=['Gene'] + use_lvls,
                           expr_source='Firehose',
                           var_source='mc3',
                           copy_source='Firehose',
                           annot_file=annot_file,
                           expr_dir=expr_dir,
                           domain_dir=domain_dir,
                           cv_prop=1.0,
                           syn=syn)

    iso_mtypes = set()
    for gene in args.genes:
        other_samps = reduce(or_, [
            cdata.train_mut[other_gn].get_samples()
            for other_gn in set(args.genes) - {gene}
        ])

        if args.verbose:
            print("Looking for combinations of subtypes of mutations in gene "
                  "{} present in at least {} of the samples in TCGA cohort "
                  "{} at annotation levels {}.\n".format(
                      gene, args.samp_cutoff, args.cohort, use_lvls))

        pnt_mtypes = cdata.train_mut[gene]['Point'].find_unique_subtypes(
            max_types=500,
            max_combs=2,
            verbose=2,
            sub_levels=use_lvls,
            min_type_size=args.samp_cutoff)

        # filter out the subtypes that appear in too many samples for there to
        # be a wild-type class of sufficient size for classification
        pnt_mtypes = {
            MuType({('Scale', 'Point'): mtype})
            for mtype in pnt_mtypes
            if (len(mtype.get_samples(cdata.train_mut[gene]['Point'])) <= (
                len(cdata.samples) - args.samp_cutoff))
        }
        pnt_mtypes |= {MuType({('Scale', 'Point'): None})}

        cna_mtypes = cdata.train_mut[gene]['Copy'].branchtypes(
            min_size=args.samp_cutoff)
        cna_mtypes |= {MuType({('Copy', ('HetGain', 'HomGain')): None})}
        cna_mtypes |= {MuType({('Copy', ('HetDel', 'HomDel')): None})}

        cna_mtypes = {
            MuType({('Scale', 'Copy'): mtype})
            for mtype in cna_mtypes
            if (len(mtype.get_samples(cdata.train_mut[gene]['Copy'])) <= (
                len(cdata.samples) - args.samp_cutoff))
        }

        all_mtype = MuType(cdata.train_mut[gene].allkey())
        use_mtypes = pnt_mtypes | cna_mtypes

        only_mtypes = {
            (MuType({('Gene', gene): mtype}), )
            for mtype in use_mtypes if (len(
                mtype.get_samples(cdata.train_mut[gene]) -
                (all_mtype - mtype).get_samples(cdata.train_mut[gene]) -
                other_samps) >= args.samp_cutoff)
        }

        comb_mtypes = {(MuType({('Gene', gene):
                                mtype1}), MuType({('Gene', gene): mtype2}))
                       for mtype1, mtype2 in combn(use_mtypes, 2)
                       if ((mtype1 & mtype2).is_empty() and (
                           len((mtype1.get_samples(cdata.train_mut[gene])
                                & mtype2.get_samples(cdata.train_mut[gene])) -
                               (mtype1.get_samples(cdata.train_mut[gene])
                                ^ mtype2.get_samples(cdata.train_mut[gene])) -
                               (all_mtype - mtype1 -
                                mtype2).get_samples(cdata.train_mut[gene]) -
                               other_samps) >= args.samp_cutoff))}

        iso_mtypes |= only_mtypes | comb_mtypes
        if args.verbose:
            print(
                "\nFound {} exclusive sub-types and {} combination sub-types "
                "to isolate!".format(len(only_mtypes), len(comb_mtypes)))

    for cur_genes in chain.from_iterable(
            combn(args.genes, r) for r in range(1, len(args.genes))):
        gene_mtype = MuType({('Gene', cur_genes): None})
        rest_mtype = MuType({
            ('Gene', tuple(set(args.genes) - set(cur_genes))):
            None
        })

        if (args.samp_cutoff <= len(
                gene_mtype.get_samples(cdata.train_mut) -
                rest_mtype.get_samples(cdata.train_mut)) <=
            (len(cdata.samples) - args.samp_cutoff)):
            iso_mtypes |= {(gene_mtype, )}

    if args.verbose:
        print("\nFound {} total sub-types to isolate!".format(len(iso_mtypes)))

    # save the list of found non-duplicate sub-types to file
    pickle.dump(
        sorted(iso_mtypes),
        open(
            os.path.join(
                out_path, 'mtypes_list__samps_{}__levels_{}.p'.format(
                    args.samp_cutoff, args.mut_levels)), 'wb'))

    with open(
            os.path.join(
                out_path, 'mtypes_count__samps_{}__levels_{}.txt'.format(
                    args.samp_cutoff, args.mut_levels)), 'w') as fl:

        fl.write(str(len(iso_mtypes)))
Exemplo n.º 2
0
def plot_overlap_divergence(pred_dfs, pheno_dicts, auc_lists, cdata_dict, args,
                            siml_metric):
    fig, (sngl_ax, mult_ax) = plt.subplots(figsize=(12, 14), nrows=2)

    siml_dicts = {(src, coh): dict() for src, coh in auc_lists}
    gn_dict = dict()
    test = dict()

    # for each dataset, find the subgroupings meeting the minimum task AUC
    # that are exclusively defined and subsets of point mutations...
    for (src, coh), auc_list in auc_lists.items():
        test[src, coh] = dict()
        use_combs = remove_pheno_dups(
            {
                mut
                for mut, auc_val in auc_list.iteritems()
                if (isinstance(mut, ExMcomb) and auc_val >= args.auc_cutoff
                    and get_mut_ex(mut) == args.ex_lbl and all(
                        pnt_mtype.is_supertype(get_subtype(mtype))
                        for mtype in mut.mtypes))
            }, pheno_dicts[src, coh])

        # skip this dataset for plotting if we cannot find any such pairs
        if not use_combs:
            continue

        # get sample order used in the cohort and a breakdown of mutations
        # in which each individual mutation can be uniquely identified
        train_samps = cdata_dict[src, coh].get_train_samples()
        use_mtree = cdata_dict[src, coh].mtrees['Gene', 'Scale', 'Copy',
                                                'Exon', 'Position', 'HGVSp']
        use_genes = {get_label(mcomb) for mcomb in use_combs}
        cmp_phns = {gene: {'Sngl': None, 'Mult': None} for gene in use_genes}

        # get the samples carrying a single point mutation or multiple
        # mutations of each gene with at least one mutation in the cohort
        for gene in use_genes:
            gene_tree = use_mtree[gene]['Point']

            if args.ex_lbl == 'Iso':
                gene_cpy = MuType({('Gene', gene): copy_mtype})
            else:
                gene_cpy = MuType({('Gene', gene): deep_mtype})

            cpy_samps = gene_cpy.get_samples(use_mtree)
            samp_counts = {
                samp: 0
                for samp in (gene_tree.get_samples() - cpy_samps)
            }

            for subk in MuType(gene_tree.allkey()).leaves():
                for samp in MuType(subk).get_samples(gene_tree):
                    if samp in samp_counts:
                        samp_counts[samp] += 1

            for samp in train_samps:
                if samp not in samp_counts:
                    samp_counts[samp] = 0

            cmp_phns[gene]['Sngl'] = np.array(
                [samp_counts[samp] == 1 for samp in train_samps])
            cmp_phns[gene]['Mult'] = np.array(
                [samp_counts[samp] > 1 for samp in train_samps])

        all_mtypes = {
            gene: MuType({('Gene', gene): use_mtree[gene].allkey()})
            for gene in use_genes
        }

        if args.ex_lbl == 'IsoShal':
            for gene in use_genes:
                all_mtypes[gene] -= MuType({('Gene', gene): shal_mtype})

        all_phns = {
            gene: np.array(cdata_dict[src, coh].train_pheno(all_mtype))
            for gene, all_mtype in all_mtypes.items()
        }

        # for each subgrouping, find the subset of point mutations that
        # defines it, the gene it's associated with, and its task predictions
        for mcomb in use_combs:
            cur_gene = get_label(mcomb)
            use_preds = pred_dfs[src, coh].loc[mcomb, train_samps]

            # get the samples that carry any point mutation of this gene
            if (src, coh, cur_gene) not in gn_dict:
                gn_dict[src, coh,
                        cur_gene] = np.array(cdata_dict[src, coh].train_pheno(
                            MuType({('Gene', cur_gene): pnt_mtype})))

            # find the samples carrying one or multiple point mutations of
            # this gene not belonging to this subgrouping
            cmp_phn = ~pheno_dicts[src, coh][mcomb]
            if len(mcomb.mtypes) == 1:
                cmp_phn &= cmp_phns[cur_gene]['Mult']
            else:
                cmp_phn &= cmp_phns[cur_gene]['Sngl']

            if cmp_phn.sum() >= 1:
                siml_dicts[src, coh][mcomb] = siml_fxs[siml_metric](
                    use_preds.loc[~all_phns[cur_gene]],
                    use_preds.loc[pheno_dicts[src, coh][mcomb]],
                    use_preds.loc[cmp_phn])
                test[src, coh][mcomb] = sum(pheno_dicts[src, coh][mcomb]
                                            & cmp_phns[cur_gene]['Mult'])

    plt_df = pd.DataFrame(
        {'Siml': pd.DataFrame.from_records(siml_dicts).stack()})
    plt_df['AUC'] = [
        auc_lists[src, coh][mcomb] for mcomb, (src, coh) in plt_df.index
    ]

    gene_means = plt_df.groupby(lambda x:
                                (get_label(x[0]), len(x[0].mtypes))).mean()
    clr_dict = {
        gene: choose_label_colour(gene)
        for gene, _ in gene_means.index
    }
    size_mult = plt_df.groupby(
        lambda x: len(x[0].mtypes)).Siml.count().max()**-0.23

    xlims = [
        args.auc_cutoff - (1 - args.auc_cutoff) / 47,
        1 + (1 - args.auc_cutoff) / 277
    ]
    ymin, ymax = plt_df.Siml.quantile(q=[0, 1])
    yrng = ymax - ymin
    ylims = [ymin - yrng / 23, ymax + yrng / 23]

    plot_dicts = {
        mcomb_i:
        {(auc_val, siml_val): [0.0001, (gene, '')]
         for (gene, mcomb_indx), (siml_val, auc_val) in gene_means.iterrows()
         if mcomb_indx == mcomb_i}
        for mcomb_i in [1, 2]
    }

    for (mcomb, (src, coh)), (siml_val, auc_val) in plt_df.iterrows():
        cur_gene = get_label(mcomb)

        plt_size = size_mult * np.mean(pheno_dicts[src, coh][mcomb])
        plot_dicts[(len(mcomb.mtypes) == 2) +
                   1][auc_val, siml_val] = [0.19 * plt_size, ('', '')]

        if len(mcomb.mtypes) == 1:
            use_ax = sngl_ax
        else:
            use_ax = mult_ax

        use_ax.scatter(auc_val,
                       siml_val,
                       s=3751 * plt_size,
                       c=[clr_dict[cur_gene]],
                       alpha=0.25,
                       edgecolor='none')

    for ax, mcomb_i in zip([sngl_ax, mult_ax], [1, 2]):
        ax.grid(alpha=0.47, linewidth=0.9)
        ax.plot([1, 1], ylims, color='black', linewidth=1.7, alpha=0.83)

        for yval in [0, 1]:
            if xlims[0] < yval < xlims[1]:
                ax.plot(xlims, [yval, yval],
                        color='black',
                        linewidth=1.11,
                        linestyle='--',
                        alpha=0.67)

                for k in np.linspace(args.auc_cutoff, 0.99, 200):
                    if (k, yval) not in plot_dicts[mcomb_i]:
                        plot_dicts[mcomb_i][k, yval] = [1 / 703, ('', '')]

        line_dict = {
            k: {
                'c': clr_dict[v[1][0]]
            }
            for k, v in plot_dicts[mcomb_i].items() if v[1][0]
        }
        font_dict = {
            k: {
                'c': v['c'],
                'weight': 'bold'
            }
            for k, v in line_dict.items()
        }

        lbl_pos = place_scatter_labels(plot_dicts[mcomb_i],
                                       ax,
                                       plt_lims=[xlims, ylims],
                                       line_dict=line_dict,
                                       font_dict=font_dict,
                                       font_size=19)

        ax.xaxis.set_major_locator(plt.MaxNLocator(5, steps=[1, 2, 5]))
        ax.yaxis.set_major_locator(plt.MaxNLocator(7, steps=[1, 2, 5]))
        ax.set_xlim(xlims)
        ax.set_ylim(ylims)

    mult_ax.set_xlabel("Subgrouping Classification Accuracy",
                       size=21,
                       weight='bold')
    sngl_ax.set_ylabel("Overlaps' Similarity to Singletons",
                       size=21,
                       weight='bold')
    mult_ax.set_ylabel("Singletons' Similarity to Overlaps",
                       size=21,
                       weight='bold')

    plt.savefig(os.path.join(
        plot_dir,
        "{}_{}-overlap-divergence_{}.svg".format(args.ex_lbl, siml_metric,
                                                 args.classif)),
                bbox_inches='tight',
                format='svg')

    plt.close()
Exemplo n.º 3
0
def plot_overlap_aucs(pheno_dict, auc_list, cdata, data_tag, args, siml_metric,
                      use_gene):
    fig, ax = plt.subplots(figsize=(13, 7))

    use_combs = remove_pheno_dups(
        {
            mut
            for mut, auc_val in auc_list.iteritems()
            if (isinstance(mut, ExMcomb) and auc_val >= args.auc_cutoff
                and get_mut_ex(mut) == args.ex_lbl and len(mut.mtypes) == 1
                and get_label(mut) == use_gene and all(
                    pnt_mtype.is_supertype(get_subtype(mtype))
                    for mtype in mut.mtypes))
        }, pheno_dict)

    train_samps = cdata.get_train_samples()
    use_mtree = cdata.mtrees['Gene', 'Scale', 'Copy', 'Exon', 'Position',
                             'HGVSp']
    gene_tree = use_mtree[use_gene]['Point']

    if args.ex_lbl == 'Iso':
        gene_cpy = MuType({('Gene', use_gene): copy_mtype})
    else:
        gene_cpy = MuType({('Gene', use_gene): deep_mtype})

    cpy_samps = gene_cpy.get_samples(use_mtree)
    samp_counts = {samp: 0 for samp in (gene_tree.get_samples() - cpy_samps)}

    for subk in MuType(gene_tree.allkey()).leaves():
        for samp in MuType(subk).get_samples(gene_tree):
            if samp in samp_counts:
                samp_counts[samp] += 1

    for samp in train_samps:
        if samp not in samp_counts:
            samp_counts[samp] = 0

    plt_df = pd.DataFrame({'AUC': auc_list[use_combs]})
    plt_df['Muts'] = [
        np.mean([
            samp_counts[samp]
            for samp, phn in zip(train_samps, pheno_dict[mcomb]) if phn
        ]) for mcomb in plt_df.index
    ]

    plt_df['Size'] = [np.mean(pheno_dict[mcomb]) for mcomb in plt_df.index]
    plt_clr = choose_label_colour(use_gene)
    size_mult = plt_df.shape[0]**-0.23

    xlims = [
        args.auc_cutoff - (1 - args.auc_cutoff) / 47,
        1 + (1 - args.auc_cutoff) / 277
    ]
    ymin, ymax = [1, plt_df.Muts.max()]
    yrng = ymax - ymin
    ylims = [ymin - yrng / 23, ymax + yrng / 23]

    for mcomb, (auc_val, mut_val, size_val) in plt_df.iterrows():
        ax.scatter(auc_val,
                   mut_val,
                   s=3751 * size_mult * size_val,
                   c=[plt_clr],
                   alpha=0.31,
                   edgecolor='none')

    ax.grid(alpha=0.47, linewidth=0.9)
    ax.plot([1, 1], ylims, color='black', linewidth=1.7, alpha=0.83)
    ax.plot(xlims, [1, 1], color='black', linewidth=1.3, alpha=0.83)

    ax.xaxis.set_major_locator(plt.MaxNLocator(5, steps=[1, 2, 5]))
    ax.yaxis.set_major_locator(plt.MaxNLocator(7, steps=[1, 2, 5]))
    ax.set_xlim(xlims)
    ax.set_ylim(ylims)

    ax.set_xlabel("Subgrouping Classification Accuracy",
                  size=23,
                  weight='bold')
    ax.set_ylabel("Average # of {} Point Mutations"
                  "\nper Subgrouping Sample".format(use_gene),
                  size=23,
                  weight='bold')

    ax.text(0.97,
            0.07,
            get_cohort_label(data_tag.split('__')[1]),
            size=25,
            style='italic',
            ha='right',
            va='bottom',
            transform=ax.transAxes)

    plt.savefig(os.path.join(
        plot_dir, data_tag, use_gene,
        "{}_{}-overlap-aucs_{}.svg".format(args.ex_lbl, siml_metric,
                                           args.classif)),
                bbox_inches='tight',
                format='svg')

    plt.close()
Exemplo n.º 4
0
def main():
    parser = argparse.ArgumentParser(
        'setup_threshold',
        description="Load datasets and enumerate subgroupings to be tested.")

    parser.add_argument('cohort', type=str, help="a tumour cohort")
    parser.add_argument('classif', type=str, help="a mutation classifier")
    parser.add_argument(
        'out_dir',
        type=str,
    )
    parser.add_argument(
        'test_dir',
        type=str,
    )

    args = parser.parse_args()
    use_coh = args.cohort.split('_')[0]
    use_source = choose_source(use_coh)

    base_path = os.path.join(
        args.out_dir.split('subgrouping_threshold')[0],
        'subgrouping_threshold')
    coh_dir = os.path.join(base_path, 'setup')
    out_path = os.path.join(args.out_dir, 'setup')

    # find all the subvariant enumeration experiments that have run to
    # completion using the given combination of cohort and mutation classifier
    test_outs = Path(os.path.join(args.test_dir, 'subgrouping_test')).glob(
        os.path.join("{}__{}__samps-*".format(use_source, args.cohort),
                     "out-trnsf__*__{}.p.gz".format(args.classif)))

    # parse the enumeration experiment output files to find the minimum sample
    # occurence threshold used for each mutation annotation level tested
    out_datas = [Path(out_file).parts[-2:] for out_file in test_outs]
    out_df = pd.DataFrame([{
        'Samps':
        int(out_data[0].split('__samps-')[1]),
        'Levels':
        '__'.join(out_data[1].split('out-trnsf__')[1].split('__')[:-1])
    } for out_data in out_datas])

    if 'Consequence__Exon' not in set(out_df.Levels):
        raise ValueError("Cannot infer subvariant behaviour until the "
                         "`subvariant_test` experiment is run "
                         "with mutation levels `Exon__Location__Protein` "
                         "which tests genes' base mutations!")

    # load bootstrapped AUCs for enumerated subgrouping mutations
    conf_dict = dict()
    for lvls, ctf in out_df.groupby('Levels')['Samps']:
        conf_fl = os.path.join(
            args.test_dir, 'subgrouping_test',
            "{}__{}__samps-{}".format(use_source, args.cohort, ctf.values[0]),
            "out-conf__{}__{}.p.gz".format(lvls, args.classif))

        with bz2.BZ2File(conf_fl, 'r') as f:
            conf_dict[lvls] = pickle.load(f)

    conf_vals = pd.concat(conf_dict.values())
    conf_vals = conf_vals[[
        not isinstance(mtype, RandomType) for mtype in conf_vals.index
    ]]

    test_genes = {'Point': set(), 'Gain': set(), 'Loss': set()}
    for gene, conf_vec in conf_vals.groupby(
            lambda mtype: tuple(mtype.label_iter())[0]):

        if len(conf_vec) > 1:
            auc_vec = conf_vec.apply(np.mean)
            base_mtype = MuType({('Gene', gene): pnt_mtype})
            base_indx = auc_vec.index.get_loc(base_mtype)

            sub_aucs = auc_vec[:base_indx].append(auc_vec[(base_indx + 1):])
            best_subtype = sub_aucs.idxmax()

            if auc_vec[best_subtype] > 0.65:
                test_genes['Point'] |= {gene}

                for test_mtype in sub_aucs.index:
                    mtype_sub = tuple(test_mtype.subtype_iter())[0][1]

                    if ((gene not in test_genes['Gain'])
                            and not (mtype_sub & dup_mtype).is_empty()):
                        test_indx = 'Gain'

                    elif ((gene not in test_genes['Loss'])
                          and not (mtype_sub & loss_mtype).is_empty()):
                        test_indx = 'Loss'

                    else:
                        test_indx = None

                    if test_indx is not None:
                        conf_sc = np.greater.outer(
                            conf_vec[test_mtype], conf_vec[base_mtype]).mean()

                        if conf_sc > 0.75:
                            test_genes[test_indx] |= {gene}

    use_genes = list(reduce(or_, test_genes.values()))
    use_mtypes = set()
    use_ctf = int(out_df.Samps.min())
    mtree_k = ('Gene', 'Scale', 'Copy')
    use_lfs = ['ref_count', 'alt_count', 'PolyPhen', 'SIFT', 'depth']

    cdata = get_cohort_data(args.cohort,
                            use_source, [mtree_k],
                            vep_cache_dir,
                            out_path,
                            use_genes,
                            leaf_annot=use_lfs)
    with bz2.BZ2File(os.path.join(out_path, "cohort-data.p.gz"), 'w') as f:
        pickle.dump(cdata, f, protocol=-1)

    for gene, mtree in cdata.mtrees[mtree_k]:
        base_mtypes = {MuType({('Gene', gene): pnt_mtype})}

        if gene in test_genes['Gain']:
            base_mtypes |= {MuType({('Gene', gene): pnt_mtype | dup_mtype})}
        if gene in test_genes['Loss']:
            base_mtypes |= {MuType({('Gene', gene): pnt_mtype | loss_mtype})}

        for base_mtype in base_mtypes:
            base_size = len(base_mtype.get_samples(cdata.mtrees[mtree_k]))

            gene_mtypes = {
                MutThresh('VAF', vaf_val, base_mtype)
                for vaf_val in set(
                    max(alt_cnt / (alt_cnt + ref_cnt) for alt_cnt, ref_cnt in
                        zip(vals['alt_count'], vals['ref_count']))
                    for vals in pnt_mtype.get_leaf_annot(
                        mtree, ['ref_count', 'alt_count']).values())
            }

            for lf_annt in ['PolyPhen', 'SIFT', 'depth']:
                gene_mtypes |= {
                    MutThresh(lf_annt, annt_val, base_mtype)
                    for annt_val in set(
                        max(vals[lf_annt]) for vals in
                        pnt_mtype.get_leaf_annot(mtree, [lf_annt]).values())
                    if annt_val > 0
                }

            use_mtypes |= {
                mtype
                for mtype in gene_mtypes
                if (use_ctf <= len(mtype.get_samples(cdata.mtrees[mtree_k])) <
                    min(base_size,
                        len(cdata.get_samples()) - use_ctf + 1))
            }

    with open(os.path.join(out_path, "muts-list.p"), 'wb') as f:
        pickle.dump(sorted(use_mtypes), f, protocol=-1)
    with open(os.path.join(out_path, "muts-count.txt"), 'w') as fl:
        fl.write(str(len(use_mtypes)))

    # get list of available cohorts for this source of expression data
    coh_list = list_cohorts('Firehose',
                            expr_dir=expr_sources['Firehose'],
                            copy_dir=expr_sources['Firehose'])
    coh_list -= {args.cohort}
    use_feats = set(cdata.get_features())
    random.seed()

    for coh in random.sample(coh_list, k=len(coh_list)):
        coh_base = coh.split('_')[0]

        coh_tag = "cohort-data__{}__{}.p".format('Firehose', coh)
        coh_path = os.path.join(coh_dir, coh_tag)

        trnsf_cdata = load_cohort(coh,
                                  'Firehose', [mtree_k],
                                  vep_cache_dir,
                                  coh_path,
                                  out_path,
                                  use_genes,
                                  leaf_annot=use_lfs)
        use_feats &= set(trnsf_cdata.get_features())

        with open(coh_path, 'wb') as f:
            pickle.dump(trnsf_cdata, f, protocol=-1)

        copy_prc = subprocess.run(
            ['cp', coh_path, os.path.join(out_path, coh_tag)], check=True)

    with open(os.path.join(out_path, "feat-list.p"), 'wb') as f:
        pickle.dump(use_feats, f, protocol=-1)
Exemplo n.º 5
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        '--tune_splits', type=int, default=4,
        help='how many training cohort splits to use for tuning'
        )
    parser.add_argument(
        '--test_count', type=int, default=12,
        help='how many hyper-parameter values to test in each tuning split'
        )

    parser.add_argument(
        '--infer_splits', type=int, default=12,
        help='how many cohort splits to use for inference bootstrapping'
        )
    parser.add_argument(
        '--infer_folds', type=int, default=4,
        help=('how many parts to split the cohort into in each inference '
              'cross-validation run')
        )

    parser.add_argument(
        '--parallel_jobs', type=int, default=12,
        help='how many parallel CPUs to allocate the tuning tests across'
        )

    parser.add_argument('--cv_id', type=int, default=0)
    parser.add_argument('--verbose', '-v', action='store_true',
                        help='turns on diagnostic messages')

    args = parser.parse_args()
    out_dir = os.path.join(base_dir, 'output', 'gene_models')
    os.makedirs(out_dir, exist_ok=True)
    out_file = os.path.join(out_dir,
                            'cv-{}.p'.format(args.cv_id))

    # log into Synapse using locally stored credentials
    syn = synapseclient.Synapse()
    syn.cache.cache_root_dir = syn_root
    syn.login()

    mut_clf = StanPipe()
    gene_df = pd.read_csv(gene_list, sep='\t', skiprows=1, index_col=0)

    use_genes = gene_df.index[
        (gene_df.loc[:, ['Vogelstein', 'Sanger CGC',
                         'Foundation One', 'MSK-IMPACT']]
         == 'Yes').sum(axis=1) > 1
        ]

    cdata = PatientMutationCohort(
        patient_expr=load_beat_expression(beataml_expr), patient_muts=None,
        tcga_cohort='LAML', mut_genes=use_genes, mut_levels=['Gene', 'Form'],
        expr_source='toil', var_source='mc3', copy_source='Firehose',
        annot_file=annot_file, expr_dir=toil_dir, copy_dir=copy_dir,
        cv_seed=(args.cv_id * 59) + 121, cv_prop=1.0,
        collapse_txs=False, syn=syn
        )

    use_mtypes = set()
    for gene, muts in cdata.train_mut:
        if len(muts) >= 15:

            if 'Copy' in dict(muts) and len(muts['Copy']) >= 15:
                if 'HomDel' in dict(muts['Copy']):
                    if len(muts['Copy']['HomDel']) >= 15:
                        use_mtypes |= {MuType({('Gene', gene): {
                            ('Scale', 'Copy'): {('Copy', 'HomDel'): None}}})}

                if 'HomGain' in dict(muts['Copy']):
                    if len(muts['Copy']['HomGain']) >= 15:
                        use_mtypes |= {MuType({('Gene', gene): {
                            ('Scale', 'Copy'): {('Copy', 'HomGain'): None}}})}

                loss_mtype = MuType({('Copy', ('HomDel', 'HetDel')): None})
                if 'HetDel' in dict(muts['Copy']):
                    if len(loss_mtype.get_samples(muts['Copy'])) >= 15:
                        use_mtypes |= {MuType({('Gene', gene): {
                            ('Scale', 'Copy'): loss_mtype}})}

                gain_mtype = MuType({('Copy', ('HomGain', 'HetGain')): None})
                if 'HetGain' in dict(muts['Copy']):
                    if len(gain_mtype.get_samples(muts['Copy'])) >= 15:
                        use_mtypes |= {MuType({('Gene', gene): {
                            ('Scale', 'Copy'): gain_mtype}})}

            if 'Point' in dict(muts) and len(muts['Point']) >= 15:
                use_mtypes |= {MuType({('Gene', gene): {
                    ('Scale', 'Point'): None}})}

                use_mtypes |= {
                    MuType({('Gene', gene): {('Scale', 'Point'): mtype}})
                    for mtype in muts['Point'].branchtypes(min_size=15)
                    }

    tuned_params = {mtype: None for mtype in use_mtypes}
    infer_mats = {mtype: None for mtype in use_mtypes}

    for mtype in use_mtypes:
        mut_gene = mtype.subtype_list()[0][0]
        ex_genes = {gene for gene, annot in cdata.gene_annot.items()
                    if annot['chr'] == cdata.gene_annot[mut_gene]['chr']}

        mut_clf.tune_coh(
            cdata, mtype,
            exclude_genes=ex_genes, exclude_samps=cdata.patient_samps,
            tune_splits=args.tune_splits, test_count=args.test_count,
            parallel_jobs=args.parallel_jobs
            )
        print(mut_clf)

        clf_params = mut_clf.get_params()
        tuned_params[mtype] = {par: clf_params[par]
                               for par, _ in StanPipe.tune_priors}
        
        infer_mats[mtype] = mut_clf.infer_coh(
            cdata, mtype,
            force_test_samps=cdata.patient_samps, exclude_genes=ex_genes,
            infer_splits=args.infer_splits, infer_folds=args.infer_folds
            )

    pickle.dump(
        {'Infer': infer_mats, 'Tune': tuned_params},
        open(out_file, 'wb')
        )
Exemplo n.º 6
0
def main():
    """Runs the experiment."""

    parser = argparse.ArgumentParser(
        "Use a classifier to infer scores for mutations using a naive "
        "approach and an isolation approach, then test how well this "
        "classifier transfers across TCGA cohorts."
        )

    parser.add_argument('classif', type=str,
                        help="a classifier in HetMan.predict.classifiers")
    parser.add_argument('ex_mtype', type=str, choices=list(ex_mtypes.keys()))

    parser.add_argument('--use_dir', type=str, default=base_dir)
    parser.add_argument(
        '--task_count', type=int, default=10,
        help='how many parallel tasks the list of types to test is split into'
        )
    parser.add_argument('--task_id', type=int, default=0,
                        help='the subset of subtypes to assign to this task')

    args = parser.parse_args()
    setup_dir = os.path.join(args.use_dir, 'setup')

    # load expression and mutation data for the cohorts used
    with open(os.path.join(setup_dir, "cohort-data.p"), 'rb') as cdata_f:
        cdata = pickle.load(cdata_f)

    # load the list of mutations to create inferred scores for
    with open(os.path.join(setup_dir, "muts-list.p"), 'rb') as muts_f:
        mtype_list = pickle.load(muts_f)

    # load the classifier used to produce the mutation scores
    clf = eval(args.classif)
    clf.predict_proba = clf.calc_pred_labels
    mut_clf = clf()

    out_tune = {test: {smps: {par: None for par, _ in mut_clf.tune_priors}
                       for smps in ['All', 'Iso']}
                for test in mtype_list}
    out_inf = {test: {'All': None, 'Iso': None} for test in mtype_list}

    for i, (cohort, mtype) in enumerate(mtype_list):
        if (i % args.task_count) == args.task_id:
            print("Isolating {} in cohort {} ...".format(mtype, cohort))

            # get the gene associated with this mutation, and the genes
            # appearing on the same chromosome for exclusion in classification
            use_gene = mtype.subtype_list()[0][0]
            use_chr = cdata.gene_annot[use_gene]['Chr']
            ex_genes = {gene for gene, annot in cdata.gene_annot.items()
                        if annot['Chr'] == use_chr}

            # get the set of mutations from the same gene that will be hidden
            # from the classifier in the isolation approach
            base_mtype = MuType(cdata.train_mut[use_gene].allkey())
            base_mtype -= ex_mtypes[args.ex_mtype]

            # get the samples that will be hidden in the isolation approach
            coh_samps = cdata.cohort_samps[cohort.split('_')[0]]
            base_samps = base_mtype.get_samples(cdata.train_mut[use_gene])
            base_samps &= coh_samps
            ex_samps = base_samps - mtype.get_samples(cdata.train_mut)

            # tune the classifier on the default approach task
            mut_clf, cv_output = mut_clf.tune_coh(
                cdata, mtype, include_samps=coh_samps, exclude_genes=ex_genes,
                tune_splits=4, test_count=48, parallel_jobs=12
                )

            # get the tuned parameters for the default approach task
            clf_params = mut_clf.get_params()
            for par, _ in mut_clf.tune_priors:
                out_tune[(cohort, mtype)]['All'][par] = clf_params[par]

            out_inf[(cohort, mtype)]['All'] = mut_clf.infer_coh(
                cdata, mtype, exclude_genes=ex_genes,
                force_test_samps=cdata.samples - coh_samps,
                infer_splits=120, infer_folds=4, parallel_jobs=12
                )

            # tune the classifier on the isolation approach task
            mut_clf, cv_output = mut_clf.tune_coh(
                cdata, mtype, exclude_genes=ex_genes,
                exclude_samps=ex_samps | (cdata.samples - coh_samps),
                tune_splits=4, test_count=48, parallel_jobs=12
                )

            # get the tuned parameters for the isolation approach task
            clf_params = mut_clf.get_params()
            for par, _ in mut_clf.tune_priors:
                out_tune[(cohort, mtype)]['Iso'][par] = clf_params[par]

            out_inf[(cohort, mtype)]['Iso'] = mut_clf.infer_coh(
                cdata, mtype, exclude_genes=ex_genes,
                force_test_samps=ex_samps | (cdata.samples - coh_samps),
                infer_splits=120, infer_folds=4, parallel_jobs=12
                )

        else:
            del(out_inf[(cohort, mtype)])
            del(out_tune[(cohort, mtype)])

    # save the experiment results for this subtask to file
    pickle.dump({'Infer': out_inf, 'Tune': out_tune, 'Clf': mut_clf},
                open(os.path.join(args.use_dir, 'output',
                                  "out_task-{}.p".format(args.task_id)),
                     'wb'))