예제 #1
0
def check_s3(expect_shutdown=False, print_error=False):
    out = None
    try:
        # wait for port to be opened
        wait_for_port_open(DEFAULT_PORT_S3_BACKEND)
        # check S3
        out = aws_stack.connect_to_service(service_name='s3').list_buckets()
    except Exception as e:
        if print_error:
            LOGGER.error('S3 health check failed: %s %s' % (e, traceback.format_exc()))
    if expect_shutdown:
        assert out is None
    else:
        assert isinstance(out['Buckets'], list)
예제 #2
0
def check_dynamodb(expect_shutdown=False, print_error=False):
    out = None
    try:
        # wait for port to be opened
        wait_for_port_open(DEFAULT_PORT_DYNAMODB_BACKEND)
        # check DynamoDB
        out = aws_stack.connect_to_service(service_name='dynamodb').list_tables()
    except Exception as e:
        if print_error:
            LOGGER.error('DynamoDB health check failed: %s %s' % (e, traceback.format_exc()))
    if expect_shutdown:
        assert out is None
    else:
        assert isinstance(out['TableNames'], list)
예제 #3
0
    def test_redrive_policy_http_subscription(self):
        self.unsubscribe_all_from_sns()

        # create HTTP endpoint and connect it to SNS topic
        class MyUpdateListener(ProxyListener):
            def forward_request(self, method, path, data, headers):
                records.append((json.loads(to_str(data)), headers))
                return 200

        records = []
        local_port = get_free_tcp_port()
        proxy = start_proxy(local_port,
                            backend_url=None,
                            update_listener=MyUpdateListener())
        wait_for_port_open(local_port)
        http_endpoint = '%s://localhost:%s' % (get_service_protocol(),
                                               local_port)

        subscription = self.sns_client.subscribe(TopicArn=self.topic_arn,
                                                 Protocol='http',
                                                 Endpoint=http_endpoint)
        self.sns_client.set_subscription_attributes(
            SubscriptionArn=subscription['SubscriptionArn'],
            AttributeName='RedrivePolicy',
            AttributeValue=json.dumps({
                'deadLetterTargetArn':
                aws_stack.sqs_queue_arn(TEST_QUEUE_DLQ_NAME)
            }))

        proxy.stop()

        self.sns_client.publish(TopicArn=self.topic_arn,
                                Message=json.dumps(
                                    {'message': 'test_redrive_policy'}))

        def receive_dlq():
            result = self.sqs_client.receive_message(
                QueueUrl=self.dlq_url, MessageAttributeNames=['All'])
            self.assertGreater(len(result['Messages']), 0)
            self.assertEqual(
                json.loads(
                    json.loads(result['Messages'][0]['Body'])['Message'][0])
                ['message'], 'test_redrive_policy')

        retry(receive_dlq, retries=10, sleep=2)
예제 #4
0
def start_kms_local(port=None, backend_port=None, asynchronous=None, update_listener=None):
    port = port or config.PORT_KMS
    backend_port = get_free_tcp_port()
    kms_binary = INSTALL_PATH_KMS_BINARY_PATTERN.replace("<arch>", get_os())
    log_startup_message("KMS")
    start_proxy_for_service("kms", port, backend_port, update_listener)
    env_vars = {
        "PORT": str(backend_port),
        "KMS_REGION": config.DEFAULT_REGION,
        "REGION": config.DEFAULT_REGION,
        "KMS_ACCOUNT_ID": TEST_AWS_ACCOUNT_ID,
        "ACCOUNT_ID": TEST_AWS_ACCOUNT_ID,
    }
    if config.dirs.data:
        env_vars["KMS_DATA_PATH"] = config.dirs.data
    result = do_run(kms_binary, asynchronous, env_vars=env_vars)
    wait_for_port_open(backend_port)
    return result
예제 #5
0
def start_kms(port=None, backend_port=None, asynchronous=None, update_listener=None):
    port = port or config.PORT_KMS
    backend_port = get_free_tcp_port()
    kms_binary = INSTALL_PATH_KMS_BINARY_PATTERN.replace('<arch>', get_arch())
    log_startup_message('KMS')
    start_proxy_for_service('kms', port, backend_port, update_listener)
    env_vars = {
        'PORT': str(backend_port),
        'KMS_REGION': config.DEFAULT_REGION,
        'REGION': config.DEFAULT_REGION,
        'KMS_ACCOUNT_ID': TEST_AWS_ACCOUNT_ID,
        'ACCOUNT_ID': TEST_AWS_ACCOUNT_ID
    }
    if config.DATA_DIR:
        env_vars['KMS_DATA_PATH'] = config.DATA_DIR
    result = do_run(kms_binary, asynchronous, env_vars=env_vars)
    wait_for_port_open(backend_port)
    return result
예제 #6
0
def check_dynamodb(expect_shutdown=False, print_error=False):
    out = None
    try:
        # wait for backend port to be opened
        wait_for_port_open(PORT_DYNAMODB_BACKEND,
                           http_path='/',
                           expect_success=False,
                           sleep_time=1)
        # check DynamoDB
        out = aws_stack.connect_to_service('dynamodb').list_tables()
    except Exception as e:
        if print_error:
            LOGGER.error('DynamoDB health check failed: %s %s' %
                         (e, traceback.format_exc()))
    if expect_shutdown:
        assert out is None
    else:
        assert isinstance(out['TableNames'], list)
예제 #7
0
def start_kms(port=None,
              backend_port=None,
              asynchronous=None,
              update_listener=None):
    port = port or config.PORT_KMS
    backend_port = get_free_tcp_port()
    kms_binary = INSTALL_PATH_KMS_BINARY_PATTERN.replace('<arch>', get_arch())
    print('Starting mock KMS service on %s ...' % edge_ports_info())
    start_proxy_for_service('kms', port, backend_port, update_listener)
    env_vars = {
        'PORT': str(backend_port),
        'KMS_REGION': config.DEFAULT_REGION,
        'REGION': config.DEFAULT_REGION,
        'KMS_ACCOUNT_ID': TEST_AWS_ACCOUNT_ID,
        'ACCOUNT_ID': TEST_AWS_ACCOUNT_ID
    }
    result = do_run(kms_binary, asynchronous, env_vars=env_vars)
    wait_for_port_open(backend_port)
    return result
예제 #8
0
def start_server_process(port):
    if '__server__' in API_SERVERS:
        return API_SERVERS['__server__']['thread']
    port = port or MULTI_SERVER_PORT
    API_SERVERS['__server__'] = config = {'port': port}
    LOG.info('Starting multi API server process on port %s' % port)
    if RUN_SERVER_IN_PROCESS:
        cmd = '"%s" "%s" %s' % (sys.executable, __file__, port)
        env_vars = {
            'PYTHONPATH': '.:%s' % constants.LOCALSTACK_ROOT_FOLDER
        }
        thread = ShellCommandThread(cmd, outfile=subprocess.PIPE, env_vars=env_vars, inherit_cwd=True)
        thread.start()
    else:
        thread = start_server(port, asynchronous=True)

    TMP_THREADS.append(thread)
    config['thread'] = thread
    wait_for_port_open(port, retries=20, sleep_time=1)
    return thread
예제 #9
0
    def test_subscribe_http_endpoint(self):
        # create HTTP endpoint and connect it to SNS topic
        class MyUpdateListener(ProxyListener):
            def forward_request(self, method, path, data, headers):
                records.append((json.loads(to_str(data)), headers))
                return 200

        records = []
        local_port = get_free_tcp_port()
        proxy = start_proxy(local_port, backend_url=None, update_listener=MyUpdateListener())
        wait_for_port_open(local_port)
        queue_arn = '%s://localhost:%s' % (get_service_protocol(), local_port)
        self.sns_client.subscribe(TopicArn=self.topic_arn, Protocol='http', Endpoint=queue_arn)

        def received():
            assert records[0][0]['Type'] == 'SubscriptionConfirmation'
            assert records[0][1]['x-amz-sns-message-type'] == 'SubscriptionConfirmation'

        retry(received, retries=5, sleep=1)
        proxy.stop()
예제 #10
0
def check_lambda(expect_shutdown=False, print_error=False):
    out = None
    try:
        from localstack.services.infra import PROXY_LISTENERS
        from localstack.utils.aws import aws_stack
        from localstack.utils.common import wait_for_port_open

        # wait for port to be opened
        port = PROXY_LISTENERS.get("lambda")[1]
        wait_for_port_open(port)  # TODO get lambda port in a cleaner way

        endpoint_url = f"http://127.0.0.1:{port}"
        out = aws_stack.connect_to_service(
            service_name="lambda", endpoint_url=endpoint_url).list_functions()
    except Exception:
        if print_error:
            LOG.exception("Lambda health check failed")
    if expect_shutdown:
        assert out is None
    else:
        assert out and isinstance(out.get("Functions"), list)
예제 #11
0
def start_server_process(port):
    if "__server__" in API_SERVERS:
        return API_SERVERS["__server__"]["thread"]
    port = port or get_multi_server_port()
    API_SERVERS["__server__"] = config = {"port": port}
    LOG.info("Starting multi API server process on port %s" % port)
    if RUN_SERVER_IN_PROCESS:
        cmd = '"%s" "%s" %s' % (sys.executable, __file__, port)
        env_vars = {"PYTHONPATH": ".:%s" % constants.LOCALSTACK_ROOT_FOLDER}
        thread = ShellCommandThread(cmd,
                                    outfile=subprocess.PIPE,
                                    env_vars=env_vars,
                                    inherit_cwd=True)
        thread.start()
    else:
        thread = start_server(port, asynchronous=True)

    TMP_THREADS.append(thread)
    config["thread"] = thread
    wait_for_port_open(port, retries=20, sleep_time=1)
    return thread
예제 #12
0
def test_ssl_proxy_server():
    class MyListener(ProxyListener):
        def forward_request(self, *args, **kwargs):
            invocations.append((args, kwargs))
            return {"foo": "bar"}

    invocations = []

    # start SSL proxy
    listener = MyListener()
    port = get_free_tcp_port()
    server = start_proxy_server(port, update_listener=listener, use_ssl=True)
    wait_for_port_open(port)

    # start SSL proxy
    proxy_port = get_free_tcp_port()
    proxy = start_ssl_proxy(proxy_port, port, asynchronous=True, fix_encoding=True)
    wait_for_port_open(proxy_port)

    # invoke SSL proxy server
    url = f"https://{LOCALHOST_HOSTNAME}:{proxy_port}"
    num_requests = 3
    for i in range(num_requests):
        response = requests.get(url, verify=False)
        assert response.status_code == 200

    # assert backend server has been invoked
    assert len(invocations) == num_requests

    # invoke SSL proxy server with gzip response
    for encoding in ["gzip", "gzip, deflate"]:
        headers = {HEADER_ACCEPT_ENCODING: encoding}
        response = requests.get(url, headers=headers, verify=False, stream=True)
        result = response.raw.read()
        assert to_str(gzip.decompress(result)) == json.dumps({"foo": "bar"})

    # clean up
    proxy.stop()
    server.stop()
예제 #13
0
def check_secretsmanager(expect_shutdown=False, print_error=False):
    out = None

    # noinspection PyBroadException
    try:
        wait_for_port_open(PORT_SECRETS_MANAGER_BACKEND,
                           http_path="/",
                           expect_success=False)
        endpoint_url = f"http://127.0.0.1:{PORT_SECRETS_MANAGER_BACKEND}"
        out = aws_stack.connect_to_service(
            service_name="secretsmanager",
            endpoint_url=endpoint_url).list_secrets()
    except Exception:
        if print_error:
            logger = logging.getLogger(__name__)
            logger.exception("Secretsmanager health check failed")

    if expect_shutdown:
        assert out is None
        return

    assert isinstance(out["SecretList"], list)
예제 #14
0
    def test_static_route(self):
        class MyListener(ProxyListener):
            def forward_request(self, method, path, *args, **kwargs):
                return {"method": method, "path": path}

        # start proxy server
        listener = MyListener()
        port = get_free_tcp_port()
        server = start_proxy_server(port, update_listener=listener)
        wait_for_port_open(port)

        # request a /static/... path from the server and assert result
        url = f"http://{LOCALHOST_HOSTNAME}:{port}/static/index.html"
        response = requests.get(url, verify=False)
        assert response.ok
        assert json.loads(to_str(response.content)) == {
            "method": "GET",
            "path": "/static/index.html",
        }

        # clean up
        server.stop()
예제 #15
0
def check_kinesis(expect_shutdown=False, print_error=False):
    if expect_shutdown is False and kinesis_stopped.is_set():
        raise AssertionError("kinesis backend has stopped")

    out = None
    try:
        # check Kinesis
        wait_for_port_open(PORT_KINESIS_BACKEND,
                           http_path="/",
                           expect_success=False,
                           sleep_time=1)
        endpoint_url = f"http://127.0.0.1:{PORT_KINESIS_BACKEND}"
        out = aws_stack.connect_to_service(
            service_name="kinesis", endpoint_url=endpoint_url).list_streams()
    except Exception:
        if print_error:
            LOGGER.exception("Kinesis health check failed")

    if expect_shutdown:
        assert out is None or kinesis_stopped.is_set()
    else:
        assert not kinesis_stopped.is_set()
        assert out and isinstance(out.get("StreamNames"), list)
예제 #16
0
    def test_subscribe_http_endpoint(self):
        # create HTTP endpoint and connect it to SNS topic
        class MyUpdateListener(ProxyListener):
            def forward_request(self, method, path, data, headers):
                records.append((json.loads(to_str(data)), headers))
                return 200

        records = []
        local_port = get_free_tcp_port()
        proxy = start_proxy(local_port,
                            backend_url=None,
                            update_listener=MyUpdateListener())
        wait_for_port_open(local_port)
        queue_arn = "%s://localhost:%s" % (get_service_protocol(), local_port)
        self.sns_client.subscribe(TopicArn=self.topic_arn,
                                  Protocol="http",
                                  Endpoint=queue_arn)

        def received():
            self.assertEqual(records[0][0]["Type"], "SubscriptionConfirmation")
            self.assertEqual(records[0][1]["x-amz-sns-message-type"],
                             "SubscriptionConfirmation")

            token = records[0][0]["Token"]
            subscribe_url = records[0][0]["SubscribeURL"]

            self.assertEqual(
                subscribe_url,
                "%s/?Action=ConfirmSubscription&TopicArn=%s&Token=%s" %
                (external_service_url("sns"), self.topic_arn, token),
            )

            self.assertIn("Signature", records[0][0])
            self.assertIn("SigningCertURL", records[0][0])

        retry(received, retries=5, sleep=1)
        proxy.stop()
예제 #17
0
    def test_multiple_subscriptions_http_endpoint(self):
        self.unsubscribe_all_from_sns()

        # create HTTP endpoint and connect it to SNS topic
        class MyUpdateListener(ProxyListener):
            def forward_request(self, method, path, data, headers):
                records.append((json.loads(to_str(data)), headers))
                return 429

        number_of_subscriptions = 4
        records = []
        proxies = []

        for _ in range(number_of_subscriptions):
            local_port = get_free_tcp_port()
            proxies.append(
                start_proxy(local_port,
                            backend_url=None,
                            update_listener=MyUpdateListener()))
            wait_for_port_open(local_port)
            http_endpoint = "%s://localhost:%s" % (get_service_protocol(),
                                                   local_port)
            self.sns_client.subscribe(TopicArn=self.topic_arn,
                                      Protocol="http",
                                      Endpoint=http_endpoint)

        # fetch subscription information
        subscription_list = self.sns_client.list_subscriptions()
        self.assertEqual(
            subscription_list["ResponseMetadata"]["HTTPStatusCode"], 200)
        self.assertEqual(len(subscription_list["Subscriptions"]),
                         number_of_subscriptions)

        self.assertEqual(number_of_subscriptions, len(records))

        for proxy in proxies:
            proxy.stop()
예제 #18
0
    def test_scheduled_expression_events(self):
        class HttpEndpointListener(ProxyListener):
            def forward_request(self, method, path, data, headers):
                event = json.loads(to_str(data))
                events.append(event)
                return 200

        local_port = get_free_tcp_port()
        proxy = start_proxy(local_port,
                            backend_url=None,
                            update_listener=HttpEndpointListener())
        wait_for_port_open(local_port)

        topic_name = 'topic-{}'.format(short_uid())
        queue_name = 'queue-{}'.format(short_uid())
        rule_name = 'rule-{}'.format(short_uid())
        endpoint = '{}://{}:{}'.format(get_service_protocol(),
                                       config.LOCALSTACK_HOSTNAME, local_port)
        sm_role_arn = aws_stack.role_arn('sfn_role')
        sm_name = 'state-machine-{}'.format(short_uid())
        topic_target_id = 'target-{}'.format(short_uid())
        sm_target_id = 'target-{}'.format(short_uid())
        queue_target_id = 'target-{}'.format(short_uid())

        events = []
        state_machine_definition = """
        {
            "StartAt": "Hello",
            "States": {
                "Hello": {
                    "Type": "Pass",
                    "Result": "World",
                    "End": true
                }
            }
        }
        """

        state_machine_arn = self.sfn_client.create_state_machine(
            name=sm_name,
            definition=state_machine_definition,
            roleArn=sm_role_arn)['stateMachineArn']

        topic_arn = self.sns_client.create_topic(Name=topic_name)['TopicArn']
        self.sns_client.subscribe(TopicArn=topic_arn,
                                  Protocol='http',
                                  Endpoint=endpoint)

        queue_url = self.sqs_client.create_queue(
            QueueName=queue_name)['QueueUrl']
        queue_arn = aws_stack.sqs_queue_arn(queue_name)

        event = {'env': 'testing'}

        self.events_client.put_rule(Name=rule_name,
                                    ScheduleExpression='rate(1 minutes)')

        self.events_client.put_targets(Rule=rule_name,
                                       Targets=[{
                                           'Id': topic_target_id,
                                           'Arn': topic_arn,
                                           'Input': json.dumps(event)
                                       }, {
                                           'Id': sm_target_id,
                                           'Arn': state_machine_arn,
                                           'Input': json.dumps(event)
                                       }, {
                                           'Id': queue_target_id,
                                           'Arn': queue_arn,
                                           'Input': json.dumps(event)
                                       }])

        def received(q_url):
            # state machine got executed
            executions = self.sfn_client.list_executions(
                stateMachineArn=state_machine_arn)['executions']
            self.assertGreaterEqual(len(executions), 1)

            # http endpoint got events
            self.assertGreaterEqual(len(events), 2)
            notifications = [
                event['Message'] for event in events
                if event['Type'] == 'Notification'
            ]
            self.assertGreaterEqual(len(notifications), 1)

            # get state machine execution detail
            execution_arn = executions[0]['executionArn']
            execution_input = self.sfn_client.describe_execution(
                executionArn=execution_arn)['input']

            # get message from queue
            msgs = self.sqs_client.receive_message(QueueUrl=q_url).get(
                'Messages', [])
            self.assertGreaterEqual(len(msgs), 1)

            return execution_input, notifications[0], msgs[0]

        execution_input, notification, msg_received = retry(received,
                                                            retries=5,
                                                            sleep=15,
                                                            q_url=queue_url)
        self.assertEqual(json.loads(notification), event)
        self.assertEqual(json.loads(execution_input), event)
        self.assertEqual(json.loads(msg_received['Body']), event)

        proxy.stop()

        self.events_client.remove_targets(Rule=rule_name,
                                          Ids=[topic_target_id, sm_target_id],
                                          Force=True)
        self.events_client.delete_rule(Name=rule_name, Force=True)

        self.sns_client.delete_topic(TopicArn=topic_arn)
        self.sfn_client.delete_state_machine(stateMachineArn=state_machine_arn)

        self.sqs_client.delete_queue(QueueUrl=queue_url)
예제 #19
0
    def test_api_destinations(self, events_client):

        token = short_uid()
        bearer = "Bearer %s" % token

        class HttpEndpointListener(ProxyListener):
            def forward_request(self, method, path, data, headers):
                event = json.loads(to_str(data))
                events.append(event)
                paths_list.append(path)
                auth = headers.get("Api") or headers.get("Authorization")
                if auth not in headers_list:
                    headers_list.append(auth)

                if headers.get("target_header"):
                    headers_list.append(headers.get("target_header"))

                if "client_id" in event:
                    oauth_data.update(
                        {
                            "client_id": event.get("client_id"),
                            "client_secret": event.get("client_secret"),
                            "header_value": headers.get("oauthheader"),
                            "body_value": event.get("oauthbody"),
                            "path": path,
                        }
                    )

                return requests_response(
                    {
                        "access_token": token,
                        "token_type": "Bearer",
                        "expires_in": 86400,
                    }
                )

        events = []
        paths_list = []
        headers_list = []
        oauth_data = {}

        local_port = get_free_tcp_port()
        proxy = start_proxy(local_port, update_listener=HttpEndpointListener())
        wait_for_port_open(local_port)
        url = f"http://localhost:{local_port}"

        auth_types = [
            {
                "type": "BASIC",
                "key": "BasicAuthParameters",
                "parameters": {"Username": "******", "Password": "******"},
            },
            {
                "type": "API_KEY",
                "key": "ApiKeyAuthParameters",
                "parameters": {"ApiKeyName": "Api", "ApiKeyValue": "apikey_secret"},
            },
            {
                "type": "OAUTH_CLIENT_CREDENTIALS",
                "key": "OAuthParameters",
                "parameters": {
                    "AuthorizationEndpoint": url,
                    "ClientParameters": {"ClientID": "id", "ClientSecret": "password"},
                    "HttpMethod": "put",
                    "OAuthHttpParameters": {
                        "BodyParameters": [{"Key": "oauthbody", "Value": "value1"}],
                        "HeaderParameters": [{"Key": "oauthheader", "Value": "value2"}],
                        "QueryStringParameters": [{"Key": "oauthquery", "Value": "value3"}],
                    },
                },
            },
        ]

        for auth in auth_types:
            connection_name = "c-%s" % short_uid()
            connection_arn = events_client.create_connection(
                Name=connection_name,
                AuthorizationType=auth.get("type"),
                AuthParameters={
                    auth.get("key"): auth.get("parameters"),
                    "InvocationHttpParameters": {
                        "BodyParameters": [
                            {"Key": "key", "Value": "value", "IsValueSecret": False}
                        ],
                        "HeaderParameters": [
                            {"Key": "key", "Value": "value", "IsValueSecret": False}
                        ],
                        "QueryStringParameters": [
                            {"Key": "key", "Value": "value", "IsValueSecret": False}
                        ],
                    },
                },
            )["ConnectionArn"]

            # create api destination
            dest_name = "d-%s" % short_uid()
            result = events_client.create_api_destination(
                Name=dest_name,
                ConnectionArn=connection_arn,
                InvocationEndpoint=url,
                HttpMethod="POST",
            )

            # create rule and target
            rule_name = "r-%s" % short_uid()
            target_id = "target-{}".format(short_uid())
            pattern = json.dumps({"source": ["source-123"], "detail-type": ["type-123"]})
            events_client.put_rule(Name=rule_name, EventPattern=pattern)
            events_client.put_targets(
                Rule=rule_name,
                Targets=[
                    {
                        "Id": target_id,
                        "Arn": result["ApiDestinationArn"],
                        "Input": '{"target_value":"value"}',
                        "HttpParameters": {
                            "PathParameterValues": ["target_path"],
                            "HeaderParameters": {"target_header": "target_header_value"},
                            "QueryStringParameters": {"target_query": "t_query"},
                        },
                    }
                ],
            )

            entries = [
                {
                    "Source": "source-123",
                    "DetailType": "type-123",
                    "Detail": '{"i": %s}' % 0,
                }
            ]
            events_client.put_events(Entries=entries)

            # cleaning
            events_client.delete_connection(Name=connection_name)
            events_client.delete_api_destination(Name=dest_name)
            events_client.delete_rule(Name=rule_name, Force=True)

        # assert that all events have been received in the HTTP server listener

        def check():
            assert len(events) >= len(auth_types)
            assert "key" in paths_list[0] and "value" in paths_list[0]
            assert "target_query" in paths_list[0] and "t_query" in paths_list[0]
            assert "target_path" in paths_list[0]
            assert events[0].get("key") == "value"
            assert events[0].get("target_value") == "value"

            assert oauth_data.get("client_id") == "id"
            assert oauth_data.get("client_secret") == "password"
            assert oauth_data.get("header_value") == "value2"
            assert oauth_data.get("body_value") == "value1"
            assert "oauthquery" in oauth_data.get("path")
            assert "value3" in oauth_data.get("path")

            user_pass = to_str(base64.b64encode(b"user:pass"))
            assert f"Basic {user_pass}" in headers_list
            assert "apikey_secret" in headers_list
            assert bearer in headers_list
            assert "target_header_value" in headers_list

        retry(check, sleep=0.5, retries=5)

        # clean up
        proxy.stop()
예제 #20
0
    def test_scheduled_expression_events(
        self, stepfunctions_client, sns_client, sqs_client, events_client
    ):
        class HttpEndpointListener(ProxyListener):
            def forward_request(self, method, path, data, headers):
                event = json.loads(to_str(data))
                events.append(event)
                return 200

        local_port = get_free_tcp_port()
        proxy = start_proxy(local_port, update_listener=HttpEndpointListener())
        wait_for_port_open(local_port)

        topic_name = "topic-{}".format(short_uid())
        queue_name = "queue-{}".format(short_uid())
        fifo_queue_name = "queue-{}.fifo".format(short_uid())
        rule_name = "rule-{}".format(short_uid())
        endpoint = "{}://{}:{}".format(
            get_service_protocol(), config.LOCALSTACK_HOSTNAME, local_port
        )
        sm_role_arn = aws_stack.role_arn("sfn_role")
        sm_name = "state-machine-{}".format(short_uid())
        topic_target_id = "target-{}".format(short_uid())
        sm_target_id = "target-{}".format(short_uid())
        queue_target_id = "target-{}".format(short_uid())
        fifo_queue_target_id = "target-{}".format(short_uid())

        events = []
        state_machine_definition = """
        {
            "StartAt": "Hello",
            "States": {
                "Hello": {
                    "Type": "Pass",
                    "Result": "World",
                    "End": true
                }
            }
        }
        """

        state_machine_arn = stepfunctions_client.create_state_machine(
            name=sm_name, definition=state_machine_definition, roleArn=sm_role_arn
        )["stateMachineArn"]

        topic_arn = sns_client.create_topic(Name=topic_name)["TopicArn"]
        sns_client.subscribe(TopicArn=topic_arn, Protocol="http", Endpoint=endpoint)

        queue_url = sqs_client.create_queue(QueueName=queue_name)["QueueUrl"]
        fifo_queue_url = sqs_client.create_queue(
            QueueName=fifo_queue_name,
            Attributes={"FifoQueue": "true", "ContentBasedDeduplication": "true"},
        )["QueueUrl"]
        queue_arn = aws_stack.sqs_queue_arn(queue_name)
        fifo_queue_arn = aws_stack.sqs_queue_arn(fifo_queue_name)

        event = {"env": "testing"}

        events_client.put_rule(Name=rule_name, ScheduleExpression="rate(1 minutes)")

        events_client.put_targets(
            Rule=rule_name,
            Targets=[
                {"Id": topic_target_id, "Arn": topic_arn, "Input": json.dumps(event)},
                {
                    "Id": sm_target_id,
                    "Arn": state_machine_arn,
                    "Input": json.dumps(event),
                },
                {"Id": queue_target_id, "Arn": queue_arn, "Input": json.dumps(event)},
                {
                    "Id": fifo_queue_target_id,
                    "Arn": fifo_queue_arn,
                    "Input": json.dumps(event),
                    "SqsParameters": {"MessageGroupId": "123"},
                },
            ],
        )

        def received(q_urls):
            # state machine got executed
            executions = stepfunctions_client.list_executions(stateMachineArn=state_machine_arn)[
                "executions"
            ]
            assert len(executions) >= 1

            # http endpoint got events
            assert len(events) >= 2
            notifications = [
                event["Message"] for event in events if event["Type"] == "Notification"
            ]
            assert len(notifications) >= 1

            # get state machine execution detail
            execution_arn = executions[0]["executionArn"]
            execution_input = stepfunctions_client.describe_execution(executionArn=execution_arn)[
                "input"
            ]

            all_msgs = []
            # get message from queue
            for url in q_urls:
                msgs = sqs_client.receive_message(QueueUrl=url).get("Messages", [])
                assert len(msgs) >= 1
                all_msgs.append(msgs[0])

            return execution_input, notifications[0], all_msgs

        execution_input, notification, msgs_received = retry(
            received, retries=5, sleep=15, q_urls=[queue_url, fifo_queue_url]
        )
        assert json.loads(notification) == event
        assert json.loads(execution_input) == event
        for msg_received in msgs_received:
            assert json.loads(msg_received["Body"]) == event

        # clean up
        proxy.stop()
        self.cleanup(
            None, rule_name, target_ids=[topic_target_id, sm_target_id], queue_url=queue_url
        )
        sns_client.delete_topic(TopicArn=topic_arn)
        stepfunctions_client.delete_state_machine(stateMachineArn=state_machine_arn)
예제 #21
0
def test_firehose_http(lambda_processor_enabled: bool):
    class MyUpdateListener(ProxyListener):
        def forward_request(self, method, path, data, headers):
            data_received = dict(json.loads(data.decode("utf-8")))
            records.append(data_received)
            return 200

    if lambda_processor_enabled:
        # create processor func
        func_name = f"proc-{short_uid()}"
        testutil.create_lambda_function(handler_file=PROCESSOR_LAMBDA,
                                        func_name=func_name)

    # define firehose configs
    local_port = get_free_tcp_port()
    endpoint = "{}://{}:{}".format(get_service_protocol(),
                                   config.LOCALSTACK_HOSTNAME, local_port)
    records = []
    http_destination_update = {
        "EndpointConfiguration": {
            "Url": endpoint,
            "Name": "test_update"
        }
    }
    http_destination = {
        "EndpointConfiguration": {
            "Url": endpoint
        },
        "S3BackupMode": "FailedDataOnly",
        "S3Configuration": {
            "RoleARN": "arn:.*",
            "BucketARN": "arn:.*",
            "Prefix": "",
            "ErrorOutputPrefix": "",
            "BufferingHints": {
                "SizeInMBs": 1,
                "IntervalInSeconds": 60
            },
        },
    }

    if lambda_processor_enabled:
        http_destination["ProcessingConfiguration"] = {
            "Enabled":
            True,
            "Processors": [{
                "Type":
                "Lambda",
                "Parameters": [{
                    "ParameterName":
                    "LambdaArn",
                    "ParameterValue":
                    lambda_function_arn(func_name),
                }],
            }],
        }

    # start proxy server
    start_proxy(local_port,
                backend_url=None,
                update_listener=MyUpdateListener())
    wait_for_port_open(local_port)

    # create firehose stream with http destination
    firehose = aws_stack.create_external_boto_client("firehose")
    stream_name = "firehose_" + short_uid()
    stream = firehose.create_delivery_stream(
        DeliveryStreamName=stream_name,
        HttpEndpointDestinationConfiguration=http_destination,
    )
    assert stream
    stream_description = firehose.describe_delivery_stream(
        DeliveryStreamName=stream_name)
    stream_description = stream_description["DeliveryStreamDescription"]
    destination_description = stream_description["Destinations"][0][
        "HttpEndpointDestinationDescription"]
    assert len(stream_description["Destinations"]) == 1
    assert (destination_description["EndpointConfiguration"]["Url"] ==
            f"http://localhost:{local_port}")

    # put record
    msg_text = "Hello World!"
    firehose.put_record(DeliveryStreamName=stream_name,
                        Record={"Data": msg_text})

    # wait for the result to arrive with proper content
    def _assert_record():
        received_record = records[0]["records"][0]
        received_record_data = to_str(
            base64.b64decode(to_bytes(received_record["data"])))
        assert (
            received_record_data ==
            f"{msg_text}{'-processed' if lambda_processor_enabled else ''}")

    retry(_assert_record, retries=5, sleep=1)

    # update stream destination
    destination_id = stream_description["Destinations"][0]["DestinationId"]
    version_id = stream_description["VersionId"]
    firehose.update_destination(
        DeliveryStreamName=stream_name,
        DestinationId=destination_id,
        CurrentDeliveryStreamVersionId=version_id,
        HttpEndpointDestinationUpdate=http_destination_update,
    )
    stream_description = firehose.describe_delivery_stream(
        DeliveryStreamName=stream_name)
    stream_description = stream_description["DeliveryStreamDescription"]
    destination_description = stream_description["Destinations"][0][
        "HttpEndpointDestinationDescription"]
    assert destination_description["EndpointConfiguration"][
        "Name"] == "test_update"

    # delete stream
    stream = firehose.delete_delivery_stream(DeliveryStreamName=stream_name)
    assert stream["ResponseMetadata"]["HTTPStatusCode"] == 200
예제 #22
0
    def test_firehose_http(self):
        class MyUpdateListener(ProxyListener):
            def forward_request(self, method, path, data, headers):
                data_received = dict(json.loads(data.decode('utf-8')))
                records.append(data_received)
                return 200

        firehose = aws_stack.connect_to_service('firehose')
        local_port = get_free_tcp_port()
        endpoint = '{}://{}:{}'.format(get_service_protocol(),
                                       config.LOCALSTACK_HOSTNAME, local_port)
        records = []
        http_destination_update = {
            'EndpointConfiguration': {
                'Url': endpoint,
                'Name': 'test_update'
            }
        }
        http_destination = {
            'EndpointConfiguration': {
                'Url': endpoint
            },
            'S3BackupMode': 'FailedDataOnly',
            'S3Configuration': {
                'RoleARN': 'arn:.*',
                'BucketARN': 'arn:.*',
                'Prefix': '',
                'ErrorOutputPrefix': '',
                'BufferingHints': {
                    'SizeInMBs': 1,
                    'IntervalInSeconds': 60
                }
            }
        }

        # start proxy server
        start_proxy(local_port,
                    backend_url=None,
                    update_listener=MyUpdateListener())
        wait_for_port_open(local_port)

        # create firehose stream with http destination
        stream = firehose.create_delivery_stream(
            DeliveryStreamName=TEST_STREAM_NAME,
            HttpEndpointDestinationConfiguration=http_destination)
        self.assertTrue(stream)
        stream_description = firehose.describe_delivery_stream(
            DeliveryStreamName=TEST_STREAM_NAME)
        stream_description = stream_description['DeliveryStreamDescription']
        destination_description = stream_description['Destinations'][0][
            'HttpEndpointDestinationDescription']
        self.assertEquals(1, len(stream_description['Destinations']))
        self.assertEquals(
            f'http://localhost:{local_port}',
            destination_description['EndpointConfiguration']['Url'])

        # put record
        firehose.put_record(DeliveryStreamName=TEST_STREAM_NAME,
                            Record={'Data': 'Hello World!'})
        record_received = to_str(
            base64.b64decode(to_bytes(records[0]['records'][0]['data'])))
        # wait for the result to arrive with proper content
        retry(lambda: self.assertEquals('Hello World!', record_received),
              retries=5,
              sleep=1)

        # update stream destination
        destination_id = stream_description['Destinations'][0]['DestinationId']
        version_id = stream_description['VersionId']
        firehose.update_destination(
            DeliveryStreamName=TEST_STREAM_NAME,
            DestinationId=destination_id,
            CurrentDeliveryStreamVersionId=version_id,
            HttpEndpointDestinationUpdate=http_destination_update)
        stream_description = firehose.describe_delivery_stream(
            DeliveryStreamName=TEST_STREAM_NAME)
        stream_description = stream_description['DeliveryStreamDescription']
        destination_description = stream_description['Destinations'][0][
            'HttpEndpointDestinationDescription']
        self.assertEquals(
            'test_update',
            destination_description['EndpointConfiguration']['Name'])

        # delete stream
        stream = firehose.delete_delivery_stream(
            DeliveryStreamName=TEST_STREAM_NAME)
        self.assertEquals(200, stream['ResponseMetadata']['HTTPStatusCode'])
예제 #23
0
    def test_api_destinations(self):
        class HttpEndpointListener(ProxyListener):
            def forward_request(self, method, path, data, headers):
                event = json.loads(to_str(data))
                events.append(event)
                return 200

        events = []
        local_port = get_free_tcp_port()
        proxy = start_proxy(local_port, update_listener=HttpEndpointListener())
        wait_for_port_open(local_port)

        events_client = aws_stack.connect_to_service("events")
        connection_arn = events_client.create_connection(
            Name="TestConnection",
            AuthorizationType="BASIC",
            AuthParameters={
                "BasicAuthParameters": {
                    "Username": "******",
                    "Password": "******"
                }
            },
        )["ConnectionArn"]

        # create api destination
        dest_name = "d-%s" % short_uid()
        url = "http://localhost:%s" % local_port
        result = self.events_client.create_api_destination(
            Name=dest_name,
            ConnectionArn=connection_arn,
            InvocationEndpoint=url,
            HttpMethod="POST",
        )

        # create rule and target
        rule_name = "r-%s" % short_uid()
        target_id = "target-{}".format(short_uid())
        pattern = json.dumps({
            "source": ["source-123"],
            "detail-type": ["type-123"]
        })
        self.events_client.put_rule(Name=rule_name, EventPattern=pattern)
        self.events_client.put_targets(
            Rule=rule_name,
            Targets=[{
                "Id": target_id,
                "Arn": result["ApiDestinationArn"]
            }],
        )

        # put events, to trigger rules
        num_events = 5
        for i in range(num_events):
            entries = [{
                "Source": "source-123",
                "DetailType": "type-123",
                "Detail": '{"i": %s}' % i,
            }]
            self.events_client.put_events(Entries=entries)

        # assert that all events have been received in the HTTP server listener
        def check():
            self.assertEqual(len(events), num_events)

        retry(check, sleep=0.5, retries=5)

        # clean up
        proxy.stop()
예제 #24
0
    def test_firehose_http(self):
        class MyUpdateListener(ProxyListener):
            def forward_request(self, method, path, data, headers):
                data_received = dict(json.loads(data.decode("utf-8")))
                records.append(data_received)
                return 200

        firehose = aws_stack.connect_to_service("firehose")
        local_port = get_free_tcp_port()
        endpoint = "{}://{}:{}".format(
            get_service_protocol(), config.LOCALSTACK_HOSTNAME, local_port
        )
        records = []
        http_destination_update = {
            "EndpointConfiguration": {"Url": endpoint, "Name": "test_update"}
        }
        http_destination = {
            "EndpointConfiguration": {"Url": endpoint},
            "S3BackupMode": "FailedDataOnly",
            "S3Configuration": {
                "RoleARN": "arn:.*",
                "BucketARN": "arn:.*",
                "Prefix": "",
                "ErrorOutputPrefix": "",
                "BufferingHints": {"SizeInMBs": 1, "IntervalInSeconds": 60},
            },
        }

        # start proxy server
        start_proxy(local_port, backend_url=None, update_listener=MyUpdateListener())
        wait_for_port_open(local_port)

        # create firehose stream with http destination
        stream = firehose.create_delivery_stream(
            DeliveryStreamName=TEST_STREAM_NAME,
            HttpEndpointDestinationConfiguration=http_destination,
        )
        self.assertTrue(stream)
        stream_description = firehose.describe_delivery_stream(DeliveryStreamName=TEST_STREAM_NAME)
        stream_description = stream_description["DeliveryStreamDescription"]
        destination_description = stream_description["Destinations"][0][
            "HttpEndpointDestinationDescription"
        ]
        self.assertEqual(1, len(stream_description["Destinations"]))
        self.assertEqual(
            f"http://localhost:{local_port}",
            destination_description["EndpointConfiguration"]["Url"],
        )

        # put record
        firehose.put_record(DeliveryStreamName=TEST_STREAM_NAME, Record={"Data": "Hello World!"})
        record_received = to_str(base64.b64decode(to_bytes(records[0]["records"][0]["data"])))
        # wait for the result to arrive with proper content
        retry(
            lambda: self.assertEqual("Hello World!", record_received),
            retries=5,
            sleep=1,
        )

        # update stream destination
        destination_id = stream_description["Destinations"][0]["DestinationId"]
        version_id = stream_description["VersionId"]
        firehose.update_destination(
            DeliveryStreamName=TEST_STREAM_NAME,
            DestinationId=destination_id,
            CurrentDeliveryStreamVersionId=version_id,
            HttpEndpointDestinationUpdate=http_destination_update,
        )
        stream_description = firehose.describe_delivery_stream(DeliveryStreamName=TEST_STREAM_NAME)
        stream_description = stream_description["DeliveryStreamDescription"]
        destination_description = stream_description["Destinations"][0][
            "HttpEndpointDestinationDescription"
        ]
        self.assertEqual("test_update", destination_description["EndpointConfiguration"]["Name"])

        # delete stream
        stream = firehose.delete_delivery_stream(DeliveryStreamName=TEST_STREAM_NAME)
        self.assertEqual(200, stream["ResponseMetadata"]["HTTPStatusCode"])
예제 #25
0
    def test_api_destinations(self):

        token = short_uid()
        bearer = "Bearer %s" % token

        class HttpEndpointListener(ProxyListener):
            def forward_request(self, method, path, data, headers):
                event = json.loads(to_str(data))
                events.append(event)
                paths_list.append(path)
                auth = headers.get("Api") or headers.get("Authorization")
                if auth not in headers_list:
                    headers_list.append(auth)

                return requests_response({
                    "access_token": token,
                    "token_type": "Bearer",
                    "expires_in": 86400,
                })

        events = []
        paths_list = []
        headers_list = []

        local_port = get_free_tcp_port()
        proxy = start_proxy(local_port, update_listener=HttpEndpointListener())
        wait_for_port_open(local_port)
        events_client = aws_stack.create_external_boto_client("events")
        url = "http://localhost:%s" % local_port

        auth_types = [
            {
                "type": "BASIC",
                "key": "BasicAuthParameters",
                "parameters": {
                    "Username": "******",
                    "Password": "******"
                },
            },
            {
                "type": "API_KEY",
                "key": "ApiKeyAuthParameters",
                "parameters": {
                    "ApiKeyName": "Api",
                    "ApiKeyValue": "apikey_secret"
                },
            },
            {
                "type": "OAUTH_CLIENT_CREDENTIALS",
                "key": "OAuthParameters",
                "parameters": {
                    "AuthorizationEndpoint": url,
                    "ClientParameters": {
                        "ClientID": "id",
                        "ClientSecret": "password"
                    },
                    "HttpMethod": "put",
                },
            },
        ]

        for auth in auth_types:
            connection_name = "c-%s" % short_uid()
            connection_arn = events_client.create_connection(
                Name=connection_name,
                AuthorizationType=auth.get("type"),
                AuthParameters={
                    auth.get("key"): auth.get("parameters"),
                    "InvocationHttpParameters": {
                        "BodyParameters": [{
                            "Key": "key",
                            "Value": "value",
                            "IsValueSecret": False
                        }],
                        "HeaderParameters": [{
                            "Key": "key",
                            "Value": "value",
                            "IsValueSecret": False
                        }],
                        "QueryStringParameters": [{
                            "Key": "key",
                            "Value": "value",
                            "IsValueSecret": False
                        }],
                    },
                },
            )["ConnectionArn"]

            # create api destination
            dest_name = "d-%s" % short_uid()
            result = self.events_client.create_api_destination(
                Name=dest_name,
                ConnectionArn=connection_arn,
                InvocationEndpoint=url,
                HttpMethod="POST",
            )

            # create rule and target
            rule_name = "r-%s" % short_uid()
            target_id = "target-{}".format(short_uid())
            pattern = json.dumps({
                "source": ["source-123"],
                "detail-type": ["type-123"]
            })
            self.events_client.put_rule(Name=rule_name, EventPattern=pattern)
            self.events_client.put_targets(
                Rule=rule_name,
                Targets=[{
                    "Id": target_id,
                    "Arn": result["ApiDestinationArn"]
                }],
            )

            entries = [{
                "Source": "source-123",
                "DetailType": "type-123",
                "Detail": '{"i": %s}' % 0,
            }]
            self.events_client.put_events(Entries=entries)

            # cleaning
            self.events_client.delete_connection(Name=connection_name)
            self.events_client.delete_api_destination(Name=dest_name)
            self.events_client.delete_rule(Name=rule_name, Force=True)

        # assert that all events have been received in the HTTP server listener
        def check():
            self.assertTrue(len(events) >= len(auth_types))
            self.assertTrue("key" in paths_list[0]
                            and "value" in paths_list[0])
            self.assertTrue(events[0].get("key") == "value")

            # TODO examine behavior difference between LS pro/community
            # Pro seems to (correctly) use base64 for basic authentication instead of plaintext
            user_pass = to_str(base64.b64encode(b"user:pass"))
            self.assertTrue("Basic user:pass" in headers_list
                            or f"Basic {user_pass}" in headers_list)
            self.assertTrue("apikey_secret" in headers_list)
            self.assertTrue(bearer in headers_list)

        retry(check, sleep=0.5, retries=5)

        # clean up
        proxy.stop()