Example #1
0
    def testExperimentUpdateTrial(self):
        save_experiment(self.experiment)

        trial = self.experiment.trials[0]
        trial.mark_staged()
        update_trial(experiment=self.experiment, trial=trial)

        loaded_experiment = load_experiment(self.experiment.name)
        self.assertEqual(trial, loaded_experiment.trials[0])

        trial._run_metadata = {"foo": "bar"}
        update_trial(experiment=self.experiment, trial=trial)

        loaded_experiment = load_experiment(self.experiment.name)
        self.assertEqual(trial, loaded_experiment.trials[0])

        self.experiment.attach_data(get_data(trial_index=trial.index))
        update_trial(experiment=self.experiment, trial=trial)

        loaded_experiment = load_experiment(self.experiment.name)
        self.assertEqual(self.experiment, loaded_experiment)

        trial = self.experiment.new_batch_trial(
            generator_run=get_generator_run())
        save_new_trial(experiment=self.experiment, trial=trial)
        self.experiment.attach_data(get_data(trial_index=trial.index))
        update_trial(experiment=self.experiment, trial=trial)

        loaded_experiment = load_experiment(self.experiment.name)
        self.assertEqual(self.experiment, loaded_experiment)
Example #2
0
    def test_update_generation_strategy(self):
        _, generation_strategy = self.init_experiment_and_generation_strategy()

        generation_run = get_generator_run()
        updated = self.with_db_settings._update_generation_strategy_in_db_if_possible(
            generation_strategy, [generation_run])
        self.assertTrue(updated)
Example #3
0
    def testExperimentGeneratorRunUpdates(self):
        experiment = get_experiment_with_batch_trial()
        save_experiment(experiment)
        # one main generator run, one for the status quo
        self.assertEqual(get_session().query(SQAGeneratorRun).count(), 2)

        # add a arm
        # this will create one wrapper generator run
        # this will also replace the status quo generator run,
        # since the weight of the status quo will have changed
        trial = experiment.trials[0]
        trial.add_arm(get_arm())
        save_experiment(experiment)
        self.assertEqual(get_session().query(SQAGeneratorRun).count(), 3)

        generator_run = get_generator_run()  # TODO[Lena, T46190605]: remove
        generator_run._model_key = None
        generator_run._model_kwargs = None
        generator_run._bridge_kwargs = None
        trial.add_generator_run(generator_run=generator_run, multiplier=0.5)
        save_experiment(experiment)
        self.assertEqual(get_session().query(SQAGeneratorRun).count(), 4)

        loaded_experiment = load_experiment(experiment.name)
        self.assertEqual(experiment, loaded_experiment)
Example #4
0
    def test_get_candidate_metadata_from_all_generator_runs(self):
        gr_1 = get_generator_run()
        gr_2 = get_generator_run2()
        self.batch.add_generator_run(gr_1)
        # Arms are named when adding GR to trial, so reassign to have a GR that has
        # names arms.
        gr_1 = self.batch._generator_run_structs[-1].generator_run
        self.batch.add_generator_run(gr_2)
        gr_2 = self.batch._generator_run_structs[-1].generator_run
        # gr_2 has no candidate metadata; all candidate metadata should come from gr_1
        cand_metadata_expected = {
            a.name: gr_1.candidate_metadata_by_arm_signature[a.signature]
            for a in gr_1.arms
        }
        self.assertEqual(
            self.batch._get_candidate_metadata_from_all_generator_runs(),
            cand_metadata_expected,
        )
        for arm in self.batch.arms:
            self.assertEqual(
                cand_metadata_expected[arm.name],
                self.batch._get_candidate_metadata(arm.name),
            )
        self.assertRaises(ValueError, self.batch._get_candidate_metadata,
                          "this_is_not_an_arm")

        # Check that if we add cand. metadata to gr_2, it will appear in cand.
        # metadata for the batch.
        gr_3 = get_generator_run2()
        new_cand_metadata = {
            a.signature: {
                "md_key": f"md_val_{a.signature}"
            }
            for a in gr_3.arms
        }
        gr_3._candidate_metadata_by_arm_signature = new_cand_metadata
        self.batch.add_generator_run(gr_3)
        gr_3 = self.batch._generator_run_structs[-1].generator_run
        cand_metadata_expected.update({
            a.name: gr_1.candidate_metadata_by_arm_signature[a.signature]
            for a in gr_1.arms
        })
        self.assertEqual(
            self.batch._get_candidate_metadata_from_all_generator_runs(),
            cand_metadata_expected,
        )
        for arm in self.batch.arms:
            self.assertEqual(
                cand_metadata_expected[arm.name],
                self.batch._get_candidate_metadata(arm.name),
            )
Example #5
0
    def testExperimentNewTrial(self):
        save_experiment(self.experiment)
        trial = self.experiment.new_batch_trial()
        save_new_trial(experiment=self.experiment, trial=trial)

        loaded_experiment = load_experiment(self.experiment.name)
        self.assertEqual(len(loaded_experiment.trials), 2)
        self.assertEqual(trial, loaded_experiment.trials[1])

        trial = self.experiment.new_batch_trial(generator_run=get_generator_run())
        save_new_trial(experiment=self.experiment, trial=trial)

        loaded_experiment = load_experiment(self.experiment.name)
        self.assertEqual(len(loaded_experiment.trials), 3)
        self.assertEqual(trial, loaded_experiment.trials[2])
Example #6
0
    def testSortable(self):
        new_batch_trial = self.experiment.new_batch_trial()
        self.assertTrue(self.batch < new_batch_trial)

        abandoned_arm = get_abandoned_arm()
        abandoned_arm_2 = get_abandoned_arm()
        abandoned_arm_2.name = "0_1"
        self.assertTrue(abandoned_arm < abandoned_arm_2)

        generator_run = get_generator_run()
        generator_run_struct = GeneratorRunStruct(generator_run=generator_run,
                                                  weight=1.0)
        generator_run_struct_2 = GeneratorRunStruct(
            generator_run=generator_run, weight=2.0)
        self.assertTrue(generator_run_struct < generator_run_struct_2)
Example #7
0
    def test_update_generation_strategy_mini_batches(self):
        _, generation_strategy = self.init_experiment_and_generation_strategy()

        # Check with 1 GR.
        generator_run = get_generator_run()
        self.assertIsNone(generator_run.db_id)
        updated = self.with_db_settings._update_generation_strategy_in_db_if_possible(
            generation_strategy, [generator_run])
        self.assertTrue(updated)
        self.assertIsNotNone(generator_run.db_id)

        # Check with multiple GRs, where their number % mini batch size is not 0.
        grs = [generator_run.clone() for _ in range(5)]
        for gr in grs:
            self.assertIsNone(gr._db_id)
        updated = self.with_db_settings._update_generation_strategy_in_db_if_possible(
            generation_strategy, grs)
        self.assertTrue(updated)
        for gr in grs:
            self.assertIsNotNone(gr.db_id)
Example #8
0
class TestAxScheduler(TestCase):
    """Tests base `Scheduler` functionality."""
    def setUp(self):
        self.branin_experiment = get_branin_experiment()
        self.branin_experiment._properties[
            Keys.IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF] = True
        self.branin_experiment_no_impl_metrics = Experiment(
            search_space=get_branin_search_space(),
            optimization_config=OptimizationConfig(objective=Objective(
                metric=Metric(name="branin"))),
        )
        self.sobol_GPEI_GS = choose_generation_strategy(
            search_space=get_branin_search_space())
        self.two_sobol_steps_GS = GenerationStrategy(  # Contrived GS to ensure
            steps=[  # that `DataRequiredError` is property handled in scheduler.
                GenerationStep(  # This error is raised when not enough trials
                    model=Models.
                    SOBOL,  # have been observed to proceed to next
                    num_trials=5,  # geneneration step.
                    min_trials_observed=3,
                    max_parallelism=2,
                ),
                GenerationStep(model=Models.SOBOL,
                               num_trials=-1,
                               max_parallelism=3),
            ])
        # GS to force the scheduler to poll completed trials after each ran trial.
        self.sobol_GS_no_parallelism = GenerationStrategy(steps=[
            GenerationStep(
                model=Models.SOBOL, num_trials=-1, max_parallelism=1)
        ])

    def test_init(self):
        with self.assertRaisesRegex(UnsupportedError,
                                    ".* metrics .* implemented fetching"):
            scheduler = BareBonesTestScheduler(
                experiment=self.branin_experiment_no_impl_metrics,
                generation_strategy=self.sobol_GPEI_GS,
                options=SchedulerOptions(total_trials=10),
            )
        scheduler = BareBonesTestScheduler(
            experiment=self.branin_experiment,
            generation_strategy=self.sobol_GPEI_GS,
            options=SchedulerOptions(
                total_trials=0,
                tolerated_trial_failure_rate=0.2,
                init_seconds_between_polls=10,
            ),
        )
        self.assertEqual(scheduler.experiment, self.branin_experiment)
        self.assertEqual(scheduler.generation_strategy, self.sobol_GPEI_GS)
        self.assertEqual(scheduler.options.total_trials, 0)
        self.assertEqual(scheduler.options.tolerated_trial_failure_rate, 0.2)
        self.assertEqual(scheduler.options.init_seconds_between_polls, 10)
        self.assertIsNone(scheduler._latest_optimization_start_timestamp)
        for status_prop in ExperimentStatusProperties:
            self.assertEqual(
                scheduler.experiment._properties[status_prop.value], [])
        scheduler.run_all_trials()  # Runs no trials since total trials is 0.
        # `_latest_optimization_start_timestamp` should be set now.
        self.assertLessEqual(
            scheduler._latest_optimization_start_timestamp,
            current_timestamp_in_millis(),
        )

    def test_repr(self):
        scheduler = BareBonesTestScheduler(
            experiment=self.branin_experiment,
            generation_strategy=self.sobol_GPEI_GS,
            options=SchedulerOptions(
                total_trials=0,
                tolerated_trial_failure_rate=0.2,
                init_seconds_between_polls=10,
            ),
        )
        self.assertEqual(
            f"{scheduler}",
            ("BareBonesTestScheduler(experiment=Experiment(branin_test_experiment), "
             "generation_strategy=GenerationStrategy(name='Sobol+GPEI', "
             "steps=[Sobol for 5 trials, GPEI for subsequent trials]), "
             "options=SchedulerOptions(trial_type=<class 'ax.core.trial.Trial'>, "
             "total_trials=0, tolerated_trial_failure_rate=0.2, log_filepath=None, "
             "logging_level=20, ttl_seconds_for_trials=None, init_seconds_between_"
             "polls=10, min_seconds_before_poll=1.0, seconds_between_polls_backoff_"
             "factor=1.5, run_trials_in_batches=False, "
             "debug_log_run_metadata=False))"),
        )

    def test_validate_runners_if_required(self):
        # `BareBonesTestScheduler` does not have runner and metrics, so it cannot
        # run on experiment that does not specify those (or specifies base Metric,
        # which do not implement data-fetching logic).
        scheduler = BareBonesTestScheduler(
            experiment=self.branin_experiment,
            generation_strategy=self.sobol_GPEI_GS,
            options=SchedulerOptions(total_trials=10),
        )
        self.branin_experiment.runner = None
        with self.assertRaisesRegex(NotImplementedError,
                                    ".* runner is required"):
            scheduler.run_all_trials()

    @patch(
        f"{GenerationStrategy.__module__}.GenerationStrategy._gen_multiple",
        return_value=[get_generator_run()],
    )
    def test_run_multi_arm_generator_run_error(self, mock_gen):
        scheduler = TestScheduler(
            experiment=self.branin_experiment,
            generation_strategy=self.sobol_GPEI_GS,
            options=SchedulerOptions(total_trials=1),
        )
        with self.assertRaisesRegex(SchedulerInternalError,
                                    ".* only one was expected"):
            scheduler.run_all_trials()

    @patch(
        # Record calls to function, but still execute it.
        (f"{Scheduler.__module__}."
         "get_pending_observation_features_based_on_trial_status"),
        side_effect=get_pending_observation_features_based_on_trial_status,
    )
    def test_run_all_trials_using_runner_and_metrics(self, mock_get_pending):
        # With runners & metrics, `BareBonesTestScheduler.run_all_trials` should run.
        scheduler = BareBonesTestScheduler(
            experiment=self.branin_experiment,  # Has runner and metrics.
            generation_strategy=self.two_sobol_steps_GS,
            options=SchedulerOptions(
                total_trials=8,
                init_seconds_between_polls=
                0.1,  # Short between polls so test is fast.
            ),
        )
        scheduler.run_all_trials()
        # Check that we got pending feat. at least 8 times (1 for each new trial and
        # maybe more for cases where we tried to generate trials but ran into limit on
        # paralel., as polling trial statuses is randomized in BareBonesTestScheduler),
        # so some trials might not yet have come back.
        self.assertGreaterEqual(len(mock_get_pending.call_args_list), 8)
        self.assertTrue(  # Make sure all trials got to complete.
            all(t.completed_successfully
                for t in scheduler.experiment.trials.values()))
        self.assertEqual(len(scheduler.experiment.trials), 8)
        # Check that all the data, fetched during optimization, was attached to the
        # experiment.
        dat = scheduler.experiment.fetch_data().df
        self.assertEqual(set(dat["trial_index"].values), set(range(8)))
        self.assertEqual(
            scheduler.experiment._properties[
                ExperimentStatusProperties.RUN_TRIALS_STATUS],
            ["started", "success"],
        )
        self.assertEqual(
            scheduler.experiment._properties[
                ExperimentStatusProperties.NUM_TRIALS_RUN_PER_CALL],
            [8],
        )
        self.assertEqual(
            scheduler.experiment._properties[
                ExperimentStatusProperties.RESUMED_FROM_STORAGE_TIMESTAMPS],
            [],
        )

    def test_run_n_trials(self):
        # With runners & metrics, `BareBonesTestScheduler.run_all_trials` should run.
        scheduler = BareBonesTestScheduler(
            experiment=self.branin_experiment,  # Has runner and metrics.
            generation_strategy=self.two_sobol_steps_GS,
            options=SchedulerOptions(
                init_seconds_between_polls=
                0.1,  # Short between polls so test is fast.
            ),
        )
        scheduler.run_n_trials(max_trials=1)
        self.assertEqual(len(scheduler.experiment.trials), 1)
        scheduler.run_n_trials(max_trials=10)
        self.assertTrue(  # Make sure all trials got to complete.
            all(t.completed_successfully
                for t in scheduler.experiment.trials.values()))
        # Check that all the data, fetched during optimization, was attached to the
        # experiment.
        dat = scheduler.experiment.fetch_data().df
        self.assertEqual(set(dat["trial_index"].values), set(range(11)))

    def test_stop_trial(self):
        # With runners & metrics, `BareBonesTestScheduler.run_all_trials` should run.
        scheduler = BareBonesTestScheduler(
            experiment=self.branin_experiment,  # Has runner and metrics.
            generation_strategy=self.two_sobol_steps_GS,
            options=SchedulerOptions(
                init_seconds_between_polls=
                0.1,  # Short between polls so test is fast.
            ),
        )
        with patch.object(scheduler.experiment.runner,
                          "stop",
                          return_value=None) as mock_runner_stop:
            scheduler.run_n_trials(max_trials=1)
            scheduler.stop_trial_run(scheduler.experiment.trials[0])
            mock_runner_stop.assert_called_once()

    def test_run_all_trials_not_using_runner(self):
        # `TestScheduler` has `run_trial` and `fetch_trial_data` logic, so runner &
        # implemented metrics are not required.
        scheduler = TestScheduler(
            experiment=self.branin_experiment,
            generation_strategy=self.two_sobol_steps_GS,
            options=SchedulerOptions(
                total_trials=8,
                init_seconds_between_polls=
                0,  # No wait between polls so test is fast.
            ),
        )
        self.branin_experiment.runner = None
        scheduler.run_all_trials()
        self.assertTrue(  # Make sure all trials got to complete.
            all(t.completed_successfully
                for t in scheduler.experiment.trials.values()))
        self.assertEqual(len(scheduler.experiment.trials), 8)
        # Check that all the data, fetched during optimization, was attached to the
        # experiment.
        dat = scheduler.experiment.fetch_data().df
        self.assertEqual(set(dat["trial_index"].values), set(range(8)))

    @patch(f"{Scheduler.__module__}.MAX_SECONDS_BETWEEN_POLLS", 2)
    def test_stop_at_MAX_SECONDS_BETWEEN_POLLS(self):
        class InfinitePollScheduler(BareBonesTestScheduler):
            def poll_trial_status(self):
                return {}

        scheduler = InfinitePollScheduler(
            experiment=self.branin_experiment,  # Has runner and metrics.
            generation_strategy=self.two_sobol_steps_GS,
            options=SchedulerOptions(
                total_trials=8,
                init_seconds_between_polls=
                1,  # No wait between polls so test is fast.
            ),
        )
        with patch.object(scheduler,
                          "wait_for_completed_trials_and_report_results",
                          return_value=None) as mock_await_trials:
            scheduler.run_all_trials(timeout_hours=1 / 60 /
                                     15)  # 4 second timeout.
            # We should be calling `wait_for_completed_trials_and_report_results`
            # N = total runtime / `MAX_SECONDS_BETWEEN_POLLS` times.
            self.assertEqual(
                len(mock_await_trials.call_args),
                2,  # MAX_SECONDS_BETWEEN_POLLS as patched in decorator
            )

    def test_timeout(self):
        # `TestScheduler` has `run_trial` and `fetch_trial_data` logic, so runner &
        # implemented metrics are not required.
        scheduler = TestScheduler(
            experiment=self.branin_experiment,
            generation_strategy=self.two_sobol_steps_GS,
            options=SchedulerOptions(
                total_trials=8,
                init_seconds_between_polls=
                0,  # No wait between polls so test is fast.
            ),
        )
        scheduler.run_all_trials(
            timeout_hours=0)  # Forcing optimization to time out.
        self.assertEqual(len(scheduler.experiment.trials), 0)

    def test_logging(self):
        with NamedTemporaryFile() as temp_file:
            BareBonesTestScheduler(
                experiment=self.branin_experiment,
                generation_strategy=self.sobol_GPEI_GS,
                options=SchedulerOptions(
                    total_trials=1,
                    init_seconds_between_polls=
                    0,  # No wait bw polls so test is fast.
                    log_filepath=temp_file.name,
                ),
            ).run_all_trials()
            self.assertGreater(os.stat(temp_file.name).st_size, 0)
            self.assertIn("Running trials [0]", str(temp_file.readline()))
            temp_file.close()

    def test_logging_level(self):
        # We don't have any warnings yet, so warning level of logging shouldn't yield
        # any logs as of now.
        with NamedTemporaryFile() as temp_file:
            BareBonesTestScheduler(
                experiment=self.branin_experiment,
                generation_strategy=self.sobol_GPEI_GS,
                options=SchedulerOptions(
                    total_trials=3,
                    init_seconds_between_polls=
                    0,  # No wait bw polls so test is fast.
                    log_filepath=temp_file.name,
                    logging_level=WARNING,
                ),
            ).run_all_trials()
            # Ensure that the temp file remains empty
            self.assertEqual(os.stat(temp_file.name).st_size, 0)
            temp_file.close()

    def test_retries(self):
        # Check that retries will be performed for a retriable error.
        class BrokenSchedulerRuntimeError(BareBonesTestScheduler):

            run_trial_call_count = 0

            def run_trial(self, trial: BaseTrial) -> Dict[str, Any]:
                self.run_trial_call_count += 1
                raise RuntimeError("Failing for testing purposes.")

        scheduler = BrokenSchedulerRuntimeError(
            experiment=self.branin_experiment,
            generation_strategy=self.two_sobol_steps_GS,
            options=SchedulerOptions(total_trials=1),
        )
        # Should raise after 3 retries.
        with self.assertRaisesRegex(RuntimeError, ".* testing .*"):
            scheduler.run_all_trials()
            self.assertEqual(scheduler.run_trial_call_count, 3)

    def test_retries_nonretriable_error(self):
        # Check that no retries will be performed for `ValueError`, since we
        # exclude it from the retriable errors.
        class BrokenSchedulerValueError(BareBonesTestScheduler):

            run_trial_call_count = 0

            def run_trial(self, trial: BaseTrial) -> Dict[str, Any]:
                self.run_trial_call_count += 1
                raise ValueError("Failing for testing purposes.")

        scheduler = BrokenSchedulerValueError(
            experiment=self.branin_experiment,
            generation_strategy=self.two_sobol_steps_GS,
            options=SchedulerOptions(total_trials=1),
        )
        # Should raise right away since ValueError is non-retriable.
        with self.assertRaisesRegex(ValueError, ".* testing .*"):
            scheduler.run_all_trials()
            self.assertEqual(scheduler.run_trial_call_count, 1)

    def test_set_ttl(self):
        scheduler = TestScheduler(
            experiment=self.branin_experiment,
            generation_strategy=self.two_sobol_steps_GS,
            options=SchedulerOptions(
                total_trials=2,
                ttl_seconds_for_trials=1,
                init_seconds_between_polls=
                0,  # No wait between polls so test is fast.
                min_seconds_before_poll=0.0,
            ),
        )
        scheduler.run_all_trials()
        self.assertTrue(
            all(t.ttl_seconds == 1
                for t in scheduler.experiment.trials.values()))

    @patch(
        f"{Scheduler.__module__}.START_CHECKING_FAILURE_RATE_AFTER_N_TRIALS",
        3)
    def test_failure_rate(self):
        class SchedulerWithFrequentFailedTrials(TestScheduler):

            poll_failed_next_time = True

            def poll_trial_status(self) -> Dict[TrialStatus, Set[int]]:
                running = [t.index for t in self.running_trials]
                status = (TrialStatus.FAILED if self.poll_failed_next_time else
                          TrialStatus.COMPLETED)
                # Poll different status next time.
                self.poll_failed_next_time = not self.poll_failed_next_time
                return {status: {running[randint(0, len(running) - 1)]}}

        scheduler = SchedulerWithFrequentFailedTrials(
            experiment=self.branin_experiment,
            generation_strategy=self.sobol_GS_no_parallelism,
            options=SchedulerOptions(
                total_trials=8,
                tolerated_trial_failure_rate=0.5,
                init_seconds_between_polls=
                0,  # No wait between polls so test is fast.
            ),
        )
        scheduler.run_all_trials()
        # Trials will have statuses: 0, 2, 4 - FAILED, 1, 3 - COMPLETED. Failure rate
        # is 0.5, and we start checking failure rate after first 3 trials.
        # Therefore, failure rate should be exceeded after trial #4.
        self.assertEqual(len(scheduler.experiment.trials), 5)

        # If we set a slightly lower failure rate, it will be reached after 4 trials.
        num_preexisting_trials = len(scheduler.experiment.trials)
        scheduler = SchedulerWithFrequentFailedTrials(
            experiment=self.branin_experiment,
            generation_strategy=self.sobol_GS_no_parallelism,
            options=SchedulerOptions(
                total_trials=8,
                tolerated_trial_failure_rate=0.49,
                init_seconds_between_polls=
                0,  # No wait between polls so test is fast.
            ),
        )
        self.assertEqual(scheduler._num_preexisting_trials,
                         num_preexisting_trials)
        scheduler.run_all_trials()
        self.assertEqual(len(scheduler.experiment.trials),
                         num_preexisting_trials + 4)

    def test_sqa_storage(self):
        init_test_engine_and_session_factory(force_init=True)
        config = SQAConfig()
        encoder = Encoder(config=config)
        decoder = Decoder(config=config)
        db_settings = DBSettings(encoder=encoder, decoder=decoder)
        experiment = self.branin_experiment
        # Scheduler currently requires that the experiment be pre-saved.
        with self.assertRaisesRegex(ValueError, ".* must specify a name"):
            experiment._name = None
            scheduler = TestScheduler(
                experiment=experiment,
                generation_strategy=self.two_sobol_steps_GS,
                options=SchedulerOptions(total_trials=1),
                db_settings=db_settings,
            )
        experiment._name = "test_experiment"
        NUM_TRIALS = 5
        scheduler = TestScheduler(
            experiment=experiment,
            generation_strategy=self.two_sobol_steps_GS,
            options=SchedulerOptions(
                total_trials=NUM_TRIALS,
                init_seconds_between_polls=
                0,  # No wait between polls so test is fast.
            ),
            db_settings=db_settings,
        )
        # Check that experiment and GS were saved.
        exp, gs = scheduler._load_experiment_and_generation_strategy(
            experiment.name)
        self.assertEqual(exp, experiment)
        self.assertEqual(gs, self.two_sobol_steps_GS)
        scheduler.run_all_trials()
        # Check that experiment and GS were saved and test reloading with reduced state.
        exp, gs = scheduler._load_experiment_and_generation_strategy(
            experiment.name, reduced_state=True)
        self.assertEqual(len(exp.trials), NUM_TRIALS)
        self.assertEqual(len(gs._generator_runs), NUM_TRIALS)
        # Test `from_stored_experiment`.
        new_scheduler = TestScheduler.from_stored_experiment(
            experiment_name=experiment.name,
            options=SchedulerOptions(
                total_trials=NUM_TRIALS + 1,
                init_seconds_between_polls=
                0,  # No wait between polls so test is fast.
            ),
            db_settings=db_settings,
        )
        # Hack "resumed from storage timestamp" into `exp` to make sure all other fields
        # are equal, since difference in resumed from storage timestamps is expected.
        exp._properties[
            ExperimentStatusProperties.
            RESUMED_FROM_STORAGE_TIMESTAMPS] = new_scheduler.experiment._properties[
                ExperimentStatusProperties.RESUMED_FROM_STORAGE_TIMESTAMPS]
        self.assertEqual(new_scheduler.experiment, exp)
        self.assertEqual(new_scheduler.generation_strategy, gs)
        self.assertEqual(
            len(new_scheduler.experiment._properties[
                ExperimentStatusProperties.RESUMED_FROM_STORAGE_TIMESTAMPS]),
            1,
        )

    def test_run_trials_and_yield_results(self):
        total_trials = 3
        scheduler = BareBonesTestScheduler(
            experiment=self.branin_experiment,  # Has runner and metrics.
            generation_strategy=self.two_sobol_steps_GS,
            options=SchedulerOptions(init_seconds_between_polls=0, ),
        )
        # `BaseBonesTestScheduler.poll_trial_status` is written to mark one
        # trial as `COMPLETED` at a time, so we should be obtaining results
        # as many times as `total_trials` and yielding from generator after
        # obtaining each new result.
        res_list = list(
            scheduler.run_trials_and_yield_results(max_trials=total_trials))
        self.assertEqual(len(res_list), total_trials)
        self.assertIsInstance(res_list, list)
        self.assertDictEqual(
            res_list[0], {"trials_completed_so_far": set(range(total_trials))})

    def test_run_trials_and_yield_results_with_early_stopper(self):
        class EarlyStopsInsteadOfNormalCompletionScheduler(
                BareBonesTestScheduler):
            def poll_trial_status(self):
                return {}

            def should_stop_trials_early(self, trial_indices: Set[int]):
                return {TrialStatus.COMPLETED: trial_indices}

        total_trials = 3
        scheduler = EarlyStopsInsteadOfNormalCompletionScheduler(
            experiment=self.branin_experiment,  # Has runner and metrics.
            generation_strategy=self.two_sobol_steps_GS,
            options=SchedulerOptions(init_seconds_between_polls=0.1, ),
        )
        # All trials should be marked complete after one run.
        with patch.object(
                scheduler,
                "should_stop_trials_early",
                wraps=scheduler.should_stop_trials_early,
        ) as mock_should_stop_trials_early:
            res_list = list(
                scheduler.run_trials_and_yield_results(
                    max_trials=total_trials))
            # Two steps complete the experiment given parallelism.
            expected_num_polls = 2
            self.assertEqual(len(res_list), expected_num_polls)
            self.assertIsInstance(res_list, list)
            self.assertDictEqual(
                res_list[0],
                {"trials_completed_so_far": set(range(total_trials))})
            self.assertEqual(mock_should_stop_trials_early.call_count,
                             expected_num_polls)

    def test_run_trials_in_batches(self):
        with self.assertRaisesRegex(UnsupportedError,
                                    "only if `poll_available_capacity`"):
            scheduler = BareBonesTestScheduler(
                experiment=self.branin_experiment,  # Has runner and metrics.
                generation_strategy=self.two_sobol_steps_GS,
                options=SchedulerOptions(
                    init_seconds_between_polls=0,
                    run_trials_in_batches=True,
                ),
            )
            scheduler.run_n_trials(max_trials=3)

        class PollAvailableCapacityScheduler(BareBonesTestScheduler):
            def poll_available_capacity(self):
                return 2

        scheduler = PollAvailableCapacityScheduler(
            experiment=self.branin_experiment,  # Has runner and metrics.
            generation_strategy=self.two_sobol_steps_GS,
            options=SchedulerOptions(
                init_seconds_between_polls=0,
                run_trials_in_batches=True,
            ),
        )

        with patch.object(scheduler,
                          "run_trials",
                          side_effect=scheduler.run_trials) as mock_run_trials:
            scheduler.run_n_trials(max_trials=3)
            # Trials should be dispatched twice, as total of three trials
            # should be dispatched but capacity is limited to 2.
            self.assertEqual(mock_run_trials.call_count, ceil(3 / 2))

    def test_base_report_results(self):
        class NoReportResultsScheduler(Scheduler):
            def poll_trial_status(self) -> Dict[TrialStatus, Set[int]]:
                if randint(0, 3) > 0:
                    running = [t.index for t in self.running_trials]
                    return {
                        TrialStatus.COMPLETED:
                        {running[randint(0,
                                         len(running) - 1)]}
                    }
                return {}

        scheduler = NoReportResultsScheduler(
            experiment=self.branin_experiment,  # Has runner and metrics.
            generation_strategy=self.two_sobol_steps_GS,
            options=SchedulerOptions(init_seconds_between_polls=0, ),
        )
        self.assertEqual(scheduler.run_n_trials(max_trials=3),
                         OptimizationResult())
Example #9
0
class TestAxScheduler(TestCase):
    """Tests base `Scheduler` functionality."""
    def setUp(self):
        self.branin_experiment = get_branin_experiment()
        self.runner = SyntheticRunnerWithStatusPolling()
        self.branin_experiment.runner = self.runner
        self.branin_experiment._properties[
            Keys.IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF] = True
        self.branin_experiment_no_impl_runner_or_metrics = Experiment(
            search_space=get_branin_search_space(),
            optimization_config=OptimizationConfig(objective=Objective(
                metric=Metric(name="branin"))),
        )
        self.sobol_GPEI_GS = choose_generation_strategy(
            search_space=get_branin_search_space())
        self.two_sobol_steps_GS = GenerationStrategy(  # Contrived GS to ensure
            steps=[  # that `DataRequiredError` is property handled in scheduler.
                GenerationStep(  # This error is raised when not enough trials
                    model=Models.
                    SOBOL,  # have been observed to proceed to next
                    num_trials=5,  # geneneration step.
                    min_trials_observed=3,
                    max_parallelism=2,
                ),
                GenerationStep(model=Models.SOBOL,
                               num_trials=-1,
                               max_parallelism=3),
            ])
        # GS to force the scheduler to poll completed trials after each ran trial.
        self.sobol_GS_no_parallelism = GenerationStrategy(steps=[
            GenerationStep(
                model=Models.SOBOL, num_trials=-1, max_parallelism=1)
        ])

    def test_init(self):
        with self.assertRaisesRegex(
                UnsupportedError,
                "`Scheduler` requires that experiment specifies a `Runner`.",
        ):
            scheduler = Scheduler(
                experiment=self.branin_experiment_no_impl_runner_or_metrics,
                generation_strategy=self.sobol_GPEI_GS,
                options=SchedulerOptions(total_trials=10),
            )
        self.branin_experiment_no_impl_runner_or_metrics.runner = self.runner
        with self.assertRaisesRegex(
                UnsupportedError,
                ".*Metrics {'branin'} do not implement fetching logic.",
        ):
            scheduler = Scheduler(
                experiment=self.branin_experiment_no_impl_runner_or_metrics,
                generation_strategy=self.sobol_GPEI_GS,
                options=SchedulerOptions(total_trials=10),
            )
        scheduler = Scheduler(
            experiment=self.branin_experiment,
            generation_strategy=self.sobol_GPEI_GS,
            options=SchedulerOptions(
                total_trials=0,
                tolerated_trial_failure_rate=0.2,
                init_seconds_between_polls=10,
            ),
        )
        self.assertEqual(scheduler.experiment, self.branin_experiment)
        self.assertEqual(scheduler.generation_strategy, self.sobol_GPEI_GS)
        self.assertEqual(scheduler.options.total_trials, 0)
        self.assertEqual(scheduler.options.tolerated_trial_failure_rate, 0.2)
        self.assertEqual(scheduler.options.init_seconds_between_polls, 10)
        self.assertIsNone(scheduler._latest_optimization_start_timestamp)
        for status_prop in ExperimentStatusProperties:
            self.assertEqual(
                scheduler.experiment._properties[status_prop.value], [])
        scheduler.run_all_trials()  # Runs no trials since total trials is 0.
        # `_latest_optimization_start_timestamp` should be set now.
        self.assertLessEqual(
            scheduler._latest_optimization_start_timestamp,
            current_timestamp_in_millis(),
        )

    def test_repr(self):
        scheduler = Scheduler(
            experiment=self.branin_experiment,
            generation_strategy=self.sobol_GPEI_GS,
            options=SchedulerOptions(
                total_trials=0,
                tolerated_trial_failure_rate=0.2,
                init_seconds_between_polls=10,
            ),
        )
        self.assertEqual(
            f"{scheduler}",
            ("Scheduler(experiment=Experiment(branin_test_experiment), "
             "generation_strategy=GenerationStrategy(name='Sobol+GPEI', "
             "steps=[Sobol for 5 trials, GPEI for subsequent trials]), "
             "options=SchedulerOptions(max_pending_trials=10, "
             "trial_type=<class 'ax.core.trial.Trial'>, "
             "total_trials=0, tolerated_trial_failure_rate=0.2, "
             "min_failed_trials_for_failure_rate_check=5, log_filepath=None, "
             "logging_level=20, ttl_seconds_for_trials=None, init_seconds_between_"
             "polls=10, min_seconds_before_poll=1.0, seconds_between_polls_backoff_"
             "factor=1.5, run_trials_in_batches=False, "
             "debug_log_run_metadata=False, early_stopping_strategy=None, "
             "suppress_storage_errors_after_retries=False))"),
        )

    def test_validate_early_stopping_strategy(self):
        class DummyEarlyStoppingStrategy(BaseEarlyStoppingStrategy):
            def should_stop_trials_early(
                self,
                trial_indices: Set[int],
                experiment: Experiment,
                **kwargs: Dict[str, Any],
            ) -> Set[int]:
                return {}

        with patch(
                f"{BraninMetric.__module__}.BraninMetric.is_available_while_running",
                return_value=False,
        ), self.assertRaises(ValueError):
            Scheduler(
                experiment=self.branin_experiment,
                generation_strategy=self.sobol_GPEI_GS,
                options=SchedulerOptions(
                    early_stopping_strategy=DummyEarlyStoppingStrategy()),
            )

        # should not error
        Scheduler(
            experiment=self.branin_experiment,
            generation_strategy=self.sobol_GPEI_GS,
            options=SchedulerOptions(
                early_stopping_strategy=DummyEarlyStoppingStrategy()),
        )

    @patch(
        f"{GenerationStrategy.__module__}.GenerationStrategy._gen_multiple",
        return_value=[get_generator_run()],
    )
    def test_run_multi_arm_generator_run_error(self, mock_gen):
        scheduler = Scheduler(
            experiment=self.branin_experiment,
            generation_strategy=self.sobol_GPEI_GS,
            options=SchedulerOptions(total_trials=1),
        )
        with self.assertRaisesRegex(SchedulerInternalError,
                                    ".* only one was expected"):
            scheduler.run_all_trials()

    @patch(
        # Record calls to function, but still execute it.
        (f"{Scheduler.__module__}."
         "get_pending_observation_features_based_on_trial_status"),
        side_effect=get_pending_observation_features_based_on_trial_status,
    )
    def test_run_all_trials_using_runner_and_metrics(self, mock_get_pending):
        # With runners & metrics, `Scheduler.run_all_trials` should run.
        scheduler = Scheduler(
            experiment=self.branin_experiment,  # Has runner and metrics.
            generation_strategy=self.two_sobol_steps_GS,
            options=SchedulerOptions(
                total_trials=8,
                init_seconds_between_polls=
                0.1,  # Short between polls so test is fast.
            ),
        )
        scheduler.run_all_trials()
        # Check that we got pending feat. at least 8 times (1 for each new trial and
        # maybe more for cases where we tried to generate trials but ran into limit on
        # paralel., as polling trial statuses is randomized in Scheduler),
        # so some trials might not yet have come back.
        self.assertGreaterEqual(len(mock_get_pending.call_args_list), 8)
        self.assertTrue(  # Make sure all trials got to complete.
            all(t.completed_successfully
                for t in scheduler.experiment.trials.values()))
        self.assertEqual(len(scheduler.experiment.trials), 8)
        # Check that all the data, fetched during optimization, was attached to the
        # experiment.
        dat = scheduler.experiment.fetch_data().df
        self.assertEqual(set(dat["trial_index"].values), set(range(8)))
        self.assertEqual(
            scheduler.experiment._properties[
                ExperimentStatusProperties.RUN_TRIALS_STATUS],
            ["started", "success"],
        )
        self.assertEqual(
            scheduler.experiment._properties[
                ExperimentStatusProperties.NUM_TRIALS_RUN_PER_CALL],
            [8],
        )
        self.assertEqual(
            scheduler.experiment._properties[
                ExperimentStatusProperties.RESUMED_FROM_STORAGE_TIMESTAMPS],
            [],
        )

    def test_run_n_trials(self):
        # With runners & metrics, `Scheduler.run_all_trials` should run.
        scheduler = Scheduler(
            experiment=self.branin_experiment,  # Has runner and metrics.
            generation_strategy=self.two_sobol_steps_GS,
            options=SchedulerOptions(
                init_seconds_between_polls=
                0.1,  # Short between polls so test is fast.
            ),
        )
        scheduler.run_n_trials(max_trials=1)
        self.assertEqual(len(scheduler.experiment.trials), 1)
        scheduler.run_n_trials(max_trials=10)
        self.assertTrue(  # Make sure all trials got to complete.
            all(t.completed_successfully
                for t in scheduler.experiment.trials.values()))
        # Check that all the data, fetched during optimization, was attached to the
        # experiment.
        dat = scheduler.experiment.fetch_data().df
        self.assertEqual(set(dat["trial_index"].values), set(range(11)))

    def test_run_preattached_trials_only(self):
        # assert that pre-attached trials run when max_trials = number of
        # pre-attached trials
        scheduler = Scheduler(
            experiment=self.branin_experiment,  # Has runner and metrics.
            generation_strategy=self.two_sobol_steps_GS,
            options=SchedulerOptions(
                init_seconds_between_polls=
                0.1,  # Short between polls so test is fast.
            ),
        )
        trial = scheduler.experiment.new_trial()
        parameter_dict = {"x1": 5, "x2": 5}
        trial.add_arm(Arm(parameters=parameter_dict))
        with self.assertRaisesRegex(
                UserInputError,
                "number of pre-attached candidate trials .* is greater than"):
            scheduler.run_n_trials(max_trials=0)
        scheduler.run_n_trials(max_trials=1)
        self.assertEqual(len(scheduler.experiment.trials), 1)
        self.assertDictEqual(scheduler.experiment.trials[0].arm.parameters,
                             parameter_dict)
        self.assertTrue(  # Make sure all trials got to complete.
            all(t.completed_successfully
                for t in scheduler.experiment.trials.values()))

    def test_stop_trial(self):
        # With runners & metrics, `Scheduler.run_all_trials` should run.
        scheduler = Scheduler(
            experiment=self.branin_experiment,  # Has runner and metrics.
            generation_strategy=self.two_sobol_steps_GS,
            options=SchedulerOptions(
                init_seconds_between_polls=
                0.1,  # Short between polls so test is fast.
            ),
        )
        with patch.object(scheduler.experiment.runner,
                          "stop",
                          return_value=None) as mock_runner_stop:
            scheduler.run_n_trials(max_trials=1)
            scheduler.stop_trial_runs(trials=[scheduler.experiment.trials[0]])
            mock_runner_stop.assert_called_once()

    @patch(f"{Scheduler.__module__}.MAX_SECONDS_BETWEEN_POLLS", 2)
    def test_stop_at_MAX_SECONDS_BETWEEN_POLLS(self):
        self.branin_experiment.runner = InfinitePollRunner()
        scheduler = Scheduler(
            experiment=self.branin_experiment,  # Has runner and metrics.
            generation_strategy=self.two_sobol_steps_GS,
            options=SchedulerOptions(
                total_trials=8,
                init_seconds_between_polls=
                1,  # No wait between polls so test is fast.
            ),
        )
        with patch.object(scheduler,
                          "wait_for_completed_trials_and_report_results",
                          return_value=None) as mock_await_trials:
            scheduler.run_all_trials(timeout_hours=1 / 60 /
                                     15)  # 4 second timeout.
            # We should be calling `wait_for_completed_trials_and_report_results`
            # N = total runtime / `MAX_SECONDS_BETWEEN_POLLS` times.
            self.assertEqual(
                len(mock_await_trials.call_args),
                2,  # MAX_SECONDS_BETWEEN_POLLS as patched in decorator
            )

    def test_timeout(self):
        scheduler = Scheduler(
            experiment=self.branin_experiment,
            generation_strategy=self.two_sobol_steps_GS,
            options=SchedulerOptions(
                total_trials=8,
                init_seconds_between_polls=
                0,  # No wait between polls so test is fast.
            ),
        )
        scheduler.run_all_trials(
            timeout_hours=0)  # Forcing optimization to time out.
        self.assertEqual(len(scheduler.experiment.trials), 0)
        self.assertIn("aborted",
                      scheduler.experiment._properties["run_trials_success"])

    def test_logging(self):
        with NamedTemporaryFile() as temp_file:
            Scheduler(
                experiment=self.branin_experiment,
                generation_strategy=self.sobol_GPEI_GS,
                options=SchedulerOptions(
                    total_trials=1,
                    init_seconds_between_polls=
                    0,  # No wait bw polls so test is fast.
                    log_filepath=temp_file.name,
                ),
            ).run_all_trials()
            self.assertGreater(os.stat(temp_file.name).st_size, 0)
            self.assertIn("Running trials [0]", str(temp_file.readline()))
            temp_file.close()

    def test_logging_level(self):
        # We don't have any warnings yet, so warning level of logging shouldn't yield
        # any logs as of now.
        with NamedTemporaryFile() as temp_file:
            Scheduler(
                experiment=self.branin_experiment,
                generation_strategy=self.sobol_GPEI_GS,
                options=SchedulerOptions(
                    total_trials=3,
                    init_seconds_between_polls=
                    0,  # No wait bw polls so test is fast.
                    log_filepath=temp_file.name,
                    logging_level=WARNING,
                ),
            ).run_all_trials()
            # Ensure that the temp file remains empty
            self.assertEqual(os.stat(temp_file.name).st_size, 0)
            temp_file.close()

    def test_retries(self):
        # Check that retries will be performed for a retriable error.
        self.branin_experiment.runner = BrokenRunnerRuntimeError()
        scheduler = Scheduler(
            experiment=self.branin_experiment,
            generation_strategy=self.two_sobol_steps_GS,
            options=SchedulerOptions(total_trials=1),
        )
        # Should raise after 3 retries.
        with self.assertRaisesRegex(RuntimeError, ".* testing .*"):
            scheduler.run_all_trials()
            self.assertEqual(scheduler.run_trial_call_count, 3)

    def test_retries_nonretriable_error(self):
        # Check that no retries will be performed for `ValueError`, since we
        # exclude it from the retriable errors.
        self.branin_experiment.runner = BrokenRunnerValueError()
        scheduler = Scheduler(
            experiment=self.branin_experiment,
            generation_strategy=self.two_sobol_steps_GS,
            options=SchedulerOptions(total_trials=1),
        )
        # Should raise right away since ValueError is non-retriable.
        with self.assertRaisesRegex(ValueError, ".* testing .*"):
            scheduler.run_all_trials()
            self.assertEqual(scheduler.run_trial_call_count, 1)

    def test_set_ttl(self):
        scheduler = Scheduler(
            experiment=self.branin_experiment,
            generation_strategy=self.two_sobol_steps_GS,
            options=SchedulerOptions(
                total_trials=2,
                ttl_seconds_for_trials=1,
                init_seconds_between_polls=
                0,  # No wait between polls so test is fast.
                min_seconds_before_poll=0.0,
            ),
        )
        scheduler.run_all_trials()
        self.assertTrue(
            all(t.ttl_seconds == 1
                for t in scheduler.experiment.trials.values()))

    def test_failure_rate(self):
        options = SchedulerOptions(
            total_trials=8,
            tolerated_trial_failure_rate=0.5,
            init_seconds_between_polls=
            0,  # No wait between polls so test is fast.
            min_failed_trials_for_failure_rate_check=2,
        )

        self.branin_experiment.runner = RunnerWithFrequentFailedTrials()
        scheduler = Scheduler(
            experiment=self.branin_experiment,
            generation_strategy=self.sobol_GS_no_parallelism,
            options=options,
        )
        with self.assertRaises(FailureRateExceededError):
            scheduler.run_all_trials()
        # Trials will have statuses: 0, 2 - FAILED, 1 - COMPLETED. Failure rate
        # is 0.5, and so if 2 of the first 3 trials are failed, we can fail
        # immediately.
        self.assertEqual(len(scheduler.experiment.trials), 3)

        # If all trials fail, we can be certain that the sweep will
        # fail after only 2 trials.
        num_preexisting_trials = len(scheduler.experiment.trials)
        self.branin_experiment.runner = RunnerWithAllFailedTrials()
        scheduler = Scheduler(
            experiment=self.branin_experiment,
            generation_strategy=self.sobol_GS_no_parallelism,
            options=options,
        )
        self.assertEqual(scheduler._num_preexisting_trials,
                         num_preexisting_trials)
        with self.assertRaises(FailureRateExceededError):
            scheduler.run_all_trials()
        self.assertEqual(len(scheduler.experiment.trials),
                         num_preexisting_trials + 2)

    def test_sqa_storage(self):
        init_test_engine_and_session_factory(force_init=True)
        config = SQAConfig()
        encoder = Encoder(config=config)
        decoder = Decoder(config=config)
        db_settings = DBSettings(encoder=encoder, decoder=decoder)
        experiment = self.branin_experiment
        # Scheduler currently requires that the experiment be pre-saved.
        with self.assertRaisesRegex(ValueError, ".* must specify a name"):
            experiment._name = None
            scheduler = Scheduler(
                experiment=experiment,
                generation_strategy=self.two_sobol_steps_GS,
                options=SchedulerOptions(total_trials=1),
                db_settings=db_settings,
            )
        experiment._name = "test_experiment"
        NUM_TRIALS = 5
        scheduler = Scheduler(
            experiment=experiment,
            generation_strategy=self.two_sobol_steps_GS,
            options=SchedulerOptions(
                total_trials=NUM_TRIALS,
                init_seconds_between_polls=
                0,  # No wait between polls so test is fast.
            ),
            db_settings=db_settings,
        )
        # Check that experiment and GS were saved.
        exp, gs = scheduler._load_experiment_and_generation_strategy(
            experiment.name)
        self.assertEqual(exp, experiment)
        self.assertEqual(gs, self.two_sobol_steps_GS)
        scheduler.run_all_trials()
        # Check that experiment and GS were saved and test reloading with reduced state.
        exp, gs = scheduler._load_experiment_and_generation_strategy(
            experiment.name, reduced_state=True)
        self.assertEqual(len(exp.trials), NUM_TRIALS)
        self.assertEqual(len(gs._generator_runs), NUM_TRIALS)
        # Test `from_stored_experiment`.
        new_scheduler = Scheduler.from_stored_experiment(
            experiment_name=experiment.name,
            options=SchedulerOptions(
                total_trials=NUM_TRIALS + 1,
                init_seconds_between_polls=
                0,  # No wait between polls so test is fast.
            ),
            db_settings=db_settings,
        )
        # Hack "resumed from storage timestamp" into `exp` to make sure all other fields
        # are equal, since difference in resumed from storage timestamps is expected.
        exp._properties[
            ExperimentStatusProperties.
            RESUMED_FROM_STORAGE_TIMESTAMPS] = new_scheduler.experiment._properties[
                ExperimentStatusProperties.RESUMED_FROM_STORAGE_TIMESTAMPS]
        self.assertEqual(new_scheduler.experiment, exp)
        self.assertEqual(new_scheduler.generation_strategy, gs)
        self.assertEqual(
            len(new_scheduler.experiment._properties[
                ExperimentStatusProperties.RESUMED_FROM_STORAGE_TIMESTAMPS]),
            1,
        )

    def test_run_trials_and_yield_results(self):
        total_trials = 3
        scheduler = TestScheduler(
            experiment=self.branin_experiment,  # Has runner and metrics.
            generation_strategy=self.two_sobol_steps_GS,
            options=SchedulerOptions(init_seconds_between_polls=0, ),
        )
        # `BaseBonesScheduler.poll_trial_status` is written to mark one
        # trial as `COMPLETED` at a time, so we should be obtaining results
        # as many times as `total_trials` and yielding from generator after
        # obtaining each new result.
        res_list = list(
            scheduler.run_trials_and_yield_results(max_trials=total_trials))
        self.assertEqual(len(res_list), total_trials + 1)
        self.assertEqual(len(res_list[0]["trials_completed_so_far"]), 1)
        self.assertEqual(len(res_list[1]["trials_completed_so_far"]), 2)
        self.assertEqual(len(res_list[2]["trials_completed_so_far"]), 3)

    def test_run_trials_and_yield_results_with_early_stopper(self):
        total_trials = 3
        self.branin_experiment.runner = InfinitePollRunner()
        scheduler = EarlyStopsInsteadOfNormalCompletionScheduler(
            experiment=self.branin_experiment,  # Has runner and metrics.
            generation_strategy=self.two_sobol_steps_GS,
            options=SchedulerOptions(init_seconds_between_polls=0.1, ),
        )
        # All trials should be marked complete after one run.
        with patch.object(
                scheduler,
                "should_stop_trials_early",
                wraps=scheduler.should_stop_trials_early,
        ) as mock_should_stop_trials_early, patch.object(
                scheduler, "stop_trial_runs",
                return_value=None) as mock_stop_trial_runs:
            res_list = list(
                scheduler.run_trials_and_yield_results(
                    max_trials=total_trials))
            # Two steps complete the experiment given parallelism.
            expected_num_polls = 2
            self.assertEqual(len(res_list), expected_num_polls + 1)
            # Both trials in first batch of parallelism will be early stopped
            self.assertEqual(len(res_list[0]["trials_early_stopped_so_far"]),
                             2)
            # Third trial in second batch of parallelism will be early stopped
            self.assertEqual(len(res_list[1]["trials_early_stopped_so_far"]),
                             3)
            self.assertEqual(mock_should_stop_trials_early.call_count,
                             expected_num_polls)
            self.assertEqual(mock_stop_trial_runs.call_count,
                             expected_num_polls)

    def test_scheduler_with_odd_index_early_stopping_strategy(self):
        total_trials = 3

        class OddIndexEarlyStoppingStrategy(BaseEarlyStoppingStrategy):
            # Trials with odd indices will be early stopped
            # Thus, with 3 total trials, trial #1 will be early stopped
            def should_stop_trials_early(
                self,
                trial_indices: Set[int],
                experiment: Experiment,
                **kwargs: Dict[str, Any],
            ) -> Dict[int, Optional[str]]:
                # Make sure that we can lookup data for the trial,
                # even though we won't use it in this dummy strategy
                data = experiment.lookup_data(trial_indices=trial_indices)
                if data.df.empty:
                    raise Exception(
                        f"No data found for trials {trial_indices}; "
                        "can't determine whether or not to stop early.")
                return {idx: None for idx in trial_indices if idx % 2 == 1}

        self.branin_experiment.runner = RunnerWithEarlyStoppingStrategy()
        scheduler = TestScheduler(
            experiment=self.branin_experiment,
            generation_strategy=self.two_sobol_steps_GS,
            options=SchedulerOptions(
                init_seconds_between_polls=0.1,
                early_stopping_strategy=OddIndexEarlyStoppingStrategy(),
            ),
        )
        with patch.object(scheduler, "stop_trial_runs",
                          return_value=None) as mock_stop_trial_runs:
            res_list = list(
                scheduler.run_trials_and_yield_results(
                    max_trials=total_trials))
            expected_num_steps = 2
            self.assertEqual(len(res_list), expected_num_steps + 1)
            # Trial #1 early stopped in first step
            self.assertEqual(res_list[0]["trials_early_stopped_so_far"], {1})
            # All trials completed by end of second step
            self.assertEqual(res_list[1]["trials_early_stopped_so_far"], {1})
            self.assertEqual(res_list[1]["trials_completed_so_far"], {0, 2})
            self.assertEqual(mock_stop_trial_runs.call_count,
                             expected_num_steps)

    def test_run_trials_in_batches(self):
        # TODO[drfreund]: Use `Runner` instead when `poll_available_capacity`
        # is moved to `Runner`
        class PollAvailableCapacityScheduler(Scheduler):
            def poll_available_capacity(self):
                return 2

        scheduler = PollAvailableCapacityScheduler(
            experiment=self.branin_experiment,  # Has runner and metrics.
            generation_strategy=self.two_sobol_steps_GS,
            options=SchedulerOptions(
                init_seconds_between_polls=0,
                run_trials_in_batches=True,
            ),
        )

        with patch.object(scheduler,
                          "run_trials",
                          side_effect=scheduler.run_trials) as mock_run_trials:
            scheduler.run_n_trials(max_trials=3)
            # Trials should be dispatched twice, as total of three trials
            # should be dispatched but capacity is limited to 2.
            self.assertEqual(mock_run_trials.call_count, ceil(3 / 2))

    def test_base_report_results(self):
        self.branin_experiment.runner = NoReportResultsRunner()
        scheduler = Scheduler(
            experiment=self.branin_experiment,  # Has runner and metrics.
            generation_strategy=self.two_sobol_steps_GS,
            options=SchedulerOptions(init_seconds_between_polls=0, ),
        )
        self.assertEqual(scheduler.run_n_trials(max_trials=3),
                         OptimizationResult())

    @patch(
        f"{GenerationStrategy.__module__}.GenerationStrategy._gen_multiple",
        side_effect=OptimizationComplete("test error"),
    )
    def test_optimization_complete(self, _):
        # With runners & metrics, `Scheduler.run_all_trials` should run.
        scheduler = Scheduler(
            experiment=self.branin_experiment,  # Has runner and metrics.
            generation_strategy=self.two_sobol_steps_GS,
            options=SchedulerOptions(
                max_pending_trials=100,
                init_seconds_between_polls=
                0.1,  # Short between polls so test is fast.
            ),
        )
        scheduler.run_n_trials(max_trials=1)
        # no trials should run if _gen_multiple throws an OptimizationComplete error
        self.assertEqual(len(scheduler.experiment.trials), 0)

    @patch((f"{WithDBSettingsBase.__module__}.WithDBSettingsBase."
            "_save_generation_strategy_to_db_if_possible"))
    @patch(f"{WithDBSettingsBase.__module__}._save_experiment",
           side_effect=StaleDataError)
    def test_suppress_all_storage_errors(self, mock_save_exp, _):
        init_test_engine_and_session_factory(force_init=True)
        config = SQAConfig()
        encoder = Encoder(config=config)
        decoder = Decoder(config=config)
        db_settings = DBSettings(encoder=encoder, decoder=decoder)
        Scheduler(
            experiment=self.branin_experiment,  # Has runner and metrics.
            generation_strategy=self.two_sobol_steps_GS,
            options=SchedulerOptions(
                max_pending_trials=100,
                init_seconds_between_polls=
                0.1,  # Short between polls so test is fast.
                suppress_storage_errors_after_retries=True,
            ),
            db_settings=db_settings,
        )
        self.assertEqual(mock_save_exp.call_count, 3)

    def test_max_pending_trials(self):
        # With runners & metrics, `BareBonesTestScheduler.run_all_trials` should run.
        scheduler = TestScheduler(
            experiment=self.branin_experiment,  # Has runner and metrics.
            generation_strategy=self.sobol_GPEI_GS,
            options=SchedulerOptions(
                max_pending_trials=1,
                init_seconds_between_polls=
                0.1,  # Short between polls so test is fast.
            ),
        )
        for idx, _ in enumerate(
                scheduler.run_trials_and_yield_results(max_trials=3)):
            # Trials should be scheduled one-at-a-time w/ parallelism limit of 1.
            self.assertEqual(len(self.branin_experiment.trials),
                             idx + 1 if idx < 3 else idx)
            # Trials also should be getting completed one-at-a-time.
            self.assertEqual(
                len(self.branin_experiment.trial_indices_by_status[
                    TrialStatus.COMPLETED]),
                idx + 1 if idx < 3 else idx,
            )