from model.vae_gm import GMVAE
from dataset.cfd import build_datasets
# from dataset.celeba import build_datasets
# from dataset.celeba_single import build_datasets
# from dataset.mnist import build_datasets

IM_DIMS = (218, 178)
TOTAL_IMAGES = 202599
MODEL_PATH = Path('../save/gmvae/final.pt')
DATA_PATH = Path('../data')
# IM_PATH = DATA_PATH / 'img'
IM_PATH = DATA_PATH / 'cfd'

# <codecell>
train_ds, test_ds = build_datasets(IM_PATH)

# <codecell>
model = GMVAE()
ckpt = torch.load(MODEL_PATH, map_location=torch.device('cpu'))
model.load_state_dict(ckpt['model_state_dict'])

model.eval()

# <codecell>
idx = 5
samp_im = test_ds[idx].numpy().transpose(1, 2, 0)

plt.imshow(samp_im)
plt.show()
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
import torch
from scipy.spatial import distance
from tqdm import tqdm

import sys
sys.path.append('../')

from dataset.cfd import build_datasets
from model.vae import VAE
from util import *

out_path = Path('../save/vae/grid_search')
_, test_ds = build_datasets(Path('../data/cfd'), train_test_split=1)


class ModelData:
    def __init__(self, name: str, save_path: str, params: dict):
        self.name = name
        self.save_path = save_path
        self.params = params


# TODO: retrain VAE models with 2^n - 1 width
configs = [
    ModelData('vae64', '../save/vae/vae64.pt', {'latent_dims': 64}),
    ModelData('vae128', '../save/vae/vae128.pt', {'latent_dims': 129}),
    ModelData('vae256', '../save/vae/vae256.pt', {'latent_dims': 257}),
    ModelData('vae512', '../save/vae/vae512.pt', {'latent_dims': 512}),
data_path = Path('../data/cfd')
# data_path = Path('../data/utk')
model_path = Path('../save/vae/vae_jan19_final.pt')
save_path = Path('../save/vae/cfd/latent')

# <codecell>
if not save_path.exists():
    save_path.mkdir(parents=True)

model = VAE()
ckpt = torch.load(model_path, map_location=torch.device('cpu'))
model.load_state_dict(ckpt['model_state_dict'])
model.eval()

_, test_ds = build_datasets(data_path, train_test_split=1)

# test_ds = [test_ds[i] for i in range(10)]

# <codecell>
test_len = len(test_ds)
ldims = model.latent_dims
mu_points = np.zeros((test_len, ldims))
var_points = np.zeros((test_len, ldims))
feats = []

for i in tqdm(range(test_len)):
    im, feat = test_ds[i]
    im = im.unsqueeze(0)
    mu, var = model.encode(im)
    # mu, var = model.encode(im)[0]   # neccessary adaptation for GMVAE