Exemplo n.º 1
0
                collate_fn=data_utils.custom_collate_fn,
                shuffle=True,
                batch_size=100,
                num_workers=4)
dl_val = DataLoader(dataset=data_val,
                    collate_fn=data_utils.custom_collate_fn,
                    shuffle=False,
                    batch_size=len(data_val),
                    num_workers=1)

## the neural negative feedback with observation jumps
model = gru_ode_bayes.NNFOwithBayesianJumps(
    input_size=params_dict["input_size"],
    hidden_size=params_dict["hidden_size"],
    p_hidden=params_dict["p_hidden"],
    prep_hidden=params_dict["prep_hidden"],
    logvar=params_dict["logvar"],
    mixing=params_dict["mixing"],
    full_gru_ode=params_dict["full_gru_ode"],
    solver=params_dict["solver"],
    impute=params_dict["impute"])
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0005)
epoch_max = 200

params_dict = dict()

#Training
for epoch in range(epoch_max):
    model.train()
    optimizer.zero_grad()
def train_gruode_mimic(simulation_name,
                       params_dict,
                       device,
                       train_idx,
                       val_idx,
                       test_idx,
                       epoch_max=40,
                       binned60=False):

    dir_path = r"D:/mimic_iii/clean_data/"

    if binned60:
        ##csv_file_path = "../../Datasets/MIMIC/Processed_MIMIC60.csv"
        ##csv_file_tags = "../../Datasets/MIMIC/MIMIC_tags60.csv"
        csv_file_path = dir_path + r"GRU_ODE_Dataset.csv"
        csv_file_tags = dir_path + r"GRU_ODE_death_tags.csv"

    else:
        ##csv_file_path = "../../Datasets/MIMIC/Processed_MIMIC.csv"
        ##csv_file_tags = "../../Datasets/MIMIC/MIMIC_tags.csv"
        #print("Todo: make simplified Dataset")
        #sys.exit()
        csv_file_path = dir_path + r"GRU_ODE_Dataset.csv"
        csv_file_tags = dir_path + r"GRU_ODE_death_tags.csv"

    if params_dict["no_cov"]:
        csv_file_cov = None
    else:
        if binned60:
            #csv_file_cov = "../../Datasets/MIMIC/MIMIC_covs60.csv"
            csv_file_cov = dir_path + r"GRU_ODE_covariates.csv"
        else:
            #csv_file_cov  = "../../Datasets/MIMIC/MIMIC_covs.csv"
            csv_file_cov = dir_path + r"GRU_ODE_covariates.csv"

    N = pd.read_csv(csv_file_tags)["ID"].nunique()

    if params_dict["lambda"] == 0:
        validation = True
        val_options = {"T_val": 75, "max_val_samples": 3}
    else:
        validation = False
        val_options = None

    if params_dict["lambda"] == 0:
        #logger = Logger(f'../../Logs/Regression/{simulation_name}')
        logger = Logger(
            f'D:/mimic_iii/clean_data/Logs/gru_ode_bayse/Regression/{simulation_name}'
        )
    else:
        #logger = Logger(f'../../Logs/Classification/{simulation_name}')
        logger = Logger(
            f'D:/mimic_iii/clean_data/Logs/gru_ode_bayse/Classification/{simulation_name}'
        )

    data_train = data_utils.ODE_Dataset(csv_file=csv_file_path,
                                        label_file=csv_file_tags,
                                        cov_file=csv_file_cov,
                                        idx=train_idx)
    data_val = data_utils.ODE_Dataset(csv_file=csv_file_path,
                                      label_file=csv_file_tags,
                                      cov_file=csv_file_cov,
                                      idx=val_idx,
                                      validation=validation,
                                      val_options=val_options)
    data_test = data_utils.ODE_Dataset(csv_file=csv_file_path,
                                       label_file=csv_file_tags,
                                       cov_file=csv_file_cov,
                                       idx=test_idx,
                                       validation=validation,
                                       val_options=val_options)

    dl = DataLoader(dataset=data_train,
                    collate_fn=data_utils.custom_collate_fn,
                    shuffle=True,
                    batch_size=150,
                    num_workers=5)
    dl_val = DataLoader(dataset=data_val,
                        collate_fn=data_utils.custom_collate_fn,
                        shuffle=True,
                        batch_size=150)
    dl_test = DataLoader(dataset=data_test,
                         collate_fn=data_utils.custom_collate_fn,
                         shuffle=True,
                         batch_size=150)  #len(test_idx))

    params_dict["input_size"] = data_train.variable_num
    params_dict["cov_size"] = data_train.cov_dim

    model_dir = dir_path + "trained_models/"

    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    np.save(model_dir + f"/{simulation_name}_params.npy", params_dict)

    nnfwobj = gru_ode_bayes.NNFOwithBayesianJumps(
        input_size=params_dict["input_size"],
        hidden_size=params_dict["hidden_size"],
        p_hidden=params_dict["p_hidden"],
        prep_hidden=params_dict["prep_hidden"],
        logvar=params_dict["logvar"],
        mixing=params_dict["mixing"],
        classification_hidden=params_dict["classification_hidden"],
        cov_size=params_dict["cov_size"],
        cov_hidden=params_dict["cov_hidden"],
        dropout_rate=params_dict["dropout_rate"],
        full_gru_ode=params_dict["full_gru_ode"],
        impute=params_dict["impute"])
    nnfwobj.to(device)

    optimizer = torch.optim.Adam(nnfwobj.parameters(),
                                 lr=params_dict["lr"],
                                 weight_decay=params_dict["weight_decay"])
    class_criterion = torch.nn.BCEWithLogitsLoss(reduction='sum')
    print("Start Training")
    val_metric_prev = -1000
    for epoch in range(epoch_max):
        nnfwobj.train()
        total_train_loss = 0
        auc_total_train = 0
        for i, b in enumerate(tqdm.tqdm(dl)):

            optimizer.zero_grad()
            times = b["times"]
            time_ptr = b["time_ptr"]
            X = b["X"].to(device)
            M = b["M"].to(device)
            obs_idx = b["obs_idx"]
            cov = b["cov"].to(device)
            labels = b["y"].to(device)
            batch_size = labels.size(0)

            h0 = 0  # torch.zeros(labels.shape[0], params_dict["hidden_size"]).to(device)
            hT, loss, class_pred, loss_pre, loss_post = nnfwobj(
                times,
                time_ptr,
                X,
                M,
                obs_idx,
                delta_t=params_dict["delta_t"],
                T=params_dict["T"],
                cov=cov)

            total_loss = (loss + params_dict["lambda"] *
                          class_criterion(class_pred, labels)) / batch_size
            total_train_loss += total_loss

            try:
                auc_total_train += roc_auc_score(
                    labels.detach().cpu(),
                    torch.sigmoid(class_pred).detach().cpu())
            except ValueError:
                print("Single CLASS ! AUC is erroneous")
                pass

            total_loss.backward()
            optimizer.step()

        info = {
            'training_loss': total_train_loss.detach().cpu().numpy() / (i + 1),
            'AUC_training': auc_total_train / (i + 1)
        }
        for tag, value in info.items():
            logger.scalar_summary(tag, value, epoch)

        data_utils.adjust_learning_rate(optimizer, epoch, params_dict["lr"])

        with torch.no_grad():
            nnfwobj.eval()
            total_loss_val = 0
            auc_total_val = 0
            loss_val = 0
            mse_val = 0
            corr_val = 0
            pre_jump_loss = 0
            post_jump_loss = 0
            num_obs = 0
            for i, b in enumerate(dl_val):
                times = b["times"]
                time_ptr = b["time_ptr"]
                X = b["X"].to(device)
                M = b["M"].to(device)
                obs_idx = b["obs_idx"]
                cov = b["cov"].to(device)
                labels = b["y"].to(device)
                batch_size = labels.size(0)

                if b["X_val"] is not None:
                    X_val = b["X_val"].to(device)
                    M_val = b["M_val"].to(device)
                    times_val = b["times_val"]
                    times_idx = b["index_val"]

                h0 = 0  #torch.zeros(labels.shape[0], params_dict["hidden_size"]).to(device)
                hT, loss, class_pred, t_vec, p_vec, h_vec, _, _, loss1, loss2 = nnfwobj(
                    times,
                    time_ptr,
                    X,
                    M,
                    obs_idx,
                    delta_t=params_dict["delta_t"],
                    T=params_dict["T"],
                    cov=cov,
                    return_path=True)
                total_loss = (loss + params_dict["lambda"] *
                              class_criterion(class_pred, labels)) / batch_size

                try:
                    auc_val = roc_auc_score(labels.cpu(),
                                            torch.sigmoid(class_pred).cpu())
                except ValueError:
                    auc_val = 0.5
                    print("Only one class : AUC is erroneous")
                    pass

                if params_dict["lambda"] == 0:
                    t_vec = np.around(
                        t_vec,
                        str(params_dict["delta_t"])[::-1].find('.')).astype(
                            np.float32
                        )  #Round floating points error in the time vector.
                    p_val = data_utils.extract_from_path(
                        t_vec, p_vec, times_val, times_idx)
                    m, v = torch.chunk(p_val, 2, dim=1)
                    last_loss = (data_utils.log_lik_gaussian(X_val, m, v) *
                                 M_val).sum()
                    mse_loss = (torch.pow(X_val - m, 2) * M_val).sum()
                    corr_val_loss = data_utils.compute_corr(X_val, m, M_val)

                    loss_val += last_loss.cpu().numpy()
                    num_obs += M_val.sum().cpu().numpy()
                    mse_val += mse_loss.cpu().numpy()
                    corr_val += corr_val_loss.cpu().numpy()
                else:
                    num_obs = 1

                pre_jump_loss += loss1.cpu().detach().numpy()
                post_jump_loss += loss2.cpu().detach().numpy()

                total_loss_val += total_loss.cpu().detach().numpy()
                auc_total_val += auc_val

            loss_val /= num_obs
            mse_val /= num_obs
            info = {
                'validation_loss': total_loss_val / (i + 1),
                'AUC_validation': auc_total_val / (i + 1),
                'loglik_loss': loss_val,
                'validation_mse': mse_val,
                'correlation_mean': np.nanmean(corr_val),
                'correlation_max': np.nanmax(corr_val),
                'correlation_min': np.nanmin(corr_val),
                'pre_jump_loss': pre_jump_loss / (i + 1),
                'post_jump_loss': post_jump_loss / (i + 1)
            }
            for tag, value in info.items():
                logger.scalar_summary(tag, value, epoch)

            if params_dict["lambda"] == 0:
                val_metric = -loss_val
            else:
                val_metric = auc_total_val / (i + 1)

            if val_metric > val_metric_prev:
                print(
                    f"New highest validation metric reached ! : {val_metric}")
                print("Saving Model")
                torch.save(
                    nnfwobj.state_dict(),
                    dir_path + f"trained_models/{simulation_name}_MAX.pt")
                val_metric_prev = val_metric
                test_loglik, test_auc, test_mse = test_evaluation(
                    nnfwobj, params_dict, class_criterion, device, dl_test)
                print(f"Test loglik loss at epoch {epoch} : {test_loglik}")
                print(f"Test AUC loss at epoch {epoch} : {test_auc}")
                print(f"Test MSE loss at epoch{epoch} : {test_mse}")
            else:
                if epoch % 10:
                    torch.save(
                        nnfwobj.state_dict(),
                        dir_path + f"trained_models/{simulation_name}.pt")

        print(
            f"Total validation loss at epoch {epoch}: {total_loss_val/(i+1)}")
        print(f"Validation AUC at epoch {epoch}: {auc_total_val/(i+1)}")
        print(
            f"Validation loss (loglik) at epoch {epoch}: {loss_val:.5f}. MSE : {mse_val:.5f}. Correlation : {np.nanmean(corr_val):.5f}. Num obs = {num_obs}"
        )

    print(
        f"Finished training GRU-ODE for MIMIC. Saved in /trained_models/{simulation_name}"
    )

    return (info, val_metric_prev, test_loglik, test_auc, test_mse)
def plot_trained_model(model_name="paper_random_r",
                       format_image="pdf",
                       random_r=True,
                       max_lag=0,
                       jitter=0,
                       random_theta=False,
                       data_type="double_OU"):

    style = "fill"

    summary_dict = np.load(f"./../trained_models/{model_name}_params.npy",
                           allow_pickle=True).item()

    params_dict = summary_dict["model_params"]
    metadata = summary_dict["metadata"]

    if type(params_dict) == np.ndarray:
        ## converting np array to dictionary:
        params_dict = params_dict.tolist()

    #Loading model
    model = gru_ode_bayes.NNFOwithBayesianJumps(
        input_size=params_dict["input_size"],
        hidden_size=params_dict["hidden_size"],
        p_hidden=params_dict["p_hidden"],
        prep_hidden=params_dict["prep_hidden"],
        logvar=params_dict["logvar"],
        mixing=params_dict["mixing"],
        full_gru_ode=params_dict["full_gru_ode"],
        impute=params_dict["impute"],
        solver=params_dict["solver"],
        store_hist=True)

    model.load_state_dict(torch.load(f"./../trained_models/{model_name}.pt"))
    model.eval()

    #Test data :
    N = 10
    T = metadata["T"]
    delta_t = metadata["delta_t"]
    theta = metadata.pop("theta", None)
    sigma = metadata["sigma"]
    rho = metadata["rho"]
    r_mu = metadata.pop("r_mu", None)
    sample_rate = metadata["sample_rate"]
    sample_rate = 1
    dual_sample_rate = metadata["dual_sample_rate"]
    r_std = metadata.pop("r_std", None)
    #print(f"R std :{r_std}")
    max_lag = metadata.pop("max_lag", None)

    if data_type == "double_OU":
        T = 6
        df = double_OU.OU_sample(T=T,
                                 dt=delta_t,
                                 N=N,
                                 sigma=0.1,
                                 theta=theta,
                                 r_mu=r_mu,
                                 r_std=r_std,
                                 rho=rho,
                                 sample_rate=sample_rate,
                                 dual_sample_rate=dual_sample_rate,
                                 max_lag=max_lag,
                                 random_theta=random_theta,
                                 full=True,
                                 seed=432)

        ## for 10 time-points
        times_1 = [1.0, 2.0, 4.0, 5.0, 7.0, 7.5]
        times_2 = [2.0, 3.0, 4.0, 6.0]
    else:
        df = gru_ode_bayes.datasets.BXLator.datagen.BXL_sample(
            T=metadata["T"],
            dt=metadata["delta_t"],
            N=N,
            sigma=metadata["sigma"],
            a=0.3,
            b=1.4,
            rho=metadata["rho"],
            sample_rate=10,
            dual_sample_rate=1,
            full=True)

        ## for 10 time-points
        times_1 = [2.0, 5.0, 12.0, 15.0, 23.0, 32.0, 35.0, 41.0, 43.0]
        times_2 = [1.0, 7.0, 12.0, 15.0, 25.0, 32.0, 38.0, 45.0]

    times = np.union1d(times_1, times_2)
    obs = df.loc[df["Time"].isin(times)].copy()
    obs[["Mask_1", "Mask_2"]] = 0
    obs.loc[df["Time"].isin(times_1), "Mask_1"] = 1
    obs.loc[df["Time"].isin(times_2), "Mask_2"] = 1

    data = data_utils.ODE_Dataset(panda_df=obs, jitter_time=jitter)
    dl = DataLoader(dataset=data,
                    collate_fn=data_utils.custom_collate_fn,
                    shuffle=False,
                    batch_size=1)

    with torch.no_grad():
        for sample, b in enumerate(dl):
            times = b["times"]
            time_ptr = b["time_ptr"]
            X = b["X"]
            M = b["M"]
            obs_idx = b["obs_idx"]
            cov = b["cov"]

            y = b["y"]
            hT, loss, _, t_vec, p_vec, _, eval_times, eval_vals = model(
                times,
                time_ptr,
                X,
                M,
                obs_idx,
                delta_t=delta_t,
                T=T,
                cov=cov,
                return_path=True)

            if params_dict["solver"] == "dopri5":
                p_vec = eval_vals
                t_vec = eval_times.cpu().numpy()

            observations = X.detach().numpy()
            m, v = torch.chunk(p_vec[:, 0, :], 2, dim=1)

            if params_dict["logvar"]:
                up = m + torch.exp(0.5 * v) * 1.96
                down = m - torch.exp(0.5 * v) * 1.96
            else:
                up = m + torch.sqrt(v) * 1.96
                down = m - torch.sqrt(v) * 1.96

            plots_dict = dict()
            plots_dict["t_vec"] = t_vec
            plots_dict["up"] = up.numpy()
            plots_dict["down"] = down.numpy()
            plots_dict["m"] = m.numpy()
            plots_dict["observations"] = observations
            plots_dict["mask"] = M.cpu().numpy()

            fill_colors = [cm.Blues(0.25), cm.Greens(0.25)]

            line_colors = [cm.Blues(0.6), cm.Greens(0.6)]
            colors = ["blue", "green"]

            ## sde trajectory
            df_i = df.query(f"ID == {sample}")

            plt.figure(figsize=(6.4, 4.8))
            if style == "fill":
                for dim in range(2):
                    plt.fill_between(x=t_vec,
                                     y1=down[:, dim].numpy(),
                                     y2=up[:, dim].numpy(),
                                     facecolor=fill_colors[dim],
                                     alpha=1.0,
                                     zorder=1)
                    plt.plot(t_vec,
                             m[:, dim].numpy(),
                             color=line_colors[dim],
                             linewidth=2,
                             zorder=2,
                             label=f"Dimension {dim+1}")
                    observed_idx = np.where(plots_dict["mask"][:, dim] == 1)[0]
                    plt.scatter(times[observed_idx],
                                observations[observed_idx, dim],
                                color=colors[dim],
                                alpha=0.5,
                                s=60)
                    plt.plot(df_i.Time,
                             df_i[f"Value_{dim+1}"],
                             ":",
                             color=colors[dim],
                             linewidth=1.5,
                             alpha=0.8,
                             label="_nolegend_")
            else:
                for dim in range(2):
                    plt.plot(t_vec,
                             up[:, dim].numpy(),
                             "--",
                             color="red",
                             linewidth=2)
                    plt.plot(t_vec,
                             down[:, dim].numpy(),
                             "--",
                             color="red",
                             linewidth=2)
                    plt.plot(t_vec,
                             m[:, dim].numpy(),
                             color=colors[dim],
                             linewidth=2)
                    observed_idx = np.where(plots_dict["mask"][:, dim] == 1)[0]
                    plt.scatter(times[observed_idx],
                                observations[observed_idx, dim],
                                color=colors[dim],
                                alpha=0.5,
                                s=60)
                    plt.plot(df_i.Time,
                             df_i[f"Value_{dim+1}"],
                             ":",
                             color=colors[dim],
                             linewidth=1.5,
                             alpha=0.8)

            #plt.title("Test trajectory of a double OU process")
            plt.xlabel("Time")
            plt.grid()
            plt.legend(loc="lower right")
            plt.ylabel("Predicton (+/- 1.96 st. dev)")
            fname = f"{model_name}_sample{sample}_{style}.{format_image}"
            plt.tight_layout()
            plt.savefig(fname)
            plt.close()
            print(f"Saved sample into '{fname}'.")