示例#1
0
 def test_train(self):
     model_config = core.ModelConfig(_lineworld_name)
     tc = core.PpoTrainContext()
     ppo_agent = tfagents.TfPpoAgent(model_config=model_config)
     ppo_agent.train(train_context=tc,
                     callbacks=[duration.Fast(),
                                log.Iteration()])
示例#2
0
 def test_ppo_train(self):
     model_config = core.ModelConfig("CartPole-v0")
     tc = core.PpoTrainContext()
     ppoAgent = tforce.TforcePpoAgent(model_config=model_config)
     ppoAgent.train(
         train_context=tc,
         callbacks=[log.Iteration(),
                    log.Agent(),
                    duration.Fast()])
示例#3
0
    def test_ppo_train(self):
        from easyagents.backends import tforce

        model_config = core.ModelConfig(_cartpole_name)
        tc = core.PpoTrainContext()
        tc.num_iterations = 20
        ppo_agent = tforce.TforcePpoAgent(model_config=model_config)
        ppo_agent.train(train_context=tc, callbacks=[log.Iteration(), log.Agent()])
        (min_r, avg_r, max_r) = tc.eval_rewards[tc.episodes_done_in_training]
        assert avg_r > 100
示例#4
0
 def test_train_cartpole(self):
     for backend in get_backends(PpoAgent):
         ppo = PpoAgent(gym_env_name="CartPole-v0", backend=backend)
         tc = core.PpoTrainContext()
         tc.num_iterations = 3
         tc.num_episodes_per_iteration = 10
         tc.max_steps_per_episode = 500
         tc.num_epochs_per_iteration = 5
         tc.num_iterations_between_eval = 2
         tc.num_episodes_per_eval = 5
         ppo.train([log.Iteration()], train_context=tc)
示例#5
0
 def test_train(self):
     agents.seed = 0
     for backend in get_backends(PpoAgent):
         ppo = PpoAgent(gym_env_name=_cartpole_name, backend=backend)
         tc = core.PpoTrainContext()
         tc.num_iterations = 10
         tc.num_episodes_per_iteration = 10
         tc.max_steps_per_episode = 200
         tc.num_epochs_per_iteration = 5
         tc.num_iterations_between_eval = 5
         tc.num_episodes_per_eval = 5
         ppo.train([log.Iteration()], train_context=tc, default_plots=False)
         assert max_avg_rewards(tc) >= 50
示例#6
0
    def test_save_(self):
        from easyagents.backends import tforce

        model_config = core.ModelConfig(_cartpole_name)
        tc = core.PpoTrainContext()
        tc.num_iterations = 3
        ppo_agent = tforce.TforcePpoAgent(model_config=model_config)
        ppo_agent.train(train_context=tc, callbacks=[log.Iteration(), log.Agent()])
        tempdir = bcore._get_temp_path()
        bcore._mkdir(tempdir)
        ppo_agent.save(tempdir, [])

        loaded_agent = tforce.TforcePpoAgent(model_config=model_config)
        loaded_agent.load(tempdir, [])
        bcore._rmpath(tempdir)
示例#7
0
    def train(self,
              callbacks: Union[List[core.AgentCallback], core.AgentCallback,
                               None] = None,
              num_iterations: int = 100,
              num_episodes_per_iteration: int = 10,
              max_steps_per_episode: int = 500,
              num_epochs_per_iteration: int = 10,
              num_iterations_between_eval: int = 5,
              num_episodes_per_eval: int = 10,
              learning_rate: float = 0.001,
              train_context: core.PpoTrainContext = None,
              default_plots: bool = None):
        """Trains a new model using the gym environment passed during instantiation.

        Args:
            callbacks: list of callbacks called during training and evaluation
            num_iterations: number of times the training is repeated (with additional data)
            num_episodes_per_iteration: number of episodes played per training iteration
            max_steps_per_episode: maximum number of steps per episode
            num_epochs_per_iteration: number of times the data collected for the current iteration
                is used to retrain the current policy
            num_iterations_between_eval: number of training iterations before the current policy is evaluated.
                if 0 no evaluation is performed.
            num_episodes_per_eval: number of episodes played to estimate the average return and steps
            learning_rate: the learning rate used in the next iteration's policy training (0,1]
            train_context: training configuration to be used. if set overrides all other training context arguments.
            default_plots: if set adds a set of default callbacks (plot.State, plot.Rewards, plot.Loss,...).
                if None default callbacks are only added if the callbacks list is empty

        Returns:
            train_context: the training configuration containing the loss and sum of rewards encountered
                during training
        """
        if train_context is None:
            train_context = core.PpoTrainContext()
            train_context.num_iterations = num_iterations
            train_context.num_episodes_per_iteration = num_episodes_per_iteration
            train_context.max_steps_per_episode = max_steps_per_episode
            train_context.num_epochs_per_iteration = num_epochs_per_iteration
            train_context.num_iterations_between_eval = num_iterations_between_eval
            train_context.num_episodes_per_eval = num_episodes_per_eval
            train_context.learning_rate = learning_rate

        super().train(train_context=train_context,
                      callbacks=callbacks,
                      default_plots=default_plots)
        return train_context
示例#8
0
 def test_save_load(self):
     model_config = core.ModelConfig(_lineworld_name)
     tc = core.PpoTrainContext()
     ppo_agent = tfagents.TfPpoAgent(model_config=model_config)
     ppo_agent.train(
         train_context=tc,
         callbacks=[duration._SingleIteration(),
                    log.Iteration()])
     tempdir = bcore._get_temp_path()
     bcore._mkdir(tempdir)
     ppo_agent.save(tempdir, [])
     ppo_agent = tfagents.TfPpoAgent(model_config=model_config)
     ppo_agent.load(tempdir, [])
     pc = core.PlayContext()
     pc.max_steps_per_episode = 10
     pc.num_episodes = 1
     ppo_agent.play(play_context=pc, callbacks=[])
     bcore._rmpath(tempdir)