def main(): parser = argparse.ArgumentParser( "Set up the paired-gene subtype expression effect isolation " "experiment by enumerating the subtypes to be tested.") # create positional command line arguments parser.add_argument('cohort', type=str, help="which TCGA cohort to use") parser.add_argument('mut_levels', type=str, help="the mutation property levels to consider") parser.add_argument('genes', type=str, nargs='+', help="a list of mutated genes") # create optional command line arguments parser.add_argument('--samp_cutoff', type=int, default=20, help='subtype sample frequency threshold') parser.add_argument('--verbose', '-v', action='store_true', help='turns on diagnostic messages') # parse command line arguments, create directory where found subtypes # will be stored args = parser.parse_args() use_lvls = args.mut_levels.split('__') out_path = os.path.join(base_dir, 'setup', args.cohort, '_'.join(args.genes)) os.makedirs(out_path, exist_ok=True) # log into Synapse using locally stored credentials syn = synapseclient.Synapse() syn.cache.cache_root_dir = syn_root syn.login() cdata = MutationCohort(cohort=args.cohort, mut_genes=args.genes, mut_levels=['Gene'] + use_lvls, expr_source='Firehose', var_source='mc3', copy_source='Firehose', annot_file=annot_file, expr_dir=expr_dir, domain_dir=domain_dir, cv_prop=1.0, syn=syn) iso_mtypes = set() for gene in args.genes: other_samps = reduce(or_, [ cdata.train_mut[other_gn].get_samples() for other_gn in set(args.genes) - {gene} ]) if args.verbose: print("Looking for combinations of subtypes of mutations in gene " "{} present in at least {} of the samples in TCGA cohort " "{} at annotation levels {}.\n".format( gene, args.samp_cutoff, args.cohort, use_lvls)) pnt_mtypes = cdata.train_mut[gene]['Point'].find_unique_subtypes( max_types=500, max_combs=2, verbose=2, sub_levels=use_lvls, min_type_size=args.samp_cutoff) # filter out the subtypes that appear in too many samples for there to # be a wild-type class of sufficient size for classification pnt_mtypes = { MuType({('Scale', 'Point'): mtype}) for mtype in pnt_mtypes if (len(mtype.get_samples(cdata.train_mut[gene]['Point'])) <= ( len(cdata.samples) - args.samp_cutoff)) } pnt_mtypes |= {MuType({('Scale', 'Point'): None})} cna_mtypes = cdata.train_mut[gene]['Copy'].branchtypes( min_size=args.samp_cutoff) cna_mtypes |= {MuType({('Copy', ('HetGain', 'HomGain')): None})} cna_mtypes |= {MuType({('Copy', ('HetDel', 'HomDel')): None})} cna_mtypes = { MuType({('Scale', 'Copy'): mtype}) for mtype in cna_mtypes if (len(mtype.get_samples(cdata.train_mut[gene]['Copy'])) <= ( len(cdata.samples) - args.samp_cutoff)) } all_mtype = MuType(cdata.train_mut[gene].allkey()) use_mtypes = pnt_mtypes | cna_mtypes only_mtypes = { (MuType({('Gene', gene): mtype}), ) for mtype in use_mtypes if (len( mtype.get_samples(cdata.train_mut[gene]) - (all_mtype - mtype).get_samples(cdata.train_mut[gene]) - other_samps) >= args.samp_cutoff) } comb_mtypes = {(MuType({('Gene', gene): mtype1}), MuType({('Gene', gene): mtype2})) for mtype1, mtype2 in combn(use_mtypes, 2) if ((mtype1 & mtype2).is_empty() and ( len((mtype1.get_samples(cdata.train_mut[gene]) & mtype2.get_samples(cdata.train_mut[gene])) - (mtype1.get_samples(cdata.train_mut[gene]) ^ mtype2.get_samples(cdata.train_mut[gene])) - (all_mtype - mtype1 - mtype2).get_samples(cdata.train_mut[gene]) - other_samps) >= args.samp_cutoff))} iso_mtypes |= only_mtypes | comb_mtypes if args.verbose: print( "\nFound {} exclusive sub-types and {} combination sub-types " "to isolate!".format(len(only_mtypes), len(comb_mtypes))) for cur_genes in chain.from_iterable( combn(args.genes, r) for r in range(1, len(args.genes))): gene_mtype = MuType({('Gene', cur_genes): None}) rest_mtype = MuType({ ('Gene', tuple(set(args.genes) - set(cur_genes))): None }) if (args.samp_cutoff <= len( gene_mtype.get_samples(cdata.train_mut) - rest_mtype.get_samples(cdata.train_mut)) <= (len(cdata.samples) - args.samp_cutoff)): iso_mtypes |= {(gene_mtype, )} if args.verbose: print("\nFound {} total sub-types to isolate!".format(len(iso_mtypes))) # save the list of found non-duplicate sub-types to file pickle.dump( sorted(iso_mtypes), open( os.path.join( out_path, 'mtypes_list__samps_{}__levels_{}.p'.format( args.samp_cutoff, args.mut_levels)), 'wb')) with open( os.path.join( out_path, 'mtypes_count__samps_{}__levels_{}.txt'.format( args.samp_cutoff, args.mut_levels)), 'w') as fl: fl.write(str(len(iso_mtypes)))
def plot_overlap_divergence(pred_dfs, pheno_dicts, auc_lists, cdata_dict, args, siml_metric): fig, (sngl_ax, mult_ax) = plt.subplots(figsize=(12, 14), nrows=2) siml_dicts = {(src, coh): dict() for src, coh in auc_lists} gn_dict = dict() test = dict() # for each dataset, find the subgroupings meeting the minimum task AUC # that are exclusively defined and subsets of point mutations... for (src, coh), auc_list in auc_lists.items(): test[src, coh] = dict() use_combs = remove_pheno_dups( { mut for mut, auc_val in auc_list.iteritems() if (isinstance(mut, ExMcomb) and auc_val >= args.auc_cutoff and get_mut_ex(mut) == args.ex_lbl and all( pnt_mtype.is_supertype(get_subtype(mtype)) for mtype in mut.mtypes)) }, pheno_dicts[src, coh]) # skip this dataset for plotting if we cannot find any such pairs if not use_combs: continue # get sample order used in the cohort and a breakdown of mutations # in which each individual mutation can be uniquely identified train_samps = cdata_dict[src, coh].get_train_samples() use_mtree = cdata_dict[src, coh].mtrees['Gene', 'Scale', 'Copy', 'Exon', 'Position', 'HGVSp'] use_genes = {get_label(mcomb) for mcomb in use_combs} cmp_phns = {gene: {'Sngl': None, 'Mult': None} for gene in use_genes} # get the samples carrying a single point mutation or multiple # mutations of each gene with at least one mutation in the cohort for gene in use_genes: gene_tree = use_mtree[gene]['Point'] if args.ex_lbl == 'Iso': gene_cpy = MuType({('Gene', gene): copy_mtype}) else: gene_cpy = MuType({('Gene', gene): deep_mtype}) cpy_samps = gene_cpy.get_samples(use_mtree) samp_counts = { samp: 0 for samp in (gene_tree.get_samples() - cpy_samps) } for subk in MuType(gene_tree.allkey()).leaves(): for samp in MuType(subk).get_samples(gene_tree): if samp in samp_counts: samp_counts[samp] += 1 for samp in train_samps: if samp not in samp_counts: samp_counts[samp] = 0 cmp_phns[gene]['Sngl'] = np.array( [samp_counts[samp] == 1 for samp in train_samps]) cmp_phns[gene]['Mult'] = np.array( [samp_counts[samp] > 1 for samp in train_samps]) all_mtypes = { gene: MuType({('Gene', gene): use_mtree[gene].allkey()}) for gene in use_genes } if args.ex_lbl == 'IsoShal': for gene in use_genes: all_mtypes[gene] -= MuType({('Gene', gene): shal_mtype}) all_phns = { gene: np.array(cdata_dict[src, coh].train_pheno(all_mtype)) for gene, all_mtype in all_mtypes.items() } # for each subgrouping, find the subset of point mutations that # defines it, the gene it's associated with, and its task predictions for mcomb in use_combs: cur_gene = get_label(mcomb) use_preds = pred_dfs[src, coh].loc[mcomb, train_samps] # get the samples that carry any point mutation of this gene if (src, coh, cur_gene) not in gn_dict: gn_dict[src, coh, cur_gene] = np.array(cdata_dict[src, coh].train_pheno( MuType({('Gene', cur_gene): pnt_mtype}))) # find the samples carrying one or multiple point mutations of # this gene not belonging to this subgrouping cmp_phn = ~pheno_dicts[src, coh][mcomb] if len(mcomb.mtypes) == 1: cmp_phn &= cmp_phns[cur_gene]['Mult'] else: cmp_phn &= cmp_phns[cur_gene]['Sngl'] if cmp_phn.sum() >= 1: siml_dicts[src, coh][mcomb] = siml_fxs[siml_metric]( use_preds.loc[~all_phns[cur_gene]], use_preds.loc[pheno_dicts[src, coh][mcomb]], use_preds.loc[cmp_phn]) test[src, coh][mcomb] = sum(pheno_dicts[src, coh][mcomb] & cmp_phns[cur_gene]['Mult']) plt_df = pd.DataFrame( {'Siml': pd.DataFrame.from_records(siml_dicts).stack()}) plt_df['AUC'] = [ auc_lists[src, coh][mcomb] for mcomb, (src, coh) in plt_df.index ] gene_means = plt_df.groupby(lambda x: (get_label(x[0]), len(x[0].mtypes))).mean() clr_dict = { gene: choose_label_colour(gene) for gene, _ in gene_means.index } size_mult = plt_df.groupby( lambda x: len(x[0].mtypes)).Siml.count().max()**-0.23 xlims = [ args.auc_cutoff - (1 - args.auc_cutoff) / 47, 1 + (1 - args.auc_cutoff) / 277 ] ymin, ymax = plt_df.Siml.quantile(q=[0, 1]) yrng = ymax - ymin ylims = [ymin - yrng / 23, ymax + yrng / 23] plot_dicts = { mcomb_i: {(auc_val, siml_val): [0.0001, (gene, '')] for (gene, mcomb_indx), (siml_val, auc_val) in gene_means.iterrows() if mcomb_indx == mcomb_i} for mcomb_i in [1, 2] } for (mcomb, (src, coh)), (siml_val, auc_val) in plt_df.iterrows(): cur_gene = get_label(mcomb) plt_size = size_mult * np.mean(pheno_dicts[src, coh][mcomb]) plot_dicts[(len(mcomb.mtypes) == 2) + 1][auc_val, siml_val] = [0.19 * plt_size, ('', '')] if len(mcomb.mtypes) == 1: use_ax = sngl_ax else: use_ax = mult_ax use_ax.scatter(auc_val, siml_val, s=3751 * plt_size, c=[clr_dict[cur_gene]], alpha=0.25, edgecolor='none') for ax, mcomb_i in zip([sngl_ax, mult_ax], [1, 2]): ax.grid(alpha=0.47, linewidth=0.9) ax.plot([1, 1], ylims, color='black', linewidth=1.7, alpha=0.83) for yval in [0, 1]: if xlims[0] < yval < xlims[1]: ax.plot(xlims, [yval, yval], color='black', linewidth=1.11, linestyle='--', alpha=0.67) for k in np.linspace(args.auc_cutoff, 0.99, 200): if (k, yval) not in plot_dicts[mcomb_i]: plot_dicts[mcomb_i][k, yval] = [1 / 703, ('', '')] line_dict = { k: { 'c': clr_dict[v[1][0]] } for k, v in plot_dicts[mcomb_i].items() if v[1][0] } font_dict = { k: { 'c': v['c'], 'weight': 'bold' } for k, v in line_dict.items() } lbl_pos = place_scatter_labels(plot_dicts[mcomb_i], ax, plt_lims=[xlims, ylims], line_dict=line_dict, font_dict=font_dict, font_size=19) ax.xaxis.set_major_locator(plt.MaxNLocator(5, steps=[1, 2, 5])) ax.yaxis.set_major_locator(plt.MaxNLocator(7, steps=[1, 2, 5])) ax.set_xlim(xlims) ax.set_ylim(ylims) mult_ax.set_xlabel("Subgrouping Classification Accuracy", size=21, weight='bold') sngl_ax.set_ylabel("Overlaps' Similarity to Singletons", size=21, weight='bold') mult_ax.set_ylabel("Singletons' Similarity to Overlaps", size=21, weight='bold') plt.savefig(os.path.join( plot_dir, "{}_{}-overlap-divergence_{}.svg".format(args.ex_lbl, siml_metric, args.classif)), bbox_inches='tight', format='svg') plt.close()
def plot_overlap_aucs(pheno_dict, auc_list, cdata, data_tag, args, siml_metric, use_gene): fig, ax = plt.subplots(figsize=(13, 7)) use_combs = remove_pheno_dups( { mut for mut, auc_val in auc_list.iteritems() if (isinstance(mut, ExMcomb) and auc_val >= args.auc_cutoff and get_mut_ex(mut) == args.ex_lbl and len(mut.mtypes) == 1 and get_label(mut) == use_gene and all( pnt_mtype.is_supertype(get_subtype(mtype)) for mtype in mut.mtypes)) }, pheno_dict) train_samps = cdata.get_train_samples() use_mtree = cdata.mtrees['Gene', 'Scale', 'Copy', 'Exon', 'Position', 'HGVSp'] gene_tree = use_mtree[use_gene]['Point'] if args.ex_lbl == 'Iso': gene_cpy = MuType({('Gene', use_gene): copy_mtype}) else: gene_cpy = MuType({('Gene', use_gene): deep_mtype}) cpy_samps = gene_cpy.get_samples(use_mtree) samp_counts = {samp: 0 for samp in (gene_tree.get_samples() - cpy_samps)} for subk in MuType(gene_tree.allkey()).leaves(): for samp in MuType(subk).get_samples(gene_tree): if samp in samp_counts: samp_counts[samp] += 1 for samp in train_samps: if samp not in samp_counts: samp_counts[samp] = 0 plt_df = pd.DataFrame({'AUC': auc_list[use_combs]}) plt_df['Muts'] = [ np.mean([ samp_counts[samp] for samp, phn in zip(train_samps, pheno_dict[mcomb]) if phn ]) for mcomb in plt_df.index ] plt_df['Size'] = [np.mean(pheno_dict[mcomb]) for mcomb in plt_df.index] plt_clr = choose_label_colour(use_gene) size_mult = plt_df.shape[0]**-0.23 xlims = [ args.auc_cutoff - (1 - args.auc_cutoff) / 47, 1 + (1 - args.auc_cutoff) / 277 ] ymin, ymax = [1, plt_df.Muts.max()] yrng = ymax - ymin ylims = [ymin - yrng / 23, ymax + yrng / 23] for mcomb, (auc_val, mut_val, size_val) in plt_df.iterrows(): ax.scatter(auc_val, mut_val, s=3751 * size_mult * size_val, c=[plt_clr], alpha=0.31, edgecolor='none') ax.grid(alpha=0.47, linewidth=0.9) ax.plot([1, 1], ylims, color='black', linewidth=1.7, alpha=0.83) ax.plot(xlims, [1, 1], color='black', linewidth=1.3, alpha=0.83) ax.xaxis.set_major_locator(plt.MaxNLocator(5, steps=[1, 2, 5])) ax.yaxis.set_major_locator(plt.MaxNLocator(7, steps=[1, 2, 5])) ax.set_xlim(xlims) ax.set_ylim(ylims) ax.set_xlabel("Subgrouping Classification Accuracy", size=23, weight='bold') ax.set_ylabel("Average # of {} Point Mutations" "\nper Subgrouping Sample".format(use_gene), size=23, weight='bold') ax.text(0.97, 0.07, get_cohort_label(data_tag.split('__')[1]), size=25, style='italic', ha='right', va='bottom', transform=ax.transAxes) plt.savefig(os.path.join( plot_dir, data_tag, use_gene, "{}_{}-overlap-aucs_{}.svg".format(args.ex_lbl, siml_metric, args.classif)), bbox_inches='tight', format='svg') plt.close()
def main(): parser = argparse.ArgumentParser( 'setup_threshold', description="Load datasets and enumerate subgroupings to be tested.") parser.add_argument('cohort', type=str, help="a tumour cohort") parser.add_argument('classif', type=str, help="a mutation classifier") parser.add_argument( 'out_dir', type=str, ) parser.add_argument( 'test_dir', type=str, ) args = parser.parse_args() use_coh = args.cohort.split('_')[0] use_source = choose_source(use_coh) base_path = os.path.join( args.out_dir.split('subgrouping_threshold')[0], 'subgrouping_threshold') coh_dir = os.path.join(base_path, 'setup') out_path = os.path.join(args.out_dir, 'setup') # find all the subvariant enumeration experiments that have run to # completion using the given combination of cohort and mutation classifier test_outs = Path(os.path.join(args.test_dir, 'subgrouping_test')).glob( os.path.join("{}__{}__samps-*".format(use_source, args.cohort), "out-trnsf__*__{}.p.gz".format(args.classif))) # parse the enumeration experiment output files to find the minimum sample # occurence threshold used for each mutation annotation level tested out_datas = [Path(out_file).parts[-2:] for out_file in test_outs] out_df = pd.DataFrame([{ 'Samps': int(out_data[0].split('__samps-')[1]), 'Levels': '__'.join(out_data[1].split('out-trnsf__')[1].split('__')[:-1]) } for out_data in out_datas]) if 'Consequence__Exon' not in set(out_df.Levels): raise ValueError("Cannot infer subvariant behaviour until the " "`subvariant_test` experiment is run " "with mutation levels `Exon__Location__Protein` " "which tests genes' base mutations!") # load bootstrapped AUCs for enumerated subgrouping mutations conf_dict = dict() for lvls, ctf in out_df.groupby('Levels')['Samps']: conf_fl = os.path.join( args.test_dir, 'subgrouping_test', "{}__{}__samps-{}".format(use_source, args.cohort, ctf.values[0]), "out-conf__{}__{}.p.gz".format(lvls, args.classif)) with bz2.BZ2File(conf_fl, 'r') as f: conf_dict[lvls] = pickle.load(f) conf_vals = pd.concat(conf_dict.values()) conf_vals = conf_vals[[ not isinstance(mtype, RandomType) for mtype in conf_vals.index ]] test_genes = {'Point': set(), 'Gain': set(), 'Loss': set()} for gene, conf_vec in conf_vals.groupby( lambda mtype: tuple(mtype.label_iter())[0]): if len(conf_vec) > 1: auc_vec = conf_vec.apply(np.mean) base_mtype = MuType({('Gene', gene): pnt_mtype}) base_indx = auc_vec.index.get_loc(base_mtype) sub_aucs = auc_vec[:base_indx].append(auc_vec[(base_indx + 1):]) best_subtype = sub_aucs.idxmax() if auc_vec[best_subtype] > 0.65: test_genes['Point'] |= {gene} for test_mtype in sub_aucs.index: mtype_sub = tuple(test_mtype.subtype_iter())[0][1] if ((gene not in test_genes['Gain']) and not (mtype_sub & dup_mtype).is_empty()): test_indx = 'Gain' elif ((gene not in test_genes['Loss']) and not (mtype_sub & loss_mtype).is_empty()): test_indx = 'Loss' else: test_indx = None if test_indx is not None: conf_sc = np.greater.outer( conf_vec[test_mtype], conf_vec[base_mtype]).mean() if conf_sc > 0.75: test_genes[test_indx] |= {gene} use_genes = list(reduce(or_, test_genes.values())) use_mtypes = set() use_ctf = int(out_df.Samps.min()) mtree_k = ('Gene', 'Scale', 'Copy') use_lfs = ['ref_count', 'alt_count', 'PolyPhen', 'SIFT', 'depth'] cdata = get_cohort_data(args.cohort, use_source, [mtree_k], vep_cache_dir, out_path, use_genes, leaf_annot=use_lfs) with bz2.BZ2File(os.path.join(out_path, "cohort-data.p.gz"), 'w') as f: pickle.dump(cdata, f, protocol=-1) for gene, mtree in cdata.mtrees[mtree_k]: base_mtypes = {MuType({('Gene', gene): pnt_mtype})} if gene in test_genes['Gain']: base_mtypes |= {MuType({('Gene', gene): pnt_mtype | dup_mtype})} if gene in test_genes['Loss']: base_mtypes |= {MuType({('Gene', gene): pnt_mtype | loss_mtype})} for base_mtype in base_mtypes: base_size = len(base_mtype.get_samples(cdata.mtrees[mtree_k])) gene_mtypes = { MutThresh('VAF', vaf_val, base_mtype) for vaf_val in set( max(alt_cnt / (alt_cnt + ref_cnt) for alt_cnt, ref_cnt in zip(vals['alt_count'], vals['ref_count'])) for vals in pnt_mtype.get_leaf_annot( mtree, ['ref_count', 'alt_count']).values()) } for lf_annt in ['PolyPhen', 'SIFT', 'depth']: gene_mtypes |= { MutThresh(lf_annt, annt_val, base_mtype) for annt_val in set( max(vals[lf_annt]) for vals in pnt_mtype.get_leaf_annot(mtree, [lf_annt]).values()) if annt_val > 0 } use_mtypes |= { mtype for mtype in gene_mtypes if (use_ctf <= len(mtype.get_samples(cdata.mtrees[mtree_k])) < min(base_size, len(cdata.get_samples()) - use_ctf + 1)) } with open(os.path.join(out_path, "muts-list.p"), 'wb') as f: pickle.dump(sorted(use_mtypes), f, protocol=-1) with open(os.path.join(out_path, "muts-count.txt"), 'w') as fl: fl.write(str(len(use_mtypes))) # get list of available cohorts for this source of expression data coh_list = list_cohorts('Firehose', expr_dir=expr_sources['Firehose'], copy_dir=expr_sources['Firehose']) coh_list -= {args.cohort} use_feats = set(cdata.get_features()) random.seed() for coh in random.sample(coh_list, k=len(coh_list)): coh_base = coh.split('_')[0] coh_tag = "cohort-data__{}__{}.p".format('Firehose', coh) coh_path = os.path.join(coh_dir, coh_tag) trnsf_cdata = load_cohort(coh, 'Firehose', [mtree_k], vep_cache_dir, coh_path, out_path, use_genes, leaf_annot=use_lfs) use_feats &= set(trnsf_cdata.get_features()) with open(coh_path, 'wb') as f: pickle.dump(trnsf_cdata, f, protocol=-1) copy_prc = subprocess.run( ['cp', coh_path, os.path.join(out_path, coh_tag)], check=True) with open(os.path.join(out_path, "feat-list.p"), 'wb') as f: pickle.dump(use_feats, f, protocol=-1)
def main(): parser = argparse.ArgumentParser() parser.add_argument( '--tune_splits', type=int, default=4, help='how many training cohort splits to use for tuning' ) parser.add_argument( '--test_count', type=int, default=12, help='how many hyper-parameter values to test in each tuning split' ) parser.add_argument( '--infer_splits', type=int, default=12, help='how many cohort splits to use for inference bootstrapping' ) parser.add_argument( '--infer_folds', type=int, default=4, help=('how many parts to split the cohort into in each inference ' 'cross-validation run') ) parser.add_argument( '--parallel_jobs', type=int, default=12, help='how many parallel CPUs to allocate the tuning tests across' ) parser.add_argument('--cv_id', type=int, default=0) parser.add_argument('--verbose', '-v', action='store_true', help='turns on diagnostic messages') args = parser.parse_args() out_dir = os.path.join(base_dir, 'output', 'gene_models') os.makedirs(out_dir, exist_ok=True) out_file = os.path.join(out_dir, 'cv-{}.p'.format(args.cv_id)) # log into Synapse using locally stored credentials syn = synapseclient.Synapse() syn.cache.cache_root_dir = syn_root syn.login() mut_clf = StanPipe() gene_df = pd.read_csv(gene_list, sep='\t', skiprows=1, index_col=0) use_genes = gene_df.index[ (gene_df.loc[:, ['Vogelstein', 'Sanger CGC', 'Foundation One', 'MSK-IMPACT']] == 'Yes').sum(axis=1) > 1 ] cdata = PatientMutationCohort( patient_expr=load_beat_expression(beataml_expr), patient_muts=None, tcga_cohort='LAML', mut_genes=use_genes, mut_levels=['Gene', 'Form'], expr_source='toil', var_source='mc3', copy_source='Firehose', annot_file=annot_file, expr_dir=toil_dir, copy_dir=copy_dir, cv_seed=(args.cv_id * 59) + 121, cv_prop=1.0, collapse_txs=False, syn=syn ) use_mtypes = set() for gene, muts in cdata.train_mut: if len(muts) >= 15: if 'Copy' in dict(muts) and len(muts['Copy']) >= 15: if 'HomDel' in dict(muts['Copy']): if len(muts['Copy']['HomDel']) >= 15: use_mtypes |= {MuType({('Gene', gene): { ('Scale', 'Copy'): {('Copy', 'HomDel'): None}}})} if 'HomGain' in dict(muts['Copy']): if len(muts['Copy']['HomGain']) >= 15: use_mtypes |= {MuType({('Gene', gene): { ('Scale', 'Copy'): {('Copy', 'HomGain'): None}}})} loss_mtype = MuType({('Copy', ('HomDel', 'HetDel')): None}) if 'HetDel' in dict(muts['Copy']): if len(loss_mtype.get_samples(muts['Copy'])) >= 15: use_mtypes |= {MuType({('Gene', gene): { ('Scale', 'Copy'): loss_mtype}})} gain_mtype = MuType({('Copy', ('HomGain', 'HetGain')): None}) if 'HetGain' in dict(muts['Copy']): if len(gain_mtype.get_samples(muts['Copy'])) >= 15: use_mtypes |= {MuType({('Gene', gene): { ('Scale', 'Copy'): gain_mtype}})} if 'Point' in dict(muts) and len(muts['Point']) >= 15: use_mtypes |= {MuType({('Gene', gene): { ('Scale', 'Point'): None}})} use_mtypes |= { MuType({('Gene', gene): {('Scale', 'Point'): mtype}}) for mtype in muts['Point'].branchtypes(min_size=15) } tuned_params = {mtype: None for mtype in use_mtypes} infer_mats = {mtype: None for mtype in use_mtypes} for mtype in use_mtypes: mut_gene = mtype.subtype_list()[0][0] ex_genes = {gene for gene, annot in cdata.gene_annot.items() if annot['chr'] == cdata.gene_annot[mut_gene]['chr']} mut_clf.tune_coh( cdata, mtype, exclude_genes=ex_genes, exclude_samps=cdata.patient_samps, tune_splits=args.tune_splits, test_count=args.test_count, parallel_jobs=args.parallel_jobs ) print(mut_clf) clf_params = mut_clf.get_params() tuned_params[mtype] = {par: clf_params[par] for par, _ in StanPipe.tune_priors} infer_mats[mtype] = mut_clf.infer_coh( cdata, mtype, force_test_samps=cdata.patient_samps, exclude_genes=ex_genes, infer_splits=args.infer_splits, infer_folds=args.infer_folds ) pickle.dump( {'Infer': infer_mats, 'Tune': tuned_params}, open(out_file, 'wb') )
def main(): """Runs the experiment.""" parser = argparse.ArgumentParser( "Use a classifier to infer scores for mutations using a naive " "approach and an isolation approach, then test how well this " "classifier transfers across TCGA cohorts." ) parser.add_argument('classif', type=str, help="a classifier in HetMan.predict.classifiers") parser.add_argument('ex_mtype', type=str, choices=list(ex_mtypes.keys())) parser.add_argument('--use_dir', type=str, default=base_dir) parser.add_argument( '--task_count', type=int, default=10, help='how many parallel tasks the list of types to test is split into' ) parser.add_argument('--task_id', type=int, default=0, help='the subset of subtypes to assign to this task') args = parser.parse_args() setup_dir = os.path.join(args.use_dir, 'setup') # load expression and mutation data for the cohorts used with open(os.path.join(setup_dir, "cohort-data.p"), 'rb') as cdata_f: cdata = pickle.load(cdata_f) # load the list of mutations to create inferred scores for with open(os.path.join(setup_dir, "muts-list.p"), 'rb') as muts_f: mtype_list = pickle.load(muts_f) # load the classifier used to produce the mutation scores clf = eval(args.classif) clf.predict_proba = clf.calc_pred_labels mut_clf = clf() out_tune = {test: {smps: {par: None for par, _ in mut_clf.tune_priors} for smps in ['All', 'Iso']} for test in mtype_list} out_inf = {test: {'All': None, 'Iso': None} for test in mtype_list} for i, (cohort, mtype) in enumerate(mtype_list): if (i % args.task_count) == args.task_id: print("Isolating {} in cohort {} ...".format(mtype, cohort)) # get the gene associated with this mutation, and the genes # appearing on the same chromosome for exclusion in classification use_gene = mtype.subtype_list()[0][0] use_chr = cdata.gene_annot[use_gene]['Chr'] ex_genes = {gene for gene, annot in cdata.gene_annot.items() if annot['Chr'] == use_chr} # get the set of mutations from the same gene that will be hidden # from the classifier in the isolation approach base_mtype = MuType(cdata.train_mut[use_gene].allkey()) base_mtype -= ex_mtypes[args.ex_mtype] # get the samples that will be hidden in the isolation approach coh_samps = cdata.cohort_samps[cohort.split('_')[0]] base_samps = base_mtype.get_samples(cdata.train_mut[use_gene]) base_samps &= coh_samps ex_samps = base_samps - mtype.get_samples(cdata.train_mut) # tune the classifier on the default approach task mut_clf, cv_output = mut_clf.tune_coh( cdata, mtype, include_samps=coh_samps, exclude_genes=ex_genes, tune_splits=4, test_count=48, parallel_jobs=12 ) # get the tuned parameters for the default approach task clf_params = mut_clf.get_params() for par, _ in mut_clf.tune_priors: out_tune[(cohort, mtype)]['All'][par] = clf_params[par] out_inf[(cohort, mtype)]['All'] = mut_clf.infer_coh( cdata, mtype, exclude_genes=ex_genes, force_test_samps=cdata.samples - coh_samps, infer_splits=120, infer_folds=4, parallel_jobs=12 ) # tune the classifier on the isolation approach task mut_clf, cv_output = mut_clf.tune_coh( cdata, mtype, exclude_genes=ex_genes, exclude_samps=ex_samps | (cdata.samples - coh_samps), tune_splits=4, test_count=48, parallel_jobs=12 ) # get the tuned parameters for the isolation approach task clf_params = mut_clf.get_params() for par, _ in mut_clf.tune_priors: out_tune[(cohort, mtype)]['Iso'][par] = clf_params[par] out_inf[(cohort, mtype)]['Iso'] = mut_clf.infer_coh( cdata, mtype, exclude_genes=ex_genes, force_test_samps=ex_samps | (cdata.samples - coh_samps), infer_splits=120, infer_folds=4, parallel_jobs=12 ) else: del(out_inf[(cohort, mtype)]) del(out_tune[(cohort, mtype)]) # save the experiment results for this subtask to file pickle.dump({'Infer': out_inf, 'Tune': out_tune, 'Clf': mut_clf}, open(os.path.join(args.use_dir, 'output', "out_task-{}.p".format(args.task_id)), 'wb'))