Example #1
0
def test_greedy_decision_for_all_return_estimators(mock_env, mock_bstep,
                                                   estimate_fn, x_action,
                                                   logits):
    # Set up
    agent = agents.ShootingAgent(
        estimate_fn=estimate_fn,
        batch_stepper_class=mock_bstep,
        n_rollouts=1,
    )
    x_onehot_action = np.zeros(3)
    x_onehot_action[x_action] = 1

    # Run
    observation = mock_env.reset()
    testing.run_with_dummy_network_response(
        agent.reset(mock_env, observation)
    )
    (actual_action, agent_info) = testing.run_with_constant_network_prediction(
        agent.act(None),
        logits=logits
    )

    # Test
    assert actual_action == x_action
    np.testing.assert_array_equal(agent_info['action_histogram'],
                                  x_onehot_action)
Example #2
0
def test_rollout_time_limit(mock_env, rollout_time_limit):
    # Set up
    rollout_max_len = 10  # It must be greater then rollout_time_limit!
    mock_env.action_space.sample.return_value = 0
    mock_env.step.side_effect = \
        [('d', 0, False, {})] * (rollout_max_len - 1) + [('d', 0, True, {})]
    mock_env.clone_state.return_value = 's'
    mock_env.restore_state.return_value = 'o'

    if rollout_time_limit is None:
        expected_rollout_time_limit = rollout_max_len
    else:
        expected_rollout_time_limit = rollout_time_limit

    def _aggregate_fn(_, episodes):
        # Test
        actual_rollout_time_limit = len(episodes[0].transition_batch.done)
        assert actual_rollout_time_limit == expected_rollout_time_limit

    with mock.patch('alpacka.agents.shooting.type') as mock_type:
        mock_type.return_value = lambda: mock_env
        agent = agents.ShootingAgent(
            action_space=mock_env.action_space,
            aggregate_fn=_aggregate_fn,
            rollout_time_limit=rollout_time_limit,
            n_rollouts=1,
        )

        # Run
        observation = mock_env.reset()
        testing.run_without_suspensions(agent.reset(mock_env, observation))
        testing.run_without_suspensions(agent.act(None))
Example #3
0
def test_integration_with_cartpole():
    # Set up
    env = envs.CartPole()
    agent = agents.ShootingAgent(n_rollouts=1)

    # Run
    episode = testing.run_with_dummy_network_response(agent.solve(env))

    # Test
    assert episode.transition_batch.observation.shape[0]  # pylint: disable=no-member
Example #4
0
def test_integration_with_cartpole():
    # Set up
    env = envs.CartPole()
    agent = agents.ShootingAgent(action_space=env.action_space, n_rollouts=1)

    # Run
    episode = testing.run_without_suspensions(agent.solve(env))

    # Test
    assert episode.transition_batch.observation.shape[0]  # pylint: disable=no-member
Example #5
0
def test_act_doesnt_change_env_state():
    # Set up
    env = envs.CartPole()
    agent = agents.ShootingAgent(n_rollouts=10)
    observation = env.reset()
    testing.run_with_dummy_network_response(agent.reset(env, observation))

    # Run
    state_before = env.clone_state()
    testing.run_without_suspensions(agent.act(observation))
    state_after = env.clone_state()

    # Test
    np.testing.assert_equal(state_before, state_after)
Example #6
0
def test_number_of_simulations(mock_env, mock_bstep):
    # Set up
    n_rollouts = 7
    n_envs = 2
    agent = agents.ShootingAgent(batch_stepper_class=mock_bstep,
                                 n_rollouts=n_rollouts,
                                 n_envs=n_envs)

    # Run
    observation = mock_env.reset()
    testing.run_with_dummy_network_response(agent.reset(mock_env, observation))
    testing.run_without_suspensions(agent.act(None))

    # Test
    assert mock_bstep.return_value.run_episode_batch.call_count == \
        math.ceil(n_rollouts / n_envs)
Example #7
0
def test_greedy_decision_for_all_aggregators(mock_env, mock_bstep_class,
                                             aggregate_fn, expected_action):
    # Set up
    agent = agents.ShootingAgent(
        action_space=mock_env.action_space,
        aggregate_fn=aggregate_fn,
        batch_stepper_class=mock_bstep_class,
        n_rollouts=1,
    )

    # Run
    observation = mock_env.reset()
    testing.run_without_suspensions(agent.reset(mock_env, observation))
    (actual_action, _) = testing.run_without_suspensions(agent.act(None))

    # Test
    assert actual_action == expected_action
Example #8
0
def test_rollout_time_limit(mock_env, rollout_time_limit):
    # Set up
    rollout_max_len = 10  # It must be greater then rollout_time_limit!
    mock_env.action_space.sample.return_value = 0
    mock_env.step.side_effect = \
        [('d', 0, False, {})] * (rollout_max_len - 1) + [('d', 0, True, {})]
    mock_env.clone_state.return_value = 's'
    mock_env.restore_state.return_value = 'o'

    if rollout_time_limit is None:
        x_rollout_time_limit = rollout_max_len
    else:
        x_rollout_time_limit = rollout_time_limit

    @asyncio.coroutine
    def _estimate_fn(episodes, discount):
        del discount
        # Test
        for episode in episodes:
            actual_rollout_time_limit = len(episode.transition_batch.done)
            assert actual_rollout_time_limit == x_rollout_time_limit

        return [1.] * len(episodes)

    with mock.patch('alpacka.agents.mc_simulation.type') as mock_type:
        mock_type.return_value = lambda: mock_env
        agent = agents.ShootingAgent(
            n_rollouts=1,
            rollout_time_limit=rollout_time_limit,
            estimate_fn=_estimate_fn,
            n_envs=1,
        )

        # Run
        observation = mock_env.reset()
        testing.run_with_dummy_network_response(
            agent.reset(mock_env, observation)
        )
        testing.run_without_suspensions(agent.act(None))