def negative_iwae_bound(self, x, iw): """ Computes the Importance Weighted Autoencoder Bound Additionally, we also compute the ELBO KL and reconstruction terms Args: x: tensor: (batch, dim): Observations iw: int: (): Number of importance weighted samples Returns: niwae: tensor: (): Negative IWAE bound kl: tensor: (): ELBO KL divergence to prior rec: tensor: (): ELBO Reconstruction term """ ################################################################################ # TODO: Modify/complete the code here # Compute niwae (negative IWAE) with iw importance samples, and the KL # and Rec decomposition of the Evidence Lower Bound # # Outputs should all be scalar ################################################################################ batch = x.size()[0] """ sample iw z's for z_i in sample: find p(z_i, all x) find q(z_i, x) average """ phi_m, phi_v = self.enc.encode(x) # (batch, z_dim) phi_m, phi_v = ut.duplicate(phi_m, iw), ut.duplicate(phi_v, iw) # (batch*iw, z_dim) x_iw = ut.duplicate(x, iw) z_hat = ut.sample_gaussian(phi_m, phi_v) # (batch*iw, z_dim) log_q_zx = ut.log_normal(z_hat, phi_m, phi_v) # (batch*iw) log_p_z = ut.log_normal(z_hat, *self.z_prior) # (batch*iw) log_p_xz = ut.log_bernoulli_with_logits( x_iw, self.dec.decode(z_hat)) # (batch*iw) f = lambda x: x.reshape(iw, batch).transpose(1, 0) log_p_xz, log_q_zx, log_p_z = f(log_p_xz), f(log_q_zx), f(log_p_z) iwae = ut.log_mean_exp(log_p_xz - log_q_zx + log_p_z, -1) iwae = iwae.mean(0) niwae = -iwae kl = ut.log_mean_exp(log_q_zx - log_p_z, -1) kl = kl.mean(0) rec = ut.log_mean_exp(log_p_xz, -1) rec = -rec.mean(0) ################################################################################ # End of code modification ################################################################################ return niwae, kl, rec
def negative_iwae_bound(self, x, iw): """ Computes the Importance Weighted Autoencoder Bound Additionally, we also compute the ELBO KL and reconstruction terms Args: x: tensor: (batch, dim): Observations iw: int: (): Number of importance weighted samples Returns: niwae: tensor: (): Negative IWAE bound kl: tensor: (): ELBO KL divergence to prior rec: tensor: (): ELBO Reconstruction term """ ################################################################################ # TODO: Modify/complete the code here # Compute niwae (negative IWAE) with iw importance samples, and the KL # and Rec decomposition of the Evidence Lower Bound # # Outputs should all be scalar ################################################################################ q_m, q_v = self.enc.encode(x) q_m_, q_v_ = ut.duplicate(q_m, rep=iw), ut.duplicate(q_v, rep=iw) z_given_x = ut.sample_gaussian(q_m_, q_v_) decoded_bernoulli_logits = self.dec.decode(z_given_x) #duplicate x x_dup = ut.duplicate(x, rep=iw) rec = ut.log_bernoulli_with_logits(x_dup, decoded_bernoulli_logits) #compute kl p_m, p_v = torch.zeros(q_m.shape), torch.ones(q_v.shape) p_m_, p_v_ = ut.duplicate(p_m, iw), ut.duplicate(p_v, iw) #print("p_m", p_m.shape) log_q_phi = ut.log_normal(z_given_x, q_m_, q_v_) #encoded distribution log_p = ut.log_normal(z_given_x, p_m_, p_v_) #prior distribution kl = log_q_phi - log_p niwae = rec - kl #reshape to size (iw, bs) and then sum niwae = ut.log_mean_exp(niwae.reshape(iw, -1), dim=0) rec = ut.log_mean_exp(rec, dim=0) kl = ut.log_mean_exp(kl, dim=0) niwae = -torch.mean(niwae) kl = torch.mean(kl) rec = torch.mean(kl) ################################################################################ # End of code modification ################################################################################ return niwae, kl, rec
def negative_iwae_bound_for(self, x, y, c, iw): """ Computes the Importance Weighted Autoencoder Bound Additionally, we also compute the ELBO KL and reconstruction terms Returns: niwae: tensor: (): Negative IWAE bound kl: tensor: (): ELBO KL divergence to prior rec: tensor: (): ELBO Reconstruction term """ # encode #print(x.shape, y.shape) qm, qv = self.enc.encode(x, y=y, c=c) # replicate qm, qv q_shape = list(qm.shape) qm = qm.unsqueeze(1).expand(q_shape[0], iw, *q_shape[1:]) qv = qv.unsqueeze(1).expand(q_shape[0], iw, *q_shape[1:]) # sample z(1)...z(iw) (for monte carlo estimate of p(x|z(1)) z = ut.sample_gaussian(qm, qv) kl_elem = self.kl_elementwise(z, qm, qv) # reshape for LSTM # replicate z, x, y, c z_shape = list(z.shape) z = z.reshape(z_shape[0] * iw, *z_shape[2:]) x_shape = list(x.shape) x = x.unsqueeze(1).expand(x_shape[0], iw, *x_shape[1:]) x = x.reshape(x_shape[0] * iw, *x_shape[1:]) if y is not None: y_shape = list(y.shape) y = y.unsqueeze(1).expand(y_shape[0], iw, *y_shape[1:]) y = y.reshape(y_shape[0] * iw, *y_shape[1:]) if c is not None: c_shape = list(c.shape) c = c.unsqueeze(1).expand(c_shape[0], iw, *c_shape[1:]) c = c.reshape(c_shape[0] * iw, *c_shape[1:]) # decode mu, var = self.dec.decode(z, y=y, c=c) nll, rec_mse, rec_var = ut.nlog_prob_normal(mu=mu, y=x, var=var, fixed_var=self.warmup, var_pen=self.var_pen) log_prob, rec_mse, rec_var = -nll, rec_mse.mean(), rec_var.mean() log_prob = log_prob.view(x_shape[0], iw) niwae = -ut.log_mean_exp(log_prob - kl_elem, dim=1).mean(-1) # reduce rec = -log_prob.mean(1).mean(-1) kl = kl_elem.mean(1).mean(-1) return niwae, kl, rec, rec_mse, rec_var
def negative_iwae_bound(self, x, iw): """ Computes the Importance Weighted Autoencoder Bound Additionally, we also compute the ELBO KL and reconstruction terms Args: x: tensor: (batch, dim): Observations iw: int: (): Number of importance weighted samples Returns: niwae: tensor: (): Negative IWAE bound kl: tensor: (): ELBO KL divergence to prior rec: tensor: (): ELBO Reconstruction term """ ################################################################################ # TODO: Modify/complete the code here # Compute niwae (negative IWAE) with iw importance samples, and the KL # and Rec decomposition of the Evidence Lower Bound # # Outputs should all be scalar ################################################################################ # Compute the mixture of Gaussian prior pm, pv = ut.gaussian_parameters(self.z_pre, dim=1) # # Generate samples. qm, qv = self.enc.encode(x) niwaes = [] recs = [] kls = [] for i in range(iw): z_sample = ut.sample_gaussian(qm, qv).view(-1, qm.shape[1]) rec = self.dec.decode(z_sample) logptheta_x_g_z = ut.log_bernoulli_with_logits(x, rec) logptheta_z = ut.log_normal_mixture(z_sample, pm, pv) logqphi_z_g_x = ut.log_normal(z_sample, qm, qv) niwae = logptheta_x_g_z + logptheta_z - logqphi_z_g_x # # Normal variables. rec = -ut.log_bernoulli_with_logits(x, rec) kl = ut.log_normal(z_sample, qm, qv) - ut.log_normal_mixture( z_sample, pm, pv) niwaes.append(niwae) recs.append(rec) kls.append(kl) niwaes = torch.stack(niwaes, -1) niwae = ut.log_mean_exp(niwaes, -1) kl = torch.stack(kls, -1) rec = torch.stack(recs, -1) ################################################################################ # End of code modification ################################################################################ return -niwae.mean(), kl.mean(), rec.mean()
def negative_iwae_bound(self, x, iw): """ Computes the Importance Weighted Autoencoder Bound Additionally, we also compute the ELBO KL and reconstruction terms Args: x: tensor: (batch, dim): Observations iw: int: (): Number of importance weighted samples Returns: niwae: tensor: (): Negative IWAE bound kl: tensor: (): ELBO KL divergence to prior rec: tensor: (): ELBO Reconstruction term """ ################################################################################ # TODO: Modify/complete the code here # Compute niwae (negative IWAE) with iw importance samples, and the KL # and Rec decomposition of the Evidence Lower Bound # # Outputs should all be scalar ################################################################################ # Compute the mixture of Gaussian prior prior = ut.gaussian_parameters(self.z_pre, dim=1) m, v = self.enc.encode(x) dist = Normal(loc=m, scale=torch.sqrt(v)) z_sample = dist.rsample(sample_shape=torch.Size([iw])) log_batch_z_sample_prob = [] kl_batch_z_sample = [] for i in range(iw): recon_logits = self.dec.decode(z_sample[i]) log_batch_z_sample_prob.append( ut.log_bernoulli_with_logits( x, recon_logits)) # [batch, z_sample] kl_batch_z_sample.append( ut.log_normal(z_sample[i], m, v) - ut.log_normal_mixture(z_sample[i], prior[0], prior[1])) log_batch_z_sample_prob = torch.stack(log_batch_z_sample_prob, dim=1) kl_batch_z_sample = torch.stack(kl_batch_z_sample, dim=1) niwae = -ut.log_mean_exp(log_batch_z_sample_prob - kl_batch_z_sample, dim=1).mean(dim=0) rec = -torch.mean(log_batch_z_sample_prob, dim=0) kl = torch.mean(kl_batch_z_sample, dim=0) ################################################################################ # End of code modification ################################################################################ return niwae, kl, rec
def negative_iwae_bound(self, x, iw): """ Computes the Importance Weighted Autoencoder Bound Additionally, we also compute the ELBO KL and reconstruction terms Args: x: tensor: (batch, dim): Observations iw: int: (): Number of importance weighted samples Returns: niwae: tensor: (): Negative IWAE bound kl: tensor: (): ELBO KL divergence to prior rec: tensor: (): ELBO Reconstruction term """ m, v = self.enc.encode(x) batch_size, dim = m.shape # Duplicate m = ut.duplicate(m, iw) v = ut.duplicate(v, iw) x = ut.duplicate(x, iw) z = ut.sample_gaussian(m, v) logits = self.dec.decode(z) km = self.km.repeat(batch_size, 1, 1) kv = self.kv.repeat(batch_size, 1, 1) km = ut.duplicate(km, iw) kv = ut.duplicate(kv, iw) kl_vec = ut.log_normal(z, m, v) - ut.log_normal_mixture(z, km, kv) kl = torch.mean(kl_vec) # TODO: compute the values below rec_vec = ut.log_bernoulli_with_logits(x, logits) rec = torch.neg(torch.mean(rec_vec)) if iw > 1: iwtensor = torch.zeros(iw) j = 0 while j < iw: i = 0 sum = 0 while i < batch_size: sum += rec_vec[j * batch_size + i] i += 1 iwtensor[j] = sum / batch_size - kl j += 1 niwae = torch.neg(ut.log_mean_exp(iwtensor, 0)) else: niwae = rec + kl return niwae, kl, rec
def negative_iwae_bound(self, x, iw): """ Computes the Importance Weighted Autoencoder Bound Additionally, we also compute the ELBO KL and reconstruction terms Args: x: tensor: (batch, dim): Observations iw: int: (): Number of importance weighted samples Returns: niwae: tensor: (): Negative IWAE bound kl: tensor: (): ELBO KL divergence to prior rec: tensor: (): ELBO Reconstruction term """ ################################################################################ # TODO: Modify/complete the code here # Compute niwae (negative IWAE) with iw importance samples, and the KL # and Rec decomposition of the Evidence Lower Bound # # Outputs should all be scalar ################################################################################ X_dupl = ut.duplicate(x, iw) # Input "x" is duplicated "iw" times (m, v) = self.enc.encode(X_dupl) # compute the encoder outut z = ut.sample_gaussian( m, v) # sample a point from the multivariate Gaussian logits = self.dec.decode(z) # pass the sampled "Z" through the decoder # Calculate log Prob of the output x_hat given latent z ln_P_x_z = ut.log_bernoulli_with_logits(X_dupl, logits) # Calculate log(P(z)) #ln_P_z = -torch.sum(z*z, -1)/2.0 ln_P_z = ut.log_normal(z, self.z_prior_m, self.z_prior_v) # Calculate log(Q(z | x)), Conditional Prob of Latent given x #ln_q_z_x = -torch.sum((z-m)*(z-m)/(2.0*v) + torch.log(v), -1) ln_q_z_x = ut.log_normal(z, m, v) exponent = ln_P_x_z + ln_P_z - ln_q_z_x exponent = exponent.reshape(iw, -1) L_m_x = ut.log_mean_exp(exponent, 0) niwae = -torch.mean(L_m_x) kl = torch.tensor(0) rec = torch.tensor(0) ################################################################################ # End of code modification ################################################################################ return niwae, kl, rec
def negative_iwae_bound(self, x, iw): """ Computes the Importance Weighted Autoencoder Bound Additionally, we also compute the ELBO KL and reconstruction terms Args: x: tensor: (batch, dim): Observations iw: int: (): Number of importance weighted samples Returns: niwae: tensor: (): Negative IWAE bound kl: tensor: (): ELBO KL divergence to prior rec: tensor: (): ELBO Reconstruction term """ ################################################################################ # TODO: Modify/complete the code here # Compute niwae (negative IWAE) with iw importance samples, and the KL # and Rec decomposition of the Evidence Lower Bound # # Outputs should all be scalar ################################################################################ # Compute the mixture of Gaussian prior prior = ut.gaussian_parameters(self.z_pre, dim=1) q_m, q_v = self.enc.encode(x) q_m_, q_v_ = ut.duplicate(q_m, rep=iw), ut.duplicate(q_v, rep=iw) z_given_x = ut.sample_gaussian(q_m_, q_v_) decoded_bernoulli_logits = self.dec.decode(z_given_x) #duplicate x x_dup = ut.duplicate(x, rep=iw) rec = ut.log_bernoulli_with_logits(x_dup, decoded_bernoulli_logits) log_p_theta = ut.log_normal_mixture(z_given_x, prior[0], prior[1]) log_q_phi = ut.log_normal(z_given_x, q_m_, q_v_) kl = log_q_phi - log_p_theta niwae = rec - kl niwae = ut.log_mean_exp(niwae.reshape(iw, -1), dim=0) niwae = -torch.mean(niwae) #yay! ################################################################################ # End of code modification ################################################################################ return niwae, kl, rec
def negative_iwae_bound(self, x, iw): """ Computes the Importance Weighted Autoencoder Bound Additionally, we also compute the ELBO KL and reconstruction terms Args: x: tensor: (batch, dim): Observations iw: int: (): Number of importance weighted samples Returns: niwae: tensor: (): Negative IWAE bound kl: tensor: (): ELBO KL divergence to prior rec: tensor: (): ELBO Reconstruction term """ ################################################################################ # TODO: Modify/complete the code here # Compute niwae (negative IWAE) with iw importance samples, and the KL # and Rec decomposition of the Evidence Lower Bound # # Outputs should all be scalar ################################################################################ m, v = self.enc.encode(x) # m, v -> (batch, dim) # (batch, dim) -> (batch*iw, dim) m = ut.duplicate(m, iw) # (batch, dim) -> (batch*iw, dim) v = ut.duplicate(v, iw) # (batch, dim) -> (batch*iw, dim) x = ut.duplicate(x, iw) # z -> (batch*iw, dim) z = ut.sample_gaussian(m, v) logits = self.dec.decode(z) kl = ut.log_normal(z, m, v) - ut.log_normal(z, self.z_prior_m, self.z_prior_v) rec = -ut.log_bernoulli_with_logits(x, logits) nelbo = kl + rec niwae = -ut.log_mean_exp(-nelbo.reshape(iw, -1), dim=0) niwae, kl, rec = niwae.mean(), kl.mean(), rec.mean() ################################################################################ # End of code modification ################################################################################ return niwae, kl, rec
def negative_iwae_bound_for(self, x, x_hat, y, c, iw): """ Computes the Importance Weighted Autoencoder Bound Additionally, we also compute the ELBO KL and reconstruction terms Args: x: tensor: (batch, dim): Observations x_hat: tensor: (batch, dim): Observations y: tensor: (batch, y_dim): whether observations contain EV c: tensor: (batch, c_dim): target mapping specification iw: int: (): Number of importance weighted samples Returns: niwae: tensor: (): Negative IWAE bound kl: tensor: (): ELBO KL divergence to prior rec: tensor: (): ELBO Reconstruction term """ # encode qm, qv = self.enc.encode(x, y=y) # replicate qm, qv q_shape = list(qm.shape) qm = qm.unsqueeze(1).expand(q_shape[0], iw, *q_shape[1:]) qv = qv.unsqueeze(1).expand(q_shape[0], iw, *q_shape[1:]) # replicate x, y, c x_shape = list(x_hat.shape) x_hat = x_hat.unsqueeze(1).expand(x_shape[0], iw, *x_shape[1:]) y_shape = list(y.shape) y = y.unsqueeze(1).expand(y_shape[0], iw, *y_shape[1:]) c_shape = list(c.shape) c = c.unsqueeze(1).expand(c_shape[0], iw, *c_shape[1:]) # sample z(1)...z(iw) (for monte carlo estimate of p(x|z(1)) z = ut.sample_gaussian(qm, qv) kl_elem = self.kl_elem(z, qm, qv) # decode mu, var = self.dec.decode(z, y=y, c=c) nll, rec_mse, rec_var = ut.nlog_prob_normal( mu=mu, y=x_hat, var=var, fixed_var=self.warmup, var_pen=self.var_pen) log_prob, rec_mse, rec_var = -nll, rec_mse.mean(), rec_var.mean() niwae = -ut.log_mean_exp(log_prob - kl_elem, dim=1).mean(-1) # reduce rec = -log_prob.mean(1).mean(-1) kl = kl_elem.mean(1).mean(-1) return niwae, kl, rec, rec_mse, rec_var
def negative_iwae_bound(self, x, iw): """ Computes the Importance Weighted Autoencoder Bound Additionally, we also compute the ELBO KL and reconstruction terms Args: x: tensor: (batch, dim): Observations iw: int: (): Number of importance weighted samples Returns: niwae: tensor: (): Negative IWAE bound kl: tensor: (): ELBO KL divergence to prior rec: tensor: (): ELBO Reconstruction term """ ################################################################################ # TODO: Modify/complete the code here # Compute niwae (negative IWAE) with iw importance samples, and the KL # and Rec decomposition of the Evidence Lower Bound # # Outputs should all be scalar ################################################################################ N_batches, dims = x.size() x = ut.duplicate(x, iw) q_mu, q_var = self.enc.encode(x) z_samp = ut.sample_gaussian(q_mu, q_var) logits = self.dec.decode(z_samp) probs = ut.log_bernoulli_with_logits(x, logits) log_vals = -ut.kl_normal(q_mu, q_var, torch.zeros_like(q_mu), torch.ones_like(q_var)) # log_vals = ut.log_normal(z_samp, torch.zeros_like(q_mu), torch.ones_like(q_var)) - ut.log_normal(z_samp, q_mu, q_var) probs = probs + log_vals niwae = torch.mean(-ut.log_mean_exp(probs.reshape(N_batches, iw), 1)) kl = torch.tensor(0) rec = torch.tensor(0) # niwae = kl + rec ################################################################################ # End of code modification ################################################################################ return niwae, kl, rec
def negative_iwae_bound(self, x, iw): """ Computes the Importance Weighted Autoencoder Bound Additionally, we also compute the ELBO KL and reconstruction terms Args: x: tensor: (batch, dim): Observations iw: int: (): Number of importance weighted samples Returns: niwae: tensor: (): Negative IWAE bound kl: tensor: (): ELBO KL divergence to prior rec: tensor: (): ELBO Reconstruction term """ # encode qm, qv = self.enc.encode(x) # replicate qm, qv q_shape = list(qm.shape) qm = qm.unsqueeze(1).expand(q_shape[0], iw, *q_shape[1:]) qv = qv.unsqueeze(1).expand(q_shape[0], iw, *q_shape[1:]) # replicate x x_shape = list(x.shape) x = x.unsqueeze(1).expand(x_shape[0], iw, *x_shape[1:]) # sample z(1)...z(iw) (for monte carlo estimate of p(x|z(1)) z = ut.sample_gaussian(qm, qv) # decode mu, var = self.dec.decode(z) kl_elem = self.kl_elem(z, qm, qv) nll, rec_mse, rec_var = ut.nlog_prob_normal(mu=mu, y=x, var=var, fixed_var=self.warmup, var_pen=self.var_pen) log_prob, rec_mse, rec_var = -nll, rec_mse.mean(), rec_var.mean() niwae = -ut.log_mean_exp(log_prob - kl_elem, dim=1).mean(-1) rec = -log_prob.mean(1).mean(-1) kl = kl_elem.mean(1).mean(-1) return niwae, kl, rec, rec_mse, rec_var
def negative_iwae_bound(self, x, iw): """ Computes the Importance Weighted Autoencoder Bound Additionally, we also compute the ELBO KL and reconstruction terms Args: x: tensor: (batch, dim): Observations iw: int: (): Number of importance weighted samples Returns: niwae: tensor: (): Negative IWAE bound kl: tensor: (): ELBO KL divergence to prior rec: tensor: (): ELBO Reconstruction term """ ################################################################################ # TODO: Modify/complete the code here # Compute niwae (negative IWAE) with iw importance samples, and the KL # and Rec decomposition of the Evidence Lower Bound # # Outputs should all be scalar ################################################################################ niwae = 0 for i in range(x.size()[0]): x_i = x[i][:].view(1, x.size()[1]) x_i = ut.duplicate(x_i, iw) m, v = self.enc.encode(x_i) z = ut.sample_gaussian(m, v) x_hat = self.dec.decode(z) exponent = ut.log_bernoulli_with_logits(x_i, x_hat) + \ ut.log_normal(z, self.z_prior_m.expand(m.size()), self.z_prior_v.expand(v.size())) \ - ut.log_normal(z, m, v) niwae += -ut.log_mean_exp(exponent, 0).squeeze() #print(np.std(exponent.data.cpu().numpy())) #print(exponent.data.cpu().numpy().shape) niwae = niwae / x.size()[0] kl = rec = torch.tensor(0) ################################################################################ # End of code modification ################################################################################ return niwae, kl, rec
def negative_iwae_bound(self, x, iw): """ Computes the Importance Weighted Autoencoder Bound Additionally, we also compute the ELBO KL and reconstruction terms Args: x: tensor: (batch, dim): Observations iw: int: (): Number of importance weighted samples Returns: niwae: tensor: (): Negative IWAE bound kl: tensor: (): ELBO KL divergence to prior rec: tensor: (): ELBO Reconstruction term """ m, v = self.enc.encode(x) # Duplicate m = ut.duplicate(m, iw) v = ut.duplicate(v, iw) x = ut.duplicate(x, iw) z = ut.sample_gaussian(m, v) logits = self.dec.decode(z) # TODO: compute the values below # Get KL and Rec of elbo again pm = torch.zeros((m.shape)) pv = torch.ones((v.shape)) kl = ut.kl_normal(m, v, pm, pv) rec = ut.log_bernoulli_with_logits(x, logits) # Now get the log mean of the exp of the KL divergence and subtact the # reconstuction from all of the weighted samples niwae = ut.log_mean_exp(ut.kl_normal(m, v, pm, pv), dim=0) - torch.mean( ut.log_bernoulli_with_logits(x, logits)) return niwae, kl, rec
def negative_iwae_bound(self, x, iw): """ Computes the Importance Weighted Autoencoder Bound Additionally, we also compute the ELBO KL and reconstruction terms Args: x: tensor: (batch, dim): Observations iw: int: (): Number of importance weighted samples Returns: niwae: tensor: (): Negative IWAE bound kl: tensor: (): ELBO KL divergence to prior rec: tensor: (): ELBO Reconstruction term """ ################################################################################ # TODO: Modify/complete the code here # Compute niwae (negative IWAE) with iw importance samples, and the KL # and Rec decomposition of the Evidence Lower Bound # # Outputs should all be scalar ################################################################################ ## E_{z(1),...z(n)} {log p(x|z) + log p(z) - log q(z|x)} m, v = self.enc.encode(x) batch_m = m.unsqueeze(1) batch_m = batch_m.repeat(1, iw, 1) # dimension (batch, iw, 10) batch_v = v.unsqueeze(1) batch_v = batch_v.repeat(1, iw, 1) batch_x = x.unsqueeze(1) batch_x = batch_x.repeat(1, iw, 1) # log p(x|z) zs = ut.sample_gaussian(batch_m, batch_v) logits = self.dec.decode(zs) raw_probs = ut.log_bernoulli_with_logits(batch_x, logits) pxz = torch.mean(ut.log_mean_exp(raw_probs, dim=-1)) # log p(z) batch_size = batch_m.shape[0] batch_z_prior_m = self.z_prior_m.view(1, 1, -1) batch_z_prior_m = batch_z_prior_m.repeat(batch_size, iw, 1) batch_z_prior_v = self.z_prior_v.view(1, 1, -1) batch_z_prior_v = batch_z_prior_v.repeat(batch_size, iw, 1) pz = ut.log_normal(zs, batch_z_prior_m, batch_z_prior_v) # (batch, iw) pz = torch.mean(ut.log_mean_exp(pz, dim=-1)) # log q(z|x) qzx = ut.log_normal(zs, batch_m, batch_v) qzx = torch.mean(ut.log_mean_exp(qzx, dim=-1)) # print(pxz, pz, qzx) niwae = -1 * (pxz + pz - qzx) rec = pxz - pz kl = pz - qzx ################################################################################ # End of code modification ################################################################################ return niwae, kl, rec
def negative_iwae_bound(self, x, iw): """ Computes the Importance Weighted Autoencoder Bound Additionally, we also compute the ELBO KL and reconstruction terms Args: x: tensor: (batch, dim): Observations iw: int: (): Number of importance weighted samples Returns: niwae: tensor: (): Negative IWAE bound kl: tensor: (): ELBO KL divergence to prior rec: tensor: (): ELBO Reconstruction term """ ################################################################################ # TODO: Modify/complete the code here # Compute niwae (negative IWAE) with iw importance samples, and the KL # and Rec decomposition of the Evidence Lower Bound # # Outputs should all be scalar ################################################################################ # m, v = self.enc.encode(x) # # # expand m to iw samples # m_iw = ut.duplicate(m, iw) # v_iw = ut.duplicate(v, iw) # x_iw = ut.duplicate(x, iw) # # # sample z [iw] # z = ut.sample_gaussian(m_iw, v_iw) # x_logits = self.dec.decode(z) # # # reconstruct loss # rec_loss = -ut.log_bernoulli_with_logits(x_iw, x_logits) # # # kl # kl = ut.log_normal(z, m, v) - ut.log_normal(z, self.z_prior[0], self.z_prior[1]) # # # iw nelbo # nelbo = kl + rec_loss # # niwae = -ut.log_mean_exp(-nelbo.reshape(iw, -1), dim=0) # niwae, kl, rec = niwae.mean(), kl.mean(), rec_loss.mean() m, v = self.enc.encode(x) dist = Normal(loc=m, scale=torch.sqrt(v)) z_iw = dist.rsample(sample_shape=torch.Size([iw])) log_z_batch, kl_z_batch = [], [] # for each z sample for i in range(iw): recon_logits = self.dec.decode(z_iw[i]) log_z_batch.append(ut.log_bernoulli_with_logits(x, recon_logits)) # [batch, z_sample] kl_z_batch.append(ut.kl_normal(m, v, torch.zeros_like(m), torch.ones_like(v))) # aggregate result together log_z = torch.stack(log_z_batch, dim=1) kl_z = torch.stack(kl_z_batch, dim=1) niwae = -ut.log_mean_exp(log_z - kl_z, dim=1).mean(dim=0) rec_loss = -torch.mean(log_z, dim=0) # over batch kl = torch.mean(kl_z, dim=0) ################################################################################ # End of code modification ################################################################################ return niwae, kl, rec_loss
def negative_iwae_bound(self, x, iw): """ Computes the Importance Weighted Autoencoder Bound Additionally, we also compute the ELBO KL and reconstruction terms Args: x: tensor: (batch, dim): Observations iw: int: (): Number of importance weighted samples Returns: niwae: tensor: (): Negative IWAE bound kl: tensor: (): ELBO KL divergence to prior rec: tensor: (): ELBO Reconstruction term """ ################################################################################ # TODO: Modify/complete the code here # Compute niwae (negative IWAE) with iw importance samples, and the KL # and Rec decomposition of the Evidence Lower Bound # # Outputs should all be scalar ################################################################################ # Compute the mixture of Gaussian prior prior = ut.gaussian_parameters(self.z_pre, dim=1) prior_m, prior_v = prior batch = x.shape[0] multi_x = ut.duplicate(x, iw) qm, qv = self.enc.encode(x) multi_qm = ut.duplicate(qm, iw) multi_qv = ut.duplicate(qv, iw) # z will be (batch*iw x z_dim) # with sampled z's for a given x non-contiguous! z = ut.sample_gaussian(multi_qm, multi_qv) probs = self.dec.decode(z) recs = ut.log_bernoulli_with_logits(multi_x, probs) rec = -1.0 * torch.mean(recs) multi_m = prior_m.expand(batch * iw, *prior_m.shape[1:]) multi_v = prior_v.expand(batch * iw, *prior_v.shape[1:]) z_priors = ut.log_normal_mixture(z, multi_m, multi_v) x_posteriors = recs z_posteriors = ut.log_normal(z, multi_qm, multi_qv) kls = z_posteriors - z_priors kl = torch.mean(kls) log_ratios = z_priors + x_posteriors - z_posteriors # Should be (batch*iw, z_dim), batch ratios non contiguous unflat_log_ratios = log_ratios.reshape(iw, batch) niwaes = ut.log_mean_exp(unflat_log_ratios, 0) niwae = -1.0 * torch.mean(niwaes) ################################################################################ # End of code modification ################################################################################ return niwae, kl, rec