def test_greedy_decision_for_all_return_estimators(mock_env, mock_bstep, estimate_fn, x_action, logits): # Set up agent = agents.ShootingAgent( estimate_fn=estimate_fn, batch_stepper_class=mock_bstep, n_rollouts=1, ) x_onehot_action = np.zeros(3) x_onehot_action[x_action] = 1 # Run observation = mock_env.reset() testing.run_with_dummy_network_response( agent.reset(mock_env, observation) ) (actual_action, agent_info) = testing.run_with_constant_network_prediction( agent.act(None), logits=logits ) # Test assert actual_action == x_action np.testing.assert_array_equal(agent_info['action_histogram'], x_onehot_action)
def test_rollout_time_limit(mock_env, rollout_time_limit): # Set up rollout_max_len = 10 # It must be greater then rollout_time_limit! mock_env.action_space.sample.return_value = 0 mock_env.step.side_effect = \ [('d', 0, False, {})] * (rollout_max_len - 1) + [('d', 0, True, {})] mock_env.clone_state.return_value = 's' mock_env.restore_state.return_value = 'o' if rollout_time_limit is None: expected_rollout_time_limit = rollout_max_len else: expected_rollout_time_limit = rollout_time_limit def _aggregate_fn(_, episodes): # Test actual_rollout_time_limit = len(episodes[0].transition_batch.done) assert actual_rollout_time_limit == expected_rollout_time_limit with mock.patch('alpacka.agents.shooting.type') as mock_type: mock_type.return_value = lambda: mock_env agent = agents.ShootingAgent( action_space=mock_env.action_space, aggregate_fn=_aggregate_fn, rollout_time_limit=rollout_time_limit, n_rollouts=1, ) # Run observation = mock_env.reset() testing.run_without_suspensions(agent.reset(mock_env, observation)) testing.run_without_suspensions(agent.act(None))
def test_integration_with_cartpole(): # Set up env = envs.CartPole() agent = agents.ShootingAgent(n_rollouts=1) # Run episode = testing.run_with_dummy_network_response(agent.solve(env)) # Test assert episode.transition_batch.observation.shape[0] # pylint: disable=no-member
def test_integration_with_cartpole(): # Set up env = envs.CartPole() agent = agents.ShootingAgent(action_space=env.action_space, n_rollouts=1) # Run episode = testing.run_without_suspensions(agent.solve(env)) # Test assert episode.transition_batch.observation.shape[0] # pylint: disable=no-member
def test_act_doesnt_change_env_state(): # Set up env = envs.CartPole() agent = agents.ShootingAgent(n_rollouts=10) observation = env.reset() testing.run_with_dummy_network_response(agent.reset(env, observation)) # Run state_before = env.clone_state() testing.run_without_suspensions(agent.act(observation)) state_after = env.clone_state() # Test np.testing.assert_equal(state_before, state_after)
def test_number_of_simulations(mock_env, mock_bstep): # Set up n_rollouts = 7 n_envs = 2 agent = agents.ShootingAgent(batch_stepper_class=mock_bstep, n_rollouts=n_rollouts, n_envs=n_envs) # Run observation = mock_env.reset() testing.run_with_dummy_network_response(agent.reset(mock_env, observation)) testing.run_without_suspensions(agent.act(None)) # Test assert mock_bstep.return_value.run_episode_batch.call_count == \ math.ceil(n_rollouts / n_envs)
def test_greedy_decision_for_all_aggregators(mock_env, mock_bstep_class, aggregate_fn, expected_action): # Set up agent = agents.ShootingAgent( action_space=mock_env.action_space, aggregate_fn=aggregate_fn, batch_stepper_class=mock_bstep_class, n_rollouts=1, ) # Run observation = mock_env.reset() testing.run_without_suspensions(agent.reset(mock_env, observation)) (actual_action, _) = testing.run_without_suspensions(agent.act(None)) # Test assert actual_action == expected_action
def test_rollout_time_limit(mock_env, rollout_time_limit): # Set up rollout_max_len = 10 # It must be greater then rollout_time_limit! mock_env.action_space.sample.return_value = 0 mock_env.step.side_effect = \ [('d', 0, False, {})] * (rollout_max_len - 1) + [('d', 0, True, {})] mock_env.clone_state.return_value = 's' mock_env.restore_state.return_value = 'o' if rollout_time_limit is None: x_rollout_time_limit = rollout_max_len else: x_rollout_time_limit = rollout_time_limit @asyncio.coroutine def _estimate_fn(episodes, discount): del discount # Test for episode in episodes: actual_rollout_time_limit = len(episode.transition_batch.done) assert actual_rollout_time_limit == x_rollout_time_limit return [1.] * len(episodes) with mock.patch('alpacka.agents.mc_simulation.type') as mock_type: mock_type.return_value = lambda: mock_env agent = agents.ShootingAgent( n_rollouts=1, rollout_time_limit=rollout_time_limit, estimate_fn=_estimate_fn, n_envs=1, ) # Run observation = mock_env.reset() testing.run_with_dummy_network_response( agent.reset(mock_env, observation) ) testing.run_without_suspensions(agent.act(None))