Exemplo n.º 1
0
    def on_epoch_end(self, model, dataset, img, epoch, **kwargs):
        """Save reconstruction images."""

        if epoch == 1 or epoch%10==0:
            model.eval()
            _, labels, latents = get_latentspace_representation(model, self.data_loader, device=self.device,bs=True)

            #get pairings
            neigh = NearestNeighbors(n_neighbors=3).fit(latents)
            distances, pairings = neigh.kneighbors()

            plot_2Dscatter(latents, labels, pairings,path_to_save=os.path.join(
                self.path, f'latent_epoch_{epoch}.pdf'), title=None, show=False)
Exemplo n.º 2
0
    def on_epoch_end(self, model, data, labels, img, epoch,get_latent, get_reconst, **kwargs):
        """Save reconstruction images."""
        model.eval()
        latent = model.encode(data)
        if get_latent:
            path_plot_svg = self.path+'svg/'+'latent_epoch_{}'.format(epoch)+'.svg'
            path_plot_png = self.path+'png/'+'latent_epoch_{}'.format(epoch)+'.png'
            plot_2Dscatter(latent, labels, path_to_save =  path_plot_png, title = None, show=False)
            plot_2Dscatter(latent, labels, path_to_save =  path_plot_svg, title = None, show=False)

        if get_reconst:
            reconst = model.decode(latent)
            path_plot_svg = self.path+'svg/'+'reconst_epoch_{}'.format(epoch)+'.svg'
            path_plot_png = self.path+'png/'+'reconst_epoch_{}'.format(epoch)+'.png'
            plot_2Dscatter(reconst, labels, path_to_save =  path_plot_png, title = None, show=False)
            plot_2Dscatter(reconst, labels, path_to_save =  path_plot_svg, title = None, show=False)
Exemplo n.º 3
0
def train(model, data_train, data_test, config, device, quiet, val_size, _seed,
          _rnd, _run, rundir):
    """Sacred wrapped function to run training of model."""

    try:
        os.makedirs(rundir)
    except:
        pass

    train_dataset, validation_dataset = split_validation(
        data_train, val_size, _rnd)
    test_dataset = data_test

    callbacks = [
        LogTrainingLoss(_run, print_progress=operator.not_(quiet)),
        LogDatasetLoss('validation',
                       validation_dataset,
                       _run,
                       method_args=config.method_args,
                       print_progress=operator.not_(quiet),
                       batch_size=config.batch_size,
                       early_stopping=config.early_stopping,
                       save_path=rundir,
                       device=device),
        LogDatasetLoss('testing',
                       test_dataset,
                       _run,
                       method_args=config.method_args,
                       print_progress=operator.not_(quiet),
                       batch_size=config.batch_size,
                       device=device),
    ]

    if quiet:
        pass
    else:
        callbacks.append(NewlineCallback())

    # If we are logging this run save reconstruction images
    if rundir is not None:
        # if hasattr(train_dataset, 'inverse_normalization'):
        #     # We have image data so we can visualize reconstructed images
        #     callbacks.append(SaveReconstructedImages(rundir))
        if config.eval.online_visualization:
            callbacks.append(
                SaveLatentRepresentation(test_dataset,
                                         rundir,
                                         batch_size=64,
                                         device=device))

    training_loop = TrainingLoop(model,
                                 train_dataset,
                                 config.n_epochs,
                                 config.batch_size,
                                 config.learning_rate,
                                 config.method_args,
                                 config.weight_decay,
                                 device,
                                 callbacks,
                                 verbose=operator.not_(quiet),
                                 num_threads=config.num_threads)

    # Run training
    epoch, run_times_epoch = training_loop()

    if rundir:
        # Save model state (and entire model)
        if not quiet:
            print('Loading model checkpoint prior to evaluation...')
        state_dict = torch.load(os.path.join(rundir, 'model_state.pth'))
        model.load_state_dict(state_dict)
    model.eval()

    logged_averages = callbacks[0].logged_averages
    logged_stds = callbacks[0].logged_stds
    loss_averages = {
        key: value
        for key, value in logged_averages.items() if 'loss' in key
    }
    loss_stds = {
        key: value
        for key, value in logged_stds.items() if 'loss' in key
    }
    metric_averages = {
        key: value
        for key, value in logged_averages.items() if 'metric' in key
    }
    metric_stds = {
        key: value
        for key, value in logged_stds.items() if 'metric' in key
    }

    if rundir:
        plot_losses(
            loss_averages,
            loss_stds,
            save_file=os.path.join(rundir, 'loss.pdf'),
        )
        plot_losses(metric_averages,
                    metric_stds,
                    save_file=os.path.join(rundir, 'metrics.pdf'),
                    pairs_axes=True)

    result = {key: values[-1] for key, values in logged_averages.items()}

    if config.eval.active:
        evaluate_on = config.eval.evaluate_on
        if evaluate_on == 'validation':
            selected_dataset = validation_dataset
        else:
            selected_dataset = test_dataset

        dataloader_eval = torch.utils.data.DataLoader(
            selected_dataset,
            batch_size=config.batch_size,
            pin_memory=True,
            drop_last=False)

        X_eval, Y_eval, Z_eval = get_latentspace_representation(
            model, dataloader_eval, device=device)

        if config.eval.eval_manifold:
            # sample true manifold
            if evaluate_on == 'validation':
                manifold_eval_train = True
            else:
                manifold_eval_train = False
            try:
                dataset = config.dataset
                Z_manifold, X_transformed, labels = dataset.sample_manifold(
                    **config.sampling_kwargs, train=manifold_eval_train)

                dataset_test = TensorDataset(torch.Tensor(X_transformed),
                                             torch.Tensor(labels))
                dataloader_eval = torch.utils.data.DataLoader(
                    dataset_test,
                    batch_size=config.batch_size,
                    pin_memory=True,
                    drop_last=False)
                X_eval, Y_eval, Z_latent = get_latentspace_representation(
                    model, dataloader_eval, device=device)

                Z_manifold[:,
                           0] = (Z_manifold[:, 0] - Z_manifold[:, 0].min()) / (
                               Z_manifold[:, 0].max() - Z_manifold[:, 0].min())
                Z_manifold[:,
                           1] = (Z_manifold[:, 1] - Z_manifold[:, 1].min()) / (
                               Z_manifold[:, 1].max() - Z_manifold[:, 1].min())
                Z_latent[:, 0] = (Z_latent[:, 0] - Z_latent[:, 0].min()) / (
                    Z_latent[:, 0].max() - Z_latent[:, 0].min())
                Z_latent[:, 1] = (Z_latent[:, 1] - Z_latent[:, 1].min()) / (
                    Z_latent[:, 1].max() - Z_latent[:, 1].min())

                pwd_Z = pairwise_distances(Z_latent, Z_latent, n_jobs=1)
                pwd_Ztrue = pairwise_distances(Z_manifold,
                                               Z_manifold,
                                               n_jobs=1)

                # normalize distances
                pairwise_distances_manifold = (pwd_Ztrue - pwd_Ztrue.min()) / (
                    pwd_Ztrue.max() - pwd_Ztrue.min())
                pairwise_distances_Z = (pwd_Z - pwd_Z.min()) / (pwd_Z.max() -
                                                                pwd_Z.min())

                # save comparison fig
                plot_distcomp_Z_manifold(
                    Z_manifold=Z_manifold,
                    Z_latent=Z_latent,
                    pwd_manifold=pairwise_distances_manifold,
                    pwd_Z=pairwise_distances_Z,
                    labels=labels,
                    path_to_save=rundir,
                    name='manifold_Z_distcomp',
                    fontsize=24,
                    show=False)

                rmse_manifold = (np.square(pairwise_distances_manifold -
                                           pairwise_distances_Z)).mean()
                result.update(dict(rmse_manifold_Z=rmse_manifold))
            except AttributeError as err:
                print(err)
                print('Manifold not evaluated!')

        if rundir and config.eval.save_eval_latent:
            df = pd.DataFrame(Z_eval)
            df['labels'] = Y_eval
            df.to_csv(os.path.join(rundir, 'latents.csv'), index=False)
            np.savez(os.path.join(rundir, 'latents.npz'),
                     latents=Y_eval,
                     labels=Z_eval)
            plot_2Dscatter(Z_eval,
                           Y_eval,
                           path_to_save=os.path.join(
                               rundir, 'test_latent_visualization.pdf'),
                           title=None,
                           show=False)

        if rundir and config.eval.save_train_latent:
            dataloader_train = torch.utils.data.DataLoader(
                train_dataset,
                batch_size=config.batch_size,
                pin_memory=True,
                drop_last=False)
            X_train, Y_train, Z_train = get_latentspace_representation(
                model, dataloader_train, device=device)

            df = pd.DataFrame(Z_train)
            df['labels'] = Y_train
            df.to_csv(os.path.join(rundir, 'train_latents.csv'), index=False)
            np.savez(os.path.join(rundir, 'latents.npz'),
                     latents=Z_train,
                     labels=Y_train)
            # Visualize latent space
            plot_2Dscatter(Z_train,
                           Y_train,
                           path_to_save=os.path.join(
                               rundir, 'train_latent_visualization.pdf'),
                           title=None,
                           show=False)
        if config.eval.quant_eval:
            ks = list(
                range(config.eval.k_min,
                      config.eval.k_max + config.eval.k_step,
                      config.eval.k_step))

            evaluator = Multi_Evaluation(dataloader=dataloader_eval,
                                         seed=_seed,
                                         model=model)
            ev_result = evaluator.get_multi_evals(X_eval,
                                                  Z_eval,
                                                  Y_eval,
                                                  ks=ks)
            prefixed_ev_result = {
                config.eval.evaluate_on + '_' + key: value
                for key, value in ev_result.items()
            }
            result.update(prefixed_ev_result)
            s = json.dumps(result, default=default)
            open(os.path.join(rundir, 'eval_metrics.json'), "w").write(s)

    result_avg = avg_array_in_dict(result)
    result_avg.update({'run_times_epoch': statistics.mean(run_times_epoch)})

    return result_avg
        state_dict2 = torch.load(
            os.path.join(os.path.join(path_norm2, exp_norm2),
                         'model_state.pth'))
        if 'latent' not in state_dict:
            state_dict['latent_norm'] = state_dict2['latent_norm'] * 0.1

        print('passed')

    model.load_state_dict(state_dict)
    model.eval()

    z = model.encode(images.float())
    debugging = True

    plot_2Dscatter(z.detach().numpy(),
                   x_position.detach().numpy(),
                   path_to_save=os.path.join(
                       path_save,
                       'x_vis_{}.pdf'.format(time.strftime("%Y%m%d-%H%M%S"))),
                   title=None,
                   show=True,
                   palette='x')
    plot_2Dscatter(z.detach().numpy(),
                   y_position.detach().numpy(),
                   path_to_save=os.path.join(
                       path_save,
                       'y_vis_{}.pdf'.format(time.strftime("%Y%m%d-%H%M%S"))),
                   title=None,
                   show=True,
                   palette='y')
Exemplo n.º 5
0
    topoae_2d_std = '/Users/simons/MT_data/eval_data/MNIST_FINAL/TopoAE/std_minimizer/MNIST_offline-seed1988-DeepAE_MNIST-default-lr1_1000-bs64-nep1000-rlw1-tlw1_256-seed1988-1ae25b75'
    topoae_2d_cont = '/Users/simons/MT_data/eval_data/MNIST_FINAL/TopoAE/cont_minimizer/MNIST_offline-seed579-DeepAE_MNIST-default-lr1_1000-bs128-nep1000-rlw1-tlw1_128-seed579-3c78a835'
    topoae_2d_rec = '/Users/simons/MT_data/eval_data/MNIST_FINAL/TopoAE/rec_minimizer/MNIST_offline-seed838-DeepAE_MNIST-default-lr1_1000-bs256-nep1000-rlw1-tlw1_16-seed838-67de8d97'

    for path_exp in [
            topoae_2d_kl01, topoae_2d_std, topoae_2d_cont, topoae_2d_rec
    ]:
        # get model
        autoencoder = DeepAE_MNIST()

        model = TopologicallyRegularizedAutoencoder(autoencoder)
        state_dict = torch.load(os.path.join(path_exp, 'model_state.pth'),
                                map_location=torch.device('cpu'))

        model.load_state_dict(state_dict)
        model.eval()

        dataset = MNIST_offline()
        data, labels = dataset.sample(train=False)

        z = model.encode(torch.Tensor(data).float())

        plot_2Dscatter(z.detach().numpy(),
                       labels,
                       path_to_save=os.path.join(
                           path_exp,
                           '{}_latent_visualization.pdf'.format('final')),
                       title=None,
                       show=False,
                       palette='custom2')
Exemplo n.º 6
0
        # get reconstructed images
        labels = latents['labels']
        latents_tensor = torch.tensor(latents[['0', '1']].values)
    except:
        dataloarder_train = torch.load(
            os.path.join(
                '/Users/simons/PycharmProjects/MT-VAEs-TDA/src/datasets/simulated/openai_rotating',
                'dataloader_train.pt'))

        X_eval, Y_eval, Z_eval = get_latentspace_representation(
            model, dataloarder_train, device='cpu')

        plot_2Dscatter(Z_eval,
                       Y_eval,
                       palette='hsv',
                       path_to_save=os.path.join(
                           path_source, '{}.pdf'.format('latentes_cyclic')),
                       title=None,
                       show=True)
        latents_tensor = torch.tensor(Z_eval)

    x_hat = model.decode(latents_tensor.float())

    #x_hat = model.decode(torch.tensor(latents_temp[:][['0', '1']].values).float())
    trans = transforms.ToPILImage()

    for i in range(12):
        ii = 20 * i
        plt.imshow(trans(x_hat[ii][:][:][:]))
        plt.savefig(os.path.join(path_source, '{}deg.pdf'.format(ii)))
        plt.show()
Exemplo n.º 7
0
    try:
        os.mkdir(os.path.join(path_source, 'latents_inproc'))
    except:
        pass

    autoencoder = ConvAE_Unity480320()
    model = WitnessComplexAutoencoder(autoencoder)
    state_dict = torch.load(os.path.join(path_source, 'model_state.pth'),
                            map_location=torch.device('cpu'))

    state_dict2 = torch.load(
        os.path.join(os.path.join(root_path, exp1), 'model_state.pth'))
    if 'latent' not in state_dict:
        state_dict['latent_norm'] = state_dict2['latent_norm'] * 0.1

    model.load_state_dict(state_dict)
    model.eval()

    X_eval, Y_eval, Z_eval = get_latentspace_representation(model,
                                                            dataloarder_train,
                                                            device='cpu')

    plot_2Dscatter(Z_eval,
                   Y_eval,
                   path_to_save=os.path.join(
                       path_source, 'latents_inproc',
                       '{}.pdf'.format(time.strftime("%Y%m%d-%H%M%S"))),
                   title=None,
                   show=True)
Exemplo n.º 8
0
def offline_eval_WAE(exp_dir, evalconfig, startwith, model_name2, check):

    if check:
        df_exist = pd.read_csv(os.path.join(exp_dir, "eval_metrics_all.csv"))
        uid_exist = list(df_exist.loc[df_exist['metric'] == 'test_rmse'].uid)
        print('passed')
    else:
        uid_exist = []
        print('other pass')
        pass
    subfolders = [
        f.path for f in os.scandir(exp_dir)
        if (f.is_dir() and f and f.path.split('/')[-1].startswith(startwith))
    ]

    for run_dir in subfolders:
        exp = run_dir.split('/')[-1]

        if exp in uid_exist and check:
            continue2 = False
        else:
            continue2 = True

        try:
            os.remove(os.path.join(run_dir, "metrics.json"))
        except:
            print('File does not exist')

        with open(os.path.join(run_dir, 'config.json'), 'r') as f:
            json_file = json.load(f)

        config = json_file['config']

        data_set_str = config['dataset']['py/object']
        mod_name, dataset_name = data_set_str.rsplit('.', 1)
        mod = importlib.import_module(mod_name)
        dataset = getattr(mod, dataset_name)

        dataset = dataset()
        X_test, y_test = dataset.sample(**config['sampling_kwargs'],
                                        train=False)
        selected_dataset = TensorDataset(torch.Tensor(X_test),
                                         torch.Tensor(y_test))

        X_train, y_train = dataset.sample(**config['sampling_kwargs'],
                                          train=True)
        train_dataset = TensorDataset(torch.Tensor(X_train),
                                      torch.Tensor(y_train))

        model_str = config['model_class']['py/type']
        mod_name2, model_name = model_str.rsplit('.', 1)
        mod2 = importlib.import_module(mod_name2)
        autoencoder = getattr(mod2, model_name)

        autoencoder = autoencoder(**config['model_kwargs'])

        if model_name2 == 'topoae_ext':
            model = WitnessComplexAutoencoder(autoencoder)
        elif model_name2 == 'vanilla_ae':
            model = autoencoder
        else:
            raise ValueError("Model {} not defined.".format(model_name2))

        continue_ = False
        try:
            state_dict = torch.load(os.path.join(run_dir, 'model_state.pth'),
                                    map_location=torch.device('cpu'))
            continue_ = True
        except:
            print('WARNING: model {} not complete'.format(exp))

        try:
            state_dict = torch.load(os.path.join(run_dir, 'model_state.pth'),
                                    map_location=torch.device('cpu'))
            continue_ = True
        except:
            print('WARNING: model {} not complete'.format(exp))

        if continue_:
            if 'latent' not in state_dict and model_name2 == 'topoae_ext':
                state_dict['latent_norm'] = torch.Tensor([1.0]).float()

            model.load_state_dict(state_dict)
            model.eval()

            dataloader_eval = torch.utils.data.DataLoader(
                selected_dataset,
                batch_size=config['batch_size'],
                pin_memory=True,
                drop_last=False)

            X_eval, Y_eval, Z_eval = get_latentspace_representation(
                model, dataloader_eval, device='cpu')

            result = dict()
            if evalconfig.eval_manifold:
                # sample true manifold
                manifold_eval_train = False
                try:
                    Z_manifold, X_transformed, labels = dataset.sample_manifold(
                        **config['sampling_kwargs'], train=manifold_eval_train)

                    dataset_test = TensorDataset(torch.Tensor(X_transformed),
                                                 torch.Tensor(labels))
                    dataloader_eval = torch.utils.data.DataLoader(
                        dataset_test,
                        batch_size=config['batch_size'],
                        pin_memory=True,
                        drop_last=False)
                    X_eval, Y_eval, Z_latent = get_latentspace_representation(
                        model, dataloader_eval, device='cpu')

                    Z_manifold[:,
                               0] = (Z_manifold[:, 0] - Z_manifold[:, 0].min()
                                     ) / (Z_manifold[:, 0].max() -
                                          Z_manifold[:, 0].min())
                    Z_manifold[:,
                               1] = (Z_manifold[:, 1] - Z_manifold[:, 1].min()
                                     ) / (Z_manifold[:, 1].max() -
                                          Z_manifold[:, 1].min())
                    Z_latent[:,
                             0] = (Z_latent[:, 0] - Z_latent[:, 0].min()) / (
                                 Z_latent[:, 0].max() - Z_latent[:, 0].min())
                    Z_latent[:,
                             1] = (Z_latent[:, 1] - Z_latent[:, 1].min()) / (
                                 Z_latent[:, 1].max() - Z_latent[:, 1].min())

                    pwd_Z = pairwise_distances(Z_latent, Z_latent, n_jobs=1)
                    pwd_Ztrue = pairwise_distances(Z_manifold,
                                                   Z_manifold,
                                                   n_jobs=1)

                    # normalize distances
                    pairwise_distances_manifold = (
                        pwd_Ztrue - pwd_Ztrue.min()) / (pwd_Ztrue.max() -
                                                        pwd_Ztrue.min())
                    pairwise_distances_Z = (pwd_Z - pwd_Z.min()) / (
                        pwd_Z.max() - pwd_Z.min())

                    # save comparison fig
                    plot_distcomp_Z_manifold(
                        Z_manifold=Z_manifold,
                        Z_latent=Z_latent,
                        pwd_manifold=pairwise_distances_manifold,
                        pwd_Z=pairwise_distances_Z,
                        labels=labels,
                        path_to_save=run_dir,
                        name='manifold_Z_distcomp',
                        fontsize=24,
                        show=False)

                    rmse_manifold = (np.square(pairwise_distances_manifold -
                                               pairwise_distances_Z)).mean()
                    result.update(dict(rmse_manifold_Z=rmse_manifold))
                except AttributeError as err:
                    print(err)
                    print('Manifold not evaluated!')

            if run_dir and evalconfig.save_eval_latent:
                df = pd.DataFrame(Z_eval)
                df['labels'] = Y_eval
                df.to_csv(os.path.join(run_dir, 'latents.csv'), index=False)
                np.savez(os.path.join(run_dir, 'latents.npz'),
                         latents=Y_eval,
                         labels=Z_eval)
                plot_2Dscatter(Z_eval,
                               Y_eval,
                               path_to_save=os.path.join(
                                   run_dir, 'test_latent_visualization.png'),
                               dpi=100,
                               title=None,
                               show=False)

            if run_dir and evalconfig.save_train_latent:
                dataloader_train = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=config['batch_size'],
                    pin_memory=True,
                    drop_last=False)
                X_train, Y_train, Z_train = get_latentspace_representation(
                    model, dataloader_train, device='cpu')

                df = pd.DataFrame(Z_train)
                df['labels'] = Y_train
                df.to_csv(os.path.join(run_dir, 'train_latents.csv'),
                          index=False)
                np.savez(os.path.join(run_dir, 'latents.npz'),
                         latents=Z_train,
                         labels=Y_train)
                # Visualize latent space
                plot_2Dscatter(Z_train,
                               Y_train,
                               path_to_save=os.path.join(
                                   run_dir, 'train_latent_visualization.png'),
                               dpi=100,
                               title=None,
                               show=False)
            if evalconfig.quant_eval and continue2:
                print('QUANT EVAL....')
                ks = list(
                    range(evalconfig.k_min,
                          evalconfig.k_max + evalconfig.k_step,
                          evalconfig.k_step))

                evaluator = Multi_Evaluation(dataloader=dataloader_eval,
                                             seed=config['seed'],
                                             model=model)
                ev_result = evaluator.get_multi_evals(X_eval,
                                                      Z_eval,
                                                      Y_eval,
                                                      ks=ks)
                prefixed_ev_result = {
                    evalconfig.evaluate_on + '_' + key: value
                    for key, value in ev_result.items()
                }
                result.update(prefixed_ev_result)
                s = json.dumps(result, default=default)
                open(os.path.join(run_dir, 'eval_metrics.json'), "w").write(s)

                result_avg = avg_array_in_dict(result)

                df = pd.DataFrame.from_dict(result_avg,
                                            orient='index').reset_index()
                df.columns = ['metric', 'value']

                id_dict = dict(
                    uid=exp,
                    seed=config['seed'],
                    batch_size=config['batch_size'],
                )
                for key, value in id_dict.items():
                    df[key] = value
                df.set_index('uid')

                df = df[COLS_DF_RESULT]
                df.to_csv(os.path.join(exp_dir, 'eval_metrics_all.csv'),
                          mode='a',
                          header=False)
            else:
                print('skipped quant eval')
        else:
            shutil.move(run_dir, os.path.join(exp_dir, 'not_evaluated'))
Exemplo n.º 9
0
import time

from sklearn.manifold import SpectralEmbedding

from src.datasets.datasets import SwissRoll, Spheres
from src.utils.plots import plot_2Dscatter

if __name__ == "__main__":


    dataset = SwissRoll()

    data, color = dataset.sample(n_samples=2560)

    start = time.time()
    embedding = SpectralEmbedding(n_components=2,n_jobs=1, n_neighbors=90)

    X_transformed = embedding.fit_transform(data)
    end = time.time()
    print('It took: {}'.format(end - start))

    plot_2Dscatter(data = X_transformed, labels=color, path_to_save= None, title = None, show = True)
import os

import torch
import seaborn as sns
import numpy as np
import pandas as pd

from src.datasets.datasets import MNIST_offline
from src.models.WitnessComplexAE.wc_ae import WitnessComplexAutoencoder
from src.models.autoencoder.autoencoders import DeepAE_MNIST
from src.utils.plots import plot_2Dscatter

if __name__ == "__main__":
    tsne_rmsez_path = '/Users/simons/MT_data/sync/euler_sync_scratch/schsimo/output/mnist_tsne/MNIST_offline-n_samples10000-tSNE--n_jobs1-perplexity5-seed1318-b2b38aea'

    data = pd.read_csv(os.path.join(tsne_rmsez_path, 'train_latents.csv'))

    latents = data[['0', '1']][:].to_numpy()
    labels = data['labels'].tolist()
    #labels = data[['labels']][:].tolist()
    plot_2Dscatter(latents,
                   labels,
                   path_to_save=os.path.join(
                       tsne_rmsez_path,
                       '{}_latent_visualization.pdf'.format('final')),
                   title=None,
                   show=False,
                   palette='custom2')
Exemplo n.º 11
0
def eval(result, Z_manifold, X, Z, Y, rundir, config, train=True):

    if train:
        name_prefix = 'train'
        save_latent = config.eval.save_train_latent
    else:
        name_prefix = 'test'
        save_latent = config.eval.save_eval_latent

    df = pd.DataFrame(Z)
    df['labels'] = Y
    df.to_csv(os.path.join(rundir, '{}_latents.csv'.format(name_prefix)),
              index=False)
    if rundir and save_latent:
        np.savez(os.path.join(rundir, '{}_latents.npz'.format(name_prefix)),
                 latents=Z,
                 labels=Y)
        plot_2Dscatter(Z,
                       Y,
                       path_to_save=os.path.join(
                           rundir,
                           '{}_latent_visualization.pdf'.format(name_prefix)),
                       title=None,
                       show=False,
                       palette=sns.color_palette("muted"))

    if config.eval.eval_manifold:
        try:
            Z_manifold[:, 0] = (Z_manifold[:, 0] - Z_manifold[:, 0].min()) / (
                Z_manifold[:, 0].max() - Z_manifold[:, 0].min())
            Z_manifold[:, 1] = (Z_manifold[:, 1] - Z_manifold[:, 1].min()) / (
                Z_manifold[:, 1].max() - Z_manifold[:, 1].min())
            Z[:,
              0] = (Z[:, 0] - Z[:, 0].min()) / (Z[:, 0].max() - Z[:, 0].min())
            Z[:,
              1] = (Z[:, 1] - Z[:, 1].min()) / (Z[:, 1].max() - Z[:, 1].min())
            # compute RMSE
            pwd_Z = pairwise_distances(Z, Z, n_jobs=1)
            pwd_Ztrue = pairwise_distances(Z_manifold, Z_manifold, n_jobs=1)

            pairwise_distances_manifold = (pwd_Ztrue - pwd_Ztrue.min()) / (
                pwd_Ztrue.max() - pwd_Ztrue.min())
            pairwise_distances_Z = (pwd_Z - pwd_Z.min()) / (pwd_Z.max() -
                                                            pwd_Z.min())
            rmse_manifold = (np.square(pairwise_distances_manifold -
                                       pairwise_distances_Z)).mean(axis=None)
            result.update(dict(rmse_manifold_Z=rmse_manifold))
            # save comparison fig
            plot_distcomp_Z_manifold(Z_manifold=Z_manifold,
                                     Z_latent=Z,
                                     pwd_manifold=pairwise_distances_manifold,
                                     pwd_Z=pairwise_distances_Z,
                                     labels=Y,
                                     path_to_save=rundir,
                                     name='manifold_Z_distcomp',
                                     fontsize=24,
                                     show=False)
        except AttributeError as err:
            print(err)
            print('Manifold not evaluated!')

    ks = list(
        range(config.eval.k_min, config.eval.k_max + config.eval.k_step,
              config.eval.k_step))

    calc = MeasureCalculator(X, Z, max(ks))

    indep_measures = calc.compute_k_independent_measures()
    dep_measures = calc.compute_measures_for_ks(ks)
    mean_dep_measures = {
        'mean_' + key: values.mean()
        for key, values in dep_measures.items()
    }

    ev_result = {
        key: value
        for key, value in itertools.chain(indep_measures.items(
        ), dep_measures.items(), mean_dep_measures.items())
    }

    prefixed_ev_result = {
        name_prefix + '_' + key: value
        for key, value in ev_result.items()
    }
    result.update(prefixed_ev_result)
    s = json.dumps(result, default=default)
    open(os.path.join(rundir, '{}_eval_metrics.json'.format(name_prefix)),
         "w").write(s)

    return avg_array_in_dict(result)