Ejemplo n.º 1
0
def test_delete_study_after_create_multiple_studies():
    # type: () -> None

    storage = RDBStorage('sqlite:///:memory:')
    study_id1 = storage.create_new_study()
    study_id2 = storage.create_new_study()
    study_id3 = storage.create_new_study()

    storage.delete_study(study_id2)

    studies = {s.study_id: s for s in storage.get_all_study_summaries()}
    assert study_id1 in studies
    assert study_id2 not in studies
    assert study_id3 in studies
Ejemplo n.º 2
0
def test_study_optimize_command(options):
    # type: (List[str]) -> None

    with StorageConfigSupplier(TEST_CONFIG_TEMPLATE) as (storage_url,
                                                         config_path):
        storage = RDBStorage(storage_url)

        study_name = storage.get_study_name_from_id(storage.create_new_study())
        command = [
            'optuna', 'study', 'optimize', '--study', study_name, '--n-trials',
            '10', __file__, 'objective_func'
        ]
        command = _add_option(command, '--storage', storage_url, 'storage'
                              in options)
        command = _add_option(command, '--config', config_path, 'config'
                              in options)
        subprocess.check_call(command)

        study = optuna.load_study(storage=storage_url, study_name=study_name)
        assert len(study.trials) == 10
        assert 'x' in study.best_params

        # Check if a default value of study_name is stored in the storage.
        assert storage.get_study_name_from_id(
            study.study_id).startswith(DEFAULT_STUDY_NAME_PREFIX)
Ejemplo n.º 3
0
def test_study_set_user_attr_command(options):
    # type: (List[str]) -> None

    with StorageConfigSupplier(TEST_CONFIG_TEMPLATE) as (storage_url,
                                                         config_path):
        storage = RDBStorage(storage_url)

        # Create study.
        study_name = storage.get_study_name_from_id(storage.create_new_study())

        base_command = [
            'optuna', 'study', 'set-user-attr', '--study', study_name
        ]
        base_command = _add_option(base_command, '--storage', storage_url,
                                   'storage' in options)
        base_command = _add_option(base_command, '--config', config_path,
                                   'config' in options)

        example_attrs = {'architecture': 'ResNet', 'baselen_score': '0.002'}
        for key, value in example_attrs.items():
            subprocess.check_call(base_command +
                                  ['--key', key, '--value', value])

        # Attrs should be stored in storage.
        study_id = storage.get_study_id_from_name(study_name)
        study_user_attrs = storage.get_study_user_attrs(study_id)
        assert len(study_user_attrs) == 2
        assert all(
            [study_user_attrs[k] == v for k, v in example_attrs.items()])
Ejemplo n.º 4
0
def test_dashboard_command(options):
    # type: (List[str]) -> None

    with \
            StorageConfigSupplier(TEST_CONFIG_TEMPLATE) as (storage_url, config_path), \
            tempfile.NamedTemporaryFile('r') as tf_report:

        storage = RDBStorage(storage_url)
        study_name = storage.get_study_name_from_id(storage.create_new_study())

        command = ['optuna', 'dashboard', '--study', study_name, '--out', tf_report.name]
        command = _add_option(command, '--storage', storage_url, 'storage' in options)
        command = _add_option(command, '--config', config_path, 'config' in options)
        subprocess.check_call(command)

        html = tf_report.read()
        assert '<body>' in html
        assert 'bokeh' in html
Ejemplo n.º 5
0
def test_dashboard_command_with_allow_websocket_origin(origins):
    # type: (List[str]) -> None

    with \
            StorageConfigSupplier(TEST_CONFIG_TEMPLATE) as (storage_url, config_path), \
            tempfile.NamedTemporaryFile('r') as tf_report:

        storage = RDBStorage(storage_url)
        study_name = storage.get_study_name_from_id(storage.create_new_study())
        command = [
            'optuna', 'dashboard', '--study', study_name, '--out',
            tf_report.name, '--storage', storage_url
        ]
        for origin in origins:
            command.extend(['--allow-websocket-origin', origin])
        subprocess.check_call(command)

        html = tf_report.read()
        assert '<body>' in html
        assert 'bokeh' in html