def forward(self, x):
     """Compute the resulting batch-distribution."""
     return MultivariateNormal(self.mean(x), self.kernel(x))
예제 #2
0
    def posterior(
        self,
        X: Tensor,
        output_indices: Optional[List[int]] = None,
        observation_noise: Union[bool, Tensor] = False,
        **kwargs: Any,
    ) -> GPyTorchPosterior:
        self.eval()  # make sure we're calling a posterior

        no_pred_variance = skip_posterior_variances._state

        with ExitStack() as es:
            es.enter_context(gpt_posterior_settings())
            es.enter_context(fast_pred_var(True))

            # we need to skip posterior variances here
            es.enter_context(skip_posterior_variances(True))
            mvn = self(X)
            if observation_noise is not False:
                # TODO: implement Kronecker + diagonal solves so that this is possible.
                # if torch.is_tensor(observation_noise):
                #     # TODO: Validate noise shape
                #     # make observation_noise `batch_shape x q x n`
                #     obs_noise = observation_noise.transpose(-1, -2)
                #     mvn = self.likelihood(mvn, X, noise=obs_noise)
                # elif isinstance(self.likelihood, FixedNoiseGaussianLikelihood):
                #     noise = self.likelihood.noise.mean().expand(X.shape[:-1])
                #     mvn = self.likelihood(mvn, X, noise=noise)
                # else:
                mvn = self.likelihood(mvn, X)

            # lazy covariance matrix includes the interpolated version of the full
            # covariance matrix so we can actually grab that instead.
            if X.ndimension() > self.train_inputs[0].ndimension():
                X_batch_shape = X.shape[:-2]
                train_inputs = self.train_inputs[0].reshape(
                    *[1] * len(X_batch_shape), *self.train_inputs[0].shape)
                train_inputs = train_inputs.repeat(
                    *X_batch_shape, *[1] * self.train_inputs[0].ndimension())
            else:
                train_inputs = self.train_inputs[0]
            full_covar = self.covar_modules[0](torch.cat((train_inputs, X),
                                                         dim=-2))

            if no_pred_variance:
                pred_variance = mvn.variance
            else:
                # we detach all of the latent dimension posteriors which precludes
                # computing quantities computed on the posterior wrt latents as
                # this reduces the memory overhead somewhat
                # TODO: add these back in if necessary
                joint_covar = self._get_joint_covariance([X])
                pred_variance = self.make_posterior_variances(joint_covar)

                full_covar = KroneckerProductLazyTensor(
                    full_covar,
                    *[x.detach() for x in joint_covar.lazy_tensors[1:]])

            joint_covar_list = [self.covar_modules[0](X, train_inputs)]
            batch_shape = joint_covar_list[0].batch_shape
            for cm, param in zip(self.covar_modules[1:],
                                 self.latent_parameters):
                covar = cm(param).detach()
                if covar.batch_shape != batch_shape:
                    covar = BatchRepeatLazyTensor(covar, batch_shape)
                joint_covar_list.append(covar)

            test_train_covar = KroneckerProductLazyTensor(*joint_covar_list)

            # mean and variance get reshaped into the target shape
            new_mean = mvn.mean.reshape(*X.shape[:-1], *self.target_shape)
            if not no_pred_variance:
                new_variance = pred_variance.reshape(*X.shape[:-1],
                                                     *self.target_shape)
                new_variance = DiagLazyTensor(new_variance)
            else:
                new_variance = ZeroLazyTensor(*X.shape[:-1],
                                              *self.target_shape,
                                              self.target_shape[-1])

            mvn = MultivariateNormal(new_mean, new_variance)

            train_train_covar = self.prediction_strategy.lik_train_train_covar.detach(
            )

            # return a specialized Posterior to allow for sampling
            # cloning the full covar allows backpropagation through it
            posterior = HigherOrderGPPosterior(
                mvn=mvn,
                train_targets=self.train_targets.unsqueeze(-1),
                train_train_covar=train_train_covar,
                test_train_covar=test_train_covar,
                joint_covariance_matrix=full_covar.clone(),
                output_shape=Size((
                    *X.shape[:-1],
                    *self.target_shape,
                )),
                num_outputs=self._num_outputs,
            )
            if hasattr(self, "outcome_transform"):
                posterior = self.outcome_transform.untransform_posterior(
                    posterior)

            return posterior
예제 #3
0
    def test_expected_improvement_batch(self):
        for dtype in (torch.float, torch.double):
            mean = torch.tensor([-0.5, 0.0, 0.5],
                                device=self.device,
                                dtype=dtype).view(3, 1, 1)
            variance = torch.ones(3, 1, 1, device=self.device, dtype=dtype)
            mm = MockModel(MockPosterior(mean=mean, variance=variance))
            module = ExpectedImprovement(model=mm, best_f=0.0)
            X = torch.empty(3, 1, 1, device=self.device, dtype=dtype)  # dummy
            ei = module(X)
            ei_expected = torch.tensor([0.19780, 0.39894, 0.69780],
                                       device=self.device,
                                       dtype=dtype)
            self.assertTrue(torch.allclose(ei, ei_expected, atol=1e-4))
            # check for proper error if multi-output model
            mean2 = torch.rand(3, 1, 2, device=self.device, dtype=dtype)
            variance2 = torch.rand(3, 1, 2, device=self.device, dtype=dtype)
            mm2 = MockModel(MockPosterior(mean=mean2, variance=variance2))
            with self.assertRaises(UnsupportedError):
                ExpectedImprovement(model=mm2, best_f=0.0)

            # test objective (single-output)
            mean = torch.tensor([[[0.5]], [[0.25]]],
                                device=self.device,
                                dtype=dtype)
            covar = torch.tensor([[[[0.16]]], [[[0.125]]]],
                                 device=self.device,
                                 dtype=dtype)
            mvn = MultivariateNormal(mean, covar)
            p = GPyTorchPosterior(mvn)
            mm = MockModel(p)
            weights = torch.tensor([0.5], device=self.device, dtype=dtype)
            obj = ScalarizedObjective(weights)
            ei = ExpectedImprovement(model=mm, best_f=0.0, objective=obj)
            X = torch.rand(2, 1, 2, device=self.device, dtype=dtype)
            ei_expected = torch.tensor([[0.2601], [0.1500]],
                                       device=self.device,
                                       dtype=dtype)
            torch.allclose(ei(X), ei_expected, atol=1e-4)

            # test objective (multi-output)
            mean = torch.tensor([[[-0.25, 0.5]], [[0.2, -0.1]]],
                                device=self.device,
                                dtype=dtype)
            covar = torch.tensor(
                [[[0.5, 0.125], [0.125, 0.5]], [[0.25, -0.1], [-0.1, 0.25]]],
                device=self.device,
                dtype=dtype,
            )
            mvn = MultitaskMultivariateNormal(mean, covar)
            p = GPyTorchPosterior(mvn)
            mm = MockModel(p)
            weights = torch.tensor([2.0, 1.0], device=self.device, dtype=dtype)
            obj = ScalarizedObjective(weights)
            ei = ExpectedImprovement(model=mm, best_f=0.0, objective=obj)
            X = torch.rand(2, 1, 2, device=self.device, dtype=dtype)
            ei_expected = torch.tensor([0.6910, 0.5371],
                                       device=self.device,
                                       dtype=dtype)
            torch.allclose(ei(X), ei_expected, atol=1e-4)

        # test bad objective class
        with self.assertRaises(UnsupportedError):
            ExpectedImprovement(model=mm,
                                best_f=0.0,
                                objective=IdentityMCObjective())
예제 #4
0
def _get_test_posterior_batched(device, dtype=torch.float):
    mean = torch.zeros(3, 2, device=device, dtype=dtype)
    cov = torch.eye(2, device=device, dtype=dtype).repeat(3, 1, 1)
    mvn = MultivariateNormal(mean, cov)
    return GPyTorchPosterior(mvn)
예제 #5
0
 def forward(self, x):
     return MultivariateNormal(self.mean_module(x), self.covar_module(x))
예제 #6
0
def scalarize_posterior(posterior: GPyTorchPosterior,
                        weights: Tensor,
                        offset: float = 0.0) -> GPyTorchPosterior:
    r"""Affine transformation of a multi-output posterior.

    Args:
        posterior: The posterior over `m` outcomes to be scalarized.
            Supports `t`-batching.
        weights: A tensor of weights of size `m`.
        offset: The offset of the affine transformation.

    Returns:
        The transformed (single-output) posterior. If the input posterior has
            mean `mu` and covariance matrix `Sigma`, this posterior has mean
            `weights^T * mu` and variance `weights^T Sigma w`.

    Example:
        Example for a model with two outcomes:

        >>> X = torch.rand(1, 2)
        >>> posterior = model.posterior(X)
        >>> weights = torch.tensor([0.5, 0.25])
        >>> new_posterior = scalarize_posterior(posterior, weights=weights)
    """
    if weights.ndim > 1:
        raise BotorchTensorDimensionError("`weights` must be one-dimensional")
    mean = posterior.mean
    q, m = mean.shape[-2:]
    batch_shape = mean.shape[:-2]
    if m != weights.size(0):
        raise RuntimeError("Output shape not equal to that of weights")
    mvn = posterior.mvn
    cov = mvn.lazy_covariance_matrix if mvn.islazy else mvn.covariance_matrix

    if m == 1:  # just scaling, no scalarization necessary
        new_mean = offset + (weights[0] * mean).view(*batch_shape, q)
        new_cov = weights[0]**2 * cov
        new_mvn = MultivariateNormal(new_mean, new_cov)
        return GPyTorchPosterior(new_mvn)

    new_mean = offset + (mean @ weights).view(*batch_shape, q)

    if q == 1:
        new_cov = weights.unsqueeze(-2) @ (cov @ weights.unsqueeze(-1))
    else:
        # we need to handle potentially different representations of the multi-task mvn
        if mvn._interleaved:
            w_cov = weights.repeat(q).unsqueeze(0)
            sum_shape = batch_shape + torch.Size([q, m, q, m])
            sum_dims = (-1, -2)
        else:
            # special-case the independent setting
            if isinstance(cov, BlockDiagLazyTensor):
                new_cov = SumLazyTensor(*[
                    cov.base_lazy_tensor[..., i, :, :] * weights[i].pow(2)
                    for i in range(cov.base_lazy_tensor.size(-3))
                ])
                new_mvn = MultivariateNormal(new_mean, new_cov)
                return GPyTorchPosterior(new_mvn)

            w_cov = torch.repeat_interleave(weights, q).unsqueeze(0)
            sum_shape = batch_shape + torch.Size([m, q, m, q])
            sum_dims = (-2, -3)

        cov_scaled = w_cov * cov * w_cov.transpose(-1, -2)
        # TODO: Do not instantiate full covariance for lazy tensors (ideally we simplify
        # this in GPyTorch: https://github.com/cornellius-gp/gpytorch/issues/1055)
        if isinstance(cov_scaled, LazyTensor):
            cov_scaled = cov_scaled.evaluate()
        new_cov = cov_scaled.view(sum_shape).sum(dim=sum_dims[0]).sum(
            dim=sum_dims[1])

    new_mvn = MultivariateNormal(new_mean, new_cov)
    return GPyTorchPosterior(new_mvn)
예제 #7
0
 def test_construct_base_samples_from_posterior(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     for dtype in (torch.float, torch.double):
         # single-output
         mean = torch.zeros(2, device=device, dtype=dtype)
         cov = torch.eye(2, device=device, dtype=dtype)
         mvn = MultivariateNormal(mean=mean, covariance_matrix=cov)
         posterior = GPyTorchPosterior(mvn=mvn)
         for sample_shape in (torch.Size([5]), torch.Size([5, 3])):
             for qmc in (False, True):
                 for seed in (None, 1234):
                     expected_shape = sample_shape + torch.Size([2, 1])
                     samples = construct_base_samples_from_posterior(
                         posterior=posterior,
                         sample_shape=sample_shape,
                         qmc=qmc,
                         seed=seed,
                     )
                     self.assertEqual(samples.shape, expected_shape)
                     self.assertEqual(samples.device.type, device.type)
                     self.assertEqual(samples.dtype, dtype)
         # single-output, batch mode
         mean = torch.zeros(2, 2, device=device, dtype=dtype)
         cov = torch.eye(2, device=device, dtype=dtype).expand(2, 2, 2)
         mvn = MultivariateNormal(mean=mean, covariance_matrix=cov)
         posterior = GPyTorchPosterior(mvn=mvn)
         for sample_shape in (torch.Size([5]), torch.Size([5, 3])):
             for qmc in (False, True):
                 for seed in (None, 1234):
                     for collapse_batch_dims in (False, True):
                         if collapse_batch_dims:
                             expected_shape = sample_shape + torch.Size([1, 2, 1])
                         else:
                             expected_shape = sample_shape + torch.Size([2, 2, 1])
                         samples = construct_base_samples_from_posterior(
                             posterior=posterior,
                             sample_shape=sample_shape,
                             qmc=qmc,
                             collapse_batch_dims=collapse_batch_dims,
                             seed=seed,
                         )
                         self.assertEqual(samples.shape, expected_shape)
                         self.assertEqual(samples.device.type, device.type)
                         self.assertEqual(samples.dtype, dtype)
         # multi-output
         mean = torch.zeros(2, 2, device=device, dtype=dtype)
         cov = torch.eye(4, device=device, dtype=dtype)
         mtmvn = MultitaskMultivariateNormal(mean=mean, covariance_matrix=cov)
         posterior = GPyTorchPosterior(mvn=mtmvn)
         for sample_shape in (torch.Size([5]), torch.Size([5, 3])):
             for qmc in (False, True):
                 for seed in (None, 1234):
                     expected_shape = sample_shape + torch.Size([2, 2])
                     samples = construct_base_samples_from_posterior(
                         posterior=posterior,
                         sample_shape=sample_shape,
                         qmc=qmc,
                         seed=seed,
                     )
                     self.assertEqual(samples.shape, expected_shape)
                     self.assertEqual(samples.device.type, device.type)
                     self.assertEqual(samples.dtype, dtype)
         # multi-output, batch mode
         mean = torch.zeros(2, 2, 2, device=device, dtype=dtype)
         cov = torch.eye(4, device=device, dtype=dtype).expand(2, 4, 4)
         mtmvn = MultitaskMultivariateNormal(mean=mean, covariance_matrix=cov)
         posterior = GPyTorchPosterior(mvn=mtmvn)
         for sample_shape in (torch.Size([5]), torch.Size([5, 3])):
             for qmc in (False, True):
                 for seed in (None, 1234):
                     for collapse_batch_dims in (False, True):
                         if collapse_batch_dims:
                             expected_shape = sample_shape + torch.Size([1, 2, 2])
                         else:
                             expected_shape = sample_shape + torch.Size([2, 2, 2])
                         samples = construct_base_samples_from_posterior(
                             posterior=posterior,
                             sample_shape=sample_shape,
                             qmc=qmc,
                             collapse_batch_dims=collapse_batch_dims,
                             seed=seed,
                         )
                         self.assertEqual(samples.shape, expected_shape)
                         self.assertEqual(samples.device.type, device.type)
                         self.assertEqual(samples.dtype, dtype)
예제 #8
0
 def forward(self, x):
     mean = torch.zeros(torch.Size([x.size(0)]),
                        dtype=x.dtype, device=x.device)
     return MultivariateNormal(mean, gpytorch.lazy.RootLazyTensor(x))
예제 #9
0
 def test_degenerate_GPyTorchPosterior_Multitask(self):
     for dtype in (torch.float, torch.double):
         # singular covariance matrix
         degenerate_covar = torch.tensor([[1, 1, 0], [1, 1, 0], [0, 0, 2]],
                                         dtype=dtype,
                                         device=self.device)
         mean = torch.rand(3, dtype=dtype, device=self.device)
         mvn = MultivariateNormal(mean, lazify(degenerate_covar))
         mvn = MultitaskMultivariateNormal.from_independent_mvns([mvn, mvn])
         posterior = GPyTorchPosterior(mvn=mvn)
         # basics
         self.assertEqual(posterior.device.type, self.device.type)
         self.assertTrue(posterior.dtype == dtype)
         self.assertEqual(posterior.event_shape, torch.Size([3, 2]))
         mean_exp = mean.unsqueeze(-1).repeat(1, 2)
         self.assertTrue(torch.equal(posterior.mean, mean_exp))
         variance_exp = degenerate_covar.diag().unsqueeze(-1).repeat(1, 2)
         self.assertTrue(torch.equal(posterior.variance, variance_exp))
         # rsample
         with warnings.catch_warnings(record=True) as w:
             # we check that the p.d. warning is emitted - this only
             # happens once per posterior, so we need to check only once
             samples = posterior.rsample(sample_shape=torch.Size([4]))
             self.assertEqual(len(w), 1)
             self.assertTrue(issubclass(w[-1].category, RuntimeWarning))
             self.assertTrue("not p.d." in str(w[-1].message))
         self.assertEqual(samples.shape, torch.Size([4, 3, 2]))
         samples2 = posterior.rsample(sample_shape=torch.Size([4, 2]))
         self.assertEqual(samples2.shape, torch.Size([4, 2, 3, 2]))
         # rsample w/ base samples
         base_samples = torch.randn(4,
                                    3,
                                    2,
                                    device=self.device,
                                    dtype=dtype)
         samples_b1 = posterior.rsample(sample_shape=torch.Size([4]),
                                        base_samples=base_samples)
         samples_b2 = posterior.rsample(sample_shape=torch.Size([4]),
                                        base_samples=base_samples)
         self.assertTrue(torch.allclose(samples_b1, samples_b2))
         base_samples2 = torch.randn(4,
                                     2,
                                     3,
                                     2,
                                     device=self.device,
                                     dtype=dtype)
         samples2_b1 = posterior.rsample(sample_shape=torch.Size([4, 2]),
                                         base_samples=base_samples2)
         samples2_b2 = posterior.rsample(sample_shape=torch.Size([4, 2]),
                                         base_samples=base_samples2)
         self.assertTrue(torch.allclose(samples2_b1, samples2_b2))
         # collapse_batch_dims
         b_mean = torch.rand(2, 3, dtype=dtype, device=self.device)
         b_degenerate_covar = degenerate_covar.expand(
             2, *degenerate_covar.shape)
         b_mvn = MultivariateNormal(b_mean, lazify(b_degenerate_covar))
         b_mvn = MultitaskMultivariateNormal.from_independent_mvns(
             [b_mvn, b_mvn])
         b_posterior = GPyTorchPosterior(mvn=b_mvn)
         b_base_samples = torch.randn(4,
                                      1,
                                      3,
                                      2,
                                      device=self.device,
                                      dtype=dtype)
         with warnings.catch_warnings(record=True) as w:
             b_samples = b_posterior.rsample(sample_shape=torch.Size([4]),
                                             base_samples=b_base_samples)
             self.assertEqual(len(w), 1)
             self.assertTrue(issubclass(w[-1].category, RuntimeWarning))
             self.assertTrue("not p.d." in str(w[-1].message))
         self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 2]))
예제 #10
0
    def posterior(
        self,
        X: Tensor,
        output_indices: Optional[List[int]] = None,
        observation_noise: Union[bool, Tensor] = False,
        **kwargs: Any,
    ) -> GPyTorchPosterior:
        self.eval()  # make sure we're calling a posterior
        # input transforms are applied at `posterior` in `eval` mode, and at
        # `model.forward()` at the training time
        X = self.transform_inputs(X)
        no_pred_variance = skip_posterior_variances._state

        with ExitStack() as es:
            es.enter_context(gpt_posterior_settings())
            es.enter_context(fast_pred_var(True))

            # we need to skip posterior variances here
            es.enter_context(skip_posterior_variances(True))
            mvn = self(X)
            if observation_noise is not False:
                # TODO: ensure that this still works for structured noise solves.
                mvn = self.likelihood(mvn, X)

            # lazy covariance matrix includes the interpolated version of the full
            # covariance matrix so we can actually grab that instead.
            if X.ndimension() > self.train_inputs[0].ndimension():
                X_batch_shape = X.shape[:-2]
                train_inputs = self.train_inputs[0].reshape(
                    *[1] * len(X_batch_shape), *self.train_inputs[0].shape
                )
                train_inputs = train_inputs.repeat(
                    *X_batch_shape, *[1] * self.train_inputs[0].ndimension()
                )
            else:
                train_inputs = self.train_inputs[0]

            # we now compute the data covariances for the training data, the testing
            # data, the joint covariances, and the test train cross-covariance
            train_train_covar = self.prediction_strategy.lik_train_train_covar.detach()
            base_train_train_covar = train_train_covar.lazy_tensor

            data_train_covar = base_train_train_covar.lazy_tensors[0]
            data_covar = self.covar_modules[0]
            data_train_test_covar = data_covar(X, train_inputs)
            data_test_test_covar = data_covar(X)
            data_joint_covar = data_train_covar.cat_rows(
                cross_mat=data_train_test_covar,
                new_mat=data_test_test_covar,
            )

            # we detach the latents so that they don't cause gradient errors
            # TODO: Can we enable backprop through the latent covariances?
            batch_shape = data_train_test_covar.batch_shape
            latent_covar_list = []
            for latent_covar in base_train_train_covar.lazy_tensors[1:]:
                if latent_covar.batch_shape != batch_shape:
                    latent_covar = BatchRepeatLazyTensor(latent_covar, batch_shape)
                latent_covar_list.append(latent_covar.detach())

            joint_covar = KroneckerProductLazyTensor(
                data_joint_covar, *latent_covar_list
            )
            test_train_covar = KroneckerProductLazyTensor(
                data_train_test_covar, *latent_covar_list
            )

            # compute the posterior variance if necessary
            if no_pred_variance:
                pred_variance = mvn.variance
            else:
                pred_variance = self.make_posterior_variances(joint_covar)

            # mean and variance get reshaped into the target shape
            new_mean = mvn.mean.reshape(*X.shape[:-1], *self.target_shape)
            if not no_pred_variance:
                new_variance = pred_variance.reshape(*X.shape[:-1], *self.target_shape)
                new_variance = DiagLazyTensor(new_variance)
            else:
                new_variance = ZeroLazyTensor(
                    *X.shape[:-1], *self.target_shape, self.target_shape[-1]
                )

            mvn = MultivariateNormal(new_mean, new_variance)

            # return a specialized Posterior to allow for sampling
            # cloning the full covar allows backpropagation through it
            posterior = HigherOrderGPPosterior(
                mvn=mvn,
                train_targets=self.train_targets.unsqueeze(-1),
                train_train_covar=train_train_covar,
                test_train_covar=test_train_covar,
                joint_covariance_matrix=joint_covar.clone(),
                output_shape=X.shape[:-1] + self.target_shape,
                num_outputs=self._num_outputs,
            )
            if hasattr(self, "outcome_transform"):
                posterior = self.outcome_transform.untransform_posterior(posterior)

            return posterior
예제 #11
0
 def forward(self, state_input: Tensor) -> MultivariateNormal:
     """Forward call of GP class."""
     mean_x = self.mean_module(state_input)
     covar_x = self.covar_module(state_input)
     return MultivariateNormal(mean_x, covar_x)
예제 #12
0
 def forward(self, x: torch.Tensor) -> MultivariateNormal:
     mean_x = self.mean_module(x)
     covar_x = self.covar_module(x)
     return MultivariateNormal(mean_x, covar_x)
예제 #13
0
    def test_added_diag_lt(self, N=10000, p=20, use_cuda=False, seed=1):

        torch.manual_seed(seed)

        if torch.cuda.is_available() and use_cuda:
            print("Using cuda")
            device = torch.device("cuda")
            torch.cuda.manual_seed_all(seed)
        else:
            device = torch.device("cpu")

        D = torch.randn(N, p, device=device)
        A = torch.randn(N, device=device).abs() * 1e-3 + 0.1

        # this is a lazy tensor for DD'
        D_lt = RootLazyTensor(D)

        # this is a lazy tensor for diag(A)
        diag_term = DiagLazyTensor(A)

        # DD' + diag(A)
        lowrank_pdiag_lt = AddedDiagLazyTensor(diag_term, D_lt)

        # z \sim N(0,I), mean = 1
        z = torch.randn(N, device=device)
        mean = torch.ones(N, device=device)

        diff = mean - z

        print(lowrank_pdiag_lt.log_det())
        logdet = lowrank_pdiag_lt.log_det()
        inv_matmul = lowrank_pdiag_lt.inv_matmul(diff.unsqueeze(1)).squeeze(1)
        inv_matmul_quad = torch.dot(diff, inv_matmul)
        """inv_matmul_quad_qld, logdet_qld = lowrank_pdiag_lt.inv_quad_log_det(inv_quad_rhs=diff.unsqueeze(1), log_det = True)
        
        """
        """from gpytorch.functions._inv_quad_log_det import InvQuadLogDet
        iqld_construct = InvQuadLogDet(gpytorch.lazy.lazy_tensor_representation_tree.LazyTensorRepresentationTree(lowrank_pdiag_lt),
                            matrix_shape=lowrank_pdiag_lt.matrix_shape,
                            dtype=lowrank_pdiag_lt.dtype,
                            device=lowrank_pdiag_lt.device,
                            inv_quad=True,
                            log_det=True,
                            preconditioner=lowrank_pdiag_lt._preconditioner()[0],
                            log_det_correction=lowrank_pdiag_lt._preconditioner()[1])
        inv_matmul_quad_qld, logdet_qld = iqld_construct(diff.unsqueeze(1))"""
        num_random_probes = gpytorch.settings.num_trace_samples.value()
        probe_vectors = torch.empty(
            lowrank_pdiag_lt.matrix_shape[-1],
            num_random_probes,
            dtype=lowrank_pdiag_lt.dtype,
            device=lowrank_pdiag_lt.device,
        )
        probe_vectors.bernoulli_().mul_(2).add_(-1)
        probe_vector_norms = torch.norm(probe_vectors, 2, dim=-2, keepdim=True)
        probe_vectors = probe_vectors.div(probe_vector_norms)

        # diff_norm = diff.norm()
        # diff = diff/diff_norm
        rhs = torch.cat([diff.unsqueeze(1), probe_vectors], dim=1)

        solves, t_mat = gpytorch.utils.linear_cg(
            lowrank_pdiag_lt.matmul,
            rhs,
            n_tridiag=num_random_probes,
            max_iter=gpytorch.settings.max_cg_iterations.value(),
            max_tridiag_iter=gpytorch.settings.
            max_lanczos_quadrature_iterations.value(),
            preconditioner=lowrank_pdiag_lt._preconditioner()[0],
        )
        # print(solves)
        inv_matmul_qld = solves[:, 0]  # * diff_norm

        diff_solve = gpytorch.utils.linear_cg(
            lowrank_pdiag_lt.matmul,
            diff.unsqueeze(1),
            max_iter=gpytorch.settings.max_cg_iterations.value(),
            preconditioner=lowrank_pdiag_lt._preconditioner()[0],
        )
        print("diff_solve_norm: ", diff_solve.norm())
        print(
            "diff between multiple linear_cg: ",
            (inv_matmul_qld.unsqueeze(1) - diff_solve).norm() /
            diff_solve.norm(),
        )

        eigenvalues, eigenvectors = gpytorch.utils.lanczos.lanczos_tridiag_to_diag(
            t_mat)
        slq = gpytorch.utils.StochasticLQ()
        log_det_term, = slq.evaluate(
            lowrank_pdiag_lt.matrix_shape,
            eigenvalues,
            eigenvectors,
            [lambda x: x.log()],
        )
        logdet_qld = log_det_term + lowrank_pdiag_lt._preconditioner()[1]

        print("Log det difference: ",
              (logdet - logdet_qld).norm() / logdet.norm())
        print(
            "inv matmul difference: ",
            (inv_matmul - inv_matmul_qld).norm() / inv_matmul_quad.norm(),
        )

        # N(1, DD' + diag(A))
        lazydist = MultivariateNormal(mean, lowrank_pdiag_lt)
        lazy_lprob = lazydist.log_prob(z)

        # exact log probability with Cholesky decomposition
        exact_dist = torch.distributions.MultivariateNormal(
            mean,
            lowrank_pdiag_lt.evaluate().float())
        exact_lprob = exact_dist.log_prob(z)

        print(lazy_lprob, exact_lprob)
        rel_error = torch.norm(lazy_lprob - exact_lprob) / exact_lprob.norm()

        self.assertLess(rel_error.cpu().item(), 0.01)
예제 #14
0
파일: ssm.py 프로젝트: zhidilin/GPSSMtorch
    def forward(self, *inputs: Tensor, **kwargs
                ) -> Tuple[List[MultivariateNormal], Tensor]:
        """Forward propagate the model.

        Parameters
        ----------
        inputs: Tensor.
            output_sequence: Tensor.
            Tensor of output data [batch_size x sequence_length x dim_outputs].

            input_sequence: Tensor.
            Tensor of input data [batch_size x sequence_length x dim_inputs].

        Returns
        -------
        output_distribution: List[Normal].
            List of length sequence_length of distributions of size
            [batch_size x dim_outputs x num_particles]
        """
        output_sequence, input_sequence = inputs
        num_particles = self.num_particles
        # dim_states = self.dim_states
        batch_size, sequence_length, dim_inputs = input_sequence.shape
        _, _, dim_outputs = output_sequence.shape

        ################################################################################
        # SAMPLE GP #
        ################################################################################
        self.forward_model.resample()
        self.backward_model.resample()

        ################################################################################
        # PERFORM Backward Pass #
        ################################################################################
        if self.training:
            output_distribution = self.backward(output_sequence, input_sequence)

        ################################################################################
        # Initial State #
        ################################################################################
        state = self.recognition(output_sequence[:, :self.recognition.length],
                                 input_sequence[:, :self.recognition.length],
                                 num_particles=num_particles)

        ################################################################################
        # PREDICT Outputs #
        ################################################################################
        outputs = []
        y_pred = self.emissions(state)
        outputs.append(MultivariateNormal(y_pred.loc.detach(),
                                          y_pred.covariance_matrix.detach()))

        ################################################################################
        # INITIALIZE losses #
        ################################################################################

        # entropy = torch.tensor(0.)
        if self.training:
            output_distribution.pop(0)
            # entropy += y_tilde.entropy().mean() / sequence_length

        y = output_sequence[:, 0].expand(num_particles, batch_size, dim_outputs
                                         ).permute(1, 2, 0)
        log_lik = y_pred.log_prob(y).sum(dim=1).mean()  # type: torch.Tensor
        l2 = ((y_pred.loc - y) ** 2).sum(dim=1).mean()  # type: torch.Tensor
        kl_cond = torch.tensor(0.)

        for t in range(sequence_length - 1):
            ############################################################################
            # PREDICT Next State #
            ############################################################################
            u = input_sequence[:, t].expand(num_particles, batch_size, dim_inputs)
            u = u.permute(1, 2, 0)  # Move last component to end.
            state_samples = state.rsample()
            state_input = torch.cat((state_samples, u), dim=1)

            next_f = self.forward_model(state_input)
            next_state = self.transitions(next_f)
            next_state.loc += state_samples

            if self.independent_particles:
                next_state = diagonal_covariance(next_state)
            ############################################################################
            # CONDITION Next State #
            ############################################################################
            if self.training:
                y_tilde = output_distribution.pop(0)
                p_next_state = next_state
                next_state = self._condition(next_state, y_tilde)
                kl_cond += kl_divergence(next_state, p_next_state).mean()
            ############################################################################
            # RESAMPLE State #
            ############################################################################
            state = next_state

            ############################################################################
            # PREDICT Outputs #
            ############################################################################
            y_pred = self.emissions(state)
            outputs.append(y_pred)

            ############################################################################
            # COMPUTE Losses #
            ############################################################################
            y = output_sequence[:, t + 1].expand(
                num_particles, batch_size, dim_outputs).permute(1, 2, 0)
            log_lik += y_pred.log_prob(y).sum(dim=1).mean()
            l2 += ((y_pred.loc - y) ** 2).sum(dim=1).mean()
            # entropy += y_tilde.entropy().mean() / sequence_length

        assert len(outputs) == sequence_length

        # if self.training:
        #     del output_distribution
        ################################################################################
        # Compute model KL divergences Divergences #
        ################################################################################
        factor = 1  # batch_size / self.dataset_size
        kl_uf = self.forward_model.kl_divergence()
        kl_ub = self.backward_model.kl_divergence()

        if self.forward_model.independent:
            kl_uf *= sequence_length
        if self.backward_model.independent:
            kl_ub *= sequence_length

        kl_cond = kl_cond * self.loss_factors['kl_conditioning'] * factor
        kl_ub = kl_ub * self.loss_factors['kl_u'] * factor
        kl_uf = kl_uf * self.loss_factors['kl_u'] * factor

        if self.loss_key.lower() == 'loglik':
            loss = -log_lik
        elif self.loss_key.lower() == 'elbo':
            loss = -(log_lik - kl_uf - kl_ub - kl_cond)
            if kwargs.get('print', False):
                str_ = 'elbo: {}, log_lik: {}, kluf: {}, klub: {}, klcond: {}'
                print(str_.format(loss.item(), log_lik.item(), kl_uf.item(),
                                  kl_ub.item(), kl_cond.item()))
        elif self.loss_key.lower() == 'l2':
            loss = l2
        elif self.loss_key.lower() == 'rmse':
            loss = torch.sqrt(l2)
        else:
            raise NotImplementedError("Key {} not implemented".format(self.loss_key))

        return outputs, loss
예제 #15
0
파일: gp_torch.py 프로젝트: pnickl/reg
 def forward(self, input):
     mean = self.mean_module(input)
     covar = self.covar_module(input)
     return MultivariateNormal(mean, covar)
예제 #16
0
    def test_GPyTorchPosterior(self):
        for dtype in (torch.float, torch.double):
            n = 3
            mean = torch.rand(n, dtype=dtype, device=self.device)
            variance = 1 + torch.rand(n, dtype=dtype, device=self.device)
            covar = variance.diag()
            mvn = MultivariateNormal(mean, lazify(covar))
            posterior = GPyTorchPosterior(mvn=mvn)
            # basics
            self.assertEqual(posterior.device.type, self.device.type)
            self.assertTrue(posterior.dtype == dtype)
            self.assertEqual(posterior.event_shape, torch.Size([n, 1]))
            self.assertTrue(torch.equal(posterior.mean, mean.unsqueeze(-1)))
            self.assertTrue(
                torch.equal(posterior.variance, variance.unsqueeze(-1)))
            # rsample
            samples = posterior.rsample()
            self.assertEqual(samples.shape, torch.Size([1, n, 1]))
            for sample_shape in ([4], [4, 2]):
                samples = posterior.rsample(
                    sample_shape=torch.Size(sample_shape))
                self.assertEqual(samples.shape,
                                 torch.Size(sample_shape + [n, 1]))
            # check enabling of approximate root decomposition
            with ExitStack() as es:
                mock_func = es.enter_context(
                    mock.patch(ROOT_DECOMP_PATH,
                               return_value=torch.linalg.cholesky(covar)))
                es.enter_context(gpt_settings.max_cholesky_size(0))
                es.enter_context(
                    gpt_settings.fast_computations(
                        covar_root_decomposition=True))
                # need to clear cache, cannot re-use previous objects
                mvn = MultivariateNormal(mean, lazify(covar))
                posterior = GPyTorchPosterior(mvn=mvn)
                posterior.rsample(sample_shape=torch.Size([4]))
                mock_func.assert_called_once()

            # rsample w/ base samples
            base_samples = torch.randn(4,
                                       3,
                                       1,
                                       device=self.device,
                                       dtype=dtype)
            # incompatible shapes
            with self.assertRaises(RuntimeError):
                posterior.rsample(sample_shape=torch.Size([3]),
                                  base_samples=base_samples)
            # ensure consistent result
            for sample_shape in ([4], [4, 2]):
                base_samples = torch.randn(*sample_shape,
                                           3,
                                           1,
                                           device=self.device,
                                           dtype=dtype)
                samples = [
                    posterior.rsample(sample_shape=torch.Size(sample_shape),
                                      base_samples=base_samples)
                    for _ in range(2)
                ]
                self.assertTrue(torch.allclose(*samples))
            # collapse_batch_dims
            b_mean = torch.rand(2, 3, dtype=dtype, device=self.device)
            b_variance = 1 + torch.rand(2, 3, dtype=dtype, device=self.device)
            b_covar = torch.diag_embed(b_variance)
            b_mvn = MultivariateNormal(b_mean, lazify(b_covar))
            b_posterior = GPyTorchPosterior(mvn=b_mvn)
            b_base_samples = torch.randn(4,
                                         1,
                                         3,
                                         1,
                                         device=self.device,
                                         dtype=dtype)
            b_samples = b_posterior.rsample(sample_shape=torch.Size([4]),
                                            base_samples=b_base_samples)
            self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 1]))
예제 #17
0
 def forward(self, input):
     mean = self.mean_module(input)
     covar = self.covar_module(input)
     return MultitaskMultivariateNormal.from_batch_mvn(
         MultivariateNormal(mean, covar))
예제 #18
0
 def forward(self, x):
     mean = self.mean(x)
     covar = self.covariance(x)
     return MultivariateNormal(mean, covar)
 def forward(self, x):
     x_mean = self.mean(x)
     x_covar = self.covar(x)
     return MultivariateNormal(x_mean, x_covar)
예제 #20
0
def gpnet(args, dataloader, test_x, prior_gp):
    N = len(dataloader.dataset)
    x_dim = 1
    prior_gp.train()

    if args.net == 'tangent':
        kernel = prior_gp.covar_module
        bnn_prev = FirstOrder([x_dim] + [args.n_hidden] * args.n_layer,
                              mvn=False)
        bnn = FirstOrder([x_dim] + [args.n_hidden] * args.n_layer, mvn=True)
    elif args.net == 'deep':
        kernel = prior_gp.covar_module
        bnn_prev = DeepKernel([x_dim] + [args.n_hidden] * args.n_layer,
                              mvn=False)
        bnn = DeepKernel([x_dim] + [args.n_hidden] * args.n_layer, mvn=True)
    elif args.net == 'rf':
        kernel = ScaleKernel(RBFKernel())
        kernel_prev = ScaleKernel(RBFKernel())
        bnn_prev = RFExpansion(x_dim,
                               args.n_hidden,
                               kernel_prev,
                               mvn=False,
                               fix_ls=args.fix_rf_ls,
                               residual=args.residual)
        bnn = RFExpansion(x_dim,
                          args.n_hidden,
                          kernel,
                          fix_ls=args.fix_rf_ls,
                          residual=args.residual)
        bnn_prev.load_state_dict(bnn.state_dict())
    else:
        raise NotImplementedError('Unknown inference net')
    bnn = bnn.to(args.device)
    bnn_prev = bnn_prev.to(args.device)
    prior_gp = prior_gp.to(args.device)

    infer_gpnet_optimizer = optim.Adam(bnn.parameters(), lr=args.learning_rate)
    hyper_opt_optimizer = optim.Adam(prior_gp.parameters(), lr=args.hyper_rate)

    x_min, x_max = dataloader.dataset.range

    bnn.train()
    bnn_prev.train()
    prior_gp.train()

    mb = master_bar(range(1, args.n_iters + 1))

    for t in mb:
        # Hyperparameter selection
        beta = args.beta0 * 1. / (1. + args.gamma * math.sqrt(t - 1))
        dl_bar = progress_bar(dataloader, parent=mb)
        for x, y in dl_bar:
            observed_size = x.size(0)
            x, y = x.to(args.device), y.to(args.device)
            x_star = torch.Tensor(args.measurement_size,
                                  x_dim).uniform_(x_min, x_max).to(args.device)
            # [Batch + Measurement Points x x_dims]
            xx = torch.cat([x, x_star], 0)

            infer_gpnet_optimizer.zero_grad()
            hyper_opt_optimizer.zero_grad()

            # inference net
            # Eq.(6) Prior p(f)
            # \mu_1=0, \Sigma_1
            mean_prior = torch.zeros(observed_size).to(args.device)
            K_prior = kernel(xx, xx).add_jitter(1e-6)

            # q_{\gamma_t}(f_M, f_n) = Normal(mu_2, sigma_2|x_n, x_m)
            # \mu_2, \Sigma_2
            qff_mean_prev, K_prox = bnn_prev(xx)

            # Eq.(8) adapt prior; p(f)^\beta x q(f)^{1 - \beta}
            mean_adapt, K_adapt = product_gaussians(mu1=mean_prior,
                                                    sigma1=K_prior,
                                                    mu2=qff_mean_prev,
                                                    sigma2=K_prox,
                                                    beta=beta)

            # Eq.(8)
            (mean_n, mean_m), (Knn, Knm,
                               Kmm) = split_gaussian(mean_adapt, K_adapt,
                                                     observed_size)

            # Eq.(2) K_{D,D} + noise / (N\beta_t)
            Ky = Knn + torch.eye(observed_size).to(
                args.device) * prior_gp.likelihood.noise / (N / observed_size *
                                                            beta)
            Ky_tril = torch.cholesky(Ky)

            # Eq.(2)
            mean_target = Knm.t().mm(cholesky_solve(y - mean_n,
                                                    Ky_tril)) + mean_m
            mean_target = mean_target.squeeze(-1)
            K_target = gpytorch.add_jitter(
                Kmm - Knm.t().mm(cholesky_solve(Knm, Ky_tril)), 1e-6)
            # \hat{q}_{t+1} (f_M)
            target_pf_star = MultivariateNormal(mean_target, K_target)

            # q_\gamma (f_M)
            qf_star = bnn(x_star)

            # Eq. (11)
            kl_obj = kl_div(qf_star, target_pf_star).sum()

            kl_obj.backward(retain_graph=True)
            infer_gpnet_optimizer.step()

            # Hyper paramter update
            (mean_n_prior, _), (Kn_prior, _,
                                _) = split_gaussian(mean_prior, K_prior,
                                                    observed_size)
            pf = MultivariateNormal(mean_n_prior, Kn_prior)

            (qf_prev_mean, _), (Kn_prox, _,
                                _) = split_gaussian(qff_mean_prev, K_prox,
                                                    observed_size)
            qf_prev = MultivariateNormal(qf_prev_mean, Kn_prox)

            hyper_obj = -(prior_gp.likelihood.expected_log_prob(
                y.squeeze(-1), qf_prev) - kl_div(qf_prev, pf))
            hyper_obj.backward(retain_graph=True)
            hyper_opt_optimizer.step()

            mb.child.comment = "kl_obj = {:.3f}, obs_var={:.3f}".format(
                kl_obj.item(), prior_gp.likelihood.noise.item())

        # update q_{\gamma_t} to q_{\gamma_{t+1}}
        bnn_prev.load_state_dict(bnn.state_dict())
        if args.net == 'rf':
            kernel_prev.load_state_dict(kernel.state_dict())
        if t % 50 == 0:
            mb.write("Iter {}/{}, kl_obj = {:.4f}, noise = {:.4f}".format(
                t, args.n_iters, kl_obj.item(),
                prior_gp.likelihood.noise.item()))

    test_x = test_x.to(args.device)
    test_stats = evaluate(bnn, prior_gp.likelihood, test_x,
                          args.net == 'tangent')
    return test_stats
예제 #21
0
def _get_test_posterior(device, dtype=torch.float):
    mean = torch.zeros(2, device=device, dtype=dtype)
    cov = torch.eye(2, device=device, dtype=dtype)
    mvn = MultivariateNormal(mean, cov)
    return GPyTorchPosterior(mvn)
예제 #22
0
def gpnet_nonconj(args, dataloader, test_x, prior_gp):
    N = len(dataloader.dataset)
    x_dim = 1
    prior_gp.train()

    if args.net == 'tangent':
        kernel = prior_gp.covar_module
        bnn_prev = FirstOrder([x_dim] + [args.n_hidden] * args.n_layer,
                              mvn=False)
        bnn = FirstOrder([x_dim] + [args.n_hidden] * args.n_layer, mvn=True)
    elif args.net == 'deep':
        kernel = prior_gp.covar_module
        bnn_prev = DeepKernel([x_dim] + [args.n_hidden] * args.n_layer,
                              mvn=False)
        bnn = DeepKernel([x_dim] + [args.n_hidden] * args.n_layer, mvn=True)
    elif args.net == 'rf':
        kernel = ScaleKernel(RBFKernel())
        kernel_prev = ScaleKernel(RBFKernel())
        bnn_prev = RFExpansion(x_dim,
                               args.n_hidden,
                               kernel_prev,
                               mvn=False,
                               fix_ls=args.fix_rf_ls,
                               residual=args.residual)
        bnn = RFExpansion(x_dim,
                          args.n_hidden,
                          kernel,
                          fix_ls=args.fix_rf_ls,
                          residual=args.residual)
        bnn_prev.load_state_dict(bnn.state_dict())
    else:
        raise NotImplementedError('Unknown inference net')

    infer_gpnet_optimizer = optim.Adam(bnn.parameters(), lr=args.learning_rate)
    hyper_opt_optimizer = optim.Adam(prior_gp.parameters(), lr=args.hyper_rate)

    x_min, x_max = dataloader.dataset.range
    n = dataloader.batch_size

    bnn.train()
    bnn_prev.train()
    prior_gp.train()

    mb = master_bar(range(1, args.n_iters + 1))

    for t in mb:
        beta = args.beta0 * 1. / (1. + args.gamma * math.sqrt(t - 1))
        dl_bar = progress_bar(dataloader, parent=mb)
        for x, y in dl_bar:
            n = x.size(0)
            x_star = torch.Tensor(args.measurement_size,
                                  x_dim).uniform_(x_min, x_max)
            xx = torch.cat([x, x_star], 0)

            # inference net
            infer_gpnet_optimizer.zero_grad()
            hyper_opt_optimizer.zero_grad()

            qff = bnn(xx)
            qff_mean_prev, K_prox = bnn_prev(xx)
            qf_mean, qf_var = bnn(x, full_cov=False)

            # Eq.(8)
            K_prior = kernel(xx, xx).add_jitter(1e-6)
            pff = MultivariateNormal(torch.zeros(xx.size(0)), K_prior)

            f_term = expected_log_prob(prior_gp.likelihood, qf_mean, qf_var,
                                       y.squeeze(-1))
            f_term = torch.sum(
                expected_log_prob(prior_gp.likelihood, qf_mean, qf_var,
                                  y.squeeze(-1)))
            f_term *= N / x.size(0) * beta

            prior_term = -beta * cross_entropy(qff, pff)

            qff_prev = MultivariateNormal(qff_mean_prev, K_prox)
            prox_term = -(1 - beta) * cross_entropy(qff, qff_prev)

            entropy_term = entropy(qff)

            lower_bound = f_term + prior_term + prox_term + entropy_term
            loss = -lower_bound / n

            loss.backward(retain_graph=True)

            infer_gpnet_optimizer.step()

            # Hyper-parameter update
            Kn_prior = K_prior[:n, :n]
            pf = MultivariateNormal(torch.zeros(n), Kn_prior)
            Kn_prox = K_prox[:n, :n]
            qf_prev_mean = qff_mean_prev[:n]
            qf_prev_var = torch.diagonal(Kn_prox)
            qf_prev = MultivariateNormal(qf_prev_mean, Kn_prior)
            hyper_obj = expected_log_prob(
                prior_gp.likelihood, qf_prev_mean, qf_prev_var,
                y.squeeze(-1)).sum() - kl_div(qf_prev, pf)
            hyper_obj = -hyper_obj
            hyper_obj.backward()
            hyper_opt_optimizer.step()

        bnn_prev.load_state_dict(bnn.state_dict())
        if args.net == 'rf':
            kernel_prev.load_state_dict(kernel.state_dict())
        if t % 50 == 0:
            mb.write("Iter {}/{}, kl_obj = {:.4f}, noise = {:.4f}".format(
                t, args.n_iters, lower_bound.item(),
                prior_gp.likelihood.noise.item()))
    test_x = test_x.to(args.device)
    test_stats = evaluate(bnn, prior_gp.likelihood, test_x,
                          args.net == 'tangent')

    return test_stats
예제 #23
0
 def forward(self, x):
     features = self.feature_extractor(x)
     mean_x = self.mean_module(features)
     covar_x = self.covar_module(features)
     return MultivariateNormal(mean_x, covar_x)
예제 #24
0
    def posterior(
        self,
        X: Tensor,
        output_indices: Optional[List[int]] = None,
        observation_noise: Union[bool, Tensor] = False,
        **kwargs: Any,
    ) -> GPyTorchPosterior:
        r"""Computes the posterior over model outputs at the provided points.

        Args:
            X: A `(batch_shape) x q x d`-dim Tensor, where `d` is the dimension
                of the feature space and `q` is the number of points considered
                jointly.
            output_indices: A list of indices, corresponding to the outputs over
                which to compute the posterior (if the model is multi-output).
                Can be used to speed up computation if only a subset of the
                model's outputs are required for optimization. If omitted,
                computes the posterior over all model outputs.
            observation_noise: If True, add the observation noise from the
                likelihood to the posterior. If a Tensor, use it directly as the
                observation noise (must be of shape `(batch_shape) x q x m`).

        Returns:
            A `GPyTorchPosterior` object, representing `batch_shape` joint
            distributions over `q` points and the outputs selected by
            `output_indices` each. Includes observation noise if specified.
        """
        self.eval()  # make sure model is in eval mode
        with gpt_posterior_settings():
            # insert a dimension for the output dimension
            if self._num_outputs > 1:
                X, output_dim_idx = add_output_dim(
                    X=X, original_batch_shape=self._input_batch_shape)
            mvn = self(X)
            if observation_noise is not False:
                if torch.is_tensor(observation_noise):
                    # TODO: Validate noise shape
                    # make observation_noise `batch_shape x q x n`
                    obs_noise = observation_noise.transpose(-1, -2)
                    mvn = self.likelihood(mvn, X, noise=obs_noise)
                elif isinstance(self.likelihood, FixedNoiseGaussianLikelihood):
                    # Use the mean of the previous noise values (TODO: be smarter here).
                    noise = self.likelihood.noise.mean().expand(X.shape[:-1])
                    mvn = self.likelihood(mvn, X, noise=noise)
                else:
                    mvn = self.likelihood(mvn, X)
            if self._num_outputs > 1:
                mean_x = mvn.mean
                covar_x = mvn.covariance_matrix
                output_indices = output_indices or range(self._num_outputs)
                mvns = [
                    MultivariateNormal(
                        mean_x.select(dim=output_dim_idx, index=t),
                        lazify(covar_x.select(dim=output_dim_idx, index=t)),
                    ) for t in output_indices
                ]
                mvn = MultitaskMultivariateNormal.from_independent_mvns(
                    mvns=mvns)

        posterior = GPyTorchPosterior(mvn=mvn)
        if hasattr(self, "outcome_transform"):
            posterior = self.outcome_transform.untransform_posterior(posterior)
        return posterior
예제 #25
0
 def _initialize_latents(
     self,
     latent_init: str,
     num_latent_dims: List[int],
     learn_latent_pars: bool,
     device: torch.device,
     dtype: torch.dtype,
 ):
     self.latent_parameters = ParameterList()
     if latent_init == "default":
         for dim_num in range(len(self.covar_modules) - 1):
             self.latent_parameters.append(
                 Parameter(
                     torch.rand(
                         *self._aug_batch_shape,
                         self.target_shape[dim_num],
                         num_latent_dims[dim_num],
                         device=device,
                         dtype=dtype,
                     ),
                     requires_grad=learn_latent_pars,
                 ))
     elif latent_init == "gp":
         for dim_num, covar in enumerate(self.covar_modules[1:]):
             latent_covar = covar(
                 torch.linspace(
                     0.0,
                     1.0,
                     self.target_shape[dim_num],
                     device=device,
                     dtype=dtype,
                 )).add_jitter(1e-4)
             latent_dist = MultivariateNormal(
                 torch.zeros(
                     self.target_shape[dim_num],
                     device=device,
                     dtype=dtype,
                 ),
                 latent_covar,
             )
             sample_shape = torch.Size((
                 *self._aug_batch_shape,
                 num_latent_dims[dim_num],
             ))
             latent_sample = latent_dist.sample(sample_shape=sample_shape)
             latent_sample = latent_sample.reshape(
                 *self._aug_batch_shape,
                 self.target_shape[dim_num],
                 num_latent_dims[dim_num],
             )
             self.latent_parameters.append(
                 Parameter(
                     latent_sample,
                     requires_grad=learn_latent_pars,
                 ))
             self.register_prior(
                 "latent_parameters_" + str(dim_num),
                 MultivariateNormalPrior(
                     latent_dist.loc,
                     latent_dist.covariance_matrix.detach().clone()),
                 lambda module, dim_num=dim_num: self.latent_parameters[
                     dim_num],
             )
예제 #26
0
 def forward(self, x):
     x = map_box_ball(x, self.dim)
     mean_x = self.mean_module(x)
     covar_x = self.covar_module(x)
     return MultivariateNormal(mean_x, covar_x)
예제 #27
0
 def _create_marginal_input(self, batch_shape=torch.Size()):
     mat = torch.randn(*batch_shape, 5, 5)
     eye = torch.diag_embed(torch.ones(*batch_shape, 5))
     return MultivariateNormal(torch.randn(*batch_shape, 5),
                               mat @ mat.transpose(-1, -2) + eye)
예제 #28
0
 def forward(self, x):
     """ApproximateGPModelのforwardメソッド
     """
     mean_x = self.mean_module(x)
     covar_x = self.covar_module(x)
     return MultivariateNormal(mean_x, covar_x)
예제 #29
0
 def forward(self, x):
     mean_x = self.mean_module(x)
     covar_x = self.covar_module(x)
     return MultivariateNormal(mean_x, covar_x)
예제 #30
0
 def _create_marginal_input(self, batch_shape=torch.Size([])):
     mat = torch.randn(*batch_shape, 6, 5, 5)
     return MultivariateNormal(torch.randn(*batch_shape, 6, 5),
                               mat @ mat.transpose(-1, -2))