示例#1
0
 def test_storage_error_handling(self, mock_save_fails):
     """Check that if `suppress_storage_errors` is True, AxClient won't
     visibly fail if encountered storage errors.
     """
     init_test_engine_and_session_factory(force_init=True)
     config = SQAConfig()
     encoder = Encoder(config=config)
     decoder = Decoder(config=config)
     db_settings = DBSettings(encoder=encoder, decoder=decoder)
     ax_client = AxClient(db_settings=db_settings,
                          suppress_storage_errors=True)
     ax_client.create_experiment(
         name="test_experiment",
         parameters=[
             {
                 "name": "x",
                 "type": "range",
                 "bounds": [-5.0, 10.0]
             },
             {
                 "name": "y",
                 "type": "range",
                 "bounds": [0.0, 15.0]
             },
         ],
         minimize=True,
     )
     for _ in range(3):
         parameters, trial_index = ax_client.get_next_trial()
         ax_client.complete_trial(trial_index=trial_index,
                                  raw_data=branin(*parameters.values()))
示例#2
0
    def test_save_load_experiment(self):
        exp = get_experiment()
        init_test_engine_and_session_factory()
        db_settings = DBSettings(url="sqlite://")
        save_experiment(exp, db_settings)
        load_experiment(exp.name, db_settings)

        simple_experiment = get_simple_experiment()
        save_experiment(simple_experiment, db_settings)
        with self.assertRaisesRegex(ValueError, "Service API only"):
            load_experiment(simple_experiment.name, db_settings)
示例#3
0
 def test_sqa_storage(self):
     init_test_engine_and_session_factory(force_init=True)
     config = SQAConfig()
     encoder = Encoder(config=config)
     decoder = Decoder(config=config)
     db_settings = DBSettings(encoder=encoder, decoder=decoder)
     ax_client = AxClient(db_settings=db_settings)
     ax_client.create_experiment(
         name="test_experiment",
         parameters=[
             {"name": "x", "type": "range", "bounds": [-5.0, 10.0]},
             {"name": "y", "type": "range", "bounds": [0.0, 15.0]},
         ],
         minimize=True,
     )
     for _ in range(5):
         parameters, trial_index = ax_client.get_next_trial()
         ax_client.complete_trial(
             trial_index=trial_index, raw_data=branin(*parameters.values())
         )
     gs = ax_client.generation_strategy
     ax_client = AxClient(db_settings=db_settings)
     ax_client.load_experiment_from_database("test_experiment")
     # Trial #4 was completed after the last time the generation strategy
     # generated candidates, so pre-save generation strategy was not
     # "aware" of completion of trial #4. Post-restoration generation
     # strategy is aware of it, however, since it gets restored with most
     # up-to-date experiment data. Do adding trial #4 to the seen completed
     # trials of pre-storage GS to check their equality otherwise.
     gs._seen_trial_indices_by_status[TrialStatus.COMPLETED].add(4)
     self.assertEqual(gs, ax_client.generation_strategy)
     with self.assertRaises(ValueError):
         # Overwriting existing experiment.
         ax_client.create_experiment(
             name="test_experiment",
             parameters=[
                 {"name": "x", "type": "range", "bounds": [-5.0, 10.0]},
                 {"name": "y", "type": "range", "bounds": [0.0, 15.0]},
             ],
             minimize=True,
         )
     with self.assertRaises(ValueError):
         # Overwriting existing experiment with overwrite flag with present
         # DB settings. This should fail as we no longer allow overwriting
         # experiments stored in the DB.
         ax_client.create_experiment(
             name="test_experiment",
             parameters=[{"name": "x", "type": "range", "bounds": [-5.0, 10.0]}],
             overwrite_existing_experiment=True,
         )
     # Original experiment should still be in DB and not have been overwritten.
     self.assertEqual(len(ax_client.experiment.trials), 5)
示例#4
0
    def setUp(self):
        self.generation_strategy = get_generation_strategy(
            with_experiment=True)
        self.experiment = self.generation_strategy.experiment

        init_test_engine_and_session_factory(force_init=True)
        self.with_db_settings = WithDBSettingsBase(db_settings=DBSettings(
            url="sqlite://"))
        _save_experiment(self.experiment,
                         encoder=self.with_db_settings.db_settings.encoder)
        _save_generation_strategy(
            generation_strategy=self.generation_strategy,
            encoder=self.with_db_settings.db_settings.encoder,
        )
示例#5
0
    def test_save_load_experiment(self):
        exp = get_experiment()
        init_test_engine_and_session_factory(force_init=True)
        db_settings = DBSettings(
            encoder=Encoder(config=SQAConfig()),
            decoder=Decoder(config=SQAConfig()),
            creator=None,
        )
        save_experiment(exp, db_settings)
        load_experiment(exp.name, db_settings)

        simple_experiment = get_simple_experiment()
        save_experiment(simple_experiment, db_settings)
        with self.assertRaisesRegex(ValueError, "Service API only"):
            load_experiment(simple_experiment.name, db_settings)
示例#6
0
 def test_suppress_all_storage_errors(self, mock_save_exp, _):
     init_test_engine_and_session_factory(force_init=True)
     config = SQAConfig()
     encoder = Encoder(config=config)
     decoder = Decoder(config=config)
     db_settings = DBSettings(encoder=encoder, decoder=decoder)
     BareBonesTestScheduler(
         experiment=self.branin_experiment,  # Has runner and metrics.
         generation_strategy=self.two_sobol_steps_GS,
         options=SchedulerOptions(
             init_seconds_between_polls=0.1,  # Short between polls so test is fast.
             suppress_storage_errors_after_retries=True,
         ),
         db_settings=db_settings,
     )
     self.assertEqual(mock_save_exp.call_count, 3)
示例#7
0
 def test_sqa_storage(self):
     init_test_engine_and_session_factory(force_init=True)
     config = SQAConfig()
     encoder = Encoder(config=config)
     decoder = Decoder(config=config)
     db_settings = DBSettings(encoder=encoder, decoder=decoder)
     ax_client = AxClient(db_settings=db_settings)
     ax_client.create_experiment(
         name="test_experiment",
         parameters=[
             {"name": "x", "type": "range", "bounds": [-5.0, 10.0]},
             {"name": "y", "type": "range", "bounds": [0.0, 15.0]},
         ],
         minimize=True,
     )
     for _ in range(5):
         parameters, trial_index = ax_client.get_next_trial()
         ax_client.complete_trial(
             trial_index=trial_index, raw_data=branin(*parameters.values())
         )
     gs = ax_client.generation_strategy
     ax_client = AxClient(db_settings=db_settings)
     ax_client.load_experiment_from_database("test_experiment")
     self.assertEqual(gs, ax_client.generation_strategy)
     with self.assertRaises(ValueError):
         # Overwriting existing experiment.
         ax_client.create_experiment(
             name="test_experiment",
             parameters=[
                 {"name": "x", "type": "range", "bounds": [-5.0, 10.0]},
                 {"name": "y", "type": "range", "bounds": [0.0, 15.0]},
             ],
             minimize=True,
         )
     with self.assertRaises(ValueError):
         # Overwriting existing experiment with overwrite flag with present
         # DB settings. This should fail as we no longer allow overwriting
         # experiments stored in the DB.
         ax_client.create_experiment(
             name="test_experiment",
             parameters=[{"name": "x", "type": "range", "bounds": [-5.0, 10.0]}],
             overwrite_existing_experiment=True,
         )
     # Original experiment should still be in DB and not have been overwritten.
     self.assertEqual(len(ax_client.experiment.trials), 5)
示例#8
0
  def __init__(self, name: str, db_path: str,
               parameters: Optional[List[TParameterization]] = None,
               objective_name: Optional[str] = None):
    if not os.path.isfile(db_path):
      init_engine_and_session_factory(url=f'sqlite:///{db_path}')
      create_all_tables(get_engine())

    self.name = name
    self.ax = AxClient(enforce_sequential_optimization=False,
                       verbose_logging=False,
                       db_settings=DBSettings(url=f'sqlite:///{db_path}'))

    if self.ax._experiment is None:
      try:
        self.ax.create_experiment(name=name, parameters=parameters,
                                  objective_name=objective_name)
      except ValueError:
        self.ax.load_experiment_from_database(name)
示例#9
0
 def test_no_sqa(self):
     # Pretend we couldn't import sqa_store.structs (this could happen when
     # SQLAlchemy is not installed).
     patcher = patch("ax.service.ax_client.DBSettings", None)
     patcher.start()
     with self.assertRaises(ModuleNotFoundError):
         import ax_client.storage.sqa_store.structs  # noqa F401
     AxClient()  # Make sure we still can instantiate client w/o db settings.
     # Even with correctly typed DBSettings, `AxClient` instantiation should
     # fail here, because `DBSettings` are mocked to None in `ax_client`.
     db_settings = DBSettings()
     self.assertIsInstance(db_settings, DBSettings)
     with self.assertRaisesRegex(ValueError, "`db_settings` argument should "):
         AxClient(db_settings=db_settings)
     patcher.stop()
     # DBSettings should be defined in `ax_client` now, but incorrectly typed
     # `db_settings` argument should still make instantiation fail.
     with self.assertRaisesRegex(ValueError, "`db_settings` argument should "):
         AxClient(db_settings="badly_typed_db_settings")
示例#10
0
 def test_sqa_storage(self):
     init_test_engine_and_session_factory(force_init=True)
     config = SQAConfig()
     encoder = Encoder(config=config)
     decoder = Decoder(config=config)
     db_settings = DBSettings(encoder=encoder, decoder=decoder)
     ax_client = AxClient(db_settings=db_settings)
     ax_client.create_experiment(
         name="test_experiment",
         parameters=[
             {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
             {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
         ],
         minimize=True,
     )
     for _ in range(5):
         parameters, trial_index = ax_client.get_next_trial()
         ax_client.complete_trial(
             trial_index=trial_index, raw_data=branin(*parameters.values())
         )
     gs = ax_client.generation_strategy
     ax_client = AxClient(db_settings=db_settings)
     ax_client.load_experiment_from_database("test_experiment")
     self.assertEqual(gs, ax_client.generation_strategy)
     with self.assertRaises(ValueError):
         # Overwriting existing experiment.
         ax_client.create_experiment(
             name="test_experiment",
             parameters=[
                 {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
                 {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
             ],
             minimize=True,
         )
     # Overwriting existing experiment with overwrite flag.
     ax_client.create_experiment(
         name="test_experiment",
         parameters=[{"name": "x1", "type": "range", "bounds": [-5.0, 10.0]}],
         overwrite_existing_experiment=True,
     )
     # There should be no trials, as we just put in a fresh experiment.
     self.assertEqual(len(ax_client.experiment.trials), 0)
示例#11
0
 def test_sqa_storage(self):
     init_test_engine_and_session_factory(force_init=True)
     config = SQAConfig()
     encoder = Encoder(config=config)
     decoder = Decoder(config=config)
     db_settings = DBSettings(encoder=encoder, decoder=decoder)
     ax = AxClient(db_settings=db_settings)
     ax.create_experiment(
         name="test_experiment",
         parameters=[
             {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
             {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
         ],
         minimize=True,
     )
     for _ in range(5):
         parameters, trial_index = ax.get_next_trial()
         ax.complete_trial(
             trial_index=trial_index, raw_data=branin(*parameters.values())
         )
     gs = ax.generation_strategy
     ax = AxClient(db_settings=db_settings)
     ax.load_experiment_from_database("test_experiment")
     self.assertEqual(gs, ax.generation_strategy)
示例#12
0
 def test_sqa_storage(self):
     init_test_engine_and_session_factory(force_init=True)
     config = SQAConfig()
     encoder = Encoder(config=config)
     decoder = Decoder(config=config)
     db_settings = DBSettings(encoder=encoder, decoder=decoder)
     experiment = self.branin_experiment
     # Scheduler currently requires that the experiment be pre-saved.
     with self.assertRaisesRegex(ValueError, ".* must specify a name"):
         experiment._name = None
         scheduler = TestScheduler(
             experiment=experiment,
             generation_strategy=self.two_sobol_steps_GS,
             options=SchedulerOptions(total_trials=1),
             db_settings=db_settings,
         )
     experiment._name = "test_experiment"
     NUM_TRIALS = 5
     scheduler = TestScheduler(
         experiment=experiment,
         generation_strategy=self.two_sobol_steps_GS,
         options=SchedulerOptions(
             total_trials=NUM_TRIALS,
             init_seconds_between_polls=
             0,  # No wait between polls so test is fast.
         ),
         db_settings=db_settings,
     )
     # Check that experiment and GS were saved.
     exp, gs = scheduler._load_experiment_and_generation_strategy(
         experiment.name)
     self.assertEqual(exp, experiment)
     self.assertEqual(gs, self.two_sobol_steps_GS)
     scheduler.run_all_trials()
     # Check that experiment and GS were saved and test reloading with reduced state.
     exp, gs = scheduler._load_experiment_and_generation_strategy(
         experiment.name, reduced_state=True)
     self.assertEqual(len(exp.trials), NUM_TRIALS)
     self.assertEqual(len(gs._generator_runs), NUM_TRIALS)
     # Test `from_stored_experiment`.
     new_scheduler = TestScheduler.from_stored_experiment(
         experiment_name=experiment.name,
         options=SchedulerOptions(
             total_trials=NUM_TRIALS + 1,
             init_seconds_between_polls=
             0,  # No wait between polls so test is fast.
         ),
         db_settings=db_settings,
     )
     # Hack "resumed from storage timestamp" into `exp` to make sure all other fields
     # are equal, since difference in resumed from storage timestamps is expected.
     exp._properties[
         ExperimentStatusProperties.
         RESUMED_FROM_STORAGE_TIMESTAMPS] = new_scheduler.experiment._properties[
             ExperimentStatusProperties.RESUMED_FROM_STORAGE_TIMESTAMPS]
     self.assertEqual(new_scheduler.experiment, exp)
     self.assertEqual(new_scheduler.generation_strategy, gs)
     self.assertEqual(
         len(new_scheduler.experiment._properties[
             ExperimentStatusProperties.RESUMED_FROM_STORAGE_TIMESTAMPS]),
         1,
     )
示例#13
0
 def setUp(self):
     self.exp = get_experiment()
     init_test_engine_and_session_factory(force=True)
     self.db_settings = DBSettings(url="sqlite://")
     save_experiment(self.exp, self.db_settings)