def _instantiate_MES( model: Model, candidate_set: Tensor, num_fantasies: int = 16, num_mv_samples: int = 10, num_y_samples: int = 128, use_gumbel: bool = True, X_pending: Optional[Tensor] = None, maximize: bool = True, num_trace_observations: int = 0, target_fidelities: Optional[Dict[int, float]] = None, fidelity_weights: Optional[Dict[int, float]] = None, cost_intercept: float = 1.0, ) -> qMaxValueEntropy: if target_fidelities: if fidelity_weights is None: fidelity_weights = {f: 1.0 for f in target_fidelities} if not set(target_fidelities) == set(fidelity_weights): raise RuntimeError( "Must provide the same indices for target_fidelities " f"({set(target_fidelities)}) and fidelity_weights " f" ({set(fidelity_weights)})." ) cost_model = AffineFidelityCostModel( fidelity_weights=fidelity_weights, fixed_cost=cost_intercept ) cost_aware_utility = InverseCostWeightedUtility(cost_model=cost_model) def project(X: Tensor) -> Tensor: return project_to_target_fidelity(X=X, target_fidelities=target_fidelities) def expand(X: Tensor) -> Tensor: return expand_trace_observations( X=X, fidelity_dims=sorted(target_fidelities), # pyre-ignore: [6] num_trace_obs=num_trace_observations, ) return qMultiFidelityMaxValueEntropy( model=model, candidate_set=candidate_set, num_fantasies=num_fantasies, num_mv_samples=num_mv_samples, num_y_samples=num_y_samples, X_pending=X_pending, maximize=maximize, cost_aware_utility=cost_aware_utility, project=project, expand=expand, ) return qMaxValueEntropy( model=model, candidate_set=candidate_set, num_fantasies=num_fantasies, num_mv_samples=num_mv_samples, num_y_samples=num_y_samples, X_pending=X_pending, maximize=maximize, )
def test_q_max_value_entropy(self): for dtype in (torch.float, torch.double): torch.manual_seed(7) mm = MESMockModel() with self.assertRaises(TypeError): qMaxValueEntropy(mm) train_inputs = torch.rand(10, 2, device=self.device, dtype=dtype) mm.train_inputs = (train_inputs, ) candidate_set = torch.rand(1000, 2, device=self.device, dtype=dtype) # test error when number of outputs > 1 with self.assertRaises(NotImplementedError): mm._num_outputs = 2 qMaxValueEntropy(mm, candidate_set, num_mv_samples=10) # test with X_pending is None mm._num_outputs = 1 # mm.num_outputs qMVE = qMaxValueEntropy(mm, candidate_set, num_mv_samples=10) # test initialization self.assertEqual(qMVE.num_fantasies, 16) self.assertEqual(qMVE.num_mv_samples, 10) self.assertIsInstance(qMVE.sampler, SobolQMCNormalSampler) self.assertEqual(qMVE.sampler.sample_shape, torch.Size([128])) self.assertIsInstance(qMVE.fantasies_sampler, SobolQMCNormalSampler) self.assertEqual(qMVE.fantasies_sampler.sample_shape, torch.Size([16])) self.assertEqual(qMVE.use_gumbel, True) self.assertEqual(qMVE.posterior_max_values.shape, torch.Size([10, 1])) # test evaluation X = torch.rand(1, 2, device=self.device, dtype=dtype) self.assertEqual(qMVE(X).shape, torch.Size([1])) # test with use_gumbel = False qMVE = qMaxValueEntropy(mm, candidate_set, num_mv_samples=10, use_gumbel=False) self.assertEqual(qMVE(X).shape, torch.Size([1])) # test with X_pending is not None with mock.patch.object(MESMockModel, "fantasize", return_value=mm) as patch_f: qMVE = qMaxValueEntropy( mm, candidate_set, num_mv_samples=10, X_pending=torch.rand(1, 2, device=self.device, dtype=dtype), ) patch_f.assert_called_once()
def test_q_max_value_entropy(self): for dtype in (torch.float, torch.double): torch.manual_seed(7) mm = MESMockModel() with self.assertRaises(TypeError): qMaxValueEntropy(mm) candidate_set = torch.rand(1000, 2, device=self.device, dtype=dtype) # test error in case of batch GP model train_inputs = torch.rand(5, 10, 2, device=self.device, dtype=dtype) mm.train_inputs = (train_inputs, ) with self.assertRaises(NotImplementedError): qMaxValueEntropy(mm, candidate_set, num_mv_samples=10) # test error when number of outputs > 1 no = ("test.acquisition.test_max_value_entropy_search" ".MESMockModel.num_outputs") with mock.patch( no, new_callable=mock.PropertyMock) as mock_num_outputs: mock_num_outputs.return_value = 2 with self.assertRaises(UnsupportedError): qMaxValueEntropy(mm, candidate_set, num_mv_samples=10) # test with X_pending is None train_inputs = torch.rand(10, 2, device=self.device, dtype=dtype) mm.train_inputs = (train_inputs, ) qMVE = qMaxValueEntropy(mm, candidate_set, num_mv_samples=10) # test initialization self.assertEqual(qMVE.num_fantasies, 16) self.assertEqual(qMVE.num_mv_samples, 10) self.assertIsInstance(qMVE.sampler, SobolQMCNormalSampler) self.assertEqual(qMVE.sampler.sample_shape, torch.Size([128])) self.assertIsInstance(qMVE.fantasies_sampler, SobolQMCNormalSampler) self.assertEqual(qMVE.fantasies_sampler.sample_shape, torch.Size([16])) self.assertEqual(qMVE.use_gumbel, True) self.assertEqual(qMVE.posterior_max_values.shape, torch.Size([10, 1])) # test evaluation X = torch.rand(1, 2, device=self.device, dtype=dtype) self.assertEqual(qMVE(X).shape, torch.Size([1])) # test set X pending to None in case of _init_model exists qMVE.set_X_pending(None) self.assertEqual(qMVE.model, qMVE._init_model) # test with use_gumbel = False qMVE = qMaxValueEntropy(mm, candidate_set, num_mv_samples=10, use_gumbel=False) self.assertEqual(qMVE(X).shape, torch.Size([1])) # test with X_pending is not None with mock.patch.object(MESMockModel, "fantasize", return_value=mm) as patch_f: qMVE = qMaxValueEntropy( mm, candidate_set, num_mv_samples=10, X_pending=torch.rand(1, 2, device=self.device, dtype=dtype), ) patch_f.assert_called_once()
def test_q_max_value_entropy(self): for dtype in (torch.float, torch.double): torch.manual_seed(7) mm = MESMockModel() with self.assertRaises(TypeError): qMaxValueEntropy(mm) candidate_set = torch.rand(1000, 2, device=self.device, dtype=dtype) # test error in case of batch GP model mm = MESMockModel(batch_shape=torch.Size([2])) with self.assertRaises(NotImplementedError): qMaxValueEntropy(mm, candidate_set, num_mv_samples=10) mm = MESMockModel() train_inputs = torch.rand(5, 10, 2, device=self.device, dtype=dtype) with self.assertRaises(NotImplementedError): qMaxValueEntropy(mm, candidate_set, num_mv_samples=10, train_inputs=train_inputs) # test that init works if batch_shape is not implemented on the model mm = NoBatchShapeMESMockModel() qMaxValueEntropy( mm, candidate_set, num_mv_samples=10, ) # test error when number of outputs > 1 and no transform is given. mm = MESMockModel() mm._num_outputs = 2 with self.assertRaises(UnsupportedError): qMaxValueEntropy(mm, candidate_set, num_mv_samples=10) # test with X_pending is None mm = MESMockModel() train_inputs = torch.rand(10, 2, device=self.device, dtype=dtype) mm.train_inputs = (train_inputs, ) qMVE = qMaxValueEntropy(mm, candidate_set, num_mv_samples=10) # test initialization self.assertEqual(qMVE.num_fantasies, 16) self.assertEqual(qMVE.num_mv_samples, 10) self.assertIsInstance(qMVE.sampler, SobolQMCNormalSampler) self.assertEqual(qMVE.sampler.sample_shape, torch.Size([128])) self.assertIsInstance(qMVE.fantasies_sampler, SobolQMCNormalSampler) self.assertEqual(qMVE.fantasies_sampler.sample_shape, torch.Size([16])) self.assertEqual(qMVE.use_gumbel, True) self.assertEqual(qMVE.posterior_max_values.shape, torch.Size([10, 1])) # test evaluation X = torch.rand(1, 2, device=self.device, dtype=dtype) self.assertEqual(qMVE(X).shape, torch.Size([1])) # test set X pending to None in case of _init_model exists qMVE.set_X_pending(None) self.assertEqual(qMVE.model, qMVE._init_model) # test with use_gumbel = False qMVE = qMaxValueEntropy(mm, candidate_set, num_mv_samples=10, use_gumbel=False) self.assertEqual(qMVE(X).shape, torch.Size([1])) # test with X_pending is not None with mock.patch.object(MESMockModel, "fantasize", return_value=mm) as patch_f: qMVE = qMaxValueEntropy( mm, candidate_set, num_mv_samples=10, X_pending=torch.rand(1, 2, device=self.device, dtype=dtype), ) patch_f.assert_called_once() # Test with multi-output model w/ transform. mm = MESMockModel(num_outputs=2) pt = ScalarizedPosteriorTransform( weights=torch.ones(2, device=self.device, dtype=dtype)) for gumbel in (True, False): qMVE = qMaxValueEntropy( mm, candidate_set, num_mv_samples=10, use_gumbel=gumbel, posterior_transform=pt, ) self.assertEqual(qMVE(X).shape, torch.Size([1]))