def test_env_runner_log_data_interaction_no_data_logger(mock_task, mock_agent): # Assign env_runner = EnvRunner(mock_task, mock_agent) # Act env_runner.log_data_interaction() # Assert mock_agent.log_metrics.assert_not_called()
def test_env_runner_log_data_interaction(mock_data_logger, mock_task, mock_agent): # Assign env_runner = EnvRunner(mock_task, mock_agent, data_logger=mock_data_logger) # Act env_runner.log_data_interaction() # Assert mock_agent.log_metrics.assert_called_once_with(mock_data_logger, 0, full_log=False)
def test_env_runner_log_data_interaction_debug_log(mock_data_logger, mock_task, mock_agent): # Assign mock_task.step.return_value = ([1, 0.1], -1, False, {}) mock_agent.act.return_value = 1 env_runner = EnvRunner(mock_task, mock_agent, data_logger=mock_data_logger, debug_log=True) # Act env_runner.interact_episode(eps=0.1, max_iterations=10, log_interaction_freq=None) env_runner.log_data_interaction() # Assert mock_agent.log_metrics.assert_called_once_with(mock_data_logger, 10, full_log=False) assert mock_data_logger.log_values_dict.call_count == 20 # 10x iter per states and actions assert mock_data_logger.log_value.call_count == 20 # 10x iter per rewards and dones