Exemplo n.º 1
0
def test_rainbow_set_loss():
    # Assign
    agent = RainbowAgent(1, 1, device='cpu')
    new_loss = 1
    assert agent.loss == {'loss': float('inf')} # Check default

    # Act
    agent.loss = new_loss

    # Assert
    assert agent.loss == {'loss': new_loss}
Exemplo n.º 2
0
def test_rainbow_log_metrics(mock_data_logger):
    # Assign
    agent = RainbowAgent(1, 1, device='cpu')
    step = 10
    agent.loss = 1

    # Act
    agent.log_metrics(mock_data_logger, step)

    # Assert
    mock_data_logger.log_value.assert_called_once_with("loss/agent", agent._loss, step)
    mock_data_logger.log_value_dict.assert_not_called()
    mock_data_logger.create_histogram.assert_not_called()
Exemplo n.º 3
0
def test_rainbow_log_metrics_full_log(mock_data_logger):
    # Assign
    agent = RainbowAgent(1, 1, device='cpu', hidden_layers=(10,))  # Only 2 layers ((I, H) -> (H, O)
    step = 10
    agent.loss = 1

    # Act
    agent.log_metrics(mock_data_logger, step, full_log=True)

    # Assert
    assert agent.dist_probs is None
    mock_data_logger.log_value.assert_called_once_with("loss/agent", agent._loss, step)
    mock_data_logger.log_value_dict.assert_not_called()
    mock_data_logger.add_histogram.assert_not_called()
    assert mock_data_logger.create_histogram.call_count == 4 * 2  # 4x per layer
Exemplo n.º 4
0
def test_rainbow_log_metrics_full_log_dist_prob(mock_data_logger):
    """Acting on a state means that there's a prob dist created for each actions."""
    # Assign
    agent = RainbowAgent(1, 1, device='cpu', hidden_layers=(10,))  # Only 2 layers ((I, H) -> (H, O)
    step = 10
    agent.loss = 1

    # Act
    agent.act([0])
    agent.log_metrics(mock_data_logger, step, full_log=True)

    # Assert
    assert agent.dist_probs is not None
    mock_data_logger.log_value.call_count == 2  # 1x loss + 1x dist_prob
    mock_data_logger.log_value_dict.assert_not_called()
    mock_data_logger.add_histogram.assert_called_once()
    assert mock_data_logger.create_histogram.call_count == 4 * 2  # 4x per layer