class VAE(nn.Module): def __init__(self, z_dim=50, hidden_dim=400, enc_kernel1=5, enc_kernel2=5, use_cuda=False): super(VAE, self).__init__() # create the encoder and decoder networks self.encoder = Encoder(z_dim, hidden_dim, enc_kernel1, enc_kernel2) self.decoder = Decoder(z_dim, hidden_dim) if use_cuda: # calling cuda() here will put all the parameters of # the encoder and decoder networks into gpu memory self.cuda() self.use_cuda = use_cuda self.z_dim = z_dim # define the model p(x|z)p(z) def model(self, x): # register PyTorch module `decoder` with Pyro pyro.module("decoder", self.decoder) # setup hyperparameters for prior p(z) # the type_as ensures we get cuda Tensors if x is on gpu z_mu = ng_zeros([x.size(0), self.z_dim], type_as=x.data) z_sigma = ng_ones([x.size(0), self.z_dim], type_as=x.data) # sample from prior (value will be sampled by guide when computing the ELBO) z = pyro.sample("latent", dist.normal, z_mu, z_sigma) # decode the latent code z mu_img = self.decoder.forward(z) # score against actual images pyro.observe("obs", dist.bernoulli, x.view(-1, 784), mu_img) # define the guide (i.e. variational distribution) q(z|x) def guide(self, x): # register PyTorch module `encoder` with Pyro pyro.module("encoder", self.encoder) # use the encoder to get the parameters used to define q(z|x) z_mu, z_sigma = self.encoder.forward(x) # sample the latent code z pyro.sample("latent", dist.normal, z_mu, z_sigma) # define a helper function for reconstructing images def reconstruct_img(self, x): # encode image x x = x.view(1, 1, 28, 28) z_mu, z_sigma = self.encoder(x) # sample in latent space z = dist.normal(z_mu, z_sigma) # decode the image (note we don't sample in image space) mu_img = self.decoder(z) return mu_img def model_sample(self, batch_size=1): # sample the handwriting style from the constant prior distribution prior_mu = Variable(torch.zeros([batch_size, self.z_dim])) prior_sigma = Variable(torch.ones([batch_size, self.z_dim])) zs = pyro.sample("z", dist.normal, prior_mu, prior_sigma) mu = self.decoder.forward(zs) xs = pyro.sample("sample", dist.bernoulli, mu) return xs, mu
class VAE(Module): ''' Class that define the posterior distribution q(z|x) as the model with the decoder and the prior distribution q(x|z) as the guide using the encoder. Inputs: :pimg_dim: dimension of image vector :label_dim: dimension of label vector :latent_dim: dimension of Z space, output ''' def __init__(self, latents_sizes, latents_names, img_dim=4096, label_dim=114, latent_dim=200, use_CUDA=False): super(VAE, self).__init__() #creating networks self.encoder = Encoder(img_dim, label_dim, latent_dim) self.decoder = Decoder(img_dim, label_dim, latent_dim) self.img_dim = img_dim self.label_dim = label_dim self.latent_dim = latent_dim self.latents_sizes = latents_sizes self.latents_names = latents_names if use_CUDA: self.cuda() self.use_CUDA = use_CUDA def label_variable(self, label): new_label = [] options = {'device': label.device, 'dtype': label.dtype} for i, length in enumerate(self.latents_sizes): prior = torch.ones(label.shape[0], length, ** options) / (1.0 * length) new_label.append( pyro.sample("label_" + str(self.latents_names[i]), OneHotCategorical(prior), obs=one_hot(tensor(label[:, i], dtype=torch.int64), int(length)))) new_label = torch.cat(new_label, -1) return new_label.to(torch.float32).to(label.device) def model(self, img, label): pyro.module("decoder", self.decoder) options = {'device': img.device, 'dtype': img.dtype} with pyro.plate("data", img.shape[0]): z_mean = torch.zeros(img.shape[0], self.latent_dim, **options) z_variance = torch.ones(img.shape[0], self.latent_dim, **options) z_sample = pyro.sample("latent", Normal(z_mean, z_variance).to_event(1)) image = self.decoder.forward(z_sample, self.label_variable(label)) pyro.sample("obs", Bernoulli(image).to_event(1), obs=img) def guide(self, img, label): pyro.module("encoder", self.encoder) with pyro.plate("data", img.shape[0]): z_mean, z_variance = self.encoder.forward( img, self.label_variable(label)) pyro.sample("latent", Normal(z_mean, z_variance).to_event(1)) def run_img(self, img, label, num=1): label = label.reshape(1, -6) dummy_label = dummy_from_label(label) img = tensor(img.reshape(-1, 4096)).to(torch.float32) mean, var = self.encoder.forward(img, dummy_label) fig = plt.figure(figsize=(4, num * 5)) plots = [] plots.append(plt.subplot(num + 1, 1, 1)) plots[0].set_title('Original image') plt.imshow(img.reshape(64, 64)) for i in range(1, num): z_sample = Normal(mean, var).sample() vae_img = self.decoder.forward(z_sample, dummy_label) plots.append(plt.subplot(num + 1, 1, i + 1)) plots[-1].set_title(str(i) + ' - sample of latent space') plt.imshow(vae_img.detach().numpy().reshape(64, 64)) plt.show() def change_attribute(self, img, label, attribute=1): print('Attribute changed was ' + str(self.latents_names[attribute])) label = label.reshape(1, -6) new_label = np.copy(label) while (new_label == label).all(): val = np.random.choice(list(range(self.latents_sizes[attribute]))) new_label[0, attribute] = val dummy_label = dummy_from_label(label) new_dummy = dummy_from_label(new_label) img = tensor(img.reshape(-1, 4096)).to(torch.float32) mean, var = self.encoder.forward(img, dummy_label) fig = plt.figure(figsize=(4, 15)) plots = [] plots.append(plt.subplot(3, 1, 1)) plots[0].set_title('Original image') plt.imshow(img.reshape(64, 64)) z_sample = Normal(mean, var).sample() vae_img = self.decoder.forward(z_sample, dummy_label) plots.append(plt.subplot(3, 1, 2)) plots[1].set_title('Sample with original attribute') plt.imshow(vae_img.detach().numpy().reshape(64, 64)) z_sample = Normal(mean, var).sample() vae_img = self.decoder.forward(z_sample, new_dummy) plots.append(plt.subplot(3, 1, 3)) plots[2].set_title('Sample with changed attribute') plt.imshow(vae_img.detach().numpy().reshape(64, 64)) plt.show()