Пример #1
0
def plotter(df, option, DURATION, EVENT, CategoricalDtype):
    kmf = KaplanMeierFitter()
    fig, ax = plt.subplots(figsize=(10, 5))
    T = df[DURATION]
    E = df[EVENT]

    if isinstance(df[option].dtype, CategoricalDtype):
        unique_codes = list(set(df[option].cat.codes.values))
        unique_codes.sort()

        mapping = dict(zip(df[option].cat.codes.values, df[option].values))
        kmf.fit(T, E)

        for code in unique_codes:
            subset = (df[option] == mapping[code])
            kmf.fit(T[subset], event_observed=E[subset], label=mapping[code])
            kmf.plot_survival_function(ax=ax)

    else:
        unique_codes = list(set(df[option].values))
        unique_codes.sort()

        kmf.fit(T, E)

        for code in unique_codes:
            subset = (df[option] == code)
            kmf.fit(T[subset], event_observed=E[subset], label=code)
            kmf.plot_survival_function(ax=ax)

#    plt.title("Lifespans by " + option)
    return plt
Пример #2
0
def kmftest(data, part):
    kmf = KaplanMeierFitter()
    kmf.fit(df['time'], event_observed=df['status'])
    #filter dataset on part
    dataframe = pd.read_json(data)
    dataframe = dataframe[dataframe["sex"] == part]
    #dataframe = dataframe[dataframe['part'] == part]
    #TODO: change to duration later
    T = dataframe['time']
    #TODO: chagne this key to the right dataset.
    E = dataframe['censored']
    kmf.fit(data['time'], event_observed=data['status'])
    kmf.plot_survival_function()
    def plot(self):
        """
        Plot side-by-side kaplan-meier of input datasets
        """

        figsize = (10, 5)
        fig, ax = plt.subplots(1, 2, figsize=figsize, sharey=True)

        # sns.set(font_scale=1.5)
        # sns.despine()
        palette = ['#0d3d56', '#006887', '#0098b5', '#00cbde', '#00ffff']

        datasets = [self.stats_original_, self.stats_synthetic_]
        for data, label, ax_cur in zip(datasets, self.labels, ax):
            t = data['time']
            e = data['event']

            kmf = KaplanMeierFitter()
            groups = np.sort(data['group'].unique())
            for g, color in zip(groups, palette):
                mask = (data['group'] == g)
                kmf.fit(t[mask], event_observed=e[mask], label=g)
                ax_cur = kmf.plot_survival_function(ax=ax_cur, color=color)
                ax_cur.legend(title=self.group_column)
                ax_cur.set_title('Kaplan-Meier - {} data'.format(label))
                ax_cur.set_ylim(0, 1)
        plt.tight_layout()
def estimate_survival(df, plot=True, censoring='right', fig_path=None):

    if censoring not in {'right', 'left'}:
        raise ValueError(f"unknown fit type: {censoring},"
                         f" use one of {{'left', 'right'}}")

    kmf = KaplanMeierFitter(alpha=1.0)  # disable confidence interval

    if plot:
        fig = plt.figure(figsize=(20, 10))
        ax = fig.add_subplot(111)

    medians = {}
    for system in sorted(df.domain_system.unique()):
        if censoring == 'right':
            kmf.fit(
                df.loc[df['domain_system'] == system].time,
                df.loc[df['domain_system'] == system].spotted,
                label=system,
            )
        elif censoring == 'left':
            kmf.fit_left_censoring(
                df.loc[df['domain_system'] == system].time,
                ~df.loc[df['domain_system'] == system].spotted,
                label=system,
            )
        else:
            raise ValueError(f"unknown fit type: {censoring},"
                             f" use one of {{'left', 'right'}}")

        if plot:
            kmf.plot_survival_function(ax=ax)

        medians[system] = kmf.median_survival_time_

    if plot:
        plt.ylim(0.0, 1.0)
        plt.xlabel("Turns")
        plt.ylabel("Survival Probability")
        plt.title("Estimated Survival Function of different systems")
        save_path = fig_path or "survival.png"
        print(f'saving plot of estimated survival functions to: {save_path}')
        plt.savefig(save_path)

    return medians
 def _PlotFigure(self, s_mean, s_std, x, time):
     ## Plot KM curve
     from lifelines import CoxPHFitter, KaplanMeierFitter
     fname = 'Data_Clinical_Encoded.csv'
     df = pd.read_csv(fname)
     df.dropna(axis=0, inplace=True)
     df['event'] = 1 - (df['Count_as_OS'] == 'N')
     print(df.head(5))
     age_bool = (df['Age_@_Dx'] > self.params[0] - 7) & (df['Age_@_Dx'] <
                                                         self.params[0] + 7)
     T_bool = (df['T' + str(self.params[1])] == 1)
     N_bool = (df['N' + str(self.params[2])] == 1)
     M_bool = (df['M' + str(self.params[3])] == 1)
     grade_bool = (df['Grade' + str(self.params[4])] == 1)
     ER_bool = (df['ER'] == self.params[5])
     PR_bool = (df['PR'] == self.params[6])
     HER2_bool = (df['Her2'] == self.params[7])
     invade_bool = (df['Invasion'] == self.params[8])
     size_bool = (df['size_precise'] > self.params[9] - 5) & (
         df['size_precise'] < self.params[9] + 5)
     nodes_bool = (df['nodespos'] > self.params[10] - 5) & (
         df['nodespos'] < self.params[10] + 5)
     df_subset = df[age_bool & T_bool & N_bool & M_bool & grade_bool
                    & ER_bool & PR_bool & HER2_bool & invade_bool]
     print(len(df_subset), len(df))
     kmf = KaplanMeierFitter()
     kmf.fit(df_subset['Time_OS'], df_subset['event'], label="Train")
     kmf.plot_survival_function(show_censors=False,
                                ci_show=True,
                                at_risk_counts=True)
     titleString ='Age=' + str(self.params[0]) + ', T' + str(self.params[1]) + ', N'\
         + str(self.params[2]) + ', M' + str(self.params[3]) + ', Grade' +  str(self.params[4])\
         + ', ER=' + str(self.params[5]) + ', PR=' + str(self.params[6]) + ', Her2=' + str(self.params[7])\
         + ', Invasion=' + str(self.params[8])
     plt.title(titleString)
     plt.plot(time, s_mean, 'r-', label='DeepSurv')
     plt.fill_between(time, s_mean - s_std, s_mean + s_std, alpha=0.4)
     plt.legend()
     plt.tight_layout()
     plt.savefig('survival_comparison')
     plt.show()
Пример #6
0
def SevenPlot(stimes,Np,aps): 
    
    #ax = plt.subplot(111)
    
    for x in range(3,10,1):
        stime = stimes[0,x,:]
        
        E = np.zeros(Np).astype(int)
        for i in range(0,30,1):
            if((stime[0,i]) > 62831):
                E[i] = 0
            else:
                E[i] = 1
                
         
        data1 = {'T':stime[0], 'E':E}
        df = pd.DataFrame(data=data1)

        T = df['T']
        E = df['E']
        
        kmf = KaplanMeierFitter()

        kmf.fit(T,E, label = "ap = {}".format(aps[x]))
        kmf.plot_survival_function(at_risk_counts = True)
        
    plt.tight_layout() 
          
            
    plt.title("Survival functions as a function of Planetary Semimajor Axis (ap)")
    
                                              


        
        
        
Пример #7
0
from lifelines.datasets import load_waltons
from lifelines import KaplanMeierFitter
from lifelines.utils import median_survival_times

df = load_waltons()
print(df.head(),'\n')
print(df['T'].min(), df['T'].max(),'\n')
print(df['E'].value_counts(),'\n')
print(df['group'].value_counts(),'\n')

kmf = KaplanMeierFitter()
kmf.fit(df['T'], event_observed=df['E'])

kmf.plot_survival_function()

median_ = kmf.median_survival_time_
median_confidence_interval_ = median_survival_times(kmf.confidence_interval_)
print(median_confidence_interval_)
Пример #8
0

# %% [markdown]
# #### lifelines

# %%
from lifelines import KaplanMeierFitter
from lifelines.plotting import add_at_risk_counts

ax = plt.subplot(111)
kmfs = []
for status in dat['ER Status'].unique():
    mask = dat['ER Status']==status
    kmf = KaplanMeierFitter()
    kmf.fit(dat["OS Time"][mask], event_observed = dat['OS event'][mask], label = status)
    ax = kmf.plot_survival_function(ax=ax,ci_show=False)
    kmfs.append(kmf)

add_at_risk_counts(*kmfs, ax=ax) # *kmfs expands to the elements of kmfs
plt.tight_layout()

# %% [markdown]
# Note that `lifelines` and `scikit-survival` leverage `matplotlib` and so the layers are having to be drawn as part of a for loop. This is typical `matplotlib` behavior.
#
# The `kaplanmeier` function basically packages up the `lifeline` Kaplan-Meier curve and makes some choices for you while creating the graph.

# %% [markdown]
# ## Question 1b

# %%
ax = plt.subplot(111)
                         SurvEst['S_est']) / (SurvEst['dt'] * SurvEst['S_est'])
    SurvEst['cumhazard'] = SurvEst['hazard'].cumsum()
    #dfC_KM['survived'] = len(dfC) - dfC_KM['E']  #total_events - dfC_KM['E']
    #dfC_KM['KMfactor'] = 1 - dfC_KM['E'] / dfC_KM['survived']
    #    S_est = np.empty((len(dfC_KM)+1,))
    #    S_est[0] = 1
    #    for i,val in enumerate(dfC_KM.KMfactor.values):
    #        S_est[i+1] = S_est[i] * val

    # Inspect result: compare to lifelines values
    S_est = SurvEst['S_est'].values
    print(np.allclose(kmf.survival_function_.KM_estimate.values, S_est))
    t_KM = np.hstack((0, dfC_KM.index.values))
    #plt.figure(figsize=(7,6))
    #plt.plot(t_KM, kmf.survival_function_.KM_estimate.values, 'm-')
    kmf.plot_survival_function(figsize=(7, 6))
    #plt.plot(dfC_KM['T'].values, S_est, 'r.', alpha=.5, label='manual')
    plt.plot(SurvEst['T'].values,
             SurvEst['S_est'].values,
             'r.:',
             alpha=.6,
             label='manual')
    plt.legend()

    kmf.cumulative_density_.plot(figsize=(7, 6))
    plt.plot(SurvEst['T'].values,
             (SurvEst['hazard'] * SurvEst['S_est']).cumsum().values,
             'r.:',
             alpha=.6,
             label='manual')
    plt.legend()
Пример #10
0
kmf_test = KaplanMeierFitter()

# In[26]:

figure = plt.figure(figsize=(12, 8), tight_layout=False)

ax = figure.add_subplot(111)

t = np.linspace(0, 84, 85)

kmf_train.fit(train['PFS'],
              event_observed=train['disease_progress'],
              timeline=t,
              label='Train set')
ax = kmf_train.plot_survival_function(show_censors=True,
                                      ci_force_lines=False,
                                      ci_show=False,
                                      ax=ax)

kmf_test.fit(test['PFS'],
             event_observed=test['disease_progress'],
             timeline=t,
             label='Test set')
ax = kmf_test.plot_survival_function(show_censors=True,
                                     ci_force_lines=False,
                                     ci_show=False,
                                     ax=ax)

add_at_risk_counts(kmf_train, kmf_test, ax=ax, fontsize=12)

ax.set_xlabel('Time (months)', fontsize=12, fontweight='bold')
ax.set_ylabel('Survival probability', fontsize=12, fontweight='bold')
Пример #11
0
using lifeline package
'''

# single category
from lifelines.datasets import load_waltons
df = load_waltons()  # returns a Pandas DataFrame

T = df['T']
E = df['E']

from lifelines import KaplanMeierFitter
kmf = KaplanMeierFitter()
kmf.fit(T, E)
a = kmf.survival_function_
b = kmf.cumulative_density_
c = kmf.plot_survival_function(at_risk_counts=True)

# two categories
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
kmf = KaplanMeierFitter()
T, E = [], []
for name, grouped_df in df.groupby('group'):
    T.append(grouped_df['T'].values)
    E.append(grouped_df['E'].values)
    kmf.fit(grouped_df["T"], grouped_df["E"], label=name)
    kmf.plot_survival_function(ax=ax)
from lifelines.statistics import logrank_test
results = logrank_test(T[0], T[1], E[0], E[1])
ax.text(x=0.05,
        y=0.05,
Пример #12
0
def CoxRegressionModel(stimes, N, ap_s, ap_s2, ap_s3, ap_s5, ebb_, eap_, ebs,
                       aps_0, mu, stime, Np):

    #ebs = singular value = 0, 0.175, 0.35, 0.525, 0.7
    #stimes = 1050 values ( 5 eb values x 7 ap values x 30 survival times)
    #stime = 30 values for each aps_) value
    #N = 1050
    #Np = 30
    #ap_s, ap_s2, ap_s3 = array of 1050 values used for the data frame

    #*************************************EVENT OBSERVATION**************************************************#
    E = np.zeros(N).astype(int)

    for i, time in enumerate(stimes):
        if (time > 62831):
            E[i] = 0
        else:
            E[i] = 1

    #**************************************MAKING A DATA FRAME************************************************#
    data1 = {
        'T': stimes,
        'E': E,
        'aps': ap_s,
        'aps2': ap_s2,
        'aps3': ap_s3,
        'aps5': ap_s5,
        'eap': eap_,
        'eb': ebb_
    }
    df = pd.DataFrame(data=data1)

    T = df['T']
    E = df['E']
    aps = df['aps']
    aps2 = df['aps2']
    aps3 = df['aps3']
    aps5 = df['aps5']
    eap = df['eap']
    eb = df['eb']

    #print(df)

    #************************************COX PH FITTER*************************************#

    fig, axes = plt.subplots()

    axes.set_xscale('log')
    axes.set_ylabel("S(t)")

    KT, KE, Kdf = PlottingLL.PlottingLL(ebs, aps_0, mu, stime, Np)
    kmf = KaplanMeierFitter().fit(KT, KE, label='KaplanMeierFitter')
    kmf.plot_survival_function(ax=axes)

    cph = CoxPHFitter()

    #cph.fit(df,duration_col = 'T', event_col = 'E', formula = "eap")
    #cph.fit(df,duration_col = 'T', event_col = 'E')

    cph.fit(df, duration_col='T', event_col='E', formula="aps + I(aps**3)")
    #cph.print_summary()
    cph.plot_partial_effects_on_outcome(plot_baseline=False,
                                        ax=axes,
                                        cmap="coolwarm")
    #cph.plot_partial_effects_on_outcome(covariates = ['aps'], values = [round(aps_0,3)], plot_baseline = False, ax = axes, cmap = "coolwarm")
    cph.baseline_survival_.plot(ax=axes, ls=":", color=f"C{i}")

    #cph.fit(df,duration_col = 'T', event_col = 'E', formula = "eb + aps + I(aps**3)")
    #cph.print_summary()
    #cph.plot_partial_effects_on_outcome(covariates = ['aps'], values = [round(aps_0,3)], ax = axes)

    plt.title('Formula(aps vs. aps3 (1050 values)): (eb,ap,mu)={}'.format(
        (round(ebs, 3), round(aps_0, 3), mu)),
              fontsize=12)
Пример #13
0
def KM_age_groups(features, outcomes):
    T = outcomes['duration_mortality']
    E = outcomes['event_mortality']

    features['Leeftijd (jaren)'] = features['Leeftijd (jaren)'].astype(float)
    age_70_plus = features['Leeftijd (jaren)'] >= 70.
    print('Alle patiënten: n=', len(age_70_plus))
    print('Leeftijd ≥ 70: n=', sum(age_70_plus))
    print('Leeftijd < 70: n=', sum(~age_70_plus))

    was_not_on_icu = data[[
        'Levend ontslagen en niet heropgenomen - waarvan niet opgenomen geweest op IC',
        'Levend dag 21 maar nog in het ziekenhuis - niet op IC geweest',
        'Dood op dag 21 - niet op IC geweest'
    ]].any(axis=1)
    was_on_icu = data[[
        'Levend ontslagen en niet heropgenomen - waarvan opgenomen geweest op IC',
        'Levend dag 21 maar nog in het ziekenhuis - op IC geweest',
        'Levend dag 21 maar nog in het ziekenhuis - waarvan nu nog op IC',
        'Dood op dag 21 - op IC geweest'
    ]].any(axis=1)
    print('Alle patiënten: n=',
          (np.nansum(was_on_icu) + np.nansum(was_not_on_icu)))
    print('ICU yes: n=', np.nansum(was_on_icu))
    print('ICU no: n=', np.nansum(was_not_on_icu))

    icu = False
    fig, axes = plt.subplots(1, 1)

    if icu:
        kmf1 = KaplanMeierFitter().fit(T[was_on_icu],
                                       E[was_on_icu],
                                       label='Wel op ICU geweest')
        kmf2 = KaplanMeierFitter().fit(T[was_not_on_icu],
                                       E[was_not_on_icu],
                                       label='Niet op ICU geweest')
    else:  # age 70
        kmf1 = KaplanMeierFitter().fit(T[age_70_plus],
                                       E[age_70_plus],
                                       label='Leeftijd ≥ 70 jaar')
        kmf2 = KaplanMeierFitter().fit(T[~age_70_plus],
                                       E[~age_70_plus],
                                       label='Leeftijd < 70 jaar')

    kmf3 = KaplanMeierFitter().fit(T, E, label='Alle patienten')
    if icu:
        kmf1.plot_survival_function(color=c4)
        kmf2.plot_survival_function(color=c5)
        kmf3.plot_survival_function(color=c3)
    else:
        kmf1.plot_survival_function(color=c1)
        kmf2.plot_survival_function(color=c2)
        kmf3.plot_survival_function(color=c3)

    axes.set_xticks([1, 5, 9, 13, 17, 21])
    axes.set_xticklabels(['1', '5', '9', '13', '17', '21'])
    axes.set_xlabel('Aantal dagen sinds opnamedag')
    axes.set_ylabel('Proportie overlevend')

    axes.set_xlim(0, 21)
    axes.set_ylim(0, 1)
    titledict = {
        'fontsize': 18,
        'fontweight': 'bold',
        'verticalalignment': 'baseline',
        'horizontalalignment': 'center'
    }
    # plt.title('COVID-PREDICT survival functie tot t = 21 dagen (n='+str(len(T))+')',fontdict=titledict)
    plt.tight_layout()
    if icu:
        plt.savefig('KM_survival_curve_ICU.png',
                    format='png',
                    dpi=300,
                    figsize=(20, 20),
                    pad_inches=0,
                    bbox_inches='tight')
    else:
        plt.savefig('KM_survival_curve_DEATH.png',
                    format='png',
                    dpi=300,
                    figsize=(20, 20),
                    pad_inches=0,
                    bbox_inches='tight')
    plt.show()

    return kmf1, kmf2, kmf3
Пример #14
0
    ofile.write("Marginal PGS Effect on Hematuria Risk P-value (LR Test): {0:.2e}\n".format(test_stats_PGS[1]))
    ofile.write("Marginal PGSxP/LP Effect on Hematuria Risk P-value (LR Test): {0:.2e}\n".format(test_stats_pgs_int[1]))
    ofile.write('\n\n')

    ofile.write("#"*5+" Final Model "+'#'*5+'\n')
    ofile.write(full_model.summary.to_string())



    f, axis = plt.subplots(1, 1,figsize=(10,8))
    f.tight_layout(pad=2)
    axis.spines['right'].set_visible(False)
    axis.spines['top'].set_visible(False)
    kmf=KaplanMeierFitter()
    kmf.fit(cph_table.loc[cph_table.Genotype=='Control']['ObsWindow'],event_observed=cph_table.loc[cph_table.Genotype=='Control']['Dx'], label='Control')
    kmf.plot_survival_function(ax=axis,color=color_list[0],lw=2.0)


    kmf.fit(cph_table.loc[cph_table.Genotype=='P/LP Carrier']['ObsWindow'],event_observed=cph_table.loc[cph_table.Genotype=='P/LP Carrier']['Dx'], label='P/LP Carrier')
    kmf.plot_survival_function(ax=axis,color=color_list[4],lw=2.0)


    axis.text(axis.get_xlim()[0]+0.25*(axis.get_xlim()[1]-axis.get_xlim()[0]),axis.get_ylim()[0]+0.35*(axis.get_ylim()[1]-axis.get_ylim()[0]),'P/LP P-value={0:.2e}'.format(test_stats_geno[1]),fontsize=20,fontweight='bold')

    axis.set_title('Pack-Year Test Dataset (N={0:d})'.format(cph_table.shape[0]),fontsize=20,fontweight='bold')
    axis.set_ylabel('Fraction of Patients\nUnaffected by\nPersistent Hematuria')
    axis.set_xlabel('Patient Age')
    plt.savefig(fig_direc+'Genotype_Hematuria_WithSmoking_KaplanMeier.svg')
    plt.close()

Пример #15
0
ax.set_yticks([0.5, 0.6, 0.7, 0.8, 0.9, 1.0])
ax.set_xticklabels([
    '4_feat.', "Prediction 1", 'Prediction 2', 'Prediction 3', 'Prediction 4'
],
                   ha='center')
#ax.set_xticklabels(['4_feat.', "Prediction 3", 'vd = 2%', 'vd = 2.5%',
#                    'vd = 3%', 'vd = 3.8%', 'vd = 5%'], ha = 'center')

#%%
kmf = KaplanMeierFitter()
T = data['bio_rec_6_delay']
E = data['bio_rec_6']
kmf.fit(T, event_observed=E)

fig, ax = plt.subplots()
kmf.plot_survival_function(at_risk_counts=True)
plt.tight_layout()

#%%
#cov = 'age'
#nameCov = 'Age'
#unitCov = 'years'

#cov = 'tum_vol'
#nameCov = 'Tumor volume'
#unitCov = 'mm$^3$'

#cov = 'ADC_ave'
#nameCov = 'Average ADC'
#unitCov = ''
Пример #16
0
    delimiter="\t")

regression_dataset['Class'] = regression_dataset['Class'].map({
    'YES': 1,
    'NO': 0
})
regression_dataset = regression_dataset.join(
    pd.get_dummies(regression_dataset['clusters']))
regression_dataset = regression_dataset.drop(['id', 'clusters'], axis=1)

print(regression_dataset.head(100))
print(regression_dataset.dtypes)

kmf = KaplanMeierFitter()
kmf.fit(regression_dataset['Time'], event_observed=regression_dataset['Class'])

kmf.survival_function_
survival_function = kmf.plot_survival_function()  # or just kmf.plot()
survival_function.get_figure().savefig("KaplanMeierFitter.png")

# Using Cox Proportional Hazards model
cph = CoxPHFitter(penalizer=0.1)
cph.fit(regression_dataset, 'Time', event_col='Class')
cph.print_summary()

# predict=cph.predict_survival_function(regression_dataset).plot()
# predict.get_figure().savefig("predicted.png")

ax = cph.plot()
ax.get_figure().savefig("CoxPHFitter.png")
Пример #17
0
# %% Grouped km using lifelines
from lifelines.plotting import add_at_risk_counts

T1 = df.query('group==1')['time']
T2 = df.query('group==2')['time']

C1 = df.query('group==1')['Died']
C2 = df.query('group==2')['Died']

fig, ax = plt.subplots()
kmf1 = KaplanMeierFitter()
kmf1.fit(T1, C1,label='Group 1',)
kmf2 = KaplanMeierFitter()
kmf2.fit(T2, C2,label = 'Group 2',)

ax = kmf1.plot_survival_function(ax = ax,  ci_show=False)
ax = kmf2.plot_survival_function(ax = ax, 
    label = 'Group 2', ci_show=False)
add_at_risk_counts(kmf1, kmf2, ax=ax)

plt.tight_layout();
# %% [markdown]
# ## Kaplan-Meier curve using _lifelines_
#
# We can also format the axes to show percentages rather than proportions. 
#
# We'll demo this with a single KM curve
# %%
import matplotlib.ticker as mtick

ax = kmf.plot()
Пример #18
0
def pipeline(rem_zeros,
             compute,
             make_plot,
             seed,
             dir,
             method='ttest',
             test_size=0.15):

    # loads all data
    os.chdir(dir)
    labels = pd.read_csv('Label_selected.csv')
    ndx = labels.index[
        labels['label'] > -1].tolist()  #gets indices of patients to use
    lbl = labels.iloc[ndx, [
        0, 4
    ]]  #makes a new dataframe of patients to use (their IDs and survival response)
    surv = labels.iloc[ndx, :]
    genes = pd.read_csv('pr_coding_feats.csv')
    genes = genes.set_index('ensembl_gene_id')

    os.chdir(dir)
    gene = pd.read_csv('gene.csv', index_col=0)
    miRNA = pd.read_csv('miRNA.csv', index_col=0)
    meth = pd.read_csv('meth.csv', index_col=0)
    CNV = pd.read_csv('CNV.csv', index_col=0)

    os.chdir(dir)

    # optionally removes rows (features) that are all 0 across patients
    if rem_zeros == True:
        gene = gene.loc[~(gene == 0).all(axis=1)]
        miRNA = miRNA.loc[~(miRNA == 0).all(axis=1)]

    # splitting labels into train set and validation set
    train_labels, test_labels, train_class, test_class = train_test_split(
        lbl['case_id'], lbl, test_size=test_size, random_state=seed)

    # removes features (rows) that have any na in them
    meth = meth.dropna(axis='rows')
    miRNA = miRNA.dropna(axis='rows')
    gene = gene.dropna(axis='rows')
    CNV = CNV.dropna(axis='rows')

    # divides individual modalities into train and test sets based on same patient splits and transposes
    miRNA_train = miRNA[train_labels].T
    miRNA_test = miRNA[test_labels].T
    gene_train = gene[train_labels].T
    gene_test = gene[test_labels].T
    CNV_train = CNV[train_labels].T
    CNV_test = CNV[test_labels].T
    meth_train = meth[train_labels].T
    meth_test = meth[test_labels].T

    # normalizing gene expression and miRNA datasets
    miRNA_train_copy = pd.DataFrame(miRNA_train,
                                    copy=True)  # copies the original dataframe
    miRNA_scaler = preprocessing.MinMaxScaler().fit(miRNA_train)
    miRNA_train = miRNA_scaler.transform(miRNA_train)
    miRNA_train = pd.DataFrame(miRNA_train,
                               columns=list(miRNA_train_copy)).set_index(
                                   miRNA_train_copy.index.values)

    miRNA_test_copy = pd.DataFrame(miRNA_test,
                                   copy=True)  # copies the original dataframe
    miRNA_test = miRNA_scaler.transform(miRNA_test)
    miRNA_test = pd.DataFrame(miRNA_test,
                              columns=list(miRNA_test_copy)).set_index(
                                  miRNA_test_copy.index.values)

    gene_train_copy = pd.DataFrame(gene_train,
                                   copy=True)  # copies the original dataframe
    gene_scaler = preprocessing.MinMaxScaler().fit(gene_train)
    gene_train = gene_scaler.transform(gene_train)
    gene_train = pd.DataFrame(gene_train,
                              columns=list(gene_train_copy)).set_index(
                                  gene_train_copy.index.values)

    gene_test_copy = pd.DataFrame(gene_test,
                                  copy=True)  # copies the original dataframe
    gene_test = gene_scaler.transform(gene_test)
    gene_test = pd.DataFrame(gene_test,
                             columns=list(gene_test_copy)).set_index(
                                 gene_test_copy.index.values)

    train_class = train_class.set_index(
        'case_id')  # changes first column to be indices
    test_class = test_class.set_index(
        'case_id')  # changes first column to be indices

    # makes copies of the y dataframe because tr_ind alters it
    train_class_copy1,train_class_copy2,train_class_copy3,train_class_copy4,train_class_copy5 = pd.DataFrame(train_class, copy=True),\
                                                                                 pd.DataFrame(train_class, copy=True),\
                                                                                 pd.DataFrame(train_class, copy=True),\
                                                                                 pd.DataFrame(train_class, copy=True),\
                                                                                 pd.DataFrame(train_class, copy=True)

    if make_plot == 'valplot':
        gen_curve(gene, lbl, 'gene', 10)
        gen_curve(miRNA, lbl, 'miRNA', 10)
        gen_curve(meth, lbl, 'meth', 10)
        gen_curve(CNV, lbl, 'CNV', 10)
        return

    # makes copies of training data
    miRNA_train_copy2 = pd.DataFrame(miRNA_train, copy=True)
    meth_train_copy2 = pd.DataFrame(meth_train, copy=True)
    CNV_train_copy2 = pd.DataFrame(CNV_train, copy=True)
    gene_train_copy2 = pd.DataFrame(gene_train, copy=True)

    if make_plot == 'fplot':
        make_fplot = True
    else:
        make_fplot = False

    if compute == 'recompute':
        # Runs CV script to generate clfs w best parameters
        clf_gene, fea_gene, _ = do_cv(gene_train, train_class_copy1, gene_test,
                                      test_class, 'ttest', 'gene', 40, 2,
                                      make_fplot)
        clf_miRNA, fea_miRNA, _ = do_cv(miRNA_train, train_class_copy2,
                                        miRNA_test, test_class, 'minfo',
                                        'miRNA', 24, 2, make_fplot)
        clf_meth, fea_meth, _ = do_cv(meth_train, train_class_copy3, meth_test,
                                      test_class, 'minfo', 'meth', 60, 2,
                                      make_fplot)
        clf_CNV, fea_CNV, _ = do_cv(CNV_train, train_class_copy4, CNV_test,
                                    test_class, 'minfo', 'CNV', 50, 2,
                                    make_fplot)
    elif compute == 'custom':
        clf_gene, fea_gene, _ = do_cv(gene_train, train_class_copy1, gene_test,
                                      test_class, method, 'gene', 100, 2,
                                      make_fplot)
        clf_miRNA, fea_miRNA, _ = do_cv(miRNA_train, train_class_copy2,
                                        miRNA_test, test_class, method,
                                        'miRNA', 100, 2, make_fplot)
        clf_meth, fea_meth, _ = do_cv(meth_train, train_class_copy3, meth_test,
                                      test_class, method, 'meth', 100, 2,
                                      make_fplot)
        clf_CNV, fea_CNV, _ = do_cv(CNV_train, train_class_copy4, CNV_test,
                                    test_class, method, 'CNV', 100, 2,
                                    make_fplot)
    elif compute == 'precomputed':
        # loads up the precomputed best classifiers and selected features
        clf_gene = load('clf_gene4.joblib')
        fea_gene = load('fea_gene4.joblib')
        clf_meth = load('clf_meth4.joblib')
        fea_meth = load('fea_meth4.joblib')
        clf_CNV = load('clf_CNV4.joblib')
        fea_CNV = load('fea_CNV4.joblib')
        clf_miRNA = load('clf_miRNA4.joblib')
        fea_miRNA = load('fea_miRNA4.joblib')
    else:
        return "enter a valid parameter for compute"

    # shrinks test feature matrix to contain only selected features
    miRNA_test = miRNA_test[fea_miRNA]
    gene_test = gene_test[fea_gene]
    meth_test = meth_test[fea_meth]
    CNV_test = CNV_test[fea_CNV]

    # gets acc results from predicting on test set with individual modalities
    gene_ind_res = clf_gene.score(gene_test, test_class)
    meth_ind_res = clf_meth.score(meth_test, test_class)
    CNV_ind_res = clf_CNV.score(CNV_test, test_class)
    miRNA_ind_res = clf_miRNA.score(miRNA_test, test_class)

    # calculates auc for each modality
    c1_gene, c2_gene, _ = roc_curve(
        test_class.values.ravel(),
        clf_gene.decision_function(gene_test).ravel())
    c1_miRNA, c2_miRNA, _ = roc_curve(
        test_class.values.ravel(),
        clf_miRNA.decision_function(miRNA_test).ravel())
    c1_CNV, c2_CNV, _ = roc_curve(test_class.values.ravel(),
                                  clf_CNV.decision_function(CNV_test).ravel())
    c1_meth, c2_meth, _ = roc_curve(
        test_class.values.ravel(),
        clf_meth.decision_function(meth_test).ravel())
    area_gene = auc(c1_gene, c2_gene)
    area_miRNA = auc(c1_miRNA, c2_miRNA)
    area_CNV = auc(c1_CNV, c2_CNV)
    area_meth = auc(c1_meth, c2_meth)

    if make_plot == 'ROC_gene' or make_plot == 'ROC_miRNA' or make_plot == 'ROC_meth' or make_plot == 'ROC_CNV':
        if make_plot == 'ROC_gene':
            c1, c2, area = c1_gene, c2_gene, area_gene
        elif make_plot == 'ROC_miRNA':
            c1, c2, area = c1_miRNA, c2_miRNA, area_miRNA
        elif make_plot == 'ROC_meth':
            c1, c2, area = c1_meth, c2_meth, area_meth
        elif make_plot == 'ROC_CNV':
            c1, c2, area = c1_CNV, c2_CNV, area_CNV

        plt.title('Receiver Operating Characteristic')
        plt.plot(c1, c2, 'b', label='AUC = %0.2f' %
                 area)  # change params to change modality
        plt.legend(loc='lower right')
        plt.plot([0, 1], [0, 1], 'r--')
        plt.xlim([0, 1])
        plt.ylim([0, 1])
        plt.ylabel('True Positive Rate')
        plt.xlabel('False Positive Rate')
        plt.show()

    # START OF DATA INTEGRATION

    fins = []
    areas = []
    fin1s = []
    combinations = [["meth"], ["miRNA"], ["gene"], ["CNV"], ["meth", "miRNA"],
                    ["meth", "gene"], ["meth", "CNV"], ["miRNA", "gene"],
                    ["miRNA", "CNV"], ["gene", "CNV"],
                    ["meth", "miRNA", "gene"], ["meth", "miRNA", "CNV"],
                    ["miRNA", "gene", "CNV"], ["meth", "gene", "CNV"],
                    ["meth", "miRNA", "gene", "CNV"]]

    # shrink training feature matricies to selected features
    miRNA_train_copy2 = miRNA_train_copy2[fea_miRNA]
    gene_train_copy2 = gene_train_copy2[fea_gene]
    meth_train_copy2 = meth_train_copy2[fea_meth]
    CNV_train_copy2 = CNV_train_copy2[fea_CNV]

    # gets prediction probabilities for samples in the training data
    pred_miRNA = clf_miRNA.predict_proba(miRNA_train_copy2)[:, 0]
    pred_gene = clf_gene.predict_proba(gene_train_copy2)[:, 0]
    pred_CNV = clf_CNV.predict_proba(CNV_train_copy2)[:, 0]
    pred_meth = clf_meth.predict_proba(meth_train_copy2)[:, 0]

    # creates dataframe with the prediction probabilities
    new_feats = {
        'sample': miRNA_train.index.values,
        'miRNA': pred_miRNA,
        'gene': pred_gene,
        'meth': pred_meth,
        'CNV': pred_CNV
    }
    new_feats = pd.DataFrame(data=new_feats)
    new_feats = new_feats.set_index('sample')

    clfs = []
    cvals = []
    for com in combinations:
        print(com)
        clf = tr_comb_grid(new_feats[com], train_class_copy5)
        # clf = bayes(new_feats[com],train_class_copy5)
        clfs.append(clf)

        pred_miRNA = clf_miRNA.predict_proba(miRNA_test)[:, 0]
        pred_gene = clf_gene.predict_proba(gene_test)[:, 0]
        pred_CNV = clf_CNV.predict_proba(CNV_test)[:, 0]
        pred_meth = clf_meth.predict_proba(meth_test)[:, 0]

        # creates new feature matrix for test predictions
        new_feats_val = {
            'sample': miRNA_test.index.values,
            'miRNA': pred_miRNA,
            'gene': pred_gene,
            'meth': pred_meth,
            'CNV': pred_CNV
        }
        new_feats_val = pd.DataFrame(data=new_feats_val)
        new_feats_val = new_feats_val.set_index('sample')

        fin = clf.score(new_feats_val[com], test_class)
        pred = clf.decision_function(new_feats_val[com])
        c1, c2, _ = roc_curve(test_class.values.ravel(), pred.ravel())
        area = auc(c1, c2)
        cvals.append([c1, c2, area])
        print('auc: ', area)
        print('acc: ', fin)
        areas.append(area)
        fins.append(fin)

    # substitutes list entries for single modalities that were changed during integration
    fins[0] = meth_ind_res
    fins[1] = miRNA_ind_res
    fins[2] = gene_ind_res
    fins[3] = CNV_ind_res
    areas[0] = area_meth
    areas[1] = area_miRNA
    areas[2] = area_gene
    areas[3] = area_CNV
    clfs[0] = clf_meth
    clfs[1] = clf_miRNA
    clfs[2] = clf_gene
    clfs[3] = clf_CNV

    indx = np.argmax(fins)
    tr_score = clfs[indx].score(new_feats[combinations[indx]],
                                train_class_copy5)
    te_score = clfs[indx].score(new_feats_val[combinations[indx]], test_class)

    clf = clfs[indx]

    if make_plot == 'barplot':
        n_groups = 15

        fig, ax = plt.subplots()

        index = np.arange(n_groups)
        bar_width = 0.35

        opacity = 0.4

        rects2 = ax.bar(index,
                        areas,
                        bar_width,
                        alpha=opacity,
                        color='r',
                        label='auc')

        rects3 = ax.bar(index + bar_width,
                        fins,
                        bar_width,
                        alpha=opacity,
                        color='b',
                        label='score')

        ax.set_xlabel('Combination')
        ax.set_ylabel('Scores')

        ax.set_xticks(index + bar_width / 2)
        ax.set_xticklabels([
            'meth', 'miRNA', 'gene', 'CNV', 'meth\nmiRNA', 'meth\ngene',
            'meth\nCNV', 'miRNA\ngene', 'miRNA\nCNV', 'gene\nCNV',
            'meth\nmiRNA\ngene', 'meth\nmiRNA\nCNV', 'miRNA\ngene\nCNV',
            'meth\ngene\nCNV', 'meth\nmiRNA\ngene\nCNV'
        ])
        ax.legend()

        fig.tight_layout()
        plt.show()
    elif make_plot == 'ROC_int':
        plt.title('Receiver Operating Characteristic')
        plt.plot(cvals[indx][0],
                 cvals[indx][1],
                 'b',
                 label='AUC = %0.2f' % cvals[indx][2])
        plt.legend(loc='lower right')
        plt.plot([0, 1], [0, 1], 'r--')
        plt.xlim([0, 1])
        plt.ylim([0, 1])
        plt.ylabel('True Positive Rate')
        plt.xlabel('False Positive Rate')
        plt.show()
    elif make_plot == 'KM' or make_plot == 'KM Gene' or make_plot == 'KM miRNA'\
            or make_plot == 'KM Meth' or make_plot == 'KM CNV' or make_plot == 'KM Integrated':
        ax = plt.subplot(111)
        kmf = KaplanMeierFitter()
        m = max(surv["days_to_death"])
        fill_max = {"days_to_death": m}
        surv = surv.fillna(value=fill_max)
        T = surv["days_to_death"]
        surv = surv.replace("alive", False)
        surv = surv.replace("dead", True)
        E = surv["vital_status"]
        if make_plot == 'KM':
            kmf.fit(T, event_observed=E)
            kmf.plot()
            class0 = (surv["label"] == 0)
            kmf.fit(T[class0],
                    event_observed=E[class0],
                    label="Low Survival (<5 years)")
            kmf.plot_survival_function(ax=ax, ci_show=False, fontsize=20)
            kmf.fit(T[~class0],
                    event_observed=E[~class0],
                    label="High Survival (>= 5 years)")
            kmf.plot_survival_function(ax=ax, ci_show=False, fontsize=20)
            ax.set_xlabel("Duration (days)", fontsize=20)
            ax.set_ylabel("Percent Alive", fontsize=20)
            ax.set_title("Breast Cancer Kaplan Meier Survival Curve",
                         fontsize=32)
        elif make_plot == "Kaplan Meier Integrated":
            surv2 = surv.copy(True)
            surv2 = surv2.loc[surv2["case_id"].isin(test_labels.values)]
            surv2 = surv2.set_index("case_id")
            surv2 = surv2.reindex(test_class.index.values)
            prd = clf.predict(new_feats_val[["meth", "miRNA", "gene"]])
            surv2["label_new"] = prd

            count = 0
            preds = clf.predict(new_feats_val[["meth", "miRNA", "gene"]])
            for i in range(test_class.shape[0]):
                if preds[i] == test_class.iloc[i, 0]:
                    count += 1
            print(count)

            T2 = surv2["days_to_death"]
            E2 = surv2["vital_status"]
            class02 = (surv2["label_new"] == 0)
            kmf.fit(T2[class02],
                    event_observed=E2[class02],
                    label="Low Survival")
            kmf.plot_survival_function(ax=ax, ci_show=False)
            kmf.fit(T2[~class02],
                    event_observed=E2[~class02],
                    label="High Survival")
            kmf.plot_survival_function(ax=ax, ci_show=False)
            ax.set_xlabel("Duration (days)")
            ax.set_ylabel("Percent Alive")
            ax.set_title("Integrated Classifier Kaplan Meier Survival Curve")
            ax.set_ylim((0, 1))

            plt.show()
        elif make_plot == 'KM Gene':
            surv2 = surv.copy(True)
            surv2 = surv2.loc[surv2["case_id"].isin(test_labels.values)]
            surv2 = surv2.set_index("case_id")
            surv2 = surv2.reindex(test_class.index.values)
            prd = clf_gene.predict(gene_test.loc[:, fea_gene])
            surv2["label_new"] = prd

            count = 0
            preds = clf_gene.predict(gene_test.loc[:, fea_gene])
            for i in range(test_class.shape[0]):
                if preds[i] == test_class.iloc[i, 0]:
                    count += 1
            print(count)

            T2 = surv2["days_to_death"]
            E2 = surv2["vital_status"]
            class02 = (surv2["label_new"] == 0)
            kmf.fit(T2[class02],
                    event_observed=E2[class02],
                    label="Low Survival")
            kmf.plot_survival_function(ax=ax, ci_show=False)
            kmf.fit(T2[~class02],
                    event_observed=E2[~class02],
                    label="High Survival")
            kmf.plot_survival_function(ax=ax, ci_show=False)
            ax.set_xlabel("Duration (days)")
            ax.set_ylabel("Percent Alive")
            ax.set_title("Gene Expression Kaplan Meier Survival Curve")
            ax.set_ylim((0, 1))

            plt.show()
        elif make_plot == 'KM miRNA':
            surv2 = surv.copy(True)
            surv2 = surv2.loc[surv2["case_id"].isin(test_labels.values)]
            surv2 = surv2.set_index("case_id")
            surv2 = surv2.reindex(test_class.index.values)
            prd = clf_miRNA.predict(miRNA_test.loc[:, fea_miRNA])
            surv2["label_new"] = prd

            count = 0
            preds = clf_miRNA.predict(miRNA_test.loc[:, fea_miRNA])
            for i in range(test_class.shape[0]):
                if preds[i] == test_class.iloc[i, 0]:
                    count += 1
            print(count)

            T2 = surv2["days_to_death"]
            E2 = surv2["vital_status"]
            class02 = (surv2["label_new"] == 0)
            kmf.fit(T2[class02],
                    event_observed=E2[class02],
                    label="Low Survival")
            kmf.plot_survival_function(ax=ax, ci_show=False)
            kmf.fit(T2[~class02],
                    event_observed=E2[~class02],
                    label="High Survival")
            kmf.plot_survival_function(ax=ax, ci_show=False)
            ax.set_xlabel("Duration (days)")
            ax.set_ylabel("Percent Alive")
            ax.set_title("miRNA Expression Kaplan Meier Survival Curve")
            ax.set_ylim((0, 1))

            plt.show()
        elif make_plot == 'KM Meth':
            surv2 = surv.copy(True)
            surv2 = surv2.loc[surv2["case_id"].isin(test_labels.values)]
            surv2 = surv2.set_index("case_id")
            surv2 = surv2.reindex(test_class.index.values)
            prd = clf_meth.predict(meth_test.loc[:, fea_meth])
            surv2["label_new"] = prd

            count = 0
            preds = clf_meth.predict(meth_test.loc[:, fea_meth])
            for i in range(test_class.shape[0]):
                if preds[i] == test_class.iloc[i, 0]:
                    count += 1
            print(count)

            T2 = surv2["days_to_death"]
            E2 = surv2["vital_status"]
            class02 = (surv2["label_new"] == 0)
            kmf.fit(T2[class02],
                    event_observed=E2[class02],
                    label="Low Survival")
            kmf.plot_survival_function(ax=ax, ci_show=False)
            kmf.fit(T2[~class02],
                    event_observed=E2[~class02],
                    label="High Survival")
            kmf.plot_survival_function(ax=ax, ci_show=False)
            ax.set_xlabel("Duration (days)")
            ax.set_ylabel("Percent Alive")
            ax.set_title("DNA Methylation Kaplan Meier Survival Curve")
            ax.set_ylim((0, 1))

            plt.show()
        elif make_plot == 'KM CNV':
            surv2 = surv.copy(True)
            surv2 = surv2.loc[surv2["case_id"].isin(test_labels.values)]
            surv2 = surv2.set_index("case_id")
            surv2 = surv2.reindex(test_class.index.values)
            prd = clf_CNV.predict(CNV_test.loc[:, fea_CNV])
            surv2["label_new"] = prd

            count = 0
            preds = clf_CNV.predict(CNV_test.loc[:, fea_CNV])
            for i in range(test_class.shape[0]):
                if preds[i] == test_class.iloc[i, 0]:
                    count += 1
            print(count)

            T2 = surv2["days_to_death"]
            E2 = surv2["vital_status"]
            class02 = (surv2["label_new"] == 0)
            kmf.fit(T2[class02],
                    event_observed=E2[class02],
                    label="Low Survival")
            kmf.plot_survival_function(ax=ax, ci_show=False)
            kmf.fit(T2[~class02],
                    event_observed=E2[~class02],
                    label="High Survival")
            kmf.plot_survival_function(ax=ax, ci_show=False)
            ax.set_xlabel("Duration (days)")
            ax.set_ylabel("Percent Alive")
            ax.set_title("CNV Kaplan Meier Survival Curve")
            ax.set_ylim((0, 1))

            plt.show()
    elif make_plot == 'Gene Distribution' or make_plot == 'miRNA Distribution' or make_plot == 'Meth Distribution' \
            or make_plot == 'CNV Distribution':
        if make_plot == 'Gene Distribution':
            lbl = lbl.set_index("case_id")
            boxframe = gene.loc[fea_gene, :].T
            boxframe = boxframe.iloc[:, 0:10]
            cols = boxframe.columns
            g = genes.loc[cols, :]
            boxframe = pd.concat([boxframe, lbl[["label"]]], axis=1)
            c1_gene = boxframe[boxframe["label"] == 1]
            c1_gene = c1_gene.iloc[:, 0:10]
            c2_gene = boxframe[boxframe["label"] == 0]
            c2_gene = c2_gene.iloc[:, 0:10]
            for i in range(0, 10):
                cl = ['High', 'Low']
                ind = [1, 2]
                plt.subplot(2, 5, i + 1)
                p1 = c1_gene.iloc[:, i]
                p2 = c2_gene.iloc[:, i]
                genes_to_plot = (p1, p2)
                plt.boxplot(genes_to_plot)
                plt.ylabel('Gene Expression (FPKM)', fontsize=16)
                if i > 4:
                    plt.xlabel('Survival Class', fontsize=16)
                plt.xticks(ind, cl, fontsize=16)
                max = 0
                max1 = (p1.values.max())
                max2 = (p2.values.max())
                if max1 > max2:
                    max = max1
                else:
                    max = max2
                plt.yticks([0, max], rotation='vertical')
                # print(cols[i])
                # title_str = 'Survival Breakdown of CNV Features by Gain/Loss: {}'
                # title_str = title_str.format(cols[i])
                plt.title(genes.loc[cols[i], 'hgnc_symbol'], fontsize=18)
            plt.suptitle('Top 10 Gene Feature Distributions by Survival Class',
                         fontsize=22)
            plt.show()
        elif make_plot == 'miRNA Distribution':
            lbl = lbl.set_index("case_id")
            boxframe = miRNA.loc[fea_miRNA, :].T
            boxframe = boxframe.iloc[:, 0:10]
            cols = boxframe.columns
            boxframe = pd.concat([boxframe, lbl[["label"]]], axis=1)
            c1_miRNA = boxframe[boxframe["label"] == 1]
            c1_miRNA = c1_miRNA.iloc[:, 0:10]
            c2_miRNA = boxframe[boxframe["label"] == 0]
            c2_miRNA = c2_miRNA.iloc[:, 0:10]
            for i in range(0, 10):
                cl = ['High', 'Low']
                ind = [1, 2]
                plt.subplot(2, 5, i + 1)
                p1 = c1_miRNA.iloc[:, i]
                p2 = c2_miRNA.iloc[:, i]
                genes_to_plot = (p1, p2)
                plt.boxplot(genes_to_plot)
                plt.ylabel('miRNA Expression (RPM)', fontsize=16)
                if i > 4:
                    plt.xlabel('Survival Class', fontsize=16)
                plt.xticks(ind, cl, fontsize=16)
                max = 0
                max1 = (p1.values.max())
                max2 = (p2.values.max())
                if max1 > max2:
                    max = max1
                else:
                    max = max2
                plt.yticks([0, max], rotation='vertical')
                plt.title(cols[i], fontsize=18)
            plt.suptitle(
                'Top 10 miRNA Feature Distributions by Survival Class',
                fontsize=22)
            plt.show()
        elif make_plot == 'Meth Distribution':
            lbl = lbl.set_index("case_id")
            boxframe = meth.loc[fea_meth, :].T
            boxframe = boxframe.iloc[:, 0:10]
            cols = boxframe.columns
            boxframe = pd.concat([boxframe, lbl[["label"]]], axis=1)
            c1_meth = boxframe[boxframe["label"] == 1]
            c1_mmeth = c1_meth.iloc[:, 0:10]
            c2_meth = boxframe[boxframe["label"] == 0]
            c2_meth = c2_meth.iloc[:, 0:10]
            for i in range(0, 10):
                cl = ['High', 'Low']
                ind = [1, 2]
                plt.subplot(2, 5, i + 1)
                p1 = c1_meth.iloc[:, i]
                p2 = c2_meth.iloc[:, i]
                genes_to_plot = (p1, p2)
                plt.boxplot(genes_to_plot)
                plt.ylabel('DNA Methylation (Beta Value)', fontsize=16)
                if i > 4:
                    plt.xlabel('Survival Class', fontsize=16)
                plt.xticks(ind, cl, fontsize=16)
                max = 0
                max1 = (p1.values.max())
                max2 = (p2.values.max())
                if max1 > max2:
                    max = max1
                else:
                    max = max2
                plt.yticks([0, max], rotation='vertical')
                plt.title(cols[i], fontsize=18)
            plt.suptitle(
                'Top 10 DNA Methylation Feature Distributions by Survival Class',
                fontsize=22)
            plt.show()
        elif make_plot == 'CNV Distribution':
            lbl = lbl.set_index("case_id")
            cnv_lbl = pd.concat([CNV.loc[fea_CNV, :].T, lbl["label"]], axis=1)
            c1_CNV = cnv_lbl[cnv_lbl["label"] == 1]
            c1_CNV = c1_CNV.iloc[:, 0:16]
            cols = c1_CNV.columns.values
            c2_CNV = cnv_lbl[cnv_lbl["label"] == 0]
            c2_CNV = c2_CNV.iloc[:, 0:16]
            for i in range(0, 10):
                c1_CNV_1 = c1_CNV.iloc[:, i][c1_CNV.iloc[:,
                                                         i] == 1].sum().sum()
                c1_CNV_neg1 = -1 * (
                    c1_CNV.iloc[:, i][c1_CNV.iloc[:, i] == -1].sum().sum())
                c1_CNV_0 = (247 * 1) - (c1_CNV_1 + c1_CNV_neg1)
                c1_tot = (247 * 1)
                c1_pct_1 = c1_CNV_1 / c1_tot
                c1_pct_neg1 = c1_CNV_neg1 / c1_tot
                c1_pct_0 = c1_CNV_0 / c1_tot

                c2_CNV_1 = c2_CNV.iloc[:, i][c2_CNV.iloc[:,
                                                         i] == 1].sum().sum()
                c2_CNV_neg1 = -1 * (
                    c2_CNV.iloc[:, i][c2_CNV.iloc[:, i] == -1].sum().sum())
                c2_CNV_0 = (95 * 1) - (c2_CNV_1 + c2_CNV_neg1)
                c2_tot = (95 * 1)
                c2_pct_1 = c2_CNV_1 / c2_tot
                c2_pct_neg1 = c2_CNV_neg1 / c2_tot
                c2_pct_0 = c2_CNV_0 / c2_tot
                ones = (c1_pct_1, c2_pct_1)
                negones = (c1_pct_neg1, c2_pct_neg1)
                zeros = (c1_pct_0, c2_pct_0)
                hold = tuple(map(operator.add, zeros, negones))

                cl = ['High', 'Low']
                ind = [1, 2]
                width = 0.5

                plt.subplot(2, 5, i + 1)
                p1 = plt.bar(ind, negones, width)
                p2 = plt.bar(ind, zeros, width, bottom=negones)
                p3 = plt.bar(ind, ones, width, bottom=hold)
                if i == 0 or i == 5:
                    plt.ylabel('Percent of Top Features', fontsize=16)
                plt.xticks(ind, cl, fontsize=16)
                if i > 4:
                    plt.xlabel('Survival Class', fontsize=16)
                plt.yticks(np.arange(0, 1, 0.1))
                plt.legend((p1[0], p2[0], p3[0]), ('-1', '0', '1'))
                plt.title(genes.loc[cols[i][0:15], 'hgnc_symbol'], fontsize=18)
            plt.suptitle('Top 10 CNV Feature Distributions by Survival Class',
                         fontsize=22)
            plt.show()
    elif make_plot == 'Heat Map Gene' or make_plot == 'Heat Map miRNA' or make_plot == 'Heat Map Meth' \
        or make_plot == 'Heat Map CNV':
        if make_plot == 'Heat Map Gene':
            clf_gene_h, fea_gene_h, heat_gene = do_cv(gene_train, train_class,
                                                      gene_test, test_class,
                                                      'ttest', 'gene', 100, 2)
            n_c = 27
            n_feat = 49
            up_bound = 100
            lin_kern = np.empty(
                [n_c,
                 n_feat])  # c vals ( see in cv_2 ), # feats (n_max - 2 / 2)
            poly_kern = np.empty([n_c, n_feat])
            rbf_kern = np.empty([n_c, n_feat])
            sig_kern = np.empty([n_c, n_feat])

            linvec = []
            polvec = []
            rbfvec = []
            sigvec = []
            heat_gene_out = heat_gene[1:len(heat_gene), :]
            for n, i in enumerate(heat_gene_out):
                if n % 4 == 0:
                    linvec.append(i[3])
                elif n % 4 == 1:
                    polvec.append(i[3])
                elif n % 4 == 2:
                    rbfvec.append(i[3])
                elif n % 4 == 3:
                    sigvec.append(i[3])
            for n in range(0, n_feat):
                for c in range(0, n_c):
                    lin_kern[c][n] = linvec[c + (27 * n)]
            for n in range(0, n_feat):
                for c in range(0, n_c):
                    poly_kern[c][n] = polvec[c + (27 * n)]
            for n in range(0, n_feat):
                for c in range(0, n_c):
                    rbf_kern[c][n] = rbfvec[c + (27 * n)]
            for n in range(0, n_feat):
                for c in range(0, n_c):
                    sig_kern[c][n] = sigvec[c + (27 * n)]
            ax = plt.subplot(111)
            chosen = []
            out_kern = None
            if clf_gene_h.get_params(False)['kernel'] == 'linear':
                im = plt.imshow(lin_kern, cmap=plt.cm.get_cmap('coolwarm', 50))
                chosen = linvec
                out_kern = lin_kern.copy(True)
            elif clf_gene_h.get_params(False)['kernel'] == 'rbf':
                im = plt.imshow(rbf_kern, cmap=plt.cm.get_cmap('coolwarm', 50))
                chosen = rbfvec
                out_kern = rbf_kern.copy(True)
            elif clf_gene_h.get_params(False)['kernel'] == 'sigmoid':
                im = plt.imshow(sig_kern, cmap=plt.cm.get_cmap('coolwarm', 50))
                chosen = sigvec
                out_kern = sig_kern.copy(True)
            elif clf_gene_h.get_params(False)['kernel'] == 'poly':
                im = plt.imshow(poly_kern,
                                cmap=plt.cm.get_cmap('coolwarm', 50))
                chosen = polvec
                out_kern = poly_kern.copy(True)
            plt.colorbar()
            plt.clim(min(chosen), max(chosen))

            ax.set_xticks(np.arange(n_feat))
            ax.set_yticks(np.arange(n_c))
            ax.set_xticklabels(range(2, up_bound, 2), fontsize=12)
            ax.set_yticklabels([
                .001, .005, 0.1, .5, 1, 1.5, 2, 2.5, 3, 4, 5, 6, 7, 8, 9, 10,
                11, 12, 15, 20, 25, 50, 75, 100, 150, 200, 250
            ],
                               fontsize=14)
            ax.set_xlabel("Number of features", fontsize=20)
            ax.set_ylabel("C value", fontsize=20)
            plt.setp(ax.get_xticklabels(), rotation=90, rotation_mode="anchor")
            for j in range(0, n_c):
                for k in range(0, n_feat):
                    text = ax.text(k,
                                   j,
                                   '%.2f' % out_kern[j, k],
                                   ha="center",
                                   va="center",
                                   color="w",
                                   fontsize=6)
            ax.set_title(
                "Gene Expression Parameter Sensitivity (Accuracy) - %s kernel"
                % clf_gene_h.get_params(False)['kernel'],
                fontsize=24)
            plt.show()
        elif make_plot == 'Heat Map miRNA':
            clf_miRNA_h, fea_miRNA_h, heat_miRNA = do_cv(
                miRNA_train, train_class, miRNA_test, test_class, 'minfo',
                'miRNA', 100, 2)
            n_c = 27
            n_feat = 49
            up_bound = 100
            lin_kern = np.empty(
                [n_c,
                 n_feat])  # c vals ( see in cv_2 ), # feats (n_max - 2 / 2)
            poly_kern = np.empty([n_c, n_feat])
            rbf_kern = np.empty([n_c, n_feat])
            sig_kern = np.empty([n_c, n_feat])

            linvec = []
            polvec = []
            rbfvec = []
            sigvec = []
            heat_miRNA_out = heat_miRNA[1:len(heat_miRNA), :]
            for n, i in enumerate(heat_miRNA_out):
                if n % 4 == 0:
                    linvec.append(i[3])
                elif n % 4 == 1:
                    polvec.append(i[3])
                elif n % 4 == 2:
                    rbfvec.append(i[3])
                elif n % 4 == 3:
                    sigvec.append(i[3])
            for n in range(0, n_feat):
                for c in range(0, n_c):
                    lin_kern[c][n] = linvec[c + (27 * n)]
            for n in range(0, n_feat):
                for c in range(0, n_c):
                    poly_kern[c][n] = polvec[c + (27 * n)]
            for n in range(0, n_feat):
                for c in range(0, n_c):
                    rbf_kern[c][n] = rbfvec[c + (27 * n)]
            for n in range(0, n_feat):
                for c in range(0, n_c):
                    sig_kern[c][n] = sigvec[c + (27 * n)]
            ax = plt.subplot(111)
            chosen = []
            out_kern = None
            if clf_miRNA_h.get_params(False)['kernel'] == 'linear':
                im = plt.imshow(lin_kern, cmap=plt.cm.get_cmap('coolwarm', 50))
                chosen = linvec
                out_kern = lin_kern.copy(True)
            elif clf_miRNA_h.get_params(False)['kernel'] == 'rbf':
                im = plt.imshow(rbf_kern, cmap=plt.cm.get_cmap('coolwarm', 50))
                chosen = rbfvec
                out_kern = rbf_kern.copy(True)
            elif clf_miRNA_h.get_params(False)['kernel'] == 'sigmoid':
                im = plt.imshow(sig_kern, cmap=plt.cm.get_cmap('coolwarm', 50))
                chosen = sigvec
                out_kern = sig_kern.copy(True)
            elif clf_miRNA_h.get_params(False)['kernel'] == 'poly':
                im = plt.imshow(poly_kern,
                                cmap=plt.cm.get_cmap('coolwarm', 50))
                chosen = polvec
                out_kern = poly_kern.copy(True)
            plt.colorbar()
            plt.clim(min(chosen), max(chosen))

            ax.set_xticks(np.arange(n_feat))
            ax.set_yticks(np.arange(n_c))
            ax.set_xticklabels(range(2, up_bound, 2), fontsize=12)
            ax.set_yticklabels([
                .001, .005, 0.1, .5, 1, 1.5, 2, 2.5, 3, 4, 5, 6, 7, 8, 9, 10,
                11, 12, 15, 20, 25, 50, 75, 100, 150, 200, 250
            ],
                               fontsize=14)
            ax.set_xlabel("Number of features", fontsize=20)
            ax.set_ylabel("C value", fontsize=20)
            plt.setp(ax.get_xticklabels(), rotation=90, rotation_mode="anchor")
            for j in range(0, n_c):
                for k in range(0, n_feat):
                    text = ax.text(k,
                                   j,
                                   '%.2f' % out_kern[j, k],
                                   ha="center",
                                   va="center",
                                   color="w",
                                   fontsize=6)
            ax.set_title(
                "miRNA Expression Parameter Sensitivity (Accuracy) - %s kernel"
                % clf_miRNA_h.get_params(False)['kernel'],
                fontsize=24)
            plt.show()
        elif make_plot == 'Heat Map Meth':
            clf_meth_h, fea_meth_h, heat_meth = do_cv(meth_train, train_class,
                                                      meth_test, test_class,
                                                      'minfo', 'meth', 100, 2)
            n_c = 27
            n_feat = 49
            up_bound = 100
            lin_kern = np.empty(
                [n_c,
                 n_feat])  # c vals ( see in cv_2 ), # feats (n_max - 2 / 2)
            poly_kern = np.empty([n_c, n_feat])
            rbf_kern = np.empty([n_c, n_feat])
            sig_kern = np.empty([n_c, n_feat])

            linvec = []
            polvec = []
            rbfvec = []
            sigvec = []
            heat_meth_out = heat_meth[1:len(heat_meth), :]
            for n, i in enumerate(heat_meth_out):
                if n % 4 == 0:
                    linvec.append(i[3])
                elif n % 4 == 1:
                    polvec.append(i[3])
                elif n % 4 == 2:
                    rbfvec.append(i[3])
                elif n % 4 == 3:
                    sigvec.append(i[3])
            for n in range(0, n_feat):
                for c in range(0, n_c):
                    lin_kern[c][n] = linvec[c + (27 * n)]
            for n in range(0, n_feat):
                for c in range(0, n_c):
                    poly_kern[c][n] = polvec[c + (27 * n)]
            for n in range(0, n_feat):
                for c in range(0, n_c):
                    rbf_kern[c][n] = rbfvec[c + (27 * n)]
            for n in range(0, n_feat):
                for c in range(0, n_c):
                    sig_kern[c][n] = sigvec[c + (27 * n)]
            ax = plt.subplot(111)
            chosen = []
            out_kern = None
            if clf_meth_h.get_params(False)['kernel'] == 'linear':
                im = plt.imshow(lin_kern, cmap=plt.cm.get_cmap('coolwarm', 50))
                chosen = linvec
                out_kern = lin_kern.copy(True)
            elif clf_meth_h.get_params(False)['kernel'] == 'rbf':
                im = plt.imshow(rbf_kern, cmap=plt.cm.get_cmap('coolwarm', 50))
                chosen = rbfvec
                out_kern = rbf_kern.copy(True)
            elif clf_meth_h.get_params(False)['kernel'] == 'sigmoid':
                im = plt.imshow(sig_kern, cmap=plt.cm.get_cmap('coolwarm', 50))
                chosen = sigvec
                out_kern = sig_kern.copy(True)
            elif clf_meth_h.get_params(False)['kernel'] == 'poly':
                im = plt.imshow(poly_kern,
                                cmap=plt.cm.get_cmap('coolwarm', 50))
                chosen = polvec
                out_kern = poly_kern.copy(True)
            plt.colorbar()
            plt.clim(min(chosen), max(chosen))

            ax.set_xticks(np.arange(n_feat))
            ax.set_yticks(np.arange(n_c))
            ax.set_xticklabels(range(2, up_bound, 2), fontsize=12)
            ax.set_yticklabels([
                .001, .005, 0.1, .5, 1, 1.5, 2, 2.5, 3, 4, 5, 6, 7, 8, 9, 10,
                11, 12, 15, 20, 25, 50, 75, 100, 150, 200, 250
            ],
                               fontsize=14)
            ax.set_xlabel("Number of features", fontsize=20)
            ax.set_ylabel("C value", fontsize=20)
            plt.setp(ax.get_xticklabels(), rotation=90, rotation_mode="anchor")
            for j in range(0, n_c):
                for k in range(0, n_feat):
                    text = ax.text(k,
                                   j,
                                   '%.2f' % out_kern[j, k],
                                   ha="center",
                                   va="center",
                                   color="w",
                                   fontsize=6)
            ax.set_title(
                "DNA Methylation Parameter Sensitivity (Accuracy) - %s kernel"
                % clf_meth_h.get_params(False)['kernel'],
                fontsize=24)
            plt.show()
        elif make_plot == 'Heat Map CNV':
            clf_CNV_h, fea_CNV_h, heat_CNV = do_cv(CNV_train, train_class,
                                                   CNV_test, test_class,
                                                   'minfo', 'CNV', 100, 2)
            n_c = 27
            n_feat = 49
            up_bound = 100
            lin_kern = np.empty(
                [n_c,
                 n_feat])  # c vals ( see in cv_2 ), # feats (n_max - 2 / 2)
            poly_kern = np.empty([n_c, n_feat])
            rbf_kern = np.empty([n_c, n_feat])
            sig_kern = np.empty([n_c, n_feat])

            linvec = []
            polvec = []
            rbfvec = []
            sigvec = []
            heat_CNV_out = heat_CNV[1:len(heat_CNV), :]
            for n, i in enumerate(heat_CNV_out):
                if n % 4 == 0:
                    linvec.append(i[3])
                elif n % 4 == 1:
                    polvec.append(i[3])
                elif n % 4 == 2:
                    rbfvec.append(i[3])
                elif n % 4 == 3:
                    sigvec.append(i[3])
            for n in range(0, n_feat):
                for c in range(0, n_c):
                    lin_kern[c][n] = linvec[c + (27 * n)]
            for n in range(0, n_feat):
                for c in range(0, n_c):
                    poly_kern[c][n] = polvec[c + (27 * n)]
            for n in range(0, n_feat):
                for c in range(0, n_c):
                    rbf_kern[c][n] = rbfvec[c + (27 * n)]
            for n in range(0, n_feat):
                for c in range(0, n_c):
                    sig_kern[c][n] = sigvec[c + (27 * n)]
            ax = plt.subplot(111)
            chosen = []
            out_kern = None
            if clf_CNV_h.get_params(False)['kernel'] == 'linear':
                im = plt.imshow(lin_kern, cmap=plt.cm.get_cmap('coolwarm', 50))
                chosen = linvec
                out_kern = lin_kern.copy(True)
            elif clf_CNV_h.get_params(False)['kernel'] == 'rbf':
                im = plt.imshow(rbf_kern, cmap=plt.cm.get_cmap('coolwarm', 50))
                chosen = rbfvec
                out_kern = rbf_kern.copy(True)
            elif clf_CNV_h.get_params(False)['kernel'] == 'sigmoid':
                im = plt.imshow(sig_kern, cmap=plt.cm.get_cmap('coolwarm', 50))
                chosen = sigvec
                out_kern = sig_kern.copy(True)
            elif clf_CNV_h.get_params(False)['kernel'] == 'poly':
                im = plt.imshow(poly_kern,
                                cmap=plt.cm.get_cmap('coolwarm', 50))
                chosen = polvec
                out_kern = poly_kern.copy(True)
            plt.colorbar()
            plt.clim(min(chosen), max(chosen))

            ax.set_xticks(np.arange(n_feat))
            ax.set_yticks(np.arange(n_c))
            ax.set_xticklabels(range(2, up_bound, 2), fontsize=12)
            ax.set_yticklabels([
                .001, .005, 0.1, .5, 1, 1.5, 2, 2.5, 3, 4, 5, 6, 7, 8, 9, 10,
                11, 12, 15, 20, 25, 50, 75, 100, 150, 200, 250
            ],
                               fontsize=14)
            ax.set_xlabel("Number of features", fontsize=20)
            ax.set_ylabel("C value", fontsize=20)
            plt.setp(ax.get_xticklabels(), rotation=90, rotation_mode="anchor")
            for j in range(0, n_c):
                for k in range(0, n_feat):
                    text = ax.text(k,
                                   j,
                                   '%.2f' % out_kern[j, k],
                                   ha="center",
                                   va="center",
                                   color="w",
                                   fontsize=6)
            ax.set_title(
                "Copy Number Variation Parameter Sensitivity (Accuracy) - %s kernel"
                % clf_CNV_h.get_params(False)['kernel'],
                fontsize=24)
            plt.show()
    elif make_plot == None:
        pass

    if compute == 'custom':
        return clf_gene, fea_gene, clf_miRNA, fea_miRNA, clf_meth, fea_meth, clf_CNV, fea_CNV, clf

    return tr_score, te_score