Esempio n. 1
0
def plot_aupr_comparisons(pred_df, pheno_dict, auc_df, conf_vals, args):
    fig, (base_ax, subg_ax) = plt.subplots(figsize=(17, 8), nrows=1, ncols=2)

    plot_dicts = {'Base': dict(), 'Subg': dict()}
    line_dicts = {'Base': dict(), 'Subg': dict()}
    plt_max = 0.53

    for gene, auc_vec in auc_df['mean'].groupby(get_label):
        if len(auc_vec) > 1:
            base_mtype = MuType({('Gene', gene): None})

            base_indx = auc_vec.index.get_loc(base_mtype)
            best_subtype = auc_vec[:base_indx].append(
                auc_vec[(base_indx + 1):]).idxmax()

            if auc_vec[best_subtype] > 0.6:
                base_size = np.mean(pheno_dict[base_mtype])
                plt_size = 0.07 * base_size ** 0.5
                best_prop = np.mean(pheno_dict[best_subtype]) / base_size

                base_infr = pred_df.loc[base_mtype].apply(np.mean)
                best_infr = pred_df.loc[best_subtype].apply(np.mean)

                base_auprs = (aupr_score(pheno_dict[base_mtype], base_infr),
                              aupr_score(pheno_dict[base_mtype], best_infr))
                subg_auprs = (aupr_score(pheno_dict[best_subtype], base_infr),
                              aupr_score(pheno_dict[best_subtype], best_infr))

                conf_sc = calc_conf(conf_vals[best_subtype],
                                    conf_vals[base_mtype])

                base_lbl = '', ''
                subg_lbl = '', ''
                min_diff = np.log2(1.25)

                mtype_lbl = get_fancy_label(get_subtype(best_subtype),
                                            pnt_link='\nor ', phrase_link=' ')

                if conf_sc > 0.9:
                    base_lbl = gene, mtype_lbl
                    subg_lbl = gene, mtype_lbl

                elif (auc_vec[base_indx] > 0.75
                        or auc_vec[best_subtype] > 0.75):
                    base_lbl = gene, ''
                    subg_lbl = gene, ''

                elif auc_vec[base_indx] > 0.6 or auc_vec[best_subtype] > 0.6:
                    if abs(np.log2(base_auprs[1] / base_auprs[0])) > min_diff:
                        base_lbl = gene, ''
                    if abs(np.log2(subg_auprs[1] / subg_auprs[0])) > min_diff:
                        subg_lbl = gene, ''

                for lbl, auprs, mtype_lbl in zip(['Base', 'Subg'],
                                                 (base_auprs, subg_auprs),
                                                 [base_lbl, subg_lbl]):
                    plot_dicts[lbl][auprs] = plt_size, mtype_lbl
                    line_dicts[lbl][auprs] = dict(c=choose_label_colour(gene))

                for ax, lbl, (base_aupr, subg_aupr) in zip(
                        [base_ax, subg_ax], ['Base', 'Subg'],
                        [base_auprs, subg_auprs]
                        ):
                    plt_max = min(1.005,
                                  max(plt_max,
                                      base_aupr + 0.11, subg_aupr + 0.11))

                    auc_bbox = (base_aupr - plt_size / 2,
                                subg_aupr - plt_size / 2, plt_size, plt_size)

                    pie_ax = inset_axes(
                        ax, width='100%', height='100%',
                        bbox_to_anchor=auc_bbox, bbox_transform=ax.transData,
                        axes_kwargs=dict(aspect='equal'), borderpad=0
                        )

                    use_clr = line_dicts[lbl][base_aupr, subg_aupr]['c']
                    pie_ax.pie(x=[best_prop, 1 - best_prop],
                               colors=[use_clr + (0.77,),
                                       use_clr + (0.29,)],
                               explode=[0.29, 0], startangle=90)

    base_ax.set_title("AUPR on all point mutations",
                      size=21, weight='semibold')
    subg_ax.set_title("AUPR on best subgrouping mutations",
                      size=21, weight='semibold')

    for ax, lbl in zip([base_ax, subg_ax], ['Base', 'Subg']):
        ax.grid(alpha=0.41, linewidth=0.9)

        ax.plot([0, plt_max], [0, 0],
                color='black', linewidth=1.5, alpha=0.89)
        ax.plot([0, 0], [0, plt_max],
                color='black', linewidth=1.5, alpha=0.89)

        ax.plot([0, plt_max], [1, 1],
                color='black', linewidth=1.5, alpha=0.89)
        ax.plot([1, 1], [0, plt_max],
                color='black', linewidth=1.5, alpha=0.89)

        ax.plot([0, plt_max], [0, plt_max],
                color='#550000', linewidth=1.7, linestyle='--', alpha=0.37)

        ax.set_xlabel("using gene-wide task's predictions",
                      size=19, weight='semibold')
        ax.set_ylabel("using best subgrouping task's predicted scores",
                      size=19, weight='semibold')

        lbl_pos = place_scatter_labels(plot_dicts[lbl], ax,
                                       plt_lims=[[plt_max / 67, plt_max]] * 2,
                                       line_dict=line_dicts[lbl])

        ax.set_xlim([-plt_max / 181, plt_max])
        ax.set_ylim([-plt_max / 181, plt_max])

    plt.savefig(
        os.path.join(plot_dir, '__'.join([args.expr_source, args.cohort]),
                     "aupr-comparisons_{}.svg".format(args.classif)),
        bbox_inches='tight', format='svg'
        )

    plt.close()
Esempio n. 2
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('expr_source', type=str,
                        choices=list(expr_sources.keys()),
                        help='which TCGA expression data source to use')
    parser.add_argument('cohort', type=str, help="which TCGA cohort to use")

    parser.add_argument(
        'samp_cutoff', type=int,
        help="minimum number of mutated samples needed to test a gene"
        )

    parser.add_argument('classif', type=str,
                        help='the name of a mutation classifier')
    
    parser.add_argument(
        '--cv_id', type=int, default=6732,
        help='the random seed to use for cross-validation draws'
        )
 
    parser.add_argument(
        '--task_count', type=int, default=10,
        help='how many parallel tasks the list of types to test is split into'
        )
    parser.add_argument('--task_id', type=int, default=0,
                        help='the subset of subtypes to assign to this task')

    parser.add_argument('--verbose', '-v', action='store_true',
                        help='turns on diagnostic messages')

    # parse command-line arguments, create directory where to save results
    args = parser.parse_args()
    out_path = os.path.join(
        base_dir, 'output', args.expr_source,
        '{}__samps-{}'.format(args.cohort, args.samp_cutoff), args.classif
        )

    comb_list = pickle.load(
        open(os.path.join(base_dir, "setup",
                          "combs-list_{}__{}__samps-{}.p".format(
                              args.expr_source, args.cohort,
                              args.samp_cutoff
                            )),
             'rb')
        )

    cdata = get_cohort_data(args.expr_source, args.cohort, args.samp_cutoff,
                            cv_prop=0.75, cv_seed=2079 + 57 * args.cv_id)

    clf_info = args.classif.split('__')
    clf_module = import_module(
        'HetMan.experiments.multi_baseline.models.{}'.format(clf_info[0]))
    mut_clf = getattr(clf_module, clf_info[1].capitalize())

    out_auc = {mtype: {'train': [None] * 2, 'test': [None] * 2}
               for mtype in comb_list}
    out_aupr = {mtype: {'train': [None] * 2, 'test': [None] * 2}
                for mtype in comb_list}

    out_params = {mtype: None for mtype in comb_list}
    out_time = {mtype: None for mtype in comb_list}

    for i, mtypes in enumerate(comb_list):
        if (i % args.task_count) == args.task_id:
            clf = mut_clf()

            if args.verbose:
                print("Testing {} ...".format(' with '.join(
                    [str(mtype) for mtype in mtypes])))

            mut_genes = {mtype.subtype_list()[0][0] for mtype in mtypes}
            ex_chroms = {cdata.gene_annot[mut_gene]['chr']
                         for mut_gene in mut_genes}
            ex_genes = {gene for gene, annot in cdata.gene_annot.items()
                        if annot['chr'] in ex_chroms}

            clf.tune_coh(cdata, mtypes, exclude_genes=ex_genes,
                         tune_splits=4, test_count=36, parallel_jobs=12)
            out_params[mtypes] = {par: clf.get_params()[par]
                                  for par, _ in mut_clf.tune_priors}

            t_start = time.time()
            clf.fit_coh(cdata, mtypes, exclude_genes=ex_genes)
            t_end = time.time()
            out_time[mtypes] = t_end - t_start

            pheno_list = dict()
            train_omics, pheno_list['train'] = cdata.train_data(
                mtypes, exclude_genes=ex_genes)
            test_omics, pheno_list['test'] = cdata.test_data(
                mtypes, exclude_genes=ex_genes)

            pred_scores = {
                'train': clf.parse_preds(clf.predict_omic(train_omics)),
                'test': clf.parse_preds(clf.predict_omic(test_omics))
                }

            samp_sizes = {
                'train': [(len(mtype.get_samples(cdata.train_mut))
                           / len(cdata.train_samps)) for mtype in mtypes],
                'test': [(len(mtype.get_samples(cdata.test_mut))
                          / len(cdata.test_samps)) for mtype in mtypes]
                }

            for samp_set, scores in pred_scores.items():
                for i in range(2):
                    if len(set(pheno_list[samp_set][:, i])) == 2:

                        out_auc[mtypes][samp_set][i] = roc_auc_score(
                            pheno_list[samp_set][:, i], scores[:, i])
                        out_aupr[mtypes][samp_set][i] = aupr_score(
                            pheno_list[samp_set][:, i], scores[:, i])
 
                    else:
                        out_auc[mtypes][samp_set][i] = 0.5
                        out_aupr[mtypes][samp_set][i] = samp_sizes[samp_set]

        else:
            del(out_auc[mtypes])
            del(out_aupr[mtypes])
            del(out_params[mtypes])
            del(out_time[mtypes])

    pickle.dump(
        {'AUC': out_auc, 'AUPR': out_aupr,
         'Clf': mut_clf, 'Params': out_params, 'Time': out_time},
        open(os.path.join(out_path,
                          'out__cv-{}_task-{}.p'.format(
                              args.cv_id, args.task_id)),
             'wb')
        )
Esempio n. 3
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('expr_source',
                        type=str,
                        choices=['Firehose', 'toil', 'toil_tx'],
                        help='which TCGA expression data source to use')

    parser.add_argument(
        'samp_cutoff',
        type=int,
        help="minimum number of mutated samples needed to test a gene")

    parser.add_argument('classif',
                        type=str,
                        help='the name of a mutation classifier')

    parser.add_argument(
        '--cv_id',
        type=int,
        default=6732,
        help='the random seed to use for cross-validation draws')

    parser.add_argument(
        '--task_count',
        type=int,
        default=10,
        help='how many parallel tasks the list of types to test is split into')
    parser.add_argument('--task_id',
                        type=int,
                        default=0,
                        help='the subset of subtypes to assign to this task')

    parser.add_argument('--verbose',
                        '-v',
                        action='store_true',
                        help='turns on diagnostic messages')

    # parse command-line arguments, create directory where to save results
    args = parser.parse_args()
    out_path = os.path.join(
        base_dir, 'output', '{}__samps-{}'.format(args.expr_source,
                                                  args.samp_cutoff),
        args.classif)

    comb_list = pickle.load(
        open(
            os.path.join(
                base_dir, "setup",
                "combs-list_{}__samps-{}.p".format(args.expr_source,
                                                   args.samp_cutoff)), 'rb'))

    comb_list = sorted((cohs, mtype) for cohs, mtypes in comb_list.items()
                       for mtype in mtypes)
    task_size = len(comb_list) // args.task_count

    combs_use = comb_list[(args.task_id * task_size):((args.task_id + 1) *
                                                      task_size)]
    if args.task_id < (len(comb_list) % args.task_count):
        combs_use += [comb_list[-(args.task_id + 1)]]

    syn = synapseclient.Synapse()
    syn.cache.cache_root_dir = syn_root
    syn.login()

    clf_info = args.classif.split('__')
    clf_module = import_module(
        'HetMan.experiments.transfer_baseline.models.{}'.format(clf_info[0]))
    mut_clf = getattr(clf_module, clf_info[1].capitalize())

    out_auc = {comb: {'train': dict(), 'test': dict()} for comb in combs_use}
    out_aupr = {comb: {'train': dict(), 'test': dict()} for comb in combs_use}
    out_params = {comb: None for comb in combs_use}
    out_time = {comb: None for comb in combs_use}

    for cur_cohs in {cohs for cohs, _ in combs_use}:
        if args.verbose:
            print("Transferring between cohort {} "
                  "and cohort {} ...".format(*cur_cohs))

        combs_cur = [(cohs, mtypes) for cohs, mtypes in combs_use
                     if cohs == cur_cohs]
        cur_genes = {mtype.subtype_list()[0][0] for _, mtype in combs_cur}

        cdata = TransferMutationCohort(
            cohorts=cur_cohs,
            mut_genes=list(cur_genes),
            mut_levels=['Gene', 'Form_base', 'Protein'],
            expr_sources=args.expr_source,
            var_sources='mc3',
            copy_sources='Firehose',
            annot_file=annot_file,
            expr_dir=expr_sources[args.expr_source],
            copy_dir=copy_dir,
            syn=syn,
            cv_prop=0.75,
            cv_seed=2079 + 57 * args.cv_id)

        for _, mtype in combs_cur:
            clf = mut_clf()
            if args.verbose:
                print("Testing {} ...".format(mtype))

            mut_gene = mtype.subtype_list()[0][0]
            ex_genes = {
                gene
                for gene, annot in cdata.gene_annot.items()
                if annot['chr'] == cdata.gene_annot[mut_gene]['chr']
            }

            clf.tune_coh(cdata,
                         mtype,
                         exclude_genes=ex_genes,
                         tune_splits=4,
                         test_count=36,
                         parallel_jobs=12)
            out_params[cur_cohs, mtype] = {
                par: clf.get_params()[par]
                for par, _ in mut_clf.tune_priors
            }

            t_start = time.time()
            clf.fit_coh(cdata, mtype, exclude_genes=ex_genes)
            t_end = time.time()
            out_time[cur_cohs, mtype] = t_end - t_start

            pheno_list = dict()
            train_omics, pheno_list['train'] = cdata.train_data(
                mtype, exclude_genes=ex_genes)
            test_omics, pheno_list['test'] = cdata.test_data(
                mtype, exclude_genes=ex_genes)

            pred_scores = {
                'train': clf.predict_omic(train_omics),
                'test': clf.predict_omic(test_omics)
            }
            samp_sizes = {
                'train': {
                    coh: (len(mtype.get_samples(cdata.train_mut[coh])) /
                          len(cdata.train_samps[coh]))
                    for coh in cur_cohs
                },
                'test': {
                    coh: (len(mtype.get_samples(cdata.test_mut[coh])) /
                          len(cdata.test_samps[coh]))
                    for coh in cur_cohs
                }
            }

            for samp_set, scores in pred_scores.items():
                for coh in cur_cohs:
                    if len(set(pheno_list[samp_set][coh])) == 2:
                        out_auc[cur_cohs, mtype][samp_set][coh] = auc_score(
                            pheno_list[samp_set][coh], scores[coh])
                        out_aupr[cur_cohs, mtype][samp_set][coh] = aupr_score(
                            pheno_list[samp_set][coh], scores[coh])

                    else:
                        out_auc[cur_cohs, mtype][samp_set][coh] = 0.5
                        out_aupr[cur_cohs, mtype][samp_set][coh] = (
                            samp_sizes[samp_set][coh])

    pickle.dump(
        {
            'AUC': out_auc,
            'AUPR': out_aupr,
            'Clf': mut_clf,
            'Params': out_params,
            'Time': out_time
        },
        open(
            os.path.join(
                out_path,
                'out__cv-{}_task-{}.p'.format(args.cv_id, args.task_id)),
            'wb'))
Esempio n. 4
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('classif',
                        type=str,
                        help="the name of a mutation classifier")
    parser.add_argument('--use_dir', type=str, default=base_dir)

    parser.add_argument(
        '--task_count',
        type=int,
        default=10,
        help="how many parallel tasks the list of types to test is split into")
    parser.add_argument('--task_id',
                        type=int,
                        default=0,
                        help="the subset of subtypes to assign to this task")
    parser.add_argument('--cv_id',
                        type=int,
                        default=0,
                        help="the seed to use for random sampling")

    args = parser.parse_args()
    out_path = os.path.join(args.use_dir, 'setup')
    use_seed = 2079 + 57 * args.cv_id

    with open(os.path.join(out_path, "feat-list.p"), 'rb') as fl:
        feat_list = pickle.load(fl)
    with open(os.path.join(out_path, "cohort-data.p"), 'rb') as fl:
        cdata = pickle.load(fl)
    with open(os.path.join(out_path, "vars-list.p"), 'rb') as fl:
        vars_list = pickle.load(fl)

    clf_info = args.classif.split('__')
    clf_module = import_module(
        'HetMan.experiments.variant_baseline.models.{}'.format(clf_info[0]))
    mut_clf = getattr(clf_module, clf_info[1].capitalize())()

    out_acc = {
        mtype: {
            'tune': {
                'mean': None,
                'std': None
            },
            'train': {
                'AUC': None,
                'AUPR': None
            },
            'test': {
                'AUC': None,
                'AUPR': None
            }
        }
        for mtype in vars_list
    }

    out_params = {mtype: None for mtype in vars_list}
    out_scores = {mtype: None for mtype in vars_list}

    out_time = {
        mtype: {
            'tune': {
                'fit': dict(),
                'score': dict()
            },
            'final': {
                'fit': None,
                'score': None
            }
        }
        for mtype in vars_list
    }

    coh_files = glob(os.path.join(out_path, "*__cohort-data.p"))
    coh_dict = {
        coh_fl.split('/setup/')[1].split('__')[0]: coh_fl
        for coh_fl in coh_files
    }
    out_trnsf = {mtype: dict() for mtype in vars_list}

    for i, mtype in enumerate(vars_list):
        if (i % args.task_count) == args.task_id:
            print("Testing {} ...".format(mtype))

            if (args.cv_id // 25) == 2:
                cdata.update_split(use_seed)

            elif (args.cv_id // 25) == 1:
                cdata_samps = cdata.get_samples()

                random.seed((args.cv_id // 5) * 1811 + 9)
                random.shuffle(cdata_samps)
                cdata.update_split(use_seed,
                                   test_samps=cdata_samps[(args.cv_id % 5)::5])

            elif (args.cv_id // 25) == 0:
                cdata.update_split(use_seed, test_prop=0.2)

            else:
                raise ValueError("Invalid cross-validation id!")

            # get the gene that the variant is associated with and the list
            # of genes on the same chromosome as that gene
            var_gene = mtype.subtype_list()[0][0]
            ex_genes = {
                gene
                for gene, annot in cdata.gene_annot.items()
                if annot['Chr'] == cdata.gene_annot[var_gene]['Chr']
            }
            use_genes = feat_list - ex_genes

            mut_clf, cv_output = mut_clf.tune_coh(cdata,
                                                  mtype,
                                                  include_feats=use_genes,
                                                  tune_splits=4,
                                                  test_count=36,
                                                  parallel_jobs=8)

            out_time[mtype]['tune']['fit']['avg'] = cv_output['mean_fit_time']
            out_time[mtype]['tune']['fit']['std'] = cv_output['std_fit_time']
            out_time[mtype]['tune']['score']['avg'] = cv_output[
                'mean_score_time']
            out_time[mtype]['tune']['score']['std'] = cv_output[
                'std_score_time']

            out_acc[mtype]['tune']['mean'] = cv_output['mean_test_score']
            out_acc[mtype]['tune']['std'] = cv_output['std_test_score']
            out_params[mtype] = {
                par: mut_clf.get_params()[par]
                for par, _ in mut_clf.tune_priors
            }

            t_start = time.time()
            mut_clf.fit_coh(cdata, mtype, include_feats=use_genes)
            t_end = time.time()
            out_time[mtype]['final']['fit'] = t_end - t_start

            if (args.cv_id // 25) < 2:
                pheno_list = dict()
                train_omics, pheno_list['train'] = cdata.train_data(
                    mtype, include_feats=use_genes)
                test_omics, pheno_list['test'] = cdata.test_data(
                    mtype, include_feats=use_genes)

                t_start = time.time()
                pred_scores = {
                    'train':
                    mut_clf.parse_preds(mut_clf.predict_omic(train_omics)),
                    'test':
                    mut_clf.parse_preds(mut_clf.predict_omic(test_omics))
                }

                out_time[mtype]['final']['score'] = time.time() - t_start
                out_scores[mtype] = pred_scores['test']

                samp_sizes = {
                    'train': (sum(cdata.train_pheno(mtype)) /
                              len(cdata.get_train_samples())),
                    'test': (sum(cdata.test_pheno(mtype)) /
                             len(cdata.get_test_samples()))
                }

                for smp_set, scores in pred_scores.items():
                    if len(set(pheno_list[smp_set])) == 2:
                        out_acc[mtype][smp_set]['AUC'] = roc_auc_score(
                            pheno_list[smp_set], scores)
                        out_acc[mtype][smp_set]['AUPR'] = aupr_score(
                            pheno_list[smp_set], scores)

                    else:
                        out_acc[mtype][smp_set]['AUC'] = 0.5
                        out_acc[mtype][smp_set]['AUPR'] = samp_sizes[smp_set]

            else:
                out_scores[mtype] = [
                    mut_clf.parse_preds(vals)[0]
                    for vals in mut_clf.infer_coh(cdata,
                                                  mtype,
                                                  include_feats=use_genes,
                                                  infer_splits=5,
                                                  infer_folds=5,
                                                  parallel_jobs=5)
                ]

            out_trnsf[mtype] = dict(
                zip(coh_dict.keys(), [
                    mut_clf.parse_preds(vals)
                    for vals in Parallel(n_jobs=8, pre_dispatch=8)(
                        delayed(mut_clf.predict_omic)
                        (pickle.load(open(coh_fl, 'rb')).train_data(
                            mtype, include_feats=use_genes)[0])
                        for coh_fl in coh_dict.values())
                ]))

        else:
            del (out_acc[mtype])
            del (out_params[mtype])
            del (out_scores[mtype])
            del (out_time[mtype])
            del (out_trnsf[mtype])

    with open(
            os.path.join(
                args.use_dir, 'output',
                "out__cv-{}_task-{}.p".format(args.cv_id, args.task_id)),
            'wb') as fl:
        pickle.dump(
            {
                'Acc': out_acc,
                'Clf': mut_clf.__class__,
                'Params': out_params,
                'Time': out_time,
                'Scores': out_scores,
                'Transfer': out_trnsf
            }, fl)
Esempio n. 5
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('use_gene',
                        type=str,
                        help="the gene whose mutations are being classified")
    parser.add_argument('model',
                        type=str,
                        help="the name of a mutation classifier")

    parser.add_argument('--use_dir', type=str, default=base_dir)
    parser.add_argument('--task_id',
                        type=int,
                        default=0,
                        help="the subset of subtypes to assign to this task")
    parser.add_argument('--cv_id',
                        type=int,
                        default=6072,
                        help="the seed to use for random sampling")

    args = parser.parse_args()
    setup_dir = os.path.join(args.use_dir, 'setup')

    with open(os.path.join(setup_dir, "vars-list.p"), 'rb') as fl:
        vars_list = pickle.load(fl)

    with open(os.path.join(setup_dir, "cohort-data.p"), 'rb') as fl:
        cdata = pickle.load(fl)
    cdata.update_seed(2079 + 57 * args.cv_id, test_prop=0.25)

    clf_info = args.model.split('__')
    clf_module = import_module(
        'HetMan.experiments.stan_baseline.models.{}'.format(clf_info[0]))
    use_clf = getattr(clf_module, clf_info[1].capitalize())

    out_acc = {
        fit_method: {
            'tune': {
                'mean': None,
                'std': None
            },
            'train': {
                'AUC': None,
                'AUPR': None
            },
            'test': {
                'AUC': None,
                'AUPR': None
            }
        }
        for fit_method in ['optim', 'varit', 'sampl']
    }

    out_time = {
        fit_method: {
            'tune': {
                'fit': dict(),
                'score': dict()
            },
            'final': {
                'fit': None,
                'score': None
            }
        }
        for fit_method in ['optim', 'varit', 'sampl']
    }

    out_params = {'optim': None, 'varit': None, 'sampl': None}
    out_scores = {
        fit_method: {
            'train': None,
            'test': None
        }
        for fit_method in ['optim', 'varit', 'sampl']
    }

    use_mtype = sorted(vars_list)[args.task_id]
    ex_genes = {
        gene
        for gene, annot in cdata.gene_annot.items()
        if annot['Chr'] == cdata.gene_annot[args.use_gene]['Chr']
    }

    for fmth in ['optim', 'varit', 'sampl']:
        mut_clf = use_clf(fit_method=fmth)

        mut_clf, cv_output = mut_clf.tune_coh(cdata,
                                              use_mtype,
                                              exclude_feats=ex_genes,
                                              tune_splits=4,
                                              test_count=24,
                                              parallel_jobs=12)

        out_time[fmth]['tune']['fit']['avg'] = cv_output['mean_fit_time']
        out_time[fmth]['tune']['fit']['std'] = cv_output['std_fit_time']
        out_time[fmth]['tune']['score']['avg'] = cv_output['mean_score_time']
        out_time[fmth]['tune']['score']['std'] = cv_output['std_score_time']

        out_acc[fmth]['tune']['mean'] = cv_output['mean_test_score']
        out_acc[fmth]['tune']['std'] = cv_output['std_test_score']
        out_params[fmth] = {
            par: mut_clf.get_params()[par]
            for par, _ in mut_clf.tune_priors
        }

        t_start = time.time()
        mut_clf.fit_coh(cdata, use_mtype, exclude_feats=ex_genes)
        t_end = time.time()
        out_time[fmth]['final']['fit'] = t_end - t_start

        pheno_list = dict()
        train_omics, pheno_list['train'] = cdata.train_data(
            use_mtype, exclude_feats=ex_genes)
        test_omics, pheno_list['test'] = cdata.test_data(
            use_mtype, exclude_feats=ex_genes)

        t_start = time.time()
        pred_scores = {
            'train': mut_clf.parse_preds(mut_clf.predict_omic(train_omics)),
            'test': mut_clf.parse_preds(mut_clf.predict_omic(test_omics))
        }
        out_time[fmth]['final']['score'] = time.time() - t_start

        samp_sizes = {
            'train': (sum(cdata.train_pheno(use_mtype)) /
                      len(cdata.get_train_samples())),
            'test':
            (sum(cdata.test_pheno(use_mtype)) / len(cdata.get_test_samples()))
        }

        for samp_set, scores in pred_scores.items():
            out_scores[fmth][samp_set] = scores

            if len(set(pheno_list[samp_set])) == 2:
                out_acc[fmth][samp_set]['AUC'] = roc_auc_score(
                    pheno_list[samp_set], scores)
                out_acc[fmth][samp_set]['AUPR'] = aupr_score(
                    pheno_list[samp_set], scores)

            else:
                out_acc[fmth][samp_set]['AUC'] = 0.5
                out_acc[fmth][samp_set]['AUPR'] = samp_sizes[samp_set]

        if fmth == 'sampl':
            out_sampl = mut_clf.named_steps['fit'].fit_obj.summary()

    with open(
            os.path.join(
                args.use_dir, 'output',
                "out__cv-{}_task-{}.p".format(args.cv_id, args.task_id)),
            'wb') as fl:
        pickle.dump(
            {
                'Acc': out_acc,
                'Clf': use_clf,
                'Scores': out_scores,
                'Params': out_params,
                'Time': out_time,
                'Sampl': out_sampl
            }, fl)