collate_fn=data_utils.custom_collate_fn, shuffle=True, batch_size=100, num_workers=4) dl_val = DataLoader(dataset=data_val, collate_fn=data_utils.custom_collate_fn, shuffle=False, batch_size=len(data_val), num_workers=1) ## the neural negative feedback with observation jumps model = gru_ode_bayes.NNFOwithBayesianJumps( input_size=params_dict["input_size"], hidden_size=params_dict["hidden_size"], p_hidden=params_dict["p_hidden"], prep_hidden=params_dict["prep_hidden"], logvar=params_dict["logvar"], mixing=params_dict["mixing"], full_gru_ode=params_dict["full_gru_ode"], solver=params_dict["solver"], impute=params_dict["impute"]) model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0005) epoch_max = 200 params_dict = dict() #Training for epoch in range(epoch_max): model.train() optimizer.zero_grad()
def train_gruode_mimic(simulation_name, params_dict, device, train_idx, val_idx, test_idx, epoch_max=40, binned60=False): dir_path = r"D:/mimic_iii/clean_data/" if binned60: ##csv_file_path = "../../Datasets/MIMIC/Processed_MIMIC60.csv" ##csv_file_tags = "../../Datasets/MIMIC/MIMIC_tags60.csv" csv_file_path = dir_path + r"GRU_ODE_Dataset.csv" csv_file_tags = dir_path + r"GRU_ODE_death_tags.csv" else: ##csv_file_path = "../../Datasets/MIMIC/Processed_MIMIC.csv" ##csv_file_tags = "../../Datasets/MIMIC/MIMIC_tags.csv" #print("Todo: make simplified Dataset") #sys.exit() csv_file_path = dir_path + r"GRU_ODE_Dataset.csv" csv_file_tags = dir_path + r"GRU_ODE_death_tags.csv" if params_dict["no_cov"]: csv_file_cov = None else: if binned60: #csv_file_cov = "../../Datasets/MIMIC/MIMIC_covs60.csv" csv_file_cov = dir_path + r"GRU_ODE_covariates.csv" else: #csv_file_cov = "../../Datasets/MIMIC/MIMIC_covs.csv" csv_file_cov = dir_path + r"GRU_ODE_covariates.csv" N = pd.read_csv(csv_file_tags)["ID"].nunique() if params_dict["lambda"] == 0: validation = True val_options = {"T_val": 75, "max_val_samples": 3} else: validation = False val_options = None if params_dict["lambda"] == 0: #logger = Logger(f'../../Logs/Regression/{simulation_name}') logger = Logger( f'D:/mimic_iii/clean_data/Logs/gru_ode_bayse/Regression/{simulation_name}' ) else: #logger = Logger(f'../../Logs/Classification/{simulation_name}') logger = Logger( f'D:/mimic_iii/clean_data/Logs/gru_ode_bayse/Classification/{simulation_name}' ) data_train = data_utils.ODE_Dataset(csv_file=csv_file_path, label_file=csv_file_tags, cov_file=csv_file_cov, idx=train_idx) data_val = data_utils.ODE_Dataset(csv_file=csv_file_path, label_file=csv_file_tags, cov_file=csv_file_cov, idx=val_idx, validation=validation, val_options=val_options) data_test = data_utils.ODE_Dataset(csv_file=csv_file_path, label_file=csv_file_tags, cov_file=csv_file_cov, idx=test_idx, validation=validation, val_options=val_options) dl = DataLoader(dataset=data_train, collate_fn=data_utils.custom_collate_fn, shuffle=True, batch_size=150, num_workers=5) dl_val = DataLoader(dataset=data_val, collate_fn=data_utils.custom_collate_fn, shuffle=True, batch_size=150) dl_test = DataLoader(dataset=data_test, collate_fn=data_utils.custom_collate_fn, shuffle=True, batch_size=150) #len(test_idx)) params_dict["input_size"] = data_train.variable_num params_dict["cov_size"] = data_train.cov_dim model_dir = dir_path + "trained_models/" if not os.path.exists(model_dir): os.makedirs(model_dir) np.save(model_dir + f"/{simulation_name}_params.npy", params_dict) nnfwobj = gru_ode_bayes.NNFOwithBayesianJumps( input_size=params_dict["input_size"], hidden_size=params_dict["hidden_size"], p_hidden=params_dict["p_hidden"], prep_hidden=params_dict["prep_hidden"], logvar=params_dict["logvar"], mixing=params_dict["mixing"], classification_hidden=params_dict["classification_hidden"], cov_size=params_dict["cov_size"], cov_hidden=params_dict["cov_hidden"], dropout_rate=params_dict["dropout_rate"], full_gru_ode=params_dict["full_gru_ode"], impute=params_dict["impute"]) nnfwobj.to(device) optimizer = torch.optim.Adam(nnfwobj.parameters(), lr=params_dict["lr"], weight_decay=params_dict["weight_decay"]) class_criterion = torch.nn.BCEWithLogitsLoss(reduction='sum') print("Start Training") val_metric_prev = -1000 for epoch in range(epoch_max): nnfwobj.train() total_train_loss = 0 auc_total_train = 0 for i, b in enumerate(tqdm.tqdm(dl)): optimizer.zero_grad() times = b["times"] time_ptr = b["time_ptr"] X = b["X"].to(device) M = b["M"].to(device) obs_idx = b["obs_idx"] cov = b["cov"].to(device) labels = b["y"].to(device) batch_size = labels.size(0) h0 = 0 # torch.zeros(labels.shape[0], params_dict["hidden_size"]).to(device) hT, loss, class_pred, loss_pre, loss_post = nnfwobj( times, time_ptr, X, M, obs_idx, delta_t=params_dict["delta_t"], T=params_dict["T"], cov=cov) total_loss = (loss + params_dict["lambda"] * class_criterion(class_pred, labels)) / batch_size total_train_loss += total_loss try: auc_total_train += roc_auc_score( labels.detach().cpu(), torch.sigmoid(class_pred).detach().cpu()) except ValueError: print("Single CLASS ! AUC is erroneous") pass total_loss.backward() optimizer.step() info = { 'training_loss': total_train_loss.detach().cpu().numpy() / (i + 1), 'AUC_training': auc_total_train / (i + 1) } for tag, value in info.items(): logger.scalar_summary(tag, value, epoch) data_utils.adjust_learning_rate(optimizer, epoch, params_dict["lr"]) with torch.no_grad(): nnfwobj.eval() total_loss_val = 0 auc_total_val = 0 loss_val = 0 mse_val = 0 corr_val = 0 pre_jump_loss = 0 post_jump_loss = 0 num_obs = 0 for i, b in enumerate(dl_val): times = b["times"] time_ptr = b["time_ptr"] X = b["X"].to(device) M = b["M"].to(device) obs_idx = b["obs_idx"] cov = b["cov"].to(device) labels = b["y"].to(device) batch_size = labels.size(0) if b["X_val"] is not None: X_val = b["X_val"].to(device) M_val = b["M_val"].to(device) times_val = b["times_val"] times_idx = b["index_val"] h0 = 0 #torch.zeros(labels.shape[0], params_dict["hidden_size"]).to(device) hT, loss, class_pred, t_vec, p_vec, h_vec, _, _, loss1, loss2 = nnfwobj( times, time_ptr, X, M, obs_idx, delta_t=params_dict["delta_t"], T=params_dict["T"], cov=cov, return_path=True) total_loss = (loss + params_dict["lambda"] * class_criterion(class_pred, labels)) / batch_size try: auc_val = roc_auc_score(labels.cpu(), torch.sigmoid(class_pred).cpu()) except ValueError: auc_val = 0.5 print("Only one class : AUC is erroneous") pass if params_dict["lambda"] == 0: t_vec = np.around( t_vec, str(params_dict["delta_t"])[::-1].find('.')).astype( np.float32 ) #Round floating points error in the time vector. p_val = data_utils.extract_from_path( t_vec, p_vec, times_val, times_idx) m, v = torch.chunk(p_val, 2, dim=1) last_loss = (data_utils.log_lik_gaussian(X_val, m, v) * M_val).sum() mse_loss = (torch.pow(X_val - m, 2) * M_val).sum() corr_val_loss = data_utils.compute_corr(X_val, m, M_val) loss_val += last_loss.cpu().numpy() num_obs += M_val.sum().cpu().numpy() mse_val += mse_loss.cpu().numpy() corr_val += corr_val_loss.cpu().numpy() else: num_obs = 1 pre_jump_loss += loss1.cpu().detach().numpy() post_jump_loss += loss2.cpu().detach().numpy() total_loss_val += total_loss.cpu().detach().numpy() auc_total_val += auc_val loss_val /= num_obs mse_val /= num_obs info = { 'validation_loss': total_loss_val / (i + 1), 'AUC_validation': auc_total_val / (i + 1), 'loglik_loss': loss_val, 'validation_mse': mse_val, 'correlation_mean': np.nanmean(corr_val), 'correlation_max': np.nanmax(corr_val), 'correlation_min': np.nanmin(corr_val), 'pre_jump_loss': pre_jump_loss / (i + 1), 'post_jump_loss': post_jump_loss / (i + 1) } for tag, value in info.items(): logger.scalar_summary(tag, value, epoch) if params_dict["lambda"] == 0: val_metric = -loss_val else: val_metric = auc_total_val / (i + 1) if val_metric > val_metric_prev: print( f"New highest validation metric reached ! : {val_metric}") print("Saving Model") torch.save( nnfwobj.state_dict(), dir_path + f"trained_models/{simulation_name}_MAX.pt") val_metric_prev = val_metric test_loglik, test_auc, test_mse = test_evaluation( nnfwobj, params_dict, class_criterion, device, dl_test) print(f"Test loglik loss at epoch {epoch} : {test_loglik}") print(f"Test AUC loss at epoch {epoch} : {test_auc}") print(f"Test MSE loss at epoch{epoch} : {test_mse}") else: if epoch % 10: torch.save( nnfwobj.state_dict(), dir_path + f"trained_models/{simulation_name}.pt") print( f"Total validation loss at epoch {epoch}: {total_loss_val/(i+1)}") print(f"Validation AUC at epoch {epoch}: {auc_total_val/(i+1)}") print( f"Validation loss (loglik) at epoch {epoch}: {loss_val:.5f}. MSE : {mse_val:.5f}. Correlation : {np.nanmean(corr_val):.5f}. Num obs = {num_obs}" ) print( f"Finished training GRU-ODE for MIMIC. Saved in /trained_models/{simulation_name}" ) return (info, val_metric_prev, test_loglik, test_auc, test_mse)
def plot_trained_model(model_name="paper_random_r", format_image="pdf", random_r=True, max_lag=0, jitter=0, random_theta=False, data_type="double_OU"): style = "fill" summary_dict = np.load(f"./../trained_models/{model_name}_params.npy", allow_pickle=True).item() params_dict = summary_dict["model_params"] metadata = summary_dict["metadata"] if type(params_dict) == np.ndarray: ## converting np array to dictionary: params_dict = params_dict.tolist() #Loading model model = gru_ode_bayes.NNFOwithBayesianJumps( input_size=params_dict["input_size"], hidden_size=params_dict["hidden_size"], p_hidden=params_dict["p_hidden"], prep_hidden=params_dict["prep_hidden"], logvar=params_dict["logvar"], mixing=params_dict["mixing"], full_gru_ode=params_dict["full_gru_ode"], impute=params_dict["impute"], solver=params_dict["solver"], store_hist=True) model.load_state_dict(torch.load(f"./../trained_models/{model_name}.pt")) model.eval() #Test data : N = 10 T = metadata["T"] delta_t = metadata["delta_t"] theta = metadata.pop("theta", None) sigma = metadata["sigma"] rho = metadata["rho"] r_mu = metadata.pop("r_mu", None) sample_rate = metadata["sample_rate"] sample_rate = 1 dual_sample_rate = metadata["dual_sample_rate"] r_std = metadata.pop("r_std", None) #print(f"R std :{r_std}") max_lag = metadata.pop("max_lag", None) if data_type == "double_OU": T = 6 df = double_OU.OU_sample(T=T, dt=delta_t, N=N, sigma=0.1, theta=theta, r_mu=r_mu, r_std=r_std, rho=rho, sample_rate=sample_rate, dual_sample_rate=dual_sample_rate, max_lag=max_lag, random_theta=random_theta, full=True, seed=432) ## for 10 time-points times_1 = [1.0, 2.0, 4.0, 5.0, 7.0, 7.5] times_2 = [2.0, 3.0, 4.0, 6.0] else: df = gru_ode_bayes.datasets.BXLator.datagen.BXL_sample( T=metadata["T"], dt=metadata["delta_t"], N=N, sigma=metadata["sigma"], a=0.3, b=1.4, rho=metadata["rho"], sample_rate=10, dual_sample_rate=1, full=True) ## for 10 time-points times_1 = [2.0, 5.0, 12.0, 15.0, 23.0, 32.0, 35.0, 41.0, 43.0] times_2 = [1.0, 7.0, 12.0, 15.0, 25.0, 32.0, 38.0, 45.0] times = np.union1d(times_1, times_2) obs = df.loc[df["Time"].isin(times)].copy() obs[["Mask_1", "Mask_2"]] = 0 obs.loc[df["Time"].isin(times_1), "Mask_1"] = 1 obs.loc[df["Time"].isin(times_2), "Mask_2"] = 1 data = data_utils.ODE_Dataset(panda_df=obs, jitter_time=jitter) dl = DataLoader(dataset=data, collate_fn=data_utils.custom_collate_fn, shuffle=False, batch_size=1) with torch.no_grad(): for sample, b in enumerate(dl): times = b["times"] time_ptr = b["time_ptr"] X = b["X"] M = b["M"] obs_idx = b["obs_idx"] cov = b["cov"] y = b["y"] hT, loss, _, t_vec, p_vec, _, eval_times, eval_vals = model( times, time_ptr, X, M, obs_idx, delta_t=delta_t, T=T, cov=cov, return_path=True) if params_dict["solver"] == "dopri5": p_vec = eval_vals t_vec = eval_times.cpu().numpy() observations = X.detach().numpy() m, v = torch.chunk(p_vec[:, 0, :], 2, dim=1) if params_dict["logvar"]: up = m + torch.exp(0.5 * v) * 1.96 down = m - torch.exp(0.5 * v) * 1.96 else: up = m + torch.sqrt(v) * 1.96 down = m - torch.sqrt(v) * 1.96 plots_dict = dict() plots_dict["t_vec"] = t_vec plots_dict["up"] = up.numpy() plots_dict["down"] = down.numpy() plots_dict["m"] = m.numpy() plots_dict["observations"] = observations plots_dict["mask"] = M.cpu().numpy() fill_colors = [cm.Blues(0.25), cm.Greens(0.25)] line_colors = [cm.Blues(0.6), cm.Greens(0.6)] colors = ["blue", "green"] ## sde trajectory df_i = df.query(f"ID == {sample}") plt.figure(figsize=(6.4, 4.8)) if style == "fill": for dim in range(2): plt.fill_between(x=t_vec, y1=down[:, dim].numpy(), y2=up[:, dim].numpy(), facecolor=fill_colors[dim], alpha=1.0, zorder=1) plt.plot(t_vec, m[:, dim].numpy(), color=line_colors[dim], linewidth=2, zorder=2, label=f"Dimension {dim+1}") observed_idx = np.where(plots_dict["mask"][:, dim] == 1)[0] plt.scatter(times[observed_idx], observations[observed_idx, dim], color=colors[dim], alpha=0.5, s=60) plt.plot(df_i.Time, df_i[f"Value_{dim+1}"], ":", color=colors[dim], linewidth=1.5, alpha=0.8, label="_nolegend_") else: for dim in range(2): plt.plot(t_vec, up[:, dim].numpy(), "--", color="red", linewidth=2) plt.plot(t_vec, down[:, dim].numpy(), "--", color="red", linewidth=2) plt.plot(t_vec, m[:, dim].numpy(), color=colors[dim], linewidth=2) observed_idx = np.where(plots_dict["mask"][:, dim] == 1)[0] plt.scatter(times[observed_idx], observations[observed_idx, dim], color=colors[dim], alpha=0.5, s=60) plt.plot(df_i.Time, df_i[f"Value_{dim+1}"], ":", color=colors[dim], linewidth=1.5, alpha=0.8) #plt.title("Test trajectory of a double OU process") plt.xlabel("Time") plt.grid() plt.legend(loc="lower right") plt.ylabel("Predicton (+/- 1.96 st. dev)") fname = f"{model_name}_sample{sample}_{style}.{format_image}" plt.tight_layout() plt.savefig(fname) plt.close() print(f"Saved sample into '{fname}'.")