Example #1
0
def test_upload_lambda_from_s3():

    s3_client = aws_stack.connect_to_service('s3')
    lambda_client = aws_stack.connect_to_service('lambda')

    lambda_name = 'test_lambda_%s' % short_uid()
    bucket_name = 'test_bucket_lambda'
    bucket_key = 'test_lambda.zip'

    # upload zip file to S3
    zip_file = testutil.create_lambda_archive(load_file(TEST_LAMBDA_PYTHON), get_content=True,
        libs=TEST_LAMBDA_LIBS, runtime=LAMBDA_RUNTIME_PYTHON27)
    s3_client.create_bucket(Bucket=bucket_name)
    s3_client.upload_fileobj(BytesIO(zip_file), bucket_name, bucket_key)

    # create lambda function
    lambda_client.create_function(
        FunctionName=lambda_name, Handler='handler.handler',
        Runtime=lambda_api.LAMBDA_RUNTIME_PYTHON27, Role='r1',
        Code={
            'S3Bucket': bucket_name,
            'S3Key': bucket_key
        }
    )

    # invoke lambda function
    data_before = b'{"foo": "bar"}'
    result = lambda_client.invoke(FunctionName=lambda_name, Payload=data_before)
    data_after = result['Payload'].read()
    assert json.loads(to_str(data_before)) == json.loads(to_str(data_after))
Example #2
0
    def return_response(self, method, path, data, headers, response):
        action = headers.get('X-Amz-Target')
        data = json.loads(to_str(data))

        records = []
        if action in (ACTION_CREATE_STREAM, ACTION_DELETE_STREAM):
            event_type = (event_publisher.EVENT_KINESIS_CREATE_STREAM if action == ACTION_CREATE_STREAM
                else event_publisher.EVENT_KINESIS_DELETE_STREAM)
            event_publisher.fire_event(event_type, payload={'n': event_publisher.get_hash(data.get('StreamName'))})
        elif action == ACTION_PUT_RECORD:
            response_body = json.loads(to_str(response.content))
            event_record = {
                'data': data['Data'],
                'partitionKey': data['PartitionKey'],
                'sequenceNumber': response_body.get('SequenceNumber')
            }
            event_records = [event_record]
            stream_name = data['StreamName']
            lambda_api.process_kinesis_records(event_records, stream_name)
        elif action == ACTION_PUT_RECORDS:
            event_records = []
            response_body = json.loads(to_str(response.content))
            response_records = response_body['Records']
            records = data['Records']
            for index in range(0, len(records)):
                record = records[index]
                event_record = {
                    'data': record['Data'],
                    'partitionKey': record['PartitionKey'],
                    'sequenceNumber': response_records[index].get('SequenceNumber')
                }
                event_records.append(event_record)
            stream_name = data['StreamName']
            lambda_api.process_kinesis_records(event_records, stream_name)
    def return_response(self, method, path, data, headers, response):
        req_data = None
        if method == 'POST' and path == '/':
            req_data = urlparse.parse_qs(to_str(data))
            action = req_data.get('Action')[0]

        if req_data:
            if action == 'DescribeStackResources':
                if response.status_code < 300:
                    response_dict = xmltodict.parse(response.content)['DescribeStackResourcesResponse']
                    resources = response_dict['DescribeStackResourcesResult']['StackResources']
                    if not resources:
                        # Check if stack exists
                        stack_name = req_data.get('StackName')[0]
                        cloudformation_client = aws_stack.connect_to_service('cloudformation')
                        try:
                            cloudformation_client.describe_stacks(StackName=stack_name)
                        except Exception:
                            return error_response('Stack with id %s does not exist' % stack_name, code=404)
            if action == 'DescribeStackResource':
                if response.status_code >= 500:
                    # fix an error in moto where it fails with 500 if the stack does not exist
                    return error_response('Stack resource does not exist', code=404)
            if action == 'ListStackResources':
                response_dict = xmltodict.parse(response.content, force_list=('member'))['ListStackResourcesResponse']
                resources = response_dict['ListStackResourcesResult']['StackResourceSummaries']
                if resources:
                    sqs_client = aws_stack.connect_to_service('sqs')
                    content_str = content_str_original = to_str(response.content)
                    new_response = Response()
                    new_response.status_code = response.status_code
                    new_response.headers = response.headers
                    for resource in resources['member']:
                        if resource['ResourceType'] == 'AWS::SQS::Queue':
                            try:
                                queue_name = resource['PhysicalResourceId']
                                queue_url = sqs_client.get_queue_url(QueueName=queue_name)['QueueUrl']
                            except Exception:
                                stack_name = req_data.get('StackName')[0]
                                return error_response('Stack with id %s does not exist' % stack_name, code=404)
                            content_str = re.sub(resource['PhysicalResourceId'], queue_url, content_str)
                    new_response._content = content_str
                    if content_str_original != new_response._content:
                        # if changes have been made, return patched response
                        new_response.headers['content-length'] = len(new_response._content)
                        return new_response
            elif action in ('CreateStack', 'UpdateStack'):
                if response.status_code >= 400:
                    return response
                # run the actual deployment
                template = template_deployer.template_to_json(req_data.get('TemplateBody')[0])
                template_deployer.deploy_template(template, req_data.get('StackName')[0])
    def forward_request(self, method, path, data, headers):
        req_data = None
        if method == 'POST' and path == '/':
            req_data = urlparse.parse_qs(to_str(data))
            action = req_data.get('Action')[0]

        if req_data:
            if action == 'CreateStack':
                return create_stack(req_data)
            if action == 'CreateChangeSet':
                return create_change_set(req_data)
            elif action == 'DescribeChangeSet':
                return describe_change_set(req_data)
            elif action == 'ExecuteChangeSet':
                return execute_change_set(req_data)
            elif action == 'UpdateStack' and req_data.get('TemplateURL'):
                # Temporary fix until the moto CF backend can handle TemplateURL (currently fails)
                url = re.sub(r'https?://s3\.amazonaws\.com', aws_stack.get_local_service_url('s3'),
                    req_data.get('TemplateURL')[0])
                req_data['TemplateBody'] = requests.get(url).content
                modified_data = urlparse.urlencode(req_data, doseq=True)
                return Request(data=modified_data, headers=headers, method=method)
            elif action == 'ValidateTemplate':
                return validate_template(req_data)

        return True
Example #5
0
def update_function_configuration(function):
    """ Update the configuration of an existing function
        ---
        operationId: 'updateFunctionConfiguration'
        parameters:
            - name: 'request'
              in: body
    """
    data = json.loads(to_str(request.data))
    arn = func_arn(function)

    # Stop/remove any containers that this arn uses.
    LAMBDA_EXECUTOR.cleanup(arn)

    lambda_details = arn_to_lambda[arn]
    if data.get('Handler'):
        lambda_details.handler = data['Handler']
    if data.get('Runtime'):
        lambda_details.runtime = data['Runtime']
    if data.get('Environment'):
        lambda_details.envvars = data.get('Environment', {}).get('Variables', {})
    if data.get('Timeout'):
        lambda_details.timeout = data['Timeout']
    result = {}
    return jsonify(result)
Example #6
0
def invoke_function(function):
    """ Invoke an existing function
        ---
        operationId: 'invokeFunction'
        parameters:
            - name: 'request'
              in: body
    """
    arn = func_arn(function)
    if arn not in arn_to_lambda:
        return error_response('Function does not exist: %s' % arn, 404, error_type='ResourceNotFoundException')
    qualifier = request.args['Qualifier'] if 'Qualifier' in request.args else '$LATEST'
    if not arn_to_lambda.get(arn).qualifier_exists(qualifier):
        return error_response('Function does not exist: {0}:{1}'.format(arn, qualifier), 404,
                              error_type='ResourceNotFoundException')
    data = None
    if request.data:
        try:
            data = json.loads(to_str(request.data))
        except Exception:
            return error_response('The payload is not JSON', 415, error_type='UnsupportedMediaTypeException')
    async = False
    if 'HTTP_X_AMZ_INVOCATION_TYPE' in request.environ:
        async = request.environ['HTTP_X_AMZ_INVOCATION_TYPE'] == 'Event'
    result = run_lambda(async=async, func_arn=arn, event=data, context={}, version=qualifier)
    if isinstance(result, dict):
        return jsonify(result)
    if result:
        return result
    return make_response('', 200)
Example #7
0
def get_metric_statistics(Namespace, MetricName, Dimensions,
        Period=60, StartTime=None, EndTime=None, Statistics=None):
    if not StartTime:
        StartTime = datetime.now() - timedelta(minutes=5)
    if not EndTime:
        EndTime = datetime.now()
    if Statistics is None:
        Statistics = ['Sum']
    cloudwatch_url = aws_stack.get_local_service_url('cloudwatch')
    url = '%s/?Action=GetMetricValues' % cloudwatch_url
    all_metrics = make_http_request(url)
    assert all_metrics.status_code == 200
    datapoints = []
    for datapoint in json.loads(to_str(all_metrics.content)):
        if datapoint['Namespace'] == Namespace and datapoint['Name'] == MetricName:
            dp_dimensions = datapoint['Dimensions']
            all_present = all(m in dp_dimensions for m in Dimensions)
            no_additional = all(m in Dimensions for m in dp_dimensions)
            if all_present and no_additional:
                datapoints.append(datapoint)
    result = {
        'Label': '%s/%s' % (Namespace, MetricName),
        'Datapoints': datapoints
    }
    return result
Example #8
0
def test_firehose_s3():

    s3_resource = aws_stack.connect_to_resource('s3')
    firehose = aws_stack.connect_to_service('firehose')

    s3_prefix = '/testdata'
    test_data = '{"test": "firehose_data_%s"}' % short_uid()
    # create Firehose stream
    stream = firehose.create_delivery_stream(
        DeliveryStreamName=TEST_FIREHOSE_NAME,
        S3DestinationConfiguration={
            'RoleARN': aws_stack.iam_resource_arn('firehose'),
            'BucketARN': aws_stack.s3_bucket_arn(TEST_BUCKET_NAME),
            'Prefix': s3_prefix
        }
    )
    assert stream
    assert TEST_FIREHOSE_NAME in firehose.list_delivery_streams()['DeliveryStreamNames']
    # create target S3 bucket
    s3_resource.create_bucket(Bucket=TEST_BUCKET_NAME)

    # put records
    firehose.put_record(
        DeliveryStreamName=TEST_FIREHOSE_NAME,
        Record={
            'Data': to_bytes(test_data)
        }
    )
    # check records in target bucket
    all_objects = testutil.list_all_s3_objects()
    testutil.assert_objects(json.loads(to_str(test_data)), all_objects)
Example #9
0
def get_machine_id():
    global MACHINE_ID
    if MACHINE_ID:
        return MACHINE_ID

    # determine MACHINE_ID from config files
    configs_map = {}
    config_file_tmp = get_config_file_tempdir()
    config_file_home = get_config_file_homedir()
    for config_file in (config_file_home, config_file_tmp):
        if config_file:
            local_configs = load_file(config_file)
            local_configs = json.loads(to_str(local_configs))
            configs_map[config_file] = local_configs
            if 'machine_id' in local_configs:
                MACHINE_ID = local_configs['machine_id']
                break

    # if we can neither find NOR create the config files, fall back to process id
    if not configs_map:
        return PROCESS_ID

    # assign default id if empty
    if not MACHINE_ID:
        MACHINE_ID = short_uid()

    # update MACHINE_ID in all config files
    for config_file, configs in configs_map.items():
        configs['machine_id'] = MACHINE_ID
        save_file(config_file, json.dumps(configs))

    return MACHINE_ID
Example #10
0
def post_request():
    action = request.headers.get('x-amz-target')
    data = json.loads(to_str(request.data))
    result = {}
    kinesis = aws_stack.connect_to_service('kinesis')
    if action == '%s.ListStreams' % ACTION_HEADER_PREFIX:
        result = {
            'Streams': list(DDB_STREAMS.values()),
            'LastEvaluatedStreamArn': 'TODO'
        }
    elif action == '%s.DescribeStream' % ACTION_HEADER_PREFIX:
        for stream in DDB_STREAMS.values():
            if stream['StreamArn'] == data['StreamArn']:
                result = {
                    'StreamDescription': stream
                }
                # get stream details
                dynamodb = aws_stack.connect_to_service('dynamodb')
                table_name = table_name_from_stream_arn(stream['StreamArn'])
                stream_name = get_kinesis_stream_name(table_name)
                stream_details = kinesis.describe_stream(StreamName=stream_name)
                table_details = dynamodb.describe_table(TableName=table_name)
                stream['KeySchema'] = table_details['Table']['KeySchema']

                # Replace Kinesis ShardIDs with ones that mimic actual
                # DynamoDBStream ShardIDs.
                stream_shards = stream_details['StreamDescription']['Shards']
                for shard in stream_shards:
                    shard['ShardId'] = shard_id(stream_name, shard['ShardId'])
                stream['Shards'] = stream_shards
                break
        if not result:
            return error_response('Requested resource not found', error_type='ResourceNotFoundException')
    elif action == '%s.GetShardIterator' % ACTION_HEADER_PREFIX:
        # forward request to Kinesis API
        stream_name = stream_name_from_stream_arn(data['StreamArn'])
        stream_shard_id = kinesis_shard_id(data['ShardId'])
        result = kinesis.get_shard_iterator(StreamName=stream_name,
            ShardId=stream_shard_id, ShardIteratorType=data['ShardIteratorType'])
    elif action == '%s.GetRecords' % ACTION_HEADER_PREFIX:
        kinesis_records = kinesis.get_records(**data)
        result = {'Records': []}
        for record in kinesis_records['Records']:
            result['Records'].append(json.loads(to_str(record['Data'])))
    else:
        print('WARNING: Unknown operation "%s"' % action)
    return jsonify(result)
Example #11
0
def set_lifecycle(bucket_name, lifecycle):
    # TODO: check if bucket exists, otherwise return 404-like error
    if isinstance(to_str(lifecycle), six.string_types):
        lifecycle = xmltodict.parse(lifecycle)
    BUCKET_LIFECYCLE[bucket_name] = lifecycle
    response = Response()
    response.status_code = 200
    return response
Example #12
0
def create_domain():
    data = json.loads(to_str(request.data))
    domain_name = data['DomainName']
    if domain_name in ES_DOMAINS:
        return error_response(error_type='ResourceAlreadyExistsException')
    ES_DOMAINS[domain_name] = data
    result = get_domain_status(domain_name)
    return jsonify(result)
Example #13
0
def post_request():
    action = request.headers.get('x-amz-target')
    data = json.loads(to_str(request.data))
    response = None
    if action == '%s.ListDeliveryStreams' % ACTION_HEADER_PREFIX:
        response = {
            'DeliveryStreamNames': get_delivery_stream_names(),
            'HasMoreDeliveryStreams': False
        }
    elif action == '%s.CreateDeliveryStream' % ACTION_HEADER_PREFIX:
        stream_name = data['DeliveryStreamName']
        response = create_stream(stream_name,
            s3_destination=data.get('S3DestinationConfiguration'),
            elasticsearch_destination=data.get('ElasticsearchDestinationConfiguration'))
    elif action == '%s.DeleteDeliveryStream' % ACTION_HEADER_PREFIX:
        stream_name = data['DeliveryStreamName']
        response = delete_stream(stream_name)
    elif action == '%s.DescribeDeliveryStream' % ACTION_HEADER_PREFIX:
        stream_name = data['DeliveryStreamName']
        response = get_stream(stream_name)
        if not response:
            return error_not_found(stream_name)
        response = {
            'DeliveryStreamDescription': response
        }
    elif action == '%s.PutRecord' % ACTION_HEADER_PREFIX:
        stream_name = data['DeliveryStreamName']
        record = data['Record']
        put_record(stream_name, record)
        response = {
            'RecordId': str(uuid.uuid4())
        }
    elif action == '%s.PutRecordBatch' % ACTION_HEADER_PREFIX:
        stream_name = data['DeliveryStreamName']
        records = data['Records']
        put_records(stream_name, records)
        response = {
            'FailedPutCount': 0,
            'RequestResponses': []
        }
    elif action == '%s.UpdateDestination' % ACTION_HEADER_PREFIX:
        stream_name = data['DeliveryStreamName']
        version_id = data['CurrentDeliveryStreamVersionId']
        destination_id = data['DestinationId']
        s3_update = data['S3DestinationUpdate'] if 'S3DestinationUpdate' in data else None
        update_destination(stream_name=stream_name, destination_id=destination_id,
            s3_update=s3_update, version_id=version_id)
        es_update = data['ESDestinationUpdate'] if 'ESDestinationUpdate' in data else None
        update_destination(stream_name=stream_name, destination_id=destination_id,
            es_update=es_update, version_id=version_id)
        response = {}
    else:
        response = error_response('Unknown action "%s"' % action, code=400, error_type='InvalidAction')

    if isinstance(response, dict):
        response = jsonify(response)
    return response
Example #14
0
def receive_assert_delete(queue_url, assertions, sqs_client=None, required_subject=None):
    if not sqs_client:
        sqs_client = aws_stack.connect_to_service('sqs')

    response = sqs_client.receive_message(QueueUrl=queue_url)

    messages = [json.loads(to_str(m['Body'])) for m in response['Messages']]
    testutil.assert_objects(assertions, messages)
    for message in response['Messages']:
        sqs_client.delete_message(QueueUrl=queue_url, ReceiptHandle=message['ReceiptHandle'])
def test_message_transformation():
    template = APIGATEWAY_TRANSFORMATION_TEMPLATE
    records = [
        {
            'data': {
                'foo': 'foo1',
                'bar': 'bar2'
            }
        },
        {
            'data': {
                'foo': 'foo1',
                'bar': 'bar2'
            },
            'partitionKey': 'key123'
        }
    ]
    context = {
        'records': records
    }
    # try rendering the template
    result = render_velocity_template(template, context, as_json=True)
    result_decoded = json.loads(to_str(base64.b64decode(result['Records'][0]['Data'])))
    assert result_decoded == records[0]['data']
    assert len(result['Records'][0]['PartitionKey']) > 0
    assert result['Records'][1]['PartitionKey'] == 'key123'
    # try again with context as string
    context = json.dumps(context)
    result = render_velocity_template(template, context, as_json=True)
    result_decoded = json.loads(to_str(base64.b64decode(result['Records'][0]['Data'])))
    assert result_decoded == records[0]['data']
    assert len(result['Records'][0]['PartitionKey']) > 0
    assert result['Records'][1]['PartitionKey'] == 'key123'

    # test with empty array
    records = []
    context = {
        'records': records
    }
    # try rendering the template
    result = render_velocity_template(template, context, as_json=True)
    assert result['Records'] == []
Example #16
0
def update_function_code(function):
    """ Update the code of an existing function
        ---
        operationId: 'updateFunctionCode'
        parameters:
            - name: 'request'
              in: body
    """
    data = json.loads(to_str(request.data))
    result = set_function_code(data, function)
    return jsonify(result or {})
Example #17
0
    def forward_request(self, method, path, data, headers):
        data = json.loads(to_str(data))

        if random.random() < config.DYNAMODB_ERROR_PROBABILITY:
            return error_response_throughput()

        action = headers.get('X-Amz-Target')
        if action in ('%s.PutItem' % ACTION_PREFIX, '%s.UpdateItem' % ACTION_PREFIX, '%s.DeleteItem' % ACTION_PREFIX):
            # find an existing item and store it in a thread-local, so we can access it in return_response,
            # in order to determine whether an item already existed (MODIFY) or not (INSERT)
            ProxyListenerDynamoDB.thread_local.existing_item = find_existing_item(data)
        elif action == '%s.UpdateTimeToLive' % ACTION_PREFIX:
            # TODO: TTL status is maintained/mocked but no real expiry is happening for items
            response = Response()
            response.status_code = 200
            self._table_ttl_map[data['TableName']] = {
                'AttributeName': data['TimeToLiveSpecification']['AttributeName'],
                'Status': data['TimeToLiveSpecification']['Enabled']
            }
            response._content = json.dumps({'TimeToLiveSpecification': data['TimeToLiveSpecification']})
            fix_headers_for_updated_response(response)
            return response
        elif action == '%s.DescribeTimeToLive' % ACTION_PREFIX:
            response = Response()
            response.status_code = 200
            if data['TableName'] in self._table_ttl_map:
                if self._table_ttl_map[data['TableName']]['Status']:
                    ttl_status = 'ENABLED'
                else:
                    ttl_status = 'DISABLED'
                response._content = json.dumps({
                    'TimeToLiveDescription': {
                        'AttributeName': self._table_ttl_map[data['TableName']]['AttributeName'],
                        'TimeToLiveStatus': ttl_status
                    }
                })
            else:  # TTL for dynamodb table not set
                response._content = json.dumps({'TimeToLiveDescription': {'TimeToLiveStatus': 'DISABLED'}})
            fix_headers_for_updated_response(response)
            return response
        elif action == '%s.TagResource' % ACTION_PREFIX or action == '%s.UntagResource' % ACTION_PREFIX:
            response = Response()
            response.status_code = 200
            response._content = ''  # returns an empty body on success.
            fix_headers_for_updated_response(response)
            return response
        elif action == '%s.ListTagsOfResource' % ACTION_PREFIX:
            response = Response()
            response.status_code = 200
            response._content = json.dumps({'Tags': []})  # TODO: mocked and returns an empty list of tags for now.
            fix_headers_for_updated_response(response)
            return response

        return True
Example #18
0
def create_event_source_mapping():
    """ Create new event source mapping
        ---
        operationId: 'createEventSourceMapping'
        parameters:
            - name: 'request'
              in: body
    """
    data = json.loads(to_str(request.data))
    mapping = add_event_source(data['FunctionName'], data['EventSourceArn'])
    return jsonify(mapping)
Example #19
0
 def return_response(self, method, path, data, headers, response):
     # This method is executed by the proxy after we've already received a
     # response from the backend, hence we can utilize the "response" variable here
     if method == 'POST' and path == '/':
         req_data = urlparse.parse_qs(to_str(data))
         req_action = req_data['Action'][0]
         if req_action == 'Subscribe' and response.status_code < 400:
             response_data = xmltodict.parse(response.content)
             topic_arn = (req_data.get('TargetArn') or req_data.get('TopicArn'))[0]
             sub_arn = response_data['SubscribeResponse']['SubscribeResult']['SubscriptionArn']
             do_subscribe(topic_arn, req_data['Endpoint'][0], req_data['Protocol'][0], sub_arn)
Example #20
0
def test_api_gateway_http_integration():
    test_port = 12123
    backend_url = 'http://localhost:%s%s' % (test_port, API_PATH_HTTP_BACKEND)

    # create target HTTP backend
    class TestListener(ProxyListener):

        def forward_request(self, **kwargs):
            response = Response()
            response.status_code = 200
            response._content = kwargs.get('data') or '{}'
            return response

    proxy = GenericProxy(test_port, update_listener=TestListener())
    proxy.start()

    # create API Gateway and connect it to the HTTP backend
    result = connect_api_gateway_to_http('test_gateway2', backend_url, path=API_PATH_HTTP_BACKEND)

    url = INBOUND_GATEWAY_URL_PATTERN.format(api_id=result['id'],
        stage_name=TEST_STAGE_NAME, path=API_PATH_HTTP_BACKEND)

    # make sure CORS headers are present
    origin = 'localhost'
    result = requests.options(url, headers={'origin': origin})
    assert result.status_code == 200
    assert re.match(result.headers['Access-Control-Allow-Origin'].replace('*', '.*'), origin)
    assert 'POST' in result.headers['Access-Control-Allow-Methods']

    # make test request to gateway
    result = requests.get(url)
    assert result.status_code == 200
    assert to_str(result.content) == '{}'
    data = {'data': 123}
    result = requests.post(url, data=json.dumps(data))
    assert result.status_code == 200
    assert json.loads(to_str(result.content)) == data

    # clean up
    proxy.stop()
Example #21
0
    def forward_request(self, method, path, data, headers):

        if method == 'POST' and path == '/':
            req_data = urlparse.parse_qs(to_str(data))
            if 'QueueName' in req_data:
                if '.' in req_data['QueueName'][0]:
                    # ElasticMQ currently does not support "." in the queue name, e.g., for *.fifo queues
                    # TODO: remove this once *.fifo queues are supported in ElasticMQ
                    req_data['QueueName'][0] = req_data['QueueName'][0].replace('.', '_')
                    modified_data = urlencode(req_data, doseq=True)
                    request = Request(data=modified_data, headers=headers, method=method)
                    return request

        return True
Example #22
0
def deserialize_event(event):
    # Deserialize into Python dictionary and extract the "NewImage" (the new version of the full ddb document)
    ddb = event.get('dynamodb')
    if ddb:
        ddb_deserializer = TypeDeserializer()
        result = ddb_deserializer.deserialize({'M': ddb.get('NewImage')})
        result['__action_type'] = event.get('eventName')
        return result
    kinesis = event.get('kinesis')
    if kinesis:
        assert kinesis['sequenceNumber']
        kinesis['data'] = json.loads(to_str(base64.b64decode(kinesis['data'])))
        return kinesis
    return event.get('Sns')
Example #23
0
def create_function():
    """ Create new function
        ---
        operationId: 'createFunction'
        parameters:
            - name: 'request'
              in: body
    """
    arn = 'n/a'
    try:
        data = json.loads(to_str(request.data))
        lambda_name = data['FunctionName']
        event_publisher.fire_event(event_publisher.EVENT_LAMBDA_CREATE_FUNC,
            payload={'n': event_publisher.get_hash(lambda_name)})
        arn = func_arn(lambda_name)
        if arn in arn_to_lambda:
            return error_response('Function already exist: %s' %
                lambda_name, 409, error_type='ResourceConflictException')
        arn_to_lambda[arn] = func_details = LambdaFunction(arn)
        func_details.versions = {'$LATEST': {'CodeSize': 50}}
        func_details.handler = data['Handler']
        func_details.runtime = data['Runtime']
        func_details.envvars = data.get('Environment', {}).get('Variables', {})
        func_details.timeout = data.get('Timeout')
        result = set_function_code(data['Code'], lambda_name)
        if isinstance(result, Response):
            del arn_to_lambda[arn]
            return result
        result.update({
            'DeadLetterConfig': data.get('DeadLetterConfig'),
            'Description': data.get('Description'),
            'Environment': {'Error': {}, 'Variables': func_details.envvars},
            'FunctionArn': arn,
            'FunctionName': lambda_name,
            'Handler': func_details.handler,
            'MemorySize': data.get('MemorySize'),
            'Role': data.get('Role'),
            'Runtime': func_details.runtime,
            'Timeout': data.get('Timeout'),
            'TracingConfig': {},
            'VpcConfig': {'SecurityGroupIds': [None], 'SubnetIds': [None], 'VpcId': None}
        })
        return jsonify(result or {})
    except Exception as e:
        del arn_to_lambda[arn]
        return error_response('Unknown error: %s %s' % (e, traceback.format_exc()))
Example #24
0
def kinesis_get_latest_records(stream_name, shard_id, count=10, env=None):
    kinesis = connect_to_service('kinesis', env=env)
    result = []
    response = kinesis.get_shard_iterator(StreamName=stream_name, ShardId=shard_id,
        ShardIteratorType='TRIM_HORIZON')
    shard_iterator = response['ShardIterator']
    while shard_iterator:
        records_response = kinesis.get_records(ShardIterator=shard_iterator)
        records = records_response['Records']
        for record in records:
            try:
                record['Data'] = to_str(record['Data'])
            except Exception:
                pass
        result.extend(records)
        shard_iterator = records_response['NextShardIterator'] if records else False
        while len(result) > count:
            result.pop(0)
    return result
Example #25
0
def test_api_gateway_kinesis_integration():
    # create target Kinesis stream
    aws_stack.create_kinesis_stream(TEST_STREAM_KINESIS_API_GW)

    # create API Gateway and connect it to the target stream
    result = connect_api_gateway_to_kinesis('test_gateway1', TEST_STREAM_KINESIS_API_GW)

    # generate test data
    test_data = {'records': [
        {'data': '{"foo": "bar1"}'},
        {'data': '{"foo": "bar2"}'},
        {'data': '{"foo": "bar3"}'}
    ]}

    url = INBOUND_GATEWAY_URL_PATTERN.format(api_id=result['id'],
        stage_name=TEST_STAGE_NAME, path=API_PATH_DATA_INBOUND)
    result = requests.post(url, data=json.dumps(test_data))
    result = json.loads(to_str(result.content))
    assert result['FailedRecordCount'] == 0
    assert len(result['Records']) == len(test_data['records'])
Example #26
0
def test_api_gateway_lambda_proxy_integration():
    # create lambda function
    zip_file = testutil.create_lambda_archive(load_file(TEST_LAMBDA_PYTHON), get_content=True,
        libs=TEST_LAMBDA_LIBS, runtime=LAMBDA_RUNTIME_PYTHON27)
    testutil.create_lambda_function(func_name=TEST_LAMBDA_PROXY_BACKEND,
        zip_file=zip_file, runtime=LAMBDA_RUNTIME_PYTHON27)

    # create API Gateway and connect it to the Lambda proxy backend
    lambda_uri = aws_stack.lambda_function_arn(TEST_LAMBDA_PROXY_BACKEND)
    target_uri = 'arn:aws:apigateway:%s:lambda:path/2015-03-31/functions/%s/invocations' % (DEFAULT_REGION, lambda_uri)
    result = connect_api_gateway_to_http_with_lambda_proxy('test_gateway2', target_uri,
        path=API_PATH_LAMBDA_PROXY_BACKEND)

    # make test request to gateway and check response
    path = API_PATH_LAMBDA_PROXY_BACKEND.replace('{test_param1}', 'foo1')
    url = INBOUND_GATEWAY_URL_PATTERN.format(api_id=result['id'], stage_name=TEST_STAGE_NAME, path=path)
    data = {'return_status_code': 203, 'return_headers': {'foo': 'bar123'}}
    result = requests.post(url, data=json.dumps(data))
    assert result.status_code == 203
    assert result.headers.get('foo') == 'bar123'
    parsed_body = json.loads(to_str(result.content))
    assert parsed_body.get('return_status_code') == 203
    assert parsed_body.get('return_headers') == {'foo': 'bar123'}
    assert parsed_body.get('pathParameters') == {'test_param1': 'foo1'}
Example #27
0
def get_payload(request):
    return json.loads(common.to_str(request.data))
Example #28
0
 def _fix_error_codes(method, data, response):
     if method == 'POST' and 'Action=CreateRole' in to_str(data) and response.status_code >= 400:
         content = to_str(response.content)
         flags = re.MULTILINE | re.DOTALL
         # remove the <Errors> wrapper element, as this breaks AWS Java SDKs (issue #2231)
         response._content = re.sub(r'<Errors>\s*(<Error>(\s|.)*</Error>)\s*</Errors>', r'\1', content, flags)
Example #29
0
    def return_response(self, method, path, data, headers, response, request_handler):

        if method == 'POST' and path == '/':
            req_data = urlparse.parse_qs(to_str(data))
            action = req_data.get('Action', [None])[0]
            event_type = None
            queue_url = None
            if action == 'CreateQueue':
                event_type = event_publisher.EVENT_SQS_CREATE_QUEUE
                response_data = xmltodict.parse(response.content)
                if 'CreateQueueResponse' in response_data:
                    queue_url = response_data['CreateQueueResponse']['CreateQueueResult']['QueueUrl']
            elif action == 'DeleteQueue':
                event_type = event_publisher.EVENT_SQS_DELETE_QUEUE
                queue_url = req_data.get('QueueUrl', [None])[0]

            if event_type and queue_url:
                event_publisher.fire_event(event_type, payload={'u': event_publisher.get_hash(queue_url)})

            # patch the response and return the correct endpoint URLs
            if action in ('CreateQueue', 'GetQueueUrl', 'ListQueues'):
                content_str = content_str_original = to_str(response.content)
                new_response = Response()
                new_response.status_code = response.status_code
                new_response.headers = response.headers
                if config.USE_SSL and '<QueueUrl>http://' in content_str:
                    # return https://... if we're supposed to use SSL
                    content_str = re.sub(r'<QueueUrl>\s*http://', r'<QueueUrl>https://', content_str)
                # expose external hostname:port
                external_port = get_external_port(headers, request_handler)
                content_str = re.sub(r'<QueueUrl>\s*([a-z]+)://[^<]*:([0-9]+)/([^<]*)\s*</QueueUrl>',
                    r'<QueueUrl>\1://%s:%s/\3</QueueUrl>' % (HOSTNAME_EXTERNAL, external_port), content_str)
                new_response._content = content_str
                if content_str_original != new_response._content:
                    # if changes have been made, return patched response
                    new_response.headers['content-length'] = len(new_response._content)
                    return new_response

            # Since the following 2 API calls are not implemented in ElasticMQ, we're mocking them
            # and letting them to return an empty response
            if action == 'TagQueue':
                new_response = Response()
                new_response.status_code = 200
                new_response._content = (
                    '<?xml version="1.0"?>'
                    '<TagQueueResponse>'
                        '<ResponseMetadata>'  # noqa: W291
                            '<RequestId>{}</RequestId>'  # noqa: W291
                        '</ResponseMetadata>'  # noqa: W291
                    '</TagQueueResponse>'
                ).format(uuid.uuid4())
                return new_response
            elif action == 'ListQueueTags':
                new_response = Response()
                new_response.status_code = 200
                new_response._content = (
                    '<?xml version="1.0"?>'
                    '<ListQueueTagsResponse xmlns="{}">'
                        '<ListQueueTagsResult/>'  # noqa: W291
                        '<ResponseMetadata>'  # noqa: W291
                            '<RequestId>{}</RequestId>'  # noqa: W291
                        '</ResponseMetadata>'  # noqa: W291
                    '</ListQueueTagsResponse>'
                ).format(XMLNS_SQS, uuid.uuid4())
                return new_response
Example #30
0
    def forward_request(self, method, path, data, headers):

        # check region
        try:
            aws_stack.check_valid_region(headers)
        except Exception as e:
            return make_error(message=str(e), code=400)

        if method == 'POST' and path == '/':
            req_data = urlparse.parse_qs(to_str(data))
            req_action = req_data['Action'][0]
            topic_arn = req_data.get('TargetArn') or req_data.get('TopicArn')

            if topic_arn:
                topic_arn = topic_arn[0]
                do_create_topic(topic_arn)

            if req_action == 'SetSubscriptionAttributes':
                sub = get_subscription_by_arn(req_data['SubscriptionArn'][0])
                if not sub:
                    return make_error(
                        message='Unable to find subscription for given ARN',
                        code=400)
                attr_name = req_data['AttributeName'][0]
                attr_value = req_data['AttributeValue'][0]
                sub[attr_name] = attr_value
                return make_response(req_action)
            elif req_action == 'GetSubscriptionAttributes':
                sub = get_subscription_by_arn(req_data['SubscriptionArn'][0])
                if not sub:
                    return make_error(
                        message='Unable to find subscription for given ARN',
                        code=400)
                content = '<Attributes>'
                for key, value in sub.items():
                    content += '<entry><key>%s</key><value>%s</value></entry>\n' % (
                        key, value)
                content += '</Attributes>'
                return make_response(req_action, content=content)
            elif req_action == 'Subscribe':
                if 'Endpoint' not in req_data:
                    return make_error(
                        message='Endpoint not specified in subscription',
                        code=400)
            elif req_action == 'Unsubscribe':
                if 'SubscriptionArn' not in req_data:
                    return make_error(
                        message=
                        'SubscriptionArn not specified in unsubscribe request',
                        code=400)
                do_unsubscribe(req_data.get('SubscriptionArn')[0])
            elif req_action == 'DeleteTopic':
                do_delete_topic(topic_arn)

            elif req_action == 'Publish':
                message = req_data['Message'][0]
                sqs_client = aws_stack.connect_to_service('sqs')
                for subscriber in SNS_SUBSCRIPTIONS.get(topic_arn, []):
                    filter_policy = json.loads(
                        subscriber.get('FilterPolicy', '{}'))
                    message_attributes = get_message_attributes(req_data)
                    if check_filter_policy(filter_policy, message_attributes):
                        if subscriber['Protocol'] == 'sqs':
                            endpoint = subscriber['Endpoint']
                            if 'sqs_queue_url' in subscriber:
                                queue_url = subscriber.get('sqs_queue_url')
                            elif '://' in endpoint:
                                queue_url = endpoint
                            else:
                                queue_name = endpoint.split(':')[5]
                                queue_url = aws_stack.get_sqs_queue_url(
                                    queue_name)
                                subscriber['sqs_queue_url'] = queue_url
                            try:
                                sqs_client.send_message(
                                    QueueUrl=queue_url,
                                    MessageBody=create_sns_message_body(
                                        subscriber, req_data),
                                    MessageAttributes=
                                    create_sqs_message_attributes(
                                        subscriber, message_attributes))
                            except Exception as exc:
                                return make_error(message=str(exc), code=400)
                        elif subscriber['Protocol'] == 'lambda':
                            lambda_api.process_sns_notification(
                                subscriber['Endpoint'],
                                topic_arn,
                                message,
                                subject=req_data.get('Subject', [None])[0])
                        elif subscriber['Protocol'] in ['http', 'https']:
                            try:
                                message_body = create_sns_message_body(
                                    subscriber, req_data)
                            except Exception as exc:
                                return make_error(message=str(exc), code=400)
                            requests.post(subscriber['Endpoint'],
                                          headers={
                                              'Content-Type':
                                              'text/plain',
                                              'x-amz-sns-message-type':
                                              'Notification'
                                          },
                                          data=message_body)
                        else:
                            LOGGER.warning(
                                'Unexpected protocol "%s" for SNS subscription'
                                % subscriber['Protocol'])
                # return response here because we do not want the request to be forwarded to SNS
                return make_response(req_action)

        return True
Example #31
0
    def return_response(self, method, path, data, headers, response,
                        request_handler):
        if method == 'OPTIONS' and path == '/':
            # Allow CORS preflight requests to succeed.
            return 200

        if method != 'POST':
            return

        region_name = extract_region_from_auth_header(headers)
        req_data = urlparse.parse_qs(to_str(data))
        action = req_data.get('Action', [None])[0]
        content_str = content_str_original = to_str(response.content)

        self._fire_event(req_data, response)

        # patch the response and add missing attributes
        if action == 'GetQueueAttributes':
            content_str = self._add_queue_attributes(path, req_data,
                                                     content_str, headers)

        # patch the response and return the correct endpoint URLs / ARNs
        if action in ('CreateQueue', 'GetQueueUrl', 'ListQueues',
                      'GetQueueAttributes'):
            if config.USE_SSL and '<QueueUrl>http://' in content_str:
                # return https://... if we're supposed to use SSL
                content_str = re.sub(r'<QueueUrl>\s*http://',
                                     r'<QueueUrl>https://', content_str)
            # expose external hostname:port
            external_port = SQS_PORT_EXTERNAL or get_external_port(
                headers, request_handler)
            content_str = re.sub(
                r'<QueueUrl>\s*([a-z]+)://[^<]*:([0-9]+)/([^<]*)\s*</QueueUrl>',
                r'<QueueUrl>\1://%s:%s/\3</QueueUrl>' %
                (HOSTNAME_EXTERNAL, external_port), content_str)
            # fix queue ARN
            content_str = re.sub(
                r'<([a-zA-Z0-9]+)>\s*arn:aws:sqs:elasticmq:([^<]+)</([a-zA-Z0-9]+)>',
                r'<\1>arn:aws:sqs:%s:\2</\3>' % (region_name), content_str)

        if content_str_original != content_str:
            # if changes have been made, return patched response
            new_response = Response()
            new_response.status_code = response.status_code
            new_response.headers = response.headers
            new_response._content = content_str
            new_response.headers['content-length'] = len(new_response._content)
            return new_response

        # Since the following 2 API calls are not implemented in ElasticMQ, we're mocking them
        # and letting them to return an empty response
        if action == 'TagQueue':
            new_response = Response()
            new_response.status_code = 200
            new_response._content = ("""
                <?xml version="1.0"?>
                <TagQueueResponse>
                    <ResponseMetadata>
                        <RequestId>{}</RequestId>
                    </ResponseMetadata>
                </TagQueueResponse>
            """).strip().format(uuid.uuid4())
            return new_response
        elif action == 'ListQueueTags':
            new_response = Response()
            new_response.status_code = 200
            new_response._content = ("""
                <?xml version="1.0"?>
                <ListQueueTagsResponse xmlns="{}">
                    <ListQueueTagsResult/>
                    <ResponseMetadata>
                        <RequestId>{}</RequestId>
                    </ResponseMetadata>
                </ListQueueTagsResponse>
            """).strip().format(XMLNS_SQS, uuid.uuid4())
            return new_response
    def return_response(self, method, path, data, headers, response):
        data = json.loads(to_str(data))

        # update table definitions
        if data and 'TableName' in data and 'KeySchema' in data:
            TABLE_DEFINITIONS[data['TableName']] = data

        if response._content:
            # fix the table ARN (DynamoDBLocal hardcodes "ddblocal" as the region)
            content_replaced = re.sub(
                r'"TableArn"\s*:\s*"arn:aws:dynamodb:ddblocal:([^"]+)"',
                r'"TableArn": "arn:aws:dynamodb:%s:\1"' %
                aws_stack.get_local_region(), to_str(response._content))
            if content_replaced != response._content:
                response._content = content_replaced
                fix_headers_for_updated_response(response)

        action = headers.get('X-Amz-Target')
        if not action:
            return

        record = {
            'eventID': '1',
            'eventVersion': '1.0',
            'dynamodb': {
                'StreamViewType': 'NEW_AND_OLD_IMAGES',
                'SizeBytes': -1
            },
            'awsRegion': DEFAULT_REGION,
            'eventSource': 'aws:dynamodb'
        }
        records = [record]

        if action == '%s.UpdateItem' % ACTION_PREFIX:
            updated_item = find_existing_item(data)
            if not updated_item:
                return
            record['eventName'] = 'MODIFY'
            record['dynamodb']['Keys'] = data['Key']
            record['dynamodb'][
                'OldImage'] = ProxyListenerDynamoDB.thread_local.existing_item
            record['dynamodb']['NewImage'] = updated_item
            record['dynamodb']['SizeBytes'] = len(json.dumps(updated_item))
        elif action == '%s.BatchWriteItem' % ACTION_PREFIX:
            records = []
            for table_name, requests in data['RequestItems'].items():
                for request in requests:
                    put_request = request.get('PutRequest')
                    if put_request:
                        keys = dynamodb_extract_keys(item=put_request['Item'],
                                                     table_name=table_name)
                        if isinstance(keys, Response):
                            return keys
                        new_record = clone(record)
                        new_record['eventName'] = 'INSERT'
                        new_record['dynamodb']['Keys'] = keys
                        new_record['dynamodb']['NewImage'] = put_request[
                            'Item']
                        new_record[
                            'eventSourceARN'] = aws_stack.dynamodb_table_arn(
                                table_name)
                        records.append(new_record)
        elif action == '%s.PutItem' % ACTION_PREFIX:
            existing_item = ProxyListenerDynamoDB.thread_local.existing_item
            ProxyListenerDynamoDB.thread_local.existing_item = None
            record['eventName'] = 'INSERT' if not existing_item else 'MODIFY'
            keys = dynamodb_extract_keys(item=data['Item'],
                                         table_name=data['TableName'])
            if isinstance(keys, Response):
                return keys
            record['dynamodb']['Keys'] = keys
            record['dynamodb']['NewImage'] = data['Item']
            record['dynamodb']['SizeBytes'] = len(json.dumps(data['Item']))
        elif action == '%s.GetItem' % ACTION_PREFIX:
            if response.status_code == 200:
                content = json.loads(to_str(response.content))
                # make sure we append 'ConsumedCapacity', which is properly
                # returned by dynalite, but not by AWS's DynamoDBLocal
                if 'ConsumedCapacity' not in content and data.get(
                        'ReturnConsumedCapacity') in ('TOTAL', 'INDEXES'):
                    content['ConsumedCapacity'] = {
                        'CapacityUnits': 0.5,  # TODO hardcoded
                        'TableName': data['TableName']
                    }
                    response._content = json.dumps(content)
                    fix_headers_for_updated_response(response)
        elif action == '%s.DeleteItem' % ACTION_PREFIX:
            record['eventName'] = 'REMOVE'
            record['dynamodb']['Keys'] = data['Key']
        elif action == '%s.CreateTable' % ACTION_PREFIX:
            if 'StreamSpecification' in data:
                create_dynamodb_stream(data)
            event_publisher.fire_event(
                event_publisher.EVENT_DYNAMODB_CREATE_TABLE,
                payload={'n': event_publisher.get_hash(data['TableName'])})
            return
        elif action == '%s.DeleteTable' % ACTION_PREFIX:
            event_publisher.fire_event(
                event_publisher.EVENT_DYNAMODB_DELETE_TABLE,
                payload={'n': event_publisher.get_hash(data['TableName'])})
            return
        elif action == '%s.UpdateTable' % ACTION_PREFIX:
            if 'StreamSpecification' in data:
                create_dynamodb_stream(data)
            return
        else:
            # nothing to do
            return

        if len(records) > 0 and 'eventName' in records[0]:
            if 'TableName' in data:
                records[0]['eventSourceARN'] = aws_stack.dynamodb_table_arn(
                    data['TableName'])
            forward_to_lambda(records)
            forward_to_ddb_stream(records)
Example #33
0
 def decode_content(self, data):
     try:
         return json.loads(to_str(data))
     except UnicodeDecodeError:
         return cbor2.loads(data)
Example #34
0
 def _escape(val):
     try:
         return val and escape(to_str(val))
     except Exception:
         return val
Example #35
0
    def return_response(self, method, path, data, headers, response):
        action = headers.get('X-Amz-Target')
        data = json.loads(to_str(data))

        print('ProxyListenerKinesis return_response data: [', data,
              '] action [', action, ']')

        records = []
        if action in (ACTION_CREATE_STREAM, ACTION_DELETE_STREAM):
            event_type = (event_publisher.EVENT_KINESIS_CREATE_STREAM
                          if action == ACTION_CREATE_STREAM else
                          event_publisher.EVENT_KINESIS_DELETE_STREAM)
            payload = {'n': event_publisher.get_hash(data.get('StreamName'))}
            if action == ACTION_CREATE_STREAM:
                payload['s'] = data.get('ShardCount')
            print(
                'ProxyListenerKinesis return_response ACTION_CREATE_STREAM, ACTION_DELETE_STREAM '
                + 'event_publisher.fire_event event_type: [', event_type,
                '] payload[', payload, ']')
            event_publisher.fire_event(event_type, payload=payload)
        elif action == ACTION_PUT_RECORD:
            response_body = json.loads(to_str(response.content))
            event_record = {
                'data': data['Data'],
                'partitionKey': data['PartitionKey'],
                'sequenceNumber': response_body.get('SequenceNumber')
            }
            event_records = [event_record]
            stream_name = data['StreamName']
            print(
                'ProxyListenerKinesis return_response ACTION_PUT_RECORD: ' +
                'lambda_api.process_kinesis_records event_record[',
                event_record, 'stream_name[', stream_name, ']')
            lambda_api.process_kinesis_records(event_records, stream_name)
        elif action == ACTION_PUT_RECORDS:
            event_records = []
            response_body = json.loads(to_str(response.content))
            if 'Records' in response_body:
                response_records = response_body['Records']
                records = data['Records']
                for index in range(0, len(records)):
                    record = records[index]
                    event_record = {
                        'data':
                        record['Data'],
                        'partitionKey':
                        record['PartitionKey'],
                        'sequenceNumber':
                        response_records[index].get('SequenceNumber')
                    }
                    event_records.append(event_record)
                stream_name = data['StreamName']
                print(
                    'ProxyListenerKinesis return_response ACTION_PUT_RECORDS: '
                    + 'lambda_api.process_kinesis_records event_record[',
                    event_records, 'stream_name[', stream_name, ']')
                lambda_api.process_kinesis_records(event_records, stream_name)
        elif action == ACTION_UPDATE_SHARD_COUNT:
            # Currently kinesalite, which backs the Kinesis implementation for localstack, does
            # not support UpdateShardCount:
            # https://github.com/mhart/kinesalite/issues/61
            #
            # [Terraform](https://www.terraform.io) makes the call to UpdateShardCount when it
            # applies Kinesis resources. A Terraform run fails when this is not present.
            #
            # The code that follows just returns a successful response, bypassing the 400
            # response that kinesalite returns.
            #
            response = Response()
            response.status_code = 200
            content = {
                'CurrentShardCount': 1,
                'StreamName': data['StreamName'],
                'TargetShardCount': data['TargetShardCount']
            }
            response.encoding = 'UTF-8'
            response._content = json.dumps(content)
            print(
                'ProxyListenerKinesis return_response ACTION_UPDATE_SHARD_COUNT:  response[',
                response, ']')
            return response
Example #36
0
    def return_response(self, method, path, data, headers, response,
                        request_handler):
        # persist requests to disk
        super(ProxyListenerSQS,
              self).return_response(method, path, data, headers, response,
                                    request_handler)

        if method == 'OPTIONS' and path == '/':
            # Allow CORS preflight requests to succeed.
            return 200

        if method != 'POST':
            return

        region_name = aws_stack.get_region()
        req_data = parse_request_data(method, path, data)
        action = req_data.get('Action')
        content_str = content_str_original = to_str(response.content)

        if response.status_code >= 400:
            return response

        _fire_event(req_data, response)

        # patch the response and add missing attributes
        if action == 'GetQueueAttributes':
            content_str = _add_queue_attributes(path, req_data, content_str,
                                                headers)

        # patch the response and return the correct endpoint URLs / ARNs
        if action in ('CreateQueue', 'GetQueueUrl', 'ListQueues',
                      'GetQueueAttributes'):
            if config.USE_SSL and '<QueueUrl>http://' in content_str:
                # return https://... if we're supposed to use SSL
                content_str = re.sub(r'<QueueUrl>\s*http://',
                                     r'<QueueUrl>https://', content_str)
            # expose external hostname:port
            external_port = SQS_PORT_EXTERNAL or get_external_port(
                headers, request_handler)
            content_str = re.sub(
                r'<QueueUrl>\s*([a-z]+)://[^<]*:([0-9]+)/([^<]*)\s*</QueueUrl>',
                r'<QueueUrl>\1://%s:%s/\3</QueueUrl>' %
                (HOSTNAME_EXTERNAL, external_port), content_str)
            # encode account ID in queue URL
            content_str = re.sub(
                r'<QueueUrl>\s*([a-z]+)://([^/]+)/queue/([^<]*)\s*</QueueUrl>',
                r'<QueueUrl>\1://\2/%s/\3</QueueUrl>' %
                constants.TEST_AWS_ACCOUNT_ID, content_str)
            # fix queue ARN
            content_str = re.sub(
                r'<([a-zA-Z0-9]+)>\s*arn:aws:sqs:elasticmq:([^<]+)</([a-zA-Z0-9]+)>',
                r'<\1>arn:aws:sqs:%s:\2</\3>' % region_name, content_str)

            if action == 'CreateQueue':
                queue_url = re.match(r'.*<QueueUrl>(.*)</QueueUrl>',
                                     content_str, re.DOTALL).group(1)
                _set_queue_attributes(queue_url, req_data)

        elif action == 'SendMessageBatch':
            if validate_empty_message_batch(data, req_data):
                msg = 'There should be at least one SendMessageBatchRequestEntry in the request.'
                return make_requests_error(code=404,
                                           code_string='EmptyBatchRequest',
                                           message=msg)

        # instruct listeners to fetch new SQS message
        if action in ('SendMessage', 'SendMessageBatch'):
            _process_sent_message(path, req_data, headers)

        if content_str_original != content_str:
            # if changes have been made, return patched response
            response.headers['content-length'] = len(content_str)
            return requests_response(content_str,
                                     headers=response.headers,
                                     status_code=response.status_code)
Example #37
0
def validate_empty_message_batch(data, req_data):
    data = to_str(data).split('Entries=')
    if len(data) > 1 and not req_data.get('Entries'):
        return True
    return False
Example #38
0
 def forward_request(self, method, path, data, headers):
     event = json.loads(to_str(data))
     events.append(event)
     return 200
    def test_apigateway_with_lambda_integration(self):
        apigw_client = aws_stack.connect_to_service('apigateway')

        # create Lambda function
        lambda_name = 'apigw-lambda-%s' % short_uid()
        self.create_lambda_function(lambda_name)
        lambda_uri = aws_stack.lambda_function_arn(lambda_name)
        target_uri = aws_stack.apigateway_invocations_arn(lambda_uri)

        # create REST API
        api = apigw_client.create_rest_api(name='test-api', description='')
        api_id = api['id']
        root_res_id = apigw_client.get_resources(restApiId=api_id)['items'][0]['id']
        api_resource = apigw_client.create_resource(restApiId=api_id, parentId=root_res_id, pathPart='test')

        apigw_client.put_method(
            restApiId=api_id,
            resourceId=api_resource['id'],
            httpMethod='GET',
            authorizationType='NONE'
        )

        rs = apigw_client.put_integration(
            restApiId=api_id,
            resourceId=api_resource['id'],
            httpMethod='GET',
            integrationHttpMethod='POST',
            type='AWS',
            uri=target_uri,
            timeoutInMillis=3000,
            contentHandling='CONVERT_TO_BINARY',
            requestTemplates={
                'application/json': '{"param1": "$input.params(\'param1\')"}'
            }
        )
        integration_keys = ['httpMethod', 'type', 'passthroughBehavior', 'cacheKeyParameters', 'uri', 'cacheNamespace',
            'timeoutInMillis', 'contentHandling', 'requestParameters']
        self.assertEqual(rs['ResponseMetadata']['HTTPStatusCode'], 200)
        for key in integration_keys:
            self.assertIn(key, rs)
        self.assertNotIn('responseTemplates', rs)

        apigw_client.create_deployment(restApiId=api_id, stageName=self.TEST_STAGE_NAME)

        rs = apigw_client.get_integration(
            restApiId=api_id,
            resourceId=api_resource['id'],
            httpMethod='GET'
        )
        self.assertEqual(rs['ResponseMetadata']['HTTPStatusCode'], 200)
        self.assertEqual(rs['type'], 'AWS')
        self.assertEqual(rs['httpMethod'], 'POST')
        self.assertEqual(rs['uri'], target_uri)

        # invoke the gateway endpoint
        url = gateway_request_url(api_id=api_id, stage_name=self.TEST_STAGE_NAME, path='/test')
        response = requests.get('%s?param1=foobar' % url)
        self.assertLess(response.status_code, 400)
        content = json.loads(to_str(response.content))
        self.assertEqual(content.get('httpMethod'), 'GET')
        self.assertEqual(content.get('requestContext', {}).get('resourceId'), api_resource['id'])
        self.assertEqual(content.get('requestContext', {}).get('stage'), self.TEST_STAGE_NAME)
        self.assertEqual(content.get('body'), '{"param1": "foobar"}')

        # delete integration
        rs = apigw_client.delete_integration(
            restApiId=api_id,
            resourceId=api_resource['id'],
            httpMethod='GET',
        )
        self.assertEqual(rs['ResponseMetadata']['HTTPStatusCode'], 200)

        with self.assertRaises(ClientError) as ctx:
            # This call should not be successful as the integration is deleted
            apigw_client.get_integration(
                restApiId=api_id,
                resourceId=api_resource['id'],
                httpMethod='GET'
            )
        self.assertEqual(ctx.exception.response['Error']['Code'], 'BadRequestException')

        # clean up
        lambda_client = aws_stack.connect_to_service('lambda')
        lambda_client.delete_function(FunctionName=lambda_name)
        apigw_client.delete_rest_api(restApiId=api_id)
Example #40
0
def test_lambda_runtimes():

    lambda_client = aws_stack.connect_to_service('lambda')

    # deploy and invoke lambda - Python 2.7
    zip_file = testutil.create_lambda_archive(load_file(TEST_LAMBDA_PYTHON),
                                              get_content=True,
                                              libs=TEST_LAMBDA_LIBS,
                                              runtime=LAMBDA_RUNTIME_PYTHON27)
    testutil.create_lambda_function(func_name=TEST_LAMBDA_NAME_PY,
                                    zip_file=zip_file,
                                    runtime=LAMBDA_RUNTIME_PYTHON27)

    # Invocation Type not set
    result = lambda_client.invoke(FunctionName=TEST_LAMBDA_NAME_PY,
                                  Payload=b'{}')
    assert result['StatusCode'] == 200
    result_data = result['Payload'].read()
    assert to_str(result_data).strip() == '{}'

    # Invocation Type - RequestResponse
    result = lambda_client.invoke(FunctionName=TEST_LAMBDA_NAME_PY,
                                  Payload=b'{}',
                                  InvocationType='RequestResponse')
    assert result['StatusCode'] == 200
    result_data = result['Payload'].read()
    assert to_str(result_data).strip() == '{}'

    # Invocation Type - Event
    result = lambda_client.invoke(FunctionName=TEST_LAMBDA_NAME_PY,
                                  Payload=b'{}',
                                  InvocationType='Event')
    assert result['StatusCode'] == 202

    # Invocation Type - DryRun
    result = lambda_client.invoke(FunctionName=TEST_LAMBDA_NAME_PY,
                                  Payload=b'{}',
                                  InvocationType='DryRun')
    assert result['StatusCode'] == 204

    if use_docker():
        # deploy and invoke lambda - Python 3.6
        zip_file = testutil.create_lambda_archive(
            load_file(TEST_LAMBDA_PYTHON3),
            get_content=True,
            libs=TEST_LAMBDA_LIBS,
            runtime=LAMBDA_RUNTIME_PYTHON36)
        testutil.create_lambda_function(func_name=TEST_LAMBDA_NAME_PY3,
                                        zip_file=zip_file,
                                        runtime=LAMBDA_RUNTIME_PYTHON36)
        result = lambda_client.invoke(FunctionName=TEST_LAMBDA_NAME_PY3,
                                      Payload=b'{}')
        assert result['StatusCode'] == 200
        result_data = result['Payload'].read()
        assert to_str(result_data).strip() == '{}'

    # deploy and invoke lambda - Java
    if not os.path.exists(TEST_LAMBDA_JAVA):
        mkdir(os.path.dirname(TEST_LAMBDA_JAVA))
        download(TEST_LAMBDA_JAR_URL, TEST_LAMBDA_JAVA)
    # Lambda supports single JAR deployments without the zip, so we upload the JAR directly.
    test_java_jar = load_file(TEST_LAMBDA_JAVA, mode='rb')
    assert test_java_jar is not None
    testutil.create_lambda_function(
        func_name=TEST_LAMBDA_NAME_JAVA,
        zip_file=test_java_jar,
        runtime=LAMBDA_RUNTIME_JAVA8,
        handler='cloud.localstack.sample.LambdaHandler')
    result = lambda_client.invoke(FunctionName=TEST_LAMBDA_NAME_JAVA,
                                  Payload=b'{}')
    assert result['StatusCode'] == 200
    result_data = result['Payload'].read()
    assert 'LinkedHashMap' in to_str(result_data)

    # test SNSEvent
    result = lambda_client.invoke(
        FunctionName=TEST_LAMBDA_NAME_JAVA,
        InvocationType='Event',
        Payload=b'{"Records": [{"Sns": {"Message": "{}"}}]}')
    assert result['StatusCode'] == 202

    # test DDBEvent
    result = lambda_client.invoke(
        FunctionName=TEST_LAMBDA_NAME_JAVA,
        InvocationType='Event',
        Payload=b'{"Records": [{"dynamodb": {"Message": "{}"}}]}')
    assert result['StatusCode'] == 202

    # test KinesisEvent
    result = lambda_client.invoke(
        FunctionName=TEST_LAMBDA_NAME_JAVA,
        Payload=
        b'{"Records": [{"Kinesis": {"Data": "data", "PartitionKey": "partition"}}]}'
    )
    assert result['StatusCode'] == 200
    result_data = result['Payload'].read()
    assert 'KinesisEvent' in to_str(result_data)

    # deploy and invoke lambda - Java with stream handler
    testutil.create_lambda_function(
        func_name=TEST_LAMBDA_NAME_JAVA_STREAM,
        zip_file=test_java_jar,
        runtime=LAMBDA_RUNTIME_JAVA8,
        handler='cloud.localstack.sample.LambdaStreamHandler')
    result = lambda_client.invoke(FunctionName=TEST_LAMBDA_NAME_JAVA_STREAM,
                                  Payload=b'{}')
    assert result['StatusCode'] == 200
    result_data = result['Payload'].read()
    assert to_str(result_data).strip() == '{}'

    # deploy and invoke lambda - Java with serializable input object
    testutil.create_lambda_function(
        func_name=TEST_LAMBDA_NAME_JAVA_SERIALIZABLE,
        zip_file=test_java_jar,
        runtime=LAMBDA_RUNTIME_JAVA8,
        handler='cloud.localstack.sample.SerializedInputLambdaHandler')
    result = lambda_client.invoke(
        FunctionName=TEST_LAMBDA_NAME_JAVA_SERIALIZABLE,
        Payload=b'{"bucket": "test_bucket", "key": "test_key"}')
    assert result['StatusCode'] == 200
    result_data = result['Payload'].read()
    assert json.loads(to_str(result_data)) == {
        'validated': True,
        'bucket': 'test_bucket',
        'key': 'test_key'
    }

    if use_docker():
        # deploy and invoke lambda - Node.js
        zip_file = testutil.create_zip_file(TEST_LAMBDA_NODEJS,
                                            get_content=True)
        testutil.create_lambda_function(func_name=TEST_LAMBDA_NAME_JS,
                                        zip_file=zip_file,
                                        handler='lambda_integration.handler',
                                        runtime=LAMBDA_RUNTIME_NODEJS)
        result = lambda_client.invoke(FunctionName=TEST_LAMBDA_NAME_JS,
                                      Payload=b'{}')
        assert result['StatusCode'] == 200
        result_data = result['Payload'].read()
        assert to_str(result_data).strip() == '{}'

        # deploy and invoke - .NET Core 2.0. Its already a zip
        zip_file = TEST_LAMBDA_DOTNETCORE2
        zip_file_content = None
        with open(zip_file, 'rb') as file_obj:
            zip_file_content = file_obj.read()
        testutil.create_lambda_function(
            func_name=TEST_LAMBDA_NAME_DOTNETCORE2,
            zip_file=zip_file_content,
            handler=
            'DotNetCore2::DotNetCore2.Lambda.Function::SimpleFunctionHandler',
            runtime=LAMBDA_RUNTIME_DOTNETCORE2)
        result = lambda_client.invoke(
            FunctionName=TEST_LAMBDA_NAME_DOTNETCORE2, Payload=b'{}')
        assert result['StatusCode'] == 200
        result_data = result['Payload'].read()
        assert to_str(result_data).strip() == '{}'
Example #41
0
 def _fix_next_token_request(data):
     # Fix for https://github.com/localstack/localstack/issues/1527
     pattern = r'"nextToken":\s*"([0-9]+)"'
     replacement = r'"nextToken": \1'
     return re.sub(pattern, replacement, to_str(data))
Example #42
0
    def return_response(self, method, path, data, headers, response):
        if path.startswith('/shell') or method == 'GET':
            return

        data = json.loads(to_str(data))

        # update table definitions
        if data and 'TableName' in data and 'KeySchema' in data:
            TABLE_DEFINITIONS[data['TableName']] = data

        if response._content:
            # fix the table and latest stream ARNs (DynamoDBLocal hardcodes "ddblocal" as the region)
            content_replaced = re.sub(
                r'("TableArn"|"LatestStreamArn"|"StreamArn")\s*:\s*"arn:aws:dynamodb:ddblocal:([^"]+)"',
                r'\1: "arn:aws:dynamodb:%s:\2"' % aws_stack.get_region(),
                to_str(response._content))
            if content_replaced != response._content:
                response._content = content_replaced
                fix_headers_for_updated_response(response)

        action = headers.get('X-Amz-Target')
        if not action:
            return

        # upgrade event version to 1.1
        record = {
            'eventID': '1',
            'eventVersion': '1.1',
            'dynamodb': {
                'ApproximateCreationDateTime': time.time(),
                'StreamViewType': 'NEW_AND_OLD_IMAGES',
                'SizeBytes': -1
            },
            'awsRegion': aws_stack.get_region(),
            'eventSource': 'aws:dynamodb'
        }
        records = [record]

        streams_enabled_cache = {}
        table_name = data.get('TableName')
        event_sources_or_streams_enabled = has_event_sources_or_streams_enabled(
            table_name, streams_enabled_cache)

        if action == '%s.UpdateItem' % ACTION_PREFIX:
            if response.status_code == 200 and event_sources_or_streams_enabled:
                existing_item = self._thread_local('existing_item')
                record[
                    'eventName'] = 'INSERT' if not existing_item else 'MODIFY'

                updated_item = find_existing_item(data)
                if not updated_item:
                    return
                record['dynamodb']['Keys'] = data['Key']
                if existing_item:
                    record['dynamodb']['OldImage'] = existing_item
                record['dynamodb']['NewImage'] = updated_item
                record['dynamodb']['SizeBytes'] = len(json.dumps(updated_item))
        elif action == '%s.BatchWriteItem' % ACTION_PREFIX:
            records = self.prepare_batch_write_item_records(record, data)
            for record in records:
                event_sources_or_streams_enabled = (
                    event_sources_or_streams_enabled
                    or has_event_sources_or_streams_enabled(
                        record['eventSourceARN'], streams_enabled_cache))

        elif action == '%s.TransactWriteItems' % ACTION_PREFIX:
            records = self.prepare_transact_write_item_records(record, data)
            for record in records:
                event_sources_or_streams_enabled = (
                    event_sources_or_streams_enabled
                    or has_event_sources_or_streams_enabled(
                        record['eventSourceARN'], streams_enabled_cache))

        elif action == '%s.PutItem' % ACTION_PREFIX:
            if response.status_code == 200:
                keys = dynamodb_extract_keys(item=data['Item'],
                                             table_name=table_name)
                if isinstance(keys, Response):
                    return keys
                # fix response
                if response._content == '{}':
                    response._content = update_put_item_response_content(
                        data, response._content)
                    fix_headers_for_updated_response(response)
                if event_sources_or_streams_enabled:
                    existing_item = self._thread_local('existing_item')
                    record[
                        'eventName'] = 'INSERT' if not existing_item else 'MODIFY'
                    # prepare record keys
                    record['dynamodb']['Keys'] = keys
                    record['dynamodb']['NewImage'] = data['Item']
                    record['dynamodb']['SizeBytes'] = len(
                        json.dumps(data['Item']))
                    if existing_item:
                        record['dynamodb']['OldImage'] = existing_item

        elif action in [
                '%s.GetItem' % ACTION_PREFIX,
                '%s.Query' % ACTION_PREFIX
        ]:
            if response.status_code == 200:
                content = json.loads(to_str(response.content))
                # make sure we append 'ConsumedCapacity', which is properly
                # returned by dynalite, but not by AWS's DynamoDBLocal
                if 'ConsumedCapacity' not in content and data.get(
                        'ReturnConsumedCapacity') in ['TOTAL', 'INDEXES']:
                    content['ConsumedCapacity'] = {
                        'TableName': table_name,
                        'CapacityUnits': 5,  # TODO hardcoded
                        'ReadCapacityUnits': 2,
                        'WriteCapacityUnits': 3
                    }
                    response._content = json.dumps(content)
                    fix_headers_for_updated_response(response)

        elif action == '%s.DeleteItem' % ACTION_PREFIX:
            if response.status_code == 200 and event_sources_or_streams_enabled:
                old_item = self._thread_local('existing_item')
                record['eventName'] = 'REMOVE'
                record['dynamodb']['Keys'] = data['Key']
                record['dynamodb']['OldImage'] = old_item

        elif action == '%s.CreateTable' % ACTION_PREFIX:
            if 'StreamSpecification' in data:
                if response.status_code == 200:
                    content = json.loads(to_str(response._content))
                    create_dynamodb_stream(
                        data,
                        content['TableDescription'].get('LatestStreamLabel'))

            event_publisher.fire_event(
                event_publisher.EVENT_DYNAMODB_CREATE_TABLE,
                payload={'n': event_publisher.get_hash(table_name)})

            if data.get('Tags') and response.status_code == 200:
                table_arn = json.loads(
                    response._content)['TableDescription']['TableArn']
                TABLE_TAGS[table_arn] = {
                    tag['Key']: tag['Value']
                    for tag in data['Tags']
                }

            return

        elif action == '%s.DeleteTable' % ACTION_PREFIX:
            if response.status_code == 200:
                table_arn = json.loads(response._content).get(
                    'TableDescription', {}).get('TableArn')
                event_publisher.fire_event(
                    event_publisher.EVENT_DYNAMODB_DELETE_TABLE,
                    payload={'n': event_publisher.get_hash(table_name)})
                self.delete_all_event_source_mappings(table_arn)
                dynamodbstreams_api.delete_streams(table_arn)
                TABLE_TAGS.pop(table_arn, None)
            return

        elif action == '%s.UpdateTable' % ACTION_PREFIX:
            if 'StreamSpecification' in data:
                if response.status_code == 200:
                    content = json.loads(to_str(response._content))
                    create_dynamodb_stream(
                        data,
                        content['TableDescription'].get('LatestStreamLabel'))
            return

        elif action == '%s.TagResource' % ACTION_PREFIX:
            table_arn = data['ResourceArn']
            if table_arn not in TABLE_TAGS:
                TABLE_TAGS[table_arn] = {}
            TABLE_TAGS[table_arn].update(
                {tag['Key']: tag['Value']
                 for tag in data.get('Tags', [])})
            return

        elif action == '%s.UntagResource' % ACTION_PREFIX:
            table_arn = data['ResourceArn']
            for tag_key in data.get('TagKeys', []):
                TABLE_TAGS.get(table_arn, {}).pop(tag_key, None)
            return

        else:
            # nothing to do
            return

        if event_sources_or_streams_enabled and records and 'eventName' in records[
                0]:
            if 'TableName' in data:
                records[0]['eventSourceARN'] = aws_stack.dynamodb_table_arn(
                    table_name)

            forward_to_lambda(records)
            forward_to_ddb_stream(records)
Example #43
0
    def _test_api_gateway_lambda_proxy_integration(self, fn_name, path):

        self.create_lambda_function(fn_name)
        # create API Gateway and connect it to the Lambda proxy backend
        lambda_uri = aws_stack.lambda_function_arn(fn_name)
        invocation_uri = 'arn:aws:apigateway:%s:lambda:path/2015-03-31/functions/%s/invocations'
        target_uri = invocation_uri % (DEFAULT_REGION, lambda_uri)

        result = self.connect_api_gateway_to_http_with_lambda_proxy(
            'test_gateway2', target_uri, path=path)

        api_id = result['id']
        path_map = get_rest_api_paths(api_id)
        _, resource = get_resource_for_path('/lambda/foo1', path_map)

        # make test request to gateway and check response
        path = path.replace('{test_param1}', 'foo1')
        path = path + '?foo=foo&bar=bar&bar=baz'

        url = INBOUND_GATEWAY_URL_PATTERN.format(
            api_id=api_id, stage_name=self.TEST_STAGE_NAME, path=path)

        data = {'return_status_code': 203, 'return_headers': {'foo': 'bar123'}}
        result = requests.post(
            url,
            data=json.dumps(data),
            headers={'User-Agent': 'python-requests/testing'})

        self.assertEqual(result.status_code, 203)
        self.assertEqual(result.headers.get('foo'), 'bar123')

        parsed_body = json.loads(to_str(result.content))
        self.assertEqual(parsed_body.get('return_status_code'), 203)
        self.assertDictEqual(parsed_body.get('return_headers'),
                             {'foo': 'bar123'})
        self.assertDictEqual(parsed_body.get('queryStringParameters'), {
            'foo': 'foo',
            'bar': ['bar', 'baz']
        })

        request_context = parsed_body.get('requestContext')
        source_ip = request_context['identity'].pop('sourceIp')

        self.assertTrue(
            re.match(r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$', source_ip))

        self.assertEqual(request_context['path'], '/lambda/foo1')
        self.assertEqual(request_context['accountId'], TEST_AWS_ACCOUNT_ID)
        self.assertEqual(request_context['resourceId'], resource.get('id'))
        self.assertEqual(request_context['stage'], self.TEST_STAGE_NAME)
        self.assertEqual(request_context['identity']['userAgent'],
                         'python-requests/testing')

        result = requests.delete(url, data=json.dumps(data))
        self.assertEqual(result.status_code, 404)

        # send message with non-ASCII chars
        body_msg = '🙀 - 参よ'
        result = requests.post(url,
                               data=json.dumps({'return_raw_body': body_msg}))
        self.assertEqual(to_str(result.content), body_msg)
Example #44
0
    def forward_request(self, method, path, data, headers):

        if method == 'POST' and path == '/':
            req_data = urlparse.parse_qs(to_str(data))
            req_action = req_data['Action'][0]
            topic_arn = req_data.get('TargetArn') or req_data.get('TopicArn')

            if topic_arn:
                topic_arn = topic_arn[0]
                do_create_topic(topic_arn)

            if req_action == 'SetSubscriptionAttributes':
                sub = get_subscription_by_arn(req_data['SubscriptionArn'][0])
                if not sub:
                    return make_error(message='Unable to find subscription for given ARN', code=400)
                attr_name = req_data['AttributeName'][0]
                attr_value = req_data['AttributeValue'][0]
                sub[attr_name] = attr_value
                return make_response(req_action)
            elif req_action == 'GetSubscriptionAttributes':
                sub = get_subscription_by_arn(req_data['SubscriptionArn'][0])
                if not sub:
                    return make_error(message='Unable to find subscription for given ARN', code=400)
                content = '<Attributes>'
                for key, value in sub.items():
                    content += '<entry><key>%s</key><value>%s</value></entry>\n' % (key, value)
                content += '</Attributes>'
                return make_response(req_action, content=content)
            elif req_action == 'Subscribe':
                if 'Endpoint' not in req_data:
                    return make_error(message='Endpoint not specified in subscription', code=400)
            elif req_action == 'Unsubscribe':
                if 'SubscriptionArn' not in req_data:
                    return make_error(message='SubscriptionArn not specified in unsubscribe request', code=400)
                do_unsubscribe(req_data.get('SubscriptionArn')[0])

            elif req_action == 'Publish':
                message = req_data['Message'][0]
                sqs_client = aws_stack.connect_to_service('sqs')
                for subscriber in SNS_SUBSCRIPTIONS[topic_arn]:
                    filter_policy = json.loads(subscriber.get('FilterPolicy', '{}'))
                    message_attributes = get_message_attributes(req_data)
                    if check_filter_policy(filter_policy, message_attributes):
                        if subscriber['Protocol'] == 'sqs':
                            endpoint = subscriber['Endpoint']
                            if 'sqs_queue_url' in subscriber:
                                queue_url = subscriber.get('sqs_queue_url')
                            elif '://' in endpoint:
                                queue_url = endpoint
                            else:
                                queue_name = endpoint.split(':')[5]
                                queue_url = aws_stack.get_sqs_queue_url(queue_name)
                                subscriber['sqs_queue_url'] = queue_url
                            try:
                                sqs_client.send_message(
                                    QueueUrl=queue_url,
                                    MessageBody=create_sns_message_body(subscriber, req_data)
                                )
                            except Exception as exc:
                                return make_error(message=str(exc), code=400)
                        elif subscriber['Protocol'] == 'lambda':
                            lambda_api.process_sns_notification(
                                subscriber['Endpoint'],
                                topic_arn, message, subject=req_data.get('Subject', [None])[0]
                            )
                        elif subscriber['Protocol'] in ['http', 'https']:
                            try:
                                message_body = create_sns_message_body(subscriber, req_data)
                            except Exception as exc:
                                return make_error(message=str(exc), code=400)
                            requests.post(
                                subscriber['Endpoint'],
                                headers={
                                    'Content-Type': 'text/plain',
                                    'x-amz-sns-message-type': 'Notification'
                                },
                                data=message_body
                            )
                        else:
                            LOGGER.warning('Unexpected protocol "%s" for SNS subscription' % subscriber['Protocol'])
                # return response here because we do not want the request to be forwarded to SNS
                return make_response(req_action)

        return True
Example #45
0
def download_s3_object(s3, bucket, path):
    with tempfile.SpooledTemporaryFile() as tmpfile:
        s3.Bucket(bucket).download_fileobj(path, tmpfile)
        tmpfile.seek(0)
        return to_str(tmpfile.read())
Example #46
0
 def forward_request(self, method, path, data, headers):
     records.append((json.loads(to_str(data)), headers))
     return 429
Example #47
0
def post_request():
    action = request.headers.get('x-amz-target')
    data = json.loads(to_str(request.data))
    response = None
    if action == '%s.ListDeliveryStreams' % ACTION_HEADER_PREFIX:
        response = {
            'DeliveryStreamNames': get_delivery_stream_names(),
            'HasMoreDeliveryStreams': False
        }
    elif action == '%s.CreateDeliveryStream' % ACTION_HEADER_PREFIX:
        stream_name = data['DeliveryStreamName']
        region_name = extract_region_from_auth_header(request.headers)
        response = create_stream(
            stream_name,
            delivery_stream_type=data.get('DeliveryStreamType'),
            delivery_stream_type_configuration=data.get(
                'KinesisStreamSourceConfiguration'),
            s3_destination=data.get('S3DestinationConfiguration'),
            elasticsearch_destination=data.get(
                'ElasticsearchDestinationConfiguration'),
            tags=data.get('Tags'),
            region_name=region_name)
    elif action == '%s.DeleteDeliveryStream' % ACTION_HEADER_PREFIX:
        stream_name = data['DeliveryStreamName']
        response = delete_stream(stream_name)
    elif action == '%s.DescribeDeliveryStream' % ACTION_HEADER_PREFIX:
        stream_name = data['DeliveryStreamName']
        response = get_stream(stream_name)
        if not response:
            return error_not_found(stream_name)
        response = {'DeliveryStreamDescription': response}
    elif action == '%s.PutRecord' % ACTION_HEADER_PREFIX:
        stream_name = data['DeliveryStreamName']
        record = data['Record']
        put_record(stream_name, record)
        response = {'RecordId': str(uuid.uuid4())}
    elif action == '%s.PutRecordBatch' % ACTION_HEADER_PREFIX:
        stream_name = data['DeliveryStreamName']
        records = data['Records']
        put_records(stream_name, records)
        request_responses = []
        for i in records:
            request_responses.append({'RecordId': str(uuid.uuid4())})
        response = {'FailedPutCount': 0, 'RequestResponses': request_responses}
    elif action == '%s.UpdateDestination' % ACTION_HEADER_PREFIX:
        stream_name = data['DeliveryStreamName']
        version_id = data['CurrentDeliveryStreamVersionId']
        destination_id = data['DestinationId']
        s3_update = data[
            'S3DestinationUpdate'] if 'S3DestinationUpdate' in data else None
        update_destination(stream_name=stream_name,
                           destination_id=destination_id,
                           s3_update=s3_update,
                           version_id=version_id)
        es_update = data[
            'ESDestinationUpdate'] if 'ESDestinationUpdate' in data else None
        update_destination(stream_name=stream_name,
                           destination_id=destination_id,
                           elasticsearch_update=es_update,
                           version_id=version_id)
        response = {}
    elif action == '%s.ListTagsForDeliveryStream' % ACTION_HEADER_PREFIX:
        response = get_delivery_stream_tags(data['DeliveryStreamName'],
                                            data.get('ExclusiveStartTagKey'),
                                            data.get('Limit', 50))
    else:
        response = error_response('Unknown action "%s"' % action,
                                  code=400,
                                  error_type='InvalidAction')

    if isinstance(response, dict):
        response = jsonify(response)
    return response
Example #48
0
 def _replace(response, pattern, replacement):
     content = to_str(response.content)
     response._content = re.sub(pattern, replacement, content)
Example #49
0
    def forward_request(self, method, path, data, headers):
        if method == "OPTIONS":
            return 200

        # check region
        try:
            aws_stack.check_valid_region(headers)
            aws_stack.set_default_region_in_headers(headers)
        except Exception as e:
            return make_error(message=str(e), code=400)

        if method == "POST":
            # parse payload and extract fields
            req_data = parse_qs(to_str(data), keep_blank_values=True)

            # parse data from query path
            if not req_data:
                parsed_path = urlparse(path)
                req_data = parse_qs(parsed_path.query, keep_blank_values=True)

            req_action = req_data["Action"][0]
            topic_arn = (
                req_data.get("TargetArn") or req_data.get("TopicArn") or req_data.get("ResourceArn")
            )
            if topic_arn:
                topic_arn = topic_arn[0]
                topic_arn = aws_stack.fix_account_id_in_arns(topic_arn)
            if req_action == "SetSubscriptionAttributes":
                sub = get_subscription_by_arn(req_data["SubscriptionArn"][0])
                if not sub:
                    return make_error(message="Unable to find subscription for given ARN", code=400)

                attr_name = req_data["AttributeName"][0]
                attr_value = req_data["AttributeValue"][0]
                sub[attr_name] = attr_value
                return make_response(req_action)

            elif req_action == "GetSubscriptionAttributes":
                sub = get_subscription_by_arn(req_data["SubscriptionArn"][0])
                if not sub:
                    return make_error(
                        message="Subscription with arn {0} not found".format(
                            req_data["SubscriptionArn"][0]
                        ),
                        code=404,
                        code_string="NotFound",
                    )

                content = "<Attributes>"
                for key, value in sub.items():
                    if key in HTTP_SUBSCRIPTION_ATTRIBUTES:
                        continue
                    content += "<entry><key>%s</key><value>%s</value></entry>\n" % (
                        key,
                        value,
                    )
                content += "</Attributes>"
                return make_response(req_action, content=content)

            elif req_action == "Subscribe":
                if "Endpoint" not in req_data:
                    return make_error(message="Endpoint not specified in subscription", code=400)

                if req_data["Protocol"][0] not in SNS_PROTOCOLS:
                    return make_error(
                        message=f"Invalid parameter: Amazon SNS does not support this protocol string: "
                        f"{req_data['Protocol'][0]}",
                        code=400,
                    )

                if ".fifo" in req_data["Endpoint"][0] and ".fifo" not in topic_arn:
                    return make_error(
                        message="FIFO SQS Queues can not be subscribed to standard SNS topics",
                        code=400,
                        code_string="InvalidParameter",
                    )

            elif req_action == "ConfirmSubscription":
                if "TopicArn" not in req_data:
                    return make_error(
                        message="TopicArn not specified in confirm subscription request",
                        code=400,
                    )

                if "Token" not in req_data:
                    return make_error(
                        message="Token not specified in confirm subscription request",
                        code=400,
                    )

                do_confirm_subscription(req_data.get("TopicArn")[0], req_data.get("Token")[0])

            elif req_action == "Unsubscribe":
                if "SubscriptionArn" not in req_data:
                    return make_error(
                        message="SubscriptionArn not specified in unsubscribe request",
                        code=400,
                    )

                do_unsubscribe(req_data.get("SubscriptionArn")[0])

            elif req_action == "DeleteTopic":
                do_delete_topic(topic_arn)

            elif req_action == "Publish":
                if req_data.get("Subject") == [""]:
                    return make_error(code=400, code_string="InvalidParameter", message="Subject")
                if not req_data.get("Message") or all(
                    not message for message in req_data.get("Message")
                ):
                    return make_error(
                        code=400, code_string="InvalidParameter", message="Empty message"
                    )

                if topic_arn and ".fifo" in topic_arn and not req_data.get("MessageGroupId"):
                    return make_error(
                        code=400,
                        code_string="InvalidParameter",
                        message="The MessageGroupId parameter is required for FIFO topics",
                    )

                sns_backend = SNSBackend.get()
                # No need to create a topic to send SMS or single push notifications with SNS
                # but we can't mock a sending so we only return that it went well
                if "PhoneNumber" not in req_data and "TargetArn" not in req_data:
                    if topic_arn not in sns_backend.sns_subscriptions:
                        return make_error(
                            code=404,
                            code_string="NotFound",
                            message="Topic does not exist",
                        )

                message_id = publish_message(topic_arn, req_data, headers)

                # return response here because we do not want the request to be forwarded to SNS backend
                return make_response(req_action, message_id=message_id)

            elif req_action == "PublishBatch":
                entries = parse_urlencoded_data(
                    req_data, "PublishBatchRequestEntries.member", "MessageAttributes.entry"
                )

                if len(entries) > 10:
                    return make_error(
                        message="The batch request contains more entries than permissible",
                        code=400,
                        code_string="TooManyEntriesInBatchRequest",
                    )
                ids = [entry["Id"] for entry in entries]

                if len(set(ids)) != len(entries):
                    return make_error(
                        message="Two or more batch entries in the request have the same Id",
                        code=400,
                        code_string="BatchEntryIdsNotDistinct",
                    )

                if topic_arn and ".fifo" in topic_arn:
                    if not all(["MessageGroupId" in entry for entry in entries]):
                        return make_error(
                            message="The MessageGroupId parameter is required for FIFO topics",
                            code=400,
                            code_string="InvalidParameter",
                        )

                response = publish_batch(topic_arn, entries, headers)
                return requests_response_xml(
                    req_action, response, xmlns="http://sns.amazonaws.com/doc/2010-03-31/"
                )

            elif req_action == "ListTagsForResource":
                tags = do_list_tags_for_resource(topic_arn)
                content = "<Tags/>"
                if len(tags) > 0:
                    content = "<Tags>"
                    for tag in tags:
                        content += "<member>"
                        content += "<Key>%s</Key>" % tag["Key"]
                        content += "<Value>%s</Value>" % tag["Value"]
                        content += "</member>"
                    content += "</Tags>"
                return make_response(req_action, content=content)

            elif req_action == "CreateTopic":
                sns_backend = SNSBackend.get()
                topic_arn = aws_stack.sns_topic_arn(req_data["Name"][0])
                tag_resource_success = self._extract_tags(topic_arn, req_data, True, sns_backend)
                sns_backend.sns_subscriptions[topic_arn] = (
                    sns_backend.sns_subscriptions.get(topic_arn) or []
                )
                # in case if there is an error it returns an error , other wise it will continue as expected.
                if not tag_resource_success:
                    return make_error(
                        code=400,
                        code_string="InvalidParameter",
                        message="Topic already exists with different tags",
                    )

            elif req_action == "TagResource":
                sns_backend = SNSBackend.get()
                self._extract_tags(topic_arn, req_data, False, sns_backend)
                return make_response(req_action)

            elif req_action == "UntagResource":
                tags_to_remove = []
                req_tags = {k: v for k, v in req_data.items() if k.startswith("TagKeys.member.")}
                req_tags = req_tags.values()
                for tag in req_tags:
                    tags_to_remove.append(tag[0])
                do_untag_resource(topic_arn, tags_to_remove)
                return make_response(req_action)

            data = self._reset_account_id(data)
            return Request(data=data, headers=headers, method=method)

        return True
Example #50
0
 def _fix_next_token_response(response):
     # Fix for https://github.com/localstack/localstack/issues/1527
     pattern = r'"nextToken":\s*([0-9]+)'
     replacement = r'"nextToken": "\1"'
     response._content = re.sub(pattern, replacement,
                                to_str(response.content))
Example #51
0
        raise Exception('Not implemented.')

    def startup(self):
        pass

    def cleanup(self, arn=None):
        pass

    def run_lambda_executor(self, cmd, env_vars={}, async=False):
        process = run(cmd, async=True, stderr=subprocess.PIPE, outfile=subprocess.PIPE, env_vars=env_vars)
        if async:
            result = '{"async": "%s"}' % async
            log_output = 'Lambda executed asynchronously'
        else:
            return_code = process.wait()
            result = to_str(process.stdout.read())
            log_output = to_str(process.stderr.read())

            if return_code != 0:
                raise Exception('Lambda process returned error status code: %s. Output:\n%s' %
                    (return_code, log_output))
        return result, log_output


# holds information about an existing container.
class ContainerInfo:
    """
    Contains basic information about a docker container.
    """
    def __init__(self, name, entry_point):
        self.name = name
Example #52
0
def invoke_function(function):
    """ Invoke an existing function
        ---
        operationId: 'invokeFunction'
        parameters:
            - name: 'request'
              in: body
    """
    # function here can either be an arn or a function name
    arn = func_arn(function)

    # arn can also contain a qualifier, extract it from there if so
    m = re.match('(arn:aws:lambda:.*:.*:function:[a-zA-Z0-9-_]+)(:.*)?', arn)
    if m and m.group(2):
        qualifier = m.group(2)[1:]
        arn = m.group(1)
    else:
        qualifier = request.args.get('Qualifier')

    data = request.get_data()
    if data:
        data = to_str(data)
        try:
            data = json.loads(data)
        except Exception:
            try:
                # try to read chunked content
                data = json.loads(parse_chunked_data(data))
            except Exception:
                return error_response(
                    'The payload is not JSON: %s' % data,
                    415,
                    error_type='UnsupportedMediaTypeException')

    # Default invocation type is RequestResponse
    invocation_type = request.environ.get('HTTP_X_AMZ_INVOCATION_TYPE',
                                          'RequestResponse')

    def _create_response(result, status_code=200):
        """ Create the final response for the given invocation result """
        if isinstance(result, Response):
            return result
        details = {'StatusCode': status_code, 'Payload': result, 'Headers': {}}
        if isinstance(result, dict):
            for key in ('StatusCode', 'Payload', 'FunctionError'):
                if result.get(key):
                    details[key] = result[key]
        # Try to parse parse payload as JSON
        payload = details['Payload']
        if payload and isinstance(
                payload, (str, bytes)) and payload[0] in ('[', '{', '"'):
            try:
                details['Payload'] = json.loads(details['Payload'])
            except Exception:
                pass
        # Set error headers
        if details.get('FunctionError'):
            details['Headers']['X-Amz-Function-Error'] = str(
                details['FunctionError'])
        # Construct response object
        response_obj = details['Payload']
        if isinstance(response_obj,
                      (dict, list, bool)) or is_number(response_obj):
            # Assume this is a JSON response
            response_obj = jsonify(response_obj)
        else:
            response_obj = str(response_obj)
            details['Headers']['Content-Type'] = 'text/plain'
        return response_obj, details['StatusCode'], details['Headers']

    # check if this lambda function exists
    not_found = None
    if arn not in arn_to_lambda:
        not_found = not_found_error(arn)
    elif qualifier and not arn_to_lambda.get(arn).qualifier_exists(qualifier):
        not_found = not_found_error('{0}:{1}'.format(arn, qualifier))

    if not_found:
        forward_result = forward_to_fallback_url(func_arn, data)
        if forward_result is not None:
            return _create_response(forward_result)
        return not_found

    if invocation_type == 'RequestResponse':
        result = run_lambda(asynchronous=False,
                            func_arn=arn,
                            event=data,
                            context={},
                            version=qualifier)
        return _create_response(result)
    elif invocation_type == 'Event':
        run_lambda(asynchronous=True,
                   func_arn=arn,
                   event=data,
                   context={},
                   version=qualifier)
        return _create_response('', status_code=202)
    elif invocation_type == 'DryRun':
        # Assume the dry run always passes.
        return _create_response('', status_code=204)
    return error_response(
        'Invocation type not one of: RequestResponse, Event or DryRun',
        code=400,
        error_type='InvalidParameterValueException')
Example #53
0
def test_lambda_runtimes():

    lambda_client = aws_stack.connect_to_service('lambda')

    # deploy and invoke lambda - Python 2.7
    zip_file = testutil.create_lambda_archive(load_file(TEST_LAMBDA_PYTHON), get_content=True,
        libs=TEST_LAMBDA_LIBS, runtime=LAMBDA_RUNTIME_PYTHON27)
    testutil.create_lambda_function(func_name=TEST_LAMBDA_NAME_PY,
        zip_file=zip_file, runtime=LAMBDA_RUNTIME_PYTHON27)
    result = lambda_client.invoke(FunctionName=TEST_LAMBDA_NAME_PY, Payload=b'{}')
    assert result['StatusCode'] == 200
    result_data = result['Payload'].read()
    assert to_str(result_data).strip() == '{}'

    if use_docker():
        # deploy and invoke lambda - Python 3.6
        zip_file = testutil.create_lambda_archive(load_file(TEST_LAMBDA_PYTHON3), get_content=True,
            libs=TEST_LAMBDA_LIBS, runtime=LAMBDA_RUNTIME_PYTHON36)
        testutil.create_lambda_function(func_name=TEST_LAMBDA_NAME_PY3,
            zip_file=zip_file, runtime=LAMBDA_RUNTIME_PYTHON36)
        result = lambda_client.invoke(FunctionName=TEST_LAMBDA_NAME_PY3, Payload=b'{}')
        assert result['StatusCode'] == 200
        result_data = result['Payload'].read()
        assert to_str(result_data).strip() == '{}'

    # deploy and invoke lambda - Java
    if not os.path.exists(TEST_LAMBDA_JAVA):
        mkdir(os.path.dirname(TEST_LAMBDA_JAVA))
        download(TEST_LAMBDA_JAR_URL, TEST_LAMBDA_JAVA)
    zip_file = testutil.create_zip_file(TEST_LAMBDA_JAVA, get_content=True)
    testutil.create_lambda_function(func_name=TEST_LAMBDA_NAME_JAVA, zip_file=zip_file,
        runtime=LAMBDA_RUNTIME_JAVA8, handler='cloud.localstack.sample.LambdaHandler')
    result = lambda_client.invoke(FunctionName=TEST_LAMBDA_NAME_JAVA, Payload=b'{}')
    assert result['StatusCode'] == 200
    result_data = result['Payload'].read()
    assert 'LinkedHashMap' in to_str(result_data)

    # test SNSEvent
    result = lambda_client.invoke(FunctionName=TEST_LAMBDA_NAME_JAVA, InvocationType='Event',
                                  Payload=b'{"Records": [{"Sns": {"Message": "{}"}}]}')
    assert result['StatusCode'] == 200
    result_data = result['Payload'].read()
    assert json.loads(to_str(result_data)) == {'async': 'True'}

    # test DDBEvent
    result = lambda_client.invoke(FunctionName=TEST_LAMBDA_NAME_JAVA, InvocationType='Event',
                                  Payload=b'{"Records": [{"dynamodb": {"Message": "{}"}}]}')
    assert result['StatusCode'] == 200
    result_data = result['Payload'].read()
    assert json.loads(to_str(result_data)) == {'async': 'True'}

    # test KinesisEvent
    result = lambda_client.invoke(FunctionName=TEST_LAMBDA_NAME_JAVA,
                                  Payload=b'{"Records": [{"Kinesis": {"Data": "data", "PartitionKey": "partition"}}]}')
    assert result['StatusCode'] == 200
    result_data = result['Payload'].read()
    assert 'KinesisEvent' in to_str(result_data)

    # deploy and invoke lambda - Java with stream handler
    testutil.create_lambda_function(func_name=TEST_LAMBDA_NAME_JAVA_STREAM, zip_file=zip_file,
        runtime=LAMBDA_RUNTIME_JAVA8, handler='cloud.localstack.sample.LambdaStreamHandler')
    result = lambda_client.invoke(FunctionName=TEST_LAMBDA_NAME_JAVA_STREAM, Payload=b'{}')
    assert result['StatusCode'] == 200
    result_data = result['Payload'].read()
    assert to_str(result_data).strip() == '{}'

    # deploy and invoke lambda - Java with serializable input object
    testutil.create_lambda_function(func_name=TEST_LAMBDA_NAME_JAVA_SERIALIZABLE, zip_file=zip_file,
        runtime=LAMBDA_RUNTIME_JAVA8, handler='cloud.localstack.sample.SerializedInputLambdaHandler')
    result = lambda_client.invoke(FunctionName=TEST_LAMBDA_NAME_JAVA_SERIALIZABLE,
                                  Payload=b'{"bucket": "test_bucket", "key": "test_key"}')
    assert result['StatusCode'] == 200
    result_data = result['Payload'].read()
    assert json.loads(to_str(result_data)) == {'validated': True, 'bucket': 'test_bucket', 'key': 'test_key'}

    if use_docker():
        # deploy and invoke lambda - Node.js
        zip_file = testutil.create_zip_file(TEST_LAMBDA_NODEJS, get_content=True)
        testutil.create_lambda_function(func_name=TEST_LAMBDA_NAME_JS,
            zip_file=zip_file, handler='lambda_integration.handler', runtime=LAMBDA_RUNTIME_NODEJS)
        result = lambda_client.invoke(FunctionName=TEST_LAMBDA_NAME_JS, Payload=b'{}')
        assert result['StatusCode'] == 200
        result_data = result['Payload'].read()
        assert to_str(result_data).strip() == '{}'

        # deploy and invoke - .NET Core 2.0. Its already a zip
        zip_file = TEST_LAMBDA_DOTNETCORE2
        zip_file_content = None
        with open(zip_file, 'rb') as file_obj:
            zip_file_content = file_obj.read()
        testutil.create_lambda_function(func_name=TEST_LAMBDA_NAME_DOTNETCORE2, zip_file=zip_file_content,
            handler='DotNetCore2::DotNetCore2.Lambda.Function::SimpleFunctionHandler',
            runtime=LAMBDA_RUNTIME_DOTNETCORE2)
        result = lambda_client.invoke(FunctionName=TEST_LAMBDA_NAME_DOTNETCORE2, Payload=b'{}')
        assert result['StatusCode'] == 200
        result_data = result['Payload'].read()
        assert to_str(result_data).strip() == '{}'
Example #54
0
    def return_response(self, method, path, data, headers, response):

        parsed = urlparse.urlparse(path)
        # TODO: consider the case of hostname-based (as opposed to path-based) bucket addressing
        bucket_name = parsed.path.split('/')[1]

        # POST requests to S3 may include a success_action_redirect field,
        # which should be used to redirect a client to a new location.
        if method == 'POST':
            key, redirect_url = multipart_content.find_multipart_redirect_url(
                data, headers)
            if key and redirect_url:
                response.status_code = 303
                response.headers['Location'] = expand_redirect_url(
                    redirect_url, key, bucket_name)
                LOGGER.debug('S3 POST {} to {}'.format(
                    response.status_code, response.headers['Location']))

        # get subscribers and send bucket notifications
        if method in ('PUT', 'DELETE') and '/' in path[1:]:
            # check if this is an actual put object request, because it could also be
            # a put bucket request with a path like this: /bucket_name/
            if len(path[1:].split('/')[1]) > 0:
                parts = parsed.path[1:].split('/', 1)
                # ignore bucket notification configuration requests
                if parsed.query != 'notification' and parsed.query != 'lifecycle':
                    object_path = parts[1] if parts[1][
                        0] == '/' else '/%s' % parts[1]
                    send_notifications(method, bucket_name, object_path)

        # publish event for creation/deletion of buckets:
        if method in ('PUT', 'DELETE') and ('/' not in path[1:] or
                                            len(path[1:].split('/')[1]) <= 0):
            event_type = (event_publisher.EVENT_S3_CREATE_BUCKET if method
                          == 'PUT' else event_publisher.EVENT_S3_DELETE_BUCKET)
            event_publisher.fire_event(
                event_type,
                payload={'n': event_publisher.get_hash(bucket_name)})

        # fix an upstream issue in moto S3 (see https://github.com/localstack/localstack/issues/382)
        if method == 'PUT' and parsed.query == 'policy':
            response._content = ''
            response.status_code = 204
            return response

        # append CORS headers to response
        if response:
            append_cors_headers(bucket_name,
                                request_method=method,
                                request_headers=headers,
                                response=response)

            response_content_str = None
            try:
                response_content_str = to_str(response._content)
            except Exception:
                pass

            # we need to un-pretty-print the XML, otherwise we run into this issue with Spark:
            # https://github.com/jserver/mock-s3/pull/9/files
            # https://github.com/localstack/localstack/issues/183
            # Note: yet, we need to make sure we have a newline after the first line: <?xml ...>\n
            if response_content_str and response_content_str.startswith('<'):
                is_bytes = isinstance(response._content, six.binary_type)
                response._content = re.sub(r'([^\?])>\n\s*<',
                                           r'\1><',
                                           response_content_str,
                                           flags=re.MULTILINE)
                if is_bytes:
                    response._content = to_bytes(response._content)
            # update content-length headers (fix https://github.com/localstack/localstack/issues/541)
            if isinstance(response._content,
                          (six.string_types, six.binary_type)):
                response.headers['content-length'] = len(response._content)
Example #55
0
    def forward_request(self, method, path, data, headers):
        result = handle_special_request(method, path, data, headers)
        if result is not None:
            return result

        if not data:
            data = '{}'
        data = json.loads(to_str(data))
        ddb_client = aws_stack.connect_to_service('dynamodb')
        action = headers.get('X-Amz-Target')

        if self.should_throttle(action):
            return error_response_throughput()

        ProxyListenerDynamoDB.thread_local.existing_item = None

        if action == '%s.CreateTable' % ACTION_PREFIX:
            # Check if table exists, to avoid error log output from DynamoDBLocal
            if self.table_exists(ddb_client, data['TableName']):
                return error_response(message='Table already created',
                                      error_type='ResourceInUseException',
                                      code=400)

        if action == '%s.CreateGlobalTable' % ACTION_PREFIX:
            return create_global_table(data)

        elif action == '%s.DescribeGlobalTable' % ACTION_PREFIX:
            return describe_global_table(data)

        elif action == '%s.ListGlobalTables' % ACTION_PREFIX:
            return list_global_tables(data)

        elif action == '%s.UpdateGlobalTable' % ACTION_PREFIX:
            return update_global_table(data)

        elif action in ('%s.PutItem' % ACTION_PREFIX,
                        '%s.UpdateItem' % ACTION_PREFIX,
                        '%s.DeleteItem' % ACTION_PREFIX):
            # find an existing item and store it in a thread-local, so we can access it in return_response,
            # in order to determine whether an item already existed (MODIFY) or not (INSERT)
            try:
                if has_event_sources_or_streams_enabled(data['TableName']):
                    ProxyListenerDynamoDB.thread_local.existing_item = find_existing_item(
                        data)
            except Exception as e:
                if 'ResourceNotFoundException' in str(e):
                    return get_table_not_found_error()
                raise

            # Fix incorrect values if ReturnValues==ALL_OLD and ReturnConsumedCapacity is
            # empty, see https://github.com/localstack/localstack/issues/2049
            if ((data.get('ReturnValues') == 'ALL_OLD') or (not data.get('ReturnValues'))) \
                    and not data.get('ReturnConsumedCapacity'):
                data['ReturnConsumedCapacity'] = 'TOTAL'
                return Request(data=json.dumps(data),
                               method=method,
                               headers=headers)

        elif action == '%s.DescribeTable' % ACTION_PREFIX:
            # Check if table exists, to avoid error log output from DynamoDBLocal
            if not self.table_exists(ddb_client, data['TableName']):
                return get_table_not_found_error()

        elif action == '%s.DeleteTable' % ACTION_PREFIX:
            # Check if table exists, to avoid error log output from DynamoDBLocal
            if not self.table_exists(ddb_client, data['TableName']):
                return get_table_not_found_error()

        elif action == '%s.BatchWriteItem' % ACTION_PREFIX:
            existing_items = []
            for table_name in sorted(data['RequestItems'].keys()):
                for request in data['RequestItems'][table_name]:
                    for key in ['PutRequest', 'DeleteRequest']:
                        inner_request = request.get(key)
                        if inner_request:
                            existing_items.append(
                                find_existing_item(inner_request, table_name))
            ProxyListenerDynamoDB.thread_local.existing_items = existing_items

        elif action == '%s.Query' % ACTION_PREFIX:
            if data.get('IndexName'):
                if not is_index_query_valid(to_str(data['TableName']),
                                            data.get('Select')):
                    return error_response(
                        message=
                        'One or more parameter values were invalid: Select type '
                        'ALL_ATTRIBUTES is not supported for global secondary index id-index '
                        'because its projection type is not ALL',
                        error_type='ValidationException',
                        code=400)

        elif action == '%s.TransactWriteItems' % ACTION_PREFIX:
            existing_items = []
            for item in data['TransactItems']:
                for key in ['Put', 'Update', 'Delete']:
                    inner_item = item.get(key)
                    if inner_item:
                        existing_items.append(find_existing_item(inner_item))
            ProxyListenerDynamoDB.thread_local.existing_items = existing_items

        elif action == '%s.UpdateTimeToLive' % ACTION_PREFIX:
            # TODO: TTL status is maintained/mocked but no real expiry is happening for items
            response = Response()
            response.status_code = 200
            self._table_ttl_map[data['TableName']] = {
                'AttributeName':
                data['TimeToLiveSpecification']['AttributeName'],
                'Status': data['TimeToLiveSpecification']['Enabled']
            }
            response._content = json.dumps(
                {'TimeToLiveSpecification': data['TimeToLiveSpecification']})
            fix_headers_for_updated_response(response)
            return response

        elif action == '%s.DescribeTimeToLive' % ACTION_PREFIX:
            response = Response()
            response.status_code = 200
            if data['TableName'] in self._table_ttl_map:
                if self._table_ttl_map[data['TableName']]['Status']:
                    ttl_status = 'ENABLED'
                else:
                    ttl_status = 'DISABLED'
                response._content = json.dumps({
                    'TimeToLiveDescription': {
                        'AttributeName':
                        self._table_ttl_map[data['TableName']]
                        ['AttributeName'],
                        'TimeToLiveStatus':
                        ttl_status
                    }
                })
            else:  # TTL for dynamodb table not set
                response._content = json.dumps({
                    'TimeToLiveDescription': {
                        'TimeToLiveStatus': 'DISABLED'
                    }
                })

            fix_headers_for_updated_response(response)
            return response

        elif action == '%s.TagResource' % ACTION_PREFIX or action == '%s.UntagResource' % ACTION_PREFIX:
            response = Response()
            response.status_code = 200
            response._content = ''  # returns an empty body on success.
            fix_headers_for_updated_response(response)
            return response

        elif action == '%s.ListTagsOfResource' % ACTION_PREFIX:
            response = Response()
            response.status_code = 200
            response._content = json.dumps({
                'Tags': [{
                    'Key': k,
                    'Value': v
                } for k, v in TABLE_TAGS.get(data['ResourceArn'], {}).items()]
            })
            fix_headers_for_updated_response(response)
            return response

        return True
def fix_hardcoded_creation_date(response):
    search = '<CreationTime>2011-05-23T15:47:44Z</CreationTime>'
    replace = '<CreationTime>%s</CreationTime>' % timestamp()
    response._content = to_str(response._content
                               or '').replace(search, replace)
    response.headers['Content-Length'] = str(len(response._content))
Example #57
0
    def forward_request(self, method, path, data, headers):
        if method == 'OPTIONS':
            return 200

        # check region
        try:
            aws_stack.check_valid_region(headers)
            aws_stack.set_default_region_in_headers(headers)
        except Exception as e:
            return make_error(message=str(e), code=400)

        if method == 'POST':
            # parse payload and extract fields
            req_data = urlparse.parse_qs(to_str(data), keep_blank_values=True)

            # parse data from query path
            if not req_data:
                parsed_path = urlparse.urlparse(path)
                req_data = urlparse.parse_qs(parsed_path.query,
                                             keep_blank_values=True)

            req_action = req_data['Action'][0]
            topic_arn = req_data.get('TargetArn') or req_data.get(
                'TopicArn') or req_data.get('ResourceArn')

            if topic_arn:
                topic_arn = topic_arn[0]
                topic_arn = aws_stack.fix_account_id_in_arns(topic_arn)

            if req_action == 'SetSubscriptionAttributes':
                sub = get_subscription_by_arn(req_data['SubscriptionArn'][0])
                if not sub:
                    return make_error(
                        message='Unable to find subscription for given ARN',
                        code=400)

                attr_name = req_data['AttributeName'][0]
                attr_value = req_data['AttributeValue'][0]
                sub[attr_name] = attr_value
                return make_response(req_action)

            elif req_action == 'GetSubscriptionAttributes':
                sub = get_subscription_by_arn(req_data['SubscriptionArn'][0])
                if not sub:
                    return make_error(
                        message='Unable to find subscription for given ARN',
                        code=400)

                content = '<Attributes>'
                for key, value in sub.items():
                    content += '<entry><key>%s</key><value>%s</value></entry>\n' % (
                        key, value)
                content += '</Attributes>'
                return make_response(req_action, content=content)

            elif req_action == 'Subscribe':
                if 'Endpoint' not in req_data:
                    return make_error(
                        message='Endpoint not specified in subscription',
                        code=400)

            elif req_action == 'ConfirmSubscription':
                if 'TopicArn' not in req_data:
                    return make_error(
                        message=
                        'TopicArn not specified in confirm subscription request',
                        code=400)

                if 'Token' not in req_data:
                    return make_error(
                        message=
                        'Token not specified in confirm subscription request',
                        code=400)

                do_confirm_subscription(
                    req_data.get('TopicArn')[0],
                    req_data.get('Token')[0])

            elif req_action == 'Unsubscribe':
                if 'SubscriptionArn' not in req_data:
                    return make_error(
                        message=
                        'SubscriptionArn not specified in unsubscribe request',
                        code=400)

                do_unsubscribe(req_data.get('SubscriptionArn')[0])

            elif req_action == 'DeleteTopic':
                do_delete_topic(topic_arn)

            elif req_action == 'Publish':
                if req_data.get('Subject') == ['']:
                    return make_error(code=400,
                                      code_string='InvalidParameter',
                                      message='Subject')

                # No need to create a topic to send SMS or single push notifications with SNS
                # but we can't mock a sending so we only return that it went well
                if 'PhoneNumber' not in req_data and 'TargetArn' not in req_data:
                    if topic_arn not in SNS_SUBSCRIPTIONS.keys():
                        return make_error(code=404,
                                          code_string='NotFound',
                                          message='Topic does not exist')

                publish_message(topic_arn, req_data)

                # return response here because we do not want the request to be forwarded to SNS backend
                return make_response(req_action)

            elif req_action == 'ListTagsForResource':
                tags = do_list_tags_for_resource(topic_arn)
                content = '<Tags/>'
                if len(tags) > 0:
                    content = '<Tags>'
                    for tag in tags:
                        content += '<member>'
                        content += '<Key>%s</Key>' % tag['Key']
                        content += '<Value>%s</Value>' % tag['Value']
                        content += '</member>'
                    content += '</Tags>'
                return make_response(req_action, content=content)

            elif req_action == 'CreateTopic':
                topic_arn = aws_stack.sns_topic_arn(req_data['Name'][0])
                tag_resource_success = self._extract_tags(
                    topic_arn, req_data, True)
                # in case if there is an error it returns an error , other wise it will continue as expected.
                if not tag_resource_success:
                    return make_error(
                        code=400,
                        code_string='InvalidParameter',
                        message='Topic already exists with different tags')

            elif req_action == 'TagResource':
                self._extract_tags(topic_arn, req_data, False)
                return make_response(req_action)

            elif req_action == 'UntagResource':
                tags_to_remove = []
                req_tags = {
                    k: v
                    for k, v in req_data.items()
                    if k.startswith('TagKeys.member.')
                }
                req_tags = req_tags.values()
                for tag in req_tags:
                    tags_to_remove.append(tag[0])
                do_untag_resource(topic_arn, tags_to_remove)
                return make_response(req_action)

            data = self._reset_account_id(data)
            return Request(data=data, headers=headers, method=method)

        return True
Example #58
0
    def return_response(self, method, path, data, headers, response):

        bucket_name = get_bucket_name(path, headers)

        # No path-name based bucket name? Try host-based
        hostname_parts = headers['host'].split('.')
        if (not bucket_name or len(bucket_name) == 0) and len(hostname_parts) > 1:
            bucket_name = hostname_parts[0]

        # POST requests to S3 may include a success_action_redirect field,
        # which should be used to redirect a client to a new location.
        key = None
        if method == 'POST':
            key, redirect_url = multipart_content.find_multipart_redirect_url(data, headers)

            if key and redirect_url:
                response.status_code = 303
                response.headers['Location'] = expand_redirect_url(redirect_url, key, bucket_name)
                LOGGER.debug('S3 POST {} to {}'.format(response.status_code, response.headers['Location']))

        parsed = urlparse.urlparse(path)

        bucket_name_in_host = headers['host'].startswith(bucket_name)

        should_send_notifications = all([
            method in ('PUT', 'POST', 'DELETE'),
            '/' in path[1:] or bucket_name_in_host,
            # check if this is an actual put object request, because it could also be
            # a put bucket request with a path like this: /bucket_name/
            bucket_name_in_host or (len(path[1:].split('/')) > 1 and len(path[1:].split('/')[1]) > 0),
            self.is_query_allowable(method, parsed.query)
        ])

        # get subscribers and send bucket notifications
        if should_send_notifications:
            # if we already have a good key, use it, otherwise examine the path
            if key:
                object_path = '/' + key
            elif bucket_name_in_host:
                object_path = parsed.path
            else:
                parts = parsed.path[1:].split('/', 1)
                object_path = parts[1] if parts[1][0] == '/' else '/%s' % parts[1]

            send_notifications(method, bucket_name, object_path)

        # publish event for creation/deletion of buckets:
        if method in ('PUT', 'DELETE') and ('/' not in path[1:] or len(path[1:].split('/')[1]) <= 0):
            event_type = (event_publisher.EVENT_S3_CREATE_BUCKET if method == 'PUT'
                else event_publisher.EVENT_S3_DELETE_BUCKET)
            event_publisher.fire_event(event_type, payload={'n': event_publisher.get_hash(bucket_name)})

        # fix an upstream issue in moto S3 (see https://github.com/localstack/localstack/issues/382)
        if method == 'PUT' and parsed.query == 'policy':
            response._content = ''
            response.status_code = 204
            return response

        if response:
            # append CORS headers to response
            append_cors_headers(bucket_name, request_method=method, request_headers=headers, response=response)

            response_content_str = None
            try:
                response_content_str = to_str(response._content)
            except Exception:
                pass

            # we need to un-pretty-print the XML, otherwise we run into this issue with Spark:
            # https://github.com/jserver/mock-s3/pull/9/files
            # https://github.com/localstack/localstack/issues/183
            # Note: yet, we need to make sure we have a newline after the first line: <?xml ...>\n
            if response_content_str and response_content_str.startswith('<'):
                is_bytes = isinstance(response._content, six.binary_type)
                response._content = re.sub(r'([^\?])>\n\s*<', r'\1><', response_content_str, flags=re.MULTILINE)
                if is_bytes:
                    response._content = to_bytes(response._content)
                # fix content-type: https://github.com/localstack/localstack/issues/618
                #                   https://github.com/localstack/localstack/issues/549
                if 'text/html' in response.headers.get('Content-Type', ''):
                    response.headers['Content-Type'] = 'application/xml; charset=utf-8'

                response.headers['content-length'] = len(response._content)

            # update content-length headers (fix https://github.com/localstack/localstack/issues/541)
            if method == 'DELETE':
                response.headers['content-length'] = len(response._content)
Example #59
0
    def _execute(self, func_arn, func_details, event, context=None, version=None):
        lambda_cwd = func_details.cwd
        runtime = func_details.runtime
        handler = func_details.handler
        environment = self._prepare_environment(func_details)

        # configure USE_SSL in environment
        if config.USE_SSL:
            environment['USE_SSL'] = '1'

        # prepare event body
        if not event:
            LOG.warning('Empty event body specified for invocation of Lambda "%s"' % func_arn)
            event = {}
        event_body = json.dumps(json_safe(event))
        stdin = self.prepare_event(environment, event_body)

        main_endpoint = get_main_endpoint_from_container()

        environment['LOCALSTACK_HOSTNAME'] = main_endpoint
        environment['_HANDLER'] = handler
        if os.environ.get('HTTP_PROXY'):
            environment['HTTP_PROXY'] = os.environ['HTTP_PROXY']
        if func_details.timeout:
            environment['AWS_LAMBDA_FUNCTION_TIMEOUT'] = str(func_details.timeout)
        if context:
            environment['AWS_LAMBDA_FUNCTION_NAME'] = context.function_name
            environment['AWS_LAMBDA_FUNCTION_VERSION'] = context.function_version
            environment['AWS_LAMBDA_FUNCTION_INVOKED_ARN'] = context.invoked_function_arn
            environment['AWS_LAMBDA_COGNITO_IDENTITY'] = json.dumps(context.cognito_identity or {})
            if context.client_context is not None:
                environment['AWS_LAMBDA_CLIENT_CONTEXT'] = json.dumps(to_str(
                    base64.b64decode(to_bytes(context.client_context))))

        # custom command to execute in the container
        command = ''
        events_file = ''

        if USE_CUSTOM_JAVA_EXECUTOR and is_java_lambda(runtime):
            # if running a Java Lambda with our custom executor, set up classpath arguments
            java_opts = Util.get_java_opts()
            stdin = None
            # copy executor jar into temp directory
            target_file = os.path.join(lambda_cwd, os.path.basename(LAMBDA_EXECUTOR_JAR))
            if not os.path.exists(target_file):
                cp_r(LAMBDA_EXECUTOR_JAR, target_file)
            # TODO cleanup once we have custom Java Docker image
            taskdir = '/var/task'
            events_file = '_lambda.events.%s.json' % short_uid()
            save_file(os.path.join(lambda_cwd, events_file), event_body)
            classpath = Util.get_java_classpath(target_file)
            command = ("bash -c 'cd %s; java %s -cp \"%s\" \"%s\" \"%s\" \"%s\"'" %
                (taskdir, java_opts, classpath, LAMBDA_EXECUTOR_CLASS, handler, events_file))

        # accept any self-signed certificates for outgoing calls from the Lambda
        if is_nodejs_runtime(runtime):
            environment['NODE_TLS_REJECT_UNAUTHORIZED'] = '0'

        # determine the command to be executed (implemented by subclasses)
        cmd = self.prepare_execution(func_arn, environment, runtime, command, handler, lambda_cwd)

        # lambci writes the Lambda result to stdout and logs to stderr, fetch it from there!
        LOG.info('Running lambda cmd: %s' % cmd)
        result = self.run_lambda_executor(cmd, stdin, env_vars=environment, func_details=func_details)

        # clean up events file
        events_file and os.path.exists(events_file) and rm_rf(events_file)

        return result
Example #60
0
def response_regex_replace(response, search, replace):
    response._content = re.sub(search, replace, to_str(response._content), flags=re.DOTALL | re.MULTILINE)
    response.headers['Content-Length'] = str(len(response._content))