def reconstruct(self, x_orig, device, epoch): # take only first 32 datapoints of the input # otherwise the output image grid may be too big for visualization x_orig = x_orig[:32, :, :, :].to(device) # sample from the bottom (zi = 1) inference model mu, scale = self.infer(0)(given=x_orig) eps = random.logistic_eps(mu.shape, device=device) z = random.transform(eps, mu, scale) # sample zs # sample from the bottom (zi = 1) generative model mu, scale = self.generate(0)(given=z) x_eps = random.logistic_eps(mu.shape, device=device) x_cont = random.transform(x_eps, mu, scale) # scale up from [-1.1] to [0,255] x_cont = (x_cont * 127.5) + 127.5 # esnure that [0,255] x_sample = torch.clamp(x_cont, 0, 255) # scale from [0,255] to [0,1] and convert to right shape x_sample = x_sample.float() / 255. x_orig = x_orig.float() / 255. # concatenate the input data and the sampled reconstructions for comparison x_with_recon = torch.cat((x_orig, x_sample)) # make a grid out of the original data and the reconstruction samples x_with_recon = x_with_recon.view((2 * x_orig.shape[0],) + self.xs) x_grid = utils.make_grid(x_with_recon) # log self.logger.add_image('x_reconstruct', x_grid, epoch)
def sample(self, device, epoch, num=64): # sample "num" latent variables from the prior z = random.logistic_eps(((num,) + self.zdim), device=device) # sample from the generative distribution(s) for i in reversed(range(self.nz)): mu, scale = self.generate(i)(given=z) eps = random.logistic_eps(mu.shape, device=device) z_prev = random.transform(eps, mu, scale) z = z_prev # scale up from [-1,1] to [0,255] x_cont = (z * 127.5) + 127.5 # ensure that [0,255] x = torch.clamp(x_cont, 0, 255) # scale from [0,255] to [0,1] and convert to right shape x_sample = x.float() / 255. x_sample = x_sample.view((num,) + self.xs) # make grid out of "num" samples x_grid = utils.make_grid(x_sample) # log self.logger.add_image('x_sample', x_grid, epoch)
def loss(self, x): # tensor to store inference model losses logenc = torch.zeros((self.nz, x.shape[0], self.zdim[0]), device=x.device) # tensor to store the generative model losses logdec = torch.zeros((self.nz, x.shape[0], self.zdim[0]), device=x.device) # tensor to store the latent samples zsamples = torch.zeros((self.nz, x.shape[0], np.prod(self.zdim)), device=x.device) for i in range(self.nz): # inference model # get the parameters of inference distribution i given x (if i == 0) or z (otherwise) mu, scale = self.infer(i)(given=x if i == 0 else z) # sample untransformed sample from Logistic distribution (mu=0, scale=1) eps = random.logistic_eps(mu.shape, device=mu.device) # reparameterization trick: transform using obtained parameters z_next = random.transform(eps, mu, scale) # store the inference model loss zsamples[i] = z_next.flatten(1) logq = torch.sum(random.logistic_logp(mu, scale, z_next), dim=2) logenc[i] += logq # generative model # get the parameters of inference distribution i given z mu, scale = self.generate(i)(given=z_next) # store the generative model loss if i == 0: # if bottom (zi = 1) generative model, evaluate loss using discretized Logistic distribution logp = torch.sum(random.discretized_logistic_logp( mu, scale, x), dim=1) logrecon = logp else: logp = torch.sum(random.logistic_logp(mu, scale, x if i == 0 else z), dim=2) logdec[i - 1] += logp z = z_next # store the prior loss logp = torch.sum(random.logistic_logp(torch.zeros(1, device=x.device), torch.ones(1, device=x.device), z), dim=2) logdec[self.nz - 1] += logp # convert from "nats" to bits logenc = torch.mean(logenc, dim=1) * self.bitsscale logdec = torch.mean(logdec, dim=1) * self.bitsscale logrecon = torch.mean(logrecon) * self.bitsscale return logrecon, logdec, logenc, zsamples