def test_evaluation(model, params_dict, class_criterion, device, dl_test): with torch.no_grad(): model.eval() total_loss_test = 0 auc_total_test = 0 loss_test = 0 mse_test = 0 corr_test = 0 num_obs = 0 for i, b in enumerate(dl_test): times = b["times"] time_ptr = b["time_ptr"] X = b["X"].to(device) M = b["M"].to(device) obs_idx = b["obs_idx"] cov = b["cov"].to(device) labels = b["y"].to(device) batch_size = labels.size(0) if b["X_val"] is not None: X_val = b["X_val"].to(device) M_val = b["M_val"].to(device) times_val = b["times_val"] times_idx = b["index_val"] h0 = 0 #torch.zeros(labels.shape[0], params_dict["hidden_size"]).to(device) hT, loss, class_pred, t_vec, p_vec, h_vec, _, _, _, _ = model( times, time_ptr, X, M, obs_idx, delta_t=params_dict["delta_t"], T=params_dict["T"], cov=cov, return_path=True) total_loss = (loss + params_dict["lambda"] * class_criterion(class_pred, labels)) / batch_size try: auc_test = roc_auc_score(labels.cpu(), torch.sigmoid(class_pred).cpu()) except ValueError: print("Only one class. AUC is wrong") auc_test = 0 pass if params_dict["lambda"] == 0: t_vec = np.around( t_vec, str(params_dict["delta_t"])[::-1].find('.')).astype( np.float32 ) #Round floating points error in the time vector. p_val = data_utils.extract_from_path(t_vec, p_vec, times_val, times_idx) m, v = torch.chunk(p_val, 2, dim=1) last_loss = (data_utils.log_lik_gaussian(X_val, m, v) * M_val).sum() mse_loss = (torch.pow(X_val - m, 2) * M_val).sum() corr_test_loss = data_utils.compute_corr(X_val, m, M_val) loss_test += last_loss.cpu().numpy() num_obs += M_val.sum().cpu().numpy() mse_test += mse_loss.cpu().numpy() corr_test += corr_test_loss.cpu().numpy() else: num_obs = 1 total_loss_test += total_loss.cpu().detach().numpy() auc_total_test += auc_test loss_test /= num_obs mse_test /= num_obs auc_total_test /= (i + 1) return (loss_test, auc_total_test, mse_test)
M, obs_idx, delta_t=delta_t, T=T, cov=cov, return_path=True) t_vec = np.around( t_vec, str(delta_t)[::-1].find('.')).astype( np.float32 ) #Round floating points error in the time vector. p_val = data_utils.extract_from_path(t_vec, p_vec, times_val, times_idx) m, v = torch.chunk(p_val, 2, dim=1) last_loss = (data_utils.log_lik_gaussian(X_val, m, v) * M_val).sum() mse_loss = (torch.pow(X_val - m, 2) * M_val).sum() loss_val += last_loss.cpu().numpy() mse_val += mse_loss.cpu().numpy() num_obs += M_val.sum().cpu().numpy() loss_val /= num_obs mse_val /= num_obs print( f"Mean validation loss at epoch {epoch}: nll={loss_val:.5f}, mse={mse_val:.5f} (num_obs={num_obs})" ) print(f"Last validation log likelihood : {loss_val}") print(f"Last validation MSE : {mse_val}")
def train_gruode_mimic(simulation_name, params_dict, device, train_idx, val_idx, test_idx, epoch_max=40, binned60=False): dir_path = r"D:/mimic_iii/clean_data/" if binned60: ##csv_file_path = "../../Datasets/MIMIC/Processed_MIMIC60.csv" ##csv_file_tags = "../../Datasets/MIMIC/MIMIC_tags60.csv" csv_file_path = dir_path + r"GRU_ODE_Dataset.csv" csv_file_tags = dir_path + r"GRU_ODE_death_tags.csv" else: ##csv_file_path = "../../Datasets/MIMIC/Processed_MIMIC.csv" ##csv_file_tags = "../../Datasets/MIMIC/MIMIC_tags.csv" #print("Todo: make simplified Dataset") #sys.exit() csv_file_path = dir_path + r"GRU_ODE_Dataset.csv" csv_file_tags = dir_path + r"GRU_ODE_death_tags.csv" if params_dict["no_cov"]: csv_file_cov = None else: if binned60: #csv_file_cov = "../../Datasets/MIMIC/MIMIC_covs60.csv" csv_file_cov = dir_path + r"GRU_ODE_covariates.csv" else: #csv_file_cov = "../../Datasets/MIMIC/MIMIC_covs.csv" csv_file_cov = dir_path + r"GRU_ODE_covariates.csv" N = pd.read_csv(csv_file_tags)["ID"].nunique() if params_dict["lambda"] == 0: validation = True val_options = {"T_val": 75, "max_val_samples": 3} else: validation = False val_options = None if params_dict["lambda"] == 0: #logger = Logger(f'../../Logs/Regression/{simulation_name}') logger = Logger( f'D:/mimic_iii/clean_data/Logs/gru_ode_bayse/Regression/{simulation_name}' ) else: #logger = Logger(f'../../Logs/Classification/{simulation_name}') logger = Logger( f'D:/mimic_iii/clean_data/Logs/gru_ode_bayse/Classification/{simulation_name}' ) data_train = data_utils.ODE_Dataset(csv_file=csv_file_path, label_file=csv_file_tags, cov_file=csv_file_cov, idx=train_idx) data_val = data_utils.ODE_Dataset(csv_file=csv_file_path, label_file=csv_file_tags, cov_file=csv_file_cov, idx=val_idx, validation=validation, val_options=val_options) data_test = data_utils.ODE_Dataset(csv_file=csv_file_path, label_file=csv_file_tags, cov_file=csv_file_cov, idx=test_idx, validation=validation, val_options=val_options) dl = DataLoader(dataset=data_train, collate_fn=data_utils.custom_collate_fn, shuffle=True, batch_size=150, num_workers=5) dl_val = DataLoader(dataset=data_val, collate_fn=data_utils.custom_collate_fn, shuffle=True, batch_size=150) dl_test = DataLoader(dataset=data_test, collate_fn=data_utils.custom_collate_fn, shuffle=True, batch_size=150) #len(test_idx)) params_dict["input_size"] = data_train.variable_num params_dict["cov_size"] = data_train.cov_dim model_dir = dir_path + "trained_models/" if not os.path.exists(model_dir): os.makedirs(model_dir) np.save(model_dir + f"/{simulation_name}_params.npy", params_dict) nnfwobj = gru_ode_bayes.NNFOwithBayesianJumps( input_size=params_dict["input_size"], hidden_size=params_dict["hidden_size"], p_hidden=params_dict["p_hidden"], prep_hidden=params_dict["prep_hidden"], logvar=params_dict["logvar"], mixing=params_dict["mixing"], classification_hidden=params_dict["classification_hidden"], cov_size=params_dict["cov_size"], cov_hidden=params_dict["cov_hidden"], dropout_rate=params_dict["dropout_rate"], full_gru_ode=params_dict["full_gru_ode"], impute=params_dict["impute"]) nnfwobj.to(device) optimizer = torch.optim.Adam(nnfwobj.parameters(), lr=params_dict["lr"], weight_decay=params_dict["weight_decay"]) class_criterion = torch.nn.BCEWithLogitsLoss(reduction='sum') print("Start Training") val_metric_prev = -1000 for epoch in range(epoch_max): nnfwobj.train() total_train_loss = 0 auc_total_train = 0 for i, b in enumerate(tqdm.tqdm(dl)): optimizer.zero_grad() times = b["times"] time_ptr = b["time_ptr"] X = b["X"].to(device) M = b["M"].to(device) obs_idx = b["obs_idx"] cov = b["cov"].to(device) labels = b["y"].to(device) batch_size = labels.size(0) h0 = 0 # torch.zeros(labels.shape[0], params_dict["hidden_size"]).to(device) hT, loss, class_pred, loss_pre, loss_post = nnfwobj( times, time_ptr, X, M, obs_idx, delta_t=params_dict["delta_t"], T=params_dict["T"], cov=cov) total_loss = (loss + params_dict["lambda"] * class_criterion(class_pred, labels)) / batch_size total_train_loss += total_loss try: auc_total_train += roc_auc_score( labels.detach().cpu(), torch.sigmoid(class_pred).detach().cpu()) except ValueError: print("Single CLASS ! AUC is erroneous") pass total_loss.backward() optimizer.step() info = { 'training_loss': total_train_loss.detach().cpu().numpy() / (i + 1), 'AUC_training': auc_total_train / (i + 1) } for tag, value in info.items(): logger.scalar_summary(tag, value, epoch) data_utils.adjust_learning_rate(optimizer, epoch, params_dict["lr"]) with torch.no_grad(): nnfwobj.eval() total_loss_val = 0 auc_total_val = 0 loss_val = 0 mse_val = 0 corr_val = 0 pre_jump_loss = 0 post_jump_loss = 0 num_obs = 0 for i, b in enumerate(dl_val): times = b["times"] time_ptr = b["time_ptr"] X = b["X"].to(device) M = b["M"].to(device) obs_idx = b["obs_idx"] cov = b["cov"].to(device) labels = b["y"].to(device) batch_size = labels.size(0) if b["X_val"] is not None: X_val = b["X_val"].to(device) M_val = b["M_val"].to(device) times_val = b["times_val"] times_idx = b["index_val"] h0 = 0 #torch.zeros(labels.shape[0], params_dict["hidden_size"]).to(device) hT, loss, class_pred, t_vec, p_vec, h_vec, _, _, loss1, loss2 = nnfwobj( times, time_ptr, X, M, obs_idx, delta_t=params_dict["delta_t"], T=params_dict["T"], cov=cov, return_path=True) total_loss = (loss + params_dict["lambda"] * class_criterion(class_pred, labels)) / batch_size try: auc_val = roc_auc_score(labels.cpu(), torch.sigmoid(class_pred).cpu()) except ValueError: auc_val = 0.5 print("Only one class : AUC is erroneous") pass if params_dict["lambda"] == 0: t_vec = np.around( t_vec, str(params_dict["delta_t"])[::-1].find('.')).astype( np.float32 ) #Round floating points error in the time vector. p_val = data_utils.extract_from_path( t_vec, p_vec, times_val, times_idx) m, v = torch.chunk(p_val, 2, dim=1) last_loss = (data_utils.log_lik_gaussian(X_val, m, v) * M_val).sum() mse_loss = (torch.pow(X_val - m, 2) * M_val).sum() corr_val_loss = data_utils.compute_corr(X_val, m, M_val) loss_val += last_loss.cpu().numpy() num_obs += M_val.sum().cpu().numpy() mse_val += mse_loss.cpu().numpy() corr_val += corr_val_loss.cpu().numpy() else: num_obs = 1 pre_jump_loss += loss1.cpu().detach().numpy() post_jump_loss += loss2.cpu().detach().numpy() total_loss_val += total_loss.cpu().detach().numpy() auc_total_val += auc_val loss_val /= num_obs mse_val /= num_obs info = { 'validation_loss': total_loss_val / (i + 1), 'AUC_validation': auc_total_val / (i + 1), 'loglik_loss': loss_val, 'validation_mse': mse_val, 'correlation_mean': np.nanmean(corr_val), 'correlation_max': np.nanmax(corr_val), 'correlation_min': np.nanmin(corr_val), 'pre_jump_loss': pre_jump_loss / (i + 1), 'post_jump_loss': post_jump_loss / (i + 1) } for tag, value in info.items(): logger.scalar_summary(tag, value, epoch) if params_dict["lambda"] == 0: val_metric = -loss_val else: val_metric = auc_total_val / (i + 1) if val_metric > val_metric_prev: print( f"New highest validation metric reached ! : {val_metric}") print("Saving Model") torch.save( nnfwobj.state_dict(), dir_path + f"trained_models/{simulation_name}_MAX.pt") val_metric_prev = val_metric test_loglik, test_auc, test_mse = test_evaluation( nnfwobj, params_dict, class_criterion, device, dl_test) print(f"Test loglik loss at epoch {epoch} : {test_loglik}") print(f"Test AUC loss at epoch {epoch} : {test_auc}") print(f"Test MSE loss at epoch{epoch} : {test_mse}") else: if epoch % 10: torch.save( nnfwobj.state_dict(), dir_path + f"trained_models/{simulation_name}.pt") print( f"Total validation loss at epoch {epoch}: {total_loss_val/(i+1)}") print(f"Validation AUC at epoch {epoch}: {auc_total_val/(i+1)}") print( f"Validation loss (loglik) at epoch {epoch}: {loss_val:.5f}. MSE : {mse_val:.5f}. Correlation : {np.nanmean(corr_val):.5f}. Num obs = {num_obs}" ) print( f"Finished training GRU-ODE for MIMIC. Saved in /trained_models/{simulation_name}" ) return (info, val_metric_prev, test_loglik, test_auc, test_mse)
p_val = data_utils.extract_from_path(t_vec, p_vec, eval_times=times) print(p_val.shape) m, v = torch.chunk(p_val, 2, dim=1) m = m[:, :, 0].squeeze(2).flatten() v = v[:, :, 0].squeeze(2).flatten() # predicted_mean = p_vec[:,:,0] # predicted_std = p_vec[:,:,-1] #validation할 time은 times에 대해 하면 되지! -> # t_vec = np.around(t_vec,str(delta_t)[::-1].find('.')).astype(np.float32) #Round floating points error in the time vector. # # m, v = torch.chunk(p_val,2,dim=1) last_loss = (data_utils.log_lik_gaussian(X, m, v)).sum() mse_loss = (torch.pow(X - m, 2)).sum() loss_val += last_loss.cpu().numpy() mse_val += mse_loss.cpu().numpy() num_obs += len(times) * batch_size #M_val.sum().cpu().numpy() loss_val /= num_obs mse_val /= num_obs print( f"Mean validation loss at epoch {epoch}: nll={loss_val:.5f}, mse={mse_val:.5f} (num_obs={num_obs})" ) print(f"Last validation log likelihood : {loss_val}") print(f"Last validation MSE : {mse_val}") df_file_name = "./../trained_models/1D-periodic.csv"