Exemple #1
0
 def test_gen_batch_initial_conditions_highdim(self):
     d = 120
     bounds = torch.stack([torch.zeros(d), torch.ones(d)])
     for dtype in (torch.float, torch.double):
         bounds = bounds.to(device=self.device, dtype=dtype)
         for nonnegative in (True, False):
             for seed in (None, 1234):
                 with warnings.catch_warnings(
                         record=True) as ws, settings.debug(True):
                     batch_initial_conditions = gen_batch_initial_conditions(
                         acq_function=MockAcquisitionFunction(),
                         bounds=bounds,
                         q=10,
                         num_restarts=1,
                         raw_samples=2,
                         options={
                             "nonnegative": nonnegative,
                             "eta": 0.01,
                             "alpha": 0.1,
                             "seed": seed,
                         },
                     )
                     self.assertTrue(
                         any(
                             issubclass(w.category, SamplingWarning)
                             for w in ws))
                 expected_shape = torch.Size([1, 10, d])
                 self.assertEqual(batch_initial_conditions.shape,
                                  expected_shape)
                 self.assertEqual(batch_initial_conditions.device,
                                  bounds.device)
                 self.assertEqual(batch_initial_conditions.dtype,
                                  bounds.dtype)
Exemple #2
0
 def test_gen_batch_initial_conditions_warning(self):
     for dtype in (torch.float, torch.double):
         bounds = torch.tensor([[0, 0], [1, 1]],
                               device=self.device,
                               dtype=dtype)
         samples = torch.zeros(10, 1, 2, device=self.device, dtype=dtype)
         with ExitStack() as es:
             ws = es.enter_context(warnings.catch_warnings(record=True))
             es.enter_context(settings.debug(True))
             es.enter_context(
                 mock.patch(
                     "botorch.optim.initializers.draw_sobol_samples",
                     return_value=samples,
                 ))
             batch_initial_conditions = gen_batch_initial_conditions(
                 acq_function=MockAcquisitionFunction(),
                 bounds=bounds,
                 q=1,
                 num_restarts=2,
                 raw_samples=10,
                 options={"seed": 1234},
             )
             self.assertEqual(len(ws), 1)
             self.assertTrue(
                 any(
                     issubclass(w.category, BadInitialCandidatesWarning)
                     for w in ws))
             self.assertTrue(
                 torch.equal(
                     batch_initial_conditions,
                     torch.zeros(2, 1, 2, device=self.device, dtype=dtype),
                 ))
 def test_gen_batch_initial_conditions(self):
     for dtype in (torch.float, torch.double):
         bounds = torch.tensor([[0, 0], [1, 1]],
                               device=self.device,
                               dtype=dtype)
         for nonnegative in (True, False):
             for seed in (None, 1234):
                 batch_initial_conditions = gen_batch_initial_conditions(
                     acq_function=MockAcquisitionFunction(),
                     bounds=bounds,
                     q=1,
                     num_restarts=2,
                     raw_samples=10,
                     options={
                         "nonnegative": nonnegative,
                         "eta": 0.01,
                         "alpha": 0.1,
                         "seed": seed,
                     },
                 )
                 expected_shape = torch.Size([2, 1, 2])
                 self.assertEqual(batch_initial_conditions.shape,
                                  expected_shape)
                 self.assertEqual(batch_initial_conditions.device,
                                  bounds.device)
                 self.assertEqual(batch_initial_conditions.dtype,
                                  bounds.dtype)
    def test_gen_batch_initial_conditions(self):
        bounds = torch.stack([torch.zeros(2), torch.ones(2)])
        mock_acqf = MockAcquisitionFunction()
        mock_acqf.objective = lambda y: y.squeeze(-1)
        for dtype in (torch.float, torch.double):
            bounds = bounds.to(device=self.device, dtype=dtype)
            mock_acqf.X_baseline = bounds  # for testing sample_around_best
            mock_acqf.model = MockModel(MockPosterior(mean=bounds[:, :1]))
            for nonnegative, seed, init_batch_limit, ffs, sample_around_best in product(
                [True, False], [None, 1234], [None, 1], [None, {0: 0.5}], [True, False]
            ):
                with mock.patch.object(
                    MockAcquisitionFunction,
                    "__call__",
                    wraps=mock_acqf.__call__,
                ) as mock_acqf_call:
                    batch_initial_conditions = gen_batch_initial_conditions(
                        acq_function=mock_acqf,
                        bounds=bounds,
                        q=1,
                        num_restarts=2,
                        raw_samples=10,
                        fixed_features=ffs,
                        options={
                            "nonnegative": nonnegative,
                            "eta": 0.01,
                            "alpha": 0.1,
                            "seed": seed,
                            "init_batch_limit": init_batch_limit,
                            "sample_around_best": sample_around_best,
                        },
                    )
                    expected_shape = torch.Size([2, 1, 2])
                    self.assertEqual(batch_initial_conditions.shape, expected_shape)
                    self.assertEqual(batch_initial_conditions.device, bounds.device)
                    self.assertEqual(batch_initial_conditions.dtype, bounds.dtype)
                    batch_shape = (
                        torch.Size([])
                        if init_batch_limit is None
                        else torch.Size([init_batch_limit])
                    )
                    raw_samps = mock_acqf_call.call_args[0][0]
                    batch_shape = (
                        torch.Size([20 if sample_around_best else 10])
                        if init_batch_limit is None
                        else torch.Size([init_batch_limit])
                    )
                    expected_raw_samps_shape = batch_shape + torch.Size([1, 2])
                    self.assertEqual(raw_samps.shape, expected_raw_samps_shape)

                    if ffs is not None:
                        for idx, val in ffs.items():
                            self.assertTrue(
                                torch.all(batch_initial_conditions[..., idx] == val)
                            )
 def optimize_acqf_and_get_observation(self, acq, seed):
     """Optimizes the acquisition function, and returns a new candidate."""
     init = initializers.gen_batch_initial_conditions(
         acq, self.bounds, options={"seed": seed}, **self.optim_kwargs)
     # optimize
     candidate, acq_value = optimize_acqf(acq,
                                          bounds=self.bounds,
                                          batch_initial_conditions=init,
                                          **self.optim_kwargs)
     # observe new value
     new_x = candidate.detach()  #self.scale_to_bounds(candidate.detach())
     return new_x
Exemple #6
0
 def test_gen_batch_initial_conditions(self):
     for dtype in (torch.float, torch.double):
         bounds = torch.tensor([[0, 0], [1, 1]],
                               device=self.device,
                               dtype=dtype)
         for nonnegative in (True, False):
             for seed in (None, 1234):
                 mock_acqf = MockAcquisitionFunction()
                 for init_batch_limit in (None, 1):
                     mock_acqf = MockAcquisitionFunction()
                     with mock.patch.object(
                             MockAcquisitionFunction,
                             "__call__",
                             wraps=mock_acqf.__call__,
                     ) as mock_acqf_call:
                         batch_initial_conditions = gen_batch_initial_conditions(
                             acq_function=mock_acqf,
                             bounds=bounds,
                             q=1,
                             num_restarts=2,
                             raw_samples=10,
                             options={
                                 "nonnegative": nonnegative,
                                 "eta": 0.01,
                                 "alpha": 0.1,
                                 "seed": seed,
                                 "init_batch_limit": init_batch_limit,
                             },
                         )
                         expected_shape = torch.Size([2, 1, 2])
                         self.assertEqual(batch_initial_conditions.shape,
                                          expected_shape)
                         self.assertEqual(batch_initial_conditions.device,
                                          bounds.device)
                         self.assertEqual(batch_initial_conditions.dtype,
                                          bounds.dtype)
                         batch_shape = (torch.Size([])
                                        if init_batch_limit is None else
                                        torch.Size([init_batch_limit]))
                         raw_samps = mock_acqf_call.call_args[0][0]
                         batch_shape = (torch.Size([10])
                                        if init_batch_limit is None else
                                        torch.Size([init_batch_limit]))
                         expected_raw_samps_shape = batch_shape + torch.Size(
                             [1, 2])
                         self.assertEqual(raw_samps.shape,
                                          expected_raw_samps_shape)
    def test_gen_batch_initial_conditions_highdim(self):
        d = 2200  # 2200 * 10 (q) > 21201 (sobol max dim)
        bounds = torch.stack([torch.zeros(d), torch.ones(d)])
        ffs_map = {i: random() for i in range(0, d, 2)}
        mock_acqf = MockAcquisitionFunction()
        mock_acqf.objective = lambda y: y.squeeze(-1)
        for dtype in (torch.float, torch.double):
            bounds = bounds.to(device=self.device, dtype=dtype)
            mock_acqf.X_baseline = bounds  # for testing sample_around_best
            mock_acqf.model = MockModel(MockPosterior(mean=bounds[:, :1]))

            for nonnegative, seed, ffs, sample_around_best in product(
                [True, False], [None, 1234], [None, ffs_map], [True, False]):
                with warnings.catch_warnings(
                        record=True) as ws, settings.debug(True):
                    batch_initial_conditions = gen_batch_initial_conditions(
                        acq_function=MockAcquisitionFunction(),
                        bounds=bounds,
                        q=10,
                        num_restarts=1,
                        raw_samples=2,
                        fixed_features=ffs,
                        options={
                            "nonnegative": nonnegative,
                            "eta": 0.01,
                            "alpha": 0.1,
                            "seed": seed,
                            "sample_around_best": sample_around_best,
                        },
                    )
                    self.assertTrue(
                        any(
                            issubclass(w.category, SamplingWarning)
                            for w in ws))
                expected_shape = torch.Size([1, 10, d])
                self.assertEqual(batch_initial_conditions.shape,
                                 expected_shape)
                self.assertEqual(batch_initial_conditions.device,
                                 bounds.device)
                self.assertEqual(batch_initial_conditions.dtype, bounds.dtype)
                if ffs is not None:
                    for idx, val in ffs.items():
                        self.assertTrue(
                            torch.all(batch_initial_conditions[...,
                                                               idx] == val))
    def get_acqui_fun_maximizer(self):

        logger.info(
            "Computing next candidate by maximizing the acquisition function ..."
        )
        batch_limit = 2
        # batch_limit = 50 # This is a super bad idea for GPCR.
        options = {
            "batch_limit": batch_limit,
            "maxiter": 300,
            "ftol": 1e-6,
            "method": "L-BFGS-B",
            "iprint": 2,
            "maxls": 20,
            "disp": self.disp_info_scipy_opti
        }

        # Get initial random restart points:
        logger.info("Generating random restarts ...")
        initial_conditions = gen_batch_initial_conditions(
            acq_function=self,
            bounds=self.bounds,
            q=1,
            num_restarts=self.Nrestarts,
            raw_samples=500,
            options=options)
        # logger.info("initial_conditions:" + str(initial_conditions))

        logger.info("Using nlopt ...")
        x_next, alpha_next = self.constrained_opt.run_optimization(
            initial_conditions.view((self.Nrestarts, self.dim)))

        # # TODO: Is this really needed?
        # prob_val = self.get_probability_of_safe_evaluation(x_next.unsqueeze(1))
        # if prob_val < self.rho_conserv:
        # 	logger.info("(Is this really needed????) x_next violates the probabilistic constraint...")
        # 	pdb.set_trace()

        logger.info("Done!")

        return x_next, alpha_next
    def optimize_acqui_use_restarts_individually(self, options):

        # Get initial random restart points:
        self.my_print("[get_next_point()] Generating random restarts ...")
        initial_conditions = gen_batch_initial_conditions(
            acq_function=self,
            bounds=self.bounds,
            q=1,
            num_restarts=self.Nrestarts,
            raw_samples=500,
            options=options)

        self.my_print(
            "[get_next_point()] Optimizing acquisition function with {0:d} restarts ..."
            .format(self.Nrestarts))
        x_next_many = torch.zeros(size=(self.Nrestarts, 1, self.dim))
        alpha_next_many = torch.zeros(size=(self.Nrestarts, ))
        for k in range(self.Nrestarts):

            if (k + 1) % 5 == 0:
                self.my_print(
                    "[get_next_point()] restart {0:d} / {1:d}".format(
                        k + 1, self.Nrestarts))

            x_next_many[k, :], alpha_next_many[k] = gen_candidates_scipy(
                initial_conditions=initial_conditions[k, :].view(
                    (1, 1, self.dim)),
                acquisition_function=self,
                lower_bounds=0.0,
                upper_bounds=1.0,
                options=options)

        # Get the best:
        self.my_print("[get_next_point()] Getting best candidates ...")
        x_next = get_best_candidates(x_next_many, alpha_next_many).detach()
        alpha_next = self.forward(x_next).detach()

        return x_next, alpha_next
    def _optimize_acqui_use_restarts_individually(self):

        # Get initial random restart points:
        logger.info("  Generating random restarts ...")
        options = {
            "maxiter": 200,
            "ftol": 1e-9,
            "method": "L-BFGS-B",
            "iprint": 2,
            "maxls": 20,
            "disp": self.disp_info_scipy_opti
        }
        bounds = torch.tensor(self.hyperpars_bounds,
                              device=device,
                              dtype=dtype)
        initial_conditions = gen_batch_initial_conditions(
            acq_function=self.mll_objective,
            bounds=bounds,
            q=1,
            num_restarts=self.Nrestarts,
            raw_samples=500,
            options=options)

        logger.info(
            "  Optimizing loss function with {0:d} restarts ...".format(
                self.Nrestarts))
        new_hyperpars_many = torch.zeros(size=(self.Nrestarts, 1,
                                               self.dim_hyperpars))
        new_hyperpars_loss_many = torch.zeros(size=(self.Nrestarts, ))

        new_hyperpars, _ = self.opti_hyperpars.run_optimization(
            x_restarts=initial_conditions.view(self.Nrestarts,
                                               self.dim_hyperpars))

        logger.info("  Done!")

        return new_hyperpars
Exemple #11
0
    def gen(
        self,
        num_points: int,  # Current implementation only generates 1 point at a time
        model: MonotonicRejectionGP,
    ):
        """Query next point(s) to run by optimizing the acquisition function.
        Args:
            num_points (int, optional): Number of points to query.
            model (AEPsychMixin): Fitted model of the data.
        Returns:
            np.ndarray: Next set of point(s) to evaluate, [num_points x dim].
        """

        options = self.model_gen_options or {}
        num_restarts = options.get("num_restarts", 10)
        raw_samples = options.get("raw_samples", 1000)
        verbosity_freq = options.get("verbosity_freq", -1)
        lr = options.get("lr", 0.01)
        momentum = options.get("momentum", 0.9)
        nesterov = options.get("nesterov", True)
        epochs = options.get("epochs", 50)
        milestones = options.get("milestones", [25, 40])
        gamma = options.get("gamma", 0.1)
        loss_constraint_fun = options.get(
            "loss_constraint_fun", default_loss_constraint_fun
        )

        # Augment bounds with deriv indicator
        bounds = torch.cat((model.bounds_, torch.zeros(2, 1)), dim=1)
        # Fix deriv indicator to 0 during optimization
        fixed_features = {(bounds.shape[1] - 1): 0.0}
        # Fix explore features to random values
        if self.explore_features is not None:
            for idx in self.explore_features:
                val = (
                    bounds[0, idx]
                    + torch.rand(1, dtype=bounds.dtype)
                    * (bounds[1, idx] - bounds[0, idx])
                ).item()
                fixed_features[idx] = val
                bounds[0, idx] = val
                bounds[1, idx] = val

        acqf = self._instantiate_acquisition_fn(model)

        # Initialize
        batch_initial_conditions = gen_batch_initial_conditions(
            acq_function=acqf,
            bounds=bounds,
            q=1,
            num_restarts=num_restarts,
            raw_samples=raw_samples,
        )
        clamped_candidates = columnwise_clamp(
            X=batch_initial_conditions, lower=bounds[0], upper=bounds[1]
        ).requires_grad_(True)
        candidates = fix_features(clamped_candidates, fixed_features)
        optimizer = torch.optim.SGD(
            params=[clamped_candidates], lr=lr, momentum=momentum, nesterov=nesterov
        )
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=milestones, gamma=gamma
        )

        # Optimize
        for epoch in range(epochs):
            loss = -acqf(candidates).sum()

            # adjust loss based on constraints on candidates
            loss = loss_constraint_fun(loss, candidates)

            if verbosity_freq > 0 and epoch % verbosity_freq == 0:
                logger.info("Iter: {} - Value: {:.3f}".format(epoch, -(loss.item())))

            def closure():
                optimizer.zero_grad()
                loss.backward(
                    retain_graph=True
                )  # Variational model requires retain_graph
                return loss

            optimizer.step(closure)
            clamped_candidates.data = columnwise_clamp(
                X=clamped_candidates, lower=bounds[0], upper=bounds[1]
            )
            candidates = fix_features(clamped_candidates, fixed_features)
            lr_scheduler.step()

        # Extract best point
        with torch.no_grad():
            batch_acquisition = acqf(candidates)
        best = torch.argmax(batch_acquisition.view(-1), dim=0)
        Xopt = candidates[best][:, :-1].detach()
        return Xopt
 def test_gen_batch_initial_conditions_constraints(self):
     for dtype in (torch.float, torch.double):
         bounds = torch.tensor([[0, 0], [1, 1]],
                               device=self.device,
                               dtype=dtype)
         inequality_constraints = [(
             torch.tensor([1], device=self.device, dtype=dtype),
             torch.tensor([-4], device=self.device, dtype=dtype),
             torch.tensor(-3, device=self.device, dtype=dtype),
         )]
         equality_constraints = [(
             torch.tensor([0], device=self.device, dtype=dtype),
             torch.tensor([1], device=self.device, dtype=dtype),
             torch.tensor(0.5, device=self.device, dtype=dtype),
         )]
         for nonnegative in (True, False):
             for seed in (None, 1234):
                 mock_acqf = MockAcquisitionFunction()
                 for init_batch_limit in (None, 1):
                     mock_acqf = MockAcquisitionFunction()
                     with mock.patch.object(
                             MockAcquisitionFunction,
                             "__call__",
                             wraps=mock_acqf.__call__,
                     ) as mock_acqf_call:
                         batch_initial_conditions = gen_batch_initial_conditions(
                             acq_function=mock_acqf,
                             bounds=bounds,
                             q=1,
                             num_restarts=2,
                             raw_samples=10,
                             options={
                                 "nonnegative": nonnegative,
                                 "eta": 0.01,
                                 "alpha": 0.1,
                                 "seed": seed,
                                 "init_batch_limit": init_batch_limit,
                                 "thinning": 2,
                                 "n_burnin": 3,
                             },
                             inequality_constraints=inequality_constraints,
                             equality_constraints=equality_constraints,
                         )
                         expected_shape = torch.Size([2, 1, 2])
                         self.assertEqual(batch_initial_conditions.shape,
                                          expected_shape)
                         self.assertEqual(batch_initial_conditions.device,
                                          bounds.device)
                         self.assertEqual(batch_initial_conditions.dtype,
                                          bounds.dtype)
                         batch_shape = (torch.Size([])
                                        if init_batch_limit is None else
                                        torch.Size([init_batch_limit]))
                         raw_samps = mock_acqf_call.call_args[0][0]
                         batch_shape = (torch.Size([10])
                                        if init_batch_limit is None else
                                        torch.Size([init_batch_limit]))
                         expected_raw_samps_shape = batch_shape + torch.Size(
                             [1, 2])
                         self.assertEqual(raw_samps.shape,
                                          expected_raw_samps_shape)
                         self.assertTrue((raw_samps[..., 0] == 0.5).all())
                         self.assertTrue(
                             (-4 * raw_samps[..., 1] >= -3).all())
    def get_safe_evaluation(self, rho_t):

        # Gather fmin samples, using the Frechet distribution:
        fmin_samples = get_fmin_samples_from_gp(
            model=self.model_list[0],
            Nsamples=self.Nsamples_fmin,
            eta=self.eta_c)  # This assumes self.eta has been updated
        self.update_u_vec(fmin_samples)

        self.which_mode = "safe"

        self.my_print(
            "[get_safe_evaluation()] Computing next candidate by maximizing the acquisition function ..."
        )
        options = {
            "batch_limit": 50,
            "maxiter": 300,
            "ftol": 1e-6,
            "method": self.method_safe,
            "iprint": 2,
            "maxls": 20,
            "disp": self.disp_info_scipy_opti
        }
        # x_next, alpha_next = optimize_acqf(acq_function=self,bounds=self.bounds,q=1,num_restarts=self.Nrestarts_safe,raw_samples=500,return_best_only=True,options=options)
        # pdb.set_trace()

        # Get initial random restart points:
        self.my_print("[get_safe_evaluation()] Generating random restarts ...")
        initial_conditions = gen_batch_initial_conditions(
            acq_function=self,
            bounds=self.bounds,
            q=1,
            num_restarts=self.Nrestarts_safe,
            raw_samples=500,
            options=options)
        # print("initial_conditions.shape:",initial_conditions.shape)

        # BOtorch does not support constrained optimization with non-linear constraints. Because of this, it provides
        # a work-around solution to optimize using a sigmoid function to push the acquisition function to zero in regions
        # where the probabilistic constraint is not satisfied (i.e., areas where Pr(g(x) <= 0)) < rho_t.
        self.my_print(
            "[get_safe_evaluation()] Optimizing acquisition function ...")
        x_next_many, alpha_next_many = gen_candidates_scipy(
            initial_conditions=initial_conditions,
            acquisition_function=self,
            lower_bounds=0.0,
            upper_bounds=1.0,
            options=options)
        # Get the best:
        self.my_print("[get_safe_evaluation()] Getting best candidates ...")
        x_next = get_best_candidates(x_next_many, alpha_next_many)

        # pdb.set_trace()

        # However, the above optimization does not guarantee that the constraint will be satisfied. The reason for this is that the
        # sigmoid may have a small but non-zero mass in unsafe regions; then, a maximum could be found there in case
        # the rest of the safe areas are such that the acquisition function is even nearer to zero. If that's the case
        # we trigger a proper non-linear optimizer able to explicitly handle constraints.
        if self.probabilistic_constraint(
                x_next
        ) > 1e-6:  # If the constraint is violated above a tolerance, use nlopt
            self.my_print(
                "[get_safe_evaluation()] scipy optimization recommended an unfeasible point. Re-run using nlopt ..."
            )
            self.use_nlopt = True
            x_next, alpha_next = self.constrained_opt.run_constrained_minimization(
                initial_conditions.view((self.Nrestarts_safe, self.dim)))
            self.use_nlopt = False
        else:
            self.my_print(
                "[get_safe_evaluation()] scipy optimization finished successfully!"
            )
            alpha_next = self.forward(x_next)

        self.my_print("Pr(g(x_next) <= 0): {0:2.8f}".format(
            self.get_probability_of_safe_evaluation(x_next).item()))

        # Using botorch optimizer:
        # x_next, alpha_next = optimize_acqf(acq_function=self,bounds=self.bounds,q=1,num_restarts=self.Nrestarts_safe,raw_samples=500,return_best_only=True,options=options)

        # # The code below spits out: Unknown solver options: constraints. Using nlopt instead
        # constraints = [dict(type="ineq",fun=self.probabilistic_constraint)]
        # options = {"batch_limit": 1, "maxiter": 200, "ftol": 1e-6, "method": self.method_risky, "constraints": constraints}
        # x_next,alpha_next = optimize_acqf(acq_function=self,bounds=self.bounds,q=1,num_restarts=self.Nrestarts,
        # 																	raw_samples=500,return_best_only=True,options=options,)

        self.my_print("Done!")

        return x_next, alpha_next
    def gen(
        self,
        model_gen_options: Optional[Dict[str, Any]] = None,
        explore_features: Optional[List[int]] = None,
    ) -> Tuple[Tensor, Optional[List[Dict[str, Any]]]]:
        """Generate candidate by optimizing acquisition function.

        Args:
            model_gen_options: Dictionary with options for generating candidate, such as
                SGD parameters. See code for all options and their defaults.
            explore_features: List of features that will be selected randomly and then
                fixed for acquisition fn optimization.

        Returns:
            Xopt: (1 x d) tensor of the generated candidate
            candidate_metadata: List of dict of metadata for each candidate. Contains
                acquisition value for the candidate.
        """
        # Default optimization settings
        # TODO are these sufficiently robust? Can they be tuned better?
        options = model_gen_options or {}
        num_restarts = options.get("num_restarts", 10)
        raw_samples = options.get("raw_samples", 1000)
        verbosity_freq = options.get("verbosity_freq", -1)
        lr = options.get("lr", 0.01)
        momentum = options.get("momentum", 0.9)
        nesterov = options.get("nesterov", True)
        epochs = options.get("epochs", 50)
        milestones = options.get("milestones", [25, 40])
        gamma = options.get("gamma", 0.1)
        loss_constraint_fun = options.get(
            "loss_constraint_fun", default_loss_constraint_fun
        )

        acq_function = self._get_acquisition_fn()
        # Augment bounds with deriv indicator
        bounds = torch.cat((self.bounds_, torch.zeros(2, 1, dtype=self.dtype)), dim=1)
        # Fix deriv indicator to 0 during optimization
        fixed_features = {(bounds.shape[1] - 1): 0.0}
        # Fix explore features to random values
        if explore_features is not None:
            for idx in explore_features:
                val = (
                    bounds[0, idx]
                    + torch.rand(1, dtype=self.dtype)
                    * (bounds[1, idx] - bounds[0, idx])
                ).item()
                fixed_features[idx] = val
                bounds[0, idx] = val
                bounds[1, idx] = val

        # Initialize
        batch_initial_conditions = gen_batch_initial_conditions(
            acq_function=acq_function,
            bounds=bounds,
            q=1,
            num_restarts=num_restarts,
            raw_samples=raw_samples,
        )
        clamped_candidates = columnwise_clamp(
            X=batch_initial_conditions, lower=bounds[0], upper=bounds[1]
        ).requires_grad_(True)
        candidates = fix_features(clamped_candidates, fixed_features)
        optimizer = torch.optim.SGD(
            params=[clamped_candidates], lr=lr, momentum=momentum, nesterov=nesterov
        )
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=milestones, gamma=gamma
        )

        # Optimize
        for epoch in range(epochs):
            loss = -acq_function(candidates).sum()

            # adjust loss based on constraints on candidates
            loss = loss_constraint_fun(loss, candidates)

            if verbosity_freq > 0 and epoch % verbosity_freq == 0:
                logger.info("Iter: {} - Value: {:.3f}".format(epoch, -(loss.item())))

            def closure():
                optimizer.zero_grad()
                loss.backward(
                    retain_graph=True
                )  # Variational model requires retain_graph
                return loss

            optimizer.step(closure)
            clamped_candidates.data = columnwise_clamp(
                X=clamped_candidates, lower=bounds[0], upper=bounds[1]
            )
            candidates = fix_features(clamped_candidates, fixed_features)
            lr_scheduler.step()

        # Extract best point
        with torch.no_grad():
            batch_acquisition = acq_function(candidates)
        best = torch.argmax(batch_acquisition.view(-1), dim=0)
        Xopt = candidates[best][:, :-1].detach()
        candidate_metadata = [{"acquisition_value": batch_acquisition[best].item()}]
        return Xopt, candidate_metadata