Exemplo n.º 1
0
def get_storage(storage: Union[None, str, BaseStorage]) -> BaseStorage:

    if storage is None:
        return InMemoryStorage()
    if isinstance(storage, str):
        if storage.startswith("redis"):
            return RedisStorage(storage)
        else:
            return _CachedStorage(RDBStorage(storage))
    elif isinstance(storage, RDBStorage):
        return _CachedStorage(storage)
    else:
        return storage
Exemplo n.º 2
0
def get_storage(storage: Union[None, str, BaseStorage]) -> BaseStorage:
    """Only for internal usage. It might be deprecated in the future."""

    if storage is None:
        return InMemoryStorage()
    if isinstance(storage, str):
        if storage.startswith("redis"):
            return RedisStorage(storage)
        else:
            return _CachedStorage(RDBStorage(storage))
    elif isinstance(storage, RDBStorage):
        return _CachedStorage(storage)
    else:
        return storage
Exemplo n.º 3
0
def test_cached_set() -> None:

    """Test CachedStorage does not flush to persistent storages.

    The CachedStorage does not flush when it modifies trial updates of params or value.

    """

    base_storage = RDBStorage("sqlite:///:memory:")
    storage = _CachedStorage(base_storage)
    study_id = storage.create_new_study("test-study")

    trial_id = storage.create_new_trial(study_id)
    with patch.object(
        base_storage, "_update_trial", return_value=True
    ) as update_mock, patch.object(base_storage, "set_trial_param", return_value=True) as set_mock:
        storage.set_trial_param(
            trial_id, "paramA", 1.2, optuna.distributions.UniformDistribution(-0.2, 2.3)
        )
        assert update_mock.call_count == 0
        assert set_mock.call_count == 0

    trial_id = storage.create_new_trial(study_id)
    with patch.object(
        base_storage, "_update_trial", return_value=True
    ) as update_mock, patch.object(
        base_storage, "set_trial_values", return_value=None
    ) as set_mock:
        storage.set_trial_values(trial_id, (0.3,))
        assert update_mock.call_count == 0
        assert set_mock.call_count == 0
Exemplo n.º 4
0
def test_uncached_set() -> None:
    """Test CachedStorage does flush to persistent storages.

     The CachedStorage flushes modifications of trials to a persistent storage when
     it modifies either intermediate_values, state, user_attrs, or system_attrs.

    """

    base_storage = RDBStorage("sqlite:///:memory:")
    storage = _CachedStorage(base_storage)
    study_id = storage.create_new_study("test-study")

    for state in [
            TrialState.COMPLETE, TrialState.PRUNED, TrialState.FAIL,
            TrialState.WAITING
    ]:
        trial_id = storage.create_new_trial(study_id)
        with patch.object(base_storage, "_update_trial",
                          return_value=True) as update_mock, patch.object(
                              base_storage,
                              "set_trial_state",
                              return_value=True) as set_mock:
            storage.set_trial_state(trial_id, state)
            assert update_mock.call_count == 1
            assert set_mock.call_count == 0

    trial_id = storage.create_new_trial(study_id)
    with patch.object(base_storage, "_update_trial",
                      return_value=True) as update_mock, patch.object(
                          base_storage,
                          "set_trial_intermediate_value",
                          return_value=None) as set_mock:
        storage.set_trial_intermediate_value(trial_id, 3, 0.3)
        assert update_mock.call_count == 1
        assert set_mock.call_count == 0

    trial_id = storage.create_new_trial(study_id)
    with patch.object(base_storage, "_update_trial",
                      return_value=True) as update_mock, patch.object(
                          base_storage,
                          "set_trial_system_attr",
                          return_value=None) as set_mock:
        storage.set_trial_system_attr(trial_id, "attrA", "foo")
        assert update_mock.call_count == 1
        assert set_mock.call_count == 0

    trial_id = storage.create_new_trial(study_id)
    with patch.object(base_storage, "_update_trial",
                      return_value=True) as update_mock, patch.object(
                          base_storage,
                          "set_trial_user_attr",
                          return_value=None) as set_mock:
        storage.set_trial_user_attr(trial_id, "attrB", "bar")
        assert update_mock.call_count == 1
        assert set_mock.call_count == 0
Exemplo n.º 5
0
def test_uncached_set() -> None:

    """Test CachedStorage does flush to persistent storages.

    The CachedStorage flushes any modification of trials to a persistent storage immediately.

    """

    base_storage = RDBStorage("sqlite:///:memory:")
    storage = _CachedStorage(base_storage)
    study_id = storage.create_new_study("test-study")

    trial_id = storage.create_new_trial(study_id)
    trial = storage.get_trial(trial_id)
    with patch.object(base_storage, "set_trial_state_values", return_value=True) as set_mock:
        storage.set_trial_state_values(trial_id, state=trial.state, values=(0.3,))
        assert set_mock.call_count == 1

    trial_id = storage.create_new_trial(study_id)
    with patch.object(
        base_storage, "_check_and_set_param_distribution", return_value=True
    ) as set_mock:
        storage.set_trial_param(
            trial_id, "paramA", 1.2, optuna.distributions.FloatDistribution(-0.2, 2.3)
        )
        assert set_mock.call_count == 1

    trial_id = storage.create_new_trial(study_id)
    with patch.object(base_storage, "set_trial_param", return_value=True) as set_mock:
        storage.set_trial_param(
            trial_id, "paramA", 1.2, optuna.distributions.FloatDistribution(-0.2, 2.3)
        )
        assert set_mock.call_count == 1

    for state in [TrialState.COMPLETE, TrialState.PRUNED, TrialState.FAIL, TrialState.WAITING]:
        trial_id = storage.create_new_trial(study_id)
        with patch.object(base_storage, "set_trial_state_values", return_value=True) as set_mock:
            storage.set_trial_state_values(trial_id, state=state)
            assert set_mock.call_count == 1

    trial_id = storage.create_new_trial(study_id)
    with patch.object(base_storage, "set_trial_intermediate_value", return_value=None) as set_mock:
        storage.set_trial_intermediate_value(trial_id, 3, 0.3)
        assert set_mock.call_count == 1

    trial_id = storage.create_new_trial(study_id)
    with patch.object(base_storage, "set_trial_system_attr", return_value=None) as set_mock:
        storage.set_trial_system_attr(trial_id, "attrA", "foo")
        assert set_mock.call_count == 1

    trial_id = storage.create_new_trial(study_id)
    with patch.object(base_storage, "set_trial_user_attr", return_value=None) as set_mock:
        storage.set_trial_user_attr(trial_id, "attrB", "bar")
        assert set_mock.call_count == 1
Exemplo n.º 6
0
def test_set_trial_state_values() -> None:
    base_storage = RDBStorage("sqlite:///:memory:")
    storage = _CachedStorage(base_storage)
    study_id = storage.create_new_study("test-study")
    trial_id = storage.create_new_trial(study_id)
    storage.set_trial_state_values(trial_id, state=TrialState.COMPLETE)

    cached_trial = storage.get_trial(trial_id)
    base_trial = base_storage.get_trial(trial_id)

    assert cached_trial == base_trial
Exemplo n.º 7
0
def test_read_trials_from_remote_storage() -> None:

    base_storage = RDBStorage("sqlite:///:memory:")
    storage = _CachedStorage(base_storage)
    study_id = storage.create_new_study("test-study")

    storage.read_trials_from_remote_storage(study_id)

    # Non-existent study.
    with pytest.raises(KeyError):
        storage.read_trials_from_remote_storage(study_id + 1)
Exemplo n.º 8
0
def test_create_trial() -> None:
    base_storage = RDBStorage("sqlite:///:memory:")
    storage = _CachedStorage(base_storage)
    study_id = storage.create_new_study("test-study")
    frozen_trial = optuna.trial.FrozenTrial(
        number=1,
        state=TrialState.RUNNING,
        value=None,
        datetime_start=None,
        datetime_complete=None,
        params={},
        distributions={},
        user_attrs={},
        system_attrs={},
        intermediate_values={},
        trial_id=1,
    )
    with patch.object(base_storage, "_create_new_trial", return_value=frozen_trial):
        storage.create_new_trial(study_id)
    storage.create_new_trial(study_id)