def negative_iwae_bound(self, x, iw): """ Computes the Importance Weighted Autoencoder Bound Additionally, we also compute the ELBO KL and reconstruction terms Args: x: tensor: (batch, dim): Observations iw: int: (): Number of importance weighted samples Returns: niwae: tensor: (): Negative IWAE bound kl: tensor: (): ELBO KL divergence to prior rec: tensor: (): ELBO Reconstruction term """ ################################################################################ # TODO: Modify/complete the code here # Compute niwae (negative IWAE) with iw importance samples, and the KL # and Rec decomposition of the Evidence Lower Bound # # Outputs should all be scalar ################################################################################ # Compute the mixture of Gaussian prior prior = ut.gaussian_parameters(self.z_pre, dim=1) ################################################################################ # End of code modification ################################################################################ return niwae, kl, rec
def encode(self, x, y=None, c=None): xy = x if y is None else torch.cat((x, y), dim=-1) xyc = xy if c is None else torch.cat((xy, c), dim=-1) states = self.elu(self.initializer(xyc)).view( x.size()[0], self.NUM_LAYERS * (1 + self.BIDIRECTIONAL), 2 * self.HIDDEN_DIM) (h_0, c_0) = states.split(self.HIDDEN_DIM, dim=-1) x = x.unsqueeze(2) position = torch.arange(self.x_dim, dtype=torch.float) / (0.5 * (self.x_dim - 1)) - 1.0 position = position.view(1, self.x_dim).expand(x.size()[0], self.x_dim).unsqueeze(2) x = torch.cat((x, position), dim=2) if y is not None: y = y.unsqueeze(1).expand(x.size()[0], self.x_dim, self.y_dim) x = torch.cat((x, y), dim=2) if c is not None: c = c.unsqueeze(1).expand(x.size()[0], self.x_dim, self.c_dim) x = torch.cat((x, c), dim=2) input = self.elu(self.embedder(x)) # print(x.size()) # print(input.size()) _, (h_n, _) = self.rnn(input.transpose(1, 0), (h_0.transpose(1, 0), c_0.transpose(1, 0))) # h_n of shape (num_layers * num_directions, batch, hidden_size) # reshape to (batch, last_hidden_outputs) h = h_n.transpose(1, 0).reshape(-1, self.NUM_LAYERS * (1+self.BIDIRECTIONAL) * self.HIDDEN_DIM) out = self.regressor(h) m, v = ut.gaussian_parameters(out, dim=-1) return m, v
def encode(self, x, y=None): xy = x if y is None else torch.cat((x, y), dim=1) xy = xy.view(-1, self.channel * 96 * 96) h = self.net(xy) m, v = ut.gaussian_parameters(h, dim=1) #print(self.z_dim,m.size(),v.size()) return m, v
def decode(self, z, y=None, c=None): # Note: not designed for IW!! assert len(z.size()) <= 2 zy = z if y is None else torch.cat((z, y), dim=-1) zyc = zy if c is None else torch.cat((zy, c), dim=-1) states = self.elu(self.initializer(zyc)).view( z.size()[0], self.NUM_LAYERS * (1 + self.BIDIRECTIONAL), 2 * self.HIDDEN_DIM) (h_0, c_0) = states.split(self.HIDDEN_DIM, dim=-1) input = zyc.unsqueeze(1).expand(*zyc.size()[0:-1], self.x_dim, zyc.size()[-1]) position = torch.arange( self.x_dim, dtype=torch.float) / (0.5 * (self.x_dim - 1)) - 1.0 position = position.view(1, self.x_dim).expand(z.size()[0], self.x_dim).unsqueeze(2) input = torch.cat((input, position), dim=2) output, (_, _) = self.rnn(input.transpose(1, 0), (h_0.transpose(1, 0), c_0.transpose(1, 0))) out = self.regressor(output.transpose(1, 0)).transpose(-2, -1).reshape( z.size()[0], -1) m, v = ut.gaussian_parameters(out, dim=-1) return m, v
def mse_rnn(model, inputs, targets, n_samples): pred = model.forward(inputs).detach() if not model.constant_var: mean, var = ut.gaussian_parameters(pred, dim=-1) else: mean = pred var = model.pred_var return ((targets - mean) ** 2).sum(-1).sum(0)
def wse_rnn(model, inputs, targets): pred = model.forward(inputs).detach() if not model.constant_var: mean, var = ut.gaussian_parameters(pred, dim=-1) else: mean = pred var = model.pred_var sample_trajs = ut.sample_gaussian(mean, var) return ((targets - sample_trajs) ** 2).sum(-1).sum(0)
def conditional_encode(self, x, l): x = x.view(-1, self.channel * 96 * 96) x = F.elu(self.fc1(x)) l = l.view(-1, 4) x = F.elu(self.fc2(torch.cat([x, l], dim=1))) x = F.elu(self.fc3(x)) x = self.fc4(x) m, v = ut.gaussian_parameters(x, dim=1) return m, v
def get_mse(model, full_true_trajs, n_samples=100): """ root-weighted square error (RWSE) captures the deviation of a model’s probability mass from real-world trajectories """ n_seqs = full_true_trajs.shape[1] inputs = full_true_trajs[:model.n_input_steps, :, :].detach() targets = full_true_trajs[model.n_input_steps:, :, :2].detach() if model.BBB: for i in range(n_samples): # not using sharpening pred = model.forward(inputs) # one output sample pred = pred.detach() if i == 0: pred_list = pred.unsqueeze(-1) else: pred = pred.unsqueeze(-1) pred_list = torch.cat((pred_list, pred), dim=-1) if model.constant_var: mean_pred = pred_list.mean(dim=-1) std_pred = pred_list.std(dim=-1) else: mean_pred = pred_list[:, :, :-1, :].mean(dim=-1) std_pred = pred_list[:, :, :-1, :].std(dim=-1) else: pred = model.forward(inputs) pred = pred.detach() if not model.constant_var: mean, var = ut.gaussian_parameters(pred, dim=-1) else: mean = pred var = model.pred_var for i in range(n_samples): sample_trajs = ut.sample_gaussian(mean, var) sample_trajs = mean if i == 0: pred_list = sample_trajs.unsqueeze(-1) else: sample_trajs = sample_trajs.unsqueeze(-1) pred_list = torch.cat((pred_list, sample_trajs), dim=-1) if model.constant_var: mean_pred = pred_list.mean(dim=-1) std_pred = pred_list.std(dim=-1) else: mean_pred = pred_list[:, :, :-1, :].mean(dim=-1) std_pred = pred_list[:, :, :-1, :].std(dim=-1) mse = ((mean_pred - targets)**2).sum() / n_seqs return mse
def sample_z(self, batch): m, v = ut.gaussian_parameters(self.z_pre.squeeze(0), dim=0) # Among all the mix Gaussian distribution, sample batch size z # For each a, which distribution it belongs to is sampled by a categorical distribution. idx = torch.distributions.categorical.Categorical(self.pi).sample( (batch, )) m, v = m[idx], v[idx] return ut.sample_gaussian(m, v)
def kl_elem(self, z, qm, qv): # Compute the mixture of Gaussian prior prior_m, prior_v = ut.gaussian_parameters(self.z_pre, dim=1) log_prob_net = ut.log_normal(z, qm, qv) log_prob_prior = ut.log_normal_mixture(z, prior_m, prior_v) # print("log_prob_net:", log_prob_net.mean(), "log_prob_prior:", log_prob_prior.mean()) kl_elem = log_prob_net - log_prob_prior return kl_elem
def negative_iwae_bound(self, x, iw): """ Computes the Importance Weighted Autoencoder Bound Additionally, we also compute the ELBO KL and reconstruction terms Args: x: tensor: (batch, dim): Observations iw: int: (): Number of importance weighted samples Returns: niwae: tensor: (): Negative IWAE bound kl: tensor: (): ELBO KL divergence to prior rec: tensor: (): ELBO Reconstruction term """ ################################################################################ # TODO: Modify/complete the code here # Compute niwae (negative IWAE) with iw importance samples, and the KL # and Rec decomposition of the Evidence Lower Bound # # Outputs should all be scalar ################################################################################ # Compute the mixture of Gaussian prior prior = ut.gaussian_parameters(self.z_pre, dim=1) m, v = self.enc.encode(x) dist = Normal(loc=m, scale=torch.sqrt(v)) z_sample = dist.rsample(sample_shape=torch.Size([iw])) log_batch_z_sample_prob = [] kl_batch_z_sample = [] for i in range(iw): recon_logits = self.dec.decode(z_sample[i]) log_batch_z_sample_prob.append( ut.log_bernoulli_with_logits( x, recon_logits)) # [batch, z_sample] kl_batch_z_sample.append( ut.log_normal(z_sample[i], m, v) - ut.log_normal_mixture(z_sample[i], prior[0], prior[1])) log_batch_z_sample_prob = torch.stack(log_batch_z_sample_prob, dim=1) kl_batch_z_sample = torch.stack(kl_batch_z_sample, dim=1) niwae = -ut.log_mean_exp(log_batch_z_sample_prob - kl_batch_z_sample, dim=1).mean(dim=0) rec = -torch.mean(log_batch_z_sample_prob, dim=0) kl = torch.mean(kl_batch_z_sample, dim=0) ################################################################################ # End of code modification ################################################################################ return niwae, kl, rec
def negative_iwae_bound(self, x, iw): """ Computes the Importance Weighted Autoencoder Bound Additionally, we also compute the ELBO KL and reconstruction terms Args: x: tensor: (batch, dim): Observations iw: int: (): Number of importance weighted samples Returns: niwae: tensor: (): Negative IWAE bound kl: tensor: (): ELBO KL divergence to prior rec: tensor: (): ELBO Reconstruction term """ ################################################################################ # TODO: Modify/complete the code here # Compute niwae (negative IWAE) with iw importance samples, and the KL # and Rec decomposition of the Evidence Lower Bound # # Outputs should all be scalar ################################################################################ # Compute the mixture of Gaussian prior pm, pv = ut.gaussian_parameters(self.z_pre, dim=1) # # Generate samples. qm, qv = self.enc.encode(x) niwaes = [] recs = [] kls = [] for i in range(iw): z_sample = ut.sample_gaussian(qm, qv).view(-1, qm.shape[1]) rec = self.dec.decode(z_sample) logptheta_x_g_z = ut.log_bernoulli_with_logits(x, rec) logptheta_z = ut.log_normal_mixture(z_sample, pm, pv) logqphi_z_g_x = ut.log_normal(z_sample, qm, qv) niwae = logptheta_x_g_z + logptheta_z - logqphi_z_g_x # # Normal variables. rec = -ut.log_bernoulli_with_logits(x, rec) kl = ut.log_normal(z_sample, qm, qv) - ut.log_normal_mixture( z_sample, pm, pv) niwaes.append(niwae) recs.append(rec) kls.append(kl) niwaes = torch.stack(niwaes, -1) niwae = ut.log_mean_exp(niwaes, -1) kl = torch.stack(kls, -1) rec = torch.stack(recs, -1) ################################################################################ # End of code modification ################################################################################ return -niwae.mean(), kl.mean(), rec.mean()
def negative_elbo_bound(self, x): """ Computes the Evidence Lower Bound, KL and, Reconstruction costs Args: x: tensor: (batch, dim): Observations Returns: nelbo: tensor: (): Negative evidence lower bound kl: tensor: (): ELBO KL divergence to prior rec: tensor: (): ELBO Reconstruction term """ ################################################################################ # TODO: Modify/complete the code here # Compute negative Evidence Lower Bound and its KL and Rec decomposition # # To help you start, we have computed the mixture of Gaussians prior # prior = (m_mixture, v_mixture) for you, where # m_mixture and v_mixture each have shape (1, self.k, self.z_dim) # # Note that nelbo = kl + rec # # Outputs should all be scalar ################################################################################ # Compute the mixture of Gaussian prior (m, v) = self.enc.encode(x) # compute the encoder output #print(" ***** \n") #print("x xhape ", x.shape) #print("m and v shapes = ", m.shape, v.shape) prior = ut.gaussian_parameters(self.z_pre, dim=1) #print("prior shapes = ", prior[0].shape, prior[1].shape) z = ut.sample_gaussian(m, v) # sample a point from the multivariate Gaussian #print("shape of z = ",z.shape) logits = self.dec.decode(z) # pass the sampled "Z" through the decoder #print("logits shape = ", logits.shape) rec = -torch.mean(ut.log_bernoulli_with_logits(x, logits), -1) # Calculate log Prob of the output log_prob = ut.log_normal(z, m, v) log_prob -= ut.log_normal_mixture(z, prior[0], prior[1]) kl = torch.mean(log_prob) rec = torch.mean(rec) nelbo = kl + rec ################################################################################ # End of code modification ################################################################################ return nelbo, kl, rec
def negative_elbo_bound(self, x): """ Computes the Evidence Lower Bound, KL and, Reconstruction costs Args: x: tensor: (batch, dim): Observations Returns: nelbo: tensor: (): Negative evidence lower bound kl: tensor: (): ELBO KL divergence to prior rec: tensor: (): ELBO Reconstruction term """ ################################################################################ # TODO: Modify/complete the code here # Compute negative Evidence Lower Bound and its KL and Rec decomposition # # To help you start, we have computed the mixture of Gaussians prior # prior = (m_mixture, v_mixture) for you, where # m_mixture and v_mixture each have shape (1, self.k, self.z_dim) # # Note that nelbo = kl + rec # # Outputs should all be scalar ################################################################################ # Compute the mixture of Gaussian prior prior = ut.gaussian_parameters(self.z_pre, dim=1) q_m, q_v = self.enc.encode(x) #print("q_m", q_m.size()) z_given_x = ut.sample_gaussian(q_m, q_v) decoded_bernoulli_logits = self.dec.decode(z_given_x) rec = -ut.log_bernoulli_with_logits(x, decoded_bernoulli_logits) #rec = -torch.mean(rec) #terms for KL divergence log_q_phi = ut.log_normal(z_given_x, q_m, q_v) #print("log_q_phi", log_q_phi.size()) log_p_theta = ut.log_normal_mixture(z_given_x, prior[0], prior[1]) #print("log_p_theta", log_p_theta.size()) kl = log_q_phi - log_p_theta #print("kl", kl.size()) nelbo = torch.mean(kl + rec) rec = torch.mean(rec) kl = torch.mean(kl) ################################################################################ # End of code modification ################################################################################ return nelbo, kl, rec
def get_nll(self, outputs, targets): """ :return: negative log-likelihood of a minibatch """ if self.likelihood_cost_form == 'mse': # This method is not validated return self.mse_fn(outputs, targets) elif self.likelihood_cost_form == 'gaussian': if not self.constant_var: mean, var = ut.gaussian_parameters(outputs, dim=-1) return -torch.mean(ut.log_normal(targets, mean, var)) else: var = self.pred_var * torch.ones_like(outputs) return -torch.mean(ut.log_normal(targets, outputs, var))
def negative_iwae_bound(self, x, iw): """ Computes the Importance Weighted Autoencoder Bound Additionally, we also compute the ELBO KL and reconstruction terms Args: x: tensor: (batch, dim): Observations iw: int: (): Number of importance weighted samples Returns: niwae: tensor: (): Negative IWAE bound kl: tensor: (): ELBO KL divergence to prior rec: tensor: (): ELBO Reconstruction term """ ################################################################################ # TODO: Modify/complete the code here # Compute niwae (negative IWAE) with iw importance samples, and the KL # and Rec decomposition of the Evidence Lower Bound # # Outputs should all be scalar ################################################################################ # Compute the mixture of Gaussian prior prior = ut.gaussian_parameters(self.z_pre, dim=1) q_m, q_v = self.enc.encode(x) q_m_, q_v_ = ut.duplicate(q_m, rep=iw), ut.duplicate(q_v, rep=iw) z_given_x = ut.sample_gaussian(q_m_, q_v_) decoded_bernoulli_logits = self.dec.decode(z_given_x) #duplicate x x_dup = ut.duplicate(x, rep=iw) rec = ut.log_bernoulli_with_logits(x_dup, decoded_bernoulli_logits) log_p_theta = ut.log_normal_mixture(z_given_x, prior[0], prior[1]) log_q_phi = ut.log_normal(z_given_x, q_m_, q_v_) kl = log_q_phi - log_p_theta niwae = rec - kl niwae = ut.log_mean_exp(niwae.reshape(iw, -1), dim=0) niwae = -torch.mean(niwae) #yay! ################################################################################ # End of code modification ################################################################################ return niwae, kl, rec
def negative_elbo_bound(self, x): """ Computes the Evidence Lower Bound, KL and, Reconstruction costs Args: x: tensor: (batch, dim): Observations Returns: nelbo: tensor: (): Negative evidence lower bound kl: tensor: (): ELBO KL divergence to prior rec: tensor: (): ELBO Reconstruction term """ ################################################################################ # TODO: Modify/complete the code here # Compute negative Evidence Lower Bound and its KL and Rec decomposition # # To help you start, we have computed the mixture of Gaussians prior # prior = (m_mixture, v_mixture) for you, where # m_mixture and v_mixture each have shape (1, self.k, self.z_dim) # # Note that nelbo = kl + rec # # Outputs should all be scalar ################################################################################ # Compute the mixture of Gaussian prior prior = ut.gaussian_parameters(self.z_pre, dim=1) prior_m, prior_v = prior batch = x.shape[0] qm, qv = self.enc.encode(x) # Now draw Zs from the posterior qm/qv z = ut.sample_gaussian(qm, qv) l_posterior = ut.log_normal(z, qm, qv) multi_m = prior_m.expand(batch, *prior_m.shape[1:]) multi_v = prior_v.expand(batch, *prior_v.shape[1:]) l_prior = ut.log_normal_mixture(z, multi_m, multi_v) kls = l_posterior - l_prior kl = torch.mean(kls) probs = self.dec.decode(z) recs = ut.log_bernoulli_with_logits(x, probs) rec = -1.0 * torch.mean(recs) nelbo = kl + rec ################################################################################ # End of code modification ################################################################################ return nelbo, kl, rec
def negative_elbo_bound(self, x): """ Computes the Evidence Lower Bound, KL and, Reconstruction costs Args: x: tensor: (batch, dim): Observations Returns: nelbo: tensor: (): Negative evidence lower bound kl: tensor: (): ELBO KL divergence to prior rec: tensor: (): ELBO Reconstruction term """ ################################################################################ # TODO: Modify/complete the code here # Compute negative Evidence Lower Bound and its KL and Rec decomposition # # To help you start, we have computed the mixture of Gaussians prior # prior = (m_mixture, v_mixture) for you, where # m_mixture and v_mixture each have shape (1, self.k, self.z_dim) # # Note that nelbo = kl + rec # # Outputs should all be scalar ################################################################################ # # Compute the mixture of Gaussian prior pm, pv = ut.gaussian_parameters(self.z_pre, dim=1) # # Generate samples. qm, qv = self.enc.encode(x) z_sample = ut.sample_gaussian(qm, qv) rec = self.dec.decode(z_sample) # # Compute loss. # KL divergence between the latent distribution and the prior. rec = -ut.log_bernoulli_with_logits(x, rec) # kl = ut.kl_normal(qm, qv, pm, pv) kl = ut.log_normal(z_sample, qm, qv) - ut.log_normal_mixture( z_sample, pm, pv) # # The liklihood of reproducing the sample image given the parameters. # Would need to take the average of this otherwise. nelbo = (kl + rec).mean() # NELBO: 89.24684143066406. KL: 10.346451759338379. Rec: 78.90038299560547 ################################################################################ # End of code modification ################################################################################ return nelbo, kl.mean(), rec.mean()
def negative_iwae_bound(self, x, iw): """ Computes the Importance Weighted Autoencoder Bound Additionally, we also compute the ELBO KL and reconstruction terms Args: x: tensor: (batch, dim): Observations iw: int: (): Number of importance weighted samples Returns: niwae: tensor: (): Negative IWAE bound kl: tensor: (): ELBO KL divergence to prior rec: tensor: (): ELBO Reconstruction term """ ################################################################################ # TODO: Modify/complete the code here # Compute niwae (negative IWAE) with iw importance samples, and the KL # and Rec decomposition of the Evidence Lower Bound # # Outputs should all be scalar ################################################################################ # Compute the mixture of Gaussian prior prior = ut.gaussian_parameters(self.z_pre, dim=1) N_batches, dims = x.size() x = ut.duplicate(x, iw) q_mu, q_var = self.enc.encode(x) z_samp = ut.sample_gaussian(q_mu, q_var) logits = self.dec.decode(z_samp) probs = ut.log_bernoulli_with_logits(x, logits) kl_vals = -ut.log_normal(z_samp, q_mu, q_var) + ut.log_normal_mixture(z_samp, *prior) probs = probs + kl_vals niwae = torch.mean(-ut.log_mean_exp(probs.reshape(N_batches, iw), 1)) kl = torch.tensor(0) rec = torch.tensor(0) ################################################################################ # End of code modification ################################################################################ return niwae, kl, rec
def get_rwse(model, full_true_trajs, n_samples=100): """ root-weighted square error (RWSE) captures the deviation of a model’s probability mass from real-world trajectories """ n_seqs = full_true_trajs.shape[1] inputs = full_true_trajs[:model.n_input_steps, :, :].detach() targets = full_true_trajs[model.n_input_steps:, :, :2].detach() if model.BBB: for i in range(n_samples): # not using sharpening pred = model.forward(inputs) pred = pred.detach() if not model.constant_var: pred = pred[:, :, :-1] mean_sq_err = ((targets - pred)**2).sum() / n_seqs if i == 0: mean_sq_err_list = mean_sq_err.unsqueeze(-1) else: mean_sq_err = mean_sq_err.unsqueeze(-1) mean_sq_err_list = torch.cat((mean_sq_err_list, mean_sq_err), dim=-1) else: pred = model.forward(inputs) pred = pred.detach() if not model.constant_var: mean, var = ut.gaussian_parameters(pred, dim=-1) else: mean = pred var = model.pred_var for i in range(n_samples): sample_trajs = ut.sample_gaussian(mean, var) mean_sq_err = ((targets - sample_trajs)**2).sum() / n_seqs if i == 0: mean_sq_err_list = mean_sq_err.unsqueeze(-1) else: mean_sq_err = mean_sq_err.unsqueeze(-1) mean_sq_err_list = torch.cat((mean_sq_err_list, mean_sq_err), dim=-1) mean_rwse = mean_sq_err_list.mean().sqrt() return mean_rwse
def negative_elbo_bound(self, x): """ Computes the Evidence Lower Bound, KL and, Reconstruction costs Args: x: tensor: (batch, dim): Observations Returns: nelbo: tensor: (): Negative evidence lower bound kl: tensor: (): ELBO KL divergence to prior rec: tensor: (): ELBO Reconstruction term """ ################################################################################ # TODO: Modify/complete the code here # Compute negative Evidence Lower Bound and its KL and Rec decomposition # # To help you start, we have computed the mixture of Gaussians prior # prior = (m_mixture, v_mixture) for you, where # m_mixture and v_mixture each have shape (1, self.k, self.z_dim) # # Note that nelbo = kl + rec # # Outputs should all be scalar ################################################################################ # Compute the mixture of Gaussian prior prior = ut.gaussian_parameters(self.z_pre, dim=1) N_samp, dim = x.size() q_mu, q_var = self.enc.encode(x) z_samp = ut.sample_gaussian(q_mu, q_var) logits = self.dec.decode(z_samp) rec = -torch.mean(ut.log_bernoulli_with_logits(x, logits)) kl = torch.mean(ut.log_normal(z_samp, q_mu, q_var) - ut.log_normal_mixture(z_samp, *prior)) nelbo = kl + rec ################################################################################ # End of code modification ################################################################################ return nelbo, kl, rec
def negative_iwae_bound(self, x, iw): """ Computes the Importance Weighted Autoencoder Bound Additionally, we also compute the ELBO KL and reconstruction terms Args: x: tensor: (batch, dim): Observations iw: int: (): Number of importance weighted samples Returns: niwae: tensor: (): Negative IWAE bound kl: tensor: (): ELBO KL divergence to prior rec: tensor: (): ELBO Reconstruction term """ ################################################################################ # TODO: Modify/complete the code here # Compute niwae (negative IWAE) with iw importance samples, and the KL # and Rec decomposition of the Evidence Lower Bound # # Outputs should all be scalar ################################################################################ # Compute the mixture of Gaussian prior prior = ut.gaussian_parameters(self.z_pre, dim=1) m, v = self.enc.encode(x) m = ut.duplicate(m, iw) v = ut.duplicate(v, iw) x = ut.duplicate(x, iw) z = ut.sample_gaussian(m, v) logits = self.dec.decode(z) kl = ut.log_normal(z, m, v) - ut.log_normal_mixture(z, *prior) rec = -ut.log_bernoulli_with_logits(x, logits) nelbo = kl + rec niwae = -ut.log_mean_exp(-nelbo.reshape(iw, -1), dim=0) niwae, kl, rec = niwae.mean(), kl.mean(), rec.mean() ################################################################################ # End of code modification ################################################################################ return niwae, kl, rec
def decode(self, z, y=None, c=None): zy = z if y is None else torch.cat((z, y), dim=-1) zyc = zy if c is None else torch.cat((zy, c), dim=-1) h = self.net(zyc) m, v = ut.gaussian_parameters(h, dim=-1) return m, v
def encode(self, x, y=None, c=None): xy = x if y is None else torch.cat((x, y), dim=-1) xyc = xy if c is None else torch.cat((xy, c), dim=-1) h = self.net(xyc) m, v = ut.gaussian_parameters(h, dim=-1) return m, v
def encode(self, x, y=None): xy = x if y is None else torch.cat((x, y), dim=1) h = self.net(xy.float()) m, v = ut.gaussian_parameters(h, dim=1) return m, v
def encode_simple(self, x): x = self.conv6(x) m, v = ut.gaussian_parameters(x, dim=1) #print(m.size()) return m, v
def sample_z(self, batch): m, v = ut.gaussian_parameters(self.z_pre.squeeze(0), dim=0) idx = torch.distributions.categorical.Categorical(self.pi).sample( (batch, )) m, v = m[idx], v[idx] return ut.sample_gaussian(m, v)
def train(model, train_data, batch_size, n_batches, lr=1e-3, clip_grad=None, iter_max=np.inf, iter_save=np.inf, iter_plot=np.inf, reinitialize=False, kernel=None): # Optimization if reinitialize: model.apply(ut.reset_weights) optimizer = optim.Adam(model.parameters(), lr=lr) mse = nn.MSELoss() # # Model # hidden = model.init_hidden(batch_size) i = 0 # i is num of gradient steps taken by end of loop iteration loss_list = [] mse_list = [] with tqdm.tqdm(total=iter_max) as pbar: while True: for batch in train_data: i += 1 # print(psutil.virtual_memory()) optimizer.zero_grad() inputs = batch[:model.n_input_steps, :, :] targets = batch[model.n_input_steps:, :, :2] # Since the data is not continued from batch to batch, # reinit hidden every batch. (using zeros) outputs = model.forward(inputs, targets=targets) batch_mean_nll, KL, KL_sharp = model.get_loss(outputs, targets) # print(batch_mean_nll, KL, KL_sharp) # # Re-weighting for minibatches NLL_term = batch_mean_nll * model.n_pred_steps # Here B = n_batchs, C = 1 (since each sequence is complete) KL_term = KL / n_batches loss = NLL_term + KL_term if model.sharpen: KL_sharp /= n_batches loss += KL_sharp loss_list.append(loss.cpu().detach()) if clip_grad is not None: torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad) # Print progress if model.likelihood_cost_form == 'gaussian': if model.constant_var: mse_val = mse(outputs, targets) * model.n_pred_steps else: if model.rnn_cell_type == 'FF': mean, var = ut.gaussian_parameters_ff(outputs, dim=0) else: mean, var = ut.gaussian_parameters(outputs, dim=-1) mse_val = mse(mean, targets) * model.n_pred_steps elif model.likelihood_cost_form == 'mse': mse_val = batch_mean_nll * model.n_pred_steps mse_list.append(mse_val.cpu().detach()) if i % iter_plot == 0: with torch.no_grad(): model.eval() if model.input_feat_dim <= 2: ut.test_plot(model, i, kernel) elif model.input_feat_dim == 4: rand_idx = random.sample(range(batch.shape[1]), 4) full_true_traj = batch[:, rand_idx, :] if not model.BBB: if model.constant_var: pred_traj = outputs[:, rand_idx, :] std_pred = None else: pred_traj = mean[:, rand_idx, :] std_pred = var.sqrt() ut.plot_highd_traj(model, i, full_true_traj, pred_traj, std_pred=std_pred) else: # resample a few forward passes ut.plot_highd_traj_BBB(model, i, full_true_traj, n_resample_weights=10) ut.plot_history(model, loss_list, i, obj='loss') ut.plot_history(model, mse_list, i, obj='mse') model.train() # loss.backward(retain_graph=True) loss.backward() optimizer.step() pbar.set_postfix(loss='{:.2e}'.format(loss), mse='{:.2e}'.format(mse_val)) pbar.update(1) # Save model if i % iter_save == 0: ut.save_model_by_name(model, i, only_latest=True) if i == iter_max: return
def negative_iwae_bound(self, x, iw): """ Computes the Importance Weighted Autoencoder Bound Additionally, we also compute the ELBO KL and reconstruction terms Args: x: tensor: (batch, dim): Observations iw: int: (): Number of importance weighted samples Returns: niwae: tensor: (): Negative IWAE bound kl: tensor: (): ELBO KL divergence to prior rec: tensor: (): ELBO Reconstruction term """ ################################################################################ # TODO: Modify/complete the code here # Compute niwae (negative IWAE) with iw importance samples, and the KL # and Rec decomposition of the Evidence Lower Bound # # Outputs should all be scalar ################################################################################ # Compute the mixture of Gaussian prior prior = ut.gaussian_parameters(self.z_pre, dim=1) prior_m, prior_v = prior batch = x.shape[0] multi_x = ut.duplicate(x, iw) qm, qv = self.enc.encode(x) multi_qm = ut.duplicate(qm, iw) multi_qv = ut.duplicate(qv, iw) # z will be (batch*iw x z_dim) # with sampled z's for a given x non-contiguous! z = ut.sample_gaussian(multi_qm, multi_qv) probs = self.dec.decode(z) recs = ut.log_bernoulli_with_logits(multi_x, probs) rec = -1.0 * torch.mean(recs) multi_m = prior_m.expand(batch * iw, *prior_m.shape[1:]) multi_v = prior_v.expand(batch * iw, *prior_v.shape[1:]) z_priors = ut.log_normal_mixture(z, multi_m, multi_v) x_posteriors = recs z_posteriors = ut.log_normal(z, multi_qm, multi_qv) kls = z_posteriors - z_priors kl = torch.mean(kls) log_ratios = z_priors + x_posteriors - z_posteriors # Should be (batch*iw, z_dim), batch ratios non contiguous unflat_log_ratios = log_ratios.reshape(iw, batch) niwaes = ut.log_mean_exp(unflat_log_ratios, 0) niwae = -1.0 * torch.mean(niwaes) ################################################################################ # End of code modification ################################################################################ return niwae, kl, rec
def encode(self, x): h = self.net(x) m, v = ut.gaussian_parameters(h, dim=1) return m, v