def run_experiment(p, out_dir):
    """
    Function to run the experiments.
    p contain all the hyperparameters needed to run the experiments
    We assume that all the parameters needed are present in p!!
    out_dir is the out directory
    #hyperparameters
    """
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    #Seed
    torch.manual_seed(p["seed"])
    np.random.seed(p["seed"])

    #Redirect output to the out dir
    # sys.stdout = open(out_dir + 'output.out', 'w')

    #save parameters to the out dir
    with open(out_dir + "params.txt", "w") as f:
        f.write(str(p))

    # DEVICE
    ## Decidint on device on device.
    DEVICE_ID = 0
    DEVICE = torch.device(
        'cuda:' + str(DEVICE_ID) if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        torch.cuda.set_device(DEVICE_ID)

    #Tensors should have the shape
    # [ntp n_ch, n_batch, n_feat]
    #as n_feat can be different across channels, ntp and n_ch need to be lists. n_batch and n_feat are the tensors
    X_train_pad = []
    X_test_pad = []

    mask_train_tensor = []
    mask_test_tensor = []

    #generate the data, and the mask corresponding to each channel
    for ch_curves in p['curves']:
        gen_model = SinDataGenerator(ch_curves,
                                     p["ntp"],
                                     p["noise"],
                                     variable_tp=True)
        samples = gen_model.generate_n_samples(p["nsamples"])
        X_train = np.asarray([y for (_, y) in samples])
        X_train_tensor = [torch.FloatTensor(t) for t in X_train]
        X_train_pad_i = nn.utils.rnn.pad_sequence(X_train_tensor,
                                                  batch_first=False,
                                                  padding_value=np.nan)
        mask_train = ~torch.isnan(X_train_pad_i)
        mask_train_tensor.append(mask_train.to(DEVICE))
        X_train_pad_i[torch.isnan(X_train_pad_i)] = 0
        X_train_pad.append(X_train_pad_i.to(DEVICE))

        #Do teh same for the testing set
        samples = gen_model.generate_n_samples(int(p["nsamples"] * 0.8))
        X_test = np.asarray([y for (_, y) in samples])
        X_test_tensor = [torch.FloatTensor(t) for t in X_test]
        X_test_pad_i = nn.utils.rnn.pad_sequence(X_test_tensor,
                                                 batch_first=False,
                                                 padding_value=np.nan)
        mask_test = ~torch.isnan(X_test_pad_i)
        mask_test_tensor.append(mask_test.to(DEVICE))
        X_test_pad_i[torch.isnan(X_test_pad_i)] = 0
        X_test_pad.append(X_test_pad_i.to(DEVICE))

    #Stack along first dimension
    #cant do that bc last dimension (features) could be different length
    # X_train_tensor = torch.stack(X_train_tensor, dim=0)
    #X_test_tensor = torch.stack(X_test_tensor, dim=0)

    #Prepare model
    # Define model and optimizer
    model = rnnvae.MCRNNVAE(p["h_size"],
                            p["hidden"],
                            p["n_layers"],
                            p["hidden"],
                            p["n_layers"],
                            p["hidden"],
                            p["n_layers"],
                            p["z_dim"],
                            p["hidden"],
                            p["n_layers"],
                            p["clip"],
                            p["n_epochs"],
                            p["batch_size"],
                            p["n_channels"],
                            p["ch_type"],
                            p["n_feats"],
                            DEVICE,
                            print_every=100,
                            phi_layers=p["phi_layers"],
                            sigmoid_mean=p["sig_mean"],
                            dropout=p["dropout"],
                            dropout_threshold=p["drop_th"])

    optimizer = torch.optim.Adam(model.parameters(), lr=p["learning_rate"])
    model.optimizer = optimizer

    model = model.to(DEVICE)
    # Fit the model
    model.fit(X_train_pad, X_test_pad, mask_train_tensor, mask_test_tensor)

    ### After training, save the model!
    model.save(out_dir, 'model.pt')

    # Predict the reconstructions from X_val and X_train
    X_test_fwd = model.predict(X_test_pad, nt=p["ntp"])
    X_train_fwd = model.predict(X_train_pad, nt=p["ntp"])

    # Unpad using the masks
    #plot validation and
    plot_total_loss(model.loss['total'], model.val_loss['total'], "Total loss",
                    out_dir, "total_loss.png")
    plot_total_loss(model.loss['kl'], model.val_loss['kl'], "kl_loss", out_dir,
                    "kl_loss.png")
    plot_total_loss(model.loss['ll'], model.val_loss['ll'], "ll_loss", out_dir,
                    "ll_loss.png")  #Negative to see downard curve

    #Compute mse and reconstruction loss
    #General mse and reconstruction over
    test_loss = model.recon_loss(X_test_fwd,
                                 target=X_test_pad,
                                 mask=mask_test_tensor)
    train_loss = model.recon_loss(X_train_fwd,
                                  target=X_train_pad,
                                  mask=mask_train_tensor)

    print('MSE over the train set: ' + str(train_loss["mae"]))
    print('Reconstruction loss over the train set: ' +
          str(train_loss["rec_loss"]))

    print('MSE over the test set: ' + str(test_loss["mae"]))
    print('Reconstruction loss the train set: ' + str(test_loss["rec_loss"]))

    ##Latent spasce
    #Reformulate things
    z_train = [np.array(x).swapaxes(0, 1) for x in X_train_fwd['z']]
    z_test = [np.array(x).swapaxes(0, 1) for x in X_test_fwd['z']]

    # Dir for projections
    proj_path = 'z_proj/'
    if not os.path.exists(out_dir + proj_path):
        os.makedirs(out_dir + proj_path)

    #plot latent space
    for ch in range(p["n_channels"]):
        for dim0 in range(p["z_dim"]):
            for dim1 in range(dim0, p["z_dim"]):
                if dim0 == dim1: continue  # very dirty
                plot_z_time_2d(z_train[ch],
                               p["ntp"], [dim0, dim1],
                               out_dir + proj_path,
                               out_name=f'z_ch_{ch}_d{dim0}_d{dim1}')

    # Dir for projections
    sampling_path = 'z_proj_sampling/'
    if not os.path.exists(out_dir + sampling_path):
        os.makedirs(out_dir + sampling_path)

    # Test the new function of latent space
    qzx = [np.array(x) for x in X_train_fwd['qzx']]

    # Get classificator labels, for n time points
    classif = [[i] * p["nsamples"] for i in range(p["ntp"])]
    classif = np.array([str(item) for elem in classif for item in elem])
    print("on_classif")
    print(p["ntp"])
    print(len(classif))

    out_dir_sample = out_dir + 'test_zspace_function/'
    if not os.path.exists(out_dir_sample):
        os.makedirs(out_dir_sample)

    # TODO: ADAPT THIS FUNCTION TO THE
    #plot_latent_space(model, qzx, p["ntp"], classificator=classif, plt_tp='all',
    #                all_plots=False, uncertainty=True, savefig=True, out_dir=out_dir_sample)

    loss = {
        "mse_train": train_loss["mae"],
        "mse_test": test_loss["mae"],
        "loss_total": model.loss['total'][-1],
        "loss_kl": model.loss['kl'][-1],
        "loss_ll": model.loss['ll'][-1]
    }

    return loss
Exemple #2
0
def run_experiment(p, csv_path, out_dir, data_cols=[]):
    """
    Function to run the experiments.
    p contain all the hyperparameters needed to run the experiments
    We assume that all the parameters needed are present in p!!
    out_dir is the out directory
    #hyperparameters
    """

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    #Seed
    torch.manual_seed(p["seed"])
    np.random.seed(p["seed"])

    #Redirect output to the out dir
    # sys.stdout = open(out_dir + 'output.out', 'w')

    #save parameters to the out dir
    with open(out_dir + "params.txt", "w") as f:
        f.write(str(p))

    # DEVICE
    ## Decidint on device on device.
    DEVICE_ID = 0
    DEVICE = torch.device(
        'cuda:' + str(DEVICE_ID) if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        torch.cuda.set_device(DEVICE_ID)

    # LOAD DATA
    #Start by not using validation data
    # this is a list of values
    X_train, X_test, Y_train, Y_test, mri_col = load_multimodal_data(
        csv_path,
        data_cols,
        p["ch_type"],
        train_set=0.9,
        normalize=True,
        return_covariates=True)

    p["n_feats"] = [x[0].shape[1] for x in X_train]

    X_train_list = []
    mask_train_list = []

    X_test_list = []
    mask_test_list = []

    print('Length of train/test')
    print(len(X_train[0]))
    print(len(X_test[0]))

    #For each channel, pad, create the mask, and append
    for x_ch in X_train:
        X_train_tensor = [torch.FloatTensor(t) for t in x_ch]
        X_train_pad = nn.utils.rnn.pad_sequence(X_train_tensor,
                                                batch_first=False,
                                                padding_value=np.nan)
        mask_train = ~torch.isnan(X_train_pad)
        mask_train_list.append(mask_train.to(DEVICE))
        X_train_pad[torch.isnan(X_train_pad)] = 0
        X_train_list.append(X_train_pad.to(DEVICE))

    for x_ch in X_test:
        X_test_tensor = [torch.FloatTensor(t) for t in x_ch]
        X_test_pad = nn.utils.rnn.pad_sequence(X_test_tensor,
                                               batch_first=False,
                                               padding_value=np.nan)
        mask_test = ~torch.isnan(X_test_pad)
        mask_test_list.append(mask_test.to(DEVICE))
        X_test_pad[torch.isnan(X_test_pad)] = 0
        X_test_list.append(X_test_pad.to(DEVICE))

    # ntp = max(X_train_list[0].shape[0], X_test_list[0].shape[0])
    ntp = max(max([x.shape[0] for x in X_train_list]),
              max([x.shape[0] for x in X_train_list]))

    model = rnnvae_h.MCRNNVAE(p["h_size"],
                              p["hidden"],
                              p["n_layers"],
                              p["hidden"],
                              p["n_layers"],
                              p["hidden"],
                              p["n_layers"],
                              p["z_dim"],
                              p["hidden"],
                              p["n_layers"],
                              p["clip"],
                              p["n_epochs"],
                              p["batch_size"],
                              p["n_channels"],
                              p["ch_type"],
                              p["n_feats"],
                              DEVICE,
                              print_every=100,
                              phi_layers=p["phi_layers"],
                              sigmoid_mean=p["sig_mean"],
                              dropout=p["dropout"],
                              dropout_threshold=p["drop_th"])

    model.ch_name = p["ch_names"]

    optimizer = torch.optim.Adam(model.parameters(), lr=p["learning_rate"])
    model.optimizer = optimizer

    model = model.to(DEVICE)

    # Fit the model
    model.fit(X_train_list, X_test_list, mask_train_list, mask_test_list)

    #fit the model after changing the lr
    if p["dropout"]:
        print("Print the dropout")
        print(model.dropout_comp)

    ### After training, save the model!
    model.save(out_dir, 'model.pt')

    # Predict the reconstructions from X_val and X_train
    X_train_fwd = model.predict(X_train_list, mask_train_list, nt=ntp)
    X_test_fwd = model.predict(X_test_list, mask_test_list, nt=ntp)

    # Unpad using the masks
    #plot validation and
    plot_total_loss(model.loss['total'], model.val_loss['total'], "Total loss",
                    out_dir, "total_loss.png")
    plot_total_loss(model.loss['kl'], model.val_loss['kl'], "kl_loss", out_dir,
                    "kl_loss.png")
    plot_total_loss(model.loss['ll'], model.val_loss['ll'], "ll_loss", out_dir,
                    "ll_loss.png")  #Negative to see downard curve

    #Compute mse and reconstruction loss
    #General mse and reconstruction over
    # test_loss = model.recon_loss(X_test_fwd, target=X_test_pad, mask=mask_test_tensor)
    train_loss = model.recon_loss(X_train_fwd,
                                  target=X_train_list,
                                  mask=mask_train_list)
    test_loss = model.recon_loss(X_test_fwd,
                                 target=X_test_list,
                                 mask=mask_test_list)

    print('MSE over the train set: ' + str(train_loss["mae"]))
    print('Reconstruction loss over the train set: ' +
          str(train_loss["rec_loss"]))

    print('MSE over the test set: ' + str(test_loss["mae"]))
    print('Reconstruction loss the train set: ' + str(test_loss["rec_loss"]))

    pred_results = {}
    for ch_name in p["ch_names"][:3]:
        pred_results[f"pred_{ch_name}_mae"] = []

    rec_results = {}
    for ch_name in p["ch_names"]:
        rec_results[f"recon_{ch_name}_mae"] = []

    results = {**pred_results, **rec_results}

    ######################
    ## Prediction of last time point
    ######################

    # FUTURE TWO TP
    X_test_list_minus = []
    X_test_tensors = []
    mask_test_list_minus = []
    for x_ch in X_test:
        X_test_tensor = [torch.FloatTensor(t[:-1, :]) for t in x_ch]
        X_test_tensor_full = [torch.FloatTensor(t) for t in x_ch]
        X_test_tensors.append(X_test_tensor_full)
        X_test_pad = nn.utils.rnn.pad_sequence(X_test_tensor,
                                               batch_first=False,
                                               padding_value=np.nan)
        mask_test = ~torch.isnan(X_test_pad)
        mask_test_list_minus.append(mask_test.to(DEVICE))
        X_test_pad[torch.isnan(X_test_pad)] = 0
        X_test_list_minus.append(X_test_pad.to(DEVICE))

    # Run prediction
    #this is terribly programmed holy shit
    X_test_fwd_minus = model.predict(X_test_list_minus,
                                     mask_test_list_minus,
                                     nt=ntp)
    X_test_xnext = X_test_fwd_minus["xnext"]

    # Test data without last timepoint
    # X_test_tensors do have the last timepoint
    i = 0
    # import pdb; pdb.set_trace()
    for (X_ch, ch) in zip(X_test[:3], p["ch_names"][:3]):
        #Select a single channel
        print(f'testing for {ch}')
        y_true = [x[-1] for x in X_ch if len(x) > 1]
        last_tp = [len(x) - 1 for x in X_ch
                   ]  # last tp is max size of original data minus one
        y_pred = []
        # for each subject, select last tp
        j = 0
        for tp in last_tp:
            if tp < 1:
                j += 1
                continue  # ignore tps with only baseline

            y_pred.append(X_test_xnext[i][tp, j, :])
            j += 1

        #Process it to predict it
        mae_tp_ch = mean_absolute_error(y_true, y_pred)
        #save the result
        results[f'pred_{ch}_mae'] = mae_tp_ch
        i += 1

    ############################
    ## Test reconstruction for each channel, using the other one
    ############################
    # For each channel
    if p["n_channels"] > 1:

        for i in range(len(X_test)):
            curr_name = p["ch_names"][i]
            av_ch = list(range(len(X_test)))
            av_ch.remove(i)
            # try to reconstruct it from the other ones
            ch_recon = model.predict(X_test_list,
                                     mask_test_list,
                                     nt=ntp,
                                     av_ch=av_ch,
                                     task='recon')
            #for all existing timepoints

            y_true = X_test[i]
            # swap dims to iterate over subjects
            y_pred = np.transpose(ch_recon["xnext"][i], (1, 0, 2))
            y_pred = [
                x_pred[:len(x_true)]
                for (x_pred, x_true) in zip(y_pred, y_true)
            ]

            #prepare it timepoint wise
            y_pred = [tp for subj in y_pred for tp in subj]
            y_true = [tp for subj in y_true for tp in subj]

            mae_rec_ch = mean_absolute_error(y_true, y_pred)

            # Get MAE result for that specific channel over all timepoints
            results[f"recon_{curr_name}_mae"] = mae_rec_ch

    loss = {
        "mae_train": train_loss["mae"],
        "rec_train": train_loss["rec_loss"],
        "mae_test": test_loss["mae"],
        "loss_total": model.loss['total'][-1],
        "loss_kl": model.loss['kl'][-1],
        "loss_ll": model.loss['ll'][-1],
    }

    if p["dropout"]:
        loss["dropout_comps"] = model.dropout_comp

    loss = {**loss, **results}
    print(loss)

    # Dir for projections
    proj_path = 'z_proj/'
    if not os.path.exists(out_dir + proj_path):
        os.makedirs(out_dir + proj_path)

    # Test the new function of latent space
    #NEED TO ADAPT THIS FUNCTION
    qzx_train = [np.array(x) for x in X_train_fwd['qzx']]
    qzx_test = [np.array(x) for x in X_test_fwd['qzx']]

    #Convert to standard
    #Add padding so that the mask also works here
    DX_train = [[x for x in elem] for elem in Y_train["DX"]]
    DX_test = [[x for x in elem] for elem in Y_test["DX"]]

    #Define colors
    pallete_dict = {"CN": "#2a9e1e", "MCI": "#bfbc1a", "AD": "#af1f1f"}
    # Get classificator labels, for n time points
    out_dir_sample = out_dir + 'zcomp_ch_dx/'
    if not os.path.exists(out_dir_sample):
        os.makedirs(out_dir_sample)

    plot_latent_space(model,
                      qzx_test,
                      ntp,
                      classificator=DX_test,
                      pallete_dict=pallete_dict,
                      plt_tp='all',
                      all_plots=True,
                      uncertainty=False,
                      savefig=True,
                      out_dir=out_dir_sample + '_test',
                      mask=mask_test_list)

    plot_latent_space(model,
                      qzx_train,
                      ntp,
                      classificator=DX_train,
                      pallete_dict=pallete_dict,
                      plt_tp='all',
                      all_plots=True,
                      uncertainty=False,
                      savefig=True,
                      out_dir=out_dir_sample + '_train',
                      mask=mask_train_list)

    out_dir_sample_t0 = out_dir + 'zcomp_ch_dx_t0/'
    if not os.path.exists(out_dir_sample_t0):
        os.makedirs(out_dir_sample_t0)

    plot_latent_space(model,
                      qzx_train,
                      ntp,
                      classificator=DX_train,
                      pallete_dict=pallete_dict,
                      plt_tp=[0],
                      all_plots=True,
                      uncertainty=False,
                      savefig=True,
                      out_dir=out_dir_sample_t0 + '_train',
                      mask=mask_train_list)

    plot_latent_space(model,
                      qzx_test,
                      ntp,
                      classificator=DX_test,
                      pallete_dict=pallete_dict,
                      plt_tp=[0],
                      all_plots=True,
                      uncertainty=False,
                      savefig=True,
                      out_dir=out_dir_sample_t0 + '_test',
                      mask=mask_test_list)

    # Now plot color by timepoint
    out_dir_sample = out_dir + 'zcomp_ch_tp/'
    if not os.path.exists(out_dir_sample):
        os.makedirs(out_dir_sample)

    classif_train = [[i for (i, x) in enumerate(elem)]
                     for elem in Y_train["DX"]]
    classif_test = [[i for (i, x) in enumerate(elem)] for elem in Y_test["DX"]]

    pallete = sns.color_palette("viridis", ntp)
    pallete_dict = {i: value for (i, value) in enumerate(pallete)}

    plot_latent_space(model,
                      qzx_train,
                      ntp,
                      classificator=classif_train,
                      pallete_dict=pallete_dict,
                      plt_tp='all',
                      all_plots=True,
                      uncertainty=False,
                      savefig=True,
                      out_dir=out_dir_sample + '_train',
                      mask=mask_train_list)

    plot_latent_space(model,
                      qzx_test,
                      ntp,
                      classificator=classif_test,
                      pallete_dict=pallete_dict,
                      plt_tp='all',
                      all_plots=True,
                      uncertainty=False,
                      savefig=True,
                      out_dir=out_dir_sample + '_test',
                      mask=mask_test_list)

    return loss
Exemple #3
0
def run_experiment(p, csv_path, out_dir, data_cols=[]):
    """
    Function to run the experiments.
    p contain all the hyperparameters needed to run the experiments
    We assume that all the parameters needed are present in p!!
    out_dir is the out directory
    #hyperparameters
    """
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    #Seed
    torch.manual_seed(p["seed"])
    np.random.seed(p["seed"])

    #Redirect output to the out dir
    sys.stdout = open(out_dir + 'output.out', 'w')

    #save parameters to the out dir
    with open(out_dir + "params.txt", "w") as f:
        f.write(str(p))

    # DEVICE
    ## Decidint on device on device.
    DEVICE_ID = 0
    DEVICE = torch.device(
        'cuda:' + str(DEVICE_ID) if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        torch.cuda.set_device(DEVICE_ID)

    # Begin on the CV data
    gen = load_multimodal_data_cv(csv_path,
                                  data_cols,
                                  p["ch_type"],
                                  nsplit=10,
                                  normalize=True)
    # Prepare the data structures for the data

    # Load the different folds
    loss = {
        "mae_train": [],
        "rec_train": [],
        "mae_test": [],
        "loss_total": [],
        "loss_total_val": [],
        "loss_kl": [],
        "loss_ll": [],
    }

    pred_results = {}
    for ch_name in p["ch_names"][:3]:
        pred_results[f"pred_{ch_name}_mae"] = []

    rec_results = {}
    for ch_name in p["ch_names"]:
        rec_results[f"recon_{ch_name}_mae"] = []

    loss = {**loss, **pred_results, **rec_results}

    # iterator marking the fold
    fold_n = 0

    for X_train, X_test, Y_train, Y_test, mri_col in gen:
        # LOAD DATA
        #Start by not using validation data
        # this is a list of values

        #Create output dir for the fold
        out_dir_cv = out_dir + f'_fold_{fold_n}/'
        if not os.path.exists(out_dir_cv):
            os.makedirs(out_dir_cv)

        #Redirect output to specific folder
        sys.stdout = open(out_dir_cv + 'output.out', 'w')

        p["n_feats"] = [x[0].shape[1] for x in X_train]

        X_train_list = []
        mask_train_list = []

        X_test_list = []
        mask_test_list = []

        print('Length of train/test')
        print(len(X_train[0]))
        print(len(X_test[0]))

        # need to deal with ntp here
        ntp = max(np.max([[len(xi) for xi in x] for x in X_train]),
                  np.max([[len(xi) for xi in x] for x in X_train]))

        if p["long_to_bl"]:
            # HERE, change bl to long and repeat the values at t0 for ntp
            for i in range(len(p["ch_type"])):
                if p["ch_type"][i] == 'bl':
                    for j in range(len(X_train[i])):
                        X_train[i][j] = np.array([X_train[i][j][0]] * ntp)

                    for j in range(len(X_test[i])):
                        X_test[i][j] = np.array([X_test[i][j][0]] * ntp)

                    # p["ch_type"][i] = 'long'

        #For each channel, pad, create the mask, and append
        for x_ch in X_train:
            X_train_tensor = [torch.FloatTensor(t) for t in x_ch]
            X_train_pad = nn.utils.rnn.pad_sequence(X_train_tensor,
                                                    batch_first=False,
                                                    padding_value=np.nan)
            mask_train = ~torch.isnan(X_train_pad)
            mask_train_list.append(mask_train.to(DEVICE))
            X_train_pad[torch.isnan(X_train_pad)] = 0
            X_train_list.append(X_train_pad.to(DEVICE))

        for x_ch in X_test:
            X_test_tensor = [torch.FloatTensor(t) for t in x_ch]
            X_test_pad = nn.utils.rnn.pad_sequence(X_test_tensor,
                                                   batch_first=False,
                                                   padding_value=np.nan)
            mask_test = ~torch.isnan(X_test_pad)
            mask_test_list.append(mask_test.to(DEVICE))
            X_test_pad[torch.isnan(X_test_pad)] = 0
            X_test_list.append(X_test_pad.to(DEVICE))

        #ntp = max(max([x.shape[0] for x in X_train_list]), max([x.shape[0] for x in X_train_list]))

        model = rnnvae_h.MCRNNVAE(p["h_size"],
                                  p["x_hidden"],
                                  p["x_n_layers"],
                                  p["z_hidden"],
                                  p["z_n_layers"],
                                  p["enc_hidden"],
                                  p["enc_n_layers"],
                                  p["z_dim"],
                                  p["dec_hidden"],
                                  p["dec_n_layers"],
                                  p["clip"],
                                  p["n_epochs"],
                                  p["batch_size"],
                                  p["n_channels"],
                                  p["ch_type"],
                                  p["n_feats"],
                                  p["c_z"],
                                  DEVICE,
                                  print_every=100,
                                  phi_layers=p["phi_layers"],
                                  sigmoid_mean=p["sig_mean"],
                                  dropout=p["dropout"],
                                  dropout_threshold=p["drop_th"])
        model.ch_name = p["ch_names"]

        optimizer = torch.optim.Adam(model.parameters(), lr=p["learning_rate"])
        model.optimizer = optimizer

        model = model.to(DEVICE)

        # Fit the model
        model.fit(X_train_list, X_test_list, mask_train_list, mask_test_list)

        #fit the model after changing the lr
        #optimizer = torch.optim.Adam(model.parameters(), lr=p["learning_rate"]*.1)
        #model.optimizer = optimizer
        #print('Refining optimization...')
        #model.fit(X_train_list, X_test_list, mask_train_list, mask_test_list)

        if p["dropout"]:
            print("Print the dropout")
            print(model.dropout_comp)

        ### After training, save the model!
        model.save(out_dir_cv, 'model.pt')

        # Predict the reconstructions from X_val and X_train
        X_train_fwd = model.predict(X_train_list, mask_train_list, nt=ntp)
        X_test_fwd = model.predict(X_test_list, mask_test_list, nt=ntp)

        # Unpad using the masks
        #plot validation and
        plot_total_loss(model.loss['total'], model.val_loss['total'],
                        "Total loss", out_dir_cv, "total_loss.png")
        plot_total_loss(model.loss['kl'], model.val_loss['kl'], "kl_loss",
                        out_dir_cv, "kl_loss.png")
        plot_total_loss(model.loss['ll'], model.val_loss['ll'], "ll_loss",
                        out_dir_cv,
                        "ll_loss.png")  #Negative to see downard curve

        #Compute mse and reconstruction loss
        #General mse and reconstruction over
        # test_loss = model.recon_loss(X_test_fwd, target=X_test_pad, mask=mask_test_tensor)
        train_loss = model.recon_loss(X_train_fwd,
                                      target=X_train_list,
                                      mask=mask_train_list)
        test_loss = model.recon_loss(X_test_fwd,
                                     target=X_test_list,
                                     mask=mask_test_list)

        print('MSE over the train set: ' + str(train_loss["mae"]))
        print('Reconstruction loss over the train set: ' +
              str(train_loss["rec_loss"]))

        print('MSE over the test set: ' + str(test_loss["mae"]))
        print('Reconstruction loss the train set: ' +
              str(test_loss["rec_loss"]))

        ######################
        ## Prediction of last time point
        ######################
        i = 0
        # Test data without last timepoint
        # X_test_tensors do have the last timepoint
        pred_ch = list(range(3))
        print(pred_ch)
        t_pred = 1
        res = eval_prediction(model, X_test, t_pred, pred_ch, DEVICE)

        for (i, ch) in enumerate(
            [x for (i, x) in enumerate(p["ch_names"]) if i in pred_ch]):
            loss[f'pred_{ch}_mae'].append(res[i])

        ############################
        ## Test reconstruction for each channel, using the other one
        ############################
        # For each channel
        if p["n_channels"] > 1:
            for i in range(len(X_test)):
                curr_name = p["ch_names"][i]
                av_ch = list(range(len(X_test)))
                av_ch.remove(i)
                mae_rec = eval_reconstruction(model, X_test, X_test_list,
                                              mask_test_list, av_ch, i)
                # Get MAE result for that specific channel over all timepoints
                loss[f"recon_{curr_name}_mae"].append(mae_rec)

        # Save results in the loss object
        loss["mae_train"].append(train_loss["mae"])
        loss["rec_train"].append(train_loss["rec_loss"])
        loss["mae_test"].append(train_loss["mae"])
        loss["loss_total"].append(model.loss['total'][-1])
        loss["loss_total_val"].append(model.val_loss['total'][-1])
        loss["loss_kl"].append(model.loss['kl'][-1])
        loss["loss_ll"].append(model.loss['ll'][-1])

        fold_n += 1
        # break at 5 iterations, need to do it faster
        if fold_n == 2:
            break

    # Compute the mean for every param in the loss dict
    for k in loss.keys():
        loss[k] = np.mean(loss[k])

    print(loss)
    return loss
Exemple #4
0
def run_experiment(p, csv_path, out_dir, data_cols=[]):
    """
    Function to run the experiments.
    p contain all the hyperparameters needed to run the experiments
    We assume that all the parameters needed are present in p!!
    out_dir is the out directory
    #hyperparameters
    """

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    #Seed
    torch.manual_seed(p["seed"])
    np.random.seed(p["seed"])

    #Redirect output to the out dir
    # sys.stdout = open(out_dir + 'output.out', 'w')

    #save parameters to the out dir
    with open(out_dir + "params.txt", "w") as f:
        f.write(str(p))

    # DEVICE
    ## Decidint on device on device.
    DEVICE_ID = 0
    DEVICE = torch.device(
        'cuda:' + str(DEVICE_ID) if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        torch.cuda.set_device(DEVICE_ID)

    # LOAD DATA
    #Start by not using validation data
    # this is a list of values
    X_train, X_test, Y_train, Y_test, mri_col = load_multimodal_data(
        csv_path,
        data_cols,
        train_set=0.8,
        normalize=True,
        return_covariates=True)

    p["n_feats"] = [x[0].shape[1] for x in X_train]

    X_train_list = []
    mask_train_list = []

    X_test_list = []
    mask_test_list = []

    print('Length of train/test')
    print(len(X_train[0]))
    print(len(X_test[0]))

    #For each channel, pad, create the mask, and append
    for x_ch in X_train:
        X_train_tensor = [torch.FloatTensor(t) for t in x_ch]
        X_train_pad = nn.utils.rnn.pad_sequence(X_train_tensor,
                                                batch_first=False,
                                                padding_value=np.nan)
        mask_train = ~torch.isnan(X_train_pad)
        mask_train_list.append(mask_train.to(DEVICE))
        X_train_pad[torch.isnan(X_train_pad)] = 0
        X_train_list.append(X_train_pad.to(DEVICE))

    for x_ch in X_test:
        X_test_tensor = [torch.FloatTensor(t) for t in x_ch]
        X_test_pad = nn.utils.rnn.pad_sequence(X_test_tensor,
                                               batch_first=False,
                                               padding_value=np.nan)
        mask_test = ~torch.isnan(X_test_pad)
        mask_test_list.append(mask_test.to(DEVICE))
        X_test_pad[torch.isnan(X_test_pad)] = 0
        X_test_list.append(X_test_pad.to(DEVICE))

    # ntp = max(X_train_list[0].shape[0], X_test_list[0].shape[0])
    ntp = max(max([x.shape[0] for x in X_train_list]),
              max([x.shape[0] for x in X_train_list]))

    model = rnnvae_drop.MCRNNVAE(p["h_size"],
                                 p["hidden"],
                                 p["n_layers"],
                                 p["hidden"],
                                 p["n_layers"],
                                 p["hidden"],
                                 p["n_layers"],
                                 p["z_dim"],
                                 p["hidden"],
                                 p["n_layers"],
                                 p["clip"],
                                 p["n_epochs"],
                                 p["batch_size"],
                                 p["n_channels"],
                                 p["n_feats"],
                                 DEVICE,
                                 0.3,
                                 print_every=100)

    model.ch_name = p["ch_names"]

    optimizer = torch.optim.Adam(model.parameters(), lr=p["learning_rate"])
    model.optimizer = optimizer

    model = model.to(DEVICE)
    # Fit the model
    model.fit(X_train_list, X_test_list, mask_train_list, mask_test_list)

    print("Print the dropout")
    print(model.dropout)

    ### After training, save the model!
    model.save(out_dir, 'model.pt')

    # Predict the reconstructions from X_val and X_train
    X_train_fwd = model.predict(X_train_list, nt=ntp)
    X_test_fwd = model.predict(X_test_list, nt=ntp)

    # Unpad using the masks
    #plot validation and
    plot_total_loss(model.loss['total'], model.val_loss['total'], "Total loss",
                    out_dir, "total_loss.png")
    plot_total_loss(model.loss['kl'], model.val_loss['kl'], "kl_loss", out_dir,
                    "kl_loss.png")
    plot_total_loss(model.loss['ll'], model.val_loss['ll'], "ll_loss", out_dir,
                    "ll_loss.png")  #Negative to see downard curve

    #Compute mse and reconstruction loss
    #General mse and reconstruction over
    # test_loss = model.recon_loss(X_test_fwd, target=X_test_pad, mask=mask_test_tensor)
    train_loss = model.recon_loss(X_train_fwd,
                                  target=X_train_list,
                                  mask=mask_train_list)
    test_loss = model.recon_loss(X_test_fwd,
                                 target=X_test_list,
                                 mask=mask_test_list)

    print('MSE over the train set: ' + str(train_loss["mae"]))
    print('Reconstruction loss over the train set: ' +
          str(train_loss["rec_loss"]))

    print('MSE over the test set: ' + str(test_loss["mae"]))
    print('Reconstruction loss the train set: ' + str(test_loss["rec_loss"]))

    ######################
    ## Prediction of last time point
    ######################

    # Test data without last timepoint
    # X_test_tensors do have the last timepoint
    X_test_list_minus = []
    X_test_tensors = []
    mask_test_list_minus = []
    for x_ch in X_test:
        X_test_tensor = [torch.FloatTensor(t[:-1, :]) for t in x_ch]
        X_test_tensor_full = [torch.FloatTensor(t) for t in x_ch]
        X_test_tensors.append(X_test_tensor_full)
        X_test_pad = nn.utils.rnn.pad_sequence(X_test_tensor,
                                               batch_first=False,
                                               padding_value=np.nan)
        mask_test = ~torch.isnan(X_test_pad)
        mask_test_list_minus.append(mask_test.to(DEVICE))
        X_test_pad[torch.isnan(X_test_pad)] = 0
        X_test_list_minus.append(X_test_pad.to(DEVICE))

    # Run prediction
    #this is terribly programmed holy shit
    X_test_fwd_minus = model.predict(X_test_list_minus, nt=ntp)
    X_test_xnext = X_test_fwd_minus["xnext"]
    last_tp_mse = 0
    #for each channel
    for i in range(len(X_test_tensors)):
        #For each subject, select the tp of the mask
        last_tp_mse_ch = 0
        n_mae = 0
        for j in range(len(X_test_tensors[i])):
            tp = len(X_test_tensors[i][j]) - 1
            last_tp_mse_ch += mean_squared_error(X_test_tensors[i][j][tp, :],
                                                 X_test_xnext[i][tp, j, :])
            n_mae += 1
        #compute the mean
        last_tp_mse += last_tp_mse_ch / n_mae
    #Compute MAE over last timepoint

    ############################
    ## Test reconstruction for each channel, using the other one
    ############################
    # For each channel
    rec_results = {}
    for i in range(len(X_test_list)):
        curr_name = p["ch_names"][i]
        av_ch = list(range(len(X_test_list)))
        av_ch.remove(i)
        # try to reconstruct it from the other ones
        ch_recon = model.predict(X_test_list, nt=ntp, av_ch=av_ch)
        ch_recon["xnext"]
        #for all existing timepoints
        mae_loss = 0
        for t in range(len(mask_test_list[i])):
            mask_channel = mask_test_list[i][t, :, 0]
            mae_loss += rnnvae_drop.mae(target=X_test_list[i][t].cpu(),
                                        predicted=ch_recon["xnext"][i][t],
                                        mask=mask_channel)

        # Get MAE result for that specific channel over all timepoints
        #for this, i also need the mask

        rec_results[f"recon_{curr_name}_mae"] = mae_loss.item()

    # Dir for projections
    proj_path = 'z_proj/'
    if not os.path.exists(out_dir + proj_path):
        os.makedirs(out_dir + proj_path)

    # Test the new function of latent space
    #NEED TO ADAPT THIS FUNCTION
    qzx_train = [np.array(x) for x in X_train_fwd['qzx']]
    qzx_test = [np.array(x) for x in X_test_fwd['qzx']]

    #Convert to standard
    #Add padding so that the mask also works here
    DX_train = [[x for x in elem] for elem in Y_train["DX"]]
    DX_test = [[x for x in elem] for elem in Y_test["DX"]]

    #Define colors
    pallete_dict = {"CN": "#2a9e1e", "MCI": "#bfbc1a", "AD": "#af1f1f"}

    # Get classificator labels, for n time points
    out_dir_sample = out_dir + 'zcomp_ch_dx/'
    if not os.path.exists(out_dir_sample):
        os.makedirs(out_dir_sample)

    plot_latent_space(model,
                      qzx_test,
                      ntp,
                      classificator=DX_test,
                      pallete_dict=pallete_dict,
                      plt_tp='all',
                      all_plots=True,
                      uncertainty=False,
                      savefig=True,
                      out_dir=out_dir_sample + '_test',
                      mask=mask_test_list)

    plot_latent_space(model,
                      qzx_train,
                      ntp,
                      classificator=DX_train,
                      pallete_dict=pallete_dict,
                      plt_tp='all',
                      all_plots=True,
                      uncertainty=False,
                      savefig=True,
                      out_dir=out_dir_sample + '_train',
                      mask=mask_train_list)

    out_dir_sample_t0 = out_dir + 'zcomp_ch_dx_t0/'
    if not os.path.exists(out_dir_sample_t0):
        os.makedirs(out_dir_sample_t0)

    plot_latent_space(model,
                      qzx_train,
                      ntp,
                      classificator=DX_train,
                      pallete_dict=pallete_dict,
                      plt_tp=[0],
                      all_plots=True,
                      uncertainty=False,
                      savefig=True,
                      out_dir=out_dir_sample_t0 + '_train',
                      mask=mask_train_list)

    plot_latent_space(model,
                      qzx_test,
                      ntp,
                      classificator=DX_test,
                      pallete_dict=pallete_dict,
                      plt_tp=[0],
                      all_plots=True,
                      uncertainty=False,
                      savefig=True,
                      out_dir=out_dir_sample_t0 + '_test',
                      mask=mask_test_list)

    # Now plot color by timepoint
    out_dir_sample = out_dir + 'zcomp_ch_tp/'
    if not os.path.exists(out_dir_sample):
        os.makedirs(out_dir_sample)

    classif_train = [[i for (i, x) in enumerate(elem)]
                     for elem in Y_train["DX"]]
    classif_test = [[i for (i, x) in enumerate(elem)] for elem in Y_test["DX"]]

    pallete = sns.color_palette("viridis", ntp)
    pallete_dict = {i: value for (i, value) in enumerate(pallete)}

    plot_latent_space(model,
                      qzx_train,
                      ntp,
                      classificator=classif_train,
                      pallete_dict=pallete_dict,
                      plt_tp='all',
                      all_plots=True,
                      uncertainty=False,
                      savefig=True,
                      out_dir=out_dir_sample + '_train',
                      mask=mask_train_list)

    plot_latent_space(model,
                      qzx_test,
                      ntp,
                      classificator=classif_test,
                      pallete_dict=pallete_dict,
                      plt_tp='all',
                      all_plots=True,
                      uncertainty=False,
                      savefig=True,
                      out_dir=out_dir_sample + '_test',
                      mask=mask_test_list)

    loss = {
        "mae_train": train_loss["mae"],
        "rec_train": train_loss["rec_loss"],
        "mae_test": test_loss["mae"],
        "mae_last_tp": last_tp_mse,
        "loss_total": model.loss['total'][-1],
        "loss_kl": model.loss['kl'][-1],
        "loss_ll": model.loss['ll'][-1]
    }

    loss = {**loss, **rec_results}

    return loss
Exemple #5
0

model = rnnvae_h.MCRNNVAE(p["h_size"], p["hidden"], p["n_layers"], 
                        p["hidden"], p["n_layers"], p["hidden"],
                        p["n_layers"], p["z_dim"], p["hidden"], p["n_layers"],
                        p["clip"], p["n_epochs"], p["batch_size"], 
                        p["n_channels"], p["ch_type"], p["n_feats"], DEVICE, print_every=100, 
                        phi_layers=p["phi_layers"], sigmoid_mean=p["sig_mean"],
                        dropout=p["dropout"], dropout_threshold=p["drop_th"])

optimizer = torch.optim.Adam(model.parameters(), lr=p["learning_rate"])
model.optimizer = optimizer

model = model.to(DEVICE)
# Fit the model
model.fit(X_train_list, X_train_list, mask_train_list, mask_train_list)
### After training, save the model!
model.save(out_dir, 'model.pt')

if p["dropout"]:
    print("Print the dropout")
    print(model.dropout_comp)

# Predict the reconstructions from X_val and X_train
X_train_fwd = model.predict(X_train_list, mask_train_list, nt=ntp)

# Unpad using the masks
#plot validation and tal
plot_total_loss(model.loss['total'], model.val_loss['total'], "Total loss", out_dir, "total_loss.png")
plot_total_loss(model.loss['kl'], model.val_loss['kl'], "kl_loss", out_dir, "kl_loss.png")
plot_total_loss(model.loss['ll'], model.val_loss['ll'], "ll_loss", out_dir, "ll_loss.png") #Negative to see downard curve
Exemple #6
0
def run_experiment(p, csv_path, out_dir):
    """
    Function to run the experiments.
    p contain all the hyperparameters needed to run the experiments
    We assume that all the parameters needed are present in p!!
    out_dir is the out directory
    #hyperparameters
    """

    # out_dir
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    #Seed
    torch.manual_seed(p["seed"])
    np.random.seed(p["seed"])

    #Redirect output to the out dir
    sys.stdout = open(out_dir + 'output.out', 'w')

    #save parameters to the out dir
    with open(out_dir + "params.txt", "w") as f:
        f.write(str(p))

    # DEVICE
    ## Decidint on device on device.
    DEVICE_ID = 0
    DEVICE = torch.device(
        'cuda:' + str(DEVICE_ID) if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        torch.cuda.set_device(DEVICE_ID)

    # LOAD DATA
    X_train, X_test, Y_train, Y_test = open_MRI_data(csv_path,
                                                     train_set=0.9,
                                                     normalize=True,
                                                     return_covariates=True)

    #Combine test and train Y for later
    Y = {}
    for k in Y_train.keys():
        Y[k] = Y_train[k] + Y_test[k]

    # List of (nt, nfeatures) numpy objects
    p["x_size"] = X_train[0].shape[1]
    print(p["x_size"])

    # Apply padding to both X_train and X_val
    #Permute so that dimension is (tp, nbatch, feat)
    X_train_tensor = torch.FloatTensor(X_train).permute((1, 0, 2))
    X_test_tensor = torch.FloatTensor(X_test).permute((1, 0, 2))

    # Define model and optimizer
    model = rnnvae.ModelRNNVAE(p["x_size"], p["h_size"], p["hidden"],
                               p["n_layers"], p["hidden"], p["n_layers"],
                               p["hidden"], p["n_layers"], p["z_dim"],
                               p["hidden"], p["n_layers"], p["clip"],
                               p["n_epochs"], p["batch_size"], DEVICE)

    optimizer = torch.optim.Adam(model.parameters(), lr=p["learning_rate"])
    model.optimizer = optimizer

    model = model.to(DEVICE)
    # Fit the model
    model.fit(X_train_tensor.to(DEVICE), X_test_tensor.to(DEVICE))

    ### After training, save the model!
    model.save(out_dir, 'model.pt')

    # Predict the reconstructions from X_val and X_train
    X_test_fwd = model.predict(X_test_tensor.to(DEVICE))
    X_train_fwd = model.predict(X_train_tensor.to(DEVICE))

    #Reformulate things
    X_train_fwd['xnext'] = np.array(X_train_fwd['xnext']).swapaxes(0, 1)
    X_train_fwd['z'] = np.array(X_train_fwd['z']).swapaxes(0, 1)
    X_test_fwd['xnext'] = np.array(X_test_fwd['xnext']).swapaxes(0, 1)
    X_test_fwd['z'] = np.array(X_test_fwd['z']).swapaxes(0, 1)

    X_test_hat = X_test_fwd["xnext"]
    X_train_hat = X_train_fwd["xnext"]

    # Unpad using the masks
    #after masking, need to rehsape to (nt, nfeat)

    #Compute mean absolute error over all sequences
    mse_train = np.mean([
        mean_absolute_error(xval, xhat)
        for (xval, xhat) in zip(X_train, X_train_hat)
    ])
    print('MSE over the train set: ' + str(mse_train))

    #Compute mean absolute error over all sequences
    mse_test = np.mean([
        mean_absolute_error(xval, xhat)
        for (xval, xhat) in zip(X_test, X_test_hat)
    ])
    print('MSE over the test set: ' + str(mse_test))

    #plot validation and
    plot_total_loss(model.loss['total'], model.val_loss['total'], "Total loss",
                    out_dir, "total_loss.png")
    plot_total_loss(model.loss['kl'], model.val_loss['kl'], "kl_loss", out_dir,
                    "kl_loss.png")
    plot_total_loss(model.loss['ll'], model.val_loss['ll'], "ll_loss", out_dir,
                    "ll_loss.png")  #Negative to see downard curve

    # Visualization of trajectories
    subj = 6
    feature = 12
    # For train

    plot_trajectory(
        X_train, X_train_hat, subj, 'all', out_dir,
        f'traj_train_s_{subj}_f_all')  # testing for a given subject
    plot_trajectory(
        X_train, X_train_hat, subj, feature, out_dir,
        f'traj_train_s_{subj}_f_{feature}')  # testing for a given feature

    # For test
    plot_trajectory(X_test, X_test_hat, subj, 'all', out_dir,
                    f'traj_test_s_{subj}_f_all')  # testing for a given subject
    plot_trajectory(
        X_test, X_test_hat, subj, feature, out_dir,
        f'traj_test_s_{subj}_f_{feature}')  # testing for a given feature

    z_train = X_train_fwd['z']
    z_test = X_test_fwd['z']
    z_train = [Z for (i, Z) in enumerate(z_train)]
    z_test = [Z for (i, Z) in enumerate(z_test)]
    z = z_train + z_test

    # Dir for projections
    proj_path = 'z_proj/'
    if not os.path.exists(out_dir + proj_path):
        os.makedirs(out_dir + proj_path)

    #plot latent space
    for dim0 in range(p["z_dim"]):
        for dim1 in range(dim0, p["z_dim"]):
            if dim0 == dim1: continue  # very dirty
            plot_z_time_2d(z,
                           p["ntp"], [dim0, dim1],
                           out_dir + 'z_proj/',
                           out_name=f'z_d{dim0}_d{dim1}')

    # Dir for projections
    sampling_path = 'z_proj_dx/'
    if not os.path.exists(out_dir + sampling_path):
        os.makedirs(out_dir + sampling_path)

    #plot latent space
    for dim0 in range(p["z_dim"]):
        for dim1 in range(dim0, p["z_dim"]):
            if dim0 == dim1: continue  # very dirty
            plot_z_time_2d(z,
                           p["ntp"], [dim0, dim1],
                           out_dir + sampling_path,
                           c='DX',
                           Y=Y,
                           out_name=f'z_d{dim0}_d{dim1}')

    # Dir for projections
    sampling_path = 'z_proj_age/'
    if not os.path.exists(out_dir + sampling_path):
        os.makedirs(out_dir + sampling_path)

    #plot latent space
    for dim0 in range(p["z_dim"]):
        for dim1 in range(dim0, p["z_dim"]):
            if dim0 == dim1: continue  # very dirty
            plot_z_time_2d(z,
                           p["ntp"], [dim0, dim1],
                           out_dir + sampling_path,
                           c='AGE',
                           Y=Y,
                           out_name=f'z_d{dim0}_d{dim1}')

    #Sampling
    # TODO: THIS NEEDS TO BE UPDATED
    nt = p['ntp']
    nsamples = 1000
    X_sample = model.sample_latent(nsamples, nt)

    #Get the samples
    X_sample['xnext'] = np.array(X_sample['xnext']).swapaxes(0, 1)
    X_sample['z'] = np.array(X_sample['z']).swapaxes(0, 1)

    # Dir for projections
    sampling_path = 'z_proj_sampling/'
    if not os.path.exists(out_dir + sampling_path):
        os.makedirs(out_dir + sampling_path)

    #plot latent space
    for dim0 in range(p["z_dim"]):
        for dim1 in range(dim0, p["z_dim"]):
            if dim0 == dim1: continue  # very dirty
            plot_z_time_2d(X_sample['z'],
                           p["ntp"], [dim0, dim1],
                           out_dir + 'z_proj_sampling/',
                           out_name=f'z_d{dim0}_d{dim1}')

    loss = {
        "mse_train": mse_train,
        "mse_test": mse_test,
        "loss_total": model.loss['total'][-1],
        "loss_kl": model.loss['kl'][-1],
        "loss_ll": model.loss['ll'][-1]
    }

    return loss
def run_experiment(p, out_dir):
    """
    Function to run the experiments.
    p contain all the hyperparameters needed to run the experiments
    We assume that all the parameters needed are present in p!!
    out_dir is the out directory
    #hyperparameters
    """
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    #Seed
    torch.manual_seed(p["seed"])
    np.random.seed(p["seed"])

    #Redirect output to the out dir
    sys.stdout = open(out_dir + 'output.out', 'w')

    #save parameters to the out dir
    with open(out_dir + "params.txt", "w") as f:
        f.write(str(p))

    # DEVICE
    ## Decidint on device on device.
    DEVICE_ID = 0
    DEVICE = torch.device(
        'cuda:' + str(DEVICE_ID) if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        torch.cuda.set_device(DEVICE_ID)

    #generate the data
    #Create class object
    gen_model = SinDataGenerator(p["curves"], p["ntp"], p["noise"])
    samples = gen_model.generate_n_samples(p["nsamples"])
    X_train = np.asarray([y for (_, y) in samples])
    X_train_tensor = torch.FloatTensor(X_train).permute((1, 0, 2))

    samples = gen_model.generate_n_samples(int(p["nsamples"] * 0.8))
    X_test = np.asarray([y for (_, y) in samples])
    X_test_tensor = torch.FloatTensor(X_test).permute((1, 0, 2))
    #Prepare model
    # Define model and optimizer
    model = rnnvae.ModelRNNVAE(p["x_size"], p["h_size"], p["hidden"],
                               p["n_layers"], p["hidden"], p["n_layers"],
                               p["hidden"], p["n_layers"], p["z_dim"],
                               p["hidden"], p["n_layers"], p["clip"],
                               p["n_epochs"], p["batch_size"], DEVICE)

    optimizer = torch.optim.Adam(model.parameters(), lr=p["learning_rate"])
    model.optimizer = optimizer

    model = model.to(DEVICE)
    # Fit the model
    model.fit(X_train_tensor.to(DEVICE), X_test_tensor.to(DEVICE))

    ### After training, save the model!
    model.save(out_dir, 'model.pt')

    # Predict the reconstructions from X_val and X_train
    X_test_fwd = model.predict(X_test_tensor.to(DEVICE))
    X_train_fwd = model.predict(X_train_tensor.to(DEVICE))

    #Reformulate things
    X_train_fwd['xnext'] = np.array(X_train_fwd['xnext']).swapaxes(0, 1)
    X_train_fwd['z'] = np.array(X_train_fwd['z']).swapaxes(0, 1)
    X_test_fwd['xnext'] = np.array(X_test_fwd['xnext']).swapaxes(0, 1)
    X_test_fwd['z'] = np.array(X_test_fwd['z']).swapaxes(0, 1)

    X_test_hat = X_test_fwd["xnext"]
    X_train_hat = X_train_fwd["xnext"]

    # Unpad using the masks
    #after masking, need to rehsape to (nt, nfeat)

    #Compute mean absolute error over all sequences
    mse_train = np.mean([
        mean_absolute_error(xval, xhat)
        for (xval, xhat) in zip(X_train, X_train_hat)
    ])
    print('MSE over the train set: ' + str(mse_train))

    #Compute mean absolute error over all sequences
    mse_test = np.mean([
        mean_absolute_error(xval, xhat)
        for (xval, xhat) in zip(X_test, X_test_hat)
    ])
    print('MSE over the test set: ' + str(mse_test))

    #plot validation and
    plot_total_loss(model.loss['total'], model.val_loss['total'], "Total loss",
                    out_dir, "total_loss.png")
    plot_total_loss(model.loss['kl'], model.val_loss['kl'], "kl_loss", out_dir,
                    "kl_loss.png")
    plot_total_loss(model.loss['ll'], model.val_loss['ll'], "ll_loss", out_dir,
                    "ll_loss.png")  #Negative to see downard curve

    # Visualization of trajectories
    subj = 6
    feature = 0
    # For train
    plot_trajectory(
        X_train, X_train_hat, subj, 'all', out_dir,
        f'traj_train_s_{subj}_f_all')  # testing for a given subject
    plot_trajectory(
        X_train, X_train_hat, subj, feature, out_dir,
        f'traj_train_s_{subj}_f_{feature}')  # testing for a given feature

    # For test
    plot_trajectory(X_test, X_test_hat, subj, 'all', out_dir,
                    f'traj_test_s_{subj}_f_all')  # testing for a given subject
    plot_trajectory(
        X_test, X_test_hat, subj, feature, out_dir,
        f'traj_test_s_{subj}_f_{feature}')  # testing for a given feature

    z_train = X_train_fwd['z']
    z_test = X_test_fwd['z']
    z_train = [Z for (i, Z) in enumerate(z_train)]
    z_test = [Z for (i, Z) in enumerate(z_test)]
    z = z_train + z_test

    # Dir for projections
    proj_path = 'z_proj/'
    if not os.path.exists(out_dir + proj_path):
        os.makedirs(out_dir + proj_path)

    #plot latent space
    for dim0 in range(p["z_dim"]):
        for dim1 in range(dim0, p["z_dim"]):
            if dim0 == dim1: continue  # very dirty
            plot_z_time_2d(z,
                           p["ntp"], [dim0, dim1],
                           out_dir + proj_path,
                           out_name=f'z_d{dim0}_d{dim1}')

    #Sampling
    # Create first samples with only one timepoint
    gen_model = SinDataGenerator(p["curves"], p["ntp"], p["noise"])
    samples = gen_model.generate_n_samples(500)
    X_samples = np.asarray([y[:1] for (_, y) in samples])
    X_samples = torch.FloatTensor(X_samples).permute((1, 0, 2))

    X_sample = model.sequence_predict(X_samples.to(DEVICE), p['ntp'])

    #Get the samples
    X_sample['xnext'] = np.array(X_sample['xnext']).swapaxes(0, 1)
    X_sample['z'] = np.array(X_sample['z']).swapaxes(0, 1)

    # plot the samples over time
    plot_many_trajectories(X_sample['xnext'], 'all', p["ntp"], out_dir,
                           'x_samples')

    # Dir for projections
    sampling_path = 'z_proj_sampling/'
    if not os.path.exists(out_dir + sampling_path):
        os.makedirs(out_dir + sampling_path)

    #plot latent space
    for dim0 in range(p["z_dim"]):
        for dim1 in range(dim0, p["z_dim"]):
            if dim0 == dim1: continue  # very dirty
            plot_z_time_2d(X_sample['z'],
                           p["ntp"], [dim0, dim1],
                           out_dir + sampling_path,
                           out_name=f'z_d{dim0}_d{dim1}')

    loss = {
        "mse_train": mse_train,
        "mse_test": mse_test,
        "loss_total": model.loss['total'][-1],
        "loss_kl": model.loss['kl'][-1],
        "loss_ll": model.loss['ll'][-1]
    }

    return loss
Exemple #8
0
def run_experiment(p, csv_path, out_dir, data_cols=[]):
    """
    Function to run the experiments.
    p contain all the hyperparameters needed to run the experiments
    We assume that all the parameters needed are present in p!!
    out_dir is the out directory
    #hyperparameters
    """

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    #Seed
    torch.manual_seed(p["seed"])
    np.random.seed(p["seed"])

    #Redirect output to the out dir
    # sys.stdout = open(out_dir + 'output.out', 'w')

    #save parameters to the out dir
    with open(out_dir + "params.txt", "w") as f:
        f.write(str(p))

    # DEVICE
    ## Decidint on device on device.
    DEVICE_ID = 0
    DEVICE = torch.device(
        'cuda:' + str(DEVICE_ID) if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        torch.cuda.set_device(DEVICE_ID)

    # LOAD DATA
    #Start by not using validation data
    # this is a list of values
    X_train, X_test, Y_train, Y_test, mri_col = load_multimodal_data(
        csv_path,
        data_cols,
        p["ch_type"],
        train_set=0.9,
        normalize=True,
        return_covariates=True)

    p["n_feats"] = [x[0].shape[1] for x in X_train]

    X_train_list = []
    mask_train_list = []

    X_test_list = []
    mask_test_list = []

    print('Length of train/test')
    print(len(X_train[0]))
    print(len(X_test[0]))

    #For each channel, pad, create the mask, and append
    for x_ch in X_train:
        X_train_tensor = [torch.FloatTensor(t) for t in x_ch]
        X_train_pad = nn.utils.rnn.pad_sequence(X_train_tensor,
                                                batch_first=False,
                                                padding_value=np.nan)
        mask_train = ~torch.isnan(X_train_pad)
        mask_train_list.append(mask_train.to(DEVICE))
        X_train_pad[torch.isnan(X_train_pad)] = 0
        X_train_list.append(X_train_pad.to(DEVICE))

    for x_ch in X_test:
        X_test_tensor = [torch.FloatTensor(t) for t in x_ch]
        X_test_pad = nn.utils.rnn.pad_sequence(X_test_tensor,
                                               batch_first=False,
                                               padding_value=np.nan)
        mask_test = ~torch.isnan(X_test_pad)
        mask_test_list.append(mask_test.to(DEVICE))
        X_test_pad[torch.isnan(X_test_pad)] = 0
        X_test_list.append(X_test_pad.to(DEVICE))

    # ntp = max(X_train_list[0].shape[0], X_test_list[0].shape[0])
    ntp = max(max([x.shape[0] for x in X_train_list]),
              max([x.shape[0] for x in X_train_list]))

    model = rnnvae_h.MCRNNVAE(p["h_size"],
                              p["x_hidden"],
                              p["x_n_layers"],
                              p["z_hidden"],
                              p["z_n_layers"],
                              p["enc_hidden"],
                              p["enc_n_layers"],
                              p["z_dim"],
                              p["dec_hidden"],
                              p["dec_n_layers"],
                              p["clip"],
                              p["n_epochs"],
                              p["batch_size"],
                              p["n_channels"],
                              p["ch_type"],
                              p["n_feats"],
                              DEVICE,
                              print_every=100,
                              phi_layers=p["phi_layers"],
                              sigmoid_mean=p["sig_mean"],
                              dropout=p["dropout"],
                              dropout_threshold=p["drop_th"])

    model.ch_name = p["ch_names"]

    optimizer = torch.optim.Adam(model.parameters(), lr=p["learning_rate"])
    model.optimizer = optimizer

    model = model.to(DEVICE)

    # Fit the model
    # FIT IT FOR THE NUMBER OF EPOCHS, X TIMES
    ntimes = 20
    for nrep in range(ntimes):
        print(nrep)
        model.fit(X_train_list, X_test_list, mask_train_list, mask_test_list)

        #fit the model after changing the lr
        if p["dropout"]:
            print("Print the dropout")
            print(model.dropout_comp)

        ### After training, save the model!
        model.save(out_dir, 'model.pt')

        # Predict the reconstructions from X_val and X_train
        X_train_fwd = model.predict(X_train_list, mask_train_list, nt=ntp)
        X_test_fwd = model.predict(X_test_list, mask_test_list, nt=ntp)

        # Unpad using the masks
        #plot validation and
        plot_total_loss(model.loss['total'], model.val_loss['total'],
                        "Total loss", out_dir, "total_loss.png")
        plot_total_loss(model.loss['kl'], model.val_loss['kl'], "kl_loss",
                        out_dir, "kl_loss.png")
        plot_total_loss(model.loss['ll'], model.val_loss['ll'], "ll_loss",
                        out_dir, "ll_loss.png")  #Negative to see downard curve

        #Compute mse and reconstruction loss
        #General mse and reconstruction over
        # test_loss = model.recon_loss(X_test_fwd, target=X_test_pad, mask=mask_test_tensor)
        train_loss = model.recon_loss(X_train_fwd,
                                      target=X_train_list,
                                      mask=mask_train_list)
        test_loss = model.recon_loss(X_test_fwd,
                                     target=X_test_list,
                                     mask=mask_test_list)

        print('MSE over the train set: ' + str(train_loss["mae"]))
        print('Reconstruction loss over the train set: ' +
              str(train_loss["rec_loss"]))

        print('MSE over the test set: ' + str(test_loss["mae"]))
        print('Reconstruction loss the train set: ' +
              str(test_loss["rec_loss"]))

        pred_results = {}
        for ch_name in p["ch_names"][:3]:
            pred_results[f"pred_{ch_name}_mae"] = []

        rec_results = {}
        for ch_name in p["ch_names"]:
            rec_results[f"recon_{ch_name}_mae"] = []

        results = {**pred_results, **rec_results}

        ######################
        ## Prediction of last time point
        ######################

        i = 0
        # Test data without last timepoint
        # X_test_tensors do have the last timepoint
        pred_ch = list(range(3))
        print(pred_ch)
        t_pred = 1
        res = eval_prediction(model, X_test, t_pred, pred_ch, DEVICE)

        for (i, ch) in enumerate(
            [x for (i, x) in enumerate(p["ch_names"]) if i in pred_ch]):
            loss[f'pred_{ch}_mae'].append(res[i])

        ############################
        ## Test reconstruction for each channel, using the other one
        ############################
        # For each channel
        if p["n_channels"] > 1:

            for i in range(len(X_test)):
                curr_name = p["ch_names"][i]
                av_ch = list(range(len(X_test)))
                av_ch.remove(i)
                mae_rec = eval_reconstruction(model, X_test, X_test_list,
                                              mask_test_list, av_ch, i)
                # Get MAE result for that specific channel over all timepoints
                results[f"recon_{curr_name}_mae"] = mae_rec

        loss = {
            "mae_train": train_loss["mae"],
            "rec_train": train_loss["rec_loss"],
            "mae_test": test_loss["mae"],
            "loss_total": model.loss['total'][-1],
            "loss_kl": model.loss['kl'][-1],
            "loss_ll": model.loss['ll'][-1],
        }

        if p["dropout"]:
            loss["dropout_comps"] = model.dropout_comp

        loss = {**loss, **results}
        print(results)
    """
    # Dir for projections
    proj_path = 'z_proj/'
    if not os.path.exists(out_dir + proj_path):
        os.makedirs(out_dir + proj_path)

    # Test the new function of latent space
    #NEED TO ADAPT THIS FUNCTION
    qzx_train = [np.array(x) for x in X_train_fwd['qzx']]
    qzx_test = [np.array(x) for x in X_test_fwd['qzx']]

    #Convert to standard
    #Add padding so that the mask also works here
    DX_train = [[x for x in elem] for elem in Y_train["DX"]]
    DX_test = [[x for x in elem] for elem in Y_test["DX"]]

    #Define colors
    pallete_dict = {
        "CN": "#2a9e1e",
        "MCI": "#bfbc1a",
        "AD": "#af1f1f"
    }
    # Get classificator labels, for n time points
    out_dir_sample = out_dir + 'zcomp_ch_dx/'
    if not os.path.exists(out_dir_sample):
        os.makedirs(out_dir_sample)

    plot_latent_space(model, qzx_test, ntp, classificator=DX_test, pallete_dict=pallete_dict, plt_tp='all',
                all_plots=True, uncertainty=False, savefig=True, out_dir=out_dir_sample + '_test', mask=mask_test_list)

    plot_latent_space(model, qzx_train, ntp, classificator=DX_train, pallete_dict=pallete_dict, plt_tp='all',
                    all_plots=True, uncertainty=False, savefig=True, out_dir=out_dir_sample + '_train', mask=mask_train_list)
    
    out_dir_sample_t0 = out_dir + 'zcomp_ch_dx_t0/'
    if not os.path.exists(out_dir_sample_t0):
        os.makedirs(out_dir_sample_t0)

    plot_latent_space(model, qzx_train, ntp, classificator=DX_train, pallete_dict=pallete_dict, plt_tp=[0],
                    all_plots=True, uncertainty=False, savefig=True, out_dir=out_dir_sample_t0 + '_train', mask=mask_train_list)

    plot_latent_space(model, qzx_test, ntp, classificator=DX_test, pallete_dict=pallete_dict, plt_tp=[0],
                    all_plots=True, uncertainty=False, savefig=True, out_dir=out_dir_sample_t0 + '_test', mask=mask_test_list)

    # Now plot color by timepoint
    out_dir_sample = out_dir + 'zcomp_ch_tp/'
    if not os.path.exists(out_dir_sample):
        os.makedirs(out_dir_sample)

    classif_train = [[i for (i, x) in enumerate(elem)] for elem in Y_train["DX"]]
    classif_test = [[i for (i, x) in enumerate(elem)] for elem in Y_test["DX"]]

    pallete = sns.color_palette("viridis", ntp)
    pallete_dict = {i:value for (i, value) in enumerate(pallete)}

    plot_latent_space(model, qzx_train, ntp, classificator=classif_train, pallete_dict=pallete_dict, plt_tp='all',
                    all_plots=True, uncertainty=False, savefig=True, out_dir=out_dir_sample + '_train', mask=mask_train_list)

    plot_latent_space(model, qzx_test, ntp, classificator=classif_test, pallete_dict=pallete_dict, plt_tp='all',
                    all_plots=True, uncertainty=False, savefig=True, out_dir=out_dir_sample + '_test', mask=mask_test_list)
    """
    return loss
def run_experiment(p, csv_path, out_dir, data_cols=[]):
    """
    Function to run the experiments.
    p contain all the hyperparameters needed to run the experiments
    We assume that all the parameters needed are present in p!!
    out_dir is the out directory
    #hyperparameters
    """

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    #Seed
    torch.manual_seed(p["seed"])
    np.random.seed(p["seed"])

    #Redirect output to the out dir
    # sys.stdout = open(out_dir + 'output.out', 'w')

    #save parameters to the out dir
    with open(out_dir + "params.txt", "w") as f:
        f.write(str(p))

    # DEVICE
    ## Decidint on device on device.
    DEVICE_ID = 0
    DEVICE = torch.device(
        'cuda:' + str(DEVICE_ID) if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        torch.cuda.set_device(DEVICE_ID)

    # LOAD DATA
    #Start by not using validation data
    # this is a list of values
    X_train, X_test, Y_train, Y_test, mri_col = load_multimodal_data(
        csv_path,
        data_cols,
        p["ch_type"],
        train_set=0.95,
        normalize=True,
        return_covariates=True)

    p["n_feats"] = [x[0].shape[1] for x in X_train]

    X_train_list = []
    mask_train_list = []

    print('Length of train/test')
    print(len(X_train[0]))

    # need to deal with ntp here
    ntp = max(np.max([[len(xi) for xi in x] for x in X_train]),
              np.max([[len(xi) for xi in x] for x in X_train]))

    if p["long_to_bl"]:
        # HERE, change bl to long and repeat the values at t0 for ntp
        for i in range(len(p["ch_type"])):
            if p["ch_type"][i] == 'bl':
                for j in range(len(X_train[i])):
                    X_train[i][j] = np.array([X_train[i][j][0]] * ntp)
                for j in range(len(X_test[i])):
                    X_test[i][j] = np.array([X_test[i][j][0]] * ntp)

                # p["ch_type"][i] = 'long'

    #For each channel, pad, create the mask, and append
    for x_ch in X_train:
        X_train_tensor = [torch.FloatTensor(t) for t in x_ch]
        X_train_pad = nn.utils.rnn.pad_sequence(X_train_tensor,
                                                batch_first=False,
                                                padding_value=np.nan)
        mask_train = ~torch.isnan(X_train_pad)
        mask_train_list.append(mask_train.to(DEVICE))
        X_train_pad[torch.isnan(X_train_pad)] = 0
        X_train_list.append(X_train_pad.to(DEVICE))

    X_test_list = []
    mask_test_list = []

    for x_ch in X_test:
        X_test_tensor = [torch.FloatTensor(t) for t in x_ch]
        X_test_pad = nn.utils.rnn.pad_sequence(X_test_tensor,
                                               batch_first=False,
                                               padding_value=np.nan)
        mask_test = ~torch.isnan(X_test_pad)
        mask_test_list.append(mask_test.to(DEVICE))
        X_test_pad[torch.isnan(X_test_pad)] = 0
        X_test_list.append(X_test_pad.to(DEVICE))

    ntp = max(max([x.shape[0] for x in X_train_list]),
              max([x.shape[0] for x in X_test_list]))

    model = rnnvae_h.MCRNNVAE(p["h_size"],
                              p["x_hidden"],
                              p["x_n_layers"],
                              p["z_hidden"],
                              p["z_n_layers"],
                              p["enc_hidden"],
                              p["enc_n_layers"],
                              p["z_dim"],
                              p["dec_hidden"],
                              p["dec_n_layers"],
                              p["clip"],
                              p["n_epochs"],
                              p["batch_size"],
                              p["n_channels"],
                              p["ch_type"],
                              p["n_feats"],
                              p["c_z"],
                              DEVICE,
                              print_every=100,
                              phi_layers=p["phi_layers"],
                              sigmoid_mean=p["sig_mean"],
                              dropout=p["dropout"],
                              dropout_threshold=p["drop_th"])

    model.ch_name = p["ch_names"]

    optimizer = torch.optim.Adam(model.parameters(), lr=p["learning_rate"])
    model.optimizer = optimizer

    model = model.to(DEVICE)
    # Fit the model
    model.fit(X_train_list, X_test_list, mask_train_list, mask_test_list)

    #fit the model after changing the lr
    #optimizer = torch.optim.Adam(model.parameters(), lr=p["learning_rate"]*.1)
    #model.optimizer = optimizer
    #print('Refining optimization...')
    #model.fit(X_train_list, X_test_list, mask_train_list, mask_test_list)
    if p["dropout"]:
        print("Print the dropout")
        print(model.dropout_comp)

    ### After training, save the model!
    model.save(out_dir, 'model.pt')

    # Predict the reconstructions from X_val and X_train
    X_train_fwd = model.predict(X_train_list, mask_train_list, nt=ntp)

    # Unpad using the masks
    #plot validation and
    plot_total_loss(model.loss['total'], model.val_loss['total'], "Total loss",
                    out_dir, "total_loss.png")
    plot_total_loss(model.loss['kl'], model.val_loss['kl'], "kl_loss", out_dir,
                    "kl_loss.png")
    plot_total_loss(model.loss['ll'], model.val_loss['ll'], "ll_loss", out_dir,
                    "ll_loss.png")  #Negative to see downard curve

    #Compute mse and reconstruction loss
    #General mse and reconstruction over
    # test_loss = model.recon_loss(X_test_fwd, target=X_test_pad, mask=mask_test_tensor)
    train_loss = model.recon_loss(X_train_fwd,
                                  target=X_train_list,
                                  mask=mask_train_list)

    print('MSE over the train set: ' + str(train_loss["mae"]))
    print('Reconstruction loss over the train set: ' +
          str(train_loss["rec_loss"]))

    loss = {
        "mae_train": train_loss["mae"],
        "rec_train": train_loss["rec_loss"],
        "loss_total": model.loss['total'][-1],
        "loss_kl": model.loss['kl'][-1],
        "loss_ll": model.loss['ll'][-1],
    }

    if p["dropout"]:
        loss["dropout_comps"] = model.dropout_comp

    print(loss)

    return loss
Exemple #10
0
def run_experiment(p,
                   out_dir,
                   gen_data=True,
                   data_suffix=None,
                   output_to_file=False):
    """
    Function to run the experiments.
    p contain all the hyperparameters needed to run the experiments
    We assume that all the parameters needed are present in p!!
    out_dir is the out directory
    gen_data: bool that indicates if we have to generate the data or load it from disk
    data_suffix: if gen_data=True, put the suffix of the data here, which will be in data/synth_data/ by defualt
    #hyperparameters
    """
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    #Seed
    torch.manual_seed(p["seed"])
    np.random.seed(p["seed"])

    #Redirect output to the out dir
    if output_to_file:
        sys.stdout = open(out_dir + 'output.out', 'w')

    #save parameters to the out dir
    with open(out_dir + "params.txt", "w") as f:
        f.write(str(p))

    # DEVICE
    ## Decidint on device on device.
    DEVICE_ID = 0
    DEVICE = torch.device(
        'cuda:' + str(DEVICE_ID) if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        torch.cuda.set_device(DEVICE_ID)

    #Tensors should have the shape
    # [ntp n_ch, n_batch, n_feat]
    if gen_data:
        synth_dir = 'data/synth_data/'
        Z_train = pickle_load(synth_dir + f"ztrain{data_suffix}")
        Z_test = pickle_load(synth_dir + f"ztest{data_suffix}")
        X_train = pickle_load(synth_dir + f"xtrain{data_suffix}")
        X_test = pickle_load(synth_dir + f"xtest{data_suffix}")
    else:
        lat_gen = LatentTemporalGenerator(p["ntp"], p["noise"], p["lat_dim"],
                                          p["n_channels"], p["n_feats"])
        Z_train, X_train = lat_gen.generate_samples(p["nsamples"])
        Z_test, X_test = lat_gen.generate_samples(int(p["nsamples"] * 0.2),
                                                  train=False)

    # Save the data used to the output disk
    to_save = [Z_train, X_train, Z_test, X_test]
    filenames = [f"ztrain", f"xtrain", f"ztest", f"xtest"]
    for object, file in zip(to_save, filenames):
        pickle_dump(object, out_dir + file)

    X_train_list = []
    mask_train_list = []

    #generate the data, and the mask corresponding to each channel
    for x_ch in X_train[:p["n_channels"]]:
        X_train_tensor = [torch.FloatTensor(t) for t in x_ch]
        X_train_pad_i = nn.utils.rnn.pad_sequence(X_train_tensor,
                                                  batch_first=False,
                                                  padding_value=np.nan)
        mask_train = ~torch.isnan(X_train_pad_i)
        mask_train_list.append(mask_train.to(DEVICE))
        X_train_pad_i[torch.isnan(X_train_pad_i)] = 0
        X_train_list.append(X_train_pad_i.to(DEVICE))

    X_test_list = []
    mask_test_list = []

    for x_ch in X_test[:p["n_channels"]]:
        X_test_tensor = [torch.FloatTensor(t) for t in x_ch]
        X_test_pad = nn.utils.rnn.pad_sequence(X_test_tensor,
                                               batch_first=False,
                                               padding_value=np.nan)
        mask_test = ~torch.isnan(X_test_pad)
        mask_test_list.append(mask_test.to(DEVICE))
        X_test_pad[torch.isnan(X_test_pad)] = 0
        X_test_list.append(X_test_pad.to(DEVICE))

    #Stack along first dimension
    #cant do that bc last dimension (features) could be different length
    p["n_feats"] = [p["n_feats"] for _ in range(p["n_channels"])]

    # Prepare model
    # Define model and optimizer
    model = rnnvae_s.MCRNNVAE(p["h_size"],
                              p["enc_hidden"],
                              p["enc_n_layers"],
                              p["z_dim"],
                              p["dec_hidden"],
                              p["dec_n_layers"],
                              p["clip"],
                              p["n_epochs"],
                              p["batch_size"],
                              p["n_channels"],
                              p["ch_type"],
                              p["n_feats"],
                              p["c_z"],
                              DEVICE,
                              print_every=100,
                              phi_layers=p["phi_layers"],
                              sigmoid_mean=p["sig_mean"],
                              dropout=p["dropout"],
                              dropout_threshold=p["drop_th"])

    model.ch_name = p["ch_names"]

    optimizer = torch.optim.Adam(model.parameters(), lr=p["learning_rate"])
    model.optimizer = optimizer

    model = model.to(DEVICE)
    # Fit the model
    model.fit(X_train_list, X_test_list, mask_train_list, mask_test_list)

    if p["dropout"]:
        print("Print the dropout")
        print(model.dropout_comp)

    ### After training, save the model!
    model.save(out_dir, 'model.pt')

    # Predict the reconstructions from X_val and X_train
    X_test_list = model.predict(X_test_list, mask_test_list, nt=p["ntp"])
    X_train_fwd = model.predict(X_train_list, mask_train_list, nt=p["ntp"])

    # Unpad using the masks
    #plot validation and
    plot_total_loss(model.loss['total'], model.val_loss['total'], "Total loss",
                    out_dir, "total_loss.png")
    plot_total_loss(model.loss['kl'], model.val_loss['kl'], "kl_loss", out_dir,
                    "kl_loss.png")
    plot_total_loss(model.loss['ll'], model.val_loss['ll'], "ll_loss", out_dir,
                    "ll_loss.png")  #Negative to see downard curve

    ##Latent spasce
    #Reformulate things
    z_train = [np.array(x).swapaxes(0, 1) for x in X_train_fwd['z']]

    # Dir for projections
    proj_path = 'z_proj/'
    if not os.path.exists(out_dir + proj_path):
        os.makedirs(out_dir + proj_path)

    #plot latent space
    for ch in range(p["n_channels"]):
        for dim0 in range(p["z_dim"]):
            for dim1 in range(dim0, p["z_dim"]):
                if dim0 == dim1: continue  # very dirty
                plot_z_time_2d(z_train[ch],
                               p["ntp"], [dim0, dim1],
                               out_dir + proj_path,
                               out_name=f'z_ch_{ch}_d{dim0}_d{dim1}')

    # Dir for projections
    sampling_path = 'z_proj_sampling/'
    if not os.path.exists(out_dir + sampling_path):
        os.makedirs(out_dir + sampling_path)

    # Test the new function of latent space
    qzx = [np.array(x) for x in X_train_fwd['qzx']]

    # Get classificator labels, for n time points
    classif = [[i] * p["nsamples"] for i in range(p["ntp"])]
    classif = np.array([str(item) for elem in classif for item in elem])
    print("on_classif")
    print(p["ntp"])
    print(len(classif))

    out_dir_sample = out_dir + 'test_zspace_function/'
    if not os.path.exists(out_dir_sample):
        os.makedirs(out_dir_sample)

    # TODO: ADAPT THIS FUNCTION TO THE
    #plot_latent_space(model, qzx, p["ntp"], classificator=classif, plt_tp='all',
    #                all_plots=False, uncertainty=True, savefig=True, out_dir=out_dir_sample)

    loss = {
        "loss_total": model.loss['total'][-1],
        "loss_kl": model.loss['kl'][-1],
        "loss_ll": model.loss['ll'][-1]
    }

    return loss
Exemple #11
0
def run_experiment(p, csv_path, out_dir, data_cols=[]):
    """
    Function to run the experiments.
    p contain all the hyperparameters needed to run the experiments
    We assume that all the parameters needed are present in p!!
    out_dir is the out directory
    #hyperparameters
    """
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    #Seed
    torch.manual_seed(p["seed"])
    np.random.seed(p["seed"])

    #Redirect output to the out dir
    # sys.stdout = open(out_dir + 'output.out', 'w')

    #save parameters to the out dir
    with open(out_dir + "params.txt", "w") as f:
        f.write(str(p))

    # DEVICE
    ## Decidint on device on device.
    DEVICE_ID = 0
    DEVICE = torch.device(
        'cuda:' + str(DEVICE_ID) if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        torch.cuda.set_device(DEVICE_ID)

    # Begin on the CV data
    gen = load_multimodal_data_cv(csv_path,
                                  data_cols,
                                  p["ch_type"],
                                  nsplit=10,
                                  normalize=True)
    # Prepare the data structures for the data

    # Load the different folds
    loss = {
        "mae_train": [],
        "rec_train": [],
        "mae_test": [],
        "loss_total": [],
        "loss_kl": [],
        "loss_ll": [],
    }

    pred_results = {}
    for ch_name in p["ch_names"][:3]:
        pred_results[f"pred_{ch_name}_mae"] = []

    rec_results = {}
    for ch_name in p["ch_names"]:
        rec_results[f"recon_{ch_name}_mae"] = []

    loss = {**loss, **pred_results, **rec_results}

    # iterator marking the fold
    fold_n = 0

    for X_train, X_test, Y_train, Y_test, mri_col in gen:
        # LOAD DATA
        #Start by not using validation data
        # this is a list of values

        #Create output dir for the fold
        out_dir_cv = out_dir + f'_fold_{fold_n}/'
        if not os.path.exists(out_dir_cv):
            os.makedirs(out_dir_cv)

        #Redirect output to specific folder
        sys.stdout = open(out_dir_cv + 'output.out', 'w')

        p["n_feats"] = [x[0].shape[1] for x in X_train]

        X_train_list = []
        mask_train_list = []

        X_test_list = []
        mask_test_list = []

        print('Length of train/test')
        print(len(X_train[0]))
        print(len(X_test[0]))

        #For each channel, pad, create the mask, and append
        for x_ch in X_train:
            X_train_tensor = [torch.FloatTensor(t) for t in x_ch]
            X_train_pad = nn.utils.rnn.pad_sequence(X_train_tensor,
                                                    batch_first=False,
                                                    padding_value=np.nan)
            mask_train = ~torch.isnan(X_train_pad)
            mask_train_list.append(mask_train.to(DEVICE))
            X_train_pad[torch.isnan(X_train_pad)] = 0
            X_train_list.append(X_train_pad.to(DEVICE))

        for x_ch in X_test:
            X_test_tensor = [torch.FloatTensor(t) for t in x_ch]
            X_test_pad = nn.utils.rnn.pad_sequence(X_test_tensor,
                                                   batch_first=False,
                                                   padding_value=np.nan)
            mask_test = ~torch.isnan(X_test_pad)
            mask_test_list.append(mask_test.to(DEVICE))
            X_test_pad[torch.isnan(X_test_pad)] = 0
            X_test_list.append(X_test_pad.to(DEVICE))

        # ntp = max(X_train_list[0].shape[0], X_test_list[0].shape[0])
        ntp = max(max([x.shape[0] for x in X_train_list]),
                  max([x.shape[0] for x in X_train_list]))

        model = rnnvae_h.MCRNNVAE(p["h_size"],
                                  p["hidden"],
                                  p["n_layers"],
                                  p["hidden"],
                                  p["n_layers"],
                                  p["hidden"],
                                  p["n_layers"],
                                  p["z_dim"],
                                  p["hidden"],
                                  p["n_layers"],
                                  p["clip"],
                                  p["n_epochs"],
                                  p["batch_size"],
                                  p["n_channels"],
                                  p["ch_type"],
                                  p["n_feats"],
                                  DEVICE,
                                  print_every=100,
                                  phi_layers=p["phi_layers"],
                                  sigmoid_mean=p["sig_mean"],
                                  dropout=p["dropout"],
                                  dropout_threshold=p["drop_th"])

        model.ch_name = p["ch_names"]

        optimizer = torch.optim.Adam(model.parameters(), lr=p["learning_rate"])
        model.optimizer = optimizer

        model = model.to(DEVICE)
        # Fit the model
        model.fit(X_train_list, X_test_list, mask_train_list, mask_test_list)

        #fit the model after changing the lr
        #optimizer = torch.optim.Adam(model.parameters(), lr=p["learning_rate"]*.1)
        #model.optimizer = optimizer
        #print('Refining optimization...')
        #model.fit(X_train_list, X_test_list, mask_train_list, mask_test_list)

        if p["dropout"]:
            print("Print the dropout")
            print(model.dropout_comp)

        ### After training, save the model!
        model.save(out_dir_cv, 'model.pt')

        # Predict the reconstructions from X_val and X_train
        X_train_fwd = model.predict(X_train_list, mask_train_list, nt=ntp)
        X_test_fwd = model.predict(X_test_list, mask_test_list, nt=ntp)

        # Unpad using the masks
        #plot validation and
        plot_total_loss(model.loss['total'], model.val_loss['total'],
                        "Total loss", out_dir_cv, "total_loss.png")
        plot_total_loss(model.loss['kl'], model.val_loss['kl'], "kl_loss",
                        out_dir_cv, "kl_loss.png")
        plot_total_loss(model.loss['ll'], model.val_loss['ll'], "ll_loss",
                        out_dir_cv,
                        "ll_loss.png")  #Negative to see downard curve

        #Compute mse and reconstruction loss
        #General mse and reconstruction over
        # test_loss = model.recon_loss(X_test_fwd, target=X_test_pad, mask=mask_test_tensor)
        train_loss = model.recon_loss(X_train_fwd,
                                      target=X_train_list,
                                      mask=mask_train_list)
        test_loss = model.recon_loss(X_test_fwd,
                                     target=X_test_list,
                                     mask=mask_test_list)

        print('MSE over the train set: ' + str(train_loss["mae"]))
        print('Reconstruction loss over the train set: ' +
              str(train_loss["rec_loss"]))

        print('MSE over the test set: ' + str(test_loss["mae"]))
        print('Reconstruction loss the train set: ' +
              str(test_loss["rec_loss"]))

        ######################
        ## Prediction of last time point
        ######################

        X_test_list_minus = []
        X_test_tensors = []
        mask_test_list_minus = []
        for x_ch in X_test:
            X_test_tensor = [torch.FloatTensor(t[:-1, :]) for t in x_ch]
            X_test_tensor_full = [torch.FloatTensor(t) for t in x_ch]
            X_test_tensors.append(X_test_tensor_full)
            X_test_pad = nn.utils.rnn.pad_sequence(X_test_tensor,
                                                   batch_first=False,
                                                   padding_value=np.nan)
            mask_test = ~torch.isnan(X_test_pad)
            mask_test_list_minus.append(mask_test.to(DEVICE))
            X_test_pad[torch.isnan(X_test_pad)] = 0
            X_test_list_minus.append(X_test_pad.to(DEVICE))

        # Run prediction
        #this is terribly programmed holy shit
        X_test_fwd_minus = model.predict(X_test_list_minus,
                                         mask_test_list_minus,
                                         nt=ntp)
        X_test_xnext = X_test_fwd_minus["xnext"]

        # Test data without last timepoint
        # X_test_tensors do have the last timepoint
        i = 0
        for (X_ch, ch) in zip(X_test[:3], p["ch_names"][:3]):
            #Select a single channel
            print(f'testing for {ch}')
            y_true = [x[-1] for x in X_ch if len(x) > 1]
            last_tp = [len(x) - 1 for x in X_ch
                       ]  # last tp is max size of original data minus one
            y_pred = []
            # for each subject, select last tp
            j = 0
            for tp in last_tp:
                if tp < 1:
                    j += 1
                    continue  # ignore tps with only baseline

                y_pred.append(X_test_xnext[i][tp, j, :])
                j += 1

            #Process it to predict it
            mae_tp_ch = mean_absolute_error(y_true, y_pred)
            #save the result
            loss[f'pred_{ch}_mae'].append(mae_tp_ch)
            i += 1

        ############################
        ## Test reconstruction for each channel, using the other one
        ############################
        # For each channel
        if p["n_channels"] > 1:

            for i in range(len(X_test)):
                curr_name = p["ch_names"][i]
                av_ch = list(range(len(X_test)))
                av_ch.remove(i)
                # try to reconstruct it from the other ones
                ch_recon = model.predict(X_test_list,
                                         mask_test_list,
                                         nt=ntp,
                                         av_ch=av_ch,
                                         task='recon')
                #for all existing timepoints

                y_true = X_test[i]
                # swap dims to iterate over subjects
                y_pred = np.transpose(ch_recon["xnext"][i], (1, 0, 2))
                y_pred = [
                    x_pred[:len(x_true)]
                    for (x_pred, x_true) in zip(y_pred, y_true)
                ]

                #prepare it timepoint wise
                y_pred = [tp for subj in y_pred for tp in subj]
                y_true = [tp for subj in y_true for tp in subj]

                mae_rec_ch = mean_absolute_error(y_true, y_pred)

                # Get MAE result for that specific channel over all timepoints
                loss[f"recon_{curr_name}_mae"].append(mae_rec_ch)

        # Save results in the loss object
        loss["mae_train"].append(train_loss["mae"])
        loss["rec_train"].append(train_loss["rec_loss"])
        loss["mae_test"].append(train_loss["mae"])
        loss["loss_total"].append(model.loss['total'][-1])
        loss["loss_kl"].append(model.loss['kl'][-1])
        loss["loss_ll"].append(model.loss['ll'][-1])

        fold_n += 1
        # break at 5 iterations, need to do it faster
        if fold_n == 5:
            break

    # Compute the mean for every param in the loss dict
    for k in loss.keys():
        loss[k] = np.mean(loss[k])

    print(loss)
    return loss
Exemple #12
0
def run_experiment(p, csv_path, out_dir, data_cols=[]):
    """
    Function to run the experiments.
    p contain all the hyperparameters needed to run the experiments
    We assume that all the parameters needed are present in p!!
    out_dir is the out directory
    #hyperparameters
    """

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    #Seed
    torch.manual_seed(p["seed"])
    np.random.seed(p["seed"])

    #Redirect output to the out dir
    # sys.stdout = open(out_dir + 'output.out', 'w')

    #save parameters to the out dir
    with open(out_dir + "params.txt", "w") as f:
        f.write(str(p))

    # DEVICE
    ## Decidint on device on device.
    DEVICE_ID = 0
    DEVICE = torch.device(
        'cuda:' + str(DEVICE_ID) if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        torch.cuda.set_device(DEVICE_ID)

    # LOAD DATA
    #Start by not using validation data
    # this is a list of values
    X_train, _, Y_train, _, mri_col = load_multimodal_data(
        csv_path,
        data_cols,
        train_set=1.0,
        normalize=True,
        return_covariates=True)

    p["n_feats"] = [x[0].shape[1] for x in X_train]

    X_train_list = []
    mask_train_list = []

    #For each channel, pad, create the mask, and append
    for x_ch in X_train:
        X_train_tensor = [torch.FloatTensor(t) for t in x_ch]
        X_train_pad = nn.utils.rnn.pad_sequence(X_train_tensor,
                                                batch_first=False,
                                                padding_value=np.nan)
        mask_train = ~torch.isnan(X_train_pad)
        mask_train_list.append(mask_train.to(DEVICE))
        X_train_pad[torch.isnan(X_train_pad)] = 0
        X_train_list.append(X_train_pad.to(DEVICE))

    ntp = max([X_train_list[i].shape[0] for i in range(len(X_train_list))])

    model = rnnvae.MCRNNVAE(p["h_size"], p["hidden"], p["n_layers"],
                            p["hidden"], p["n_layers"], p["hidden"],
                            p["n_layers"], p["z_dim"], p["hidden"],
                            p["n_layers"], p["clip"], p["n_epochs"],
                            p["batch_size"], p["n_channels"], p["n_feats"],
                            DEVICE)

    model.ch_name = p["ch_names"]

    optimizer = torch.optim.Adam(model.parameters(), lr=p["learning_rate"])
    model.optimizer = optimizer

    model = model.to(DEVICE)
    # Fit the model
    model.fit(X_train_list, X_train_list, mask_train_list, mask_train_list)

    ### After training, save the model!
    model.save(out_dir, 'model.pt')

    # Predict the reconstructions from X_val and X_train
    X_train_fwd = model.predict(X_train_list, nt=ntp)

    # Unpad using the masks
    #plot validation and
    plot_total_loss(model.loss['total'], model.val_loss['total'], "Total loss",
                    out_dir, "total_loss.png")
    plot_total_loss(model.loss['kl'], model.val_loss['kl'], "kl_loss", out_dir,
                    "kl_loss.png")
    plot_total_loss(model.loss['ll'], model.val_loss['ll'], "ll_loss", out_dir,
                    "ll_loss.png")  #Negative to see downard curve

    #Compute mse and reconstruction loss
    #General mse and reconstruction over
    # test_loss = model.recon_loss(X_test_fwd, target=X_test_pad, mask=mask_test_tensor)
    train_loss = model.recon_loss(X_train_fwd,
                                  target=X_train_list,
                                  mask=mask_train_list)

    print('MSE over the train set: ' + str(train_loss["mae"]))
    print('Reconstruction loss over the train set: ' +
          str(train_loss["rec_loss"]))

    # print('MSE over the test set: ' + str(test_loss["mae"]))
    # print('Reconstruction loss the train set: ' + str(test_loss["rec_loss"]))

    ##Latent spasce
    #Reformulate things
    #z_train = [np.array(x).swapaxes(0,1) for x in X_train_fwd['z']]
    # IT DOESNT WORK RIGHT NOW
    # Not needed rn
    # z_train = []
    #for (i, z_ch) in enumerate(X_train_fwd['z']):
    #    mask_ch = mask_train_list[i].cpu().numpy()
    #    z_train.append([X[np.tile(mask_ch[:,j,0], (p["z_dim"], 1)).T].reshape((-1, p["z_dim"])) for (j, X) in enumerate(z_ch)])

    # X_test_hat = [X[mask_test[:,i,:]].reshape((-1, nfeatures)) for (i, X) in enumerate(X_test_hat)]

    # z_test = [np.array(x).swapaxes(0,1) for x in X_test_fwd['z']]
    #Zspace needs to be masked

    # Dir for projections
    proj_path = 'z_proj/'
    if not os.path.exists(out_dir + proj_path):
        os.makedirs(out_dir + proj_path)

    #plot latent space for ALL the
    #Por plotting the latent space, we need to do a similar function to plot_latent_space. Wait, that directly does it
    #for ch in range(p["n_channels"]):
    #    for dim0 in range(p["z_dim"]):
    #        for dim1 in range(dim0, p["z_dim"]):
    #            if dim0 == dim1: continue   # very dirty
    #            plot_z_time_2d(z_train[ch], ntp, [dim0, dim1], out_dir + proj_path, out_name=f'z_ch_{ch}_d{dim0}_d{dim1}')

    # Test the new function of latent space
    #NEED TO ADAPT THIS FUNCTION
    qzx = [np.array(x) for x in X_train_fwd['qzx']]

    print('len qzx')
    print(len(qzx))
    # Get classificator labels, for n time points
    out_dir_sample = out_dir + 'zcomp_ch_dx/'
    if not os.path.exists(out_dir_sample):
        os.makedirs(out_dir_sample)

    dx_dict = {
        "NL": "CN",
        "MCI": "MCI",
        "MCI to NL": "CN",
        "Dementia": "AD",
        "Dementia to MCI": "MCI",
        "NL to MCI": "MCI",
        "NL to Dementia": "AD",
        "MCI to Dementia": "AD"
    }
    #Convert to standard
    #Add padding so that the mask also works here
    DX = [[x for x in elem] for elem in Y_train["DX"]]

    #Define colors
    pallete_dict = {"CN": "#2a9e1e", "MCI": "#bfbc1a", "AD": "#af1f1f"}

    plot_latent_space(model,
                      qzx,
                      ntp,
                      classificator=DX,
                      pallete_dict=pallete_dict,
                      plt_tp='all',
                      all_plots=True,
                      uncertainty=False,
                      savefig=True,
                      out_dir=out_dir_sample,
                      mask=mask_train_list)

    out_dir_sample_t0 = out_dir + 'zcomp_ch_dx_t0/'
    if not os.path.exists(out_dir_sample_t0):
        os.makedirs(out_dir_sample_t0)

    plot_latent_space(model,
                      qzx,
                      ntp,
                      classificator=DX,
                      pallete_dict=pallete_dict,
                      plt_tp=[0],
                      all_plots=True,
                      uncertainty=False,
                      savefig=True,
                      out_dir=out_dir_sample_t0,
                      mask=mask_train_list)

    # Now plot color by timepoint
    out_dir_sample = out_dir + 'zcomp_ch_tp/'
    if not os.path.exists(out_dir_sample):
        os.makedirs(out_dir_sample)

    classif = [[i for (i, x) in enumerate(elem)] for elem in Y_train["DX"]]
    pallete = sns.color_palette("viridis", ntp)
    pallete_dict = {i: value for (i, value) in enumerate(pallete)}

    plot_latent_space(model,
                      qzx,
                      ntp,
                      classificator=classif,
                      pallete_dict=pallete_dict,
                      plt_tp='all',
                      all_plots=True,
                      uncertainty=False,
                      savefig=True,
                      out_dir=out_dir_sample,
                      mask=mask_train_list)

    loss = {
        "mse_train": train_loss["mae"],
        "rec_train": train_loss["rec_loss"],
        # "mse_test": test_loss["mae"],
        "loss_total": model.loss['total'][-1],
        "loss_kl": model.loss['kl'][-1],
        "loss_ll": model.loss['ll'][-1]
    }

    return loss
Exemple #13
0
def run_experiment(p, csv_path, out_dir, data_cols='_mri_vol'):
    """
    Function to run the experiments.
    p contain all the hyperparameters needed to run the experiments
    We assume that all the parameters needed are present in p!!
    out_dir is the out directory
    #hyperparameters
    """

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    #Seed
    torch.manual_seed(p["seed"])
    np.random.seed(p["seed"])

    #Redirect output to the out dir
    # sys.stdout = open(out_dir + 'output.out', 'w')

    #save parameters to the out dir
    with open(out_dir + "params.txt", "w") as f:
        f.write(str(p))

    # DEVICE
    ## Decidint on device on device.
    DEVICE_ID = 0
    DEVICE = torch.device(
        'cuda:' + str(DEVICE_ID) if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        torch.cuda.set_device(DEVICE_ID)

    # LOAD DATA
    X_train, X_test, Y_train, Y_test, mri_col = open_MRI_data_var(
        csv_path,
        train_set=0.9,
        normalize=True,
        return_covariates=True,
        data_cols=data_cols)
    #TEMPORAL

    #Combine test and train Y for later
    Y = {}
    for k in Y_train.keys():
        Y[k] = Y_train[k] + Y_test[k]

    # List of (nt, nfeatures) numpy objects
    p["x_size"] = X_train[0].shape[1]
    print(p["x_size"])

    # Apply padding to both X_train and X_val
    # REMOVE LAST POINT OF EACH INDIVIDUAL
    X_train_tensor = [torch.FloatTensor(t[:-1, :]) for t in X_train]
    X_train_pad = nn.utils.rnn.pad_sequence(X_train_tensor,
                                            batch_first=False,
                                            padding_value=np.nan)
    X_test_tensor = [torch.FloatTensor(t) for t in X_test]
    X_test_pad = nn.utils.rnn.pad_sequence(X_test_tensor,
                                           batch_first=False,
                                           padding_value=np.nan)

    p["ntp"] = max(X_train_pad.size(0), X_test_pad.size(0))

    # Those datasets are of size [Tmax, Batch_size, nfeatures]
    # Save mask to unpad later when testing
    mask_train = ~torch.isnan(X_train_pad)
    mask_test = ~torch.isnan(X_test_pad)

    # convert to tensor
    mask_train_tensor = torch.BoolTensor(mask_train)
    mask_test_tensor = torch.BoolTensor(mask_test)

    #convert those NaN to zeros
    X_train_pad[torch.isnan(X_train_pad)] = 0
    X_test_pad[torch.isnan(X_test_pad)] = 0

    # Define model and optimizer
    model = rnnvae.ModelRNNVAE(p["x_size"], p["h_size"], p["hidden"],
                               p["n_layers"], p["hidden"], p["n_layers"],
                               p["hidden"], p["n_layers"], p["z_dim"],
                               p["hidden"], p["n_layers"], p["clip"],
                               p["n_epochs"], p["batch_size"], DEVICE)

    optimizer = torch.optim.Adam(model.parameters(), lr=p["learning_rate"])
    model.optimizer = optimizer

    model = model.to(DEVICE)
    # Fit the model
    model.fit(X_train_pad.to(DEVICE), X_test_pad.to(DEVICE),
              mask_train_tensor.to(DEVICE), mask_test_tensor.to(DEVICE))

    ### After training, save the model!
    model.save(out_dir, 'model.pt')

    # Predict the reconstructions from X_val and X_train
    X_test_fwd = model.predict(X_test_pad.to(DEVICE))
    X_train_fwd = model.predict(X_train_pad.to(DEVICE))

    #Reformulate things
    X_train_fwd['xnext'] = np.array(X_train_fwd['xnext']).swapaxes(0, 1)
    X_train_fwd['z'] = np.array(X_train_fwd['z']).swapaxes(0, 1)
    X_test_fwd['xnext'] = np.array(X_test_fwd['xnext']).swapaxes(0, 1)
    X_test_fwd['z'] = np.array(X_test_fwd['z']).swapaxes(0, 1)

    X_test_hat = X_test_fwd["xnext"]
    X_train_hat = X_train_fwd["xnext"]

    # Unpad using the masks
    #after masking, need to rehsape to (nt, nfeat)
    X_test_hat = [
        X[mask_test[:, i, :]].reshape((-1, p["x_size"]))
        for (i, X) in enumerate(X_test_hat)
    ]
    X_train_hat = [
        X[mask_train[:, i, :]].reshape((-1, p["x_size"]))
        for (i, X) in enumerate(X_train_hat)
    ]

    #Compute mean absolute error over all sequences
    mse_train = np.mean([
        mean_absolute_error(xval[:-1, :], xhat)
        for (xval, xhat) in zip(X_train, X_train_hat)
    ])
    print('MSE over the train set: ' + str(mse_train))

    #Compute mean absolute error over all sequences
    mse_test = np.mean([
        mean_absolute_error(xval, xhat)
        for (xval, xhat) in zip(X_test, X_test_hat)
    ])
    print('MSE over the test set: ' + str(mse_test))

    #plot validation and
    plot_total_loss(model.loss['total'], model.val_loss['total'], "Total loss",
                    out_dir, "total_loss.png")
    plot_total_loss(model.loss['kl'], model.val_loss['kl'], "kl_loss", out_dir,
                    "kl_loss.png")
    plot_total_loss(model.loss['ll'], model.val_loss['ll'], "ll_loss", out_dir,
                    "ll_loss.png")  #Negative to see downard curve

    # Visualization of trajectories
    """
    subj = 6
    feature = 12

    # For train
    plot_trajectory(X_train, X_train_hat, subj, 'all', out_dir, f'traj_train_s_{subj}_f_all') # testing for a given subject
    plot_trajectory(X_train, X_train_hat, subj, feature, out_dir, f'traj_train_s_{subj}_f_{feature}') # testing for a given feature

    # For test
    plot_trajectory(X_test, X_test_hat, subj, 'all', out_dir, f'traj_test_s_{subj}_f_all') # testing for a given subject
    plot_trajectory(X_test, X_test_hat, subj, feature, out_dir, f'traj_test_s_{subj}_f_{feature}') # testing for a given feature
    """

    z_train = X_train_fwd['z']
    z_test = X_test_fwd['z']

    # select only the existing time points
    # Repeat the mask for each latent features, as we can have variable features, need to treat the mask
    #Use ptile to repeat it as many times as p["z_dim"], and transpose it
    z_test = [
        X[np.tile(mask_test[:, i, 0], (p["z_dim"], 1)).T].reshape(
            (-1, p["z_dim"])) for (i, X) in enumerate(z_test)
    ]
    z_train = [
        X[np.tile(mask_train[:, i, 0], (p["z_dim"], 1)).T].reshape(
            (-1, p["z_dim"])) for (i, X) in enumerate(z_train)
    ]
    z = z_train + z_test

    # Dir for projections
    proj_path = 'z_proj/'
    if not os.path.exists(out_dir + proj_path):
        os.makedirs(out_dir + proj_path)

    #plot latent space
    for dim0 in range(p["z_dim"]):
        for dim1 in range(dim0, p["z_dim"]):
            if dim0 == dim1: continue  # very dirty
            plot_z_time_2d(z,
                           p["ntp"], [dim0, dim1],
                           out_dir + proj_path,
                           out_name=f'z_d{dim0}_d{dim1}')

    # Dir for projections
    sampling_path = 'z_proj_dx/'
    if not os.path.exists(out_dir + sampling_path):
        os.makedirs(out_dir + sampling_path)

    #plot latent space
    for dim0 in range(p["z_dim"]):
        for dim1 in range(dim0, p["z_dim"]):
            if dim0 == dim1: continue  # very dirty
            plot_z_time_2d(z,
                           p["ntp"], [dim0, dim1],
                           out_dir + sampling_path,
                           c='DX',
                           Y=Y,
                           out_name=f'z_d{dim0}_d{dim1}')

    # Dir for projections
    sampling_path = 'z_proj_age/'
    if not os.path.exists(out_dir + sampling_path):
        os.makedirs(out_dir + sampling_path)

    #plot latent space
    for dim0 in range(p["z_dim"]):
        for dim1 in range(dim0, p["z_dim"]):
            if dim0 == dim1: continue  # very dirty
            plot_z_time_2d(z,
                           p["ntp"], [dim0, dim1],
                           out_dir + sampling_path,
                           c='AGE',
                           Y=Y,
                           out_name=f'z_d{dim0}_d{dim1}')

    # Compute MSE
    # Predict for max+1 and select only the positions that I am interested in
    #this sequence predict DO NOT work well
    Y_true = [p[-1, :] for p in X_train]
    Y_pred = []

    for i in range(X_train_pad.size(1)):
        x = torch.FloatTensor(X_train[i][:-1, :])
        x = x.unsqueeze(1)
        tp = x.size(0)  # max time points (and timepoint to predict)
        if tp == 0:
            continue
        X_fwd = model.sequence_predict(x.to(DEVICE), tp + 1)
        X_hat = X_fwd['xnext']
        Y_pred.append(X_hat[tp, 0, :])  #get predicted point

    #For each patient in X_hat, saveonly the timepoint that we want
    #Compute mse
    mse_predict = mean_squared_error(Y_true, Y_pred)
    print('MSE over a future timepoint prediction: ' + str(mse_predict))

    # TODO: THIS SAMPLING PROCEDURE NEEDS TO BE UPDATED
    """
    nt = len(X_train_pad)
    nsamples = 1000
    X_sample = model.sample_latent(nsamples, nt)

    #Get the samples
    X_sample['xnext'] = np.array(X_sample['xnext']).swapaxes(0,1)
    X_sample['z'] = np.array(X_sample['z']).swapaxes(0,1)

    # Dir for projections
    sampling_path = 'z_proj_sampling/'
    if not os.path.exists(out_dir + sampling_path):
        os.makedirs(out_dir + sampling_path)

    #plot latent space
    for dim0 in range(p["z_dim"]):
        for dim1 in range(dim0, p["z_dim"]):
            if dim0 == dim1: continue   # very dirty
            plot_z_time_2d(X_sample['z'], p["ntp"], [dim0, dim1], out_dir + 'z_proj_sampling/', out_name=f'z_d{dim0}_d{dim1}')
    """

    loss = {
        "mse_train": mse_train,
        "mse_test": mse_test,
        "mse_predict": mse_predict,
        "loss_total": model.loss['total'][-1],
        "loss_kl": model.loss['kl'][-1],
        "loss_ll": model.loss['ll'][-1]
    }

    return loss