Example #1
0
def test_compose_affine(event_dims):
    transforms = [AffineTransform(torch.zeros((1,) * e), 1, event_dim=e) for e in event_dims]
    transform = ComposeTransform(transforms)
    assert transform.codomain.event_dim == max(event_dims)
    assert transform.domain.event_dim == max(event_dims)

    base_dist = Normal(0, 1)
    if transform.domain.event_dim:
        base_dist = base_dist.expand((1,) * transform.domain.event_dim)
    dist = TransformedDistribution(base_dist, transform.parts)
    assert dist.support.event_dim == max(event_dims)

    base_dist = Dirichlet(torch.ones(5))
    if transform.domain.event_dim > 1:
        base_dist = base_dist.expand((1,) * (transform.domain.event_dim - 1))
    dist = TransformedDistribution(base_dist, transforms)
    assert dist.support.event_dim == max(1, max(event_dims))
Example #2
0
def test_transformed_distribution(base_batch_dim, base_event_dim,
                                  transform_dim, num_transforms, sample_shape):
    shape = torch.Size([2, 3, 4, 5])
    base_dist = Normal(0, 1)
    base_dist = base_dist.expand(shape[4 - base_batch_dim - base_event_dim:])
    if base_event_dim:
        base_dist = Independent(base_dist, base_event_dim)
    transforms = [
        AffineTransform(torch.zeros(shape[4 - transform_dim:]), 1),
        ReshapeTransform((4, 5), (20, )),
        ReshapeTransform((3, 20), (6, 10))
    ]
    transforms = transforms[:num_transforms]
    transform = ComposeTransform(transforms)

    # Check validation in .__init__().
    if base_batch_dim + base_event_dim < transform.domain.event_dim:
        with pytest.raises(ValueError):
            TransformedDistribution(base_dist, transforms)
        return
    d = TransformedDistribution(base_dist, transforms)

    # Check sampling is sufficiently expanded.
    x = d.sample(sample_shape)
    assert x.shape == sample_shape + d.batch_shape + d.event_shape
    num_unique = len(set(x.reshape(-1).tolist()))
    assert num_unique >= 0.9 * x.numel()

    # Check log_prob shape on full samples.
    log_prob = d.log_prob(x)
    assert log_prob.shape == sample_shape + d.batch_shape

    # Check log_prob shape on partial samples.
    y = x
    while y.dim() > len(d.event_shape):
        y = y[0]
    log_prob = d.log_prob(y)
    assert log_prob.shape == d.batch_shape
    def compute_losses(self,
                       pd_dict: dict,
                       data_batch: dict,
                       recon_weight=1.0,
                       kl_weight=1.0,
                       **kwargs):
        loss_dict = dict()
        image = data_batch['image']
        b, c0, h0, w0 = image.size()

        # ---------------------------------------------------------------------------- #
        # Reconstruction
        # ---------------------------------------------------------------------------- #
        fg_recon = pd_dict['fg_recon']  # (b, c0, h0, w0)
        fg_mask = pd_dict['fg_mask']  # (b, 1, h0, w0)
        bg_recon = pd_dict['bg_recon']  # (b, c0, h0, w0)

        fg_recon_dist = Normal(fg_recon, scale=data_batch['fg_recon_scale_prior'])
        fg_recon_log_prob = fg_recon_dist.log_prob(image) + torch.log(fg_mask.clamp(min=self._eps))
        bg_recon_dist = Normal(bg_recon, scale=data_batch['bg_recon_scale_prior'])
        bg_recon_log_prob = bg_recon_dist.log_prob(image) + torch.log((1.0 - fg_mask).clamp(min=self._eps))
        # conditional probability p(x|z) = p(x|fg, z) * p(fg|z) + p(x|bg, z) * p(bg|z)
        image_recon_log_prob = torch.stack([fg_recon_log_prob, bg_recon_log_prob], dim=1)
        # log likelihood, (b, c0, h0, w0)
        image_recon_log_prob = torch.logsumexp(image_recon_log_prob, dim=1)

        observation_nll = - torch.sum(image_recon_log_prob, dim=[1, 2, 3])
        loss_dict['recon_loss'] = observation_nll.mean() * recon_weight

        # ---------------------------------------------------------------------------- #
        # KL divergence (z_where)
        # ---------------------------------------------------------------------------- #
        if 'z_where_loc_prior' in data_batch and 'z_where_scale_prior' in data_batch:
            z_where_post = pd_dict['z_where_post']  # (b, A * h1 * w1, 4)
            z_where_prior = Normal(loc=data_batch['z_where_loc_prior'],
                                   scale=data_batch['z_where_scale_prior'],
                                   )
            # (b, A * h1 * w1, 4)
            kl_where = kl_divergence(z_where_post, z_where_prior.expand(z_where_post.batch_shape))
            kl_where = kl_where.reshape(b, -1).sum(1)
            loss_dict['kl_where_loss'] = kl_where.mean() * kl_weight

        # ---------------------------------------------------------------------------- #
        # KL divergence (z_what)
        # ---------------------------------------------------------------------------- #
        if 'z_what_loc_prior' in data_batch and 'z_what_scale_prior' in data_batch:
            z_what_post = pd_dict['z_what_post']  # (b * A * h1 * w1, z_what_size)
            z_what_prior = Normal(loc=data_batch['z_what_loc_prior'],
                                  scale=data_batch['z_what_scale_prior'],
                                  )
            # (b * A * h1 * w1, z_what_size)
            kl_what = kl_divergence(z_what_post, z_what_prior.expand(z_what_post.batch_shape))
            kl_what = kl_what.reshape(b, -1).sum(1)
            loss_dict['kl_what_loss'] = kl_what.mean() * kl_weight

        # ---------------------------------------------------------------------------- #
        # KL divergence (z_pres)
        # ---------------------------------------------------------------------------- #
        if 'z_pres_p_prior' in data_batch:
            z_pres_p = pd_dict['z_pres_p']  # (b, A * h1 * w1)
            z_pres_post = Bernoulli(probs=z_pres_p)
            z_pres_prior = Bernoulli(probs=data_batch['z_pres_p_prior'])
            kl_pres = kl_divergence(z_pres_post, z_pres_prior.expand(z_pres_post.batch_shape))
            kl_pres = kl_pres.reshape(b, -1).sum(1)
            loss_dict['kl_pres_loss'] = kl_pres.mean() * kl_weight

        # ---------------------------------------------------------------------------- #
        # KL divergence (z_depth)
        # ---------------------------------------------------------------------------- #
        if 'z_depth_loc_prior' in data_batch and 'z_depth_scale_prior' in data_batch:
            z_depth_post = pd_dict['z_depth_post']  # (b, A * h1 * w1)
            z_depth_prior = Normal(loc=data_batch['z_depth_loc_prior'],
                                   scale=data_batch['z_depth_scale_prior'],
                                   )
            # (b, A * h1 * w1)
            kl_depth = kl_divergence(z_depth_post, z_depth_prior.expand(z_depth_post.batch_shape))
            kl_depth = kl_depth.reshape(b, -1).sum(1)
            loss_dict['kl_depth_loss'] = kl_depth.mean() * kl_weight

        return loss_dict
Example #4
0
 def expand(self, batch_shape):
     return Normal.expand(self, batch_shape, _instance=self)