def test_save_updated_trial(self): experiment, _ = self.init_experiment_and_generation_strategy( save_generation_strategy=False ) exp = _load_experiment( experiment.name, decoder=self.with_db_settings.db_settings.decoder ) trial = exp.new_trial() _save_or_update_trials( experiment=experiment, trials=[trial], encoder=self.with_db_settings.db_settings.encoder, decoder=self.with_db_settings.db_settings.decoder, ) self.assertEqual(trial.status, TrialStatus.CANDIDATE) trial.mark_running(True) saved = self.with_db_settings._save_or_update_trial_in_db_if_possible( exp, trial ) self.assertTrue(saved) exp = _load_experiment( experiment.name, decoder=self.with_db_settings.db_settings.decoder ) self.assertEqual(len(exp.trials), 1) self.assertEqual(exp.trials[0].status, TrialStatus.RUNNING)
def test_save_new_trial(self): experiment, _ = self.init_experiment_and_generation_strategy( save_generation_strategy=False) exp = _load_experiment( experiment.name, decoder=self.with_db_settings.db_settings.decoder) trial = exp.new_trial() saved = self.with_db_settings._save_new_trial_to_db_if_possible( exp, trial) self.assertTrue(saved) exp = _load_experiment( experiment.name, decoder=self.with_db_settings.db_settings.decoder) self.assertEqual(len(exp.trials), 1) self.assertEqual(exp.trials[0].status, TrialStatus.CANDIDATE)
def test_save_experiment(self): experiment = self.get_random_experiment() saved = self.with_db_settings._save_experiment_to_db_if_possible(experiment) self.assertTrue(saved) loaded_experiment = _load_experiment( experiment.name, decoder=self.with_db_settings.db_settings.decoder ) self.assertIsNotNone(loaded_experiment) self.assertEqual(experiment, loaded_experiment)
def test_update_experiment_properties_in_db(self): experiment, _ = self.init_experiment_and_generation_strategy( save_generation_strategy=False) experiment._properties["test_property"] = True self.with_db_settings._update_experiment_properties_in_db( experiment_with_updated_properties=experiment) loaded_experiment = _load_experiment( experiment.name, decoder=self.with_db_settings.db_settings.decoder) self.assertEqual(loaded_experiment._properties, {"test_property": True})
def test_updated_trials_mini_batch(self): experiment, _ = self.init_experiment_and_generation_strategy( save_generation_strategy=False ) # Check with 1 trial. trial = experiment.new_trial() self.assertIsNone(trial.db_id) self.with_db_settings._save_or_update_trials_in_db_if_possible( experiment=experiment, trials=[trial], ) loaded_experiment = _load_experiment( experiment.name, decoder=self.with_db_settings.db_settings.decoder ) self.assertEqual( loaded_experiment.trials.get(trial.index).status, TrialStatus.CANDIDATE ) self.assertIsNotNone(trial.db_id) # Check with multiple trials, where their number % mini batch size is not 0. trials = [experiment.new_trial() for _ in range(5)] for t in trials: self.assertIsNone(t.db_id) trial.mark_running(no_runner_required=True) trials.append(trial) self.with_db_settings._save_or_update_trials_in_db_if_possible( experiment=experiment, trials=trials, ) loaded_experiment = _load_experiment( experiment.name, decoder=self.with_db_settings.db_settings.decoder ) # All trials except for the one we marked as running should be candidates. for t in trials: self.assertIsNotNone(t.db_id) if t.index != trial.index: self.assertEqual(t.status, TrialStatus.CANDIDATE) else: self.assertEqual(t.status, TrialStatus.RUNNING)
def load_experiment(name: str, db_settings: DBSettings) -> Experiment: """ Load experiment from the db. Service API only supports `Experiment`. Args: name: Experiment name. db_settings: Defines behavior for loading/saving experiment to/from db. Returns: ax.core.Experiment: Loaded experiment. """ init_engine_and_session_factory(creator=db_settings.creator, url=db_settings.url) experiment = _load_experiment(name, decoder=db_settings.decoder) if not isinstance(experiment, Experiment) or experiment.is_simple_experiment: raise ValueError("Service API only supports Experiment") return experiment
def load_experiment(name: str, db_settings: DBSettings) -> Experiment: """ Load experiment from the db. Service API only supports `Experiment`. Args: name: Experiment name. db_settings: Specifies decoder and xdb tier where experiment is stored. Returns: ax.core.Experiment: Loaded experiment. """ initialize_db(db_settings) experiment = _load_experiment(name, db_settings.decoder) if not isinstance(experiment, Experiment): raise ValueError("Service API only supports Experiment") return experiment
def _load_experiment_and_generation_strategy( self, experiment_name: str ) -> Tuple[Optional[Experiment], Optional[GenerationStrategy]]: """Loads experiment and its corresponding generation strategy from database if DB settings are set on this `WithDBSettingsBase` instance. Args: experiment_name: Name of the experiment to load, used as unique identifier by which to find the experiment. Returns: - Tuple of `None` and `None` if `DBSettings` are set and no experiment exists by the given name. - Tuple of `Experiment` and `None` if experiment exists but does not have a generation strategy attached to it. - Tuple of `Experiment` and `GenerationStrategy` if experiment exists and has a generation strategy attached to it. """ if not self.db_settings_set: raise ValueError("Cannot load from DB in absence of DB settings.") start_time = time.time() experiment = _load_experiment(experiment_name, decoder=self.db_settings.decoder) if not isinstance(experiment, Experiment) or experiment.is_simple_experiment: raise ValueError("Service API only supports `Experiment`.") logger.debug( f"Loaded experiment {experiment_name} in " f"{_round_floats_for_logging(time.time() - start_time)} seconds.") try: start_time = time.time() generation_strategy = _load_generation_strategy_by_experiment_name( experiment_name=experiment_name, decoder=self.db_settings.decoder) logger.debug( f"Loaded generation strategy for experiment {experiment_name} in " f"{_round_floats_for_logging(time.time() - start_time)} seconds." ) except ValueError as err: if "does not have a generation strategy" in str(err): return experiment, None raise # `ValueError` here could signify more than just absence of GS. return experiment, generation_strategy
def test_update_reduced_state_generator_runs(self): experiment, generation_strategy = self.init_experiment_and_generation_strategy( save_generation_strategy=True ) trials = [experiment.new_trial() for _ in range(5)] grs = [] for t in trials: gr = generation_strategy.gen(experiment) grs.append(gr) t.add_generator_run(gr) self.with_db_settings._save_or_update_trials_and_generation_strategy_if_possible( # noqa E501 experiment=experiment, trials=trials, generation_strategy=generation_strategy, new_generator_runs=grs, reduce_state_generator_runs=True, ) loaded_experiment = _load_experiment( experiment.name, decoder=self.with_db_settings.db_settings.decoder ) # Only the last trial's generator run should have large model attributes for idx, trial in loaded_experiment.trials.items(): for key in [f"_{attr.key}" for attr in GR_LARGE_MODEL_ATTRS]: if idx < len(loaded_experiment.trials) - 1: self.assertIsNone(getattr(trial.generator_run, key)) else: self.assertIsNotNone(getattr(trial.generator_run, key)) loaded_generation_strategy = _load_generation_strategy_by_experiment_name( experiment.name, decoder=self.with_db_settings.db_settings.decoder ) # Only the last generator run should have large model attributes for idx, gr in enumerate(loaded_generation_strategy._generator_runs): for key in [f"_{attr.key}" for attr in GR_LARGE_MODEL_ATTRS]: if idx < len(loaded_generation_strategy._generator_runs) - 1: self.assertIsNone(getattr(gr, key)) else: self.assertIsNotNone(getattr(gr, key))
def load_experiment(name: str, db_settings: DBSettings) -> Experiment: """ Load experiment from the db. Service API only supports `Experiment`. Args: name: Experiment name. db_settings: Defines behavior for loading/saving experiment to/from db. Returns: ax.core.Experiment: Loaded experiment. """ init_engine_and_session_factory(creator=db_settings.creator, url=db_settings.url) start_time = time.time() experiment = _load_experiment(name, decoder=db_settings.decoder) if not isinstance(experiment, Experiment) or experiment.is_simple_experiment: raise ValueError("Service API only supports `Experiment`.") logger.debug( f"Loaded experiment {name} in " f"{_round_floats_for_logging(time.time() - start_time)} seconds.") return experiment
def _load_experiment_and_generation_strategy( self, experiment_name: str, reduced_state: bool = False, ) -> Tuple[Optional[Experiment], Optional[GenerationStrategy]]: """Loads experiment and its corresponding generation strategy from database if DB settings are set on this `WithDBSettingsBase` instance. Args: experiment_name: Name of the experiment to load, used as unique identifier by which to find the experiment. reduced_state: Whether to load experiment and generation strategy with a slightly reduced state (without abandoned arms on experiment and model state on each generator run in experiment and generation strategy; last generator run on generation strategy will still have its model state). Returns: - Tuple of `None` and `None` if `DBSettings` are set and no experiment exists by the given name. - Tuple of `Experiment` and `None` if experiment exists but does not have a generation strategy attached to it. - Tuple of `Experiment` and `GenerationStrategy` if experiment exists and has a generation strategy attached to it. """ if not self.db_settings_set: raise ValueError("Cannot load from DB in absence of DB settings.") logger.info( "Loading experiment and generation strategy (with reduced state: " f"{reduced_state})...") start_time = time.time() experiment = _load_experiment( experiment_name, decoder=self.db_settings.decoder, reduced_state=reduced_state, load_trials_in_batches_of_size=LOADING_MINI_BATCH_SIZE, ) if not isinstance(experiment, Experiment): raise ValueError("Service API only supports `Experiment`.") logger.info( f"Loaded experiment {experiment_name} in " f"{_round_floats_for_logging(time.time() - start_time)} seconds, " f"loading trials in mini-batches of {LOADING_MINI_BATCH_SIZE}.") try: start_time = time.time() generation_strategy = _load_generation_strategy_by_experiment_name( experiment_name=experiment_name, decoder=self.db_settings.decoder, experiment=experiment, reduced_state=reduced_state, ) logger.info( f"Loaded generation strategy for experiment {experiment_name} in " f"{_round_floats_for_logging(time.time() - start_time)} seconds." ) except ValueError as err: if "does not have a generation strategy" in str(err): return experiment, None raise # `ValueError` here could signify more than just absence of GS. return experiment, generation_strategy