def test_multiagent_cycle_env_runner_interact_episode_render_gif():
    # Assign
    test_task.render = mock.MagicMock(return_value=[[0, 0, 1], [0, 1, 0], [1, 1, 0]])
    multi_sync_env_runner = MultiAgentCycleEnvRunner(test_task, test_agent, max_iterations=10)

    # Act
    multi_sync_env_runner.interact_episode(render_gif=True)

    # Assert
    assert len(multi_sync_env_runner._images) == 10
    assert test_task.render.call_count == 10
def test_multiagent_cycle_env_runner_interact_episode_debug_log():
    # Assign
    multi_sync_env_runner = MultiAgentCycleEnvRunner(test_task, test_agent, max_iterations=10, debug_log=True)

    # Act
    multi_sync_env_runner.interact_episode()

    # Assert
    assert all([len(actions) == 10 for actions in multi_sync_env_runner._actions.values()])
    assert all([len(dones) == 10 for dones in multi_sync_env_runner._dones.values()])
    assert all([len(rewards) == 10 for rewards in multi_sync_env_runner._rewards.values()])
def test_multiagent_cycle_env_runner_interact_episode_log_interaction_without_data_logger():
    # Assign
    test_agent.log_metrics = mock.MagicMock()
    multi_sync_env_runner = MultiAgentCycleEnvRunner(test_task, test_agent, max_iterations=10)
    multi_sync_env_runner.log_data_interaction = mock.MagicMock()

    # Act
    multi_sync_env_runner.interact_episode(log_interaction_freq=1)

    # Assert
    assert multi_sync_env_runner.log_data_interaction.call_count == 10
    assert test_agent.log_metrics.call_count == 0
def test_multiagent_cycle_env_runner_log_data_interaction_debug_log(mock_data_logger):
    # Assign
    test_agent.log_metrics = mock.MagicMock()
    env_runner = MultiAgentCycleEnvRunner(test_task, test_agent, data_logger=mock_data_logger, debug_log=True)

    # Act
    env_runner.interact_episode(eps=0.1, max_iterations=10, log_interaction_freq=None)
    env_runner.log_data_interaction()

    # Assert
    test_agent.log_metrics.assert_called_once_with(mock_data_logger, 10, full_log=False)
    assert mock_data_logger.log_values_dict.call_count == 90  # 3 agents x (A + R + D) x 10 interactions
    assert mock_data_logger.log_value.call_count == 0  # 10x iter per rewards and dones
def test_multiagent_cycle_env_runner_interact_episode_log_interaction(mock_data_logger):
    # Assign
    test_agent.log_metrics = mock.MagicMock()
    multi_sync_env_runner = MultiAgentCycleEnvRunner(test_task, test_agent, data_logger=mock_data_logger, max_iterations=10)

    # Act
    multi_sync_env_runner.interact_episode(log_interaction_freq=1)

    # Assert
    assert test_agent.log_metrics.call_count == 10
    test_agent.log_metrics.assert_called_with(mock_data_logger, 10, full_log=False)  # Last
    mock_data_logger.log_value.assert_not_called()
    mock_data_logger.log_value_dict.assert_not_called()
def test_multiagent_cycle_env_runner_interact_episode_override_max_iteractions():
    # Assign
    test_task.render = mock.MagicMock()
    multi_sync_env_runner = MultiAgentCycleEnvRunner(test_task, test_agent, max_iterations=10)

    # Act
    _, interactions = multi_sync_env_runner.interact_episode(max_iterations=20)

    # Assert
    assert interactions == 20
def test_multiagent_cycle_env_runner_run():
    # Assign
    return_rewards = {name: 1 for name in test_agent.agents}
    multi_sync_env_runner = MultiAgentCycleEnvRunner(test_task, test_agent)
    multi_sync_env_runner.interact_episode = mock.MagicMock(return_value=(return_rewards, 10))

    # Act
    out = multi_sync_env_runner.run(max_episodes=5)

    # Assert
    assert multi_sync_env_runner.interact_episode.call_count == 5
    assert len(out) == 5
    assert len(out[0]) == test_agent.num_agents
def test_multiagent_cycle_env_runner_interact_episode():
    # Assign
    test_task.render = mock.MagicMock()
    multi_sync_env_runner = MultiAgentCycleEnvRunner(test_task, test_agent, max_iterations=10)

    # Act
    output = multi_sync_env_runner.interact_episode()

    # Assert
    assert len(output) == 2  # (rewards, iterations)
    assert isinstance(output[0], dict)
    assert len(output[0]) == test_agent.num_agents
    assert output[1] > 1

    assert len(multi_sync_env_runner._images) == 0
    assert len(multi_sync_env_runner._actions) == 0
    assert len(multi_sync_env_runner._rewards) == 0
    assert len(multi_sync_env_runner._dones) == 0

    test_task.render.assert_not_called()