Ejemplo n.º 1
0
 def test_exponential_shape_scalar_param(self):
     expon = Exponential(1.)
     self.assertEqual(expon._batch_shape, torch.Size())
     self.assertEqual(expon._event_shape, torch.Size())
     self.assertEqual(expon.sample().size(), torch.Size((1, )))
     self.assertEqual(expon.sample((3, 2)).size(), torch.Size((3, 2)))
     self.assertRaises(ValueError, expon.log_prob, self.scalar_sample)
     self.assertEqual(
         expon.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
     self.assertEqual(
         expon.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
Ejemplo n.º 2
0
    def jump(self, xi, t):
        n = xi.size(0)
        # spatial multivariate gaussian
        m_gauss = MultivariateNormal(self.mu_jump * torch.ones(n),
                                     self.std_jump * torch.eye(n))
        # poisson process, probability of arrival at time t
        exp_d = Exponential(self.lambd)
        # independent events, mult probabilities
        p = torch.exp(m_gauss.log_prob(xi)) * (
            1 - torch.exp(exp_d.log_prob(self.last_jump)))

        # one sample from bernoulli trial
        event = Bernoulli(p).sample([1])

        if event:
            coord_before = xi
            xi = self.jump_event(xi, t)  # flatten resulting sampled location
            coord_after = xi
            # saving jump coordinate info
            self.log_jump(t, coord_before, coord_after)

            self.last_jump = 0
        # if no jump, increase counter for bern trial
        else:
            self.last_jump += 1

        return xi
Ejemplo n.º 3
0
 def test_exponential_shape_tensor_param(self):
     expon = Exponential(torch.Tensor([1, 1]))
     self.assertEqual(expon._batch_shape, torch.Size((2,)))
     self.assertEqual(expon._event_shape, torch.Size(()))
     self.assertEqual(expon.sample().size(), torch.Size((2,)))
     self.assertEqual(expon.sample((3, 2)).size(), torch.Size((3, 2, 2)))
     self.assertEqual(expon.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
     self.assertRaises(ValueError, expon.log_prob, self.tensor_sample_2)
Ejemplo n.º 4
0
    def likelihood(self, x, z_obj):
        """Evaluate likelihood of x under model.

        Args:
            x (torch.Tensor), (n, T, c, w, h) The given sequence of observations.
            z_obj (torch.Tensor), (nTO, 4): Samples from z distribution.

        Returns:
            log_p_xz (torch.Tensor), (nT): Image likelihood.

        """
        # reshape x to merge batch and sequences to pseudo batch since spn
        # will work on image basis shape (n4, c, w, h)
        x_img = x.flatten(end_dim=1)

        # 1. Background Likelihood
        # reshape to (n4, o, 4) and extract marginalisation information
        z_img = z_obj.view(-1, self.c.num_obj, 4)
        marginalise_patch, marginalise_bg, overlap_ratios = self.masks_from_z(
            z_img)
        # flatten, st. shape is (n4, cwh) for both
        img_flat = x_img.flatten(start_dim=1)
        marg_flat = marginalise_bg.flatten(start_dim=1)
        # get likelihood of background under bg_spn, output from (n4, 1) to (n4)
        bg_loglik = self.bg_spn.forward(img_flat, marg_flat)[:, 0]

        # 2. Patch Likelihoods
        # extract patches (n4o, c, w, h) from transformer
        patches = self.patches_from_z(x_img, z_obj)
        # flatten into (n4o, c w_out h_out)
        patches_flat = patches.flatten(start_dim=1)
        marginalise_flat = marginalise_patch.flatten(start_dim=1)
        # (n4o)
        patches_loglik = self.obj_spn.forward(patches_flat,
                                              marginalise_flat)[:, 0]
        # scale patch_likelihoods by size of patch to obtain
        # well calibrated likelihoods
        patches_loglik = patches_loglik * z_obj[:, 0] * z_obj[:, 1]
        # shape n4o to n4,o
        patches_loglik = patches_loglik.view(-1, self.c.num_obj)

        # 3. Add Exponential overlap_penalty
        overlap_prior = Exponential(self.c.overlap_beta)
        overlap_log_liks = overlap_prior.log_prob(overlap_ratios)

        # 4. Assemble final img likelihood E_q(z|x)[log p(x, z)]
        # expectation is approximated by a single sample
        patches_loglik = patches_loglik.sum(1)
        overlap_log_liks = overlap_log_liks.sum(1)
        scores = [bg_loglik, patches_loglik, overlap_log_liks]
        scores = torch.stack(scores, -1)
        # shape (n4)
        log_p_xz = scores.sum(-1)

        if ((self.step_counter % self.c.print_every == 0)
                or (self.step_counter % self.c.plot_every == 0)):
            if self.c.debug:
                self.prop_dict['bg'] = bg_loglik.mean().detach()
                self.prop_dict['patch'] = patches_loglik.mean().detach()
                self.prop_dict['overlap'] = overlap_log_liks.mean().detach()
            if self.c.debug and self.c.debug_extend_plots:
                self.prop_dict['overlap_ratios'] = overlap_ratios.detach()
                self.prop_dict['patches'] = patches.detach()
                self.prop_dict['marginalise_flat'] = marginalise_flat.detach()
                self.prop_dict['patches_loglik'] = patches_loglik.detach()
                self.prop_dict['marginalise_bg'] = marginalise_bg.detach()
                self.prop_dict['bg_loglik'] = bg_loglik.detach()

        return log_p_xz, self.prop_dict