def test_log_prob(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") for dtype in (torch.float, torch.double): mean = torch.randn(4, device=device, dtype=dtype) var = torch.randn(4, device=device, dtype=dtype).abs_() values = torch.randn(4, device=device, dtype=dtype) res = MultivariateNormal(mean, DiagLazyTensor(var)).log_prob(values) actual = TMultivariateNormal( mean, torch.eye(4, device=device, dtype=dtype) * var).log_prob(values) self.assertLess((res - actual).div(res).abs().item(), 1e-2) mean = torch.randn(3, 4, device=device, dtype=dtype) var = torch.randn(3, 4, device=device, dtype=dtype).abs_() values = torch.randn(3, 4, device=device, dtype=dtype) res = MultivariateNormal(mean, DiagLazyTensor(var)).log_prob(values) actual = TMultivariateNormal( mean, var.unsqueeze(-1) * torch.eye(4, device=device, dtype=dtype).repeat( 3, 1, 1)).log_prob(values) self.assertLess((res - actual).div(res).abs().norm(), 1e-2)
def _initialize_caches(targets, noise_diagonal, wmat, create_w_cache=True): if len(noise_diagonal.shape) > 2: noise_diagonal = noise_diagonal.squeeze(-1) noise_diagonal = DiagLazyTensor(noise_diagonal) # reshape the targets so that we have niceness in the batch dimensions if targets.ndimension() == 2: targets = targets.unsqueeze(-1) if targets.ndimension() <= 3: targets = targets.transpose(-2, -3) dinv_y = noise_diagonal.inv_matmul(targets) cache_dict = { "response_cache": targets.transpose(-1, -2) @ dinv_y, "interpolation_cache": wmat @ dinv_y, } if create_w_cache: cache_dict["WtW"] = UpdatedRootLazyTensor( wmat @ (noise_diagonal.inv_matmul(wmat.transpose(-1, -2))), initial_is_root=False, ) cache_dict["D_logdet"] = noise_diagonal.logdet() # trim tails in the case of large batch dim if targets.ndimension() > 3: cache_dict["response_cache"] = cache_dict["response_cache"].squeeze(-1) return cache_dict
def forward(self, x1, x2, diag=False, are_equal=True, **params): batch_shape = self.batch_shape leading_dim = x1.size()[:-2] if self.constant_noise: _are_equal = (x1.shape == x2.shape) else: _are_equal = torch.equal(x1, x2) and are_equal if _are_equal: noise_var = torch.exp(self.noise_log_var).expand(-1, x1.size(-2)) K = DiagLazyTensor(noise_var) else: K = ZeroLazyTensor(*leading_dim, x1.size(-2), x2.size(-2), dtype=x1.dtype, device=x1.device) if diag: K = K.diag() if not leading_dim: K = K.unsqueeze(0) return K # return torch.tensor rather than lazy. Consistent with other's kernels behavior return K
def test_dirichlet_classification_likelihood(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") for dtype in (torch.float, torch.double): noise = torch.rand(6, device=device, dtype=dtype) > 0.5 noise = noise.long() lkhd = DirichletClassificationLikelihood(noise, dtype=dtype) # test basics self.assertIsInstance(lkhd.noise_covar, FixedGaussianNoise) noise = torch.rand(6, device=device, dtype=dtype) > 0.5 noise = noise.long() new_noise, _, _ = lkhd._prepare_targets(noise, dtype=dtype) lkhd.noise = new_noise self.assertTrue(torch.equal(lkhd.noise, new_noise)) # test __call__ mean = torch.zeros(6, device=device, dtype=dtype) covar = DiagLazyTensor(torch.ones(6, device=device, dtype=dtype)) mvn = MultivariateNormal(mean, covar) out = lkhd(mvn) self.assertTrue(torch.allclose(out.variance, 1 + new_noise)) # things should break if dimensions mismatch mean = torch.zeros(5, device=device, dtype=dtype) covar = DiagLazyTensor(torch.ones(5, device=device, dtype=dtype)) mvn = MultivariateNormal(mean, covar) with self.assertWarns(UserWarning): lkhd(mvn) # test __call__ w/ new targets obs_noise = 0.1 + torch.rand(5, device=device, dtype=dtype) obs_noise = (obs_noise > 0.5).long() out = lkhd(mvn, targets=obs_noise) obs_targets, _, _ = lkhd._prepare_targets(obs_noise, dtype=dtype) self.assertTrue(torch.allclose(out.variance, 1.0 + obs_targets))
def test_fixed_noise_gaussian_likelihood(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") for dtype in (torch.float, torch.double): noise = 0.1 + torch.rand(4, device=device, dtype=dtype) lkhd = FixedNoiseGaussianLikelihood(noise=noise) # test basics self.assertIsInstance(lkhd.noise_covar, FixedGaussianNoise) self.assertTrue(torch.equal(noise, lkhd.noise)) new_noise = 0.1 + torch.rand(4, device=device, dtype=dtype) lkhd.noise = new_noise self.assertTrue(torch.equal(lkhd.noise, new_noise)) # test __call__ mean = torch.zeros(4, device=device, dtype=dtype) covar = DiagLazyTensor(torch.ones(4, device=device, dtype=dtype)) mvn = MultivariateNormal(mean, covar) out = lkhd(mvn) self.assertTrue(torch.allclose(out.variance, 1 + new_noise)) # things should break if dimensions mismatch mean = torch.zeros(5, device=device, dtype=dtype) covar = DiagLazyTensor(torch.ones(5, device=device, dtype=dtype)) mvn = MultivariateNormal(mean, covar) with self.assertWarns(UserWarning): lkhd(mvn) # test __call__ w/ observation noise obs_noise = 0.1 + torch.rand(5, device=device, dtype=dtype) out = lkhd(mvn, noise=obs_noise) self.assertTrue(torch.allclose(out.variance, 1 + obs_noise))
def test_from_independent_mvns(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") for dtype in (torch.float, torch.double): # Test non-batch mode mvns n_tasks = 2 n = 4 mvns = [ MultivariateNormal( mean=torch.randn(4, device=device, dtype=dtype), covariance_matrix=DiagLazyTensor( torch.randn(n, device=device, dtype=dtype).abs_()), ) for i in range(n_tasks) ] mvn = MultitaskMultivariateNormal.from_independent_mvns(mvns=mvns) expected_mean_shape = [n, n_tasks] expected_covar_shape = [n * n_tasks] * 2 self.assertEqual(list(mvn.mean.shape), expected_mean_shape) self.assertEqual(list(mvn.covariance_matrix.shape), expected_covar_shape) # Test batch mode mvns b = 3 mvns = [ MultivariateNormal( mean=torch.randn(b, n, device=device, dtype=dtype), covariance_matrix=DiagLazyTensor( torch.randn(b, n, device=device, dtype=dtype).abs_()), ) for i in range(n_tasks) ] mvn = MultitaskMultivariateNormal.from_independent_mvns(mvns=mvns) self.assertEqual(list(mvn.mean.shape), [b] + expected_mean_shape) self.assertEqual(list(mvn.covariance_matrix.shape), [b] + expected_covar_shape)
def create_lazy_tensor(self): a = torch.tensor([4.0, 1.0, 2.0], dtype=torch.float) b = torch.tensor([3.0, 1.3], dtype=torch.float) c = torch.tensor([1.75, 1.95, 1.05, 0.25], dtype=torch.float) a.requires_grad_(True) b.requires_grad_(True) c.requires_grad_(True) kp_lazy_tensor = KroneckerProductDiagLazyTensor(DiagLazyTensor(a), DiagLazyTensor(b), DiagLazyTensor(c)) return kp_lazy_tensor
def test_sample(self): res = DiagLazyTensor(diag) actual = res.evaluate() with gpytorch.settings.max_root_decomposition_size(1000): samples = res.zero_mean_mvn_samples(10000) sample_covar = samples.unsqueeze(-1).matmul( samples.unsqueeze(-2)).mean(0) self.assertLess(((sample_covar - actual).abs() / actual.abs().clamp(1, 1e5)).max().item(), 4e-1)
def test_batch_getitem(self): # 2d diag_lv = DiagLazyTensor(diag.repeat(5, 1)) diag_ev = diag_lv.evaluate() self.assertTrue( torch.equal(diag_lv[0, 0:2].evaluate(), diag_ev[0, 0:2])) self.assertTrue( torch.equal(diag_lv[0, 0:2, :3].evaluate(), diag_ev[0, 0:2, :3])) self.assertTrue( torch.equal(diag_lv[:, 0:2, :3].evaluate(), diag_ev[:, 0:2, :3]))
def forward(self, x1, x2): if self.training and torch.equal(x1, x2): # Reshape into a batch of batch_size diagonal matrices, each of which is # (data_size * task_size) x (data_size * task_size) return DiagLazyTensor( self.variances.view(self.variances.size(0), -1)) elif x1.size(-2) == x2.size(-2) and x1.size(-2) == self.variances.size( 1) and torch.equal(x1, x2): return DiagLazyTensor( self.variances.view(self.variances.size(0), -1)) else: return ZeroLazyTensor(x1.size(-3), x1.size(-2), x2.size(-2))
def _covar_diag(self, inputs): if inputs.ndimension() == 1: inputs = inputs.unsqueeze(1) # Get diagonal of covar covar_diag = delazify(self.base_kernel(inputs, diag=True)) return DiagLazyTensor(covar_diag)
def test_log_prob(self): mean = torch.randn(4) var = torch.randn(4).abs_() values = torch.randn(4) res = MultivariateNormal(mean, DiagLazyTensor(var)).log_prob(values) actual = TMultivariateNormal(mean, torch.eye(4) * var).log_prob(values) self.assertLess((res - actual).div(res).abs().item(), 1e-2) mean = torch.randn(3, 4) var = torch.randn(3, 4).abs_() values = torch.randn(3, 4) res = MultivariateNormal(mean, DiagLazyTensor(var)).log_prob(values) actual = TMultivariateNormal(mean, var.unsqueeze(-1) * torch.eye(4).repeat(3, 1, 1)).log_prob(values) self.assertLess((res - actual).div(res).abs().norm(), 1e-2)
def untransform_posterior(self, posterior: Posterior) -> Posterior: r"""Un-standardize the posterior. Args: posterior: A posterior in the standardized space. Returns: The un-standardized posterior. If the input posterior is a MVN, the transformed posterior is again an MVN. """ if self._outputs is not None: raise NotImplementedError( "Standardize does not yet support output selection for " "untransform_posterior" ) if not self._m == posterior.event_shape[-1]: raise RuntimeError( "Incompatible output dimensions encountered for transform " f"{self._m} and posterior {posterior.event_shape[-1]}" ) if not isinstance(posterior, GPyTorchPosterior): # fall back to TransformedPosterior return TransformedPosterior( posterior=posterior, sample_transform=lambda s: self.means + self.stdvs * s, mean_transform=lambda m, v: self.means + self.stdvs * m, variance_transform=lambda m, v: self._stdvs_sq * v, ) # GPyTorchPosterior (TODO: Should we Lazy-evaluate the mean here as well?) mvn = posterior.mvn offset = self.means scale_fac = self.stdvs if not posterior._is_mt: mean_tf = offset.squeeze(-1) + scale_fac.squeeze(-1) * mvn.mean scale_fac = scale_fac.squeeze(-1).expand_as(mean_tf) else: mean_tf = offset + scale_fac * mvn.mean reps = mean_tf.shape[-2:].numel() // scale_fac.size(-1) scale_fac = scale_fac.squeeze(-2) if mvn._interleaved: scale_fac = scale_fac.repeat(*[1 for _ in scale_fac.shape[:-1]], reps) else: scale_fac = torch.repeat_interleave(scale_fac, reps, dim=-1) if ( not mvn.islazy # TODO: Figure out attribute namming weirdness here or mvn._MultivariateNormal__unbroadcasted_scale_tril is not None ): # if already computed, we can save a lot of time using scale_tril covar_tf = CholLazyTensor(mvn.scale_tril * scale_fac.unsqueeze(-1)) else: lcv = mvn.lazy_covariance_matrix # allow batch-evaluation of the model scale_mat = DiagLazyTensor(scale_fac.expand(lcv.shape[:-1])) covar_tf = scale_mat @ lcv @ scale_mat kwargs = {"interleaved": mvn._interleaved} if posterior._is_mt else {} mvn_tf = mvn.__class__(mean=mean_tf, covariance_matrix=covar_tf, **kwargs) return GPyTorchPosterior(mvn_tf)
def create_lazy_tensor(self): tensor = torch.randn(3, 5, 5) tensor = tensor.transpose(-1, -2).matmul(tensor).detach() diag = torch.tensor( [[1.0, 2.0, 4.0, 2.0, 3.0], [2.0, 1.0, 2.0, 1.0, 4.0], [1.0, 2.0, 2.0, 3.0, 4.0]], requires_grad=True ) return AddedDiagLazyTensor(NonLazyTensor(tensor), DiagLazyTensor(diag))
def forward(self, X, Y=None): # Propagate samples through the layers. F = X for i, gp in enumerate(self.gps): if isinstance(F, MultivariateNormal): # Sample from independent Gaussian m = F.mean.reshape(-1, gp.input_dims) s = F.stddev.reshape(-1, gp.input_dims) F = Normal(loc=m, scale=s).rsample() if i > 0 and self.add_input: # Add input to layers after the first one assert F.shape == X.shape F.add_(X) # Get posterior distribution F = gp(F, Y=Y, likelihood=self.likelihood) if isinstance(self.likelihood, MultitaskGaussianLikelihood): D = self.likelihood.num_tasks else: D = 1 # If our likelihood is one-dimensional but we are outputting multiple # dimensions, we sum the dimensions together if isinstance(F, MultitaskMultivariateNormal) and D == 1: mean = F.mean.sum(axis=-1) variance = F.variance.sum(axis=-1) F = MultivariateNormal(mean, DiagLazyTensor(variance)) return F
def block_logdet(self, var, cov_mat_root): var = flatten(var) cov_mat_lt = RootLazyTensor(cov_mat_root.t()) var_lt = DiagLazyTensor(var + 1e-6) covar_lt = AddedDiagLazyTensor(var_lt, cov_mat_lt) return covar_lt.log_det()
def prior_distribution(self): zeros = torch.zeros( self._variational_distribution.shape(), dtype=self._variational_distribution.dtype, device=self._variational_distribution.device, ) ones = torch.ones_like(zeros) res = MultivariateNormal(zeros, DiagLazyTensor(ones)) return res
def test_unsupported_dimension(self): sampler = SobolQMCNormalSampler(num_samples=2) mean = torch.zeros(1112) cov = DiagLazyTensor(torch.ones(1112)) mvn = MultivariateNormal(mean, cov) posterior = GPyTorchPosterior(mvn) with self.assertRaises(UnsupportedError) as e: sampler(posterior) self.assertIn("Requested: 1112", str(e.exception))
def test_batch_function_factory(self): # 2d diag_var1 = diag.repeat(5, 1).requires_grad_(True) diag_var2 = diag.repeat(5, 1).requires_grad_(True) test_mat = torch.eye(3).repeat(5, 1, 1) diag_lv = DiagLazyTensor(diag_var1) diag_ev = DiagLazyTensor(diag_var2).evaluate() # Forward res = diag_lv.matmul(test_mat) actual = torch.matmul(diag_ev, test_mat) self.assertLess(torch.norm(res - actual), 1e-4) # Backward res.sum().backward() actual.sum().backward() self.assertLess(torch.norm(diag_var1.grad - diag_var2.grad), 1e-3)
def test_log_prob(self): mean = torch.randn(4, 3) var = torch.randn(12).abs_() values = mean + 0.5 diffs = (values - mean).view(-1) res = MultitaskMultivariateNormal(mean, DiagLazyTensor(var)).log_prob(values) actual = -0.5 * (math.log(math.pi * 2) * 12 + var.log().sum() + (diffs / var * diffs).sum()) self.assertLess((res - actual).div(res).abs().item(), 1e-2) mean = torch.randn(3, 4, 3) var = torch.randn(3, 12).abs_() values = mean + 0.5 diffs = (values - mean).view(3, -1) res = MultitaskMultivariateNormal(mean, DiagLazyTensor(var)).log_prob(values) actual = -0.5 * (math.log(math.pi * 2) * 12 + var.log().sum(-1) + (diffs / var * diffs).sum(-1)) self.assertLess((res - actual).div(res).abs().norm(), 1e-2)
def test_unsupported_dimension(self): sampler = SobolQMCNormalSampler(num_samples=2) maxdim = torch.quasirandom.SobolEngine.MAXDIM + 1 mean = torch.zeros(maxdim) cov = DiagLazyTensor(torch.ones(maxdim)) mvn = MultivariateNormal(mean, cov) posterior = GPyTorchPosterior(mvn) with self.assertRaises(UnsupportedError) as e: sampler(posterior) self.assertIn(f"Requested: {maxdim}", str(e.exception))
def _matmul(self, rhs): # We decompose the result in two parts: diag_res and off_diag_res # Approximately: # diag_res is the result of eye(p) * rhs # off_diag_res is the result of diag(self._off_diag) * rhs # To get the correct result, all matrices in these two products may be # masked or concatenated with rows or columns of zeros depending on # self.square and self.upper from gpytorch.lazy import DiagLazyTensor is_vector = rhs.ndimension() == 1 if is_vector: rhs = rhs.unsqueeze(-1) alldim = slice(None, None, None) # Alias to shorten code batch_size = rhs.shape[:-2] batch_slices = tuple(alldim for _ in range(len(batch_size))) # Off diag extract_off_diag = (slice(1, None, None), alldim) if self.upper else\ (slice(None, -1, None), alldim) if self.square else\ (alldim, alldim) extract_off_diag = batch_slices + extract_off_diag off_diag_rhs = rhs[extract_off_diag] off_diag_res = DiagLazyTensor(self._off_diag).matmul(off_diag_rhs) if (not self.square) and self.upper: pass else: zero_row = torch.zeros(*batch_size, 1, rhs.size(-1), dtype=rhs.dtype, device=rhs.device) to_cat = (off_diag_res, zero_row) if self.upper else \ (zero_row, off_diag_res) off_diag_res = torch.cat(to_cat, dim=-2) # Diag if self.square: diag_res = rhs elif self.upper: extract_diag = batch_slices + (slice(None, -1, None), alldim) diag_res = rhs[extract_diag] else: zero_row = torch.zeros(*batch_size, 1, rhs.size(-1), dtype=rhs.dtype, device=rhs.device) diag_res = torch.cat((rhs, zero_row), dim=-2) res = diag_res + off_diag_res if is_vector: res = res.squeeze(-1) return res
def test_precond_solve(self): seed = 4 torch.random.manual_seed(seed) tensor = torch.randn(1000, 800) diag = torch.abs(torch.randn(1000)) standard_lt = AddedDiagLazyTensor(RootLazyTensor(tensor), DiagLazyTensor(diag)) evals, evecs = standard_lt.symeig(eigenvectors=True) # this preconditioner is a simple example of near deflation def nonstandard_preconditioner(self): top_100_evecs = evecs[:, :100] top_100_evals = evals[:100] + 0.2 * torch.randn(100) precond_lt = RootLazyTensor( top_100_evecs @ torch.diag(top_100_evals**0.5)) logdet = top_100_evals.log().sum() def precond_closure(rhs): rhs2 = top_100_evecs.t() @ rhs return top_100_evecs @ torch.diag(1.0 / top_100_evals) @ rhs2 return precond_closure, precond_lt, logdet overrode_lt = AddedDiagLazyTensor( RootLazyTensor(tensor), DiagLazyTensor(diag), preconditioner_override=nonstandard_preconditioner) # compute a solve - mostly to make sure that we can actually perform the solve rhs = torch.randn(1000, 1) standard_solve = standard_lt.inv_matmul(rhs) overrode_solve = overrode_lt.inv_matmul(rhs) # gut checking that our preconditioner is not breaking anything self.assertEqual(standard_solve.shape, overrode_solve.shape) self.assertLess( torch.norm(standard_solve - overrode_solve) / standard_solve.norm(), 1.0)
def create_lazy_tensor(self): a = torch.tensor([[4, 0, 2], [0, 3, -1], [2, -1, 3]], dtype=torch.float) b = torch.tensor([[2, 1], [1, 2]], dtype=torch.float) c = torch.tensor([[4, 0.5, 1, 0], [0.5, 4, -1, 0], [1, -1, 3, 0], [0, 0, 0, 4]], dtype=torch.float) d = 0.5 * torch.rand(24, dtype=torch.float) a.requires_grad_(True) b.requires_grad_(True) c.requires_grad_(True) d.requires_grad_(True) kp_lazy_tensor = KroneckerProductLazyTensor(NonLazyTensor(a), NonLazyTensor(b), NonLazyTensor(c)) diag_lazy_tensor = DiagLazyTensor(d) return KroneckerProductAddedDiagLazyTensor(kp_lazy_tensor, diag_lazy_tensor)
def test_kl_divergence(self): mean0 = torch.randn(4) mean1 = mean0 + 1 var0 = torch.randn(4).abs_() var1 = var0 * math.exp(2) dist_a = MultivariateNormal(mean0, DiagLazyTensor(var0)) dist_b = MultivariateNormal(mean1, DiagLazyTensor(var0)) dist_c = MultivariateNormal(mean0, DiagLazyTensor(var1)) res = torch.distributions.kl.kl_divergence(dist_a, dist_a) actual = 0.0 self.assertLess((res - actual).abs().item(), 1e-2) res = torch.distributions.kl.kl_divergence(dist_b, dist_a) actual = var0.reciprocal().sum().div(2.0) self.assertLess((res - actual).div(res).abs().item(), 1e-2) res = torch.distributions.kl.kl_divergence(dist_a, dist_c) actual = 0.5 * (8 - 4 + 4 * math.exp(-2)) self.assertLess((res - actual).div(res).abs().item(), 1e-2)
def test_multitask_from_repeat(self): mean = torch.randn(2, 3) variance = torch.randn(2, 3).clamp_min(1e-6) mvn = MultivariateNormal(mean, DiagLazyTensor(variance)) mmvn = MultitaskMultivariateNormal.from_repeated_mvn(mvn, num_tasks=4) self.assertTrue(isinstance(mmvn, MultitaskMultivariateNormal)) self.assertEqual(mmvn.batch_shape, torch.Size([2])) self.assertEqual(mmvn.event_shape, torch.Size([3, 4])) self.assertEqual(mmvn.covariance_matrix.shape, torch.Size([2, 12, 12])) for i in range(4): self.assertEqual(mmvn.mean[..., i], mean) self.assertEqual(mmvn.variance[..., i], variance)
def compute_ll_for_block(self, vec, mean, var, cov_mat_root): vec = flatten(vec) mean = flatten(mean) var = flatten(var) cov_mat_lt = RootLazyTensor(cov_mat_root.t()) var_lt = DiagLazyTensor(var + 1e-6) covar_lt = AddedDiagLazyTensor(var_lt, cov_mat_lt) qdist = MultivariateNormal(mean, covar_lt) with gpytorch.settings.num_trace_samples(1) and gpytorch.settings.max_cg_iterations(25): return qdist.log_prob(vec)
def test_multitask_from_batch(self): mean = torch.randn(2, 3) variance = torch.randn(2, 3).clamp_min(1e-6) mvn = MultivariateNormal(mean, DiagLazyTensor(variance)) mmvn = MultitaskMultivariateNormal.from_batch_mvn(mvn, task_dim=-1) self.assertTrue(isinstance(mmvn, MultitaskMultivariateNormal)) self.assertEqual(mmvn.batch_shape, torch.Size([])) self.assertEqual(mmvn.event_shape, torch.Size([3, 2])) self.assertEqual(mmvn.covariance_matrix.shape, torch.Size([6, 6])) self.assertEqual(mmvn.mean, mean.transpose(-1, -2)) self.assertEqual(mmvn.variance, variance.transpose(-1, -2)) mean = torch.randn(2, 4, 3) variance = torch.randn(2, 4, 3).clamp_min(1e-6) mvn = MultivariateNormal(mean, DiagLazyTensor(variance)) mmvn = MultitaskMultivariateNormal.from_batch_mvn(mvn, task_dim=0) self.assertTrue(isinstance(mmvn, MultitaskMultivariateNormal)) self.assertEqual(mmvn.batch_shape, torch.Size([4])) self.assertEqual(mmvn.event_shape, torch.Size([3, 2])) self.assertEqual(mmvn.covariance_matrix.shape, torch.Size([4, 6, 6])) self.assertEqual(mmvn.mean, mean.permute(1, 2, 0)) self.assertEqual(mmvn.variance, variance.permute(1, 2, 0))
def forward(self, indices=None): """ Return the variational posterior for the latent variables, pertaining to provided indices """ if indices is None: ms = self.variational_mean vs = self.variational_variance else: ms = self.variational_mean[indices] vs = self.variational_variance[indices] vs = vs.expand(len(vs), self.output_dims) if self.output_dims == 1: m, = ms v, = vs return MultivariateNormal(m, DiagLazyTensor(v)) else: mvns = [MultivariateNormal(m, DiagLazyTensor(v)) for m, v in zip(ms.T, vs.T)] return MultitaskMultivariateNormal.from_independent_mvns(mvns)
def test_kl_divergence(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") for dtype in (torch.float, torch.double): mean0 = torch.randn(4, device=device, dtype=dtype) mean1 = mean0 + 1 var0 = torch.randn(4, device=device, dtype=dtype).abs_() var1 = var0 * math.exp(2) dist_a = MultivariateNormal(mean0, DiagLazyTensor(var0)) dist_b = MultivariateNormal(mean1, DiagLazyTensor(var0)) dist_c = MultivariateNormal(mean0, DiagLazyTensor(var1)) res = torch.distributions.kl.kl_divergence(dist_a, dist_a) actual = 0.0 self.assertLess((res - actual).abs().item(), 1e-2) res = torch.distributions.kl.kl_divergence(dist_b, dist_a) actual = var0.reciprocal().sum().div(2.0) self.assertLess((res - actual).div(res).abs().item(), 1e-2) res = torch.distributions.kl.kl_divergence(dist_a, dist_c) actual = 0.5 * (8 - 4 + 4 * math.exp(-2)) self.assertLess((res - actual).div(res).abs().item(), 1e-2)