コード例 #1
0
ファイル: utilities.py プロジェクト: Wandercraft/jiminy
def train(train_agent: BaseAlgorithm, max_timesteps: int) -> str:
    """Train a model on a specific environment using a given agent.

    :param train_agent: Training agent.
    :param max_timesteps: Number of maximum training timesteps.

    :returns: Whether or not the threshold reward has been exceeded in average
              over 10 episodes.
    """
    # Get testing environment spec
    spec = train_agent.eval_env.envs[0].spec

    # Create callback to stop learning early if reward threshold is exceeded
    if spec.reward_threshold is not None:
        callback_reward = StopOnReward(reward_threshold=spec.reward_threshold)
        eval_callback = EvalCallback(train_agent.eval_env,
                                     callback_on_new_best=callback_reward,
                                     eval_freq=10000,
                                     n_eval_episodes=10,
                                     verbose=False)
    else:
        eval_callback = None

    # Run the learning process
    train_agent.learn(total_timesteps=max_timesteps, callback=eval_callback)

    return train_agent.num_timesteps < max_timesteps
コード例 #2
0
    def learn(self, model: BaseAlgorithm) -> None:
        """
        :param model: an initialized RL model
        """
        kwargs = {}
        if self.log_interval > -1:
            kwargs = {"log_interval": self.log_interval}

        if len(self.callbacks) > 0:
            kwargs["callback"] = self.callbacks

        # Special case for ARS
        if self.algo == "ars" and self.n_envs > 1:
            kwargs["async_eval"] = AsyncEval([
                lambda: self.create_envs(n_envs=1, no_log=True)
                for _ in range(self.n_envs)
            ], model.policy)

        try:
            model.learn(self.n_timesteps, **kwargs)
        except KeyboardInterrupt:
            # this allows to save the model when interrupting training
            pass
        finally:
            # Release resources
            try:
                model.env.close()
            except EOFError:
                pass
コード例 #3
0
    def learn(self, model: BaseAlgorithm) -> None:
        """
        :param model: an initialized RL model
        """
        kwargs = {}
        if self.log_interval > -1:
            kwargs = {"log_interval": self.log_interval}

        if len(self.callbacks) > 0:
            kwargs["callback"] = self.callbacks

        if self.continue_training:
            kwargs["reset_num_timesteps"] = False
            model.env.reset()

        try:
            model.learn(self.n_timesteps, **kwargs)
        except KeyboardInterrupt:
            # this allows to save the model when interrupting training
            pass
        finally:
            # Release resources
            try:
                model.env.close()
            except EOFError:
                pass
コード例 #4
0
def train(train_agent: BaseAlgorithm,
          max_timesteps: int,
          verbose: bool = True) -> str:
    """Train a model on a specific environment using a given agent.

    Note that the agent is associated with a given reinforcement learning
    algorithm, and instanciated for a specific environment and neural network
    model. Thus, it already wraps all the required information to actually
    perform training.

    .. note::
        This function can be terminated early using CTRL+C.

    :param train_agent: Training agent.
    :param max_timesteps: Number of maximum training timesteps.
    :param verbose: Whether or not to print information about what is going on.
                    Optional: True by default.

    :returns: Fullpath of agent's final state dump. Note that it also contains
              the trained neural network model.
    """
    # Get testing environment spec
    spec = train_agent.eval_env.envs[0].spec

    # Create callback to stop learning early if reward threshold is exceeded
    if spec.reward_threshold is not None:
        callback_reward = StopOnReward(reward_threshold=spec.reward_threshold)
        eval_callback = EvalCallback(train_agent.eval_env,
                                     callback_on_new_best=callback_reward,
                                     eval_freq=5000,
                                     n_eval_episodes=100)
    else:
        eval_callback = None

    try:
        # Run the learning process
        train_agent.learn(total_timesteps=max_timesteps,
                          log_interval=5,
                          reset_num_timesteps=False,
                          callback=eval_callback)
        if train_agent.num_timesteps < max_timesteps:
            print("Problem solved successfully!")
    except KeyboardInterrupt:
        if verbose:
            print("Interrupting training...")

    fd, checkpoint_path = tempfile.mkstemp(dir=train_agent.tensorboard_log,
                                           prefix=spec.id,
                                           suffix='.zip')
    os.close(fd)
    train_agent.save(checkpoint_path)

    return checkpoint_path
コード例 #5
0
def train(
    model: BaseAlgorithm, timesteps: int, eval_env: GymEnv, model_path: Path
) -> None:
    """
    Train agent moves in his environment. Learning will finish when agent performs given number of timesteps or when mean reward of 10 gameplays reachs value 1.
    :param model: RL agent
    :param timesteps: total number of steps to take (through all episodes)
    :param eval_env: evaluation environment
    :param model_path: location where model will be saved
    :param tb_log_name: the name of the run for tensorboard log
    """
    mlflow_callback = MlflowCallback(model_path)
    reward_threshold_callback = StopTrainingOnRewardThreshold(
        reward_threshold=1
    )
    eval_callback = MlflowEvalCallback(
        eval_env=eval_env, callback_on_new_best=reward_threshold_callback
    )
    callbacks = CallbackList([mlflow_callback, eval_callback])

    model.learn(total_timesteps=timesteps, callback=callbacks)