def test_kl_divergence(): mask = torch.tensor([[0, 1], [1, 1]]).bool() p = Normal(torch.randn(2, 2), torch.randn(2, 2).exp()) q = Normal(torch.randn(2, 2), torch.randn(2, 2).exp()) expected = kl_divergence(p.to_event(2), q.to_event(2)) actual = (kl_divergence(p.mask(mask).to_event(2), q.mask(mask).to_event(2)) + kl_divergence(p.mask(~mask).to_event(2), q.mask(~mask).to_event(2))) assert_equal(actual, expected)
def test_kl_divergence_type(p_mask, q_mask): p = Normal(torch.randn(2, 2), torch.randn(2, 2).exp()) q = Normal(torch.randn(2, 2), torch.randn(2, 2).exp()) mask = ( (torch.tensor(p_mask) if isinstance(p_mask, bool) else p_mask) & (torch.tensor(q_mask) if isinstance(q_mask, bool) else q_mask)).expand( 2, 2) expected = kl_divergence(p, q) expected[~mask] = 0 actual = kl_divergence(p.mask(p_mask), q.mask(q_mask)) if p_mask is False or q_mask is False: assert isinstance(actual, float) and actual == 0. else: assert_equal(actual, expected)
def forward( self, source_tokens: Dict[str, torch.LongTensor] ) -> Tuple[torch.Tensor, torch.Tensor]: # pylint: disable=arguments-differ """ Make a forward pass of the encoder, then returning the hidden state. """ final_state = self.encode(source_tokens) mean = self._latent_to_mean(final_state) logvar = self._latent_to_logvar(final_state) prior = Normal( torch.zeros((mean.size(0), self.latent_dim), device=mean.device), torch.ones((mean.size(0), self.latent_dim), device=mean.device)) posterior = Normal(mean, (0.5 * logvar).exp()) return { 'prior': prior, 'posterior': posterior, }
def propose_log_prob(self, value): v = value / self._d result = -self._d.log() y = v.pow(1 / 3) result -= torch.log(3 * y**2) x = (y - 1) / self._c result -= self._c.log() result += Normal(torch.zeros_like(self.concentration), torch.ones_like(self.concentration)).log_prob(x) return result
def generate(self, num_to_sample: int = 1): cuda_device = self._get_prediction_device() prior_mean = nn_util.move_to_device( torch.zeros((num_to_sample, self._latent_dim)), cuda_device) prior_stddev = torch.ones_like(prior_mean) prior = Normal(prior_mean, prior_stddev) latent = prior.sample() generated = self._decoder.generate(latent) return self.decode(generated)
def test_broadcast(event_shape, dist_shape, mask_shape): mask = torch.empty(torch.Size(mask_shape)).bernoulli_(0.5).bool() base_dist = Normal(torch.zeros(dist_shape + event_shape), 1.) base_dist = base_dist.to_event(len(event_shape)) assert base_dist.batch_shape == dist_shape assert base_dist.event_shape == event_shape d = base_dist.mask(mask) d_shape = broadcast_shape(mask.shape, base_dist.batch_shape) assert d.batch_shape == d_shape assert d.event_shape == event_shape
def test_mask_type(mask): p = Normal(torch.randn(2, 2), torch.randn(2, 2).exp()) p_masked = p.mask(mask) if isinstance(mask, bool): mask = torch.tensor(mask) x = p.sample() actual = p_masked.log_prob(x) expected = p.log_prob(x) * mask.float() assert_equal(actual, expected) actual = p_masked.score_parts(x) expected = p.score_parts(x) for a, e in zip(actual, expected): if isinstance(e, torch.Tensor): e = e * mask.float() assert_equal(a, e)