def test_predict_json(model): headers = make_headers() data = "{\"instances\": [1.0, 2.0, 5.0]}" response = requests.post(INVOCATION_URL.format(model), data=data, headers=headers).json() assert response == {"predictions": [3.5, 4.0, 5.5]}
def test_csv_input(): headers = make_headers(content_type="text/csv") data = "1.0,2.0,5.0" response = requests.post(INVOCATION_URL.format(MODEL_NAME), data=data, headers=headers).json() assert response == {"predictions": [3.5, 4.0, 5.5]}
def test_specific_versions(): for version in ("123", "124"): headers = make_headers(content_type="text/csv", version=version) data = "1.0,2.0,5.0" response = requests.post( INVOCATION_URL.format(MODEL_NAME), data=data, headers=headers ).json() assert response == {"predictions": [3.5, 4.0, 5.5]}
def test_unsupported_content_type(): headers = make_headers("unsupported-type", "predict") data = "aW1hZ2UgYnl0ZXM=" response = requests.post(INVOCATION_URL.format(MODEL_NAME), data=data, headers=headers) assert 500 == response.status_code assert "unsupported content type" in response.text
def test_zero_content(): headers = make_headers() x = "" response = requests.post(INVOCATION_URL.format(MODEL_NAME), data=x, headers=headers) assert 500 == response.status_code assert "document is empty" in response.text
def test_large_input(): data_file = "test/resources/inputs/test-large.csv" with open(data_file, "r") as file: x = file.read() headers = make_headers(content_type="text/csv") response = requests.post(INVOCATION_URL.format(MODEL_NAME), data=x, headers=headers).json() predictions = response["predictions"] assert len(predictions) == 753936