def test_dispatch_common_service_exception(): def delete_queue(_context: RequestContext, _request: ServiceRequest): raise CommonServiceException("NonExistentQueue", "No such queue") table: DispatchTable = dict() table["DeleteQueue"] = delete_queue sqs_service = load_service("sqs") skeleton = Skeleton(sqs_service, table) context = RequestContext() context.account = "test" context.region = "us-west-1" context.service = sqs_service context.request = { "method": "POST", "path": "/", "body": "Action=DeleteQueue&Version=2012-11-05&QueueUrl=http%3A%2F%2Flocalhost%3A4566%2F000000000000%2Ftf-acc-test-queue", "headers": _get_sqs_request_headers(), } result = skeleton.invoke(context) # Use the parser from botocore to parse the serialized response response_parser = create_parser(sqs_service.protocol) parsed_response = response_parser.parse( result, sqs_service.operation_model("SendMessage").output_shape) assert "Error" in parsed_response assert parsed_response["Error"] == { "Code": "NonExistentQueue", "Message": "No such queue", }
def test_dispatch_missing_method_returns_internal_failure(): table: DispatchTable = dict() sqs_service = load_service("sqs") skeleton = Skeleton(sqs_service, table) context = RequestContext() context.account = "test" context.region = "us-west-1" context.service = sqs_service context.request = { "method": "POST", "path": "/", "body": "Action=DeleteQueue&Version=2012-11-05&QueueUrl=http%3A%2F%2Flocalhost%3A4566%2F000000000000%2Ftf-acc-test-queue", "headers": _get_sqs_request_headers(), } result = skeleton.invoke(context) # Use the parser from botocore to parse the serialized response response_parser = create_parser(sqs_service.protocol) parsed_response = response_parser.parse( result, sqs_service.operation_model("SendMessage").output_shape) assert "Error" in parsed_response assert parsed_response["Error"] == { "Code": "InternalFailure", "Message": "API action 'DeleteQueue' for service 'sqs' not yet implemented", }
def test_create_op_router_works_for_every_service(service): router = RestServiceOperationRouter(load_service(service)) try: router.match(Request("GET", "/")) except NotFound: pass
def _botocore_error_serializer_integration_test( service: str, action: str, exception: ServiceException, code: str, status_code: int, message: Optional[str], is_sender_fault: bool = False, ): """ Performs an integration test for the error serialization using botocore as parser. It executes the following steps: - Load the given service (f.e. "sqs") - Serialize the _error_ response with the appropriate serializer from the AWS Serivce Framework - Parse the serialized error response using the botocore parser - Checks the the metadata is correct (status code, requestID,...) - Checks if the parsed error response content is correct :param service: to load the correct service specification, serializer, and parser :param action: to load the correct service specification, serializer, and parser :param exception: which should be serialized and tested against :param code: expected "code" of the exception (i.e. the AWS specific exception ID, f.e. "CloudFrontOriginAccessIdentityAlreadyExists") :param status_code: expected HTTP response status code :param message: expected error message :return: None """ # Load the appropriate service service = load_service(service) # Use our serializer to serialize the response response_serializer = create_serializer(service) serialized_response = response_serializer.serialize_error_to_response( exception, service.operation_model(action)) # Use the parser from botocore to parse the serialized response response_parser: ResponseParser = create_parser(service.protocol) parsed_response = response_parser.parse( serialized_response.to_readonly_response_dict(), service.operation_model(action).output_shape, ) # Check if the result is equal to the initial response params assert "Error" in parsed_response assert "Code" in parsed_response["Error"] assert "Message" in parsed_response["Error"] assert parsed_response["Error"]["Code"] == code assert parsed_response["Error"]["Message"] == message assert "ResponseMetadata" in parsed_response assert "RequestId" in parsed_response["ResponseMetadata"] assert len(parsed_response["ResponseMetadata"]["RequestId"]) == 52 assert "HTTPStatusCode" in parsed_response["ResponseMetadata"] assert parsed_response["ResponseMetadata"]["HTTPStatusCode"] == status_code type = parsed_response["Error"].get("Type") if is_sender_fault: assert type == "Sender" else: assert type is None
def test_query_parser_non_flattened_list_structure(): """Simple test with a non-flattened list structure (CloudFormation CreateChangeSet).""" parser = QueryRequestParser(load_service("cloudformation")) request = HttpRequest( body=to_bytes( "Action=CreateChangeSet&" "ChangeSetName=SampleChangeSet&" "Parameters.member.1.ParameterKey=KeyName&" "Parameters.member.1.UsePreviousValue=true&" "Parameters.member.2.ParameterKey=Purpose&" "Parameters.member.2.ParameterValue=production&" "StackName=arn:aws:cloudformation:us-east-1:123456789012:stack/SampleStack/1a2345b6-0000-00a0-a123-00abc0abc000&" "UsePreviousTemplate=true&" "Version=2010-05-15&" "X-Amz-Algorithm=AWS4-HMAC-SHA256&" "X-Amz-Credential=[Access-key-ID-and-scope]&" "X-Amz-Date=20160316T233349Z&" "X-Amz-SignedHeaders=content-type;host&" "X-Amz-Signature=[Signature]" ), method="POST", headers={}, path="", ) operation, params = parser.parse(request) assert operation.name == "CreateChangeSet" assert params == { "StackName": "arn:aws:cloudformation:us-east-1:123456789012:stack/SampleStack/1a2345b6-0000-00a0-a123-00abc0abc000", "UsePreviousTemplate": True, "Parameters": [ {"ParameterKey": "KeyName", "UsePreviousValue": True}, {"ParameterKey": "Purpose", "ParameterValue": "production"}, ], "ChangeSetName": "SampleChangeSet", }
def test_query_parser_non_flattened_list_structure_changed_name(): """Simple test with a non-flattened list structure where the name of the list differs from the shape's name (CloudWatch PutMetricData).""" parser = QueryRequestParser(load_service("cloudwatch")) request = HttpRequest( body=to_bytes( "Action=PutMetricData&" "Version=2010-08-01&" "Namespace=TestNamespace&" "MetricData.member.1.MetricName=buffers&" "MetricData.member.1.Unit=Bytes&" "MetricData.member.1.Value=231434333&" "MetricData.member.1.Dimensions.member.1.Name=InstanceType&" "MetricData.member.1.Dimensions.member.1.Value=m1.small&" "AUTHPARAMS" ), method="POST", headers={}, path="", ) operation, params = parser.parse(request) assert operation.name == "PutMetricData" assert params == { "MetricData": [ { "Dimensions": [{"Name": "InstanceType", "Value": "m1.small"}], "MetricName": "buffers", "Unit": "Bytes", "Value": 231434333.0, } ], "Namespace": "TestNamespace", }
def generate_code(service_name: str, doc: bool = False) -> str: model = load_service(service_name) output = io.StringIO() generate_service_types(output, model, doc=doc) generate_service_api(output, model, doc=doc) code = output.getvalue() try: import autoflake import isort from black import FileMode, format_str # try to format with black code = format_str(code, mode=FileMode(line_length=100)) # try to remove unused imports code = autoflake.fix_code(code, remove_all_unused_imports=True) # try to sort imports code = isort.code(code, config=isort.Config(profile="black", line_length=100)) except Exception: pass return code
def _botocore_parser_integration_test(service: str, action: str, method: str = None, request_uri: str = None, headers: dict = None, expected: dict = None, **kwargs): # Load the appropriate service service = load_service(service) # Use the serializer from botocore to serialize the request params serializer = create_serializer(service.protocol) serialized_request = serializer.serialize_to_request( kwargs, service.operation_model(action)) serialized_request["path"] = request_uri serialized_request["method"] = method serialized_request["headers"] = headers if service.protocol in ["query", "ec2"]: # Serialize the body as query parameter serialized_request["body"] = urlencode(serialized_request["body"]) # Use our parser to parse the serialized body parser = create_parser(service) operation_model, parsed_request = parser.parse(serialized_request) # Check if the result is equal to the given "expected" dict or the kwargs (if "expected" has not been set) assert parsed_request == (expected or kwargs)
def test_query_parser_flattened_list_structure(): """Simple test with a flattened list of structures.""" parser = QueryRequestParser(load_service("sqs")) request = HttpRequest( body=to_bytes( "Action=DeleteMessageBatch&" "Version=2012-11-05&" "QueueUrl=http%3A%2F%2Flocalhost%3A4566%2F000000000000%2Ftf-acc-test-queue&" "DeleteMessageBatchRequestEntry.1.Id=bar&" "DeleteMessageBatchRequestEntry.1.ReceiptHandle=foo&" "DeleteMessageBatchRequestEntry.2.Id=bar&" "DeleteMessageBatchRequestEntry.2.ReceiptHandle=foo"), method="POST", headers={}, path="", ) operation, params = parser.parse(request) assert operation.name == "DeleteMessageBatch" assert params == { "QueueUrl": "http://localhost:4566/000000000000/tf-acc-test-queue", "Entries": [{ "Id": "bar", "ReceiptHandle": "foo" }, { "Id": "bar", "ReceiptHandle": "foo" }], }
def test_s3_head_request(): router = RestServiceOperationRouter(load_service("s3")) op, _ = router.match(Request("GET", "/my-bucket/my-key/")) assert op.name == "GetObject" op, _ = router.match(Request("HEAD", "/my-bucket/my-key/")) assert op.name == "HeadObject"
def _botocore_serializer_integration_test( service: str, action: str, response: dict, status_code=200, expected_response_content: dict = None, ): """ Performs an integration test for the serializer using botocore as parser. It executes the following steps: - Load the given service (f.e. "sqs") - Serialize the response with the appropriate serializer from the AWS Serivce Framework - Parse the serialized response using the botocore parser - Checks if the metadata is correct (status code, requestID,...) - Checks if the parsed response content is equal to the input to the serializer :param service: to load the correct service specification, serializer, and parser :param action: to load the correct service specification, serializer, and parser :param response: which should be serialized and tested against :param status_code: Optional - expected status code of the response - defaults to 200 :param expected_response_content: Optional - if the input data ("response") differs from the actually expected data (because f.e. it contains None values) :return: None """ # Load the appropriate service service = load_service(service) # Use our serializer to serialize the response response_serializer = create_serializer(service) # The serializer changes the incoming dict, therefore copy it before passing it to the serializer response_to_parse = copy.deepcopy(response) serialized_response = response_serializer.serialize_to_response( response_to_parse, service.operation_model(action)) # Use the parser from botocore to parse the serialized response response_parser = create_parser(service.protocol) parsed_response = response_parser.parse( serialized_response.to_readonly_response_dict(), service.operation_model(action).output_shape, ) return_response = copy.deepcopy(parsed_response) # Check if the result is equal to the initial response params assert "ResponseMetadata" in parsed_response assert "HTTPStatusCode" in parsed_response["ResponseMetadata"] assert parsed_response["ResponseMetadata"]["HTTPStatusCode"] == status_code assert "RequestId" in parsed_response["ResponseMetadata"] assert len(parsed_response["ResponseMetadata"]["RequestId"]) == 52 del parsed_response["ResponseMetadata"] if expected_response_content is None: expected_response_content = response if expected_response_content is not _skip_assert: assert parsed_response == expected_response_content return return_response
def test_json_protocol_content_type_1_1(): """Logs defines the jsonVersion 1.1, therefore the Content-Type needs to be application/x-amz-json-1.1.""" service = load_service("logs") response_serializer = create_serializer(service) result: Response = response_serializer.serialize_to_response( {}, service.operation_model("DeleteLogGroup")) assert result is not None assert result.content_type is not None assert result.content_type == "application/x-amz-json-1.1"
def test_json_protocol_content_type_1_0(): """AppRunner defines the jsonVersion 1.0, therefore the Content-Type needs to be application/x-amz-json-1.0.""" service = load_service("apprunner") response_serializer = create_serializer(service) result: Response = response_serializer.serialize_to_response( {}, service.operation_model("DeleteConnection")) assert result is not None assert result.content_type is not None assert result.content_type == "application/x-amz-json-1.0"
def test_serializer_error_on_protocol_error(): """Test that the serializer raises a ProtocolSerializerError in case of invalid data to serialize.""" service = load_service("sqs") operation_model = service.operation_model("SendMessage") serializer = QueryResponseSerializer() with pytest.raises(ProtocolSerializerError): # a known protocol error would be if we try to serialize an exception which is not a CommonServiceException and # also not a generated exception serializer.serialize_error_to_response(NotImplementedError(), operation_model)
def test_s3_virtual_host_addressing(): """Test the parsing of a map with the location trait 'headers'.""" request = HttpRequest( method="PUT", headers={"host": s3_utils.get_bucket_hostname("test-bucket")} ) parser = create_parser(load_service("s3")) parsed_operation_model, parsed_request = parser.parse(request) assert parsed_operation_model.name == "CreateBucket" assert "Bucket" in parsed_request assert parsed_request["Bucket"] == "test-bucket"
def test_restjson_parser_path_params_with_slashes(): parser = RestJSONRequestParser(load_service("qldb")) resource_arn = "arn:aws:qldb:eu-central-1:000000000000:ledger/c-c67c827a" request = HttpRequest( body=b"", method="GET", headers={}, path=f"/tags/{resource_arn}", ) operation, params = parser.parse(request) assert operation.name == "ListTagsForResource" assert params == {"ResourceArn": resource_arn}
def _validate_actions(self, actions: ActionNameList): service = load_service(service=self.service, version=self.version) # FIXME: this is a bit of a heuristic as it will also include actions like "ListQueues" which is not # associated with an action on a queue valid = list(service.operation_names) valid.append("*") for action in actions: if action not in valid: raise InvalidParameterValue( f"Value SQS:{action} for parameter ActionName is invalid. Reason: Please refer to the appropriate " "WSDL for a list of valid actions. ")
def _botocore_parser_integration_test(service: str, action: str, headers: dict = None, expected: dict = None, **kwargs): # Load the appropriate service service = load_service(service) # Use the serializer from botocore to serialize the request params serializer = create_serializer(service.protocol) operation_model = service.operation_model(action) serialized_request = serializer.serialize_to_request( kwargs, operation_model) prepare_request_dict(serialized_request, "") split_url = urlsplit(serialized_request.get("url")) path = split_url.path query_string = split_url.query body = serialized_request["body"] # use custom headers (if provided), or headers from serialized request as default headers = serialized_request.get("headers") if headers is None else headers if service.protocol in ["query", "ec2"]: # Serialize the body as query parameter body = urlencode(serialized_request["body"]) # Use our parser to parse the serialized body parser = create_parser(service) parsed_operation_model, parsed_request = parser.parse( HttpRequest( method=serialized_request.get("method") or "GET", path=unquote(path), query_string=to_str(query_string), headers=headers, body=body, raw_path=path, )) # Check if the determined operation_model is correct assert parsed_operation_model == operation_model # Check if the result is equal to the given "expected" dict or the kwargs (if "expected" has not been set) expected = expected or kwargs # The parser adds None for none-existing members on purpose. Remove those for the assert expected = { key: value for key, value in expected.items() if value is not None } parsed_request = { key: value for key, value in parsed_request.items() if value is not None } assert parsed_request == expected
def create_aws_request_context( service_name: str, action: str, parameters: Mapping[str, Any] = None, region: str = None, endpoint_url: Optional[str] = None, ) -> RequestContext: """ This is a stripped-down version of what the botocore client does to perform an HTTP request from a client call. A client call looks something like this: boto3.client("sqs").create_queue(QueueName="myqueue"), which will be serialized into an HTTP request. This method does the same, without performing the actual request, and with a more low-level interface. An equivalent call would be create_aws_request_context("sqs", "CreateQueue", {"QueueName": "myqueue"}) :param service_name: the AWS service :param action: the action to invoke :param parameters: the invocation parameters :param region: the region name (default is us-east-1) :param endpoint_url: the endpoint to call (defaults to localstack) :return: a RequestContext object that describes this request """ if parameters is None: parameters = {} if region is None: region = config.AWS_REGION_US_EAST_1 service = load_service(service_name) operation = service.operation_model(action) # we re-use botocore internals here to serialize the HTTP request, but don't send it client = aws_stack.connect_to_service(service_name, endpoint_url=endpoint_url, region_name=region) request_context = { "client_region": region, "has_streaming_input": operation.has_streaming_input, "auth_type": operation.auth_type, } request_dict = client._convert_to_request_dict(parameters, operation, context=request_context) aws_request = client._endpoint.create_request(request_dict, operation) context = RequestContext() context.service = service context.operation = operation context.region = region context.request = create_http_request(aws_request) return context
def test_restjson_operation_detection_with_subpath(): """ Tests if the operation lookup correctly fails for a subpath of an operation. For example: The detection of a URL which is routed through API Gateway. """ service = load_service("apigateway") parser = create_parser(service) with pytest.raises(OperationNotFoundParserError): parser.parse( HttpRequest( method="GET", path="/restapis/cmqinv79uh/local/_user_request_/", raw_path="/restapis/cmqinv79uh/local/_user_request_/", ))
def test_serializer_error_on_unknown_error(): """Test that the serializer raises a UnknownSerializerError in case of an unknown exception.""" service = load_service("sqs") operation_model = service.operation_model("SendMessage") serializer = QueryResponseSerializer() # An unknown error is obviously hard to trigger (because we would fix it if we would know of a way to trigger it), # therefore we patch a function to raise an unexpected error def raise_error(*args, **kwargs): raise NotImplementedError() serializer._serialize_response = raise_error with pytest.raises(UnknownSerializerError): serializer.serialize_to_response({}, operation_model)
def test_parser_error_on_protocol_error(): """Test that the parser raises a ProtocolParserError in case of invalid data to parse.""" parser = QueryRequestParser(load_service("sqs")) request = HttpRequest( body=to_bytes( "Action=UnknownOperation&Version=2012-11-05&" "QueueUrl=http%3A%2F%2Flocalhost%3A4566%2F000000000000%2Ftf-acc-test-queue&" "MessageBody=%7B%22foo%22%3A+%22bared%22%7D&" "DelaySeconds=2"), method="POST", headers={}, path="", ) with pytest.raises(ProtocolParserError): parser.parse(request)
def test_missing_required_field_restjson(self): parser = create_parser(load_service("opensearch")) op, params = parser.parse( HttpRequest( "POST", "/2021-01-01/tags", body='{"ARN":"somearn"}', )) with pytest.raises(MissingRequiredField) as e: validate_request(op, params).raise_first() assert e.value.error.reason == "missing required field" assert e.value.required_name == "TagList"
def test_trailing_slashes_are_not_strict(): # this is tested against AWS. AWS is not strict about trailing slashes when routing operations. router = RestServiceOperationRouter(load_service("lambda")) op, _ = router.match(Request("GET", "/2015-03-31/functions")) assert op.name == "ListFunctions" op, _ = router.match(Request("GET", "/2015-03-31/functions/")) assert op.name == "ListFunctions" op, _ = router.match(Request("POST", "/2015-03-31/functions")) assert op.name == "CreateFunction" op, _ = router.match(Request("POST", "/2015-03-31/functions/")) assert op.name == "CreateFunction"
def test_missing_required_field_restxml(self): parser = create_parser(load_service("route53")) op, params = parser.parse( HttpRequest( "POST", "/2013-04-01/hostedzone", body= "<CreateHostedZoneRequest><Name>foobar.com</Name></CreateHostedZoneRequest>", )) with pytest.raises(MissingRequiredField) as e: validate_request(op, params).raise_first() assert e.value.error.reason == "missing required field" assert e.value.required_name == "CallerReference"
def generate(service: str, doc: bool, save: bool): """ Generate types and API stubs for a given AWS service. SERVICE is the service to generate the stubs for (e.g., sqs, or cloudformation) """ from click import ClickException try: model = load_service(service) except UnknownServiceError: raise ClickException("unknown service %s" % service) output = io.StringIO() generate_service_types(output, model, doc=doc) generate_service_api(output, model, doc=doc) code = output.getvalue() try: # try to format with black from black import FileMode, format_str code = format_str(code, mode=FileMode()) except Exception: pass if not save: # either just print the code to stdout click.echo(code) return # or find the file path and write the code to that location here = os.path.dirname(__file__) path = os.path.join(here, "api", service) if not os.path.exists(path): click.echo("creating directory %s" % path) mkdir(path) file = os.path.join(path, "__init__.py") click.echo("writing to file %s" % file) with open(file, "w") as fd: fd.write(code) click.echo("done!")
def test_no_mutation_of_parameters(): service = load_service("appconfig") response_serializer = create_serializer(service) parameters = { "ApplicationId": "app_id", "ConfigurationProfileId": "conf_id", "VersionNumber": 1, "Content": b'{"Id":"foo"}', "ContentType": "application/json", } expected = parameters.copy() # serialize response and check whether parameters are unchanged _ = response_serializer.serialize_to_response( parameters, service.operation_model("CreateHostedConfigurationVersion")) assert parameters == expected
def test_query_protocol_error_serialization_plain(): # Specific error of the ChangeMessageVisibility operation in SQS as the scaffold would generate it class ReceiptHandleIsInvalid(ServiceException): pass exception = ReceiptHandleIsInvalid( 'The input receipt handle "garbage" is not a valid receipt handle.') # Load the SQS service service = load_service("sqs") # Use our serializer to serialize the response response_serializer = create_serializer(service) serialized_response = response_serializer.serialize_error_to_response( exception, service.operation_model("ChangeMessageVisibility")) serialized_response_dict = serialized_response.to_readonly_response_dict() # Replace the random request ID with a static value for comparison serialized_response_body = re.sub( "<RequestId>.*</RequestId>", "<RequestId>static_request_id</RequestId>", to_str(serialized_response_dict["body"]), ) # This expected_response_body differs from the actual response in the following ways: # - The original response does not define an encoding. # - There is no newline after the XML declaration. # - The response does not contain a Type nor Detail tag (since they aren't contained in the spec). # - The original response uses double quotes for the xml declaration. # Most of these differences should be handled equally by parsing clients, however, we might adopt some of these # changes in the future. expected_response_body = ( "<?xml version='1.0' encoding='utf-8'?>\n" '<ErrorResponse xmlns="http://queue.amazonaws.com/doc/2012-11-05/">' "<Error>" "<Code>ReceiptHandleIsInvalid</Code>" "<Message>The input receipt handle "garbage" is not a valid receipt handle." "</Message>" "</Error>" "<RequestId>static_request_id</RequestId>" "</ErrorResponse>") assert serialized_response_body == expected_response_body assert serialized_response_dict["headers"].get("Content-Type") is not None assert serialized_response_dict["headers"]["Content-Type"] == "text/xml"
def test_query_parser(): """Basic test for the QueryParser with a simple example (SQS SendMessage request).""" parser = QueryRequestParser(load_service("sqs")) request = HttpRequest( body=to_bytes( "Action=SendMessage&Version=2012-11-05&" "QueueUrl=http%3A%2F%2Flocalhost%3A4566%2F000000000000%2Ftf-acc-test-queue&" "MessageBody=%7B%22foo%22%3A+%22bared%22%7D&" "DelaySeconds=2"), method="POST", headers={}, path="", ) operation, params = parser.parse(request) assert operation.name == "SendMessage" assert params == { "QueueUrl": "http://localhost:4566/000000000000/tf-acc-test-queue", "MessageBody": '{"foo": "bared"}', "DelaySeconds": 2, }
def test_invalid_length_query(self): parser = create_parser(load_service("sts")) op, params = parser.parse( HttpRequest( "POST", "/", body=urlencode( query={ "Action": "AssumeRole", "RoleArn": "arn:aws", # min=8 "RoleSessionName": "foobared", }), headers={"Content-Type": "application/x-www-form-urlencoded"}, )) with pytest.raises(InvalidLength) as e: validate_request(op, params).raise_first() e.match("RoleArn")