Ejemplo n.º 1
0
def offline_eval_WAE(exp_dir, evalconfig, startwith, model_name2, check):

    if check:
        df_exist = pd.read_csv(os.path.join(exp_dir, "eval_metrics_all.csv"))
        uid_exist = list(df_exist.loc[df_exist['metric'] == 'test_rmse'].uid)
        print('passed')
    else:
        uid_exist = []
        print('other pass')
        pass
    subfolders = [
        f.path for f in os.scandir(exp_dir)
        if (f.is_dir() and f and f.path.split('/')[-1].startswith(startwith))
    ]

    for run_dir in subfolders:
        exp = run_dir.split('/')[-1]

        if exp in uid_exist and check:
            continue2 = False
        else:
            continue2 = True

        try:
            os.remove(os.path.join(run_dir, "metrics.json"))
        except:
            print('File does not exist')

        with open(os.path.join(run_dir, 'config.json'), 'r') as f:
            json_file = json.load(f)

        config = json_file['config']

        data_set_str = config['dataset']['py/object']
        mod_name, dataset_name = data_set_str.rsplit('.', 1)
        mod = importlib.import_module(mod_name)
        dataset = getattr(mod, dataset_name)

        dataset = dataset()
        X_test, y_test = dataset.sample(**config['sampling_kwargs'],
                                        train=False)
        selected_dataset = TensorDataset(torch.Tensor(X_test),
                                         torch.Tensor(y_test))

        X_train, y_train = dataset.sample(**config['sampling_kwargs'],
                                          train=True)
        train_dataset = TensorDataset(torch.Tensor(X_train),
                                      torch.Tensor(y_train))

        model_str = config['model_class']['py/type']
        mod_name2, model_name = model_str.rsplit('.', 1)
        mod2 = importlib.import_module(mod_name2)
        autoencoder = getattr(mod2, model_name)

        autoencoder = autoencoder(**config['model_kwargs'])

        if model_name2 == 'topoae_ext':
            model = WitnessComplexAutoencoder(autoencoder)
        elif model_name2 == 'vanilla_ae':
            model = autoencoder
        else:
            raise ValueError("Model {} not defined.".format(model_name2))

        continue_ = False
        try:
            state_dict = torch.load(os.path.join(run_dir, 'model_state.pth'),
                                    map_location=torch.device('cpu'))
            continue_ = True
        except:
            print('WARNING: model {} not complete'.format(exp))

        try:
            state_dict = torch.load(os.path.join(run_dir, 'model_state.pth'),
                                    map_location=torch.device('cpu'))
            continue_ = True
        except:
            print('WARNING: model {} not complete'.format(exp))

        if continue_:
            if 'latent' not in state_dict and model_name2 == 'topoae_ext':
                state_dict['latent_norm'] = torch.Tensor([1.0]).float()

            model.load_state_dict(state_dict)
            model.eval()

            dataloader_eval = torch.utils.data.DataLoader(
                selected_dataset,
                batch_size=config['batch_size'],
                pin_memory=True,
                drop_last=False)

            X_eval, Y_eval, Z_eval = get_latentspace_representation(
                model, dataloader_eval, device='cpu')

            result = dict()
            if evalconfig.eval_manifold:
                # sample true manifold
                manifold_eval_train = False
                try:
                    Z_manifold, X_transformed, labels = dataset.sample_manifold(
                        **config['sampling_kwargs'], train=manifold_eval_train)

                    dataset_test = TensorDataset(torch.Tensor(X_transformed),
                                                 torch.Tensor(labels))
                    dataloader_eval = torch.utils.data.DataLoader(
                        dataset_test,
                        batch_size=config['batch_size'],
                        pin_memory=True,
                        drop_last=False)
                    X_eval, Y_eval, Z_latent = get_latentspace_representation(
                        model, dataloader_eval, device='cpu')

                    Z_manifold[:,
                               0] = (Z_manifold[:, 0] - Z_manifold[:, 0].min()
                                     ) / (Z_manifold[:, 0].max() -
                                          Z_manifold[:, 0].min())
                    Z_manifold[:,
                               1] = (Z_manifold[:, 1] - Z_manifold[:, 1].min()
                                     ) / (Z_manifold[:, 1].max() -
                                          Z_manifold[:, 1].min())
                    Z_latent[:,
                             0] = (Z_latent[:, 0] - Z_latent[:, 0].min()) / (
                                 Z_latent[:, 0].max() - Z_latent[:, 0].min())
                    Z_latent[:,
                             1] = (Z_latent[:, 1] - Z_latent[:, 1].min()) / (
                                 Z_latent[:, 1].max() - Z_latent[:, 1].min())

                    pwd_Z = pairwise_distances(Z_latent, Z_latent, n_jobs=1)
                    pwd_Ztrue = pairwise_distances(Z_manifold,
                                                   Z_manifold,
                                                   n_jobs=1)

                    # normalize distances
                    pairwise_distances_manifold = (
                        pwd_Ztrue - pwd_Ztrue.min()) / (pwd_Ztrue.max() -
                                                        pwd_Ztrue.min())
                    pairwise_distances_Z = (pwd_Z - pwd_Z.min()) / (
                        pwd_Z.max() - pwd_Z.min())

                    # save comparison fig
                    plot_distcomp_Z_manifold(
                        Z_manifold=Z_manifold,
                        Z_latent=Z_latent,
                        pwd_manifold=pairwise_distances_manifold,
                        pwd_Z=pairwise_distances_Z,
                        labels=labels,
                        path_to_save=run_dir,
                        name='manifold_Z_distcomp',
                        fontsize=24,
                        show=False)

                    rmse_manifold = (np.square(pairwise_distances_manifold -
                                               pairwise_distances_Z)).mean()
                    result.update(dict(rmse_manifold_Z=rmse_manifold))
                except AttributeError as err:
                    print(err)
                    print('Manifold not evaluated!')

            if run_dir and evalconfig.save_eval_latent:
                df = pd.DataFrame(Z_eval)
                df['labels'] = Y_eval
                df.to_csv(os.path.join(run_dir, 'latents.csv'), index=False)
                np.savez(os.path.join(run_dir, 'latents.npz'),
                         latents=Y_eval,
                         labels=Z_eval)
                plot_2Dscatter(Z_eval,
                               Y_eval,
                               path_to_save=os.path.join(
                                   run_dir, 'test_latent_visualization.png'),
                               dpi=100,
                               title=None,
                               show=False)

            if run_dir and evalconfig.save_train_latent:
                dataloader_train = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=config['batch_size'],
                    pin_memory=True,
                    drop_last=False)
                X_train, Y_train, Z_train = get_latentspace_representation(
                    model, dataloader_train, device='cpu')

                df = pd.DataFrame(Z_train)
                df['labels'] = Y_train
                df.to_csv(os.path.join(run_dir, 'train_latents.csv'),
                          index=False)
                np.savez(os.path.join(run_dir, 'latents.npz'),
                         latents=Z_train,
                         labels=Y_train)
                # Visualize latent space
                plot_2Dscatter(Z_train,
                               Y_train,
                               path_to_save=os.path.join(
                                   run_dir, 'train_latent_visualization.png'),
                               dpi=100,
                               title=None,
                               show=False)
            if evalconfig.quant_eval and continue2:
                print('QUANT EVAL....')
                ks = list(
                    range(evalconfig.k_min,
                          evalconfig.k_max + evalconfig.k_step,
                          evalconfig.k_step))

                evaluator = Multi_Evaluation(dataloader=dataloader_eval,
                                             seed=config['seed'],
                                             model=model)
                ev_result = evaluator.get_multi_evals(X_eval,
                                                      Z_eval,
                                                      Y_eval,
                                                      ks=ks)
                prefixed_ev_result = {
                    evalconfig.evaluate_on + '_' + key: value
                    for key, value in ev_result.items()
                }
                result.update(prefixed_ev_result)
                s = json.dumps(result, default=default)
                open(os.path.join(run_dir, 'eval_metrics.json'), "w").write(s)

                result_avg = avg_array_in_dict(result)

                df = pd.DataFrame.from_dict(result_avg,
                                            orient='index').reset_index()
                df.columns = ['metric', 'value']

                id_dict = dict(
                    uid=exp,
                    seed=config['seed'],
                    batch_size=config['batch_size'],
                )
                for key, value in id_dict.items():
                    df[key] = value
                df.set_index('uid')

                df = df[COLS_DF_RESULT]
                df.to_csv(os.path.join(exp_dir, 'eval_metrics_all.csv'),
                          mode='a',
                          header=False)
            else:
                print('skipped quant eval')
        else:
            shutil.move(run_dir, os.path.join(exp_dir, 'not_evaluated'))

    if path_model == MODEL_vae2_1:
        state_dict = torch.load(os.path.join(path_model, 'model_state.pth'),map_location=torch.device('cpu'))
        model = autoencoder
        print('Loading Passed')
    else:
        model = WitnessComplexAutoencoder(autoencoder)
        state_dict = torch.load(os.path.join(path_model, 'model_state.pth'),map_location=torch.device('cpu'))

        state_dict2 = torch.load(os.path.join(os.path.join(path_norm2,exp_norm2), 'model_state.pth'))
        if 'latent' not in state_dict:
            state_dict['latent_norm'] = state_dict2['latent_norm'] * 0.1

        print('passed')

    model.load_state_dict(state_dict)
    model.eval()

    z = model.encode(images.float())
    debugging = True

    np.save('/Users/simons/PycharmProjects/MT-VAEs-TDA/src/datasets/simulated/xy_trans_l_newpers/latents_wae.np', z.detach().numpy())



    # plot_2Dscatter(z.detach().numpy(), labels, path_to_save=os.path.join(path_save, 'pos_{}.pdf'.format(
    #     time.strftime("%Y%m%d-%H%M%S"))), title=None, show=True, palette = 'custom')


Ejemplo n.º 3
0
def train_TopoAE_ext(_run, _seed, _rnd, config: ConfigWCAE, experiment_dir,
                     experiment_root, device, num_threads, verbose):
    if device == 'cpu' and num_threads is not None:
        torch.set_num_threads(num_threads)
    try:
        os.makedirs(experiment_dir)
    except:
        pass

    try:
        os.makedirs(experiment_root)
    except:
        pass

    if os.path.isfile(os.path.join(experiment_root, 'eval_metrics_all.csv')):
        pass
    else:
        df = pd.DataFrame(columns=COLS_DF_RESULT)
        df.to_csv(os.path.join(experiment_root, 'eval_metrics_all.csv'))

    # Sample data
    dataset = config.dataset
    X_train, y_train = dataset.sample(**config.sampling_kwargs, train=True)
    dataset_train = TensorDataset(torch.Tensor(X_train), torch.Tensor(y_train))

    X_test, y_test = dataset.sample(**config.sampling_kwargs, train=False)
    dataset_test = TensorDataset(torch.Tensor(X_test), torch.Tensor(y_test))

    torch.manual_seed(_seed)

    # Initialize model
    # save normalization constants somewhere
    if isinstance(dataset, (MNIST, MNIST_offline)):
        norm_X = 28**2 / 10  #-> dimension of images is 28x28, max delta per pixel is 1, since data is normalized.
    elif isinstance(dataset, (Unity_Rotblock, Unity_RotCorgi)):
        norm_X = 180
    elif X_train.shape[0] > 4096:
        #todo fix somehow
        inds = random.sample(range(X_train.shape[0]), 2048)
        norm_X = torch.cdist(dataset_train[inds][:][0],
                             dataset_train[inds][:][0]).max()
    else:
        norm_X = torch.cdist(dataset_train[:][:][0],
                             dataset_train[:][:][0]).max()

    model_class = config.model_class
    autoencoder = model_class(**config.model_kwargs)
    model = WitnessComplexAutoencoder(autoencoder,
                                      lam_r=config.rec_loss_weight,
                                      lam_t=config.top_loss_weight,
                                      toposig_kwargs=config.toposig_kwargs,
                                      norm_X=norm_X,
                                      device=config.device)

    if config.method_args['pre_trained_model'] is not None:
        state_dict = torch.load(
            os.path.join(config.method_args['pre_trained_model'],
                         'model_state.pth'))
        model.load_state_dict(state_dict)
    model.to(device)

    # Train and evaluate model
    result = train(model=model,
                   data_train=dataset_train,
                   data_test=dataset_test,
                   config=config,
                   device=device,
                   quiet=operator.not_(verbose),
                   val_size=config.method_args['val_size'],
                   _seed=_seed,
                   _rnd=_rnd,
                   _run=_run,
                   rundir=experiment_dir)

    # Format experiment data
    df = pd.DataFrame.from_dict(result, orient='index').reset_index()
    df.columns = ['metric', 'value']

    id_dict = config.create_id_dict()
    for key, value in id_dict.items():
        df[key] = value
    df.set_index('uid')

    df = df[COLS_DF_RESULT]

    df.to_csv(os.path.join(experiment_root, 'eval_metrics_all.csv'),
              mode='a',
              header=False)