def output_sim_data(model, surv, X_train, df_train, X_test, df_test): """ Compute the output of the model on the test set # Arguments model: neural network model trained with final parameters. X_train : input variables of the training set df_train: training dataset X_val : input variables of the validation set df_val: validation dataset # Returns results_test: Uno C-index at median survival time and Integrated Brier Score """ time_grid = np.linspace(np.percentile(df_test['yy'], 10), np.percentile(df_test['yy'], 90), 100) median_time = np.percentile(df_test['yy'], 50) data_train = skSurv.from_arrays(event=df_train['status'], time=df_train['yy']) data_test = skSurv.from_arrays(event=df_test['status'], time=df_test['yy']) c_med = concordance_index_ipcw( data_train, data_test, np.array(-determine_surv_prob(surv, median_time)), median_time)[0] ev = EvalSurv(surv, np.array(df_test['yy']), np.array(df_test['status']), censor_surv='km') ibs = ev.integrated_brier_score(time_grid) res = pd.DataFrame([c_med, ibs]).T res.columns = ['c_median', 'ibs'] return res
def evaluate(self, predictions, predictions_interpolated): """ Function responsible for computing evaluation metrics (C-index, Brier score). Also responsible for logging said metrics to MLflow using my Logger module. """ self.logger.info("Evaluating Model Strength.. \n") targets = dataset_2(self.dataset_filtered, phase='test', targets=True)[0][0] events = dataset_2(self.dataset_filtered, phase='test', targets=True)[0][1] survs = dataset_2(self.dataset_filtered, phase='test', targets=True)[0][2] censored = [] for surv in events: if surv == 1: censored.append(False) else: censored.append(True) self.logger.info("Calculating concordance and brier score\n") ev = EvalSurv(predictions_interpolated, survs, events, 'km') concordance = ev.concordance_td() time_grid = np.linspace(0, survs.max()) integrated_brier_score = ev.integrated_brier_score(time_grid) self.logger.info( "Concordance: {}, Integrated Brier Score: {}\n".format( concordance, integrated_brier_score)) logger = Logger(self.mlflow_url, "CNN_Predictions") logger.log_predictions(predictions, "{}".format(self.mode), concordance, integrated_brier_score, self.params, targets, self.cuts) self.logger.info("Logged Evaluation metrics to MLFlow\n")
def metric_fn_par(args): predicted_test_times, predicted_survival_functions, observed_y = args # predicted survival function dim: n_train * n_test obs_test_times = observed_y[:, 0].astype('float32') obs_test_events = observed_y[:, 1].astype('float32') results = [0, 0, 0, 0, 0] ev = EvalSurv(predicted_survival_functions, obs_test_times, obs_test_events, censor_surv='km') results[0] = ev.concordance_td('antolini') # concordance_antolini results[1] = concordance_index(obs_test_times, predicted_test_times, obs_test_events.astype(np.bool)) # concordance_median # we ignore brier scores at the highest test times because it becomes unstable time_grid = np.linspace(obs_test_times.min(), obs_test_times.max(), 100)[:80] results[2] = ev.integrated_brier_score(time_grid) # integrated_brier if sum(obs_test_events) > 0: # only noncensored samples are used for rmse/mae calculation pred_obs_differences = predicted_test_times[obs_test_events.astype(np.bool)] - obs_test_times[obs_test_events.astype(np.bool)] results[3] = np.sqrt(np.mean((pred_obs_differences)**2)) # rmse results[4] = np.mean(np.abs(pred_obs_differences)) # mae else: print("[WARNING] All samples are censored.") results[3] = 0 results[4] = 0 return results
def Coxnnet_evaluate(model, x1, x2, durations, events): # durations = durations.cpu().numpy() # events = events.cpu().numpy() _ = model.compute_baseline_hazards() if x2 is not None: surv = model.predict_surv_df((x1, x2)) else: surv = model.predict_surv_df(x1) ev = EvalSurv(surv, durations, events, censor_surv='km') return ev.concordance_td()
def get_metrics(val_data, surv_pred, time_grid): ev = EvalSurv(surv_pred, val_data['t'], val_data['y'], censor_surv='km') return pd.DataFrame([{ 'dt_c_index': ev.concordance_td('antolini'), 'int_brier_score': ev.integrated_brier_score(time_grid), 'int_nbill': ev.integrated_nbll(time_grid) }])
def test_quality(t_true, y_true, pred, time_grid=np.linspace(0, 300, 30, dtype=np.int), concordance_at_t=None, plot=False): # get survival proba for time_grid all_surv_time = pd.DataFrame() for t in time_grid: surv_prob = np.exp(-1 * np.power(t / (pred[:, 0] + 1e-6), pred[:, 1])) all_surv_time = pd.concat([all_surv_time, pd.DataFrame(surv_prob).T]) all_surv_time.index = time_grid ev = EvalSurv(surv=all_surv_time, durations=t_true, events=y_true, censor_surv='km') dt_c_index = ev.concordance_td('antolini') int_brier_score = ev.integrated_brier_score(time_grid) int_nbill = ev.integrated_nbll(time_grid) if plot: fig, ax = plt.subplots(1, 3, figsize=(20, 7)) d = all_surv_time.sample(5, axis=1).loc[1:] obs = d.columns for o in obs: ax[0].plot(d.index, d[o]) ax[0].set_xlabel('Time') ax[0].set_title("Sample survival curves") nb = ev.nbll(time_grid) ax[1].plot(time_grid, nb) ax[1].set_title('NBLL') ax[1].set_xlabel('Time') br = ev.brier_score(time_grid) ax[2].plot(time_grid, br) ax[2].set_title('Brier score') ax[2].set_xlabel('Time') plt.show() if concordance_at_t is not None: harell_c_index = concordance_index( predicted_scores=all_surv_time.loc[concordance_at_t, :].values, event_times=t_true, event_observed=y_true) return pd.DataFrame([{ 'harell_c_index': harell_c_index, 'dt_c_index': dt_c_index, 'int_brier_score': int_brier_score, 'int_nbill': int_nbill }]) else: return pd.DataFrame([{ 'dt_c_index': dt_c_index, 'int_brier_score': int_brier_score, 'int_nbill': int_nbill }])
def Bootstrap(self, surv, event: list, duration: list): np.random.seed(42) # control reproducibility cindex, brier, nbll = [], [], [] for _ in range(self.bootstrap_n): sampled_index = choices(range(surv.shape[1]), k=surv.shape[1]) sampled_surv = surv.iloc[:, sampled_index] sampled_event = [event[i] for i in sampled_index] sampled_duration = [duration[i] for i in sampled_index] ev = EvalSurv(sampled_surv, np.array(sampled_duration), np.array(sampled_event).astype(int), censor_surv='km') time_grid = np.linspace(min(sampled_duration), max(sampled_duration), 100) cindex.append(ev.concordance_td('antolini')) brier.append(ev.integrated_brier_score(time_grid)) nbll.append(ev.integrated_nbll(time_grid)) return cindex, brier, nbll
def run_baseline(runs=10): concordance = [] ibs = [] for i in tqdm(range(runs)): df_train,df_test,df_val = load_data("./summaries/survival_data") x_mapper, labtrans, train, val, x_test, durations_test, events_test, pca = transform_data( df_train,df_test,df_val,'LogisticHazard', "standard", cols_standardize, log_columns, num_durations=100) x_train, y_train = train cols = ['PC'+str(i) for i in range(x_train.shape[1])] + ['duration','event'] pc_col = ['PC'+str(i) for i in range(x_train.shape[1])] cox_train = pd.DataFrame(x_train,columns = pc_col) cox_test = pd.DataFrame(x_test,columns=pc_col) # cox_train.loc[:,pc_col] = x_train cox_train.loc[:,["duration"]] = y_train[0] cox_train.loc[:,'event'] = y_train[1] # cox_train = cox_train.drop(columns=[i for i in list(df_train) if i not in cols]) # cox_test.loc[:,pc_col] = x_test # cox_test = cox_test.drop(columns=[i for i in list(df_train) if i not in cols]) cox_train = cox_train.dropna() cox_test = cox_test.dropna() cph = CoxPHFitter().fit(cox_train, 'duration', 'event') # cph.print_summary() surv = cph.predict_survival_function(cox_test) ev = EvalSurv(surv, durations_test, events_test, censor_surv='km') concordance.append(ev.concordance_td('antolini')) time_grid = np.linspace(durations_test.min(), durations_test.max(), 100) ibs.append(ev.integrated_brier_score(time_grid)) print("Average concordance: %s"%np.mean(concordance)) print("Average IBS: %s"%np.mean(ibs)) plot_survival(cox_train, pc_col,cph,'./survival/cox',baseline=True)
def main(): parser = setup_parser() args = parser.parse_args() if args.which_gpu != 'none': os.environ["CUDA_VISIBLE_DEVICES"] = args.which_gpu # save setting if not os.path.exists(os.path.join(args.save_path, args.model_name)): os.mkdir(os.path.join(args.save_path, args.model_name)) # label transform labtrans = DeepHitSingle.label_transform(args.durations) # data reading seeting singnal_data_path = args.signal_dataset_path table_path = args.table_path time_col = 'SurvivalDays' event_col = 'Mortality' # dataset data_pathes, times, events = read_dataset(singnal_data_path, table_path, time_col, event_col, args.sample_ratio) data_pathes_train, data_pathes_test, times_train, times_test, events_train, events_test = train_test_split( data_pathes, times, events, test_size=0.3, random_state=369) data_pathes_train, data_pathes_val, times_train, times_val, events_train, events_val = train_test_split( data_pathes_train, times_train, events_train, test_size=0.2, random_state=369) labels_train = label_transfer(times_train, events_train) target_train = labtrans.fit_transform(*labels_train) dataset_train = VsDatasetBatch(data_pathes_train, *target_train) dl_train = tt.data.DataLoaderBatch(dataset_train, args.train_batch_size, shuffle=True) labels_val = label_transfer(times_val, events_val) target_val = labtrans.transform(*labels_val) dataset_val = VsDatasetBatch(data_pathes_val, *target_val) dl_val = tt.data.DataLoaderBatch(dataset_val, args.train_batch_size, shuffle=True) labels_test = label_transfer(times_test, events_test) dataset_test_x = VsTestInput(data_pathes_test) dl_test_x = DataLoader(dataset_test_x, args.test_batch_size, shuffle=False) net = resnet18(args) model = DeepHitSingle(net, tt.optim.Adam(lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=5e-4, amsgrad=False), duration_index=labtrans.cuts) # callbacks = [tt.cb.EarlyStopping(patience=15)] callbacks = [ tt.cb.BestWeights(file_path=os.path.join( args.save_path, args.model_name, args.model_name + '_bestWeight'), rm_file=False) ] verbose = True model_log = model.fit_dataloader(dl_train, args.epochs, callbacks, verbose, val_dataloader=dl_val) save_args(os.path.join(args.save_path, args.model_name), args) model_log.to_pandas().to_csv(os.path.join(args.save_path, args.model_name, 'loss.csv'), index=False) model.save_net( path=os.path.join(args.save_path, args.model_name, args.model_name + '_final')) surv = model.predict_surv_df(dl_test_x) surv.to_csv(os.path.join(args.save_path, args.model_name, 'test_sur_df.csv'), index=False) ev = EvalSurv(surv, *labels_test, 'km') print(ev.concordance_td()) save_cindex(os.path.join(args.save_path, args.model_name), ev.concordance_td()) print('done')
def pycox_deep(filename, Y_train, Y_test, opt, choice): # choice = {'lr_rate': l, 'batch': b, 'decay': 0, 'weighted_decay': wd, 'net': net, 'index': index} X_train, X_test = enc_using_trained_ae(filename, TARGET=opt, ALPHA=0.01, N_ITER=100, L1R=-9999) path = './models/analysis/' check = 0 savename = 'model_check_autoen_m5_test_batch+dropout+wd.csv' # r=root, d=directories, f = files for r, d, f in os.walk(path): for file in f: if savename in file: check = 1 # X_train = X_train.drop('UR_SG3', axis=1) # X_test = X_test.drop('UR_SG3', axis=1) x_train = X_train x_test = X_test x_train['SVDTEPC_G'] = Y_train['SVDTEPC_G'] x_train['PC_YN'] = Y_train['PC_YN'] x_test['SVDTEPC_G'] = Y_test['SVDTEPC_G'] x_test['PC_YN'] = Y_test['PC_YN'] ## DataFrameMapper ## cols_standardize = list(X_train.columns) cols_standardize.remove('SVDTEPC_G') cols_standardize.remove('PC_YN') standardize = [(col, None) for col in cols_standardize] x_mapper = DataFrameMapper(standardize) _ = x_mapper.fit_transform(X_train).astype('float32') X_train = x_mapper.transform(X_train).astype('float32') X_test = x_mapper.transform(X_test).astype('float32') get_target = lambda df: (df['SVDTEPC_G'].values, df['PC_YN'].values) y_train = get_target(x_train) durations_test, events_test = get_target(x_test) in_features = X_train.shape[1] print(in_features) num_nodes = choice['nodes'] out_features = 1 batch_norm = True # False for batch_normalization dropout = 0.01 output_bias = False # net = choice['net'] net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm, dropout, output_bias=output_bias) print("training") model = CoxPH(net, tt.optim.Adam) # lrfinder = model.lr_finder(X_train, y_train, batch_size) # lr_best = lrfinder.get_best_lr() lr_best = 0.0001 model.optimizer.set_lr(choice['lr_rate']) weighted_decay = choice['weighted_decay'] verbose = True batch_size = choice['batch'] epochs = 100 if weighted_decay == 0: callbacks = [tt.callbacks.EarlyStopping(patience=epochs)] # model.fit(X_train, y_train, batch_size, epochs, callbacks, verbose=verbose) else: callbacks = [tt.callbacks.DecoupledWeightDecay(weight_decay=choice['decay'])] # model.fit(X_train, y_train, batch_size, epochs, callbacks, verbose) '''''' # dataloader = model.make_dataloader(tt.tuplefy(X_train, y_train),batch_size,True) datas = tt.tuplefy(X_train, y_train).to_tensor() print(datas) make_dataset = tt.data.DatasetTuple; DataLoader = tt.data.DataLoaderBatch dataset = make_dataset(*datas) dataloader = DataLoader(dataset, batch_size, False, sampler=StratifiedSampler(datas, batch_size)) # dataloader = DataLoader(dataset,batch_size, True) model.fit_dataloader(dataloader, epochs, callbacks, verbose) # model.fit(X_train, y_train, batch_size, epochs, callbacks, verbose) # model.partial_log_likelihood(*val).mean() print("predicting") baseline_hazards = model.compute_baseline_hazards(datas[0], datas[1]) baseline_hazards = df(baseline_hazards) surv = model.predict_surv_df(X_test) surv = 1 - surv ev = EvalSurv(surv, durations_test, events_test, censor_surv='km') print("scoring") c_index = ev.concordance_td() print("c-index(", opt, "): ", c_index) if int(c_index * 10) == 0: hazardname = 'pycox_model_hazard_m5_v2_' + opt + '_0' netname = 'pycox_model_net_m5_v2_' + opt + '_0' weightname = 'pycox_model_weight_m5_v2_' + opt + '_0' else: hazardname = 'pycox_model_hazard_m5_' + opt + '_' netname = 'pycox_model_net_m5_' + opt + '_' weightname = 'pycox_model_weight_m5_' + opt + '_' baseline_hazards.to_csv('./test/'+hazardname + str(int(c_index * 100)) + '_' + str(index) + '.csv', index=False) netname = netname + str(int(c_index * 100)) + '_' + str(index) + '.sav' weightname = weightname + str(int(c_index * 100)) + '_' + str(index) + '.sav' model.save_net('./test/' + netname) model.save_model_weights('./test/' + weightname) pred = df(surv) pred = pred.transpose() surv_final = [] pred_final = [] for i in range(len(pred)): pred_final.append(float(1-pred[Y_test['SVDTEPC_G'][i]][i])) surv_final.append(float(pred[Y_test['SVDTEPC_G'][i]][i])) Y_test_cox = CoxformY(Y_test) #print(surv_final) c_cox, concordant, discordant,_,_ = concordance_index_censored(Y_test_cox['PC_YN'], Y_test_cox['SVDTEPC_G'], surv_final) c_cox_pred = concordance_index_censored(Y_test_cox['PC_YN'], Y_test_cox['SVDTEPC_G'], pred_final)[0] print("c-index(", opt, ") - sksurv: ", round(c_cox, 4)) print("cox-concordant(", opt, ") - sksurv: ", concordant) print("cox-disconcordant(", opt, ") - sksurv: ", discordant) print("c-index_pred(", opt, ") - sksurv: ", round(c_cox_pred, 4)) fpr, tpr, _ = metrics.roc_curve(Y_test['PC_YN'], pred_final) auc = metrics.auc(fpr, tpr) print("auc(", opt, "): ", round(auc, 4)) if check == 1: model_check = pd.read_csv(path+savename) else: model_check = df(columns=['option', 'gender', 'c-td', 'c-index', 'auc']) line_append = {'option':str(choice), 'gender':opt, 'c-td':round(c_index,4), 'c-index':round(c_cox_pred,4), 'auc':round(auc,4)} model_check = model_check.append(line_append, ignore_index=True) model_check.to_csv(path+savename, index=False) del X_train del X_test return surv_final
def train_deepsurv(data_df, r_splits): epochs = 100 verbose = True num_nodes = [32] out_features = 1 batch_norm = True dropout = 0.6 output_bias = False c_index_at = [] c_index_30 = [] time_auc_30 = [] time_auc_60 = [] time_auc_365 = [] for i in range(len(r_splits)): print("\nIteration %s"%(i)) #DATA PREP df_train, df_val, df_test, df_test_30 = prepare_datasets(data_df, r_splits[i][2], r_splits[i][1], r_splits[i][0]) xcols = list(df_train.columns) for col_name in ["subject_id", "event", "duration"]: if col_name in xcols: xcols.remove(col_name) cols_standardize = xcols standardize = [([col], StandardScaler()) for col in cols_standardize] x_mapper = DataFrameMapper(standardize) x_train = x_mapper.fit_transform(df_train).astype('float32') x_val = x_mapper.transform(df_val).astype('float32') x_test = x_mapper.transform(df_test).astype('float32') x_test_30 = x_mapper.transform(df_test_30).astype('float32') labtrans = CoxTime.label_transform() get_target = lambda df: (df['duration'].values, df['event'].values) y_train = labtrans.fit_transform(*get_target(df_train)) y_val = labtrans.transform(*get_target(df_val)) durations_test, events_test = get_target(df_test) durations_test_30, events_test_30 = get_target(df_test_30) val = tt.tuplefy(x_val, y_val) (train_x, train_y), (val_x, val_y), (test_x, test_y), _ = df2array(data_df, df_train, df_val, df_test, df_test_30) #MODEL in_features = x_train.shape[1] callbacks = [tt.callbacks.EarlyStopping()] net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm, dropout, output_bias=output_bias) model = CoxPH(net, tt.optim.Adam) model.optimizer.set_lr(0.0001) if x_train.shape[0] % 2: batch_size = 255 else: batch_size = 256 log = model.fit(x_train, y_train, batch_size, epochs, callbacks, val_data=val, val_batch_size=batch_size) model.compute_baseline_hazards() surv = model.predict_surv_df(x_test) ev = EvalSurv(surv, durations_test, events_test, censor_surv='km') c_index_at.append(ev.concordance_td()) surv_30 = model.predict_surv_df(x_test_30) ev_30 = EvalSurv(surv_30, durations_test_30, events_test_30, censor_surv='km') c_index_30.append(ev_30.concordance_td()) for time_x in [30, 60, 365]: va_auc, va_mean_auc = cumulative_dynamic_auc(train_y, test_y, model.predict(x_test).flatten(), time_x) eval("time_auc_" + str(time_x)).append(va_auc[0]) print("C-index_30:", c_index_30[i]) print("C-index_AT:", c_index_at[i]) print("time_auc_30", time_auc_30[i]) print("time_auc_60", time_auc_60[i]) print("time_auc_365", time_auc_365[i]) return c_index_at, c_index_30, time_auc_30, time_auc_60, time_auc_365
def train_LSTMCox(data_df, r_splits): epochs = 100 verbose = True in_features = 768 out_features = 1 batch_norm = True dropout = 0.6 output_bias = False c_index_at = [] c_index_30 = [] time_auc_30 = [] time_auc_60 = [] time_auc_365 = [] for i in range(len(r_splits)): print("\nIteration %s"%(i)) #DATA PREP df_train, df_val, df_test, df_test_30 = prepare_datasets(data_df, r_splits[i][2], r_splits[i][1], r_splits[i][0]) x_train = np.array(df_train["x0"].tolist()).astype("float32") x_val = np.array(df_val["x0"].tolist()).astype("float32") x_test = np.array(df_test["x0"].tolist()).astype("float32") x_test_30 = np.array(df_test_30["x0"].tolist()).astype("float32") labtrans = CoxTime.label_transform() get_target = lambda df: (df['duration'].values, df['event'].values) y_train = labtrans.fit_transform(*get_target(df_train)) y_val = labtrans.transform(*get_target(df_val)) durations_test, events_test = get_target(df_test) durations_test_30, events_test_30 = get_target(df_test_30) val = tt.tuplefy(x_val, y_val) (train_x, train_y), (val_x, val_y), (test_x, test_y), _ = df2array(data_df, df_train, df_val, df_test, df_test_30) #MODEL callbacks = [tt.callbacks.EarlyStopping()] net = LSTMCox(768, 32, 1, 1) model = CoxPH(net, tt.optim.Adam) model.optimizer.set_lr(0.0001) if x_train.shape[0] % 2: batch_size = 255 else: batch_size = 256 log = model.fit(x_train, y_train, batch_size, epochs, callbacks, val_data=val, val_batch_size=batch_size) model.compute_baseline_hazards() surv = model.predict_surv_df(x_test) ev = EvalSurv(surv, durations_test, events_test, censor_surv='km') c_index_at.append(ev.concordance_td()) surv_30 = model.predict_surv_df(x_test_30) ev_30 = EvalSurv(surv_30, durations_test_30, events_test_30, censor_surv='km') c_index_30.append(ev_30.concordance_td()) for time_x in [30, 60, 365]: va_auc, va_mean_auc = cumulative_dynamic_auc(train_y, test_y, model.predict(x_test).flatten(), time_x) eval("time_auc_" + str(time_x)).append(va_auc[0]) print("C-index_30:", c_index_30[i]) print("C-index_AT:", c_index_at[i]) print("time_auc_30", time_auc_30[i]) print("time_auc_60", time_auc_60[i]) print("time_auc_365", time_auc_365[i]) return c_index_at, c_index_30, time_auc_30, time_auc_60, time_auc_365
y_pred_train_surv = np.cumprod((1 - ypred_surv_train_NN), axis=1) oneyr_surv_train = y_pred_train_surv[:, 50] oneyr_surv_valid = y_pred_valid_surv[:, 50] surv_valid = pd.DataFrame(np.transpose(y_pred_valid_surv)) surv_valid.index = interval_l surv_train = pd.DataFrame(np.transpose(y_pred_train_surv)) surv_train.index = interval_l dict_cv_cindex_train[key] = concordance_index(dataTrain.time, oneyr_surv_train) dict_cv_cindex_valid[key] = concordance_index(dataValid.time, oneyr_surv_valid) #scores_train += concordance_index(dataTrain.time,oneyr_surv_train)#,data_train.dead) #scores_test += concordance_index(dataValid.time,oneyr_surv_valid) #cta.append(concordance_index(dataTrain.time,oneyr_surv_train)) #cte.append(concordance_index(dataValid.time,oneyr_surv_valid)) ev_valid = EvalSurv(surv_valid, time_valid, dead_valid, censor_surv='km') scores_test += ev_valid.concordance_td() ev_train = EvalSurv(surv_train, dataTrain['time'].values, dataTrain['dead'].values, censor_surv='km') scores_train += ev_train.concordance_td() cta.append(concordance_index(dataTrain.time, oneyr_surv_train)) cte.append(concordance_index(dataValid.time, oneyr_surv_valid)) ctda.append(ev_train.concordance_td()) ctde.append(ev_valid.concordance_td()) #scores_test += cindex_CV_score(Yvalid, ypred_test) #scores_train += cindex_CV_score(Ytrain, ypred_train) save_loss_png = "output/loss_cv_" + str(h) + "_KIRC.png" import seaborn as sns
epochs = args.epochs callbacks = [tt.callbacks.EarlyStopping(patience=patience)] verbose = True log = model.fit(x_train, y_train_transformed, batch_size, epochs, callbacks, verbose, val_data=val_transformed, val_batch_size=batch_size) # Evaluation =================================================================== surv = get_surv(model, x_test) ev = EvalSurv(surv, durations_test_transformed, events_test, censor_surv='km') # ctd = ev.concordance_td() ctd = concordance_index(event_times=durations_test_transformed, predicted_scores=model.predict(x_test).reshape(-1), event_observed=events_test) time_grid = np.linspace(durations_test_transformed.min(), durations_test_transformed.max(), 100) ibs = ev.integrated_brier_score(time_grid) nbll = ev.integrated_nbll(time_grid) val_loss = min(log.monitors['val_'].scores['loss']['score']) wandb.log({'val_loss': val_loss, 'ctd': ctd, 'ibs': ibs, 'nbll': nbll}) wandb.finish()
def _survLDA_train_variational_EM(self, word_count_matrix, survival_or_censoring_times, censoring_indicator_variables, val_word_count_matrix, val_survival_or_censoring_times, val_censoring_indicator_variables): """ Input: - K: number of topics - alpha0: hyperparameter for Dirichlet prior Output: - tau: 2D numpy array; g-th row is for g-th topic and consists of word distribution for that topic - beta: 1D numpy array with length equal to the number of topics (Cox regression coefficients) - h0_reformatted: 2D numpy array; first row is sorted unique times and second row is the discretized version of log(h0) - gamma: 2D numpy array: i-th row is variational parameter for Dirichlet distribution for i-th subject (length is number of topics) - phi: list of length given by number of subjects; i-th element is a 2D numpy array with number of rows given by the number of words in i-th subject's document and number of columns given by number of topics (variational parameter for how much each word of each subject belongs to different topics) - rmse: estimated median survival time RMSE computed using validation data - mae: estimated median survival time MAE computed using validation data - cindex: concordance index computed using validation data - stop_iter: interation number after stopping """ # ********************************************************************* # 1. SET UP TOPIC MODEL PARAMETERS BASED ON TRAINING DATA self.word_count_matrix = word_count_matrix num_subjects, num_words = self.word_count_matrix.shape # train size and vocabulary size # the length of each patient's document (i.e. sum of all words including duplicates) doc_length = self.word_count_matrix.sum(axis=1).astype(np.int64) # tau tells us what the word distribution is per topic # - each row corresponds to a topic # - each row is a probability distribution, which we initialize to be a # uniform distribution across the words (so each entry is 1 divided by `num_words` if self.random_init_tau: self.tau = np.random.rand(self.n_topics, num_words) self.tau /= self.tau.sum(axis=1)[:, np.newaxis] # normalize tau else: self.tau = np.full((self.n_topics, num_words), 1. / num_words, dtype=np.float64) # variational distribution parameter gamma tells us the Dirichlet # distribution parameter for the topic distribution specific to each subject; # we can initialize this to be all ones corresponding to a uniform distribution prior self.gamma = np.ones((num_subjects, self.n_topics), dtype=np.float64) # variational distribution parameter phi tells us the probabilities of # each subject's words coming from each of the K different topics; we can # initialize these to be uniform over topics (1/K) self.W, self.phi = self._word_count_matrix_to_word_vectors_phi( word_count_matrix) val_W, _ = self._word_count_matrix_to_word_vectors_phi( val_word_count_matrix) # ********************************************************************* # 2. SET UP COX BASELINE HAZARD FUNCTION # Cox baseline hazard function h0 can be represented as a finite 1D vector death_counter = {} for t in survival_or_censoring_times: if t not in death_counter: death_counter[t] = 1 else: death_counter[t] += 1 sorted_unique_times = np.sort(list(death_counter.keys())) self.sorted_unique_times = sorted_unique_times num_unique_times = len(sorted_unique_times) self.log_h0_discretized = np.zeros(num_unique_times) for r, t in enumerate(sorted_unique_times): self.log_h0_discretized[r] = np.log(death_counter[t]) log_H0 = [] for r in range(num_unique_times): log_H0.append(logsumexp(self.log_h0_discretized[:(r + 1)])) log_H0 = np.array(log_H0) time_map = {t: r for r, t in enumerate(sorted_unique_times)} time_order = np.array( [time_map[t] for t in survival_or_censoring_times], dtype=np.int64) # ********************************************************************* # 3. EM MAIN LOOP pool = multiprocessing.Pool(4) stop_iter = 0 val_censoring_indicator_variables = val_censoring_indicator_variables.astype( np.bool) for EM_iter_idx in range(self.max_iter): # ------------------------------------------------------------------ # E-step (update gamma, phi; uses helper variables psi, xi) if self.verbose: print('[Variational EM iteration %d]' % (EM_iter_idx + 1)) print(' Running E-step...', end='', flush=True) tic = time() # update gamma (dimensions: `num_subjects` by `K`) for i, phi_i in enumerate(self.phi): self.gamma[i] = self.alpha + phi_i.sum(axis=0) # compute psi (dimensions: `num_subjects` by `K`) psi = digamma(self.gamma) - digamma( self.gamma.sum(axis=1))[:, np.newaxis] # update phi, this normalizes already self.phi = pool.map(SurvLDA._compute_updated_phi_i_pmap_helper, [(i, self.phi[i], psi[i], self.W[i], self.tau, self.beta, np.exp(log_H0[time_order[i]]), censoring_indicator_variables[i], doc_length[i], self.n_topics) for i in range(num_subjects)]) toc = time() if self.verbose: print(' Done. Time elapsed: %f second(s).' % (toc - tic), flush=True) # -------------------------------------------------------------------- # M-step (update tau, beta, h0, H0; uses helper variable phi_bar) if self.verbose: print(' Running M-step...', end='', flush=True) tic = time() # update tau tau = np.zeros((self.n_topics, num_words), dtype=np.float64) for i in range(num_subjects): for j in range(doc_length[i]): word = self.W[i][j] for k in range(self.n_topics): tau[k][word] += self.phi[i][j, k] # normalize tau self.tau /= tau.sum(axis=1)[:, np.newaxis] phi_bar = np.zeros((num_subjects, self.n_topics), dtype=np.float64) for i in range(num_subjects): phi_bar[i] = self.phi[i].sum(axis=0) / doc_length[i] y_r = np.vstack( (survival_or_censoring_times, censoring_indicator_variables)).T beta_phi_bar = np.dot(phi_bar, self.beta) log_h0_discretized_ = np.zeros(num_unique_times) log_H0_ = [] for r, t in enumerate(sorted_unique_times): R_t = np.where(survival_or_censoring_times >= t, True, False) log_h0_discretized_[r] = np.log(death_counter[t]) - logsumexp( beta_phi_bar[R_t]) log_H0_.append(logsumexp(log_h0_discretized_[:(r + 1)])) # update beta def obj_fun(beta_): beta_phi_bar_ = np.dot(phi_bar, beta_) fun_val = np.dot(censoring_indicator_variables, beta_phi_bar_) # finally, we add in the third term for i in range(num_subjects): product_terms = np.dot(self.phi[i], np.exp(beta_ / doc_length[i])) fun_val -= np.exp(log_H0_[time_order[i]] + np.sum(np.log(product_terms))) return -fun_val self.beta = minimize(obj_fun, self.beta).x # compute dot product <beta, phi_bar[i]> for each patient i # (resulting in a 1D array of length `num_subjects`) beta_phi_bar = np.dot(phi_bar, self.beta) # print("phi bar is \n", phi_bar) # update h0 for r, t in enumerate(sorted_unique_times): R_t = np.where(survival_or_censoring_times >= t, True, False) self.log_h0_discretized[r] = np.log( death_counter[t]) - logsumexp(beta_phi_bar[R_t]) log_H0 = [] for r in range(num_unique_times): log_H0.append(logsumexp(self.log_h0_discretized[:(r + 1)])) log_H0 = np.array(log_H0) toc = time() if self.verbose: print(' Done. Time elapsed: %f second(s).' % (toc - tic), flush=True) print(' Beta:', self.beta) # convergence criteria to decide whether to break out of the for loop early surv_estimates = [] median_estimates = [] for i in range(len(val_W)): preds = self._predict_survival(val_W[i]) surv_estimates.append(preds[1]) median_estimates.append(preds[0]) surv_estimates = np.array(surv_estimates) median_estimates = np.array(median_estimates) val_ev = EvalSurv(pd.DataFrame(np.transpose(surv_estimates), index=sorted_unique_times), val_survival_or_censoring_times, val_censoring_indicator_variables, censor_surv='km') val_cindex = val_ev.concordance_td('antolini') val_rmse = np.sqrt(np.mean((median_estimates[val_censoring_indicator_variables] - \ val_survival_or_censoring_times[val_censoring_indicator_variables])**2)) val_mae = np.mean(np.abs(median_estimates[val_censoring_indicator_variables] - \ val_survival_or_censoring_times[val_censoring_indicator_variables])) if self.verbose: print(' Validation set survival time c-index: %f' % val_cindex) print(' Validation set survival time rmse: %f' % val_rmse) print(' Validation set survival time mae: %f' % val_mae) stop_iter += 1 pool.close() if self.verbose: print(' EM algorithm completes training at iteration ', stop_iter) return
out_survival = utils.predict_survival_lognormal( model, X.to(device), times) elif arg.distribution == "exponient": out_survival = utils.predict_survival_exponient( model, X.to(device), times) elif arg.distribution == "weibull": out_survival = utils.predict_survival_weibull( model, X.to(device), times) elif arg.distribution == "combine": out_survival = utils.predict_survival_multiple_distributions( model, X.to(device), times) surv = pd.DataFrame(out_survival, index=times) durations_test, events_test = T.detach().numpy(), E.detach( ).numpy() ev = EvalSurv(surv, durations_test, events_test, censor_surv='km') c_index_train = ev.concordance_td() ##### valid model.eval() batch_size = 8000 with torch.no_grad(): for X, T, E in dataloader.load_data(arg.dataname, arg.path_clinical_val, 'test'): T = T / ratio #pdf = misc.cal_pdf(T) if arg.distribution == "lognormal":
verbose = True #%%time # Magic command for Jupyter Notebook only log = model.fit(x_train, y_train, batch_size, epochs, callbacks, verbose, val_data=val) _ = log.plot() # Evaluation surv = model.predict_surv_df(x_test) ev = EvalSurv(surv, durations_test, events_test != 0, censor_surv='km') ev.concordance_td() ev.integrated_brier_score(np.linspace(0, durations_test.max(), 100)) cif = model.predict_cif(x_test) cif1 = pd.DataFrame(cif[0], model.duration_index) cif2 = pd.DataFrame(cif[1], model.duration_index) ev1 = EvalSurv(1 - cif1, durations_test, events_test == 1, censor_surv='km') ev2 = EvalSurv(1 - cif2, durations_test, events_test == 2, censor_surv='km') ev1.concordance_td() ev2.concordance_td()
event_col='status', show_progress=False, step_size=.1) elapsed = time.time() - tic print('Time elapsed: %f second(s)' % elapsed) np.savetxt(time_elapsed_filename, np.array(elapsed).reshape(1, -1)) # --------------------------------------------------------------------- # evaluation # sorted_y_test = np.unique(y_test[:, 0]) surv_df = surv_model.predict_survival_function(X_test_std, sorted_y_test) surv = surv_df.values.T ev = EvalSurv(surv_df, y_test[:, 0], y_test[:, 1], censor_surv='km') cindex_td = ev.concordance_td('antolini') print('c-index (td):', cindex_td) linear_predictors = \ surv_model.predict_log_partial_hazard(X_test_std) cindex = concordance_index(y_test[:, 0], -linear_predictors, y_test[:, 1]) print('c-index:', cindex) time_grid = np.linspace(sorted_y_test[0], sorted_y_test[-1], 100) integrated_brier = ev.integrated_brier_score(time_grid) print('Integrated Brier score:', integrated_brier, flush=True) test_set_metrics = [cindex_td, integrated_brier]
ypred_train_NN = model.predict_proba(x_train_NN) ypred_test_NN = model.predict_proba(x_valid_NN) y_pred_valid_surv = np.cumprod((1 - ypred_test_NN), axis=1) y_pred_train_surv = np.cumprod((1 - ypred_train_NN), axis=1) oneyr_surv_train = y_pred_train_surv[:, 50] oneyr_surv_valid = y_pred_valid_surv[:, 50] surv_valid = pd.DataFrame(np.transpose(y_pred_valid_surv)) surv_valid.index = interval_l surv_train = pd.DataFrame(np.transpose(y_pred_train_surv)) surv_train.index = interval_l dict_cv_cindex_train[key] = concordance_index(dataTrain.time, oneyr_surv_train) dict_cv_cindex_valid[key] = concordance_index(dataValid.time, oneyr_surv_valid) ev_valid = EvalSurv(surv_valid, dataValid['time'].values, dataValid['dead'].values, censor_surv='km') scores_test += ev_valid.concordance_td() ev_train = EvalSurv(surv_train, dataTrain['time'].values, dataTrain['dead'].values, censor_surv='km') scores_train += ev_train.concordance_td() cta.append(concordance_index(dataTrain.time, oneyr_surv_train)) cte.append(concordance_index(dataValid.time, oneyr_surv_valid)) ctda.append(ev_train.concordance_td()) ctde.append(ev_valid.concordance_td()) #scores_train += concordance_index(dataTrain.time,oneyr_surv_train) #scores_test += concordance_index(dataValid.time,oneyr_surv_valid) save_loss_png = "loss_cv_" + str(h) + "_KIRC_vK.png"
# Training ====================================================================== epochs = args.epochs callbacks = [tt.callbacks.EarlyStopping()] verbose = True log = model.fit(x_train, y_train, batch_size, epochs, callbacks, verbose, val_data=val, val_batch_size=batch_size) # log = model.fit(x_train, y_train_transformed, batch_size, epochs, callbacks, verbose, val_data = val_transformed, val_batch_size = batch_size) # Evaluation =================================================================== _ = model.compute_baseline_hazards() surv = model.predict_surv_df(x_test) ev = EvalSurv(surv, durations_test, events_test, censor_surv='km') # ctd = ev.concordance_td() ctd = ev.concordance_td() time_grid = np.linspace(durations_test.min(), durations_test.max(), 100) ibs = ev.integrated_brier_score(time_grid) nbll = ev.integrated_nbll(time_grid) val_loss = min(log.monitors['val_'].scores['loss']['score']) wandb.log({'val_loss': val_loss, 'ctd': ctd, 'ibs': ibs, 'nbll': nbll}) wandb.finish()
x_train, y_train = train for dim in hiddens: for lr in lrs: outpath = "./survival/%s_%s_%s_%s_%s"%(mod,scale,dim,lr,seed) if not os.path.exists(outpath): os.mkdir(outpath) in_features = x_train.shape[1] model = initialize_model(dim,labtrans,in_features) model.optimizer.set_lr(0.001) callbacks = [tt.callbacks.EarlyStopping()] log = model.fit(x_train, y_train, batch_size, epochs, callbacks, val_data=val) surv = model.predict_surv_df(x_test) ev = EvalSurv(surv, durations_test, events_test, censor_surv='km') result = pd.DataFrame([[0]*8],columns=["random","model","hiddens", "lr","scalers","c-index","brier","nll"]) result["c-index"] = ev.concordance_td('antolini') print(ev.concordance_td('antolini') ) time_grid = np.linspace(durations_test.min(), durations_test.max(), 100) result["brier"] = ev.integrated_brier_score(time_grid) print(ev.integrated_brier_score(time_grid) ) result["nll"] = ev.integrated_nbll(time_grid) result["lr"] = lr result["model"] = mod result["scaler"] = scale result["random"] = seed result["hiddens"] = dim
def _run_training_loop(self, num_epochs, scheduler, info_freq, log_dir): logger = SummaryWriter(log_dir) log_info = True if info_freq is not None: def print_header(): sub_header = ' Epoch Loss Ctd Loss Ctd' print('-' * (len(sub_header) + 2)) print(' Training Validation') print(' ------------ ------------') print(sub_header) print('-' * (len(sub_header) + 2)) print() print_header() for epoch in range(1, num_epochs + 1): if info_freq is None: print_info = False else: print_info = epoch == 1 or epoch % info_freq == 0 for phase in ['train', 'val']: if phase == 'train': self.model.train() else: self.model.eval() running_losses = [] if print_info or log_info: running_durations = torch.FloatTensor().to(self.device) running_censors = torch.LongTensor().to(self.device) running_risks = torch.FloatTensor().to(self.device) # Iterate over data for data in self.dataloaders[phase]: batch_result = self._process_data_batch(data, phase) loss, risk, time, event = batch_result # Stats running_losses.append(loss.item()) running_durations = torch.cat( (running_durations, time.data.float())) running_censors = torch.cat( (running_censors, event.long().data)) running_risks = torch.cat((running_risks, risk.detach())) epoch_loss = torch.mean(torch.tensor(running_losses)) surv_probs = self._predictions_to_pycox(running_risks, time_points=None) running_durations = running_durations.cpu().numpy() running_censors = running_censors.cpu().numpy() epoch_concord = EvalSurv( surv_probs, running_durations, running_censors, censor_surv='km').concordance_td('adj_antolini') if print_info: if phase == 'train': message = f' {epoch}/{num_epochs}' space = 10 if phase == 'train' else 27 message += ' ' * (space - len(message)) message += f'{epoch_loss:.4f}' space = 19 if phase == 'train' else 36 message += ' ' * (space - len(message)) message += f'{epoch_concord:.3f}' if phase == 'val': print(message) if log_info: self._log_info(phase=phase, logger=logger, epoch=epoch, epoch_loss=epoch_loss, epoch_concord=epoch_concord) if phase == 'val': if scheduler: scheduler.step(epoch_concord) # Record current performance k = list(self.current_perf.keys())[0] self.current_perf['epoch' + str(epoch)] = self.current_perf.pop(k) self.current_perf['epoch' + str(epoch)] = epoch_concord # Deep copy the model for k, v in self.best_perf.items(): if epoch_concord >= v: self.best_perf['epoch' + str(epoch)] = self.best_perf.pop(k) self.best_perf['epoch' + str(epoch)] = epoch_concord self.best_wts['epoch' + str(epoch)] = self.best_wts.pop(k) self.best_wts['epoch' + str(epoch)] = copy.deepcopy( self.model.state_dict()) break
_ = plt.xlabel('Time') # Interpolating the survival estimates because the survival estimates so far are # only defined at the 10 times in the discretization grid and the survival # estimates are therefore a step function rather than a continuous one surv = model.interpolate(10).predict_surv_df(x_test) surv.iloc[:, :5].plot(drawstyle='steps-post') plt.ylabel('S(t | x)') _ = plt.xlabel('Time') # The EvalSurv class contains some useful evaluation criteria for time-to-event prediction. # We set censor_surv = 'km' to state that we want to use Kaplan-Meier for estimating the # censoring distribution. ev = EvalSurv(surv, durations_test, events_test, censor_surv='km') ev.concordance_td('antolini') # Brier Score # We can plot the the IPCW Brier score for a given set of times. # Here we just use 100 time-points between the min and max duration in the test set. # Note that the score becomes unstable for the highest times. # It is therefore common to disregard the rightmost part of the graph. time_grid = np.linspace(durations_test.min(), durations_test.max(), 100) ev.brier_score(time_grid).plot() plt.ylabel('Brier score') _ = plt.xlabel('Time') # Negative binomial log-likelihood
fold_y_train_discrete, batch_size, n_epochs, verbose=False) elapsed = time.time() - tic print('Time elapsed: %f second(s)' % elapsed) np.savetxt(time_elapsed_filename, np.array(elapsed).reshape(1, -1)) surv_model.save_net(model_filename) else: surv_model.load_net(model_filename) elapsed = float(np.loadtxt(time_elapsed_filename)) print('Time elapsed (from previous fitting): ' + '%f second(s)' % elapsed) surv_model.sub = 10 surv_df = surv_model.predict_surv_df(fold_X_val_std) ev = EvalSurv(surv_df, fold_y_val[:, 0], fold_y_val[:, 1], censor_surv='km') sorted_fold_y_val = np.sort(np.unique(fold_y_val[:, 0])) time_grid = np.linspace(sorted_fold_y_val[0], sorted_fold_y_val[-1], 100) surv = surv_df.to_numpy().T cindex_scores.append(ev.concordance_td('antolini')) integrated_brier_scores.append( ev.integrated_brier_score(time_grid)) print(' c-index (td):', cindex_scores[-1]) print(' Integrated Brier score:', integrated_brier_scores[-1]) cross_val_cindex = np.mean(cindex_scores)
def main(data_root, cancer_type, anatomical_location, fold): # Import the RDF graph for PPI network f = open('seen.pkl', 'rb') seen = pickle.load(f) f.close() ##################### f = open('ei.pkl', 'rb') ei = pickle.load(f) f.close() global device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # device = torch.device('cpu') cancer_type_vector = np.zeros((33, ), dtype=np.float32) cancer_type_vector[cancer_type] = 1 cancer_subtype_vector = np.zeros((25, ), dtype=np.float32) for i in CANCER_SUBTYPES[cancer_type]: cancer_subtype_vector[i] = 1 anatomical_location_vector = np.zeros((52, ), dtype=np.float32) anatomical_location_vector[anatomical_location] = 1 cell_type_vector = np.zeros((10, ), dtype=np.float32) cell_type_vector[CELL_TYPES[cancer_type]] = 1 pt_tensor_cancer_type = torch.FloatTensor(cancer_type_vector).to(device) pt_tensor_cancer_subtype = torch.FloatTensor(cancer_subtype_vector).to( device) pt_tensor_anatomical_location = torch.FloatTensor( anatomical_location_vector).to(device) pt_tensor_cell_type = torch.FloatTensor(cell_type_vector).to(device) edge_index = torch.LongTensor(ei).to(device) # Import a dictionary that maps protiens to their coresponding genes by Ensembl database f = open('ens_dic.pkl', 'rb') dicty = pickle.load(f) f.close() dic = {} for d in dicty: key = dicty[d] if key not in dic: dic[key] = {} dic[key][d] = 1 # Build a dictionary from ENSG -- ENST d = {} with open('data1/prot_names1.txt') as f: for line in f: tok = line.split() d[tok[1]] = tok[0] clin = [ ] # for clinical data (i.e. number of days to survive, days to death for dead patients and days to last followup for alive patients) feat_vecs = [ ] # list of lists ([[patient1],[patient2],.....[patientN]]) -- [patientX] = [gene_expression_value, diff_gene_expression_value, methylation_value, diff_methylation_value, VCF_value, CNV_value] suv_time = [ ] # list that include wheather a patient is alive or dead (i.e. 0 for dead and 1 for alive) can_types = ["BRCA_v2"] data_root = '/ibex/scratch/projects/c2014/sara/' for i in range(len(can_types)): # file that contain patients ID with their coressponding 6 differnt files names (i.e. files names for gene_expression, diff_gene_expression, methylation, diff_methylation, VCF and CNV) f = open(data_root + can_types[i] + '.txt') lines = f.read().splitlines() f.close() lines = lines[1:] count = 0 feat_vecs = np.zeros((len(lines), 17186 * 6), dtype=np.float32) i = 0 for l in tqdm(lines): l = l.split('\t') clinical_file = l[6] surv_file = l[2] myth_file = 'myth/' + l[3] diff_myth_file = 'diff_myth/' + l[1] exp_norm_file = 'exp_count/' + l[-1] diff_exp_norm_file = 'diff_exp/' + l[0] cnv_file = 'cnv/' + l[4] + '.txt' vcf_file = 'vcf/' + 'OutputAnnoFile_' + l[ 5] + '.hg38_multianno.txt.dat' # Check if all 6 files are exist for a patient (that's because for some patients, their survival time not reported) all_files = [ myth_file, diff_exp_norm_file, diff_myth_file, exp_norm_file, cnv_file, vcf_file ] for fname in all_files: if not os.path.exists(fname): print('File ' + fname + ' does not exist!') sys.exit(1) # f = open(clinical_file) # content = f.read().strip() # f.close() clin.append(clinical_file) # f = open(surv_file) # content = f.read().strip() # f.close() suv_time.append(surv_file) temp_myth = myth_data(myth_file, seen, d, dic) vec = np.array(get_data(exp_norm_file, diff_exp_norm_file, diff_myth_file, cnv_file, vcf_file, temp_myth, seen, dic), dtype=np.float32) vec = vec.flatten() # vec = np.concatenate([ # vec, cancer_type_vector, cancer_subtype_vector, # anatomical_location_vector, cell_type_vector]) feat_vecs[i, :] = vec i += 1 min_max_scaler = MinMaxScaler(clip=True) labels_days = [] labels_surv = [] for days, surv in zip(clin, suv_time): # if days.replace("-", "") != "": # days = float(days) # else: # days = 0.0 labels_days.append(float(days)) labels_surv.append(float(surv)) # Train by batch dataset = feat_vecs #print(dataset.shape) labels_days = np.array(labels_days) labels_surv = np.array(labels_surv) censored_index = [] uncensored_index = [] for i in range(len(dataset)): if labels_surv[i] == 1: censored_index.append(i) else: uncensored_index.append(i) model = CoxPH(MyNet(edge_index).to(device), tt.optim.Adam(0.0001)) censored_index = np.array(censored_index) uncensored_index = np.array(uncensored_index) # Each time test on a specific cancer type # total_cancers = ["TCGA-BRCA"] # for i in range(len(total_cancers)): # test_set = [d for t, d in zip(total_cancers, dataset) if t == total_cancers[i]] # train_set = [d for t, d in zip(total_cancers, dataset) if t != total_cancers[i]] # Censored split n = len(censored_index) index = np.arange(n) i = n // 5 np.random.seed(seed=0) np.random.shuffle(index) if fold < 4: ctest_idx = index[fold * i:fold * i + i] ctrain_idx = index[:fold * i] + index[fold * i + i:] else: ctest_idx = index[fold * i:] ctrain_idx = index[:fold * i] ctrain_n = len(ctrain_idx) cvalid_n = ctrain_n // 10 cvalid_idx = ctrain_idx[:cvalid_n] ctrain_idx = ctrain_idx[cvalid_n:] # Uncensored split n = len(uncensored_index) index = np.arange(n) i = n // 5 np.random.seed(seed=0) np.random.shuffle(index) if fold < 4: utest_idx = index[fold * i:fold * i + i] utrain_idx = index[:fold * i] + index[fold * i + i:] else: utest_idx = index[fold * i:] utrain_idx = index[:fold * i] utrain_n = len(utrain_idx) uvalid_n = utrain_n // 10 uvalid_idx = utrain_idx[:uvalid_n] utrain_idx = utrain_idx[uvalid_n:] train_idx = np.concatenate(censored_index[ctrain_idx], uncensored_index[utrain_idx]) np.random.seed(seed=0) np.random.shuffle(train_idx) valid_idx = np.concatenate(censored_index[cvalid_idx], uncensored_index[uvalid_idx]) np.random.seed(seed=0) np.random.shuffle(valid_idx) test_idx = np.concatenate(censored_index[ctest_idx], uncensored_index[utest_idx]) np.random.seed(seed=0) np.random.shuffle(test_idx) train_data = dataset[train_idx] train_data = min_max_scaler.fit_transform(train_data) train_labels_days = labels_days[train_idx] train_labels_surv = labels_surv[train_idx] train_labels = (train_labels_days, train_labels_surv) val_data = dataset[valid_idx] val_data = min_max_scaler.transform(val_data) val_labels_days = labels_days[valid_idx] val_labels_surv = labels_surv[valid_idx] test_data = dataset[test_idx] test_data = min_max_scaler.transform(test_data) test_labels_days = labels_days[test_idx] test_labels_surv = labels_surv[test_idx] val_labels = (val_labels_days, val_labels_surv) print(val_labels) callbacks = [tt.callbacks.EarlyStopping()] batch_size = 16 epochs = 100 val = (val_data, val_labels) log = model.fit(train_data, train_labels, batch_size, epochs, callbacks, True, val_data=val, val_batch_size=batch_size) log.plot() plt.show() # print(model.partial_log_likelihood(*val).mean()) train = train_data, train_labels # Compute the evaluation measurements model.compute_baseline_hazards(*train) surv = model.predict_surv_df(test_data) print(surv) ev = EvalSurv(surv, test_labels_days, test_labels_surv) print(ev.concordance_td())