예제 #1
0
 def test_multivariate_normal_batch_non_lazy(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     mean = torch.tensor([0, 1, 2], dtype=torch.float, device=device)
     covmat = torch.diag(torch.tensor([1, 0.75, 1.5], device=device))
     mvn = MultivariateNormal(mean=mean.repeat(2, 1), covariance_matrix=covmat.repeat(2, 1, 1), validate_args=True)
     self.assertTrue(torch.is_tensor(mvn.covariance_matrix))
     self.assertIsInstance(mvn.lazy_covariance_matrix, LazyTensor)
     self.assertTrue(approx_equal(mvn.variance, covmat.diag().repeat(2, 1)))
     self.assertTrue(approx_equal(mvn.scale_tril, torch.diag(covmat.diag().sqrt()).repeat(2, 1, 1)))
     mvn_plus1 = mvn + 1
     self.assertTrue(torch.equal(mvn_plus1.mean, mvn.mean + 1))
     self.assertTrue(torch.equal(mvn_plus1.covariance_matrix, mvn.covariance_matrix))
     mvn_times2 = mvn * 2
     self.assertTrue(torch.equal(mvn_times2.mean, mvn.mean * 2))
     self.assertTrue(torch.equal(mvn_times2.covariance_matrix, mvn.covariance_matrix * 4))
     mvn_divby2 = mvn / 2
     self.assertTrue(torch.equal(mvn_divby2.mean, mvn.mean / 2))
     self.assertTrue(torch.equal(mvn_divby2.covariance_matrix, mvn.covariance_matrix / 4))
     self.assertTrue(approx_equal(mvn.entropy(), 4.3157 * torch.ones(2, device=device)))
     logprob = mvn.log_prob(torch.zeros(2, 3, device=device))
     logprob_expected = -4.8157 * torch.ones(2, device=device)
     self.assertTrue(approx_equal(logprob, logprob_expected))
     logprob = mvn.log_prob(torch.zeros(2, 2, 3, device=device))
     logprob_expected = -4.8157 * torch.ones(2, 2, device=device)
     self.assertTrue(approx_equal(logprob, logprob_expected))
     conf_lower, conf_upper = mvn.confidence_region()
     self.assertTrue(approx_equal(conf_lower, mvn.mean - 2 * mvn.stddev))
     self.assertTrue(approx_equal(conf_upper, mvn.mean + 2 * mvn.stddev))
     self.assertTrue(mvn.sample().shape == torch.Size([2, 3]))
     self.assertTrue(mvn.sample(torch.Size([2])).shape == torch.Size([2, 2, 3]))
     self.assertTrue(mvn.sample(torch.Size([2, 4])).shape == torch.Size([2, 4, 2, 3]))
예제 #2
0
 def test_multivariate_normal_lazy(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     mean = torch.tensor([0, 1, 2], dtype=torch.float, device=device)
     covmat = torch.diag(torch.tensor([1, 0.75, 1.5], device=device))
     mvn = MultivariateNormal(mean=mean, covariance_matrix=NonLazyTensor(covmat))
     self.assertTrue(torch.is_tensor(mvn.covariance_matrix))
     self.assertIsInstance(mvn.lazy_covariance_matrix, LazyTensor)
     self.assertTrue(torch.equal(mvn.variance, torch.diag(covmat)))
     self.assertTrue(torch.equal(mvn.covariance_matrix, covmat))
     mvn_plus1 = mvn + 1
     self.assertTrue(torch.equal(mvn_plus1.mean, mvn.mean + 1))
     self.assertTrue(torch.equal(mvn_plus1.covariance_matrix, mvn.covariance_matrix))
     mvn_times2 = mvn * 2
     self.assertTrue(torch.equal(mvn_times2.mean, mvn.mean * 2))
     self.assertTrue(torch.equal(mvn_times2.covariance_matrix, mvn.covariance_matrix * 4))
     mvn_divby2 = mvn / 2
     self.assertTrue(torch.equal(mvn_divby2.mean, mvn.mean / 2))
     self.assertTrue(torch.equal(mvn_divby2.covariance_matrix, mvn.covariance_matrix / 4))
     # TODO: Add tests for entropy, log_prob, etc. - this an issue b/c it
     # uses using root_decomposition which is not very reliable
     # self.assertAlmostEqual(mvn.entropy().item(), 4.3157, places=4)
     # self.assertAlmostEqual(mvn.log_prob(torch.zeros(3)).item(), -4.8157, places=4)
     # self.assertTrue(
     #     approx_equal(
     #         mvn.log_prob(torch.zeros(2, 3)), -4.8157 * torch.ones(2))
     #     )
     # )
     conf_lower, conf_upper = mvn.confidence_region()
     self.assertTrue(approx_equal(conf_lower, mvn.mean - 2 * mvn.stddev))
     self.assertTrue(approx_equal(conf_upper, mvn.mean + 2 * mvn.stddev))
     self.assertTrue(mvn.sample().shape == torch.Size([3]))
     self.assertTrue(mvn.sample(torch.Size([2])).shape == torch.Size([2, 3]))
     self.assertTrue(mvn.sample(torch.Size([2, 4])).shape == torch.Size([2, 4, 3]))
예제 #3
0
    def test_multivariate_normal_batch_correlated_samples(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        mean = torch.tensor([0, 1, 2], dtype=torch.float, device=device)
        covmat = torch.diag(torch.tensor([1, 0.75, 1.5], device=device))
        mvn = MultivariateNormal(mean=mean.repeat(2, 1), covariance_matrix=NonLazyTensor(covmat).repeat(2, 1, 1))

        base_samples = mvn.get_base_samples(torch.Size((3, 4)))
        self.assertTrue(mvn.sample(base_samples=base_samples).shape == torch.Size([3, 4, 2, 3]))

        base_samples = mvn.get_base_samples()
        self.assertTrue(mvn.sample(base_samples=base_samples).shape == torch.Size([2, 3]))
예제 #4
0
 def test_multivariate_normal_batch_lazy(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     for dtype in (torch.float, torch.double):
         mean = torch.tensor([0, 1, 2], device=device,
                             dtype=dtype).repeat(2, 1)
         covmat = torch.diag(
             torch.tensor([1, 0.75, 1.5], device=device,
                          dtype=dtype)).repeat(2, 1, 1)
         covmat_chol = torch.cholesky(covmat)
         mvn = MultivariateNormal(mean=mean,
                                  covariance_matrix=NonLazyTensor(covmat))
         self.assertTrue(torch.is_tensor(mvn.covariance_matrix))
         self.assertIsInstance(mvn.lazy_covariance_matrix, LazyTensor)
         self.assertAllClose(mvn.variance,
                             torch.diagonal(covmat, dim1=-2, dim2=-1))
         self.assertAllClose(mvn._unbroadcasted_scale_tril, covmat_chol)
         mvn_plus1 = mvn + 1
         self.assertAllClose(mvn_plus1.mean, mvn.mean + 1)
         self.assertAllClose(mvn_plus1.covariance_matrix,
                             mvn.covariance_matrix)
         self.assertAllClose(mvn_plus1._unbroadcasted_scale_tril,
                             covmat_chol)
         mvn_times2 = mvn * 2
         self.assertAllClose(mvn_times2.mean, mvn.mean * 2)
         self.assertAllClose(mvn_times2.covariance_matrix,
                             mvn.covariance_matrix * 4)
         self.assertAllClose(mvn_times2._unbroadcasted_scale_tril,
                             covmat_chol * 2)
         mvn_divby2 = mvn / 2
         self.assertAllClose(mvn_divby2.mean, mvn.mean / 2)
         self.assertAllClose(mvn_divby2.covariance_matrix,
                             mvn.covariance_matrix / 4)
         self.assertAllClose(mvn_divby2._unbroadcasted_scale_tril,
                             covmat_chol / 2)
         # TODO: Add tests for entropy, log_prob, etc. - this an issue b/c it
         # uses using root_decomposition which is not very reliable
         # self.assertTrue(torch.allclose(mvn.entropy(), 4.3157 * torch.ones(2)))
         # self.assertTrue(
         #     torch.allclose(mvn.log_prob(torch.zeros(2, 3)), -4.8157 * torch.ones(2))
         # )
         # self.assertTrue(
         #     torch.allclose(mvn.log_prob(torch.zeros(2, 2, 3)), -4.8157 * torch.ones(2, 2))
         # )
         conf_lower, conf_upper = mvn.confidence_region()
         self.assertAllClose(conf_lower, mvn.mean - 2 * mvn.stddev)
         self.assertAllClose(conf_upper, mvn.mean + 2 * mvn.stddev)
         self.assertTrue(mvn.sample().shape == torch.Size([2, 3]))
         self.assertTrue(
             mvn.sample(torch.Size([2])).shape == torch.Size([2, 2, 3]))
         self.assertTrue(
             mvn.sample(torch.Size([2, 4])).shape == torch.Size(
                 [2, 4, 2, 3]))
 def test_multivariate_normal_non_lazy(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     for dtype in (torch.float, torch.double):
         mean = torch.tensor([0, 1, 2], device=device, dtype=dtype)
         covmat = torch.diag(
             torch.tensor([1, 0.75, 1.5], device=device, dtype=dtype))
         mvn = MultivariateNormal(mean=mean,
                                  covariance_matrix=covmat,
                                  validate_args=True)
         self.assertTrue(torch.is_tensor(mvn.covariance_matrix))
         self.assertIsInstance(mvn.lazy_covariance_matrix, LazyTensor)
         self.assertTrue(torch.allclose(mvn.variance, torch.diag(covmat)))
         self.assertTrue(torch.allclose(mvn.scale_tril, covmat.sqrt()))
         mvn_plus1 = mvn + 1
         self.assertTrue(torch.equal(mvn_plus1.mean, mvn.mean + 1))
         self.assertTrue(
             torch.equal(mvn_plus1.covariance_matrix,
                         mvn.covariance_matrix))
         mvn_times2 = mvn * 2
         self.assertTrue(torch.equal(mvn_times2.mean, mvn.mean * 2))
         self.assertTrue(
             torch.equal(mvn_times2.covariance_matrix,
                         mvn.covariance_matrix * 4))
         mvn_divby2 = mvn / 2
         self.assertTrue(torch.equal(mvn_divby2.mean, mvn.mean / 2))
         self.assertTrue(
             torch.equal(mvn_divby2.covariance_matrix,
                         mvn.covariance_matrix / 4))
         self.assertAlmostEqual(mvn.entropy().item(), 4.3157, places=4)
         self.assertAlmostEqual(mvn.log_prob(
             torch.zeros(3, device=device, dtype=dtype)).item(),
                                -4.8157,
                                places=4)
         logprob = mvn.log_prob(
             torch.zeros(2, 3, device=device, dtype=dtype))
         logprob_expected = torch.tensor([-4.8157, -4.8157],
                                         device=device,
                                         dtype=dtype)
         self.assertTrue(torch.allclose(logprob, logprob_expected))
         conf_lower, conf_upper = mvn.confidence_region()
         self.assertTrue(
             torch.allclose(conf_lower, mvn.mean - 2 * mvn.stddev))
         self.assertTrue(
             torch.allclose(conf_upper, mvn.mean + 2 * mvn.stddev))
         self.assertTrue(mvn.sample().shape == torch.Size([3]))
         self.assertTrue(
             mvn.sample(torch.Size([2])).shape == torch.Size([2, 3]))
         self.assertTrue(
             mvn.sample(torch.Size([2, 4])).shape == torch.Size([2, 4, 3]))
예제 #6
0
 def test_multivariate_normal_correlated_samples(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     for dtype in (torch.float, torch.double):
         mean = torch.tensor([0, 1, 2], device=device, dtype=dtype)
         covmat = torch.diag(
             torch.tensor([1, 0.75, 1.5], device=device, dtype=dtype))
         mvn = MultivariateNormal(mean=mean,
                                  covariance_matrix=NonLazyTensor(covmat))
         base_samples = mvn.get_base_samples(torch.Size([3, 4]))
         self.assertTrue(
             mvn.sample(
                 base_samples=base_samples).shape == torch.Size([3, 4, 3]))
         base_samples = mvn.get_base_samples()
         self.assertTrue(
             mvn.sample(base_samples=base_samples).shape == torch.Size([3]))
예제 #7
0
    def forward(self, x):
        for d in range(self.depth - 1):
            mean_x = self.mean_module(x)
            covar_x = self.covar_module(x)
            mvn = MultivariateNormal(mean_x, covar_x)
            x = mvn.sample(sample_shape=torch.Size([self.dim]))
            x = x.t()
            if self.collect:
                self.collector[d] = x.detach().numpy()

        # last layer with single output
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        mvn = MultivariateNormal(mean_x, covar_x)
        x = mvn.sample()
        if self.collect:
            self.collector[self.depth - 1] = x.detach().numpy()
        return x
    def test_missing_value_inference(self):
        """
        samples = mvn samples + noise samples
        In this test, we try to recover noise parameters when some elements in
        'samples' are missing at random.
        """

        torch.manual_seed(self.seed)

        mu = torch.zeros(2, 3)
        sigma = torch.tensor([[[1, 0.999, -0.999], [0.999, 1, -0.999], [-0.999, -0.999, 1]]] * 2).float()
        mvn = MultivariateNormal(mu, sigma)
        samples = mvn.sample(torch.Size([10000]))  # mvn samples

        noise_sd = 0.5
        noise_dist = torch.distributions.Normal(0, noise_sd)
        samples += noise_dist.sample(samples.shape)  # noise

        missing_prop = 0.33
        missing_idx = torch.distributions.Binomial(1, missing_prop).sample(samples.shape).bool()
        samples[missing_idx] = float("nan")

        likelihood = GaussianLikelihoodWithMissingObs()

        # check that the missing value fill doesn't impact the likelihood

        likelihood.MISSING_VALUE_FILL = 999.0
        like_init_plus = likelihood.log_marginal(samples, mvn).sum().data

        likelihood.MISSING_VALUE_FILL = -999.0
        like_init_minus = likelihood.log_marginal(samples, mvn).sum().data

        torch.testing.assert_allclose(like_init_plus, like_init_minus)

        # check that the correct noise sd is recovered

        opt = torch.optim.Adam(likelihood.parameters(), lr=0.05)

        for _ in range(100):
            opt.zero_grad()
            loss = -likelihood.log_marginal(samples, mvn).sum()
            loss.backward()
            opt.step()

        assert abs(float(likelihood.noise.sqrt()) - 0.5) < 0.02

        # Check log marginal works
        likelihood.log_marginal(samples[0], mvn)
예제 #9
0
    def test_natgrad(self, D=5):
        mu = torch.randn(D)
        cov = torch.randn(D, D).tril_()
        dist = MultivariateNormal(mu, CholLazyTensor(TriangularLazyTensor(cov)))
        sample = dist.sample()

        v_dist = NaturalVariationalDistribution(D)
        v_dist.initialize_variational_distribution(dist)
        mu = v_dist().mean.detach()

        v_dist().log_prob(sample).squeeze().backward()

        eta1 = mu.clone().requires_grad_(True)
        eta2 = (mu[:, None] * mu + cov @ cov.t()).requires_grad_(True)
        L = torch.cholesky(eta2 - eta1[:, None] * eta1)
        dist2 = MultivariateNormal(eta1, CholLazyTensor(TriangularLazyTensor(L)))
        dist2.log_prob(sample).squeeze().backward()

        assert torch.allclose(v_dist.natural_vec.grad, eta1.grad)
        assert torch.allclose(v_dist.natural_mat.grad, eta2.grad)
예제 #10
0
    def test_natgrad(self, D=5):
        mu = torch.randn(D)
        cov = torch.randn(D, D)
        cov = cov @ cov.t()
        dist = MultivariateNormal(
            mu,
            CholLazyTensor(TriangularLazyTensor(torch.linalg.cholesky(cov))))
        sample = dist.sample()

        v_dist = TrilNaturalVariationalDistribution(D, mean_init_std=0.0)
        v_dist.initialize_variational_distribution(dist)
        v_dist().log_prob(sample).squeeze().backward()
        dout_dnat1 = v_dist.natural_vec.grad
        dout_dnat2 = v_dist.natural_tril_mat.grad

        # mean_init_std=0. because we need to ensure both have the same distribution
        v_dist_ref = NaturalVariationalDistribution(D, mean_init_std=0.0)
        v_dist_ref.initialize_variational_distribution(dist)
        v_dist_ref().log_prob(sample).squeeze().backward()
        dout_dnat1_noforward_ref = v_dist_ref.natural_vec.grad
        dout_dnat2_noforward_ref = v_dist_ref.natural_mat.grad

        def f(natural_vec, natural_tril_mat):
            "Transform natural_tril_mat to L"
            Sigma = torch.inverse(-2 * natural_tril_mat)
            mu = natural_vec
            return mu, torch.linalg.cholesky(Sigma).inverse().tril()

        (mu_ref, natural_tril_mat_ref), (dout_dmu_ref, dout_dnat2_ref) = jvp(
            f,
            (v_dist_ref.natural_vec.detach(), v_dist_ref.natural_mat.detach()),
            (dout_dnat1_noforward_ref, dout_dnat2_noforward_ref),
        )

        assert torch.allclose(natural_tril_mat_ref,
                              v_dist.natural_tril_mat), "Sigma transformation"
        assert torch.allclose(dout_dnat2_ref, dout_dnat2), "Sigma gradient"

        assert torch.allclose(mu_ref, v_dist.natural_vec), "mu transformation"
        assert torch.allclose(dout_dmu_ref, dout_dnat1), "mu gradient"
예제 #11
0
 def _initialize_latents(
     self,
     latent_init: str,
     num_latent_dims: List[int],
     learn_latent_pars: bool,
     device: torch.device,
     dtype: torch.dtype,
 ):
     self.latent_parameters = ParameterList()
     if latent_init == "default":
         for dim_num in range(len(self.covar_modules) - 1):
             self.latent_parameters.append(
                 Parameter(
                     torch.rand(
                         *self._aug_batch_shape,
                         self.target_shape[dim_num],
                         num_latent_dims[dim_num],
                         device=device,
                         dtype=dtype,
                     ),
                     requires_grad=learn_latent_pars,
                 )
             )
     elif latent_init == "gp":
         for dim_num, covar in enumerate(self.covar_modules[1:]):
             latent_covar = covar(
                 torch.linspace(
                     0.0,
                     1.0,
                     self.target_shape[dim_num],
                     device=device,
                     dtype=dtype,
                 )
             ).add_jitter(1e-4)
             latent_dist = MultivariateNormal(
                 torch.zeros(
                     self.target_shape[dim_num],
                     device=device,
                     dtype=dtype,
                 ),
                 latent_covar,
             )
             sample_shape = torch.Size(
                 (
                     *self._aug_batch_shape,
                     num_latent_dims[dim_num],
                 )
             )
             latent_sample = latent_dist.sample(sample_shape=sample_shape)
             latent_sample = latent_sample.reshape(
                 *self._aug_batch_shape,
                 self.target_shape[dim_num],
                 num_latent_dims[dim_num],
             )
             self.latent_parameters.append(
                 Parameter(
                     latent_sample,
                     requires_grad=learn_latent_pars,
                 )
             )
             self.register_prior(
                 "latent_parameters_" + str(dim_num),
                 MultivariateNormalPrior(
                     latent_dist.loc, latent_dist.covariance_matrix.detach().clone()
                 ),
                 lambda module, dim_num=dim_num: self.latent_parameters[dim_num],
             )