def train_and_assert(self, agent_type, is_v1: bool, num_iterations=100): logger = logging.warning v2_backends = [b for b in get_backends(agent_type, skip_v1=True)] v1_backends = [ b for b in get_backends(agent_type) if (not b in v2_backends) ] backends = v1_backends if is_v1 else v2_backends for backend in backends: logger( f'backend={backend} agent={agent_type}, num_iterations={num_iterations}' ) cem_agent: CemAgent = agent_type('CartPole-v0', fc_layers=(100, ), backend=backend) tc: core.TrainContext = cem_agent.train( [log.Duration(), log.Iteration(eval_only=True), log.Agent()], num_iterations=num_iterations, num_iterations_between_eval=10, max_steps_per_episode=200, default_plots=False) (min_steps, avg_steps, max_steps) = tc.eval_steps[tc.episodes_done_in_training] assert max_steps >= 100 assert avg_steps >= 50
def train_and_assert(self, agent_type, is_v1: bool, num_iterations=10000): logger = logging.warning v2_backends = [b for b in get_backends(agent_type, skip_v1=True) if b != 'default'] v1_backends = [b for b in get_backends(agent_type) if (not b in v2_backends) and b != 'default'] backends = v1_backends if is_v1 else v2_backends for backend in backends: current_num_iterations = num_iterations if backend == 'tensorforce': current_num_iterations = num_iterations * 3 logger(f'backend={backend} agent={agent_type}, num_iterations={current_num_iterations}') max_avg_steps = self.train_and_eval(agent_type=agent_type, backend=backend, num_iterations=current_num_iterations) assert max_avg_steps >= 80, f'agent_type={agent_type} backend={backend} num_iterations={num_iterations}'
def test_train(self): for backend in get_backends(SacAgent): sac_agent: SacAgent = SacAgent('CartPole-v0', backend=backend) sac_agent.train([log.Duration(), log.Iteration(), log.Agent()], num_iterations=10, max_steps_per_episode=200, default_plots=False)
def test_train_single_episode(self): for backend in get_backends(PpoAgent): ppo = agents.PpoAgent(gym_env_name=_env_name, backend=backend) count = log._CallbackCounts() ppo.train([log.Agent(), count, duration._SingleEpisode()]) assert count.gym_init_begin_count == count.gym_init_end_count == 1 assert count.gym_step_begin_count == count.gym_step_end_count assert count.gym_step_begin_count < 10 + count.gym_reset_begin_count
def test_train(self): for backend in get_backends(RandomAgent): reinforce_agent: ReinforceAgent = ReinforceAgent('CartPole-v0', backend=backend) tc: core.TrainContext = reinforce_agent.train([log.Duration(), log.Iteration()], num_iterations=10, max_steps_per_episode=200, default_plots=False) (min_steps, avg_steps, max_steps) = tc.eval_steps[tc.episodes_done_in_training] assert avg_steps >= 10
def test_play_single_episode(self): for backend in get_backends(PpoAgent): ppo = agents.PpoAgent(gym_env_name=_env_name, backend=backend) count = log._CallbackCounts() cb = [log.Agent(), count, duration._SingleEpisode()] ppo.train(duration._SingleEpisode()) ppo.play(cb) assert count.gym_init_begin_count == count.gym_init_end_count == 1 assert count.gym_step_begin_count == count.gym_step_end_count <= 10
def test_train_cartpole(self): for backend in get_backends(PpoAgent): ppo = PpoAgent(gym_env_name="CartPole-v0", backend=backend) tc = core.PpoTrainContext() tc.num_iterations = 3 tc.num_episodes_per_iteration = 10 tc.max_steps_per_episode = 500 tc.num_epochs_per_iteration = 5 tc.num_iterations_between_eval = 2 tc.num_episodes_per_eval = 5 ppo.train([log.Iteration()], train_context=tc)
def test_getbackends_randomagent(self): assert agents._backends is not None backends = agents.get_backends(agents.RandomAgent) assert 'default' in backends assert 'tfagents' in backends assert 'tensorforce' in backends
def test_getbackends(self): assert agents._backends is not None assert agents.get_backends() is not None
def test_callback_single(self): for backend in get_backends(PpoAgent): env._StepCountEnv.clear() agent = PpoAgent(_env_name, backend=backend) agent.train(duration._SingleEpisode()) assert env._StepCountEnv.reset_count <= 2
def test_getbackends_ppoagent(self): assert agents._backends is not None backends = agents.get_backends(agents.PpoAgent) assert 'default' in backends assert 'tfagents' in backends
def get_backends(agent: Optional[Type[EasyAgent]] = None): result = [b for b in agents.get_backends(agent) if b != 'default'] assert result, f'no backend found for agent {agent}.' return result