def get_ALEBO( experiment: Experiment, search_space: SearchSpace, data: Data, B: torch.Tensor, **model_kwargs: Any, ) -> TorchModelBridge: if search_space is None: search_space = experiment.search_space return TorchModelBridge( experiment=experiment, search_space=search_space, data=data, model=ALEBO(B=B, **model_kwargs), transforms=ALEBO_X_trans + [Derelativize, StandardizeY], # pyre-ignore torch_dtype=B.dtype, torch_device=B.device, )
def fit_and_predict_map(B, train_X, train_Y, train_Yvar, test_X, mu, sigma): m = ALEBO(B=B, laplace_nsamp=1) # laplace_nsamp=1 uses MAP estimate m.fit([train_X], [train_Y], [train_Yvar], [], [], [], [], []) f, var = m.predict(test_X) # Return predictions, un-standardized return f.squeeze() * sigma + mu, var.squeeze() * sigma**2
def testALEBO(self): B = torch.tensor( [[1.0, 2.0, 3.0, 4.0, 5.0], [2.0, 3.0, 4.0, 5.0, 6.0]], dtype=torch.double) train_X = torch.tensor( [ [0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0, 2.0], ], dtype=torch.double, ) train_Y = torch.tensor([[1.0], [2.0], [3.0]], dtype=torch.double) train_Yvar = 0.1 * torch.ones(3, 1, dtype=torch.double) m = ALEBO(B=B, laplace_nsamp=5, fit_restarts=1) self.assertTrue(torch.equal(B, m.B)) self.assertEqual(m.laplace_nsamp, 5) self.assertEqual(m.fit_restarts, 1) self.assertEqual(m.refit_on_update, True) self.assertEqual(m.refit_on_cv, False) self.assertEqual(m.warm_start_refitting, False) # Test fit m.fit( Xs=[train_X, train_X], Ys=[train_Y, train_Y], Yvars=[train_Yvar, train_Yvar], bounds=[(-1, 1)] * 5, task_features=[], feature_names=[], metric_names=[], fidelity_features=[], ) self.assertIsInstance(m.model, ModelListGP) self.assertTrue(torch.allclose(m.Xs[0], (B @ train_X.t()).t())) # Test predict f, cov = m.predict(X=B) self.assertEqual(f.shape, torch.Size([2, 2])) self.assertEqual(cov.shape, torch.Size([2, 2, 2])) # Test best point objective_weights = torch.tensor([1.0, 0.0], dtype=torch.double) with self.assertRaises(NotImplementedError): m.best_point(bounds=[(-1, 1)] * 5, objective_weights=objective_weights) # Test gen # With clipping with mock.patch( "ax.models.torch.alebo.optimize_acqf", autospec=True, return_value=(m.Xs[0], torch.tensor([])), ): Xopt, _, _ = m.gen( n=1, bounds=[(-1, 1)] * 5, objective_weights=torch.tensor([1.0, 0.0], dtype=torch.double), ) self.assertFalse(torch.allclose(Xopt, train_X)) self.assertTrue(Xopt.min() >= -1) self.assertTrue(Xopt.max() <= 1) # Without with mock.patch( "ax.models.torch.alebo.optimize_acqf", autospec=True, return_value=(torch.ones(1, 2, dtype=torch.double), torch.tensor([])), ): Xopt, _, _ = m.gen( n=1, bounds=[(-1, 1)] * 5, objective_weights=torch.tensor([1.0, 0.0], dtype=torch.double), ) self.assertTrue( torch.allclose( Xopt, torch.tensor([[-0.2, -0.1, 0.0, 0.1, 0.2]], dtype=torch.double))) # Test update train_X2 = torch.tensor( [ [3.0, 3.0, 3.0, 3.0, 3.0], [1.0, 1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0, 2.0], ], dtype=torch.double, ) m.update( Xs=[train_X, train_X2], Ys=[train_Y, train_Y], Yvars=[train_Yvar, train_Yvar], ) self.assertTrue(torch.allclose(m.Xs[0], (B @ train_X.t()).t())) self.assertTrue(torch.allclose(m.Xs[1], (B @ train_X2.t()).t())) m.refit_on_update = False m.update( Xs=[train_X, train_X2], Ys=[train_Y, train_Y], Yvars=[train_Yvar, train_Yvar], ) # Test get_and_fit with single meric gp = m.get_and_fit_model(Xs=[(B @ train_X.t()).t()], Ys=[train_Y], Yvars=[train_Yvar]) self.assertIsInstance(gp, ALEBOGP) # Test cross_validate f, cov = m.cross_validate( Xs_train=[train_X], Ys_train=[train_Y], Yvars_train=[train_Yvar], X_test=train_X2, ) self.assertEqual(f.shape, torch.Size([3, 1])) self.assertEqual(cov.shape, torch.Size([3, 1, 1])) m.refit_on_cv = True f, cov = m.cross_validate( Xs_train=[train_X], Ys_train=[train_Y], Yvars_train=[train_Yvar], X_test=train_X2, ) self.assertEqual(f.shape, torch.Size([3, 1])) self.assertEqual(cov.shape, torch.Size([3, 1, 1]))
def fit_and_predict_alebo(B, train_X, train_Y, train_Yvar, test_X, mu, sigma): m = ALEBO(B=B) m.fit([train_X], [train_Y], [train_Yvar], [], [], [], [], []) f, var = m.predict(test_X) # Return predictions, un-standardized return f.squeeze() * sigma + mu, var.squeeze() * sigma**2