def _test_log_batch_helper_success(metric_entities, param_entities, tag_entities, expected_metrics=None, expected_params=None, expected_tags=None): """ Simulates a LogBatch API request using the provided metrics/params/tags, asserting that it succeeds & that the backing store contains either the set of expected metrics/params/tags (if provided) or, by default, the metrics/params/tags used in the API request. """ with mlflow.start_run() as active_run: run_id = active_run.info.run_uuid mock_get_request_message.return_value = LogBatch( run_id=run_id, metrics=[m.to_proto() for m in metric_entities], params=[p.to_proto() for p in param_entities], tags=[t.to_proto() for t in tag_entities]) response = _log_batch() assert response.status_code == 200 json_response = json.loads(response.get_data()) assert json_response == {} _assert_logged_entities(run_id, expected_metrics or metric_entities, expected_params or param_entities, expected_tags or tag_entities)
def test_log_batch_api_req(mock_get_request_json): mock_get_request_json.return_value = "a" * (MAX_BATCH_LOG_REQUEST_SIZE + 1) response = _log_batch() assert response.status_code == 400 json_response = json.loads(response.get_data()) assert json_response["error_code"] == ErrorCode.Name(INVALID_PARAMETER_VALUE) assert ("Batched logging API requests must be at most %s bytes" % MAX_BATCH_LOG_REQUEST_SIZE in json_response["message"])