def test_setup(self):
        mean = torch.zeros(1, 1)
        variance = torch.ones(1, 1)
        mm = MockModel(MockPosterior(mean=mean, variance=variance))
        # basic test
        sampler = IIDNormalSampler(1)
        acqf = DummyCachedCholeskyAcqf(model=mm, sampler=sampler)
        acqf._setup(model=mm, sampler=sampler)
        self.assertFalse(acqf._is_mt)
        self.assertFalse(acqf._is_deterministic)
        self.assertFalse(acqf._uses_matheron)
        self.assertFalse(acqf._cache_root)
        acqf._setup(model=mm, sampler=sampler, cache_root=True)
        self.assertTrue(acqf._cache_root)

        # test check_sampler
        with warnings.catch_warnings(record=True) as ws, settings.debug(True):
            acqf._setup(model=mm, sampler=sampler, check_sampler=True)
            self.assertEqual(len(ws), 0)

        # test collapse_batch_dims=False
        sampler = IIDNormalSampler(1, collapse_batch_dims=False)
        acqf = DummyCachedCholeskyAcqf(model=mm, sampler=sampler)
        with self.assertRaises(UnsupportedError):
            acqf._setup(model=mm, sampler=sampler, check_sampler=True)
        # test warning if base_samples is not None
        sampler = IIDNormalSampler(1)
        sampler.base_samples = torch.zeros(1, 1)
        acqf = DummyCachedCholeskyAcqf(model=mm, sampler=sampler)
        with warnings.catch_warnings(record=True) as ws, settings.debug(True):
            acqf._setup(model=mm, sampler=sampler, check_sampler=True)
            self.assertTrue(issubclass(ws[-1].category, BotorchWarning))
        # test the base_samples are set to None
        self.assertIsNone(acqf.sampler.base_samples)
        # test model that uses matheron's rule and sampler.batch_range != (0, -1)
        hogp = HigherOrderGP(torch.zeros(1, 1), torch.zeros(1, 1, 1)).eval()
        acqf = DummyCachedCholeskyAcqf(model=hogp, sampler=sampler)
        with self.assertRaises(RuntimeError):
            acqf._setup(model=hogp, sampler=sampler, cache_root=True)
        self.assertTrue(acqf._uses_matheron)
        self.assertTrue(acqf._is_mt)
        self.assertFalse(acqf._is_deterministic)

        # test deterministic model
        model = GenericDeterministicModel(f=lambda X: X)
        acqf = DummyCachedCholeskyAcqf(model=model, sampler=sampler)
        acqf._setup(model=model, sampler=sampler, cache_root=True)
        self.assertTrue(acqf._is_deterministic)
        self.assertFalse(acqf._uses_matheron)
        self.assertFalse(acqf._is_mt)
        self.assertFalse(acqf._cache_root)
    def test_cache_root(self):
        sample_cached_path = (
            "botorch.acquisition.cached_cholesky.sample_cached_cholesky")
        raw_state_dict = {
            "likelihood.noise_covar.raw_noise":
            torch.tensor([[0.0895], [0.2594]], dtype=torch.float64),
            "mean_module.constant":
            torch.tensor([[-0.4545], [-0.1285]], dtype=torch.float64),
            "covar_module.raw_outputscale":
            torch.tensor([1.4876, 1.4897], dtype=torch.float64),
            "covar_module.base_kernel.raw_lengthscale":
            torch.tensor([[[-0.7202, -0.2868]], [[-0.8794, -1.2877]]],
                         dtype=torch.float64),
        }
        # test batched models (e.g. for MCMC)
        for train_batch_shape, m, dtype in product(
            (torch.Size([]), torch.Size([3])), (1, 2),
            (torch.float, torch.double)):
            state_dict = deepcopy(raw_state_dict)
            for k, v in state_dict.items():
                if m == 1:
                    v = v[0]
                if len(train_batch_shape) > 0:
                    v = v.unsqueeze(0).expand(*train_batch_shape, *v.shape)
                state_dict[k] = v
            tkwargs = {"device": self.device, "dtype": dtype}
            if m == 2:
                objective = GenericMCObjective(lambda Y, X: Y.sum(dim=-1))
            else:
                objective = None
            for k, v in state_dict.items():
                state_dict[k] = v.to(**tkwargs)
            all_close_kwargs = ({
                "atol": 1e-1,
                "rtol": 0.0,
            } if dtype == torch.float else {
                "atol": 1e-4,
                "rtol": 0.0
            })
            torch.manual_seed(1234)
            train_X = torch.rand(*train_batch_shape, 3, 2, **tkwargs)
            train_Y = (
                torch.sin(train_X * 2 * pi) +
                torch.randn(*train_batch_shape, 3, 2, **tkwargs))[..., :m]
            train_Y = standardize(train_Y)
            model = SingleTaskGP(
                train_X,
                train_Y,
            )
            if len(train_batch_shape) > 0:
                X_baseline = train_X[0]
            else:
                X_baseline = train_X
            model.load_state_dict(state_dict, strict=False)
            # test sampler with collapse_batch_dims=False
            sampler = IIDNormalSampler(5, seed=0, collapse_batch_dims=False)
            with self.assertRaises(UnsupportedError):
                qNoisyExpectedImprovement(
                    model=model,
                    X_baseline=X_baseline,
                    sampler=sampler,
                    objective=objective,
                    prune_baseline=False,
                    cache_root=True,
                )
            sampler = IIDNormalSampler(5, seed=0)
            torch.manual_seed(0)
            acqf = qNoisyExpectedImprovement(
                model=model,
                X_baseline=X_baseline,
                sampler=sampler,
                objective=objective,
                prune_baseline=False,
                cache_root=True,
            )

            orig_base_samples = acqf.base_sampler.base_samples.detach().clone()
            sampler2 = IIDNormalSampler(5, seed=0)
            sampler2.base_samples = orig_base_samples
            torch.manual_seed(0)
            acqf_no_cache = qNoisyExpectedImprovement(
                model=model,
                X_baseline=X_baseline,
                sampler=sampler2,
                objective=objective,
                prune_baseline=False,
                cache_root=False,
            )
            for q, batch_shape in product(
                (1, 3), (torch.Size([]), torch.Size([3]), torch.Size([4, 3]))):
                test_X = (0.3 +
                          0.05 * torch.randn(*batch_shape, q, 2, **tkwargs)
                          ).requires_grad_(True)
                with mock.patch(
                        sample_cached_path,
                        wraps=sample_cached_cholesky) as mock_sample_cached:
                    torch.manual_seed(0)
                    val = acqf(test_X)
                    mock_sample_cached.assert_called_once()
                val.sum().backward()
                base_samples = acqf.sampler.base_samples.detach().clone()
                X_grad = test_X.grad.clone()
                test_X2 = test_X.detach().clone().requires_grad_(True)
                acqf_no_cache.sampler.base_samples = base_samples
                with mock.patch(
                        sample_cached_path,
                        wraps=sample_cached_cholesky) as mock_sample_cached:
                    torch.manual_seed(0)
                    val2 = acqf_no_cache(test_X2)
                mock_sample_cached.assert_not_called()
                self.assertTrue(torch.allclose(val, val2, **all_close_kwargs))
                val2.sum().backward()
                self.assertTrue(
                    torch.allclose(X_grad, test_X2.grad, **all_close_kwargs))
            # test we fall back to standard sampling for
            # ill-conditioned covariances
            acqf._baseline_L = torch.zeros_like(acqf._baseline_L)
            with warnings.catch_warnings(
                    record=True) as ws, settings.debug(True):
                with torch.no_grad():
                    acqf(test_X)
            self.assertEqual(len(ws), 1)
            self.assertTrue(issubclass(ws[-1].category, BotorchWarning))
Exemple #3
0
    def test_sample_cached_cholesky(self):
        torch.manual_seed(0)
        tkwargs = {"device": self.device}
        for dtype in (torch.float, torch.double):
            tkwargs["dtype"] = dtype
            train_X = torch.rand(10, 2, **tkwargs)
            train_Y = torch.randn(10, 2, **tkwargs)
            for m in (1, 2):
                model_list_values = (True, False) if m == 2 else (False, )
                for use_model_list in model_list_values:
                    if use_model_list:
                        model = ModelListGP(
                            SingleTaskGP(
                                train_X,
                                train_Y[..., :1],
                            ),
                            SingleTaskGP(
                                train_X,
                                train_Y[..., 1:],
                            ),
                        )
                    else:
                        model = SingleTaskGP(
                            train_X,
                            train_Y[:, :m],
                        )
                    sampler = IIDNormalSampler(3)
                    base_sampler = IIDNormalSampler(3)
                    for q in (1, 3, 9):
                        # test batched baseline_L
                        for train_batch_shape in (
                                torch.Size([]),
                                torch.Size([3]),
                                torch.Size([3, 2]),
                        ):
                            # test batched test points
                            for test_batch_shape in (
                                    torch.Size([]),
                                    torch.Size([4]),
                                    torch.Size([4, 2]),
                            ):

                                if len(train_batch_shape) > 0:
                                    train_X_ex = train_X.unsqueeze(0).expand(
                                        train_batch_shape + train_X.shape)
                                else:
                                    train_X_ex = train_X
                                if len(test_batch_shape) > 0:
                                    test_X = train_X_ex.unsqueeze(0).expand(
                                        test_batch_shape + train_X_ex.shape)
                                else:
                                    test_X = train_X_ex
                                with torch.no_grad():
                                    base_posterior = model.posterior(
                                        train_X_ex[..., :-q, :])
                                    mvn = base_posterior.mvn
                                    lazy_covar = mvn.lazy_covariance_matrix
                                    if m == 2:
                                        lazy_covar = lazy_covar.base_lazy_tensor
                                    baseline_L = lazy_covar.root_decomposition(
                                    )
                                    baseline_L = baseline_L.root.evaluate()
                                test_X = test_X.clone().requires_grad_(True)
                                new_posterior = model.posterior(test_X)
                                samples = sampler(new_posterior)
                                samples[..., -q:, :].sum().backward()
                                test_X2 = test_X.detach().clone(
                                ).requires_grad_(True)
                                new_posterior2 = model.posterior(test_X2)
                                q_samples = sample_cached_cholesky(
                                    posterior=new_posterior2,
                                    baseline_L=baseline_L,
                                    q=q,
                                    base_samples=sampler.base_samples.detach().
                                    clone(),
                                    sample_shape=sampler.sample_shape,
                                )
                                q_samples.sum().backward()
                                all_close_kwargs = ({
                                    "atol": 1e-4,
                                    "rtol": 1e-2,
                                } if dtype == torch.float else {})
                                self.assertTrue(
                                    torch.allclose(
                                        q_samples.detach(),
                                        samples[..., -q:, :].detach(),
                                        **all_close_kwargs,
                                    ))
                                self.assertTrue(
                                    torch.allclose(
                                        test_X2.grad[..., -q:, :],
                                        test_X.grad[..., -q:, :],
                                        **all_close_kwargs,
                                    ))
                                # Test that adding a new point and base_sample
                                # did not change posterior samples for previous points.
                                # This tests that we properly account for not
                                # interleaving.
                                base_sampler.base_samples = (
                                    sampler.base_samples[
                                        ..., :-q, :].detach().clone())

                                baseline_samples = base_sampler(base_posterior)
                                new_batch_shape = samples.shape[
                                    1:-baseline_samples.ndim + 1]
                                expanded_baseline_samples = baseline_samples.view(
                                    baseline_samples.shape[0],
                                    *[1] * len(new_batch_shape),
                                    *baseline_samples.shape[1:],
                                ).expand(
                                    baseline_samples.shape[0],
                                    *new_batch_shape,
                                    *baseline_samples.shape[1:],
                                )
                                self.assertTrue(
                                    torch.allclose(
                                        expanded_baseline_samples,
                                        samples[..., :-q, :],
                                        **all_close_kwargs,
                                    ))
                            # test nans
                            with torch.no_grad():
                                test_posterior = model.posterior(test_X2)
                            test_posterior.mvn.loc = torch.full_like(
                                test_posterior.mvn.loc, float("nan"))
                            with self.assertRaises(NanError):
                                sample_cached_cholesky(
                                    posterior=test_posterior,
                                    baseline_L=baseline_L,
                                    q=q,
                                    base_samples=sampler.base_samples.detach().
                                    clone(),
                                    sample_shape=sampler.sample_shape,
                                )
                            # test infs
                            test_posterior.mvn.loc = torch.full_like(
                                test_posterior.mvn.loc, float("inf"))
                            with self.assertRaises(NanError):
                                sample_cached_cholesky(
                                    posterior=test_posterior,
                                    baseline_L=baseline_L,
                                    q=q,
                                    base_samples=sampler.base_samples.detach().
                                    clone(),
                                    sample_shape=sampler.sample_shape,
                                )
                            # test triangular solve raising RuntimeError
                            test_posterior.mvn.loc = torch.full_like(
                                test_posterior.mvn.loc, 0.0)
                            base_samples = sampler.base_samples.detach().clone(
                            )
                            with mock.patch(
                                    "botorch.utils.low_rank.torch.triangular_solve",
                                    side_effect=RuntimeError("singular"),
                            ):
                                with self.assertRaises(NotPSDError):
                                    sample_cached_cholesky(
                                        posterior=test_posterior,
                                        baseline_L=baseline_L,
                                        q=q,
                                        base_samples=base_samples,
                                        sample_shape=sampler.sample_shape,
                                    )
                            with mock.patch(
                                    "botorch.utils.low_rank.torch.triangular_solve",
                                    side_effect=RuntimeError(""),
                            ):
                                with self.assertRaises(RuntimeError):
                                    sample_cached_cholesky(
                                        posterior=test_posterior,
                                        baseline_L=baseline_L,
                                        q=q,
                                        base_samples=base_samples,
                                        sample_shape=sampler.sample_shape,
                                    )