コード例 #1
0
    def test_validation(self):
        # num_arms can be positive or -1.
        with self.assertRaises(ValueError):
            GenerationStrategy(steps=[
                GenerationStep(model=Models.SOBOL, num_arms=5),
                GenerationStep(model=Models.GPEI, num_arms=-10),
            ])

        # only last num_arms can be -1.
        with self.assertRaises(ValueError):
            GenerationStrategy(steps=[
                GenerationStep(model=Models.SOBOL, num_arms=-1),
                GenerationStep(model=Models.GPEI, num_arms=10),
            ])

        exp = Experiment(
            name="test",
            search_space=SearchSpace(parameters=[get_choice_parameter()]))
        factorial_thompson_generation_strategy = GenerationStrategy(steps=[
            GenerationStep(model=Models.FACTORIAL, num_arms=1),
            GenerationStep(model=Models.THOMPSON, num_arms=2),
        ])
        with self.assertRaises(ValueError):
            factorial_thompson_generation_strategy.gen(exp)
コード例 #2
0
ファイル: benchmark.py プロジェクト: cristicmf/Ax
def _benchmark_replication_Dev_API(
    problem: BenchmarkProblem,
    method: GenerationStrategy,
    num_trials: int,
    experiment_name: str,
    batch_size: int = 1,
    raise_all_exceptions: bool = False,
    benchmark_trial: FunctionType = benchmark_trial,
    verbose_logging: bool = True,
    # Number of trials that need to fail for a replication to be considered failed.
    failed_trials_tolerated: int = 5,
    async_benchmark_options: Optional[AsyncBenchmarkOptions] = None,
) -> Tuple[Experiment, List[Exception]]:
    """Run a benchmark replication via the Developer API because the problem was
    set up with Ax classes (likely to allow for additional complexity like
    adding constraints or non-range parameters).
    """
    if async_benchmark_options is not None:
        raise NonRetryableBenchmarkingError(
            "`async_benchmark_options` not supported when using the Dev API.")

    exceptions = []
    experiment = Experiment(
        name=experiment_name,
        search_space=problem.search_space,
        optimization_config=problem.optimization_config,
        runner=SyntheticRunner(),
    )
    for trial_index in range(num_trials):
        try:
            gr = method.gen(experiment=experiment, n=batch_size)
            if batch_size == 1:
                trial = experiment.new_trial(generator_run=gr)
            else:
                assert batch_size > 1
                trial = experiment.new_batch_trial(generator_run=gr)
            trial.run()
            benchmark_trial(experiment=experiment, trial_index=trial_index)
            trial.mark_completed()
        except Exception as err:  # TODO[T53975770]: test
            if raise_all_exceptions:
                raise
            exceptions.append(err)
        if len(exceptions) > failed_trials_tolerated:
            raise RuntimeError(  # TODO[T53975770]: test
                f"More than {failed_trials_tolerated} failed for {experiment_name}."
            )
    return experiment, exceptions
コード例 #3
0
ファイル: test_generation_strategy.py プロジェクト: ekilic/Ax
 def test_trials_as_df(self):
     exp = get_branin_experiment()
     sobol_generation_strategy = GenerationStrategy(
         steps=[GenerationStep(model=Models.SOBOL, num_trials=5)])
     # No trials yet, so the DF will be None.
     self.assertIsNone(sobol_generation_strategy.trials_as_df)
     # Now the trial should appear in the DF.
     trial = exp.new_trial(sobol_generation_strategy.gen(experiment=exp))
     self.assertFalse(sobol_generation_strategy.trials_as_df.empty)
     self.assertEqual(
         sobol_generation_strategy.trials_as_df.head()["Trial Status"][0],
         "CANDIDATE",
     )
     # Changes in trial status should be reflected in the DF.
     trial._status = TrialStatus.RUNNING
     self.assertEqual(
         sobol_generation_strategy.trials_as_df.head()["Trial Status"][0],
         "RUNNING")
コード例 #4
0
    def test_with_factory_function(self):
        """Checks that generation strategy works with custom factory functions.
        No information about the model should be saved on generator run."""
        def get_sobol(search_space: SearchSpace) -> RandomModelBridge:
            return RandomModelBridge(
                search_space=search_space,
                model=SobolGenerator(),
                transforms=Cont_X_trans,
            )

        exp = get_branin_experiment()
        sobol_generation_strategy = GenerationStrategy(
            steps=[GenerationStep(model=get_sobol, num_arms=5)])
        g = sobol_generation_strategy.gen(exp)
        self.assertIsInstance(sobol_generation_strategy.model,
                              RandomModelBridge)
        self.assertIsNone(g._model_key)
        self.assertIsNone(g._model_kwargs)
        self.assertIsNone(g._bridge_kwargs)
コード例 #5
0
def _benchmark_replication_Dev_API(
    problem: BenchmarkProblem,
    method: GenerationStrategy,
    num_trials: int,
    experiment_name: str,
    batch_size: int = 1,
    raise_all_exceptions: bool = False,
    benchmark_trial: FunctionType = benchmark_trial,
    verbose_logging: bool = True,
    # Number of trials that need to fail for a replication to be considered failed.
    failed_trials_tolerated: int = 5,
) -> Tuple[Experiment, List[Exception]]:
    """Run a benchmark replication via the Developer API because the problem was
    set up with Ax classes (likely to allow for additional complexity like
    adding constraints or non-range parameters).
    """
    exceptions = []
    experiment = Experiment(
        search_space=problem.search_space,
        optimization_config=problem.optimization_config,
        runner=SyntheticRunner(),
    )
    new_data = Data()
    for trial_idx in range(num_trials):
        try:
            gr = method.gen(experiment=experiment, new_data=new_data, n=batch_size)
            if batch_size == 1:
                experiment.new_trial(generator_run=gr).run()
            else:
                assert batch_size > 1
                experiment.new_batch_trial(generator_run=gr).run()
            new_data = checked_cast(
                Data, benchmark_trial(experiment=experiment, trial_index=trial_idx)
            )
        except Exception as err:  # TODO[T53975770]: test
            if raise_all_exceptions:
                raise
            exceptions.append(err)
        if len(exceptions) > failed_trials_tolerated:
            raise RuntimeError(  # TODO[T53975770]: test
                f"More than {failed_trials_tolerated} failed for {experiment_name}."
            )
    return experiment, exceptions
コード例 #6
0
 def run_benchmark_run(
         self, setup: BenchmarkSetup,
         generation_strategy: GenerationStrategy) -> BenchmarkSetup:
     remaining_iterations = setup.total_iterations
     updated_trials = []
     while remaining_iterations > 0:
         num_suggestions = min(remaining_iterations, setup.batch_size)
         generator_run = generation_strategy.gen(
             experiment=setup,
             new_data=Data.from_multiple_data(
                 [setup._fetch_trial_data(idx) for idx in updated_trials]),
             n=setup.batch_size,
         )
         updated_trials = []
         if setup.batch_size > 1:  # pragma: no cover
             trial = setup.new_batch_trial().add_generator_run(
                 generator_run).run()
         else:
             trial = setup.new_trial(generator_run=generator_run).run()
         updated_trials.append(trial.index)
         remaining_iterations -= num_suggestions
     return setup
コード例 #7
0
 def test_sobol_GPEI_strategy_batches(
     self, mock_GPEI_gen, mock_GPEI_update, mock_GPEI_init
 ):
     exp = get_branin_experiment()
     sobol_GPEI_generation_strategy = GenerationStrategy(
         name="Sobol+GPEI",
         steps=[
             GenerationStep(model=Models.SOBOL, num_arms=5),
             GenerationStep(model=Models.GPEI, num_arms=8),
         ],
     )
     self.assertEqual(sobol_GPEI_generation_strategy.name, "Sobol+GPEI")
     self.assertEqual(sobol_GPEI_generation_strategy.generator_changes, [5])
     exp.new_batch_trial(
         generator_run=sobol_GPEI_generation_strategy.gen(exp, n=2)
     ).run()
     for i in range(1, 8):
         if i == 2:
             with self.assertRaisesRegex(ValueError, "Cannot generate 2 new"):
                 g = sobol_GPEI_generation_strategy.gen(
                     exp, exp._fetch_trial_data(trial_index=i - 1), n=2
                 )
             g = sobol_GPEI_generation_strategy.gen(
                 exp, exp._fetch_trial_data(trial_index=i - 1)
             )
         elif i == 7:
             # Check completeness error message.
             with self.assertRaisesRegex(ValueError, "Generation strategy"):
                 g = sobol_GPEI_generation_strategy.gen(
                     exp, exp._fetch_trial_data(trial_index=i - 1), n=2
                 )
         else:
             g = sobol_GPEI_generation_strategy.gen(
                 exp, exp._fetch_trial_data(trial_index=i - 1), n=2
             )
         exp.new_batch_trial(generator_run=g).run()
     with self.assertRaises(ValueError):
         sobol_GPEI_generation_strategy.gen(exp, exp.fetch_data())
     self.assertIsInstance(sobol_GPEI_generation_strategy.model, TorchModelBridge)
コード例 #8
0
def create_new_experiment(input_data, runner, metric, saver_loader):
    ## parse search space
    search_space = parse_search_space(input_data['search_space'])

    ## define experiment
    experiment = Experiment(name=input_data['test_name'],
                            search_space=search_space,
                            description=input_data['test_description'])

    ## set control_group
    if input_data['control_group']:
        experiment.status_quo = Arm(name="control",
                                    parameters=input_data['control_group'])
    else:
        pass

    ## create objectives
    metrics = []
    weights = []
    for i, j in input_data['metrics_weights'].items():
        metrics += [metric(name=i, lower_is_better=False)]
        weights += [j]

    main_objective = ScalarizedObjective(metrics=metrics,
                                         weights=weights,
                                         minimize=False)

    optimization_config = OptimizationConfig(objective=main_objective)
    experiment.optimization_config = optimization_config

    ## create generator strategy
    if input_data['arms_to_generate'] == -1:
        generation_step0_model = Models.FACTORIAL
    else:
        generation_step0_model = Models.SOBOL

    if input_data['test_description']['module'] == 'bayesian_optimization':
        if 'choice' in [
                j['type']
                for i, j in input_data['search_space']['parameters'].items()
        ]:
            return 'choice param not implemented for bayesian opt'
        else:
            generation_step1_model = Models.BOTORCH
    elif input_data['test_description']['module'] == 'bandit':
        generation_step1_model = Models.THOMPSON

    generation_strategy = GenerationStrategy(steps=[
        GenerationStep(model=generation_step0_model, num_trials=1),
        GenerationStep(model=generation_step1_model,
                       num_trials=-1,
                       model_kwargs={'min_weight': 0.01}),
    ])

    ## generate primary arms
    generation_strategy.gen(experiment=experiment,
                            search_space=search_space,
                            n=input_data['arms_to_generate'])

    ## Runners can also be manually added to a trial to override the experiment default.
    experiment.runner = runner()

    ## create first trial with starting arms
    if input_data['control_group']:
        optimize_for_power = True
    else:
        optimize_for_power = False

    experiment.new_batch_trial(
        generator_run=generation_strategy.last_generator_run,
        optimize_for_power=optimize_for_power)

    ## save experiment
    saver_loader.save_full_experiment(experiment, generation_strategy)

    ## return information
    exp_json = object_to_json(experiment)
    experiment_metadata = {
        'experiment_name': exp_json['name'],
        'experiment_description': exp_json['description'],
        'search_space': exp_json['search_space'],
        'trial0_arms': {
            object_to_json(arm)['name']: {
                'parameters': object_to_json(arm)['parameters'],
                'weight': weight
            }
            for arm, weight in
            experiment.trials[0].normalized_arm_weights().items()
        },
        'optimization_config': exp_json['optimization_config'],
        'control_group': exp_json['status_quo'],
        'runner': exp_json['runner'],
        'time_created': exp_json['time_created']['value']
    }
    return experiment_metadata
コード例 #9
0
class TestGenerationStrategy(TestCase):
    def setUp(self):
        self.gr = GeneratorRun(arms=[Arm(parameters={"x1": 1, "x2": 2})])

        # Mock out slow GPEI.
        self.torch_model_bridge_patcher = patch(
            f"{TorchModelBridge.__module__}.TorchModelBridge", spec=True)
        self.mock_torch_model_bridge = self.torch_model_bridge_patcher.start()
        self.mock_torch_model_bridge.return_value.gen.return_value = self.gr

        # Mock out slow TS.
        self.discrete_model_bridge_patcher = patch(
            f"{DiscreteModelBridge.__module__}.DiscreteModelBridge", spec=True)
        self.mock_discrete_model_bridge = self.discrete_model_bridge_patcher.start(
        )
        self.mock_discrete_model_bridge.return_value.gen.return_value = self.gr

        # Mock in `Models` registry
        self.registry_setup_dict_patcher = patch.dict(
            f"{Models.__module__}.MODEL_KEY_TO_MODEL_SETUP",
            {
                "Factorial":
                MODEL_KEY_TO_MODEL_SETUP["Factorial"]._replace(
                    bridge_class=self.mock_discrete_model_bridge),
                "Thompson":
                MODEL_KEY_TO_MODEL_SETUP["Thompson"]._replace(
                    bridge_class=self.mock_discrete_model_bridge),
                "GPEI":
                MODEL_KEY_TO_MODEL_SETUP["GPEI"]._replace(
                    bridge_class=self.mock_torch_model_bridge),
            },
        )
        self.mock_in_registry = self.registry_setup_dict_patcher.start()

        # model bridges are mocked, which makes kwargs' validation difficult,
        # so for now we will skip it in the generation strategy tests.
        # NOTE: Starting with Python3.8 this is not a problem as `autospec=True`
        # ensures that the mocks have correct signatures, but in earlier
        # versions kwarg validation on mocks does not really work.
        self.step_model_kwargs = {"silently_filter_kwargs": True}
        self.hss_experiment = get_hierarchical_search_space_experiment()
        self.sobol_GPEI_GS = GenerationStrategy(
            name="Sobol+GPEI",
            steps=[
                GenerationStep(
                    model=Models.SOBOL,
                    num_trials=5,
                    model_kwargs=self.step_model_kwargs,
                ),
                GenerationStep(model=Models.GPEI,
                               num_trials=2,
                               model_kwargs=self.step_model_kwargs),
            ],
        )
        self.sobol_GS = GenerationStrategy(steps=[
            GenerationStep(
                Models.SOBOL,
                num_trials=-1,
                should_deduplicate=True,
            )
        ])

    def tearDown(self):
        self.torch_model_bridge_patcher.stop()
        self.discrete_model_bridge_patcher.stop()
        self.registry_setup_dict_patcher.stop()

    def test_name(self):
        self.sobol_GS.name = "SomeGSName"
        self.assertEqual(self.sobol_GS.name, "SomeGSName")

    def test_validation(self):
        # num_trials can be positive or -1.
        with self.assertRaises(UserInputError):
            GenerationStrategy(steps=[
                GenerationStep(model=Models.SOBOL, num_trials=5),
                GenerationStep(model=Models.GPEI, num_trials=-10),
            ])

        # only last num_trials can be -1.
        with self.assertRaises(UserInputError):
            GenerationStrategy(steps=[
                GenerationStep(model=Models.SOBOL, num_trials=-1),
                GenerationStep(model=Models.GPEI, num_trials=10),
            ])

        exp = Experiment(
            name="test",
            search_space=SearchSpace(parameters=[get_choice_parameter()]))
        factorial_thompson_generation_strategy = GenerationStrategy(steps=[
            GenerationStep(model=Models.FACTORIAL, num_trials=1),
            GenerationStep(model=Models.THOMPSON, num_trials=2),
        ])
        self.assertTrue(
            factorial_thompson_generation_strategy._uses_registered_models)
        self.assertFalse(
            factorial_thompson_generation_strategy.uses_non_registered_models)
        with self.assertRaises(ValueError):
            factorial_thompson_generation_strategy.gen(exp)
        self.assertEqual(
            GenerationStep(model=sum, num_trials=1).model_name, "sum")
        with self.assertRaisesRegex(UserInputError,
                                    "Maximum parallelism should be"):
            GenerationStrategy(steps=[
                GenerationStep(
                    model=Models.SOBOL, num_trials=5, max_parallelism=-1),
                GenerationStep(model=Models.GPEI, num_trials=-1),
            ])

    def test_custom_callables_for_models(self):
        exp = get_branin_experiment()
        sobol_factory_generation_strategy = GenerationStrategy(
            steps=[GenerationStep(model=get_sobol, num_trials=-1)])
        self.assertFalse(
            sobol_factory_generation_strategy._uses_registered_models)
        self.assertTrue(
            sobol_factory_generation_strategy.uses_non_registered_models)
        gr = sobol_factory_generation_strategy.gen(experiment=exp, n=1)
        self.assertEqual(len(gr.arms), 1)

    def test_string_representation(self):
        gs1 = GenerationStrategy(steps=[
            GenerationStep(model=Models.SOBOL, num_trials=5),
            GenerationStep(model=Models.GPEI, num_trials=-1),
        ])
        self.assertEqual(
            str(gs1),
            ("GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 5 trials,"
             " GPEI for subsequent trials])"),
        )
        gs2 = GenerationStrategy(
            steps=[GenerationStep(model=Models.SOBOL, num_trials=-1)])
        self.assertEqual(
            str(gs2),
            "GenerationStrategy(name='Sobol', steps=[Sobol for all trials])")

    def test_equality(self):
        gs1 = GenerationStrategy(
            name="Sobol+GPEI",
            steps=[
                GenerationStep(model=Models.SOBOL, num_trials=5),
                GenerationStep(model=Models.GPEI, num_trials=-1),
            ],
        )
        gs2 = GenerationStrategy(
            name="Sobol+GPEI",
            steps=[
                GenerationStep(model=Models.SOBOL, num_trials=5),
                GenerationStep(model=Models.GPEI, num_trials=-1),
            ],
        )
        self.assertEqual(gs1, gs2)

        # Clone_reset() doesn't clone exactly, so they won't be equal.
        gs3 = gs1.clone_reset()
        self.assertEqual(gs1, gs3)

    def test_restore_from_generator_run(self):
        gs = GenerationStrategy(
            steps=[GenerationStep(model=Models.SOBOL, num_trials=5)])
        # No generator runs on GS, so can't restore from one.
        with self.assertRaises(ValueError):
            gs._restore_model_from_generator_run()
        exp = get_branin_experiment(with_batch=True)
        gs.gen(experiment=exp)
        model = gs.model
        # Create a copy of the generation strategy and check that when
        # we restore from last generator run, the model will be set
        # correctly and that `_seen_trial_indices_by_status` is filled.
        new_gs = GenerationStrategy(
            steps=[GenerationStep(model=Models.SOBOL, num_trials=5)])
        new_gs._experiment = exp
        new_gs._generator_runs = gs._generator_runs
        self.assertIsNone(new_gs._seen_trial_indices_by_status)
        new_gs._restore_model_from_generator_run()
        self.assertEqual(gs._seen_trial_indices_by_status,
                         exp.trial_indices_by_status)
        # Model should be reset, but it should be the same model with same data.
        self.assertIsNot(model, new_gs.model)
        self.assertEqual(model.__class__,
                         new_gs.model.__class__)  # Model bridge.
        self.assertEqual(model.model.__class__,
                         new_gs.model.model.__class__)  # Model.
        self.assertEqual(model._training_data, new_gs.model._training_data)

    def test_min_observed(self):
        # We should fail to transition the next model if there is not
        # enough data observed.
        exp = get_branin_experiment(get_branin_experiment())
        gs = GenerationStrategy(steps=[
            GenerationStep(
                model=Models.SOBOL, num_trials=5, min_trials_observed=5),
            GenerationStep(model=Models.GPEI, num_trials=1),
        ])
        self.assertFalse(gs.uses_non_registered_models)
        for _ in range(5):
            exp.new_trial(gs.gen(exp))
        with self.assertRaises(DataRequiredError):
            gs.gen(exp)

    def test_do_not_enforce_min_observations(self):
        # We should be able to move on to the next model if there is not
        # enough data observed if `enforce_num_trials` setting is False, in which
        # case the previous model should be used until there is enough data.
        exp = get_branin_experiment()
        gs = GenerationStrategy(steps=[
            GenerationStep(
                model=Models.SOBOL,
                num_trials=1,
                min_trials_observed=5,
                enforce_num_trials=False,
            ),
            GenerationStep(model=Models.GPEI, num_trials=1),
        ])
        for _ in range(2):
            gs.gen(exp)
        # Make sure Sobol is used to generate the 6th point.
        self.assertIsInstance(gs._model, RandomModelBridge)

    def test_sobol_GPEI_strategy(self):
        exp = get_branin_experiment()
        self.assertEqual(self.sobol_GPEI_GS.name, "Sobol+GPEI")
        self.assertEqual(self.sobol_GPEI_GS.model_transitions, [5])
        for i in range(7):
            g = self.sobol_GPEI_GS.gen(exp)
            exp.new_trial(generator_run=g).run()
            self.assertEqual(len(self.sobol_GPEI_GS._generator_runs), i + 1)
            if i > 4:
                self.mock_torch_model_bridge.assert_called()
            else:
                self.assertEqual(g._model_key, "Sobol")
                self.assertEqual(
                    g._model_kwargs,
                    {
                        "seed": None,
                        "deduplicate": False,
                        "init_position": i,
                        "scramble": True,
                        "generated_points": None,
                        "fallback_to_sample_polytope": False,
                    },
                )
                self.assertEqual(
                    g._bridge_kwargs,
                    {
                        "optimization_config": None,
                        "status_quo_features": None,
                        "status_quo_name": None,
                        "transform_configs": None,
                        "transforms": Cont_X_trans,
                        "fit_out_of_design": False,
                        "fit_abandoned": False,
                    },
                )
                self.assertEqual(g._model_state_after_gen,
                                 {"init_position": i + 1})
        # Check completeness error message when GS should be done.
        with self.assertRaises(GenerationStrategyCompleted):
            g = self.sobol_GPEI_GS.gen(exp)

    def test_sobol_GPEI_strategy_keep_generating(self):
        exp = get_branin_experiment()
        sobol_GPEI_generation_strategy = GenerationStrategy(steps=[
            GenerationStep(
                model=Models.SOBOL,
                num_trials=5,
                model_kwargs=self.step_model_kwargs,
            ),
            GenerationStep(
                model=Models.GPEI,
                num_trials=-1,
                model_kwargs=self.step_model_kwargs,
            ),
        ])
        self.assertEqual(sobol_GPEI_generation_strategy.name, "Sobol+GPEI")
        self.assertEqual(sobol_GPEI_generation_strategy.model_transitions, [5])
        exp.new_trial(
            generator_run=sobol_GPEI_generation_strategy.gen(exp)).run()
        for i in range(1, 15):
            g = sobol_GPEI_generation_strategy.gen(exp)
            exp.new_trial(generator_run=g).run()
            if i > 4:
                self.assertIsInstance(sobol_GPEI_generation_strategy.model,
                                      TorchModelBridge)

    def test_sobol_strategy(self):
        exp = get_branin_experiment()
        sobol_generation_strategy = GenerationStrategy(steps=[
            GenerationStep(
                model=Models.SOBOL,
                num_trials=5,
                max_parallelism=10,
                use_update=False,
                enforce_num_trials=False,
            )
        ])
        for i in range(1, 6):
            sobol_generation_strategy.gen(exp, n=1)
            self.assertEqual(len(sobol_generation_strategy._generator_runs), i)

    @patch(f"{Experiment.__module__}.Experiment.fetch_data",
           return_value=get_data())
    def test_factorial_thompson_strategy(self, _):
        exp = get_branin_experiment()
        factorial_thompson_generation_strategy = GenerationStrategy(steps=[
            GenerationStep(
                model=Models.FACTORIAL,
                num_trials=1,
                model_kwargs=self.step_model_kwargs,
            ),
            GenerationStep(
                model=Models.THOMPSON,
                num_trials=-1,
                model_kwargs=self.step_model_kwargs,
            ),
        ])
        self.assertEqual(factorial_thompson_generation_strategy.name,
                         "Factorial+Thompson")
        self.assertEqual(
            factorial_thompson_generation_strategy.model_transitions, [1])
        mock_model_bridge = self.mock_discrete_model_bridge.return_value

        # Initial factorial batch.
        exp.new_batch_trial(
            factorial_thompson_generation_strategy.gen(experiment=exp))
        args, kwargs = mock_model_bridge._set_kwargs_to_save.call_args
        self.assertEqual(kwargs.get("model_key"), "Factorial")

        # Subsequent Thompson sampling batch.
        exp.new_batch_trial(
            factorial_thompson_generation_strategy.gen(experiment=exp))
        args, kwargs = mock_model_bridge._set_kwargs_to_save.call_args
        self.assertEqual(kwargs.get("model_key"), "Thompson")

    def test_clone_reset(self):
        ftgs = GenerationStrategy(steps=[
            GenerationStep(model=Models.FACTORIAL, num_trials=1),
            GenerationStep(model=Models.THOMPSON, num_trials=2),
        ])
        ftgs._curr = ftgs._steps[1]
        self.assertEqual(ftgs._curr.index, 1)
        self.assertEqual(ftgs.clone_reset()._curr.index, 0)

    def test_kwargs_passed(self):
        gs = GenerationStrategy(steps=[
            GenerationStep(model=Models.SOBOL,
                           num_trials=1,
                           model_kwargs={"scramble": False})
        ])
        exp = get_branin_experiment()
        gs.gen(exp)
        self.assertFalse(gs._model.model.scramble)

    def test_sobol_GPEI_strategy_batches(self):
        mock_GPEI_gen = self.mock_torch_model_bridge.return_value.gen
        mock_GPEI_gen.return_value = GeneratorRun(arms=[
            Arm(parameters={
                "x1": 1,
                "x2": 2
            }),
            Arm(parameters={
                "x1": 3,
                "x2": 4
            }),
        ])
        exp = get_branin_experiment()
        sobol_GPEI_generation_strategy = GenerationStrategy(
            name="Sobol+GPEI",
            steps=[
                GenerationStep(
                    model=Models.SOBOL,
                    num_trials=1,
                    model_kwargs=self.step_model_kwargs,
                ),
                GenerationStep(model=Models.GPEI,
                               num_trials=6,
                               model_kwargs=self.step_model_kwargs),
            ],
        )
        self.assertEqual(sobol_GPEI_generation_strategy.name, "Sobol+GPEI")
        self.assertEqual(sobol_GPEI_generation_strategy.model_transitions, [1])
        gr = sobol_GPEI_generation_strategy.gen(exp, n=2)
        exp.new_batch_trial(generator_run=gr).run()
        for i in range(1, 8):
            if i == 7:
                # Check completeness error message.
                with self.assertRaises(GenerationStrategyCompleted):
                    g = sobol_GPEI_generation_strategy.gen(exp, n=2)
            else:
                g = sobol_GPEI_generation_strategy.gen(exp, n=2)
            exp.new_batch_trial(generator_run=g).run()
        self.assertIsInstance(sobol_GPEI_generation_strategy.model,
                              TorchModelBridge)

    def test_with_factory_function(self):
        """Checks that generation strategy works with custom factory functions.
        No information about the model should be saved on generator run."""
        def get_sobol(search_space: SearchSpace) -> RandomModelBridge:
            return RandomModelBridge(
                search_space=search_space,
                model=SobolGenerator(),
                transforms=Cont_X_trans,
            )

        exp = get_branin_experiment()
        sobol_generation_strategy = GenerationStrategy(
            steps=[GenerationStep(model=get_sobol, num_trials=5)])
        g = sobol_generation_strategy.gen(exp)
        self.assertIsInstance(sobol_generation_strategy.model,
                              RandomModelBridge)
        self.assertIsNone(g._model_key)
        self.assertIsNone(g._model_kwargs)
        self.assertIsNone(g._bridge_kwargs)

    def test_store_experiment(self):
        exp = get_branin_experiment()
        sobol_generation_strategy = GenerationStrategy(
            steps=[GenerationStep(model=Models.SOBOL, num_trials=5)])
        self.assertIsNone(sobol_generation_strategy._experiment)
        sobol_generation_strategy.gen(exp)
        self.assertIsNotNone(sobol_generation_strategy._experiment)

    def test_trials_as_df(self):
        exp = get_branin_experiment()
        sobol_generation_strategy = GenerationStrategy(
            steps=[GenerationStep(model=Models.SOBOL, num_trials=5)])
        # No trials yet, so the DF will be None.
        self.assertIsNone(sobol_generation_strategy.trials_as_df)
        # Now the trial should appear in the DF.
        trial = exp.new_trial(sobol_generation_strategy.gen(experiment=exp))
        self.assertFalse(sobol_generation_strategy.trials_as_df.empty)
        self.assertEqual(
            sobol_generation_strategy.trials_as_df.head()["Trial Status"][0],
            "CANDIDATE",
        )
        # Changes in trial status should be reflected in the DF.
        trial._status = TrialStatus.RUNNING
        self.assertEqual(
            sobol_generation_strategy.trials_as_df.head()["Trial Status"][0],
            "RUNNING")

    def test_max_parallelism_reached(self):
        exp = get_branin_experiment()
        sobol_generation_strategy = GenerationStrategy(steps=[
            GenerationStep(model=Models.SOBOL, num_trials=5, max_parallelism=1)
        ])
        exp.new_trial(generator_run=sobol_generation_strategy.gen(
            experiment=exp)).mark_running(no_runner_required=True)
        with self.assertRaises(MaxParallelismReachedException):
            sobol_generation_strategy.gen(experiment=exp)

    @patch(f"{RandomModelBridge.__module__}.RandomModelBridge.update")
    @patch(f"{Experiment.__module__}.Experiment.lookup_data")
    def test_use_update(self, mock_lookup_data, mock_update):
        exp = get_branin_experiment()
        sobol_gs_with_update = GenerationStrategy(steps=[
            GenerationStep(model=Models.SOBOL, num_trials=-1, use_update=True)
        ])
        sobol_gs_with_update._experiment = exp
        self.assertEqual(
            sobol_gs_with_update._find_trials_completed_since_last_gen(),
            set(),
        )
        with self.assertRaises(NotImplementedError):
            # `BraninMetric` is available while running by default, which should
            # raise an error when use with `use_update=True` on a generation step, as we
            # have not yet properly addressed that edge case (for lack of use case).
            sobol_gs_with_update.gen(experiment=exp)

        core_stubs_module = get_branin_experiment.__module__
        with patch(
                f"{core_stubs_module}.BraninMetric.is_available_while_running",
                return_value=False,
        ):
            # Try without passing data (GS looks up data on experiment).
            trial = exp.new_trial(generator_run=sobol_gs_with_update.gen(
                experiment=exp))
            mock_update.assert_not_called()
            trial._status = TrialStatus.COMPLETED
            for i in range(3):
                gr = sobol_gs_with_update.gen(experiment=exp)
                self.assertEqual(
                    mock_lookup_data.call_args[1].get("trial_indices"), {i})
                trial = exp.new_trial(generator_run=gr)
                trial._status = TrialStatus.COMPLETED
            # `_seen_trial_indices_by_status` is set during `gen`, to the experiment's
            # `trial_indices_by_Status` at the time of candidate generation.
            self.assertNotEqual(
                sobol_gs_with_update._seen_trial_indices_by_status,
                exp.trial_indices_by_status,
            )
            # Try with passing data.
            sobol_gs_with_update.gen(
                experiment=exp, data=get_branin_data(trial_indices=range(4)))
        # Now `_seen_trial_indices_by_status` should be set to experiment's,
        self.assertEqual(
            sobol_gs_with_update._seen_trial_indices_by_status,
            exp.trial_indices_by_status,
        )
        # Only the data for the last completed trial should be considered new and passed
        # to `update`.
        self.assertEqual(
            set(mock_update.call_args[1].get(
                "new_data").df["trial_index"].values), {3})
        # Try with passing same data as before; no update should be performed.
        with patch.object(sobol_gs_with_update,
                          "_update_current_model") as mock_update:
            sobol_gs_with_update.gen(
                experiment=exp, data=get_branin_data(trial_indices=range(4)))
            mock_update.assert_not_called()

    def test_deduplication(self):
        tiny_parameters = [
            FixedParameter(
                name="x1",
                parameter_type=ParameterType.FLOAT,
                value=1.0,
            ),
            ChoiceParameter(
                name="x2",
                parameter_type=ParameterType.FLOAT,
                values=[float(x) for x in range(2)],
            ),
        ]
        tiny_search_space = SearchSpace(
            parameters=cast(List[Parameter], tiny_parameters))
        exp = get_branin_experiment(search_space=tiny_search_space)
        sobol = GenerationStrategy(
            name="Sobol",
            steps=[
                GenerationStep(
                    model=Models.SOBOL,
                    num_trials=-1,
                    model_kwargs=self.step_model_kwargs,
                    should_deduplicate=True,
                ),
            ],
        )
        for _ in range(2):
            g = sobol.gen(exp)
            exp.new_trial(generator_run=g).run()

        self.assertEqual(len(exp.arms_by_signature), 2)

        with self.assertRaisesRegex(GenerationStrategyRepeatedPoints,
                                    "exceeded `MAX_GEN_DRAWS`"):
            g = sobol.gen(exp)

    def test_current_generator_run_limit(self):
        NUM_INIT_TRIALS = 5
        SECOND_STEP_PARALLELISM = 3
        NUM_ROUNDS = 4
        exp = get_branin_experiment()
        sobol_gs_with_parallelism_limits = GenerationStrategy(steps=[
            GenerationStep(
                model=Models.SOBOL,
                num_trials=NUM_INIT_TRIALS,
                min_trials_observed=3,
            ),
            GenerationStep(
                model=Models.SOBOL,
                num_trials=(NUM_ROUNDS - 1) * SECOND_STEP_PARALLELISM,
                max_parallelism=SECOND_STEP_PARALLELISM,
            ),
        ])
        sobol_gs_with_parallelism_limits._experiment = exp
        could_gen = self._run_GS_for_N_rounds(
            gs=sobol_gs_with_parallelism_limits,
            exp=exp,
            num_rounds=NUM_ROUNDS)

        # Optimization should now be complete.
        (
            num_trials_to_gen,
            opt_complete,
        ) = sobol_gs_with_parallelism_limits.current_generator_run_limit()
        self.assertTrue(opt_complete)
        self.assertEqual(num_trials_to_gen, 0)

        # We expect trials from first generation step + trials from remaining rounds in
        # batches limited by parallelism setting in the second step.
        self.assertEqual(
            len(exp.trials),
            NUM_INIT_TRIALS + (NUM_ROUNDS - 1) * SECOND_STEP_PARALLELISM,
        )
        self.assertTrue(all(t.status.is_completed
                            for t in exp.trials.values()))
        self.assertEqual(could_gen, [NUM_INIT_TRIALS] +
                         [SECOND_STEP_PARALLELISM] * (NUM_ROUNDS - 1))

    def test_current_generator_run_limit_unlimited_second_step(self):
        NUM_INIT_TRIALS = 5
        SECOND_STEP_PARALLELISM = 3
        NUM_ROUNDS = 4
        exp = get_branin_experiment()
        sobol_gs_with_parallelism_limits = GenerationStrategy(steps=[
            GenerationStep(
                model=Models.SOBOL,
                num_trials=NUM_INIT_TRIALS,
                min_trials_observed=3,
            ),
            GenerationStep(
                model=Models.SOBOL,
                num_trials=-1,
                max_parallelism=SECOND_STEP_PARALLELISM,
            ),
        ])
        sobol_gs_with_parallelism_limits._experiment = exp
        could_gen = self._run_GS_for_N_rounds(
            gs=sobol_gs_with_parallelism_limits,
            exp=exp,
            num_rounds=NUM_ROUNDS)
        # We expect trials from first generation step + trials from remaining rounds in
        # batches limited by parallelism setting in the second step.
        self.assertEqual(
            len(exp.trials),
            NUM_INIT_TRIALS + (NUM_ROUNDS - 1) * SECOND_STEP_PARALLELISM,
        )
        self.assertTrue(all(t.status.is_completed
                            for t in exp.trials.values()))
        self.assertEqual(could_gen, [NUM_INIT_TRIALS] +
                         [SECOND_STEP_PARALLELISM] * (NUM_ROUNDS - 1))

    def test_hierarchical_search_space(self):
        experiment = get_hierarchical_search_space_experiment()
        self.assertTrue(experiment.search_space.is_hierarchical)
        self.sobol_GS.gen(experiment=experiment)
        for _ in range(10):
            # During each iteration, check that all transformed observation features
            # contain all parameters of the flat search space.
            with patch.object(RandomModelBridge,
                              "_fit") as mock_model_fit, patch.object(
                                  RandomModelBridge, "gen"):
                self.sobol_GS.gen(experiment=experiment)
                mock_model_fit.assert_called_once()
                obs_feats = mock_model_fit.call_args[1].get(
                    "observation_features")
                all_parameter_names = (
                    experiment.search_space._all_parameter_names.copy())
                # One of the parameter names is modified by transforms (because it's
                # one-hot encoded).
                all_parameter_names.remove("model")
                all_parameter_names.add("model_OH_PARAM_")
                for obsf in obs_feats:
                    for p_name in all_parameter_names:
                        self.assertIn(p_name, obsf.parameters)

            trial = (experiment.new_trial(generator_run=self.sobol_GS.gen(
                experiment=experiment)).mark_running(
                    no_runner_required=True).mark_completed())
            experiment.attach_data(
                get_data(
                    metric_name="m1",
                    trial_index=trial.index,
                    num_non_sq_arms=1,
                    include_sq=False,
                ))
            experiment.attach_data(
                get_data(
                    metric_name="m2",
                    trial_index=trial.index,
                    num_non_sq_arms=1,
                    include_sq=False,
                ))

    # ------------- Testing helpers (put tests above this line) -------------

    def _run_GS_for_N_rounds(self, gs: GenerationStrategy, exp: Experiment,
                             num_rounds: int) -> List[int]:
        could_gen = []
        for _ in range(num_rounds):
            (
                num_trials_to_gen,
                opt_complete,
            ) = gs.current_generator_run_limit()
            self.assertFalse(opt_complete)
            could_gen.append(num_trials_to_gen)
            trials = []

            for _ in range(num_trials_to_gen):
                gr = gs.gen(
                    experiment=exp,
                    pending_observations=get_pending(experiment=exp),
                )
                trials.append(
                    exp.new_trial(gr).mark_running(no_runner_required=True))

            for trial in trials:
                exp.attach_data(get_branin_data(trial_indices=[trial.index]))
                trial.mark_completed()

        return could_gen