コード例 #1
0
    def objective(self, trial: optuna.Trial) -> float:

        kwargs = self._hyperparams.copy()

        trial.model_class = None
        if self.algo == "her":
            trial.model_class = self._hyperparams.get("model_class", None)

        # Hack to use DDPG/TD3 noise sampler
        trial.n_actions = self.n_actions
        # Sample candidate hyperparameters
        kwargs.update(HYPERPARAMS_SAMPLER[self.algo](trial))

        model = ALGOS[self.algo](
            env=self.create_envs(self.n_envs, no_log=True),
            tensorboard_log=None,
            # We do not seed the trial
            seed=None,
            verbose=0,
            **kwargs,
        )

        model.trial = trial

        eval_env = self.create_envs(n_envs=1, eval_env=True)

        eval_freq = int(self.n_timesteps / self.n_evaluations)
        # Account for parallel envs
        eval_freq_ = max(eval_freq // model.get_env().num_envs, 1)
        # Use non-deterministic eval for Atari
        eval_callback = TrialEvalCallback(
            eval_env,
            trial,
            n_eval_episodes=self.n_eval_episodes,
            eval_freq=eval_freq_,
            deterministic=self.deterministic_eval,
        )

        try:
            model.learn(self.n_timesteps, callback=eval_callback)
            # Free memory
            model.env.close()
            eval_env.close()
        except AssertionError as e:
            # Sometimes, random hyperparams can generate NaN
            # Free memory
            model.env.close()
            eval_env.close()
            # Prune hyperparams that generate NaNs
            print(e)
            raise optuna.exceptions.TrialPruned()
        is_pruned = eval_callback.is_pruned
        reward = eval_callback.last_mean_reward

        del model.env, eval_env
        del model

        if is_pruned:
            raise optuna.exceptions.TrialPruned()

        return reward
コード例 #2
0
    def objective(self, trial: optuna.Trial) -> float:

        kwargs = self._hyperparams.copy()

        trial.model_class = None
        if self.algo == "her":
            trial.model_class = self._hyperparams.get("model_class", None)

        # Hack to use DDPG/TD3 noise sampler
        trial.n_actions = self._env.action_space.shape[0]

        # Sample candidate hyperparameters
        kwargs.update(HYPERPARAMS_SAMPLER[self.algo](trial))
        print(f"\nRunning a new trial with hyperparameters: {kwargs}")

        # Write hyperparameters into a file
        trial_params_path = os.path.join(self.params_path, "optimization")
        os.makedirs(trial_params_path, exist_ok=True)
        with open(
                os.path.join(trial_params_path,
                             f"hyperparameters_trial_{trial.number}.yml"),
                "w") as f:
            yaml.dump(kwargs, f)

        model = ALGOS[self.algo](
            env=self._env,
            # Note: Here I enabled tensorboard logs
            tensorboard_log=self.tensorboard_log,
            # Note: Here I differ and I seed the trial. I want all trials to have the same starting conditions
            seed=self.seed,
            verbose=self.verbose,
            **kwargs,
        )

        # Pre-load replay buffer if enabled
        if self.preload_replay_buffer:
            if self.preload_replay_buffer.endswith('.pkl'):
                replay_buffer_path = self.preload_replay_buffer
            else:
                replay_buffer_path = os.path.join(self.preload_replay_buffer,
                                                  "replay_buffer.pkl")
            if os.path.exists(replay_buffer_path):
                print("Pre-loading replay buffer")
                if self.algo == "her":
                    model.load_replay_buffer(replay_buffer_path,
                                             self.truncate_last_trajectory)
                else:
                    model.load_replay_buffer(replay_buffer_path)
            else:
                raise Exception(f"Replay buffer {replay_buffer_path} "
                                "does not exist")

        model.trial = trial

        eval_freq = int(self.n_timesteps / self.n_evaluations)
        # Account for parallel envs
        eval_freq_ = max(eval_freq // model.get_env().num_envs, 1)
        # Use non-deterministic eval for Atari
        eval_callback = TrialEvalCallback(
            model.env,
            model.trial,
            n_eval_episodes=self.n_eval_episodes,
            eval_freq=eval_freq_,
            deterministic=self.deterministic_eval,
            verbose=self.verbose,
        )

        try:
            model.learn(self.n_timesteps, callback=eval_callback)
            # Reset env
            self._env.reset()
        except AssertionError as e:
            # Reset env
            self._env.reset()
            print('Trial stopped:', e)
            # Prune hyperparams that generate NaNs
            raise optuna.exceptions.TrialPruned()
        except Exception as err:
            exception_type = type(err).__name__
            print('Trial stopped due to raised exception:', exception_type,
                  err)
            # Prune also all other exceptions
            raise optuna.exceptions.TrialPruned()
        is_pruned = eval_callback.is_pruned
        reward = eval_callback.last_mean_reward

        print(
            f"\nFinished a trial with reward={reward}, is_pruned={is_pruned} "
            f"for hyperparameters: {kwargs}")

        del model

        if is_pruned:
            raise optuna.exceptions.TrialPruned()

        return reward