コード例 #1
0
    def test_run_threat_intel_enabled(self, mock_threat_intel, mock_query): # pylint: disable=no-self-use
        """StreamAlert Class - Run SA when threat intel enabled"""
        @rule(datatypes=['sourceAddress'], outputs=['s3:sample_bucket'])
        def match_ipaddress(_): # pylint: disable=unused-variable
            """Testing dummy rule"""
            return True

        mock_threat_intel.return_value = StreamThreatIntel('test_table_name', 'us-east-1')
        mock_query.return_value = ([], [])

        sa_handler = StreamAlert(get_mock_context(), False)
        event = {
            'account': 123456,
            'region': '123456123456',
            'source': '1.1.1.2',
            'detail': {
                'eventName': 'ConsoleLogin',
                'sourceIPAddress': '1.1.1.2',
                'recipientAccountId': '654321'
            }
        }
        events = []
        for i in range(10):
            event['source'] = '1.1.1.{}'.format(i)
            events.append(event)

        kinesis_events = {
            'Records': [make_kinesis_raw_record('test_kinesis_stream', json.dumps(event))
                        for event in events]
        }

        passed = sa_handler.run(kinesis_events)
        assert_true(passed)

        assert_equal(mock_query.call_count, 1)
コード例 #2
0
ファイル: test.py プロジェクト: etsangsplk/streamalert
    def __init__(self, context, config, print_output):
        """RuleProcessorTester initializer

        Args:
            print_output (bool): Whether this processor test
                should print results to stdout. This is set to false when the
                alert processor is explicitly being testing alone, and set to
                true for rule processor tests and end-to-end tests.
                Warnings and errors captrued during rule processor testing
                will still be written to stdout regardless of this setting.
        """
        # Create the RuleProcessor. Passing a mocked context object with fake
        # values and False for suppressing sending of alerts to alert processor
        self.processor = StreamAlert(context)
        self.cli_config = config
        # Use a list of status_messages to store pass/fail/warning info
        self.status_messages = []
        self.total_tests = 0
        self.all_tests_passed = True
        self.print_output = print_output
        # Configure mocks for Firehose and DDB
        helpers.setup_mock_firehose_delivery_streams(config)
        helpers.setup_mock_dynamodb_ioc_table(config)
        # Create a cache map of parsers to parser classes
        self.parsers = {}

        # Patch the tmp shredding as to not slow down testing
        patch(
            'stream_alert.rule_processor.payload.S3Payload._shred_temp_directory'
        ).start()

        # Patch random_bool to always return true
        patch('helpers.base.random_bool', return_value=True).start()
コード例 #3
0
def test_rule(rule_name, test_record, formatted_record):
    """Feed formatted records into StreamAlert and check for alerts
    Args:
        rule_name: The rule name being tested
        test_record: A single record to test
        formatted_record: A properly formatted version of record for the service to be tested

    Returns:
        boolean indicating if this rule passed
    """
    event = {'Records': [formatted_record]}

    trigger_count = test_record.get('trigger_count')
    if trigger_count:
        expected_alert_count = trigger_count
    else:
        expected_alert_count = (0, 1)[test_record['trigger']]

    alerts = StreamAlert(return_alerts=True).run(event, None)
    # we only want alerts for the specific rule passed in
    matched_alert_count = len(
        [x for x in alerts if x['rule_name'] == rule_name])

    report_output([test_record['service'], test_record['description']],
                  matched_alert_count != expected_alert_count)

    return matched_alert_count == expected_alert_count
コード例 #4
0
def handler(event, context):
    """Main Lambda handler function"""
    try:
        StreamAlert(context).run(event)
    except Exception:
        LOGGER.error('Invocation event: %s', json.dumps(event))
        raise
コード例 #5
0
ファイル: test.py プロジェクト: mchaffie/streamalert
    def test_rule(rule_name, test_record, formatted_record):
        """Feed formatted records into StreamAlert and check for alerts
        Args:
            rule_name [str]: The rule name being tested
            test_record [dict]: A single record to test
            formatted_record [dict]: A properly formatted version of
                record for the service to be tested

        Returns:
            [bool] boolean indicating if this rule passed
        """
        event = {'Records': [formatted_record]}

        expected_alert_count = test_record.get('trigger_count')
        if not expected_alert_count:
            expected_alert_count = (0, 1)[test_record['trigger']]

        # Run the rule processor. Passing 'None' for context
        # will load a mocked object later
        alerts = StreamAlert(None, True).run(event)

        # we only want alerts for the specific rule being tested
        alerts = [
            alert for alert in alerts
            if alert['metadata']['rule_name'] == rule_name
        ]

        return alerts, expected_alert_count
コード例 #6
0
    def __init__(self, context, print_output):
        """RuleProcessorTester initializer

        Args:
            print_output (bool): Whether this processor test
                should print results to stdout. This is set to false when the
                alert processor is explicitly being testing alone, and set to
                true for rule processor tests and end-to-end tests.
                Warnings and errors captrued during rule processor testing
                will still be written to stdout regardless of this setting.
        """
        # Create the RuleProcessor. Passing a mocked context object with fake
        # values and False for suppressing sending of alerts to alert processor
        self.processor = StreamAlert(context, False)
        # Use a list of status_messages to store pass/fail/warning info
        self.status_messages = []
        self.total_tests = 0
        self.all_tests_passed = True
        self.print_output = print_output
コード例 #7
0
ファイル: test.py プロジェクト: VVMichaelSawyer/streamalert
    def test_rule(self, rule_name, test_record, formatted_record):
        """Feed formatted records into StreamAlert and check for alerts
        Args:
            rule_name [str]: The rule name being tested
            test_record [dict]: A single record to test
            formatted_record [dict]: A dictionary that includes the 'data' from the
                test record, formatted into a structure that is resemblant of how
                an incoming record from a service would format it.
                See test/integration/templates for example of how each service
                formats records.

        Returns:
            [list] alerts that hit for this rule
            [integer] count of expected alerts for this rule
            [bool] boolean where False indicates errors occurred during processing
        """
        event = {'Records': [formatted_record]}

        expected_alert_count = test_record.get('trigger_count')
        if not expected_alert_count:
            expected_alert_count = 1 if test_record['trigger'] else 0

        # Run the rule processor. Passing mocked context object with fake
        # values and False for suppressing sending of alerts
        processor = StreamAlert(self.context, False)
        all_records_matched_schema = processor.run(event)

        if not all_records_matched_schema:
            payload = StreamPayload(raw_record=formatted_record)
            classifier = StreamClassifier(config=load_config())
            classifier.map_source(payload)
            logs = classifier._log_metadata()
            self.analyze_record_delta(logs, rule_name, test_record)

        alerts = processor.get_alerts()

        # we only want alerts for the specific rule being tested
        alerts = [alert for alert in alerts
                  if alert['rule_name'] == rule_name]

        return alerts, expected_alert_count, all_records_matched_schema
コード例 #8
0
ファイル: test.py プロジェクト: cabecada/streamalert
def test_rule(rule_name, test_record, formatted_record):
    """Feed formatted records into StreamAlert and check for alerts
    Args:
        rule_name: The rule name being tested
        test_record: A single record to test
        formatted_record: A properly formatted version of record for the service to be tested

    Returns:
        boolean indicating if this rule passed
    """
    event = {'Records': [formatted_record]}

    trigger_count = test_record.get('trigger_count')
    if trigger_count:
        expected_alert_count = trigger_count
    else:
        expected_alert_count = (0, 1)[test_record['trigger']]

    # Start mocked sns
    BOTO_MOCKER_SNS.start()

    # Create the topic used for the mocking of alert sending
    boto3.client('sns', region_name='us-east-1').create_topic(Name='test_streamalerts')

    # Run the rule processor. Passing 'None' for context will load a mocked object later
    alerts = StreamAlert(None, True).run(event)

    # Stop mocked sns
    BOTO_MOCKER_SNS.stop()

    # we only want alerts for the specific rule passed in
    matched_alert_count = len([x for x in alerts if x['metadata']['rule_name'] == rule_name])

    report_output([test_record['service'], test_record['description']],
                  matched_alert_count != expected_alert_count)

    return matched_alert_count == expected_alert_count
コード例 #9
0
 def test_do_not_invoke_threat_intel(self, load_intelligence_mock):
     """StreamAlert Class - Invoke load_intelligence"""
     self.__sa_handler = StreamAlert(get_mock_context(), False)
     load_intelligence_mock.assert_called()
コード例 #10
0
class TestStreamAlert(object):
    """Test class for StreamAlert class"""
    def __init__(self):
        self.__sa_handler = None

    @patch('stream_alert.rule_processor.handler.load_config',
           lambda: load_config('tests/unit/conf/'))
    def setup(self):
        """Setup before each method"""
        self.__sa_handler = StreamAlert(get_mock_context(), False)

    def test_run_no_records(self):
        """StreamAlert Class - Run, No Records"""
        passed = self.__sa_handler.run({'Records': []})
        assert_false(passed)

    def test_get_alerts(self):
        """StreamAlert Class - Get Alerts"""
        default_list = ['alert1', 'alert2']
        self.__sa_handler._alerts = default_list

        assert_list_equal(self.__sa_handler.get_alerts(), default_list)

    @patch('stream_alert.rule_processor.handler.StreamClassifier.load_sources')
    @patch(
        'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity'
    )
    def test_run_no_sources(self, extract_mock, load_sources_mock):
        """StreamAlert Class - Run, No Loaded Sources"""
        extract_mock.return_value = ('lambda', 'entity')
        load_sources_mock.return_value = None

        self.__sa_handler.run({'Records': ['record']})

        load_sources_mock.assert_called_with('lambda', 'entity')

    @patch('logging.Logger.error')
    @patch(
        'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity'
    )
    def test_run_bad_service(self, extract_mock, log_mock):
        """StreamAlert Class - Run, Bad Service"""
        extract_mock.return_value = ('', 'entity')

        self.__sa_handler.run({'Records': ['record']})

        log_mock.assert_called_with(
            'No valid service found in payload\'s raw record. '
            'Skipping record: %s', 'record')

    @patch('logging.Logger.error')
    @patch(
        'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity'
    )
    def test_run_bad_entity(self, extract_mock, log_mock):
        """StreamAlert Class - Run, Bad Entity"""
        extract_mock.return_value = ('kinesis', '')

        self.__sa_handler.run({'Records': ['record']})

        log_mock.assert_called_with(
            'Unable to extract entity from payload\'s raw record for '
            'service %s. Skipping record: %s', 'kinesis', 'record')

    @patch('stream_alert.rule_processor.handler.load_stream_payload')
    @patch('stream_alert.rule_processor.handler.StreamClassifier.load_sources')
    @patch(
        'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity'
    )
    def test_run_load_payload_bad(self, extract_mock, load_sources_mock,
                                  load_payload_mock):
        """StreamAlert Class - Run, Loaded Payload Fail"""
        extract_mock.return_value = ('lambda', 'entity')
        load_sources_mock.return_value = True

        self.__sa_handler.run({'Records': ['record']})

        load_payload_mock.assert_called_with('lambda', 'entity', 'record')

    @patch('stream_alert.rule_processor.handler.StreamRules.process')
    @patch(
        'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity'
    )
    def test_run_with_alert(self, extract_mock, rules_mock):
        """StreamAlert Class - Run, With Alert"""
        extract_mock.return_value = ('kinesis', 'unit_test_default_stream')
        rules_mock.return_value = ['success!!']

        passed = self.__sa_handler.run(get_valid_event())

        assert_true(passed)

    @patch('logging.Logger.debug')
    @patch(
        'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity'
    )
    def test_run_no_alerts(self, extract_mock, log_mock):
        """StreamAlert Class - Run, With No Alerts"""
        extract_mock.return_value = ('kinesis', 'unit_test_default_stream')
        self.__sa_handler.run(get_valid_event())

        calls = [
            call('Processed %d valid record(s) that resulted in %d alert(s).',
                 1, 0),
            call('Invalid record count: %d', 0),
            call('%s alerts triggered', 0)
        ]

        log_mock.assert_has_calls(calls)

    @patch('logging.Logger.error')
    @patch(
        'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity'
    )
    def test_run_invalid_data(self, extract_mock, log_mock):
        """StreamAlert Class - Run, Invalid Data"""
        extract_mock.return_value = ('kinesis', 'unit_test_default_stream')
        event = get_valid_event()

        # Replace the good log data with bad data
        event['Records'][0]['kinesis']['data'] = base64.b64encode(
            '{"bad": "data"}')

        # Swap out the alias so the logging occurs
        self.__sa_handler.env['lambda_alias'] = 'production'
        self.__sa_handler.run(event)

        assert_equal(log_mock.call_args[0][0],
                     'Record does not match any defined schemas: %s\n%s')
        assert_equal(log_mock.call_args[0][2], '{"bad": "data"}')

    @patch('stream_alert.rule_processor.sink.StreamSink.sink')
    @patch('stream_alert.rule_processor.handler.StreamRules.process')
    @patch(
        'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity'
    )
    def test_run_send_alerts(self, extract_mock, rules_mock, sink_mock):
        """StreamAlert Class - Run, Send Alert"""
        extract_mock.return_value = ('kinesis', 'unit_test_default_stream')
        rules_mock.return_value = ['success!!']

        # Set send_alerts to true so the sink happens
        self.__sa_handler.enable_alert_processor = True

        # Swap out the alias so the logging occurs
        self.__sa_handler.env['lambda_alias'] = 'production'

        self.__sa_handler.run(get_valid_event())

        sink_mock.assert_called_with(['success!!'])

    @patch('logging.Logger.debug')
    @patch('stream_alert.rule_processor.handler.StreamRules.process')
    @patch(
        'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity'
    )
    def test_run_debug_log_alert(self, extract_mock, rules_mock, log_mock):
        """StreamAlert Class - Run, Debug Log Alert"""
        extract_mock.return_value = ('kinesis', 'unit_test_default_stream')
        rules_mock.return_value = ['success!!']

        # Cache the logger level
        log_level = LOGGER.getEffectiveLevel()

        # Increase the logger level to debug
        LOGGER.setLevel(logging.DEBUG)

        self.__sa_handler.run(get_valid_event())

        # Reset the logger level
        LOGGER.setLevel(log_level)

        log_mock.assert_called_with('Alerts:\n%s', '[\n  "success!!"\n]')

    @patch('stream_alert.rule_processor.handler.load_stream_payload')
    @patch('stream_alert.rule_processor.handler.StreamClassifier.load_sources')
    @patch(
        'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity'
    )
    def test_run_no_payload_class(self, extract_mock, load_sources_mock,
                                  load_payload_mock):
        """StreamAlert Class - Run, No Payload Class"""
        extract_mock.return_value = ('blah', 'entity')
        load_sources_mock.return_value = True
        load_payload_mock.return_value = None

        self.__sa_handler.run({'Records': ['record']})

        load_payload_mock.assert_called()

    @patch('stream_alert.rule_processor.handler.LOGGER')
    @mock_kinesis
    def test_firehose_record_delivery(self, mock_logging):
        """StreamAlert Class - Firehose Record Delivery"""
        self.__sa_handler.firehose_client = boto3.client(
            'firehose', region_name='us-east-1')

        test_event = convert_events_to_kinesis([
            # unit_test_simple_log
            {
                'unit_key_01': 1,
                'unit_key_02': 'test'
            },
            {
                'unit_key_01': 2,
                'unit_key_02': 'test'
            },
            # test_log_type_json_nested
            {
                'date': 'January 01, 3005',
                'unixtime': '32661446400',
                'host': 'my-host.name.website.com',
                'data': {
                    'super': 'secret'
                }
            }
        ])

        delivery_stream_names = [
            'streamalert_data_test_log_type_json_nested',
            'streamalert_data_unit_test_simple_log'
        ]

        # Setup mock delivery streams
        for delivery_stream in delivery_stream_names:
            self.__sa_handler.firehose_client.create_delivery_stream(
                DeliveryStreamName=delivery_stream,
                S3DestinationConfiguration={
                    'RoleARN':
                    'arn:aws:iam::123456789012:role/firehose_delivery_role',
                    'BucketARN': 'arn:aws:s3:::kinesis-test',
                    'Prefix': '{}/'.format(delivery_stream),
                    'BufferingHints': {
                        'SizeInMBs': 123,
                        'IntervalInSeconds': 124
                    },
                    'CompressionFormat': 'Snappy',
                })

        with patch.object(self.__sa_handler.firehose_client,
                          'put_record_batch') as firehose_mock:
            firehose_mock.return_value = {'FailedPutCount': 0}
            self.__sa_handler.run(test_event)

            firehose_mock.assert_called()
            assert_true(mock_logging.info.called)

    @patch('stream_alert.rule_processor.handler.LOGGER')
    def test_firehose_limit_record_size(self, mock_logging):
        """StreamAlert Class - Firehose - Record Size Check"""
        test_events = [
            # unit_test_simple_log
            {
                'unit_key_01': 1,
                'unit_key_02': 'test' * 250001  # is 4 bytes higher than max
            },
            {
                'unit_key_01': 2,
                'unit_key_02': 'test'
            },
            # test_log_type_json_nested
            {
                'date': 'January 01, 3005',
                'unixtime': '32661446400',
                'host': 'my-host.name.website.com',
                'data': {
                    'super': 'secret'
                }
            }
        ]

        self.__sa_handler._limit_record_size(test_events)

        assert_true(len(test_events), 2)
        assert_true(mock_logging.error.called)

    @patch('stream_alert.rule_processor.handler.LOGGER')
    @mock_kinesis
    def test_firehose_record_delivery_failure(self, mock_logging):
        """StreamAlert Class - Firehose Record Delivery - Failed PutRecord"""
        class MockFirehoseClient(object):
            @staticmethod
            def put_record_batch(**kwargs):
                return {
                    'FailedPutCount':
                    len(kwargs.get('Records')),
                    'RequestResponses': [
                        {
                            'RecordId': '12345',
                            'ErrorCode': '300',
                            'ErrorMessage': 'Bad message!!!'
                        },
                    ]
                }

        self.__sa_handler.firehose_client = MockFirehoseClient()

        test_event = convert_events_to_kinesis([
            # unit_test_simple_log
            {
                'unit_key_01': 1,
                'unit_key_02': 'test'
            },
            {
                'unit_key_01': 2,
                'unit_key_02': 'test'
            },
            # test_log_type_json_nested
            {
                'date': 'January 01, 3005',
                'unixtime': '32661446400',
                'host': 'my-host.name.website.com',
                'data': {
                    'super': 'secret'
                }
            }
        ])

        self.__sa_handler.run(test_event)
        assert_true(mock_logging.error.called)

    @patch(
        'stream_alert.rule_processor.handler.StreamThreatIntel.load_intelligence'
    )
    def test_do_not_invoke_threat_intel(self, load_intelligence_mock):
        """StreamAlert Class - Invoke load_intelligence"""
        self.__sa_handler = StreamAlert(get_mock_context(), False)
        load_intelligence_mock.assert_called()

    def test_firehose_sanitize_keys(self):
        """StreamAlert Class - Firehose - Sanitize Keys"""
        # test_log_type_json_nested
        test_event = {
            'date': 'January 01, 3005',
            'unixtime': '32661446400',
            'host': 'my-host.name.website.com',
            'data': {
                'super-duper': 'secret',
                'sanitize_me': 1,
                'example-key': 1,
                'moar**data': 2,
                'even.more': 3
            }
        }

        expected_sanitized_event = {
            'date': 'January 01, 3005',
            'unixtime': '32661446400',
            'host': 'my-host.name.website.com',
            'data': {
                'super_duper': 'secret',
                'sanitize_me': 1,
                'example_key': 1,
                'moar__data': 2,
                'even_more': 3
            }
        }

        sanitized_event = self.__sa_handler.sanitize_keys(test_event)
        assert_equal(sanitized_event, expected_sanitized_event)

    def test_firehose_segment_records_by_size(self):
        """StreamAlert Class - Firehose - Segment Large Records"""

        record_batch = [
            # unit_test_simple_log
            {
                'unit_key_01': 2,
                'unit_key_02': 'testtest' * 10000
            } for _ in range(100)
        ]

        sized_batches = []

        for sized_batch in self.__sa_handler._segment_records_by_size(
                record_batch):
            sized_batches.append(sized_batch)

        assert_true(len(str(sized_batches[0])) < 4000000)
        assert_equal(len(sized_batches), 4)
        assert_true(isinstance(sized_batches[3][0], dict))

    @mock_kinesis
    def test_firehose_record_delivery_disabled_logs(self):
        """StreamAlert Class - Firehose Record Delivery - Disabled Logs"""
        self.__sa_handler.firehose_client = boto3.client(
            'firehose', region_name='us-east-1')

        test_event = convert_events_to_kinesis([
            # unit_test_simple_log
            {
                'unit_key_01': 2,
                'unit_key_02': 'testtest'
            } for _ in range(10)
        ])

        delivery_stream_names = ['streamalert_data_unit_test_simple_log']

        # Setup mock delivery streams
        for delivery_stream in delivery_stream_names:
            self.__sa_handler.firehose_client.create_delivery_stream(
                DeliveryStreamName=delivery_stream,
                S3DestinationConfiguration={
                    'RoleARN':
                    'arn:aws:iam::123456789012:role/firehose_delivery_role',
                    'BucketARN': 'arn:aws:s3:::kinesis-test',
                    'Prefix': '{}/'.format(delivery_stream),
                    'BufferingHints': {
                        'SizeInMBs': 123,
                        'IntervalInSeconds': 124
                    },
                    'CompressionFormat': 'Snappy',
                })

        with patch.object(self.__sa_handler.firehose_client,
                          'put_record_batch') as firehose_mock:
            firehose_mock.return_value = {'FailedPutCount': 0}

            self.__sa_handler.config['global']['infrastructure'][
                'firehose'] = {
                    'disabled_logs': ['unit_test_simple_log']
                }
            self.__sa_handler.run(test_event)

            firehose_mock.assert_not_called()

    @patch('stream_alert.rule_processor.handler.LOGGER')
    @mock_kinesis
    def test_firehose_record_delivery_client_errorr(self, mock_logging):
        """StreamAlert Class - Firehose Record Delivery - Client Error"""
        self.__sa_handler.firehose_client = boto3.client(
            'firehose', region_name='us-east-1')

        test_events = [
            # unit_test_simple_log
            {
                'unit_key_01': 2,
                'unit_key_02': 'testtest'
            } for _ in range(10)
        ]

        self.__sa_handler._firehose_request_helper('invalid_stream',
                                                   test_events)

        missing_stream_message = 'Client Error ... An error occurred ' \
        '(ResourceNotFoundException) when calling the PutRecordBatch ' \
        'operation: Stream invalid_stream under account 123456789012 not found.'
        assert_true(mock_logging.error.called_with(missing_stream_message))
コード例 #11
0
def athena_handler(options):
    """Handle Athena operations"""
    athena_client = StreamAlertAthenaClient(
        CONFIG, results_key_prefix='stream_alert_cli')

    if options.subcommand == 'init':
        CONFIG.generate_athena()

    elif options.subcommand == 'enable':
        CONFIG.set_athena_lambda_enable()

    elif options.subcommand == 'create-db':
        if athena_client.check_database_exists():
            LOGGER_CLI.info(
                'The \'streamalert\' database already exists, nothing to do')
            return

        create_db_success, create_db_result = athena_client.run_athena_query(
            query='CREATE DATABASE streamalert')

        if create_db_success and create_db_result['ResultSet'].get('Rows'):
            LOGGER_CLI.info('streamalert database successfully created!')
            LOGGER_CLI.info('results: %s',
                            create_db_result['ResultSet']['Rows'])

    elif options.subcommand == 'create-table':
        if not options.bucket:
            LOGGER_CLI.error('Missing command line argument --bucket')
            return

        if not options.refresh_type:
            LOGGER_CLI.error('Missing command line argument --refresh_type')
            return

        if options.type == 'data':
            if not options.table_name:
                LOGGER_CLI.error('Missing command line argument --table_name')
                return

            if options.table_name not in enabled_firehose_logs(CONFIG):
                LOGGER_CLI.error(
                    'Table name %s missing from configuration or '
                    'is not enabled.', options.table_name)
                return

            if athena_client.check_table_exists(options.table_name):
                LOGGER_CLI.info('The \'%s\' table already exists.',
                                options.table_name)
                return

            log_info = CONFIG['logs'][options.table_name.replace('_', ':', 1)]
            schema = dict(log_info['schema'])
            schema_statement = ''

            sanitized_schema = StreamAlert.sanitize_keys(schema)

            athena_schema = {}
            schema_type_mapping = {
                'string': 'string',
                'integer': 'int',
                'boolean': 'boolean',
                'float': 'decimal',
                dict: 'map<string, string>',
                list: 'array<string>'
            }

            def add_to_athena_schema(schema, root_key=''):
                """Helper function to add sanitized schemas to the Athena table schema"""
                # Setup the root_key dict
                if root_key and not athena_schema.get(root_key):
                    athena_schema[root_key] = {}

                for key_name, key_type in schema.iteritems():
                    # When using special characters in the beginning or end
                    # of a column name, they have to be wrapped in backticks
                    key_name = '`{}`'.format(key_name)

                    special_key = None
                    # Transform the {} or [] into hashable types
                    if key_type == {}:
                        special_key = dict
                    elif key_type == []:
                        special_key = list
                    # Cast nested dict as a string for now
                    # TODO(jacknagz): support recursive schemas
                    elif isinstance(key_type, dict):
                        special_key = 'string'

                    # Account for envelope keys
                    if root_key:
                        if special_key is not None:
                            athena_schema[root_key][
                                key_name] = schema_type_mapping[special_key]
                        else:
                            athena_schema[root_key][
                                key_name] = schema_type_mapping[key_type]
                    else:
                        if special_key is not None:
                            athena_schema[key_name] = schema_type_mapping[
                                special_key]
                        else:
                            athena_schema[key_name] = schema_type_mapping[
                                key_type]

            add_to_athena_schema(sanitized_schema)

            # Support envelope keys
            configuration_options = log_info.get('configuration')
            if configuration_options:
                envelope_keys = configuration_options.get('envelope_keys')
                if envelope_keys:
                    sanitized_envelope_keys = StreamAlert.sanitize_keys(
                        envelope_keys)
                    # Note: this key is wrapped in backticks to be Hive compliant
                    add_to_athena_schema(sanitized_envelope_keys,
                                         '`streamalert:envelope_keys`')

            for key_name, key_type in athena_schema.iteritems():
                # Account for nested structs
                if isinstance(key_type, dict):
                    struct_schema = ''.join([
                        '{0}:{1},'.format(sub_key, sub_type)
                        for sub_key, sub_type in key_type.iteritems()
                    ])
                    nested_schema_statement = '{0} struct<{1}>, '.format(
                        key_name,
                        # Use the minus index to remove the last comma
                        struct_schema[:-1])
                    schema_statement += nested_schema_statement
                else:
                    schema_statement += '{0} {1},'.format(key_name, key_type)

            query = (
                'CREATE EXTERNAL TABLE {table_name} ({schema}) '
                'PARTITIONED BY (dt string) '
                'ROW FORMAT SERDE \'org.openx.data.jsonserde.JsonSerDe\' '
                'LOCATION \'s3://{bucket}/{table_name}/\''.format(
                    table_name=options.table_name,
                    # Use the minus index to remove the last comma
                    schema=schema_statement[:-1],
                    bucket=options.bucket))

        elif options.type == 'alerts':
            if athena_client.check_table_exists(options.type):
                LOGGER_CLI.info('The \'alerts\' table already exists.')
                return

            query = ('CREATE EXTERNAL TABLE alerts ('
                     'log_source string,'
                     'log_type string,'
                     'outputs array<string>,'
                     'record string,'
                     'rule_description string,'
                     'rule_name string,'
                     'source_entity string,'
                     'source_service string)'
                     'PARTITIONED BY (dt string)'
                     'ROW FORMAT SERDE \'org.openx.data.jsonserde.JsonSerDe\''
                     'LOCATION \'s3://{bucket}/alerts/\''.format(
                         bucket=options.bucket))

        if query:
            create_table_success, _ = athena_client.run_athena_query(
                query=query, database='streamalert')

            if create_table_success:
                CONFIG['lambda']['athena_partition_refresh_config'] \
                      ['refresh_type'][options.refresh_type][options.bucket] = options.type
                CONFIG.write()
                table_name = options.type if options.type == 'alerts' else options.table_name
                LOGGER_CLI.info('The %s table was successfully created!',
                                table_name)
コード例 #12
0
ファイル: runner.py プロジェクト: ryandeivert/streamalert
def athena_handler(options):
    """Handle Athena operations"""
    athena_client = StreamAlertAthenaClient(
        CONFIG, results_key_prefix='stream_alert_cli')

    if options.subcommand == 'init':
        CONFIG.generate_athena()

    elif options.subcommand == 'enable':
        CONFIG.set_athena_lambda_enable()

    elif options.subcommand == 'create-db':
        if athena_client.check_database_exists():
            LOGGER_CLI.info(
                'The \'streamalert\' database already exists, nothing to do')
            return

        create_db_success, create_db_result = athena_client.run_athena_query(
            query='CREATE DATABASE streamalert')

        if create_db_success and create_db_result['ResultSet'].get('Rows'):
            LOGGER_CLI.info('streamalert database successfully created!')
            LOGGER_CLI.info('results: %s',
                            create_db_result['ResultSet']['Rows'])

    elif options.subcommand == 'create-table':
        if not options.bucket:
            LOGGER_CLI.error('Missing command line argument --bucket')
            return

        if not options.refresh_type:
            LOGGER_CLI.error('Missing command line argument --refresh_type')
            return

        if options.type == 'data':
            if not options.table_name:
                LOGGER_CLI.error('Missing command line argument --table_name')
                return

            if options.table_name not in enabled_firehose_logs(CONFIG):
                LOGGER_CLI.error(
                    'Table name %s missing from configuration or '
                    'is not enabled.', options.table_name)
                return

            if athena_client.check_table_exists(options.table_name):
                LOGGER_CLI.info('The \'%s\' table already exists.',
                                options.table_name)
                return

            schema = CONFIG['logs'][options.table_name.replace('_',
                                                               ':')]['schema']
            sanitized_schema = StreamAlert.sanitize_keys(schema)

            athena_schema = {}
            schema_type_mapping = {
                'string': 'string',
                'integer': 'int',
                'boolean': 'boolean',
                'float': 'decimal',
                dict: 'map<string, string>',
                list: 'array<string>'
            }

            for key_name, key_type in sanitized_schema.iteritems():
                # Transform the {} or [] into hashable types
                if key_type == {}:
                    key_type = dict
                elif key_type == []:
                    key_type = list

                athena_schema[key_name] = schema_type_mapping[key_type]

            schema_statement = ''.join([
                '{0} {1},'.format(key_name, key_type)
                for key_name, key_type in athena_schema.iteritems()
            ])[:-1]
            query = ('CREATE EXTERNAL TABLE {table_name} ({schema})'
                     'PARTITIONED BY (dt string)'
                     'ROW FORMAT SERDE \'org.openx.data.jsonserde.JsonSerDe\''
                     'LOCATION \'s3://{bucket}/{table_name}/\''.format(
                         table_name=options.table_name,
                         schema=schema_statement,
                         bucket=options.bucket))

        elif options.type == 'alerts':
            if athena_client.check_table_exists(options.type):
                LOGGER_CLI.info('The \'alerts\' table already exists.')
                return

            query = ('CREATE EXTERNAL TABLE alerts ('
                     'log_source string,'
                     'log_type string,'
                     'outputs array<string>,'
                     'record string,'
                     'rule_description string,'
                     'rule_name string,'
                     'source_entity string,'
                     'source_service string)'
                     'PARTITIONED BY (dt string)'
                     'ROW FORMAT SERDE \'org.openx.data.jsonserde.JsonSerDe\''
                     'LOCATION \'s3://{bucket}/alerts/\''.format(
                         bucket=options.bucket))

        if query:
            create_table_success, _ = athena_client.run_athena_query(
                query=query, database='streamalert')

            if create_table_success:
                CONFIG['lambda']['athena_partition_refresh_config'] \
                      ['refresh_type'][options.refresh_type][options.bucket] = options.type
                CONFIG.write()
                LOGGER_CLI.info('The %s table was successfully created!',
                                options.type)
コード例 #13
0
def create_table(athena_client, options, config):
    """Create a 'streamalert' Athena table

    Args:
        athena_client (boto3.client): Instantiated CLI AthenaClient
        options (namedtuple): The parsed args passed from the CLI
        config (CLIConfig): Loaded StreamAlert CLI
    """
    if not options.bucket:
        LOGGER_CLI.error('Missing command line argument --bucket')
        return

    if not options.refresh_type:
        LOGGER_CLI.error('Missing command line argument --refresh_type')
        return

    if options.type == 'data':
        if not options.table_name:
            LOGGER_CLI.error('Missing command line argument --table_name')
            return

        if options.table_name not in terraform_cli_helpers.enabled_firehose_logs(
                config):
            LOGGER_CLI.error(
                'Table name %s missing from configuration or '
                'is not enabled.', options.table_name)
            return

        if athena_client.check_table_exists(options.table_name):
            LOGGER_CLI.info('The \'%s\' table already exists.',
                            options.table_name)
            return

        log_info = config['logs'][options.table_name.replace('_', ':', 1)]
        schema = dict(log_info['schema'])
        schema_statement = ''

        sanitized_schema = StreamAlert.sanitize_keys(schema)
        athena_schema = {}

        _add_to_athena_schema(sanitized_schema, athena_schema)

        # Support envelope keys
        configuration_options = log_info.get('configuration')
        if configuration_options:
            envelope_keys = configuration_options.get('envelope_keys')
            if envelope_keys:
                sanitized_envelope_key_schema = StreamAlert.sanitize_keys(
                    envelope_keys)
                # Note: this key is wrapped in backticks to be Hive compliant
                _add_to_athena_schema(sanitized_envelope_key_schema,
                                      athena_schema,
                                      '`streamalert:envelope_keys`')

        for key_name, key_type in athena_schema.iteritems():
            # Account for nested structs
            if isinstance(key_type, dict):
                struct_schema = ''.join([
                    '{0}:{1},'.format(sub_key, sub_type)
                    for sub_key, sub_type in key_type.iteritems()
                ])
                nested_schema_statement = '{0} struct<{1}>, '.format(
                    key_name,
                    # Use the minus index to remove the last comma
                    struct_schema[:-1])
                schema_statement += nested_schema_statement
            else:
                schema_statement += '{0} {1},'.format(key_name, key_type)

        query = (
            'CREATE EXTERNAL TABLE {table_name} ({schema}) '
            'PARTITIONED BY (dt string) '
            'ROW FORMAT SERDE \'org.openx.data.jsonserde.JsonSerDe\' '
            'WITH SERDEPROPERTIES ( \'ignore.malformed.json\' = \'true\') '
            'LOCATION \'s3://{bucket}/{table_name}/\''.format(
                table_name=options.table_name,
                # Use the minus index to remove the last comma
                schema=schema_statement[:-1],
                bucket=options.bucket))

    elif options.type == 'alerts':
        if athena_client.check_table_exists(options.type):
            LOGGER_CLI.info('The \'alerts\' table already exists.')
            return

        query = ('CREATE EXTERNAL TABLE alerts ('
                 'log_source string,'
                 'log_type string,'
                 'outputs array<string>,'
                 'record string,'
                 'rule_description string,'
                 'rule_name string,'
                 'source_entity string,'
                 'source_service string)'
                 'PARTITIONED BY (dt string)'
                 'ROW FORMAT SERDE \'org.openx.data.jsonserde.JsonSerDe\''
                 'LOCATION \'s3://{bucket}/alerts/\''.format(
                     bucket=options.bucket))

    if query:
        create_table_success, _ = athena_client.run_athena_query(
            query=query, database='streamalert')

        if create_table_success:
            # Update the CLI config
            config['lambda']['athena_partition_refresh_config'] \
                  ['refresh_type'][options.refresh_type][options.bucket] = options.type
            config.write()

            table_name = options.type if options.type == 'alerts' else options.table_name
            LOGGER_CLI.info('The %s table was successfully created!',
                            table_name)
コード例 #14
0
 def test_run_config_error():
     """StreamAlert Class - Run, Config Error"""
     mock = mock_open(
         read_data='non-json string that will raise an exception')
     with patch('__builtin__.open', mock):
         StreamAlert(get_mock_context())
コード例 #15
0
 def setup(self):
     """Setup before each method"""
     self.__sa_handler = StreamAlert(get_mock_context(), False)
コード例 #16
0
ファイル: test.py プロジェクト: etsangsplk/streamalert
class RuleProcessorTester(object):
    """Class to encapsulate testing the rule processor"""
    def __init__(self, context, config, print_output):
        """RuleProcessorTester initializer

        Args:
            print_output (bool): Whether this processor test
                should print results to stdout. This is set to false when the
                alert processor is explicitly being testing alone, and set to
                true for rule processor tests and end-to-end tests.
                Warnings and errors captrued during rule processor testing
                will still be written to stdout regardless of this setting.
        """
        # Create the RuleProcessor. Passing a mocked context object with fake
        # values and False for suppressing sending of alerts to alert processor
        self.processor = StreamAlert(context)
        self.cli_config = config
        # Use a list of status_messages to store pass/fail/warning info
        self.status_messages = []
        self.total_tests = 0
        self.all_tests_passed = True
        self.print_output = print_output
        # Configure mocks for Firehose and DDB
        helpers.setup_mock_firehose_delivery_streams(config)
        helpers.setup_mock_dynamodb_ioc_table(config)
        # Create a cache map of parsers to parser classes
        self.parsers = {}

        # Patch the tmp shredding as to not slow down testing
        patch(
            'stream_alert.rule_processor.payload.S3Payload._shred_temp_directory'
        ).start()

        # Patch random_bool to always return true
        patch('helpers.base.random_bool', return_value=True).start()

    def test_processor(self, rules_filter, files_filter, validate_only):
        """Perform integration tests for the 'rule' Lambda function

        Args:
            rules_filter (set): A collection of rules to filter on, passed in by the user
                via the CLI using the --test-rules option.
            files_filter (set): A collection of files to filter on, passed in by the user
                via the CLI using the --test-files option.
            validate_only (bool): If true, validation of test records will occur
                without the rules engine being applied to events.

        Yields:
            tuple (bool, list) or None: If testing rules, this yields a tuple containig a
                boolean of test status and a list of alerts to run through the alert
                processor. If validating test records only, this does not yield.
        """
        test_file_info = self._filter_files(
            helpers.get_rule_test_files(TEST_EVENTS_DIR), files_filter)

        for name in sorted(test_file_info):
            path = test_file_info[name]

            events, error = helpers.load_test_file(path)
            if error is not None:
                self.all_tests_passed = False
                self.status_messages.append(
                    StatusMessage(StatusMessage.WARNING, error))
                continue

            print_header = True
            for test_event in events:
                self.total_tests += 1
                if self._detect_old_test_event(test_event):
                    self.all_tests_passed = False
                    message = (
                        'Detected old format for test event in file \'{}.json\'. '
                        'Please visit https://streamalert.io/rule-testing.html '
                        'for information on the new format and update your '
                        'test events accordingly.'.format(name))
                    self.status_messages.append(
                        StatusMessage(StatusMessage.FAILURE, message))
                    continue

                if not self.check_keys(test_event):
                    self.all_tests_passed = False
                    continue

                # Check if there are any rule filters in place, and if the current test event
                # should be exeecuted per the filter
                if rules_filter and set(
                        test_event['trigger_rules']).isdisjoint(rules_filter):
                    self.total_tests -= 1
                    continue

                self.apply_helpers(test_event)

                if 'override_record' in test_event:
                    self.apply_template(test_event)

                formatted_record = helpers.format_lambda_test_record(
                    test_event)

                # If this test is to validate the schema only, continue the loop and
                # do not yield results on the rule tests below
                if validate_only or (not validate_only and
                                     test_event.get('validate_schema_only')):
                    if self._validate_test_record(name, test_event,
                                                  formatted_record,
                                                  print_header) is False:
                        self.all_tests_passed = False
                else:
                    yield self._run_rule_tests(name, test_event,
                                               formatted_record, print_header)

                print_header = False

        # Report on the final test results
        self.report_output_summary()

    def _filter_files(self, file_info, files_filter):
        """Filter the test files based in input from the user

        Args:
            file_info (dict): Information about test files on disk, where the key is the
                base name of the file and the value is the relative path to the file
            files_filter (set): A collection of files to filter tests on

        Returns:
            dict: A modified version of the `file_info` arg with pared down values
        """
        if not files_filter:
            return file_info

        files_filter = {os.path.splitext(name)[0] for name in files_filter}

        file_info = {
            name: path
            for name, path in file_info.iteritems()
            if os.path.splitext(name)[0] in files_filter
        }

        filter_diff = set(files_filter).difference(set(file_info))
        message_template = 'No test events file found with base name \'{}\''
        for missing_file in filter_diff:
            self.status_messages.append(
                StatusMessage(StatusMessage.WARNING,
                              message_template.format(missing_file)))

        return file_info

    def _validate_test_record(self, file_name, test_event, formatted_record,
                              print_header_line):
        """Function to validate test records and log any errors

        Args:
            file_name (str): The base name of the test event file.
            test_event (dict): A single test event containing the record and other detail
            formatted_record (dict): A dictionary that includes the 'data' from the
                test record, formatted into a structure that is resemblant of how
                an incoming record from a service would format it.
                See test/integration/templates for example of how each service
                formats records.
            print_header_line (bool): Indicates if this is the first record from
                a test file, and therefore we should print some header information
        """
        service, entity = self.processor.classifier.extract_service_and_entity(
            formatted_record)

        if not self.processor.classifier.load_sources(service, entity):
            return False

        # Create the StreamPayload to use for encapsulating parsed info
        payload = load_stream_payload(service, entity, formatted_record)
        if not payload:
            return False

        if print_header_line:
            print '\n{}'.format(file_name)

        for record in payload.pre_parse():
            self.processor.classifier.classify_record(record)

            if not record.valid:
                self.all_tests_passed = False
                self.analyze_record_delta(file_name, test_event)

            report_output(record.valid, [
                '[log=\'{}\']'.format(record.log_source or 'unknown'),
                'validation',
                record.service(), test_event['description']
            ])

    def _run_rule_tests(self, file_name, test_event, formatted_record,
                        print_header_line):
        """Run tests on a test record for a given rule

        Args:
            file_name (str): The base name of the test event file.
            test_event (dict): The loaded test event from json
            formatted_record (dict): A dictionary that includes the 'data' from the
                test record, formatted into a structure that is resemblant of how
                an incoming record from a service would format it.
                See test/integration/templates for example of how each service
                formats records.
            print_header_line (bool): Indicates if this is the first record from
                a test file, and therefore we should print some header information

        Returns:
            list: alerts that were generated from this test event
        """
        event = {'Records': [formatted_record]}

        expected_alert_count = len(test_event['trigger_rules'])

        # Run tests on the formatted record
        alerts, all_records_matched_schema = self.test_rule(event)

        # Get a list of any rules that triggerer but are not defined in the 'trigger_rules'
        unexpected_alerts = []

        disabled_rules = [
            item for item in test_event['trigger_rules']
            if rule.Rule.get_rule(item).disabled
        ]

        expected_alert_count -= len(disabled_rules)

        triggers = set(test_event['trigger_rules']) - set(disabled_rules)
        # we only want alerts for the specific rule being tested (if trigger_rules are defined)
        if triggers:
            unexpected_alerts = [
                alert for alert in alerts if alert.rule_name not in triggers
            ]

            alerts = [alert for alert in alerts if alert.rule_name in triggers]

        alerted_properly = (len(alerts)
                            == expected_alert_count) and not unexpected_alerts
        current_test_passed = alerted_properly and all_records_matched_schema

        self.all_tests_passed = current_test_passed and self.all_tests_passed

        # Print rule name for section header, but only if we get
        # to a point where there is a record to actually be tested.
        # This avoids potentially blank sections
        if print_header_line and (alerts or self.print_output):
            print '\n{}'.format(file_name)

        if self.print_output:
            disabled_output = ''
            if disabled_rules:
                disabled_output = ',disabled={}'.format(len(disabled_rules))
            report_output(current_test_passed, [
                '[trigger={}{}]'.format(expected_alert_count, disabled_output),
                'rule', test_event['service'], test_event['description']
            ])

        # Add the status of the rule to messages list
        if not all_records_matched_schema:
            self.analyze_record_delta(file_name, test_event)
        elif not alerted_properly:
            message = ('Test failure: [{}.json] Test event with description '
                       '\'{}\'').format(file_name, test_event['description'])
            if alerts and not triggers:
                # If there was a failure due to alerts triggering for a test event
                # that does not have any trigger_rules configured
                context = 'is triggering the following rules but should not trigger at all: {}'
                trigger_rules = ', '.join('\'{}\''.format(alert.rule_name)
                                          for alert in alerts)
                message = '{} {}'.format(message,
                                         context.format(trigger_rules))
            elif unexpected_alerts:
                # If there was a failure due to alerts triggering for other rules outside
                # of the rules defined in the trigger_rules list for the event
                context = 'is triggering the following rules but should not be: {}'
                bad_rules = ', '.join('\'{}\''.format(alert.rule_name)
                                      for alert in unexpected_alerts)
                message = '{} {}'.format(message, context.format(bad_rules))
            elif expected_alert_count != len(alerts):
                # If there was a failure due to alerts NOT triggering for 1+ rules
                # defined in the trigger_rules list for the event
                context = 'did not trigger the following rules: {}'
                non_triggered_rules = ', '.join(
                    '\'{}\''.format(rule) for rule in triggers
                    if rule not in [alert.rule_name for alert in alerts])
                message = '{} {}'.format(message,
                                         context.format(non_triggered_rules))
            else:
                # If there was a failure for some other reason, just use a default message
                message = 'Rule failure: [{}.json] {}'.format(
                    file_name, test_event['description'])
            self.status_messages.append(
                StatusMessage(StatusMessage.FAILURE, message))

        # Return the alerts back to caller
        return alerts

    @staticmethod
    def _detect_old_test_event(test_event):
        """Check if the test event contains the old format used

        Args:
            test_event (dict): The loaded test event from json

        Returns:
            bool: True if a legacy test file is detected, False otherwise
        """
        record_keys = set(test_event)
        if (not {'log', 'trigger_rules'}.issubset(record_keys)
                and {'trigger'}.issubset(record_keys)):
            return True

        return False

    def check_keys(self, test_event):
        """Check if the test event contains the required keys

        Args:
            test_event (dict): The loaded test event from json

        Returns:
            bool: True if the proper keys are present
        """
        required_keys = {
            'description', 'log', 'service', 'source', 'trigger_rules'
        }

        record_keys = set(test_event)
        if not required_keys.issubset(record_keys):
            req_key_diff = required_keys.difference(record_keys)
            missing_keys = ', '.join('\'{}\''.format(key)
                                     for key in req_key_diff)
            message = 'Missing required key(s) in log: {}'.format(missing_keys)
            self.status_messages.append(
                StatusMessage(StatusMessage.FAILURE, message))
            return False

        input_data_keys = {'data', 'override_record'}
        if not record_keys & input_data_keys:
            missing_keys = ', '.join('\'{}\''.format(key)
                                     for key in input_data_keys)
            message = 'Missing one of the following keys in log: {}'.format(
                missing_keys)
            self.status_messages.append(
                StatusMessage(StatusMessage.FAILURE, message))
            return False

        optional_keys = {'compress', 'validate_schema_only'}

        key_diff = record_keys.difference(required_keys | optional_keys
                                          | input_data_keys)

        # Log a warning if there are extra keys declared in the test log
        if key_diff:
            extra_keys = ', '.join('\'{}\''.format(key) for key in key_diff)
            message = 'Additional unnecessary keys in log: {}'.format(
                extra_keys)
            # Remove the key(s) and just warn the user that they are extra
            record_keys.difference_update(key_diff)
            self.status_messages.append(
                StatusMessage(StatusMessage.WARNING, message))

        return record_keys.issubset(required_keys | optional_keys
                                    | input_data_keys)

    def apply_template(self, test_event):
        """Apply default values to the given test event

        Args:
            test_event (dict): The loaded test event
        """
        event_log = self.cli_config['logs'].get(test_event['log'])

        parser = event_log['parser']
        schema = event_log['schema']
        configuration = event_log.get('configuration', {})

        # Add envelope keys
        schema.update(configuration.get('envelope_keys', {}))

        # Setup the parser to access default optional values
        self.parsers[parser] = self.parsers.get(parser, get_parser(parser))

        # Add apply default values based on the declared schema
        default_test_event = {
            key: self.parsers[parser].default_optional_values(value)
            for key, value in schema.iteritems()
        }

        # Fill in the fields left out in the 'override_record' field,
        # and update the test event with a full 'data' key
        default_test_event.update(test_event['override_record'])
        test_event['data'] = default_test_event

    @staticmethod
    def apply_helpers(test_record):
        """Detect and apply helper functions to test event data

        Helpers are declared in test fixtures via the following keyword:
        "<helpers:helper_name>"

        Supported helper functions:
            last_hour: return the current epoch time minus 60 seconds to pass the
                       last_hour rule helper.

        Args:
            test_record (dict): loaded fixture file JSON as a dict.
        """
        # declare all helper functions here, they should always return a string
        record_helpers = {'last_hour': lambda: str(int(time.time()) - 60)}
        helper_regex = re.compile(r'<helper:(?P<helper>\w+)>')

        def find_and_apply_helpers(test_record):
            """Apply any helpers to the passed in test_record"""
            for key, value in test_record.iteritems():
                if isinstance(value, (str, unicode)):
                    test_record[key] = re.sub(
                        helper_regex,
                        lambda match: record_helpers[match.group('helper')](),
                        test_record[key])
                elif isinstance(value, dict):
                    find_and_apply_helpers(test_record[key])

        find_and_apply_helpers(test_record)

    def report_output_summary(self):
        """Helper function to print the summary results of all tests"""
        failure_messages = [
            item for item in self.status_messages
            if item.type == StatusMessage.FAILURE
        ]
        warning_messages = [
            item for item in self.status_messages
            if item.type == StatusMessage.WARNING
        ]
        passed_tests = sum(1 for item in self.status_messages
                           if item.type == StatusMessage.SUCCESS)
        passed_tests = self.total_tests - len(failure_messages)
        # Print some lines at the bottom of output to make it more readable
        # This occurs here so there is always space and not only when the
        # successful test info prints
        print '\n\n'

        # Only print success info if we explicitly want to print output
        # but always print any errors or warnings below
        if self.print_output:
            # Print a message indicating how many of the total tests passed
            LOGGER_CLI.info('%s(%d/%d) Successful Tests%s', COLOR_GREEN,
                            passed_tests, self.total_tests, COLOR_RESET)

        # Check if there were failed tests and report on them appropriately
        if failure_messages:
            # Print a message indicating how many of the total tests failed
            LOGGER_CLI.error('%s(%d/%d) Failures%s', COLOR_RED,
                             len(failure_messages), self.total_tests,
                             COLOR_RESET)

            # Iterate over the rule_name values in the failed list and report on them
            for index, failure in enumerate(failure_messages, start=1):
                LOGGER_CLI.error('%s(%d/%d) %s%s', COLOR_RED, index,
                                 len(failure_messages), failure.message,
                                 COLOR_RESET)

        # Check if there were any warnings and report on them
        if warning_messages:
            warning_count = len(warning_messages)
            LOGGER_CLI.warn('%s%d Warning%s%s', COLOR_YELLOW, warning_count,
                            ('s' if warning_count > 1 else ''), COLOR_RESET)

            for index, warning in enumerate(warning_messages, start=1):
                LOGGER_CLI.warn('%s(%d/%d) %s%s', COLOR_YELLOW, index,
                                warning_count, warning.message, COLOR_RESET)

    def test_rule(self, record):
        """Feed formatted records into StreamAlert and check for alerts

        Args:
            record (dict): A formatted event that reflects the structure expected
                as input to the Lambda function.

        Returns:
            list: alerts that hit for this rule
            bool: False if errors occurred during processing
        """
        # Clear out any old alerts or errors from the previous test run
        # pylint: disable=protected-access
        del self.processor._alerts[:]
        self.processor._failed_record_count = 0

        # Run the rule processor
        all_records_matched_schema = self.processor.run(record)

        return self.processor.alerts, all_records_matched_schema

    def check_log_declared_in_sources(self, base_message, test_event):
        """A simple check to see if this log type is defined in the sources for the service

        Args:
            base_message (str): Base error message to be reported with extra context
            test_event (dict): Actual record data being tested

        Returns:
            bool: False if the log type is not in the sources list, True if it is
        """
        source = test_event['source']
        service = test_event['service']
        log = test_event['log'].split(':')[0]
        if not log in self.cli_config['sources'][service][source]['logs']:
            message = (
                'The \'sources.json\' file does not include the log type \'{}\' '
                'in the list of logs for this service & entity (\'{}:{}\').')
            message = '{} {}'.format(base_message,
                                     message.format(log, service, source))
            self.status_messages.append(
                StatusMessage(StatusMessage.FAILURE, message))
            return False

        return True

    def analyze_record_delta(self, file_name, test_event):
        """Provide some additional context on why this test failed. This will
        perform some analysis of the test record to determine which keys are
        missing or which unnecessary keys are causing the test to fail. Any
        errors are appended to a list of errors so they can be printed at
        the end of the test run.

        Args:
            file_name (str): Name of file containing the test event
            test_event (dict): Actual record data being tested
        """
        base_message = (
            'Invalid test event in file \'{}.json\' with description '
            '\'{}\'.'.format(file_name, test_event['description']))

        if not self.check_log_declared_in_sources(base_message, test_event):
            return

        log_type = test_event['log']
        if log_type not in self.cli_config['logs']:
            message = (
                '{} Log (\'{}\') declared in test event does not exist in '
                'logs.json'.format(base_message, log_type))

            self.status_messages.append(
                StatusMessage(StatusMessage.FAILURE, message))
            return

        config_log_info = self.cli_config['logs'][log_type]
        schema_keys = config_log_info['schema']

        envelope_keys = config_log_info.get('configuration',
                                            {}).get('envelope_keys')
        if envelope_keys:
            if self.report_envelope_key_error(base_message, envelope_keys,
                                              test_event['data']):
                return

        # Check is a json path is used for nested records
        json_path = config_log_info.get('configuration', {}).get('json_path')
        if json_path:
            records_jsonpath = jsonpath_rw.parse(json_path)
            for match in records_jsonpath.find(test_event['data']):
                self.report_record_delta(base_message, log_type, schema_keys,
                                         match.value)

            return

        self.report_record_delta(base_message, log_type, schema_keys,
                                 test_event['data'])

    def report_envelope_key_error(self, base_message, envelope_keys,
                                  test_record):
        """Provide context failures related to envelope key issues.

        Args:
            base_message (str): Base error message to be reported with extra context
            envelope_keys (list): A collection of the envelope keys for this nested schema
            test_record (dict): Actual record being tested - this could be one of
                many records extracted using jsonpath_rw
        """
        missing_env_key_list = set(envelope_keys).difference(set(test_record))
        if missing_env_key_list:
            missing_key_list = ', '.join('\'{}\''.format(key)
                                         for key in missing_env_key_list)
            message = (
                '{} Data is invalid due to missing envelope key(s) in test record: '
                '{}.'.format(base_message, missing_key_list))

            self.status_messages.append(
                StatusMessage(StatusMessage.FAILURE, message))
            return True

        return False

    def report_record_delta(self, base_message, log_type, schema_keys,
                            test_record):
        """Provide context on why this specific record failed.

        Args:
            base_message (str): Base error message to be reported with extra context
            log_type (str): Type of log being tested
            schema_keys (set): A collection of the keys from the schema
            test_record (dict): Actual record being tested - this could be one of
                many records extracted using jsonpath_rw
        """
        optional_keys = set(
            self.cli_config['logs'] \
                [log_type].get('configuration', {}).get('optional_top_level_keys', {})
        )

        min_req_record_schema_keys = set(schema_keys).difference(optional_keys)

        test_record_keys = set(test_record)

        schema_diff = min_req_record_schema_keys.difference(test_record_keys)
        if schema_diff:
            missing_key_list = ', '.join('\'{}\''.format(key)
                                         for key in schema_diff)
            message = (
                '{} Data is invalid due to missing key(s) in test record: '
                '{}.'.format(base_message, missing_key_list))

            self.status_messages.append(
                StatusMessage(StatusMessage.FAILURE, message))
            return

        unexpected_keys = test_record_keys.difference(schema_keys)
        if unexpected_keys:
            unexpected_key_list = ', '.join('\'{}\''.format(key)
                                            for key in unexpected_keys)
            message = (
                '{} Data is invalid due to unexpected key(s) in test record: '
                '{}.'.format(base_message, unexpected_key_list))

            self.status_messages.append(
                StatusMessage(StatusMessage.FAILURE, message))
            return

        # Add a generic error message if we can not determine what the issue is
        message = '{} Please look for any errors above.'.format(base_message)
        self.status_messages.append(
            StatusMessage(StatusMessage.FAILURE, message))
コード例 #17
0
class RuleProcessorTester(object):
    """Class to encapsulate testing the rule processor"""
    def __init__(self, context, print_output):
        """RuleProcessorTester initializer

        Args:
            print_output (bool): Whether this processor test
                should print results to stdout. This is set to false when the
                alert processor is explicitly being testing alone, and set to
                true for rule processor tests and end-to-end tests.
                Warnings and errors captrued during rule processor testing
                will still be written to stdout regardless of this setting.
        """
        # Create the RuleProcessor. Passing a mocked context object with fake
        # values and False for suppressing sending of alerts to alert processor
        self.processor = StreamAlert(context, False)
        # Use a list of status_messages to store pass/fail/warning info
        self.status_messages = []
        self.total_tests = 0
        self.all_tests_passed = True
        self.print_output = print_output

    def test_processor(self, filter_rules, validate_only=False):
        """Perform integration tests for the 'rule' Lambda function

        Args:
            filter_rules (list|None): Specific rule names (or None) to restrict
                testing to. This is passed in from the CLI using the --rules option.
            validate_only (bool): If true, validation of test records will occur
                without the rules engine being applied to events.

        Yields:
            tuple (bool, list) or None: If testing rules, this yields a tuple containig a
                boolean of test status and a list of alerts to run through the alert
                processor. If validating test records only, this does not yield.
        """
        for rule_name, contents in self._get_rule_test_files(
                filter_rules, validate_only):
            # Go over the records and test the applicable rule
            for index, test_record in enumerate(contents.get('records')):
                self.total_tests += 1

                if not self.check_keys(rule_name, test_record):
                    self.all_tests_passed = False
                    continue

                self.apply_helpers(test_record)

                print_header_line = index == 0

                formatted_record = helpers.format_lambda_test_record(
                    test_record)

                if validate_only:
                    self._validate_test_records(rule_name, test_record,
                                                formatted_record,
                                                print_header_line)
                    continue

                yield self._run_rule_tests(rule_name, test_record,
                                           formatted_record, print_header_line)

        # Report on the final test results
        self.report_output_summary()

    def _validate_test_records(self, rule_name, test_record, formatted_record,
                               print_header_line):
        """Function to validate test records and log any errors

        Args:
            rule_name (str): The rule name being tested
            test_record (dict): A single record to test
            formatted_record (dict): A dictionary that includes the 'data' from the
                test record, formatted into a structure that is resemblant of how
                an incoming record from a service would format it.
                See test/integration/templates for example of how each service
                formats records.
        """
        service, entity = self.processor.classifier.extract_service_and_entity(
            formatted_record)

        if not self.processor.classifier.load_sources(service, entity):
            self.all_tests_passed = False
            return

        # Create the StreamPayload to use for encapsulating parsed info
        payload = load_stream_payload(service, entity, formatted_record)
        if not payload:
            self.all_tests_passed = False
            return

        if print_header_line:
            print '\n{}'.format(rule_name)

        for record in payload.pre_parse():
            self.processor.classifier.classify_record(record)

            if not record.valid:
                self.all_tests_passed = False
                self.analyze_record_delta(rule_name, test_record)

            report_output(record.valid, [
                '[log=\'{}\']'.format(record.log_source or 'unknown'),
                'validation',
                record.service(), test_record['description']
            ])

    def _run_rule_tests(self, rule_name, test_record, formatted_record,
                        print_header_line):
        """Run tests on a test record for a given rule

        Args:
            rule_name (str): The name of the rule being tested.
            test_record (dict): The loaded test event from json
            formatted_record (dict): A dictionary that includes the 'data' from the
                test record, formatted into a structure that is resemblant of how
                an incoming record from a service would format it.
                See test/integration/templates for example of how each service
                formats records.
            print_header_line (bool): Indicates if this is the first record from
                a test file, and therefore we should print some header information

        Returns:
            list: alerts that were generated from this test event
        """
        event = {'Records': [formatted_record]}
        # Run tests on the formatted record
        alerts, expected_alerts, all_records_matched_schema = self.test_rule(
            rule_name, test_record, event)

        alerted_properly = (len(alerts) == expected_alerts)
        current_test_passed = alerted_properly and all_records_matched_schema

        self.all_tests_passed = current_test_passed and self.all_tests_passed

        # Print rule name for section header, but only if we get
        # to a point where there is a record to actually be tested.
        # This avoids potentialy blank sections
        if print_header_line and (alerts or self.print_output):
            print '\n{}'.format(rule_name)

        if self.print_output:
            report_output(current_test_passed, [
                '[trigger={}]'.format(expected_alerts), 'rule',
                test_record['service'], test_record['description']
            ])

        # Add the status of the rule to messages list
        if not all_records_matched_schema:
            self.analyze_record_delta(rule_name, test_record)
        elif not alerted_properly:
            message = 'Rule failure: {}'.format(test_record['description'])
            self.status_messages.append(
                StatusMessage(StatusMessage.FAILURE, rule_name, message))

        # Return the alerts back to caller
        return alerts

    def _get_rule_test_files(self, filter_rules, validate_only):
        """Helper to get rule files to be tested

        Args:
            filter_rules (list|None): List of specific rule names or file names
                (or None) that has been fed in from the CLI to restrict testing to

        Yields:
            str: rule name
            dict: loaded json contents of the respective test event file
        """
        # Since filter_rules can be either a list of rule names or rule files,
        # we should check to see if there is a '.json' extension and just use the
        # base filename. This approach avoids two functions that do largely the same thing
        if filter_rules:
            for index, rule in enumerate(filter_rules):
                parts = os.path.splitext(rule)
                if parts[1] == '.json':
                    filter_rules[index] = parts[0]

            # Create a copy of the filtered rules that can be altered
            filter_rules_copy = filter_rules[:]

        for _, _, test_rule_files in os.walk(DIR_RULES):
            for rule_file in test_rule_files:
                rule_name = os.path.splitext(rule_file)[0]

                # If only specific rules are being tested,
                # skip files that do not match those rules
                if filter_rules:
                    if rule_name not in filter_rules:
                        continue

                    filter_rules_copy.remove(rule_name)

                with open(os.path.join(DIR_RULES, rule_file),
                          'r') as rule_file_handle:
                    try:
                        contents = json.load(rule_file_handle)
                    except (ValueError, TypeError) as err:
                        self.all_tests_passed = False
                        message = 'Improperly formatted file ({}): {}'.format(
                            rule_file, err.message)
                        self.status_messages.append(
                            StatusMessage(StatusMessage.WARNING, rule_name,
                                          message))
                        continue

                if not contents.get('records'):
                    self.all_tests_passed = False
                    self.status_messages.append(
                        StatusMessage(StatusMessage.WARNING, rule_name,
                                      'No records to test in file'))
                    continue

                yield rule_name, contents

        # Print any of the filtered rules that remain in the list
        # This means that there are not tests configured for them
        if filter_rules and filter_rules_copy:
            self.all_tests_passed = False
            message = 'No test events configured for designated rule'
            for filter_rule in filter_rules:
                if validate_only:
                    message = 'Designated file ({}.json) does not exist within \'{}\''.format(
                        filter_rule, DIR_RULES)
                self.status_messages.append(
                    StatusMessage(StatusMessage.WARNING, filter_rule, message))

    def check_keys(self, rule_name, test_record):
        """Check the test_record contains the required keys

        Args:
            rule_name (str): The name of the rule being tested. This is passed in
                here strictly for reporting any errors with key checks.
            test_record (dict): The raw test record being processed

        Returns:
            bool: True if the proper keys are present
        """
        required_keys = {'data', 'description', 'service', 'source', 'trigger'}

        record_keys = set(test_record.keys())
        if not required_keys.issubset(record_keys):
            req_key_diff = required_keys.difference(record_keys)
            missing_keys = ','.join('\'{}\''.format(key)
                                    for key in req_key_diff)
            message = 'Missing required key(s) in log: {}'.format(missing_keys)
            self.status_messages.append(
                StatusMessage(StatusMessage.FAILURE, rule_name, message))
            return False

        optional_keys = {'trigger_count', 'compress'}

        key_diff = record_keys.difference(required_keys | optional_keys)

        # Log a warning if there are extra keys declared in the test log
        if key_diff:
            extra_keys = ','.join('\'{}\''.format(key) for key in key_diff)
            message = 'Additional unnecessary keys in log: {}'.format(
                extra_keys)
            # Remove the key(s) and just warn the user that they are extra
            record_keys.difference_update(key_diff)
            self.status_messages.append(
                StatusMessage(StatusMessage.WARNING, rule_name, message))

        return record_keys.issubset(required_keys | optional_keys)

    @staticmethod
    def apply_helpers(test_record):
        """Detect and apply helper functions to test fixtures
        Helpers are declared in test fixtures via the following keyword:
        "<helpers:helper_name>"

        Supported helper functions:
            last_hour: return the current epoch time minus 60 seconds to pass the
                       last_hour rule helper.

        Args:
            test_record (dict): loaded fixture file JSON as a dict.
        """
        # declare all helper functions here, they should always return a string
        record_helpers = {'last_hour': lambda: str(int(time.time()) - 60)}
        helper_regex = re.compile(r'<helper:(?P<helper>\w+)>')

        def find_and_apply_helpers(test_record):
            """Apply any helpers to the passed in test_record"""
            for key, value in test_record.iteritems():
                if isinstance(value, (str, unicode)):
                    test_record[key] = re.sub(
                        helper_regex,
                        lambda match: record_helpers[match.group('helper')](),
                        test_record[key])
                elif isinstance(value, dict):
                    find_and_apply_helpers(test_record[key])

        find_and_apply_helpers(test_record)

    def report_output_summary(self):
        """Helper function to print the summary results of all tests"""
        failure_messages = [
            item for item in self.status_messages
            if item.type == StatusMessage.FAILURE
        ]
        warning_messages = [
            item for item in self.status_messages
            if item.type == StatusMessage.WARNING
        ]
        passed_tests = sum(1 for item in self.status_messages
                           if item.type == StatusMessage.SUCCESS)
        passed_tests = self.total_tests - len(failure_messages)
        # Print some lines at the bottom of output to make it more readable
        # This occurs here so there is always space and not only when the
        # successful test info prints
        print '\n\n'

        # Only print success info if we explicitly want to print output
        # but always print any errors or warnings below
        if self.print_output:
            # Print a message indicating how many of the total tests passed
            LOGGER_CLI.info('%s(%d/%d) Successful Tests%s', COLOR_GREEN,
                            passed_tests, self.total_tests, COLOR_RESET)

        # Check if there were failed tests and report on them appropriately
        if failure_messages:
            # Print a message indicating how many of the total tests failed
            LOGGER_CLI.error('%s(%d/%d) Failures%s', COLOR_RED,
                             len(failure_messages), self.total_tests,
                             COLOR_RESET)

            # Iterate over the rule_name values in the failed list and report on them
            for index, failure in enumerate(failure_messages, start=1):
                LOGGER_CLI.error('%s(%d/%d) [%s] %s%s', COLOR_RED, index,
                                 len(failure_messages), failure.rule,
                                 failure.message, COLOR_RESET)

        # Check if there were any warnings and report on them
        if warning_messages:
            warning_count = len(warning_messages)
            LOGGER_CLI.warn('%s%d Warning%s%s', COLOR_YELLOW, warning_count,
                            ('s' if warning_count > 1 else ''), COLOR_RESET)

            for index, warning in enumerate(warning_messages, start=1):
                LOGGER_CLI.warn('%s(%d/%d) [%s] %s%s', COLOR_YELLOW, index,
                                warning_count, warning.rule, warning.message,
                                COLOR_RESET)

    def test_rule(self, rule_name, test_record, event):
        """Feed formatted records into StreamAlert and check for alerts

        Args:
            rule_name (str): The rule name being tested
            test_record (dict): A single raw record to test
            event (dict): A formatted event that reflects the structure expected
                as input to the Lambda function.

        Returns:
            list: alerts that hit for this rule
            int: count of expected alerts for this rule
            bool: False if errors occurred during processing
        """
        # Clear out any old alerts or errors from the previous test run
        # pylint: disable=protected-access
        del self.processor._alerts[:]
        self.processor._failed_record_count = 0

        expected_alert_count = test_record.get('trigger_count')
        if not expected_alert_count:
            expected_alert_count = 1 if test_record['trigger'] else 0

        # Run the rule processor
        all_records_matched_schema = self.processor.run(event)

        alerts = self.processor.get_alerts()

        # we only want alerts for the specific rule being tested
        alerts = [alert for alert in alerts if alert['rule_name'] == rule_name]

        return alerts, expected_alert_count, all_records_matched_schema

    def analyze_record_delta(self, rule_name, test_record):
        """Provide some additional context on why this test failed. This will
        perform some analysis of the test record to determine which keys are
        missing or which unnecessary keys are causing the test to fail. Any
        errors are appended to a list of errors so they can be printed at
        the end of the test run.

        Args:
            rule_name (str): Name of rule being tested
            test_record (dict): Actual record data being tested
        """
        logs = self.processor.classifier.get_log_info_for_source()
        rule_info = StreamRules.get_rules()[rule_name]
        test_record_keys = set(test_record['data'])
        for log in rule_info.logs:
            if log not in logs:
                message = 'Log declared in rule ({}) does not exist in logs.json'.format(
                    log)
                self.status_messages.append(
                    StatusMessage(StatusMessage.FAILURE, rule_name, message))
                continue
            all_record_schema_keys = set(logs[log]['schema'])
            optional_keys = set(logs[log].get('configuration', {}).get(
                'optional_top_level_keys', {}))

            min_req_record_schema_keys = all_record_schema_keys.difference(
                optional_keys)

            schema_diff = min_req_record_schema_keys.difference(
                test_record_keys)
            if schema_diff:
                message = (
                    'Data is invalid due to missing key(s) in test record: {}. '
                    'Rule: \'{}\'. Description: \'{}\''.format(
                        ', '.join('\'{}\''.format(key) for key in schema_diff),
                        rule_info.rule_name, test_record['description']))

                self.status_messages.append(
                    StatusMessage(StatusMessage.FAILURE, rule_name, message))
                continue

            unexpected_record_keys = test_record_keys.difference(
                all_record_schema_keys)
            if unexpected_record_keys:
                message = (
                    'Data is invalid due to unexpected key(s) in test record: {}. '
                    'Rule: \'{}\'. Description: \'{}\''.format(
                        ', '.join('\'{}\''.format(key)
                                  for key in unexpected_record_keys),
                        rule_info.rule_name, test_record['description']))

                self.status_messages.append(
                    StatusMessage(StatusMessage.FAILURE, rule_name, message))
コード例 #18
0
class TestStreamAlert(object):
    """Test class for StreamAlert class"""

    @patch('stream_alert.rule_processor.handler.load_config',
           lambda: load_config('tests/unit/conf/'))
    def setup(self):
        """Setup before each method"""
        self.__sa_handler = StreamAlert(get_mock_context(), False)

    def test_run_no_records(self):
        """StreamAlert Class - Run, No Records"""
        passed = self.__sa_handler.run({'Records': []})
        assert_false(passed)

    def test_get_alerts(self):
        """StreamAlert Class - Get Alerts"""
        default_list = ['alert1', 'alert2']
        self.__sa_handler._alerts = default_list

        assert_list_equal(self.__sa_handler.get_alerts(), default_list)

    @patch('stream_alert.rule_processor.handler.StreamClassifier.load_sources')
    @patch('stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity')
    def test_run_no_sources(self, extract_mock, load_sources_mock):
        """StreamAlert Class - Run, No Loaded Sources"""
        extract_mock.return_value = ('lambda', 'entity')
        load_sources_mock.return_value = None

        self.__sa_handler.run({'Records': ['record']})

        load_sources_mock.assert_called_with('lambda', 'entity')

    @patch('logging.Logger.error')
    @patch('stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity')
    def test_run_bad_service(self, extract_mock, log_mock):
        """StreamAlert Class - Run, Bad Service"""
        extract_mock.return_value = ('', 'entity')

        self.__sa_handler.run({'Records': ['record']})

        log_mock.assert_called_with('No valid service found in payload\'s raw record. '
                                    'Skipping record: %s', 'record')

    @patch('logging.Logger.error')
    @patch('stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity')
    def test_run_bad_entity(self, extract_mock, log_mock):
        """StreamAlert Class - Run, Bad Entity"""
        extract_mock.return_value = ('kinesis', '')

        self.__sa_handler.run({'Records': ['record']})

        log_mock.assert_called_with(
            'Unable to extract entity from payload\'s raw record for '
            'service %s. Skipping record: %s', 'kinesis', 'record')

    @patch('stream_alert.rule_processor.handler.load_stream_payload')
    @patch('stream_alert.rule_processor.handler.StreamClassifier.load_sources')
    @patch('stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity')
    def test_run_load_payload_bad(
            self,
            extract_mock,
            load_sources_mock,
            load_payload_mock):
        """StreamAlert Class - Run, Loaded Payload Fail"""
        extract_mock.return_value = ('lambda', 'entity')
        load_sources_mock.return_value = True

        self.__sa_handler.run({'Records': ['record']})

        load_payload_mock.assert_called_with(
            'lambda',
            'entity',
            'record'
        )

    @patch('stream_alert.rule_processor.handler.StreamRules.process')
    @patch('stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity')
    def test_run_with_alert(self, extract_mock, rules_mock):
        """StreamAlert Class - Run, With Alert"""
        extract_mock.return_value = ('kinesis', 'unit_test_default_stream')
        rules_mock.return_value = (['success!!'], ['normalized_records'])

        passed = self.__sa_handler.run(get_valid_event())

        assert_true(passed)

    @patch('stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity')
    def test_run_alert_count(self, extract_mock):
        """StreamAlert Class - Run, Check Count With 4 Logs"""
        count = 4
        extract_mock.return_value = ('kinesis', 'unit_test_default_stream')
        self.__sa_handler.run(get_valid_event(count))
        assert_equal(self.__sa_handler._processed_record_count, count)

    @patch('logging.Logger.debug')
    @patch('stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity')
    def test_run_no_alerts(self, extract_mock, log_mock):
        """StreamAlert Class - Run, With No Alerts"""
        extract_mock.return_value = ('kinesis', 'unit_test_default_stream')
        self.__sa_handler.run(get_valid_event())

        calls = [call('Processed %d valid record(s) that resulted in %d alert(s).', 1, 0),
                 call('Invalid record count: %d', 0),
                 call('%s alerts triggered', 0)]

        log_mock.assert_has_calls(calls)

    @patch('logging.Logger.error')
    @patch('stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity')
    def test_run_invalid_data(self, extract_mock, log_mock):
        """StreamAlert Class - Run, Invalid Data"""
        extract_mock.return_value = ('kinesis', 'unit_test_default_stream')
        event = get_valid_event()

        # Replace the good log data with bad data
        event['Records'][0]['kinesis']['data'] = base64.b64encode('{"bad": "data"}')

        # Swap out the alias so the logging occurs
        self.__sa_handler.env['lambda_alias'] = 'production'
        self.__sa_handler.run(event)

        assert_equal(
            log_mock.call_args[0][0],
            'Record does not match any defined schemas: %s\n%s')
        assert_equal(log_mock.call_args[0][2], '{"bad": "data"}')

    @patch('stream_alert.rule_processor.sink.StreamSink.sink')
    @patch('stream_alert.rule_processor.handler.StreamRules.process')
    @patch('stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity')
    def test_run_send_alerts(self, extract_mock, rules_mock, sink_mock):
        """StreamAlert Class - Run, Send Alert"""
        extract_mock.return_value = ('kinesis', 'unit_test_default_stream')
        rules_mock.return_value = (['success!!'], ['normalized_records'])

        # Set send_alerts to true so the sink happens
        self.__sa_handler.enable_alert_processor = True

        # Swap out the alias so the logging occurs
        self.__sa_handler.env['lambda_alias'] = 'production'

        self.__sa_handler.run(get_valid_event())

        sink_mock.assert_called_with(['success!!'])

    @patch('logging.Logger.debug')
    @patch('stream_alert.rule_processor.handler.StreamRules.process')
    @patch('stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity')
    def test_run_debug_log_alert(self, extract_mock, rules_mock, log_mock):
        """StreamAlert Class - Run, Debug Log Alert"""
        extract_mock.return_value = ('kinesis', 'unit_test_default_stream')
        rules_mock.return_value = (['success!!'], ['normalized_records'])

        # Cache the logger level
        log_level = LOGGER.getEffectiveLevel()

        # Increase the logger level to debug
        LOGGER.setLevel(logging.DEBUG)

        self.__sa_handler.run(get_valid_event())

        # Reset the logger level
        LOGGER.setLevel(log_level)

        log_mock.assert_called_with('Alerts:\n%s', '[\n  "success!!"\n]')

    @patch('stream_alert.rule_processor.handler.load_stream_payload')
    @patch('stream_alert.rule_processor.handler.StreamClassifier.load_sources')
    @patch('stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity')
    def test_run_no_payload_class(
            self,
            extract_mock,
            load_sources_mock,
            load_payload_mock):
        """StreamAlert Class - Run, No Payload Class"""
        extract_mock.return_value = ('blah', 'entity')
        load_sources_mock.return_value = True
        load_payload_mock.return_value = None

        self.__sa_handler.run({'Records': ['record']})

        load_payload_mock.assert_called()

    @mock_kinesis
    def test_firehose_record_delivery_disabled_logs(self):
        """StreamAlert Class - Firehose Record Delivery - Disabled Logs"""
        self.__sa_handler.firehose_client = boto3.client(
            'firehose', region_name='us-east-1')

        test_event = convert_events_to_kinesis([
            # unit_test_simple_log
            {'unit_key_01': 2, 'unit_key_02': 'testtest'}
            for _
            in range(10)])

        delivery_stream_names = ['streamalert_data_unit_test_simple_log']

        # Setup mock delivery streams
        for delivery_stream in delivery_stream_names:
            self.__sa_handler.firehose_client.create_delivery_stream(
                DeliveryStreamName=delivery_stream,
                S3DestinationConfiguration={
                    'RoleARN': 'arn:aws:iam::123456789012:role/firehose_delivery_role',
                    'BucketARN': 'arn:aws:s3:::kinesis-test',
                    'Prefix': '{}/'.format(delivery_stream),
                    'BufferingHints': {
                        'SizeInMBs': 123,
                        'IntervalInSeconds': 124
                    },
                    'CompressionFormat': 'Snappy',
                }
            )

        with patch.object(self.__sa_handler.firehose_client, 'put_record_batch') as firehose_mock:
            firehose_mock.return_value = {'FailedPutCount': 0}

            self.__sa_handler.config['global']['infrastructure']['firehose'] = {
                'disabled_logs': ['unit_test_simple_log']}
            self.__sa_handler.run(test_event)

            firehose_mock.assert_not_called()

    @patch('stream_alert.rule_processor.threat_intel.StreamThreatIntel._query')
    @patch('stream_alert.rule_processor.threat_intel.StreamThreatIntel.load_from_config')
    def test_run_threat_intel_enabled(self, mock_threat_intel, mock_query): # pylint: disable=no-self-use
        """StreamAlert Class - Run SA when threat intel enabled"""
        @rule(datatypes=['sourceAddress'], outputs=['s3:sample_bucket'])
        def match_ipaddress(_): # pylint: disable=unused-variable
            """Testing dummy rule"""
            return True

        mock_threat_intel.return_value = StreamThreatIntel('test_table_name', 'us-east-1')
        mock_query.return_value = ([], [])

        sa_handler = StreamAlert(get_mock_context(), False)
        event = {
            'account': 123456,
            'region': '123456123456',
            'source': '1.1.1.2',
            'detail': {
                'eventName': 'ConsoleLogin',
                'sourceIPAddress': '1.1.1.2',
                'recipientAccountId': '654321'
            }
        }
        events = []
        for i in range(10):
            event['source'] = '1.1.1.{}'.format(i)
            events.append(event)

        kinesis_events = {
            'Records': [make_kinesis_raw_record('test_kinesis_stream', json.dumps(event))
                        for event in events]
        }

        passed = sa_handler.run(kinesis_events)
        assert_true(passed)

        assert_equal(mock_query.call_count, 1)
コード例 #19
0
def handler(event, context):
    """Main Lambda handler function"""
    StreamAlert(context).run(event)
コード例 #20
0
class TestStreamAlert(object):
    """Test class for StreamAlert class"""
    def __init__(self):
        self.__sa_handler = None

    @patch('stream_alert.rule_processor.handler.load_config',
           lambda: load_config('tests/unit/conf/'))
    def setup(self):
        """Setup before each method"""
        self.__sa_handler = StreamAlert(get_mock_context(), False)

    def test_run_no_records(self):
        """StreamAlert Class - Run, No Records"""
        passed = self.__sa_handler.run({'Records': []})
        assert_false(passed)

    @staticmethod
    @raises(ConfigError)
    def test_run_config_error():
        """StreamAlert Class - Run, Config Error"""
        mock = mock_open(
            read_data='non-json string that will raise an exception')
        with patch('__builtin__.open', mock):
            StreamAlert(get_mock_context())

    def test_get_alerts(self):
        """StreamAlert Class - Get Alerts"""
        default_list = ['alert1', 'alert2']
        self.__sa_handler._alerts = default_list

        assert_list_equal(self.__sa_handler.get_alerts(), default_list)

    @patch('stream_alert.rule_processor.handler.StreamClassifier.load_sources')
    @patch(
        'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity'
    )
    def test_run_no_sources(self, extract_mock, load_sources_mock):
        """StreamAlert Class - Run, No Loaded Sources"""
        extract_mock.return_value = ('lambda', 'entity')
        load_sources_mock.return_value = None

        self.__sa_handler.run({'Records': ['record']})

        load_sources_mock.assert_called_with('lambda', 'entity')

    @patch('logging.Logger.error')
    @patch(
        'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity'
    )
    def test_run_bad_service(self, extract_mock, log_mock):
        """StreamAlert Class - Run, Bad Service"""
        extract_mock.return_value = ('', 'entity')

        self.__sa_handler.run({'Records': ['record']})

        log_mock.assert_called_with(
            'No valid service found in payload\'s raw record. '
            'Skipping record: %s', 'record')

    @patch('logging.Logger.error')
    @patch(
        'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity'
    )
    def test_run_bad_entity(self, extract_mock, log_mock):
        """StreamAlert Class - Run, Bad Entity"""
        extract_mock.return_value = ('kinesis', '')

        self.__sa_handler.run({'Records': ['record']})

        log_mock.assert_called_with(
            'Unable to extract entity from payload\'s raw record for '
            'service %s. Skipping record: %s', 'kinesis', 'record')

    @patch('stream_alert.rule_processor.handler.load_stream_payload')
    @patch('stream_alert.rule_processor.handler.StreamClassifier.load_sources')
    @patch(
        'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity'
    )
    def test_run_load_payload_bad(self, extract_mock, load_sources_mock,
                                  load_payload_mock):
        """StreamAlert Class - Run, Loaded Payload Fail"""
        extract_mock.return_value = ('lambda', 'entity')
        load_sources_mock.return_value = True

        self.__sa_handler.run({'Records': ['record']})

        load_payload_mock.assert_called_with('lambda', 'entity', 'record')

    @patch('stream_alert.rule_processor.handler.StreamRules.process')
    @patch(
        'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity'
    )
    def test_run_with_alert(self, extract_mock, rules_mock):
        """StreamAlert Class - Run, With Alert"""
        extract_mock.return_value = ('kinesis', 'unit_test_default_stream')
        rules_mock.return_value = ['success!!']

        passed = self.__sa_handler.run(get_valid_event())

        assert_true(passed)

    @patch('logging.Logger.debug')
    @patch(
        'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity'
    )
    def test_run_no_alerts(self, extract_mock, log_mock):
        """StreamAlert Class - Run, With No Alerts"""
        extract_mock.return_value = ('kinesis', 'unit_test_default_stream')
        self.__sa_handler.run(get_valid_event())

        calls = [
            call('Processed %d valid record(s) that resulted in %d alert(s).',
                 1, 0),
            call('Invalid record count: %d', 0),
            call('%s alerts triggered', 0)
        ]

        log_mock.assert_has_calls(calls)

    @patch('logging.Logger.error')
    @patch(
        'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity'
    )
    def test_run_invalid_data(self, extract_mock, log_mock):
        """StreamAlert Class - Run, Invalid Data"""
        extract_mock.return_value = ('kinesis', 'unit_test_default_stream')
        event = get_valid_event()

        # Replace the good log data with bad data
        event['Records'][0]['kinesis']['data'] = base64.b64encode(
            '{"bad": "data"}')

        # Swap out the alias so the logging occurs
        self.__sa_handler.env['lambda_alias'] = 'production'
        self.__sa_handler.run(event)

        assert_equal(log_mock.call_args[0][0],
                     'Record does not match any defined schemas: %s\n%s')
        assert_equal(log_mock.call_args[0][2], '{"bad": "data"}')

    @patch('stream_alert.rule_processor.sink.StreamSink.sink')
    @patch('stream_alert.rule_processor.handler.StreamRules.process')
    @patch(
        'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity'
    )
    def test_run_send_alerts(self, extract_mock, rules_mock, sink_mock):
        """StreamAlert Class - Run, Send Alert"""
        extract_mock.return_value = ('kinesis', 'unit_test_default_stream')
        rules_mock.return_value = ['success!!']

        # Set send_alerts to true so the sink happens
        self.__sa_handler.enable_alert_processor = True

        # Swap out the alias so the logging occurs
        self.__sa_handler.env['lambda_alias'] = 'production'

        self.__sa_handler.run(get_valid_event())

        sink_mock.assert_called_with(['success!!'])

    @patch('logging.Logger.debug')
    @patch('stream_alert.rule_processor.handler.StreamRules.process')
    @patch(
        'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity'
    )
    def test_run_debug_log_alert(self, extract_mock, rules_mock, log_mock):
        """StreamAlert Class - Run, Debug Log Alert"""
        extract_mock.return_value = ('kinesis', 'unit_test_default_stream')
        rules_mock.return_value = ['success!!']

        # Cache the logger level
        log_level = LOGGER.getEffectiveLevel()

        # Increase the logger level to debug
        LOGGER.setLevel(logging.DEBUG)

        self.__sa_handler.run(get_valid_event())

        # Reset the logger level
        LOGGER.setLevel(log_level)

        log_mock.assert_called_with('Alerts:\n%s', '[\n  "success!!"\n]')

    @patch('stream_alert.rule_processor.handler.load_stream_payload')
    @patch('stream_alert.rule_processor.handler.StreamClassifier.load_sources')
    @patch(
        'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity'
    )
    def test_run_no_payload_class(self, extract_mock, load_sources_mock,
                                  load_payload_mock):
        """StreamAlert Class - Run, No Payload Class"""
        extract_mock.return_value = ('blah', 'entity')
        load_sources_mock.return_value = True
        load_payload_mock.return_value = None

        self.__sa_handler.run({'Records': ['record']})

        load_payload_mock.assert_called()