Пример #1
0
    def __init__(
        self,
        trial: optuna.Trial,
        config_file: str,
        serialization_dir: str,
        metrics: str = "best_validation_accuracy",
        *,
        include_package: Optional[Union[str, List[str]]] = None,
        force: bool = False,
        file_friendly_logging: bool = False,
    ):
        _imports.check()

        self._params = trial.params
        self._config_file = config_file
        self._serialization_dir = serialization_dir
        self._metrics = metrics
        self._force = force
        self._file_friendly_logging = file_friendly_logging

        if include_package is None:
            include_package = []
        if isinstance(include_package, str):
            include_package = [include_package]

        self._include_package = include_package + [
            "optuna.integration.allennlp"
        ]

        storage = trial.study._storage

        if isinstance(storage, optuna.storages.RDBStorage):
            url = storage.url

        elif isinstance(storage, optuna.storages.RedisStorage):
            url = storage._url

        elif isinstance(storage, optuna.storages._CachedStorage):
            assert isinstance(storage._backend, optuna.storages.RDBStorage)
            url = storage._backend.url

        else:
            url = ""

        target_pid = psutil.Process().ppid()
        variable_manager = _VariableManager(target_pid)

        pruner_kwargs = _fetch_pruner_config(trial)
        variable_manager.set_value("study_name", trial.study.study_name)
        variable_manager.set_value("trial_id", trial._trial_id)
        variable_manager.set_value("storage_name", url)
        variable_manager.set_value("monitor", metrics)

        if trial.study.pruner is not None:
            variable_manager.set_value("pruner_class",
                                       type(trial.study.pruner).__name__)
            variable_manager.set_value("pruner_kwargs", pruner_kwargs)
Пример #2
0
def test_allennlp_pruning_callback_with_executor(
    pruner_class: Type[optuna.pruners.BasePruner],
    pruner_kwargs: Dict[str, Union[int, float]],
    input_config_file: str,
) -> None:
    def run_allennlp_executor(pruner: optuna.pruners.BasePruner) -> None:
        study = optuna.create_study(direction="maximize",
                                    pruner=pruner,
                                    storage=storage)
        trial = optuna.trial.Trial(
            study, study._storage.create_new_trial(study._study_id))
        trial.suggest_float("DROPOUT", 0.0, 0.5)
        executor = optuna.integration.AllenNLPExecutor(trial,
                                                       input_config_file,
                                                       serialization_dir)
        executor.run()

    with tempfile.TemporaryDirectory() as tmp_dir:
        pruner_name = pruner_class.__name__
        os.mkdir(os.path.join(tmp_dir, pruner_name))
        storage = "sqlite:///" + os.path.join(tmp_dir, pruner_name,
                                              "result.db")
        serialization_dir = os.path.join(tmp_dir, pruner_name, "allennlp")

        pruner = pruner_class(**pruner_kwargs)  # type: ignore
        run_allennlp_executor(pruner)
        process = psutil.Process()
        manager = _VariableManager(process.ppid())
        ret_pruner = _create_pruner(
            manager.get_value("pruner_class"),
            manager.get_value("pruner_kwargs"),
        )

        assert isinstance(ret_pruner, pruner_class)
        for key, value in pruner_kwargs.items():
            assert getattr(ret_pruner, "_{}".format(key)) == value
Пример #3
0
    def __init__(
        self,
        trial: Optional[Trial] = None,
        monitor: Optional[str] = None,
    ):
        _imports.check()

        if version.parse(allennlp.__version__) < version.parse("2.0.0"):
            raise ImportError(
                "`AllenNLPPruningCallback` requires AllenNLP>=v2.0.0."
                "If you want to use a callback with an old version of AllenNLP, "
                "please install Optuna v2.5.0 by executing `pip install 'optuna==2.5.0'`."
            )

        # When `AllenNLPPruningCallback` is instantiated in Python script,
        # trial and monitor should not be `None`.
        if trial is not None and monitor is not None:
            self._trial = trial
            self._monitor = monitor

        # When `AllenNLPPruningCallback` is used with `AllenNLPExecutor`,
        # `trial` and `monitor` would be None. `AllenNLPExecutor` sets information
        # for a study name, trial id, monitor, and storage in environment variables.
        else:
            current_process = psutil.Process()

            if os.getenv(OPTUNA_ALLENNLP_DISTRIBUTED_FLAG) == "1":
                os.environ.pop(OPTUNA_ALLENNLP_DISTRIBUTED_FLAG)
                parent_process = current_process.parent()
                target_pid = parent_process.ppid()

            else:
                target_pid = current_process.ppid()

            variable_manager = _VariableManager(target_pid)

            study_name = variable_manager.get_value("study_name")
            trial_id = variable_manager.get_value("trial_id")
            monitor = variable_manager.get_value("monitor")
            storage = variable_manager.get_value("storage_name")

            if study_name is None or trial_id is None or monitor is None or storage is None:
                message = (
                    "Fail to load study. Perhaps you attempt to use `AllenNLPPruningCallback`"
                    " without `AllenNLPExecutor`. If you want to use a callback"
                    " without an executor, you have to instantiate a callback with"
                    "`trial` and `monitor. Please see the Optuna example: https://github.com/"
                    "optuna/optuna-examples/tree/main/allennlp/allennlp_simple.py."
                )
                raise RuntimeError(message)

            else:
                # If `stoage` is empty despite `study_name`, `trial_id`,
                # and `monitor` are not `None`, users attempt to use `AllenNLPPruningCallback`
                # with `AllenNLPExecutor` and in-memory storage.
                # `AllenNLPruningCallback` needs RDB or Redis storages to work.
                if storage == "":
                    message = (
                        "If you want to use AllenNLPExecutor and AllenNLPPruningCallback,"
                        " you have to use RDB or Redis storage.")
                    raise RuntimeError(message)

                pruner = _create_pruner(
                    variable_manager.get_value("pruner_class"),
                    variable_manager.get_value("pruner_kwargs"),
                )

                study = load_study(
                    study_name,
                    storage,
                    pruner=pruner,
                )
                self._trial = Trial(study, trial_id)
                self._monitor = monitor