# Training ====================================================================== batch_size = args.batch_size lr_finder = model.lr_finder(x_train, y_train, batch_size, tolerance=3) best = lr_finder.get_best_lr() model.optimizer.set_lr(best) epochs = args.epochs callbacks = [tt.callbacks.EarlyStopping()] verbose = True log = model.fit(x_train, y_train, batch_size, epochs, callbacks, val_data=val) # Evaluation =================================================================== surv = model.interpolate(10).predict_surv_df(x_test) ev = EvalSurv(surv, durations_test, events_test, censor_surv='km') # ctd = ev.concordance_td() ctd = ev.concordance_td('antolini') time_grid = np.linspace(durations_test.min(), durations_test.max(), 100) ibs = ev.integrated_brier_score(time_grid) nbll = ev.integrated_nbll(time_grid) val_loss = min(log.monitors['val_'].scores['loss']['score']) wandb.log({
'%s_%s_exp%d_bs%d_nep%d_nla%d_nno%d_' % (init_survival_estimator_name, dataset, experiment_idx, init_batch_size, init_n_epochs, n_layers, n_nodes) + 'lr%f_a%f_s%f_nd%d_test.pt' % (init_lr, init_alpha, init_sigma, init_num_durations)) time_elapsed_filename = model_filename[:-3] + '_time.txt' print('*** Pre-training...') # assert os.path.isfile(model_filename) if not os.path.isfile(model_filename): # print('*** Fitting with hyperparam:', hyperparam, flush=True) surv_model.fit(X_train_std, y_train_discrete, init_batch_size, init_n_epochs, verbose=False) elapsed = time.time() - tic # print('Time elapsed: %f second(s)' % elapsed) np.savetxt(time_elapsed_filename, np.array(elapsed).reshape(1, -1)) surv_model.save_net(model_filename) else: # print('*** Loading ***', flush=True) surv_model.load_net(model_filename) elapsed = float(np.loadtxt(time_elapsed_filename)) # print('Time elapsed (from previous fitting): %f second(s)' # % elapsed) surv_model.net.train()
os.path.join(output_dir, 'models', '%s_%s_exp%d_bs%d_nep%d_nla%d_nno%d_' % (survival_estimator_name, dataset, experiment_idx, batch_size, n_epochs, n_layers, n_nodes) + 'lr%f_a%f_s%f_nd%d_cv%d.pt' % (lr, alpha, sigma, num_durations, cross_val_idx)) time_elapsed_filename = model_filename[:-3] + '_time.txt' if not os.path.isfile(model_filename): # print('*** Fitting with hyperparam:', hyperparam, # '-- cross val index:', cross_val_idx, flush=True) surv_model.fit(fold_X_train_std, fold_y_train_discrete, batch_size, n_epochs, verbose=False) elapsed = time.time() - tic print('Time elapsed: %f second(s)' % elapsed) np.savetxt(time_elapsed_filename, np.array(elapsed).reshape(1, -1)) surv_model.save_net(model_filename) else: # print('*** Loading ***', flush=True) surv_model.load_net(model_filename) elapsed = float(np.loadtxt(time_elapsed_filename)) print('Time elapsed (from previous fitting): ' + '%f second(s)' % elapsed) surv_df = \
os.path.join(output_dir, 'models', '%s_%s_exp%d_%s_bs%d_nep%d_nla%d_nno%d_' % (survival_estimator_name, dataset, experiment_idx, val_string, batch_size, n_epochs, n_layers, n_nodes) + 'lr%f_a%f_s%f_nd%d_fold%d.pt' % (lr, alpha, sigma, num_durations, fold_idx)) time_elapsed_filename = model_filename[:-3] + '_time.txt' if not os.path.isfile(model_filename): if use_early_stopping and not use_cross_val: surv_model.fit(fold_X_train_std, fold_y_train_discrete, batch_size, n_epochs, [tt.callbacks.EarlyStopping()], val_data=(fold_X_val_std, fold_y_val_discrete), verbose=False) else: surv_model.fit(fold_X_train_std, fold_y_train_discrete, batch_size, n_epochs, verbose=False) elapsed = time.time() - tic print('Time elapsed: %f second(s)' % elapsed) np.savetxt(time_elapsed_filename, np.array(elapsed).reshape(1, -1)) surv_model.save_net(model_filename) else:
class DeepHit_pycox(): def __init__(self, nodes_per_layer, layers, dropout, weight_decay, batch_size, \ num_durations, alpha, sigma, lr=0.0001, seed=47): # set seed np.random.seed(seed) _ = torch.manual_seed(seed) self.standardalizer = None self.standardize_data = True self._duration_col = "duration" self._event_col = "event" self.in_features = None self.batch_norm = True self.output_bias = False self.activation = torch.nn.ReLU self.epochs = 512 self.num_workers = 2 self.callbacks = [tt.callbacks.EarlyStopping()] # parameters tuned self.alpha = alpha self.sigma = 10**sigma self.num_nodes = [int(nodes_per_layer) for _ in range(int(layers))] self.dropout = dropout self.weight_decay = weight_decay self.lr = lr self.batch_size = int(batch_size) self.num_durations = int(num_durations) self.labtrans = DeepHitSingle.label_transform(self.num_durations) def set_standardize(self, standardize_bool): self.standardize_data = standardize_bool def _format_to_pycox(self, X, Y, F): # from numpy to pandas df df = pd.DataFrame(data=X, columns=F) if Y is not None: df[self._duration_col] = Y[:, 0] df[self._event_col] = Y[:, 1] return df def _standardize_df(self, df, flag): # if flag = test, the df passed in does not contain Y labels if self.standardize_data: # standardize x df_x = df if flag == 'test' else df.drop( columns=[self._duration_col, self._event_col]) if flag == "train": cols_leave = [] cols_standardize = [] for column in df_x.columns: if set(pd.unique(df[column])) == set([0, 1]): cols_leave.append(column) else: cols_standardize.append(column) standardize = [([col], StandardScaler()) for col in cols_standardize] leave = [(col, None) for col in cols_leave] self.standardalizer = DataFrameMapper(standardize + leave) x = self.standardalizer.fit_transform(df_x).astype('float32') # standardize y y = self.labtrans.fit_transform(df[self._duration_col].values.astype('float32'), \ df[self._event_col].values.astype('float32')) elif flag == "val": x = self.standardalizer.transform(df_x).astype('float32') y = self.labtrans.transform(df[self._duration_col].values.astype('float32'), \ df[self._event_col].values.astype('float32')) elif flag == "test": x = self.standardalizer.transform(df_x).astype('float32') y = None else: raise NotImplementedError return x, y else: raise NotImplementedError def fit(self, X, y, column_names): # format data self.column_names = column_names full_df = self._format_to_pycox(X, y, self.column_names) val_df = full_df.sample(frac=0.2) train_df = full_df.drop(val_df.index) train_x, train_y = self._standardize_df(train_df, "train") val_x, val_y = self._standardize_df(val_df, "val") # configure model self.in_features = train_x.shape[1] self.out_features = self.labtrans.out_features net = tt.practical.MLPVanilla( in_features=self.in_features, num_nodes=self.num_nodes, out_features=self.out_features, batch_norm=self.batch_norm, dropout=self.dropout, activation=self.activation, output_bias=self.output_bias, w_init_=lambda w: torch.nn.init.xavier_normal_(w)) self.model = DeepHitSingle(net, tt.optim.Adam(lr=self.lr, weight_decay=self.weight_decay), \ alpha=self.alpha, sigma=self.sigma, duration_index=self.labtrans.cuts) n_train = train_x.shape[0] while n_train % self.batch_size == 1: # this will cause issues in batch norm self.batch_size += 1 self.model.fit(train_x, train_y, self.batch_size, self.epochs, self.callbacks, verbose=True, val_data=(val_x, val_y), val_batch_size=self.batch_size, num_workers=self.num_workers) def predict(self, test_x, time_list): # format data test_df = self._format_to_pycox(test_x, None, self.column_names) test_x, _ = self._standardize_df(test_df, "test") proba_matrix_ = self.model.interpolate().predict_surv_df(test_x) time_list_ = list(proba_matrix_.index) proba_matrix = np.transpose(proba_matrix_.values) pred_medians = [] median_time = max(time_list_) # if the predicted proba never goes below 0.5, predict the largest seen value for test_idx, survival_proba in enumerate(proba_matrix): # the survival_proba is in descending order for col, proba in enumerate(survival_proba): if proba > 0.5: continue if proba == 0.5 or col == 0: median_time = time_list_[col] else: median_time = (time_list_[col - 1] + time_list_[col]) / 2 break pred_medians.append(median_time) return np.array(pred_medians), proba_matrix_
def _train_dht(x, t, e, folds, params): """Helper Function to train a deep-hit model (van der schaar et. al). Args: x: a numpy array of input features (Training Data). t: a numpy vector of event times (Training Data). e: a numpy vector of event indicators (1 if event occured, 0 otherwise) (Training Data). folds: vector of the training cv folds. Returns: Trained pycox.DeepHitSingle model. """ if params is None: num_nodes = [100, 100] lr = 1e-3 bs = 128 else: num_nodes = params['num_nodes'] lr = params['lr'] bs = params['bs'] x = x.astype('float32') t = t.astype('float32') e = e.astype('int32') # num_durations = int(0.5*max(t)) # print ("num_durations:", num_durations) num_durations = int(max(t)) #num_durations = int(30) print("num_durations:", num_durations) labtrans = DeepHitSingle.label_transform(num_durations, scheme='quantiles') #labtrans = DeepHitSingle.label_transform(num_durations,) #print (labtrans) in_features = x.shape[1] batch_norm = False dropout = 0.0 output_bias = False fold_model = {} for f in set(folds): xf = x[folds != f] tf = t[folds != f] ef = e[folds != f] validx = sorted( np.random.choice(len(xf), size=(int(0.15 * len(xf))), replace=False)) vidx = np.array([False] * len(xf)) vidx[validx] = True y_train = labtrans.fit_transform(tf[~vidx], ef[~vidx]) y_val = labtrans.transform(tf[vidx], ef[vidx]) out_features = labtrans.out_features net = ttup.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm, dropout) model = DeepHitSingle(net, ttup.optim.Adam, alpha=0.5, sigma=1, duration_index=labtrans.cuts) y_train = y_train[0].astype('int64'), y_train[1].astype('float32') y_val = y_val[0].astype('int64'), y_val[1].astype('float32') val = xf[vidx], y_val train = xf[~vidx], y_train batch_size = bs model.optimizer.set_lr(lr) epochs = 10 callbacks = [ttup.callbacks.EarlyStopping()] model.fit( xf[~vidx], y_train, batch_size, epochs, callbacks, True, val_data=val, ) fold_model[f] = model return fold_model