Beispiel #1
0
def save_updated_trials(
    experiment: Experiment, trials: List[BaseTrial], db_settings: DBSettings
) -> None:
    """Save a set of updated trials on an experiment in DB.

    NOTE: This function also saves data attached to experiment
    for these trials.

    Args:
        experiment: `Experiment` object.
        trials: List of trials (subclasses of `BaseTrial`: `Trial` or `BatchTrial`).
        db_settings: Defines behavior for loading/saving experiment to/from db.
    """
    init_engine_and_session_factory(creator=db_settings.creator, url=db_settings.url)
    start_time = time.time()
    _update_trials(experiment=experiment, trials=trials, encoder=db_settings.encoder)
    logger.debug(
        f"Updated trials {[trial.index for trial in trials]} in "
        f"{_round_floats_for_logging(time.time() - start_time)} seconds."
    )
def save_experiment(
    experiment: Experiment,
    db_settings: DBSettings,
    overwrite_existing_experiment: bool = False,
) -> None:
    """
    Save experiment to db.

    Args:
        experiment: `Experiment` object.
        db_settings: Defines behavior for loading/saving experiment to/from db.
        overwrite_existing_experiment: If the experiment being created
            has the same name as some experiment already stored, this flag
            determines whether to overwrite the existing experiment.
            Defaults to False.
    """
    init_engine_and_session_factory(creator=db_settings.creator, url=db_settings.url)
    _save_experiment(
        experiment, encoder=db_settings.encoder, overwrite=overwrite_existing_experiment
    )
Beispiel #3
0
def load_experiment(name: str, db_settings: DBSettings) -> Experiment:
    """
    Load experiment from the db. Service API only supports `Experiment`.

    Args:
        name: Experiment name.
        db_settings: Defines behavior for loading/saving experiment to/from db.

    Returns:
        ax.core.Experiment: Loaded experiment.
    """
    init_engine_and_session_factory(creator=db_settings.creator, url=db_settings.url)
    start_time = time.time()
    experiment = _load_experiment(name, decoder=db_settings.decoder)
    if not isinstance(experiment, Experiment) or experiment.is_simple_experiment:
        raise ValueError("Service API only supports Experiment")
    logger.info(
        f"Loaded experiment {name} in "
        f"{_round_floats_for_logging(time.time() - start_time)} seconds"
    )
    return experiment
Beispiel #4
0
def save_updated_trial(
    experiment: Experiment, trial: BaseTrial, db_settings: DBSettings
) -> None:
    """
    Save experiment to db.

    NOTE: This function also saves data attached to experiment
    for this trial.

    Args:
        experiment: `Experiment` object.
        trial: `BaseTrial` object.
        db_settings: Defines behavior for loading/saving experiment to/from db.
    """
    init_engine_and_session_factory(creator=db_settings.creator, url=db_settings.url)
    start_time = time.time()
    _update_trial(experiment=experiment, trial=trial, encoder=db_settings.encoder)
    logger.info(
        f"Saved trial {trial.index} in "
        f"{_round_floats_for_logging(time.time() - start_time)} seconds"
    )
Beispiel #5
0
def update_generation_strategy(
    generation_strategy: GenerationStrategy,
    generator_runs: List[GeneratorRun],
    db_settings: DBSettings,
) -> None:
    """Update generation strategy in DB with new generator runs.

    Args:
        generation_strategy: Corresponding generation strategy.
        generator_runs: New generator runs produced from the generation strategy
            since its last save.
        db_settings: Defines behavior for loading/saving experiment to/from db.
    """
    init_engine_and_session_factory(creator=db_settings.creator,
                                    url=db_settings.url)
    start_time = time.time()
    _update_generation_strategy(
        generation_strategy=generation_strategy,
        generator_runs=generator_runs,
        encoder=db_settings.encoder,
    )
    logger.debug(
        f"Updated generation strategy {generation_strategy.name} in "
        f"{_round_floats_for_logging(time.time() - start_time)} seconds.")
Beispiel #6
0
    def testConnectionToDBWithCreator(self):
        def MockDBAPI():
            connection = Mock()

            def connect(*args, **kwargs):
                return connection

            return MagicMock(connect=Mock(side_effect=connect))

        mocked_dbapi = MockDBAPI()
        init_engine_and_session_factory(
            creator=lambda: mocked_dbapi.connect(),
            force_init=True,
            module=mocked_dbapi,
            echo=True,
            pool_size=2,
            _initialize=False,
        )
        with session_scope() as session:
            engine = session.bind
            engine.connect()
            self.assertEqual(mocked_dbapi.connect.call_count, 1)
            self.assertTrue(engine.echo)
            self.assertEqual(engine.pool.size(), 2)
Beispiel #7
0
 def testConnectionToDBWithURL(self):
     init_engine_and_session_factory(url="sqlite://", force_init=True)