Пример #1
0
 def test_init(self):
     mm = MockModel(MockPosterior(mean=None))
     MPS = MaxPosteriorSampling(mm)
     self.assertEqual(MPS.model, mm)
     self.assertTrue(MPS.replacement)
     self.assertIsInstance(MPS.objective, IdentityMCObjective)
     obj = LinearMCObjective(torch.rand(2))
     MPS = MaxPosteriorSampling(mm, objective=obj, replacement=False)
     self.assertEqual(MPS.objective, obj)
     self.assertFalse(MPS.replacement)
Пример #2
0
def get_candidate(model, acq, full_train_Y, q, bounds, dim):
    if acq == 'EI':
        if q == 1:
            EI = ExpectedImprovement(model, full_train_Y.max().item())
        else:
            EI = qExpectedImprovement(model, full_train_Y.max().item())

        bounds_t = torch.FloatTensor([[bounds[0]] * dim, [bounds[1]] * dim])
        candidate, acq_value = optimize_acqf(
            EI,
            bounds=bounds_t,
            q=q,
            num_restarts=15,
            raw_samples=5000,
        )

    elif acq == 'TS':
        sobol = SobolEngine(dim, scramble=True)
        n_candidates = min(5000, max(20000, 2000 * dim))
        pert = sobol.draw(n_candidates)
        X_cand = (bounds[1] - bounds[0]) * pert + bounds[0]
        thompson_sampling = MaxPosteriorSampling(model=model,
                                                 replacement=False)
        candidate = thompson_sampling(X_cand, num_samples=q)

    else:
        raise NotImplementedError('Only TS and EI are implemented')

    return candidate, EI if acq == 'EI' else None
Пример #3
0
    def test_max_posterior_sampling(self):
        batch_shapes = (torch.Size(), torch.Size([3]), torch.Size([3, 2]))
        dtypes = (torch.float, torch.double)
        for batch_shape, dtype, N, num_samples, d in itertools.product(
            batch_shapes, dtypes, (5, 6), (1, 2), (1, 2)
        ):
            tkwargs = {"device": self.device, "dtype": dtype}
            # X is `batch_shape x N x d` = batch_shape x N x 1.
            X = torch.randn(*batch_shape, N, d, **tkwargs)
            # the event shape is `num_samples x batch_shape x N x m`
            psamples = torch.zeros(num_samples, *batch_shape, N, 1, **tkwargs)
            psamples[..., 0, :] = 1.0

            # IdentityMCObjective, with replacement
            with mock.patch.object(MockPosterior, "rsample", return_value=psamples):
                mp = MockPosterior(None)
                with mock.patch.object(MockModel, "posterior", return_value=mp):
                    mm = MockModel(None)
                    MPS = MaxPosteriorSampling(mm)
                    s = MPS(X, num_samples=num_samples)
                    self.assertTrue(torch.equal(s, X[..., [0] * num_samples, :]))

            # ScalarizedMCObjective, with replacement
            with mock.patch.object(MockPosterior, "rsample", return_value=psamples):
                mp = MockPosterior(None)
                with mock.patch.object(MockModel, "posterior", return_value=mp):
                    mm = MockModel(None)
                    with mock.patch.object(
                        ScalarizedObjective, "forward", return_value=mp
                    ):
                        obj = ScalarizedObjective(torch.rand(2, **tkwargs))
                        MPS = MaxPosteriorSampling(mm, objective=obj)
                        s = MPS(X, num_samples=num_samples)
                        self.assertTrue(torch.equal(s, X[..., [0] * num_samples, :]))

            # without replacement
            psamples[..., 1, 0] = 1e-6
            with mock.patch.object(MockPosterior, "rsample", return_value=psamples):
                mp = MockPosterior(None)
                with mock.patch.object(MockModel, "posterior", return_value=mp):
                    mm = MockModel(None)
                    MPS = MaxPosteriorSampling(mm, replacement=False)
                    if len(batch_shape) > 1:
                        with self.assertRaises(NotImplementedError):
                            MPS(X, num_samples=num_samples)
                    else:
                        s = MPS(X, num_samples=num_samples)
                        # order is not guaranteed, need to sort
                        self.assertTrue(
                            torch.equal(
                                torch.sort(s, dim=-2).values,
                                torch.sort(X[..., :num_samples, :], dim=-2).values,
                            )
                        )

            # ScalarizedMCObjective, without replacement
            with mock.patch.object(MockPosterior, "rsample", return_value=psamples):
                mp = MockPosterior(None)
                with mock.patch.object(MockModel, "posterior", return_value=mp):
                    mm = MockModel(None)
                    with mock.patch.object(
                        ScalarizedObjective, "forward", return_value=mp
                    ):
                        obj = ScalarizedObjective(torch.rand(2, **tkwargs))
                        MPS = MaxPosteriorSampling(mm, objective=obj, replacement=False)
                        if len(batch_shape) > 1:
                            with self.assertRaises(NotImplementedError):
                                MPS(X, num_samples=num_samples)
                        else:
                            s = MPS(X, num_samples=num_samples)
                            # order is not guaranteed, need to sort
                            self.assertTrue(
                                torch.equal(
                                    torch.sort(s, dim=-2).values,
                                    torch.sort(X[..., :num_samples, :], dim=-2).values,
                                )
                            )