def save_updated_trials( experiment: Experiment, trials: List[BaseTrial], db_settings: DBSettings ) -> None: """Save a set of updated trials on an experiment in DB. NOTE: This function also saves data attached to experiment for these trials. Args: experiment: `Experiment` object. trials: List of trials (subclasses of `BaseTrial`: `Trial` or `BatchTrial`). db_settings: Defines behavior for loading/saving experiment to/from db. """ init_engine_and_session_factory(creator=db_settings.creator, url=db_settings.url) start_time = time.time() _update_trials(experiment=experiment, trials=trials, encoder=db_settings.encoder) logger.debug( f"Updated trials {[trial.index for trial in trials]} in " f"{_round_floats_for_logging(time.time() - start_time)} seconds." )
def save_experiment( experiment: Experiment, db_settings: DBSettings, overwrite_existing_experiment: bool = False, ) -> None: """ Save experiment to db. Args: experiment: `Experiment` object. db_settings: Defines behavior for loading/saving experiment to/from db. overwrite_existing_experiment: If the experiment being created has the same name as some experiment already stored, this flag determines whether to overwrite the existing experiment. Defaults to False. """ init_engine_and_session_factory(creator=db_settings.creator, url=db_settings.url) _save_experiment( experiment, encoder=db_settings.encoder, overwrite=overwrite_existing_experiment )
def load_experiment(name: str, db_settings: DBSettings) -> Experiment: """ Load experiment from the db. Service API only supports `Experiment`. Args: name: Experiment name. db_settings: Defines behavior for loading/saving experiment to/from db. Returns: ax.core.Experiment: Loaded experiment. """ init_engine_and_session_factory(creator=db_settings.creator, url=db_settings.url) start_time = time.time() experiment = _load_experiment(name, decoder=db_settings.decoder) if not isinstance(experiment, Experiment) or experiment.is_simple_experiment: raise ValueError("Service API only supports Experiment") logger.info( f"Loaded experiment {name} in " f"{_round_floats_for_logging(time.time() - start_time)} seconds" ) return experiment
def save_updated_trial( experiment: Experiment, trial: BaseTrial, db_settings: DBSettings ) -> None: """ Save experiment to db. NOTE: This function also saves data attached to experiment for this trial. Args: experiment: `Experiment` object. trial: `BaseTrial` object. db_settings: Defines behavior for loading/saving experiment to/from db. """ init_engine_and_session_factory(creator=db_settings.creator, url=db_settings.url) start_time = time.time() _update_trial(experiment=experiment, trial=trial, encoder=db_settings.encoder) logger.info( f"Saved trial {trial.index} in " f"{_round_floats_for_logging(time.time() - start_time)} seconds" )
def update_generation_strategy( generation_strategy: GenerationStrategy, generator_runs: List[GeneratorRun], db_settings: DBSettings, ) -> None: """Update generation strategy in DB with new generator runs. Args: generation_strategy: Corresponding generation strategy. generator_runs: New generator runs produced from the generation strategy since its last save. db_settings: Defines behavior for loading/saving experiment to/from db. """ init_engine_and_session_factory(creator=db_settings.creator, url=db_settings.url) start_time = time.time() _update_generation_strategy( generation_strategy=generation_strategy, generator_runs=generator_runs, encoder=db_settings.encoder, ) logger.debug( f"Updated generation strategy {generation_strategy.name} in " f"{_round_floats_for_logging(time.time() - start_time)} seconds.")
def testConnectionToDBWithCreator(self): def MockDBAPI(): connection = Mock() def connect(*args, **kwargs): return connection return MagicMock(connect=Mock(side_effect=connect)) mocked_dbapi = MockDBAPI() init_engine_and_session_factory( creator=lambda: mocked_dbapi.connect(), force_init=True, module=mocked_dbapi, echo=True, pool_size=2, _initialize=False, ) with session_scope() as session: engine = session.bind engine.connect() self.assertEqual(mocked_dbapi.connect.call_count, 1) self.assertTrue(engine.echo) self.assertEqual(engine.pool.size(), 2)
def testConnectionToDBWithURL(self): init_engine_and_session_factory(url="sqlite://", force_init=True)