def test_rainbow_set_loss(): # Assign agent = RainbowAgent(1, 1, device='cpu') new_loss = 1 assert agent.loss == {'loss': float('inf')} # Check default # Act agent.loss = new_loss # Assert assert agent.loss == {'loss': new_loss}
def test_rainbow_log_metrics(mock_data_logger): # Assign agent = RainbowAgent(1, 1, device='cpu') step = 10 agent.loss = 1 # Act agent.log_metrics(mock_data_logger, step) # Assert mock_data_logger.log_value.assert_called_once_with("loss/agent", agent._loss, step) mock_data_logger.log_value_dict.assert_not_called() mock_data_logger.create_histogram.assert_not_called()
def test_rainbow_log_metrics_full_log(mock_data_logger): # Assign agent = RainbowAgent(1, 1, device='cpu', hidden_layers=(10,)) # Only 2 layers ((I, H) -> (H, O) step = 10 agent.loss = 1 # Act agent.log_metrics(mock_data_logger, step, full_log=True) # Assert assert agent.dist_probs is None mock_data_logger.log_value.assert_called_once_with("loss/agent", agent._loss, step) mock_data_logger.log_value_dict.assert_not_called() mock_data_logger.add_histogram.assert_not_called() assert mock_data_logger.create_histogram.call_count == 4 * 2 # 4x per layer
def test_rainbow_log_metrics_full_log_dist_prob(mock_data_logger): """Acting on a state means that there's a prob dist created for each actions.""" # Assign agent = RainbowAgent(1, 1, device='cpu', hidden_layers=(10,)) # Only 2 layers ((I, H) -> (H, O) step = 10 agent.loss = 1 # Act agent.act([0]) agent.log_metrics(mock_data_logger, step, full_log=True) # Assert assert agent.dist_probs is not None mock_data_logger.log_value.call_count == 2 # 1x loss + 1x dist_prob mock_data_logger.log_value_dict.assert_not_called() mock_data_logger.add_histogram.assert_called_once() assert mock_data_logger.create_histogram.call_count == 4 * 2 # 4x per layer