def test_setup(self): mean = torch.zeros(1, 1) variance = torch.ones(1, 1) mm = MockModel(MockPosterior(mean=mean, variance=variance)) # basic test sampler = IIDNormalSampler(1) acqf = DummyCachedCholeskyAcqf(model=mm, sampler=sampler) acqf._setup(model=mm, sampler=sampler) self.assertFalse(acqf._is_mt) self.assertFalse(acqf._is_deterministic) self.assertFalse(acqf._uses_matheron) self.assertFalse(acqf._cache_root) acqf._setup(model=mm, sampler=sampler, cache_root=True) self.assertTrue(acqf._cache_root) # test check_sampler with warnings.catch_warnings(record=True) as ws, settings.debug(True): acqf._setup(model=mm, sampler=sampler, check_sampler=True) self.assertEqual(len(ws), 0) # test collapse_batch_dims=False sampler = IIDNormalSampler(1, collapse_batch_dims=False) acqf = DummyCachedCholeskyAcqf(model=mm, sampler=sampler) with self.assertRaises(UnsupportedError): acqf._setup(model=mm, sampler=sampler, check_sampler=True) # test warning if base_samples is not None sampler = IIDNormalSampler(1) sampler.base_samples = torch.zeros(1, 1) acqf = DummyCachedCholeskyAcqf(model=mm, sampler=sampler) with warnings.catch_warnings(record=True) as ws, settings.debug(True): acqf._setup(model=mm, sampler=sampler, check_sampler=True) self.assertTrue(issubclass(ws[-1].category, BotorchWarning)) # test the base_samples are set to None self.assertIsNone(acqf.sampler.base_samples) # test model that uses matheron's rule and sampler.batch_range != (0, -1) hogp = HigherOrderGP(torch.zeros(1, 1), torch.zeros(1, 1, 1)).eval() acqf = DummyCachedCholeskyAcqf(model=hogp, sampler=sampler) with self.assertRaises(RuntimeError): acqf._setup(model=hogp, sampler=sampler, cache_root=True) self.assertTrue(acqf._uses_matheron) self.assertTrue(acqf._is_mt) self.assertFalse(acqf._is_deterministic) # test deterministic model model = GenericDeterministicModel(f=lambda X: X) acqf = DummyCachedCholeskyAcqf(model=model, sampler=sampler) acqf._setup(model=model, sampler=sampler, cache_root=True) self.assertTrue(acqf._is_deterministic) self.assertFalse(acqf._uses_matheron) self.assertFalse(acqf._is_mt) self.assertFalse(acqf._cache_root)
def test_cache_root(self): sample_cached_path = ( "botorch.acquisition.cached_cholesky.sample_cached_cholesky") raw_state_dict = { "likelihood.noise_covar.raw_noise": torch.tensor([[0.0895], [0.2594]], dtype=torch.float64), "mean_module.constant": torch.tensor([[-0.4545], [-0.1285]], dtype=torch.float64), "covar_module.raw_outputscale": torch.tensor([1.4876, 1.4897], dtype=torch.float64), "covar_module.base_kernel.raw_lengthscale": torch.tensor([[[-0.7202, -0.2868]], [[-0.8794, -1.2877]]], dtype=torch.float64), } # test batched models (e.g. for MCMC) for train_batch_shape, m, dtype in product( (torch.Size([]), torch.Size([3])), (1, 2), (torch.float, torch.double)): state_dict = deepcopy(raw_state_dict) for k, v in state_dict.items(): if m == 1: v = v[0] if len(train_batch_shape) > 0: v = v.unsqueeze(0).expand(*train_batch_shape, *v.shape) state_dict[k] = v tkwargs = {"device": self.device, "dtype": dtype} if m == 2: objective = GenericMCObjective(lambda Y, X: Y.sum(dim=-1)) else: objective = None for k, v in state_dict.items(): state_dict[k] = v.to(**tkwargs) all_close_kwargs = ({ "atol": 1e-1, "rtol": 0.0, } if dtype == torch.float else { "atol": 1e-4, "rtol": 0.0 }) torch.manual_seed(1234) train_X = torch.rand(*train_batch_shape, 3, 2, **tkwargs) train_Y = ( torch.sin(train_X * 2 * pi) + torch.randn(*train_batch_shape, 3, 2, **tkwargs))[..., :m] train_Y = standardize(train_Y) model = SingleTaskGP( train_X, train_Y, ) if len(train_batch_shape) > 0: X_baseline = train_X[0] else: X_baseline = train_X model.load_state_dict(state_dict, strict=False) # test sampler with collapse_batch_dims=False sampler = IIDNormalSampler(5, seed=0, collapse_batch_dims=False) with self.assertRaises(UnsupportedError): qNoisyExpectedImprovement( model=model, X_baseline=X_baseline, sampler=sampler, objective=objective, prune_baseline=False, cache_root=True, ) sampler = IIDNormalSampler(5, seed=0) torch.manual_seed(0) acqf = qNoisyExpectedImprovement( model=model, X_baseline=X_baseline, sampler=sampler, objective=objective, prune_baseline=False, cache_root=True, ) orig_base_samples = acqf.base_sampler.base_samples.detach().clone() sampler2 = IIDNormalSampler(5, seed=0) sampler2.base_samples = orig_base_samples torch.manual_seed(0) acqf_no_cache = qNoisyExpectedImprovement( model=model, X_baseline=X_baseline, sampler=sampler2, objective=objective, prune_baseline=False, cache_root=False, ) for q, batch_shape in product( (1, 3), (torch.Size([]), torch.Size([3]), torch.Size([4, 3]))): test_X = (0.3 + 0.05 * torch.randn(*batch_shape, q, 2, **tkwargs) ).requires_grad_(True) with mock.patch( sample_cached_path, wraps=sample_cached_cholesky) as mock_sample_cached: torch.manual_seed(0) val = acqf(test_X) mock_sample_cached.assert_called_once() val.sum().backward() base_samples = acqf.sampler.base_samples.detach().clone() X_grad = test_X.grad.clone() test_X2 = test_X.detach().clone().requires_grad_(True) acqf_no_cache.sampler.base_samples = base_samples with mock.patch( sample_cached_path, wraps=sample_cached_cholesky) as mock_sample_cached: torch.manual_seed(0) val2 = acqf_no_cache(test_X2) mock_sample_cached.assert_not_called() self.assertTrue(torch.allclose(val, val2, **all_close_kwargs)) val2.sum().backward() self.assertTrue( torch.allclose(X_grad, test_X2.grad, **all_close_kwargs)) # test we fall back to standard sampling for # ill-conditioned covariances acqf._baseline_L = torch.zeros_like(acqf._baseline_L) with warnings.catch_warnings( record=True) as ws, settings.debug(True): with torch.no_grad(): acqf(test_X) self.assertEqual(len(ws), 1) self.assertTrue(issubclass(ws[-1].category, BotorchWarning))
def test_sample_cached_cholesky(self): torch.manual_seed(0) tkwargs = {"device": self.device} for dtype in (torch.float, torch.double): tkwargs["dtype"] = dtype train_X = torch.rand(10, 2, **tkwargs) train_Y = torch.randn(10, 2, **tkwargs) for m in (1, 2): model_list_values = (True, False) if m == 2 else (False, ) for use_model_list in model_list_values: if use_model_list: model = ModelListGP( SingleTaskGP( train_X, train_Y[..., :1], ), SingleTaskGP( train_X, train_Y[..., 1:], ), ) else: model = SingleTaskGP( train_X, train_Y[:, :m], ) sampler = IIDNormalSampler(3) base_sampler = IIDNormalSampler(3) for q in (1, 3, 9): # test batched baseline_L for train_batch_shape in ( torch.Size([]), torch.Size([3]), torch.Size([3, 2]), ): # test batched test points for test_batch_shape in ( torch.Size([]), torch.Size([4]), torch.Size([4, 2]), ): if len(train_batch_shape) > 0: train_X_ex = train_X.unsqueeze(0).expand( train_batch_shape + train_X.shape) else: train_X_ex = train_X if len(test_batch_shape) > 0: test_X = train_X_ex.unsqueeze(0).expand( test_batch_shape + train_X_ex.shape) else: test_X = train_X_ex with torch.no_grad(): base_posterior = model.posterior( train_X_ex[..., :-q, :]) mvn = base_posterior.mvn lazy_covar = mvn.lazy_covariance_matrix if m == 2: lazy_covar = lazy_covar.base_lazy_tensor baseline_L = lazy_covar.root_decomposition( ) baseline_L = baseline_L.root.evaluate() test_X = test_X.clone().requires_grad_(True) new_posterior = model.posterior(test_X) samples = sampler(new_posterior) samples[..., -q:, :].sum().backward() test_X2 = test_X.detach().clone( ).requires_grad_(True) new_posterior2 = model.posterior(test_X2) q_samples = sample_cached_cholesky( posterior=new_posterior2, baseline_L=baseline_L, q=q, base_samples=sampler.base_samples.detach(). clone(), sample_shape=sampler.sample_shape, ) q_samples.sum().backward() all_close_kwargs = ({ "atol": 1e-4, "rtol": 1e-2, } if dtype == torch.float else {}) self.assertTrue( torch.allclose( q_samples.detach(), samples[..., -q:, :].detach(), **all_close_kwargs, )) self.assertTrue( torch.allclose( test_X2.grad[..., -q:, :], test_X.grad[..., -q:, :], **all_close_kwargs, )) # Test that adding a new point and base_sample # did not change posterior samples for previous points. # This tests that we properly account for not # interleaving. base_sampler.base_samples = ( sampler.base_samples[ ..., :-q, :].detach().clone()) baseline_samples = base_sampler(base_posterior) new_batch_shape = samples.shape[ 1:-baseline_samples.ndim + 1] expanded_baseline_samples = baseline_samples.view( baseline_samples.shape[0], *[1] * len(new_batch_shape), *baseline_samples.shape[1:], ).expand( baseline_samples.shape[0], *new_batch_shape, *baseline_samples.shape[1:], ) self.assertTrue( torch.allclose( expanded_baseline_samples, samples[..., :-q, :], **all_close_kwargs, )) # test nans with torch.no_grad(): test_posterior = model.posterior(test_X2) test_posterior.mvn.loc = torch.full_like( test_posterior.mvn.loc, float("nan")) with self.assertRaises(NanError): sample_cached_cholesky( posterior=test_posterior, baseline_L=baseline_L, q=q, base_samples=sampler.base_samples.detach(). clone(), sample_shape=sampler.sample_shape, ) # test infs test_posterior.mvn.loc = torch.full_like( test_posterior.mvn.loc, float("inf")) with self.assertRaises(NanError): sample_cached_cholesky( posterior=test_posterior, baseline_L=baseline_L, q=q, base_samples=sampler.base_samples.detach(). clone(), sample_shape=sampler.sample_shape, ) # test triangular solve raising RuntimeError test_posterior.mvn.loc = torch.full_like( test_posterior.mvn.loc, 0.0) base_samples = sampler.base_samples.detach().clone( ) with mock.patch( "botorch.utils.low_rank.torch.triangular_solve", side_effect=RuntimeError("singular"), ): with self.assertRaises(NotPSDError): sample_cached_cholesky( posterior=test_posterior, baseline_L=baseline_L, q=q, base_samples=base_samples, sample_shape=sampler.sample_shape, ) with mock.patch( "botorch.utils.low_rank.torch.triangular_solve", side_effect=RuntimeError(""), ): with self.assertRaises(RuntimeError): sample_cached_cholesky( posterior=test_posterior, baseline_L=baseline_L, q=q, base_samples=base_samples, sample_shape=sampler.sample_shape, )