def offline_eval_WAE(exp_dir, evalconfig, startwith, model_name2, check): if check: df_exist = pd.read_csv(os.path.join(exp_dir, "eval_metrics_all.csv")) uid_exist = list(df_exist.loc[df_exist['metric'] == 'test_rmse'].uid) print('passed') else: uid_exist = [] print('other pass') pass subfolders = [ f.path for f in os.scandir(exp_dir) if (f.is_dir() and f and f.path.split('/')[-1].startswith(startwith)) ] for run_dir in subfolders: exp = run_dir.split('/')[-1] if exp in uid_exist and check: continue2 = False else: continue2 = True try: os.remove(os.path.join(run_dir, "metrics.json")) except: print('File does not exist') with open(os.path.join(run_dir, 'config.json'), 'r') as f: json_file = json.load(f) config = json_file['config'] data_set_str = config['dataset']['py/object'] mod_name, dataset_name = data_set_str.rsplit('.', 1) mod = importlib.import_module(mod_name) dataset = getattr(mod, dataset_name) dataset = dataset() X_test, y_test = dataset.sample(**config['sampling_kwargs'], train=False) selected_dataset = TensorDataset(torch.Tensor(X_test), torch.Tensor(y_test)) X_train, y_train = dataset.sample(**config['sampling_kwargs'], train=True) train_dataset = TensorDataset(torch.Tensor(X_train), torch.Tensor(y_train)) model_str = config['model_class']['py/type'] mod_name2, model_name = model_str.rsplit('.', 1) mod2 = importlib.import_module(mod_name2) autoencoder = getattr(mod2, model_name) autoencoder = autoencoder(**config['model_kwargs']) if model_name2 == 'topoae_ext': model = WitnessComplexAutoencoder(autoencoder) elif model_name2 == 'vanilla_ae': model = autoencoder else: raise ValueError("Model {} not defined.".format(model_name2)) continue_ = False try: state_dict = torch.load(os.path.join(run_dir, 'model_state.pth'), map_location=torch.device('cpu')) continue_ = True except: print('WARNING: model {} not complete'.format(exp)) try: state_dict = torch.load(os.path.join(run_dir, 'model_state.pth'), map_location=torch.device('cpu')) continue_ = True except: print('WARNING: model {} not complete'.format(exp)) if continue_: if 'latent' not in state_dict and model_name2 == 'topoae_ext': state_dict['latent_norm'] = torch.Tensor([1.0]).float() model.load_state_dict(state_dict) model.eval() dataloader_eval = torch.utils.data.DataLoader( selected_dataset, batch_size=config['batch_size'], pin_memory=True, drop_last=False) X_eval, Y_eval, Z_eval = get_latentspace_representation( model, dataloader_eval, device='cpu') result = dict() if evalconfig.eval_manifold: # sample true manifold manifold_eval_train = False try: Z_manifold, X_transformed, labels = dataset.sample_manifold( **config['sampling_kwargs'], train=manifold_eval_train) dataset_test = TensorDataset(torch.Tensor(X_transformed), torch.Tensor(labels)) dataloader_eval = torch.utils.data.DataLoader( dataset_test, batch_size=config['batch_size'], pin_memory=True, drop_last=False) X_eval, Y_eval, Z_latent = get_latentspace_representation( model, dataloader_eval, device='cpu') Z_manifold[:, 0] = (Z_manifold[:, 0] - Z_manifold[:, 0].min() ) / (Z_manifold[:, 0].max() - Z_manifold[:, 0].min()) Z_manifold[:, 1] = (Z_manifold[:, 1] - Z_manifold[:, 1].min() ) / (Z_manifold[:, 1].max() - Z_manifold[:, 1].min()) Z_latent[:, 0] = (Z_latent[:, 0] - Z_latent[:, 0].min()) / ( Z_latent[:, 0].max() - Z_latent[:, 0].min()) Z_latent[:, 1] = (Z_latent[:, 1] - Z_latent[:, 1].min()) / ( Z_latent[:, 1].max() - Z_latent[:, 1].min()) pwd_Z = pairwise_distances(Z_latent, Z_latent, n_jobs=1) pwd_Ztrue = pairwise_distances(Z_manifold, Z_manifold, n_jobs=1) # normalize distances pairwise_distances_manifold = ( pwd_Ztrue - pwd_Ztrue.min()) / (pwd_Ztrue.max() - pwd_Ztrue.min()) pairwise_distances_Z = (pwd_Z - pwd_Z.min()) / ( pwd_Z.max() - pwd_Z.min()) # save comparison fig plot_distcomp_Z_manifold( Z_manifold=Z_manifold, Z_latent=Z_latent, pwd_manifold=pairwise_distances_manifold, pwd_Z=pairwise_distances_Z, labels=labels, path_to_save=run_dir, name='manifold_Z_distcomp', fontsize=24, show=False) rmse_manifold = (np.square(pairwise_distances_manifold - pairwise_distances_Z)).mean() result.update(dict(rmse_manifold_Z=rmse_manifold)) except AttributeError as err: print(err) print('Manifold not evaluated!') if run_dir and evalconfig.save_eval_latent: df = pd.DataFrame(Z_eval) df['labels'] = Y_eval df.to_csv(os.path.join(run_dir, 'latents.csv'), index=False) np.savez(os.path.join(run_dir, 'latents.npz'), latents=Y_eval, labels=Z_eval) plot_2Dscatter(Z_eval, Y_eval, path_to_save=os.path.join( run_dir, 'test_latent_visualization.png'), dpi=100, title=None, show=False) if run_dir and evalconfig.save_train_latent: dataloader_train = torch.utils.data.DataLoader( train_dataset, batch_size=config['batch_size'], pin_memory=True, drop_last=False) X_train, Y_train, Z_train = get_latentspace_representation( model, dataloader_train, device='cpu') df = pd.DataFrame(Z_train) df['labels'] = Y_train df.to_csv(os.path.join(run_dir, 'train_latents.csv'), index=False) np.savez(os.path.join(run_dir, 'latents.npz'), latents=Z_train, labels=Y_train) # Visualize latent space plot_2Dscatter(Z_train, Y_train, path_to_save=os.path.join( run_dir, 'train_latent_visualization.png'), dpi=100, title=None, show=False) if evalconfig.quant_eval and continue2: print('QUANT EVAL....') ks = list( range(evalconfig.k_min, evalconfig.k_max + evalconfig.k_step, evalconfig.k_step)) evaluator = Multi_Evaluation(dataloader=dataloader_eval, seed=config['seed'], model=model) ev_result = evaluator.get_multi_evals(X_eval, Z_eval, Y_eval, ks=ks) prefixed_ev_result = { evalconfig.evaluate_on + '_' + key: value for key, value in ev_result.items() } result.update(prefixed_ev_result) s = json.dumps(result, default=default) open(os.path.join(run_dir, 'eval_metrics.json'), "w").write(s) result_avg = avg_array_in_dict(result) df = pd.DataFrame.from_dict(result_avg, orient='index').reset_index() df.columns = ['metric', 'value'] id_dict = dict( uid=exp, seed=config['seed'], batch_size=config['batch_size'], ) for key, value in id_dict.items(): df[key] = value df.set_index('uid') df = df[COLS_DF_RESULT] df.to_csv(os.path.join(exp_dir, 'eval_metrics_all.csv'), mode='a', header=False) else: print('skipped quant eval') else: shutil.move(run_dir, os.path.join(exp_dir, 'not_evaluated'))
if path_model == MODEL_vae2_1: state_dict = torch.load(os.path.join(path_model, 'model_state.pth'),map_location=torch.device('cpu')) model = autoencoder print('Loading Passed') else: model = WitnessComplexAutoencoder(autoencoder) state_dict = torch.load(os.path.join(path_model, 'model_state.pth'),map_location=torch.device('cpu')) state_dict2 = torch.load(os.path.join(os.path.join(path_norm2,exp_norm2), 'model_state.pth')) if 'latent' not in state_dict: state_dict['latent_norm'] = state_dict2['latent_norm'] * 0.1 print('passed') model.load_state_dict(state_dict) model.eval() z = model.encode(images.float()) debugging = True np.save('/Users/simons/PycharmProjects/MT-VAEs-TDA/src/datasets/simulated/xy_trans_l_newpers/latents_wae.np', z.detach().numpy()) # plot_2Dscatter(z.detach().numpy(), labels, path_to_save=os.path.join(path_save, 'pos_{}.pdf'.format( # time.strftime("%Y%m%d-%H%M%S"))), title=None, show=True, palette = 'custom')
def train_TopoAE_ext(_run, _seed, _rnd, config: ConfigWCAE, experiment_dir, experiment_root, device, num_threads, verbose): if device == 'cpu' and num_threads is not None: torch.set_num_threads(num_threads) try: os.makedirs(experiment_dir) except: pass try: os.makedirs(experiment_root) except: pass if os.path.isfile(os.path.join(experiment_root, 'eval_metrics_all.csv')): pass else: df = pd.DataFrame(columns=COLS_DF_RESULT) df.to_csv(os.path.join(experiment_root, 'eval_metrics_all.csv')) # Sample data dataset = config.dataset X_train, y_train = dataset.sample(**config.sampling_kwargs, train=True) dataset_train = TensorDataset(torch.Tensor(X_train), torch.Tensor(y_train)) X_test, y_test = dataset.sample(**config.sampling_kwargs, train=False) dataset_test = TensorDataset(torch.Tensor(X_test), torch.Tensor(y_test)) torch.manual_seed(_seed) # Initialize model # save normalization constants somewhere if isinstance(dataset, (MNIST, MNIST_offline)): norm_X = 28**2 / 10 #-> dimension of images is 28x28, max delta per pixel is 1, since data is normalized. elif isinstance(dataset, (Unity_Rotblock, Unity_RotCorgi)): norm_X = 180 elif X_train.shape[0] > 4096: #todo fix somehow inds = random.sample(range(X_train.shape[0]), 2048) norm_X = torch.cdist(dataset_train[inds][:][0], dataset_train[inds][:][0]).max() else: norm_X = torch.cdist(dataset_train[:][:][0], dataset_train[:][:][0]).max() model_class = config.model_class autoencoder = model_class(**config.model_kwargs) model = WitnessComplexAutoencoder(autoencoder, lam_r=config.rec_loss_weight, lam_t=config.top_loss_weight, toposig_kwargs=config.toposig_kwargs, norm_X=norm_X, device=config.device) if config.method_args['pre_trained_model'] is not None: state_dict = torch.load( os.path.join(config.method_args['pre_trained_model'], 'model_state.pth')) model.load_state_dict(state_dict) model.to(device) # Train and evaluate model result = train(model=model, data_train=dataset_train, data_test=dataset_test, config=config, device=device, quiet=operator.not_(verbose), val_size=config.method_args['val_size'], _seed=_seed, _rnd=_rnd, _run=_run, rundir=experiment_dir) # Format experiment data df = pd.DataFrame.from_dict(result, orient='index').reset_index() df.columns = ['metric', 'value'] id_dict = config.create_id_dict() for key, value in id_dict.items(): df[key] = value df.set_index('uid') df = df[COLS_DF_RESULT] df.to_csv(os.path.join(experiment_root, 'eval_metrics_all.csv'), mode='a', header=False)