def test_multivariate_normal_batch_non_lazy(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") mean = torch.tensor([0, 1, 2], dtype=torch.float, device=device) covmat = torch.diag(torch.tensor([1, 0.75, 1.5], device=device)) mvn = MultivariateNormal(mean=mean.repeat(2, 1), covariance_matrix=covmat.repeat(2, 1, 1), validate_args=True) self.assertTrue(torch.is_tensor(mvn.covariance_matrix)) self.assertIsInstance(mvn.lazy_covariance_matrix, LazyTensor) self.assertTrue(approx_equal(mvn.variance, covmat.diag().repeat(2, 1))) self.assertTrue(approx_equal(mvn.scale_tril, torch.diag(covmat.diag().sqrt()).repeat(2, 1, 1))) mvn_plus1 = mvn + 1 self.assertTrue(torch.equal(mvn_plus1.mean, mvn.mean + 1)) self.assertTrue(torch.equal(mvn_plus1.covariance_matrix, mvn.covariance_matrix)) mvn_times2 = mvn * 2 self.assertTrue(torch.equal(mvn_times2.mean, mvn.mean * 2)) self.assertTrue(torch.equal(mvn_times2.covariance_matrix, mvn.covariance_matrix * 4)) mvn_divby2 = mvn / 2 self.assertTrue(torch.equal(mvn_divby2.mean, mvn.mean / 2)) self.assertTrue(torch.equal(mvn_divby2.covariance_matrix, mvn.covariance_matrix / 4)) self.assertTrue(approx_equal(mvn.entropy(), 4.3157 * torch.ones(2, device=device))) logprob = mvn.log_prob(torch.zeros(2, 3, device=device)) logprob_expected = -4.8157 * torch.ones(2, device=device) self.assertTrue(approx_equal(logprob, logprob_expected)) logprob = mvn.log_prob(torch.zeros(2, 2, 3, device=device)) logprob_expected = -4.8157 * torch.ones(2, 2, device=device) self.assertTrue(approx_equal(logprob, logprob_expected)) conf_lower, conf_upper = mvn.confidence_region() self.assertTrue(approx_equal(conf_lower, mvn.mean - 2 * mvn.stddev)) self.assertTrue(approx_equal(conf_upper, mvn.mean + 2 * mvn.stddev)) self.assertTrue(mvn.sample().shape == torch.Size([2, 3])) self.assertTrue(mvn.sample(torch.Size([2])).shape == torch.Size([2, 2, 3])) self.assertTrue(mvn.sample(torch.Size([2, 4])).shape == torch.Size([2, 4, 2, 3]))
def test_multivariate_normal_lazy(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") mean = torch.tensor([0, 1, 2], dtype=torch.float, device=device) covmat = torch.diag(torch.tensor([1, 0.75, 1.5], device=device)) mvn = MultivariateNormal(mean=mean, covariance_matrix=NonLazyTensor(covmat)) self.assertTrue(torch.is_tensor(mvn.covariance_matrix)) self.assertIsInstance(mvn.lazy_covariance_matrix, LazyTensor) self.assertTrue(torch.equal(mvn.variance, torch.diag(covmat))) self.assertTrue(torch.equal(mvn.covariance_matrix, covmat)) mvn_plus1 = mvn + 1 self.assertTrue(torch.equal(mvn_plus1.mean, mvn.mean + 1)) self.assertTrue(torch.equal(mvn_plus1.covariance_matrix, mvn.covariance_matrix)) mvn_times2 = mvn * 2 self.assertTrue(torch.equal(mvn_times2.mean, mvn.mean * 2)) self.assertTrue(torch.equal(mvn_times2.covariance_matrix, mvn.covariance_matrix * 4)) mvn_divby2 = mvn / 2 self.assertTrue(torch.equal(mvn_divby2.mean, mvn.mean / 2)) self.assertTrue(torch.equal(mvn_divby2.covariance_matrix, mvn.covariance_matrix / 4)) # TODO: Add tests for entropy, log_prob, etc. - this an issue b/c it # uses using root_decomposition which is not very reliable # self.assertAlmostEqual(mvn.entropy().item(), 4.3157, places=4) # self.assertAlmostEqual(mvn.log_prob(torch.zeros(3)).item(), -4.8157, places=4) # self.assertTrue( # approx_equal( # mvn.log_prob(torch.zeros(2, 3)), -4.8157 * torch.ones(2)) # ) # ) conf_lower, conf_upper = mvn.confidence_region() self.assertTrue(approx_equal(conf_lower, mvn.mean - 2 * mvn.stddev)) self.assertTrue(approx_equal(conf_upper, mvn.mean + 2 * mvn.stddev)) self.assertTrue(mvn.sample().shape == torch.Size([3])) self.assertTrue(mvn.sample(torch.Size([2])).shape == torch.Size([2, 3])) self.assertTrue(mvn.sample(torch.Size([2, 4])).shape == torch.Size([2, 4, 3]))
def gp_posterior(ax, x: torch.Tensor, preds: MultivariateNormal, ewma_alpha: float = 0.0, label: Optional[str] = None, sort=True, fill_alpha=0.05, **kwargs): x = x.view(-1) if sort: # i = x.argsort(dim=-2)[:, 0] i = x.argsort() if i.equal(torch.arange(i.size(0))): i = slice(None, None, None) else: i = slice(None, None, None) x = n(x[i]) preds_mean = preds.mean.view(-1) mean = ewma(n(preds_mean[i]), ewma_alpha) line, *_ = ax.plot(x, mean, **kwargs) if label is not None: line.set_label(label) C = line.get_color() lower, upper = (p.view(-1) for p in preds.confidence_region()) lower = ewma(n(lower[i]), ewma_alpha) upper = ewma(n(upper[i]), ewma_alpha) ax.fill_between(x, lower, upper, alpha=fill_alpha, color=C) ax.plot(x, lower, color=C, linewidth=0.5) ax.plot(x, upper, color=C, linewidth=0.5)
def test_multivariate_normal_batch_lazy(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") for dtype in (torch.float, torch.double): mean = torch.tensor([0, 1, 2], device=device, dtype=dtype).repeat(2, 1) covmat = torch.diag( torch.tensor([1, 0.75, 1.5], device=device, dtype=dtype)).repeat(2, 1, 1) covmat_chol = torch.cholesky(covmat) mvn = MultivariateNormal(mean=mean, covariance_matrix=NonLazyTensor(covmat)) self.assertTrue(torch.is_tensor(mvn.covariance_matrix)) self.assertIsInstance(mvn.lazy_covariance_matrix, LazyTensor) self.assertAllClose(mvn.variance, torch.diagonal(covmat, dim1=-2, dim2=-1)) self.assertAllClose(mvn._unbroadcasted_scale_tril, covmat_chol) mvn_plus1 = mvn + 1 self.assertAllClose(mvn_plus1.mean, mvn.mean + 1) self.assertAllClose(mvn_plus1.covariance_matrix, mvn.covariance_matrix) self.assertAllClose(mvn_plus1._unbroadcasted_scale_tril, covmat_chol) mvn_times2 = mvn * 2 self.assertAllClose(mvn_times2.mean, mvn.mean * 2) self.assertAllClose(mvn_times2.covariance_matrix, mvn.covariance_matrix * 4) self.assertAllClose(mvn_times2._unbroadcasted_scale_tril, covmat_chol * 2) mvn_divby2 = mvn / 2 self.assertAllClose(mvn_divby2.mean, mvn.mean / 2) self.assertAllClose(mvn_divby2.covariance_matrix, mvn.covariance_matrix / 4) self.assertAllClose(mvn_divby2._unbroadcasted_scale_tril, covmat_chol / 2) # TODO: Add tests for entropy, log_prob, etc. - this an issue b/c it # uses using root_decomposition which is not very reliable # self.assertTrue(torch.allclose(mvn.entropy(), 4.3157 * torch.ones(2))) # self.assertTrue( # torch.allclose(mvn.log_prob(torch.zeros(2, 3)), -4.8157 * torch.ones(2)) # ) # self.assertTrue( # torch.allclose(mvn.log_prob(torch.zeros(2, 2, 3)), -4.8157 * torch.ones(2, 2)) # ) conf_lower, conf_upper = mvn.confidence_region() self.assertAllClose(conf_lower, mvn.mean - 2 * mvn.stddev) self.assertAllClose(conf_upper, mvn.mean + 2 * mvn.stddev) self.assertTrue(mvn.sample().shape == torch.Size([2, 3])) self.assertTrue( mvn.sample(torch.Size([2])).shape == torch.Size([2, 2, 3])) self.assertTrue( mvn.sample(torch.Size([2, 4])).shape == torch.Size( [2, 4, 2, 3]))
def test_multivariate_normal_non_lazy(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") for dtype in (torch.float, torch.double): mean = torch.tensor([0, 1, 2], device=device, dtype=dtype) covmat = torch.diag( torch.tensor([1, 0.75, 1.5], device=device, dtype=dtype)) mvn = MultivariateNormal(mean=mean, covariance_matrix=covmat, validate_args=True) self.assertTrue(torch.is_tensor(mvn.covariance_matrix)) self.assertIsInstance(mvn.lazy_covariance_matrix, LazyTensor) self.assertTrue(torch.allclose(mvn.variance, torch.diag(covmat))) self.assertTrue(torch.allclose(mvn.scale_tril, covmat.sqrt())) mvn_plus1 = mvn + 1 self.assertTrue(torch.equal(mvn_plus1.mean, mvn.mean + 1)) self.assertTrue( torch.equal(mvn_plus1.covariance_matrix, mvn.covariance_matrix)) mvn_times2 = mvn * 2 self.assertTrue(torch.equal(mvn_times2.mean, mvn.mean * 2)) self.assertTrue( torch.equal(mvn_times2.covariance_matrix, mvn.covariance_matrix * 4)) mvn_divby2 = mvn / 2 self.assertTrue(torch.equal(mvn_divby2.mean, mvn.mean / 2)) self.assertTrue( torch.equal(mvn_divby2.covariance_matrix, mvn.covariance_matrix / 4)) self.assertAlmostEqual(mvn.entropy().item(), 4.3157, places=4) self.assertAlmostEqual(mvn.log_prob( torch.zeros(3, device=device, dtype=dtype)).item(), -4.8157, places=4) logprob = mvn.log_prob( torch.zeros(2, 3, device=device, dtype=dtype)) logprob_expected = torch.tensor([-4.8157, -4.8157], device=device, dtype=dtype) self.assertTrue(torch.allclose(logprob, logprob_expected)) conf_lower, conf_upper = mvn.confidence_region() self.assertTrue( torch.allclose(conf_lower, mvn.mean - 2 * mvn.stddev)) self.assertTrue( torch.allclose(conf_upper, mvn.mean + 2 * mvn.stddev)) self.assertTrue(mvn.sample().shape == torch.Size([3])) self.assertTrue( mvn.sample(torch.Size([2])).shape == torch.Size([2, 3])) self.assertTrue( mvn.sample(torch.Size([2, 4])).shape == torch.Size([2, 4, 3]))