Exemplo n.º 1
0
def run_experiment(p, csv_path, out_dir, data_cols=[]):
    """
    Function to run the experiments.
    p contain all the hyperparameters needed to run the experiments
    We assume that all the parameters needed are present in p!!
    out_dir is the out directory
    #hyperparameters
    """
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    #Seed
    torch.manual_seed(p["seed"])
    np.random.seed(p["seed"])

    #Redirect output to the out dir
    sys.stdout = open(out_dir + 'output.out', 'w')

    #save parameters to the out dir
    with open(out_dir + "params.txt", "w") as f:
        f.write(str(p))

    # DEVICE
    ## Decidint on device on device.
    DEVICE_ID = 0
    DEVICE = torch.device(
        'cuda:' + str(DEVICE_ID) if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        torch.cuda.set_device(DEVICE_ID)

    # Begin on the CV data
    gen = load_multimodal_data_cv(csv_path,
                                  data_cols,
                                  p["ch_type"],
                                  nsplit=10,
                                  normalize=True)
    # Prepare the data structures for the data

    # Load the different folds
    loss = {
        "mae_train": [],
        "rec_train": [],
        "mae_test": [],
        "loss_total": [],
        "loss_total_val": [],
        "loss_kl": [],
        "loss_ll": [],
    }

    pred_results = {}
    for ch_name in p["ch_names"][:3]:
        pred_results[f"pred_{ch_name}_mae"] = []

    rec_results = {}
    for ch_name in p["ch_names"]:
        rec_results[f"recon_{ch_name}_mae"] = []

    loss = {**loss, **pred_results, **rec_results}

    # iterator marking the fold
    fold_n = 0

    for X_train, X_test, Y_train, Y_test, mri_col in gen:
        # LOAD DATA
        #Start by not using validation data
        # this is a list of values

        #Create output dir for the fold
        out_dir_cv = out_dir + f'_fold_{fold_n}/'
        if not os.path.exists(out_dir_cv):
            os.makedirs(out_dir_cv)

        #Redirect output to specific folder
        sys.stdout = open(out_dir_cv + 'output.out', 'w')

        p["n_feats"] = [x[0].shape[1] for x in X_train]

        X_train_list = []
        mask_train_list = []

        X_test_list = []
        mask_test_list = []

        print('Length of train/test')
        print(len(X_train[0]))
        print(len(X_test[0]))

        # need to deal with ntp here
        ntp = max(np.max([[len(xi) for xi in x] for x in X_train]),
                  np.max([[len(xi) for xi in x] for x in X_train]))

        if p["long_to_bl"]:
            # HERE, change bl to long and repeat the values at t0 for ntp
            for i in range(len(p["ch_type"])):
                if p["ch_type"][i] == 'bl':
                    for j in range(len(X_train[i])):
                        X_train[i][j] = np.array([X_train[i][j][0]] * ntp)

                    for j in range(len(X_test[i])):
                        X_test[i][j] = np.array([X_test[i][j][0]] * ntp)

                    # p["ch_type"][i] = 'long'

        #For each channel, pad, create the mask, and append
        for x_ch in X_train:
            X_train_tensor = [torch.FloatTensor(t) for t in x_ch]
            X_train_pad = nn.utils.rnn.pad_sequence(X_train_tensor,
                                                    batch_first=False,
                                                    padding_value=np.nan)
            mask_train = ~torch.isnan(X_train_pad)
            mask_train_list.append(mask_train.to(DEVICE))
            X_train_pad[torch.isnan(X_train_pad)] = 0
            X_train_list.append(X_train_pad.to(DEVICE))

        for x_ch in X_test:
            X_test_tensor = [torch.FloatTensor(t) for t in x_ch]
            X_test_pad = nn.utils.rnn.pad_sequence(X_test_tensor,
                                                   batch_first=False,
                                                   padding_value=np.nan)
            mask_test = ~torch.isnan(X_test_pad)
            mask_test_list.append(mask_test.to(DEVICE))
            X_test_pad[torch.isnan(X_test_pad)] = 0
            X_test_list.append(X_test_pad.to(DEVICE))

        #ntp = max(max([x.shape[0] for x in X_train_list]), max([x.shape[0] for x in X_train_list]))

        model = rnnvae_h.MCRNNVAE(p["h_size"],
                                  p["x_hidden"],
                                  p["x_n_layers"],
                                  p["z_hidden"],
                                  p["z_n_layers"],
                                  p["enc_hidden"],
                                  p["enc_n_layers"],
                                  p["z_dim"],
                                  p["dec_hidden"],
                                  p["dec_n_layers"],
                                  p["clip"],
                                  p["n_epochs"],
                                  p["batch_size"],
                                  p["n_channels"],
                                  p["ch_type"],
                                  p["n_feats"],
                                  p["c_z"],
                                  DEVICE,
                                  print_every=100,
                                  phi_layers=p["phi_layers"],
                                  sigmoid_mean=p["sig_mean"],
                                  dropout=p["dropout"],
                                  dropout_threshold=p["drop_th"])
        model.ch_name = p["ch_names"]

        optimizer = torch.optim.Adam(model.parameters(), lr=p["learning_rate"])
        model.optimizer = optimizer

        model = model.to(DEVICE)

        # Fit the model
        model.fit(X_train_list, X_test_list, mask_train_list, mask_test_list)

        #fit the model after changing the lr
        #optimizer = torch.optim.Adam(model.parameters(), lr=p["learning_rate"]*.1)
        #model.optimizer = optimizer
        #print('Refining optimization...')
        #model.fit(X_train_list, X_test_list, mask_train_list, mask_test_list)

        if p["dropout"]:
            print("Print the dropout")
            print(model.dropout_comp)

        ### After training, save the model!
        model.save(out_dir_cv, 'model.pt')

        # Predict the reconstructions from X_val and X_train
        X_train_fwd = model.predict(X_train_list, mask_train_list, nt=ntp)
        X_test_fwd = model.predict(X_test_list, mask_test_list, nt=ntp)

        # Unpad using the masks
        #plot validation and
        plot_total_loss(model.loss['total'], model.val_loss['total'],
                        "Total loss", out_dir_cv, "total_loss.png")
        plot_total_loss(model.loss['kl'], model.val_loss['kl'], "kl_loss",
                        out_dir_cv, "kl_loss.png")
        plot_total_loss(model.loss['ll'], model.val_loss['ll'], "ll_loss",
                        out_dir_cv,
                        "ll_loss.png")  #Negative to see downard curve

        #Compute mse and reconstruction loss
        #General mse and reconstruction over
        # test_loss = model.recon_loss(X_test_fwd, target=X_test_pad, mask=mask_test_tensor)
        train_loss = model.recon_loss(X_train_fwd,
                                      target=X_train_list,
                                      mask=mask_train_list)
        test_loss = model.recon_loss(X_test_fwd,
                                     target=X_test_list,
                                     mask=mask_test_list)

        print('MSE over the train set: ' + str(train_loss["mae"]))
        print('Reconstruction loss over the train set: ' +
              str(train_loss["rec_loss"]))

        print('MSE over the test set: ' + str(test_loss["mae"]))
        print('Reconstruction loss the train set: ' +
              str(test_loss["rec_loss"]))

        ######################
        ## Prediction of last time point
        ######################
        i = 0
        # Test data without last timepoint
        # X_test_tensors do have the last timepoint
        pred_ch = list(range(3))
        print(pred_ch)
        t_pred = 1
        res = eval_prediction(model, X_test, t_pred, pred_ch, DEVICE)

        for (i, ch) in enumerate(
            [x for (i, x) in enumerate(p["ch_names"]) if i in pred_ch]):
            loss[f'pred_{ch}_mae'].append(res[i])

        ############################
        ## Test reconstruction for each channel, using the other one
        ############################
        # For each channel
        if p["n_channels"] > 1:
            for i in range(len(X_test)):
                curr_name = p["ch_names"][i]
                av_ch = list(range(len(X_test)))
                av_ch.remove(i)
                mae_rec = eval_reconstruction(model, X_test, X_test_list,
                                              mask_test_list, av_ch, i)
                # Get MAE result for that specific channel over all timepoints
                loss[f"recon_{curr_name}_mae"].append(mae_rec)

        # Save results in the loss object
        loss["mae_train"].append(train_loss["mae"])
        loss["rec_train"].append(train_loss["rec_loss"])
        loss["mae_test"].append(train_loss["mae"])
        loss["loss_total"].append(model.loss['total'][-1])
        loss["loss_total_val"].append(model.val_loss['total'][-1])
        loss["loss_kl"].append(model.loss['kl'][-1])
        loss["loss_ll"].append(model.loss['ll'][-1])

        fold_n += 1
        # break at 5 iterations, need to do it faster
        if fold_n == 2:
            break

    # Compute the mean for every param in the loss dict
    for k in loss.keys():
        loss[k] = np.mean(loss[k])

    print(loss)
    return loss
Exemplo n.º 2
0
def run_eval(out_dir,
             test_csv,
             data_cols,
             dropout_threshold_test,
             output_to_file=False):
    """
    Main function to evaluate a model.

    Evaluate a trained model
    out_dir: directory where the model is and the results will be stored.
    test_csv: where the csv with the test data is stored.
    data_cols: name of channels.
    dropout_threshold_test: threshold of the dropout
    use_synth: use synthetic data
    """

    ch_bl = []  ##STORE THE CHANNELS THAT WE CONVERT TO LONG BUT WERE BL

    #Redirect output to the out dir
    if output_to_file:
        sys.stdout = open(out_dir + 'output.out', 'w')

    #load parameters
    p = eval(open(out_dir + "params.txt").read())

    long_to_bl = p[
        "long_to_bl"]  #variable to decide if we have transformed the long to bl or not.

    # DEVICE
    ## Decidint on device on device.
    DEVICE_ID = 0
    DEVICE = torch.device(
        'cuda:' + str(DEVICE_ID) if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        torch.cuda.set_device(DEVICE_ID)

    X_test, _, Y_test, _, col_lists = load_multimodal_data(
        test_csv,
        data_cols,
        p["ch_type"],
        train_set=1.0,
        normalize=True,
        return_covariates=True)
    p["n_feats"] = [x[0].shape[1] for x in X_test]

    # need to deal with ntp here
    ntp = max(np.max([[len(xi) for xi in x] for x in X_test]),
              np.max([[len(xi) for xi in x] for x in X_test]))

    if long_to_bl:
        # Process MASK WITHOUT THE REPETITION OF BASELINE
        # HERE, change bl to long and repeat the values at t0 for ntp
        for i in range(len(p["ch_type"])):
            if p["ch_type"][i] == 'bl':

                for j in range(len(X_test[i])):
                    X_test[i][j] = np.array([X_test[i][j][0]] * ntp)

                # p["ch_type"][i] = 'long'
                ch_bl.append(i)

    X_test_list = []
    mask_test_list = []

    # Process test set
    for x_ch in X_test:
        X_test_tensor = [torch.FloatTensor(t) for t in x_ch]
        X_test_pad = nn.utils.rnn.pad_sequence(X_test_tensor,
                                               batch_first=False,
                                               padding_value=np.nan)
        mask_test = ~torch.isnan(X_test_pad)
        mask_test_list.append(mask_test.to(DEVICE))
        X_test_pad[torch.isnan(X_test_pad)] = 0
        X_test_list.append(X_test_pad.to(DEVICE))

    model = rnnvae_s.MCRNNVAE(p["h_size"],
                              p["enc_hidden"],
                              p["enc_n_layers"],
                              p["z_dim"],
                              p["dec_hidden"],
                              p["dec_n_layers"],
                              p["clip"],
                              p["n_epochs"],
                              p["batch_size"],
                              p["n_channels"],
                              p["ch_type"],
                              p["n_feats"],
                              p["c_z"],
                              DEVICE,
                              print_every=100,
                              phi_layers=p["phi_layers"],
                              sigmoid_mean=p["sig_mean"],
                              dropout=p["dropout"],
                              dropout_threshold=p["drop_th"])

    model = model.to(DEVICE)
    model.load(out_dir + 'model.pt')
    if p["dropout"]:
        print(model.dropout_comp)
        model.dropout_threshold = dropout_threshold_test

    ####################################
    # IF DROPOUT, CHECK THE COMPONENTS AND THRESHOLD AND CHANGE IT
    ####################################

    ##TEST
    X_test_fwd = model.predict(X_test_list, mask_test_list, nt=ntp)

    # Test the reconstruction and prediction

    ######################
    ## Prediction of last time point
    ######################
    # Test data without last timepoint
    # X_test_tensors do have the last timepoint
    pred_ch = list(range(3))
    t_pred = 1
    res = eval_prediction(model, X_test, t_pred, pred_ch, DEVICE)

    for (i, ch) in enumerate(
        [x for (i, x) in enumerate(p["ch_names"]) if i in pred_ch]):
        print(f'pred_{ch}_mae: {res[i]}')

    ############################
    ## Test reconstruction for each channel, using the other one
    ############################
    # For each channel
    results = np.zeros(
        (len(X_test), len(X_test)))  #store the results, will save later

    for i in range(len(X_test)):
        for j in range(len(X_test)):
            curr_name = p["ch_names"][i]
            to_recon = p["ch_names"][j]
            av_ch = [j]
            mae_rec = eval_reconstruction(model, X_test, X_test_list,
                                          mask_test_list, av_ch, i)
            results[i, j] = mae_rec
            # Get MAE result for that specific channel over all timepoints
            print(f"recon_{curr_name}_from{to_recon}_mae: {mae_rec}")

    df_crossrec = pd.DataFrame(data=results,
                               index=p["ch_names"],
                               columns=p["ch_names"])
    plt.tight_layout()
    ax = sns.heatmap(df_crossrec, annot=True, fmt=".2f", vmin=0, vmax=1)
    plt.savefig(out_dir + "figure_crossrecon.png")
    plt.close()
    # SAVE AS FIGURE
    df_crossrec.to_latex(out_dir + "table_crossrecon.tex")

    ############################
    ## Test reconstruction for each channel, using the rest
    ############################
    # For each channel
    results = np.zeros((len(X_test), 1))  #store the results, will save later

    for i in range(len(X_test)):
        av_ch = list(range(len(X_test))).remove(i)
        to_recon = p["ch_names"][i]
        mae_rec = eval_reconstruction(model, X_test, X_test_list,
                                      mask_test_list, av_ch, i)
        results[i] = mae_rec
        # Get MAE result for that specific channel over all timepoints
        print(f"recon_{to_recon}_fromall_mae: {mae_rec}")

    df_totalrec = pd.DataFrame(data=results.T, columns=p["ch_names"])

    # SAVE AS FIGURE
    df_totalrec.to_latex(out_dir + "table_totalrecon.tex")

    ###############################################################
    # PLOTTING, FIRST GENERAL PLOTTING AND THEN SPECIFIC PLOTTING #
    ###############################################################

    # Test the new function of latent space
    #NEED TO ADAPT THIS FUNCTION
    qzx_test = [np.array(x) for x in X_test_fwd['qzx']]

    # IF WE DO THAT TRANSFORMATION
    if long_to_bl:
        for i in ch_bl:
            qzx_test[i] = np.array(
                [qzx if j == 0 else None for j, qzx in enumerate(qzx_test[i])])

    # Now plot color by timepoint
    out_dir_sample = out_dir + 'zcomp_ch_age/'
    if not os.path.exists(out_dir_sample):
        os.makedirs(out_dir_sample)

    #Binarize the ages and
    age_full = [x for elem in Y_test["AGE_demog"] for x in elem]
    bins, retstep = np.linspace(min(age_full), max(age_full), 8, retstep=True)
    age_digitized = [np.digitize(y, bins) for y in Y_test["AGE_demog"]]

    classif_test = [[bins[x - 1] for (i, x) in enumerate(elem)]
                    for elem in age_digitized]

    pallete = sns.color_palette("viridis", 8)
    pallete_dict = {bins[i]: value for (i, value) in enumerate(pallete)}

    ####IF DROPOUT, SELECT ONLY COMPS WITH DROPOUT > TAL
    if model.dropout:
        kept_comp = model.kept_components
    else:
        kept_comp = None

    print(kept_comp)
    plot_latent_space(model,
                      qzx_test,
                      ntp,
                      classificator=classif_test,
                      pallete_dict=pallete_dict,
                      plt_tp='all',
                      all_plots=True,
                      uncertainty=False,
                      comp=kept_comp,
                      savefig=True,
                      out_dir=out_dir_sample + '_test',
                      mask=mask_test_list)

    #Convert to standard
    #Add padding so that the mask also works here
    DX_test = [[x for x in elem] for elem in Y_test["DX"]]

    #Define colors
    pallete_dict = {"CN": "#2a9e1e", "MCI": "#bfbc1a", "AD": "#af1f1f"}
    # Get classificator labels, for n time points
    out_dir_sample = out_dir + 'zcomp_ch_dx/'
    if not os.path.exists(out_dir_sample):
        os.makedirs(out_dir_sample)

    plot_latent_space(model,
                      qzx_test,
                      ntp,
                      classificator=DX_test,
                      pallete_dict=pallete_dict,
                      plt_tp='all',
                      all_plots=True,
                      uncertainty=False,
                      comp=kept_comp,
                      savefig=True,
                      out_dir=out_dir_sample + '_test',
                      mask=mask_test_list)

    out_dir_sample_t0 = out_dir + 'zcomp_ch_dx_t0/'
    if not os.path.exists(out_dir_sample_t0):
        os.makedirs(out_dir_sample_t0)

    plot_latent_space(model,
                      qzx_test,
                      ntp,
                      classificator=DX_test,
                      pallete_dict=pallete_dict,
                      plt_tp=[0],
                      all_plots=True,
                      uncertainty=False,
                      comp=kept_comp,
                      savefig=True,
                      out_dir=out_dir_sample_t0 + '_test',
                      mask=mask_test_list)

    # Now plot color by timepoint
    out_dir_sample = out_dir + 'zcomp_ch_tp/'
    if not os.path.exists(out_dir_sample):
        os.makedirs(out_dir_sample)

    classif_test = [[i for (i, x) in enumerate(elem)] for elem in Y_test["DX"]]

    pallete = sns.color_palette("viridis", ntp)
    pallete_dict = {i: value for (i, value) in enumerate(pallete)}

    plot_latent_space(model,
                      qzx_test,
                      ntp,
                      classificator=classif_test,
                      pallete_dict=pallete_dict,
                      plt_tp='all',
                      all_plots=True,
                      uncertainty=False,
                      comp=kept_comp,
                      savefig=True,
                      out_dir=out_dir_sample + '_test',
                      mask=mask_test_list)
Exemplo n.º 3
0
def run_experiment(p, csv_path, out_dir, data_cols=[]):
    """
    Function to run the experiments.
    p contain all the hyperparameters needed to run the experiments
    We assume that all the parameters needed are present in p!!
    out_dir is the out directory
    #hyperparameters
    """

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    #Seed
    torch.manual_seed(p["seed"])
    np.random.seed(p["seed"])

    #Redirect output to the out dir
    # sys.stdout = open(out_dir + 'output.out', 'w')

    #save parameters to the out dir
    with open(out_dir + "params.txt", "w") as f:
        f.write(str(p))

    # DEVICE
    ## Decidint on device on device.
    DEVICE_ID = 0
    DEVICE = torch.device(
        'cuda:' + str(DEVICE_ID) if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        torch.cuda.set_device(DEVICE_ID)

    # LOAD DATA
    #Start by not using validation data
    # this is a list of values
    X_train, X_test, Y_train, Y_test, mri_col = load_multimodal_data(
        csv_path,
        data_cols,
        p["ch_type"],
        train_set=0.9,
        normalize=True,
        return_covariates=True)

    p["n_feats"] = [x[0].shape[1] for x in X_train]

    X_train_list = []
    mask_train_list = []

    X_test_list = []
    mask_test_list = []

    print('Length of train/test')
    print(len(X_train[0]))
    print(len(X_test[0]))

    #For each channel, pad, create the mask, and append
    for x_ch in X_train:
        X_train_tensor = [torch.FloatTensor(t) for t in x_ch]
        X_train_pad = nn.utils.rnn.pad_sequence(X_train_tensor,
                                                batch_first=False,
                                                padding_value=np.nan)
        mask_train = ~torch.isnan(X_train_pad)
        mask_train_list.append(mask_train.to(DEVICE))
        X_train_pad[torch.isnan(X_train_pad)] = 0
        X_train_list.append(X_train_pad.to(DEVICE))

    for x_ch in X_test:
        X_test_tensor = [torch.FloatTensor(t) for t in x_ch]
        X_test_pad = nn.utils.rnn.pad_sequence(X_test_tensor,
                                               batch_first=False,
                                               padding_value=np.nan)
        mask_test = ~torch.isnan(X_test_pad)
        mask_test_list.append(mask_test.to(DEVICE))
        X_test_pad[torch.isnan(X_test_pad)] = 0
        X_test_list.append(X_test_pad.to(DEVICE))

    # ntp = max(X_train_list[0].shape[0], X_test_list[0].shape[0])
    ntp = max(max([x.shape[0] for x in X_train_list]),
              max([x.shape[0] for x in X_train_list]))

    model = rnnvae_h.MCRNNVAE(p["h_size"],
                              p["x_hidden"],
                              p["x_n_layers"],
                              p["z_hidden"],
                              p["z_n_layers"],
                              p["enc_hidden"],
                              p["enc_n_layers"],
                              p["z_dim"],
                              p["dec_hidden"],
                              p["dec_n_layers"],
                              p["clip"],
                              p["n_epochs"],
                              p["batch_size"],
                              p["n_channels"],
                              p["ch_type"],
                              p["n_feats"],
                              DEVICE,
                              print_every=100,
                              phi_layers=p["phi_layers"],
                              sigmoid_mean=p["sig_mean"],
                              dropout=p["dropout"],
                              dropout_threshold=p["drop_th"])

    model.ch_name = p["ch_names"]

    optimizer = torch.optim.Adam(model.parameters(), lr=p["learning_rate"])
    model.optimizer = optimizer

    model = model.to(DEVICE)

    # Fit the model
    # FIT IT FOR THE NUMBER OF EPOCHS, X TIMES
    ntimes = 20
    for nrep in range(ntimes):
        print(nrep)
        model.fit(X_train_list, X_test_list, mask_train_list, mask_test_list)

        #fit the model after changing the lr
        if p["dropout"]:
            print("Print the dropout")
            print(model.dropout_comp)

        ### After training, save the model!
        model.save(out_dir, 'model.pt')

        # Predict the reconstructions from X_val and X_train
        X_train_fwd = model.predict(X_train_list, mask_train_list, nt=ntp)
        X_test_fwd = model.predict(X_test_list, mask_test_list, nt=ntp)

        # Unpad using the masks
        #plot validation and
        plot_total_loss(model.loss['total'], model.val_loss['total'],
                        "Total loss", out_dir, "total_loss.png")
        plot_total_loss(model.loss['kl'], model.val_loss['kl'], "kl_loss",
                        out_dir, "kl_loss.png")
        plot_total_loss(model.loss['ll'], model.val_loss['ll'], "ll_loss",
                        out_dir, "ll_loss.png")  #Negative to see downard curve

        #Compute mse and reconstruction loss
        #General mse and reconstruction over
        # test_loss = model.recon_loss(X_test_fwd, target=X_test_pad, mask=mask_test_tensor)
        train_loss = model.recon_loss(X_train_fwd,
                                      target=X_train_list,
                                      mask=mask_train_list)
        test_loss = model.recon_loss(X_test_fwd,
                                     target=X_test_list,
                                     mask=mask_test_list)

        print('MSE over the train set: ' + str(train_loss["mae"]))
        print('Reconstruction loss over the train set: ' +
              str(train_loss["rec_loss"]))

        print('MSE over the test set: ' + str(test_loss["mae"]))
        print('Reconstruction loss the train set: ' +
              str(test_loss["rec_loss"]))

        pred_results = {}
        for ch_name in p["ch_names"][:3]:
            pred_results[f"pred_{ch_name}_mae"] = []

        rec_results = {}
        for ch_name in p["ch_names"]:
            rec_results[f"recon_{ch_name}_mae"] = []

        results = {**pred_results, **rec_results}

        ######################
        ## Prediction of last time point
        ######################

        i = 0
        # Test data without last timepoint
        # X_test_tensors do have the last timepoint
        pred_ch = list(range(3))
        print(pred_ch)
        t_pred = 1
        res = eval_prediction(model, X_test, t_pred, pred_ch, DEVICE)

        for (i, ch) in enumerate(
            [x for (i, x) in enumerate(p["ch_names"]) if i in pred_ch]):
            loss[f'pred_{ch}_mae'].append(res[i])

        ############################
        ## Test reconstruction for each channel, using the other one
        ############################
        # For each channel
        if p["n_channels"] > 1:

            for i in range(len(X_test)):
                curr_name = p["ch_names"][i]
                av_ch = list(range(len(X_test)))
                av_ch.remove(i)
                mae_rec = eval_reconstruction(model, X_test, X_test_list,
                                              mask_test_list, av_ch, i)
                # Get MAE result for that specific channel over all timepoints
                results[f"recon_{curr_name}_mae"] = mae_rec

        loss = {
            "mae_train": train_loss["mae"],
            "rec_train": train_loss["rec_loss"],
            "mae_test": test_loss["mae"],
            "loss_total": model.loss['total'][-1],
            "loss_kl": model.loss['kl'][-1],
            "loss_ll": model.loss['ll'][-1],
        }

        if p["dropout"]:
            loss["dropout_comps"] = model.dropout_comp

        loss = {**loss, **results}
        print(results)
    """
    # Dir for projections
    proj_path = 'z_proj/'
    if not os.path.exists(out_dir + proj_path):
        os.makedirs(out_dir + proj_path)

    # Test the new function of latent space
    #NEED TO ADAPT THIS FUNCTION
    qzx_train = [np.array(x) for x in X_train_fwd['qzx']]
    qzx_test = [np.array(x) for x in X_test_fwd['qzx']]

    #Convert to standard
    #Add padding so that the mask also works here
    DX_train = [[x for x in elem] for elem in Y_train["DX"]]
    DX_test = [[x for x in elem] for elem in Y_test["DX"]]

    #Define colors
    pallete_dict = {
        "CN": "#2a9e1e",
        "MCI": "#bfbc1a",
        "AD": "#af1f1f"
    }
    # Get classificator labels, for n time points
    out_dir_sample = out_dir + 'zcomp_ch_dx/'
    if not os.path.exists(out_dir_sample):
        os.makedirs(out_dir_sample)

    plot_latent_space(model, qzx_test, ntp, classificator=DX_test, pallete_dict=pallete_dict, plt_tp='all',
                all_plots=True, uncertainty=False, savefig=True, out_dir=out_dir_sample + '_test', mask=mask_test_list)

    plot_latent_space(model, qzx_train, ntp, classificator=DX_train, pallete_dict=pallete_dict, plt_tp='all',
                    all_plots=True, uncertainty=False, savefig=True, out_dir=out_dir_sample + '_train', mask=mask_train_list)
    
    out_dir_sample_t0 = out_dir + 'zcomp_ch_dx_t0/'
    if not os.path.exists(out_dir_sample_t0):
        os.makedirs(out_dir_sample_t0)

    plot_latent_space(model, qzx_train, ntp, classificator=DX_train, pallete_dict=pallete_dict, plt_tp=[0],
                    all_plots=True, uncertainty=False, savefig=True, out_dir=out_dir_sample_t0 + '_train', mask=mask_train_list)

    plot_latent_space(model, qzx_test, ntp, classificator=DX_test, pallete_dict=pallete_dict, plt_tp=[0],
                    all_plots=True, uncertainty=False, savefig=True, out_dir=out_dir_sample_t0 + '_test', mask=mask_test_list)

    # Now plot color by timepoint
    out_dir_sample = out_dir + 'zcomp_ch_tp/'
    if not os.path.exists(out_dir_sample):
        os.makedirs(out_dir_sample)

    classif_train = [[i for (i, x) in enumerate(elem)] for elem in Y_train["DX"]]
    classif_test = [[i for (i, x) in enumerate(elem)] for elem in Y_test["DX"]]

    pallete = sns.color_palette("viridis", ntp)
    pallete_dict = {i:value for (i, value) in enumerate(pallete)}

    plot_latent_space(model, qzx_train, ntp, classificator=classif_train, pallete_dict=pallete_dict, plt_tp='all',
                    all_plots=True, uncertainty=False, savefig=True, out_dir=out_dir_sample + '_train', mask=mask_train_list)

    plot_latent_space(model, qzx_test, ntp, classificator=classif_test, pallete_dict=pallete_dict, plt_tp='all',
                    all_plots=True, uncertainty=False, savefig=True, out_dir=out_dir_sample + '_test', mask=mask_test_list)
    """
    return loss
Exemplo n.º 4
0
# IF DROPOUT, CHECK THE COMPONENTS AND THRESHOLD AND CHANGE IT
####################################

##TEST
X_test_fwd = model.predict(X_test_list, mask_test_list, nt=ntp)

# Test the reconstruction and prediction

######################
## Prediction of last time point
######################
# Test data without last timepoint
# X_test_tensors do have the last timepoint
pred_ch = list(range(3))
t_pred = 1
res = eval_prediction(model, X_test, t_pred, pred_ch, DEVICE)

for (i, ch) in enumerate(
    [x for (i, x) in enumerate(p["ch_names"]) if i in pred_ch]):
    print(f'pred_{ch}_mae: {res[i]}')

############################
## Test reconstruction for each channel, using the other one
############################
# For each channel
results = np.zeros(
    (len(X_test), len(X_test)))  #store the results, will save later

for i in range(len(X_test)):
    for j in range(len(X_test)):
        curr_name = p["ch_names"][j]