def test_env_runner_log_episode_metrics(mock_data_logger, mock_task, mock_agent): # Assign episodes = [1, 2] epsilons = [0.2, 0.1] mean_scores = [0.5, 1] scores = [1.5, 5] iterations = [10, 10] episode_data = dict(episodes=episodes, epsilons=epsilons, mean_scores=mean_scores, iterations=iterations, scores=scores) env_runner = EnvRunner(mock_task, mock_agent, data_logger=mock_data_logger) # Act env_runner.log_episode_metrics(**episode_data) # Assert for idx, episode in enumerate(episodes): mock_data_logger.log_value.assert_any_call("episode/epsilon", epsilons[idx], episode) mock_data_logger.log_value.assert_any_call("episode/avg_score", mean_scores[idx], episode) mock_data_logger.log_value.assert_any_call("episode/score", scores[idx], episode) mock_data_logger.log_value.assert_any_call("episode/iterations", iterations[idx], episode)
def test_env_runner_log_episode_metrics_values_missing(mock_data_logger, mock_task, mock_agent): # Assign episodes = [1, 2] episode_data = dict(episodes=episodes) env_runner = EnvRunner(mock_task, mock_agent, data_logger=mock_data_logger) # Act env_runner.log_episode_metrics(**episode_data) # Assert mock_data_logger.log_value.assert_not_called()