Exemple #1
0
    def fix_openapi_definitions(cls, template):
        """
        Helper method to postprocess the resources to make sure the swagger doc version matches
        the one specified on the resource with flag OpenApiVersion.

        This is done postprocess in globals because, the implicit api plugin runs before globals, \
        and at that point the global flags aren't applied on each resource, so we do not know \
        whether OpenApiVersion flag is specified. Running the globals plugin before implicit api \
        was a risky change, so we decided to postprocess the openapi version here.

        To make sure we don't modify customer defined swagger, we also check for __MANAGE_SWAGGER flag.

        :param dict template: SAM template
        :return: Modified SAM template with corrected swagger doc matching the OpenApiVersion.
        """
        resources = template.get("Resources", {})

        for _, resource in resources.items():
            if ("Type" in resource) and (resource["Type"] == cls._API_TYPE):
                properties = resource["Properties"]
                if (cls._OPENAPIVERSION in properties) and (cls._MANAGE_SWAGGER in properties) and \
                    SwaggerEditor.safe_compare_regex_with_string(
                        SwaggerEditor.get_openapi_version_3_regex(), properties[cls._OPENAPIVERSION]):
                    if not isinstance(properties[cls._OPENAPIVERSION], string_types):
                        properties[cls._OPENAPIVERSION] = str(properties[cls._OPENAPIVERSION])
                        resource["Properties"] = properties
                    if "DefinitionBody" in properties:
                        definition_body = properties['DefinitionBody']
                        definition_body['openapi'] = properties[cls._OPENAPIVERSION]
                        if definition_body.get('swagger'):
                            del definition_body['swagger']
    def _add_auth(self):
        """
        Add Auth configuration to the Swagger file, if necessary
        """

        if not self.auth:
            return

        if self.auth and not self.definition_body:
            raise InvalidResourceException(self.logical_id,
                                           "Auth works only with inline Swagger specified in "
                                           "'DefinitionBody' property")

        # Make sure keys in the dict are recognized
        if not all(key in AuthProperties._fields for key in self.auth.keys()):
            raise InvalidResourceException(
                self.logical_id, "Invalid value for 'Auth' property")

        if not SwaggerEditor.is_valid(self.definition_body):
            raise InvalidResourceException(self.logical_id, "Unable to add Auth configuration because "
                                                            "'DefinitionBody' does not contain a valid Swagger")
        swagger_editor = SwaggerEditor(self.definition_body)
        auth_properties = AuthProperties(**self.auth)
        authorizers = self._get_authorizers(auth_properties.Authorizers)

        if authorizers:
            swagger_editor.add_authorizers(authorizers)
            self._set_default_authorizer(swagger_editor, authorizers, auth_properties.DefaultAuthorizer)

        # Assign the Swagger back to template
        self.definition_body = swagger_editor.swagger
Exemple #3
0
    def _add_auth(self):
        """
        Add Auth configuration to the Swagger file, if necessary
        """

        if not self.auth:
            return

        if self.auth and not self.definition_body:
            raise InvalidResourceException(
                self.logical_id,
                "Auth works only with inline Swagger specified in "
                "'DefinitionBody' property")

        # Make sure keys in the dict are recognized
        if not all(key in AuthProperties._fields for key in self.auth.keys()):
            raise InvalidResourceException(
                self.logical_id, "Invalid value for 'Auth' property")

        if not SwaggerEditor.is_valid(self.definition_body):
            raise InvalidResourceException(
                self.logical_id, "Unable to add Auth configuration because "
                "'DefinitionBody' does not contain a valid Swagger")
        swagger_editor = SwaggerEditor(self.definition_body)
        auth_properties = AuthProperties(**self.auth)
        authorizers = self._get_authorizers(auth_properties.Authorizers)

        if authorizers:
            swagger_editor.add_authorizers(authorizers)
            self._set_default_authorizer(swagger_editor, authorizers,
                                         auth_properties.DefaultAuthorizer)

        # Assign the Swagger back to template
        self.definition_body = swagger_editor.swagger
    def setUp(self):

        self.original_swagger = {
            "swagger": "2.0",
            "paths": {
                "/foo": {
                    "get": {
                        _X_INTEGRATION: {
                            "a": "b"
                        }
                    },
                    "post": {
                        _X_INTEGRATION: {
                            "a": "b"
                        }
                    }
                },
                "/bar": {
                    "get": {
                        _X_INTEGRATION: {
                            "a": "b"
                        }
                    }
                },
            }
        }

        self.editor = SwaggerEditor(self.original_swagger)
    def test_allow_headers_is_skipped_with_no_value(self):
        headers = None  # No value
        methods = "methods"
        origins = "origins"
        allow_credentials = True

        expected = {
            "method.response.header.Access-Control-Allow-Credentials":
            _ALLOW_CREDENTALS_TRUE,
            "method.response.header.Access-Control-Allow-Methods": methods,
            "method.response.header.Access-Control-Allow-Origin": origins,
        }

        expected_headers = {
            "Access-Control-Allow-Credentials": {
                "type": "string"
            },
            "Access-Control-Allow-Methods": {
                "type": "string"
            },
            "Access-Control-Allow-Origin": {
                "type": "string"
            }
        }

        options_config = SwaggerEditor(
            SwaggerEditor.gen_skeleton())._options_method_response_for_cors(
                origins, headers, methods, allow_credentials=allow_credentials)

        actual = options_config[_X_INTEGRATION]["responses"]["default"][
            "responseParameters"]
        self.assertEqual(expected, actual)
        self.assertEqual(expected_headers,
                         options_config["responses"]["200"]["headers"])
Exemple #6
0
    def _openapi_postprocess(self, definition_body):
        """
        Convert definitions to openapi 3 in definition body if OpenApiVersion flag is specified.

        If the is swagger defined in the definition body, we treat it as a swagger spec and dod not
        make any openapi 3 changes to it
        """
        if definition_body.get('swagger') is not None:
            return definition_body

        if definition_body.get(
                'openapi') is not None and self.open_api_version is None:
            self.open_api_version = definition_body.get('openapi')

        if self.open_api_version and \
           SwaggerEditor.safe_compare_regex_with_string(SwaggerEditor.get_openapi_version_3_regex(),
                                                        self.open_api_version):
            if definition_body.get('securityDefinitions'):
                components = definition_body.get('components', {})
                components['securitySchemes'] = definition_body[
                    'securityDefinitions']
                definition_body['components'] = components
                del definition_body['securityDefinitions']
            if definition_body.get('definitions'):
                components = definition_body.get('components', {})
                components['schemas'] = definition_body['definitions']
                definition_body['components'] = components
                del definition_body['definitions']
        return definition_body
    def _add_models(self):
        """
        Add Model definitions to the Swagger file, if necessary
        :return:
        """

        if not self.models:
            return

        if self.models and not self.definition_body:
            raise InvalidResourceException(
                self.logical_id, "Models works only with inline Swagger specified in " "'DefinitionBody' property."
            )

        if not SwaggerEditor.is_valid(self.definition_body):
            raise InvalidResourceException(
                self.logical_id,
                "Unable to add Models definitions because "
                "'DefinitionBody' does not contain a valid Swagger definition.",
            )

        if not all(isinstance(model, dict) for model in self.models.values()):
            raise InvalidResourceException(self.logical_id, "Invalid value for 'Models' property")

        swagger_editor = SwaggerEditor(self.definition_body)
        swagger_editor.add_models(self.models)

        # Assign the Swagger back to template

        self.definition_body = self._openapi_postprocess(swagger_editor.swagger)
Exemple #8
0
    def _add_swagger_integration(self, api, function):
        """Adds the path and method for this Api event source to the Swagger body for the provided RestApi.

        :param model.apigateway.ApiGatewayRestApi rest_api: the RestApi to which the path and method should be added.
        """
        swagger_body = api.get("DefinitionBody")
        if swagger_body is None:
            return

        function_arn = function.get_runtime_attr('arn')
        partition = ArnGenerator.get_partition_name()
        uri = fnSub(
            'arn:' + partition +
            ':apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/' +
            make_shorthand(function_arn) + '/invocations')

        editor = SwaggerEditor(swagger_body)

        if editor.has_integration(self.Path, self.Method):
            # Cannot add the Lambda Integration, if it is already present
            raise InvalidEventException(
                self.relative_id,
                'API method "{method}" defined multiple times for path "{path}".'
                .format(method=self.Method, path=self.Path))

        editor.add_lambda_integration(self.Path, self.Method, uri)
        api["DefinitionBody"] = editor.swagger
    def test_must_add_body_parameter_to_method_openapi_required_true(self):

        original_openapi = {
            "openapi": "3.0.1",
            "paths": {
                "/foo": {
                    'get': {
                        'x-amazon-apigateway-integration': {
                            'test': 'must have integration'
                        }
                    }
                }
            }
        }

        editor = SwaggerEditor(original_openapi)

        model = {'Model': 'User'}

        editor.add_request_model_to_method('/foo', 'get', model)

        expected = {
            'content': {
                'application/json': {
                    'schema': {
                        '$ref': '#/components/schemas/user'
                    }
                }
            }
        }

        self.assertEqual(expected,
                         editor.swagger['paths']['/foo']['get']['requestBody'])
Exemple #10
0
    def _construct_rest_api(self):
        """Constructs and returns the ApiGateway RestApi.

        :returns: the RestApi to which this SAM Api corresponds
        :rtype: model.apigateway.ApiGatewayRestApi
        """
        rest_api = ApiGatewayRestApi(self.logical_id,
                                     depends_on=self.depends_on,
                                     attributes=self.resource_attributes)
        # NOTE: For backwards compatibility we need to retain BinaryMediaTypes on the CloudFormation Property
        # Removing this and only setting x-amazon-apigateway-binary-media-types results in other issues.
        rest_api.BinaryMediaTypes = self.binary_media
        rest_api.MinimumCompressionSize = self.minimum_compression_size

        if self.endpoint_configuration:
            self._set_endpoint_configuration(rest_api,
                                             self.endpoint_configuration)

        elif not RegionConfiguration.is_apigw_edge_configuration_supported():
            # Since this region does not support EDGE configuration, we explicitly set the endpoint type
            # to Regional which is the only supported config.
            self._set_endpoint_configuration(rest_api, "REGIONAL")

        if self.definition_uri and self.definition_body:
            raise InvalidResourceException(
                self.logical_id,
                "Specify either 'DefinitionUri' or 'DefinitionBody' property and not both."
            )

        if self.open_api_version:
            if not SwaggerEditor.safe_compare_regex_with_string(
                    SwaggerEditor.get_openapi_versions_supported_regex(),
                    self.open_api_version):
                raise InvalidResourceException(
                    self.logical_id,
                    "The OpenApiVersion value must be of the format '3.0.0'.")

        self._add_cors()
        self._add_auth()
        self._add_gateway_responses()
        self._add_binary_media_types()
        self._add_models()

        if self.definition_uri:
            rest_api.BodyS3Location = self._construct_body_s3_dict()
        elif self.definition_body:
            # # Post Process OpenApi Auth Settings
            self.definition_body = self._openapi_postprocess(
                self.definition_body)
            rest_api.Body = self.definition_body

        if self.name:
            rest_api.Name = self.name

        if self.description:
            rest_api.Description = self.description

        return rest_api
    def _openapi_postprocess(self, definition_body):
        """
        Convert definitions to openapi 3 in definition body if OpenApiVersion flag is specified.

        If the is swagger defined in the definition body, we treat it as a swagger spec and do not
        make any openapi 3 changes to it
        """
        if definition_body.get("swagger") is not None:
            return definition_body

        if definition_body.get("openapi") is not None and self.open_api_version is None:
            self.open_api_version = definition_body.get("openapi")

        if self.open_api_version and SwaggerEditor.safe_compare_regex_with_string(
            SwaggerEditor.get_openapi_version_3_regex(), self.open_api_version
        ):
            if definition_body.get("securityDefinitions"):
                components = definition_body.get("components", {})
                components["securitySchemes"] = definition_body["securityDefinitions"]
                definition_body["components"] = components
                del definition_body["securityDefinitions"]
            if definition_body.get("definitions"):
                components = definition_body.get("components", {})
                components["schemas"] = definition_body["definitions"]
                definition_body["components"] = components
                del definition_body["definitions"]
            # removes `consumes` and `produces` options for CORS in openapi3 and
            # adds `schema` for the headers in responses for openapi3
            if definition_body.get("paths"):
                for path in definition_body.get("paths"):
                    if definition_body.get("paths").get(path).get("options"):
                        definition_body_options = definition_body.get("paths").get(path).get("options").copy()
                        for field in definition_body_options.keys():
                            # remove unsupported produces and consumes in options for openapi3
                            if field in ["produces", "consumes"]:
                                del definition_body["paths"][path]["options"][field]
                            # add schema for the headers in options section for openapi3
                            if field in ["responses"]:
                                options_path = definition_body["paths"][path]["options"]
                                if (
                                    options_path
                                    and options_path.get(field).get("200")
                                    and options_path.get(field).get("200").get("headers")
                                ):
                                    headers = definition_body["paths"][path]["options"][field]["200"]["headers"]
                                    for header in headers.keys():
                                        header_value = {
                                            "schema": definition_body["paths"][path]["options"][field]["200"][
                                                "headers"
                                            ][header]
                                        }
                                        definition_body["paths"][path]["options"][field]["200"]["headers"][
                                            header
                                        ] = header_value

        return definition_body
    def setUp(self):

        self.original_swagger = {
            "swagger": "2.0",
            "paths": {
                "/foo": {},
                "/bar": {},
                "/baz": "some value"
            }
        }

        self.editor = SwaggerEditor(self.original_swagger)
    def _add_auth(self):
        """
        Add Auth configuration to the Swagger file, if necessary
        """

        if not self.auth:
            return

        if self.auth and not self.definition_body:
            raise InvalidResourceException(
                self.logical_id,
                "Auth works only with inline Swagger specified in "
                "'DefinitionBody' property")

        # Make sure keys in the dict are recognized
        if not all(key in AuthProperties._fields for key in self.auth.keys()):
            raise InvalidResourceException(
                self.logical_id, "Invalid value for 'Auth' property")

        if not SwaggerEditor.is_valid(self.definition_body):
            raise InvalidResourceException(
                self.logical_id,
                "Unable to add Auth configuration because "
                "'DefinitionBody' does not contain a valid Swagger",
            )
        swagger_editor = SwaggerEditor(self.definition_body)
        auth_properties = AuthProperties(**self.auth)
        authorizers = self._get_authorizers(auth_properties.Authorizers,
                                            auth_properties.DefaultAuthorizer)

        if authorizers:
            swagger_editor.add_authorizers_security_definitions(authorizers)
            self._set_default_authorizer(
                swagger_editor,
                authorizers,
                auth_properties.DefaultAuthorizer,
                auth_properties.AddDefaultAuthorizerToCorsPreflight,
                auth_properties.Authorizers,
            )

        if auth_properties.ApiKeyRequired:
            swagger_editor.add_apikey_security_definition()
            self._set_default_apikey_required(swagger_editor)

        if auth_properties.ResourcePolicy:
            for path in swagger_editor.iter_on_path():
                swagger_editor.add_resource_policy(
                    auth_properties.ResourcePolicy, path, self.logical_id,
                    self.stage_name)

        self.definition_body = self._openapi_postprocess(
            swagger_editor.swagger)
    def test_must_not_add_body_parameter_to_method_without_integration(self):

        original_swagger = {"swagger": "2.0", "paths": {"/foo": {'get': {}}}}

        editor = SwaggerEditor(original_swagger)

        model = {'Model': 'User', 'Required': True}

        editor.add_request_model_to_method('/foo', 'get', model)

        expected = {}

        self.assertEqual(expected, editor.swagger['paths']['/foo']['get'])
    def setUp(self):
        self.original_swagger = {
            "swagger": "2.0",
            "paths": {
                "/foo": {
                    'get': {
                        'x-amazon-apigateway-integration': {
                            'test': 'must have integration'
                        }
                    }
                }
            }
        }

        self.editor = SwaggerEditor(self.original_swagger)
    def test_must_return_copy_of_swagger(self):

        input = {"swagger": "2.0", "paths": {}}

        editor = SwaggerEditor(input)
        self.assertEqual(input, editor.swagger)  # They are equal in content
        input["swagger"] = "3"
        self.assertEqual(
            "2.0",
            editor.swagger["swagger"])  # Editor works on a diff copy of input

        editor.add_path("/foo", "get")
        self.assertEqual({"/foo": {"get": {}}}, editor.swagger["paths"])
        self.assertEqual(
            {}, input["paths"])  # Editor works on a diff copy of input
    def test_must_return_copy_of_swagger(self):

        input = {
            "swagger": "2.0",
            "paths": {}
        }

        editor = SwaggerEditor(input)
        self.assertEqual(input, editor.swagger) # They are equal in content
        input["swagger"] = "3"
        self.assertEqual("2.0", editor.swagger["swagger"]) # Editor works on a diff copy of input

        editor.add_path("/foo", "get")
        self.assertEqual({"/foo": {"get": {}}}, editor.swagger["paths"])
        self.assertEqual({}, input["paths"]) # Editor works on a diff copy of input
    def _add_gateway_responses(self):
        """
        Add Gateway Response configuration to the Swagger file, if necessary
        """

        if not self.gateway_responses:
            return

        if self.gateway_responses and not self.definition_body:
            raise InvalidResourceException(
                self.logical_id,
                "GatewayResponses works only with inline Swagger specified in " "'DefinitionBody' property.",
            )

        # Make sure keys in the dict are recognized
        for responses_key, responses_value in self.gateway_responses.items():
            for response_key in responses_value.keys():
                if response_key not in GatewayResponseProperties:
                    raise InvalidResourceException(
                        self.logical_id,
                        "Invalid property '{}' in 'GatewayResponses' property '{}'.".format(
                            response_key, responses_key
                        ),
                    )

        if not SwaggerEditor.is_valid(self.definition_body):
            raise InvalidResourceException(
                self.logical_id,
                "Unable to add Auth configuration because "
                "'DefinitionBody' does not contain a valid Swagger definition.",
            )

        swagger_editor = SwaggerEditor(self.definition_body)

        gateway_responses = {}
        for response_type, response in self.gateway_responses.items():
            gateway_responses[response_type] = ApiGatewayResponse(
                api_logical_id=self.logical_id,
                response_parameters=response.get("ResponseParameters", {}),
                response_templates=response.get("ResponseTemplates", {}),
                status_code=response.get("StatusCode", None),
            )

        if gateway_responses:
            swagger_editor.add_gateway_responses(gateway_responses)

        # Assign the Swagger back to template
        self.definition_body = swagger_editor.swagger
    def on_before_transform_template(self, template_dict):
        """
        Hook method that gets called before the SAM template is processed.
        The template has passed the validation and is guaranteed to contain a non-empty "Resources" section.

        :param dict template_dict: Dictionary of the SAM template
        :return: Nothing
        """
        template = SamTemplate(template_dict)

        for api_type in [
                SamResourceType.Api.value, SamResourceType.HttpApi.value
        ]:
            for logicalId, api in template.iterate({api_type}):
                if api.properties.get("DefinitionBody") or api.properties.get(
                        "DefinitionUri"):
                    continue

                if api_type is SamResourceType.HttpApi.value:
                    # If "Properties" is not set in the template, set them here
                    if not api.properties:
                        template.set(logicalId, api)
                    api.properties[
                        "DefinitionBody"] = OpenApiEditor.gen_skeleton()

                if api_type is SamResourceType.Api.value:
                    api.properties[
                        "DefinitionBody"] = SwaggerEditor.gen_skeleton()

                api.properties["__MANAGE_SWAGGER"] = True
    def test_allow_headers_is_skipped_with_no_value(self):
        headers = None # No value
        methods = "methods"
        origins = "origins"
        allow_credentials = True

        expected = {
            "method.response.header.Access-Control-Allow-Credentials": _ALLOW_CREDENTALS_TRUE,
            "method.response.header.Access-Control-Allow-Methods": methods,
            "method.response.header.Access-Control-Allow-Origin": origins,
        }

        expected_headers = {
            "Access-Control-Allow-Credentials": {
                "type": "string"
            },
            "Access-Control-Allow-Methods": {
                "type": "string"
            },
            "Access-Control-Allow-Origin": {
                "type": "string"
            }
        }

        options_config = SwaggerEditor(SwaggerEditor.gen_skeleton())._options_method_response_for_cors(
            origins, headers, methods, allow_credentials=allow_credentials)

        actual = options_config[_X_INTEGRATION]["responses"]["default"]["responseParameters"]
        self.assertEqual(expected, actual)
        self.assertEqual(expected_headers, options_config["responses"]["200"]["headers"])
    def setUp(self):

        self.original_swagger = {
            "swagger": "2.0",
            "paths": {
                "/foo": {},
                "/withoptions": {
                    "options": {
                        "some": "value"
                    }
                },
                "/bad": "some value"
            }
        }

        self.editor = SwaggerEditor(self.original_swagger)
    def test_must_succeed_on_valid_openapi3(self):
        valid_swagger = {"openapi": "3.0.1", "paths": {"/foo": {}, "/bar": {}}}

        editor = SwaggerEditor(valid_swagger)
        self.assertIsNotNone(editor)

        self.assertEqual(editor.paths, {"/foo": {}, "/bar": {}})
 def setUp(self):
     self.editor = SwaggerEditor({
         "swagger": "2.0",
         "paths": {
             "/foo": {
                 "get": {},
                 "POST": {},
                 "DeLeTe": {}
             },
             "/withany": {
                 "head": {},
                 _X_ANY_METHOD: {}
             },
             "/nothing": {}
         }
     })
    def setUp(self):

        self.original_swagger = {
            "swagger": "2.0",
            "paths": {
                "/foo": {
                    "get": {
                        "a": "b"
                    }
                },
                "/bar": {},
                "/badpath": "string value"
            }
        }

        self.editor = SwaggerEditor(self.original_swagger)
    def test_must_not_add_parameter_to_method_without_integration(self):
        original_swagger = {"swagger": "2.0", "paths": {"/foo": {'get': {}}}}

        editor = SwaggerEditor(original_swagger)

        parameters = [{
            'Name': 'method.request.header.Authorization',
            'Required': True,
            'Caching': True
        }]

        editor.add_request_parameters_to_method('/foo', 'get', parameters)

        expected = {}

        self.assertEqual(expected, editor.swagger['paths']['/foo']['get'])
Exemple #26
0
    def _get_permission(self, resources_to_link, stage, suffix):
        # It turns out that APIGW doesn't like trailing slashes in paths (#665)
        # and removes as a part of their behaviour, but this isn't documented.
        # The regex removes the tailing slash to ensure the permission works as intended
        path = re.sub(r'^(.+)/$', r'\1', self.Path)

        if not stage or not suffix:
            raise RuntimeError("Could not add permission to lambda function.")

        path = SwaggerEditor.get_path_without_trailing_slash(path)
        method = '*' if self.Method.lower() == 'any' else self.Method.upper()

        api_id = self.RestApiId

        # RestApiId can be a simple string or intrinsic function like !Ref. Using Fn::Sub will handle both cases
        resource = '${__ApiId__}/' + '${__Stage__}/' + method + path
        partition = ArnGenerator.get_partition_name()
        source_arn = fnSub(
            ArnGenerator.generate_arn(partition=partition,
                                      service='execute-api',
                                      resource=resource), {
                                          "__ApiId__": api_id,
                                          "__Stage__": stage
                                      })

        return self._construct_permission(resources_to_link['function'],
                                          source_arn=source_arn,
                                          suffix=suffix)
class TestSwaggerEditor_is_valid(TestCase):
    @parameterized.expand([
        param(SwaggerEditor.gen_skeleton()),

        # Dict can contain any other unrecognized properties
        param({
            "swagger": "anyvalue",
            "paths": {},
            "foo": "bar",
            "baz": "bar"
        })
    ])
    def test_must_work_on_valid_values(self, swagger):
        self.assertTrue(SwaggerEditor.is_valid(swagger))

    @parameterized.expand([
        ({}, "empty dictionary"),
        ([1, 2, 3], "array data type"),
        ({
            "paths": {}
        }, "missing swagger property"),
        ({
            "swagger": "hello"
        }, "missing paths property"),
        ({
            "swagger": "hello",
            "paths": [1, 2, 3]
        }, "array value for paths property"),
    ])
    def test_must_fail_for_invalid_values(self, data, case):
        self.assertFalse(
            SwaggerEditor.is_valid(data),
            "Swagger dictionary with {} must not be valid".format(case))
class TestSwaggerEditor_make_cors_allowed_methods_for_path(TestCase):

    def setUp(self):
        self.editor = SwaggerEditor({
            "swagger": "2.0",
            "paths": {
                "/foo": {
                    "get": {},
                    "POST": {},
                    "DeLeTe": {}
                },
                "/withany": {
                    "head": {},
                    _X_ANY_METHOD: {}
                },
                "/nothing": {
                }
            }
        })

    def test_must_return_all_defined_methods(self):
        path = "/foo"
        expected = "DELETE,GET,OPTIONS,POST" # Result should be sorted alphabetically

        actual = self.editor._make_cors_allowed_methods_for_path(path)
        self.assertEqual(expected, actual)

    def test_must_work_for_any_method(self):
        path = "/withany"
        expected = "DELETE,GET,HEAD,OPTIONS,PATCH,POST,PUT" # Result should be sorted alphabetically

        actual = self.editor._make_cors_allowed_methods_for_path(path)
        self.assertEqual(expected, actual)

    def test_must_work_with_no_methods(self):
        path = "/nothing"
        expected = "OPTIONS"

        actual = self.editor._make_cors_allowed_methods_for_path(path)
        self.assertEqual(expected, actual)

    def test_must_skip_non_existent_path(self):
        path = "/no-path"
        expected = ""

        actual = self.editor._make_cors_allowed_methods_for_path(path)
        self.assertEqual(expected, actual)
    def _add_binary_media_types(self):
        """
        Add binary media types to Swagger
        """

        if not self.binary_media:
            return

        # We don't raise an error here like we do for similar cases because that would be backwards incompatible
        if self.binary_media and not self.definition_body:
            return

        editor = SwaggerEditor(self.definition_body)
        editor.add_binary_media_types(self.binary_media)

        # Assign the Swagger back to template
        self.definition_body = editor.swagger
    def setUp(self):
        self.swagger = {
            "swagger": "2.0",
            "paths": {
                "/foo": {
                    "get": {},
                    "somemethod": {}
                },
                "/bar": {
                    "post": {},
                    _X_ANY_METHOD: {}
                },
                "badpath": "string value"
            }
        }

        self.editor = SwaggerEditor(self.swagger)
class TestSwaggerEditor_make_cors_allowed_methods_for_path(TestCase):
    def setUp(self):
        self.editor = SwaggerEditor({
            "swagger": "2.0",
            "paths": {
                "/foo": {
                    "get": {},
                    "POST": {},
                    "DeLeTe": {}
                },
                "/withany": {
                    "head": {},
                    _X_ANY_METHOD: {}
                },
                "/nothing": {}
            }
        })

    def test_must_return_all_defined_methods(self):
        path = "/foo"
        expected = "DELETE,GET,OPTIONS,POST"  # Result should be sorted alphabetically

        actual = self.editor._make_cors_allowed_methods_for_path(path)
        self.assertEqual(expected, actual)

    def test_must_work_for_any_method(self):
        path = "/withany"
        expected = "DELETE,GET,HEAD,OPTIONS,PATCH,POST,PUT"  # Result should be sorted alphabetically

        actual = self.editor._make_cors_allowed_methods_for_path(path)
        self.assertEqual(expected, actual)

    def test_must_work_with_no_methods(self):
        path = "/nothing"
        expected = "OPTIONS"

        actual = self.editor._make_cors_allowed_methods_for_path(path)
        self.assertEqual(expected, actual)

    def test_must_skip_non_existent_path(self):
        path = "/no-path"
        expected = ""

        actual = self.editor._make_cors_allowed_methods_for_path(path)
        self.assertEqual(expected, actual)
class TestSwaggerEditor_add_path(TestCase):
    def setUp(self):

        self.original_swagger = {
            "swagger": "2.0",
            "paths": {
                "/foo": {
                    "get": {
                        "a": "b"
                    }
                },
                "/bar": {},
                "/badpath": "string value"
            }
        }

        self.editor = SwaggerEditor(self.original_swagger)

    @parameterized.expand([
        param("/new", "get", "new path, new method"),
        param("/foo", "new method", "existing path, new method"),
        param("/bar", "get", "existing path, new method"),
    ])
    def test_must_add_new_path_and_method(self, path, method, case):

        self.assertFalse(self.editor.has_path(path, method))
        self.editor.add_path(path, method)

        self.assertTrue(self.editor.has_path(path, method),
                        "must add for " + case)
        self.assertEqual(self.editor.swagger["paths"][path][method], {})

    def test_must_raise_non_dict_path_values(self):

        path = "/badpath"
        method = "get"

        with self.assertRaises(InvalidDocumentException):
            self.editor.add_path(path, method)

    def test_must_skip_existing_path(self):
        """
        Given an existing path/method, this must
        :return:
        """

        path = "/foo"
        method = "get"
        original_value = copy.deepcopy(
            self.original_swagger["paths"][path][method])

        self.editor.add_path(path, method)
        modified_swagger = self.editor.swagger
        self.assertEqual(original_value,
                         modified_swagger["paths"][path][method])
    def test_allow_origins_is_not_skipped_with_no_value(self):
        headers = None
        methods = None
        origins = None
        allow_credentials = False

        expected = {
            # We will ALWAYS set AllowOrigin. This is a minimum requirement for CORS
            "method.response.header.Access-Control-Allow-Origin": origins
        }

        options_config = SwaggerEditor(
            SwaggerEditor.gen_skeleton())._options_method_response_for_cors(
                origins, headers, methods, allow_credentials=allow_credentials)

        actual = options_config[_X_INTEGRATION]["responses"]["default"][
            "responseParameters"]
        self.assertEqual(expected, actual)
    def setUp(self):
        self.swagger = {
            "swagger": "2.0",
            "paths": {
                "/foo": {
                    "get": {
                        _X_INTEGRATION: {
                            "a": "b"
                        }
                    },
                    "post": {
                        "Fn::If": [
                            "Condition", {
                                _X_INTEGRATION: {
                                    "a": "b"
                                }
                            }, {
                                "Ref": "AWS::NoValue"
                            }
                        ]
                    },
                    "delete": {
                        "Fn::If": [
                            "Condition", {
                                "Ref": "AWS::NoValue"
                            }, {
                                _X_INTEGRATION: {
                                    "a": "b"
                                }
                            }
                        ]
                    },
                    "somemethod": {
                        "foo": "value",
                    },
                    "emptyintegration": {
                        _X_INTEGRATION: {}
                    },
                    "badmethod": "string value"
                },
            }
        }

        self.editor = SwaggerEditor(self.swagger)
    def test_correct_value_is_returned(self):
        self.maxDiff = None
        headers = "foo"
        methods = {"a": "b"}
        origins = [1,2,3]
        max_age = 60

        expected = {
            "summary": "CORS support",
            "consumes": ["application/json"],
            "produces": ["application/json"],
            _X_INTEGRATION: {
                "type": "mock",
                "requestTemplates": {
                    "application/json": "{\n  \"statusCode\" : 200\n}\n"
                },
                "responses": {
                    "default": {
                        "statusCode": "200",
                        "responseParameters": {
                            "method.response.header.Access-Control-Allow-Headers": headers,
                            "method.response.header.Access-Control-Allow-Methods": methods,
                            "method.response.header.Access-Control-Allow-Origin": origins,
                            "method.response.header.Access-Control-Max-Age": max_age
                        },
                        "responseTemplates": {
                            "application/json": "{}\n"
                        }
                    }
                }
            },
            "responses": {
                "200": {
                    "description": "Default response for CORS method",
                    "headers": {
                        "Access-Control-Allow-Headers": {
                            "type": "string"
                        },
                        "Access-Control-Allow-Methods": {
                            "type": "string"
                        },
                        "Access-Control-Allow-Origin": {
                            "type": "string"
                        },
                        "Access-Control-Max-Age": {
                            "type": "integer"
                        }
                    }
                }
            }
        }

        actual = SwaggerEditor(SwaggerEditor.gen_skeleton())._options_method_response_for_cors(origins, headers, methods, max_age)
        self.assertEquals(expected, actual)
class TestSwaggerEditor_add_path(TestCase):

    def setUp(self):

        self.original_swagger = {
            "swagger": "2.0",
            "paths": {
                "/foo": {
                    "get": {"a": "b"}
                },
                "/bar": {},
                "/badpath": "string value"
            }
        }

        self.editor = SwaggerEditor(self.original_swagger)

    @parameterized.expand([
        param("/new", "get", "new path, new method"),
        param("/foo", "new method", "existing path, new method"),
        param("/bar", "get", "existing path, new method"),
    ])
    def test_must_add_new_path_and_method(self, path, method, case):

        self.assertFalse(self.editor.has_path(path, method))
        self.editor.add_path(path, method)

        self.assertTrue(self.editor.has_path(path, method), "must add for "+case)
        self.assertEqual(self.editor.swagger["paths"][path][method], {})

    def test_must_raise_non_dict_path_values(self):

        path = "/badpath"
        method = "get"

        with self.assertRaises(InvalidDocumentException):
            self.editor.add_path(path, method)

    def test_must_skip_existing_path(self):
        """
        Given an existing path/method, this must
        :return:
        """

        path = "/foo"
        method = "get"
        original_value = copy.deepcopy(self.original_swagger["paths"][path][method])

        self.editor.add_path(path, method)
        modified_swagger = self.editor.swagger
        self.assertEqual(original_value, modified_swagger["paths"][path][method])
    def setUp(self):

        self.original_swagger = {
            "swagger": "2.0",
            "paths": {
                "/foo": {},
                "/bar": {},
                "/baz": "some value"
            }
        }

        self.editor = SwaggerEditor(self.original_swagger)
    def _add_cors(self):
        """
        Add CORS configuration to the Swagger file, if necessary
        """

        INVALID_ERROR = "Invalid value for 'Cors' property"

        if not self.cors:
            return

        if self.cors and not self.definition_body:
            raise InvalidResourceException(self.logical_id,
                                           "Cors works only with inline Swagger specified in "
                                           "'DefinitionBody' property")

        if isinstance(self.cors, string_types) or is_instrinsic(self.cors):
            # Just set Origin property. Others will be defaults
            properties = CorsProperties(AllowOrigin=self.cors)
        elif isinstance(self.cors, dict):

            # Make sure keys in the dict are recognized
            if not all(key in CorsProperties._fields for key in self.cors.keys()):
                raise InvalidResourceException(self.logical_id, INVALID_ERROR)

            properties = CorsProperties(**self.cors)

        else:
            raise InvalidResourceException(self.logical_id, INVALID_ERROR)

        if not SwaggerEditor.is_valid(self.definition_body):
            raise InvalidResourceException(self.logical_id, "Unable to add Cors configuration because "
                                                            "'DefinitionBody' does not contain a valid Swagger")

        editor = SwaggerEditor(self.definition_body)
        for path in editor.iter_on_path():
            editor.add_cors(path,  properties.AllowOrigin, properties.AllowHeaders, properties.AllowMethods,
                            max_age=properties.MaxAge)

        # Assign the Swagger back to template
        self.definition_body = editor.swagger
    def setUp(self):

        self.original_swagger = {
            "swagger": "2.0",
            "paths": {
                "/foo": {},
                "/withoptions": {
                    "options": {"some": "value"}
                },
                "/bad": "some value"
            }
        }

        self.editor = SwaggerEditor(self.original_swagger)
    def setUp(self):

        self.original_swagger = {
            "swagger": "2.0",
            "paths": {
                "/foo": {
                    "get": {"a": "b"}
                },
                "/bar": {},
                "/badpath": "string value"
            }
        }

        self.editor = SwaggerEditor(self.original_swagger)
    def test_allow_methods_is_skipped_with_no_value(self):
        headers = "headers"
        methods = None # No value
        origins = "origins"

        expected = {
            "method.response.header.Access-Control-Allow-Headers": headers,
            "method.response.header.Access-Control-Allow-Origin": origins
        }

        options_config = SwaggerEditor(SwaggerEditor.gen_skeleton())._options_method_response_for_cors(
            origins, headers, methods)

        actual = options_config[_X_INTEGRATION]["responses"]["default"]["responseParameters"]
        self.assertEquals(expected, actual)
    def test_allow_origins_is_not_skipped_with_no_value(self):
        headers = None
        methods = None
        origins = None

        expected = {
            # We will ALWAYS set AllowOrigin. This is a minimum requirement for CORS
            "method.response.header.Access-Control-Allow-Origin": origins
        }

        options_config = SwaggerEditor(SwaggerEditor.gen_skeleton())._options_method_response_for_cors(
            origins, headers, methods)

        actual = options_config[_X_INTEGRATION]["responses"]["default"]["responseParameters"]
        self.assertEquals(expected, actual)
class TestSwaggerEditor_has_integration(TestCase):

    def setUp(self):
        self.swagger = {
            "swagger": "2.0",
            "paths": {
                "/foo": {
                    "get": {
                        _X_INTEGRATION: {
                            "a": "b"
                        }
                    },
                    "somemethod": {
                        "foo": "value",
                    },
                    "emptyintegration": {
                        _X_INTEGRATION: {}
                    },
                    "badmethod": "string value"
                },
            }
        }

        self.editor = SwaggerEditor(self.swagger)

    def test_must_find_integration(self):
        self.assertTrue(self.editor.has_integration("/foo", "get"))

    def test_must_not_find_integration(self):
        self.assertFalse(self.editor.has_integration("/foo", "somemethod"))

    def test_must_not_find_empty_integration(self):
        self.assertFalse(self.editor.has_integration("/foo", "emptyintegration"))

    def test_must_handle_bad_value_for_method(self):
        self.assertFalse(self.editor.has_integration("/foo", "badmethod"))
    def _add_swagger_integration(self, api, function):
        """Adds the path and method for this Api event source to the Swagger body for the provided RestApi.

        :param model.apigateway.ApiGatewayRestApi rest_api: the RestApi to which the path and method should be added.
        """
        swagger_body = api.get("DefinitionBody")
        if swagger_body is None:
            return

        function_arn = function.get_runtime_attr('arn')
        partition = ArnGenerator.get_partition_name()
        uri = fnSub('arn:'+partition+':apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/'
                    + make_shorthand(function_arn) + '/invocations')

        editor = SwaggerEditor(swagger_body)

        if editor.has_integration(self.Path, self.Method):
            # Cannot add the Lambda Integration, if it is already present
            raise InvalidEventException(
                self.relative_id,
                'API method "{method}" defined multiple times for path "{path}".'.format(
                    method=self.Method, path=self.Path))

        editor.add_lambda_integration(self.Path, self.Method, uri)

        if self.Auth:
            method_authorizer = self.Auth.get('Authorizer')

            if method_authorizer:
                api_auth = api.get('Auth')
                api_authorizers = api_auth and api_auth.get('Authorizers')

                if not api_authorizers:
                    raise InvalidEventException(
                        self.relative_id,
                        'Unable to set Authorizer [{authorizer}] on API method [{method}] for path [{path}] because '
                        'the related API does not define any Authorizers.'.format(
                            authorizer=method_authorizer, method=self.Method, path=self.Path))

                if method_authorizer != 'NONE' and not api_authorizers.get(method_authorizer):
                    raise InvalidEventException(
                        self.relative_id,
                        'Unable to set Authorizer [{authorizer}] on API method [{method}] for path [{path}] because it '
                        'wasn\'t defined in the API\'s Authorizers.'.format(
                            authorizer=method_authorizer, method=self.Method, path=self.Path))

                if method_authorizer == 'NONE' and not api_auth.get('DefaultAuthorizer'):
                    raise InvalidEventException(
                        self.relative_id,
                        'Unable to set Authorizer on API method [{method}] for path [{path}] because \'NONE\' '
                        'is only a valid value when a DefaultAuthorizer on the API is specified.'.format(
                            method=self.Method, path=self.Path))

            editor.add_auth_to_method(api=api, path=self.Path, method_name=self.Method, auth=self.Auth)

        api["DefinitionBody"] = editor.swagger
    def on_before_transform_template(self, template_dict):
        """
        Hook method that gets called before the SAM template is processed.
        The template has passed the validation and is guaranteed to contain a non-empty "Resources" section.

        :param dict template_dict: Dictionary of the SAM template
        :return: Nothing
        """
        template = SamTemplate(template_dict)

        for logicalId, api in template.iterate(SamResourceType.Api.value):
            if api.properties.get('DefinitionBody') or api.properties.get('DefinitionUri'):
                continue

            api.properties['DefinitionBody'] = SwaggerEditor.gen_skeleton()
            api.properties['__MANAGE_SWAGGER'] = True
 def setUp(self):
     self.editor = SwaggerEditor({
         "swagger": "2.0",
         "paths": {
             "/foo": {
                 "get": {},
                 "POST": {},
                 "DeLeTe": {}
             },
             "/withany": {
                 "head": {},
                 _X_ANY_METHOD: {}
             },
             "/nothing": {
             }
         }
     })
    def test_max_age_can_be_set_to_zero(self):
        headers = None
        methods = "methods"
        origins = "origins"
        max_age = 0

        expected = {
            "method.response.header.Access-Control-Allow-Methods": methods,
            "method.response.header.Access-Control-Allow-Origin": origins,
            "method.response.header.Access-Control-Max-Age": max_age
        }

        options_config = SwaggerEditor(SwaggerEditor.gen_skeleton())._options_method_response_for_cors(
            origins, headers, methods, max_age)

        actual = options_config[_X_INTEGRATION]["responses"]["default"]["responseParameters"]
        self.assertEquals(expected, actual)
    def setUp(self):
        self.swagger = {
            "swagger": "2.0",
            "paths": {
                "/foo": {
                    "get": {},
                    "somemethod": {}
                },
                "/bar": {
                    "post": {},
                    _X_ANY_METHOD: {}
                },
                "badpath": "string value"
            }
        }

        self.editor = SwaggerEditor(self.swagger)
    def test_allow_credentials_is_skipped_with_false_value(self):
        headers = "headers"
        methods = "methods"
        origins = "origins"
        allow_credentials = False

        expected = {
            "method.response.header.Access-Control-Allow-Headers": headers,
            "method.response.header.Access-Control-Allow-Methods": methods,
            "method.response.header.Access-Control-Allow-Origin": origins,
        }

        options_config = SwaggerEditor(SwaggerEditor.gen_skeleton())._options_method_response_for_cors(
            origins, headers, methods, allow_credentials=allow_credentials)

        actual = options_config[_X_INTEGRATION]["responses"]["default"]["responseParameters"]
        self.assertEqual(expected, actual)
    def setUp(self):
        self.swagger = {
            "swagger": "2.0",
            "paths": {
                "/foo": {
                    "get": {
                        _X_INTEGRATION: {
                            "a": "b"
                        }
                    },
                    "post": {
                        "Fn::If": [
                            "Condition",
                            {
                                _X_INTEGRATION: {
                                    "a": "b"
                                }
                            },
                            {"Ref": "AWS::NoValue"}
                        ]
                    },
                    "delete": {
                        "Fn::If": [
                            "Condition",
                            {"Ref": "AWS::NoValue"},
                            {
                                _X_INTEGRATION: {
                                    "a": "b"
                                }
                            }
                        ]
                    },
                    "somemethod": {
                        "foo": "value",
                    },
                    "emptyintegration": {
                        _X_INTEGRATION: {}
                    },
                    "badmethod": "string value"
                },
            }
        }

        self.editor = SwaggerEditor(self.swagger)
class TestSwaggerEditor_iter_on_path(TestCase):

    def setUp(self):

        self.original_swagger = {
            "swagger": "2.0",
            "paths": {
                "/foo": {},
                "/bar": {},
                "/baz": "some value"
            }
        }

        self.editor = SwaggerEditor(self.original_swagger)

    def test_must_iterate_on_paths(self):

        expected = {"/foo", "/bar", "/baz"}
        actual = set([path for path in self.editor.iter_on_path()])

        self.assertEqual(expected, actual)
    def setUp(self):
        self.swagger = {
            "swagger": "2.0",
            "paths": {
                "/foo": {
                    "get": {
                        _X_INTEGRATION: {
                            "a": "b"
                        }
                    },
                    "somemethod": {
                        "foo": "value",
                    },
                    "emptyintegration": {
                        _X_INTEGRATION: {}
                    },
                    "badmethod": "string value"
                },
            }
        }

        self.editor = SwaggerEditor(self.swagger)
    def setUp(self):

        self.original_swagger = {
            "swagger": "2.0",
            "paths": {
                "/foo": {
                    "post": {
                        "a": [1, 2, "b"],
                        "responses": {
                            "something": "is already here"
                        }
                    }
                },
                "/bar": {
                    "get": {
                        _X_INTEGRATION: {
                            "a": "b"
                        }
                    }
                },
            }
        }

        self.editor = SwaggerEditor(self.original_swagger)
class TestSwaggerEditor_add_lambda_integration(TestCase):

    def setUp(self):

        self.original_swagger = {
            "swagger": "2.0",
            "paths": {
                "/foo": {
                    "post": {
                        "a": [1, 2, "b"],
                        "responses": {
                            "something": "is already here"
                        }
                    }
                },
                "/bar": {
                    "get": {
                        _X_INTEGRATION: {
                            "a": "b"
                        }
                    }
                },
            }
        }

        self.editor = SwaggerEditor(self.original_swagger)

    def test_must_add_new_integration_to_new_path(self):
        path = "/newpath"
        method = "get"
        integration_uri = "something"
        expected = {
            "responses": {},
            _X_INTEGRATION: {
                "type": "aws_proxy",
                "httpMethod": "POST",
                "uri": integration_uri
            }
        }

        self.editor.add_lambda_integration(path, method, integration_uri)

        self.assertTrue(self.editor.has_path(path, method))
        actual = self.editor.swagger["paths"][path][method]
        self.assertEqual(expected, actual)

    def test_must_add_new_integration_with_conditions_to_new_path(self):
        path = "/newpath"
        method = "get"
        integration_uri = "something"
        condition = "condition"
        expected = {
            "Fn::If": [
                "condition",
                {
                    "responses": {},
                    _X_INTEGRATION: {
                        "type": "aws_proxy",
                        "httpMethod": "POST",
                        "uri": {
                            "Fn::If": [
                                "condition",
                                integration_uri,
                                {
                                    "Ref": "AWS::NoValue"
                                }
                            ]
                        }
                    }
                },
                {
                    "Ref": "AWS::NoValue"
                }
            ]
        }

        self.editor.add_lambda_integration(path, method, integration_uri, condition=condition)

        self.assertTrue(self.editor.has_path(path, method))
        actual = self.editor.swagger["paths"][path][method]
        self.assertEqual(expected, actual)

    def test_must_add_new_integration_to_existing_path(self):
        path = "/foo"
        method = "post"
        integration_uri = "something"
        expected = {
            # Current values present in the dictionary *MUST* be preserved
            "a": [1, 2, "b"],

            # Responses key must be untouched
            "responses": {
                "something": "is already here"
            },

            # New values must be added
            _X_INTEGRATION: {
                "type": "aws_proxy",
                "httpMethod": "POST",
                "uri": integration_uri
            }
        }

        # Just make sure test is working on an existing path
        self.assertTrue(self.editor.has_path(path, method))

        self.editor.add_lambda_integration(path, method, integration_uri)

        actual = self.editor.swagger["paths"][path][method]
        self.assertEqual(expected, actual)

    def test_must_raise_on_existing_integration(self):

        with self.assertRaises(ValueError):
            self.editor.add_lambda_integration("/bar", "get", "integrationUri")

    def test_must_add_credentials_to_the_integration(self):
        path = "/newpath"
        method = "get"
        integration_uri = "something"
        expected = 'arn:aws:iam::*:user/*'
        api_auth_config = {
          "DefaultAuthorizer": "AWS_IAM",
          "InvokeRole": "CALLER_CREDENTIALS"
        }

        self.editor.add_lambda_integration(path, method, integration_uri, None, api_auth_config)
        actual = self.editor.swagger["paths"][path][method][_X_INTEGRATION]['credentials']
        self.assertEqual(expected, actual)

    def test_must_add_credentials_to_the_integration_overrides(self):
        path = "/newpath"
        method = "get"
        integration_uri = "something"
        expected = 'arn:aws:iam::*:role/xxxxxx'
        api_auth_config = {
          "DefaultAuthorizer": "MyAuth",
        }
        method_auth_config = {
          "Authorizer": "AWS_IAM",
          "InvokeRole": "arn:aws:iam::*:role/xxxxxx"
        }

        self.editor.add_lambda_integration(path, method, integration_uri, method_auth_config, api_auth_config)
        actual = self.editor.swagger["paths"][path][method][_X_INTEGRATION]['credentials']
        self.assertEqual(expected, actual)
class TestSwaggerEditor_has_path(TestCase):

    def setUp(self):
        self.swagger = {
            "swagger": "2.0",
            "paths": {
                "/foo": {
                    "get": {},
                    "somemethod": {}
                },
                "/bar": {
                    "post": {},
                    _X_ANY_METHOD: {}
                },
                "badpath": "string value"
            }
        }

        self.editor = SwaggerEditor(self.swagger)

    def test_must_find_path_and_method(self):
        self.assertTrue(self.editor.has_path("/foo"))
        self.assertTrue(self.editor.has_path("/foo", "get"))
        self.assertTrue(self.editor.has_path("/foo", "somemethod"))
        self.assertTrue(self.editor.has_path("/bar"))
        self.assertTrue(self.editor.has_path("/bar", "post"))

    def test_must_find_with_method_case_insensitive(self):
        self.assertTrue(self.editor.has_path("/foo", "GeT"))
        self.assertTrue(self.editor.has_path("/bar", "POST"))

        # Only Method is case insensitive. Path is case sensitive
        self.assertFalse(self.editor.has_path("/FOO"))

    def test_must_work_with_any_method(self):
        """
        Method name "ANY" is special. It must be converted to the x-amazon style value before search
        """
        self.assertTrue(self.editor.has_path("/bar", "any"))
        self.assertTrue(self.editor.has_path("/bar", "AnY")) # Case insensitive
        self.assertTrue(self.editor.has_path("/bar", _X_ANY_METHOD))
        self.assertFalse(self.editor.has_path("/foo", "any"))

    def test_must_not_find_path(self):
        self.assertFalse(self.editor.has_path("/foo/other"))
        self.assertFalse(self.editor.has_path("/bar/xyz"))
        self.assertFalse(self.editor.has_path("/abc"))

    def test_must_not_find_path_and_method(self):
        self.assertFalse(self.editor.has_path("/foo", "post"))
        self.assertFalse(self.editor.has_path("/foo", "abc"))
        self.assertFalse(self.editor.has_path("/bar", "get"))
        self.assertFalse(self.editor.has_path("/bar", "xyz"))

    def test_must_not_fail_on_bad_path(self):

        self.assertTrue(self.editor.has_path("badpath"))
        self.assertFalse(self.editor.has_path("badpath", "somemethod"))
class TestSwaggerEditor_add_cors(TestCase):

    def setUp(self):

        self.original_swagger = {
            "swagger": "2.0",
            "paths": {
                "/foo": {},
                "/withoptions": {
                    "options": {"some": "value"}
                },
                "/bad": "some value"
            }
        }

        self.editor = SwaggerEditor(self.original_swagger)

    def test_must_add_options_to_new_path(self):
        allowed_origins = "origins"
        allowed_headers = ["headers", "2"]
        allowed_methods = {"key": "methods"}
        max_age = 60
        allow_credentials = True
        options_method_response_allow_credentials = True
        path = "/foo"
        expected = {"some cors": "return value"}

        self.editor._options_method_response_for_cors = Mock()
        self.editor._options_method_response_for_cors.return_value = expected

        self.editor.add_cors(path, allowed_origins, allowed_headers, allowed_methods, max_age, allow_credentials)
        self.assertEqual(expected, self.editor.swagger["paths"][path]["options"])
        self.editor._options_method_response_for_cors.assert_called_with(allowed_origins,
                                                                         allowed_headers,
                                                                         allowed_methods,
                                                                         max_age,
                                                                         options_method_response_allow_credentials)

    def test_must_skip_existing_path(self):
        path = "/withoptions"
        expected = copy.deepcopy(self.original_swagger["paths"][path]["options"])

        self.editor.add_cors(path, "origins", "headers", "methods")
        self.assertEqual(expected, self.editor.swagger["paths"][path]["options"])

    def test_must_fail_with_bad_values_for_path(self):
        path = "/bad"

        with self.assertRaises(InvalidDocumentException):
            self.editor.add_cors(path, "origins", "headers", "methods")

    def test_must_fail_for_invalid_allowed_origin(self):

        path = "/foo"
        with self.assertRaises(ValueError):
            self.editor.add_cors(path, None, "headers", "methods")

    def test_must_work_for_optional_allowed_headers(self):

        allowed_origins = "origins"
        allowed_headers = None # No Value
        allowed_methods = "methods"
        max_age = 60
        allow_credentials = True
        options_method_response_allow_credentials = True

        expected = {"some cors": "return value"}
        path = "/foo"

        self.editor._options_method_response_for_cors = Mock()
        self.editor._options_method_response_for_cors.return_value = expected

        self.editor.add_cors(path, allowed_origins, allowed_headers, allowed_methods, max_age, allow_credentials)

        self.assertEqual(expected, self.editor.swagger["paths"][path]["options"])

        self.editor._options_method_response_for_cors.assert_called_with(allowed_origins,
                                                                         allowed_headers,
                                                                         allowed_methods,
                                                                         max_age,
                                                                         options_method_response_allow_credentials)

    def test_must_make_default_value_with_optional_allowed_methods(self):

        allowed_origins = "origins"
        allowed_headers = "headers"
        allowed_methods = None  # No Value
        max_age = 60
        allow_credentials = True
        options_method_response_allow_credentials = True

        default_allow_methods_value = "some default value"
        default_allow_methods_value_with_quotes = "'{}'".format(default_allow_methods_value)
        expected = {"some cors": "return value"}
        path = "/foo"

        self.editor._make_cors_allowed_methods_for_path = Mock()
        self.editor._make_cors_allowed_methods_for_path.return_value = default_allow_methods_value

        self.editor._options_method_response_for_cors = Mock()
        self.editor._options_method_response_for_cors.return_value = expected

        self.editor.add_cors(path, allowed_origins, allowed_headers, allowed_methods, max_age, allow_credentials)

        self.assertEqual(expected, self.editor.swagger["paths"][path]["options"])

        self.editor._options_method_response_for_cors.assert_called_with(allowed_origins,
                                                                         allowed_headers,
                                                                         # Must be called with default value.
                                                                         # And value must be quoted
                                                                         default_allow_methods_value_with_quotes,
                                                                         max_age,
                                                                         options_method_response_allow_credentials)

    def test_must_accept_none_allow_credentials(self):
        allowed_origins = "origins"
        allowed_headers = ["headers", "2"]
        allowed_methods = {"key": "methods"}
        max_age = 60
        allow_credentials = None
        options_method_response_allow_credentials = False
        path = "/foo"
        expected = {"some cors": "return value"}

        self.editor._options_method_response_for_cors = Mock()
        self.editor._options_method_response_for_cors.return_value = expected

        self.editor.add_cors(path, allowed_origins, allowed_headers, allowed_methods, max_age, allow_credentials)
        self.assertEqual(expected, self.editor.swagger["paths"][path]["options"])
        self.editor._options_method_response_for_cors.assert_called_with(allowed_origins,
                                                                         allowed_headers,
                                                                         allowed_methods,
                                                                         max_age,
                                                                         options_method_response_allow_credentials)
 def test_must_normalize(self, input, expected, msg):
     self.assertEqual(expected, SwaggerEditor._normalize_method_name(input), msg)
 def test_must_work_on_valid_values(self, swagger):
     self.assertTrue(SwaggerEditor.is_valid(swagger))
 def test_must_fail_for_invalid_values(self, data, case):
     self.assertFalse(SwaggerEditor.is_valid(data), "Swagger dictionary with {} must not be valid".format(case))
class TestSwaggerEditor_add_lambda_integration(TestCase):

    def setUp(self):

        self.original_swagger = {
            "swagger": "2.0",
            "paths": {
                "/foo": {
                    "post": {
                        "a": [1, 2, "b"],
                        "responses": {
                            "something": "is already here"
                        }
                    }
                },
                "/bar": {
                    "get": {
                        _X_INTEGRATION: {
                            "a": "b"
                        }
                    }
                },
            }
        }

        self.editor = SwaggerEditor(self.original_swagger)

    def test_must_add_new_integration_to_new_path(self):
        path = "/newpath"
        method = "get"
        integration_uri = "something"
        expected = {
            "responses": {},
            _X_INTEGRATION: {
                "type": "aws_proxy",
                "httpMethod": "POST",
                "uri": integration_uri
            }
        }

        self.editor.add_lambda_integration(path, method, integration_uri)

        self.assertTrue(self.editor.has_path(path, method))
        actual = self.editor.swagger["paths"][path][method]
        self.assertEquals(expected, actual)

    def test_must_add_new_integration_to_existing_path(self):
        path = "/foo"
        method = "post"
        integration_uri = "something"
        expected = {
            # Current values present in the dictionary *MUST* be preserved
            "a": [1, 2, "b"],

            # Responses key must be untouched
            "responses": {
                "something": "is already here"
            },

            # New values must be added
            _X_INTEGRATION: {
                "type": "aws_proxy",
                "httpMethod": "POST",
                "uri": integration_uri
            }
        }

        # Just make sure test is working on an existing path
        self.assertTrue(self.editor.has_path(path, method))

        self.editor.add_lambda_integration(path, method, integration_uri)

        actual = self.editor.swagger["paths"][path][method]
        self.assertEquals(expected, actual)

    def test_must_raise_on_existing_integration(self):

        with self.assertRaises(ValueError):
            self.editor.add_lambda_integration("/bar", "get", "integrationUri")