class TestStreamAlertSQSClient(object):
    """Test class for StreamAlertSQSClient"""
    def setup(self):
        """Add a fake message to the queue."""
        self.mock_sqs = mock_sqs()
        self.mock_sqs.start()

        sqs = boto3.resource('sqs', region_name=TEST_REGION)
        self.queue = sqs.create_queue(QueueName=StreamAlertSQSClient.QUEUENAME)
        self.client = StreamAlertSQSClient(CONFIG_DATA)

        # Create a fake s3 notification message to send
        bucket = 'unit-testing.streamalerts'
        test_s3_notification = {
            'Records': [
                {
                    'eventVersion': '2.0',
                    'eventSource': 'aws:s3',
                    'awsRegion': 'us-east-1',
                    'eventTime': '2017-08-07T18:26:30.956Z',
                    'eventName': 'S3:PutObject',
                    'userIdentity': {
                        'principalId': 'AWS:AAAAAAAAAAAAAAA'
                    },
                    'requestParameters': {
                        'sourceIPAddress': '127.0.0.1'
                    },
                    'responseElements': {
                        'x-amz-request-id': 'FOO',
                        'x-amz-id-2': 'BAR'
                    },
                    's3': {
                        's3SchemaVersion': '1.0',
                        'configurationId': 'queue',
                        'bucket': {
                            'name': bucket,
                            'ownerIdentity': {
                                'principalId': 'AAAAAAAAAAAAAAA'
                            },
                            'arn': 'arn:aws:s3:::{}'.format(bucket)
                        },
                        'object': {
                            'key':
                            'alerts/dt=2017-08-26-14-02/rule_name_alerts-1304134918401.json',
                            'size': 1494,
                            'eTag': '12214134141431431',
                            'versionId': 'asdfasdfasdf.dfadCJkj1',
                            'sequencer': '1212312321312321321'
                        }
                    }
                },
                {
                    'eventVersion': '2.0',
                    'eventSource': 'aws:s3',
                    'awsRegion': 'us-east-1',
                    'eventTime': '2017-08-07T18:26:30.956Z',
                    'eventName': 'S3:GetObject',
                    'userIdentity': {
                        'principalId': 'AWS:AAAAAAAAAAAAAAA'
                    },
                    'requestParameters': {
                        'sourceIPAddress': '127.0.0.1'
                    },
                    'responseElements': {
                        'x-amz-request-id': 'FOO',
                        'x-amz-id-2': 'BAR'
                    },
                    's3': {
                        's3SchemaVersion': '1.0',
                        'configurationId': 'queue',
                        'bucket': {
                            'name': bucket,
                            'ownerIdentity': {
                                'principalId': 'AAAAAAAAAAAAAAA'
                            },
                            'arn': 'arn:aws:s3:::{}'.format(bucket)
                        },
                        'object': {
                            # Different day than the above record
                            'key':
                            'alerts/dt=2017-08-27-14-02/rule_name_alerts-1304134918401.json',
                            'size': 1494,
                            'eTag': '12214134141431431',
                            'versionId': 'asdfasdfasdf.dfadCJkj1',
                            'sequencer': '1212312321312321321'
                        }
                    }
                }
            ]
        }
        self.queue.send_message(MessageBody=json.dumps(test_s3_notification),
                                QueueUrl=self.client.athena_sqs_url)

    def teardown(self):
        """Purge the Queue and reset the client between runs"""
        self.client.sqs_client.purge_queue(QueueUrl=self.client.athena_sqs_url)
        self.client = None

    @patch('stream_alert.athena_partition_refresh.main.LOGGER')
    def test_delete_messages_none_received(self, mock_logging):
        """Athena SQS - Delete Messages - No Receieved Messages"""
        self.client.delete_messages()

        assert_true(mock_logging.error.called)

    # The return value is not being mocked successfully
    @nottest
    @patch('stream_alert.athena_partition_refresh.main.StreamAlertSQSClient')
    @patch('stream_alert.athena_partition_refresh.main.LOGGER')
    def test_delete_messages_failure(self, mock_logging, mock_sqs_client):
        """Athena SQS - Delete Messages - Failure Response"""
        instance = mock_sqs_client.return_value
        instance.sqs_client.delete_message_batch.return_value = {
            'Successful': [{
                'Id': '2'
            }],
            'Failed': [{
                'Id': '1'
            }]
        }

        self.client.get_messages()
        self.client.unique_s3_buckets_and_keys()
        self.client.delete_messages()

        assert_true(mock_logging.error.called)

    @patch('stream_alert.athena_partition_refresh.main.LOGGER')
    def test_delete_messages_none_processed(self, mock_logging):
        """Athena SQS - Delete Messages - No Processed Messages"""
        self.client.processed_messages = []
        result = self.client.delete_messages()

        assert_true(mock_logging.error.called)
        assert_false(result)

    @patch('stream_alert.athena_partition_refresh.main.LOGGER')
    def test_delete_messages(self, mock_logging):
        """Athena SQS - Delete Messages"""
        self.client.get_messages(max_tries=1)
        self.client.unique_s3_buckets_and_keys()
        self.client.delete_messages()

        assert_true(mock_logging.info.called)

    @patch('stream_alert.athena_partition_refresh.main.LOGGER')
    def test_get_messages_invalid_max_messages(self, mock_logging):
        """Athena SQS - Invalid Max Message Request"""
        resp = self.client.get_messages(max_messages=100)

        assert_true(mock_logging.error.called)
        assert_is_none(resp)

    @patch('stream_alert.athena_partition_refresh.main.LOGGER')
    def test_get_messages(self, mock_logging):
        """Athena SQS - Get Valid Messages"""
        self.client.get_messages(max_tries=1)

        assert_equal(len(self.client.received_messages), 1)
        assert_true(mock_logging.info.called)

    def test_unique_s3_buckets_and_keys(self):
        """Athena SQS - Get Unique Bucket Ids"""
        self.client.get_messages(max_tries=1)
        unique_buckets = self.client.unique_s3_buckets_and_keys()

        assert_equal(
            unique_buckets, {
                'unit-testing.streamalerts':
                set([
                    'alerts/dt=2017-08-26-14-02/rule_name_alerts-1304134918401.json',
                    'alerts/dt=2017-08-27-14-02/rule_name_alerts-1304134918401.json',
                ])
            })
        assert_equal(len(self.client.processed_messages), 2)

    @patch('stream_alert.athena_partition_refresh.main.LOGGER')
    def test_unique_s3_buckets_and_keys_invalid_sqs(self, mock_logging):
        """Athena SQS - Unique Buckets - Invalid SQS Message"""
        self.client.received_messages = ['wrong-format-test']
        unique_buckets = self.client.unique_s3_buckets_and_keys()

        assert_false(unique_buckets)
        assert_true(mock_logging.error.called)

    @patch('stream_alert.athena_partition_refresh.main.LOGGER')
    def test_unique_s3_buckets_and_keys_s3_test_event(self, mock_logging):
        """Athena SQS - Unique Buckets - S3 Test Event"""
        s3_test_event = {
            'Body':
            json.dumps({
                'HostId':
                '8cLeGAmw098X5cv4Zkwcmo8vvZa3eH3eKxsPzbB9wrR+YstdA6Knx4Ip8EXAMPLE',
                'Service': 'Amazon S3',
                'Bucket': 'bucketname',
                'RequestId': '5582815E1AEA5ADF',
                'Time': '2014-10-13T15:57:02.089Z',
                'Event': 's3:TestEvent'
            })
        }
        self.client.received_messages = [s3_test_event]
        unique_buckets = self.client.unique_s3_buckets_and_keys()

        assert_false(unique_buckets)
        assert_true(
            mock_logging.debug.called_with(
                'Skipping S3 bucket notification test event'))

    @patch('stream_alert.athena_partition_refresh.main.LOGGER')
    def test_unique_s3_buckets_and_keys_invalid_record(self, mock_logging):
        """Athena SQS - Unique Buckets - Missing Records Key in SQS Message"""
        self.client.received_messages = [{
            'Body': '{"missing-records-key": 1}'
        }]
        unique_buckets = self.client.unique_s3_buckets_and_keys()

        assert_false(unique_buckets)
        assert_true(mock_logging.error.called)

    @patch('stream_alert.athena_partition_refresh.main.LOGGER')
    def test_unique_s3_buckets_and_keys_non_s3_notification(
            self, mock_logging):
        """Athena SQS - Unique Buckets - Non S3 Notification"""
        self.client.received_messages = [{
            'Body':
            '{"Records": [{"kinesis": 1}]}'
        }]
        unique_buckets = self.client.unique_s3_buckets_and_keys()

        assert_false(unique_buckets)
        assert_true(mock_logging.info.called)
        assert_true(mock_logging.debug.called)

    @patch('stream_alert.athena_partition_refresh.main.LOGGER')
    def test_unique_s3_buckets_and_keys_no_mesages(self, mock_logging):
        """Athena SQS - Unique Buckets - No Receieved Messages"""
        self.client.received_messages = []
        unique_buckets = self.client.unique_s3_buckets_and_keys()

        assert_is_none(unique_buckets)
        assert_true(mock_logging.error.called)
예제 #2
0
class TestStreamAlertSQSClient(object):
    """Test class for StreamAlertSQSClient"""

    @patch.dict(os.environ, {'AWS_DEFAULT_REGION': 'us-west-1'})
    def setup(self):
        """Add a fake message to the queue."""
        self.mock_sqs = mock_sqs()
        self.mock_sqs.start()

        sqs = boto3.resource('sqs')

        config = load_config('tests/unit/conf/')
        prefix = config['global']['account']['prefix']
        name = StreamAlertSQSClient.DEFAULT_QUEUE_NAME.format(prefix)

        self.queue = sqs.create_queue(QueueName=name)

        # Create a fake s3 notification message to send
        bucket = 'unit-testing.streamalerts'
        test_s3_notification = {
            'Records': [
                {
                    'eventVersion': '2.0',
                    'eventSource': 'aws:s3',
                    'awsRegion': 'us-east-1',
                    'eventTime': '2017-08-07T18:26:30.956Z',
                    'eventName': 'S3:PutObject',
                    'userIdentity': {
                        'principalId': 'AWS:AAAAAAAAAAAAAAA'
                    },
                    'requestParameters': {
                        'sourceIPAddress': '127.0.0.1'
                    },
                    'responseElements': {
                        'x-amz-request-id': 'FOO',
                        'x-amz-id-2': 'BAR'
                    },
                    's3': {
                        's3SchemaVersion': '1.0',
                        'configurationId': 'queue',
                        'bucket': {
                            'name': bucket,
                            'ownerIdentity': {
                                'principalId': 'AAAAAAAAAAAAAAA'
                            },
                            'arn': 'arn:aws:s3:::{}'.format(bucket)
                        },
                        'object': {
                            'key': ('alerts/dt=2017-08-2{}-14-02/rule_name_alerts-'
                                    '1304134918401.json'.format(day)),
                            'size': 1494,
                            'eTag': '12214134141431431',
                            'versionId': 'asdfasdfasdf.dfadCJkj1',
                            'sequencer': '1212312321312321321'
                        }
                    }
                }
                for day in {6, 7}
            ]
        }
        self.queue.send_message(MessageBody=json.dumps(test_s3_notification))

        self.client = StreamAlertSQSClient(config)

    def teardown(self):
        """Purge the Queue and reset the client between runs"""
        self.mock_sqs.stop()

    @patch('logging.Logger.error')
    def test_delete_messages_none_received(self, mock_logging):
        """Athena SQS - Delete Messages - No Receieved Messages"""
        self.client.delete_messages()
        assert_true(mock_logging.called)

    @patch('logging.Logger.error')
    def test_delete_messages_failure_retries(self, log_mock):
        """Athena SQS - Delete Messages - Failure Response and push back messages to queue"""
        with patch.object(self.client.sqs_client, 'delete_message_batch') as sqs_mock:
            sqs_mock.return_value = {'Failed': [{'Id': '1'}]}

            self.client.processed_messages = [{'MessageId': '1', 'ReceiptHandle': 'handle1'},
                                              {'MessageId': '2', 'ReceiptHandle': 'handle2'}]
            self.client.delete_messages()
            for message in self.client.processed_messages:
                assert_is_instance(message, dict)

            assert_true(log_mock.called_with('Failed to delete the messages with following'))

    @patch('logging.Logger.error')
    def test_delete_messages_none_processed(self, log_mock):
        """Athena SQS - Delete Messages - No Processed Messages"""
        self.client.processed_messages = []
        result = self.client.delete_messages()

        assert_true(log_mock.called)
        assert_false(result)

    @patch('logging.Logger.info')
    def test_delete_messages(self, log_mock):
        """Athena SQS - Delete Messages"""
        self.client.get_messages(max_tries=1)
        self.client.unique_s3_buckets_and_keys()
        self.client.delete_messages()

        assert_true(log_mock.called)

    @patch('logging.Logger.error')
    def test_get_messages_invalid_max_messages(self, log_mock):
        """Athena SQS - Invalid Max Message Request"""
        resp = self.client.get_messages(max_messages=100)

        assert_true(log_mock.called)
        assert_is_none(resp)

    @patch('logging.Logger.info')
    def test_get_messages(self, log_mock):
        """Athena SQS - Get Valid Messages"""
        self.client.get_messages(max_tries=1)

        assert_equal(len(self.client.received_messages), 1)
        assert_true(log_mock.called)

    def test_unique_s3_buckets_and_keys(self):
        """Athena SQS - Get Unique Bucket Ids"""
        self.client.get_messages(max_tries=1)
        unique_buckets = self.client.unique_s3_buckets_and_keys()

        assert_equal(unique_buckets, {
            'unit-testing.streamalerts': set([
                'alerts/dt=2017-08-26-14-02/rule_name_alerts-1304134918401.json',
                'alerts/dt=2017-08-27-14-02/rule_name_alerts-1304134918401.json',
            ])
        })
        assert_equal(len(self.client.processed_messages), 2)

    @patch('logging.Logger.error')
    def test_unique_s3_buckets_and_keys_invalid_sqs(self, log_mock):
        """Athena SQS - Unique Buckets - Invalid SQS Message"""
        self.client.received_messages = ['wrong-format-test']
        unique_buckets = self.client.unique_s3_buckets_and_keys()

        assert_false(unique_buckets)
        assert_true(log_mock.called)

    @patch('logging.Logger.debug')
    def test_unique_s3_buckets_and_keys_s3_test_event(self, log_mock):
        """Athena SQS - Unique Buckets - S3 Test Event"""
        s3_test_event = {'Body': json.dumps({
            'HostId': '8cLeGAmw098X5cv4Zkwcmo8vvZa3eH3eKxsPzbB9wrR+YstdA6Knx4Ip8EXAMPLE',
            'Service': 'Amazon S3',
            'Bucket': 'bucketname',
            'RequestId': '5582815E1AEA5ADF',
            'Time': '2014-10-13T15:57:02.089Z',
            'Event': 's3:TestEvent'})}
        self.client.received_messages = [s3_test_event]
        unique_buckets = self.client.unique_s3_buckets_and_keys()

        assert_false(unique_buckets)
        assert_true(log_mock.called_with(
            'Skipping S3 bucket notification test event'))

    @patch('logging.Logger.error')
    def test_unique_s3_buckets_and_keys_invalid_record(self, log_mock):
        """Athena SQS - Unique Buckets - Missing Records Key in SQS Message"""
        self.client.received_messages = [{'Body': '{"missing-records-key": 1}'}]
        unique_buckets = self.client.unique_s3_buckets_and_keys()

        assert_false(unique_buckets)
        assert_true(log_mock.called)

    @patch('logging.Logger.info')
    @patch('logging.Logger.debug')
    def test_unique_s3_buckets_and_keys_non_s3_notification(self, log_debug_mock, log_info_mock):
        """Athena SQS - Unique Buckets - Non S3 Notification"""
        self.client.received_messages = [{'Body': '{"Records": [{"kinesis": 1}]}'}]
        unique_buckets = self.client.unique_s3_buckets_and_keys()

        assert_false(unique_buckets)
        assert_true(log_debug_mock.called)
        log_info_mock.assert_called_with('Skipping non-s3 bucket notification message')

    @patch('logging.Logger.error')
    def test_unique_s3_buckets_and_keys_no_mesages(self, log_mock):
        """Athena SQS - Unique Buckets - No Receieved Messages"""
        self.client.received_messages = []
        unique_buckets = self.client.unique_s3_buckets_and_keys()

        assert_is_none(unique_buckets)
        assert_true(log_mock.called)