Ejemplo n.º 1
0
    def test_put_metric_data_gzip(self, cloudwatch_client):
        metric_name = "test-metric"
        namespace = "namespace"
        data = ("Action=PutMetricData&MetricData.member.1."
                "MetricName=%s&MetricData.member.1.Value=1&"
                "Namespace=%s&Version=2010-08-01" % (metric_name, namespace))
        bytes_data = bytes(data, encoding="utf-8")
        encoded_data = gzip.compress(bytes_data)

        url = config.get_edge_url()
        headers = aws_stack.mock_aws_request_headers("cloudwatch")
        authorization = aws_stack.mock_aws_request_headers(
            "monitoring")["Authorization"]

        headers.update({
            "Content-Type": "application/x-www-form-urlencoded; charset=utf-8",
            "Content-Length": len(encoded_data),
            "Content-Encoding": "GZIP",
            "User-Agent": "aws-sdk-nodejs/2.819.0 linux/v12.18.2 callback",
            "Authorization": authorization,
        })
        request = Request(url, encoded_data, headers, method="POST")
        urlopen(request)

        rs = cloudwatch_client.list_metrics(Namespace=namespace,
                                            MetricName=metric_name)
        assert 1 == len(rs["Metrics"])
        assert namespace == rs["Metrics"][0]["Namespace"]
Ejemplo n.º 2
0
    def test_response_content_type(self):
        url = config.get_edge_url()
        data = {"Action": "GetCallerIdentity", "Version": "2011-06-15"}

        # receive response as XML (default)
        headers = aws_stack.mock_aws_request_headers("sts")
        response = requests.post(url, data=data, headers=headers)
        assert response
        content1 = to_str(response.content)
        with pytest.raises(json.decoder.JSONDecodeError):
            json.loads(content1)
        content1 = xmltodict.parse(content1)
        content1_result = content1["GetCallerIdentityResponse"]["GetCallerIdentityResult"]
        assert content1_result["Account"] == TEST_AWS_ACCOUNT_ID

        # receive response as JSON (via Accept header)
        headers = aws_stack.mock_aws_request_headers("sts")
        headers["Accept"] = APPLICATION_JSON
        response = requests.post(url, data=data, headers=headers)
        assert response
        content2 = json.loads(to_str(response.content))
        content2_result = content2["GetCallerIdentityResponse"]["GetCallerIdentityResult"]
        assert content2_result["Account"] == TEST_AWS_ACCOUNT_ID
        content1.get("GetCallerIdentityResponse", {}).pop("ResponseMetadata", None)
        content2.get("GetCallerIdentityResponse", {}).pop("ResponseMetadata", None)
        assert strip_xmlns(content1) == content2
Ejemplo n.º 3
0
    def test_cdk_bootstrap_redeploy(self, is_change_set_finished,
                                    cleanup_stacks, cleanup_changesets):
        """Test that simulates a sequence of commands executed by CDK when running 'cdk bootstrap' twice"""

        base_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                   "..")
        requests_file = os.path.join(base_folder, "files",
                                     "cdk-bootstrap-requests.json")
        operations = json.loads(load_file(requests_file))

        change_set_name = "cdk-deploy-change-set-a4b98b18"
        stack_name = "CDKToolkit-a4b98b18"
        try:
            headers = aws_stack.mock_aws_request_headers("cloudformation")
            base_url = config.get_edge_url()
            for op in operations:
                url = f"{base_url}{op['path']}"
                data = op["data"]
                requests.request(method=op["method"],
                                 url=url,
                                 headers=headers,
                                 data=data)
                if "Action=ExecuteChangeSet" in data:
                    assert wait_until(is_change_set_finished(change_set_name),
                                      _max_wait=20,
                                      strategy="linear")
        finally:
            # clean up
            cleanup_changesets([change_set_name])
            cleanup_stacks([stack_name])
Ejemplo n.º 4
0
def test_no_arn_partition_rewriting_in_internal_response():
    """Partitions should not be rewritten for _responses_ of _internal_ requests."""
    listener = ArnPartitionRewriteListener()
    response = Response()
    body_content = json.dumps({
        "some-data-with-arn":
        "arn:aws:apigateway:us-gov-west-1::/restapis/arn-in-body/*"
    })
    response._content = body_content
    response._status_code = 200
    response_header_content = {
        "some-header-with-arn":
        "arn:aws:apigateway:us-gov-west-1::/restapis/arn-in-header/*"
    }
    response.headers = response_header_content

    # mimic an internal request
    request_headers = mock_aws_request_headers(
        region_name="us-gov-west-1", access_key=INTERNAL_AWS_ACCESS_KEY_ID)

    result = listener.return_response(method="POST",
                                      path="/",
                                      data="ignored",
                                      headers=request_headers,
                                      response=response)

    assert result is None
Ejemplo n.º 5
0
    def test_publish_by_path_parameters(self):
        topic_name = 'topic-{}'.format(short_uid())
        queue_name = 'queue-{}'.format(short_uid())

        message = 'test message {}'.format(short_uid())
        topic_arn = self.sns_client.create_topic(Name=topic_name)['TopicArn']

        base_url = '{}://{}:{}'.format(get_service_protocol(), config.LOCALSTACK_HOSTNAME, config.PORT_SNS)
        path = 'Action=Publish&Version=2010-03-31&TopicArn={}&Message={}'.format(topic_arn, message)

        queue_url = self.sqs_client.create_queue(QueueName=queue_name)['QueueUrl']
        queue_arn = aws_stack.sqs_queue_arn(queue_name)

        self.sns_client.subscribe(TopicArn=topic_arn, Protocol='sqs', Endpoint=queue_arn)

        r = requests.post(
            url='{}/?{}'.format(base_url, path),
            headers=aws_stack.mock_aws_request_headers('sns')
        )
        self.assertEqual(r.status_code, 200)

        def get_notification(q_url):
            resp = self.sqs_client.receive_message(QueueUrl=q_url)
            return json.loads(resp['Messages'][0]['Body'])

        notification = retry(get_notification, retries=3, sleep=2, q_url=queue_url)
        self.assertEqual(notification['TopicArn'], topic_arn)
        self.assertEqual(notification['Message'], message)

        # clean up
        self.sns_client.delete_topic(TopicArn=topic_arn)
        self.sqs_client.delete_queue(QueueUrl=queue_url)
Ejemplo n.º 6
0
def update_apigateway(method,
                      path,
                      data,
                      headers,
                      response=None,
                      return_forward_info=False):
    if return_forward_info:
        regex1 = r'^/restapis/[A-Za-z0-9\-]+/deployments$'
        if method == 'POST' and re.match(regex1, path):
            # this is a request to deploy the API gateway, simply return HTTP code 200
            return 200

        regex2 = r'^/restapis/([A-Za-z0-9_\-]+)/([A-Za-z0-9_\-]+)/%s/([^/]+)$' % PATH_USER_REQUEST
        if method == 'POST' and re.match(regex2, path):
            api_id = re.search(regex2, path).group(1)
            sub_path = '/%s' % re.search(regex2, path).group(3)
            integration = aws_stack.get_apigateway_integration(
                api_id, method, sub_path)
            template = integration['requestTemplates'][APPLICATION_JSON]
            new_request = aws_stack.render_velocity_template(template, data)

            # forward records to our main kinesis stream
            # TODO check whether the target of this API method is 'kinesis'
            headers = aws_stack.mock_aws_request_headers(service='kinesis')
            headers['X-Amz-Target'] = KINESIS_ACTION_PUT_RECORDS
            result = common.make_http_request(url=TEST_KINESIS_URL,
                                              method='POST',
                                              data=new_request,
                                              headers=headers)
            return 200
        return True
Ejemplo n.º 7
0
    def test_put_metric_data_gzip(self):
        metric_name = 'test-metric'
        namespace = 'namespace'
        data = 'Action=PutMetricData&MetricData.member.1.' \
            'MetricName=%s&MetricData.member.1.Value=1&' \
            'Namespace=%s&Version=2010-08-01' \
            % (metric_name, namespace)
        bytes_data = bytes(data, encoding='utf-8')
        encoded_data = gzip.compress(bytes_data)

        url = config.get_edge_url()
        headers = aws_stack.mock_aws_request_headers('cloudwatch')

        authorization = 'AWS4-HMAC-SHA256 Credential=test/20201230/' \
            'us-east-1/monitoring/aws4_request, ' \
            'SignedHeaders=content-encoding;host;' \
            'x-amz-content-sha256;x-amz-date, Signature='\
            'bb31fc5f4e58040ede9ed751133fe'\
            '839668b27290bc1406b6ffadc4945c705dc'

        headers.update({
            'Content-Type': 'application/x-www-form-urlencoded; charset=utf-8',
            'Content-Length': len(encoded_data),
            'Content-Encoding': 'GZIP',
            'User-Agent': 'aws-sdk-nodejs/2.819.0 linux/v12.18.2 callback',
            'Authorization': authorization,
        })
        request = Request(url, encoded_data, headers, method='POST')
        urlopen(request)

        client = aws_stack.connect_to_service('cloudwatch')
        rs = client.list_metrics(Namespace=namespace, MetricName=metric_name)
        self.assertEqual(len(rs['Metrics']), 1)
        self.assertEqual(rs['Metrics'][0]['Namespace'], namespace)
Ejemplo n.º 8
0
    def test_put_metric_data_gzip(self):
        metric_name = "test-metric"
        namespace = "namespace"
        data = ("Action=PutMetricData&MetricData.member.1."
                "MetricName=%s&MetricData.member.1.Value=1&"
                "Namespace=%s&Version=2010-08-01" % (metric_name, namespace))
        bytes_data = bytes(data, encoding="utf-8")
        encoded_data = gzip.compress(bytes_data)

        url = config.get_edge_url()
        headers = aws_stack.mock_aws_request_headers("cloudwatch")

        authorization = ("AWS4-HMAC-SHA256 Credential=test/20201230/"
                         "us-east-1/monitoring/aws4_request, "
                         "SignedHeaders=content-encoding;host;"
                         "x-amz-content-sha256;x-amz-date, Signature="
                         "bb31fc5f4e58040ede9ed751133fe"
                         "839668b27290bc1406b6ffadc4945c705dc")

        headers.update({
            "Content-Type": "application/x-www-form-urlencoded; charset=utf-8",
            "Content-Length": len(encoded_data),
            "Content-Encoding": "GZIP",
            "User-Agent": "aws-sdk-nodejs/2.819.0 linux/v12.18.2 callback",
            "Authorization": authorization,
        })
        request = Request(url, encoded_data, headers, method="POST")
        urlopen(request)

        client = aws_stack.connect_to_service("cloudwatch")
        rs = client.list_metrics(Namespace=namespace, MetricName=metric_name)
        self.assertEqual(1, len(rs["Metrics"]))
        self.assertEqual(namespace, rs["Metrics"][0]["Namespace"])
Ejemplo n.º 9
0
    def test_get_records(self, kinesis_client, kinesis_create_stream,
                         wait_for_stream_ready):
        stream_name = "test-%s" % short_uid()

        kinesis_create_stream(StreamName=stream_name, ShardCount=1)
        wait_for_stream_ready(stream_name)

        kinesis_client.put_records(
            StreamName=stream_name,
            Records=[{
                "Data": "SGVsbG8gd29ybGQ=",
                "PartitionKey": "1"
            }],
        )

        # get records with JSON encoding
        iterator = self._get_shard_iterator(stream_name, kinesis_client)
        response = kinesis_client.get_records(ShardIterator=iterator)
        json_records = response.get("Records")
        assert 1 == len(json_records)
        assert "Data" in json_records[0]

        # get records with CBOR encoding
        iterator = self._get_shard_iterator(stream_name, kinesis_client)
        url = config.get_edge_url()
        headers = aws_stack.mock_aws_request_headers("kinesis")
        headers["Content-Type"] = constants.APPLICATION_AMZ_CBOR_1_1
        headers["X-Amz-Target"] = "Kinesis_20131202.GetRecords"
        data = cbor2.dumps({"ShardIterator": iterator})
        result = requests.post(url, data, headers=headers)
        assert 200 == result.status_code
        result = cbor2.loads(result.content)
        attrs = ("Data", "EncryptionType", "PartitionKey", "SequenceNumber")
        assert select_attributes(json_records[0], attrs) == select_attributes(
            result["Records"][0], attrs)
Ejemplo n.º 10
0
def configure_region_for_current_request(region_name: str, service_name: str):
    """Manually configure (potentially overwrite) the region in the current request context. This may be
    used by API endpoints that are invoked directly by the user (without specifying AWS Authorization
    headers), to still enable transparent region lookup via aws_stack.get_region() ..."""

    # TODO: leaving import here for now, to avoid circular dependency
    from localstack.utils.aws import aws_stack

    request_context = get_request_context()
    if not request_context:
        LOG.info(
            "Unable to set region '%s' in undefined request context: %s",
            region_name,
            request_context,
        )
        return

    headers = request_context.headers
    auth_header = headers.get("Authorization")
    auth_header = auth_header or aws_stack.mock_aws_request_headers(
        service_name)["Authorization"]
    auth_header = auth_header.replace("/%s/" % aws_stack.get_region(),
                                      "/%s/" % region_name)
    try:
        headers["Authorization"] = auth_header
    except Exception as e:
        if "immutable" not in str(e):
            raise
        _context_to_update = get_proxy_request_for_thread() or request
        _context_to_update.headers = CaseInsensitiveDict({
            **headers, "Authorization":
            auth_header
        })
Ejemplo n.º 11
0
def send_dynamodb_request(path, action, request_body):
    headers = {
        "Host": "dynamodb.amazonaws.com",
        "x-amz-target": "DynamoDB_20120810.{}".format(action),
        "Authorization": aws_stack.mock_aws_request_headers("dynamodb")["Authorization"],
    }
    url = "{}/{}".format(os.getenv("TEST_DYNAMODB_URL"), path)
    return requests.put(url, data=request_body, headers=headers, verify=False)
Ejemplo n.º 12
0
def send_dynamodb_request(path, action, request_body):
    headers = {
        'Host': 'dynamodb.amazonaws.com',
        'x-amz-target': 'DynamoDB_20120810.{}'.format(action),
        'Authorization': aws_stack.mock_aws_request_headers('dynamodb')['Authorization']
    }
    url = '{}/{}'.format(os.getenv('TEST_DYNAMODB_URL'), path)
    return requests.put(url, data=request_body, headers=headers, verify=False)
Ejemplo n.º 13
0
def send_dynamodb_request(path, action, request_body):
    headers = {
        "Host": "dynamodb.amazonaws.com",
        "x-amz-target": "DynamoDB_20120810.{}".format(action),
        "Authorization": aws_stack.mock_aws_request_headers("dynamodb")["Authorization"],
    }
    url = f"{config.service_url('dynamodb')}/{path}"
    return requests.put(url, data=request_body, headers=headers, verify=False)
Ejemplo n.º 14
0
    def forward_request(self, method, path, data, headers):
        if method == 'OPTIONS':
            return 200

        req_data = parse_request_data(method, path, data)

        if is_sqs_queue_url(path) and method == 'GET':
            if not headers.get('Authorization'):
                headers['Authorization'] = aws_stack.mock_aws_request_headers(service='sqs')['Authorization']
            method = 'POST'
            req_data = {'Action': 'GetQueueUrl', 'Version': API_VERSION, 'QueueName': path.split('/')[-1]}

        if req_data:
            action = req_data.get('Action')

            if action in ('SendMessage', 'SendMessageBatch') and SQS_BACKEND_IMPL == 'moto':
                # check message contents
                for key, value in req_data.items():
                    if not re.match(MSG_CONTENT_REGEX, str(value)):
                        return make_requests_error(code=400, code_string='InvalidMessageContents',
                            message='Message contains invalid characters')

            elif action == 'SetQueueAttributes':
                # TODO remove this function if we stop using ElasticMQ entirely
                queue_url = _queue_url(path, req_data, headers)
                if SQS_BACKEND_IMPL == 'elasticmq':
                    forward_attrs = _set_queue_attributes(queue_url, req_data)
                    if len(req_data) != len(forward_attrs):
                        # make sure we only forward the supported attributes to the backend
                        return _get_attributes_forward_request(method, path, headers, req_data, forward_attrs)

            elif action == 'CreateQueue':
                changed_attrs = _fix_dlq_arn_in_attributes(req_data)
                if changed_attrs:
                    return _get_attributes_forward_request(method, path, headers, req_data, changed_attrs)

            elif action == 'DeleteQueue':
                queue_url = _queue_url(path, req_data, headers)
                QUEUE_ATTRIBUTES.pop(queue_url, None)
                sns_listener.unsubscribe_sqs_queue(queue_url)

            elif action == 'ListDeadLetterSourceQueues':
                # TODO remove this function if we stop using ElasticMQ entirely
                queue_url = _queue_url(path, req_data, headers)
                if SQS_BACKEND_IMPL == 'elasticmq':
                    headers = {'content-type': 'application/xhtml+xml'}
                    content_str = _list_dead_letter_source_queues(QUEUE_ATTRIBUTES, queue_url)
                    return requests_response(content_str, headers=headers)

            if 'QueueName' in req_data:
                encoded_data = urlencode(req_data, doseq=True) if method == 'POST' else ''
                modified_url = None
                if method == 'GET':
                    base_path = path.partition('?')[0]
                    modified_url = '%s?%s' % (base_path, urlencode(req_data, doseq=True))
                return Request(data=encoded_data, url=modified_url, headers=headers, method=method)

        return True
Ejemplo n.º 15
0
def test_fix_region_in_headers():
    # the NoSQL Workbench sends "localhost" or "local" as the region name
    # TODO: this may need to be updated once we migrate DynamoDB to ASF

    for region_name in ["local", "localhost"]:
        headers = aws_stack.mock_aws_request_headers("dynamodb", region_name=region_name)
        assert aws_stack.get_region() not in headers.get("Authorization")
        ProxyListenerDynamoDB.prepare_request_headers(headers)
        assert aws_stack.get_region() in headers.get("Authorization")
Ejemplo n.º 16
0
    def forward_request(self, method, path, data, headers):

        if path.split('?')[0] == '/health':
            return serve_health_endpoint(method, path, data)
        if method == 'POST' and path == '/graph':
            return serve_resource_graph(data)

        # kill the process if we receive this header
        headers.get(HEADER_KILL_SIGNAL) and os._exit(0)

        target = headers.get('x-amz-target', '')
        auth_header = headers.get('authorization', '')
        host = headers.get('host', '')
        headers[HEADER_LOCALSTACK_EDGE_URL] = 'https://%s' % host

        # extract API details
        api, port, path, host = get_api_from_headers(headers, path)

        set_default_region_in_headers(headers)

        if port and int(port) < 0:
            return 404

        if not port:
            api, port = get_api_from_custom_rules(method, path, data,
                                                  headers) or (api, port)

        if not port:
            if method == 'OPTIONS':
                return 200

            if api in ['', None, '_unknown_']:
                truncated = truncate(data)
                LOG.info((
                    'Unable to find forwarding rule for host "%s", path "%s", '
                    'target header "%s", auth header "%s", data "%s"') %
                         (host, path, target, auth_header, truncated))
            else:
                LOG.info((
                    'Unable to determine forwarding port for API "%s" - please '
                    'make sure this API is enabled via the SERVICES configuration'
                ) % api)
            response = Response()
            response.status_code = 404
            response._content = '{"status": "running"}'
            return response

        if api and not headers.get('Authorization'):
            headers['Authorization'] = aws_stack.mock_aws_request_headers(
                api)['Authorization']

        headers['Host'] = host
        if isinstance(data, dict):
            data = json.dumps(data)

        return do_forward_request(api, port, method, path, data, headers)
Ejemplo n.º 17
0
 def test_expiration_date_format(self):
     url = config.get_edge_url()
     data = {"Action": "GetSessionToken", "Version": "2011-06-15"}
     headers = aws_stack.mock_aws_request_headers("sts")
     headers["Accept"] = APPLICATION_JSON
     response = requests.post(url, data=data, headers=headers)
     assert response
     content = json.loads(to_str(response.content))
     # Expiration field should be numeric (tested against AWS)
     result = content["GetSessionTokenResponse"]["GetSessionTokenResult"]
     assert is_number(result["Credentials"]["Expiration"])
Ejemplo n.º 18
0
    def test_request_with_custom_host_header(self):
        url = config.get_edge_url()

        headers = aws_stack.mock_aws_request_headers("lambda")

        # using a simple for-loop here (instead of pytest parametrization), for simplicity
        for host in ["localhost", "example.com"]:
            for port in ["", ":123", f":{config.EDGE_PORT}"]:
                headers["Host"] = f"{host}{port}"
                response = requests.get(f"{url}/2015-03-31/functions", headers=headers)
                assert response
                assert "Functions" in json.loads(to_str(response.content))
Ejemplo n.º 19
0
 def test_create_bucket_via_host_name(self):
     body = """<?xml version="1.0" encoding="UTF-8"?>
         <CreateBucketConfiguration xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
             <LocationConstraint>eu-central-1</LocationConstraint>
         </CreateBucketConfiguration>"""
     headers = aws_stack.mock_aws_request_headers('s3')
     bucket_name = 'test-%s' % short_uid()
     headers['Host'] = '%s.s3.amazonaws.com' % bucket_name
     response = requests.put(config.TEST_S3_URL, data=body, headers=headers, verify=False)
     self.assertEquals(response.status_code, 200)
     response = self.s3_client.get_bucket_location(Bucket=bucket_name)
     self.assertEqual(response['ResponseMetadata']['HTTPStatusCode'], 200)
     self.assertIn('LocationConstraint', response)
Ejemplo n.º 20
0
    def __call__(self, chain: HandlerChain, context: RequestContext, response: Response):
        # FIXME: this is needed for allowing access to resources via plain URLs where access is typically restricted (
        #  e.g., GET requests on S3 URLs or apigateway routes). this should probably be part of a general IAM middleware
        #  (that allows access to restricted resources by default)
        if not context.service:
            return
        from localstack.utils.aws import aws_stack

        api = context.service.service_name
        headers = context.request.headers

        if not headers.get("Authorization"):
            headers["Authorization"] = aws_stack.mock_aws_request_headers(api)["Authorization"]
Ejemplo n.º 21
0
 def invoke(cls, invocation_context: ApiInvocationContext):
     try:
         request_templates = RequestTemplates()
         payload = request_templates.render(invocation_context)
     except Exception as e:
         LOG.warning("Failed to apply template for SNS integration", e)
         raise
     uri = (invocation_context.integration.get("uri")
            or invocation_context.integration.get("integrationUri") or "")
     region_name = uri.split(":")[3]
     headers = aws_stack.mock_aws_request_headers(service="sns",
                                                  region_name=region_name)
     return make_http_request(config.service_url("sns"),
                              method="POST",
                              headers=headers,
                              data=payload)
Ejemplo n.º 22
0
def update_apigateway(method,
                      path,
                      data,
                      headers,
                      response=None,
                      return_forward_info=False):
    if return_forward_info:

        regex1 = r'^/restapis/[A-Za-z0-9\-]+/deployments$'
        if method == 'POST' and re.match(regex1, path):
            # this is a request to deploy the API gateway, simply return HTTP code 200
            return 200

        regex2 = r'^/restapis/([A-Za-z0-9_\-]+)/([A-Za-z0-9_\-]+)/%s/([^/]+)$' % PATH_USER_REQUEST
        if method == 'POST' and re.match(regex2, path):
            api_id = re.search(regex2, path).group(1)
            sub_path = '/%s' % re.search(regex2, path).group(3)
            integration = aws_stack.get_apigateway_integration(
                api_id, method, sub_path)
            if integration['type'] == 'AWS':
                if integration['uri'].endswith('kinesis:action/PutRecords'):
                    template = integration['requestTemplates'][
                        APPLICATION_JSON]
                    new_request = aws_stack.render_velocity_template(
                        template, data)

                    # forward records to target kinesis stream
                    headers = aws_stack.mock_aws_request_headers(
                        service='kinesis')
                    headers['X-Amz-Target'] = KINESIS_ACTION_PUT_RECORDS
                    result = common.make_http_request(url=TEST_KINESIS_URL,
                                                      method='POST',
                                                      data=new_request,
                                                      headers=headers)
                    return result
                else:
                    LOGGER.warning(
                        'API Gateway action uri "%s" not yet implemented' %
                        integration['uri'])
            else:
                LOGGER.warning(
                    'API Gateway integration type "%s" not yet implemented' %
                    integration['type'])

            return 200

        return True
Ejemplo n.º 23
0
def test_arn_partition_rewriting_in_request(internal_call, encoding,
                                            origin_partition):
    listener = ArnPartitionRewriteListener()
    data = encoding(
        json.dumps({
            "some-data-with-arn":
            f"arn:{origin_partition}:apigateway:us-gov-west-1::/restapis/arn-in-body/*"
        }))

    # if this test is parameterized to be an internal call, set the internal auth
    # incoming requests should be rewritten for both, internal and external requests (in contrast to the responses!)
    if internal_call:
        headers = mock_aws_request_headers(
            region_name=origin_partition,
            access_key=INTERNAL_AWS_ACCESS_KEY_ID)
    else:
        headers = {}

    headers[
        "some-header-with-arn"] = f"arn:{origin_partition}:apigateway:us-gov-west-1::/restapis/arn-in-header/*"

    result = listener.forward_request(
        method="POST",
        path=
        f"/?arn=arn%3A{origin_partition}%3Aapigateway%3Aus-gov-west-1%3A%3A%2Frestapis%2Farn-in-path%2F%2A&"
        f"arn2=arn%3A{origin_partition}%3Aapigateway%3Aus-gov-west-1%3A%3A%2Frestapis%2Farn-in-path2%2F%2A",
        data=data,
        headers=headers,
    )
    assert result.method == "POST"
    assert (
        result.path ==
        "/?arn=arn%3Aaws%3Aapigateway%3Aus-gov-west-1%3A%3A%2Frestapis%2Farn-in-path%2F%2A&"
        "arn2=arn%3Aaws%3Aapigateway%3Aus-gov-west-1%3A%3A%2Frestapis%2Farn-in-path2%2F%2A"
    )
    assert result.data == encoding(
        json.dumps({
            "some-data-with-arn":
            "arn:aws:apigateway:us-gov-west-1::/restapis/arn-in-body/*"
        }))
    assert (result.headers["some-header-with-arn"] ==
            "arn:aws:apigateway:us-gov-west-1::/restapis/arn-in-header/*")
Ejemplo n.º 24
0
    def test_get_records(self):
        client = aws_stack.create_external_boto_client("kinesis")
        stream_name = "test-%s" % short_uid()

        client.create_stream(StreamName=stream_name, ShardCount=1)
        sleep(1.5)
        client.put_records(
            StreamName=stream_name,
            Records=[{
                "Data": "SGVsbG8gd29ybGQ=",
                "PartitionKey": "1"
            }],
        )

        # get records with JSON encoding
        iterator = self._get_shard_iterator(stream_name)
        response = client.get_records(ShardIterator=iterator)
        json_records = response.get("Records")
        self.assertEqual(1, len(json_records))
        self.assertIn("Data", json_records[0])

        # get records with CBOR encoding
        iterator = self._get_shard_iterator(stream_name)
        url = config.get_edge_url()
        headers = aws_stack.mock_aws_request_headers("kinesis")
        headers["Content-Type"] = constants.APPLICATION_AMZ_CBOR_1_1
        headers["X-Amz-Target"] = "Kinesis_20131202.GetRecords"
        data = cbor2.dumps({"ShardIterator": iterator})
        result = requests.post(url, data, headers=headers)
        self.assertEqual(200, result.status_code)
        result = cbor2.loads(result.content)
        attrs = ("Data", "EncryptionType", "PartitionKey", "SequenceNumber")
        self.assertEqual(
            select_attributes(json_records[0], attrs),
            select_attributes(result["Records"][0], attrs),
        )

        # clean up
        client.delete_stream(StreamName=stream_name)
Ejemplo n.º 25
0
    def test_disable_cors_headers(self, monkeypatch):
        """Test DISABLE_CORS_CHECKS=1 (most restrictive setting, not sending any CORS headers)"""
        headers = aws_stack.mock_aws_request_headers("sns")
        headers["Origin"] = "https://app.localstack.cloud"
        url = config.get_edge_url()
        data = {"Action": "ListTopics", "Version": "2010-03-31"}
        response = requests.post(url, headers=headers, data=data)
        assert response.status_code == 200
        assert response.headers["access-control-allow-origin"] == headers[
            "Origin"]
        assert "authorization" in response.headers[
            "access-control-allow-headers"].lower()
        assert "GET" in response.headers["access-control-allow-methods"].split(
            ",")
        assert "<ListTopicsResponse" in to_str(response.content)

        monkeypatch.setattr(config, "DISABLE_CORS_HEADERS", True)
        response = requests.post(url, headers=headers, data=data)
        assert response.status_code == 200
        assert "<ListTopicsResponse" in to_str(response.content)
        assert not response.headers.get("access-control-allow-headers")
        assert not response.headers.get("access-control-allow-methods")
        assert not response.headers.get("access-control-allow-origin")
        assert not response.headers.get("access-control-allow-credentials")
Ejemplo n.º 26
0
    def forward_request(self, method, path, data, headers):
        data = data and json.loads(to_str(data))

        # Paths to match
        regex2 = r'^/restapis/([A-Za-z0-9_\-]+)/([A-Za-z0-9_\-]+)/%s/(.*)$' % PATH_USER_REQUEST

        if re.match(regex2, path):
            search_match = re.search(regex2, path)
            api_id = search_match.group(1)
            relative_path = '/%s' % search_match.group(3)
            try:
                integration = aws_stack.get_apigateway_integration(api_id, method, path=relative_path)
                assert integration
            except Exception:
                # if we have no exact match, try to find an API resource that contains path parameters
                path_map = get_rest_api_paths(rest_api_id=api_id)
                try:
                    extracted_path, resource = get_resource_for_path(path=relative_path, path_map=path_map)
                except Exception:
                    return make_error('Unable to find path %s' % path, 404)

                integrations = resource.get('resourceMethods', {})
                integration = integrations.get(method, {})
                integration = integration.get('methodIntegration')
                if not integration:

                    if method == 'OPTIONS' and 'Origin' in headers:
                        # default to returning CORS headers if this is an OPTIONS request
                        return get_cors_response(headers)

                    return make_error('Unable to find integration for path %s' % path, 404)

            uri = integration.get('uri')
            if method == 'POST' and integration['type'] == 'AWS':
                if uri.endswith('kinesis:action/PutRecords'):
                    template = integration['requestTemplates'][APPLICATION_JSON]
                    new_request = aws_stack.render_velocity_template(template, data)

                    # forward records to target kinesis stream
                    headers = aws_stack.mock_aws_request_headers(service='kinesis')
                    headers['X-Amz-Target'] = kinesis_listener.ACTION_PUT_RECORDS
                    result = common.make_http_request(url=TEST_KINESIS_URL,
                        method='POST', data=new_request, headers=headers)
                    return result
                else:
                    msg = 'API Gateway action uri "%s" not yet implemented' % uri
                    LOGGER.warning(msg)
                    return make_error(msg, 404)

            elif integration['type'] == 'AWS_PROXY':
                if uri.startswith('arn:aws:apigateway:') and ':lambda:path' in uri:
                    func_arn = uri.split(':lambda:path')[1].split('functions/')[1].split('/invocations')[0]
                    data_str = json.dumps(data) if isinstance(data, dict) else data

                    try:
                        path_params = extract_path_params(path=relative_path, extracted_path=extracted_path)
                    except Exception:
                        path_params = {}
                    result = lambda_api.process_apigateway_invocation(func_arn, relative_path, data_str,
                        headers, path_params=path_params, method=method, resource_path=path)

                    if isinstance(result, FlaskResponse):
                        return flask_to_requests_response(result)

                    response = Response()
                    parsed_result = result if isinstance(result, dict) else json.loads(result)
                    parsed_result = common.json_safe(parsed_result)
                    response.status_code = int(parsed_result.get('statusCode', 200))
                    response.headers.update(parsed_result.get('headers', {}))
                    try:
                        if isinstance(parsed_result['body'], dict):
                            response._content = json.dumps(parsed_result['body'])
                        else:
                            response._content = parsed_result['body']
                    except Exception:
                        response._content = '{}'
                    return response
                else:
                    msg = 'API Gateway action uri "%s" not yet implemented' % uri
                    LOGGER.warning(msg)
                    return make_error(msg, 404)

            elif integration['type'] == 'HTTP':
                function = getattr(requests, method.lower())
                if isinstance(data, dict):
                    data = json.dumps(data)
                result = function(integration['uri'], data=data, headers=headers)
                return result

            else:
                msg = ('API Gateway integration type "%s" for method "%s" not yet implemented' %
                    (integration['type'], method))
                LOGGER.warning(msg)
                return make_error(msg, 404)

            return 200

        if re.match(PATH_REGEX_AUTHORIZERS, path):
            return handle_authorizers(method, path, data, headers)

        return True
Ejemplo n.º 27
0
    def forward_request(self, method, path, data, headers):
        if method == "OPTIONS":
            return 200

        req_data = parse_request_data(method, path, data)

        if is_sqs_queue_url(path) and method == "GET":
            if not headers.get("Authorization"):
                headers["Authorization"] = aws_stack.mock_aws_request_headers(
                    service="sqs")["Authorization"]
            method = "POST"
            req_data = {
                "Action": "GetQueueUrl",
                "Version": API_VERSION,
                "QueueName": path.split("/")[-1],
            }

        if req_data:
            action = req_data.get("Action")

            if action in ("SendMessage",
                          "SendMessageBatch") and SQS_BACKEND_IMPL == "moto":
                # check message contents
                for key, value in req_data.items():
                    if not re.match(MSG_CONTENT_REGEX, str(value)):
                        return make_requests_error(
                            code=400,
                            code_string="InvalidMessageContents",
                            message="Message contains invalid characters",
                        )

            elif action == "SetQueueAttributes":
                # TODO remove this function if we stop using ElasticMQ
                queue_url = _queue_url(path, req_data, headers)
                if SQS_BACKEND_IMPL == "elasticmq":
                    forward_attrs = _set_queue_attributes(queue_url, req_data)
                    if len(req_data) != len(forward_attrs):
                        # make sure we only forward the supported attributes to the backend
                        return _get_attributes_forward_request(
                            method, path, headers, req_data, forward_attrs)

            elif action == "TagQueue":
                req_data = self.fix_missing_tag_values(req_data)

            elif action == "CreateQueue":
                req_data = self.fix_missing_tag_values(req_data)

                def _is_fifo():
                    for k, v in req_data.items():
                        if v == "FifoQueue":
                            return req_data[k.replace(
                                "Name", "Value")].lower() == "true"
                    return False

                if req_data.get("QueueName").endswith(
                        ".fifo") and not _is_fifo():
                    msg = "Can only include alphanumeric characters, hyphens, or underscores. 1 to 80 in length"
                    return make_requests_error(
                        code=400,
                        code_string="InvalidParameterValue",
                        message=msg)
                changed_attrs = _fix_dlq_arn_in_attributes(req_data)
                if changed_attrs:
                    return _get_attributes_forward_request(
                        method, path, headers, req_data, changed_attrs)

            elif action == "DeleteQueue":
                queue_url = _queue_url(path, req_data, headers)
                QUEUE_ATTRIBUTES.pop(queue_url, None)
                sns_listener.unsubscribe_sqs_queue(queue_url)

            elif action == "ListDeadLetterSourceQueues":
                # TODO remove this function if we stop using ElasticMQ entirely
                queue_url = _queue_url(path, req_data, headers)
                if SQS_BACKEND_IMPL == "elasticmq":
                    headers = {"content-type": "application/xhtml+xml"}
                    content_str = _list_dead_letter_source_queues(
                        QUEUE_ATTRIBUTES, queue_url)
                    return requests_response(content_str, headers=headers)

            if "QueueName" in req_data:
                encoded_data = urlencode(
                    req_data, doseq=True) if method == "POST" else ""
                modified_url = None
                if method == "GET":
                    base_path = path.partition("?")[0]
                    modified_url = "%s?%s" % (
                        base_path,
                        urlencode(req_data, doseq=True),
                    )
                return Request(data=encoded_data,
                               url=modified_url,
                               headers=headers,
                               method=method)

        return True
Ejemplo n.º 28
0
def update_apigateway(method,
                      path,
                      data,
                      headers,
                      response=None,
                      return_forward_info=False):
    if return_forward_info:

        regex2 = r'^/restapis/([A-Za-z0-9_\-]+)/([A-Za-z0-9_\-]+)/%s/([^/]*)$' % PATH_USER_REQUEST
        if re.match(regex2, path):
            search_match = re.search(regex2, path)
            api_id = search_match.group(1)
            sub_path = '/%s' % search_match.group(3)
            try:
                integration = aws_stack.get_apigateway_integration(
                    api_id, method, sub_path)
            except Exception as e:
                msg = ('API Gateway endpoint "%s" for method "%s" not found' %
                       (path, method))
                LOGGER.warning(msg)
                return make_error(msg, 404)
            if method == 'POST' and integration['type'] == 'AWS':
                if integration['uri'].endswith('kinesis:action/PutRecords'):
                    template = integration['requestTemplates'][
                        APPLICATION_JSON]
                    new_request = aws_stack.render_velocity_template(
                        template, data)

                    # forward records to target kinesis stream
                    headers = aws_stack.mock_aws_request_headers(
                        service='kinesis')
                    headers['X-Amz-Target'] = KINESIS_ACTION_PUT_RECORDS
                    result = common.make_http_request(url=TEST_KINESIS_URL,
                                                      method='POST',
                                                      data=new_request,
                                                      headers=headers)
                    return result
                else:
                    msg = 'API Gateway action uri "%s" not yet implemented' % integration[
                        'uri']
                    LOGGER.warning(msg)
                    return make_error(msg, 404)

            elif integration['type'] == 'HTTP':
                function = getattr(requests, method.lower())
                if isinstance(data, dict):
                    data = json.dumps(data)
                result = function(integration['uri'],
                                  data=data,
                                  headers=headers)
                return result

            else:
                msg = (
                    'API Gateway integration type "%s" for method "%s" not yet implemented'
                    % (integration['type'], method))
                LOGGER.warning(msg)
                return make_error(msg, 404)

            return 200

        return True
Ejemplo n.º 29
0
def invoke_rest_api(api_id,
                    stage,
                    method,
                    invocation_path,
                    data,
                    headers,
                    path=None):
    path = path or invocation_path
    relative_path, query_string_params = extract_query_string_params(
        path=invocation_path)

    # run gateway authorizers for this request
    authorize_invocation(api_id, headers)
    path_map = helpers.get_rest_api_paths(rest_api_id=api_id)
    try:
        extracted_path, resource = get_resource_for_path(path=relative_path,
                                                         path_map=path_map)
    except Exception:
        return make_error_response('Unable to find path %s' % path, 404)

    api_key_required = resource.get('resourceMethods',
                                    {}).get(method, {}).get('apiKeyRequired')
    if not is_api_key_valid(api_key_required, headers, stage):
        return make_error_response('Access denied - invalid API key', 403)

    integrations = resource.get('resourceMethods', {})
    integration = integrations.get(method, {})
    if not integration:
        integration = integrations.get('ANY', {})
    integration = integration.get('methodIntegration')
    if not integration:
        if method == 'OPTIONS' and 'Origin' in headers:
            # default to returning CORS headers if this is an OPTIONS request
            return get_cors_response(headers)
        return make_error_response(
            'Unable to find integration for path %s' % path, 404)

    uri = integration.get('uri') or ''
    integration_type = integration['type'].upper()

    if uri.startswith('arn:aws:apigateway:') and ':lambda:path' in uri:
        if integration_type in ['AWS', 'AWS_PROXY']:
            func_arn = uri.split(':lambda:path')[1].split(
                'functions/')[1].split('/invocations')[0]
            data_str = json.dumps(data) if isinstance(data,
                                                      (dict,
                                                       list)) else to_str(data)
            account_id = uri.split(':lambda:path')[1].split(
                ':function:')[0].split(':')[-1]
            source_ip = headers['X-Forwarded-For'].split(',')[-2]

            try:
                path_params = extract_path_params(
                    path=relative_path, extracted_path=extracted_path)
            except Exception:
                path_params = {}

            # apply custom request template
            data_str = apply_template(integration,
                                      'request',
                                      data_str,
                                      path_params=path_params,
                                      query_params=query_string_params,
                                      headers=headers)

            # Sample request context:
            # https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-create-api-as-simple-proxy-for-lambda.html#api-gateway-create-api-as-simple-proxy-for-lambda-test
            request_context = {
                # adding stage to the request context path.
                # https://github.com/localstack/localstack/issues/2210
                'path': '/' + stage + relative_path,
                'accountId': account_id,
                'resourceId': resource.get('id'),
                'stage': stage,
                'identity': {
                    'accountId': account_id,
                    'sourceIp': source_ip,
                    'userAgent': headers['User-Agent'],
                },
                'httpMethod': method,
                'protocol': 'HTTP/1.1',
                'requestTime': datetime.datetime.utcnow(),
                'requestTimeEpoch': int(time.time() * 1000),
            }

            result = lambda_api.process_apigateway_invocation(
                func_arn,
                relative_path,
                data_str,
                stage,
                api_id,
                headers,
                path_params=path_params,
                query_string_params=query_string_params,
                method=method,
                resource_path=path,
                request_context=request_context)

            if isinstance(result, FlaskResponse):
                response = flask_to_requests_response(result)
            elif isinstance(result, Response):
                response = result
            else:
                response = LambdaResponse()
                parsed_result = result if isinstance(
                    result, dict) else json.loads(str(result or '{}'))
                parsed_result = common.json_safe(parsed_result)
                parsed_result = {} if parsed_result is None else parsed_result
                response.status_code = int(parsed_result.get(
                    'statusCode', 200))
                parsed_headers = parsed_result.get('headers', {})
                if parsed_headers is not None:
                    response.headers.update(parsed_headers)
                try:
                    if isinstance(parsed_result['body'], dict):
                        response._content = json.dumps(parsed_result['body'])
                    else:
                        response._content = to_bytes(parsed_result['body'])
                except Exception:
                    response._content = '{}'
                update_content_length(response)
                response.multi_value_headers = parsed_result.get(
                    'multiValueHeaders') or {}

            # apply custom response template
            response._content = apply_template(integration, 'response',
                                               response._content)
            response.headers['Content-Length'] = str(
                len(response.content or ''))

            return response

        msg = 'API Gateway AWS integration action URI "%s", method "%s" not yet implemented' % (
            uri, method)
        LOGGER.warning(msg)
        return make_error_response(msg, 404)

    elif integration_type == 'AWS':
        if 'kinesis:action/' in uri:
            if uri.endswith('kinesis:action/PutRecords'):
                target = kinesis_listener.ACTION_PUT_RECORDS
            if uri.endswith('kinesis:action/ListStreams'):
                target = kinesis_listener.ACTION_LIST_STREAMS

            template = integration['requestTemplates'][APPLICATION_JSON]
            new_request = aws_stack.render_velocity_template(template, data)
            # forward records to target kinesis stream
            headers = aws_stack.mock_aws_request_headers(service='kinesis')
            headers['X-Amz-Target'] = target
            result = common.make_http_request(url=TEST_KINESIS_URL,
                                              method='POST',
                                              data=new_request,
                                              headers=headers)
            # TODO apply response template..?
            return result

        if method == 'POST':
            if uri.startswith('arn:aws:apigateway:') and ':sqs:path' in uri:
                template = integration['requestTemplates'][APPLICATION_JSON]
                account_id, queue = uri.split('/')[-2:]
                region_name = uri.split(':')[3]

                new_request = '%s&QueueName=%s' % (
                    aws_stack.render_velocity_template(template, data), queue)
                headers = aws_stack.mock_aws_request_headers(
                    service='sqs', region_name=region_name)

                url = urljoin(TEST_SQS_URL,
                              '%s/%s' % (TEST_AWS_ACCOUNT_ID, queue))
                result = common.make_http_request(url,
                                                  method='POST',
                                                  headers=headers,
                                                  data=new_request)
                return result

        msg = 'API Gateway AWS integration action URI "%s", method "%s" not yet implemented' % (
            uri, method)
        LOGGER.warning(msg)
        return make_error_response(msg, 404)

    elif integration_type == 'AWS_PROXY':
        if uri.startswith('arn:aws:apigateway:') and ':dynamodb:action' in uri:
            # arn:aws:apigateway:us-east-1:dynamodb:action/PutItem&Table=MusicCollection
            table_name = uri.split(':dynamodb:action')[1].split('&Table=')[1]
            action = uri.split(':dynamodb:action')[1].split('&Table=')[0]

            if 'PutItem' in action and method == 'PUT':
                response_template = path_map.get(relative_path, {}).get('resourceMethods', {})\
                    .get(method, {}).get('methodIntegration', {}).\
                    get('integrationResponses', {}).get('200', {}).get('responseTemplates', {})\
                    .get('application/json', None)

                if response_template is None:
                    msg = 'Invalid response template defined in integration response.'
                    return make_error_response(msg, 404)

                response_template = json.loads(response_template)
                if response_template['TableName'] != table_name:
                    msg = 'Invalid table name specified in integration response template.'
                    return make_error_response(msg, 404)

                dynamo_client = aws_stack.connect_to_resource('dynamodb')
                table = dynamo_client.Table(table_name)

                event_data = {}
                data_dict = json.loads(data)
                for key, _ in response_template['Item'].items():
                    event_data[key] = data_dict[key]

                table.put_item(Item=event_data)
                response = requests_response(
                    event_data, headers=aws_stack.mock_aws_request_headers())
                return response
        else:
            msg = 'API Gateway action uri "%s" not yet implemented' % uri
            LOGGER.warning(msg)
            return make_error_response(msg, 404)

    elif integration_type in ['HTTP_PROXY', 'HTTP']:
        function = getattr(requests, method.lower())

        # apply custom request template
        data = apply_template(integration, 'request', data)

        if isinstance(data, dict):
            data = json.dumps(data)

        result = function(integration['uri'], data=data, headers=headers)

        # apply custom response template
        data = apply_template(integration, 'response', data)

        return result

    elif integration_type == 'MOCK':
        # TODO: add logic for MOCK responses
        pass

    if method == 'OPTIONS':
        # fall back to returning CORS headers if this is an OPTIONS request
        return get_cors_response(headers)

    msg = (
        'API Gateway integration type "%s", method "%s", URI "%s" not yet implemented'
        % (integration['type'], method, integration.get('uri')))
    LOGGER.warning(msg)
    return make_error_response(msg, 404)
Ejemplo n.º 30
0
def invoke_rest_api(api_id, stage, method, invocation_path, data, headers, path=None):
    path = path or invocation_path
    relative_path, query_string_params = extract_query_string_params(path=invocation_path)

    path_map = helpers.get_rest_api_paths(rest_api_id=api_id)
    try:
        extracted_path, resource = get_resource_for_path(path=relative_path, path_map=path_map)
    except Exception:
        return make_error('Unable to find path %s' % path, 404)

    integrations = resource.get('resourceMethods', {})
    integration = integrations.get(method, {})
    if not integration:
        integration = integrations.get('ANY', {})
    integration = integration.get('methodIntegration')
    if not integration:
        if method == 'OPTIONS' and 'Origin' in headers:
            # default to returning CORS headers if this is an OPTIONS request
            return get_cors_response(headers)
        return make_error('Unable to find integration for path %s' % path, 404)

    uri = integration.get('uri')
    if method == 'POST' and integration['type'] == 'AWS':
        if uri.endswith('kinesis:action/PutRecords'):
            template = integration['requestTemplates'][APPLICATION_JSON]
            new_request = aws_stack.render_velocity_template(template, data)

            # forward records to target kinesis stream
            headers = aws_stack.mock_aws_request_headers(service='kinesis')
            headers['X-Amz-Target'] = kinesis_listener.ACTION_PUT_RECORDS
            result = common.make_http_request(url=TEST_KINESIS_URL,
                method='POST', data=new_request, headers=headers)
            return result

        elif uri.startswith('arn:aws:apigateway:') and ':sqs:path' in uri:
            template = integration['requestTemplates'][APPLICATION_JSON]
            account_id, queue = uri.split('/')[-2:]
            region_name = uri.split(':')[3]

            new_request = aws_stack.render_velocity_template(template, data) + '&QueueName=%s' % queue
            headers = aws_stack.mock_aws_request_headers(service='sqs', region_name=region_name)

            url = urljoin(TEST_SQS_URL, '%s/%s' % (account_id, queue))
            result = common.make_http_request(url, method='POST', headers=headers, data=new_request)
            return result

        else:
            msg = 'API Gateway action uri "%s" not yet implemented' % uri
            LOGGER.warning(msg)
            return make_error(msg, 404)

    elif integration['type'] == 'AWS_PROXY':
        if uri.startswith('arn:aws:apigateway:') and ':lambda:path' in uri:
            func_arn = uri.split(':lambda:path')[1].split('functions/')[1].split('/invocations')[0]
            data_str = json.dumps(data) if isinstance(data, (dict, list)) else data
            account_id = uri.split(':lambda:path')[1].split(':function:')[0].split(':')[-1]

            source_ip = headers['X-Forwarded-For'].split(',')[-2]

            # Sample request context:
            # https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-create-api-as-simple-proxy-for-lambda.html#api-gateway-create-api-as-simple-proxy-for-lambda-test
            request_context = {
                'path': relative_path,
                'accountId': account_id,
                'resourceId': resource.get('id'),
                'stage': stage,
                'identity': {
                    'accountId': account_id,
                    'sourceIp': source_ip,
                    'userAgent': headers['User-Agent'],
                }
            }

            try:
                path_params = extract_path_params(path=relative_path, extracted_path=extracted_path)
            except Exception:
                path_params = {}

            result = lambda_api.process_apigateway_invocation(func_arn, relative_path, data_str,
                headers, path_params=path_params, query_string_params=query_string_params,
                method=method, resource_path=path, request_context=request_context)

            if isinstance(result, FlaskResponse):
                return flask_to_requests_response(result)
            if isinstance(result, Response):
                return result

            response = Response()
            parsed_result = result if isinstance(result, dict) else json.loads(result)
            parsed_result = common.json_safe(parsed_result)
            parsed_result = {} if parsed_result is None else parsed_result
            response.status_code = int(parsed_result.get('statusCode', 200))
            response.headers.update(parsed_result.get('headers', {}))
            try:
                if isinstance(parsed_result['body'], dict):
                    response._content = json.dumps(parsed_result['body'])
                else:
                    response._content = to_bytes(parsed_result['body'])
            except Exception:
                response._content = '{}'
            response.headers['Content-Length'] = len(response._content)
            return response
        else:
            msg = 'API Gateway action uri "%s" not yet implemented' % uri
            LOGGER.warning(msg)
            return make_error(msg, 404)

    elif integration['type'] == 'HTTP':
        function = getattr(requests, method.lower())
        if isinstance(data, dict):
            data = json.dumps(data)
        result = function(integration['uri'], data=data, headers=headers)
        return result

    else:
        msg = ('API Gateway integration type "%s" for method "%s" not yet implemented' %
               (integration['type'], method))
        LOGGER.warning(msg)
        return make_error(msg, 404)

    return 200
Ejemplo n.º 31
0
def invoke_rest_api_integration(api_id,
                                stage,
                                integration,
                                method,
                                path,
                                invocation_path,
                                data,
                                headers,
                                resource_path,
                                context={},
                                resource_id=None,
                                response_templates={}):

    relative_path, query_string_params = extract_query_string_params(
        path=invocation_path)
    integration_type = integration.get('type') or integration.get(
        'integrationType')
    uri = integration.get('uri') or integration.get('integrationUri')

    if (uri.startswith('arn:aws:apigateway:')
            and ':lambda:path' in uri) or uri.startswith('arn:aws:lambda'):
        if integration_type in ['AWS', 'AWS_PROXY']:
            func_arn = uri
            if ':lambda:path' in uri:
                func_arn = uri.split(':lambda:path')[1].split(
                    'functions/')[1].split('/invocations')[0]
            data_str = json.dumps(data) if isinstance(data,
                                                      (dict,
                                                       list)) else to_str(data)

            try:
                path_params = extract_path_params(path=relative_path,
                                                  extracted_path=resource_path)
            except Exception:
                path_params = {}

            # apply custom request template
            data_str = apply_template(integration,
                                      'request',
                                      data_str,
                                      path_params=path_params,
                                      query_params=query_string_params,
                                      headers=headers)

            # Sample request context:
            # https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-create-api-as-simple-proxy-for-lambda.html#api-gateway-create-api-as-simple-proxy-for-lambda-test
            request_context = get_lambda_event_request_context(
                method,
                path,
                data,
                headers,
                integration_uri=uri,
                resource_id=resource_id)
            stage_variables = get_stage_variables(api_id, stage)

            result = lambda_api.process_apigateway_invocation(
                func_arn,
                relative_path,
                data_str,
                stage,
                api_id,
                headers,
                path_params=path_params,
                query_string_params=query_string_params,
                method=method,
                resource_path=path,
                request_context=request_context,
                event_context=context,
                stage_variables=stage_variables)

            if isinstance(result, FlaskResponse):
                response = flask_to_requests_response(result)
            elif isinstance(result, Response):
                response = result
            else:
                response = LambdaResponse()
                parsed_result = result if isinstance(
                    result, dict) else json.loads(str(result or '{}'))
                parsed_result = common.json_safe(parsed_result)
                parsed_result = {} if parsed_result is None else parsed_result
                response.status_code = int(parsed_result.get(
                    'statusCode', 200))
                parsed_headers = parsed_result.get('headers', {})
                if parsed_headers is not None:
                    response.headers.update(parsed_headers)
                try:
                    if isinstance(parsed_result['body'], dict):
                        response._content = json.dumps(parsed_result['body'])
                    else:
                        response._content = to_bytes(parsed_result['body'])
                except Exception:
                    response._content = '{}'
                update_content_length(response)
                response.multi_value_headers = parsed_result.get(
                    'multiValueHeaders') or {}

            # apply custom response template
            response._content = apply_template(integration, 'response',
                                               response._content)
            response.headers['Content-Length'] = str(
                len(response.content or ''))

            return response

        msg = 'API Gateway AWS integration action URI "%s", method "%s" not yet implemented' % (
            uri, method)
        LOGGER.warning(msg)
        return make_error_response(msg, 404)

    elif integration_type == 'AWS':
        if 'kinesis:action/' in uri:
            if uri.endswith('kinesis:action/PutRecords'):
                target = kinesis_listener.ACTION_PUT_RECORDS
            if uri.endswith('kinesis:action/ListStreams'):
                target = kinesis_listener.ACTION_LIST_STREAMS

            template = integration['requestTemplates'][APPLICATION_JSON]
            new_request = aws_stack.render_velocity_template(template, data)
            # forward records to target kinesis stream
            headers = aws_stack.mock_aws_request_headers(service='kinesis')
            headers['X-Amz-Target'] = target
            result = common.make_http_request(url=TEST_KINESIS_URL,
                                              method='POST',
                                              data=new_request,
                                              headers=headers)
            # TODO apply response template..?
            return result

        elif 'states:action/' in uri:
            if uri.endswith('states:action/StartExecution'):
                action = 'StartExecution'
            decoded_data = data.decode()
            if 'stateMachineArn' in decoded_data and 'input' in decoded_data:
                payload = json.loads(decoded_data)
            elif APPLICATION_JSON in integration.get('requestTemplates', {}):
                template = integration['requestTemplates'][APPLICATION_JSON]
                payload = aws_stack.render_velocity_template(template,
                                                             data,
                                                             as_json=True)
            client = aws_stack.connect_to_service('stepfunctions')

            kwargs = {'name': payload['name']} if 'name' in payload else {}
            result = client.start_execution(
                stateMachineArn=payload['stateMachineArn'],
                input=payload['input'],
                **kwargs)
            response = requests_response(
                content={
                    'executionArn': result['executionArn'],
                    'startDate': str(result['startDate'])
                },
                headers=aws_stack.mock_aws_request_headers())
            return response

        if method == 'POST':
            if uri.startswith('arn:aws:apigateway:') and ':sqs:path' in uri:
                template = integration['requestTemplates'][APPLICATION_JSON]
                account_id, queue = uri.split('/')[-2:]
                region_name = uri.split(':')[3]

                new_request = '%s&QueueName=%s' % (
                    aws_stack.render_velocity_template(template, data), queue)
                headers = aws_stack.mock_aws_request_headers(
                    service='sqs', region_name=region_name)

                url = urljoin(TEST_SQS_URL,
                              '%s/%s' % (TEST_AWS_ACCOUNT_ID, queue))
                result = common.make_http_request(url,
                                                  method='POST',
                                                  headers=headers,
                                                  data=new_request)
                return result

        msg = 'API Gateway AWS integration action URI "%s", method "%s" not yet implemented' % (
            uri, method)
        LOGGER.warning(msg)
        return make_error_response(msg, 404)

    elif integration_type == 'AWS_PROXY':
        if uri.startswith('arn:aws:apigateway:') and ':dynamodb:action' in uri:
            # arn:aws:apigateway:us-east-1:dynamodb:action/PutItem&Table=MusicCollection
            table_name = uri.split(':dynamodb:action')[1].split('&Table=')[1]
            action = uri.split(':dynamodb:action')[1].split('&Table=')[0]

            if 'PutItem' in action and method == 'PUT':
                response_template = response_templates.get(
                    'application/json', None)

                if response_template is None:
                    msg = 'Invalid response template defined in integration response.'
                    return make_error_response(msg, 404)

                response_template = json.loads(response_template)
                if response_template['TableName'] != table_name:
                    msg = 'Invalid table name specified in integration response template.'
                    return make_error_response(msg, 404)

                dynamo_client = aws_stack.connect_to_resource('dynamodb')
                table = dynamo_client.Table(table_name)

                event_data = {}
                data_dict = json.loads(data)
                for key, _ in response_template['Item'].items():
                    event_data[key] = data_dict[key]

                table.put_item(Item=event_data)
                response = requests_response(
                    event_data, headers=aws_stack.mock_aws_request_headers())
                return response
        else:
            msg = 'API Gateway action uri "%s", integration type %s not yet implemented' % (
                uri, integration_type)
            LOGGER.warning(msg)
            return make_error_response(msg, 404)

    elif integration_type in ['HTTP_PROXY', 'HTTP']:
        function = getattr(requests, method.lower())

        # apply custom request template
        data = apply_template(integration, 'request', data)

        if isinstance(data, dict):
            data = json.dumps(data)

        result = function(uri, data=data, headers=headers)

        # apply custom response template
        data = apply_template(integration, 'response', data)

        return result

    elif integration_type == 'MOCK':
        # TODO: add logic for MOCK responses
        pass

    if method == 'OPTIONS':
        # fall back to returning CORS headers if this is an OPTIONS request
        return get_cors_response(headers)

    msg = (
        'API Gateway integration type "%s", method "%s", URI "%s" not yet implemented'
        % (integration_type, method, uri))
    LOGGER.warning(msg)
    return make_error_response(msg, 404)