Exemple #1
0
    def compute_cond_log_likelihood(self, posteriors, batch, mode='train'):
        '''
            Computer E[log p(y|x,z)] up to constant additional factors

            arrays are N_tasks x N_samples x dim
        '''
        # not dealing with more than 1 case right now
        n_samples = 1

        input_batch_list, output_batch_list = batch['input_batch_list'], batch[
            'output_batch_list']
        mask = batch['mask'].view(-1, 1)

        z_means = posteriors[0]
        z_log_covs = posteriors[1]
        z_covs = torch.exp(z_log_covs)

        if mode == 'eval':
            z_samples = z_means
        else:
            z_samples = sample_diag_gaussians(z_means, z_covs, n_samples)
        z_samples = local_repeat(z_samples, input_batch_list[0].size(1))
        input_batch_list = [
            inp.view(-1, inp.size(2)) for inp in input_batch_list
        ]
        output_batch_list = [
            out.view(-1, out.size(2)) for out in output_batch_list
        ]

        preds = self.base_map(z_samples, input_batch_list)

        if self.base_map.siamese_output:
            log_prob = 0.0
            if self.base_map.deterministic:
                for pred, output in zip(preds, output_batch_list):
                    log_prob += compute_spherical_log_prob(
                        pred, output, mask, n_samples)
            else:
                for pred, output in zip(preds, output_batch_list):
                    log_prob += compute_diag_log_prob(pred[0], pred[1], output,
                                                      mask, n_samples)
        else:
            if self.base_map.deterministic:
                log_prob = compute_spherical_log_prob(preds[0],
                                                      output_batch_list[0],
                                                      mask, n_samples)
            else:
                preds_mean, preds_log_cov = preds[0][0], preds[0][1]
                log_prob = compute_diag_log_prob(preds_mean, preds_log_cov,
                                                 output_batch_list[0], mask,
                                                 n_samples)

        return log_prob
Exemple #2
0
    def sample_outputs(self, posteriors, input_batch_list, n_samples):
        z_means = posteriors[0]
        z_log_covs = posteriors[1]
        z_covs = torch.exp(z_log_covs)
        z_samples = sample_diag_gaussians(z_means, z_covs, n_samples)
        z_samples = local_repeat(z_samples, input_batch_list[0].size(1))

        num_tasks, num_per_task = input_batch_list[0].size(0), input_batch_list[0].size(1)
        input_batch_list = [inp.contiguous().view(-1,inp.size(2)) for inp in input_batch_list]
        input_batch_list = [local_repeat(inp, n_samples) for inp in input_batch_list]

        if (not self.base_map.siamese_output) and self.base_map.deterministic:
            outputs = self.base_map(z_samples, input_batch_list)[0]
            outputs = outputs.view(num_tasks, n_samples, num_per_task, outputs.size(-1))
        else:
            raise NotImplementedError

        return outputs