def test_multivariate_normal_batch_non_lazy(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") mean = torch.tensor([0, 1, 2], dtype=torch.float, device=device) covmat = torch.diag(torch.tensor([1, 0.75, 1.5], device=device)) mvn = MultivariateNormal(mean=mean.repeat(2, 1), covariance_matrix=covmat.repeat(2, 1, 1), validate_args=True) self.assertTrue(torch.is_tensor(mvn.covariance_matrix)) self.assertIsInstance(mvn.lazy_covariance_matrix, LazyTensor) self.assertTrue(approx_equal(mvn.variance, covmat.diag().repeat(2, 1))) self.assertTrue(approx_equal(mvn.scale_tril, torch.diag(covmat.diag().sqrt()).repeat(2, 1, 1))) mvn_plus1 = mvn + 1 self.assertTrue(torch.equal(mvn_plus1.mean, mvn.mean + 1)) self.assertTrue(torch.equal(mvn_plus1.covariance_matrix, mvn.covariance_matrix)) mvn_times2 = mvn * 2 self.assertTrue(torch.equal(mvn_times2.mean, mvn.mean * 2)) self.assertTrue(torch.equal(mvn_times2.covariance_matrix, mvn.covariance_matrix * 4)) mvn_divby2 = mvn / 2 self.assertTrue(torch.equal(mvn_divby2.mean, mvn.mean / 2)) self.assertTrue(torch.equal(mvn_divby2.covariance_matrix, mvn.covariance_matrix / 4)) self.assertTrue(approx_equal(mvn.entropy(), 4.3157 * torch.ones(2, device=device))) logprob = mvn.log_prob(torch.zeros(2, 3, device=device)) logprob_expected = -4.8157 * torch.ones(2, device=device) self.assertTrue(approx_equal(logprob, logprob_expected)) logprob = mvn.log_prob(torch.zeros(2, 2, 3, device=device)) logprob_expected = -4.8157 * torch.ones(2, 2, device=device) self.assertTrue(approx_equal(logprob, logprob_expected)) conf_lower, conf_upper = mvn.confidence_region() self.assertTrue(approx_equal(conf_lower, mvn.mean - 2 * mvn.stddev)) self.assertTrue(approx_equal(conf_upper, mvn.mean + 2 * mvn.stddev)) self.assertTrue(mvn.sample().shape == torch.Size([2, 3])) self.assertTrue(mvn.sample(torch.Size([2])).shape == torch.Size([2, 2, 3])) self.assertTrue(mvn.sample(torch.Size([2, 4])).shape == torch.Size([2, 4, 2, 3]))
def test_multivariate_normal_lazy(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") mean = torch.tensor([0, 1, 2], dtype=torch.float, device=device) covmat = torch.diag(torch.tensor([1, 0.75, 1.5], device=device)) mvn = MultivariateNormal(mean=mean, covariance_matrix=NonLazyTensor(covmat)) self.assertTrue(torch.is_tensor(mvn.covariance_matrix)) self.assertIsInstance(mvn.lazy_covariance_matrix, LazyTensor) self.assertTrue(torch.equal(mvn.variance, torch.diag(covmat))) self.assertTrue(torch.equal(mvn.covariance_matrix, covmat)) mvn_plus1 = mvn + 1 self.assertTrue(torch.equal(mvn_plus1.mean, mvn.mean + 1)) self.assertTrue(torch.equal(mvn_plus1.covariance_matrix, mvn.covariance_matrix)) mvn_times2 = mvn * 2 self.assertTrue(torch.equal(mvn_times2.mean, mvn.mean * 2)) self.assertTrue(torch.equal(mvn_times2.covariance_matrix, mvn.covariance_matrix * 4)) mvn_divby2 = mvn / 2 self.assertTrue(torch.equal(mvn_divby2.mean, mvn.mean / 2)) self.assertTrue(torch.equal(mvn_divby2.covariance_matrix, mvn.covariance_matrix / 4)) # TODO: Add tests for entropy, log_prob, etc. - this an issue b/c it # uses using root_decomposition which is not very reliable # self.assertAlmostEqual(mvn.entropy().item(), 4.3157, places=4) # self.assertAlmostEqual(mvn.log_prob(torch.zeros(3)).item(), -4.8157, places=4) # self.assertTrue( # approx_equal( # mvn.log_prob(torch.zeros(2, 3)), -4.8157 * torch.ones(2)) # ) # ) conf_lower, conf_upper = mvn.confidence_region() self.assertTrue(approx_equal(conf_lower, mvn.mean - 2 * mvn.stddev)) self.assertTrue(approx_equal(conf_upper, mvn.mean + 2 * mvn.stddev)) self.assertTrue(mvn.sample().shape == torch.Size([3])) self.assertTrue(mvn.sample(torch.Size([2])).shape == torch.Size([2, 3])) self.assertTrue(mvn.sample(torch.Size([2, 4])).shape == torch.Size([2, 4, 3]))
def test_multivariate_normal_batch_correlated_samples(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") mean = torch.tensor([0, 1, 2], dtype=torch.float, device=device) covmat = torch.diag(torch.tensor([1, 0.75, 1.5], device=device)) mvn = MultivariateNormal(mean=mean.repeat(2, 1), covariance_matrix=NonLazyTensor(covmat).repeat(2, 1, 1)) base_samples = mvn.get_base_samples(torch.Size((3, 4))) self.assertTrue(mvn.sample(base_samples=base_samples).shape == torch.Size([3, 4, 2, 3])) base_samples = mvn.get_base_samples() self.assertTrue(mvn.sample(base_samples=base_samples).shape == torch.Size([2, 3]))
def test_multivariate_normal_batch_lazy(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") for dtype in (torch.float, torch.double): mean = torch.tensor([0, 1, 2], device=device, dtype=dtype).repeat(2, 1) covmat = torch.diag( torch.tensor([1, 0.75, 1.5], device=device, dtype=dtype)).repeat(2, 1, 1) covmat_chol = torch.cholesky(covmat) mvn = MultivariateNormal(mean=mean, covariance_matrix=NonLazyTensor(covmat)) self.assertTrue(torch.is_tensor(mvn.covariance_matrix)) self.assertIsInstance(mvn.lazy_covariance_matrix, LazyTensor) self.assertAllClose(mvn.variance, torch.diagonal(covmat, dim1=-2, dim2=-1)) self.assertAllClose(mvn._unbroadcasted_scale_tril, covmat_chol) mvn_plus1 = mvn + 1 self.assertAllClose(mvn_plus1.mean, mvn.mean + 1) self.assertAllClose(mvn_plus1.covariance_matrix, mvn.covariance_matrix) self.assertAllClose(mvn_plus1._unbroadcasted_scale_tril, covmat_chol) mvn_times2 = mvn * 2 self.assertAllClose(mvn_times2.mean, mvn.mean * 2) self.assertAllClose(mvn_times2.covariance_matrix, mvn.covariance_matrix * 4) self.assertAllClose(mvn_times2._unbroadcasted_scale_tril, covmat_chol * 2) mvn_divby2 = mvn / 2 self.assertAllClose(mvn_divby2.mean, mvn.mean / 2) self.assertAllClose(mvn_divby2.covariance_matrix, mvn.covariance_matrix / 4) self.assertAllClose(mvn_divby2._unbroadcasted_scale_tril, covmat_chol / 2) # TODO: Add tests for entropy, log_prob, etc. - this an issue b/c it # uses using root_decomposition which is not very reliable # self.assertTrue(torch.allclose(mvn.entropy(), 4.3157 * torch.ones(2))) # self.assertTrue( # torch.allclose(mvn.log_prob(torch.zeros(2, 3)), -4.8157 * torch.ones(2)) # ) # self.assertTrue( # torch.allclose(mvn.log_prob(torch.zeros(2, 2, 3)), -4.8157 * torch.ones(2, 2)) # ) conf_lower, conf_upper = mvn.confidence_region() self.assertAllClose(conf_lower, mvn.mean - 2 * mvn.stddev) self.assertAllClose(conf_upper, mvn.mean + 2 * mvn.stddev) self.assertTrue(mvn.sample().shape == torch.Size([2, 3])) self.assertTrue( mvn.sample(torch.Size([2])).shape == torch.Size([2, 2, 3])) self.assertTrue( mvn.sample(torch.Size([2, 4])).shape == torch.Size( [2, 4, 2, 3]))
def test_multivariate_normal_non_lazy(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") for dtype in (torch.float, torch.double): mean = torch.tensor([0, 1, 2], device=device, dtype=dtype) covmat = torch.diag( torch.tensor([1, 0.75, 1.5], device=device, dtype=dtype)) mvn = MultivariateNormal(mean=mean, covariance_matrix=covmat, validate_args=True) self.assertTrue(torch.is_tensor(mvn.covariance_matrix)) self.assertIsInstance(mvn.lazy_covariance_matrix, LazyTensor) self.assertTrue(torch.allclose(mvn.variance, torch.diag(covmat))) self.assertTrue(torch.allclose(mvn.scale_tril, covmat.sqrt())) mvn_plus1 = mvn + 1 self.assertTrue(torch.equal(mvn_plus1.mean, mvn.mean + 1)) self.assertTrue( torch.equal(mvn_plus1.covariance_matrix, mvn.covariance_matrix)) mvn_times2 = mvn * 2 self.assertTrue(torch.equal(mvn_times2.mean, mvn.mean * 2)) self.assertTrue( torch.equal(mvn_times2.covariance_matrix, mvn.covariance_matrix * 4)) mvn_divby2 = mvn / 2 self.assertTrue(torch.equal(mvn_divby2.mean, mvn.mean / 2)) self.assertTrue( torch.equal(mvn_divby2.covariance_matrix, mvn.covariance_matrix / 4)) self.assertAlmostEqual(mvn.entropy().item(), 4.3157, places=4) self.assertAlmostEqual(mvn.log_prob( torch.zeros(3, device=device, dtype=dtype)).item(), -4.8157, places=4) logprob = mvn.log_prob( torch.zeros(2, 3, device=device, dtype=dtype)) logprob_expected = torch.tensor([-4.8157, -4.8157], device=device, dtype=dtype) self.assertTrue(torch.allclose(logprob, logprob_expected)) conf_lower, conf_upper = mvn.confidence_region() self.assertTrue( torch.allclose(conf_lower, mvn.mean - 2 * mvn.stddev)) self.assertTrue( torch.allclose(conf_upper, mvn.mean + 2 * mvn.stddev)) self.assertTrue(mvn.sample().shape == torch.Size([3])) self.assertTrue( mvn.sample(torch.Size([2])).shape == torch.Size([2, 3])) self.assertTrue( mvn.sample(torch.Size([2, 4])).shape == torch.Size([2, 4, 3]))
def test_multivariate_normal_correlated_samples(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") for dtype in (torch.float, torch.double): mean = torch.tensor([0, 1, 2], device=device, dtype=dtype) covmat = torch.diag( torch.tensor([1, 0.75, 1.5], device=device, dtype=dtype)) mvn = MultivariateNormal(mean=mean, covariance_matrix=NonLazyTensor(covmat)) base_samples = mvn.get_base_samples(torch.Size([3, 4])) self.assertTrue( mvn.sample( base_samples=base_samples).shape == torch.Size([3, 4, 3])) base_samples = mvn.get_base_samples() self.assertTrue( mvn.sample(base_samples=base_samples).shape == torch.Size([3]))
def forward(self, x): for d in range(self.depth - 1): mean_x = self.mean_module(x) covar_x = self.covar_module(x) mvn = MultivariateNormal(mean_x, covar_x) x = mvn.sample(sample_shape=torch.Size([self.dim])) x = x.t() if self.collect: self.collector[d] = x.detach().numpy() # last layer with single output mean_x = self.mean_module(x) covar_x = self.covar_module(x) mvn = MultivariateNormal(mean_x, covar_x) x = mvn.sample() if self.collect: self.collector[self.depth - 1] = x.detach().numpy() return x
def test_missing_value_inference(self): """ samples = mvn samples + noise samples In this test, we try to recover noise parameters when some elements in 'samples' are missing at random. """ torch.manual_seed(self.seed) mu = torch.zeros(2, 3) sigma = torch.tensor([[[1, 0.999, -0.999], [0.999, 1, -0.999], [-0.999, -0.999, 1]]] * 2).float() mvn = MultivariateNormal(mu, sigma) samples = mvn.sample(torch.Size([10000])) # mvn samples noise_sd = 0.5 noise_dist = torch.distributions.Normal(0, noise_sd) samples += noise_dist.sample(samples.shape) # noise missing_prop = 0.33 missing_idx = torch.distributions.Binomial(1, missing_prop).sample(samples.shape).bool() samples[missing_idx] = float("nan") likelihood = GaussianLikelihoodWithMissingObs() # check that the missing value fill doesn't impact the likelihood likelihood.MISSING_VALUE_FILL = 999.0 like_init_plus = likelihood.log_marginal(samples, mvn).sum().data likelihood.MISSING_VALUE_FILL = -999.0 like_init_minus = likelihood.log_marginal(samples, mvn).sum().data torch.testing.assert_allclose(like_init_plus, like_init_minus) # check that the correct noise sd is recovered opt = torch.optim.Adam(likelihood.parameters(), lr=0.05) for _ in range(100): opt.zero_grad() loss = -likelihood.log_marginal(samples, mvn).sum() loss.backward() opt.step() assert abs(float(likelihood.noise.sqrt()) - 0.5) < 0.02 # Check log marginal works likelihood.log_marginal(samples[0], mvn)
def test_natgrad(self, D=5): mu = torch.randn(D) cov = torch.randn(D, D).tril_() dist = MultivariateNormal(mu, CholLazyTensor(TriangularLazyTensor(cov))) sample = dist.sample() v_dist = NaturalVariationalDistribution(D) v_dist.initialize_variational_distribution(dist) mu = v_dist().mean.detach() v_dist().log_prob(sample).squeeze().backward() eta1 = mu.clone().requires_grad_(True) eta2 = (mu[:, None] * mu + cov @ cov.t()).requires_grad_(True) L = torch.cholesky(eta2 - eta1[:, None] * eta1) dist2 = MultivariateNormal(eta1, CholLazyTensor(TriangularLazyTensor(L))) dist2.log_prob(sample).squeeze().backward() assert torch.allclose(v_dist.natural_vec.grad, eta1.grad) assert torch.allclose(v_dist.natural_mat.grad, eta2.grad)
def test_natgrad(self, D=5): mu = torch.randn(D) cov = torch.randn(D, D) cov = cov @ cov.t() dist = MultivariateNormal( mu, CholLazyTensor(TriangularLazyTensor(torch.linalg.cholesky(cov)))) sample = dist.sample() v_dist = TrilNaturalVariationalDistribution(D, mean_init_std=0.0) v_dist.initialize_variational_distribution(dist) v_dist().log_prob(sample).squeeze().backward() dout_dnat1 = v_dist.natural_vec.grad dout_dnat2 = v_dist.natural_tril_mat.grad # mean_init_std=0. because we need to ensure both have the same distribution v_dist_ref = NaturalVariationalDistribution(D, mean_init_std=0.0) v_dist_ref.initialize_variational_distribution(dist) v_dist_ref().log_prob(sample).squeeze().backward() dout_dnat1_noforward_ref = v_dist_ref.natural_vec.grad dout_dnat2_noforward_ref = v_dist_ref.natural_mat.grad def f(natural_vec, natural_tril_mat): "Transform natural_tril_mat to L" Sigma = torch.inverse(-2 * natural_tril_mat) mu = natural_vec return mu, torch.linalg.cholesky(Sigma).inverse().tril() (mu_ref, natural_tril_mat_ref), (dout_dmu_ref, dout_dnat2_ref) = jvp( f, (v_dist_ref.natural_vec.detach(), v_dist_ref.natural_mat.detach()), (dout_dnat1_noforward_ref, dout_dnat2_noforward_ref), ) assert torch.allclose(natural_tril_mat_ref, v_dist.natural_tril_mat), "Sigma transformation" assert torch.allclose(dout_dnat2_ref, dout_dnat2), "Sigma gradient" assert torch.allclose(mu_ref, v_dist.natural_vec), "mu transformation" assert torch.allclose(dout_dmu_ref, dout_dnat1), "mu gradient"
def _initialize_latents( self, latent_init: str, num_latent_dims: List[int], learn_latent_pars: bool, device: torch.device, dtype: torch.dtype, ): self.latent_parameters = ParameterList() if latent_init == "default": for dim_num in range(len(self.covar_modules) - 1): self.latent_parameters.append( Parameter( torch.rand( *self._aug_batch_shape, self.target_shape[dim_num], num_latent_dims[dim_num], device=device, dtype=dtype, ), requires_grad=learn_latent_pars, ) ) elif latent_init == "gp": for dim_num, covar in enumerate(self.covar_modules[1:]): latent_covar = covar( torch.linspace( 0.0, 1.0, self.target_shape[dim_num], device=device, dtype=dtype, ) ).add_jitter(1e-4) latent_dist = MultivariateNormal( torch.zeros( self.target_shape[dim_num], device=device, dtype=dtype, ), latent_covar, ) sample_shape = torch.Size( ( *self._aug_batch_shape, num_latent_dims[dim_num], ) ) latent_sample = latent_dist.sample(sample_shape=sample_shape) latent_sample = latent_sample.reshape( *self._aug_batch_shape, self.target_shape[dim_num], num_latent_dims[dim_num], ) self.latent_parameters.append( Parameter( latent_sample, requires_grad=learn_latent_pars, ) ) self.register_prior( "latent_parameters_" + str(dim_num), MultivariateNormalPrior( latent_dist.loc, latent_dist.covariance_matrix.detach().clone() ), lambda module, dim_num=dim_num: self.latent_parameters[dim_num], )