Beispiel #1
0
 def test_deprecated_save_load_method_errors(self):
     ax_client = AxClient()
     with self.assertRaises(NotImplementedError):
         ax_client.save()
     with self.assertRaises(NotImplementedError):
         ax_client.load()
     with self.assertRaises(NotImplementedError):
         ax_client.load_experiment("test_experiment")
Beispiel #2
0
 def test_sqa_storage(self):
     init_test_engine_and_session_factory(force_init=True)
     config = SQAConfig()
     encoder = Encoder(config=config)
     decoder = Decoder(config=config)
     db_settings = DBSettings(encoder=encoder, decoder=decoder)
     ax = AxClient(db_settings=db_settings)
     ax.create_experiment(
         name="test_experiment",
         parameters=[
             {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
             {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
         ],
         minimize=True,
     )
     for _ in range(5):
         parameters, trial_index = ax.get_next_trial()
         ax.complete_trial(
             trial_index=trial_index, raw_data=branin(*parameters.values())
         )
     gs = ax.generation_strategy
     ax = AxClient(db_settings=db_settings)
     ax.load_experiment("test_experiment")
     self.assertEqual(gs, ax.generation_strategy)