Exemplo n.º 1
0
 def __init__(self):
     self.device = torch.device(
         'cuda:0' if torch.cuda.is_available() else 'cpu')
     self.vae = torch.nn.DataParallel(VAE(args.z_dim, args.model_dim,
                                          args.img_size, args.img_channels,
                                          args.n_res_blocks),
                                      device_ids=args.device_ids).to(device)
     self.vae.load_state_dict(
         torch.load(f"{args.checkpoint_dir}/VAE.pth")['model'])
     self.vae.eval()
     Path(args.sample_path).mkdir(parents=True, exist_ok=True)
Exemplo n.º 2
0
def VAE_interpolating_experiment(device):
    batch_size = 2
    latent_dim = 100
    z = Variable(torch.randn(batch_size, latent_dim)).to(device)

    model = VAE(100).to(device)

    model.load_state_dict(torch.load(vae_path, map_location=device), strict=False)

    x = Variable(model.generate(z)).to(device)
    x_0 = x[0]
    x_1 = x[1]

    a_list = np.arange(0, 1.1, 0.1)
    z_list = []
    x_list = []
    for a in a_list:
        z_list.append(a*z[0] + (1-a)*z[1])
        x_list.append(a*x_0 + (1-a)*x_1)

    z_list = torch.cat(z_list, dim=0).view(len(a_list),-1)
    x_list =  torch.cat(x_list, dim=0).view(-1,3,32,32)

    zh_y = Variable(model.generate(z_list)).to(device)

    path = 'vae/results/interpolated/VAE_interpolated_zs.png'
    save_images(zh_y, path, nrow=len(zh_y))

    path = 'vae/results/interpolated/VAE_interpolated_xs.png'
    save_images(x_list, path, nrow= len(x_list))

    path = 'vae/results/interpolated/VAE_interpolated_xs_zs.png'
    save_images(torch.cat((x_list, zh_y), dim=0), path, nrow=11)
    results = torch.cat((x_list, zh_y), dim=0)
    difference = x_list - zh_y
    results = torch.cat((results,  difference), dim=0)
    save_images(results, path, nrow=11)
Exemplo n.º 3
0
def VAE_disentangled_representation_experiment(device):
    batch_size = 1
    latent_dim=100
    noise = Variable(torch.randn(batch_size, latent_dim)).to(device)

    model = VAE(100).to(device)

    model.load_state_dict(torch.load(vae_path, map_location=device), strict=False)

    dims = range(0,100)
    outputs = []
    z_y =  Variable(model.generate(noise)).to(device)
    for d in dims:
        zh = make_interpolation(noise, dim=d).view(batch_size, latent_dim)
        output = Variable(model.generate(zh)).to(device)
        outputs.append(output)

    outputs = torch.cat(outputs, dim=0)

    difference = outputs - z_y
    difference = torch.abs(difference).view(100,-1)
    sum_dif = torch.sum(difference, dim=1).detach().cpu().numpy()
    top_sum_diff_indcs = np.unravel_index(np.argsort(sum_dif, axis=None), sum_dif.shape)[0]
    top_sum_diff_indcs = top_sum_diff_indcs[-10:]

    top_k_images = Variable(outputs[top_sum_diff_indcs]).to(device).view(len(top_sum_diff_indcs), -1)
    top_k_images = torch.cat((z_y.view(1, -1), top_k_images))

    path = 'vae/results/interpolated/VAE_top_disentangleds.png'
    save_images(top_k_images, path, nrow=len(top_k_images))
    difference = top_k_images - z_y.view(1, -1)
    top_k_images = torch.cat((top_k_images,  difference), dim=0)
    save_images(top_k_images, path, nrow=11)

    path = 'vae/results/interpolated/VAE_disentangleds_all.png'
    save_images(outputs, path, nrow=10)
Exemplo n.º 4
0
if __name__ == '__main__':
    writer = SummaryWriter(args.log_dir +
                           f'/{int(datetime.now().timestamp()*1e6)}')
    Path(args.checkpoint_dir).mkdir(parents=True, exist_ok=True)
    Path(args.log_dir).mkdir(parents=True, exist_ok=True)

    loader = get_celeba_loaders(args.data_path, args.img_ext, args.crop_size,
                                args.img_size, args.batch_size, args.download)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # initialize model, instantiate opt & scheduler & loss fn
    if args.plus:
        model = VAE_Plus(args.z_dim, args.model_dim, args.img_size,
                         args.img_channels).to(device)
    else:
        model = VAE(args.z_dim, args.model_dim, args.img_size,
                    args.img_channels, args.n_res_blocks).to(device)

    if args.data_parallel:
        model = torch.nn.DataParallel(model,
                                      device_ids=args.device_ids).to(device)

    model.apply(initialize_modules)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 betas=args.betas)
    scheduler = torch.optim.lr_scheduler.MultiplicativeLR(
        optimizer, lambda epoch: 0.995)
    criterion = VAELoss(args.recon, args.beta, args.reduction)

    # fixed z to see how model changes on the same latent vectors
    fixed_z = torch.randn(args.sample_size, args.z_dim).to(device)
Exemplo n.º 5
0
torch.manual_seed(args.seed)
np.random.seed(args.seed)

# Agent
device = torch.device("cuda" if args.cuda else "cpu")

dataset1 = ReplayMemoryDataset(args.agent_memory1)

state_size = env.observation_space.shape[0]
encoder = Encoder(state_size,
                  hidden_dim=args.hidden_dim,
                  latent_dim=args.latent_dim)
decoder = Decoder(args.latent_dim,
                  hidden_dim=args.hidden_dim,
                  output_dim=state_size)
model = VAE(encoder, decoder)

model.load_model(args.model_path)

# Evaluation loop
total_numsteps = 0
avg_reward = 0.

state = env.reset()

with imageio.get_writer(args.video_file_name, fps=30) as video:
    for idx, x, in enumerate(dataset1):
        state, action, reward, next_state, done = x
        x_hat, _, _ = model(state)
        env.set_to_observation(x_hat.detach().numpy())
        video.append_data(env.render('rgb_array'))
validate_x, validate_cond = vae_make_one_hot_data(validate_data)

print(train_x.shape)
print(train_cond.shape)

sindex = 999
a = train_x[sindex]
b = train_cond[sindex]
print(a.shape)
print(b.shape)
a = np.reshape(a, (-1, 32, 130))
b = np.reshape(b, (-1, 32, 12))
print(a.shape)
print(b.shape)
#initialize model
model = VAE(130, 2048, 3, 12, 128, 128, 32)
model.eval()
dic = torch.load("vae/tr_chord.pt")
for name in list(dic.keys()):
    dic[name.replace('module.', '')] = dic.pop(name)
model.load_state_dict(dic)
if torch.cuda.is_available():
    model = model.cuda()
print(model)

a = torch.from_numpy(a).float()
b = torch.from_numpy(b).float()
res = model.encoder(a, b)
z1 = res[0].loc.detach().numpy()
z2 = res[1].loc.detach().numpy()
print(z1)
Exemplo n.º 7
0
import torch.nn as nn
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
import numpy as np
from tqdm import trange

from sa.model import MnistClassifier
from vae.model import VAE

img_size = 28*28*1
torch.no_grad() # since nothing is trained here


### Prep (e.g. Load Models) ###
vae = VAE(img_size = 28*28, h_dim = 1600, z_dim = 400)
vae.load_state_dict(torch.load('./vae/models/MNIST_EnD.pth'))
vae.cuda()

classifier = MnistClassifier(img_size = img_size)
classifier.load_state_dict(torch.load('./sa/models/MNIST_conv_classifier.pth'))
classifier.eval()
classifier.cuda()
print("models loaded...")

test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=False)
test_data_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1, shuffle=True)
print("Data loader ready...")

### GA Params ###
gen_num = 500
Exemplo n.º 8
0
env.seed(args.seed)
env.action_space.seed(args.seed)

torch.manual_seed(args.seed)
np.random.seed(args.seed)

dataset1 = ReplayMemoryDataset(args.agent_memory1)

state_size = env.observation_space.shape[0]
encoder = Encoder(state_size,
                  hidden_dim=args.hidden_dim,
                  latent_dim=args.latent_dim)
decoder = Decoder(args.latent_dim,
                  hidden_dim=args.hidden_dim,
                  output_dim=state_size)
model = VAE(encoder, decoder)

#Tesnorboard
datetime_st = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
log_dir = f'runs/{datetime_st}_VAE_{args.env_name}'
writer = SummaryWriter(log_dir)

dataloader = DataLoader(dataset1,
                        batch_size=args.batch_size,
                        shuffle=True,
                        num_workers=0)


def loss_function(x, x_hat, mean, log_var):
    reproduction_loss = F.mse_loss(x_hat, x)
    KLD = -0.5 * torch.mean(