def compute_cond_log_likelihood(self, posteriors, batch, mode='train'): ''' Computer E[log p(y|x,z)] up to constant additional factors arrays are N_tasks x N_samples x dim ''' # not dealing with more than 1 case right now n_samples = 1 input_batch_list, output_batch_list = batch['input_batch_list'], batch[ 'output_batch_list'] mask = batch['mask'].view(-1, 1) z_means = posteriors[0] z_log_covs = posteriors[1] z_covs = torch.exp(z_log_covs) if mode == 'eval': z_samples = z_means else: z_samples = sample_diag_gaussians(z_means, z_covs, n_samples) z_samples = local_repeat(z_samples, input_batch_list[0].size(1)) input_batch_list = [ inp.view(-1, inp.size(2)) for inp in input_batch_list ] output_batch_list = [ out.view(-1, out.size(2)) for out in output_batch_list ] preds = self.base_map(z_samples, input_batch_list) if self.base_map.siamese_output: log_prob = 0.0 if self.base_map.deterministic: for pred, output in zip(preds, output_batch_list): log_prob += compute_spherical_log_prob( pred, output, mask, n_samples) else: for pred, output in zip(preds, output_batch_list): log_prob += compute_diag_log_prob(pred[0], pred[1], output, mask, n_samples) else: if self.base_map.deterministic: log_prob = compute_spherical_log_prob(preds[0], output_batch_list[0], mask, n_samples) else: preds_mean, preds_log_cov = preds[0][0], preds[0][1] log_prob = compute_diag_log_prob(preds_mean, preds_log_cov, output_batch_list[0], mask, n_samples) return log_prob
def sample_outputs(self, posteriors, input_batch_list, n_samples): z_means = posteriors[0] z_log_covs = posteriors[1] z_covs = torch.exp(z_log_covs) z_samples = sample_diag_gaussians(z_means, z_covs, n_samples) z_samples = local_repeat(z_samples, input_batch_list[0].size(1)) num_tasks, num_per_task = input_batch_list[0].size(0), input_batch_list[0].size(1) input_batch_list = [inp.contiguous().view(-1,inp.size(2)) for inp in input_batch_list] input_batch_list = [local_repeat(inp, n_samples) for inp in input_batch_list] if (not self.base_map.siamese_output) and self.base_map.deterministic: outputs = self.base_map(z_samples, input_batch_list)[0] outputs = outputs.view(num_tasks, n_samples, num_per_task, outputs.size(-1)) else: raise NotImplementedError return outputs