示例#1
0
 def test_get_EHVI_input_validation_errors(self):
     weights = torch.ones(2)
     objective_thresholds = torch.zeros(2)
     mm = MockModel(MockPosterior())
     with self.assertRaisesRegex(ValueError,
                                 "There are no feasible observed points."):
         get_EHVI(
             model=mm,
             objective_weights=weights,
             objective_thresholds=objective_thresholds,
         )
示例#2
0
 def test_get_ehvi(self):
     weights = torch.tensor([0.0, 1.0, 1.0])
     X_observed = torch.rand(4, 3)
     X_pending = torch.rand(1, 3)
     constraints = (torch.tensor([1.0, 0.0, 0.0]), torch.tensor([[10.0]]))
     Y = torch.rand(4, 3)
     mm = MockModel(MockPosterior(mean=Y))
     objective_thresholds = torch.arange(3, dtype=torch.float)
     obj_and_obj_t = get_weighted_mc_objective_and_objective_thresholds(
         objective_weights=weights,
         objective_thresholds=objective_thresholds,
     )
     (weighted_obj, new_obj_thresholds) = obj_and_obj_t
     cons_tfs = get_outcome_constraint_transforms(constraints)
     with manual_seed(0):
         seed = torch.randint(1, 10000, (1, )).item()
     with ExitStack() as es:
         mock_get_acqf = es.enter_context(mock.patch(GET_ACQF_PATH))
         es.enter_context(
             mock.patch(GET_CONSTRAINT_PATH, return_value=cons_tfs))
         es.enter_context(
             mock.patch(GET_OBJ_PATH, return_value=obj_and_obj_t))
         es.enter_context(manual_seed(0))
         get_EHVI(
             model=mm,
             objective_weights=weights,
             outcome_constraints=constraints,
             objective_thresholds=objective_thresholds,
             X_observed=X_observed,
             X_pending=X_pending,
         )
         mock_get_acqf.assert_called_once_with(
             acquisition_function_name="qEHVI",
             model=mm,
             objective=weighted_obj,
             X_observed=X_observed,
             X_pending=X_pending,
             constraints=cons_tfs,
             mc_samples=128,
             qmc=True,
             alpha=0.0,
             seed=seed,
             ref_point=new_obj_thresholds.tolist(),
             Y=Y,
         )
示例#3
0
 def test_get_EHVI_input_validation_errors(self):
     model = MultiObjectiveBotorchModel()
     x = torch.zeros(2, 2)
     weights = torch.ones(2)
     ref_point = torch.zeros(2)
     with self.assertRaisesRegex(ValueError,
                                 "There are no feasible observed points."):
         get_EHVI(model=model,
                  objective_weights=weights,
                  ref_point=ref_point)
     with self.assertRaisesRegex(
             ValueError,
             "Expected Hypervolume Improvement requires Ys argument"):
         get_EHVI(
             model=model,
             X_observed=x,
             objective_weights=weights,
             ref_point=ref_point,
         )