Beispiel #1
0
def _instantiate_MES(
    model: Model,
    candidate_set: Tensor,
    num_fantasies: int = 16,
    num_mv_samples: int = 10,
    num_y_samples: int = 128,
    use_gumbel: bool = True,
    X_pending: Optional[Tensor] = None,
    maximize: bool = True,
    num_trace_observations: int = 0,
    target_fidelities: Optional[Dict[int, float]] = None,
    fidelity_weights: Optional[Dict[int, float]] = None,
    cost_intercept: float = 1.0,
) -> qMaxValueEntropy:
    if target_fidelities:
        if fidelity_weights is None:
            fidelity_weights = {f: 1.0 for f in target_fidelities}
        if not set(target_fidelities) == set(fidelity_weights):
            raise RuntimeError(
                "Must provide the same indices for target_fidelities "
                f"({set(target_fidelities)}) and fidelity_weights "
                f" ({set(fidelity_weights)})."
            )
        cost_model = AffineFidelityCostModel(
            fidelity_weights=fidelity_weights, fixed_cost=cost_intercept
        )
        cost_aware_utility = InverseCostWeightedUtility(cost_model=cost_model)

        def project(X: Tensor) -> Tensor:
            return project_to_target_fidelity(X=X, target_fidelities=target_fidelities)

        def expand(X: Tensor) -> Tensor:
            return expand_trace_observations(
                X=X,
                fidelity_dims=sorted(target_fidelities),  # pyre-ignore: [6]
                num_trace_obs=num_trace_observations,
            )

        return qMultiFidelityMaxValueEntropy(
            model=model,
            candidate_set=candidate_set,
            num_fantasies=num_fantasies,
            num_mv_samples=num_mv_samples,
            num_y_samples=num_y_samples,
            X_pending=X_pending,
            maximize=maximize,
            cost_aware_utility=cost_aware_utility,
            project=project,
            expand=expand,
        )

    return qMaxValueEntropy(
        model=model,
        candidate_set=candidate_set,
        num_fantasies=num_fantasies,
        num_mv_samples=num_mv_samples,
        num_y_samples=num_y_samples,
        X_pending=X_pending,
        maximize=maximize,
    )
    def test_q_max_value_entropy(self):
        for dtype in (torch.float, torch.double):
            torch.manual_seed(7)
            mm = MESMockModel()
            with self.assertRaises(TypeError):
                qMaxValueEntropy(mm)

            train_inputs = torch.rand(10, 2, device=self.device, dtype=dtype)
            mm.train_inputs = (train_inputs, )
            candidate_set = torch.rand(1000,
                                       2,
                                       device=self.device,
                                       dtype=dtype)

            # test error when number of outputs > 1
            with self.assertRaises(NotImplementedError):
                mm._num_outputs = 2
                qMaxValueEntropy(mm, candidate_set, num_mv_samples=10)

            # test with X_pending is None
            mm._num_outputs = 1  # mm.num_outputs
            qMVE = qMaxValueEntropy(mm, candidate_set, num_mv_samples=10)

            # test initialization
            self.assertEqual(qMVE.num_fantasies, 16)
            self.assertEqual(qMVE.num_mv_samples, 10)
            self.assertIsInstance(qMVE.sampler, SobolQMCNormalSampler)
            self.assertEqual(qMVE.sampler.sample_shape, torch.Size([128]))
            self.assertIsInstance(qMVE.fantasies_sampler,
                                  SobolQMCNormalSampler)
            self.assertEqual(qMVE.fantasies_sampler.sample_shape,
                             torch.Size([16]))
            self.assertEqual(qMVE.use_gumbel, True)
            self.assertEqual(qMVE.posterior_max_values.shape,
                             torch.Size([10, 1]))

            # test evaluation
            X = torch.rand(1, 2, device=self.device, dtype=dtype)
            self.assertEqual(qMVE(X).shape, torch.Size([1]))

            # test with use_gumbel = False
            qMVE = qMaxValueEntropy(mm,
                                    candidate_set,
                                    num_mv_samples=10,
                                    use_gumbel=False)
            self.assertEqual(qMVE(X).shape, torch.Size([1]))

            # test with X_pending is not None
            with mock.patch.object(MESMockModel, "fantasize",
                                   return_value=mm) as patch_f:
                qMVE = qMaxValueEntropy(
                    mm,
                    candidate_set,
                    num_mv_samples=10,
                    X_pending=torch.rand(1, 2, device=self.device,
                                         dtype=dtype),
                )
                patch_f.assert_called_once()
Beispiel #3
0
    def test_q_max_value_entropy(self):
        for dtype in (torch.float, torch.double):
            torch.manual_seed(7)
            mm = MESMockModel()
            with self.assertRaises(TypeError):
                qMaxValueEntropy(mm)

            candidate_set = torch.rand(1000,
                                       2,
                                       device=self.device,
                                       dtype=dtype)

            # test error in case of batch GP model
            train_inputs = torch.rand(5,
                                      10,
                                      2,
                                      device=self.device,
                                      dtype=dtype)
            mm.train_inputs = (train_inputs, )
            with self.assertRaises(NotImplementedError):
                qMaxValueEntropy(mm, candidate_set, num_mv_samples=10)

            # test error when number of outputs > 1
            no = ("test.acquisition.test_max_value_entropy_search"
                  ".MESMockModel.num_outputs")
            with mock.patch(
                    no, new_callable=mock.PropertyMock) as mock_num_outputs:
                mock_num_outputs.return_value = 2
                with self.assertRaises(UnsupportedError):
                    qMaxValueEntropy(mm, candidate_set, num_mv_samples=10)

            # test with X_pending is None
            train_inputs = torch.rand(10, 2, device=self.device, dtype=dtype)
            mm.train_inputs = (train_inputs, )
            qMVE = qMaxValueEntropy(mm, candidate_set, num_mv_samples=10)

            # test initialization
            self.assertEqual(qMVE.num_fantasies, 16)
            self.assertEqual(qMVE.num_mv_samples, 10)
            self.assertIsInstance(qMVE.sampler, SobolQMCNormalSampler)
            self.assertEqual(qMVE.sampler.sample_shape, torch.Size([128]))
            self.assertIsInstance(qMVE.fantasies_sampler,
                                  SobolQMCNormalSampler)
            self.assertEqual(qMVE.fantasies_sampler.sample_shape,
                             torch.Size([16]))
            self.assertEqual(qMVE.use_gumbel, True)
            self.assertEqual(qMVE.posterior_max_values.shape,
                             torch.Size([10, 1]))

            # test evaluation
            X = torch.rand(1, 2, device=self.device, dtype=dtype)
            self.assertEqual(qMVE(X).shape, torch.Size([1]))

            # test set X pending to None in case of _init_model exists
            qMVE.set_X_pending(None)
            self.assertEqual(qMVE.model, qMVE._init_model)

            # test with use_gumbel = False
            qMVE = qMaxValueEntropy(mm,
                                    candidate_set,
                                    num_mv_samples=10,
                                    use_gumbel=False)
            self.assertEqual(qMVE(X).shape, torch.Size([1]))

            # test with X_pending is not None
            with mock.patch.object(MESMockModel, "fantasize",
                                   return_value=mm) as patch_f:
                qMVE = qMaxValueEntropy(
                    mm,
                    candidate_set,
                    num_mv_samples=10,
                    X_pending=torch.rand(1, 2, device=self.device,
                                         dtype=dtype),
                )
                patch_f.assert_called_once()
    def test_q_max_value_entropy(self):
        for dtype in (torch.float, torch.double):
            torch.manual_seed(7)
            mm = MESMockModel()
            with self.assertRaises(TypeError):
                qMaxValueEntropy(mm)

            candidate_set = torch.rand(1000,
                                       2,
                                       device=self.device,
                                       dtype=dtype)

            # test error in case of batch GP model
            mm = MESMockModel(batch_shape=torch.Size([2]))
            with self.assertRaises(NotImplementedError):
                qMaxValueEntropy(mm, candidate_set, num_mv_samples=10)
            mm = MESMockModel()
            train_inputs = torch.rand(5,
                                      10,
                                      2,
                                      device=self.device,
                                      dtype=dtype)
            with self.assertRaises(NotImplementedError):
                qMaxValueEntropy(mm,
                                 candidate_set,
                                 num_mv_samples=10,
                                 train_inputs=train_inputs)

            # test that init works if batch_shape is not implemented on the model
            mm = NoBatchShapeMESMockModel()
            qMaxValueEntropy(
                mm,
                candidate_set,
                num_mv_samples=10,
            )

            # test error when number of outputs > 1 and no transform is given.
            mm = MESMockModel()
            mm._num_outputs = 2
            with self.assertRaises(UnsupportedError):
                qMaxValueEntropy(mm, candidate_set, num_mv_samples=10)

            # test with X_pending is None
            mm = MESMockModel()
            train_inputs = torch.rand(10, 2, device=self.device, dtype=dtype)
            mm.train_inputs = (train_inputs, )
            qMVE = qMaxValueEntropy(mm, candidate_set, num_mv_samples=10)

            # test initialization
            self.assertEqual(qMVE.num_fantasies, 16)
            self.assertEqual(qMVE.num_mv_samples, 10)
            self.assertIsInstance(qMVE.sampler, SobolQMCNormalSampler)
            self.assertEqual(qMVE.sampler.sample_shape, torch.Size([128]))
            self.assertIsInstance(qMVE.fantasies_sampler,
                                  SobolQMCNormalSampler)
            self.assertEqual(qMVE.fantasies_sampler.sample_shape,
                             torch.Size([16]))
            self.assertEqual(qMVE.use_gumbel, True)
            self.assertEqual(qMVE.posterior_max_values.shape,
                             torch.Size([10, 1]))

            # test evaluation
            X = torch.rand(1, 2, device=self.device, dtype=dtype)
            self.assertEqual(qMVE(X).shape, torch.Size([1]))

            # test set X pending to None in case of _init_model exists
            qMVE.set_X_pending(None)
            self.assertEqual(qMVE.model, qMVE._init_model)

            # test with use_gumbel = False
            qMVE = qMaxValueEntropy(mm,
                                    candidate_set,
                                    num_mv_samples=10,
                                    use_gumbel=False)
            self.assertEqual(qMVE(X).shape, torch.Size([1]))

            # test with X_pending is not None
            with mock.patch.object(MESMockModel, "fantasize",
                                   return_value=mm) as patch_f:
                qMVE = qMaxValueEntropy(
                    mm,
                    candidate_set,
                    num_mv_samples=10,
                    X_pending=torch.rand(1, 2, device=self.device,
                                         dtype=dtype),
                )
                patch_f.assert_called_once()

            # Test with multi-output model w/ transform.
            mm = MESMockModel(num_outputs=2)
            pt = ScalarizedPosteriorTransform(
                weights=torch.ones(2, device=self.device, dtype=dtype))
            for gumbel in (True, False):
                qMVE = qMaxValueEntropy(
                    mm,
                    candidate_set,
                    num_mv_samples=10,
                    use_gumbel=gumbel,
                    posterior_transform=pt,
                )
                self.assertEqual(qMVE(X).shape, torch.Size([1]))