예제 #1
0
def test_kl_divergence():
    mask = torch.tensor([[0, 1], [1, 1]]).bool()
    p = Normal(torch.randn(2, 2), torch.randn(2, 2).exp())
    q = Normal(torch.randn(2, 2), torch.randn(2, 2).exp())
    expected = kl_divergence(p.to_event(2), q.to_event(2))
    actual = (kl_divergence(p.mask(mask).to_event(2),
                            q.mask(mask).to_event(2)) +
              kl_divergence(p.mask(~mask).to_event(2),
                            q.mask(~mask).to_event(2)))
    assert_equal(actual, expected)
예제 #2
0
    def generate(self, num_to_sample: int = 1):
        cuda_device = self._get_prediction_device()
        prior_mean = nn_util.move_to_device(
            torch.zeros((num_to_sample, self._latent_dim)), cuda_device)
        prior_stddev = torch.ones_like(prior_mean)
        prior = Normal(prior_mean, prior_stddev)
        latent = prior.sample()
        generated = self._decoder.generate(latent)

        return self.decode(generated)
예제 #3
0
def test_broadcast(event_shape, dist_shape, mask_shape):
    mask = torch.empty(torch.Size(mask_shape)).bernoulli_(0.5).bool()
    base_dist = Normal(torch.zeros(dist_shape + event_shape), 1.)
    base_dist = base_dist.to_event(len(event_shape))
    assert base_dist.batch_shape == dist_shape
    assert base_dist.event_shape == event_shape

    d = base_dist.mask(mask)
    d_shape = broadcast_shape(mask.shape, base_dist.batch_shape)
    assert d.batch_shape == d_shape
    assert d.event_shape == event_shape
예제 #4
0
 def forward(
     self, source_tokens: Dict[str, torch.LongTensor]
 ) -> Tuple[torch.Tensor, torch.Tensor]:
     # pylint: disable=arguments-differ
     """
     Make a forward pass of the encoder, then returning the hidden state.
     """
     final_state = self.encode(source_tokens)
     mean = self._latent_to_mean(final_state)
     logvar = self._latent_to_logvar(final_state)
     prior = Normal(
         torch.zeros((mean.size(0), self.latent_dim), device=mean.device),
         torch.ones((mean.size(0), self.latent_dim), device=mean.device))
     posterior = Normal(mean, (0.5 * logvar).exp())
     return {
         'prior': prior,
         'posterior': posterior,
     }
예제 #5
0
 def propose_log_prob(self, value):
     v = value / self._d
     result = -self._d.log()
     y = v.pow(1 / 3)
     result -= torch.log(3 * y**2)
     x = (y - 1) / self._c
     result -= self._c.log()
     result += Normal(torch.zeros_like(self.concentration),
                      torch.ones_like(self.concentration)).log_prob(x)
     return result
예제 #6
0
 def reparametrize(prior: Normal,
                   posterior: Normal,
                   temperature: float = 1.0) -> torch.Tensor:
     """
     Creating the latent vector using the reparameterization trick
     """
     mean = posterior.mean
     std = posterior.stddev
     eps = prior.rsample()
     return eps.mul(std * temperature).add_(mean)
예제 #7
0
파일: test_mask.py 프로젝트: www3cam/pyro
def test_kl_divergence_type(p_mask, q_mask):
    p = Normal(torch.randn(2, 2), torch.randn(2, 2).exp())
    q = Normal(torch.randn(2, 2), torch.randn(2, 2).exp())
    mask = (
        (torch.tensor(p_mask) if isinstance(p_mask, bool) else p_mask) &
        (torch.tensor(q_mask) if isinstance(q_mask, bool) else q_mask)).expand(
            2, 2)

    expected = kl_divergence(p, q)
    expected[~mask] = 0

    actual = kl_divergence(p.mask(p_mask), q.mask(q_mask))
    if p_mask is False or q_mask is False:
        assert isinstance(actual, float) and actual == 0.
    else:
        assert_equal(actual, expected)
예제 #8
0
파일: test_mask.py 프로젝트: www3cam/pyro
def test_mask_type(mask):
    p = Normal(torch.randn(2, 2), torch.randn(2, 2).exp())
    p_masked = p.mask(mask)
    if isinstance(mask, bool):
        mask = torch.tensor(mask)

    x = p.sample()
    actual = p_masked.log_prob(x)
    expected = p.log_prob(x) * mask.float()
    assert_equal(actual, expected)

    actual = p_masked.score_parts(x)
    expected = p.score_parts(x)
    for a, e in zip(actual, expected):
        if isinstance(e, torch.Tensor):
            e = e * mask.float()
        assert_equal(a, e)