def test_additional_request_post_headers(monkeypatch): """ Tests that the `TrinoRequest.post` function can take addtional headers and that it combines them with the existing ones to perform the request. """ post_recorder = ArgumentsRecorder() monkeypatch.setattr(TrinoRequest.http.Session, "post", post_recorder) req = TrinoRequest( host="coordinator", port=8080, user="******", source="test", catalog="test", schema="test", http_scheme="http", session_properties={}, ) sql = 'select 1' additional_headers = { 'X-Trino-Fake-1': 'one', 'X-Trino-Fake-2': 'two', } combined_headers = req.http_headers combined_headers.update(additional_headers) req.post(sql, additional_headers) # Validate that the post call was performed including the addtional headers assert post_recorder.kwargs['headers'] == combined_headers
def test_request_timeout(): timeout = 0.1 http_scheme = "http" host = "coordinator" port = 8080 url = http_scheme + "://" + host + ":" + str( port) + constants.URL_STATEMENT_PATH def long_call(request, uri, headers): time.sleep(timeout * 2) return (200, headers, "delayed success") httpretty.enable() for method in [httpretty.POST, httpretty.GET]: httpretty.register_uri(method, url, body=long_call) # timeout without retry for request_timeout in [timeout, (timeout, timeout)]: req = TrinoRequest( host=host, port=port, user="******", http_scheme=http_scheme, max_attempts=1, request_timeout=request_timeout, ) with pytest.raises(requests.exceptions.Timeout): req.get(url) with pytest.raises(requests.exceptions.Timeout): req.post("select 1") httpretty.disable() httpretty.reset()
def test_additional_request_post_headers(mock_get_and_post): """ Tests that the `TrinoRequest.post` function can take addtional headers and that it combines them with the existing ones to perform the request. """ _, post = mock_get_and_post req = TrinoRequest( host="coordinator", port=8080, user="******", source="test", catalog="test", schema="test", http_scheme="http", session_properties={}, ) sql = 'select 1' additional_headers = { 'X-Trino-Fake-1': 'one', 'X-Trino-Fake-2': 'two', } combined_headers = req.http_headers combined_headers.update(additional_headers) req.post(sql, additional_headers) # Validate that the post call was performed including the addtional headers _, post_kwargs = post.call_args assert post_kwargs['headers'] == combined_headers
def test_authentication_fail_retry(monkeypatch): post_retry = RetryRecorder(error=KerberosExchangeError()) monkeypatch.setattr(TrinoRequest.http.Session, "post", post_retry) get_retry = RetryRecorder(error=KerberosExchangeError()) monkeypatch.setattr(TrinoRequest.http.Session, "get", get_retry) attempts = 3 kerberos_auth = KerberosAuthentication() req = TrinoRequest( host="coordinator", port=8080, user="******", http_scheme=constants.HTTPS, auth=kerberos_auth, max_attempts=attempts, ) with pytest.raises(KerberosExchangeError): req.post("URL") assert post_retry.retry_count == attempts with pytest.raises(KerberosExchangeError): req.get("URL") assert post_retry.retry_count == attempts
def test_enabling_https_automatically_when_using_port_443(mock_get_and_post): _, post = mock_get_and_post req = TrinoRequest( host="coordinator", port=constants.DEFAULT_TLS_PORT, user="******", ) req.post("SELECT 1") post_args, _ = post.call_args parsed_url = urlparse(post_args[0]) assert parsed_url.scheme == constants.HTTPS
def test_extra_credential_value_encoding(mock_get_and_post): _, post = mock_get_and_post req = TrinoRequest( host="coordinator", port=constants.DEFAULT_TLS_PORT, user="******", extra_credential=[("foo", "bar 的")], ) req.post("SELECT 1") _, post_kwargs = post.call_args headers = post_kwargs["headers"] assert constants.HEADER_EXTRA_CREDENTIAL in headers assert headers[constants.HEADER_EXTRA_CREDENTIAL] == "foo=bar+%E7%9A%84"
def test_extra_credential(mock_get_and_post): _, post = mock_get_and_post req = TrinoRequest( host="coordinator", port=constants.DEFAULT_TLS_PORT, user="******", extra_credential=[("a.username", "foo"), ("b.password", "bar")], ) req.post("SELECT 1") _, post_kwargs = post.call_args headers = post_kwargs["headers"] assert constants.HEADER_EXTRA_CREDENTIAL in headers assert headers[ constants.HEADER_EXTRA_CREDENTIAL] == "a.username=foo, b.password=bar"
def test_http_scheme_with_port(mock_get_and_post): _, post = mock_get_and_post req = TrinoRequest( host="coordinator", port=constants.DEFAULT_TLS_PORT, user="******", http_scheme=constants.HTTP, ) req.post("SELECT 1") post_args, _ = post.call_args parsed_url = urlparse(post_args[0]) assert parsed_url.scheme == constants.HTTP assert parsed_url.port == constants.DEFAULT_TLS_PORT
def test_request_headers(monkeypatch): post_recorder = ArgumentsRecorder() monkeypatch.setattr(TrinoRequest.http.Session, "post", post_recorder) get_recorder = ArgumentsRecorder() monkeypatch.setattr(TrinoRequest.http.Session, "get", get_recorder) catalog = "test_catalog" schema = "test_schema" user = "******" source = "test_source" accept_encoding_header = "accept-encoding" accept_encoding_value = "identity,deflate,gzip" client_info_header = constants.HEADER_CLIENT_INFO client_info_value = "some_client_info" req = TrinoRequest( host="coordinator", port=8080, user=user, source=source, catalog=catalog, schema=schema, http_scheme="http", session_properties={}, http_headers={ accept_encoding_header: accept_encoding_value, client_info_header: client_info_value, }, redirect_handler=None, ) def assert_headers(headers): assert headers[constants.HEADER_CATALOG] == catalog assert headers[constants.HEADER_SCHEMA] == schema assert headers[constants.HEADER_SOURCE] == source assert headers[constants.HEADER_USER] == user assert headers[constants.HEADER_SESSION] == "" assert headers[accept_encoding_header] == accept_encoding_value assert headers[client_info_header] == client_info_value assert len(headers.keys()) == 8 req.post("URL") assert_headers(post_recorder.kwargs["headers"]) req.get("URL") assert_headers(get_recorder.kwargs["headers"])
def test_request_headers(mock_get_and_post): get, post = mock_get_and_post catalog = "test_catalog" schema = "test_schema" user = "******" source = "test_source" accept_encoding_header = "accept-encoding" accept_encoding_value = "identity,deflate,gzip" client_info_header = constants.HEADER_CLIENT_INFO client_info_value = "some_client_info" req = TrinoRequest( host="coordinator", port=8080, user=user, source=source, catalog=catalog, schema=schema, http_scheme="http", session_properties={}, http_headers={ accept_encoding_header: accept_encoding_value, client_info_header: client_info_value, }, redirect_handler=None, ) def assert_headers(headers): assert headers[constants.HEADER_CATALOG] == catalog assert headers[constants.HEADER_SCHEMA] == schema assert headers[constants.HEADER_SOURCE] == source assert headers[constants.HEADER_USER] == user assert headers[constants.HEADER_SESSION] == "" assert headers[accept_encoding_header] == accept_encoding_value assert headers[client_info_header] == client_info_value assert len(headers.keys()) == 8 req.post("URL") _, post_kwargs = post.call_args assert_headers(post_kwargs["headers"]) req.get("URL") _, get_kwargs = get.call_args assert_headers(get_kwargs["headers"])
def test_oauth2_authentication_missing_headers(header, error): # bind post statement httpretty.register_uri( method=httpretty.POST, uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", adding_headers={'WWW-Authenticate': header}, status=401) request = TrinoRequest(host="coordinator", port=constants.DEFAULT_TLS_PORT, user="******", http_scheme=constants.HTTPS, auth=trino.auth.OAuth2Authentication( redirect_auth_url_handler=RedirectHandler())) with pytest.raises(trino.exceptions.TrinoAuthError) as exp: request.post("select 1") assert str(exp.value) == error
def test_oauth2_authentication_fail_token_server(http_status, sample_post_response_data): token = str(uuid.uuid4()) challenge_id = str(uuid.uuid4()) redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}" token_server = f"{TOKEN_RESOURCE}/{challenge_id}" post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data) # bind post statement httpretty.register_uri( method=httpretty.POST, uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", body=post_statement_callback) httpretty.register_uri(method=httpretty.GET, uri=f"{TOKEN_RESOURCE}/{challenge_id}", status=http_status, body="error") redirect_handler = RedirectHandler() request = TrinoRequest(host="coordinator", port=constants.DEFAULT_TLS_PORT, user="******", http_scheme=constants.HTTPS, auth=trino.auth.OAuth2Authentication( redirect_auth_url_handler=redirect_handler)) with pytest.raises(trino.exceptions.TrinoAuthError) as exp: request.post("select 1") assert redirect_handler.redirect_server == redirect_server assert str( exp.value ) == f"Error while getting the token response status code: {http_status}, body: error" assert len(_post_statement_requests()) == 1 assert len(_get_token_requests(challenge_id)) == 1
def test_503_error_retry(monkeypatch): http_resp = TrinoRequest.http.Response() http_resp.status_code = 503 post_retry = RetryRecorder(result=http_resp) monkeypatch.setattr(TrinoRequest.http.Session, "post", post_retry) get_retry = RetryRecorder(result=http_resp) monkeypatch.setattr(TrinoRequest.http.Session, "get", get_retry) attempts = 3 req = TrinoRequest(host="coordinator", port=8080, user="******", max_attempts=attempts) req.post("URL") assert post_retry.retry_count == attempts req.get("URL") assert post_retry.retry_count == attempts
def test_oauth2_exceed_max_attempts(attempts, sample_post_response_data): token = str(uuid.uuid4()) challenge_id = str(uuid.uuid4()) redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}" token_server = f"{TOKEN_RESOURCE}/{challenge_id}" post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data) # bind post statement httpretty.register_uri( method=httpretty.POST, uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", body=post_statement_callback) # bind get token get_token_callback = GetTokenCallback(token_server, token, attempts) httpretty.register_uri(method=httpretty.GET, uri=f"{TOKEN_RESOURCE}/{challenge_id}", body=get_token_callback) redirect_handler = RedirectHandler() request = TrinoRequest(host="coordinator", port=constants.DEFAULT_TLS_PORT, user="******", http_scheme=constants.HTTPS, auth=trino.auth.OAuth2Authentication( redirect_auth_url_handler=redirect_handler)) with pytest.raises(trino.exceptions.TrinoAuthError) as exp: request.post("select 1") assert str(exp.value) == "Exceeded max attempts while getting the token" assert redirect_handler.redirect_server == redirect_server assert get_token_callback.attempts == attempts - _OAuth2TokenBearer.MAX_OAUTH_ATTEMPTS assert len(_post_statement_requests()) == 1 assert len(_get_token_requests( challenge_id)) == _OAuth2TokenBearer.MAX_OAUTH_ATTEMPTS
def test_request_session_properties_headers(mock_get_and_post): get, post = mock_get_and_post req = TrinoRequest(host="coordinator", port=8080, user="******", session_properties={ "a": "1", "b": "2", "c": "more=v1,v2" }) def assert_headers(headers): assert headers[constants.HEADER_SESSION] == "a=1,b=2,c=more%3Dv1%2Cv2" req.post("URL") _, post_kwargs = post.call_args assert_headers(post_kwargs["headers"]) req.get("URL") _, get_kwargs = get.call_args assert_headers(get_kwargs["headers"])
def run(self) -> None: request = TrinoRequest(host="coordinator", port=constants.DEFAULT_TLS_PORT, user="******", http_scheme=constants.HTTPS, auth=auth) for i in range(10): # apparently HTTPretty in the current version is not thread-safe # https://github.com/gabrielfalcao/HTTPretty/issues/209 with RunningThread.lock: response = request.post("select 1") self.token = response.request.headers["Authorization"].replace( "Bearer ", "")
def test_oauth2_authentication_flow(attempts, sample_post_response_data): token = str(uuid.uuid4()) challenge_id = str(uuid.uuid4()) redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}" token_server = f"{TOKEN_RESOURCE}/{challenge_id}" post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data) # bind post statement httpretty.register_uri( method=httpretty.POST, uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", body=post_statement_callback) # bind get token get_token_callback = GetTokenCallback(token_server, token, attempts) httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback) redirect_handler = RedirectHandler() request = TrinoRequest(host="coordinator", port=constants.DEFAULT_TLS_PORT, user="******", http_scheme=constants.HTTPS, auth=trino.auth.OAuth2Authentication( redirect_auth_url_handler=redirect_handler)) response = request.post("select 1") assert response.request.headers['Authorization'] == f"Bearer {token}" assert redirect_handler.redirect_server == redirect_server assert get_token_callback.attempts == 0 assert len(_post_statement_requests()) == 2 assert len(_get_token_requests(challenge_id)) == attempts