def test_multiagent_cycle_env_runner_interact_episode_render_gif(): # Assign test_task.render = mock.MagicMock(return_value=[[0, 0, 1], [0, 1, 0], [1, 1, 0]]) multi_sync_env_runner = MultiAgentCycleEnvRunner(test_task, test_agent, max_iterations=10) # Act multi_sync_env_runner.interact_episode(render_gif=True) # Assert assert len(multi_sync_env_runner._images) == 10 assert test_task.render.call_count == 10
def test_multiagent_cycle_env_runner_interact_episode_debug_log(): # Assign multi_sync_env_runner = MultiAgentCycleEnvRunner(test_task, test_agent, max_iterations=10, debug_log=True) # Act multi_sync_env_runner.interact_episode() # Assert assert all([len(actions) == 10 for actions in multi_sync_env_runner._actions.values()]) assert all([len(dones) == 10 for dones in multi_sync_env_runner._dones.values()]) assert all([len(rewards) == 10 for rewards in multi_sync_env_runner._rewards.values()])
def test_multiagent_cycle_env_runner_interact_episode_log_interaction_without_data_logger(): # Assign test_agent.log_metrics = mock.MagicMock() multi_sync_env_runner = MultiAgentCycleEnvRunner(test_task, test_agent, max_iterations=10) multi_sync_env_runner.log_data_interaction = mock.MagicMock() # Act multi_sync_env_runner.interact_episode(log_interaction_freq=1) # Assert assert multi_sync_env_runner.log_data_interaction.call_count == 10 assert test_agent.log_metrics.call_count == 0
def test_multiagent_cycle_env_runner_log_data_interaction_debug_log(mock_data_logger): # Assign test_agent.log_metrics = mock.MagicMock() env_runner = MultiAgentCycleEnvRunner(test_task, test_agent, data_logger=mock_data_logger, debug_log=True) # Act env_runner.interact_episode(eps=0.1, max_iterations=10, log_interaction_freq=None) env_runner.log_data_interaction() # Assert test_agent.log_metrics.assert_called_once_with(mock_data_logger, 10, full_log=False) assert mock_data_logger.log_values_dict.call_count == 90 # 3 agents x (A + R + D) x 10 interactions assert mock_data_logger.log_value.call_count == 0 # 10x iter per rewards and dones
def test_multiagent_cycle_env_runner_interact_episode_log_interaction(mock_data_logger): # Assign test_agent.log_metrics = mock.MagicMock() multi_sync_env_runner = MultiAgentCycleEnvRunner(test_task, test_agent, data_logger=mock_data_logger, max_iterations=10) # Act multi_sync_env_runner.interact_episode(log_interaction_freq=1) # Assert assert test_agent.log_metrics.call_count == 10 test_agent.log_metrics.assert_called_with(mock_data_logger, 10, full_log=False) # Last mock_data_logger.log_value.assert_not_called() mock_data_logger.log_value_dict.assert_not_called()
def test_multiagent_cycle_env_runner_interact_episode_override_max_iteractions(): # Assign test_task.render = mock.MagicMock() multi_sync_env_runner = MultiAgentCycleEnvRunner(test_task, test_agent, max_iterations=10) # Act _, interactions = multi_sync_env_runner.interact_episode(max_iterations=20) # Assert assert interactions == 20
def test_multiagent_cycle_env_runner_run(): # Assign return_rewards = {name: 1 for name in test_agent.agents} multi_sync_env_runner = MultiAgentCycleEnvRunner(test_task, test_agent) multi_sync_env_runner.interact_episode = mock.MagicMock(return_value=(return_rewards, 10)) # Act out = multi_sync_env_runner.run(max_episodes=5) # Assert assert multi_sync_env_runner.interact_episode.call_count == 5 assert len(out) == 5 assert len(out[0]) == test_agent.num_agents
def test_multiagent_cycle_env_runner_interact_episode(): # Assign test_task.render = mock.MagicMock() multi_sync_env_runner = MultiAgentCycleEnvRunner(test_task, test_agent, max_iterations=10) # Act output = multi_sync_env_runner.interact_episode() # Assert assert len(output) == 2 # (rewards, iterations) assert isinstance(output[0], dict) assert len(output[0]) == test_agent.num_agents assert output[1] > 1 assert len(multi_sync_env_runner._images) == 0 assert len(multi_sync_env_runner._actions) == 0 assert len(multi_sync_env_runner._rewards) == 0 assert len(multi_sync_env_runner._dones) == 0 test_task.render.assert_not_called()