]) Censored = Data['Censored'].reshape([ N, ]) fnames = Data[dtype + '_Symbs'] fnames = [j.split(' ')[0] for j in fnames] Data = None #%% # Get result files #============================================================================== # Getting at-risk groups t_batch, o_batch, at_risk_batch, x_batch = \ sUtils.calc_at_risk(Survival, 1-Censored, Features) sys.exit() #%% # Get mask (to be multiplied by Pij) ****************** n_batch = t_batch.shape[0] Pij_mask = np.zeros((n_batch, n_batch)) # Get difference in outcomes between all cases if mask_type == 'observed': outcome_diff = np.abs(t_batch[None, :] - t_batch[:, None]) for idx in range(n_batch):
import sys sys.path.append('/home/mohamed/Desktop/CooperLab_Research/KNN_Survival/Codes') import SurvivalUtils as sUtils import tensorflow as tf import numpy as np #%% # # Generate simulated data # n = 30; d = 140 X_input = np.random.rand(n, d) T = np.random.randint(0, 300, [n,]) C = np.random.randint(0, 2, [n,]) T, O, at_risk, X_input = sUtils.calc_at_risk(T, 1-C, X_input) #%% # ----------------------------- # Add to graph (for demo) tf.reset_default_graph() X_input = tf.Variable(X_input) T = tf.Variable(T, dtype='float32') O = tf.Variable(O, dtype='float32') at_risk = tf.Variable(at_risk) # for now, let's assume we already NCA_transformed X X_transformed = X_input # no of feats and split size
def predict(self, neighbor_idxs, Survival_train, Censored_train, Survival_test=None, Censored_test=None, K=15, Method='non-cumulative'): """ Predict testing set using 'prototype' (i.e. training) set using KNN neighbor_idxs - indices of nearest neighbors; (N_test, N_train) Survival_train - training sample time-to-event; (N,) np array Censored_train - training sample censorship status; (N,) np array K - number of nearest-neighbours to use, int Method - cumulative vs non-cumulative probability """ # Keep only desired K neighbor_idxs = neighbor_idxs[:, 0:K] # Initialize N_test = neighbor_idxs.shape[0] T_test = np.zeros([N_test]) if Method == 'non-cumulative': # Convert outcomes to "alive status" at each time point alive_train = sUtils.getAliveStatus(Survival_train, Censored_train) # Get survival prediction for each patient for idx in range(N_test): status = alive_train[neighbor_idxs[idx, :], :] totalKnown = np.sum(status >= 0, axis=0) status[status < 0] = 0 # remove timepoints where there are no known statuses status = status[:, totalKnown != 0] totalKnown = totalKnown[totalKnown != 0] # get "average" predicted survival time status = np.sum(status, axis=0) / totalKnown # now get overall time prediction T_test[idx] = np.sum(status) elif Method == 'cumulative': for idx in range(N_test): # Get at-risk groups for each time point for nearest neighbors T = Survival_train[neighbor_idxs[idx, :]] O = 1 - Censored_train[neighbor_idxs[idx, :]] T, O, at_risk, _ = sUtils.calc_at_risk(T, O) N_at_risk = K - at_risk # Calcuate cumulative probability of survival P = np.cumprod((N_at_risk - O) / N_at_risk) # now get overall time prediction T_test[idx] = np.sum(P) else: raise ValueError( "Method is either 'cumulative' or 'non-cumulative'.") # Get c-index #====================================================================== CI = 0 if Survival_test is not None: assert (Censored_test is not None) CI = sUtils.c_index(T_test, Survival_test, Censored_test, prediction_type='survival_time') return T_test, CI
Censored = np.int32(Data['Censored']).reshape([ N, ]) fnames = Data['Integ_Symbs'] #fnames = Data['Gene_Symbs'] # remove zero-variance features fvars = np.std(Features, 0) keep = fvars > 0 Features = Features[:, keep] fnames = fnames[keep] N, D = Features.shape # after feature removal # Getting at-risk groups (trainign set) Features, Survival, Observed, at_risk = \ sUtils.calc_at_risk(Features, Survival, 1-Censored) ## Limit N (for prototyping) ---- #n = 100 #Features = Features[0:n, :] #Survival = Survival[0:n] #Observed = Observed[0:n] #at_risk = at_risk[0:n] #-------------------------------- # ************************************************************* # Z-scoring survival to prevent numerical errors Survival = (Survival - np.mean(Survival)) / np.std(Survival) # ************************************************************* #%%============================================================================
def train(self, features, survival, censored, features_valid=None, survival_valid=None, censored_valid=None, COMPUT_GRAPH_PARAMS={}, BATCH_SIZE=20, PLOT_STEP=10, MODEL_SAVE_STEP=10, MAX_ITIR=100): """ train a survivalNCA model features - (N,D) np array survival and censored - (N,) np array """ #pUtils.Log_and_print("Training survival NCA model.") # Initial preprocessing and sanity checks #====================================================================== #pUtils.Log_and_print("Initial preprocessing.") assert len(features.shape) == 2 assert len(survival.shape) == 1 assert len(censored.shape) == 1 USE_VALID = False if features_valid is not None: USE_VALID = True assert (features_valid.shape[1] == features.shape[1]) assert (survival_valid is not None) assert (censored_valid is not None) # normalize (for numeric stability) epsilon = 1e-10 survival = (survival / self.T_MAX) + epsilon if USE_VALID: survival_valid = (survival_valid / self.T_MAX) + epsilon # Define computational graph #====================================================================== COMPUT_GRAPH_PARAMS['dim_input'] = features.shape[1] graph = self._build_computational_graph(COMPUT_GRAPH_PARAMS) # Begin session #====================================================================== #print("Running TF session.") #pUtils.Log_and_print("Running TF session.") with tf.Session() as sess: # Initial ground work #================================================================== # op to save/restore all the variables saver = tf.train.Saver() if "checkpoint" in os.listdir(self.WEIGHTPATH): # load existing weights #pUtils.Log_and_print("Restoring saved model ...") saver.restore(sess, self.WEIGHTPATH + self.description + ".ckpt") #pUtils.Log_and_print("Model restored.") else: # start a new model sess.run(tf.global_variables_initializer()) # for tensorboard visualization #train_writer = tf.summary.FileWriter(self.RESULTPATH + 'model/tensorboard', # sess.graph) # Define some methods #================================================================== # periodically save model def _saveTFmodel(): """Saves model weights using tensorflow saver""" # save weights #pUtils.Log_and_print("\nSaving TF model weights...") #save_path = saver.save(sess, \ # self.WEIGHTPATH + self.description + ".ckpt") #pUtils.Log_and_print("Model saved in file: %s" % save_path) # save attributes self.save() # monitor def _monitorProgress(): """Monitor cost""" cs = np.array(self.Costs_epochLevel_train) epoch_no = np.arange(len(cs)) cs = np.concatenate((epoch_no[:, None], cs), axis=1) cs_valid = None if USE_VALID: cs_valid = np.array(self.Costs_epochLevel_valid) #timestamp = str(datetime.datetime.today()).replace(' ','_') #timestamp.replace(":", '_') #self._plotMonitor(arr= cs, arr2= cs_valid, # title= "cost vs. epoch", # xlab= "epoch", ylab= "cost", # savename= self.RESULTPATH + "plots/" + # self.description + "cost_" + timestamp + ".svg") # Begin epochs #================================================================== try: itir = 0 #print("\n\tepoch\tbatch\tcost") #print("\t-----------------------") while itir < MAX_ITIR: #pUtils.Log_and_print("\n\tTraining epoch {}\n".format(self.EPOCHS_RUN)) itir += 1 cost_tot = 0 cost_tot_valid = 0 # Shuffle so that training batches differ every epoch #========================================================== idxs = np.arange(features.shape[0]) np.random.shuffle(idxs) features = features[idxs, :] survival = survival[idxs] censored = censored[idxs] # Divide into balanced batches #========================================================== # Get balanced batches (if relevant) if BATCH_SIZE < censored.shape[0]: batchIdxs = dm.get_balanced_batches( censored, BATCH_SIZE=BATCH_SIZE) else: batchIdxs = [np.arange(censored.shape[0])] if USE_VALID: batchIdxs_valid = \ dm.get_balanced_batches(censored_valid, BATCH_SIZE = BATCH_SIZE) # Run over training set #========================================================== for batchidx, batch in enumerate(batchIdxs): # Getting at-risk groups t_batch, o_batch, at_risk_batch, x_batch = \ sUtils.calc_at_risk(survival[batch], 1-censored[batch], features[batch, :]) # run optimizer and fetch cost feed_dict = { graph.X_input: x_batch, graph.T: t_batch, graph.O: o_batch, graph.At_Risk: at_risk_batch, } _, cost = sess.run([graph.optimizer, graph.cost], \ feed_dict = feed_dict) # normalize cost for sample size cost = cost / len(batch) # record/append cost #self.Costs_batchLevel_train.append(cost) cost_tot += cost #print("\t{}\t{}\t{}".format(self.EPOCHS_RUN, batchidx, round(cost[0], 3))) #pUtils.Log_and_print("\t\tTraining: Batch {} of {}, cost = {}".\ # format(batchidx, len(batchIdxs)-1, round(cost[0], 3))) # Run over validation set #========================================================== if USE_VALID: for batchidx, batch in enumerate(batchIdxs_valid): # Getting at-risk groups t_batch, o_batch, at_risk_batch, x_batch = \ sUtils.calc_at_risk(survival[batch], 1-censored[batch], features[batch, :]) # fetch cost feed_dict = { graph.X_input: x_batch, graph.T: t_batch, graph.O: o_batch, graph.At_Risk: at_risk_batch, } cost = sess.run(graph.cost, feed_dict=feed_dict) # normalize cost for sample size cost = cost / len(batch) # record/append cost #self.Costs_batchLevel_valid.append(cost) cost_tot_valid += cost #pUtils.Log_and_print("\t\tValidation: Batch {} of {}, cost = {}".\ # format(batchidx, len(batchIdxs_valid)-1, round(cost[0], 3))) # Update and save #========================================================== # update epochs and append costs self.EPOCHS_RUN += 1 self.Costs_epochLevel_train.append(cost_tot) if USE_VALID: self.Costs_epochLevel_valid.append(cost_tot_valid) # periodically save model #if (self.EPOCHS_RUN % MODEL_SAVE_STEP) == 0: # _saveTFmodel() # periodically monitor progress if (self.EPOCHS_RUN % PLOT_STEP == 0) and \ (self.EPOCHS_RUN > 0): _monitorProgress() except KeyboardInterrupt: pass # save final model and plot costs #_saveTFmodel() _monitorProgress() #pUtils.Log_and_print("Finished training model.") #pUtils.Log_and_print("Obtaining final results.") # save learned weights W = sess.run(graph.W, feed_dict=feed_dict) np.save(self.RESULTPATH + 'model/' + self.description + \ 'featWeights.npy', W) return W
def train(self, features, survival, censored, features_valid=None, survival_valid=None, censored_valid=None, graph_hyperparams={}, BATCH_SIZE=20, PLOT_STEP=10, MODEL_SAVE_STEP=10, MAX_ITIR=100, MODEL_BUFFER=4, EARLY_STOPPING=False, MONITOR=True, PLOT=True, K=35, Method='cumulative-time', norm=2): """ train a survivalNCA model features - (N,D) np array survival and censored - (N,) np array """ #pUtils.Log_and_print("Training survival NCA model.") # Initial preprocessing and sanity checks #====================================================================== #pUtils.Log_and_print("Initial preprocessing.") D = features.shape[1] assert len(features.shape) == 2 assert len(survival.shape) == 1 assert len(censored.shape) == 1 USE_VALID = False if features_valid is not None: USE_VALID = True assert (features_valid.shape[1] == D) assert (survival_valid is not None) assert (censored_valid is not None) if EARLY_STOPPING: assert USE_VALID # Define computational graph #====================================================================== graph_hyperparams = \ pUtils.Merge_dict_with_default(\ dict_given = graph_hyperparams, dict_default = self.default_graph_hyperparams, keys_Needed = self.userspecified_graph_hyperparams) # Begin session #====================================================================== #print("Running TF session.") #pUtils.Log_and_print("Running TF session.") with tf.Session() as sess: # Initial ground work #================================================================== # op to save/restore all the variables saver = tf.train.Saver() if "checkpoint" in os.listdir(self.WEIGHTPATH): # load existing weights #pUtils.Log_and_print("Restoring saved model ...") saver.restore(sess, self.WEIGHTPATH + self.description + ".ckpt") #pUtils.Log_and_print("Model restored.") else: # start a new model sess.run(tf.global_variables_initializer()) # for tensorboard visualization #train_writer = tf.summary.FileWriter(self.RESULTPATH + 'model/tensorboard', # sess.self.graph) # Define some methods #================================================================== # periodically save model def _saveTFmodel(): """Saves model weights using tensorflow saver""" # save weights #pUtils.Log_and_print("\nSaving TF model weights...") #save_path = saver.save(sess, \ # self.WEIGHTPATH + self.description + ".ckpt") #pUtils.Log_and_print("Model saved in file: %s" % save_path) # save attributes self.save() # monitor def _monitorProgress(snapshot_idx=None): """ Monitor cost - save txt and plots cost """ # find min epochs to display in case of keyboard interrupt max_epoch = np.min([ len(self.Costs_epochLevel_train), len(self.CIs_train), len(self.CIs_valid) ]) # concatenate costs costs = np.array(self.Costs_epochLevel_train[0:max_epoch]) cis_train = np.array(self.CIs_train[0:max_epoch]) if USE_VALID: cis_valid = np.array(self.CIs_valid[0:max_epoch]) else: cis_valid = None epoch_no = np.arange(max_epoch) costs = np.concatenate((epoch_no[:, None], costs[:, None]), axis=1) cis_train = np.concatenate( (epoch_no[:, None], cis_train[:, None]), axis=1) # Saving raw numbers for later reference savename = self.RESULTPATH + "plots/" + self.description + self.timestamp with open(savename + '_costs.txt', 'wb') as f: np.savetxt(f, costs, fmt='%s', delimiter='\t') with open(savename + '_cis_train.txt', 'wb') as f: np.savetxt(f, cis_train, fmt='%s', delimiter='\t') if USE_VALID: with open(savename + '_cis_valid.txt', 'wb') as f: np.savetxt(f, cis_valid, fmt='%s', delimiter='\t') # # Note, plotting would not work when running # this using screen (Xdisplay is not supported) # if PLOT: self._plotMonitor(arr=costs, title="Cost vs. epoch", xlab="epoch", ylab="Cost", savename=savename + "_costs.svg") self._plotMonitor(arr=cis_train, arr2=cis_valid, title="C-index vs. epoch", xlab="epoch", ylab="C-index", savename=savename + "_Ci.svg", snapshot_idx=snapshot_idx) # Begin epochs #================================================================== try: itir = 0 if MONITOR: print("\n\tepoch\tcost\tCi_train\tCi_valid") print("\t----------------------------------------------") knnmodel = knn.SurvivalKNN(self.RESULTPATH, description=self.description) # Initialize weights buffer # (keep a snapshot of model for early stopping) # each "channel" in 3rd dim is one snapshot of the model if USE_VALID: Ws = np.zeros((D, D, MODEL_BUFFER)) Cis = [] while itir < MAX_ITIR: #pUtils.Log_and_print("\n\tTraining epoch {}\n".format(self.EPOCHS_RUN)) itir += 1 cost_tot = 0 self._update_timestamp() # Divide into balanced batches #========================================================== n = censored.shape[0] # Get balanced batches (if relevant) if BATCH_SIZE < n: # Shuffle so that training batches differ every epoch idxs = np.arange(features.shape[0]) np.random.shuffle(idxs) features = features[idxs, :] survival = survival[idxs] censored = censored[idxs] # stochastic mini-batch GD batchIdxs = dm.get_balanced_batches( censored, BATCH_SIZE=BATCH_SIZE) else: # Global GD batchIdxs = [np.arange(n)] # Run over training set #========================================================== for batchidx, batch in enumerate(batchIdxs): # Getting at-risk groups t_batch, o_batch, at_risk_batch, x_batch = \ sUtils.calc_at_risk(survival[batch], 1-censored[batch], features[batch, :]) # Get at-risk mask (to be multiplied by Pij) n_batch = t_batch.shape[0] # print("\tbatch {} of {}".format(batchidx, n_batch-1)) Pij_mask = np.zeros((n_batch, n_batch)) for idx in range(n_batch): # only observed cases if o_batch[idx] == 1: # only at-risk cases Pij_mask[idx, at_risk_batch[idx]:] = 1 # run optimizer and fetch cost feed_dict = { self.graph.X_input: x_batch, self.graph.Pij_mask: Pij_mask, self.graph.ALPHA: graph_hyperparams['ALPHA'], self.graph.LAMBDA: graph_hyperparams['LAMBDA'], self.graph.SIGMA: graph_hyperparams['SIGMA'], self.graph.DROPOUT_FRACTION: graph_hyperparams['DROPOUT_FRACTION'], } _, cost = sess.run( [self.graph.optimizer, self.graph.cost], feed_dict=feed_dict) # normalize cost for sample size cost = cost / len(batch) # record/append cost #self.Costs_batchLevel_train.append(cost) cost_tot += cost #pUtils.Log_and_print("\t\tTraining: Batch {} of {}, cost = {}".\ # format(batchidx, len(batchIdxs)-1, round(cost[0], 3))) # Now get final NCA matrix (without dropput) #========================================================== feed_dict[self.graph.DROPOUT_FRACTION] = 0 W_grabbed = self.graph.W.eval(feed_dict=feed_dict) # Get Ci for training/validation set #========================================================== # transform x_train_transformed = np.dot(features, W_grabbed) if USE_VALID: x_valid_transformed = np.dot(features_valid, W_grabbed) # get neighbor indices neighbor_idxs_train = \ knnmodel._get_neighbor_idxs(x_train_transformed, x_train_transformed, norm=norm) if USE_VALID: neighbor_idxs_valid = \ knnmodel._get_neighbor_idxs(x_valid_transformed, x_train_transformed, norm=norm) # Predict training/validation set _, Ci_train = knnmodel.predict(neighbor_idxs_train, Survival_train=survival, Censored_train=censored, Survival_test=survival, Censored_test=censored, K=K, Method=Method) if USE_VALID: _, Ci_valid = knnmodel.predict( neighbor_idxs_valid, Survival_train=survival, Censored_train=censored, Survival_test=survival_valid, Censored_test=censored_valid, K=K, Method=Method) if not USE_VALID: Ci_valid = 0 if MONITOR: print("\t{}\t{}\t{}\t{}".format(\ self.EPOCHS_RUN, round(cost_tot, 3), round(Ci_train, 3), round(Ci_valid, 3))) # Update and save #========================================================== # update epochs and append costs self.EPOCHS_RUN += 1 self.Costs_epochLevel_train.append(cost_tot) self.CIs_train.append(Ci_train) self.CIs_valid.append(Ci_valid) # periodically save model #if (self.EPOCHS_RUN % MODEL_SAVE_STEP) == 0: # _saveTFmodel() # periodically monitor progress if MONITOR: if (self.EPOCHS_RUN % PLOT_STEP == 0) and \ (self.EPOCHS_RUN > 0): _monitorProgress() # Early stopping #========================================================== if EARLY_STOPPING: # Save snapshot Ws[:, :, itir % MODEL_BUFFER] = W_grabbed Cis.append(Ci_valid) # Stop when overfitting starts to occur if len(Cis) > (2 * MODEL_BUFFER): ci_new = np.mean(Cis[-MODEL_BUFFER:]) ci_old = np.mean(Cis[-2 * MODEL_BUFFER:-MODEL_BUFFER]) if ci_new < ci_old: snapshot_idx = (itir - MODEL_BUFFER + 1) % MODEL_BUFFER W_grabbed = Ws[:, :, snapshot_idx] break except KeyboardInterrupt: pass #pUtils.Log_and_print("Finished training model.") #pUtils.Log_and_print("Obtaining final results.") if MONITOR: # save final model #_saveTFmodel() # plot costs if EARLY_STOPPING: snapshot = itir - MODEL_BUFFER else: snapshot = None _monitorProgress(snapshot_idx=snapshot) # save learned weights np.save(self.RESULTPATH + 'model/' + self.description + \ self.timestamp + 'NCA_matrix.npy', W_grabbed) return W_grabbed
#============================================================================== if GETLOGS == True: # Separate out validation set N_tot = np.size(Features, 0) Features_valid = Features[int(PERC_TRAIN * N_tot):N_tot, :] Survival_valid = Survival[int(PERC_TRAIN * N_tot):N_tot] Censored_valid = Censored[int(PERC_TRAIN * N_tot):N_tot] Features = Features[0:int(PERC_TRAIN * N_tot), :] Survival = Survival[0:int(PERC_TRAIN * N_tot)] Censored = Censored[0:int(PERC_TRAIN * N_tot)] # Getting at-risk groups (validation set) Features_valid, Survival_valid, Observed_valid, at_risk_valid = \ sUtils.calc_at_risk(Features_valid, Survival_valid, 1-Censored_valid) # Getting at-risk groups (trainign set) Features, Survival, Observed, at_risk = \ sUtils.calc_at_risk(Features, Survival, 1-Censored) #%%============================================================================ # Setting params and other stuff #============================================================================== # Convert to integer/bool (important for BayesOpt to work properly since it # tries float values) EPOCHS = int(EPOCHS) DEPTH = int(DEPTH) MAXWIDTH = int(MAXWIDTH)