def train_deepsurv(data_df, r_splits): epochs = 100 verbose = True num_nodes = [32] out_features = 1 batch_norm = True dropout = 0.6 output_bias = False c_index_at = [] c_index_30 = [] time_auc_30 = [] time_auc_60 = [] time_auc_365 = [] for i in range(len(r_splits)): print("\nIteration %s"%(i)) #DATA PREP df_train, df_val, df_test, df_test_30 = prepare_datasets(data_df, r_splits[i][2], r_splits[i][1], r_splits[i][0]) xcols = list(df_train.columns) for col_name in ["subject_id", "event", "duration"]: if col_name in xcols: xcols.remove(col_name) cols_standardize = xcols standardize = [([col], StandardScaler()) for col in cols_standardize] x_mapper = DataFrameMapper(standardize) x_train = x_mapper.fit_transform(df_train).astype('float32') x_val = x_mapper.transform(df_val).astype('float32') x_test = x_mapper.transform(df_test).astype('float32') x_test_30 = x_mapper.transform(df_test_30).astype('float32') labtrans = CoxTime.label_transform() get_target = lambda df: (df['duration'].values, df['event'].values) y_train = labtrans.fit_transform(*get_target(df_train)) y_val = labtrans.transform(*get_target(df_val)) durations_test, events_test = get_target(df_test) durations_test_30, events_test_30 = get_target(df_test_30) val = tt.tuplefy(x_val, y_val) (train_x, train_y), (val_x, val_y), (test_x, test_y), _ = df2array(data_df, df_train, df_val, df_test, df_test_30) #MODEL in_features = x_train.shape[1] callbacks = [tt.callbacks.EarlyStopping()] net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm, dropout, output_bias=output_bias) model = CoxPH(net, tt.optim.Adam) model.optimizer.set_lr(0.0001) if x_train.shape[0] % 2: batch_size = 255 else: batch_size = 256 log = model.fit(x_train, y_train, batch_size, epochs, callbacks, val_data=val, val_batch_size=batch_size) model.compute_baseline_hazards() surv = model.predict_surv_df(x_test) ev = EvalSurv(surv, durations_test, events_test, censor_surv='km') c_index_at.append(ev.concordance_td()) surv_30 = model.predict_surv_df(x_test_30) ev_30 = EvalSurv(surv_30, durations_test_30, events_test_30, censor_surv='km') c_index_30.append(ev_30.concordance_td()) for time_x in [30, 60, 365]: va_auc, va_mean_auc = cumulative_dynamic_auc(train_y, test_y, model.predict(x_test).flatten(), time_x) eval("time_auc_" + str(time_x)).append(va_auc[0]) print("C-index_30:", c_index_30[i]) print("C-index_AT:", c_index_at[i]) print("time_auc_30", time_auc_30[i]) print("time_auc_60", time_auc_60[i]) print("time_auc_365", time_auc_365[i]) return c_index_at, c_index_30, time_auc_30, time_auc_60, time_auc_365
lrfinder = model.lr_finder(x_train, y_train, batch_size, tolerance=10) best = lrfinder.get_best_lr() model.optimizer.set_lr(best) epochs = args.epochs callbacks = [tt.callbacks.EarlyStopping(patience=patience)] verbose = True log = model.fit(x_train, y_train_transformed, batch_size, epochs, callbacks, verbose, val_data = val_transformed, val_batch_size = batch_size) # Evaluation =================================================================== val_loss = min(log.monitors['val_'].scores['loss']['score']) # get Ctd ctd = concordance_index(event_times = durations_test_transformed, predicted_scores = model.predict(x_test).reshape(-1), event_observed = events_test) # set time grid for numerical integration to get IBS and IBLL if durations_test.min()>0: time_grid = np.linspace(durations_test.min(), durations_test.max(), 100) else: durations_test_copy = durations_test.copy() durations_test_copy.sort() time_grid = np.linspace(durations_test_copy[1], durations_test.max(), 100) # time_grid = np.linspace(durations_test.min(), durations_test.max(), 100) # transform time grid into DSAFT scale for fair comparison # pdb.set_trace() time_grid = np.exp(scaler_train.transform(np.log(time_grid.reshape(-1, 1)))).reshape(-1) # grid interval for numerical integration ds = np.array(time_grid - np.array([0.0] + time_grid[:-1].tolist()))
def train_LSTMCox(data_df, r_splits): epochs = 100 verbose = True in_features = 768 out_features = 1 batch_norm = True dropout = 0.6 output_bias = False c_index_at = [] c_index_30 = [] time_auc_30 = [] time_auc_60 = [] time_auc_365 = [] for i in range(len(r_splits)): print("\nIteration %s"%(i)) #DATA PREP df_train, df_val, df_test, df_test_30 = prepare_datasets(data_df, r_splits[i][2], r_splits[i][1], r_splits[i][0]) x_train = np.array(df_train["x0"].tolist()).astype("float32") x_val = np.array(df_val["x0"].tolist()).astype("float32") x_test = np.array(df_test["x0"].tolist()).astype("float32") x_test_30 = np.array(df_test_30["x0"].tolist()).astype("float32") labtrans = CoxTime.label_transform() get_target = lambda df: (df['duration'].values, df['event'].values) y_train = labtrans.fit_transform(*get_target(df_train)) y_val = labtrans.transform(*get_target(df_val)) durations_test, events_test = get_target(df_test) durations_test_30, events_test_30 = get_target(df_test_30) val = tt.tuplefy(x_val, y_val) (train_x, train_y), (val_x, val_y), (test_x, test_y), _ = df2array(data_df, df_train, df_val, df_test, df_test_30) #MODEL callbacks = [tt.callbacks.EarlyStopping()] net = LSTMCox(768, 32, 1, 1) model = CoxPH(net, tt.optim.Adam) model.optimizer.set_lr(0.0001) if x_train.shape[0] % 2: batch_size = 255 else: batch_size = 256 log = model.fit(x_train, y_train, batch_size, epochs, callbacks, val_data=val, val_batch_size=batch_size) model.compute_baseline_hazards() surv = model.predict_surv_df(x_test) ev = EvalSurv(surv, durations_test, events_test, censor_surv='km') c_index_at.append(ev.concordance_td()) surv_30 = model.predict_surv_df(x_test_30) ev_30 = EvalSurv(surv_30, durations_test_30, events_test_30, censor_surv='km') c_index_30.append(ev_30.concordance_td()) for time_x in [30, 60, 365]: va_auc, va_mean_auc = cumulative_dynamic_auc(train_y, test_y, model.predict(x_test).flatten(), time_x) eval("time_auc_" + str(time_x)).append(va_auc[0]) print("C-index_30:", c_index_30[i]) print("C-index_AT:", c_index_at[i]) print("time_auc_30", time_auc_30[i]) print("time_auc_60", time_auc_60[i]) print("time_auc_365", time_auc_365[i]) return c_index_at, c_index_30, time_auc_30, time_auc_60, time_auc_365