Exemple #1
0
 def test_GPyTorchPosterior(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     for dtype in (torch.float, torch.double):
         mean = torch.rand(3, dtype=dtype, device=device)
         variance = 1 + torch.rand(3, dtype=dtype, device=device)
         covar = variance.diag()
         mvn = MultivariateNormal(mean, lazify(covar))
         posterior = GPyTorchPosterior(mvn=mvn)
         # basics
         self.assertEqual(posterior.device.type, device.type)
         self.assertTrue(posterior.dtype == dtype)
         self.assertEqual(posterior.event_shape, torch.Size([3, 1]))
         self.assertTrue(torch.equal(posterior.mean, mean.unsqueeze(-1)))
         self.assertTrue(
             torch.equal(posterior.variance, variance.unsqueeze(-1)))
         # rsample
         samples = posterior.rsample()
         self.assertEqual(samples.shape, torch.Size([1, 3, 1]))
         samples = posterior.rsample(sample_shape=torch.Size([4]))
         self.assertEqual(samples.shape, torch.Size([4, 3, 1]))
         samples2 = posterior.rsample(sample_shape=torch.Size([4, 2]))
         self.assertEqual(samples2.shape, torch.Size([4, 2, 3, 1]))
         # rsample w/ base samples
         base_samples = torch.randn(4, 3, 1, device=device, dtype=dtype)
         # incompatible shapes
         with self.assertRaises(RuntimeError):
             posterior.rsample(sample_shape=torch.Size([3]),
                               base_samples=base_samples)
         samples_b1 = posterior.rsample(sample_shape=torch.Size([4]),
                                        base_samples=base_samples)
         samples_b2 = posterior.rsample(sample_shape=torch.Size([4]),
                                        base_samples=base_samples)
         self.assertTrue(torch.allclose(samples_b1, samples_b2))
         base_samples2 = torch.randn(4, 2, 3, 1, device=device, dtype=dtype)
         samples2_b1 = posterior.rsample(sample_shape=torch.Size([4, 2]),
                                         base_samples=base_samples2)
         samples2_b2 = posterior.rsample(sample_shape=torch.Size([4, 2]),
                                         base_samples=base_samples2)
         self.assertTrue(torch.allclose(samples2_b1, samples2_b2))
         # collapse_batch_dims
         b_mean = torch.rand(2, 3, dtype=dtype, device=device)
         b_variance = 1 + torch.rand(2, 3, dtype=dtype, device=device)
         b_covar = b_variance.unsqueeze(-1) * torch.eye(3).type_as(
             b_variance)
         b_mvn = MultivariateNormal(b_mean, lazify(b_covar))
         b_posterior = GPyTorchPosterior(mvn=b_mvn)
         b_base_samples = torch.randn(4,
                                      1,
                                      3,
                                      1,
                                      device=device,
                                      dtype=dtype)
         b_samples = b_posterior.rsample(sample_shape=torch.Size([4]),
                                         base_samples=b_base_samples)
         self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 1]))
Exemple #2
0
 def test_GPyTorchPosterior(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     for dtype in (torch.float, torch.double):
         mean = torch.rand(3, dtype=dtype, device=device)
         variance = 1 + torch.rand(3, dtype=dtype, device=device)
         covar = variance.diag()
         mvn = MultivariateNormal(mean, lazify(covar))
         posterior = GPyTorchPosterior(mvn=mvn)
         # basics
         self.assertEqual(posterior.device.type, device.type)
         self.assertTrue(posterior.dtype == dtype)
         self.assertEqual(posterior.event_shape, torch.Size([3, 1]))
         self.assertTrue(torch.equal(posterior.mean, mean.unsqueeze(-1)))
         self.assertTrue(torch.equal(posterior.variance, variance.unsqueeze(-1)))
         # rsample
         samples = posterior.rsample()
         self.assertEqual(samples.shape, torch.Size([1, 3, 1]))
         samples = posterior.rsample(sample_shape=torch.Size([4]))
         self.assertEqual(samples.shape, torch.Size([4, 3, 1]))
         samples2 = posterior.rsample(sample_shape=torch.Size([4, 2]))
         self.assertEqual(samples2.shape, torch.Size([4, 2, 3, 1]))
         # rsample w/ base samples
         base_samples = torch.randn(4, 3, 1, device=device, dtype=dtype)
         # incompatible shapes
         with self.assertRaises(RuntimeError):
             posterior.rsample(
                 sample_shape=torch.Size([3]), base_samples=base_samples
             )
         samples_b1 = posterior.rsample(
             sample_shape=torch.Size([4]), base_samples=base_samples
         )
         samples_b2 = posterior.rsample(
             sample_shape=torch.Size([4]), base_samples=base_samples
         )
         self.assertTrue(torch.allclose(samples_b1, samples_b2))
         base_samples2 = torch.randn(4, 2, 3, 1, device=device, dtype=dtype)
         samples2_b1 = posterior.rsample(
             sample_shape=torch.Size([4, 2]), base_samples=base_samples2
         )
         samples2_b2 = posterior.rsample(
             sample_shape=torch.Size([4, 2]), base_samples=base_samples2
         )
         self.assertTrue(torch.allclose(samples2_b1, samples2_b2))
         # collapse_batch_dims
         b_mean = torch.rand(2, 3, dtype=dtype, device=device)
         b_variance = 1 + torch.rand(2, 3, dtype=dtype, device=device)
         b_covar = b_variance.unsqueeze(-1) * torch.eye(3).type_as(b_variance)
         b_mvn = MultivariateNormal(b_mean, lazify(b_covar))
         b_posterior = GPyTorchPosterior(mvn=b_mvn)
         b_base_samples = torch.randn(4, 1, 3, 1, device=device, dtype=dtype)
         b_samples = b_posterior.rsample(
             sample_shape=torch.Size([4]), base_samples=b_base_samples
         )
         self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 1]))
Exemple #3
0
 def test_transformed_posterior(self):
     for dtype in (torch.float, torch.double):
         for m in (1, 2):
             shape = torch.Size([3, m])
             mean = torch.rand(shape, dtype=dtype, device=self.device)
             variance = 1 + torch.rand(
                 shape, dtype=dtype, device=self.device)
             if m == 1:
                 covar = torch.diag_embed(variance.squeeze(-1))
                 mvn = MultivariateNormal(mean.squeeze(-1), lazify(covar))
             else:
                 covar = torch.diag_embed(
                     variance.view(*variance.shape[:-2], -1))
                 mvn = MultitaskMultivariateNormal(mean, lazify(covar))
             p_base = GPyTorchPosterior(mvn=mvn)
             p_tf = TransformedPosterior(  # dummy transforms
                 posterior=p_base,
                 sample_transform=lambda s: s + 2,
                 mean_transform=lambda m, v: 2 * m + v,
                 variance_transform=lambda m, v: m + 2 * v,
             )
             # mean, variance
             self.assertEqual(p_tf.device.type, self.device.type)
             self.assertTrue(p_tf.dtype == dtype)
             self.assertEqual(p_tf.event_shape, shape)
             self.assertEqual(p_tf.base_sample_shape, shape)
             self.assertTrue(torch.equal(p_tf.mean, 2 * mean + variance))
             self.assertTrue(torch.equal(p_tf.variance,
                                         mean + 2 * variance))
             # rsample
             samples = p_tf.rsample()
             self.assertEqual(samples.shape, torch.Size([1]) + shape)
             samples = p_tf.rsample(sample_shape=torch.Size([4]))
             self.assertEqual(samples.shape, torch.Size([4]) + shape)
             samples2 = p_tf.rsample(sample_shape=torch.Size([4, 2]))
             self.assertEqual(samples2.shape, torch.Size([4, 2]) + shape)
             # rsample w/ base samples
             base_samples = torch.randn(4,
                                        *shape,
                                        device=self.device,
                                        dtype=dtype)
             # incompatible shapes
             with self.assertRaises(RuntimeError):
                 p_tf.rsample(sample_shape=torch.Size([3]),
                              base_samples=base_samples)
             # make sure sample transform is applied correctly
             samples_base = p_base.rsample(sample_shape=torch.Size([4]),
                                           base_samples=base_samples)
             samples_tf = p_tf.rsample(sample_shape=torch.Size([4]),
                                       base_samples=base_samples)
             self.assertTrue(torch.equal(samples_tf, samples_base + 2))
             # check error handling
             p_tf_2 = TransformedPosterior(posterior=p_base,
                                           sample_transform=lambda s: s + 2)
             with self.assertRaises(NotImplementedError):
                 p_tf_2.mean
             with self.assertRaises(NotImplementedError):
                 p_tf_2.variance
Exemple #4
0
    def test_degenerate_GPyTorchPosterior(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        for dtype in (torch.float, torch.double):
            # singular covariance matrix
            degenerate_covar = torch.tensor(
                [[1, 1, 0], [1, 1, 0], [0, 0, 2]], dtype=dtype, device=device
            )
            mean = torch.rand(3, dtype=dtype, device=device)
            mvn = MultivariateNormal(mean, lazify(degenerate_covar))
            posterior = GPyTorchPosterior(mvn=mvn)
            # basics
            self.assertEqual(posterior.device.type, device.type)
            self.assertTrue(posterior.dtype == dtype)
            self.assertEqual(posterior.event_shape, torch.Size([3, 1]))
            self.assertTrue(torch.equal(posterior.mean, mean.unsqueeze(-1)))
            variance_exp = degenerate_covar.diag().unsqueeze(-1)
            self.assertTrue(torch.equal(posterior.variance, variance_exp))

            # rsample
            with warnings.catch_warnings(record=True) as w:
                # we check that the p.d. warning is emitted - this only
                # happens once per posterior, so we need to check only once
                samples = posterior.rsample(sample_shape=torch.Size([4]))
                self.assertEqual(len(w), 1)
                self.assertTrue(issubclass(w[-1].category, RuntimeWarning))
                self.assertTrue("not p.d." in str(w[-1].message))
            self.assertEqual(samples.shape, torch.Size([4, 3, 1]))
            samples2 = posterior.rsample(sample_shape=torch.Size([4, 2]))
            self.assertEqual(samples2.shape, torch.Size([4, 2, 3, 1]))
            # rsample w/ base samples
            base_samples = torch.randn(4, 3, 1, device=device, dtype=dtype)
            samples_b1 = posterior.rsample(
                sample_shape=torch.Size([4]), base_samples=base_samples
            )
            samples_b2 = posterior.rsample(
                sample_shape=torch.Size([4]), base_samples=base_samples
            )
            self.assertTrue(torch.allclose(samples_b1, samples_b2))
            base_samples2 = torch.randn(4, 2, 3, 1, device=device, dtype=dtype)
            samples2_b1 = posterior.rsample(
                sample_shape=torch.Size([4, 2]), base_samples=base_samples2
            )
            samples2_b2 = posterior.rsample(
                sample_shape=torch.Size([4, 2]), base_samples=base_samples2
            )
            self.assertTrue(torch.allclose(samples2_b1, samples2_b2))
            # collapse_batch_dims
            b_mean = torch.rand(2, 3, dtype=dtype, device=device)
            b_degenerate_covar = degenerate_covar.expand(2, *degenerate_covar.shape)
            b_mvn = MultivariateNormal(b_mean, lazify(b_degenerate_covar))
            b_posterior = GPyTorchPosterior(mvn=b_mvn)
            b_base_samples = torch.randn(4, 2, 3, 1, device=device, dtype=dtype)
            with warnings.catch_warnings(record=True) as w:
                b_samples = b_posterior.rsample(
                    sample_shape=torch.Size([4]), base_samples=b_base_samples
                )
                self.assertEqual(len(w), 1)
                self.assertTrue(issubclass(w[-1].category, RuntimeWarning))
                self.assertTrue("not p.d." in str(w[-1].message))
            self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 1]))
Exemple #5
0
 def test_degenerate_GPyTorchPosterior_Multitask(self):
     for dtype in (torch.float, torch.double):
         # singular covariance matrix
         degenerate_covar = torch.tensor(
             [[1, 1, 0], [1, 1, 0], [0, 0, 2]], dtype=dtype, device=self.device
         )
         mean = torch.rand(3, dtype=dtype, device=self.device)
         mvn = MultivariateNormal(mean, lazify(degenerate_covar))
         mvn = MultitaskMultivariateNormal.from_independent_mvns([mvn, mvn])
         posterior = GPyTorchPosterior(mvn=mvn)
         # basics
         self.assertEqual(posterior.device.type, self.device.type)
         self.assertTrue(posterior.dtype == dtype)
         self.assertEqual(posterior.event_shape, torch.Size([3, 2]))
         mean_exp = mean.unsqueeze(-1).repeat(1, 2)
         self.assertTrue(torch.equal(posterior.mean, mean_exp))
         variance_exp = degenerate_covar.diag().unsqueeze(-1).repeat(1, 2)
         self.assertTrue(torch.equal(posterior.variance, variance_exp))
         # rsample
         with warnings.catch_warnings(record=True) as ws:
             # we check that the p.d. warning is emitted - this only
             # happens once per posterior, so we need to check only once
             samples = posterior.rsample(sample_shape=torch.Size([4]))
             self.assertTrue(any(issubclass(w.category, RuntimeWarning) for w in ws))
             self.assertTrue(any("not p.d" in str(w.message) for w in ws))
         self.assertEqual(samples.shape, torch.Size([4, 3, 2]))
         samples2 = posterior.rsample(sample_shape=torch.Size([4, 2]))
         self.assertEqual(samples2.shape, torch.Size([4, 2, 3, 2]))
         # rsample w/ base samples
         base_samples = torch.randn(4, 3, 2, device=self.device, dtype=dtype)
         samples_b1 = posterior.rsample(
             sample_shape=torch.Size([4]), base_samples=base_samples
         )
         samples_b2 = posterior.rsample(
             sample_shape=torch.Size([4]), base_samples=base_samples
         )
         self.assertTrue(torch.allclose(samples_b1, samples_b2))
         base_samples2 = torch.randn(4, 2, 3, 2, device=self.device, dtype=dtype)
         samples2_b1 = posterior.rsample(
             sample_shape=torch.Size([4, 2]), base_samples=base_samples2
         )
         samples2_b2 = posterior.rsample(
             sample_shape=torch.Size([4, 2]), base_samples=base_samples2
         )
         self.assertTrue(torch.allclose(samples2_b1, samples2_b2))
         # collapse_batch_dims
         b_mean = torch.rand(2, 3, dtype=dtype, device=self.device)
         b_degenerate_covar = degenerate_covar.expand(2, *degenerate_covar.shape)
         b_mvn = MultivariateNormal(b_mean, lazify(b_degenerate_covar))
         b_mvn = MultitaskMultivariateNormal.from_independent_mvns([b_mvn, b_mvn])
         b_posterior = GPyTorchPosterior(mvn=b_mvn)
         b_base_samples = torch.randn(4, 1, 3, 2, device=self.device, dtype=dtype)
         with warnings.catch_warnings(record=True) as ws:
             b_samples = b_posterior.rsample(
                 sample_shape=torch.Size([4]), base_samples=b_base_samples
             )
             self.assertTrue(any(issubclass(w.category, RuntimeWarning) for w in ws))
             self.assertTrue(any("not p.d" in str(w.message) for w in ws))
         self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 2]))
Exemple #6
0
 def test_GPyTorchPosterior_Multitask(self):
     for dtype in (torch.float, torch.double):
         mean = torch.rand(3, 2, dtype=dtype, device=self.device)
         variance = 1 + torch.rand(3, 2, dtype=dtype, device=self.device)
         covar = variance.view(-1).diag()
         mvn = MultitaskMultivariateNormal(mean, lazify(covar))
         posterior = GPyTorchPosterior(mvn=mvn)
         # basics
         self.assertEqual(posterior.device.type, self.device.type)
         self.assertTrue(posterior.dtype == dtype)
         self.assertEqual(posterior.event_shape, torch.Size([3, 2]))
         self.assertTrue(torch.equal(posterior.mean, mean))
         self.assertTrue(torch.equal(posterior.variance, variance))
         # rsample
         samples = posterior.rsample(sample_shape=torch.Size([4]))
         self.assertEqual(samples.shape, torch.Size([4, 3, 2]))
         samples2 = posterior.rsample(sample_shape=torch.Size([4, 2]))
         self.assertEqual(samples2.shape, torch.Size([4, 2, 3, 2]))
         # rsample w/ base samples
         base_samples = torch.randn(4,
                                    3,
                                    2,
                                    device=self.device,
                                    dtype=dtype)
         samples_b1 = posterior.rsample(sample_shape=torch.Size([4]),
                                        base_samples=base_samples)
         samples_b2 = posterior.rsample(sample_shape=torch.Size([4]),
                                        base_samples=base_samples)
         self.assertTrue(torch.allclose(samples_b1, samples_b2))
         base_samples2 = torch.randn(4,
                                     2,
                                     3,
                                     2,
                                     device=self.device,
                                     dtype=dtype)
         samples2_b1 = posterior.rsample(sample_shape=torch.Size([4, 2]),
                                         base_samples=base_samples2)
         samples2_b2 = posterior.rsample(sample_shape=torch.Size([4, 2]),
                                         base_samples=base_samples2)
         self.assertTrue(torch.allclose(samples2_b1, samples2_b2))
         # collapse_batch_dims
         b_mean = torch.rand(2, 3, 2, dtype=dtype, device=self.device)
         b_variance = 1 + torch.rand(
             2, 3, 2, dtype=dtype, device=self.device)
         b_covar = b_variance.view(2, 6,
                                   1) * torch.eye(6).type_as(b_variance)
         b_mvn = MultitaskMultivariateNormal(b_mean, lazify(b_covar))
         b_posterior = GPyTorchPosterior(mvn=b_mvn)
         b_base_samples = torch.randn(4,
                                      1,
                                      3,
                                      2,
                                      device=self.device,
                                      dtype=dtype)
         b_samples = b_posterior.rsample(sample_shape=torch.Size([4]),
                                         base_samples=b_base_samples)
         self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 2]))
Exemple #7
0
 def test_GPyTorchPosterior_Multitask(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     for dtype in (torch.float, torch.double):
         mean = torch.rand(3, 2, dtype=dtype, device=device)
         variance = 1 + torch.rand(3, 2, dtype=dtype, device=device)
         covar = variance.view(-1).diag()
         mvn = MultitaskMultivariateNormal(mean, lazify(covar))
         posterior = GPyTorchPosterior(mvn=mvn)
         # basics
         self.assertEqual(posterior.device.type, device.type)
         self.assertTrue(posterior.dtype == dtype)
         self.assertEqual(posterior.event_shape, torch.Size([3, 2]))
         self.assertTrue(torch.equal(posterior.mean, mean))
         self.assertTrue(torch.equal(posterior.variance, variance))
         # rsample
         samples = posterior.rsample(sample_shape=torch.Size([4]))
         self.assertEqual(samples.shape, torch.Size([4, 3, 2]))
         samples2 = posterior.rsample(sample_shape=torch.Size([4, 2]))
         self.assertEqual(samples2.shape, torch.Size([4, 2, 3, 2]))
         # rsample w/ base samples
         base_samples = torch.randn(4, 3, 2, device=device, dtype=dtype)
         samples_b1 = posterior.rsample(
             sample_shape=torch.Size([4]), base_samples=base_samples
         )
         samples_b2 = posterior.rsample(
             sample_shape=torch.Size([4]), base_samples=base_samples
         )
         self.assertTrue(torch.allclose(samples_b1, samples_b2))
         base_samples2 = torch.randn(4, 2, 3, 2, device=device, dtype=dtype)
         samples2_b1 = posterior.rsample(
             sample_shape=torch.Size([4, 2]), base_samples=base_samples2
         )
         samples2_b2 = posterior.rsample(
             sample_shape=torch.Size([4, 2]), base_samples=base_samples2
         )
         self.assertTrue(torch.allclose(samples2_b1, samples2_b2))
         # collapse_batch_dims
         b_mean = torch.rand(2, 3, 2, dtype=dtype, device=device)
         b_variance = 1 + torch.rand(2, 3, 2, dtype=dtype, device=device)
         b_covar = b_variance.view(2, 6, 1) * torch.eye(6).type_as(b_variance)
         b_mvn = MultitaskMultivariateNormal(b_mean, lazify(b_covar))
         b_posterior = GPyTorchPosterior(mvn=b_mvn)
         b_base_samples = torch.randn(4, 1, 3, 2, device=device, dtype=dtype)
         b_samples = b_posterior.rsample(
             sample_shape=torch.Size([4]), base_samples=b_base_samples
         )
         self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 2]))
Exemple #8
0
    def test_GPyTorchPosterior(self):
        for dtype in (torch.float, torch.double):
            n = 3
            mean = torch.rand(n, dtype=dtype, device=self.device)
            variance = 1 + torch.rand(n, dtype=dtype, device=self.device)
            covar = variance.diag()
            mvn = MultivariateNormal(mean, lazify(covar))
            posterior = GPyTorchPosterior(mvn=mvn)
            # basics
            self.assertEqual(posterior.device.type, self.device.type)
            self.assertTrue(posterior.dtype == dtype)
            self.assertEqual(posterior.event_shape, torch.Size([n, 1]))
            self.assertTrue(torch.equal(posterior.mean, mean.unsqueeze(-1)))
            self.assertTrue(
                torch.equal(posterior.variance, variance.unsqueeze(-1)))
            # rsample
            samples = posterior.rsample()
            self.assertEqual(samples.shape, torch.Size([1, n, 1]))
            for sample_shape in ([4], [4, 2]):
                samples = posterior.rsample(
                    sample_shape=torch.Size(sample_shape))
                self.assertEqual(samples.shape,
                                 torch.Size(sample_shape + [n, 1]))
            # check enabling of approximate root decomposition
            with ExitStack() as es:
                mock_func = es.enter_context(
                    mock.patch(ROOT_DECOMP_PATH,
                               return_value=torch.cholesky(covar)))
                es.enter_context(gpt_settings.max_cholesky_size(0))
                es.enter_context(
                    gpt_settings.fast_computations(
                        covar_root_decomposition=True))
                # need to clear cache, cannot re-use previous objects
                mvn = MultivariateNormal(mean, lazify(covar))
                posterior = GPyTorchPosterior(mvn=mvn)
                posterior.rsample(sample_shape=torch.Size([4]))
                mock_func.assert_called_once()

            # rsample w/ base samples
            base_samples = torch.randn(4,
                                       3,
                                       1,
                                       device=self.device,
                                       dtype=dtype)
            # incompatible shapes
            with self.assertRaises(RuntimeError):
                posterior.rsample(sample_shape=torch.Size([3]),
                                  base_samples=base_samples)
            # ensure consistent result
            for sample_shape in ([4], [4, 2]):
                base_samples = torch.randn(*sample_shape,
                                           3,
                                           1,
                                           device=self.device,
                                           dtype=dtype)
                samples = [
                    posterior.rsample(sample_shape=torch.Size(sample_shape),
                                      base_samples=base_samples)
                    for _ in range(2)
                ]
                self.assertTrue(torch.allclose(*samples))
            # collapse_batch_dims
            b_mean = torch.rand(2, 3, dtype=dtype, device=self.device)
            b_variance = 1 + torch.rand(2, 3, dtype=dtype, device=self.device)
            b_covar = torch.diag_embed(b_variance)
            b_mvn = MultivariateNormal(b_mean, lazify(b_covar))
            b_posterior = GPyTorchPosterior(mvn=b_mvn)
            b_base_samples = torch.randn(4,
                                         1,
                                         3,
                                         1,
                                         device=self.device,
                                         dtype=dtype)
            b_samples = b_posterior.rsample(sample_shape=torch.Size([4]),
                                            base_samples=b_base_samples)
            self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 1]))
Exemple #9
0
    def test_degenerate_GPyTorchPosterior(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        for dtype in (torch.float, torch.double):
            # singular covariance matrix
            degenerate_covar = torch.tensor([[1, 1, 0], [1, 1, 0], [0, 0, 2]],
                                            dtype=dtype,
                                            device=device)
            mean = torch.rand(3, dtype=dtype, device=device)
            mvn = MultivariateNormal(mean, lazify(degenerate_covar))
            posterior = GPyTorchPosterior(mvn=mvn)
            # basics
            self.assertEqual(posterior.device.type, device.type)
            self.assertTrue(posterior.dtype == dtype)
            self.assertEqual(posterior.event_shape, torch.Size([3, 1]))
            self.assertTrue(torch.equal(posterior.mean, mean.unsqueeze(-1)))
            variance_exp = degenerate_covar.diag().unsqueeze(-1)
            self.assertTrue(torch.equal(posterior.variance, variance_exp))

            # rsample
            with warnings.catch_warnings(record=True) as w:
                # we check that the p.d. warning is emitted - this only
                # happens once per posterior, so we need to check only once
                samples = posterior.rsample(sample_shape=torch.Size([4]))
                self.assertEqual(len(w), 1)
                self.assertTrue(issubclass(w[-1].category, RuntimeWarning))
                self.assertTrue("not p.d." in str(w[-1].message))
            self.assertEqual(samples.shape, torch.Size([4, 3, 1]))
            samples2 = posterior.rsample(sample_shape=torch.Size([4, 2]))
            self.assertEqual(samples2.shape, torch.Size([4, 2, 3, 1]))
            # rsample w/ base samples
            base_samples = torch.randn(4, 3, 1, device=device, dtype=dtype)
            samples_b1 = posterior.rsample(sample_shape=torch.Size([4]),
                                           base_samples=base_samples)
            samples_b2 = posterior.rsample(sample_shape=torch.Size([4]),
                                           base_samples=base_samples)
            self.assertTrue(torch.allclose(samples_b1, samples_b2))
            base_samples2 = torch.randn(4, 2, 3, 1, device=device, dtype=dtype)
            samples2_b1 = posterior.rsample(sample_shape=torch.Size([4, 2]),
                                            base_samples=base_samples2)
            samples2_b2 = posterior.rsample(sample_shape=torch.Size([4, 2]),
                                            base_samples=base_samples2)
            self.assertTrue(torch.allclose(samples2_b1, samples2_b2))
            # collapse_batch_dims
            b_mean = torch.rand(2, 3, dtype=dtype, device=device)
            b_degenerate_covar = degenerate_covar.expand(
                2, *degenerate_covar.shape)
            b_mvn = MultivariateNormal(b_mean, lazify(b_degenerate_covar))
            b_posterior = GPyTorchPosterior(mvn=b_mvn)
            b_base_samples = torch.randn(4,
                                         2,
                                         3,
                                         1,
                                         device=device,
                                         dtype=dtype)
            with warnings.catch_warnings(record=True) as w:
                b_samples = b_posterior.rsample(sample_shape=torch.Size([4]),
                                                base_samples=b_base_samples)
                self.assertEqual(len(w), 1)
                self.assertTrue(issubclass(w[-1].category, RuntimeWarning))
                self.assertTrue("not p.d." in str(w[-1].message))
            self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 1]))