def test_invalid_heartbeat_interval_and_grace_period(storage_mode: str) -> None: with pytest.raises(ValueError): with StorageSupplier(storage_mode, heartbeat_interval=-1): pass with pytest.raises(ValueError): with StorageSupplier(storage_mode, grace_period=-1): pass
def test_run_trial_exception(storage_mode: str) -> None: with StorageSupplier(storage_mode) as storage: study = create_study(storage=storage) with pytest.raises(ValueError): _optimize._run_trial(study, fail_objective, ()) # Test trial with unacceptable exception. with StorageSupplier(storage_mode) as storage: study = create_study(storage=storage) with pytest.raises(ValueError): _optimize._run_trial(study, fail_objective, (ArithmeticError, ))
def test_pytorch_lightning_pruning_callback_ddp_monitor( storage_mode: str, ) -> None: def objective(trial: optuna.trial.Trial) -> float: trainer = pl.Trainer( max_epochs=2, accelerator="ddp_cpu", num_processes=2, enable_checkpointing=False, callbacks=[PyTorchLightningPruningCallback(trial, monitor="accuracy")], ) model = ModelDDP() trainer.fit(model) return 1.0 with StorageSupplier(storage_mode) as storage: study = optuna.create_study(storage=storage, pruner=DeterministicPruner(True)) study.optimize(objective, n_trials=1) assert study.trials[0].state == optuna.trial.TrialState.PRUNED assert list(study.trials[0].intermediate_values.keys()) == [0] np.testing.assert_almost_equal(study.trials[0].intermediate_values[0], 0.45) study = optuna.create_study(storage=storage, pruner=DeterministicPruner(False)) study.optimize(objective, n_trials=1) assert study.trials[0].state == optuna.trial.TrialState.COMPLETE assert study.trials[0].value == 1.0 assert list(study.trials[0].intermediate_values.keys()) == [0, 1] np.testing.assert_almost_equal(study.trials[0].intermediate_values[0], 0.45) np.testing.assert_almost_equal(study.trials[0].intermediate_values[1], 0.45)
def test_run_trial(storage_mode: str, caplog: LogCaptureFixture) -> None: with StorageSupplier(storage_mode) as storage: study = create_study(storage=storage) caplog.clear() frozen_trial = _optimize._run_trial(study, lambda _: 1.0, catch=()) assert frozen_trial.state == TrialState.COMPLETE assert frozen_trial.value == 1.0 assert "Trial 0 finished with value: 1.0 and parameters" in caplog.text caplog.clear() frozen_trial = _optimize._run_trial(study, lambda _: float("inf"), catch=()) assert frozen_trial.state == TrialState.COMPLETE assert frozen_trial.value == float("inf") assert "Trial 1 finished with value: inf and parameters" in caplog.text caplog.clear() frozen_trial = _optimize._run_trial(study, lambda _: -float("inf"), catch=()) assert frozen_trial.state == TrialState.COMPLETE assert frozen_trial.value == -float("inf") assert "Trial 2 finished with value: -inf and parameters" in caplog.text
def test_get_param_importance_target_is_none_and_study_is_multi_obj( storage_mode: str, evaluator_init_func: Callable[[], BaseImportanceEvaluator], ) -> None: def objective(trial: Trial) -> Tuple[float, float]: x1 = trial.suggest_float("x1", 0.1, 3) x2 = trial.suggest_float("x2", 0.1, 3, log=True) x3 = trial.suggest_float("x3", 0, 3, step=1) x4 = trial.suggest_int("x4", -3, 3) x5 = trial.suggest_int("x5", 1, 5, log=True) x6 = trial.suggest_categorical("x6", [1.0, 1.1, 1.2]) if trial.number % 2 == 0: # Conditional parameters are ignored unless `params` is specified and is not `None`. x7 = trial.suggest_float("x7", 0.1, 3) assert isinstance(x6, float) value = x1**4 + x2 + x3 - x4**2 - x5 + x6 if trial.number % 2 == 0: value += x7 return value, 0.0 with StorageSupplier(storage_mode) as storage: study = create_study(directions=["minimize", "minimize"], storage=storage) study.optimize(objective, n_trials=3) with pytest.raises(ValueError): get_param_importances(study, evaluator=evaluator_init_func())
def test_trials_dataframe_with_failure(storage_mode: str) -> None: def f(trial: Trial) -> float: x = trial.suggest_int("x", 1, 1) y = trial.suggest_categorical("y", (2.5,)) trial.set_user_attr("train_loss", 3) raise ValueError() return x + y # 3.5 with StorageSupplier(storage_mode) as storage: study = create_study(storage=storage) study.optimize(f, n_trials=3, catch=(ValueError,)) df = study.trials_dataframe() # Change index to access rows via trial number. df.set_index("number", inplace=True, drop=False) assert len(df) == 3 # non-nested: 6, params: 2, user_attrs: 1 system_attrs: 0 assert len(df.columns) == 9 for i in range(3): assert df.number[i] == i assert df.state[i] == "FAIL" assert df.value[i] is None assert isinstance(df.datetime_start[i], pd.Timestamp) assert isinstance(df.datetime_complete[i], pd.Timestamp) assert isinstance(df.duration[i], pd.Timedelta) assert df.params_x[i] == 1 assert df.params_y[i] == 2.5 assert df.user_attrs_train_loss[i] == 3
def test_get_n_trials_state_option(storage_mode: str) -> None: with StorageSupplier(storage_mode) as storage: study_id = storage.create_new_study() storage.set_study_directions(study_id, (StudyDirection.MAXIMIZE, )) generator = random.Random(51) states = [ TrialState.COMPLETE, TrialState.COMPLETE, TrialState.PRUNED, ] for s in states: t = _generate_trial(generator) t.state = s storage.create_new_trial(study_id, template_trial=t) assert storage.get_n_trials(study_id, TrialState.COMPLETE) == 2 assert storage.get_n_trials(study_id, TrialState.PRUNED) == 1 other_states = [ s for s in ALL_STATES if s != TrialState.COMPLETE and s != TrialState.PRUNED ] for s in other_states: assert storage.get_n_trials(study_id, s) == 0
def test_set_and_get_study_system_attrs(storage_mode: str) -> None: with StorageSupplier(storage_mode) as storage: study_id = storage.create_new_study() def check_set_and_get(key: str, value: Any) -> None: storage.set_study_system_attr(study_id, key, value) assert storage.get_study_system_attrs(study_id)[key] == value # Test setting value. for key, value in EXAMPLE_ATTRS.items(): check_set_and_get(key, value) assert storage.get_study_system_attrs(study_id) == EXAMPLE_ATTRS # Test overwriting value. check_set_and_get("dataset", "ImageNet") # Non-existent study id. non_existent_study_id = study_id + 1 with pytest.raises(KeyError): storage.get_study_system_attrs(non_existent_study_id) # Non-existent study id. with pytest.raises(KeyError): storage.set_study_system_attr(non_existent_study_id, "key", "value")
def test_set_trial_user_attr(storage_mode: str) -> None: with StorageSupplier(storage_mode) as storage: trial_id_1 = storage.create_new_trial(storage.create_new_study()) def check_set_and_get(trial_id: int, key: str, value: Any) -> None: storage.set_trial_user_attr(trial_id, key, value) assert storage.get_trial(trial_id).user_attrs[key] == value # Test setting value. for key, value in EXAMPLE_ATTRS.items(): check_set_and_get(trial_id_1, key, value) assert storage.get_trial(trial_id_1).user_attrs == EXAMPLE_ATTRS # Test overwriting value. check_set_and_get(trial_id_1, "dataset", "ImageNet") # Test another trial. trial_id_2 = storage.create_new_trial(storage.create_new_study()) check_set_and_get(trial_id_2, "baseline_score", 0.001) assert len(storage.get_trial(trial_id_2).user_attrs) == 1 assert storage.get_trial( trial_id_2).user_attrs["baseline_score"] == 0.001 # Cannot set attributes of non-existent trials. non_existent_trial_id = max({trial_id_1, trial_id_2}) + 1 with pytest.raises(KeyError): storage.set_trial_user_attr(non_existent_trial_id, "key", "value") # Cannot set attributes of finished trials. storage.set_trial_state_values(trial_id_1, state=TrialState.COMPLETE) with pytest.raises(RuntimeError): storage.set_trial_user_attr(trial_id_1, "key", "value")
def test_get_trial_param_and_get_trial_params(storage_mode: str) -> None: with StorageSupplier(storage_mode) as storage: _, study_to_trials = _setup_studies(storage, n_study=2, n_trial=5, seed=1) for _, trial_id_to_trial in study_to_trials.items(): for trial_id, expected_trial in trial_id_to_trial.items(): assert storage.get_trial_params( trial_id) == expected_trial.params for key in expected_trial.params.keys(): assert storage.get_trial_param( trial_id, key ) == expected_trial.distributions[key].to_internal_repr( expected_trial.params[key]) non_existent_trial_id = (max(tid for ts in study_to_trials.values() for tid in ts.keys()) + 1) with pytest.raises(KeyError): storage.get_trial_params(non_existent_trial_id) with pytest.raises(KeyError): storage.get_trial_param(non_existent_trial_id, "paramA") existent_trial_id = non_existent_trial_id - 1 with pytest.raises(KeyError): storage.get_trial_param(existent_trial_id, "dummy-key")
def test_failed_trial_callback(storage_mode: str) -> None: heartbeat_interval = 1 grace_period = 2 def _failed_trial_callback(study: Study, trial: FrozenTrial) -> None: assert study.system_attrs["test"] == "A" assert trial.system_attrs["test"] == "B" failed_trial_callback = Mock(wraps=_failed_trial_callback) with StorageSupplier( storage_mode, heartbeat_interval=heartbeat_interval, grace_period=grace_period, failed_trial_callback=failed_trial_callback, ) as storage: assert is_heartbeat_enabled(storage) assert isinstance(storage, BaseHeartbeat) study = optuna.create_study(storage=storage) study.set_system_attr("test", "A") with pytest.warns(UserWarning): trial = study.ask() trial.set_system_attr("test", "B") storage.record_heartbeat(trial._trial_id) time.sleep(grace_period + 1) # Exceptions raised in spawned threads are caught by `_TestableThread`. with patch("optuna.storages._heartbeat.Thread", _TestableThread): study.optimize(lambda _: 1.0, n_trials=1) failed_trial_callback.assert_called_once()
def test_fail_stale_trials(storage_mode: str, grace_period: Optional[int]) -> None: heartbeat_interval = 1 _grace_period = (heartbeat_interval * 2) if grace_period is None else grace_period def failed_trial_callback(study: "optuna.Study", trial: FrozenTrial) -> None: assert study.system_attrs["test"] == "A" assert trial.system_attrs["test"] == "B" with StorageSupplier(storage_mode) as storage: assert isinstance(storage, (RDBStorage, RedisStorage)) storage.heartbeat_interval = heartbeat_interval storage.grace_period = grace_period storage.failed_trial_callback = failed_trial_callback study = optuna.create_study(storage=storage) study.set_system_attr("test", "A") with pytest.warns(UserWarning): trial = study.ask() trial.set_system_attr("test", "B") storage.record_heartbeat(trial._trial_id) time.sleep(_grace_period + 1) assert study.trials[0].state is TrialState.RUNNING optuna.storages.fail_stale_trials(study) assert study.trials[ 0].state is TrialState.FAIL # type: ignore [comparison-overlap]
def test_run_trial_pruned(storage_mode: str, caplog: LogCaptureFixture) -> None: def gen_func( intermediate: Optional[float] = None) -> Callable[[Trial], float]: def func(trial: Trial) -> float: if intermediate is not None: trial.report(step=1, value=intermediate) raise TrialPruned return func with StorageSupplier(storage_mode) as storage: study = create_study(storage=storage) caplog.clear() frozen_trial = _optimize._run_trial(study, gen_func(), catch=()) assert frozen_trial.state == TrialState.PRUNED assert frozen_trial.value is None assert "Trial 0 pruned." in caplog.text caplog.clear() frozen_trial = _optimize._run_trial(study, gen_func(intermediate=1), catch=()) assert frozen_trial.state == TrialState.PRUNED assert frozen_trial.value == 1 assert "Trial 1 pruned." in caplog.text caplog.clear() frozen_trial = _optimize._run_trial( study, gen_func(intermediate=float("nan")), catch=()) assert frozen_trial.state == TrialState.PRUNED assert frozen_trial.value is None assert "Trial 2 pruned." in caplog.text
def test_fail_stale_trials_with_optimize(storage_mode: str) -> None: heartbeat_interval = 1 grace_period = 2 with StorageSupplier(storage_mode, heartbeat_interval=heartbeat_interval, grace_period=grace_period) as storage: assert is_heartbeat_enabled(storage) assert isinstance(storage, BaseHeartbeat) study1 = optuna.create_study(storage=storage) study2 = optuna.create_study(storage=storage) with pytest.warns(UserWarning): trial1 = study1.ask() trial2 = study2.ask() storage.record_heartbeat(trial1._trial_id) storage.record_heartbeat(trial2._trial_id) time.sleep(grace_period + 1) assert study1.trials[0].state is TrialState.RUNNING assert study2.trials[0].state is TrialState.RUNNING # Exceptions raised in spawned threads are caught by `_TestableThread`. with patch("optuna.storages._heartbeat.Thread", _TestableThread): study1.optimize(lambda _: 1.0, n_trials=1) assert study1.trials[ 0].state is TrialState.FAIL # type: ignore [comparison-overlap] assert study2.trials[0].state is TrialState.RUNNING
def test_check_distribution_suggest_float(storage_mode: str) -> None: sampler = samplers.RandomSampler() with StorageSupplier(storage_mode) as storage: study = create_study(storage=storage, sampler=sampler) trial = Trial(study, study._storage.create_new_trial(study._study_id)) x1 = trial.suggest_float("x1", 10, 20) x2 = trial.suggest_uniform("x1", 10, 20) assert x1 == x2 x3 = trial.suggest_float("x2", 1e-5, 1e-3, log=True) x4 = trial.suggest_loguniform("x2", 1e-5, 1e-3) assert x3 == x4 x5 = trial.suggest_float("x3", 10, 20, step=1.0) x6 = trial.suggest_discrete_uniform("x3", 10, 20, 1.0) assert x5 == x6 with pytest.raises(ValueError): trial.suggest_float("x4", 1e-5, 1e-2, step=1e-5, log=True) with pytest.raises(ValueError): trial.suggest_int("x1", 10, 20) trial = Trial(study, study._storage.create_new_trial(study._study_id)) with pytest.raises(ValueError): trial.suggest_int("x1", 10, 20)
def test_suggest_loguniform(storage_mode: str) -> None: with pytest.raises(ValueError): FloatDistribution(low=1.0, high=0.9, log=True) with pytest.raises(ValueError): FloatDistribution(low=0.0, high=0.9, log=True) mock = Mock() mock.side_effect = [1.0, 2.0] sampler = samplers.RandomSampler() with patch.object( sampler, "sample_independent", mock) as mock_object, StorageSupplier(storage_mode) as storage: study = create_study(storage=storage, sampler=sampler) trial = Trial(study, study._storage.create_new_trial(study._study_id)) distribution = FloatDistribution(low=0.1, high=4.0, log=True) assert trial._suggest("x", distribution) == 1.0 # Test suggesting a param. assert trial._suggest( "x", distribution) == 1.0 # Test suggesting the same param. assert trial._suggest( "y", distribution) == 2.0 # Test suggesting a different param. assert trial.params == {"x": 1.0, "y": 2.0} assert mock_object.call_count == 2
def test_updates_properties(storage_mode: str) -> None: """Check for any distributed deadlock following a property read.""" with StorageSupplier(storage_mode) as storage: if dist.get_rank() == 0: # type: ignore study = optuna.create_study(storage=storage) trial = TorchDistributedTrial(study.ask()) else: trial = TorchDistributedTrial(None) trial.suggest_float("f", 0, 1) trial.suggest_int("i", 0, 1) trial.suggest_categorical("c", ("a", "b", "c")) property_names = [ p for p in dir(TorchDistributedTrial) if isinstance(getattr(TorchDistributedTrial, p), property) ] # Rank 0 can read properties without deadlock. if dist.get_rank() == 0: # type: ignore [getattr(trial, p) for p in property_names] dist.barrier() # type: ignore # Same with rank 1. if dist.get_rank() == 1: # type: ignore [getattr(trial, p) for p in property_names] dist.barrier() # type: ignore
def test_get_param_importances_with_params( storage_mode: str, params: List[str], evaluator_init_func: Callable[[], BaseImportanceEvaluator], ) -> None: def objective(trial: Trial) -> float: x1 = trial.suggest_float("x1", 0.1, 3) x2 = trial.suggest_float("x2", 0.1, 3, log=True) x3 = trial.suggest_float("x3", 0, 3, step=1) if trial.number % 2 == 0: x4 = trial.suggest_float("x4", 0.1, 3) value = x1**4 + x2 + x3 if trial.number % 2 == 0: value += x4 return value with StorageSupplier(storage_mode) as storage: study = create_study(storage=storage) study.optimize(objective, n_trials=10) param_importance = get_param_importances( study, evaluator=evaluator_init_func(), params=params) assert isinstance(param_importance, OrderedDict) assert len(param_importance) == len(params) assert all(param in param_importance for param in params) for param_name, importance in param_importance.items(): assert isinstance(param_name, str) assert isinstance(importance, float) # Sanity check for param importances assert all(0 <= x < float("inf") for x in param_importance.values())
def test_get_trial_id_from_study_id_trial_number(storage_mode: str) -> None: with StorageSupplier(storage_mode) as storage: with pytest.raises(KeyError): # Matching study does not exist. storage.get_trial_id_from_study_id_trial_number(study_id=0, trial_number=0) study_id = storage.create_new_study() with pytest.raises(KeyError): # Matching trial does not exist. storage.get_trial_id_from_study_id_trial_number(study_id, trial_number=0) trial_id = storage.create_new_trial(study_id) assert trial_id == storage.get_trial_id_from_study_id_trial_number( study_id, trial_number=0) # Trial IDs are globally unique within a storage but numbers are only unique within a # study. Create a second study within the same storage. study_id = storage.create_new_study() trial_id = storage.create_new_trial(study_id) assert trial_id == storage.get_trial_id_from_study_id_trial_number( study_id, trial_number=0)
def test_set_and_get_study_directions(storage_mode: str) -> None: with StorageSupplier(storage_mode) as storage: for target, opposite in [ ((StudyDirection.MINIMIZE, ), (StudyDirection.MAXIMIZE, )), ((StudyDirection.MAXIMIZE, ), (StudyDirection.MINIMIZE, )), ( (StudyDirection.MINIMIZE, StudyDirection.MAXIMIZE), (StudyDirection.MAXIMIZE, StudyDirection.MINIMIZE), ), ( [StudyDirection.MINIMIZE, StudyDirection.MAXIMIZE], [StudyDirection.MAXIMIZE, StudyDirection.MINIMIZE], ), ]: study_id = storage.create_new_study() def check_set_and_get( directions: Sequence[StudyDirection]) -> None: storage.set_study_directions(study_id, directions) got_directions = storage.get_study_directions(study_id) assert got_directions == list( directions ), "Direction of a study should be a tuple of `StudyDirection` objects." directions = storage.get_study_directions(study_id) assert len(directions) == 1 assert directions[0] == StudyDirection.NOT_SET # Test setting value. check_set_and_get(target) # Test overwriting value to the same direction. storage.set_study_directions(study_id, target) # Test overwriting value to the opposite direction. with pytest.raises(ValueError): storage.set_study_directions(study_id, opposite) # Test overwriting value to the not set. with pytest.raises(ValueError): storage.set_study_directions(study_id, (StudyDirection.NOT_SET, )) # Test non-existent study. non_existent_study_id = study_id + 1 with pytest.raises(KeyError): storage.get_study_directions(non_existent_study_id) # Test non-existent study. with pytest.raises(KeyError): storage.set_study_directions(non_existent_study_id, opposite) # Test non-existent study is checked before directions. with pytest.raises(KeyError): storage.set_study_directions(non_existent_study_id, (StudyDirection.NOT_SET, ))
def test_run_trial_catch_exception(storage_mode: str) -> None: with StorageSupplier(storage_mode) as storage: study = create_study(storage=storage) frozen_trial = _optimize._run_trial(study, fail_objective, catch=(ValueError, )) assert frozen_trial.state == TrialState.FAIL assert STUDY_TELL_WARNING_KEY not in frozen_trial.system_attrs
def test_suggest_int_log_invalid_range(storage_mode: str) -> None: sampler = samplers.RandomSampler() with StorageSupplier(storage_mode) as storage: study = create_study(storage=storage, sampler=sampler) trial = Trial(study, study._storage.create_new_trial(study._study_id)) with warnings.catch_warnings(): # UserWarning will be raised since [0.5, 10] is not divisible by 1. warnings.simplefilter("ignore", category=UserWarning) with pytest.raises(ValueError): trial.suggest_int("z", 0.5, 10, log=True) # type: ignore with StorageSupplier(storage_mode) as storage: study = create_study(storage=storage, sampler=sampler) trial = Trial(study, study._storage.create_new_trial(study._study_id)) with pytest.raises(ValueError): trial.suggest_int("w", 1, 3, step=2, log=True)
def test_number(storage_mode: str) -> None: with StorageSupplier(storage_mode) as storage: if dist.get_rank() == 0: # type: ignore study = optuna.create_study(storage=storage) trial = TorchDistributedTrial(study.ask()) else: trial = TorchDistributedTrial(None) assert trial.number == 0
def test_datetime_start(storage_mode: str) -> None: with StorageSupplier(storage_mode) as storage: if dist.get_rank() == 0: # type: ignore study = optuna.create_study(storage=storage) trial = TorchDistributedTrial(study.ask()) else: trial = TorchDistributedTrial(None) assert isinstance(trial.datetime_start, datetime.datetime)
def test_system_attrs_with_exception() -> None: with StorageSupplier("sqlite") as storage: if dist.get_rank() == 0: # type: ignore study = optuna.create_study(storage=storage) trial = TorchDistributedTrial(study.ask()) else: trial = TorchDistributedTrial(None) with pytest.raises(TypeError): trial.set_system_attr("not serializable", torch.Tensor([1, 2]))
def test_should_prune(storage_mode: str, is_pruning: bool) -> None: with StorageSupplier(storage_mode) as storage: if dist.get_rank() == 0: # type: ignore study = optuna.create_study(storage=storage, pruner=DeterministicPruner(is_pruning)) trial = TorchDistributedTrial(study.ask()) else: trial = TorchDistributedTrial(None) trial.report(1, 0) assert trial.should_prune() == is_pruning
def test_create_new_trial(storage_mode: str) -> None: def _check_trials( trials: List[FrozenTrial], idx: int, trial_id: int, time_before_creation: datetime, time_after_creation: datetime, ) -> None: assert len(trials) == idx + 1 assert len({t._trial_id for t in trials}) == idx + 1 assert trial_id in {t._trial_id for t in trials} assert {t.number for t in trials} == set(range(idx + 1)) assert all(t.state == TrialState.RUNNING for t in trials) assert all(t.params == {} for t in trials) assert all(t.intermediate_values == {} for t in trials) assert all(t.user_attrs == {} for t in trials) assert all(t.system_attrs == {} for t in trials) assert all(t.datetime_start < time_before_creation for t in trials if t._trial_id != trial_id and t.datetime_start is not None) assert all( time_before_creation < t.datetime_start < time_after_creation for t in trials if t._trial_id == trial_id and t.datetime_start is not None) assert all(t.datetime_complete is None for t in trials) assert all(t.value is None for t in trials) with StorageSupplier(storage_mode) as storage: study_id = storage.create_new_study() n_trial_in_study = 3 for i in range(n_trial_in_study): time_before_creation = datetime.now() trial_id = storage.create_new_trial(study_id) time_after_creation = datetime.now() trials = storage.get_all_trials(study_id) _check_trials(trials, i, trial_id, time_before_creation, time_after_creation) # Create trial in non-existent study. with pytest.raises(KeyError): storage.create_new_trial(study_id + 1) study_id2 = storage.create_new_study() for i in range(n_trial_in_study): storage.create_new_trial(study_id2) trials = storage.get_all_trials(study_id2) # Check that the offset of trial.number is zero. assert {t.number for t in trials} == set(range(i + 1)) trials = storage.get_all_trials(study_id) + storage.get_all_trials( study_id2) # Check trial_ids are unique across studies. assert len({t._trial_id for t in trials}) == 2 * n_trial_in_study
def test_group_decomposed_search_space_with_different_studies() -> None: search_space = _GroupDecomposedSearchSpace() with StorageSupplier("sqlite") as storage: study0 = create_study(storage=storage) study1 = create_study(storage=storage) search_space.calculate(study0) with pytest.raises(ValueError): # `_GroupDecomposedSearchSpace` isn't supposed to be used for multiple studies. search_space.calculate(study1)
def test_create_new_study_with_name(storage_mode: str) -> None: with StorageSupplier(storage_mode) as storage: # Generate unique study_name from the current function name and storage_mode. function_name = test_create_new_study_with_name.__name__ study_name = function_name + "/" + storage_mode study_id = storage.create_new_study(study_name) assert study_name == storage.get_study_name_from_id(study_id) with pytest.raises(optuna.exceptions.DuplicatedStudyError): storage.create_new_study(study_name)
def test_get_n_trials(storage_mode: str) -> None: with StorageSupplier(storage_mode) as storage: study_id_to_frozen_studies, _ = _setup_studies(storage, n_study=2, n_trial=7, seed=50) for study_id in study_id_to_frozen_studies: assert storage.get_n_trials(study_id) == 7 non_existent_study_id = max(study_id_to_frozen_studies.keys()) + 1 with pytest.raises(KeyError): assert storage.get_n_trials(non_existent_study_id)