Exemplo n.º 1
0
    def __init__(
        self,
        model: Model,
        X_baseline: Tensor,
        sampler: Optional[MCSampler] = None,
        objective: Optional[MCAcquisitionObjective] = None,
        X_pending: Optional[Tensor] = None,
        prune_baseline: bool = False,
        **kwargs: Any,
    ) -> None:
        r"""q-Noisy Expected Improvement.

        Args:
            model: A fitted model.
            X_baseline: A `batch_shape x r x d`-dim Tensor of `r` design points
                that have already been observed. These points are considered as
                the potential best design point.
            sampler: The sampler used to draw base samples. Defaults to
                `SobolQMCNormalSampler(num_samples=500, collapse_batch_dims=True)`.
            objective: The MCAcquisitionObjective under which the samples are
                evaluated. Defaults to `IdentityMCObjective()`.
            X_pending: A `batch_shape x m x d`-dim Tensor of `m` design points
                that have points that have been submitted for function evaluation
                but have not yet been evaluated. Concatenated into `X` upon
                forward call. Copied and set to have no gradient.
            prune_baseline: If True, remove points in `X_baseline` that are
                highly unlikely to be the best point. This can significantly
                improve performance and is generally recommended. In order to
                customize pruning parameters, instead manually call
                `botorch.acquisition.utils.prune_inferior_points` on `X_baseline`
                before instantiating the acquisition function.
        """
        super().__init__(
            model=model, sampler=sampler, objective=objective, X_pending=X_pending
        )
        if prune_baseline:
            X_baseline = prune_inferior_points(
                model=model,
                X=X_baseline,
                objective=objective,
                marginalize_dim=kwargs.get("marginalize_dim"),
            )
        self.register_buffer("X_baseline", X_baseline)
Exemplo n.º 2
0
    def __init__(
        self,
        model: Model,
        X_baseline: Tensor,
        sampler: Optional[MCSampler] = None,
        objective: Optional[MCAcquisitionObjective] = None,
        posterior_transform: Optional[PosteriorTransform] = None,
        X_pending: Optional[Tensor] = None,
        prune_baseline: bool = False,
        cache_root: bool = True,
        **kwargs: Any,
    ) -> None:
        r"""q-Noisy Expected Improvement.

        Args:
            model: A fitted model.
            X_baseline: A `batch_shape x r x d`-dim Tensor of `r` design points
                that have already been observed. These points are considered as
                the potential best design point.
            sampler: The sampler used to draw base samples. Defaults to
                `SobolQMCNormalSampler(num_samples=512, collapse_batch_dims=True)`.
            objective: The MCAcquisitionObjective under which the samples are
                evaluated. Defaults to `IdentityMCObjective()`.
            posterior_transform: A PosteriorTransform (optional).
            X_pending: A `batch_shape x m x d`-dim Tensor of `m` design points
                that have points that have been submitted for function evaluation
                but have not yet been evaluated. Concatenated into `X` upon
                forward call. Copied and set to have no gradient.
            prune_baseline: If True, remove points in `X_baseline` that are
                highly unlikely to be the best point. This can significantly
                improve performance and is generally recommended. In order to
                customize pruning parameters, instead manually call
                `botorch.acquisition.utils.prune_inferior_points` on `X_baseline`
                before instantiating the acquisition function.
            cache_root: A boolean indicating whether to cache the root
                decomposition over `X_baseline` and use low-rank updates.

        TODO: similar to qNEHVI, when we are using sequential greedy candidate
        selection, we could incorporate pending points X_baseline and compute
        the incremental qNEI from the new point. This would greatly increase
        efficiency for large batches.
        """
        super().__init__(
            model=model,
            sampler=sampler,
            objective=objective,
            posterior_transform=posterior_transform,
            X_pending=X_pending,
        )
        self._setup(model=model, sampler=self.sampler, cache_root=cache_root)
        self.base_sampler = deepcopy(self.sampler)
        if prune_baseline:
            X_baseline = prune_inferior_points(
                model=model,
                X=X_baseline,
                objective=objective,
                posterior_transform=posterior_transform,
                marginalize_dim=kwargs.get("marginalize_dim"),
            )
        self.register_buffer("X_baseline", X_baseline)

        if self._cache_root:
            self.q = -1
            # set baseline samples
            with torch.no_grad():
                posterior = self.model.posterior(X_baseline)
                baseline_samples = self.base_sampler(posterior)
            baseline_obj = self.objective(baseline_samples, X=X_baseline)
            self.register_buffer("baseline_samples", baseline_samples)
            self.register_buffer("baseline_obj_max_values",
                                 baseline_obj.max(dim=-1).values)
            self._cache_root_decomposition(posterior=posterior)
Exemplo n.º 3
0
 def test_prune_inferior_points(self):
     for dtype in (torch.float, torch.double):
         X = torch.rand(3, 2, device=self.device, dtype=dtype)
         # the event shape is `q x t` = 3 x 1
         samples = torch.tensor([[-1.0], [0.0], [1.0]],
                                device=self.device,
                                dtype=dtype)
         mm = MockModel(MockPosterior(samples=samples))
         # test that a batched X raises errors
         with self.assertRaises(UnsupportedError):
             prune_inferior_points(model=mm, X=X.expand(2, 3, 2))
         # test that a batched model raises errors (event shape is `q x t` = 3 x 1)
         mm2 = MockModel(MockPosterior(samples=samples.expand(2, 3, 1)))
         with self.assertRaises(UnsupportedError):
             prune_inferior_points(model=mm2, X=X)
         # test that invalid max_frac is checked properly
         with self.assertRaises(ValueError):
             prune_inferior_points(model=mm, X=X, max_frac=1.1)
         # test basic behaviour
         X_pruned = prune_inferior_points(model=mm, X=X)
         self.assertTrue(torch.equal(X_pruned, X[[-1]]))
         # test custom objective
         neg_id_obj = GenericMCObjective(lambda Y, X: -(Y.squeeze(-1)))
         X_pruned = prune_inferior_points(model=mm,
                                          X=X,
                                          objective=neg_id_obj)
         self.assertTrue(torch.equal(X_pruned, X[[0]]))
         # test non-repeated samples (requires mocking out MockPosterior's rsample)
         samples = torch.tensor(
             [[[3.0], [0.0], [0.0]], [[0.0], [2.0], [0.0]],
              [[0.0], [0.0], [1.0]]],
             device=self.device,
             dtype=dtype,
         )
         with mock.patch.object(MockPosterior,
                                "rsample",
                                return_value=samples):
             mm = MockModel(MockPosterior(samples=samples))
             X_pruned = prune_inferior_points(model=mm, X=X)
         self.assertTrue(torch.equal(X_pruned, X))
         # test max_frac limiting
         with mock.patch.object(MockPosterior,
                                "rsample",
                                return_value=samples):
             mm = MockModel(MockPosterior(samples=samples))
             X_pruned = prune_inferior_points(model=mm, X=X, max_frac=2 / 3)
         if self.device == torch.device("cuda"):
             # sorting has different order on cuda
             self.assertTrue(
                 torch.equal(X_pruned, torch.stack([X[2], X[1]], dim=0)))
         else:
             self.assertTrue(torch.equal(X_pruned, X[:2]))
         # test that zero-probability is in fact pruned
         samples[2, 0, 0] = 10
         with mock.patch.object(MockPosterior,
                                "rsample",
                                return_value=samples):
             mm = MockModel(MockPosterior(samples=samples))
             X_pruned = prune_inferior_points(model=mm, X=X)
         self.assertTrue(torch.equal(X_pruned, X[:2]))
         # test high-dim sampling
         with ExitStack() as es:
             mock_event_shape = es.enter_context(
                 mock.patch(
                     "botorch.utils.testing.MockPosterior.base_sample_shape",
                     new_callable=mock.PropertyMock,
                 ))
             mock_event_shape.return_value = torch.Size(
                 [1, 1, torch.quasirandom.SobolEngine.MAXDIM + 1])
             es.enter_context(
                 mock.patch.object(MockPosterior,
                                   "rsample",
                                   return_value=samples))
             mm = MockModel(MockPosterior(samples=samples))
             with warnings.catch_warnings(
                     record=True) as ws, settings.debug(True):
                 prune_inferior_points(model=mm, X=X)
                 self.assertTrue(
                     issubclass(ws[-1].category, SamplingWarning))
Exemplo n.º 4
0
 def test_prune_inferior_points(self):
     for dtype in (torch.float, torch.double):
         X = torch.rand(3, 2, device=self.device, dtype=dtype)
         # the event shape is `q x t` = 3 x 1
         samples = torch.tensor(
             [[-1.0], [0.0], [1.0]], device=self.device, dtype=dtype
         )
         mm = MockModel(MockPosterior(samples=samples))
         # test that a batched X raises errors
         with self.assertRaises(UnsupportedError):
             prune_inferior_points(model=mm, X=X.expand(2, 3, 2))
         # test that a batched model raises errors (event shape is `q x t` = 3 x 1)
         mm2 = MockModel(MockPosterior(samples=samples.expand(2, 3, 1)))
         with self.assertRaises(UnsupportedError):
             prune_inferior_points(model=mm2, X=X)
         # test that invalid max_frac is checked properly
         with self.assertRaises(ValueError):
             prune_inferior_points(model=mm, X=X, max_frac=1.1)
         # test basic behaviour
         X_pruned = prune_inferior_points(model=mm, X=X)
         self.assertTrue(torch.equal(X_pruned, X[[-1]]))
         # test custom objective
         neg_id_obj = GenericMCObjective(lambda X: -X.squeeze(-1))
         X_pruned = prune_inferior_points(model=mm, X=X, objective=neg_id_obj)
         self.assertTrue(torch.equal(X_pruned, X[[0]]))
         # test non-repeated samples (requires mocking out MockPosterior's rsample)
         samples = torch.tensor(
             [[[3.0], [0.0], [0.0]], [[0.0], [2.0], [0.0]], [[0.0], [0.0], [1.0]]],
             device=self.device,
             dtype=dtype,
         )
         with mock.patch.object(MockPosterior, "rsample", return_value=samples):
             mm = MockModel(MockPosterior(samples=samples))
             X_pruned = prune_inferior_points(model=mm, X=X)
         self.assertTrue(torch.equal(X_pruned, X))
         # test max_frac limiting
         with mock.patch.object(MockPosterior, "rsample", return_value=samples):
             mm = MockModel(MockPosterior(samples=samples))
             X_pruned = prune_inferior_points(model=mm, X=X, max_frac=2 / 3)
         if self.device == torch.device("cuda"):
             # sorting has different order on cuda
             self.assertTrue(torch.equal(X_pruned, torch.stack([X[2], X[1]], dim=0)))
         else:
             self.assertTrue(torch.equal(X_pruned, X[:2]))
         # test that zero-probability is in fact pruned
         samples[2, 0, 0] = 10
         with mock.patch.object(MockPosterior, "rsample", return_value=samples):
             mm = MockModel(MockPosterior(samples=samples))
             X_pruned = prune_inferior_points(model=mm, X=X)
         self.assertTrue(torch.equal(X_pruned, X[:2]))