Exemple #1
0
    def run_experiment(self, environment, experiment_num=0):
        environment = RLgraphEnvironmentWrapper(environment)
        environment.add_episode_end_callback(self.episode_finished,
                                             environment,
                                             runner_id=1)

        config = copy(self.config)

        max_episodes = config.pop('max_episodes', None)
        max_timesteps = config.pop('max_timesteps', None)
        max_episode_timesteps = config.pop('max_episode_timesteps')

        agent = Agent.from_spec(
            spec=config,
            state_space=environment.state_space,
            action_space=environment.action_space,
        )

        if experiment_num == 0 and self.load_model_file:
            logging.info("Loading model data from file: {}".format(
                self.load_model))
            agent.load_model(self.load_model_file)

        runner = SingleThreadedWorker(agent=agent, environment=environment)

        environment.reset()
        agent.reset_buffers()

        if max_timesteps:
            runner.execute_timesteps(
                num_timesteps=max_timesteps,
                max_timesteps_per_episode=max_episode_timesteps)
        else:
            runner.execute_episodes(
                num_episodes=max_episodes,
                max_timesteps_per_episode=max_episode_timesteps)

        return dict(initial_reset_time=0,
                    episode_rewards=runner.episode_rewards,
                    episode_timesteps=runner.episode_steps,
                    episode_end_times=runner.episode_durations)
Exemple #2
0
    def test_episodes(self):
        """
        Simply tests if episode execution loop works and returns a result.
        """
        agent = RandomAgent(action_space=self.environment.action_space,
                            state_space=self.environment.state_space)
        worker = SingleThreadedWorker(env_spec=lambda: self.environment,
                                      agent=agent,
                                      frameskip=1,
                                      worker_executes_preprocessing=False)

        result = worker.execute_episodes(5, max_timesteps_per_episode=10)
        # Max 5 * 10.
        self.assertLessEqual(result['timesteps_executed'], 50)
        self.assertEqual(result['episodes_executed'], 5)
        self.assertLessEqual(result['env_frames'], 50)
        self.assertGreaterEqual(result['runtime'], 0.0)