def test_batcher_with_prefix(self): """Limit batch operation to object keys which start with the given prefix.""" def mock_list(**kwargs): """Mock for S3.list_objects_v2 which honors the object prefix.""" keys = [ 'test-key-1', 'test-key-2', 'important/path', 'important-file' ] return { 'Contents': [{ 'Key': k } for k in keys if k.startswith(kwargs['Prefix'])], 'IsTruncated': False } self.batcher_main.S3.list_objects_v2 = mock_list with mock.patch.object(self.batcher_main, 'LOGGER') as mock_logger: num_keys = self.batcher_main.batch_lambda_handler( {}, common.MockLambdaContext()) self.assertEqual(2, num_keys) mock_logger.assert_has_calls([ mock.call.info('Invoked with event %s', {}), mock.call.info('Restricting batch operation to prefix: %s', 'important'), mock.call.info('Finalize: sending last batch of keys'), mock.call.info('Sending SQS batch of %d keys: %s ... %s', 2, 'important/path', 'important-file') ])
def test_batcher_invoke_with_continuation(self): """Invoke the batcher with a continuation token.""" self.batcher_main.S3.list_objects_v2 = lambda **kwargs: { 'Contents': [{ 'Key': kwargs['ContinuationToken'] } # Make sure continuation token is included. ], 'IsTruncated': False } with mock.patch.object(self.batcher_main, 'LOGGER'): num_keys = self.batcher_main.batch_lambda_handler( {'S3ContinuationToken': 'test-continuation-token'}, common.MockLambdaContext()) self.assertEqual(1, num_keys) self.batcher_main.SQS.assert_has_calls([ mock.call.Queue().send_messages(Entries=[{ 'Id': '0', 'MessageBody': json.dumps({ 'Records': [{ 's3': { 'bucket': { 'name': 'test_s3_bucket' }, 'object': { 'key': 'test-continuation-token' } } }] }) }]) ])
def test_batcher_one_full_batch(self): """Batcher enqueues the configured maximum number of objects in a single SQS message.""" self.batcher_main.S3.list_objects_v2 = lambda **kwargs: { 'Contents': [{ 'Key': 'test-key-1' }, { 'Key': 'test-key-2' }], 'IsTruncated': False } with mock.patch.object(self.batcher_main, 'LOGGER') as mock_logger: num_keys = self.batcher_main.batch_lambda_handler( {}, common.MockLambdaContext()) self.assertEqual(2, num_keys) mock_logger.assert_has_calls([ mock.call.info('Invoked with event %s', {}), mock.call.info('Finalize: sending last batch of keys'), mock.call.info('Sending SQS batch of %d keys: %s ... %s', 2, 'test-key-1', 'test-key-2') ]) self.batcher_main.SQS.assert_has_calls([ mock.call.Queue('test_queue'), mock.call.Queue().send_messages(Entries=[{ 'Id': '0', 'MessageBody': json.dumps({ 'Records': [{ 's3': { 'bucket': { 'name': 'test_s3_bucket' }, 'object': { 'key': 'test-key-1' } } }, { 's3': { 'bucket': { 'name': 'test_s3_bucket' }, 'object': { 'key': 'test-key-2' } } }] }) }]) ]) self.batcher_main.CLOUDWATCH.assert_not_called() self.batcher_main.LAMBDA.assert_not_called()
def test_batcher_empty_bucket(self): """Batcher does nothing for an empty bucket.""" self.batcher_main.S3.list_objects_v2 = lambda **kwargs: {} with mock.patch.object(self.batcher_main, 'LOGGER') as mock_logger: num_keys = self.batcher_main.batch_lambda_handler( {}, common.MockLambdaContext()) self.assertEqual(0, num_keys) mock_logger.assert_has_calls([ mock.call.info('Invoked with event %s', {}), mock.call.info('The S3 bucket is empty; nothing to do') ])
def test_dispatcher_no_messages(self): """Dispatcher doesn't do anything if there are no SQS messages.""" self.dispatcher_main.SQS_QUEUE.receive_messages.return_value = [] with mock.patch.object(self.dispatcher_main, 'LOGGER') as mock_logger: invocations = self.dispatcher_main.dispatch_lambda_handler( {}, common.MockLambdaContext(decrement_ms=10000)) self.assertEqual(0, invocations) mock_logger.assert_has_calls([ mock.call.info('No SQS messages found'), mock.call.info('Invoked %d total analyzers', 0) ])
def test_dispatcher_invokes_analyzer(self): """Dispatcher flattens multiple messages and invokes an analyzer.""" self.dispatcher_main.SQS_QUEUE.receive_messages.return_value = [ MockSQSMessage(body=json.dumps({ 'Records': [{ 's3': { 'object': { 'key': 'test-key-1' } } }, { 's3': { 'object': { 'key': 'test-key-2' } } }] }), receipt_handle='receipt1'), MockSQSMessage(body=json.dumps( {'Records': [{ 's3': { 'object': { 'key': 'test-key-3' } } }]}), receipt_handle='receipt2') ] with mock.patch.object(self.dispatcher_main, 'LOGGER') as mock_logger: invocations = self.dispatcher_main.dispatch_lambda_handler( {}, common.MockLambdaContext(decrement_ms=10000)) self.assertEqual(1, invocations) mock_logger.assert_has_calls([ mock.call.info('Sending %d object(s) to an analyzer: %s', 3, ['test-key-1', 'test-key-2', 'test-key-3']), mock.call.info('Invoked %d total analyzers', 1) ]) self.dispatcher_main.LAMBDA.assert_has_calls([ mock.call.invoke(FunctionName='test-analyzer', InvocationType='Event', Payload=json.dumps({ 'S3Objects': ['test-key-1', 'test-key-2', 'test-key-3'], 'SQSReceipts': ['receipt1', 'receipt2'] }), Qualifier='Production') ])
def test_dispatch_handler(self): """Dispatch handler creates and starts processes.""" with mock.patch.object(self.main, 'Process') as mock_process: self.main.dispatch_lambda_handler(None, common.MockLambdaContext()) mock_process.assert_has_calls([ mock.call(target=self.main._sqs_poll, args=(self.config1, mock.ANY)), mock.call(target=self.main._sqs_poll, args=(self.config2, mock.ANY)), mock.call().start(), mock.call().start(), mock.call().join(), mock.call().join() ])
def test_dispatcher_invalid_message(self): """Dispatcher discards invalid SQS messages.""" self.dispatcher_main.SQS_QUEUE.receive_messages.return_value = [ MockSQSMessage(body=json.dumps({'InvalidKey': 'Value'}), receipt_handle='receipt1'), MockSQSMessage(body=json.dumps({}), receipt_handle='receipt2'), ] with mock.patch.object(self.dispatcher_main, 'LOGGER') as mock_logger: invocations = self.dispatcher_main.dispatch_lambda_handler( {}, common.MockLambdaContext(decrement_ms=10000)) self.assertEqual(0, invocations) mock_logger.assert_has_calls([ mock.call.warning('Invalid SQS message body: %s', mock.ANY), mock.call.warning('Invalid SQS message body: %s', mock.ANY), mock.call.warning('Removing %d invalid messages', 2), mock.call.info('Invoked %d total analyzers', 0) ])
def test_batcher_one_object(self): """Batcher enqueues a single S3 object.""" self.batcher_main.S3.list_objects_v2 = lambda **kwargs: { 'Contents': [{ 'Key': 'test-key-1' }], 'IsTruncated': False } with mock.patch.object(self.batcher_main, 'LOGGER') as mock_logger: num_keys = self.batcher_main.batch_lambda_handler( {}, common.MockLambdaContext()) self.assertEqual(1, num_keys) mock_logger.assert_has_calls([ mock.call.info('Invoked with event %s', {}), mock.call.info('Finalize: sending last batch of keys'), mock.call.info('Sending SQS batch of %d keys: %s ... %s', 1, 'test-key-1', 'test-key-1') ]) self.batcher_main.SQS.assert_has_calls([ mock.call.Queue('test_queue'), mock.call.Queue().send_messages(Entries=[{ 'Id': '0', 'MessageBody': json. dumps({'Records': [{ 's3': { 'object': { 'key': 'test-key-1' } } }]}) }]) ]) self.batcher_main.CLOUDWATCH.assert_not_called( ) # No error metrics to report. self.batcher_main.LAMBDA.assert_not_called( ) # Second batcher invocation not necessary.
def test_batcher_re_invoke(self): """If the batcher runs out of time, it has to re-invoke itself.""" class MockEnumerator(object): """Simple mock for S3BucketEnumerator which never finishes.""" def __init__(self, *args): # pylint: disable=unused-argument self.continuation_token = 'test-continuation-token' self.finished = False with mock.patch.object(self.batcher_main, 'S3BucketEnumerator', MockEnumerator),\ mock.patch.object(self.batcher_main, 'LOGGER') as mock_logger: self.batcher_main.batch_lambda_handler( {}, common.MockLambdaContext(time_limit_ms=1)) mock_logger.assert_has_calls( [mock.call.info('Invoking another batcher')]) self.batcher_main.LAMBDA.assert_has_calls([ mock.call.invoke( FunctionName='test_batch_lambda_name', InvocationType='Event', Payload='{"S3ContinuationToken": "test-continuation-token"}', Qualifier='Production') ])
def test_sqs_poll(self): """Dispatcher invokes each of the Lambda targets with data from its respective queue.""" self.config1.queue.send_message(MessageBody='queue1-message1') self.config1.queue.send_message(MessageBody='queue1-message2') with mock.patch.object(self.main, 'LOGGER') as mock_logger, \ mock.patch.object(self.main, 'LAMBDA') as mock_lambda, \ mock.patch.object(self.main, 'WAIT_TIME_SECONDS', 0): self.main._sqs_poll(self.config1, common.MockLambdaContext()) mock_logger.assert_has_calls([ mock.call.info('Polling process started: %s => lambda:%s:%s', self.config1.queue.url, self.config1.lambda_name, self.config1.lambda_qualifier), mock.call.info('Sending %d messages to %s:%s', 2, 'analyzer', 'production') ]) mock_lambda.invoke.assert_called_once_with(FunctionName='analyzer', InvocationType='Event', Payload=mock.ANY, Qualifier='production')
def test_batcher_multiple_messages(self): """Batcher enqueues 2 SQS messages.""" def mock_list(**kwargs): """Mock for S3.list_objects_v2 which includes multiple pages of results.""" if 'ContinuationToken' in kwargs: return { 'Contents': [{ 'Key': 'test-key-3' }], 'IsTruncated': False } return { 'Contents': [{ 'Key': 'test-key-1' }, { 'Key': 'test-key-2' }], 'IsTruncated': True, 'NextContinuationToken': 'test-continuation-token' } self.batcher_main.S3.list_objects_v2 = mock_list with mock.patch.object(self.batcher_main, 'LOGGER') as mock_logger: num_keys = self.batcher_main.batch_lambda_handler( {}, common.MockLambdaContext(time_limit_ms=50000, decrement_ms=10000)) self.assertEqual(3, num_keys) mock_logger.assert_has_calls([ mock.call.info('Invoked with event %s', {}), mock.call.info('Finalize: sending last batch of keys'), mock.call.info('Sending SQS batch of %d keys: %s ... %s', 3, 'test-key-1', 'test-key-3') ]) self.batcher_main.SQS.assert_has_calls([ mock.call.Queue('test_queue'), mock.call.Queue().send_messages(Entries=[{ 'Id': '0', 'MessageBody': json.dumps({ 'Records': [{ 's3': { 'bucket': { 'name': 'test_s3_bucket' }, 'object': { 'key': 'test-key-1' } } }, { 's3': { 'bucket': { 'name': 'test_s3_bucket' }, 'object': { 'key': 'test-key-2' } } }] }) }, { 'Id': '1', 'MessageBody': json.dumps({ 'Records': [{ 's3': { 'bucket': { 'name': 'test_s3_bucket' }, 'object': { 'key': 'test-key-3' } } }] }) }]) ]) self.batcher_main.CLOUDWATCH.assert_not_called() self.batcher_main.LAMBDA.assert_not_called()
FILE_MODIFIED_TIME = 'test-last-modified' GOOD_FILE_CONTENTS = 'Hello, world!\n' GOOD_FILE_METADATA = {'filepath': 'win32'} GOOD_S3_OBJECT_KEY = 'space plus+file.test' EVIL_FILE_CONTENTS = 'Hello, evil world!\n' EVIL_FILE_METADATA = {'filepath': '/path/to/mock-evil.exe'} EVIL_S3_OBJECT_KEY = 'evil.exe' MOCK_DYNAMO_TABLE_NAME = 'mock-dynamo-table' MOCK_SNS_TOPIC_ARN = 's3:mock-sns-arn' MOCK_SQS_URL = 'https://sqs.mock.url' MOCK_SQS_RECEIPTS = ['sqs_receipt1', 'sqs_receipt2'] # Mimics minimal parts of S3:ObjectAdded event that triggers the lambda function. LAMBDA_VERSION = 1 TEST_CONTEXT = common.MockLambdaContext(LAMBDA_VERSION) class MockS3Object(object): """Simple mock for boto3.resource('s3').Object""" def __init__(self, bucket_name, object_key): self.name = bucket_name self.key = object_key def download_file(self, download_path): with open(download_path, 'w') as f: f.write(GOOD_FILE_CONTENTS if self.key == GOOD_S3_OBJECT_KEY else EVIL_FILE_CONTENTS) @property def last_modified(self):