def test_clone_reset(self): ftgs = GenerationStrategy(steps=[ GenerationStep(model=Models.FACTORIAL, num_arms=1), GenerationStep(model=Models.THOMPSON, num_arms=2), ]) ftgs._curr = ftgs._steps[1] self.assertEqual(ftgs._curr.index, 1) self.assertEqual(ftgs.clone_reset()._curr.index, 0)
def ax_client_with_explicit_strategy(num_random, num_computed): steps = [] if num_random > 0: steps.append(GenerationStep(model=Models.SOBOL, num_arms=num_random)) if num_computed > 0: steps.append(GenerationStep(model=Models.GPEI, num_arms=-1)) return AxClient(enforce_sequential_optimization=False, generation_strategy=GenerationStrategy(steps))
def test_string_representation(self): gs1 = GenerationStrategy(steps=[ GenerationStep(model=Models.SOBOL, num_arms=5), GenerationStep(model=Models.GPEI, num_arms=-1), ]) self.assertEqual( str(gs1), ("GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 5 arms," " GPEI for subsequent arms], generated 0 arm(s))"), )
def test_sobol_GPEI_strategy(self): exp = get_branin_experiment() sobol_GPEI = GenerationStrategy( name="Sobol+GPEI", steps=[ GenerationStep( model=Models.SOBOL, num_trials=5, model_kwargs=self.step_model_kwargs, ), GenerationStep(model=Models.GPEI, num_trials=2, model_kwargs=self.step_model_kwargs), ], ) self.assertEqual(sobol_GPEI.name, "Sobol+GPEI") self.assertEqual(sobol_GPEI.model_transitions, [5]) # exp.new_trial(generator_run=sobol_GPEI.gen(exp)).run() for i in range(7): g = sobol_GPEI.gen(exp) exp.new_trial(generator_run=g).run() self.assertEqual(len(sobol_GPEI._generator_runs), i + 1) if i > 4: self.mock_torch_model_bridge.assert_called() else: self.assertEqual(g._model_key, "Sobol") self.assertEqual( g._model_kwargs, { "seed": None, "deduplicate": False, "init_position": i, "scramble": True, "generated_points": None, }, ) self.assertEqual( g._bridge_kwargs, { "optimization_config": None, "status_quo_features": None, "status_quo_name": None, "transform_configs": None, "transforms": Cont_X_trans, "fit_out_of_design": False, "fit_abandoned": False, }, ) self.assertEqual(g._model_state_after_gen, {"init_position": i + 1}) # Check completeness error message when GS should be done. with self.assertRaises(GenerationStrategyCompleted): g = sobol_GPEI.gen(exp)
def test_restore_from_generator_run(self): gs = GenerationStrategy(steps=[ GenerationStep(model=Models.SOBOL, num_arms=5), GenerationStep(model=Models.GPEI, num_arms=-1), ]) with self.assertRaises(ValueError): gs._restore_model_from_generator_run() gs.gen(experiment=get_branin_experiment()) model = gs.model gs._restore_model_from_generator_run() # Model should be reset. self.assertIsNot(model, gs.model)
def test_sobol_GPEI_strategy(self, mock_GPEI_gen, mock_GPEI_update, mock_GPEI_init): exp = get_branin_experiment() sobol_GPEI_generation_strategy = GenerationStrategy( name="Sobol+GPEI", steps=[ GenerationStep(model=Models.SOBOL, num_arms=5), GenerationStep(model=Models.GPEI, num_arms=2), ], ) self.assertEqual(sobol_GPEI_generation_strategy.name, "Sobol+GPEI") self.assertEqual(sobol_GPEI_generation_strategy.model_transitions, [5]) exp.new_trial(generator_run=sobol_GPEI_generation_strategy.gen(exp)).run() for i in range(1, 8): if i == 7: # Check completeness error message. with self.assertRaisesRegex(ValueError, "Generation strategy"): g = sobol_GPEI_generation_strategy.gen( exp, exp._fetch_trial_data(trial_index=i - 1) ) else: g = sobol_GPEI_generation_strategy.gen( exp, exp._fetch_trial_data(trial_index=i - 1) ) exp.new_trial(generator_run=g).run() if i > 4: self.assertIsInstance( sobol_GPEI_generation_strategy.model, TorchModelBridge ) else: self.assertEqual(g._model_key, "Sobol") self.assertEqual( g._model_kwargs, { "seed": None, "deduplicate": False, "init_position": i + 1, "scramble": True, }, ) self.assertEqual( g._bridge_kwargs, { "optimization_config": None, "status_quo_features": None, "status_quo_name": None, "transform_configs": None, "transforms": Cont_X_trans, "fit_out_of_design": False, }, ) # Check for "seen data" error message. with self.assertRaisesRegex(ValueError, "Data for arm"): sobol_GPEI_generation_strategy.gen(exp, exp.fetch_data())
def test_min_observed(self): # We should fail to transition the next model if there is not # enough data observed. exp = get_branin_experiment() gs = GenerationStrategy(steps=[ GenerationStep(model=Models.SOBOL, num_arms=5, min_arms_observed=5), GenerationStep(model=Models.GPEI, num_arms=1), ]) for _ in range(5): gs.gen(exp) with self.assertRaises(ValueError): gs.gen(exp)
def test_equality(self): gs1 = GenerationStrategy(steps=[ GenerationStep(model=Models.SOBOL, num_arms=5), GenerationStep(model=Models.GPEI, num_arms=-1), ]) gs2 = GenerationStrategy(steps=[ GenerationStep(model=Models.SOBOL, num_arms=5), GenerationStep(model=Models.GPEI, num_arms=-1), ]) self.assertEqual(gs1, gs2) # Clone_reset() doesn't clone exactly, so they won't be equal. gs3 = gs1.clone_reset() self.assertNotEqual(gs1, gs3)
def test_current_generator_run_limit(self): NUM_INIT_TRIALS = 5 SECOND_STEP_PARALLELISM = 3 NUM_ROUNDS = 4 exp = get_branin_experiment() sobol_gs_with_parallelism_limits = GenerationStrategy(steps=[ GenerationStep( model=Models.SOBOL, num_trials=NUM_INIT_TRIALS, min_trials_observed=3, ), GenerationStep( model=Models.SOBOL, num_trials=-1, max_parallelism=SECOND_STEP_PARALLELISM, ), ]) sobol_gs_with_parallelism_limits._experiment = exp could_gen = [] for _ in range(NUM_ROUNDS): ( num_trials_to_gen, opt_complete, ) = sobol_gs_with_parallelism_limits.current_generator_run_limit() self.assertFalse(opt_complete) could_gen.append(num_trials_to_gen) trials = [] for _ in range(num_trials_to_gen): gr = sobol_gs_with_parallelism_limits.gen( experiment=exp, pending_observations=get_pending(experiment=exp), ) trials.append( exp.new_trial(gr).mark_running(no_runner_required=True)) for trial in trials: exp.attach_data(get_branin_data(trial_indices=[trial.index])) trial.mark_completed() # We expect trials from first generation step + trials from remaining rounds in # batches limited by parallelism setting in the second step. self.assertEqual( len(exp.trials), NUM_INIT_TRIALS + (NUM_ROUNDS - 1) * SECOND_STEP_PARALLELISM, ) self.assertTrue(all(t.status.is_completed for t in exp.trials.values())) self.assertEqual(could_gen, [NUM_INIT_TRIALS] + [SECOND_STEP_PARALLELISM] * (NUM_ROUNDS - 1))
def test_min_observed(self): # We should fail to transition the next model if there is not # enough data observed. exp = get_branin_experiment(get_branin_experiment()) gs = GenerationStrategy(steps=[ GenerationStep( model=Models.SOBOL, num_trials=5, min_trials_observed=5), GenerationStep(model=Models.GPEI, num_trials=1), ]) self.assertFalse(gs.uses_non_registered_models) for _ in range(5): exp.new_trial(gs.gen(exp)) with self.assertRaises(DataRequiredError): gs.gen(exp)
def choose_generation_strategy( search_space: SearchSpace, arms_per_trial: int = 1, enforce_sequential_optimization: bool = True, random_seed: Optional[int] = None, ) -> GenerationStrategy: """Select an appropriate generation strategy based on the properties of the search space.""" model_kwargs = {"seed": random_seed} if (random_seed is not None) else None num_continuous_parameters, num_discrete_choices = 0, 0 for parameter in search_space.parameters.values(): if isinstance(parameter, ChoiceParameter): num_discrete_choices += len(parameter.values) if isinstance(parameter, RangeParameter): num_continuous_parameters += 1 # If there are more discrete choices than continuous parameters, Sobol # will do better than GP+EI. if num_continuous_parameters >= num_discrete_choices: # Ensure that number of arms per model is divisible by batch size. sobol_arms = ( ceil(max(5, len(search_space.parameters)) / arms_per_trial) * arms_per_trial) logger.info( "Using Bayesian Optimization generation strategy. Iterations after " f"{sobol_arms} will take longer to generate due to model-fitting.") return GenerationStrategy( name="Sobol+GPEI", steps=[ GenerationStep( model=Models.SOBOL, num_arms=sobol_arms, min_arms_observed=ceil(sobol_arms / 2), enforce_num_arms=enforce_sequential_optimization, model_kwargs=model_kwargs, ), GenerationStep(model=Models.GPEI, num_arms=-1, recommended_max_parallelism=3), ], ) else: logger.info(f"Using Sobol generation strategy.") return GenerationStrategy( name="Sobol", steps=[ GenerationStep(model=Models.SOBOL, num_arms=-1, model_kwargs=model_kwargs) ], )
def test_string_representation(self): gs1 = GenerationStrategy(steps=[ GenerationStep(model=Models.SOBOL, num_trials=5), GenerationStep(model=Models.GPEI, num_trials=-1), ]) self.assertEqual( str(gs1), ("GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 5 trials," " GPEI for subsequent trials])"), ) gs2 = GenerationStrategy( steps=[GenerationStep(model=Models.SOBOL, num_trials=-1)]) self.assertEqual( str(gs2), ("GenerationStrategy(name='Sobol', steps=[Sobol for all trials])"))
def generation_step_from_json(generation_step_json: Dict[str, Any]) -> GenerationStep: """Load generation step from JSON.""" generation_step_json = _convert_generation_step_keys_for_backwards_compatibility( generation_step_json ) kwargs = generation_step_json.pop("model_kwargs", None) gen_kwargs = generation_step_json.pop("model_gen_kwargs", None) return GenerationStep( model=object_from_json(generation_step_json.pop("model")), num_trials=generation_step_json.pop("num_trials"), min_trials_observed=generation_step_json.pop("min_trials_observed", 0), max_parallelism=(generation_step_json.pop("max_parallelism", None)), use_update=generation_step_json.pop("use_update", False), enforce_num_trials=generation_step_json.pop("enforce_num_trials", True), model_kwargs=_decode_callables_from_references(object_from_json(kwargs)) if kwargs else None, model_gen_kwargs=_decode_callables_from_references(object_from_json(gen_kwargs)) if gen_kwargs else None, index=generation_step_json.pop("index", -1), should_deduplicate=generation_step_json.pop("should_deduplicate") if "should_deduplicate" in generation_step_json else False, )
def run_branin_and_gramacy_100_benchmarks(rep): strategy0 = GenerationStrategy( name="Sobol", steps=[ GenerationStep(model=Models.SOBOL, num_arms=-1, model_kwargs={'seed': rep + 1}) ], ) strategy1 = ALEBOStrategy(D=100, d=4, init_size=10) strategy2 = REMBOStrategy(D=100, d=2, init_per_proj=2) strategy3 = HeSBOStrategy(D=100, d=4, init_per_proj=10, name=f"HeSBO, d=2d") all_benchmarks = full_benchmark_run( num_replications=1, num_trials=50, batch_size=1, methods=[strategy0, strategy1, strategy2, strategy3], problems=[branin_100, gramacy_100], ) with open( f'results/branin_gramacy_100_alebo_rembo_hesbo_sobol_rep_{rep}.json', "w") as fout: json.dump(object_to_json(all_benchmarks), fout)
def test_annotate_exception(self, _): strategy0 = GenerationStrategy( name="Sobol", steps=[GenerationStep(model=Models.SOBOL, num_trials=-1)]) loop = OptimizationLoop.with_evaluation_function( parameters=[ { "name": "x1", "type": "range", "bounds": [-5.0, 10.0], "value_type": "float", "log_scale": False, }, { "name": "x2", "type": "range", "bounds": [0.0, 10.0] }, ], experiment_name="test", objective_name="branin", minimize=True, evaluation_function=_branin_evaluation_function, total_trials=6, generation_strategy=strategy0, ) with self.assertRaisesRegex( expected_exception=RuntimeError, expected_regex="Cholesky errors typically occur", ): loop.run_trial()
def test_custom_gs(self) -> None: """Managed loop with custom generation strategy""" strategy0 = GenerationStrategy( name="Sobol", steps=[GenerationStep(model=Models.SOBOL, num_trials=-1)]) loop = OptimizationLoop.with_evaluation_function( parameters=[ { "name": "x1", "type": "range", "bounds": [-5.0, 10.0], "value_type": "float", "log_scale": False, }, { "name": "x2", "type": "range", "bounds": [0.0, 10.0] }, ], experiment_name="test", objective_name="branin", minimize=True, evaluation_function=_branin_evaluation_function, total_trials=6, generation_strategy=strategy0, ) bp, _ = loop.full_run().get_best_point() self.assertIn("x1", bp) self.assertIn("x2", bp)
def test_basic(self): """Run through the benchmarking loop.""" results = full_benchmark_run( problem_groups={ self.CATEGORY_NAME: [ SimpleBenchmarkProblem(branin, noise_sd=0.4), BenchmarkProblem( name="Branin", search_space=get_branin_search_space(), optimization_config=get_branin_optimization_config(), ), BenchmarkProblem( search_space=get_branin_search_space(), optimization_config=get_optimization_config(), ), ] }, method_groups={ self.CATEGORY_NAME: [ GenerationStrategy(steps=[ GenerationStep(model=Models.SOBOL, num_trials=-1) ]) ] }, num_replications=3, num_trials=5, # Just to have it be more telling if something is broken raise_all_exceptions=True, batch_size=[[1], [3], [1]], ) self.assertEqual(len(results["Branin"]["Sobol"]), 3)
def test_sobol_GPEI_strategy_keep_generating(self): exp = get_branin_experiment() sobol_GPEI_generation_strategy = GenerationStrategy(steps=[ GenerationStep(model=Models.SOBOL, num_arms=5), GenerationStep(model=Models.GPEI, num_arms=-1), ]) self.assertEqual(sobol_GPEI_generation_strategy.name, "Sobol+GPEI") self.assertEqual(sobol_GPEI_generation_strategy.model_transitions, [5]) exp.new_trial( generator_run=sobol_GPEI_generation_strategy.gen(exp)).run() for i in range(1, 15): g = sobol_GPEI_generation_strategy.gen(exp, exp.fetch_data()) exp.new_trial(generator_run=g).run() if i > 4: self.assertIsInstance(sobol_GPEI_generation_strategy.model, TorchModelBridge)
def test_store_experiment(self): exp = get_branin_experiment() sobol_generation_strategy = GenerationStrategy( steps=[GenerationStep(model=Models.SOBOL, num_arms=5)]) self.assertIsNone(sobol_generation_strategy._experiment) sobol_generation_strategy.gen(exp) self.assertIsNotNone(sobol_generation_strategy._experiment)
def test_sobol(self): suite = BOBenchmarkingSuite() runner = suite.run( num_runs=1, total_iterations=5, batch_size=2, bo_strategies=[ GenerationStrategy( [GenerationStep(model=Models.SOBOL, num_arms=10)]) ], bo_problems=[branin], ) # If run_benchmarking_trial fails, corresponding trial in '_runs' is None. self.assertTrue(all(x is not None for x in runner._runs.values())) # Make sure no errors came up in running trials. self.assertEqual(len(runner.errors), 0) report = suite.generate_report() self.assertIsInstance(report, str) # Add a trial setup = BenchmarkSetup(problem=branin, total_iterations=10, batch_size=1) suite.add_run(setup=setup, strategy_name="strategy_name") self.assertTrue(("Branin", "strategy_name", 0) in suite._runner._runs) suite.add_run(setup=setup, strategy_name="strategy_name") self.assertTrue(("Branin", "strategy_name", 1) in suite._runner._runs)
def run_hartmann6_benchmarks(D, rep, random_subspace=False): if D == 100: problem = hartmann6_100 elif D == 1000 and not random_subspace: problem = hartmann6_1000 elif D == 1000 and random_subspace: problem = hartmann6_random_subspace_1000 strategy0 = GenerationStrategy( name="Sobol", steps=[ GenerationStep(model=Models.SOBOL, num_arms=-1, model_kwargs={'seed': rep + 1}) ], ) strategy1 = ALEBOStrategy(D=D, d=12, init_size=10) strategy2 = REMBOStrategy(D=D, d=6, init_per_proj=2) strategy3 = HeSBOStrategy(D=D, d=6, init_per_proj=10, name=f"HeSBO, d=d") strategy4 = HeSBOStrategy(D=D, d=12, init_per_proj=10, name=f"HeSBO, d=2d") all_benchmarks = full_benchmark_run( num_replications=1, # Running them 1 at a time for distributed num_trials=200, batch_size=1, methods=[strategy0, strategy1, strategy2, strategy3, strategy4], problems=[problem], ) rs_str = 'random_subspace_' if random_subspace else '' with open( f'results/hartmann6_{rs_str}{D}_alebo_rembo_hesbo_sobol_rep_{rep}.json', "w") as fout: json.dump(object_to_json(all_benchmarks), fout)
def test_raise_all_exceptions(self): """Checks that an exception nested in the benchmarking stack is raised when `raise_all_exceptions` is True. """ def broken_benchmark_replication(*args, **kwargs) -> Experiment: raise ValueError("Oh, exception!") with self.assertRaisesRegex(ValueError, "Oh, exception!"): full_benchmark_run( problem_groups={ self.CATEGORY_NAME: [SimpleBenchmarkProblem(branin, noise_sd=0.4)] }, method_groups={ self.CATEGORY_NAME: [ GenerationStrategy(steps=[ GenerationStep(model=Models.SOBOL, num_trials=-1) ]) ] }, num_replications=3, num_trials=5, raise_all_exceptions=True, benchmark_replication=broken_benchmark_replication, )
def test_optimize_graceful_exit_on_exception(self) -> None: """Tests optimization as a single call, with exception during candidate generation. """ best, vals, exp, model = optimize( parameters=[ # pyre-fixme[6] {"name": "x1", "type": "range", "bounds": [-10.0, 10.0]}, {"name": "x2", "type": "range", "bounds": [-10.0, 10.0]}, ], # Booth function. evaluation_function=lambda p: ( (p["x1"] + 2 * p["x2"] - 7) ** 2 + (2 * p["x1"] + p["x2"] - 5) ** 2, None, ), minimize=True, total_trials=6, generation_strategy=GenerationStrategy( name="Sobol", steps=[GenerationStep(model=Models.SOBOL, num_trials=3)] ), ) self.assertEqual(len(exp.trials), 3) # Check that we stopped at 3 trials. # All the regular return values should still be present. self.assertIn("x1", best) self.assertIn("x2", best) self.assertIsNotNone(vals) self.assertIn("objective", vals[0]) self.assertIn("objective", vals[1]) self.assertIn("objective", vals[1]["objective"])
def test_use_update(self, mock_fetch_trials_data, mock_update): exp = get_branin_experiment() sobol_gs_with_update = GenerationStrategy( steps=[GenerationStep(model=Models.SOBOL, num_trials=-1, use_update=True)] ) # Try without passing data (generation strategy fetches data from experiment). trial = exp.new_trial(generator_run=sobol_gs_with_update.gen(experiment=exp)) mock_update.assert_not_called() trial._status = TrialStatus.COMPLETED for i in range(3): trial = exp.new_trial( generator_run=sobol_gs_with_update.gen(experiment=exp) ) self.assertEqual( mock_fetch_trials_data.call_args[1].get("trial_indices"), {i} ) trial._status = TrialStatus.COMPLETED # Try with passing data. sobol_gs_with_update.gen( experiment=exp, data=get_branin_data(trial_indices=range(4)) ) # Only the data for the last completed trial should be considered new and passed # to `update`. self.assertEqual( set(mock_update.call_args[1].get("new_data").df["trial_index"].values), {3} )
def test_validation(self): # num_trials can be positive or -1. with self.assertRaises(ValueError): GenerationStrategy( steps=[ GenerationStep(model=Models.SOBOL, num_trials=5), GenerationStep(model=Models.GPEI, num_trials=-10), ] ) # only last num_trials can be -1. with self.assertRaises(ValueError): GenerationStrategy( steps=[ GenerationStep(model=Models.SOBOL, num_trials=-1), GenerationStep(model=Models.GPEI, num_trials=10), ] ) exp = Experiment( name="test", search_space=SearchSpace(parameters=[get_choice_parameter()]) ) factorial_thompson_generation_strategy = GenerationStrategy( steps=[ GenerationStep(model=Models.FACTORIAL, num_trials=1), GenerationStep(model=Models.THOMPSON, num_trials=2), ] ) self.assertTrue(factorial_thompson_generation_strategy._uses_registered_models) self.assertFalse( factorial_thompson_generation_strategy.uses_non_registered_models ) with self.assertRaises(ValueError): factorial_thompson_generation_strategy.gen(exp) self.assertEqual(GenerationStep(model=sum, num_trials=1).model_name, "sum")
def test_kwargs_passed(self): gs = GenerationStrategy(steps=[ GenerationStep(model=Models.SOBOL, num_arms=1, model_kwargs={"scramble": False}) ]) exp = get_branin_experiment() gs.gen(exp, exp.fetch_data()) self.assertFalse(gs._model.model.scramble)
def test_do_not_enforce_min_observations(self): # We should be able to move on to the next model if there is not # enough data observed if `enforce_num_arms` setting is False, in which # case the previous model should be used until there is enough data. exp = get_branin_experiment() gs = GenerationStrategy(steps=[ GenerationStep( model=Models.SOBOL, num_arms=1, min_arms_observed=5, enforce_num_arms=False, ), GenerationStep(model=Models.GPEI, num_arms=1), ]) for _ in range(2): gs.gen(exp) # Make sure Sobol is used to generate the 6th point. self.assertIsInstance(gs._model, RandomModelBridge)
def test_custom_callables_for_models(self): exp = get_branin_experiment() sobol_factory_generation_strategy = GenerationStrategy( steps=[GenerationStep(model=get_sobol, num_trials=-1)] ) self.assertFalse(sobol_factory_generation_strategy._uses_registered_models) self.assertTrue(sobol_factory_generation_strategy.uses_non_registered_models) gr = sobol_factory_generation_strategy.gen(experiment=exp, n=1) self.assertEqual(len(gr.arms), 1)
def test_max_parallelism_reached(self): exp = get_branin_experiment() sobol_generation_strategy = GenerationStrategy(steps=[ GenerationStep(model=Models.SOBOL, num_trials=5, max_parallelism=1) ]) exp.new_trial(generator_run=sobol_generation_strategy.gen( experiment=exp)).mark_running(no_runner_required=True) with self.assertRaises(MaxParallelismReachedException): sobol_generation_strategy.gen(experiment=exp)
def test_use_update(self, mock_lookup_data, mock_update): exp = get_branin_experiment() sobol_gs_with_update = GenerationStrategy(steps=[ GenerationStep(model=Models.SOBOL, num_trials=-1, use_update=True) ]) sobol_gs_with_update._experiment = exp self.assertEqual( sobol_gs_with_update._find_trials_completed_since_last_gen(), set(), ) with self.assertRaises(NotImplementedError): # `BraninMetric` is available while running by default, which should # raise an error when use with `use_update=True` on a generation step, as we # have not yet properly addressed that edge case (for lack of use case). sobol_gs_with_update.gen(experiment=exp) core_stubs_module = get_branin_experiment.__module__ with patch( f"{core_stubs_module}.BraninMetric.is_available_while_running", return_value=False, ): # Try without passing data (GS looks up data on experiment). trial = exp.new_trial(generator_run=sobol_gs_with_update.gen( experiment=exp)) mock_update.assert_not_called() trial._status = TrialStatus.COMPLETED for i in range(3): gr = sobol_gs_with_update.gen(experiment=exp) self.assertEqual( mock_lookup_data.call_args[1].get("trial_indices"), {i}) trial = exp.new_trial(generator_run=gr) trial._status = TrialStatus.COMPLETED # `_seen_trial_indices_by_status` is set during `gen`, to the experiment's # `trial_indices_by_Status` at the time of candidate generation. self.assertNotEqual( sobol_gs_with_update._seen_trial_indices_by_status, exp.trial_indices_by_status, ) # Try with passing data. sobol_gs_with_update.gen( experiment=exp, data=get_branin_data(trial_indices=range(4))) # Now `_seen_trial_indices_by_status` should be set to experiment's, self.assertEqual( sobol_gs_with_update._seen_trial_indices_by_status, exp.trial_indices_by_status, ) # Only the data for the last completed trial should be considered new and passed # to `update`. self.assertEqual( set(mock_update.call_args[1].get( "new_data").df["trial_index"].values), {3}) # Try with passing same data as before; no update should be performed. with patch.object(sobol_gs_with_update, "_update_current_model") as mock_update: sobol_gs_with_update.gen( experiment=exp, data=get_branin_data(trial_indices=range(4))) mock_update.assert_not_called()