Beispiel #1
0
def _save_experiment(experiment: Experiment, encoder: Encoder) -> None:
    """Save experiment, using given Encoder instance.

    1) Convert Ax object to SQLAlchemy object.
    2) Determine if there is an existing experiment with that name in the DB.
    3) If not, create a new one.
    4) If so, update the old one.
        The update works by merging the new SQLAlchemy object into the
        existing SQLAlchemy object, and then letting SQLAlchemy handle the
        actual DB updates.
    """
    # Convert user-facing class to SQA outside of session scope to avoid timeouts
    exp_sqa_class = encoder.config.class_to_sqa_class[Experiment]
    with session_scope() as session:
        existing_sqa_experiment = (session.query(exp_sqa_class).filter_by(
            name=experiment.name).one_or_none())
    encoder.validate_experiment_metadata(
        experiment, existing_sqa_experiment=existing_sqa_experiment)
    new_sqa_experiment = encoder.experiment_to_sqa(experiment)

    if existing_sqa_experiment is not None:
        # Update the SQA object outside of session scope to avoid timeouts.
        # This object is detached from the session, but contains a database
        # identity marker, so when we do `session.add` below, SQA knows to
        # perform an update rather than an insert.
        existing_sqa_experiment.update(new_sqa_experiment)
        new_sqa_experiment = existing_sqa_experiment

    with session_scope() as session:
        session.add(new_sqa_experiment)
Beispiel #2
0
 def setUp(self):
     init_test_engine_and_session_factory(force_init=True)
     self.config = SQAConfig()
     self.encoder = Encoder(config=self.config)
     self.decoder = Decoder(config=self.config)
     self.experiment = get_experiment_with_batch_trial()
     self.dummy_parameters = [
         get_range_parameter(),  # w
         get_range_parameter2(),  # x
     ]
Beispiel #3
0
def _update_trials(experiment: Experiment, trials: List[BaseTrial],
                   encoder: Encoder) -> None:
    """Update trials and attach data."""
    trial_sqa_class = encoder.config.class_to_sqa_class[Trial]
    trial_indices = [trial.index for trial in trials]
    obj_to_sqa = []
    with session_scope() as session:
        experiment_id = _get_experiment_id(experiment=experiment,
                                           encoder=encoder,
                                           session=session)
        existing_trials = (
            session.query(trial_sqa_class).filter_by(
                experiment_id=experiment_id).filter(
                    trial_sqa_class.index.in_(trial_indices))  # pyre-ignore
            .all())

    trial_index_to_existing_trial = {
        trial.index: trial
        for trial in existing_trials
    }

    updated_sqa_trials, new_sqa_data = [], []
    for trial in trials:
        existing_trial = trial_index_to_existing_trial.get(trial.index)
        if existing_trial is None:
            raise ValueError(
                f"Trial {trial.index} is not attached to the experiment.")

        new_sqa_trial, _obj_to_sqa = encoder.trial_to_sqa(trial)
        obj_to_sqa.extend(_obj_to_sqa)
        existing_trial.update(new_sqa_trial)
        updated_sqa_trials.append(existing_trial)

        data, ts = experiment.lookup_data_for_trial(trial_index=trial.index)
        if ts != -1:
            sqa_data = encoder.data_to_sqa(data=data,
                                           trial_index=trial.index,
                                           timestamp=ts)
            obj_to_sqa.append((data, sqa_data))
            sqa_data.experiment_id = experiment_id
            new_sqa_data.append(sqa_data)

    with session_scope() as session:
        session.add_all(updated_sqa_trials)
        session.add_all(new_sqa_data)
        session.flush()

    _set_db_ids(obj_to_sqa=obj_to_sqa)
Beispiel #4
0
def save_new_trial(experiment: Experiment,
                   trial: BaseTrial,
                   config: Optional[SQAConfig] = None) -> None:
    """Add new trial to the experiment (using default SQAConfig)."""
    config = config or SQAConfig()
    encoder = Encoder(config=config)
    _save_new_trial(experiment=experiment, trial=trial, encoder=encoder)
Beispiel #5
0
def _update_generation_strategy(
    generation_strategy: GenerationStrategy,
    generator_runs: List[GeneratorRun],
    encoder: Encoder,
) -> None:
    """Update generation strategy's current step and attach generator runs."""
    gs_sqa_class = encoder.config.class_to_sqa_class[GenerationStrategy]

    gs_id = generation_strategy._db_id
    if gs_id is None:
        raise ValueError(
            "GenerationStrategy must be saved before being updated.")

    with session_scope() as session:
        experiment_id = _get_experiment_id(
            experiment=generation_strategy.experiment,
            encoder=encoder,
            session=session)
        gs_sqa = session.query(gs_sqa_class).get(gs_id)
        gs_sqa.curr_index = generation_strategy._curr.index  # pyre-fixme
        gs_sqa.experiment_id = experiment_id  # pyre-ignore

        session.add(gs_sqa)
        for generator_run in generator_runs:
            gr_sqa = encoder.generator_run_to_sqa(generator_run=generator_run)
            gr_sqa.generation_strategy_id = gs_id
            session.add(gr_sqa)
Beispiel #6
0
def update_trial(experiment: Experiment,
                 trial: BaseTrial,
                 config: Optional[SQAConfig] = None) -> None:
    """Update trial and attach data (using default SQAConfig)."""
    config = config or SQAConfig()
    encoder = Encoder(config=config)
    _update_trial(experiment=experiment, trial=trial, encoder=encoder)
Beispiel #7
0
 def test_storage_error_handling(self, mock_save_fails):
     """Check that if `suppress_storage_errors` is True, AxClient won't
     visibly fail if encountered storage errors.
     """
     init_test_engine_and_session_factory(force_init=True)
     config = SQAConfig()
     encoder = Encoder(config=config)
     decoder = Decoder(config=config)
     db_settings = DBSettings(encoder=encoder, decoder=decoder)
     ax_client = AxClient(db_settings=db_settings,
                          suppress_storage_errors=True)
     ax_client.create_experiment(
         name="test_experiment",
         parameters=[
             {
                 "name": "x",
                 "type": "range",
                 "bounds": [-5.0, 10.0]
             },
             {
                 "name": "y",
                 "type": "range",
                 "bounds": [0.0, 15.0]
             },
         ],
         minimize=True,
     )
     for _ in range(3):
         parameters, trial_index = ax_client.get_next_trial()
         ax_client.complete_trial(trial_index=trial_index,
                                  raw_data=branin(*parameters.values()))
Beispiel #8
0
def save_or_update_trials(
    experiment: Experiment,
    trials: List[BaseTrial],
    config: Optional[SQAConfig] = None,
    batch_size: Optional[int] = None,
    reduce_state_generator_runs: bool = False,
) -> None:
    """Add new trials to the experiment, or update if already exists
    (using default SQAConfig).

    Note that new data objects (whether attached to existing or new trials)
    will also be added to the experiment, but existing data objects in the
    database will *not* be updated or removed.
    """
    config = config or SQAConfig()
    encoder = Encoder(config=config)
    decoder = Decoder(config=config)
    _save_or_update_trials(
        experiment=experiment,
        trials=trials,
        encoder=encoder,
        decoder=decoder,
        batch_size=batch_size,
        reduce_state_generator_runs=reduce_state_generator_runs,
    )
Beispiel #9
0
def _save_generation_strategy(generation_strategy: GenerationStrategy,
                              encoder: Encoder) -> int:
    # If the generation strategy has not yet generated anything, there will be no
    # experiment set on it.
    if generation_strategy._experiment is None:
        experiment_id = None
    else:
        # Experiment was set on the generation strategy, so need to check whether
        # if has been saved and create a relationship b/w GS and experiment if so.
        experiment_id = _get_experiment_id(
            experiment=generation_strategy._experiment, encoder=encoder)

    gs_sqa = encoder.generation_strategy_to_sqa(
        generation_strategy=generation_strategy, experiment_id=experiment_id)

    with session_scope() as session:
        if generation_strategy._db_id is None:
            session.add(gs_sqa)
            session.flush()  # Ensures generation strategy id is set.
            generation_strategy._db_id = gs_sqa.id
        else:
            gs_sqa_class = encoder.config.class_to_sqa_class[
                GenerationStrategy]
            existing_gs_sqa = session.query(gs_sqa_class).get(
                generation_strategy._db_id)
            existing_gs_sqa.update(gs_sqa)
            # our update logic ignores foreign keys, i.e. fields ending in _id,
            # because we want SQLAlchemy to handle those relationships for us
            # however, generation_strategy.experiment_id is an exception, so we
            # need to update that manually
            existing_gs_sqa.experiment_id = gs_sqa.experiment_id

    return generation_strategy._db_id
Beispiel #10
0
def _save_experiment(
    experiment: Experiment,
    encoder: Encoder,
    return_sqa: bool = False,
    validation_kwargs: Optional[Dict[str, Any]] = None,
) -> Optional[SQABase]:
    """Save experiment, using given Encoder instance.

    1) Convert Ax object to SQLAlchemy object.
    2) Determine if there is an existing experiment with that name in the DB.
    3) If not, create a new one.
    4) If so, update the old one.
        The update works by merging the new SQLAlchemy object into the
        existing SQLAlchemy object, and then letting SQLAlchemy handle the
        actual DB updates.
    """
    # Convert user-facing class to SQA outside of session scope to avoid timeouts
    exp_sqa_class = encoder.config.class_to_sqa_class[Experiment]
    with session_scope() as session:
        existing_sqa_experiment = (session.query(exp_sqa_class).filter_by(
            name=experiment.name).one_or_none())
    encoder.validate_experiment_metadata(
        experiment,
        # pyre-fixme[6]: Expected
        #  `Optional[ax.storage.sqa_store.sqa_classes.SQAExperiment]` for 2nd param but
        #  got `Optional[ax.storage.sqa_store.db.SQABase]`.
        existing_sqa_experiment=existing_sqa_experiment,
        **(validation_kwargs or {}),
    )
    new_sqa_experiment, obj_to_sqa = encoder.experiment_to_sqa(experiment)

    if existing_sqa_experiment is not None:
        # Update the SQA object outside of session scope to avoid timeouts.
        # This object is detached from the session, but contains a database
        # identity marker, so when we do `session.add` below, SQA knows to
        # perform an update rather than an insert.
        # pyre-fixme[6]: Expected `SQABase` for 1st param but got `SQAExperiment`.
        existing_sqa_experiment.update(new_sqa_experiment)
        new_sqa_experiment = existing_sqa_experiment

    with session_scope() as session:
        session.add(new_sqa_experiment)
        session.flush()

    _set_db_ids(obj_to_sqa=obj_to_sqa)

    return checked_cast(SQABase, new_sqa_experiment) if return_sqa else None
Beispiel #11
0
def save_or_update_trial(experiment: Experiment,
                         trial: BaseTrial,
                         config: Optional[SQAConfig] = None) -> None:
    """Add new trial to the experiment, or update if already exists
    (using default SQAConfig)."""
    config = config or SQAConfig()
    encoder = Encoder(config=config)
    _save_or_update_trial(experiment=experiment, trial=trial, encoder=encoder)
Beispiel #12
0
Datei: save.py Projekt: jlin27/Ax
def save_generation_strategy(generation_strategy: GenerationStrategy,
                             config: Optional[SQAConfig] = None) -> int:
    """Save generation strategy (using default SQAConfig if no config is
    specified). If the generation strategy has an experiment set, the experiment
    will be saved first.

    Returns:
        The ID of the saved generation strategy.
    """

    # Start up SQA encoder.
    config = config or SQAConfig()
    encoder = Encoder(config=config)

    # If the generation strategy has not yet generated anything, there will be no
    # experiment set on it.
    if generation_strategy._experiment is None:
        experiment_id = None
    else:
        # Experiment was set on the generation strategy, so we need to save it first.
        save_experiment(experiment=generation_strategy._experiment,
                        config=config)
        experiment_id = _get_experiment_id(
            experiment=generation_strategy._experiment, encoder=encoder)

    gs_sqa = encoder.generation_strategy_to_sqa(
        generation_strategy=generation_strategy, experiment_id=experiment_id)

    with session_scope() as session:
        if generation_strategy._db_id is None:
            session.add(gs_sqa)
            session.flush()  # Ensures generation strategy id is set.
            generation_strategy._db_id = gs_sqa.id
        else:
            existing_gs_sqa = session.query(SQAGenerationStrategy).get(
                generation_strategy._db_id)
            existing_gs_sqa.update(gs_sqa)
            # our update logic ignores foreign keys, i.e. fields ending in _id,
            # because we want SQLAlchemy to handle those relationships for us
            # however, generation_strategy.experiment_id is an exception, so we
            # need to update that manually
            existing_gs_sqa.experiment_id = gs_sqa.experiment_id

    return generation_strategy._db_id
Beispiel #13
0
class DBSettings(NamedTuple):
    """
    Defines behavior for loading/saving experiment to/from db.
    Either creator or url must be specified as a way to connect to the SQL db.
    """

    creator: Optional[Callable] = None
    decoder: Decoder = Decoder(config=SQAConfig())
    encoder: Encoder = Encoder(config=SQAConfig())
    url: Optional[str] = None
Beispiel #14
0
def save_experiment(experiment: Experiment, config: Optional[SQAConfig] = None) -> None:
    """Save experiment (using default SQAConfig)."""
    if not isinstance(experiment, Experiment):
        raise ValueError("Can only save instances of Experiment")
    if not experiment.has_name:
        raise ValueError("Experiment name must be set prior to saving.")

    config = config or SQAConfig()
    encoder = Encoder(config=config)
    return _save_experiment(experiment=experiment, encoder=encoder)
Beispiel #15
0
def _update_trial(experiment: Experiment, trial: BaseTrial,
                  encoder: Encoder) -> None:
    """Update trial and attach data, using given Encoder instance."""
    exp_sqa_class = encoder.config.class_to_sqa_class[Experiment]
    with session_scope() as session:
        existing_sqa_experiment = (session.query(exp_sqa_class).filter_by(
            name=experiment.name).one_or_none())

    if existing_sqa_experiment is None:
        raise ValueError("Must save experiment before updating a trial.")

    existing_trial_indices = {
        trial.index
        for trial in existing_sqa_experiment.trials
    }
    if trial.index not in existing_trial_indices:
        raise ValueError(
            f"Trial {trial.index} is not attached to the experiment.")

    # There should only be one existing trial with the same index
    existing_sqa_trials = [
        sqa_trial for sqa_trial in existing_sqa_experiment.trials
        if sqa_trial.index == trial.index
    ]
    assert len(existing_sqa_trials) == 1
    existing_sqa_trial = existing_sqa_trials[0]

    new_sqa_trial = encoder.trial_to_sqa(trial)
    existing_sqa_trial.update(new_sqa_trial)

    with session_scope() as session:
        session.add(existing_sqa_trial)

    data, ts = experiment.lookup_data_for_trial(trial_index=trial.index)
    if ts != -1:
        sqa_data = encoder.data_to_sqa(data=data,
                                       trial_index=trial.index,
                                       timestamp=ts)
        with session_scope() as session:
            existing_sqa_experiment.data.append(sqa_data)
            session.add(existing_sqa_experiment)
Beispiel #16
0
def _save_experiment(
    experiment: Experiment,
    encoder: Encoder,
    decoder: Decoder,
    return_sqa: bool = False,
    validation_kwargs: Optional[Dict[str, Any]] = None,
) -> Optional[SQABase]:
    """Save experiment, using given Encoder instance.

    1) Convert Ax object to SQLAlchemy object.
    2) Determine if there is an existing experiment with that name in the DB.
    3) If not, create a new one.
    4) If so, update the old one.
        The update works by merging the new SQLAlchemy object into the
        existing SQLAlchemy object, and then letting SQLAlchemy handle the
        actual DB updates.
    """
    exp_sqa_class = encoder.config.class_to_sqa_class[Experiment]
    with session_scope() as session:
        existing_sqa_experiment_id = (
            # pyre-ignore Undefined attribute [16]: `SQABase` has no attribute `id`
            session.query(exp_sqa_class.id)
            .filter_by(name=experiment.name)
            .one_or_none()
        )
    if existing_sqa_experiment_id:
        existing_sqa_experiment_id = existing_sqa_experiment_id[0]

    encoder.validate_experiment_metadata(
        experiment,
        existing_sqa_experiment_id=existing_sqa_experiment_id,
        **(validation_kwargs or {}),
    )

    experiment_sqa = _merge_into_session(
        obj=experiment,
        encode_func=encoder.experiment_to_sqa,
        decode_func=decoder.experiment_from_sqa,
    )

    return checked_cast(SQABase, experiment_sqa) if return_sqa else None
Beispiel #17
0
 def test_sqa_storage(self):
     init_test_engine_and_session_factory(force_init=True)
     config = SQAConfig()
     encoder = Encoder(config=config)
     decoder = Decoder(config=config)
     db_settings = DBSettings(encoder=encoder, decoder=decoder)
     ax_client = AxClient(db_settings=db_settings)
     ax_client.create_experiment(
         name="test_experiment",
         parameters=[
             {"name": "x", "type": "range", "bounds": [-5.0, 10.0]},
             {"name": "y", "type": "range", "bounds": [0.0, 15.0]},
         ],
         minimize=True,
     )
     for _ in range(5):
         parameters, trial_index = ax_client.get_next_trial()
         ax_client.complete_trial(
             trial_index=trial_index, raw_data=branin(*parameters.values())
         )
     gs = ax_client.generation_strategy
     ax_client = AxClient(db_settings=db_settings)
     ax_client.load_experiment_from_database("test_experiment")
     # Trial #4 was completed after the last time the generation strategy
     # generated candidates, so pre-save generation strategy was not
     # "aware" of completion of trial #4. Post-restoration generation
     # strategy is aware of it, however, since it gets restored with most
     # up-to-date experiment data. Do adding trial #4 to the seen completed
     # trials of pre-storage GS to check their equality otherwise.
     gs._seen_trial_indices_by_status[TrialStatus.COMPLETED].add(4)
     self.assertEqual(gs, ax_client.generation_strategy)
     with self.assertRaises(ValueError):
         # Overwriting existing experiment.
         ax_client.create_experiment(
             name="test_experiment",
             parameters=[
                 {"name": "x", "type": "range", "bounds": [-5.0, 10.0]},
                 {"name": "y", "type": "range", "bounds": [0.0, 15.0]},
             ],
             minimize=True,
         )
     with self.assertRaises(ValueError):
         # Overwriting existing experiment with overwrite flag with present
         # DB settings. This should fail as we no longer allow overwriting
         # experiments stored in the DB.
         ax_client.create_experiment(
             name="test_experiment",
             parameters=[{"name": "x", "type": "range", "bounds": [-5.0, 10.0]}],
             overwrite_existing_experiment=True,
         )
     # Original experiment should still be in DB and not have been overwritten.
     self.assertEqual(len(ax_client.experiment.trials), 5)
Beispiel #18
0
def update_generation_strategy(
    generation_strategy: GenerationStrategy,
    generator_runs: List[GeneratorRun],
    config: Optional[SQAConfig] = None,
) -> None:
    """Update generation strategy's current step and attach generator runs
    (using default SQAConfig)."""
    config = config or SQAConfig()
    encoder = Encoder(config=config)
    _update_generation_strategy(
        generation_strategy=generation_strategy,
        generator_runs=generator_runs,
        encoder=encoder,
    )
Beispiel #19
0
def update_trials(experiment: Experiment, trials: List[BaseTrial],
                  encoder: Encoder) -> None:
    """Update trials and attach data."""
    trial_sqa_class = encoder.config.class_to_sqa_class[Trial]
    with session_scope() as session:
        experiment_id = _get_experiment_id(experiment=experiment,
                                           encoder=encoder,
                                           session=session)
        trial_indices = [trial.index for trial in trials]
        existing_trials = (
            session.query(trial_sqa_class).filter_by(
                experiment_id=experiment_id).filter(
                    trial_sqa_class.index.in_(trial_indices))  # pyre-ignore
            .all())
        trial_index_to_existing_trial = {
            trial.index: trial
            for trial in existing_trials
        }

        for trial in trials:
            existing_trial = trial_index_to_existing_trial.get(trial.index)
            if existing_trial is None:
                raise ValueError(
                    f"Trial {trial.index} is not attached to the experiment.")

            new_sqa_trial = encoder.trial_to_sqa(trial)
            existing_trial.update(new_sqa_trial)
            session.add(existing_trial)

            data, ts = experiment.lookup_data_for_trial(
                trial_index=trial.index)
            if ts != -1:
                sqa_data = encoder.data_to_sqa(data=data,
                                               trial_index=trial.index,
                                               timestamp=ts)
                sqa_data.experiment_id = experiment_id
                session.add(sqa_data)
Beispiel #20
0
Datei: save.py Projekt: ekilic/Ax
def save_experiment(experiment: Experiment,
                    config: Optional[SQAConfig] = None,
                    overwrite: bool = False) -> None:
    """Save experiment (using default SQAConfig)."""
    # pyre-fixme[25]: Assertion will always fail.
    if not isinstance(experiment, Experiment):
        raise ValueError("Can only save instances of Experiment")
    if not experiment.has_name:
        raise ValueError("Experiment name must be set prior to saving.")

    config = config or SQAConfig()
    encoder = Encoder(config=config)
    _save_experiment(experiment=experiment,
                     encoder=encoder,
                     overwrite=overwrite)
Beispiel #21
0
def save_generation_strategy(generation_strategy: GenerationStrategy,
                             config: Optional[SQAConfig] = None) -> int:
    """Save generation strategy (using default SQAConfig if no config is
    specified). If the generation strategy has an experiment set, the experiment
    will be saved first.

    Returns:
        The ID of the saved generation strategy.
    """
    # Start up SQA encoder.
    config = config or SQAConfig()
    encoder = Encoder(config=config)

    return _save_generation_strategy(generation_strategy=generation_strategy,
                                     encoder=encoder)
Beispiel #22
0
    def test_save_load_experiment(self):
        exp = get_experiment()
        init_test_engine_and_session_factory(force_init=True)
        db_settings = DBSettings(
            encoder=Encoder(config=SQAConfig()),
            decoder=Decoder(config=SQAConfig()),
            creator=None,
        )
        save_experiment(exp, db_settings)
        load_experiment(exp.name, db_settings)

        simple_experiment = get_simple_experiment()
        save_experiment(simple_experiment, db_settings)
        with self.assertRaisesRegex(ValueError, "Service API only"):
            load_experiment(simple_experiment.name, db_settings)
Beispiel #23
0
 def test_suppress_all_storage_errors(self, mock_save_exp, _):
     init_test_engine_and_session_factory(force_init=True)
     config = SQAConfig()
     encoder = Encoder(config=config)
     decoder = Decoder(config=config)
     db_settings = DBSettings(encoder=encoder, decoder=decoder)
     BareBonesTestScheduler(
         experiment=self.branin_experiment,  # Has runner and metrics.
         generation_strategy=self.two_sobol_steps_GS,
         options=SchedulerOptions(
             init_seconds_between_polls=0.1,  # Short between polls so test is fast.
             suppress_storage_errors_after_retries=True,
         ),
         db_settings=db_settings,
     )
     self.assertEqual(mock_save_exp.call_count, 3)
Beispiel #24
0
 def test_sqa_storage(self):
     init_test_engine_and_session_factory(force_init=True)
     config = SQAConfig()
     encoder = Encoder(config=config)
     decoder = Decoder(config=config)
     db_settings = DBSettings(encoder=encoder, decoder=decoder)
     ax_client = AxClient(db_settings=db_settings)
     ax_client.create_experiment(
         name="test_experiment",
         parameters=[
             {"name": "x", "type": "range", "bounds": [-5.0, 10.0]},
             {"name": "y", "type": "range", "bounds": [0.0, 15.0]},
         ],
         minimize=True,
     )
     for _ in range(5):
         parameters, trial_index = ax_client.get_next_trial()
         ax_client.complete_trial(
             trial_index=trial_index, raw_data=branin(*parameters.values())
         )
     gs = ax_client.generation_strategy
     ax_client = AxClient(db_settings=db_settings)
     ax_client.load_experiment_from_database("test_experiment")
     self.assertEqual(gs, ax_client.generation_strategy)
     with self.assertRaises(ValueError):
         # Overwriting existing experiment.
         ax_client.create_experiment(
             name="test_experiment",
             parameters=[
                 {"name": "x", "type": "range", "bounds": [-5.0, 10.0]},
                 {"name": "y", "type": "range", "bounds": [0.0, 15.0]},
             ],
             minimize=True,
         )
     with self.assertRaises(ValueError):
         # Overwriting existing experiment with overwrite flag with present
         # DB settings. This should fail as we no longer allow overwriting
         # experiments stored in the DB.
         ax_client.create_experiment(
             name="test_experiment",
             parameters=[{"name": "x", "type": "range", "bounds": [-5.0, 10.0]}],
             overwrite_existing_experiment=True,
         )
     # Original experiment should still be in DB and not have been overwritten.
     self.assertEqual(len(ax_client.experiment.trials), 5)
Beispiel #25
0
def save_or_update_trials(
    experiment: Experiment,
    trials: List[BaseTrial],
    config: Optional[SQAConfig] = None,
    batch_size: Optional[int] = None,
) -> None:
    """Add new trials to the experiment, or update if already exists
    (using default SQAConfig)."""
    config = config or SQAConfig()
    encoder = Encoder(config=config)
    decoder = Decoder(config=config)
    _save_or_update_trials(
        experiment=experiment,
        trials=trials,
        encoder=encoder,
        decoder=decoder,
        batch_size=batch_size,
    )
Beispiel #26
0
def _update_generation_strategy(
    generation_strategy: GenerationStrategy,
    generator_runs: List[GeneratorRun],
    encoder: Encoder,
) -> None:
    """Update generation strategy's current step and attach generator runs."""
    gs_sqa_class = encoder.config.class_to_sqa_class[GenerationStrategy]

    gs_id = generation_strategy.db_id
    if gs_id is None:
        raise ValueError(
            "GenerationStrategy must be saved before being updated.")

    if any(gr.db_id for gr in generator_runs):
        raise ValueError("Can only save new GeneratorRuns.")

    experiment_id = generation_strategy.experiment.db_id
    if experiment_id is None:
        raise ValueError(  # pragma: no cover
            f"Experiment {generation_strategy.experiment.name} "
            "should be saved before generation strategy.")

    obj_to_sqa = []
    with session_scope() as session:
        session.query(gs_sqa_class).filter_by(id=gs_id).update({
            "curr_index":
            generation_strategy._curr.index,
            "experiment_id":
            experiment_id,
        })

    generator_runs_sqa = []
    for generator_run in generator_runs:
        gr_sqa, _obj_to_sqa = encoder.generator_run_to_sqa(
            generator_run=generator_run)
        obj_to_sqa.extend(_obj_to_sqa)
        gr_sqa.generation_strategy_id = gs_id
        generator_runs_sqa.append(gr_sqa)

    with session_scope() as session:
        session.add_all(generator_runs_sqa)

    _set_db_ids(obj_to_sqa=obj_to_sqa)
Beispiel #27
0
def _save_experiment(
    experiment: Experiment, encoder: Encoder, overwrite: bool = False
) -> None:
    """Save experiment, using given Encoder instance.

    1) Convert Ax object to SQLAlchemy object.
    2) Determine if there is an existing experiment with that name in the DB.
    3) If not, create a new one.
    4) If so, update the old one.
        The update works by merging the new SQLAlchemy object into the
        existing SQLAlchemy object, and then letting SQLAlchemy handle the
        actual DB updates.
    """
    # Convert user-facing class to SQA outside of session scope to avoid timeouts
    new_sqa_experiment = encoder.experiment_to_sqa(experiment)
    exp_sqa_class = encoder.config.class_to_sqa_class[Experiment]
    with session_scope() as session:
        existing_sqa_experiment = (
            session.query(exp_sqa_class).filter_by(name=experiment.name).one_or_none()
        )

    if existing_sqa_experiment is not None:
        if (
            not datetime_equals(
                existing_sqa_experiment.time_created, new_sqa_experiment.time_created
            )
            and overwrite is False
        ):
            raise Exception(
                "An experiment already exists with the name "
                f"{new_sqa_experiment.name}. To overwrite, specify "
                "`overwrite = True` when calling `save_experiment`."
            )

        # Update the SQA object outside of session scope to avoid timeouts.
        # This object is detached from the session, but contains a database
        # identity marker, so when we do `session.add` below, SQA knows to
        # perform an update rather than an insert.
        existing_sqa_experiment.update(new_sqa_experiment)
        new_sqa_experiment = existing_sqa_experiment

    with session_scope() as session:
        session.add(new_sqa_experiment)
Beispiel #28
0
def _save_generation_strategy(generation_strategy: GenerationStrategy,
                              encoder: Encoder) -> int:
    # If the generation strategy has not yet generated anything, there will be no
    # experiment set on it.
    if generation_strategy._experiment is None:
        experiment_id = None
    else:
        # Experiment was set on the generation strategy, so need to check whether
        # if has been saved and create a relationship b/w GS and experiment if so.
        experiment_id = _get_experiment_id(
            # pyre-fixme[6]: Expected `Experiment` for 1st param but got
            #  `Optional[Experiment]`.
            experiment=generation_strategy._experiment,
            encoder=encoder,
        )

    gs_sqa = encoder.generation_strategy_to_sqa(
        generation_strategy=generation_strategy, experiment_id=experiment_id)

    if generation_strategy._db_id is not None:
        gs_sqa_class = encoder.config.class_to_sqa_class[GenerationStrategy]
        with session_scope() as session:
            existing_gs_sqa = session.query(gs_sqa_class).get(
                generation_strategy._db_id)

        # pyre-fixme[16]: `Optional` has no attribute `update`.
        existing_gs_sqa.update(gs_sqa)
        # our update logic ignores foreign keys, i.e. fields ending in _id,
        # because we want SQLAlchemy to handle those relationships for us
        # however, generation_strategy.experiment_id is an exception, so we
        # need to update that manually
        # pyre-fixme[16]: `Optional` has no attribute `experiment_id`.
        existing_gs_sqa.experiment_id = gs_sqa.experiment_id
        gs_sqa = existing_gs_sqa

    with session_scope() as session:
        session.add(gs_sqa)
        session.flush()  # Ensures generation strategy id is set.

    # pyre-fixme[16]: `None` has no attribute `id`.
    generation_strategy._db_id = gs_sqa.id
    # pyre-fixme[7]: Expected `int` but got `Optional[int]`.
    return generation_strategy._db_id
Beispiel #29
0
 def test_sqa_storage(self):
     init_test_engine_and_session_factory(force_init=True)
     config = SQAConfig()
     encoder = Encoder(config=config)
     decoder = Decoder(config=config)
     db_settings = DBSettings(encoder=encoder, decoder=decoder)
     ax_client = AxClient(db_settings=db_settings)
     ax_client.create_experiment(
         name="test_experiment",
         parameters=[
             {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
             {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
         ],
         minimize=True,
     )
     for _ in range(5):
         parameters, trial_index = ax_client.get_next_trial()
         ax_client.complete_trial(
             trial_index=trial_index, raw_data=branin(*parameters.values())
         )
     gs = ax_client.generation_strategy
     ax_client = AxClient(db_settings=db_settings)
     ax_client.load_experiment_from_database("test_experiment")
     self.assertEqual(gs, ax_client.generation_strategy)
     with self.assertRaises(ValueError):
         # Overwriting existing experiment.
         ax_client.create_experiment(
             name="test_experiment",
             parameters=[
                 {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
                 {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
             ],
             minimize=True,
         )
     # Overwriting existing experiment with overwrite flag.
     ax_client.create_experiment(
         name="test_experiment",
         parameters=[{"name": "x1", "type": "range", "bounds": [-5.0, 10.0]}],
         overwrite_existing_experiment=True,
     )
     # There should be no trials, as we just put in a fresh experiment.
     self.assertEqual(len(ax_client.experiment.trials), 0)
Beispiel #30
0
def _save_generation_strategy(generation_strategy: GenerationStrategy,
                              encoder: Encoder) -> int:
    # If the generation strategy has not yet generated anything, there will be no
    # experiment set on it.
    experiment = generation_strategy._experiment
    if experiment is None:
        experiment_id = None
    else:
        # Experiment was set on the generation strategy, so need to check whether
        # if has been saved and create a relationship b/w GS and experiment if so.
        experiment_id = experiment.db_id
        if experiment_id is None:
            raise ValueError(  # pragma: no cover
                f"Experiment {experiment.name} should be saved before "
                "generation strategy.")

    gs_sqa, obj_to_sqa = encoder.generation_strategy_to_sqa(
        generation_strategy=generation_strategy, experiment_id=experiment_id)

    if generation_strategy._db_id is not None:
        gs_sqa_class = encoder.config.class_to_sqa_class[GenerationStrategy]
        with session_scope() as session:
            existing_gs_sqa = session.query(gs_sqa_class).get(
                generation_strategy._db_id)

        # pyre-fixme[16]: `Optional` has no attribute `update`.
        existing_gs_sqa.update(gs_sqa)
        # our update logic ignores foreign keys, i.e. fields ending in _id,
        # because we want SQLAlchemy to handle those relationships for us
        # however, generation_strategy.experiment_id is an exception, so we
        # need to update that manually
        # pyre-fixme[16]: `Optional` has no attribute `experiment_id`.
        existing_gs_sqa.experiment_id = gs_sqa.experiment_id
        gs_sqa = existing_gs_sqa

    with session_scope() as session:
        session.add(gs_sqa)
        session.flush()  # Ensures generation strategy id is set.

    _set_db_ids(obj_to_sqa=obj_to_sqa)

    return not_none(generation_strategy.db_id)