def test_examples_part2(): """ Tests snippets in maze/docs/source/concepts_and_structure/run_context_overview.rst. Adds some performance-specific configuration that should not influence snippets' functionality. Split for runtime reasons. """ a2c_overrides = {"runner.concurrency": 1} es_overrides = { "algorithm.n_epochs": 1, "algorithm.n_rollouts_per_update": 1 } env_factory = lambda: GymMazeEnv('CartPole-v0') alg_config = _get_alg_config('CartPole-v0', "dev") # ------------------------------------------------------------------ rc = RunContext(algorithm=alg_config, runner="dev", overrides=a2c_overrides, configuration="test") rc.train(n_epochs=1) # ------------------------------------------------------------------ with pytest.raises(omegaconf.errors.ConfigAttributeError): rc = RunContext(overrides={"algorithm": alg_config}, runner="dev") rc.train(n_epochs=1) # ------------------------------------------------------------------ rc = RunContext(env=lambda: env_factory(), overrides=es_overrides, runner="dev", configuration="test") rc.train(n_epochs=1) # Run trained policy. env = env_factory() obs = env.reset() for i in range(10): action = rc.compute_action(obs) obs, rewards, dones, info = env.step(action) # ------------------------------------------------------------------ rc = RunContext(env=lambda: GymMazeEnv('CartPole-v0'), overrides=es_overrides, runner="dev", configuration="test") rc.train() rc.evaluate()
def test_readme(): """ Tests snippets in readme.md. """ rc = RunContext(env=lambda: GymMazeEnv('CartPole-v0')) rc.train(n_epochs=1) # Run trained policy. env = GymMazeEnv('CartPole-v0') obs = env.reset() done = False while not done: action = rc.compute_action(obs) obs, reward, done, info = env.step(action) break
def test_concepts_and_structures_run_context_overview(): """ Tests snippets in docs/source/concepts_and_structure/run_context_overview.rst. """ # Default overrides for faster tests. Shouldn't change functionality. ac_overrides = {"runner.concurrency": 1} es_overrides = {"algorithm.n_epochs": 1, "algorithm.n_rollouts_per_update": 1} # Training # -------- rc = RunContext( algorithm="a2c", overrides={"env.name": "CartPole-v0", **ac_overrides}, model="vector_obs", critic="template_state", runner="dev", configuration="test" ) rc.train(n_epochs=1) alg_config = A2CAlgorithmConfig( n_epochs=1, epoch_length=25, patience=15, critic_burn_in_epochs=0, n_rollout_steps=100, lr=0.0005, gamma=0.98, gae_lambda=1.0, policy_loss_coef=1.0, value_loss_coef=0.5, entropy_coef=0.00025, max_grad_norm=0.0, device='cpu', rollout_evaluator=RolloutEvaluator( eval_env=SequentialVectorEnv([lambda: GymMazeEnv("CartPole-v0")]), n_episodes=1, model_selection=None, deterministic=True ) ) rc = RunContext( algorithm=alg_config, overrides={"env.name": "CartPole-v0", **ac_overrides}, model="vector_obs", critic="template_state", runner="dev", configuration="test" ) rc.train(n_epochs=1) rc = RunContext(env=lambda: GymMazeEnv('CartPole-v0'), overrides=es_overrides, runner="dev", configuration="test") rc.train(n_epochs=1) policy_composer_config = { '_target_': 'maze.perception.models.policies.ProbabilisticPolicyComposer', 'networks': [{ '_target_': 'maze.perception.models.built_in.flatten_concat.FlattenConcatPolicyNet', 'non_lin': 'torch.nn.Tanh', 'hidden_units': [256, 256] }], "substeps_with_separate_agent_nets": [], "agent_counts_dict": {0: 1} } rc = RunContext( overrides={"model.policy": policy_composer_config, **es_overrides}, runner="dev", configuration="test" ) rc.train(n_epochs=1) env = GymMazeEnv('CartPole-v0') policy_composer = ProbabilisticPolicyComposer( action_spaces_dict=env.action_spaces_dict, observation_spaces_dict=env.observation_spaces_dict, distribution_mapper=DistributionMapper(action_space=env.action_space, distribution_mapper_config={}), networks=[{ '_target_': 'maze.perception.models.built_in.flatten_concat.FlattenConcatPolicyNet', 'non_lin': 'torch.nn.Tanh', 'hidden_units': [222, 222] }], substeps_with_separate_agent_nets=[], agent_counts_dict={0: 1} ) rc = RunContext(overrides={"model.policy": policy_composer, **es_overrides}, runner="dev", configuration="test") rc.train(n_epochs=1) rc = RunContext(algorithm=alg_config, overrides=ac_overrides, runner="dev", configuration="test") rc.train(n_epochs=1) rc.train() # Rollout # ------- obs = env.reset() for i in range(10): action = rc.compute_action(obs) obs, rewards, dones, info = env.step(action) # Evaluation # ---------- env.reset() evaluator = RolloutEvaluator( # Environment has to be have statistics logging capabilities for RolloutEvaluator. eval_env=LogStatsWrapper.wrap(env, logging_prefix="eval"), n_episodes=1, model_selection=None ) evaluator.evaluate(rc.policy)