예제 #1
0
def test_env_runner_log_data_interaction_no_data_logger(mock_task, mock_agent):
    # Assign
    env_runner = EnvRunner(mock_task, mock_agent)

    # Act
    env_runner.log_data_interaction()

    # Assert
    mock_agent.log_metrics.assert_not_called()
예제 #2
0
def test_env_runner_log_data_interaction(mock_data_logger, mock_task,
                                         mock_agent):
    # Assign
    env_runner = EnvRunner(mock_task, mock_agent, data_logger=mock_data_logger)

    # Act
    env_runner.log_data_interaction()

    # Assert
    mock_agent.log_metrics.assert_called_once_with(mock_data_logger,
                                                   0,
                                                   full_log=False)
예제 #3
0
def test_env_runner_log_data_interaction_debug_log(mock_data_logger, mock_task,
                                                   mock_agent):
    # Assign
    mock_task.step.return_value = ([1, 0.1], -1, False, {})
    mock_agent.act.return_value = 1
    env_runner = EnvRunner(mock_task,
                           mock_agent,
                           data_logger=mock_data_logger,
                           debug_log=True)

    # Act
    env_runner.interact_episode(eps=0.1,
                                max_iterations=10,
                                log_interaction_freq=None)
    env_runner.log_data_interaction()

    # Assert
    mock_agent.log_metrics.assert_called_once_with(mock_data_logger,
                                                   10,
                                                   full_log=False)
    assert mock_data_logger.log_values_dict.call_count == 20  # 10x iter per states and actions
    assert mock_data_logger.log_value.call_count == 20  # 10x iter per rewards and dones