def sample_outputs(self, posteriors, input_batch_list, n_samples): z_means = posteriors[0] z_log_covs = posteriors[1] z_covs = torch.exp(z_log_covs) z_samples = sample_diag_gaussians(z_means, z_covs, n_samples) z_samples = local_repeat(z_samples, input_batch_list[0].size(1)) num_tasks, num_per_task = input_batch_list[0].size(0), input_batch_list[0].size(1) input_batch_list = [inp.contiguous().view(-1,inp.size(2)) for inp in input_batch_list] input_batch_list = [local_repeat(inp, n_samples) for inp in input_batch_list] if (not self.base_map.siamese_output) and self.base_map.deterministic: outputs = self.base_map(z_samples, input_batch_list)[0] outputs = outputs.view(num_tasks, n_samples, num_per_task, outputs.size(-1)) else: raise NotImplementedError return outputs
def compute_cond_log_likelihood(self, posteriors, batch, mode='train'): ''' Computer E[log p(y|x,z)] up to constant additional factors arrays are N_tasks x N_samples x dim ''' # not dealing with more than 1 case right now n_samples = 1 input_batch_list, output_batch_list = batch['input_batch_list'], batch[ 'output_batch_list'] mask = batch['mask'].view(-1, 1) z_means = posteriors[0] z_log_covs = posteriors[1] z_covs = torch.exp(z_log_covs) if mode == 'eval': z_samples = z_means else: z_samples = sample_diag_gaussians(z_means, z_covs, n_samples) z_samples = local_repeat(z_samples, input_batch_list[0].size(1)) input_batch_list = [ inp.view(-1, inp.size(2)) for inp in input_batch_list ] output_batch_list = [ out.view(-1, out.size(2)) for out in output_batch_list ] preds = self.base_map(z_samples, input_batch_list) if self.base_map.siamese_output: log_prob = 0.0 if self.base_map.deterministic: for pred, output in zip(preds, output_batch_list): log_prob += compute_spherical_log_prob( pred, output, mask, n_samples) else: for pred, output in zip(preds, output_batch_list): log_prob += compute_diag_log_prob(pred[0], pred[1], output, mask, n_samples) else: if self.base_map.deterministic: log_prob = compute_spherical_log_prob(preds[0], output_batch_list[0], mask, n_samples) else: preds_mean, preds_log_cov = preds[0][0], preds[0][1] log_prob = compute_diag_log_prob(preds_mean, preds_log_cov, output_batch_list[0], mask, n_samples) return log_prob
encoder_optim.zero_grad() r_to_z_map_optim.zero_grad() base_map_optim.zero_grad() r = encoder(torch.cat([X, Y], 1)) r_dim = r.size(-1) r = r.view(N_tasks, N_samples, r_dim) # r = torch.mean(r, 1) r = torch.sum(r, 1) mean, log_cov = r_to_z_map(r) cov = torch.exp(log_cov) if not sample_one_z_per_sample: z = Variable(torch.randn(mean.size())) * cov + mean z = local_repeat(z, N_samples) else: rep_mean = local_repeat(mean, N_samples) rep_cov = local_repeat(cov, N_samples) z = Variable(torch.randn(rep_mean.size())) * rep_cov + rep_mean Y_pred = base_map(z, X) KL = -0.5 * torch.sum(1.0 + log_cov - mean**2 - cov) cond_log_likelihood = -0.5 * torch.sum((Y_pred - Y)**2) neg_elbo = -1.0 * (cond_log_likelihood - KL) / float(N_tasks) neg_elbo.backward() base_map_optim.step()
encoder_optim.zero_grad() r_to_z_map_optim.zero_grad() # z2w_optim.zero_grad() # base_map_optim.zero_grad() r = encoder(torch.cat([X, Y], 1)) r_dim = r.size(-1) r = r.view(N_tasks, N_samples, r_dim) r = torch.mean(r, 1) # r = torch.sum(r, 1) mean, log_cov = r_to_z_map(r) cov = torch.exp(log_cov) z = Variable(torch.randn(mean.size())) * cov + mean z = local_repeat(z, N_samples) # Y_pred = base_map(z, X) Y_pred = X * z # w1, w2, w3 = z2w(z) # w1 = w1.view(-1, 1, w1.size(1)) # w3 = w3.view(-1, w1.size(-1), 1) # Y_pred = torch.matmul(X.view(-1,1,1), w1) # Y_pred = torch.matmul(Y_pred, w3) # Y_pred = Y_pred.view(-1,1) KL = -0.5 * torch.sum(1.0 + log_cov - mean**2 - cov) cond_log_likelihood = -0.5 * torch.sum((Y_pred - Y)**2)