コード例 #1
0
ファイル: test_sampler.py プロジェクト: sile/optuna
def test_sample_relative_handle_unsuccessful_states(
    state: optuna.trial.TrialState,
) -> None:
    dist = optuna.distributions.UniformDistribution(1.0, 100.0)

    # Prepare sampling result for later tests.
    study = optuna.create_study()
    for i in range(1, 100):
        trial = frozen_trial_factory(i, dist=dist)
        study._storage.create_new_trial(study._study_id, template_trial=trial)
    trial = frozen_trial_factory(100)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
        sampler = TPESampler(n_startup_trials=5, seed=0, multivariate=True)
    all_success_suggestion = sampler.sample_relative(study, trial, {"param-a": dist})

    # Test unsuccessful trials are handled differently.
    study = optuna.create_study()
    state_fn = build_state_fn(state)
    for i in range(1, 100):
        trial = frozen_trial_factory(i, dist=dist, state_fn=state_fn)
        study._storage.create_new_trial(study._study_id, template_trial=trial)
    trial = frozen_trial_factory(100)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
        sampler = TPESampler(n_startup_trials=5, seed=0, multivariate=True)
    partial_unsuccessful_suggestion = sampler.sample_relative(study, trial, {"param-a": dist})

    assert partial_unsuccessful_suggestion != all_success_suggestion
コード例 #2
0
ファイル: test_sampler.py プロジェクト: sile/optuna
def test_sample_relative_log_uniform_distributions() -> None:
    """Prepare sample from uniform distribution for cheking other distributions."""

    study = optuna.create_study()

    uni_dist = optuna.distributions.UniformDistribution(1.0, 100.0)
    past_trials = [frozen_trial_factory(i, dist=uni_dist) for i in range(1, 8)]
    trial = frozen_trial_factory(8)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
        sampler = TPESampler(n_startup_trials=5, seed=0, multivariate=True)
    with patch.object(study._storage, "get_all_trials", return_value=past_trials):
        uniform_suggestion = sampler.sample_relative(study, trial, {"param-a": uni_dist})

    # Test sample from log-uniform is different from uniform.
    log_dist = optuna.distributions.LogUniformDistribution(1.0, 100.0)
    past_trials = [frozen_trial_factory(i, dist=log_dist) for i in range(1, 8)]
    trial = frozen_trial_factory(8)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
        sampler = TPESampler(n_startup_trials=5, seed=0, multivariate=True)
    with patch.object(study._storage, "get_all_trials", return_value=past_trials):
        loguniform_suggestion = sampler.sample_relative(study, trial, {"param-a": log_dist})
    assert 1.0 <= loguniform_suggestion["param-a"] < 100.0
    assert uniform_suggestion["param-a"] != loguniform_suggestion["param-a"]
コード例 #3
0
def test_sample_relative_misc_arguments() -> None:
    study = optuna.create_study()
    dist = optuna.distributions.UniformDistribution(1.0, 100.0)
    past_trials = [frozen_trial_factory(i, dist=dist) for i in range(1, 40)]

    # Prepare a trial and a sample for later checks.
    trial = frozen_trial_factory(40)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
        sampler = TPESampler(n_startup_trials=5, seed=0, multivariate=True)
    with patch.object(study._storage,
                      "get_all_trials",
                      return_value=past_trials):
        suggestion = sampler.sample_relative(study, trial, {"param-a": dist})

    # Test misc. parameters.
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
        sampler = TPESampler(n_ei_candidates=13,
                             n_startup_trials=5,
                             seed=0,
                             multivariate=True)
    with patch.object(study._storage,
                      "get_all_trials",
                      return_value=past_trials):
        assert sampler.sample_relative(study, trial,
                                       {"param-a": dist}) != suggestion

    with warnings.catch_warnings():
        warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
        sampler = TPESampler(gamma=lambda _: 5,
                             n_startup_trials=5,
                             seed=0,
                             multivariate=True)
    with patch.object(study._storage,
                      "get_all_trials",
                      return_value=past_trials):
        assert sampler.sample_relative(study, trial,
                                       {"param-a": dist}) != suggestion

    with warnings.catch_warnings():
        warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
        sampler = TPESampler(
            weights=lambda n: np.asarray([i**2 + 1 for i in range(n)]),
            n_startup_trials=5,
            seed=0,
            multivariate=True,
        )
    with patch.object(study._storage,
                      "get_all_trials",
                      return_value=past_trials):
        assert sampler.sample_relative(study, trial,
                                       {"param-a": dist}) != suggestion
コード例 #4
0
ファイル: test_sampler.py プロジェクト: not522/optuna
def test_sample_relative_n_startup_trial() -> None:
    study = optuna.create_study()
    dist = optuna.distributions.UniformDistribution(1.0, 100.0)
    past_trials = [frozen_trial_factory(i, dist=dist) for i in range(1, 8)]

    trial = frozen_trial_factory(8)
    # sample_relative returns {} for only 4 observations.
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
        sampler = TPESampler(n_startup_trials=5, seed=0, multivariate=True)
    with patch.object(study._storage, "get_all_trials", return_value=past_trials[:4]):
        assert sampler.sample_relative(study, trial, {"param-a": dist}) == {}
    # sample_relative returns some value for only 7 observations.
    with patch.object(study._storage, "get_all_trials", return_value=past_trials):
        assert "param-a" in sampler.sample_relative(study, trial, {"param-a": dist}).keys()
コード例 #5
0
def test_sample_relative_pruned_state() -> None:
    """Tests PRUNED state is treated differently from both FAIL and COMPLETE."""

    dist = optuna.distributions.UniformDistribution(1.0, 100.0)

    suggestions = []
    for state in [
            optuna.trial.TrialState.COMPLETE,
            optuna.trial.TrialState.FAIL,
            optuna.trial.TrialState.PRUNED,
    ]:
        study = optuna.create_study()
        state_fn = build_state_fn(state)
        for i in range(1, 40):
            trial = frozen_trial_factory(i, dist=dist, state_fn=state_fn)
            study._storage.create_new_trial(study._study_id,
                                            template_trial=trial)
        trial = frozen_trial_factory(40)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore",
                                  optuna.exceptions.ExperimentalWarning)
            sampler = TPESampler(n_startup_trials=5, seed=0, multivariate=True)
        suggestions.append(
            sampler.sample_relative(study, trial,
                                    {"param-a": dist})["param-a"])

    assert len(set(suggestions)) == 3
コード例 #6
0
def test_sample_relative_ignored_states() -> None:
    """Tests FAIL, RUNNING, and WAITING states are equally."""

    dist = optuna.distributions.UniformDistribution(1.0, 100.0)

    suggestions = []
    for state in [
            optuna.trial.TrialState.FAIL,
            optuna.trial.TrialState.RUNNING,
            optuna.trial.TrialState.WAITING,
    ]:
        study = optuna.create_study()
        state_fn = build_state_fn(state)
        for i in range(1, 30):
            trial = frozen_trial_factory(i, dist=dist, state_fn=state_fn)
            study._storage.create_new_trial(study._study_id,
                                            template_trial=trial)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore",
                                  optuna.exceptions.ExperimentalWarning)
            sampler = TPESampler(n_startup_trials=5, seed=0, multivariate=True)
        suggestions.append(
            sampler.sample_relative(study, trial,
                                    {"param-a": dist})["param-a"])

    assert len(set(suggestions)) == 1
コード例 #7
0
def test_sample_relative_int_loguniform_distributions() -> None:
    """Test sampling from int distribution returns integer."""

    study = optuna.create_study()

    def int_value_fn(idx: int) -> float:
        random.seed(idx)
        return random.randint(0, 100)

    intlog_dist = optuna.distributions.IntLogUniformDistribution(1, 100)
    past_trials = [
        frozen_trial_factory(i, dist=intlog_dist, value_fn=int_value_fn)
        for i in range(1, 8)
    ]
    trial = frozen_trial_factory(8)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
        sampler = TPESampler(n_startup_trials=5, seed=0, multivariate=True)
    with patch.object(study._storage,
                      "get_all_trials",
                      return_value=past_trials):
        intlog_suggestion = sampler.sample_relative(study, trial,
                                                    {"param-a": intlog_dist})
    assert 1 <= intlog_suggestion["param-a"] <= 100
    assert isinstance(intlog_suggestion["param-a"], int)
コード例 #8
0
def test_sample_relative_categorical_distributions() -> None:
    """Test samples are drawn from the specified category."""

    study = optuna.create_study()
    categories = [i * 0.3 + 1.0 for i in range(330)]

    def cat_value_fn(idx: int) -> float:
        random.seed(idx)
        return categories[random.randint(0, len(categories) - 1)]

    cat_dist = optuna.distributions.CategoricalDistribution(categories)
    past_trials = [
        frozen_trial_factory(i, dist=cat_dist, value_fn=cat_value_fn)
        for i in range(1, 8)
    ]
    trial = frozen_trial_factory(8)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
        sampler = TPESampler(n_startup_trials=5, seed=0, multivariate=True)
    with patch.object(study._storage,
                      "get_all_trials",
                      return_value=past_trials):
        categorical_suggestion = sampler.sample_relative(
            study, trial, {"param-a": cat_dist})
    assert categorical_suggestion["param-a"] in categories
コード例 #9
0
def test_sample_relative_disrete_uniform_distributions() -> None:
    """Test samples from discrete have expected intervals."""

    study = optuna.create_study()
    disc_dist = optuna.distributions.DiscreteUniformDistribution(
        1.0, 100.0, 0.1)

    def value_fn(idx: int) -> float:
        random.seed(idx)
        return int(random.random() * 1000) * 0.1

    past_trials = [
        frozen_trial_factory(i, dist=disc_dist, value_fn=value_fn)
        for i in range(1, 8)
    ]
    trial = frozen_trial_factory(8)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
        sampler = TPESampler(n_startup_trials=5, seed=0, multivariate=True)
    with patch.object(study._storage,
                      "get_all_trials",
                      return_value=past_trials):
        discrete_uniform_suggestion = sampler.sample_relative(
            study, trial, {"param-a": disc_dist})
    assert 1.0 <= discrete_uniform_suggestion["param-a"] <= 100.0
    np.testing.assert_almost_equal(
        int(discrete_uniform_suggestion["param-a"] * 10),
        discrete_uniform_suggestion["param-a"] * 10,
    )
コード例 #10
0
def test_sample_relative_empty_input(multivariate: bool) -> None:
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
        sampler = TPESampler(multivariate=multivariate)
    # A frozen-trial is not supposed to be accessed.
    study = optuna.create_study()
    frozen_trial = Mock(spec=[])
    assert sampler.sample_relative(study, frozen_trial, {}) == {}
コード例 #11
0
def test_sample_relative_prior() -> None:
    study = optuna.create_study()
    dist = optuna.distributions.UniformDistribution(1.0, 100.0)
    past_trials = [frozen_trial_factory(i, dist=dist) for i in range(1, 8)]

    # Prepare a trial and a sample for later checks.
    trial = frozen_trial_factory(8)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
        sampler = TPESampler(n_startup_trials=5, seed=0, multivariate=True)
    with patch.object(study._storage,
                      "get_all_trials",
                      return_value=past_trials):
        suggestion = sampler.sample_relative(study, trial, {"param-a": dist})

    with warnings.catch_warnings():
        warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
        sampler = TPESampler(consider_prior=False,
                             n_startup_trials=5,
                             seed=0,
                             multivariate=True)
    with patch.object(study._storage,
                      "get_all_trials",
                      return_value=past_trials):
        assert sampler.sample_relative(study, trial,
                                       {"param-a": dist}) != suggestion

    with warnings.catch_warnings():
        warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
        sampler = TPESampler(prior_weight=0.2,
                             n_startup_trials=5,
                             seed=0,
                             multivariate=True)
    with patch.object(study._storage,
                      "get_all_trials",
                      return_value=past_trials):
        assert sampler.sample_relative(study, trial,
                                       {"param-a": dist}) != suggestion
コード例 #12
0
def test_sample_relative() -> None:
    sampler = TPESampler()
    # Study and frozen-trial are not supposed to be accessed.
    study = Mock(spec=[])
    frozen_trial = Mock(spec=[])
    assert sampler.sample_relative(study, frozen_trial, {}) == {}