def test_rollout_evaluator(): env = SequentialVectorEnv([lambda: TimeLimitWrapper.wrap(build_dummy_maze_env(), max_episode_steps=2)] * 2) policy = flatten_concat_probabilistic_policy_for_env(build_dummy_maze_env()) model_selection = _MockModelSelection() evaluator = RolloutEvaluator(eval_env=env, n_episodes=3, model_selection=model_selection) for i in range(2): evaluator.evaluate(policy) increment_log_step() assert model_selection.update_count == 2 assert evaluator.eval_env.get_stats_value( BaseEnvEvents.reward, LogStatsLevel.EPOCH, name="total_episode_count" ) >= 2 * 3
def _get_alg_config(env_name: str, runner_type: str) -> A2CAlgorithmConfig: """ Returns algorithm config used in tests. :param env_name: Env name for rollout evaluator. :param runner_type: Runner type. "dev" or "local". :return: A2CAlgorithmConfig instance. """ env_factory = lambda: GymMazeEnv(env_name) return A2CAlgorithmConfig( n_epochs=1, epoch_length=25, patience=15, critic_burn_in_epochs=0, n_rollout_steps=100, lr=0.0005, gamma=0.98, gae_lambda=1.0, policy_loss_coef=1.0, value_loss_coef=0.5, entropy_coef=0.00025, max_grad_norm=0.0, device='cpu', rollout_evaluator=RolloutEvaluator( eval_env=SubprocVectorEnv([env_factory]) if runner_type == "local" else SequentialVectorEnv([env_factory]), n_episodes=1, model_selection=None, deterministic=True))
def _algorithm_config(): eval_env = SequentialVectorEnv([_env_factory for _ in range(2)], logging_prefix='eval') return ImpalaAlgorithmConfig(n_epochs=2, epoch_length=2, queue_out_of_sync_factor=2, patience=15, n_rollout_steps=20, lr=0.0005, gamma=0.98, policy_loss_coef=1.0, value_loss_coef=0.5, entropy_coef=0.0, max_grad_norm=0.0, device="cpu", vtrace_clip_pg_rho_threshold=1, vtrace_clip_rho_threshold=1, num_actors=1, actors_batch_size=5, critic_burn_in_epochs=0, rollout_evaluator=RolloutEvaluator( eval_env=eval_env, n_episodes=1, model_selection=None, deterministic=True))
def train_function(n_epochs: int, distributed_env_cls) -> A2C: """Trains the cart pole environment with the multi-step a2c implementation. """ # initialize distributed env envs = distributed_env_cls([lambda: GymMazeEnv(env="CartPole-v0") for _ in range(2)]) # initialize the env and enable statistics collection eval_env = distributed_env_cls([lambda: GymMazeEnv(env="CartPole-v0") for _ in range(2)], logging_prefix='eval') # init distribution mapper env = GymMazeEnv(env="CartPole-v0") distribution_mapper = DistributionMapper(action_space=env.action_space, distribution_mapper_config={}) # initialize policies policies = {0: FlattenConcatPolicyNet({'observation': (4,)}, {'action': (2,)}, hidden_units=[16], non_lin=nn.Tanh)} # initialize critic critics = {0: FlattenConcatStateValueNet({'observation': (4,)}, hidden_units=[16], non_lin=nn.Tanh)} # algorithm configuration algorithm_config = A2CAlgorithmConfig( n_epochs=n_epochs, epoch_length=2, patience=10, critic_burn_in_epochs=0, n_rollout_steps=20, lr=0.0005, gamma=0.98, gae_lambda=1.0, policy_loss_coef=1.0, value_loss_coef=0.5, entropy_coef=0.0, max_grad_norm=0.0, device="cpu", rollout_evaluator=RolloutEvaluator(eval_env=eval_env, n_episodes=1, model_selection=None, deterministic=True) ) # initialize actor critic model model = TorchActorCritic( policy=TorchPolicy(networks=policies, distribution_mapper=distribution_mapper, device=algorithm_config.device), critic=TorchSharedStateCritic(networks=critics, obs_spaces_dict=env.observation_spaces_dict, device=algorithm_config.device, stack_observations=False), device=algorithm_config.device) a2c = A2C(rollout_generator=RolloutGenerator(envs), algorithm_config=algorithm_config, evaluator=algorithm_config.rollout_evaluator, model=model, model_selection=None) # train agent a2c.train() return a2c
def _generate_inconsistency_type_2_configs( ) -> Tuple[Dict, Dict, Dict, A2CAlgorithmConfig, Dict]: """ Returns configsf for tests of inconsistencies of type 2. :return: es_dev_runner_config, a2c_dev_runner_config, invalid_a2c_dev_runner_config, a2c_alg_config, default_overrides. """ gym_env_name = "CartPole-v0" es_dev_runner_config = { 'state_dict_dump_file': 'state_dict.pt', 'spaces_config_dump_file': 'spaces_config.pkl', 'normalization_samples': 1, '_target_': 'maze.train.trainers.es.ESDevRunner', 'n_eval_rollouts': 1, 'shared_noise_table_size': 10, "dump_interval": None } a2c_dev_runner_config = { 'state_dict_dump_file': 'state_dict.pt', 'spaces_config_dump_file': 'spaces_config.pkl', 'normalization_samples': 1, '_target_': 'maze.train.trainers.common.actor_critic.actor_critic_runners.ACDevRunner', "trainer_class": "maze.train.trainers.a2c.a2c_trainer.A2C", 'concurrency': 1, "dump_interval": None, "eval_concurrency": 1 } invalid_a2c_dev_runner_config = copy.deepcopy(a2c_dev_runner_config) invalid_a2c_dev_runner_config[ "trainer_class"] = "maze.train.trainers.es.es_trainer.ESTrainer" a2c_alg_config = A2CAlgorithmConfig( n_epochs=1, epoch_length=25, patience=15, critic_burn_in_epochs=0, n_rollout_steps=100, lr=0.0005, gamma=0.98, gae_lambda=1.0, policy_loss_coef=1.0, value_loss_coef=0.5, entropy_coef=0.00025, max_grad_norm=0.0, device='cpu', rollout_evaluator=RolloutEvaluator(eval_env=SequentialVectorEnv( [lambda: GymMazeEnv(gym_env_name)]), n_episodes=1, model_selection=None, deterministic=True)) default_overrides = {"env.name": gym_env_name} return es_dev_runner_config, a2c_dev_runner_config, invalid_a2c_dev_runner_config, a2c_alg_config, default_overrides
def test_does_not_carry_over_stats_from_unfinished_episodes(): policy = flatten_concat_probabilistic_policy_for_env(build_dummy_maze_env()) # Wrap envs in a time-limit wrapper env = SequentialVectorEnv([lambda: TimeLimitWrapper.wrap(build_dummy_maze_env())] * 2) # Make one env slower than the other env.envs[0].set_max_episode_steps(2) env.envs[1].set_max_episode_steps(10) evaluator = RolloutEvaluator(eval_env=env, n_episodes=1, model_selection=None) for i in range(2): evaluator.evaluate(policy) increment_log_step() # We should get just one episode counted in stats assert evaluator.eval_env.get_stats_value( BaseEnvEvents.reward, LogStatsLevel.EPOCH, name="episode_count" ) == 1
def test_autoresolving_proxy_attribute(): """ Tests auto-resolving proxy attributes like critic (see for :py:class:`maze.api.utils._ATTRIBUTE_PROXIES` for more info). """ cartpole_env_factory = lambda: GymMazeEnv(env=gym.make("CartPole-v0")) _, _, critic_composer, _, _ = _get_cartpole_setup_components() alg_config = A2CAlgorithmConfig(n_epochs=1, epoch_length=25, patience=15, critic_burn_in_epochs=0, n_rollout_steps=100, lr=0.0005, gamma=0.98, gae_lambda=1.0, policy_loss_coef=1.0, value_loss_coef=0.5, entropy_coef=0.00025, max_grad_norm=0.0, device='cpu', rollout_evaluator=RolloutEvaluator( eval_env=SequentialVectorEnv( [cartpole_env_factory]), n_episodes=1, model_selection=None, deterministic=True)) default_overrides = { "runner.normalization_samples": 1, "runner.concurrency": 1 } rc = run_context.RunContext(env=cartpole_env_factory, silent=True, algorithm=alg_config, critic=critic_composer, runner="dev", overrides=default_overrides) rc.train(n_epochs=1) assert isinstance(rc._runners[RunMode.TRAINING][0].model_composer.critic, TorchSharedStateCritic) rc = run_context.RunContext(env=cartpole_env_factory, silent=True, algorithm=alg_config, critic="template_state", runner="dev", overrides=default_overrides) rc.train(n_epochs=1) assert isinstance(rc._runners[RunMode.TRAINING][0].model_composer.critic, TorchStepStateCritic)
def test_evaluation(): """ Tests evaluation. """ # Test with ES: No rollout evaluator in config. rc = run_context.RunContext( env=lambda: GymMazeEnv(env=gym.make("CartPole-v0")), silent=True, configuration="test", overrides={ "runner.normalization_samples": 1, "runner.shared_noise_table_size": 10 }) rc.train(1) stats = rc.evaluate(n_episodes=5) assert len(stats) == 1 assert stats[0][(BaseEnvEvents.reward, "episode_count", None)] in (5, 6) # Test with A2C: Partially specified rollout evaluator in config. rc = run_context.RunContext( env=lambda: GymMazeEnv(env=gym.make("CartPole-v0")), silent=True, algorithm="a2c", configuration="test", overrides={"runner.concurrency": 1}) rc.train(1) stats = rc.evaluate(n_episodes=2) assert len(stats) == 1 assert stats[0][(BaseEnvEvents.reward, "episode_count", None)] in (2, 3) # Test with A2C and instanatiated RolloutEvaluator. rc = run_context.RunContext( env=lambda: GymMazeEnv(env=gym.make("CartPole-v0")), silent=True, algorithm="a2c", configuration="test", overrides={ "runner.concurrency": 1, "algorithm.rollout_evaluator": RolloutEvaluator(eval_env=SequentialVectorEnv( [lambda: GymMazeEnv("CartPole-v0")]), n_episodes=1, model_selection=None, deterministic=True) }) rc.train(1) stats = rc.evaluate(n_episodes=5) assert len(stats) == 1 assert stats[0][(BaseEnvEvents.reward, "episode_count", None)] in (1, 2)
def test_inconsistency_identification_type_3() -> None: """ Tests identification of inconsistency due to derived config group. """ es_dev_runner_config = { 'state_dict_dump_file': 'state_dict.pt', 'spaces_config_dump_file': 'spaces_config.pkl', 'normalization_samples': 10000, '_target_': 'maze.train.trainers.es.ESDevRunner', 'n_eval_rollouts': 10, 'shared_noise_table_size': 100000000 } a2c_alg_config = A2CAlgorithmConfig( n_epochs=1, epoch_length=25, patience=15, critic_burn_in_epochs=0, n_rollout_steps=100, lr=0.0005, gamma=0.98, gae_lambda=1.0, policy_loss_coef=1.0, value_loss_coef=0.5, entropy_coef=0.00025, max_grad_norm=0.0, device='cpu', rollout_evaluator=RolloutEvaluator(eval_env=SequentialVectorEnv( [lambda: GymMazeEnv(env="CartPole-v0")]), n_episodes=1, model_selection=None, deterministic=True)) default_overrides = { "runner.normalization_samples": 1, "runner.concurrency": 1 } rc = run_context.RunContext(algorithm=a2c_alg_config, env=lambda: GymMazeEnv(env="CartPole-v0"), silent=True, runner="dev", overrides=default_overrides) rc.train(1) run_context.RunContext(env=lambda: GymMazeEnv(env="CartPole-v0"), runner=es_dev_runner_config, silent=True, overrides=default_overrides) rc.train(1)
def setup(self, cfg: DictConfig) -> None: """ See :py:meth:`~maze.train.trainers.common.training_runner.TrainingRunner.setup`. """ super().setup(cfg) env = self.env_factory() with SwitchWorkingDirectoryToInput(cfg.input_dir): dataset = Factory(base_type=Dataset).instantiate( self.dataset, conversion_env_factory=self.env_factory) assert len(dataset) > 0, f"Expected to find trajectory data, but did not find any. Please check that " \ f"the path you supplied is correct." size_in_byte, size_in_gbyte = getsize(dataset) BColors.print_colored( f'Size of loaded dataset: {size_in_byte} -> {size_in_gbyte} GB', BColors.OKBLUE) validation, train = self._split_dataset( dataset, cfg.algorithm.validation_percentage, self.maze_seeding.generate_env_instance_seed()) # Create data loaders torch_generator = torch.Generator().manual_seed( self.maze_seeding.generate_env_instance_seed()) train_data_loader = DataLoader(train, shuffle=True, batch_size=cfg.algorithm.batch_size, generator=torch_generator, num_workers=self.dataset.n_workers) policy = TorchPolicy( networks=self._model_composer.policy.networks, distribution_mapper=self._model_composer.distribution_mapper, device=cfg.algorithm.device, substeps_with_separate_agent_nets=self._model_composer.policy. substeps_with_separate_agent_nets) policy.seed(self.maze_seeding.agent_global_seed) self._model_selection = BestModelSelection( self.state_dict_dump_file, policy, dump_interval=self.dump_interval) optimizer = Factory(Optimizer).instantiate(cfg.algorithm.optimizer, params=policy.parameters()) loss = BCLoss(action_spaces_dict=env.action_spaces_dict, entropy_coef=cfg.algorithm.entropy_coef) self._trainer = BCTrainer(algorithm_config=self._cfg.algorithm, data_loader=train_data_loader, policy=policy, optimizer=optimizer, loss=loss) # initialize model from input_dir self._init_trainer_from_input_dir( trainer=self._trainer, state_dict_dump_file=self.state_dict_dump_file, input_dir=cfg.input_dir) # evaluate using the validation set self.evaluators = [] if len(validation) > 0: validation_data_loader = DataLoader( validation, shuffle=True, batch_size=cfg.algorithm.batch_size, generator=torch_generator, num_workers=self.dataset.n_workers) self.evaluators += [ BCValidationEvaluator( data_loader=validation_data_loader, loss=loss, logging_prefix="eval-validation", model_selection=self. _model_selection # use the validation set evaluation to select the best model ) ] # if evaluation episodes are set, perform additional evaluation by policy rollout if cfg.algorithm.n_eval_episodes > 0: eval_env = self.create_distributed_eval_env( self.env_factory, self.eval_concurrency, logging_prefix="eval-rollout") eval_env_instance_seeds = [ self.maze_seeding.generate_env_instance_seed() for _ in range(self.eval_concurrency) ] eval_env.seed(eval_env_instance_seeds) self.evaluators += [ RolloutEvaluator(eval_env, n_episodes=cfg.algorithm.n_eval_episodes, model_selection=None) ]
def train_function(n_epochs: int, epoch_length: int, deterministic_eval: bool, eval_repeats: int, distributed_env_cls, split_rollouts_into_transitions: bool) -> SAC: """Implements the lunar lander continuous env and performs tests on it w.r.t. the sac trainer. """ # initialize distributed env env_factory = lambda: GymMazeEnv(env="LunarLanderContinuous-v2") # initialize the env and enable statistics collection eval_env = distributed_env_cls([env_factory for _ in range(2)], logging_prefix='eval') env = env_factory() # init distribution mapper distribution_mapper = DistributionMapper( action_space=env.action_space, distribution_mapper_config=[{ 'action_space': 'gym.spaces.Box', 'distribution': 'maze.distributions.squashed_gaussian.SquashedGaussianProbabilityDistribution' }]) action_shapes = { step_key: { action_head: tuple(distribution_mapper.required_logits_shape(action_head)) for action_head in env.action_spaces_dict[step_key].spaces.keys() } for step_key in env.action_spaces_dict.keys() } obs_shapes = observation_spaces_to_in_shapes(env.observation_spaces_dict) # initialize policies policies = { ii: PolicyNet(obs_shapes=obs_shapes[ii], action_logits_shapes=action_shapes[ii], non_lin=nn.Tanh) for ii in obs_shapes.keys() } for key, value in env.action_spaces_dict.items(): for act_key, act_space in value.spaces.items(): obs_shapes[key][act_key] = act_space.sample().shape # initialize critic critics = { ii: QCriticNetContinuous(obs_shapes[ii], non_lin=nn.Tanh, action_spaces_dict=env.action_spaces_dict) for ii in obs_shapes.keys() } # initialize optimizer algorithm_config = SACAlgorithmConfig( n_rollout_steps=5, lr=0.001, entropy_coef=0.2, gamma=0.99, max_grad_norm=0.5, batch_size=100, num_actors=2, tau=0.005, target_update_interval=1, entropy_tuning=False, device='cpu', replay_buffer_size=10000, initial_buffer_size=100, initial_sampling_policy={ '_target_': 'maze.core.agent.random_policy.RandomPolicy' }, rollouts_per_iteration=1, split_rollouts_into_transitions=split_rollouts_into_transitions, entropy_coef_lr=0.0007, num_batches_per_iter=1, n_epochs=n_epochs, epoch_length=epoch_length, rollout_evaluator=RolloutEvaluator(eval_env=eval_env, n_episodes=eval_repeats, model_selection=None, deterministic=deterministic_eval), patience=50, target_entropy_multiplier=1.0) actor_policy = TorchPolicy(networks=policies, distribution_mapper=distribution_mapper, device='cpu') replay_buffer = UniformReplayBuffer( buffer_size=algorithm_config.replay_buffer_size, seed=1234) SACRunner.init_replay_buffer( replay_buffer=replay_buffer, initial_sampling_policy=algorithm_config.initial_sampling_policy, initial_buffer_size=algorithm_config.initial_buffer_size, replay_buffer_seed=1234, split_rollouts_into_transitions=split_rollouts_into_transitions, n_rollout_steps=algorithm_config.n_rollout_steps, env_factory=env_factory) distributed_actors = DummyDistributedWorkersWithBuffer( env_factory=env_factory, worker_policy=actor_policy, n_rollout_steps=algorithm_config.n_rollout_steps, n_workers=algorithm_config.num_actors, batch_size=algorithm_config.batch_size, rollouts_per_iteration=algorithm_config.rollouts_per_iteration, split_rollouts_into_transitions=split_rollouts_into_transitions, env_instance_seeds=list(range(algorithm_config.num_actors)), replay_buffer=replay_buffer) critics_policy = TorchStepStateActionCritic( networks=critics, num_policies=1, device='cpu', only_discrete_spaces={0: False}, action_spaces_dict=env.action_spaces_dict) learner_model = TorchActorCritic(policy=actor_policy, critic=critics_policy, device='cpu') # initialize trainer sac = SAC(learner_model=learner_model, distributed_actors=distributed_actors, algorithm_config=algorithm_config, evaluator=algorithm_config.rollout_evaluator, model_selection=None) # train agent sac.train(n_epochs=algorithm_config.n_epochs) return sac
def main(n_epochs: int, rnn_steps: int) -> None: """Trains the cart pole environment with the multi-step a2c implementation. """ env_name = "CartPole-v0" # initialize distributed env envs = SequentialVectorEnv([ lambda: to_rnn_dict_space_environment(env=env_name, rnn_steps=rnn_steps) for _ in range(4) ], logging_prefix="train") # initialize the env and enable statistics collection eval_env = SequentialVectorEnv([ lambda: to_rnn_dict_space_environment(env=env_name, rnn_steps=rnn_steps) for _ in range(4) ], logging_prefix="eval") # map observations to a modality obs_modalities_mappings = {"observation": "feature"} # define how to process a modality modality_config = dict() modality_config["feature"] = { "block_type": "maze.perception.blocks.DenseBlock", "block_params": { "hidden_units": [32, 32], "non_lin": "torch.nn.Tanh" } } modality_config["hidden"] = { "block_type": "maze.perception.blocks.DenseBlock", "block_params": { "hidden_units": [64], "non_lin": "torch.nn.Tanh" } } modality_config["recurrence"] = {} if rnn_steps > 0: modality_config["recurrence"] = { "block_type": "maze.perception.blocks.LSTMLastStepBlock", "block_params": { "hidden_size": 8, "num_layers": 1, "bidirectional": False, "non_lin": "torch.nn.Tanh" } } template_builder = TemplateModelComposer( action_spaces_dict=envs.action_spaces_dict, observation_spaces_dict=envs.observation_spaces_dict, agent_counts_dict=envs.agent_counts_dict, distribution_mapper_config={}, model_builder=ConcatModelBuilder(modality_config, obs_modalities_mappings, None), policy={ '_target_': 'maze.perception.models.policies.ProbabilisticPolicyComposer' }, critic={ '_target_': 'maze.perception.models.critics.StateCriticComposer' }) algorithm_config = A2CAlgorithmConfig(n_epochs=n_epochs, epoch_length=10, patience=10, critic_burn_in_epochs=0, n_rollout_steps=20, lr=0.0005, gamma=0.98, gae_lambda=1.0, policy_loss_coef=1.0, value_loss_coef=0.5, entropy_coef=0.0, max_grad_norm=0.0, device="cpu", rollout_evaluator=RolloutEvaluator( eval_env=eval_env, n_episodes=1, model_selection=None, deterministic=True)) model = TorchActorCritic(policy=TorchPolicy( networks=template_builder.policy.networks, distribution_mapper=template_builder.distribution_mapper, device=algorithm_config.device), critic=template_builder.critic, device=algorithm_config.device) a2c = A2C(rollout_generator=RolloutGenerator(envs), evaluator=algorithm_config.rollout_evaluator, algorithm_config=algorithm_config, model=model, model_selection=None) setup_logging(job_config=None) # train agent a2c.train() # final evaluation run print("Final Evaluation Run:") a2c.evaluate()
def test_concepts_and_structures_run_context_overview(): """ Tests snippets in docs/source/concepts_and_structure/run_context_overview.rst. """ # Default overrides for faster tests. Shouldn't change functionality. ac_overrides = {"runner.concurrency": 1} es_overrides = {"algorithm.n_epochs": 1, "algorithm.n_rollouts_per_update": 1} # Training # -------- rc = RunContext( algorithm="a2c", overrides={"env.name": "CartPole-v0", **ac_overrides}, model="vector_obs", critic="template_state", runner="dev", configuration="test" ) rc.train(n_epochs=1) alg_config = A2CAlgorithmConfig( n_epochs=1, epoch_length=25, patience=15, critic_burn_in_epochs=0, n_rollout_steps=100, lr=0.0005, gamma=0.98, gae_lambda=1.0, policy_loss_coef=1.0, value_loss_coef=0.5, entropy_coef=0.00025, max_grad_norm=0.0, device='cpu', rollout_evaluator=RolloutEvaluator( eval_env=SequentialVectorEnv([lambda: GymMazeEnv("CartPole-v0")]), n_episodes=1, model_selection=None, deterministic=True ) ) rc = RunContext( algorithm=alg_config, overrides={"env.name": "CartPole-v0", **ac_overrides}, model="vector_obs", critic="template_state", runner="dev", configuration="test" ) rc.train(n_epochs=1) rc = RunContext(env=lambda: GymMazeEnv('CartPole-v0'), overrides=es_overrides, runner="dev", configuration="test") rc.train(n_epochs=1) policy_composer_config = { '_target_': 'maze.perception.models.policies.ProbabilisticPolicyComposer', 'networks': [{ '_target_': 'maze.perception.models.built_in.flatten_concat.FlattenConcatPolicyNet', 'non_lin': 'torch.nn.Tanh', 'hidden_units': [256, 256] }], "substeps_with_separate_agent_nets": [], "agent_counts_dict": {0: 1} } rc = RunContext( overrides={"model.policy": policy_composer_config, **es_overrides}, runner="dev", configuration="test" ) rc.train(n_epochs=1) env = GymMazeEnv('CartPole-v0') policy_composer = ProbabilisticPolicyComposer( action_spaces_dict=env.action_spaces_dict, observation_spaces_dict=env.observation_spaces_dict, distribution_mapper=DistributionMapper(action_space=env.action_space, distribution_mapper_config={}), networks=[{ '_target_': 'maze.perception.models.built_in.flatten_concat.FlattenConcatPolicyNet', 'non_lin': 'torch.nn.Tanh', 'hidden_units': [222, 222] }], substeps_with_separate_agent_nets=[], agent_counts_dict={0: 1} ) rc = RunContext(overrides={"model.policy": policy_composer, **es_overrides}, runner="dev", configuration="test") rc.train(n_epochs=1) rc = RunContext(algorithm=alg_config, overrides=ac_overrides, runner="dev", configuration="test") rc.train(n_epochs=1) rc.train() # Rollout # ------- obs = env.reset() for i in range(10): action = rc.compute_action(obs) obs, rewards, dones, info = env.step(action) # Evaluation # ---------- env.reset() evaluator = RolloutEvaluator( # Environment has to be have statistics logging capabilities for RolloutEvaluator. eval_env=LogStatsWrapper.wrap(env, logging_prefix="eval"), n_episodes=1, model_selection=None ) evaluator.evaluate(rc.policy)
def train(n_epochs: int) -> int: """ Trains agent in pure Python. :param n_epochs: Number of epochs to train. :return: 0 if successful. """ # Environment setup # ----------------- env = cartpole_env_factory() # Algorithm setup # --------------- algorithm_config = A2CAlgorithmConfig( n_epochs=5, epoch_length=25, patience=15, critic_burn_in_epochs=0, n_rollout_steps=100, lr=0.0005, gamma=0.98, gae_lambda=1.0, policy_loss_coef=1.0, value_loss_coef=0.5, entropy_coef=0.00025, max_grad_norm=0.0, device='cpu', rollout_evaluator=RolloutEvaluator( eval_env=SequentialVectorEnv([cartpole_env_factory]), n_episodes=1, model_selection=None, deterministic=True ) ) # Custom model setup # ------------------ # Policy customization # ^^^^^^^^^^^^^^^^^^^^ # Policy network. policy_net = CartpolePolicyNet( obs_shapes={'observation': env.observation_space.spaces['observation'].shape}, action_logit_shapes={'action': (env.action_space.spaces['action'].n,)} ) policy_networks = [policy_net] # Policy distribution. distribution_mapper = DistributionMapper(action_space=env.action_space, distribution_mapper_config={}) # Policy composer. policy_composer = ProbabilisticPolicyComposer( action_spaces_dict=env.action_spaces_dict, observation_spaces_dict=env.observation_spaces_dict, # Derive distribution from environment's action space. distribution_mapper=distribution_mapper, networks=policy_networks, # We have only one agent and network, thus this is an empty list. substeps_with_separate_agent_nets=[], # We have only one step and one agent. agent_counts_dict={0: 1} ) # Critic customization # ^^^^^^^^^^^^^^^^^^^^ # Value networks. value_networks = { 0: TorchModelBlock( in_keys='observation', out_keys='value', in_shapes=env.observation_space.spaces['observation'].shape, in_num_dims=[2], out_num_dims=2, net=CartpoleValueNet({'observation': env.observation_space.spaces['observation'].shape}) ) } # Critic composer. critic_composer = SharedStateCriticComposer( observation_spaces_dict=env.observation_spaces_dict, agent_counts_dict={0: 1}, networks=value_networks, stack_observations=True ) # Training # ^^^^^^^^ rc = run_context.RunContext( env=cartpole_env_factory, algorithm=algorithm_config, policy=policy_composer, critic=critic_composer, runner="dev" ) rc.train(n_epochs=n_epochs) # Distributed training # ^^^^^^^^^^^^^^^^^^^^ algorithm_config.rollout_evaluator.eval_env = SubprocVectorEnv([cartpole_env_factory]) rc = run_context.RunContext( env=cartpole_env_factory, algorithm=algorithm_config, policy=policy_composer, critic=critic_composer, runner="local" ) rc.train(n_epochs=n_epochs) # Evaluation # ^^^^^^^^^^ print("-----------------") evaluator = RolloutEvaluator( eval_env=LogStatsWrapper.wrap(cartpole_env_factory(), logging_prefix="eval"), n_episodes=1, model_selection=None ) evaluator.evaluate(rc.policy) return 0
def main(n_epochs: int) -> None: """Trains the cart pole environment with the multi-step a2c implementation. """ # initialize distributed env envs = SequentialVectorEnv( [lambda: GymMazeEnv(env="CartPole-v0") for _ in range(8)], logging_prefix="train") # initialize the env and enable statistics collection eval_env = SequentialVectorEnv( [lambda: GymMazeEnv(env="CartPole-v0") for _ in range(8)], logging_prefix="eval") # init distribution mapper env = GymMazeEnv(env="CartPole-v0") # init default distribution mapper distribution_mapper = DistributionMapper(action_space=env.action_space, distribution_mapper_config={}) # initialize policies policies = { 0: PolicyNet({'observation': (4, )}, {'action': (2, )}, non_lin=nn.Tanh) } # initialize critic critics = {0: ValueNet({'observation': (4, )})} # initialize optimizer algorithm_config = A2CAlgorithmConfig(n_epochs=n_epochs, epoch_length=10, patience=10, critic_burn_in_epochs=0, n_rollout_steps=20, lr=0.0005, gamma=0.98, gae_lambda=1.0, policy_loss_coef=1.0, value_loss_coef=0.5, entropy_coef=0.0, max_grad_norm=0.0, device="cpu", rollout_evaluator=RolloutEvaluator( eval_env=eval_env, n_episodes=1, model_selection=None, deterministic=True)) # initialize actor critic model model = TorchActorCritic(policy=TorchPolicy( networks=policies, distribution_mapper=distribution_mapper, device=algorithm_config.device), critic=TorchSharedStateCritic( networks=critics, obs_spaces_dict=env.observation_spaces_dict, device=algorithm_config.device, stack_observations=False), device=algorithm_config.device) a2c = A2C(rollout_generator=RolloutGenerator(envs), evaluator=algorithm_config.rollout_evaluator, algorithm_config=algorithm_config, model=model, model_selection=None) setup_logging(job_config=None) # train agent a2c.train() # final evaluation run print("Final Evaluation Run:") a2c.evaluate()
def train(n_epochs): # Instantiate one environment. This will be used for convenient access to observation # and action spaces. env = cartpole_env_factory() observation_space = env.observation_space action_space = env.action_space # Policy Setup # ------------ # Policy Network # ^^^^^^^^^^^^^^ # Instantiate policy with the correct shapes of observation and action spaces. policy_net = CartpolePolicyNet( obs_shapes={'observation': observation_space.spaces['observation'].shape}, action_logit_shapes={'action': (action_space.spaces['action'].n,)}) maze_wrapped_policy_net = TorchModelBlock( in_keys='observation', out_keys='action', in_shapes=observation_space.spaces['observation'].shape, in_num_dims=[2], out_num_dims=2, net=policy_net) policy_networks = {0: maze_wrapped_policy_net} # Policy Distribution # ^^^^^^^^^^^^^^^^^^^ distribution_mapper = DistributionMapper( action_space=action_space, distribution_mapper_config={}) # Optionally, you can specify a different distribution with the distribution_mapper_config argument. Using a # Categorical distribution for a discrete action space would be done via distribution_mapper = DistributionMapper( action_space=action_space, distribution_mapper_config=[{ "action_space": gym.spaces.Discrete, "distribution": "maze.distributions.categorical.CategoricalProbabilityDistribution"}]) # Instantiating the Policy # ^^^^^^^^^^^^^^^^^^^^^^^^ torch_policy = TorchPolicy(networks=policy_networks, distribution_mapper=distribution_mapper, device='cpu') # Value Function Setup # -------------------- # Value Network # ^^^^^^^^^^^^^ value_net = CartpoleValueNet(obs_shapes={'observation': observation_space.spaces['observation'].shape}) maze_wrapped_value_net = TorchModelBlock( in_keys='observation', out_keys='value', in_shapes=observation_space.spaces['observation'].shape, in_num_dims=[2], out_num_dims=2, net=value_net) value_networks = {0: maze_wrapped_value_net} # Instantiate the Value Function # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ torch_critic = TorchSharedStateCritic(networks=value_networks, obs_spaces_dict=env.observation_spaces_dict, device='cpu', stack_observations=False) # Initializing the ActorCritic Model. # ----------------------------------- actor_critic_model = TorchActorCritic(policy=torch_policy, critic=torch_critic, device='cpu') # Instantiating the Trainer # ========================= algorithm_config = A2CAlgorithmConfig( n_epochs=n_epochs, epoch_length=25, patience=15, critic_burn_in_epochs=0, n_rollout_steps=100, lr=0.0005, gamma=0.98, gae_lambda=1.0, policy_loss_coef=1.0, value_loss_coef=0.5, entropy_coef=0.00025, max_grad_norm=0.0, device='cpu', rollout_evaluator=RolloutEvaluator( eval_env=SequentialVectorEnv([cartpole_env_factory]), n_episodes=1, model_selection=None, deterministic=True ) ) # Distributed Environments # ------------------------ # In order to use the distributed trainers, the previously created env factory is supplied to one of Maze's # distribution classes: train_envs = SequentialVectorEnv([cartpole_env_factory for _ in range(2)], logging_prefix="train") eval_envs = SequentialVectorEnv([cartpole_env_factory for _ in range(2)], logging_prefix="eval") # Initialize best model selection. model_selection = BestModelSelection(dump_file="params.pt", model=actor_critic_model) a2c_trainer = A2C(rollout_generator=RolloutGenerator(train_envs), evaluator=algorithm_config.rollout_evaluator, algorithm_config=algorithm_config, model=actor_critic_model, model_selection=model_selection) # Train the Agent # =============== # Before starting the training, we will enable logging by calling log_dir = '.' setup_logging(job_config=None, log_dir=log_dir) # Now, we can train the agent. a2c_trainer.train() return 0