Exemplo n.º 1
0
from optuna import distributions
from optuna import samplers
from optuna import storages
from optuna.study import create_study
from optuna.testing.sampler import DeterministicRelativeSampler
from optuna.trial import FixedTrial
from optuna.trial import Trial
from optuna import type_checking

if type_checking.TYPE_CHECKING:
    import typing  # NOQA

parametrize_storage = pytest.mark.parametrize(
    'storage_init_func',
    [storages.InMemoryStorage, lambda: storages.RDBStorage('sqlite:///:memory:')])


@parametrize_storage
def test_suggest_uniform(storage_init_func):
    # type: (typing.Callable[[], storages.BaseStorage]) -> None

    mock = Mock()
    mock.side_effect = [1., 2., 3.]
    sampler = samplers.RandomSampler()

    with patch.object(sampler, 'sample_independent', mock) as mock_object:
        study = create_study(storage_init_func(), sampler=sampler)
        trial = Trial(study, study.storage.create_new_trial_id(study.study_id))
        distribution = distributions.UniformDistribution(low=0., high=3.)
Exemplo n.º 2
0
from optuna.study import create_study
from optuna.testing.integration import DeterministicPruner
from optuna.testing.sampler import DeterministicRelativeSampler
from optuna.trial import FixedTrial
from optuna.trial import Trial
from optuna import type_checking

if type_checking.TYPE_CHECKING:
    from datetime import datetime  # NOQA
    import typing  # NOQA

parametrize_storage = pytest.mark.parametrize(
    "storage_init_func",
    [
        storages.InMemoryStorage,
        lambda: storages.RDBStorage("sqlite:///:memory:")
    ],
)


@parametrize_storage
def test_check_distribution_suggest_float(storage_init_func):
    # type: (typing.Callable[[], storages.BaseStorage]) -> None

    sampler = samplers.RandomSampler()
    study = create_study(storage_init_func(), sampler=sampler)
    trial = Trial(study, study._storage.create_new_trial(study._study_id))

    x1 = trial.suggest_float("x1", 10, 20)
    x2 = trial.suggest_uniform("x1", 10, 20)
Exemplo n.º 3
0
from optuna import samplers
from optuna import storages
from optuna.study import create_study
from optuna.testing.integration import DeterministicPruner
from optuna.testing.sampler import DeterministicRelativeSampler
from optuna.trial._frozen import create_trial
from optuna.trial import BaseTrial
from optuna.trial import FixedTrial
from optuna.trial import FrozenTrial
from optuna.trial import Trial
from optuna.trial import TrialState


parametrize_storage = pytest.mark.parametrize(
    "storage_init_func",
    [storages.InMemoryStorage, lambda: storages.RDBStorage("sqlite:///:memory:")],
)


@parametrize_storage
def test_check_distribution_suggest_float(
    storage_init_func: Callable[[], storages.BaseStorage]
) -> None:

    sampler = samplers.RandomSampler()
    study = create_study(storage_init_func(), sampler=sampler)
    trial = Trial(study, study._storage.create_new_trial(study._study_id))

    x1 = trial.suggest_float("x1", 10, 20)
    x2 = trial.suggest_uniform("x1", 10, 20)
- Removes old training-related tables (e.g., losses)
- Sets up fresh Optuna tables ready for new studies

"""

training_tables = [
    "studies", "version_info", "study_user_attributes",
    "study_system_attributes", "trials", "trial_user_attributes",
    "trial_system_attributes", "trial_params", "trial_values",
    "alembic_version", "predictions", "training_summaries", "best_trials",
    "evaluation_metrics", "shap_values"
]

DBNAME = snakemake.params["dbname"]
DBUSER = snakemake.params["dbuser"]
DBSCHEMA = snakemake.params["dbschema"]
psql_url = f"postgresql://{DBUSER}@dbserver/{DBNAME}?options=-c%20search_path={DBSCHEMA}"

engine = create_engine(psql_url)
with engine.connect() as connection:
    q = " ".join(f"DROP TABLE IF EXISTS {DBSCHEMA}.{table} CASCADE;"
                 for table in training_tables)
    connection.execute(text("BEGIN; " + q +
                            "COMMIT;"))  # remove training-related tables
    storages.RDBStorage(url=psql_url,
                        engine_kwargs={"pool_size":
                                       0})  # initiate empty Optuna tables

with open(snakemake.output[0], "w") as f:
    f.write(str(datetime.now()))