def main(): parser = argparse.ArgumentParser() parser.add_argument('cohort', type=str, help="which TCGA cohort to use") parser.add_argument('gene', type=str, help="which TCGA cohort to use") parser.add_argument('--setup_dir', type=str, default=base_dir) args = parser.parse_args() out_path = os.path.join(args.setup_dir, 'setup') cdata = get_cohort_data(args.cohort, args.gene) vars_list = set() with open(os.path.join(out_path, "cohort-data.p"), 'wb') as f: pickle.dump(cdata, f) # find subsets of point mutations with enough affected samples for each # mutated gene in the cohort if 'Point' in dict(cdata.mtree) and len(cdata.mtree['Point']) >= 20: vars_list |= cdata.mtree['Point'].branchtypes(min_size=20) if len(dict(cdata.mtree['Point'])) > 1: vars_list |= {MuType({('Scale', 'Point'): None})} if 'Copy' in dict(cdata.mtree): if 'DeepDel' in dict(cdata.mtree['Copy']): if len(cdata.mtree['Copy']['DeepDel']) >= 20: vars_list |= {MuType({('Copy', 'DeepDel'): None})} if 'DeepGain' in dict(cdata.mtree['Copy']): if len(cdata.mtree['Copy']['DeepGain']) >= 20: vars_list |= {MuType({('Copy', 'DeepGain'): None})} # filter out mutations that do not have enough wild-type samples vars_list = { mtype for mtype in vars_list if (len(mtype.get_samples(cdata.mtree)) <= (len(cdata.get_samples()) - 20)) } # remove mutations that are functionally equivalent to another mutation vars_list -= { mtype1 for mtype1, mtype2 in product(vars_list, repeat=2) if (mtype1 != mtype2 and mtype1.is_supertype(mtype2) and ( mtype1.get_samples(cdata.mtree) == mtype2.get_samples(cdata.mtree)) ) } # save the enumerated mutations, and the number of such mutations, to file with open(os.path.join(out_path, "vars-list.p"), 'wb') as fl: pickle.dump(sorted(vars_list), fl) with open(os.path.join(out_path, "vars-count.txt"), 'w') as f: f.write(str(len(vars_list)))
def get_fancy_label(mtype, scale_link=None, pnt_link=None, phrase_link=None): if scale_link is None: scale_link = ' or ' if pnt_link is None: pnt_link = ' or ' if phrase_link is None: phrase_link = ' ' if mtype.cur_level == 'Scale': sub_dict = dict(mtype.subtype_iter()) use_lbls = [] if 'Copy' in sub_dict: use_lbls += [ copy_lbls[MuType({('Scale', 'Copy'): sub_dict['Copy']})] ] if 'Point' in sub_dict: if sub_dict['Point'] is None: use_lbls += ["any point mutation"] else: use_lbls += [ nest_label(sub_dict['Point'], pnt_link, phrase_link) ] else: use_lbls = [nest_label(mtype, pnt_link, phrase_link)] return scale_link.join(use_lbls)
def compare_scores(iso_df, samps, muts_dict, get_similarities=True, all_mtype=None): base_muts = tuple(muts_dict.values())[0] if all_mtype is None: all_mtype = MuType(base_muts.allkey()) pheno_dict = { mtype: np.array(muts_dict[lvls].status(samps, mtype)) if lvls in muts_dict else np.array(base_muts.status(samps, mtype)) for lvls, mtype in iso_df.index } simil_df = pd.DataFrame(0.0, index=pheno_dict.keys(), columns=pheno_dict.keys(), dtype=np.float) auc_df = pd.DataFrame(index=pheno_dict.keys(), columns=['All', 'Iso'], dtype=np.float) all_pheno = np.array(base_muts.status(samps, all_mtype)) pheno_dict['Wild-Type'] = ~all_pheno for (_, cur_mtype), iso_vals in iso_df.iterrows(): simil_df.loc[cur_mtype, cur_mtype] = 1.0 none_vals = np.concatenate(iso_vals[~all_pheno].values) wt_vals = np.concatenate(iso_vals[~pheno_dict[cur_mtype]].values) cur_vals = np.concatenate(iso_vals[pheno_dict[cur_mtype]].values) auc_df.loc[cur_mtype, 'All'] = np.greater.outer(cur_vals, wt_vals).mean() auc_df.loc[cur_mtype, 'All'] += np.equal.outer(cur_vals, wt_vals).mean() / 2 auc_df.loc[cur_mtype, 'Iso'] = np.greater.outer(cur_vals, none_vals).mean() auc_df.loc[cur_mtype, 'Iso'] += np.equal.outer(cur_vals, none_vals).mean() / 2 if get_similarities: cur_diff = np.subtract.outer(cur_vals, none_vals).mean() if cur_diff != 0: for other_mtype in set(simil_df.index) - {cur_mtype}: other_vals = np.concatenate( iso_vals[pheno_dict[other_mtype]].values) other_diff = np.subtract.outer(other_vals, none_vals).mean() simil_df.loc[cur_mtype, other_mtype] = other_diff simil_df.loc[cur_mtype, other_mtype] /= cur_diff return pheno_dict, auc_df, simil_df
def get_pie_counts(base_mtype, use_mtree): all_type = MuType(use_mtree.allkey()) pie_mtypes = [ ExMcomb(copy_mtype, base_mtype), ExMcomb(all_type, gains_mtype, base_mtype), ExMcomb(all_type, dels_mtype, base_mtype) ] return [len(mcomb.get_samples(use_mtree)) for mcomb in pie_mtypes]
def get_all_mtype(mtype, gene, use_mtrees, lvls_dict=None, base_lvls=None): _, sub_type = tuple(mtype.subtype_iter())[0] if sub_type in lvls_dict: use_lvls = lvls_dict[sub_type] elif base_lvls is not None: use_lvls = tuple(base_lvls) else: use_lvls = sorted(use_mtrees.keys())[0] return MuType({('Gene', gene): use_mtrees[use_lvls][gene].allkey()})
def get_samples(self, *mtrees): samps = self.mtype_apply( lambda mtype: mtype.get_samples(*mtrees), and_) if self.not_mtype is not None: if self.cur_level == 'Gene': for gn, sub_type in self.not_mtype.subtype_iter(): samps -= MuType({ ('Gene', gn): sub_type}).get_samples(*mtrees) else: samps -= self.not_mtype.get_samples(*mtrees) return samps
def main(): data_dir = os.path.join(base_dir, "resources") expr_data = pd.read_csv(os.path.join(data_dir, "expr.txt.gz"), sep='\t', index_col=0) mut_data = pd.read_csv(os.path.join(data_dir, "variants.txt.gz"), sep='\t', index_col=0) cdata = BaseMutationCohort(expr_data, mut_data, mut_levels=['Gene'], mut_genes=['TP53'], cv_seed=987, test_prop=0.8) test_mtype = MuType({('Gene', 'TP53'): None}) ovp_clf = StanOverlap() ovp_clf.fit_coh(cdata, test_mtype) train_auc = ovp_clf.eval_coh(cdata, test_mtype, use_train=True) print("Stan logistic model training AUC: {:.3f}".format(train_auc)) assert train_auc >= 0.7, ( "Stan logistic model did not obtain a training AUC of at least 0.7!") test_auc = ovp_clf.eval_coh(cdata, test_mtype, use_train=False) print("Stan logistic model testing AUC: {:.3f}".format(test_auc)) assert test_auc >= 0.7, ( "Stan logistic model did not obtain a testing AUC of at least 0.7!") ovp_clf = StanOverlap() ovp_clf.fit_coh(cdata, test_mtype) train_auc = ovp_clf.eval_coh(cdata, test_mtype, use_train=True) print("Stan overlap model training AUC: {:.3f}".format(train_auc)) assert train_auc >= 0.7, ( "Stan overlap model did not obtain a training AUC of at least 0.7!") test_auc = ovp_clf.eval_coh(cdata, test_mtype, use_train=False) print("Stan overlap model testing AUC: {:.3f}".format(test_auc)) assert test_auc >= 0.7, ( "Stan overlap model did not obtain a testing AUC of at least 0.7!") log_clf = StanVarit() log_clf.fit_coh(cdata, test_mtype) train_auc = log_clf.eval_coh(cdata, test_mtype, use_train=True) print("All Stan learning tests passed successfully!")
def plot_gene_scores(out_df, args, cdata): fig, axarr = plt.subplots(figsize=(13, 10), nrows=2, ncols=6, sharex=True, sharey=True) samp_list = cdata.subset_samps() tcga_indx = np.array([samp.split('-')[0] == "TCGA" for samp in samp_list]) pt_indx = [i for i, samp in enumerate(samp_list) if samp.split(' --- ')[0] == args.patient][0] for ax, (gene, vals) in zip(axarr.flatten(), out_df.iteritems()): cur_mtype = MuType({('Gene', gene): None}) cur_pheno = np.array(cdata.train_pheno(cur_mtype)) tcga_wt = vals[~cur_pheno & tcga_indx].quantile(q=(0.25, 0.75)) tcga_mut = vals[cur_pheno & tcga_indx].quantile(q=(0.25, 0.75)) norm_val = tcga_mut[0.75] - tcga_wt[0.25] ax.add_patch(Rectangle( (0.2, 0), 0.6, (tcga_wt[0.75] - tcga_wt[0.25]) / norm_val, color=wt_clr, ec=wt_clr, alpha=0.31, lw=2.1 )) ax.add_patch(Rectangle( (0.2, (tcga_mut[0.25] - tcga_wt[0.25]) / norm_val), 0.6, (tcga_mut[0.75] - tcga_mut[0.25]) / norm_val, color=mut_clr, ec=mut_clr, alpha=0.31, lw=2.1 )) pt_val = (vals[pt_indx] - tcga_wt[0.25]) / norm_val ax.axhline(np.clip((vals[pt_indx] - tcga_wt[0.25]) / norm_val, 0, 1), xmin=0.13, xmax=0.87, color='black', lw=2.9) ax.axes.get_xaxis().set_visible(False) ax.set_ylim(-0.03, 1.03) ax.set_yticks([0, 0.25, 0.5, 0.75, 1], minor=False) ax.set_yticklabels(['WT', '', '', '', 'Mut']) ax.set_title(gene) plt.tight_layout(w_pad=1.7, h_pad=1.7) fig.savefig( os.path.join(plot_dir, 'patient_{}__{}-{}.png'.format( args.patient, args.cohort, args.classif)), dpi=300, bbox_inches='tight' ) plt.close()
def plot_coef_divergence(coef_df, auc_vals, pheno_dict, args): fig, ax = plt.subplots(figsize=(13, 8)) coef_means = coef_df.groupby(level=0, axis=1).mean() base_mtype = MuType({('Gene', args.gene): pnt_mtype}) for mtype, coef_vals in coef_means.iterrows(): use_clr = choose_subtype_colour(mtype) corr_val = spearmanr(coef_means.loc[base_mtype], coef_vals)[0] ax.scatter(auc_vals[mtype], 1 - corr_val, facecolor=[use_clr], s=751 * np.mean(pheno_dict[mtype]), alpha=0.31, edgecolors='none') x_lims = ax.get_xlim() y_lims = [-ax.get_ylim()[1] / 91, ax.get_ylim()[1]] ax.plot(x_lims, [0, 0], color='black', linewidth=1.6, alpha=0.71) ax.plot([0.5, 0.5], [0, y_lims[1]], color='black', linewidth=1.4, linestyle=':', alpha=0.61) ax.tick_params(axis='both', which='major', labelsize=17) plt.grid(alpha=0.37, linewidth=0.9) ax.set_xlim(x_lims) ax.set_ylim(y_lims) plt.xlabel("Task AUC", fontsize=23, weight='semibold') plt.ylabel("Signature Divergence\nfrom Gene-Wide Task", fontsize=23, weight='semibold') plt.savefig(os.path.join( plot_dir, '__'.join([args.expr_source, args.cohort]), "{}_coef-divergence_{}.svg".format(args.gene, args.classif)), bbox_inches='tight', format='svg') plt.close()
def compare_scores(infer_df, cdata): use_gene = { mtype.get_labels()[0] for mtype in infer_df['All'].index if not isinstance(mtype, (ExMcomb, Mcomb, RandomType)) } assert len(use_gene) == 1, ("Mutations to merge are not all associated " "with the same gene!") use_gene = tuple(use_gene)[0] base_lvls = 'Gene', 'Scale', 'Copy', 'Exon', 'Location', 'Protein' if base_lvls not in cdata.mtrees: cdata.add_mut_lvls(base_lvls) pheno_dict = { mtype: np.array(cdata.train_pheno(mtype)) for mtype in infer_df['All'].index } pheno_dict['Wild-Type'] = ~np.array( cdata.train_pheno(MuType(cdata.mtrees[base_lvls].allkey()))) siml_vals = dict( zip( infer_df['All'].index, Parallel(backend='threading', n_jobs=12, pre_dispatch=12)( delayed(calculate_simls)(pheno_dict, cur_mtype, infer_df['All'].loc[cur_mtype].values, infer_df['Iso'].loc[cur_mtype].values) for cur_mtype in infer_df['All'].index))) auc_df = pd.DataFrame( { mut: [all_auc, iso_auc] for mut, (all_auc, iso_auc, _) in siml_vals.items() }, index=['All', 'Iso']).transpose() simil_df = pd.DataFrame( {mut: siml_dict for mut, (_, _, siml_dict) in siml_vals.items()}).transpose() return pheno_dict, auc_df, simil_df
def main(): data_dir = os.path.join(base_dir, "resources") expr_data = pd.read_csv(os.path.join(data_dir, "expr_txs.txt.gz"), sep='\t', index_col=0, header=[0, 1]) mut_data = pd.read_csv(os.path.join(data_dir, "variants.txt.gz"), sep='\t', index_col=0) cdata = BaseMutationCohort(expr_data, mut_data, mut_levels=['Gene'], mut_genes=['TP53'], cv_seed=987, test_prop=0.2) test_mtype = MuType({('Gene', 'TP53'): None}) tx_clf = StanTranscripts() tx_clf.tune_coh(cdata, test_mtype, test_count=4, tune_splits=1, parallel_jobs=1) print(tx_clf) tx_clf.fit_coh(cdata, test_mtype) train_auc = tx_clf.eval_coh(cdata, test_mtype, use_train=True) print("Stan transcript model training AUC: {:.3f}".format(train_auc)) assert train_auc >= 0.7, ( "Stan transcript model did not obtain a training AUC of at least 0.7!") test_auc = tx_clf.eval_coh(cdata, test_mtype, use_train=False) print("Stan transcript model testing AUC: {:.3f}".format(test_auc)) assert test_auc >= 0.7, ( "Stan transcript model did not obtain a testing AUC of at least 0.7!") print("All Stan transcript tests passed successfully!")
def main(): parser = argparse.ArgumentParser( "Plot how isolating subvariants affects classification performance " "within and between cohorts for a gene in a transfer experiment.") parser.add_argument('gene', type=str, help="a mutated gene") parser.add_argument('classif', type=str, help="the mutation classification algorithm used") parser.add_argument('ex_mtype', type=str) parser.add_argument('cohorts', type=str, nargs='+', help="which TCGA cohort to use") parser.add_argument('--samp_cutoff', default=20, help='subtype sample frequency threshold') args = parser.parse_args() out_tag = "{}__samps-{}".format('__'.join(args.cohorts), args.samp_cutoff) out_files = glob( os.path.join(base_dir, out_tag, "out-data__*_{}_{}.p".format(args.classif, args.ex_mtype))) out_list = [ pickle.load(open(out_file, 'rb'))['Infer'] for out_file in out_files ] all_df = pd.concat([ols['All'] for ols in out_list]) iso_df = pd.concat([ols['Iso'] for ols in out_list]) if not any(mtype.subtype_list()[0][0] == args.gene for _, mtype in all_df.index): raise ValueError("No mutations associated with gene {} were " "included in this version of the " "experiment!".format(args.gene)) os.makedirs(os.path.join(plot_dir, out_tag, args.gene), exist_ok=True) out_mdls = [ out_file.split("out-data__")[1].split(".p")[0] for out_file in out_files ] # load expression and mutation data for each of the cohorts considered cdata_dict = { lvl: merge_cohort_data(os.path.join(base_dir, out_tag), use_lvl=lvl) for lvl in [mdl.split('_{}_'.format(args.classif))[0] for mdl in out_mdls] } cdata = tuple(cdata_dict.values())[0] use_samps = sorted(cdata.train_samps) copy_dict = {False: dict(), True: dict()} for norml in [False, True]: for coh in args.cohorts: copy_dict[norml][coh] = get_copies_firehose(coh.split('_')[0], copy_dir, discrete=False, normalize=norml) copy_samps = { old_smp: new_smp for old_smp, new_smp in match_tcga_samples( copy_dict[norml][coh].index)[0].items() if new_smp in cdata.cohort_samps[coh.split('_')[0]] } copy_dict[norml][coh] = copy_dict[norml][coh].loc[ copy_samps.keys(), cdata.genes] copy_dict[norml][coh].index = copy_samps.values() copy_dict[norml] = pd.concat(list( copy_dict[norml].values())).loc[use_samps] coh_stat = { cohort: np.array([ samp in cdata.cohort_samps[cohort.split('_')[0]] for samp in use_samps ]) for cohort in args.cohorts } auc_dict = { smps: { 'Reg': dict(), 'Oth': dict(), 'Hld': dict() } for smps in ['All', 'Iso'] } stab_dict = {'All': dict(), 'Iso': dict()} type_dict = dict() stat_dict = { copy_lbl: np.array(cdata.train_mut[args.gene].status( use_samps, MuType({('Scale', 'Copy'): { ('Copy', copy_lbl): None }}))) for copy_lbl in ['ShalGain', 'ShalDel'] } for (coh, mtype) in all_df.index: if mtype.subtype_list()[0][0] == args.gene: if mtype not in type_dict: use_type = mtype.subtype_list()[0][1] if (isinstance(use_type, ExMcomb) or isinstance(use_type, Mcomb)): if len(use_type.mtypes) == 1: use_subtype = tuple(use_type.mtypes)[0] mtype_lvls = use_subtype.get_sorted_levels()[1:] else: mtype_lvls = None else: use_subtype = use_type mtype_lvls = use_type.get_sorted_levels()[1:] if mtype_lvls == ('Copy', ): copy_type = use_subtype.subtype_list()[0][1].\ subtype_list()[0][0] if copy_type == 'DeepGain': type_dict[mtype] = 'Gain' elif copy_type == 'DeepDel': type_dict[mtype] = 'Loss' else: type_dict[mtype] = 'Other' else: type_dict[mtype] = 'Point' if mtype not in auc_dict['All']['Reg']: for smps in ['All', 'Iso']: stab_dict[smps][mtype] = dict() for auc_type in ['Reg', 'Oth', 'Hld']: auc_dict[smps][auc_type][mtype] = dict() use_gene, use_type = mtype.subtype_list()[0] mtype_lvls = use_type.get_sorted_levels()[1:] if '__'.join(mtype_lvls) in cdata_dict: use_lvls = '__'.join(mtype_lvls) elif not mtype_lvls or mtype_lvls == ('Copy', ): use_lvls = 'Location__Protein' mtype_stat = np.array(cdata_dict[use_lvls].train_mut.status( use_samps, mtype)) all_vals = all_df.loc[[(coh, mtype)]].values[0] iso_vals = iso_df.loc[[(coh, mtype)]].values[0] gene_muts = cdata_dict[use_lvls].train_mut[args.gene] gene_mtype = MuType(gene_muts.allkey()) - ex_mtypes[args.ex_mtype] gene_stat = np.array(gene_muts.status(use_samps, gene_mtype)) for tst_coh in args.cohorts: use_stat = coh_stat[tst_coh] & mtype_stat if np.sum(use_stat) >= 20: stat_dict[tst_coh, mtype] = use_stat stab_dict['All'][mtype][coh, tst_coh] = np.mean( [np.std(vals) for vals in all_vals[coh_stat[tst_coh]]]) stab_dict['All'][mtype][coh, tst_coh] /= np.std([ np.mean(vals) for vals in all_vals[coh_stat[tst_coh]] ]) stab_dict['Iso'][mtype][coh, tst_coh] = np.mean( [np.std(vals) for vals in iso_vals[coh_stat[tst_coh]]]) stab_dict['Iso'][mtype][coh, tst_coh] /= np.std([ np.mean(vals) for vals in iso_vals[coh_stat[tst_coh]] ]) wt_stat = coh_stat[tst_coh] & ~mtype_stat wt_vals = np.concatenate(all_vals[wt_stat]) none_stat = coh_stat[tst_coh] & ~gene_stat none_vals = np.concatenate(iso_vals[none_stat]) if tst_coh == coh: cv_count = 30 else: cv_count = 120 cur_stat = coh_stat[tst_coh] & mtype_stat cur_all_vals = np.concatenate(all_vals[cur_stat]) cur_iso_vals = np.concatenate(iso_vals[cur_stat]) auc_dict['All']['Reg'][mtype][(coh, tst_coh)] = np.\ greater.outer(cur_all_vals, wt_vals).mean() auc_dict['All']['Reg'][mtype][(coh, tst_coh)] += np.\ equal.outer(cur_all_vals, wt_vals).mean() / 2 auc_dict['Iso']['Reg'][mtype][(coh, tst_coh)] = np.\ greater.outer(cur_iso_vals, none_vals).mean() auc_dict['Iso']['Reg'][mtype][(coh, tst_coh)] += np.\ equal.outer(cur_iso_vals, none_vals).mean() / 2 plot_auc_comparisons(auc_dict, stat_dict, type_dict, args) for copy_norml in [False, True]: for use_cohort in args.cohorts: plot_copy_calls(use_cohort, all_df, iso_df, copy_dict, copy_norml, auc_dict, stat_dict, coh_stat, type_dict, args)
def main(): data_dir = os.path.join(base_dir, "resources") expr_data = pd.read_csv(os.path.join(data_dir, "expr.txt.gz"), sep='\t', index_col=0) mut_data = pd.read_csv(os.path.join(data_dir, "variants.txt.gz"), sep='\t', index_col=0) expr_dict = {'C1': expr_data.iloc[::2, :], 'C2': expr_data.iloc[1::2, :]} mut_dict = { coh: mut_data.loc[mut_data.Sample.isin(expr.index), :].copy() for coh, expr in expr_dict.items() } sing_clf = SingleTransfer() mult_clf = MultiTransfer() multmult_clf = MultiMultiTransfer() sing_mtype = MuType({('Gene', 'TP53'): None}) mult_mtypes = [ MuType({('Gene', 'TP53'): None}), MuType({('Gene', 'GATA3'): None}) ] uni_cdata = BaseMutationCohort(expr_data, mut_data, mut_levels=[['Gene']], mut_genes=['TP53', 'GATA3'], cv_seed=None, test_prop=0.3) uni_cdata.update_split(new_seed=101) sing_clf.tune_coh(uni_cdata, mult_mtypes, test_count=4, tune_splits=2, parallel_jobs=1) print(sing_clf) sing_clf.fit_coh(uni_cdata, mult_mtypes) train_auc = sing_clf.eval_coh(uni_cdata, mult_mtypes, use_train=True) print("Multi-pheno single-domain " "KBTL model training AUC: {:.3f}".format(train_auc)) assert train_auc >= 0.6, ( "KBTL model did not obtain a training AUC of at least 0.6!") test_auc = sing_clf.eval_coh(uni_cdata, mult_mtypes, use_train=False) print("Multi-pheno single-domain " "KBTL model testing AUC: {:.3f}".format(test_auc)) assert test_auc >= 0.6, ( "KBTL model did not obtain a testing AUC of at least 0.6!") trs_cdata = BaseTransferMutationCohort(expr_dict, mut_dict, mut_levels=[['Gene']], mut_genes=['TP53', 'GATA3'], cv_seed=None, test_prop=0.3) trs_cdata.update_split(new_seed=101) mult_clf.tune_coh(trs_cdata, sing_mtype, test_count=4, tune_splits=2, parallel_jobs=1) print(mult_clf) mult_clf.fit_coh(trs_cdata, sing_mtype) train_auc = mult_clf.eval_coh(trs_cdata, sing_mtype, use_train=True) print("Single-pheno multi-domain " "KBTL model training AUC: {:.3f}".format(train_auc)) assert train_auc >= 0.6, ( "KBTL model did not obtain a training AUC of at least 0.6!") test_auc = mult_clf.eval_coh(trs_cdata, sing_mtype, use_train=False) print("Single-pheno multi-domain " "KBTL model testing AUC: {:.3f}".format(test_auc)) assert test_auc >= 0.6, ( "KBTL model did not obtain a testing AUC of at least 0.6!") multmult_clf.tune_coh(trs_cdata, mult_mtypes, test_count=4, tune_splits=2, parallel_jobs=1) print(multmult_clf) multmult_clf.fit_coh(trs_cdata, mult_mtypes) train_auc = multmult_clf.eval_coh(trs_cdata, mult_mtypes, use_train=True) print("Multi-pheno multi-domain " "KBTL model training AUC: {:.3f}".format(train_auc)) assert train_auc >= 0.6, ( "KBTL model did not obtain a training AUC of at least 0.6!") test_auc = multmult_clf.eval_coh(trs_cdata, mult_mtypes, use_train=False) print("Multi-pheno multi-domain " "KBTL model testing AUC: {:.3f}".format(test_auc)) assert test_auc >= 0.6, ( "KBTL model did not obtain a testing AUC of at least 0.6!") print("All transfer learning tests passed successfully!")
def plot_subcopy_symmetry(pred_dfs, pheno_dict, auc_dfs, cdata, args, cna_lbl, use_src, use_coh, siml_metric): fig, ax = plt.subplots(figsize=(8.43, 9)) cna_mtype = cna_mtypes[cna_lbl] use_combs = { mut for mut, auc_val in auc_dfs['Iso'].iteritems() if (isinstance(mut, ExMcomb) and auc_val >= 0.6 and not (mut.all_mtype & shal_mtype).is_empty()) } plt_combs = { mcomb for mcomb in use_combs if (set(mcomb.mtypes) == {MuType({('Gene', args.gene): cna_mtype})}) } assert len(plt_combs) <= 1, ( "Too many exclusive {} CNAs found!".format(cna_lbl)) if len(plt_combs) == 1: plt_comb = tuple(plt_combs)[0] else: return None use_combs = remove_pheno_dups( { mcomb for mcomb in use_combs if (all( (cna_mtype & tuple(mtp.subtype_iter())[0][1]).is_empty() for mtp in mcomb.mtypes) or not (pheno_dict[plt_comb] & pheno_dict[mcomb]).any()) }, pheno_dict) if len(use_combs) == 0: print("No mutation-copy mutual similarity pairs found!") return None use_mtree = tuple(cdata.mtrees.values())[0][args.gene] all_mtype = MuType({('Gene', args.gene): use_mtree.allkey()}) all_phn = np.array(cdata.train_pheno(all_mtype)) train_samps = cdata.get_train_samples() map_args = [] ex_indx = [] use_preds = pred_dfs['Iso'].loc[use_combs | plt_combs, train_samps] wt_vals = { mcomb: pred_vals[~all_phn] for mcomb, pred_vals in use_preds.iterrows() } mut_vals = { mcomb: pred_vals[pheno_dict[mcomb]] for mcomb, pred_vals in use_preds.iterrows() } if siml_metric == 'mean': wt_means = {mcomb: vals.mean() for mcomb, vals in wt_vals.items()} mut_means = {mcomb: vals.mean() for mcomb, vals in mut_vals.items()} map_args += [(wt_vals[mcomb1], mut_vals[mcomb1], use_preds.loc[mcomb1, pheno_dict[mcomb2]], wt_means[mcomb1], mut_means[mcomb1], None) for mcomb in use_combs for mcomb1, mcomb2 in permt([mcomb, plt_comb])] elif siml_metric == 'ks': base_dists = { mcomb: ks_2samp(wt_vals[mcomb], mut_vals[mcomb], alternative='greater').statistic for mcomb in use_preds.index } map_args += [(wt_vals[mcomb1], mut_vals[mcomb1], use_preds.loc[mcomb1, pheno_dict[mcomb2]], base_dists[mcomb1]) for mcomb in use_combs for mcomb1, mcomb2 in permt([mcomb, plt_comb])] if siml_metric == 'mean': chunk_size = int(len(map_args) / args.cores) + 1 elif siml_metric == 'ks': chunk_size = int(len(map_args) / (23 * args.cores)) + 1 pool = mp.Pool(args.cores) siml_list = pool.starmap(siml_fxs[siml_metric], map_args, chunk_size) pool.close() siml_vals = dict(zip(use_combs, zip(siml_list[::2], siml_list[1::2]))) plt_lims = min(siml_list) - 0.19, max(max(siml_list) + 0.19, 1.03) size_mult = 20307 * len(map_args)**(-5 / 13) clr_norm = colors.Normalize(vmin=-1, vmax=2) ax.plot(plt_lims, [0, 0], color='black', linewidth=1.37, linestyle=':', alpha=0.53) ax.plot([0, 0], plt_lims, color='black', linewidth=1.37, linestyle=':', alpha=0.53) ax.plot(plt_lims, plt_lims, color='#550000', linewidth=1.43, linestyle='--', alpha=0.41) for siml_val in [-1, 1, 2]: ax.plot(plt_lims, [siml_val] * 2, color=simil_cmap(clr_norm(siml_val)), linewidth=4.1, linestyle=':', alpha=0.37) ax.plot([siml_val] * 2, plt_lims, color=simil_cmap(clr_norm(siml_val)), linewidth=4.1, linestyle=':', alpha=0.37) plt_lctr = plt.MaxNLocator(7, steps=[1, 2, 5]) ax.xaxis.set_major_locator(plt_lctr) ax.yaxis.set_major_locator(plt_lctr) for mcomb in use_combs: plt_sz = size_mult * np.mean(pheno_dict[mcomb]) if len(mcomb.mtypes) == 1: plt_clr = choose_subtype_colour( tuple(reduce(or_, mcomb.mtypes).subtype_iter())[0][1]) ax.scatter(*siml_vals[mcomb], s=plt_sz, c=[plt_clr], alpha=13 / 71, edgecolor='none') else: for i, (plt_half, mtype) in enumerate(zip(['left', 'right'], mcomb.mtypes)): mrk_style = MarkerStyle('o', fillstyle=plt_half) plt_clr = choose_subtype_colour( tuple(mtype.subtype_iter())[0][1]) ax.scatter(*siml_vals[mcomb], marker=mrk_style, s=plt_sz, facecolor=plt_clr, alpha=13 / 71, edgecolor='none') ax.set_xlabel("{} Similarity to Subgrouping".format(cna_lbl), size=23, weight='bold') ax.set_ylabel( "Subgrouping Similarity to\nAll {} Alterations".format(cna_lbl), size=23, weight='bold') ax.grid(alpha=0.47, linewidth=0.9) ax.set_xlim(*plt_lims) ax.set_ylim(*plt_lims) plt.savefig(os.path.join( plot_dir, args.gene, "{}__{}-sub{}-symmetry_{}_{}.svg".format(use_coh, siml_metric, cna_lbl, args.classif, use_src)), bbox_inches='tight', format='svg') plt.close()
def plot_score_symmetry(pred_dfs, pheno_dict, auc_dfs, cdata, args, use_src, use_coh, siml_metric): fig, (iso_ax, ish_ax) = plt.subplots(figsize=(15, 8), nrows=1, ncols=2) use_mtree = tuple(cdata.mtrees.values())[0][args.gene] all_mtypes = {'Iso': MuType({('Gene', args.gene): use_mtree.allkey()})} all_mtypes['IsoShal'] = all_mtypes['Iso'] - MuType( {('Gene', args.gene): shal_mtype}) all_phns = { ex_lbl: np.array(cdata.train_pheno(all_mtype)) for ex_lbl, all_mtype in all_mtypes.items() } train_samps = cdata.get_train_samples() iso_combs = remove_pheno_dups( { mut for mut, auc_val in auc_dfs['Iso'].iteritems() if (isinstance(mut, ExMcomb) and auc_val >= args.auc_cutoff and not (mut.all_mtype & shal_mtype).is_empty()) }, pheno_dict) ish_combs = remove_pheno_dups( { mut for mut, auc_val in auc_dfs['IsoShal'].iteritems() if (isinstance(mut, ExMcomb) and auc_val >= args.auc_cutoff and ( mut.all_mtype & shal_mtype).is_empty() and all( (mtp & shal_mtype).is_empty() for mtp in mut.mtypes)) }, pheno_dict) pairs_dict = { ex_lbl: [(mcomb1, mcomb2) for mcomb1, mcomb2 in combn(use_combs, 2) if (all( (mtp1 & mtp2).is_empty() for mtp1, mtp2 in product(mcomb1.mtypes, mcomb2.mtypes)) or not (pheno_dict[mcomb1] & pheno_dict[mcomb2]).any())] for ex_lbl, use_combs in [('Iso', iso_combs), ('IsoShal', ish_combs)] } if args.verbose: for ex_lbl, use_combs in zip(['Iso', 'IsoShal'], [iso_combs, ish_combs]): pair_strs = [ "\n#########\n" "{}: {}({}) {} pairs from {} types".format( use_coh, args.gene, ex_lbl, len(pairs_dict[ex_lbl]), len(use_combs)) ] if pairs_dict[ex_lbl]: pair_strs += ['----------'] pair_strs += [ '\txxxxx\t'.join([str(mcomb) for mcomb in pair]) for pair in pairs_dict[ex_lbl][::( len(pairs_dict[ex_lbl]) // (args.verbose * 7) + 1)] ] combs_dict = { ex_lbl: set(reduce(add, use_pairs)) for ex_lbl, use_pairs in pairs_dict.items() if use_pairs } if not combs_dict: return None map_args = [] ex_indx = [] for ex_lbl, pair_combs in combs_dict.items(): ex_indx += [(ex_lbl, mcombs) for mcombs in pairs_dict[ex_lbl]] use_preds = pred_dfs[ex_lbl].loc[pair_combs, train_samps] wt_vals = { mcomb: use_preds.loc[mcomb, ~all_phns[ex_lbl]] for mcomb in pair_combs } mut_vals = { mcomb: use_preds.loc[mcomb, pheno_dict[mcomb]] for mcomb in pair_combs } if siml_metric == 'mean': wt_means = {mcomb: vals.mean() for mcomb, vals in wt_vals.items()} mut_means = { mcomb: vals.mean() for mcomb, vals in mut_vals.items() } map_args += [(wt_vals[mcomb1], mut_vals[mcomb1], use_preds.loc[mcomb1, pheno_dict[mcomb2]], wt_means[mcomb1], mut_means[mcomb1], None) for mcombs in pairs_dict[ex_lbl] for mcomb1, mcomb2 in permt(mcombs)] elif siml_metric == 'ks': base_dists = { mcomb: ks_2samp(wt_vals[mcomb], mut_vals[mcomb], alternative='greater').statistic for mcomb in pair_combs } map_args += [ (wt_vals[mcomb1], mut_vals[mcomb1], use_preds.loc[mcomb1, pheno_dict[mcomb2]], base_dists[mcomb1]) for mcombs in pairs_dict[ex_lbl] for mcomb1, mcomb2 in permt(mcombs) ] if siml_metric == 'mean': chunk_size = int(len(map_args) / (41 * args.cores)) + 1 elif siml_metric == 'ks': chunk_size = int(len(map_args) / (23 * args.cores)) + 1 pool = mp.Pool(args.cores) siml_list = pool.starmap(siml_fxs[siml_metric], map_args, chunk_size) pool.close() siml_vals = dict(zip(ex_indx, zip(siml_list[::2], siml_list[1::2]))) #TODO: scale by plot ranges or leave as is and thus make sizes # relative to "true" plotting area? plt_lims = min(siml_list) - 0.19, max(siml_list) + 0.19 size_mult = 18301 * len(map_args)**(-5 / 13) clr_norm = colors.Normalize(vmin=-1, vmax=2) for ax, ex_lbl in zip([iso_ax, ish_ax], ['Iso', 'IsoShal']): ax.grid(alpha=0.47, linewidth=0.9) ax.plot(plt_lims, [0, 0], color='black', linewidth=1.37, linestyle=':', alpha=0.53) ax.plot([0, 0], plt_lims, color='black', linewidth=1.37, linestyle=':', alpha=0.53) ax.plot(plt_lims, plt_lims, color='#550000', linewidth=1.43, linestyle='--', alpha=0.41) for siml_val in [-1, 1, 2]: ax.plot(plt_lims, [siml_val] * 2, color=simil_cmap(clr_norm(siml_val)), linewidth=4.1, linestyle=':', alpha=0.37) ax.plot([siml_val] * 2, plt_lims, color=simil_cmap(clr_norm(siml_val)), linewidth=4.1, linestyle=':', alpha=0.37) plt_lctr = plt.MaxNLocator(7, steps=[1, 2, 5]) ax.xaxis.set_major_locator(plt_lctr) ax.yaxis.set_major_locator(plt_lctr) for mcomb1, mcomb2 in pairs_dict[ex_lbl]: plt_sz = size_mult * (np.mean(pheno_dict[mcomb1]) * np.mean(pheno_dict[mcomb2]))**0.5 for i, (plt_half, mcomb) in enumerate( zip(['left', 'right'], [mcomb1, mcomb2])): mrk_style = MarkerStyle('o', fillstyle=plt_half) plt_clr = choose_subtype_colour( tuple(reduce(or_, mcomb.mtypes).subtype_iter())[0][1]) ax.scatter(*siml_vals[ex_lbl, (mcomb1, mcomb2)], s=plt_sz, facecolor=plt_clr, marker=mrk_style, alpha=13 / 71, edgecolor='none') if ex_lbl == 'IsoShal': ax.text(1, 0, "AUC >= {:.2f}".format(args.auc_cutoff), size=19, ha='right', va='bottom', transform=ax.transAxes, fontstyle='italic') iso_ax.set_title( "Similarities Computed Treating\nShallow CNAs as Mutant\n", size=23, weight='bold') ish_ax.set_title( "Similarities Computed Treating\nShallow CNAs as Wild-Type\n", size=23, weight='bold') for ax in [iso_ax, ish_ax]: ax.set_xlim(*plt_lims) ax.set_ylim(*plt_lims) plt.tight_layout(w_pad=3.1) plt.savefig(os.path.join( plot_dir, args.gene, "{}__{}-siml-symmetry_{}_{}.svg".format(use_coh, siml_metric, args.classif, use_src)), bbox_inches='tight', format='svg') plt.close()
def plot_distr_comparisons(auc_vals, conf_vals, pheno_dict, args): gene_dict = dict() conf_list = conf_vals[[ not isinstance(mtype, RandomType) and (tuple(mtype.subtype_iter())[0][1] & copy_mtype).is_empty() for mtype in conf_vals.index ]] for gene, conf_vec in conf_list.apply( lambda confs: np.percentile(confs, 25)).groupby( lambda mtype: tuple(mtype.label_iter())[0]): if len(conf_vec) > 1: base_mtype = MuType({('Gene', gene): pnt_mtype}) base_indx = conf_vec.index.get_loc(base_mtype) best_subtype = conf_vec[:base_indx].append( conf_vec[(base_indx + 1):]).idxmax() if conf_vec[best_subtype] > 0.7: gene_dict[gene] = ( choose_label_colour(gene), base_mtype, best_subtype, calc_conf(conf_list[best_subtype], conf_list[base_mtype]) ) plt_size = min(len(gene_dict), 12) ymin = 0.47 fig, axarr = plt.subplots(figsize=(0.5 + 1.5 * plt_size, 7), nrows=1, ncols=plt_size, sharey=True) for i, (gene, (gene_clr, base_mtype, best_subtype, conf_sc)) in enumerate( sorted(gene_dict.items(), key=lambda x: auc_vals[x[1][2]], reverse=True)[:plt_size] ): axarr[i].set_title(gene, size=21, weight='semibold') plt_df = pd.concat([ pd.DataFrame({'Type': 'Base', 'Conf': conf_list[base_mtype]}), pd.DataFrame({'Type': 'Subg', 'Conf': conf_list[best_subtype]}) ]) sns.violinplot(x=plt_df.Type, y=plt_df.Conf, ax=axarr[i], order=['Subg', 'Base'], palette=[gene_clr, gene_clr], cut=0, linewidth=0, width=0.93) axarr[i].scatter(0, auc_vals[best_subtype], s=47, c=[gene_clr], edgecolor='0.31', alpha=0.93) axarr[i].scatter(1, auc_vals[base_mtype], s=47, c=[gene_clr], edgecolor='0.31', alpha=0.41) axarr[i].get_children()[0].set_alpha(0.71) axarr[i].get_children()[2].set_alpha(0.29) if conf_sc == 1: conf_lbl = "1" elif 0.9995 < conf_sc < 1: conf_lbl = ">0.999" else: conf_lbl = "{:.3f}".format(conf_sc) axarr[i].text(0.5, 1 / 97, conf_lbl, size=17, ha='center', va='bottom', transform=axarr[i].transAxes) axarr[i].plot([-0.5, 1.5], [0.5, 0.5], color='black', linewidth=2.3, linestyle=':', alpha=0.83) axarr[i].plot([-0.5, 1.5], [1, 1], color='black', linewidth=1.7, alpha=0.83) axarr[i].set_xlabel('') axarr[i].set_xticklabels([]) ymin = min(ymin, min(conf_list[base_mtype]) - 0.04, min(conf_list[best_subtype]) - 0.04) if i == 0: axarr[i].set_ylabel('AUCs', size=21, weight='semibold') else: axarr[i].set_ylabel('') fig.text(plt_size ** 0.71 / 97, 1 / 19, "conf.\nscore", fontsize=15, weight='semibold', ha='right', va='bottom') if 0.463 < ymin < 0.513: ymin = 0.453 for ax in axarr: ax.set_ylim([ymin, 1 + (1 - ymin) / 23]) plt.savefig( os.path.join(plot_dir, '__'.join([args.expr_source, args.cohort]), "distr-comparisons_{}.svg".format(args.classif)), bbox_inches='tight', format='svg' ) plt.close()
def plot_mtype_distributions(mtypes, infer_dict, cdata, args): fig, axarr = plt.subplots(figsize=(0.1 + 2.8 * len(mtypes), 9), nrows=2, ncols=len(mtypes)) gene_stat = np.array(cdata.train_pheno(MuType({('Gene', args.gene): None}))) all_mtype = MuType(cdata.train_mut.allkey()) for j, cur_mtype in enumerate(mtypes): rest_stat = np.array(cdata.train_pheno(all_mtype - cur_mtype)) infer_df = pd.DataFrame({ 'Value': infer_dict['All'].loc[cur_mtype], 'cStat': np.array(cdata.train_pheno(cur_mtype)), 'rStat': np.array(cdata.train_pheno(all_mtype - cur_mtype)) }) mtype_str = str(cur_mtype).split(':')[-1] axarr[0, j].text(0.5, 1.01, "{}({}) mutations\n({} affected samples)".format( args.gene, mtype_str, np.sum(infer_df.cStat)), size=12, ha='center', va='bottom', transform=axarr[0, j].transAxes) vals_min, vals_max = infer_df.Value.quantile(q=[0, 1]) vals_rng = (vals_max - vals_min) / 31 axarr[0, j].set_ylim(vals_min - 6 * vals_rng, vals_max + 3 * vals_rng) sns.violinplot(data=infer_df[~infer_df.cStat], y='Value', palette=[mut_clrs['Wild-Type']], linewidth=0, cut=0, ax=axarr[0, j]) sns.violinplot(data=infer_df[infer_df.cStat], y='Value', palette=[mut_clrs['Mutant']], linewidth=0, cut=0, ax=axarr[0, j]) axarr[0, j].text(0.5, 0.97, "{} AUC: {:.3f}".format( args.classif, calc_auc(infer_df.Value, infer_df.cStat)), size=10, ha='center', va='top', transform=axarr[0, j].transAxes) axarr[0, j].legend([ Patch(color=mut_clrs['Mutant'], alpha=0.36), Patch(color=mut_clrs['Wild-Type'], alpha=0.36) ], ["{} Mutants".format(mtype_str), "{} Wild-Types".format(mtype_str)], fontsize=11, ncol=1, loc=8, bbox_to_anchor=( 0.5, -0.01)).get_frame().set_linewidth(0.0) sns.violinplot(data=infer_df[~infer_df.cStat], x='cStat', y='Value', hue='rStat', palette=[mut_clrs['Wild-Type']], hue_order=[False, True], split=True, linewidth=0, cut=0, ax=axarr[1, j]) sns.violinplot(data=infer_df[infer_df.cStat], x='cStat', y='Value', hue='rStat', palette=[mut_clrs['Mutant']], hue_order=[False, True], split=True, linewidth=0, cut=0, ax=axarr[1, j]) axarr[1, j].get_legend().remove() axarr[1, j].set_ylim(vals_min - 2 * vals_rng, vals_max + 2 * vals_rng) axarr[1, j].axvline(x=0, ymin=-1, ymax=2, color='black', linewidth=1.1, alpha=0.61) axarr[1, j].text(0.09, 0.98, "AUC: {:.3f}".format( calc_auc(infer_df.Value[~infer_df.rStat], infer_df.cStat[~infer_df.rStat])), size=9, ha='left', va='top', transform=axarr[1, j].transAxes) axarr[1, j].text(0.91, 0.98, "AUC: {:.3f}".format( calc_auc(infer_df.Value[infer_df.rStat], infer_df.cStat[infer_df.rStat])), size=9, ha='right', va='top', transform=axarr[1, j].transAxes) vio_mesh = (('wtWO', np.stack( [[0.45] * 40, np.arange(0.1, 0.5, 0.01)], axis=1)), ('mutWO', np.stack([[0.45] * 30, np.arange(0.57, 0.865, 0.01)], axis=1)), ('wtW', np.stack( [[0.55] * 40, np.arange(0.1, 0.5, 0.01)], axis=1)), ('mutW', np.stack([[0.55] * 30, np.arange(0.57, 0.865, 0.01)], axis=1))) vio_indx = {'wtWO': 0, 'wtW': 1, 'mutWO': 3, 'mutW': 4} pos_deflt = {'wtWO': 0.06, 'wtW': 0.06, 'mutWO': 0.88, 'mutW': 0.88} lbl_txt = { 'wtWO': "{} wt w/o\nother {} muts".format(mtype_str, args.gene), 'mutWO': "{} mut w/o\nother {} muts".format(mtype_str, args.gene), 'wtW': "{} wt w/\nother {} muts".format(mtype_str, args.gene), 'mutW': "{} mut w/\nother {} muts".format(mtype_str, args.gene) } for i in range(2): axarr[i, j].set_xticks([]) axarr[i, j].xaxis.label.set_visible(False) axarr[i, j].yaxis.label.set_visible(False) axarr[i, j].set_xticklabels([]) axarr[i, j].set_yticklabels([]) for art in axarr[i, j].get_children(): if isinstance(art, PolyCollection): art.set_alpha(0.41) for lbl, mesh in vio_mesh: vio_ovlp = axarr[1, j].get_children()[ vio_indx[lbl]].get_paths()[0].contains_points( axarr[1, j].transAxes.transform(mesh), transform=axarr[i, j].transData) if np.all(vio_ovlp): ypos = pos_deflt[lbl] else: if lbl[:2] == 'wt': ypos = np.max(mesh[:, 1][~vio_ovlp]) else: ypos = np.min(mesh[:, 1][~vio_ovlp]) if lbl[:2] == 'wt': str_clr = mut_clrs['Wild-Type'] str_va = 'top' else: str_clr = mut_clrs['Mutant'] str_va = 'bottom' if lbl[-2:] == 'WO': str_ha = 'right' else: str_ha = 'left' axarr[1, j].text(mesh[0, 0], ypos, lbl_txt[lbl], color=str_clr, size=7, ha=str_ha, va=str_va, transform=axarr[1, j].transAxes) plt.tight_layout() plt.savefig(os.path.join( plot_dir, args.cohort, "mtype-distributions_{}_{}_{}_samps-{}.png".format( args.gene, args.mut_levels.replace('__', '-'), args.classif, args.samp_cutoff)), dpi=300, bbox_inches='tight') plt.close()
mpl.use('Agg') import seaborn as sns import matplotlib.pyplot as plt from matplotlib.patches import Patch from matplotlib.collections import PolyCollection mut_clrs = { 'Mutant': sns.light_palette('#C50000', reverse=True)[0], 'Wild-Type': '0.29' } variant_mtypes = (('Loss Alterations', MuType({ ('Scale', 'Copy'): { ('Copy', ('ShalDel', 'DeepDel')): None } })), ('Other\nPoint Mutations', MuType({('Scale', 'Point'): None})), ('Gain Alterations', MuType({ ('Scale', 'Copy'): { ('Copy', ('ShalGain', 'DeepGain')): None } }))) def calc_auc(vals, stat): return (np.greater.outer(vals[stat], vals[~stat]).mean() + 0.5 * np.equal.outer(vals[stat], vals[~stat]).mean())
def main(): parser = argparse.ArgumentParser( 'plot_point', description="Compares point mutation subgroupings with a cohort.") parser.add_argument('classif', help="a mutation classifier") parser.add_argument('--data_cache', type=Path) parser.add_argument('--filters', nargs='+', default=['Point']) args = parser.parse_args() if not args.data_cache or not Path.exists(args.data_cache): out_datas = tuple( Path(base_dir).glob( os.path.join( "*", "out-aucs__*__semideep__{}.p.gz".format(args.classif)))) out_list = pd.DataFrame([{ 'Source': '__'.join(out_data.parts[-2].split('__')[:-1]), 'Cohort': out_data.parts[-2].split('__')[-1], 'Levels': '__'.join(out_data.parts[-1].split('__')[1:-2]), 'File': out_data } for out_data in out_datas]).groupby('Cohort').filter( lambda outs: 'Consequence__Exon' in set(outs.Levels)) if len(out_list) == 0: raise ValueError("No completed experiments found for this " "combination of parameters!") out_list = out_list[out_list.Cohort.isin(train_cohorts)] use_iter = out_list.groupby(['Source', 'Cohort', 'Levels'])['File'] out_dirs = {(src, coh): Path(base_dir, '__'.join([src, coh])) for src, coh, _ in use_iter.groups} out_tags = { fl: '__'.join(fl.parts[-1].split('__')[1:]) for fl in out_list.File } phn_dicts = {(src, coh): dict() for src, coh, _ in use_iter.groups} auc_dfs = { (src, coh): {ex_lbl: pd.DataFrame() for ex_lbl in ['All', 'Iso', 'IsoShal']} for src, coh, _ in use_iter.groups } for (src, coh, lvls), out_files in use_iter: out_aucs = list() for out_file in out_files: with bz2.BZ2File( Path(out_dirs[src, coh], '__'.join(["out-pheno", out_tags[out_file]])), 'r') as f: phn_dicts[src, coh].update(pickle.load(f)) with bz2.BZ2File( Path(out_dirs[src, coh], '__'.join(["out-aucs", out_tags[out_file]])), 'r') as f: out_aucs += [pickle.load(f)] mtypes_comp = np.greater_equal.outer( *([[set(auc_dict['All'].index) for auc_dict in out_aucs]] * 2)) super_comp = np.apply_along_axis(all, 1, mtypes_comp) # if there is not a subgrouping set that contains all the others, # concatenate the output of all sets... if not super_comp.any(): for ex_lbl in ['All', 'Iso', 'IsoShal']: auc_dfs[src, coh][ex_lbl] = auc_dfs[src, coh][ex_lbl].append( pd.concat([aucs[ex_lbl] for aucs in out_aucs])) # ...otherwise, use the "superset" else: super_indx = super_comp.argmax() for ex_lbl in ['All', 'Iso', 'IsoShal']: auc_dfs[src, coh][ex_lbl] = auc_dfs[src, coh][ex_lbl].append( out_aucs[super_indx][ex_lbl]) # filter out duplicate subgroupings due to overlapping search criteria for src, coh, _ in use_iter.groups: for ex_lbl in ['All', 'Iso', 'IsoShal']: auc_dfs[src, coh][ex_lbl].sort_index(inplace=True) auc_dfs[src, coh][ex_lbl] = auc_dfs[src, coh][ex_lbl].loc[~auc_dfs[ src, coh][ex_lbl].index.duplicated()] auc_dict = dict() for ex_lbl in ['All', 'Iso', 'IsoShal']: auc_dict[ex_lbl] = pd.DataFrame({ (src, coh, mut): auc_vals for (src, coh), auc_df in auc_dfs.items() for mut, auc_vals in auc_df[ex_lbl].iterrows() }).transpose() auc_dict[ex_lbl]['mean'] = auc_dict[ex_lbl]['mean'].astype(float) auc_dict[ex_lbl]['all'] = auc_dict[ex_lbl]['all'].astype(float) if args.data_cache: with bz2.BZ2File(args.data_cache, 'w') as f: pickle.dump((phn_dicts, auc_dict), f, protocol=-1) else: with bz2.BZ2File(args.data_cache, 'r') as f: phn_dicts, auc_dict = pickle.load(f) os.makedirs(plot_dir, exist_ok=True) plot_auc_comparisons(auc_dict, phn_dicts, args) for filter_lbl in args.filters: cv_dict = dict() acc_dict = dict() filter_fx, base_subtype = sub_filters[filter_lbl] use_aucs = auc_dict['Iso'][[ get_mut_ex(mut) == 'Iso' and filter_fx(mut) for _, _, mut in auc_dict['Iso'].index ]]['mean'] for (src, coh, gene), auc_vec in use_aucs.groupby( lambda x: (x[0], x[1], get_label(x[2]))): base_mtype = MuType({('Gene', gene): base_subtype}) sub_aucs = auc_vec[[ mut != base_mtype for _, _, mut in auc_vec.index ]] if len(sub_aucs) == 0: cv_dict[coh, gene] = -1 else: cv_dict[coh, gene] = max( (np.array(auc_dict['Iso']['CV'][comb]) > np.array( auc_dict['All']['CV'][src, coh, base_mtype])).sum() for comb in sub_aucs.index) acc_dict[coh, gene, base_mtype] = auc_dict['All']['mean'][src, coh, base_mtype] for comb in sub_aucs.index: acc_dict[coh, gene, comb] = auc_dict['Iso']['mean'][comb] for ex_lbl, auc_df in auc_dict.items(): for filter_lbl in args.filters: plot_sub_comparisons(auc_df, phn_dicts, args, ex_lbl, filter_lbl, auc_dict['All'])
def plot_sub_comparisons(conf_vals, pheno_dict, args): fig, ax = plt.subplots(figsize=(11, 11)) pnt_dict = dict() clr_dict = dict() conf_list = conf_vals[[ not isinstance(mtype, RandomType) and not (mtype.subtype_list()[0][1] != pnt_mtype and pheno_dict[mtype].sum() == pheno_dict[MuType( {('Gene', mtype.get_labels()[0]): pnt_mtype})].sum()) for mtype in conf_vals.index ]] conf_list = conf_list.apply(lambda confs: np.percentile(confs, 25)) for gene, conf_vec in conf_list.groupby( lambda mtype: mtype.get_labels()[0]): if len(conf_vec) > 1: base_mtype = MuType({('Gene', gene): pnt_mtype}) base_indx = conf_vec.index.get_loc(base_mtype) best_subtype = conf_vec[:base_indx].append( conf_vec[(base_indx + 1):]).idxmax() best_indx = conf_vec.index.get_loc(best_subtype) if conf_vec[best_indx] > 0.6: clr_dict[gene] = choose_gene_colour(gene) base_size = np.mean(pheno_dict[base_mtype]) best_prop = np.mean(pheno_dict[best_subtype]) / base_size conf_sc = np.greater.outer(conf_list[best_subtype], conf_list[base_mtype]).mean() if conf_sc > 0.9: pnt_dict[conf_vec[base_indx], conf_vec[best_indx]] = ( base_size**0.53, (gene, get_fancy_label(best_subtype))) elif conf_vec[base_indx] > 0.7 or conf_vec[best_indx] > 0.7: pnt_dict[conf_vec[base_indx], conf_vec[best_indx]] = (base_size**0.53, (gene, '')) else: pnt_dict[conf_vec[base_indx], conf_vec[best_indx]] = (base_size**0.53, ('', '')) pie_ax = inset_axes(ax, width=base_size**0.5, height=base_size**0.5, bbox_to_anchor=(conf_vec[base_indx], conf_vec[best_indx]), bbox_transform=ax.transData, loc=10, axes_kwargs=dict(aspect='equal'), borderpad=0) pie_ax.pie(x=[best_prop, 1 - best_prop], explode=[0.29, 0], colors=[ clr_dict[gene] + (0.77, ), clr_dict[gene] + (0.29, ) ]) lbl_pos = place_labels(pnt_dict) for (pnt_x, pnt_y), pos in lbl_pos.items(): ax.text(pos[0][0], pos[0][1] + 700**-1, pnt_dict[pnt_x, pnt_y][1][0], size=13, ha=pos[1], va='bottom') ax.text(pos[0][0], pos[0][1] - 700**-1, pnt_dict[pnt_x, pnt_y][1][1], size=9, ha=pos[1], va='top') x_delta = pnt_x - pos[0][0] y_delta = pnt_y - pos[0][1] ln_lngth = np.sqrt((x_delta**2) + (y_delta**2)) # if the label is sufficiently far away from its point... if ln_lngth > (0.021 + pnt_dict[pnt_x, pnt_y][0] / 31): use_clr = clr_dict[pnt_dict[pnt_x, pnt_y][1][0]] pnt_gap = pnt_dict[pnt_x, pnt_y][0] / (29 * ln_lngth) lbl_gap = 0.006 / ln_lngth ax.plot([pnt_x - pnt_gap * x_delta, pos[0][0] + lbl_gap * x_delta], [ pnt_y - pnt_gap * y_delta, pos[0][1] + lbl_gap * y_delta + 0.008 + 0.004 * np.sign(y_delta) ], c=use_clr, linewidth=2.3, alpha=0.27) ax.plot([0.48, 1.0005], [1, 1], color='black', linewidth=1.9, alpha=0.89) ax.plot([1, 1], [0.48, 1.0005], color='black', linewidth=1.9, alpha=0.89) ax.plot([0.49, 0.997], [0.49, 0.997], linewidth=2.1, linestyle='--', color='#550000', alpha=0.41) ax.set_xlim([0.48, 1.01]) ax.set_ylim([0.48, 1.01]) ax.set_xlabel( "1st quartile of down-sampled AUCs" "\nusing all point mutations", size=21, weight='semibold') ax.set_ylabel( "1st quartile of down-sampled AUCs" "\nof best found subgrouping", size=21, weight='semibold') plt.savefig(os.path.join(plot_dir, '__'.join([args.expr_source, args.cohort]), "sub-comparisons_{}.svg".format(args.classif)), bbox_inches='tight', format='svg') plt.close()
def plot_sub_comparisons(auc_df, pheno_dicts, args, ex_lbl, use_filter, all_aucs, add_lgnd=True): fig, ax = plt.subplots(figsize=(10.3, 11)) filter_fx, base_subtype = sub_filters[use_filter] use_aucs = auc_df[[ get_mut_ex(mut) == ex_lbl and filter_fx(mut) for _, _, mut in auc_df.index ]]['mean'] plot_dict = dict() line_dict = dict() plt_min = 0.57 for (src, coh, gene), auc_vec in use_aucs.groupby( lambda x: (x[0], x[1], get_label(x[2]))): base_mtype = MuType({('Gene', gene): base_subtype}) sub_aucs = auc_vec[[mut != base_mtype for _, _, mut in auc_vec.index]] if len(sub_aucs) == 0: continue best_subtype = sub_aucs.idxmax()[2] auc_tupl = (auc_df.loc[(src, coh, base_mtype), 'mean'], auc_vec[src, coh, best_subtype]) plt_min = min(plt_min, auc_tupl[0] - 0.03, auc_tupl[1] - 0.029) base_size = np.mean(pheno_dicts[src, coh][base_mtype]) best_prop = np.mean(pheno_dicts[src, coh][best_subtype]) best_prop /= base_size plt_size = 0.07 * base_size**0.5 plot_dict[auc_tupl] = [plt_size, ('', '')] line_dict[auc_tupl] = dict(c=choose_label_colour(gene)) cv_sig = (np.array(auc_df['CV'][src, coh, best_subtype]) > np.array( auc_df['CV'][src, coh, base_mtype])).all() # ...and if we are sure that the optimal subgrouping AUC is # better than the point mutation AUC then add a label with the # gene name and a description of the best found subgrouping... if auc_vec.max() >= 0.7: gene_lbl = "{} in {}".format(gene, get_cohort_label(coh)) if cv_sig: if isinstance(best_subtype, MuType): plot_dict[auc_tupl][1] = (gene_lbl, get_fancy_label( get_subtype(best_subtype), pnt_link='\nor ', phrase_link=' ')) else: plot_dict[auc_tupl][1] = (gene_lbl, get_mcomb_lbl(best_subtype)) pie_bbox = (auc_tupl[0] - plt_size / 2, auc_tupl[1] - plt_size / 2, plt_size, plt_size) pie_ax = inset_axes(ax, width='100%', height='100%', bbox_to_anchor=pie_bbox, bbox_transform=ax.transData, axes_kwargs=dict(aspect='equal'), borderpad=0) pie_ax.pie(x=[best_prop, 1 - best_prop], colors=[ line_dict[auc_tupl]['c'] + (0.77, ), line_dict[auc_tupl]['c'] + (0.29, ) ], explode=[0.29, 0], startangle=90) plt_lims = plt_min, 1 + (1 - plt_min) / 181 ax.grid(linewidth=0.83, alpha=0.41) ax.plot(plt_lims, [0.5, 0.5], color='black', linewidth=1.3, linestyle=':', alpha=0.71) ax.plot([0.5, 0.5], plt_lims, color='black', linewidth=1.3, linestyle=':', alpha=0.71) ax.plot(plt_lims, [1, 1], color='black', linewidth=1.9, alpha=0.89) ax.plot([1, 1], plt_lims, color='black', linewidth=1.9, alpha=0.89) ax.plot(plt_lims, plt_lims, color='#550000', linewidth=2.1, linestyle='--', alpha=0.41) ax.set_xlabel("Accuracy of Gene-Wide Classifier", size=23, weight='semibold') ax.set_ylabel("Accuracy of Best Subgrouping Classifier", size=23, weight='semibold') if add_lgnd: ax, plot_dict = add_scatterpie_legend(ax, plot_dict, plt_min, pnt_mtype, args) if plot_dict: lbl_pos = place_scatter_labels(plot_dict, ax, plt_lims=[[plt_min + 0.01, 0.99]] * 2, line_dict=line_dict) ax.set_xlim(plt_lims) ax.set_ylim(plt_lims) plt.savefig(os.path.join( plot_dir, "{}-sub{}-comparisons_{}.svg".format(ex_lbl, use_filter, args.classif)), bbox_inches='tight', format='svg') plt.close()
def plot_phantm_scores(phantm_scrs, pred_df, cdata, args): fig, ax = plt.subplots(figsize=(13, 8)) plt_mtypes = { mtype for mtype in cdata.mtrees[mtree_k]['TP53'].branchtypes(mtype=MuType({ ('Scale', 'Point'): { ('Consequence', ('missense_variant', 'stop_gained', 'synonymous_variant')): None } }), ) if 'HGVSp' in mtype.get_levels() } pred_scrs = pred_df.loc[base_mtype].apply(np.mean) ls_phn = np.array(cdata.train_pheno(ls_mtype)) plt_dict = dict() for mtype in plt_mtypes: mtype_lbl = get_fancy_label(mtype) if mtype_lbl in phantm_scrs: mtype_phn = np.array(cdata.train_pheno(mtype)) & ls_phn if mtype_phn.any(): plt_dict[mtype] = mtype_lbl, mtype_phn else: print("Could not find `{}` in PHANTM table!".format(mtype_lbl)) for mtype, (lbl, phn) in plt_dict.items(): mtype_scrs = pred_scrs[phn].mean() plt_sz = 71003 * np.mean(phn) mtype_cnsq = tuple(tuple(mtype.subtype_iter())[0][1].label_iter())[0] if mtype_cnsq == 'missense_variant': plt_clr = form_clrs['Missense_Mutation'] elif mtype_cnsq == 'stop_gained': plt_clr = form_clrs['Nonsense_Mutation'] elif mtype_cnsq == 'synonymous_variant': plt_clr = form_clrs['Silent'] else: raise ValueError( "Unknown mutation consequence `{}`!".format(mtype_cnsq)) ax.scatter(pred_scrs[phn].mean(), phantm_scrs[lbl], c=[plt_clr], s=plt_sz, alpha=0.31, edgecolor='none') ax.grid(linewidth=0.71, alpha=0.37) ax.tick_params(labelsize=15) coh_lbl = get_cohort_label(args.cohort) ax.set_xlabel("Predicted TP53 Scores\nin {}".format(coh_lbl), fontsize=27, weight='semibold') ax.set_ylabel("PHANTM Combined\nPhenotype Score", fontsize=27, weight='semibold') plt.tight_layout(h_pad=1.7) fig.savefig(os.path.join(plot_dir, '__'.join([args.expr_source, args.cohort]), "pred-scores_{}.svg".format(args.classif)), bbox_inches='tight', format='svg') plt.close()
def plot_symmetry_decomposition(pred_df, pheno_dict, auc_vals, cdata, args, plt_gene, ex_lbl, siml_metric): use_mtree = tuple(cdata.mtrees.values())[0][plt_gene] use_combs = auc_vals.index.tolist() use_pairs = [(mcomb1, mcomb2) for mcomb1, mcomb2 in combn(use_combs, 2) if (all( (mtp1 & mtp2).is_empty() for mtp1, mtp2 in product(mcomb1.mtypes, mcomb2.mtypes)) or not (pheno_dict[mcomb1] & pheno_dict[mcomb2]).any())] if not use_pairs: print("no suitable pairs found among {} possible " "mutations for: {}({}) !".format(len(use_combs), plt_gene, ex_lbl)) return True if len(use_pairs) > PLOT_MAX: print("found {} suitable pairs for {}({}), only plotting " "the top {} by max AUC!".format(len(use_pairs), plt_gene, ex_lbl, PLOT_MAX)) use_pairs = pd.Series({ tuple(mcombs): max(auc_vals[mcomb] for mcomb in mcombs) for mcombs in use_pairs }).sort_values()[-(PLOT_MAX):].index.tolist() mcomb_clx = {mcomb: classify_mcomb(mcomb) for mcomb in use_combs} cls_counts = pd.Series( reduce(add, [[mcomb_clx[mcomb] for mcomb in use_pair] for use_pair in use_pairs])).value_counts() if len(cls_counts) == 1: print("only one partition found, cannot plot decomposition " "for {}({}) !".format(plt_gene, ex_lbl)) return True fig, axarr = plt.subplots( figsize=(1.5 + 3 * len(cls_counts), 1 + 3 * len(cls_counts)), nrows=1 + len(cls_counts), ncols=1 + len(cls_counts), gridspec_kw=dict(width_ratios=[1] + [2] * len(cls_counts), height_ratios=[7] * len(cls_counts) + [2])) all_mtype = MuType({('Gene', plt_gene): use_mtree.allkey()}) if ex_lbl == 'IsoShal': all_mtype -= MuType({('Gene', plt_gene): shal_mtype}) pair_combs = set(reduce(add, use_pairs)) train_samps = cdata.get_train_samples() use_preds = pred_df.loc[pair_combs, train_samps] all_phn = np.array(cdata.train_pheno(all_mtype)) wt_vals = {mcomb: use_preds.loc[mcomb, ~all_phn] for mcomb in pair_combs} mut_vals = { mcomb: use_preds.loc[mcomb, pheno_dict[mcomb]] for mcomb in pair_combs } if siml_metric == 'mean': chunk_size = int(0.91 * len(use_pairs) / args.cores) + 1 wt_means = {mcomb: vals.mean() for mcomb, vals in wt_vals.items()} mut_means = {mcomb: vals.mean() for mcomb, vals in mut_vals.items()} map_args = [(wt_vals[mcomb1], mut_vals[mcomb1], use_preds.loc[mcomb1, pheno_dict[mcomb2]], wt_means[mcomb1], mut_means[mcomb1], None) for mcombs in use_pairs for mcomb1, mcomb2 in permt(mcombs)] elif siml_metric == 'ks': chunk_size = int(0.91 * len(use_pairs) / args.cores) + 1 base_dists = { mcomb: ks_2samp(wt_vals[mcomb], mut_vals[mcomb], alternative='greater').statistic for mcomb in pair_combs } map_args = [(wt_vals[mcomb1], mut_vals[mcomb1], use_preds.loc[mcomb1, pheno_dict[mcomb2]], base_dists[mcomb1]) for mcombs in use_pairs for mcomb1, mcomb2 in permt(mcombs)] pool = mp.Pool(args.cores) siml_list = pool.starmap(siml_fxs[siml_metric], map_args, chunk_size) pool.close() siml_vals = dict(zip(use_pairs, zip(siml_list[::2], siml_list[1::2]))) size_mult = max(727 - math.log(len(use_pairs), 1 + 1 / 77), 31) PAIR_CLRS = ['#0DAAFF', '#FF8B00'] acc_norm = colors.Normalize(vmin=args.auc_cutoff, vmax=auc_vals.max()) acc_cmap = sns.cubehelix_palette(start=1.07, rot=1.31, gamma=0.83, light=0.19, dark=0.73, reverse=True, as_cmap=True) plt_sizes = { (mcomb1, mcomb2): size_mult * (np.mean(pheno_dict[mcomb1]) * np.mean(pheno_dict[mcomb2]))**0.5 for mcomb1, mcomb2 in use_pairs } for (i, cls1), (j, cls2) in combn(enumerate(cls_counts.index), 2): pair_count = len(plt_sizes) for (mcomb1, mcomb2), plt_sz in plt_sizes.items(): if mcomb_clx[mcomb1] == cls2 and mcomb_clx[mcomb2] == cls1: use_clr, use_alpha = PAIR_CLRS[0], 1 / 6.1 elif mcomb_clx[mcomb1] == cls1 and mcomb_clx[mcomb2] == cls2: use_clr, use_alpha = PAIR_CLRS[1], 1 / 6.1 else: use_clr, use_alpha = '0.61', 1 / 17 pair_count -= 1 axarr[i, j + 1].scatter(*siml_vals[mcomb1, mcomb2], c=[use_clr], s=plt_sz, alpha=use_alpha, edgecolor='none') if use_clr in PAIR_CLRS: axarr[j, i + 1].scatter(*siml_vals[mcomb1, mcomb2], c=[acc_cmap(acc_norm(auc_vals[mcomb1]))], s=plt_sz, alpha=use_alpha, edgecolor='none') if pair_count == 1: pair_lbl = "1 pair" else: pair_lbl = "{} pairs".format(pair_count) axarr[j, i + 1].text(0.01, 1, pair_lbl, size=13, ha='left', va='bottom', fontstyle='italic', transform=axarr[j, i + 1].transAxes) axarr[i, j + 1].text(0.99, 1, "({})".format(pair_count), size=13, ha='right', va='bottom', fontstyle='italic', transform=axarr[i, j + 1].transAxes) plt_lims = min(siml_list) - 0.07, max(siml_list) + 0.07 plt_gap = (plt_lims[1] - plt_lims[0]) / 53 cls_counts: pd.Series clx_counts = pd.Series(mcomb_clx).value_counts() for i, (cls, cls_count) in enumerate(cls_counts.iteritems()): axarr[-1, i + 1].text(0.5, 13 / 17, cls, size=23, ha='center', va='top', fontweight='semibold', transform=axarr[-1, i + 1].transAxes) if clx_counts[cls] == 1: count_lbl = "1 subgrouping" else: count_lbl = "{} subgroupings".format(clx_counts[cls]) axarr[-1, i + 1].text(0.5, -1 / 7, count_lbl, size=19, ha='center', va='bottom', fontstyle='italic', transform=axarr[-1, i + 1].transAxes) for (mcomb1, mcomb2), plt_sz in plt_sizes.items(): if mcomb_clx[mcomb1] == cls and mcomb_clx[mcomb2] == cls: use_clr, use_alpha = 'black', 0.37 elif mcomb_clx[mcomb1] == cls: use_clr, use_alpha = PAIR_CLRS[0], 0.19 elif mcomb_clx[mcomb2] == cls: use_clr, use_alpha = PAIR_CLRS[1], 0.19 else: use_clr, use_alpha = '0.73', 1 / 6.1 axarr[i, i + 1].scatter(*siml_vals[mcomb1, mcomb2], c=[use_clr], s=plt_sz, alpha=use_alpha, edgecolor='none') if cls_count == 1: cls_lbl = "1 total pair" else: cls_lbl = "{} total pairs".format(cls_count) axarr[i, i + 1].text(0.99, 1, cls_lbl, size=13, ha='right', va='bottom', fontstyle='italic', transform=axarr[i, i + 1].transAxes) axarr[-2, i + 1].add_patch( ptchs.Rectangle((0.02, -0.23), 0.96, 0.061, facecolor=PAIR_CLRS[0], alpha=0.61, edgecolor='none', transform=axarr[-2, i + 1].transAxes, clip_on=False)) clr_ax = axarr[-2, 0].inset_axes(bounds=(1 / 3, -3 / 17, 4 / 7, 43 / 23), clip_on=False, in_layout=False) clr_bar = ColorbarBase(ax=clr_ax, cmap=acc_cmap, norm=acc_norm, ticklocation='left') clr_ax.set_title("AUC", size=21, fontweight='bold') clr_ax.yaxis.set_major_locator(plt.MaxNLocator(7, steps=[1, 2, 4, 5])) tcks_loc = clr_ax.get_yticks().tolist() clr_ax.yaxis.set_major_locator(mpl.ticker.FixedLocator(tcks_loc)) clr_bar.ax.set_yticklabels( [format(tick, '.2f').lstrip('0') for tick in tcks_loc], size=15, fontweight='semibold') siml_norm = colors.Normalize(vmin=-1, vmax=2) plt_lctr = plt.MaxNLocator(5, steps=[1, 2, 5]) for ax in axarr[:-1, 1:].flatten(): ax.grid(alpha=0.47, linewidth=0.7) ax.plot(plt_lims, [0, 0], color='black', linewidth=0.83, linestyle=':', alpha=0.47) ax.plot([0, 0], plt_lims, color='black', linewidth=0.83, linestyle=':', alpha=0.47) ax.plot(plt_lims, plt_lims, color='#550000', linewidth=1.13, linestyle='--', alpha=0.37) for siml_val in [-1, 1, 2]: ax.plot(plt_lims, [siml_val] * 2, color=simil_cmap(siml_norm(siml_val)), linewidth=2.7, linestyle=':', alpha=0.31) ax.plot([siml_val] * 2, plt_lims, color=simil_cmap(siml_norm(siml_val)), linewidth=2.7, linestyle=':', alpha=0.31) ax.set_xlim(*plt_lims) ax.set_ylim(*plt_lims) ax.xaxis.set_major_locator(plt_lctr) ax.yaxis.set_major_locator(plt_lctr) for i in range(len(cls_counts)): for j in range(1, len(cls_counts) + 1): if i != (j - 1): axarr[i, j].set_xticklabels([]) axarr[i, j].set_yticklabels([]) else: axarr[i, j].tick_params(labelsize=11) for ax in axarr[:, 0].tolist() + axarr[-1, :].tolist(): ax.axis('off') plt.tight_layout(w_pad=2 / 7, h_pad=2 / 7) plt.savefig(os.path.join( plot_dir, '__'.join([args.expr_source, args.cohort]), "{}_{}_{}-symm-decomposition_{}.svg".format(plt_gene, ex_lbl, siml_metric, args.classif)), bbox_inches='tight', format='svg') plt.close()
def main(): parser = argparse.ArgumentParser() parser.add_argument('cohort', type=str, help="which TCGA cohort to use") parser.add_argument( 'samp_cutoff', type=int, help="minimum number of mutated samples needed to test a gene") # parse command line arguments, identify directory where intermediate # files are to be stored, load cohort expression and mutation data parser.add_argument('--setup_dir', type=str, default=base_dir) args = parser.parse_args() out_path = os.path.join(args.setup_dir, 'setup') cdata = get_cohort_data(args.cohort) # save cohort data to file for use by future tasks with open(os.path.join(out_path, "cohort-data.p"), 'wb') as cdata_fl: pickle.dump(cdata, cdata_fl) # find subsets of point mutations with enough affected samples for each # mutated gene in the cohort vars_list = reduce(or_, [{ MuType({('Gene', gene): mtype}) for mtype in muts['Point'].branchtypes(min_size=args.samp_cutoff) } for gene, muts in cdata.mtree if ('Scale', 'Point') in muts.allkey()], set()) # add copy number deletions for each gene if enough samples are affected vars_list |= { MuType({('Gene', gene): { ('Copy', 'DeepDel'): None }}) for gene, muts in cdata.mtree if (('Scale', 'Copy') in muts.allkey() and ('Copy', 'DeepDel') in muts['Copy'].allkey() and len(muts['Copy']['DeepDel']) >= args.samp_cutoff) } # add copy number amplifications for each gene vars_list |= { MuType({('Gene', gene): { ('Copy', 'DeepGain'): None }}) for gene, muts in cdata.mtree if (('Scale', 'Copy') in muts.allkey() and ( 'Copy', 'DeepGain') in muts['Copy'].allkey() and len(muts['Copy']['DeepGain']) >= args.samp_cutoff) } # add all point mutations as a single mutation type for each gene if it # contains more than one type of point mutation vars_list |= { MuType({('Gene', gene): { ('Scale', 'Point'): None }}) for gene, muts in cdata.mtree if (('Scale', 'Point') in muts.allkey() and len(muts['Point'].allkey()) > 1 and len(muts['Point']) >= args.samp_cutoff) } # filter out mutations that do not have enough wild-type samples vars_list = { mtype for mtype in vars_list if (len(mtype.get_samples(cdata.mtree)) <= (len(cdata.get_samples()) - args.samp_cutoff)) } # remove mutations that are functionally equivalent to another mutation vars_list -= { mtype1 for mtype1, mtype2 in product(vars_list, repeat=2) if (mtype1 != mtype2 and mtype1.is_supertype(mtype2) and ( mtype1.get_samples(cdata.mtree) == mtype2.get_samples(cdata.mtree)) ) } # find the pairs of remaining mutations that do not have overlapping # definitions and have enough samples with exactly one mutation in the pair samp_dict = {mtype: mtype.get_samples(cdata.mtree) for mtype in vars_list} pairs_list = { tuple(sorted([mtype1, mtype2])) for (mtype1, samps1), (mtype2, samps2) in combn(samp_dict.items(), 2) if (len(samps1 - samps2) >= args.samp_cutoff and len(samps2 - samps1) >= args.samp_cutoff and (mtype1 & mtype2).is_empty()) } # save the enumerate pairs to file along with a count of the pairs with open(os.path.join(out_path, "pairs-list.p"), 'wb') as f: pickle.dump(sorted(pairs_list), f) with open(os.path.join(out_path, "pairs-count.txt"), 'w') as fl: fl.write(str(len(pairs_list)))
def plot_gene_isolation(mtypes, infer_dict, cdata, args): fig, axarr = plt.subplots(figsize=(0.1 + 3 * len(mtypes), 9.2), nrows=3, ncols=len(mtypes), sharex=True) all_mtype = MuType(cdata.train_mut.allkey()) for j, cur_mtype in enumerate(mtypes): cur_mcombs = { 'All': cur_mtype, 'Ex': ExMcomb(cdata.train_mut, cur_mtype) } mtype_str = str(cur_mtype).split(':')[-1] cur_stats = { lbl: np.array(cdata.train_pheno(mtp)) for lbl, mtp in cur_mcombs.items() } axarr[0, j].text(0.5, 1.03, "{} mutations\n({} affected samples)".format( mtype_str, np.sum(cur_stats['All'])), size=13, ha='center', va='bottom', transform=axarr[0, j].transAxes) for xval in [0, 2]: axarr[2, j].text(xval, -0.05, "all other\nsamps", ha='center', va='top', size=9) for xval in [1, 3]: axarr[2, j].text(xval, -0.05, "only {}\nWT samps".format(args.gene), ha='center', va='top', size=9) axarr[2, j].text(0.5, -0.18, "all samps w/\n{} muts".format(mtype_str), ha='center', va='top', size=9) axarr[2, j].text(2.5, -0.18, "samps w/ only\n{} muts".format(mtype_str), ha='center', va='top', size=9) infer_vals = pd.DataFrame({ "{}_{}".format(lbl, smps): vals.loc[cur_mcombs[lbl]] for (lbl, stat), ( smps, vals) in product(cur_stats.items(), infer_dict.items()) }) infer_vals = ((infer_vals - infer_vals.min()) / (infer_vals.max() - infer_vals.min())) for lbl, stat in cur_stats.items(): infer_vals["cStat_{}".format(lbl)] = stat iso_mtypes = [(lbl, (all_mtype & mtype) - cur_mtype) if mtype.is_supertype(cur_mtype) else (lbl, mtype) for lbl, mtype in variant_mtypes] for i, (lbl, other_mtype) in enumerate(iso_mtypes): val_df = infer_vals.copy() val_df['oStat'] = np.array(cdata.train_pheno(other_mtype)) if j == 0: axarr[i, j].text(-0.04, 0.5, "{}\n({} samples)".format( lbl, np.sum(val_df['oStat'])), ha='right', va='center', size=11, transform=axarr[i, j].transAxes) val_df = pd.melt(val_df, id_vars=['cStat_All', 'cStat_Ex', 'oStat'], var_name='Samps', value_name='Val') val_df['cStat'] = np.where( val_df.Samps.str.split('_').apply(itemgetter(0)) == 'All', val_df.cStat_All, val_df.cStat_Ex) sns.violinplot(data=val_df[~val_df.cStat], x='Samps', y='Val', hue='oStat', palette=[mut_clrs['Wild-Type']], linewidth=0, bw=0.11, split=True, cut=0, order=['All_All', 'All_Iso', 'Ex_All', 'Ex_Iso'], ax=axarr[i, j]) sns.violinplot(data=val_df[val_df.cStat], x='Samps', y='Val', hue='oStat', palette=[mut_clrs['Mutant']], linewidth=0, bw=0.11, split=True, cut=0, order=['All_All', 'All_Iso', 'Ex_All', 'Ex_Iso'], ax=axarr[i, j]) axarr[i, j].xaxis.label.set_visible(False) axarr[i, j].yaxis.label.set_visible(False) axarr[i, j].set_xticklabels([]) axarr[i, j].set_yticklabels([]) axarr[i, j].set_ylim(-0.02, 1.02) for xval in range(4): axarr[i, j].axvline(x=xval, ymin=-1, ymax=2, color='black', linewidth=1.3, alpha=0.61) axarr[i, j].get_legend().remove() for art in axarr[i, j].get_children(): if isinstance(art, PolyCollection): art.set_alpha(0.41) axarr[2, j].legend([ Patch(color=mut_clrs['Mutant'], alpha=0.36), Patch(color=mut_clrs['Wild-Type'], alpha=0.36) ], ["{} Mutants".format(mtype_str), "{} Wild-Types".format(mtype_str)], fontsize=11, ncol=1, loc=9, bbox_to_anchor=( 0.5, -0.26)).get_frame().set_linewidth(0.0) axarr[2, 0].text(-1, -0.05, "Negative\nClassification Set", ha='right', va='top', size=10) axarr[2, 0].text(-1, -0.17, "Positive\nClassification Set", ha='right', va='top', size=10) plt.tight_layout() plt.savefig(os.path.join( plot_dir, args.cohort, "gene-isolation_{}_{}_{}_samps-{}.png".format( args.gene, args.mut_levels.replace('__', '-'), args.classif, args.samp_cutoff)), dpi=300, bbox_inches='tight') plt.close()
def main(): parser = argparse.ArgumentParser( 'setup_test', description="Load datasets and enumerate subgroupings to be tested." ) parser.add_argument('expr_source', help="a source of expression datasets") parser.add_argument('cohort', help="a tumour sample -omic dataset") parser.add_argument( 'samp_cutoff', type=int, help="minimum number of affected samples needed to test a mutation" ) parser.add_argument('mut_levels', type=str, help="a combination of mutation attribute levels") parser.add_argument('out_dir', type=str, help="the working directory for this experiment") # parse command line arguments, figure out where output will be stored, # get the mutation attributes and cancer genes that will be used args = parser.parse_args() out_path = os.path.join(args.out_dir, 'setup') lvl_list = ('Gene', 'Scale', 'Copy') + tuple(args.mut_levels.split('__')) use_genes = get_gene_list(min_sources=2) # load and process the -omic datasets for this cohort cdata = get_cohort_data(args.cohort, args.expr_source, lvl_list, vep_cache_dir, out_path, use_genes) with bz2.BZ2File(os.path.join(out_path, "cohort-data.p.gz"), 'w') as f: pickle.dump(cdata, f, protocol=-1) # get the maximum number of samples allowed per subgrouping, initialize # the list of enumerated subgroupings max_samps = len(cdata.get_samples()) - args.samp_cutoff use_mtypes = set() # for each gene with enough samples harbouring its point mutations in the # cohort, find the subgroupings composed of at most two branches for gene, mtree in cdata.mtrees[lvl_list]: if len(pnt_mtype.get_samples(mtree)) >= args.samp_cutoff: pnt_mtypes = { mtype for mtype in mtree['Point'].combtypes( comb_sizes=(1, 2), min_type_size=args.samp_cutoff) if (args.samp_cutoff <= len(mtype.get_samples(mtree)) <= max_samps) } # remove subgroupings that have only one child subgrouping # containing all of their samples pnt_mtypes -= { mtype1 for mtype1, mtype2 in product(pnt_mtypes, repeat=2) if mtype1 != mtype2 and mtype1.is_supertype(mtype2) and mtype1.get_samples(mtree) == mtype2.get_samples(mtree) } # remove groupings that contain all of the gene's point mutations pnt_mtypes = {MuType({('Scale', 'Point'): mtype}) for mtype in pnt_mtypes if (len(mtype.get_samples(mtree['Point'])) < len(mtree['Point'].get_samples()))} # check if this gene had at least five samples with deep gains or # deletions that weren't all already carrying point mutations copy_mtypes = { mtype for mtype in [dup_mtype, loss_mtype] if ((5 <= len(mtype.get_samples(mtree)) <= (len(cdata.get_samples()) - 5)) and not (mtype.get_samples(mtree) <= mtree['Point'].get_samples()) and not (mtree['Point'].get_samples() <= mtype.get_samples(mtree))) } # find the enumerated point mutations for this gene that can be # combined with CNAs to produce a novel set of mutated samples dyad_mtypes = { pt_mtype | cp_mtype for pt_mtype, cp_mtype in product(pnt_mtypes, copy_mtypes) if ((pt_mtype.get_samples(mtree) - cp_mtype.get_samples(mtree)) and (cp_mtype.get_samples(mtree) - pt_mtype.get_samples(mtree))) } # if we are using the base list of mutation attributes, add the # gene-wide set of all point mutations... gene_mtypes = pnt_mtypes | dyad_mtypes if args.mut_levels == 'Consequence__Exon': gene_mtypes |= {pnt_mtype} # ...as well as CNA-only subgroupings... gene_mtypes |= { mtype for mtype in copy_mtypes if (args.samp_cutoff <= len(mtype.get_samples(mtree)) <= max_samps) } # ...and finally the CNA + all point mutations subgroupings gene_mtypes |= { pnt_mtype | mtype for mtype in copy_mtypes if (args.samp_cutoff <= len((pnt_mtype | mtype).get_samples(mtree)) <= max_samps) } use_mtypes |= {MuType({('Gene', gene): mtype}) for mtype in gene_mtypes} # set a random seed for use in picking random subgroupings lvls_seed = np.prod([(ord(char) % 7 + 3) for i, char in enumerate(args.mut_levels) if (i % 5) == 1]) # makes sure random subgroupings are the same between different runs # of this experiment mtype_list = sorted(use_mtypes) random.seed((88701 * lvls_seed + 1313) % (2 ** 17)) random.shuffle(mtype_list) # generate random subgroupings chosen from all samples in the cohort use_mtypes |= { RandomType(size_dist=len(mtype.get_samples(cdata.mtrees[lvl_list])), seed=(lvls_seed * (i + 3751) + 19207) % (2 ** 26)) for i, (mtype, _) in enumerate(product(mtype_list, range(5))) if (mtype & copy_mtype).is_empty() } # generate random subgroupings chosen from samples mutated for each gene use_mtypes |= { RandomType( size_dist=len(mtype.get_samples(cdata.mtrees[lvl_list])), base_mtype=MuType({ ('Gene', tuple(mtype.label_iter())[0]): pnt_mtype}), seed=(lvls_seed * (i + 1021) + 7391) % (2 ** 23) ) for i, (mtype, _) in enumerate(product(mtype_list, range(5))) if ((mtype & copy_mtype).is_empty() and tuple(mtype.subtype_iter())[0][1] != pnt_mtype) } # save enumerated subgroupings and number of subgroupings to file with open(os.path.join(out_path, "muts-list.p"), 'wb') as f: pickle.dump(sorted(use_mtypes), f, protocol=-1) with open(os.path.join(out_path, "muts-count.txt"), 'w') as fl: fl.write(str(len(use_mtypes))) # get list of available cohorts for transference of classifiers coh_list = list_cohorts('Firehose', expr_dir=expr_sources['Firehose'], copy_dir=expr_sources['Firehose']) coh_list -= {args.cohort} coh_list |= {'METABRIC', 'beatAML', 'CCLE'} # initiate set of genetic expression features, reset random seed coh_dir = os.path.join(args.out_dir.split('subgrouping_test')[0], 'subgrouping_test', 'setup') use_feats = set(cdata.get_features()) random.seed() # for each cohort used for transferring... for coh in random.sample(coh_list, k=len(coh_list)): coh_base = coh.split('_')[0] # ...choose a default source of expression data if coh_base in {'METABRIC', 'CCLE'}: use_src = 'microarray' elif coh_base in {'beatAML'}: use_src = 'toil__gns' else: use_src = 'Firehose' # ...figure out where to store its pickled representation coh_tag = "cohort-data__{}__{}.p".format(use_src, coh) coh_path = os.path.join(coh_dir, coh_tag) # load and process the cohort's -omic datasets, update the list of # expression features common across all cohorts trnsf_cdata = load_cohort(coh, use_src, lvl_list, vep_cache_dir, coh_path, out_path, use_genes) use_feats &= set(trnsf_cdata.get_features()) with open(coh_path, 'wb') as f: pickle.dump(trnsf_cdata, f, protocol=-1) copy_prc = subprocess.run(['cp', coh_path, os.path.join(out_path, coh_tag)], check=True) with open(os.path.join(out_path, "feat-list.p"), 'wb') as f: pickle.dump(use_feats, f, protocol=-1)
def plot_sub_comparisons(conf_vals, pheno_dict, args): fig, ax = plt.subplots(figsize=(10.3, 11)) plot_dict = dict() clr_dict = dict() plt_min = 0.57 conf_list = conf_vals[[ not isinstance(mtype, RandomType) and (tuple(mtype.subtype_iter())[0][1] & copy_mtype).is_empty() for mtype in conf_vals.index ]].apply(lambda confs: np.percentile(confs, 25)) for gene, conf_vec in conf_list.groupby( lambda mtype: tuple(mtype.label_iter())[0]): if len(conf_vec) > 1: base_mtype = MuType({('Gene', gene): pnt_mtype}) base_indx = conf_vec.index.get_loc(base_mtype) best_subtype = conf_vec[:base_indx].append( conf_vec[(base_indx + 1):]).idxmax() if conf_vec[best_subtype] > 0.6: auc_tupl = conf_vec[base_mtype], conf_vec[best_subtype] clr_dict[auc_tupl] = choose_label_colour(gene) base_size = np.mean(pheno_dict[base_mtype]) plt_size = 0.07 * base_size ** 0.5 plot_dict[auc_tupl] = [plt_size, ('', '')] plt_min = min(plt_min, conf_vec[base_indx] - 0.053, conf_vec[best_subtype] - 0.029) best_prop = np.mean(pheno_dict[best_subtype]) / base_size conf_sc = calc_conf(conf_vals[best_subtype], conf_vals[base_mtype]) if conf_sc > 0.8: plot_dict[auc_tupl][1] = gene, get_fancy_label( tuple(best_subtype.subtype_iter())[0][1], pnt_link='\n', phrase_link=' ' ) elif auc_tupl[0] > 0.7 or auc_tupl[1] > 0.7: plot_dict[auc_tupl][1] = gene, '' auc_bbox = (auc_tupl[0] - plt_size / 2, auc_tupl[1] - plt_size / 2, plt_size, plt_size) pie_ax = inset_axes( ax, width='100%', height='100%', bbox_to_anchor=auc_bbox, bbox_transform=ax.transData, axes_kwargs=dict(aspect='equal'), borderpad=0 ) pie_ax.pie(x=[best_prop, 1 - best_prop], colors=[clr_dict[auc_tupl] + (0.77,), clr_dict[auc_tupl] + (0.29,)], explode=[0.29, 0], startangle=90) plt_lims = plt_min, 1 + (1 - plt_min) / 61 ax.plot(plt_lims, [0.5, 0.5], color='black', linewidth=1.3, linestyle=':', alpha=0.71) ax.plot([0.5, 0.5], plt_lims, color='black', linewidth=1.3, linestyle=':', alpha=0.71) ax.plot(plt_lims, [1, 1], color='black', linewidth=1.9, alpha=0.89) ax.plot([1, 1], plt_lims, color='black', linewidth=1.9, alpha=0.89) ax.plot(plt_lims, plt_lims, color='#550000', linewidth=2.1, linestyle='--', alpha=0.41) ax.set_xlabel("1st quartile of down-sampled AUCs" "\nusing all point mutations", size=21, weight='semibold') ax.set_ylabel("1st quartile of down-sampled AUCs" "\nof best found subgrouping", size=21, weight='semibold') if plot_dict: lbl_pos = place_scatter_labels(plot_dict, ax, plt_lims=[plt_lims, plt_lims]) ax.set_xlim(plt_lims) ax.set_ylim(plt_lims) plt.savefig( os.path.join(plot_dir, '__'.join([args.expr_source, args.cohort]), "sub-comparisons_{}.svg".format(args.classif)), bbox_inches='tight', format='svg' ) plt.close()
def main(): parser = argparse.ArgumentParser() parser.add_argument( '--tune_splits', type=int, default=4, help='how many training cohort splits to use for tuning') parser.add_argument( '--test_count', type=int, default=24, help='how many hyper-parameter values to test in each tuning split') parser.add_argument( '--infer_splits', type=int, default=24, help='how many cohort splits to use for inference bootstrapping') parser.add_argument( '--infer_folds', type=int, default=4, help=('how many parts to split the cohort into in each inference ' 'cross-validation run')) parser.add_argument( '--parallel_jobs', type=int, default=12, help='how many parallel CPUs to allocate the tuning tests across') parser.add_argument('--cv_id', type=int, default=0) parser.add_argument('--verbose', '-v', action='store_true', help='turns on diagnostic messages') args = parser.parse_args() out_dir = os.path.join(base_dir, 'output', 'gene_models') os.makedirs(out_dir, exist_ok=True) out_file = os.path.join(out_dir, 'cv-{}.p'.format(args.cv_id)) #test_dir = os.path.join(base_dir, '..', 'mut_baseline', 'output', # 'Firehose', 'PAAD__samps-25') #test_models = os.listdir(test_dir) #test_dict = dict() #for test_model in test_models: # test_fls = [ # test_fl for test_fl in os.listdir( # os.path.join(test_dir, test_model)) # if 'out__' in test_fl # ] # log_fls = [ # log_fl for log_fl in os.listdir(os.path.join( # test_dir, test_model, 'slurm')) # if 'fit-' in log_fl # ] # if len(log_fls) > 0 and len(log_fls) == (len(test_fls) * 2): # test_dict[test_model] = load_baseline('Firehose', 'PAAD', 25, # test_model)[0] # log into Synapse using locally stored credentials syn = synapseclient.Synapse() syn.cache.cache_root_dir = syn_root syn.login() mut_clf = StanPipe() test_genes = ['KRAS', 'SMAD4', 'TP53'] cdata = PatientMutationCohort(patient_expr=load_bcc_expression(bcc_dir), patient_muts=None, tcga_cohort='PAAD', mut_genes=test_genes, mut_levels=['Gene', 'Form'], expr_source='toil', var_source='mc3', copy_source='Firehose', annot_file=annot_file, expr_dir=toil_dir, copy_dir=copy_dir, cv_seed=(args.cv_id * 59) + 121, cv_prop=1.0, collapse_txs=True, syn=syn) tuned_params = {gene: None for gene in test_genes} infer_mats = {gene: None for gene in test_genes} for gene in test_genes: base_mtype = MuType({('Gene', gene): None}) mut_clf.tune_coh(cdata, base_mtype, exclude_genes={gene}, exclude_samps=cdata.patient_samps, tune_splits=args.tune_splits, test_count=args.test_count, parallel_jobs=args.parallel_jobs) clf_params = mut_clf.get_params() tuned_params[gene] = { par: clf_params[par] for par, _ in StanPipe.tune_priors } print(tuned_params) infer_mats[gene] = mut_clf.infer_coh(cdata, base_mtype, force_test_samps=bcc_samps, exclude_genes={gene}, infer_splits=args.infer_splits, infer_folds=args.infer_folds) pickle.dump({ 'Infer': infer_mats, 'Tune': tuned_params }, open(out_file, 'wb'))
def main(): data_dir = os.path.join(base_dir, "resources") expr_data = pd.read_csv(os.path.join(data_dir, "expr.txt.gz"), sep='\t', index_col=0) mut_data = pd.read_csv(os.path.join(data_dir, "variants.txt.gz"), sep='\t', index_col=0) cdata = BaseMutationCohort(expr_data, mut_data, mut_levels=[['Gene']], mut_genes=['GATA3'], cv_seed=101, test_prop=0.3) test_mtype = MuType({('Gene', 'GATA3'): None}) clf = Lasso() clf, cvs = clf.tune_coh(cdata, test_mtype, test_count=16, tune_splits=4, parallel_jobs=1) best_indx = np.argmax(cvs['mean_test_score'] - cvs['std_test_score']) for param, _ in clf.tune_priors: assert clf.get_params()[param] == cvs['params'][best_indx][param] use_feats = cdata.get_features()[::3] clf.fit_coh(cdata, test_mtype, include_feats=use_feats) train_preds = clf.predict_train(cdata, include_feats=use_feats) test_preds = clf.predict_test(cdata, include_feats=use_feats) test_preds = clf.predict_test(cdata, include_feats=use_feats, lbl_type='prob') clf.fit_coh(cdata, test_mtype) train_preds = clf.predict_train(cdata) test_preds = clf.predict_test(cdata, lbl_type='raw') tuned_coefs = np.floor(expr_data.shape[1] * (clf.named_steps['feat'].mean_perc / 100)) assert tuned_coefs == len(clf.named_steps['fit'].coef_[0]), ( "Tuned feature selection step does not match number of features that " "were fit over!") assert len(clf.get_coef()) <= len(clf.expr_genes), ( "Pipeline produced more gene coefficients than genes " "it was originally given!") train_auc = clf.eval_coh(cdata, test_mtype, use_train=True) print("Lasso model training AUC: {:.3f}".format(train_auc)) assert train_auc >= 0.6, ( "Lasso model did not obtain a training AUC of at least 0.6!") test_auc = clf.eval_coh(cdata, test_mtype, use_train=False) print("Lasso model testing AUC: {:.3f}".format(test_auc)) assert test_auc >= 0.6, ( "Lasso model did not obtain a testing AUC of at least 0.6!") infer_mat = clf.infer_coh(cdata, test_mtype, infer_splits=8, infer_folds=4, parallel_jobs=1) assert len(infer_mat) == 143, ( "Pipeline inference did not produce scores for each sample!") clf = RidgeWhite() clf.tune_coh(cdata, test_mtype, test_count=4, tune_splits=2, parallel_jobs=1) clf.fit_coh(cdata, test_mtype) infer_mat = clf.infer_coh(cdata, test_mtype, infer_splits=12, infer_folds=4, parallel_jobs=2) assert len(infer_mat) == len(cdata.get_train_samples()) assert {len(vals) for vals in infer_mat} == {3} assert {hasattr(v, '__len__') for vals in infer_mat for v in vals} == {False} assert not (0 <= np.concatenate(infer_mat)).all() clf = SVCrbf() clf.tune_coh(cdata, test_mtype, test_count=4, tune_splits=2, parallel_jobs=1) clf.fit_coh(cdata, test_mtype) print("All pipeline tests passed successfully!")
def plot_dyad_comparisons(auc_vals, pheno_dict, conf_vals, use_src, use_coh, args): fig, (gain_ax, loss_ax) = plt.subplots(figsize=(17, 8), nrows=1, ncols=2) pnt_aucs = auc_vals[[ not isinstance(mtype, (Mcomb, ExMcomb)) and (tuple(mtype.subtype_iter())[0][1] & copy_mtype).is_empty() for mtype in auc_vals.index ]] plot_df = pd.DataFrame(index=pnt_aucs.index, columns=pd.MultiIndex.from_product( [['gain', 'loss'], ['all', 'deep']]), dtype=float) for pnt_type, (copy_indx, copy_type) in product( pnt_aucs.index, zip(plot_df.columns, [gains_mtype, dup_mtype, dels_mtype, loss_mtype])): dyad_type = MuType({('Gene', args.gene): copy_type}) | pnt_type if dyad_type in auc_vals.index: plot_df.loc[pnt_type, copy_indx] = auc_vals[dyad_type] plt_min = 0.83 for ax, copy_lbl in zip([gain_ax, loss_ax], ['gain', 'loss']): for dpth_lbl in ['all', 'deep']: copy_aucs = plot_df[copy_lbl, dpth_lbl] copy_aucs = copy_aucs[~copy_aucs.isnull()] for pnt_type, copy_auc in copy_aucs.iteritems(): plt_min = min(plt_min, pnt_aucs[pnt_type] - 0.03, copy_auc - 0.03) mtype_sz = 1003 * np.mean(pheno_dict[pnt_type]) plt_clr = choose_subtype_colour( tuple(pnt_type.subtype_iter())[0][1]) if dpth_lbl == 'all': dpth_clr = plt_clr edg_lw = 0 else: dpth_clr = 'none' edg_lw = mtype_sz**0.5 / 4.7 ax.scatter(pnt_aucs[pnt_type], copy_auc, facecolor=dpth_clr, s=mtype_sz, alpha=0.21, edgecolor=plt_clr, linewidths=edg_lw) for copy_lbl, copy_type, copy_ax, copy_lw in zip( ['All Gains', 'Deep Gains', 'All Losses', 'Deep Losses'], [gains_mtype, dup_mtype, dels_mtype, loss_mtype], [gain_ax, gain_ax, loss_ax, loss_ax], [3.1, 4.3, 3.1, 4.3]): gene_copy = MuType({('Gene', args.gene): copy_type}) if gene_copy in auc_vals.index: copy_auc = auc_vals[gene_copy] copy_clr = choose_subtype_colour(copy_type) use_lbl = ' '.join( [copy_lbl.split(' ')[0], args.gene, copy_lbl.split(' ')[1]]) copy_ax.text(max(plt_min, 0.51), copy_auc + (1 - copy_auc) / 173, use_lbl, c=copy_clr, size=13, ha='left', va='bottom') copy_ax.plot([plt_min, 1], [copy_auc, copy_auc], color=copy_clr, linewidth=copy_lw, linestyle=':', alpha=0.83) plt_lims = plt_min, 1 + (1 - plt_min) / 131 for ax in (gain_ax, loss_ax): ax.grid(linewidth=0.83, alpha=0.41) ax.set_xlim(plt_lims) ax.set_ylim(plt_lims) ax.plot(plt_lims, [0.5, 0.5], color='black', linewidth=1.1, linestyle=':', alpha=0.71) ax.plot([0.5, 0.5], plt_lims, color='black', linewidth=1.1, linestyle=':', alpha=0.71) ax.plot(plt_lims, [1, 1], color='black', linewidth=1.7, alpha=0.89) ax.plot([1, 1], plt_lims, color='black', linewidth=1.7, alpha=0.89) ax.plot(plt_lims, plt_lims, color='#550000', linewidth=1.9, linestyle='--', alpha=0.41) ax.set_xlabel("Accuracy of Subgrouping Classifier", size=23, weight='bold') gain_ax.set_ylabel("Accuracy of\n(Subgrouping or CNAs) Classifier", size=23, weight='bold') gain_ax.set_title("Gain CNAs", size=27, weight='bold') loss_ax.set_title("Loss CNAs", size=27, weight='bold') plt.tight_layout(w_pad=3.1) plt.savefig(os.path.join( plot_dir, args.gene, "{}__dyad-comparisons_{}_{}.svg".format(use_coh, args.classif, use_src)), bbox_inches='tight', format='svg') plt.close()