Пример #1
0
def test_integration_with_cartpole():
    env = envs.CartPole()
    agent = agents.DeterministicMCTSAgent(
        action_space=env.action_space,
        n_passes=2,
    )
    episode = testing.run_with_dummy_network(agent.solve(env))
    assert episode.transition_batch.observation.shape[0]  # pylint: disable=no-member
Пример #2
0
def test_integration_with_cartpole():
    env = envs.CartPole()
    agent = agents.DeterministicMCTSAgent(n_passes=2)
    network_sig = agent.network_signature(env.observation_space,
                                          env.action_space)
    episode = testing.run_with_dummy_network_prediction(
        agent.solve(env), network_sig)
    assert episode.transition_batch.observation.shape[0]  # pylint: disable=no-member
Пример #3
0
def test_integration_with_cartpole():
    # Set up
    env = envs.CartPole()
    agent = agents.ShootingAgent(n_rollouts=1)

    # Run
    episode = testing.run_with_dummy_network_response(agent.solve(env))

    # Test
    assert episode.transition_batch.observation.shape[0]  # pylint: disable=no-member
Пример #4
0
def test_integration_with_cartpole():
    # Set up
    env = envs.CartPole()
    agent = agents.ShootingAgent(action_space=env.action_space, n_rollouts=1)

    # Run
    episode = testing.run_without_suspensions(agent.solve(env))

    # Test
    assert episode.transition_batch.observation.shape[0]  # pylint: disable=no-member
Пример #5
0
def test_integration_with_cartpole():
    env = envs.CartPole()
    agent = agents.StochasticMCTSAgent(
        n_passes=2,
        new_leaf_rater_class=functools.partial(
            agents.stochastic_mcts.RolloutNewLeafRater,
            rollout_time_limit=2,
        ),
    )
    episode = testing.run_without_suspensions(agent.solve(env))
    assert episode.transition_batch.observation.shape[0]  # pylint: disable=no-member
def test_integration_with_cartpole(graph_mode):
    env = envs.CartPole()
    agent = agents.StochasticMCTSAgent(
        action_space=env.action_space,
        n_passes=2,
        rate_new_leaves_fn=functools.partial(
            agents.stochastic_mcts.rate_new_leaves_with_rollouts,
            rollout_time_limit=2,
        ),
        graph_mode=graph_mode,
    )
    episode = testing.run_without_suspensions(agent.solve(env))
    assert episode.transition_batch.observation.shape[0]  # pylint: disable=no-member
Пример #7
0
def test_act_doesnt_change_env_state():
    # Set up
    env = envs.CartPole()
    agent = agents.ShootingAgent(n_rollouts=10)
    observation = env.reset()
    testing.run_with_dummy_network_response(agent.reset(env, observation))

    # Run
    state_before = env.clone_state()
    testing.run_without_suspensions(agent.act(observation))
    state_after = env.clone_state()

    # Test
    np.testing.assert_equal(state_before, state_after)
def test_batch_steppers_different_requests_single_batch(batch_stepper_cls):
    xparams = 'params'
    n_envs = 9
    n_requests = 100

    agent_network_fn = functools.partial(
        _TestPredictNetwork, request_type=RequestType.AGENT_PREDICTION)
    model_network_fn = functools.partial(
        _TestPredictNetwork, request_type=RequestType.MODEL_PREDICTION)

    bs = batch_stepper_cls(
        env_class=envs.CartPole,
        agent_class=RandomAgent,  # This is only placeholder.
        network_fn=agent_network_fn,  # This is only placeholder.
        model_network_fn=model_network_fn,  # This is only placeholder.
        n_envs=n_envs,
        output_dir=None,
    )

    envs_and_agents = [
        (envs.CartPole(),
         _TestAgentYieldingNetworkRequest(agent_network_fn, xparams,
                                          n_requests))
        for _ in range(n_envs // 3)
    ]
    envs_and_agents += [(envs.CartPole(),
                         _TestAgentYieldingAgentPredictRequest(n_requests))
                        for _ in range(n_envs // 3)]
    envs_and_agents += [(envs.CartPole(),
                         _TestAgentYieldingModelPredictRequest(n_requests))
                        for _ in range(n_envs // 3)]

    # Envs and agents have to be replaced here manually, because
    # LocalBatchStepper does not support different types of agents to be run
    # simultaneously.
    bs._envs_and_agents = envs_and_agents  # pylint: disable=protected-access
    bs.run_episode_batch(xparams)
def test_act_doesnt_change_env_state(graph_mode, rate_new_leaves_fn):
    env = envs.CartPole()
    agent = agents.StochasticMCTSAgent(
        action_space=env.action_space,
        n_passes=2,
        rate_new_leaves_fn=rate_new_leaves_fn,
        graph_mode=graph_mode,
    )
    observation = env.reset()
    testing.run_without_suspensions(agent.reset(env, observation))

    state_before = env.clone_state()
    testing.run_with_dummy_network(agent.act(observation))
    state_after = env.clone_state()
    np.testing.assert_equal(state_before, state_after)
Пример #10
0
def test_act_doesnt_change_env_state(new_leaf_rater_class):
    env = envs.CartPole()
    agent = agents.StochasticMCTSAgent(
        n_passes=2,
        new_leaf_rater_class=new_leaf_rater_class,
    )
    observation = env.reset()
    testing.run_without_suspensions(agent.reset(env, observation))

    state_before = env.clone_state()
    network_sig = agent.network_signature(env.observation_space,
                                          env.action_space)
    testing.run_with_dummy_network_prediction(agent.act(observation),
                                              network_sig)
    state_after = env.clone_state()
    np.testing.assert_equal(state_before, state_after)