コード例 #1
0
def test_examples_part2():
    """
    Tests snippets in maze/docs/source/concepts_and_structure/run_context_overview.rst.
    Adds some performance-specific configuration that should not influence snippets' functionality.
    Split for runtime reasons.
    """

    a2c_overrides = {"runner.concurrency": 1}
    es_overrides = {
        "algorithm.n_epochs": 1,
        "algorithm.n_rollouts_per_update": 1
    }
    env_factory = lambda: GymMazeEnv('CartPole-v0')
    alg_config = _get_alg_config('CartPole-v0', "dev")

    # ------------------------------------------------------------------

    rc = RunContext(algorithm=alg_config,
                    runner="dev",
                    overrides=a2c_overrides,
                    configuration="test")
    rc.train(n_epochs=1)

    # ------------------------------------------------------------------

    with pytest.raises(omegaconf.errors.ConfigAttributeError):
        rc = RunContext(overrides={"algorithm": alg_config}, runner="dev")
        rc.train(n_epochs=1)

    # ------------------------------------------------------------------

    rc = RunContext(env=lambda: env_factory(),
                    overrides=es_overrides,
                    runner="dev",
                    configuration="test")
    rc.train(n_epochs=1)

    # Run trained policy.
    env = env_factory()
    obs = env.reset()
    for i in range(10):
        action = rc.compute_action(obs)
        obs, rewards, dones, info = env.step(action)

    # ------------------------------------------------------------------

    rc = RunContext(env=lambda: GymMazeEnv('CartPole-v0'),
                    overrides=es_overrides,
                    runner="dev",
                    configuration="test")
    rc.train()
    rc.evaluate()
コード例 #2
0
def test_readme():
    """
    Tests snippets in readme.md.
    """

    rc = RunContext(env=lambda: GymMazeEnv('CartPole-v0'))
    rc.train(n_epochs=1)

    # Run trained policy.
    env = GymMazeEnv('CartPole-v0')
    obs = env.reset()
    done = False

    while not done:
        action = rc.compute_action(obs)
        obs, reward, done, info = env.step(action)
        break
コード例 #3
0
def test_concepts_and_structures_run_context_overview():
    """
    Tests snippets in docs/source/concepts_and_structure/run_context_overview.rst.
    """

    # Default overrides for faster tests. Shouldn't change functionality.
    ac_overrides = {"runner.concurrency": 1}
    es_overrides = {"algorithm.n_epochs": 1, "algorithm.n_rollouts_per_update": 1}

    # Training
    # --------

    rc = RunContext(
        algorithm="a2c",
        overrides={"env.name": "CartPole-v0", **ac_overrides},
        model="vector_obs",
        critic="template_state",
        runner="dev",
        configuration="test"
    )
    rc.train(n_epochs=1)

    alg_config = A2CAlgorithmConfig(
        n_epochs=1,
        epoch_length=25,
        patience=15,
        critic_burn_in_epochs=0,
        n_rollout_steps=100,
        lr=0.0005,
        gamma=0.98,
        gae_lambda=1.0,
        policy_loss_coef=1.0,
        value_loss_coef=0.5,
        entropy_coef=0.00025,
        max_grad_norm=0.0,
        device='cpu',
        rollout_evaluator=RolloutEvaluator(
            eval_env=SequentialVectorEnv([lambda: GymMazeEnv("CartPole-v0")]),
            n_episodes=1,
            model_selection=None,
            deterministic=True
        )
    )

    rc = RunContext(
        algorithm=alg_config,
        overrides={"env.name": "CartPole-v0", **ac_overrides},
        model="vector_obs",
        critic="template_state",
        runner="dev",
        configuration="test"
    )
    rc.train(n_epochs=1)

    rc = RunContext(env=lambda: GymMazeEnv('CartPole-v0'), overrides=es_overrides, runner="dev", configuration="test")
    rc.train(n_epochs=1)

    policy_composer_config = {
        '_target_': 'maze.perception.models.policies.ProbabilisticPolicyComposer',
        'networks': [{
            '_target_': 'maze.perception.models.built_in.flatten_concat.FlattenConcatPolicyNet',
            'non_lin': 'torch.nn.Tanh',
            'hidden_units': [256, 256]
        }],
        "substeps_with_separate_agent_nets": [],
        "agent_counts_dict": {0: 1}
    }
    rc = RunContext(
        overrides={"model.policy": policy_composer_config, **es_overrides}, runner="dev", configuration="test"
    )
    rc.train(n_epochs=1)

    env = GymMazeEnv('CartPole-v0')
    policy_composer = ProbabilisticPolicyComposer(
        action_spaces_dict=env.action_spaces_dict,
        observation_spaces_dict=env.observation_spaces_dict,
        distribution_mapper=DistributionMapper(action_space=env.action_space, distribution_mapper_config={}),
        networks=[{
            '_target_': 'maze.perception.models.built_in.flatten_concat.FlattenConcatPolicyNet',
            'non_lin': 'torch.nn.Tanh',
            'hidden_units': [222, 222]
        }],
        substeps_with_separate_agent_nets=[],
        agent_counts_dict={0: 1}
    )
    rc = RunContext(overrides={"model.policy": policy_composer, **es_overrides}, runner="dev", configuration="test")
    rc.train(n_epochs=1)

    rc = RunContext(algorithm=alg_config, overrides=ac_overrides, runner="dev", configuration="test")
    rc.train(n_epochs=1)
    rc.train()

    # Rollout
    # -------

    obs = env.reset()
    for i in range(10):
        action = rc.compute_action(obs)
        obs, rewards, dones, info = env.step(action)

    # Evaluation
    # ----------

    env.reset()
    evaluator = RolloutEvaluator(
        # Environment has to be have statistics logging capabilities for RolloutEvaluator.
        eval_env=LogStatsWrapper.wrap(env, logging_prefix="eval"),
        n_episodes=1,
        model_selection=None
    )
    evaluator.evaluate(rc.policy)