Esempio n. 1
0
def get_separation(iso_df, args, cdata):
    base_phenos = {
        gene: np.array(cdata.train_pheno(MuType({('Gene', gene): None})))
        for gene in args.genes
        }
    module_pheno = np.array(
        cdata.train_pheno(MuType({('Gene', tuple(args.genes)): None})))

    auc_list = pd.Series(index=iso_df.index, dtype=np.float)
    sep_dict = {gene: pd.Series(dtype=np.float) for gene in args.genes}
    prop_list = pd.Series(index=iso_df.index, dtype=np.float)

    for mtype, iso_vals in iso_df.iterrows():
        cur_pheno = np.array(cdata.train_pheno(mtype))

        none_vals = np.concatenate(iso_vals[~module_pheno].values)
        cur_vals = np.concatenate(iso_vals[cur_pheno].values)
        auc_list[mtype] = np.less.outer(none_vals, cur_vals).mean()
        prop_list[mtype] = np.sum(cur_pheno) / np.sum(module_pheno)

        for gene, base_pheno in base_phenos.items():
            rest_stat = base_pheno & ~cur_pheno

            if np.any(rest_stat):
                rest_vals = np.concatenate(iso_vals[rest_stat].values)
                sep_dict[gene][mtype] = np.less.outer(
                    none_vals, rest_vals).mean()

    return auc_list, sep_dict, prop_list
Esempio n. 2
0
def get_separation(iso_df, args, cdata):
    base_pheno = np.array(cdata.train_pheno(MuType(cdata.train_mut.allkey())))

    # get the mutation status of the samples in the cohort for each of the
    # tested subtypes, remove subtypes that include all mutated samples
    mtype_phenos = {
        mtype: np.array(cdata.train_pheno(mtype))
        for mtype in iso_df.index
        if len(mtype.get_samples(cdata.train_mut)) < np.sum(base_pheno)
    }

    prop_list = pd.Series(index=mtype_phenos.keys(), dtype=np.float)
    auc_list = pd.Series(index=mtype_phenos.keys(), dtype=np.float)
    sep_list = pd.Series(index=mtype_phenos.keys(), dtype=np.float)

    for mtype, cur_pheno in mtype_phenos.items():
        prop_list[mtype] = np.sum(cur_pheno) / np.sum(base_pheno)

        none_vals = np.concatenate(iso_df.loc[mtype, ~base_pheno].values)
        cur_vals = np.concatenate(iso_df.loc[mtype, cur_pheno].values)
        rest_vals = np.concatenate(iso_df.loc[mtype,
                                              base_pheno & ~cur_pheno].values)

        auc_list[mtype] = np.less.outer(none_vals, cur_vals).mean()
        auc_list[mtype] += np.equal.outer(none_vals, cur_vals).mean() / 2
        sep_list[mtype] = np.less.outer(none_vals, rest_vals).mean()
        sep_list[mtype] += np.equal.outer(none_vals, rest_vals).mean() / 2

    return auc_list, sep_list, prop_list
Esempio n. 3
0
def plot_subtype_stability(out_data, args, cdata, use_levels):
    fig, ax = plt.subplots(figsize=(13, 8))

    use_phenos = {
        mtype: np.array(cdata.train_pheno(mtype))
        for mtype in cdata.train_mut.branchtypes(sub_levels=use_levels)
        }

    use_phenos = {mtype: use_pheno for mtype, use_pheno in use_phenos.items()
                  if np.sum(use_pheno) >= 15}
    subtype_cmap = sns.hls_palette(len(use_phenos), l=.57, s=.89)

    out_means = np.mean(out_data, axis=1)
    out_sds = np.std(out_data, axis=1)
    plt_xmax = np.max(np.absolute(out_means)) * 1.25

    all_mtype = MuType(cdata.train_mut.allkey())
    wt_pheno = ~np.array(cdata.train_pheno(all_mtype))

    ax = sns.kdeplot(out_means[wt_pheno], out_sds[wt_pheno],
                     cmap=sns.light_palette('0.5', as_cmap=True),
                     linewidths=1.7, alpha=0.7, gridsize=500, n_levels=32)

    ax.text(np.min(out_means[wt_pheno]), np.min(out_sds[wt_pheno]),
            'Wild-Type', size=15, color='0.5')

    for (mtype, pheno), use_clr in zip(use_phenos.items(), subtype_cmap):
        use_cmap = sns.light_palette(use_clr, as_cmap=True)

        ax = sns.kdeplot(out_means[pheno], out_sds[pheno],
                         cmap=use_cmap, linewidths=3.6, alpha=0.4,
                         gridsize=500, n_levels=3)
        
    plt.xlim(-plt_xmax, plt_xmax)
    plt.ylim(0, np.max(out_sds) * 1.03)
    plt.xlabel('Mutation Score CV Mean', fontsize=17, weight='semibold')
    plt.ylabel('Mutation Score CV SD', fontsize=17, weight='semibold')
    
    plt.legend([Line2D([0], [0], color=use_clr, lw=6.8)
                for use_clr in subtype_cmap],
               [str(mtype) for mtype in use_phenos],
               fontsize=11, ncol=2, frameon=False)

    fig.savefig(
        os.path.join(plot_dir,
                     'stability__{}-{}_{}-{}__levels_{}.png'.format(
                         args.model_name, args.solve_method,
                         args.cohort, args.gene, '__'.join(use_levels)
                        )),
        dpi=250, bbox_inches='tight'
        )

    plt.close()
Esempio n. 4
0
def plot_subtype_violins(out_data, args, cdata, use_levels):
    use_phenos = {
        mtype: np.array(cdata.train_pheno(mtype))
        for mtype in cdata.train_mut.branchtypes(sub_levels=use_levels)
        }

    use_phenos = {mtype: use_pheno for mtype, use_pheno in use_phenos.items()
                  if np.sum(use_pheno) >= 5}
    subtype_cmap = sns.hls_palette(len(use_phenos), l=.57, s=.89)
    fig, ax = plt.subplots(figsize=(1.55 + len(use_phenos) * 0.68, 8))

    all_mtype = MuType(cdata.train_mut.allkey())
    all_pheno = np.array(cdata.train_pheno(all_mtype))
Esempio n. 5
0
def plot_label_stability(out_data, args, cdata):
    fig, ax = plt.subplots(figsize=(13, 8))

    wt_cmap = sns.light_palette(wt_clr, as_cmap=True)
    mut_cmap = sns.light_palette(mut_clr, as_cmap=True)

    use_mtype = MuType({('Gene', args.gene): None})
    mtype_stat = np.array(cdata.train_pheno(use_mtype))
    out_means = np.mean(out_data, axis=1)
    out_sds = np.std(out_data, axis=1)
    plt_xmax = np.max(np.absolute(out_means)) * 1.1

    ax = sns.kdeplot(out_means[~mtype_stat],
                     out_sds[~mtype_stat],
                     cmap=wt_cmap,
                     linewidths=2.1,
                     alpha=0.8,
                     gridsize=1000,
                     n_levels=34)
    ax.text(np.percentile(out_means[~mtype_stat], q=39),
            np.percentile(out_sds[~mtype_stat], q=99.3),
            "Wild-Type",
            size=17,
            color=wt_clr)

    ax = sns.kdeplot(out_means[mtype_stat],
                     out_sds[mtype_stat],
                     cmap=mut_cmap,
                     linewidths=2.1,
                     alpha=0.8,
                     gridsize=1000,
                     n_levels=34)
    ax.text(np.percentile(out_means[mtype_stat], q=61),
            np.percentile(out_sds[mtype_stat], q=99.3),
            "{} Mutant".format(args.gene),
            size=17,
            color=mut_clr)

    plt.xlim(-plt_xmax, plt_xmax)
    plt.ylim(0, np.max(out_sds) * 1.05)
    plt.xlabel('Mutation Score CV Mean', fontsize=19, weight='semibold')
    plt.ylabel('Mutation Score CV SD', fontsize=19, weight='semibold')

    fig.savefig(os.path.join(
        plot_dir,
        'stability__{}-{}_{}-{}.png'.format(args.model_name, args.solve_method,
                                            args.cohort, args.gene)),
                dpi=250,
                bbox_inches='tight')

    plt.close()
Esempio n. 6
0
def plot_gene_clustering(trans_dict, use_gene, cdata, use_comps=(0, 1)):
    fig, axarr = plt.subplots(nrows=1, ncols=len(trans_dict), figsize=(21, 7))

    # extracts the given pair of components from each transformed dataset
    use_comps = np.array(use_comps)
    trans_dict = [(trs_lbl, trans_expr[:, use_comps])
                  for trs_lbl, trans_expr in trans_dict]

    # turn off the axis tick labels
    for ax in axarr.reshape(-1):
        ax.set_xticklabels([])
        ax.set_yticklabels([])

    base_mtype = MuType({('Gene', use_gene): None})
    base_pheno = np.array(cdata.train_pheno(base_mtype))
    mut_clr = sns.light_palette((1 / 3, 0, 0),
                                input="rgb",
                                n_colors=5,
                                reverse=True)[1]

    for i, (trs_lbl, trans_expr) in enumerate(trans_dict):
        axarr[i].set_title(trs_lbl, size=24, weight='semibold')

        # plot the wild-type points
        axarr[i].scatter(trans_expr[~base_pheno, 0],
                         trans_expr[~base_pheno, 1],
                         marker='o',
                         s=6,
                         c='0.5',
                         alpha=0.15,
                         edgecolor='none')

        # plot the mutated points
        axarr[i].scatter(trans_expr[base_pheno, 0],
                         trans_expr[base_pheno, 1],
                         marker='o',
                         s=10,
                         c=mut_clr,
                         alpha=0.3,
                         edgecolor='none')

    fig.tight_layout(w_pad=1.1)
    fig.savefig(os.path.join(
        plot_dir,
        "clustering-gene_{}__comps_{}-{}.png".format(use_gene, use_comps[0],
                                                     use_comps[1])),
                dpi=300,
                bbox_inches='tight')

    plt.close()
Esempio n. 7
0
def get_aucs(iso_df, args, cdata):
    base_pheno = np.array(cdata.train_pheno(
        MuType({('Gene', args.gene): None})))

    loss_aucs = {ctf: {'CNA': None, 'Mut': None} for ctf in iso_df.index
                 if ctf[0] < 0}
    gain_aucs = {ctf: {'CNA': None, 'Mut': None} for ctf in iso_df.index
                 if ctf[0] > 0}

    for low_ctf, high_ctf in iso_df.index:
        use_vals = iso_df.loc[(low_ctf, high_ctf), :].values

        if low_ctf < 0:
            cna_pheno = np.array(cdata.train_pheno(
                {'Gene': args.gene, 'CNA': 'Loss', 'Cutoff': low_ctf}))

            wt_stat = ~base_pheno & np.array(cdata.train_pheno(
                {'Gene': args.gene, 'CNA': 'Range', 
                 'Cutoff': (high_ctf, -high_ctf)}
                ))

        else:
            cna_pheno = np.array(cdata.train_pheno(
                {'Gene': args.gene, 'CNA': 'Gain', 'Cutoff': high_ctf}))

            wt_stat = ~base_pheno & np.array(cdata.train_pheno(
                {'Gene': args.gene, 'CNA': 'Range', 
                 'Cutoff': (-low_ctf, low_ctf)}
                ))

        wt_vals = np.concatenate(use_vals[wt_stat])
        cna_vals = np.concatenate(use_vals[cna_pheno & ~base_pheno])
        mut_vals = np.concatenate(use_vals[~cna_pheno & base_pheno])

        cna_auc = np.greater.outer(cna_vals, wt_vals).mean()
        mut_auc = np.greater.outer(mut_vals, wt_vals).mean()

        if low_ctf < 0:
            loss_aucs[low_ctf, high_ctf]['CNA'] = cna_auc
            loss_aucs[low_ctf, high_ctf]['Mut'] = mut_auc

        else:
            gain_aucs[low_ctf, high_ctf]['CNA'] = cna_auc
            gain_aucs[low_ctf, high_ctf]['Mut'] = mut_auc

    loss_df = pd.DataFrame.from_dict(loss_aucs, orient='index')
    gain_df = pd.DataFrame.from_dict(gain_aucs, orient='index')

    return loss_df, gain_df
Esempio n. 8
0
def get_similarities(iso_df, base_genes, cdata):
    base_pheno = np.array(
        cdata.train_pheno(MuType({('Gene', tuple(base_genes)): None})))

    simil_df = pd.DataFrame(index=iso_df.index,
                            columns=iso_df.index,
                            dtype=np.float)
    auc_list = pd.Series(index=iso_df.index, dtype=np.float)

    for cur_mtype, other_mtype in product(iso_df.index, repeat=2):
        none_vals = np.concatenate(iso_df.loc[cur_mtype, ~base_pheno].values)

        cur_pheno = np.array(cdata.train_pheno(cur_mtype))
        other_pheno = np.array(cdata.train_pheno(other_mtype))

        if cur_mtype == other_mtype:
            simil_df.loc[cur_mtype, other_mtype] = 1.0
            cur_vals = np.concatenate(iso_df.loc[cur_mtype, cur_pheno].values)
            auc_list[cur_mtype] = np.less.outer(none_vals, cur_vals).mean()

        else:
            if not np.any(~cur_pheno & other_pheno):
                cur_vals = np.concatenate(iso_df.loc[cur_mtype, cur_pheno
                                                     & ~other_pheno].values)
                other_vals = np.concatenate(iso_df.loc[cur_mtype,
                                                       other_pheno].values)

            elif not np.any(cur_pheno & ~other_pheno):
                cur_vals = np.concatenate(iso_df.loc[cur_mtype,
                                                     cur_pheno].values)
                other_vals = np.concatenate(iso_df.loc[cur_mtype, ~cur_pheno
                                                       & other_pheno].values)

            else:
                cur_vals = np.concatenate(iso_df.loc[cur_mtype, cur_pheno
                                                     & ~other_pheno].values)
                other_vals = np.concatenate(iso_df.loc[cur_mtype, ~cur_pheno
                                                       & other_pheno].values)

            other_none_prob = np.greater.outer(none_vals, other_vals).mean()
            other_cur_prob = np.greater.outer(other_vals, cur_vals).mean()
            cur_none_prob = np.greater.outer(none_vals, cur_vals).mean()

            simil_df.loc[cur_mtype,
                         other_mtype] = ((other_cur_prob - other_none_prob) /
                                         (0.5 - cur_none_prob))

    return simil_df, auc_list
Esempio n. 9
0
def plot_subtype_violins(out_data, args, cdata, use_levels):
    use_phenos = {
        mtype: np.array(cdata.train_pheno(mtype))
        for mtype in cdata.train_mut.branchtypes(sub_levels=use_levels)
        }

    use_phenos = {mtype: use_pheno for mtype, use_pheno in use_phenos.items()
                  if np.sum(use_pheno) >= 5}
    subtype_cmap = sns.hls_palette(len(use_phenos), l=.57, s=.89)
    fig, ax = plt.subplots(figsize=(1.55 + len(use_phenos) * 0.68, 8))

    all_mtype = MuType(cdata.train_mut.allkey())
    all_pheno = np.array(cdata.train_pheno(all_mtype))

    out_meds = np.percentile(out_data, q=50, axis=1)
    mtype_meds = [('Wild-Type ({} samples)'.format(np.sum(~all_pheno)),
                   out_meds[~all_pheno])]

    mtype_meds += [('{} ({} samples)'.format(mtype, np.sum(use_pheno)),
                    out_meds[use_pheno])
                   for mtype, use_pheno in use_phenos.items()]

    mtype_meds = sorted(mtype_meds, key=lambda x: np.mean(x[1]))
    med_df = pd.concat(pd.DataFrame({'Subtype': mtype, 'Score': meds})
                       for mtype, meds in mtype_meds)
    ax = sns.violinplot(data=med_df, x='Subtype', y='Score',
                        palette=['0.5'] + subtype_cmap, width=0.96)

    plt.xlabel('Mutation Type', size=18, weight='semibold')
    plt.ylabel('Inferred Mutation Score', size=18, weight='semibold')
    plt.xticks(rotation=45, ha='right', size=10)
    plt.yticks(size=15)

    fig.savefig(
        os.path.join(plot_dir,
                     'violins__{}-{}_{}-{}__levels_{}.png'.format(
                         args.model_name, args.solve_method,
                         args.cohort, args.gene, '__'.join(use_levels)
                        )),
        dpi=250, bbox_inches='tight'
        )

    plt.close()
Esempio n. 10
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('expr_source', type=str,
                        choices=['Firehose', 'toil', 'toil_tx'],
                        help='which TCGA expression data source to use')
    parser.add_argument('cohort', type=str, help="which TCGA cohort to use")

    parser.add_argument(
        'syn_root', type=str,
        help="the root cache directory for data downloaded from Synapse"
        )

    parser.add_argument(
        'samp_cutoff', type=int,
        help="minimum number of mutated samples needed to test a gene"
        )

    parser.add_argument('classif', type=str,
                        help='the name of a mutation classifier')
    
    parser.add_argument(
        '--cv_id', type=int, default=6732,
        help='the random seed to use for cross-validation draws'
        )
 
    parser.add_argument(
        '--task_count', type=int, default=10,
        help='how many parallel tasks the list of types to test is split into'
        )
    parser.add_argument('--task_id', type=int, default=0,
                        help='the subset of subtypes to assign to this task')

    parser.add_argument('--verbose', '-v', action='store_true',
                        help='turns on diagnostic messages')

    # parse command-line arguments, create directory where to save results
    args = parser.parse_args()
    out_path = os.path.join(
        base_dir, 'output', args.expr_source,
        '{}__samps-{}'.format(args.cohort, args.samp_cutoff), args.classif
        )

    gene_list = pickle.load(
        open(os.path.join(base_dir, "setup",
                          "genes-list_{}__{}__samps-{}.p".format(
                              args.expr_source, args.cohort,
                              args.samp_cutoff
                            )),
             'rb')
        )

    # log into Synapse using locally stored credentials
    syn = synapseclient.Synapse()
    syn.cache.cache_root_dir = args.syn_root
    syn.login()
 
    expr_dir = pd.read_csv(
        open(os.path.join(base_dir, 'expr_sources.txt'), 'r'),
        sep='\t', header=None, index_col=0
        ).loc[args.expr_source].iloc[0]

    cdata = MutationCohort(
        cohort=args.cohort, mut_genes=gene_list, mut_levels=['Gene'],
        expr_source=args.expr_source, expr_dir=expr_dir, var_source='mc3',
        syn=syn, cv_prop=0.75, cv_seed=2079 + 57 * args.cv_id
        )

    clf_info = args.classif.split('__')
    clf_module = import_module(
        'HetMan.experiments.gene_baseline.models.{}'.format(clf_info[0]))
    mut_clf = getattr(clf_module, clf_info[1].capitalize())

    out_auc = {mut_gene: None for mut_gene in gene_list}
    out_aupr = {mut_gene: None for mut_gene in gene_list}
    out_params = {mut_gene: None for mut_gene in gene_list}
    out_time = {mut_gene: None for mut_gene in gene_list}

    for i, mut_gene in enumerate(gene_list):
        if (i % args.task_count) == args.task_id:
            if args.verbose:
                print("Testing {} ...".format(mut_gene))

            clf = mut_clf()
            mtype = MuType({('Gene', mut_gene): None})

            clf.tune_coh(cdata, mtype, exclude_genes={mut_gene},
                         tune_splits=4, test_count=24, parallel_jobs=16)
            out_params[mut_gene] = {par: clf.get_params()[par]
                                    for par, _ in mut_clf.tune_priors}

            t_start = time.time()
            clf.fit_coh(cdata, mtype, exclude_genes={mut_gene})
            t_end = time.time()
            out_time[mut_gene] = t_end - t_start

            test_omics, test_pheno = cdata.test_data(
                mtype, exclude_genes={mut_gene})
            pred_scores = clf.predict_omic(test_omics)

            if len(set(test_pheno)) == 2:
                out_auc[mut_gene] = roc_auc_score(test_pheno, pred_scores)
                out_aupr[mut_gene] = average_precision_score(
                    test_pheno, pred_scores)

            else:
                out_auc[mut_gene] = 0.5
                out_aupr[mut_gene] = len(mtype.get_samples(cdata.train_mut))
                out_aupr[mut_gene] /= len(cdata.train_samps)

        else:
            del(out_auc[mut_gene])
            del(out_aupr[mut_gene])
            del(out_params[mut_gene])
            del(out_time[mut_gene])

    pickle.dump(
        {'AUC': out_auc, 'AUPR': out_aupr,
         'Clf': mut_clf, 'Params': out_params, 'Time': out_time},
        open(os.path.join(out_path,
                          'out__cv-{}_task-{}.p'.format(
                              args.cv_id, args.task_id)),
             'wb')
        )
Esempio n. 11
0
def plot_label_stability(out_data, args, cdata):
    fig, ax = plt.subplots(figsize=(13, 8))
    kern_bw = (np.max(out_data) - np.min(out_data)) / 40
    mut_clr = sns.hls_palette(1, l=.4, s=.9)[0]
    
    if '_' in args.gene:
        mut_info = args.gene.split('_')
        use_mtype = MuType({('Gene', mut_info[0]): mtype_list[mut_info[1]]})

    else:
        use_mtype = MuType({('Gene', args.gene): None})

    mtype_stat = np.array(cdata.train_pheno(use_mtype))
    out_meds = np.percentile(out_data, q=50, axis=1)
    mtype_meds = [('Wild-Type ({} samples)'.format(np.sum(~all_pheno)),
                   out_meds[~all_pheno])]

    mtype_meds += [('{} ({} samples)'.format(mtype, np.sum(use_pheno)),
                    out_meds[use_pheno])
                   for mtype, use_pheno in use_phenos.items()]

    mtype_meds = sorted(mtype_meds, key=lambda x: np.mean(x[1]))
    med_df = pd.concat(pd.DataFrame({'Subtype': mtype, 'Score': meds})
                       for mtype, meds in mtype_meds)
    ax = sns.violinplot(data=med_df, x='Subtype', y='Score',
                        palette=['0.5'] + subtype_cmap, width=0.96)

    plt.xlabel('Mutation Type', size=18, weight='semibold')
    plt.ylabel('Inferred Mutation Score', size=18, weight='semibold')
    plt.xticks(rotation=45, ha='right', size=10)
    plt.yticks(size=15)

    fig.savefig(
        os.path.join(plot_dir,
                     'violins__{}-{}_{}-{}__levels_{}.png'.format(
                         args.model_name, args.solve_method,
                         args.cohort, args.gene, '__'.join(use_levels)
                        )),
        dpi=250, bbox_inches='tight'
        )
    wt_cmap = sns.light_palette('0.07', as_cmap=True)
    mut_cmap = sns.light_palette(sns.hls_palette(1, l=.33, s=.95)[0],
                                 as_cmap=True)

    if '_' in args.gene:
        mut_info = args.gene.split('_')
        use_mtype = MuType({('Gene', mut_info[0]): mtype_list[mut_info[1]]})

    else:
        use_mtype = MuType({('Gene', args.gene): None})

    mtype_stat = np.array(cdata.train_pheno(use_mtype))
    out_means = np.mean(out_data, axis=1)
    out_sds = np.std(out_data, axis=1)

    ax = sns.kdeplot(out_means[~mtype_stat], out_sds[~mtype_stat],
                     cmap=wt_cmap, linewidths=2.7, alpha=0.5,
                     gridsize=1000, shade_lowest=False, n_levels=15,
                     label='Wild-Type')

    ax = sns.kdeplot(out_means[mtype_stat], out_sds[mtype_stat],
                     cmap=mut_cmap, linewidths=2.7, alpha=0.5,
                     gridsize=1000, shade_lowest=False, n_levels=15,
                     label='{} Mutant'.format(args.gene))

    plt.xlabel('Mutation Score CV Mean', fontsize=20)
    plt.ylabel('Mutation Score CV SD', fontsize=20)

    plt.close()
Esempio n. 12
0
def plot_tuning_gene(cdata, args, tune_params, pca_comps=(0, 1)):
    tune_size1 = len(tune_params[0][1])
    tune_size2 = len(tune_params[1][1])

    fig, axarr = plt.subplots(nrows=tune_size1,
                              ncols=tune_size2,
                              figsize=(tune_size2 * 5 - 1, tune_size1 * 5))
    fig.tight_layout(pad=1.6)

    for ax in axarr.reshape(-1):
        ax.set_xticklabels([])
        ax.set_yticklabels([])

    pca_comps = np.array(pca_comps)
    trans_dict = dict()
    base_pheno = np.array(cdata.train_pheno(MuType(cdata.train_mut.allkey())))

    mut_clr = sns.light_palette((1 / 3, 0, 0),
                                input="rgb",
                                n_colors=5,
                                reverse=True)[1]

    for prms in product(*[[(x[0], y) for y in x[1]] for x in tune_params[:2]]):

        mut_trans = eval(args.transform)().set_params(
            **dict(prms + (('fit__random_state', 903), )))
        trans_dict[prms] = mut_trans.fit_transform_coh(cdata)[:, pca_comps]

    for i in range(tune_size1):
        axarr[i, 0].set_ylabel('{}: {}'.format(tune_params[0][0],
                                               tune_params[0][1][i]),
                               size=21)

    for j in range(tune_size2):
        axarr[tune_size1 - 1,
              j].set_xlabel('{}: {}'.format(tune_params[1][0],
                                            tune_params[1][1][j]),
                            size=21)

    for i, j in product(range(tune_size1), range(tune_size2)):
        trans_expr = trans_dict[((tune_params[0][0], tune_params[0][1][i]),
                                 (tune_params[1][0], tune_params[1][1][j]))]

        axarr[i, j].scatter(trans_expr[~base_pheno, 0],
                            trans_expr[~base_pheno, 1],
                            marker='o',
                            s=12,
                            c='0.4',
                            alpha=0.25,
                            edgecolor='none')

        axarr[i, j].scatter(trans_expr[base_pheno, 0],
                            trans_expr[base_pheno, 1],
                            marker='o',
                            s=35,
                            c=mut_clr,
                            alpha=0.4,
                            edgecolor='none')

    fig.savefig(os.path.join(
        plot_dir,
        "{}__gene_{}_{}__comps_{}-{}__{}.png".format(args.transform, args.gene,
                                                     args.cohort, pca_comps[0],
                                                     pca_comps[1],
                                                     tune_params[-1][1])),
                dpi=250,
                bbox_inches='tight')

    plt.close()
Esempio n. 13
0
def main():
    parser = argparse.ArgumentParser(
        description=("Test a Stan multitask model's inferred mutation scores "
                     "for a pair of mutation sub-types it is trained to "
                     "classify and compare them to the scores it infers for "
                     "the remaining mutated and wild-type samples for a "
                     "given gene in a TCGA cohort."))

    parser.add_argument('mtype_file',
                        type=str,
                        help='the pickle file where sub-types are stored')
    parser.add_argument('out_dir',
                        type=str,
                        help='where to save the output of testing sub-types')

    parser.add_argument('cohort', type=str, help='a TCGA cohort')
    parser.add_argument('model_name', type=str, help='a TCGA cohort')
    parser.add_argument('solve_method', type=str, help='a TCGA cohort')

    parser.add_argument('--use_genes',
                        type=str,
                        default=None,
                        nargs='+',
                        help='specify which gene(s) to isolate against')

    parser.add_argument(
        '--cv_id',
        type=int,
        default=8807,
        help='the random seed to use for cross-validation draws')

    parser.add_argument(
        '--task_count',
        type=int,
        default=10,
        help='how many parallel tasks the list of types to test is split into')
    parser.add_argument('--task_id',
                        type=int,
                        default=0,
                        help='the subset of subtypes to assign to this task')

    parser.add_argument(
        '--tune_splits',
        type=int,
        default=4,
        help='how many training cohort splits to use for tuning')
    parser.add_argument(
        '--test_count',
        type=int,
        default=16,
        help='how many hyper-parameter values to test in each tuning split')

    parser.add_argument(
        '--infer_splits',
        type=int,
        default=20,
        help='how many cohort splits to use for inference bootstrapping')
    parser.add_argument(
        '--infer_folds',
        type=int,
        default=4,
        help=('how many parts to split the cohort into in each inference '
              'cross-validation run'))

    parser.add_argument(
        '--parallel_jobs',
        type=int,
        default=4,
        help='how many parallel CPUs to allocate the tuning tests across')

    parser.add_argument('--verbose',
                        '-v',
                        action='store_true',
                        help='turns on diagnostic messages')

    args = parser.parse_args()
    out_file = os.path.join(args.out_dir,
                            'out__task-{}.p'.format(args.task_id))

    if args.verbose:
        print("Starting multi-task isolation with Stan model <{}> for "
              "sub-types in\n{}\nthe results of which will be stored "
              "in\n{}\n".format(args.model_name, args.mtype_file,
                                args.out_dir))

    pair_list = pickle.load(open(args.mtype_file, 'rb'))
    or_list = [mtype1 | mtype2 for mtype1, mtype2 in pair_list]
    use_lvls = []

    for lvls in reduce(or_,
                       [{mtype.get_sorted_levels()} for mtype in or_list]):
        for lvl in lvls:
            if lvl not in use_lvls:
                use_lvls.append(lvl)

    if args.use_genes is None:
        if set(mtype.cur_level for mtype in or_list) == {'Gene'}:
            use_genes = reduce(or_, [
                set(gn for gn, _ in mtype.subtype_list()) for mtype in or_list
            ])

        else:
            raise ValueError(
                "A gene to isolate against must be given or the pairs of "
                "subtypes listed must each have <Gene> as their top level!")

    else:
        use_genes = set(args.use_genes)

    if args.verbose:
        print("Subtypes at mutation annotation levels {} will be isolated "
              "against genes:\n{}".format(use_lvls, use_genes))

    use_module = import_module('HetMan.experiments.subvariant_multi'
                               '.models.{}'.format(args.model_name))
    UsePipe = getattr(use_module, 'UsePipe')

    if args.solve_method == 'optim':
        clf_stan = getattr(use_module, 'UsePipe')(getattr(
            use_module,
            'UseOptimizing')(model_code=getattr(use_module, 'use_model')))

    elif args.solve_method == 'variat':
        clf_stan = getattr(use_module, 'UsePipe')(getattr(
            use_module,
            'UseVariational')(model_code=getattr(use_module, 'use_model')))

    elif args.solve_method == 'sampl':
        clf_stan = getattr(use_module, 'UsePipe')(getattr(
            use_module,
            'UseSampling')(model_code=getattr(use_module, 'use_model')))

    else:
        raise ValueError("Unrecognized <solve_method> argument!")

    if args.verbose:
        print('Using the following Stan model:\n\n{}'.format(
            clf_stan.named_steps['fit'].model_code))

    # log into Synapse using locally stored credentials
    syn = synapseclient.Synapse()
    syn.cache.cache_root_dir = ('/home/exacloud/lustre1/CompBio'
                                '/mgrzad/input-data/synapse')
    syn.login()

    cdata = MutationCohort(cohort=args.cohort,
                           mut_genes=list(use_genes),
                           mut_levels=use_lvls,
                           expr_source='Firehose',
                           expr_dir=firehose_dir,
                           syn=syn,
                           cv_seed=9099,
                           cv_prop=1.0)

    if args.verbose:
        print("Loaded {} pairs of subtypes of which roughly {} will be "
              "isolated in cohort {} with {} samples.".format(
                  len(pair_list),
                  len(pair_list) // args.task_count, args.cohort,
                  len(cdata.samples)))

    out_multi = {mtypes: None for mtypes in pair_list}
    out_par = {mtypes: None for mtypes in pair_list}
    out_vars = {mtypes: None for mtypes in pair_list}

    base_mtype = MuType({('Gene', tuple(use_genes)): None})
    base_samps = base_mtype.get_samples(cdata.train_mut)

    # for each sub-variant, check if it has been assigned to this task
    for i, (mtype1, mtype2) in enumerate(pair_list):
        if (i % args.task_count) == args.task_id:
            if args.verbose:
                print("Isolating {} and {} ...".format(mtype1, mtype2))

            ex_samps = base_samps - (mtype1.get_samples(cdata.train_mut)
                                     | mtype2.get_samples(cdata.train_mut))

            clf_stan.tune_coh(cdata, [mtype1, mtype2],
                              exclude_genes=use_genes,
                              exclude_samps=ex_samps,
                              tune_splits=args.tune_splits,
                              test_count=args.test_count,
                              parallel_jobs=args.parallel_jobs)

            clf_stan.fit_coh(cdata, [mtype1, mtype2],
                             exclude_genes=use_genes,
                             exclude_samps=ex_samps)
            clf_params = clf_stan.get_params()

            out_par[(mtype1, mtype2)] = {
                par: clf_params[par]
                for par, _ in clf_stan.tune_priors
            }
            out_vars[(mtype1,
                      mtype2)] = (clf_stan.named_steps['fit'].get_var_means())

            out_multi[(mtype1, mtype2)] = clf_stan.infer_coh(
                cdata, [mtype1, mtype2],
                exclude_genes=use_genes,
                force_test_samps=ex_samps,
                infer_splits=args.infer_splits,
                infer_folds=args.infer_folds,
                parallel_jobs=args.parallel_jobs)

        else:
            del (out_multi[(mtype1, mtype2)])
            del (out_par[(mtype1, mtype2)])
            del (out_vars[(mtype1, mtype2)])

    pickle.dump(
        {
            'Infer': out_multi,
            'Par': out_par,
            'Vars': out_vars,
            'Info': {
                'TunePriors': clf_stan.tune_priors,
                'TuneSplits': args.tune_splits,
                'TestCount': args.test_count
            }
        }, open(out_file, 'wb'))
Esempio n. 14
0
def plot_cna_scores(iso_vals, args, cdata):
    fig, ax = plt.subplots(figsize=(16, 11))

    low_ctf, high_ctf = iso_vals.name
    cna_vals = cdata.copy_data.loc[cdata.subset_samps(), args.gene]
    iso_means = iso_vals.apply(np.mean).values
    y_bound = np.max(np.absolute(np.percentile(iso_means, q=(0.1, 99.9))))

    use_mtype = MuType({('Gene', args.gene): None})
    mut_stat = np.array(cdata.train_pheno(use_mtype))
    gap_stat = cdata.train_pheno({
        'Gene': args.gene,
        'CNA': 'Range',
        'Cutoff': (low_ctf, high_ctf)
    })

    zero_qnt = np.mean(cna_vals < 0)
    cna_qnts = quantile_transform(cna_vals.copy().values.reshape(-1,
                                                                 1)).flatten()

    if low_ctf < 0:
        loss_ctf = low_ctf
        wt_ctf = high_ctf
        gain_ctf = -high_ctf

        wt_stat = cdata.train_pheno({
            'Gene': args.gene,
            'CNA': 'Range',
            'Cutoff': (high_ctf, -high_ctf)
        })

    else:
        loss_ctf = -low_ctf
        wt_ctf = low_ctf
        gain_ctf = high_ctf

        wt_stat = cdata.train_pheno({
            'Gene': args.gene,
            'CNA': 'Range',
            'Cutoff': (-low_ctf, low_ctf)
        })

    loss_stat = cdata.train_pheno({
        'Gene': args.gene,
        'CNA': 'Loss',
        'Cutoff': loss_ctf
    })
    gain_stat = cdata.train_pheno({
        'Gene': args.gene,
        'CNA': 'Gain',
        'Cutoff': gain_ctf
    })

    loss_qnt = np.mean(cna_vals < loss_ctf)
    wt_qnt = np.mean(cna_vals < wt_ctf)
    gain_qnt = np.mean(cna_vals < gain_ctf)

    if np.any(loss_stat):
        sns.kdeplot(cna_qnts[loss_stat & ~mut_stat],
                    iso_means[loss_stat & ~mut_stat],
                    cmap=loss_cmap,
                    shade=True,
                    shade_lowest=False,
                    alpha=0.73,
                    bw=loss_qnt / 7,
                    gridsize=250,
                    n_levels=11,
                    cut=0)

    if np.any(gap_stat):
        sns.kdeplot(cna_qnts[gap_stat & ~mut_stat],
                    iso_means[gap_stat & ~mut_stat],
                    cmap=gap_cmap,
                    shade=True,
                    shade_lowest=False,
                    alpha=0.73,
                    bw=(wt_qnt - loss_qnt) / 7,
                    gridsize=250,
                    n_levels=11,
                    cut=0)

    sns.kdeplot(cna_qnts[wt_stat & ~mut_stat],
                iso_means[wt_stat & ~mut_stat],
                cmap=wt_cmap,
                shade=True,
                shade_lowest=False,
                alpha=0.73,
                bw=(gain_qnt - wt_qnt) / 7,
                gridsize=250,
                n_levels=11,
                cut=0)

    if np.any(gain_stat):
        sns.kdeplot(cna_qnts[gain_stat & ~mut_stat],
                    iso_means[gain_stat & ~mut_stat],
                    cmap=gain_cmap,
                    shade=True,
                    shade_lowest=False,
                    alpha=0.73,
                    bw=(1 - gain_qnt) / 7,
                    gridsize=250,
                    n_levels=11,
                    cut=0)

    ax.scatter(cna_qnts[mut_stat],
               iso_means[mut_stat],
               s=13,
               c='#550000',
               alpha=0.37)

    plt.xlabel("{} {} Gistic Score Rank".format(args.cohort, args.gene),
               fontsize=23,
               weight='semibold')
    plt.ylabel("Inferred {} CNA Score".format(args.gene),
               fontsize=23,
               weight='semibold')

    plt.xticks([loss_qnt, wt_qnt, zero_qnt, gain_qnt], [
        'Loss Cutoff ({:+.3f})'.format(low_ctf),
        'WT Cutoff ({:+.3f})'.format(high_ctf), 'GISTIC = 0',
        'Gain Cutoff ({:+.3f})'.format(-high_ctf)
    ],
               fontsize=15,
               ha='right',
               rotation=38)

    ax.axvline(x=loss_qnt,
               ymin=-y_bound * 2,
               ymax=y_bound * 2,
               ls='--',
               lw=3.3,
               c=loss_cmap(50))
    ax.axvline(x=wt_qnt,
               ymin=-y_bound * 2,
               ymax=y_bound * 2,
               ls='--',
               lw=3.3,
               c=wt_cmap(50))
    ax.axvline(x=zero_qnt,
               ymin=-y_bound * 2,
               ymax=y_bound * 2,
               ls=':',
               lw=0.9,
               c='black')
    ax.axvline(x=gain_qnt,
               ymin=-y_bound * 2,
               ymax=y_bound * 2,
               ls='--',
               lw=3.3,
               c=gain_cmap(50))

    ax.grid(False, which='major', axis='x')
    plt.ylim(-y_bound * 1.29, y_bound * 1.29)

    if low_ctf < 0:
        lgnd_lbls = [
            'CNA Loss used in training', 'CNA WT used in training',
            'CNA Loss withheld in training', 'CNA Gain withheld in training'
        ]

    else:
        lgnd_lbls = [
            'CNA Loss withheld in training', 'CNA WT used in training',
            'CNA Gain withheld in training', 'CNA Gain used in training'
        ]

    lgnd_lbls += ['Mutants withheld in training']
    plt.legend([
        Patch(color=loss_cmap(100), alpha=0.73),
        Patch(color=wt_cmap(100), alpha=0.73),
        Patch(color=gap_cmap(100), alpha=0.73),
        Patch(color=gain_cmap(100), alpha=0.73),
        Line2D([0], [0],
               lw=0,
               marker='o',
               markersize=14,
               alpha=0.57,
               markerfacecolor='#550000',
               markeredgecolor='#550000')
    ],
               lgnd_lbls,
               fontsize=17,
               loc=8,
               ncol=2)

    plt.savefig(os.path.join(
        plot_dir,
        "{}_{}_{}__ctfs_{:.3f}_{:.3f}.png".format(args.cohort, args.gene,
                                                  args.classif, low_ctf,
                                                  high_ctf)),
                dpi=300,
                bbox_inches='tight')

    plt.close()
Esempio n. 15
0
def main():
    parser = argparse.ArgumentParser(
        "Plot the ordering of the simplest subtypes within a module of genes "
        "in a given cohort based on how their isolated expression signatures "
        "classify one another.")

    parser.add_argument('cohort', help='a TCGA cohort')
    parser.add_argument('classif', help='a mutation classifier')
    parser.add_argument('mut_levels',
                        type=str,
                        help='a set of mutation annotation levels')
    parser.add_argument('genes',
                        type=str,
                        nargs='+',
                        help='a list of mutated genes')
    parser.add_argument('--samp_cutoff', type=int, default=25)

    # parse command-line arguments, create directory where plots will be saved
    args = parser.parse_args()
    os.makedirs(plot_dir, exist_ok=True)

    # log into Synapse using locally stored credentials
    syn = synapseclient.Synapse()
    syn.cache.cache_root_dir = ("/home/exacloud/lustre1/CompBio/"
                                "mgrzad/input-data/synapse")
    syn.login()

    cdata = MutationCohort(cohort=args.cohort,
                           mut_genes=args.genes,
                           mut_levels=['Gene'] + args.mut_levels.split('__'),
                           expr_source='Firehose',
                           expr_dir=firehose_dir,
                           syn=syn,
                           cv_prop=1.0)

    infer_df = load_infer_output(
        os.path.join(base_dir, 'output', args.cohort,
                     '_'.join(sorted(args.genes)), args.classif,
                     'samps_{}'.format(args.samp_cutoff), args.mut_levels))

    base_pheno = np.array(
        cdata.train_pheno(MuType({('Gene', tuple(args.genes)): None})))
    auc_list = get_aucs(infer_df, base_pheno,
                        cdata).sort_values(ascending=False)
    auc_list = auc_list[auc_list > 0.6]

    mtype_lens = {mtype: len(mtype.subkeys()) for mtype in auc_list.index}
    mtype_list = sorted(auc_list.index, key=lambda mtype: mtype_lens[mtype])

    mtype_genes = {
        mtype: mtype.subtype_list()[0][0]
        for mtype in auc_list.index
    }
    mtype_samps = {
        mtype: mtype.get_samples(cdata.train_mut)
        for mtype in auc_list.index
    }

    plot_mtypes = reduce(or_, [
        set([mtype for mtype in mtype_list if mtype_genes[mtype] == gene][:3])
        for gene in args.genes
    ])

    ovlp_threshold = 0.5
    i = j = 1
    while len(plot_mtypes) <= 15:
        ovlp_score = min(
            len(mtype_samps[mtype_list[i]] ^ mtype_samps[plot_mtype]) /
            max(len(mtype_samps[mtype_list[i]]), len(mtype_samps[plot_mtype]))
            for plot_mtype in plot_mtypes)

        if ovlp_score >= ovlp_threshold:
            plot_mtypes |= {mtype_list[i]}

        i += 1
        if i >= len(mtype_list):
            j += 1
            i = j
            ovlp_threshold **= 4 / 3

    simil_df = get_similarities(infer_df.loc[plot_mtypes, :], base_pheno,
                                cdata)
    plot_gene_ordering(simil_df, auc_list, args, cdata)
Esempio n. 16
0
def main():
    parser = argparse.ArgumentParser(
        "Set up the copy number alteration expression effect isolation "
        "experiment by enumerating alteration score thresholds to be tested.")

    # create command line arguments
    parser.add_argument('cohort', type=str, help="which TCGA cohort to use")
    parser.add_argument('gene', type=str, help="which gene to consider")
    parser.add_argument('--verbose',
                        '-v',
                        action='store_true',
                        help='turns on diagnostic messages')

    # parse command line arguments, create directory where found thresholds
    # and threshold counts will be stored
    args = parser.parse_args()
    os.makedirs(os.path.join(base_dir, 'setup', 'ctf_lists'), exist_ok=True)
    os.makedirs(os.path.join(base_dir, 'setup', 'ctf_counts'), exist_ok=True)

    # log into Synapse using locally stored credentials
    syn = synapseclient.Synapse()
    syn.cache.cache_root_dir = ("/home/exacloud/lustre1/CompBio/"
                                "mgrzad/input-data/synapse")
    syn.login()

    # load expression, variant call, and copy number alteration data for
    # the given TCGA cohort and mutated gene
    cdata = MutationCohort(cohort=args.cohort,
                           mut_genes=[args.gene],
                           mut_levels=['Gene'],
                           expr_source='Firehose',
                           var_source='mc3',
                           expr_dir=firehose_dir,
                           copy_source='Firehose',
                           copy_dir=copy_dir,
                           copy_discrete=False,
                           cv_prop=1.0,
                           syn=syn)

    ctf_list = []
    mut_stat = np.array(cdata.train_mut.status(cdata.copy_data.index))
    mut_pheno = np.array(cdata.train_pheno(MuType({('Gene', args.gene):
                                                   None})))

    copy_vals = cdata.copy_data.loc[~mut_stat, args.gene]
    loss_vals = copy_vals[copy_vals < 0]
    gain_vals = copy_vals[copy_vals > 0]

    loss_step = 20 / len(loss_vals)
    loss_ctfs = np.unique(
        loss_vals.quantile(np.arange(loss_step, 1, loss_step)))

    gain_step = 20 / len(gain_vals)
    gain_ctfs = np.unique(
        gain_vals.quantile(np.arange(gain_step, 1, gain_step)))[::-1]

    for low_ctf, high_ctf in combn(loss_ctfs, 2):
        cna_stat = (~mut_pheno
                    & cdata.train_pheno({
                        'Gene': args.gene,
                        'CNA': 'Loss',
                        'Cutoff': low_ctf
                    }))

        wt_stat = (~mut_pheno
                   & ~cdata.train_pheno({
                       'Gene': args.gene,
                       'CNA': 'Range',
                       'Cutoff': (low_ctf, high_ctf)
                   })
                   & ~cdata.train_pheno({
                       'Gene': args.gene,
                       'CNA': 'Gain',
                       'Cutoff': -high_ctf
                   }))

        if (np.sum(cna_stat) >= 20) & (np.sum(wt_stat) >= 20):
            ctf_list += [(low_ctf, high_ctf)]

    for high_ctf, low_ctf in combn(gain_ctfs, 2):
        cna_stat = (~mut_pheno
                    & cdata.train_pheno({
                        'Gene': args.gene,
                        'CNA': 'Gain',
                        'Cutoff': high_ctf
                    }))

        wt_stat = (~mut_pheno
                   & ~cdata.train_pheno({
                       'Gene': args.gene,
                       'CNA': 'Range',
                       'Cutoff': (low_ctf, high_ctf)
                   })
                   & ~cdata.train_pheno({
                       'Gene': args.gene,
                       'CNA': 'Loss',
                       'Cutoff': -low_ctf
                   }))

        if (np.sum(cna_stat) >= 20) & (np.sum(wt_stat) >= 20):
            ctf_list += [(low_ctf, high_ctf)]

    # save the list of found non-duplicate subtypes to file
    pickle.dump(
        sorted(ctf_list),
        open(
            os.path.join(base_dir, 'setup', 'ctf_lists',
                         '{}_{}.p'.format(args.cohort, args.gene)), 'wb'))

    with open(
            os.path.join(base_dir, 'setup', 'ctf_counts',
                         '{}_{}.txt'.format(args.cohort, args.gene)),
            'w') as fl:

        fl.write(str(len(ctf_list)))
Esempio n. 17
0
def plot_label_distribution(out_data, args, cdata):
    fig, ax = plt.subplots(figsize=(13, 8))

    # get the median mutation score for each sample across cross-validation
    # runs, use the range of these scores to set plotting parameters
    out_meds = np.percentile(out_data, q=50, axis=1)
    kern_bw = (np.max(out_meds) - np.min(out_meds)) / 38
    plt_xmax = np.max(np.absolute(out_data)) * 1.1

    # get mutation status for the given gene in the given TCGA cohort
    use_mtype = MuType({('Gene', args.gene): None})
    mtype_stat = np.array(cdata.train_pheno(use_mtype))

    # calculates the classifier AUC for predicting mutation status based on
    # its inferred labels for each cross-validation run
    label_aucs = np.apply_along_axis(lambda vals: np.greater.outer(
        vals[mtype_stat], vals[~mtype_stat]).mean(),
                                     axis=0,
                                     arr=out_data)

    # plots distribution of wild-type label medians
    ax = sns.kdeplot(out_meds[~mtype_stat],
                     color=wt_clr,
                     alpha=0.7,
                     shade=False,
                     linewidth=3.4,
                     bw=kern_bw,
                     gridsize=1000,
                     label='Wild-Type')

    # plots distribution of mutant label medians
    ax = sns.kdeplot(out_meds[mtype_stat],
                     color=mut_clr,
                     alpha=0.7,
                     shade=False,
                     linewidth=3.4,
                     bw=kern_bw,
                     gridsize=1000,
                     label='{} Mutant'.format(args.gene))

    # plots distribution of wild-type and mutant labels individually for
    # each cross-validation run
    for i in range(out_data.shape[1]):
        ax = sns.kdeplot(out_data[~mtype_stat, i],
                         shade=True,
                         alpha=0.04,
                         linewidth=0,
                         color=wt_clr,
                         bw=kern_bw,
                         gridsize=1000)
        ax = sns.kdeplot(out_data[mtype_stat, i],
                         shade=True,
                         alpha=0.04,
                         linewidth=0,
                         color=mut_clr,
                         bw=kern_bw,
                         gridsize=1000)

    # display interquartile range of cross-validation run AUCs
    ax.text(
        -plt_xmax * 0.96,
        ax.get_ylim()[1] * 0.92,
        "AUCs: {:.3f} - {:.3f}".format(*np.percentile(label_aucs, q=(25, 75))),
        size=15)

    # set plot legend and axis characteristics
    plt.legend(frameon=False, prop={'size': 17})
    plt.xlim(-plt_xmax, plt_xmax)
    plt.xlabel('Inferred Mutation Score', fontsize=19, weight='semibold')
    plt.ylabel('Density', fontsize=19, weight='semibold')

    fig.savefig(os.path.join(
        plot_dir,
        'distribution__{}-{}_{}-{}.png'.format(args.model_name,
                                               args.solve_method, args.cohort,
                                               args.gene)),
                dpi=250,
                bbox_inches='tight')

    plt.close()
Esempio n. 18
0
def plot_mtype_positions(prob_series, args, cdata):
    kern_bw = (np.max(prob_series) - np.min(prob_series)) / 29

    fig, (ax1, ax2) = plt.subplots(nrows=2,
                                   ncols=1,
                                   figsize=(13, 18),
                                   sharex=True,
                                   sharey=False,
                                   gridspec_kw={'height_ratios': [1, 3.41]})

    base_mtype = MuType({('Gene', args.gene): None})
    cur_mtype = MuType({('Gene', args.gene): prob_series.name})
    base_pheno = np.array(cdata.train_pheno(base_mtype))
    cur_pheno = np.array(cdata.train_pheno(cur_mtype))

    without_phenos = {
        mtype: np.array(cdata.train_pheno(mtype))
        for mtype in cdata.train_mut.branchtypes(min_size=args.samp_cutoff)
        if (mtype & base_mtype).is_empty()
    }

    within_mtypes = {
        MuType({('Gene', args.gene): mtype})
        for mtype in cdata.train_mut[args.gene].combtypes(
            comb_sizes=(1, 2), min_type_size=args.samp_cutoff)
        if (mtype & prob_series.name).is_empty()
    }

    within_phenos = {
        mtype: np.array(cdata.train_pheno(mtype))
        for mtype in within_mtypes
    }

    cur_diff = (np.mean(prob_series[cur_pheno]) -
                np.mean(prob_series[~base_pheno]))

    sns.kdeplot(prob_series[~base_pheno],
                ax=ax1,
                cut=0,
                color='0.4',
                alpha=0.45,
                linewidth=2.8,
                bw=kern_bw,
                gridsize=250,
                shade=True,
                label='{} Wild-Type'.format(args.gene))
    sns.kdeplot(prob_series[cur_pheno],
                ax=ax1,
                cut=0,
                color=(0.267, 0.137, 0.482),
                alpha=0.45,
                linewidth=2.8,
                bw=kern_bw,
                gridsize=250,
                shade=True,
                label='{} Mutant'.format(prob_series.name))
    sns.kdeplot(prob_series[base_pheno & ~cur_pheno],
                ax=ax1,
                cut=0,
                color=(0.698, 0.329, 0.616),
                alpha=0.3,
                linewidth=1.0,
                bw=kern_bw,
                gridsize=250,
                shade=True,
                label='Other {} Mutants'.format(args.gene))

    ax1.set_ylabel('Density', size=23, weight='semibold')
    ax1.yaxis.set_tick_params(labelsize=14)

    without_tests = {
        mtype: {
            'pval':
            ks_2samp(prob_series[~base_pheno & ~pheno],
                     prob_series[~base_pheno & pheno]).pvalue,
            'diff': (np.mean(prob_series[~base_pheno & pheno]) -
                     np.mean(prob_series[~base_pheno & ~pheno]))
        }
        for mtype, pheno in without_phenos.items()
    }

    without_tests = sorted([(mtype, tests)
                            for mtype, tests in without_tests.items()
                            if tests['pval'] < 0.05 and tests['diff'] > 0],
                           key=lambda x: x[1]['pval'])[:8]

    within_tests = {
        mtype: {
            'pval':
            ks_2samp(prob_series[base_pheno & ~cur_pheno & ~pheno],
                     prob_series[base_pheno & ~cur_pheno & pheno]).pvalue,
            'diff': (np.mean(prob_series[base_pheno & ~cur_pheno & pheno]) -
                     np.mean(prob_series[base_pheno & ~cur_pheno & ~pheno]))
        }
        for mtype, pheno in within_phenos.items()
    }

    within_tests = sorted(
        [(mtype, tests)
         for mtype, tests in within_tests.items() if tests['pval'] < 0.1],
        key=lambda x: x[1]['pval'])[:8]

    subtype_df = pd.concat([
        pd.DataFrame({
            'Mtype': repr(mtype).replace(' WITH ', '\n'),
            'Type': '{} Wild-Type'.format(args.gene),
            'Scores': prob_series[~base_pheno
                                  & without_phenos[mtype]]
        }) for mtype, tests in without_tests
    ] + [
        pd.DataFrame({
            'Mtype':
            repr(mtype).replace('Gene IS {}'.format(args.gene), '').replace(
                ' WITH ', '\n'),
            'Type':
            '{} Mutants'.format(args.gene),
            'Scores':
            prob_series[base_pheno & within_phenos[mtype]]
        }) for mtype, tests in within_tests
    ])

    plt_order = subtype_df.groupby(['Mtype'
                                    ])['Scores'].mean().sort_values().index
    subtype_df['Mtype'] = subtype_df['Mtype'].astype(
        'category').cat.reorder_categories(plt_order)

    sns.violinplot(data=subtype_df,
                   x='Scores',
                   y='Mtype',
                   hue='Type',
                   palette={
                       '{} Wild-Type'.format(args.gene): '0.5',
                       '{} Mutants'.format(args.gene): (0.812, 0.518, 0.745)
                   },
                   alpha=0.3,
                   linewidth=1.3,
                   bw=kern_bw,
                   dodge=False,
                   cut=0,
                   gridsize=500,
                   legend=False)

    ax2.set_ylabel('Mutation Type', size=23, weight='semibold')
    ax2.yaxis.set_tick_params(labelsize=12)

    ax2.xaxis.set_tick_params(labelsize=18)
    ax2.set_xlabel('Inferred {} Score'.format(prob_series.name),
                   size=23,
                   weight='semibold')

    fig.tight_layout()
    fig.savefig(os.path.join(
        plot_dir, args.cohort, args.gene,
        "{}_positions__{}_{}__{}__levels__{}.png".format(
            re.sub('/|\.|:', '_', str(prob_series.name)), args.cohort,
            args.gene, args.classif, args.mut_levels)),
                dpi=250,
                bbox_inches='tight')

    plt.close()
Esempio n. 19
0
def main():
    parser = argparse.ArgumentParser(
        "Set up the paired-gene subtype expression effect isolation "
        "experiment by enumerating the subtypes to be tested.")

    # create positional command line arguments
    parser.add_argument('cohort', type=str, help="which TCGA cohort to use")
    parser.add_argument('mut_levels',
                        type=str,
                        help="the mutation property levels to consider")
    parser.add_argument('genes',
                        type=str,
                        nargs='+',
                        help="a list of mutated genes")

    # create optional command line arguments
    parser.add_argument('--samp_cutoff',
                        type=int,
                        default=25,
                        help='subtype sample frequency threshold')
    parser.add_argument('--verbose',
                        '-v',
                        action='store_true',
                        help='turns on diagnostic messages')

    # parse command line arguments, create directory where found subtypes
    # will be stored
    args = parser.parse_args()
    use_lvls = args.mut_levels.split('__')
    out_path = os.path.join(base_dir, 'setup', args.cohort,
                            '_'.join(args.genes))
    os.makedirs(out_path, exist_ok=True)

    # log into Synapse using locally stored credentials
    syn = synapseclient.Synapse()
    syn.cache.cache_root_dir = ("/home/exacloud/lustre1/CompBio/"
                                "mgrzad/input-data/synapse")
    syn.login()

    cdata = MutationCohort(cohort=args.cohort,
                           mut_genes=args.genes,
                           mut_levels=['Gene'] + use_lvls,
                           expr_source='Firehose',
                           var_source='mc3',
                           expr_dir=firehose_dir,
                           cv_prop=1.0,
                           syn=syn)

    iso_mtypes = set()
    for gene in args.genes:
        if args.verbose:
            print("Looking for combinations of subtypes of mutations in gene "
                  "{} present in at least {} of the samples in TCGA cohort "
                  "{} at annotation levels {}.\n".format(
                      gene, args.samp_cutoff, args.cohort, use_lvls))

        gene_mtypes = cdata.train_mut[gene].find_unique_subtypes(
            max_types=1500 / len(args.genes),
            max_combs=4,
            verbose=2,
            sub_levels=use_lvls,
            min_type_size=args.samp_cutoff)

        if args.verbose:
            print("\nFound {} subtypes of gene {} to isolate!".format(
                len(gene_mtypes), gene))

        iso_mtypes |= {
            MuType({('Gene', gene): mtype})
            for mtype in gene_mtypes
            if (len(mtype.get_samples(cdata.train_mut[gene])) <= (
                len(cdata.samples) - args.samp_cutoff))
        }

    if args.verbose:
        print("\nFound {} total sub-types to isolate!".format(len(iso_mtypes)))

    # save the list of found non-duplicate sub-types to file
    pickle.dump(
        sorted(iso_mtypes),
        open(
            os.path.join(
                out_path, 'mtypes_list__samps_{}__levels_{}.p'.format(
                    args.samp_cutoff, args.mut_levels)), 'wb'))

    with open(
            os.path.join(
                out_path, 'mtypes_count__samps_{}__levels_{}.txt'.format(
                    args.samp_cutoff, args.mut_levels)), 'w') as fl:

        fl.write(str(len(iso_mtypes)))
Esempio n. 20
0
def plot_subtype_clustering(trans_expr,
                            args,
                            cdata,
                            use_gene,
                            pca_comps=(0, 1)):
    fig, axarr = plt.subplots(nrows=3,
                              ncols=4,
                              figsize=(18, 13),
                              sharex=True,
                              sharey=True)
    fig.tight_layout(pad=2.4, w_pad=2.1, h_pad=5.4)

    for ax in axarr.reshape(-1):
        ax.set_xticklabels([])
        ax.set_yticklabels([])

    pca_comps = np.array(pca_comps)
    trans_expr = trans_expr[:, pca_comps]
    mut_clr = sns.light_palette((1 / 3, 0, 0),
                                input="rgb",
                                n_colors=5,
                                reverse=True)[1]

    base_pheno = np.array(cdata.train_pheno(MuType({('Gene', use_gene):
                                                    None})))
    axarr[0, 0].set_title(use_gene, size=23)

    axarr[0, 0].scatter(trans_expr[~base_pheno, 0],
                        trans_expr[~base_pheno, 1],
                        marker='o',
                        s=14,
                        c='0.4',
                        alpha=0.25,
                        edgecolor='none')

    axarr[0, 0].scatter(trans_expr[base_pheno, 0],
                        trans_expr[base_pheno, 1],
                        marker='o',
                        s=45,
                        c=mut_clr,
                        alpha=0.4,
                        edgecolor='none')

    plot_mtypes = {
        MuType({('Gene', use_gene): mtype})
        for mtype in cdata.train_mut[use_gene].branchtypes(min_size=20)
    }

    plot_phenos = sorted([(mtype, np.array(cdata.train_pheno(mtype)))
                          for mtype in plot_mtypes],
                         key=lambda x: np.sum(x[1]),
                         reverse=True)

    if len(plot_mtypes) < 11:
        comb_mtypes = {
            MuType({('Gene', use_gene): mtype})
            for mtype in cdata.train_mut[use_gene].combtypes(min_type_size=25,
                                                             comb_sizes=(2, ))
        }

        plot_phenos += sorted([(mtype, np.array(cdata.train_pheno(mtype)))
                               for mtype in comb_mtypes],
                              key=lambda x: np.sum(x[1]),
                              reverse=True)[:(11 - len(plot_mtypes))]

    for ax, (mtype, pheno) in zip(axarr.reshape(-1)[1:], plot_phenos[:11]):
        ax.set_title(repr(mtype).replace(' WITH ', '\n').replace(' OR ', '\n'),
                     size=16)

        ax.scatter(trans_expr[~pheno, 0],
                   trans_expr[~pheno, 1],
                   marker='o',
                   s=14,
                   c='0.4',
                   alpha=0.25,
                   edgecolor='none')
        ax.scatter(trans_expr[pheno, 0],
                   trans_expr[pheno, 1],
                   marker='o',
                   s=45,
                   c=mut_clr,
                   alpha=0.4,
                   edgecolor='none')

    fig.text(0.5,
             0.02,
             'Component {}'.format(pca_comps[0] + 1),
             size=24,
             weight='semibold',
             ha='center')
    fig.text(0.02,
             0.5,
             'Component {}'.format(pca_comps[1] + 1),
             size=24,
             weight='semibold',
             va='center',
             rotation='vertical')

    fig.savefig(os.path.join(
        plot_dir, "{}_clustering_comps_{}-{}__{}_{}__levels__{}.png".format(
            args.cohort,
            pca_comps[0],
            pca_comps[1],
            use_gene,
            args.transform,
            args.mut_levels,
        )),
                dpi=250,
                bbox_inches='tight')

    plt.close()
Esempio n. 21
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('model_name', type=str,
                        help='the name of a Stan model')
    parser.add_argument(
        'solve_method', type=str,
        help='the method used for optimizing the parameters of the Stan model'
        )

    parser.add_argument('cohort', type=str, help='a TCGA cohort')
    parser.add_argument('gene', type=str, help='a gene with mutated samples')

    parser.add_argument('cv_id', type=int,
                        help='a random seed used for cross-validation')
    parser.add_argument('--verbose', '-v', action='store_true',
                        help='turns on diagnostic messages')

    args = parser.parse_args()
    out_path = os.path.join(base_dir, 'output', args.model_name,
                            args.solve_method, args.cohort, args.gene)

    if args.verbose:
        print("Starting distribution testing for Stan model {} using "
              "optimization method {} on mutated gene {} in TCGA cohort {} "
              "for cross-validation ID {} ...".format(
                  args.model_name, args.solve_method,
                  args.cohort, args.gene, args.cv_id
                ))

    use_mtype = MuType({('Gene', args.gene): None})
    use_module = import_module('HetMan.experiments.stan_test'
                               '.distr.models.{}'.format(args.model_name))
    UsePipe = getattr(use_module, 'UsePipe')

    if args.solve_method == 'optim':
        clf_stan = getattr(use_module, 'UsePipe')(
            getattr(use_module, 'UseOptimizing')(
                model_code=getattr(use_module, 'use_model'))
            )

    elif args.solve_method == 'variat':
        clf_stan = getattr(use_module, 'UsePipe')(
            getattr(use_module, 'UseVariational')(
                model_code=getattr(use_module, 'use_model'))
            )

    elif args.solve_method == 'sampl':
        clf_stan = getattr(use_module, 'UsePipe')(
            getattr(use_module, 'UseSampling')(
                model_code=getattr(use_module, 'use_model'))
            )

    else:
        raise ValueError("Unrecognized <solve_method> argument!")

    if '_' in args.gene:
        mut_info = args.gene.split('_')
        use_mtype = MuType({('Gene', mut_info[0]): mtype_list[mut_info[1]]})

    else:
        use_mtype = MuType({('Gene', args.gene): None})

    clf_stan = eval("model_dict['{}']".format(args.model_name))
    
    cdata = MutationCohort(
        cohort=args.cohort, mut_genes=[args.gene], mut_levels=['Gene'],
        expr_source='Firehose', expr_dir=firehose_dir, var_source='mc3',
        syn=syn, cv_prop=1.0, cv_seed=1298 + 93 * args.cv_id
        )

    clf_stan.tune_coh(cdata, use_mtype, exclude_genes={args.gene},
                      tune_splits=4, test_count=24, parallel_jobs=12)
    clf_stan.fit_coh(cdata, use_mtype, exclude_genes={args.gene})

    if clf_stan.tune_priors:
        clf_params = clf_stan.get_params()
    else:
        clf_params = None

    infer_mat = clf_stan.infer_coh(
        cdata, use_mtype, exclude_genes={args.gene},
        infer_splits=12, infer_folds=4, parallel_jobs=12
        )

    pickle.dump(
        {'Params': clf_params, 'Infer': infer_mat,
         'Vars': clf_stan.named_steps['fit'].get_var_means()},
        open(os.path.join(out_path, 'out__cv-{}.p'.format(args.cv_id)), 'wb')
        )
Esempio n. 22
0
def main():
    """Runs the experiment."""

    parser = argparse.ArgumentParser(
        description='Set up touring for sub-types to detect.')

    parser.add_argument('cohort', type=str, help="which TCGA cohort to use")
    parser.add_argument('gene1', type=str, help="which gene to consider")
    parser.add_argument('gene2', type=str, help="which gene to consider")

    parser.add_argument(
        'mut_levels',
        type=str,
        help='the mutation property levels to consider, in addition to `Gene`')

    parser.add_argument('--samp_cutoff',
                        type=int,
                        default=20,
                        help='subtype sample frequency threshold')
    parser.add_argument('--verbose',
                        '-v',
                        action='store_true',
                        help='turns on diagnostic messages')

    # parse the command line arguments, get the directory where found sub-types
    # will be saved for future use
    args = parser.parse_args()
    out_path = os.path.join(base_dir, 'setup', args.cohort,
                            '{}_{}'.format(args.gene1, args.gene2))

    os.makedirs(out_path, exist_ok=True)
    use_lvls = args.mut_levels.split('__')

    # log into Synapse using locally stored credentials
    syn = synapseclient.Synapse()
    syn.cache.cache_root_dir = ("/home/exacloud/lustre1/CompBio/"
                                "mgrzad/input-data/synapse")
    syn.login()

    cdata = MutationCohort(cohort=args.cohort,
                           mut_genes=[args.gene1, args.gene2],
                           mut_levels=['Gene'] + use_lvls,
                           expr_source='Firehose',
                           var_source='mc3',
                           expr_dir=firehose_dir,
                           cv_prop=1.0,
                           syn=syn)

    cross_mtypes1 = cdata.train_mut[args.gene1].find_unique_subtypes(
        max_types=40,
        max_combs=50,
        verbose=2,
        sub_levels=use_lvls,
        min_type_size=args.samp_cutoff)
    cross_mtypes2 = cdata.train_mut[args.gene2].find_unique_subtypes(
        max_types=40,
        max_combs=50,
        verbose=2,
        sub_levels=use_lvls,
        min_type_size=args.samp_cutoff)

    if args.verbose:
        print("Found {} sub-types of {} and {} sub-types of {} "
              "to cross!".format(len(cross_mtypes1), args.gene1,
                                 len(cross_mtypes2), args.gene2))

    cross_mtypes1 = {
        MuType({('Gene', args.gene1): mtype})
        for mtype in cross_mtypes1
    }
    cross_mtypes2 = {
        MuType({('Gene', args.gene2): mtype})
        for mtype in cross_mtypes2
    }

    samps1 = {
        mtype: mtype.get_samples(cdata.train_mut)
        for mtype in cross_mtypes1
    }
    samps2 = {
        mtype: mtype.get_samples(cdata.train_mut)
        for mtype in cross_mtypes2
    }

    use_pairs = sorted(
        (mtype1, mtype2)
        for mtype1, mtype2 in product(cross_mtypes1, cross_mtypes2)
        if (len(samps1[mtype1] - samps2[mtype2]) >= args.samp_cutoff
            and len(samps2[mtype2] - samps1[mtype1]) >= args.samp_cutoff))

    if args.verbose:
        print("\nSaving {} pairs with sufficient "
              "exclusivity...".format(len(use_pairs)))

    pickle.dump(
        use_pairs,
        open(
            os.path.join(
                out_path, 'pairs_list__samps_{}__levels_{}.p'.format(
                    args.samp_cutoff, args.mut_levels)), 'wb'))

    pickle.dump(
        {(mtype1, mtype2): cdata.mutex_test(mtype1, mtype2)
         for mtype1, mtype2 in use_pairs},
        open(
            os.path.join(
                out_path, 'pairs_mutex__samps_{}__levels_{}.p'.format(
                    args.samp_cutoff, args.mut_levels)), 'wb'))

    pickle.dump({'Samps': cdata.samples},
                open(os.path.join(out_path, 'cohort_info.p'), 'wb'))

    with open(
            os.path.join(
                out_path, 'pairs_count__samps_{}__levels_{}.txt'.format(
                    args.samp_cutoff, args.mut_levels)), 'w') as fl:

        fl.write(str(len(use_pairs)))
Esempio n. 23
0
def main():
    parser = argparse.ArgumentParser(
        "Set up the paired gene expression effect isolation experiment by "
        "enumerating the dyads of genes to be tested.")

    parser.add_argument('cohort', type=str, help="which TCGA cohort to use")
    parser.add_argument('--samp_cutoff',
                        type=int,
                        default=40,
                        help='subtype sample frequency threshold')
    parser.add_argument('--verbose',
                        '-v',
                        action='store_true',
                        help='turns on diagnostic messages')

    # parse command line arguments, create directory where found pairs
    # will be stored
    args = parser.parse_args()
    out_path = os.path.join(base_dir, 'setup', args.cohort)
    os.makedirs(out_path, exist_ok=True)

    # log into Synapse using locally stored credentials
    syn = synapseclient.Synapse()
    syn.cache.cache_root_dir = ("/home/exacloud/lustre1/CompBio/"
                                "mgrzad/input-data/synapse")
    syn.login()

    cdata = MutationCohort(cohort=args.cohort,
                           mut_genes=None,
                           mut_levels=['Gene'],
                           expr_source='Firehose',
                           var_source='mc3',
                           expr_dir=firehose_dir,
                           samp_cutoff=args.samp_cutoff,
                           cv_prop=1.0,
                           syn=syn)

    if args.verbose:
        print("Looking for pairs of mutated genes present in at least {} of "
              "the samples in TCGA cohort {} with {} total samples.".format(
                  args.samp_cutoff, args.cohort, len(cdata.samples)))

    gene_pairs = {
        (MuType({('Gene', gn1): None}), MuType({('Gene', gn2): None}))
        for (gn1, muts1), (gn2, muts2) in combn(cdata.train_mut, r=2)
        if (len(muts1 - muts2) >= args.samp_cutoff
            and len(muts2 - muts1) >= args.samp_cutoff
            and len(muts1 | muts2) <= (len(cdata.samples) - args.samp_cutoff))
    }

    if args.verbose:
        print("Found {} pairs of genes to isolate!".format(len(gene_pairs)))

    pickle.dump(
        sorted(gene_pairs),
        open(
            os.path.join(out_path,
                         'pairs_list__samps_{}.p'.format(args.samp_cutoff)),
            'wb'))

    with open(
            os.path.join(out_path,
                         'pairs_count__samps_{}.txt'.format(args.samp_cutoff)),
            'w') as fl:

        fl.write(str(len(gene_pairs)))
Esempio n. 24
0
def main():
    """Runs the experiment."""

    parser = argparse.ArgumentParser(
        "Isolate the expression signature of mutation subtypes from their "
        "parent gene(s)' signature or that of a list of genes in a given "
        "TCGA cohort.")

    # positional command line arguments for where input data and output
    # data is to be stored
    parser.add_argument('mtype_file',
                        type=str,
                        help='the pickle file where sub-types are stored')
    parser.add_argument('out_dir',
                        type=str,
                        help='where to save the output of testing sub-types')

    # positional arguments for which cohort of samples and which mutation
    # classifier to use for testing
    parser.add_argument('cohort', type=str, help='a TCGA cohort')
    parser.add_argument('classif',
                        type=str,
                        help='a classifier in HetMan.predict.classifiers')

    parser.add_argument('--use_genes',
                        type=str,
                        default=None,
                        nargs='+',
                        help='specify which gene(s) to isolate against')

    parser.add_argument(
        '--cv_id',
        type=int,
        default=6732,
        help='the random seed to use for cross-validation draws')

    parser.add_argument(
        '--task_count',
        type=int,
        default=10,
        help='how many parallel tasks the list of types to test is split into')
    parser.add_argument('--task_id',
                        type=int,
                        default=0,
                        help='the subset of subtypes to assign to this task')

    # optional arguments controlling how classifier tuning is to be performed
    parser.add_argument('--tune_splits',
                        type=int,
                        default=4,
                        help='how many cohort splits to use for tuning')
    parser.add_argument(
        '--test_count',
        type=int,
        default=16,
        help='how many hyper-parameter values to test in each tuning split')

    parser.add_argument(
        '--infer_splits',
        type=int,
        default=20,
        help='how many cohort splits to use for inference bootstrapping')
    parser.add_argument(
        '--infer_folds',
        type=int,
        default=4,
        help=('how many parts to split the cohort into in each inference '
              'cross-validation run'))

    parser.add_argument(
        '--parallel_jobs',
        type=int,
        default=4,
        help='how many parallel CPUs to allocate the tuning tests across')

    parser.add_argument('--verbose',
                        '-v',
                        action='store_true',
                        help='turns on diagnostic messages')

    args = parser.parse_args()
    out_file = os.path.join(args.out_dir,
                            'out__task-{}.p'.format(args.task_id))

    if args.verbose:
        print("Starting isolation for sub-types in\n{}\nthe results of "
              "which will be stored in\n{}\nwith classifier <{}>.".format(
                  args.mtype_file, args.out_dir, args.classif))

    mtype_list = pickle.load(open(args.mtype_file, 'rb'))
    use_lvls = []

    for lvls in reduce(or_,
                       [{mtype.get_sorted_levels()} for mtype in mtype_list]):
        for lvl in lvls:
            if lvl not in use_lvls:
                use_lvls.append(lvl)

    if args.use_genes is None:
        if set(mtype.cur_level for mtype in mtype_list) == {'Gene'}:
            use_genes = reduce(or_, [
                set(gn for gn, _ in mtype.subtype_list())
                for mtype in mtype_list
            ])

        else:
            raise ValueError(
                "A gene to isolate against must be given or the subtypes "
                "listed must have <Gene> as their top level!")

    else:
        use_genes = set(args.use_genes)

    if args.verbose:
        print("Subtypes at mutation annotation levels {} will be isolated "
              "against genes:\n{}".format(use_lvls, use_genes))

    if args.classif[:6] == 'Stan__':
        use_module = import_module('HetMan.experiments.utilities'
                                   '.stan_models.{}'.format(
                                       args.classif.split('Stan__')[1]))
        mut_clf = getattr(use_module, 'UsePipe')

    else:
        mut_clf = eval(args.classif)

    # log into Synapse using locally stored credentials
    syn = synapseclient.Synapse()
    syn.cache.cache_root_dir = ("/home/exacloud/lustre1/CompBio/"
                                "mgrzad/input-data/synapse")
    syn.login()

    # loads the expression data and gene mutation data for the given TCGA
    # cohort, with the training/testing cohort split defined by the
    # cross-validation ID for this task
    cdata = MutationCohort(cohort=args.cohort,
                           mut_genes=list(use_genes),
                           mut_levels=use_lvls,
                           expr_source='Firehose',
                           expr_dir=firehose_dir,
                           syn=syn,
                           cv_seed=args.cv_id,
                           cv_prop=1.0)

    if args.verbose:
        print("Loaded {} subtypes of which roughly {} will be isolated in "
              "cohort {} with {} samples.".format(
                  len(mtype_list),
                  len(mtype_list) // args.task_count, args.cohort,
                  len(cdata.samples)))

    out_iso = {mtype: None for mtype in mtype_list}
    base_mtype = MuType({('Gene', tuple(use_genes)): None})
    base_samps = base_mtype.get_samples(cdata.train_mut)

    # for each subtype, check if it has been assigned to this task
    for i, mtype in enumerate(mtype_list):
        if (i % args.task_count) == args.task_id:
            if args.verbose:
                print("Isolating {} ...".format(mtype))

            clf = mut_clf()
            ex_samps = base_samps - mtype.get_samples(cdata.train_mut)

            clf.tune_coh(cdata,
                         mtype,
                         exclude_genes=use_genes,
                         exclude_samps=ex_samps,
                         tune_splits=args.tune_splits,
                         test_count=args.test_count,
                         parallel_jobs=args.parallel_jobs)

            out_iso[mtype] = clf.infer_coh(cdata,
                                           mtype,
                                           exclude_genes=use_genes,
                                           force_test_samps=ex_samps,
                                           infer_splits=args.infer_splits,
                                           infer_folds=args.infer_folds,
                                           parallel_jobs=args.parallel_jobs)

        else:
            del (out_iso[mtype])

    pickle.dump(
        {
            'Infer': out_iso,
            'Info': {
                'TunePriors': mut_clf.tune_priors,
                'TuneSplits': args.tune_splits,
                'TestCount': args.test_count,
                'InferFolds': args.infer_folds
            }
        }, open(out_file, 'wb'))
Esempio n. 25
0
def plot_position(infer_vals, args, cdata, mtype1, mtype2):
    fig, ax = plt.subplots(figsize=(15, 14))

    base_pheno = np.array(cdata.train_pheno(
        MuType({('Gene', args.gene): None})))
    pheno1 = np.array(cdata.train_pheno(mtype1))
    pheno2 = np.array(cdata.train_pheno(mtype2))

    use_vals = [np.concatenate(infer_vals[pheno].values)
                for pheno in [~base_pheno, pheno2, pheno1,
                              base_pheno & (~pheno1 & ~pheno2)]]

    auc_mtype1 = np.greater.outer(use_vals[2][:, 0], use_vals[0][:, 0]).mean()
    auc_mtype2 = np.greater.outer(use_vals[1][:, 1], use_vals[0][:, 1]).mean()

    plt_bound = np.max(np.absolute([
        infer_vals.apply(np.min).quantile(0.01),
        infer_vals.apply(np.max).quantile(0.99), 1.1
        ]))

    use_clrs = ['0.5', '#9B5500', '#044063', '#5C0165']
    use_lws = [2.1, None, None, 2.4]
    use_lvls = [19, 5, 5, 14]
    use_alphas = [0.43, 0.55, 0.55, 0.65]
    shade_stat = [False, True, True, False]

    for vals, use_clr, use_lw, use_alpha, use_lvl, shd in zip(
            use_vals, use_clrs, use_lws, use_alphas, use_lvls, shade_stat):

        sns.kdeplot(vals[:, 0], vals[:, 1],
                    cmap=sns.light_palette(use_clr, as_cmap=True), shade=shd,
                    alpha=use_alpha, shade_lowest=False, linewidths=use_lw,
                    bw=plt_bound / 37, gridsize=250, n_levels=use_lvl)

    ax.text(plt_bound * -0.86, plt_bound * 0.27,
            "AUC: {:.3f}".format(auc_mtype2),
            size=24, color='#9B5500', alpha=0.76)
    ax.text(plt_bound * 0.51, plt_bound * -0.53,
            "AUC: {:.3f}".format(auc_mtype1),
            size=24, color='#044063', alpha=0.76)

    plt.legend(
        [Line2D([0], [0], color=use_clr, lw=9.2) for use_clr in use_clrs],
        ["Wild-Type", str(mtype2), str(mtype1),
         "Remaining {} Mutants".format(args.gene)],
        fontsize=21, loc=4
        )

    plt.xlim(-plt_bound, plt_bound)
    plt.ylim(-plt_bound, plt_bound)

    plt.xlabel('Inferred {} Score'.format(mtype1), size=28, weight='semibold')
    plt.ylabel('Inferred {} Score'.format(mtype2), size=28, weight='semibold')

    plt.savefig(
        os.path.join(plot_dir, args.cohort, args.gene, args.mut_levels,
                     "{}__xx__{}__{}__{}.png".format(
                         mtype1.get_label(), mtype2.get_label(),
                         args.model_name, args.solve_method
                        )),
        dpi=300, bbox_inches='tight'
        )

    plt.close()
Esempio n. 26
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('cohort', type=str, help='a TCGA cohort')
    parser.add_argument('gene', type=str, help='a mutated gene')
    parser.add_argument('classif', type=str, help='a mutated gene')

    parser.add_argument(
        'toil_dir',
        type=str,
        help='the directory where toil expression data is saved')
    parser.add_argument('syn_root',
                        type=str,
                        help='Synapse cache root directory')
    parser.add_argument(
        'patient_dir',
        type=str,
        help='directy where SMMART patient RNAseq abundances are stored')

    parser.add_argument(
        '--tune_splits',
        type=int,
        default=4,
        help='how many training cohort splits to use for tuning')
    parser.add_argument(
        '--test_count',
        type=int,
        default=16,
        help='how many hyper-parameter values to test in each tuning split')

    parser.add_argument(
        '--infer_splits',
        type=int,
        default=20,
        help='how many cohort splits to use for inference bootstrapping')
    parser.add_argument(
        '--infer_folds',
        type=int,
        default=4,
        help=('how many parts to split the cohort into in each inference '
              'cross-validation run'))

    parser.add_argument(
        '--parallel_jobs',
        type=int,
        default=4,
        help='how many parallel CPUs to allocate the tuning tests across')

    parser.add_argument('--cv_id', type=int, default=0)
    parser.add_argument('--verbose',
                        '-v',
                        action='store_true',
                        help='turns on diagnostic messages')

    args = parser.parse_args()
    out_dir = os.path.join(base_dir, 'output', 'gene_models', args.cohort,
                           args.gene)
    os.makedirs(out_dir, exist_ok=True)
    out_file = os.path.join(out_dir,
                            '{}__cv-{}.p'.format(args.classif, args.cv_id))

    if args.classif[:6] == 'Stan__':
        use_module = import_module('HetMan.experiments.utilities'
                                   '.stan_models.{}'.format(
                                       args.classif.split('Stan__')[1]))
        mut_clf = getattr(use_module, 'UsePipe')

    else:
        mut_clf = eval(args.classif)

    base_mtype = MuType({('Gene', args.gene): None})
    clf = mut_clf()

    # log into Synapse using locally stored credentials
    syn = synapseclient.Synapse()
    syn.cache.cache_root_dir = args.syn_root
    syn.login()

    cdata = CancerCohort(cancer=args.cohort,
                         mut_genes=[args.gene],
                         mut_levels=['Gene'],
                         tcga_dir=args.toil_dir,
                         patient_dir=args.patient_dir,
                         syn=syn,
                         collapse_txs=True,
                         cv_seed=(args.cv_id * 59) + 121,
                         cv_prop=1.0)
    smrt_samps = {samp for samp in cdata.samples if samp[:4] != 'TCGA'}

    clf.tune_coh(cdata,
                 base_mtype,
                 exclude_genes={args.gene},
                 exclude_samps=smrt_samps,
                 tune_splits=args.tune_splits,
                 test_count=args.test_count,
                 parallel_jobs=args.parallel_jobs)

    clf_params = clf.get_params()
    tuned_params = {par: clf_params[par] for par, _ in mut_clf.tune_priors}

    infer_mat = clf.infer_coh(cdata,
                              base_mtype,
                              force_test_samps=smrt_samps,
                              exclude_genes={args.gene},
                              infer_splits=args.infer_splits,
                              infer_folds=args.infer_folds)

    pickle.dump(
        {
            'Infer': infer_mat,
            'Info': {
                'TunePriors': mut_clf.tune_priors,
                'TuneSplits': args.tune_splits,
                'TestCount': args.test_count,
                'TunedParams': tuned_params
            }
        }, open(out_file, 'wb'))
Esempio n. 27
0
def plot_label_distribution(infer_vals, args, cdata):
    fig, ax = plt.subplots(figsize=(7, 14))

    samp_list = cdata.subset_samps()
    infer_means = np.apply_along_axis(
        lambda x: np.mean(np.concatenate(x)), 1, infer_vals)

    tcga_means = pd.Series(
        {samp: val for samp, val in zip(samp_list, infer_means)
         if 'TCGA' in samp}
        )

    smrt_means = sorted(
        [(samp, val) for samp, val in zip(samp_list, infer_means)
         if 'TCGA' not in samp],
        key=itemgetter(1)
        )

    if np.all(infer_means >= 0):
        plt_ymin, plt_ymax = 0, max(np.max(infer_means) * 1.09, 1)

    else:
        plt_ymax = np.max([np.max(np.absolute(infer_means)) * 1.09, 1.1])
        plt_ymin = -plt_ymax

    plt.ylim(plt_ymin, plt_ymax)
    plt_xmin, plt_xmax = plt.xlim()
    lbl_pad = (plt_ymax - plt_ymin) / 79

    use_mtype = MuType({('Gene', args.gene): None})
    mtype_stat = np.array(cdata.train_mut.status(tcga_means.index))
    kern_bw = (plt_ymax - plt_ymin) / 47

    ax = sns.kdeplot(tcga_means[~mtype_stat], color=wt_clr, vertical=True,
                     shade=True, alpha=0.36, linewidth=0, bw=kern_bw, cut=0,
                     gridsize=1000, label='Wild-Type')

    ax = sns.kdeplot(tcga_means[mtype_stat], color=mut_clrs[0], vertical=True,
                     shade=True, alpha=0.36, linewidth=0, bw=kern_bw, cut=0,
                     gridsize=1000, label='{} Mutant'.format(args.gene))

    # for each SMMART patient, check if they have a mutation of the given gene
    for i, (patient, val) in enumerate(smrt_means):
        if patient in cdata.train_mut.get_samples():

            mut_list = []
            for lbl, muts in cdata.train_mut[args.gene]:
                if patient in muts:
                    mut_list += [lbl]

            plt_str = '{} ({})'.format(patient, '+'.join(mut_list))
            plt_clr = mut_clrs[1]
            plt_lw = 3.1

        # if the patient's RNAseq sample did not have any mutations, check all
        # the samples associated with the patient
        else:
            mut_dir = os.path.join(
                args.patient_dir, "16113-{}".format(patient.split(' ---')[0]),
                'output', 'cancer_exome'
                )

            # check if any mutation calling was done for this patient
            mut_files = subprocess.run(
                'find {}'.format(mut_dir), shell=True,
                stdout=subprocess.PIPE, stderr=subprocess.PIPE
                ).stdout.decode('utf-8')

            # if calling was done for any sample associated with this patient,
            # check for mutations of the given gene
            if mut_files:
                mut_grep = subprocess.run(
                    'grep "^{}" {}'.format(
                        args.gene, os.path.join(
                            mut_dir, "*SMMART_Cancer_Exome*", "*.maf")
                        ),
                    shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
                    ).stdout.decode('utf-8')

                if mut_grep:
                    mut_list = []
                    for mut_match in mut_grep.split('\n'):
                        if mut_match:
                            mut_list += [mut_match.split('\t')[8]]

                    plt_str = '{} ({})'.format(
                        patient, '+'.join(np.unique(mut_list)))
                    plt_clr = mut_clrs[3]
                    plt_lw = 2.6

                else:
                    plt_str = '{}'.format(patient)
                    plt_clr = wt_clr
                    plt_lw = 1.7

            else:
                plt_str = '{}'.format(patient)
                plt_clr = '#2E6AF3'
                plt_lw = 1.7

        ax.axhline(y=val, xmin=0, xmax=plt_xmax * 0.22,
                   c=plt_clr, ls='--', lw=plt_lw)

        if i > 0 and smrt_means[i - 1][1] > (val - lbl_pad):
            txt_va = 'bottom'

        elif (i < (len(smrt_means) - 1)
              and smrt_means[i + 1][1] < (val + lbl_pad)):
            txt_va = 'top'

        else:
            txt_va = 'center'

        ax.text(plt_xmax * 0.32, val, plt_str, size=9, ha='left', va=txt_va)

    # calculate the accuracy of the mutation scores inferred across
    # validation runs in predicting mutation status
    tcga_f1 = average_precision_score(mtype_stat, tcga_means)
    tcga_auc = np.greater.outer(tcga_means[mtype_stat],
                                tcga_means[~mtype_stat]).mean()

    # add annotation about the mutation scores' accuracy to the plot
    ax.text(ax.get_xlim()[1] * 0.91, plt_ymax * 0.82, size=18, ha='right',
            s="TCGA AUPR:{:8.3f}".format(tcga_f1))
    ax.text(ax.get_xlim()[1] * 0.91, plt_ymax * 0.88, size=18, ha='right',
            s="TCGA AUC:{:8.3f}".format(tcga_auc))

    plt.xlabel('TCGA-{} Density'.format(args.cohort),
               fontsize=21, weight='semibold')
    plt.ylabel('Inferred {} Mutation Score'.format(args.gene),
               fontsize=21, weight='semibold')

    plt.legend([Line2D([0], [0], color=mut_clrs[1], lw=3.7, ls='--'),
                Line2D([0], [0], color=mut_clrs[3], lw=3.7, ls='--'),
                Patch(color=mut_clrs[0], alpha=0.36),
                Line2D([0], [0], color=wt_clr, lw=3.7, ls='--'),
                Line2D([0], [0], color='#2E6AF3', lw=3.7, ls='--'),
                Patch(color=wt_clr, alpha=0.36)],
               ["Sample {} Mutant".format(args.gene),
                "Patient {} Mutant".format(args.gene),
                "TCGA {} Mutants".format(args.gene), "SMMART Wild-Type",
                "No Mutation Calls", "TCGA Wild-Types"],
               fontsize=13, loc=8, ncol=2)

    fig.savefig(
        os.path.join(plot_dir,
                     'labels__{}-{}-{}.png'.format(
                         args.cohort, args.gene, args.classif)),
        dpi=300, bbox_inches='tight'
        )

    plt.close()