def test_run_threat_intel_enabled(self, mock_threat_intel, mock_query): # pylint: disable=no-self-use """StreamAlert Class - Run SA when threat intel enabled""" @rule(datatypes=['sourceAddress'], outputs=['s3:sample_bucket']) def match_ipaddress(_): # pylint: disable=unused-variable """Testing dummy rule""" return True mock_threat_intel.return_value = StreamThreatIntel('test_table_name', 'us-east-1') mock_query.return_value = ([], []) sa_handler = StreamAlert(get_mock_context(), False) event = { 'account': 123456, 'region': '123456123456', 'source': '1.1.1.2', 'detail': { 'eventName': 'ConsoleLogin', 'sourceIPAddress': '1.1.1.2', 'recipientAccountId': '654321' } } events = [] for i in range(10): event['source'] = '1.1.1.{}'.format(i) events.append(event) kinesis_events = { 'Records': [make_kinesis_raw_record('test_kinesis_stream', json.dumps(event)) for event in events] } passed = sa_handler.run(kinesis_events) assert_true(passed) assert_equal(mock_query.call_count, 1)
def test_rule(self, rule_name, test_record, formatted_record): """Feed formatted records into StreamAlert and check for alerts Args: rule_name [str]: The rule name being tested test_record [dict]: A single record to test formatted_record [dict]: A dictionary that includes the 'data' from the test record, formatted into a structure that is resemblant of how an incoming record from a service would format it. See test/integration/templates for example of how each service formats records. Returns: [list] alerts that hit for this rule [integer] count of expected alerts for this rule [bool] boolean where False indicates errors occurred during processing """ event = {'Records': [formatted_record]} expected_alert_count = test_record.get('trigger_count') if not expected_alert_count: expected_alert_count = 1 if test_record['trigger'] else 0 # Run the rule processor. Passing mocked context object with fake # values and False for suppressing sending of alerts processor = StreamAlert(self.context, False) all_records_matched_schema = processor.run(event) if not all_records_matched_schema: payload = StreamPayload(raw_record=formatted_record) classifier = StreamClassifier(config=load_config()) classifier.map_source(payload) logs = classifier._log_metadata() self.analyze_record_delta(logs, rule_name, test_record) alerts = processor.get_alerts() # we only want alerts for the specific rule being tested alerts = [alert for alert in alerts if alert['rule_name'] == rule_name] return alerts, expected_alert_count, all_records_matched_schema
class TestStreamAlert(object): """Test class for StreamAlert class""" def __init__(self): self.__sa_handler = None @patch('stream_alert.rule_processor.handler.load_config', lambda: load_config('tests/unit/conf/')) def setup(self): """Setup before each method""" self.__sa_handler = StreamAlert(get_mock_context(), False) def test_run_no_records(self): """StreamAlert Class - Run, No Records""" passed = self.__sa_handler.run({'Records': []}) assert_false(passed) def test_get_alerts(self): """StreamAlert Class - Get Alerts""" default_list = ['alert1', 'alert2'] self.__sa_handler._alerts = default_list assert_list_equal(self.__sa_handler.get_alerts(), default_list) @patch('stream_alert.rule_processor.handler.StreamClassifier.load_sources') @patch( 'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity' ) def test_run_no_sources(self, extract_mock, load_sources_mock): """StreamAlert Class - Run, No Loaded Sources""" extract_mock.return_value = ('lambda', 'entity') load_sources_mock.return_value = None self.__sa_handler.run({'Records': ['record']}) load_sources_mock.assert_called_with('lambda', 'entity') @patch('logging.Logger.error') @patch( 'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity' ) def test_run_bad_service(self, extract_mock, log_mock): """StreamAlert Class - Run, Bad Service""" extract_mock.return_value = ('', 'entity') self.__sa_handler.run({'Records': ['record']}) log_mock.assert_called_with( 'No valid service found in payload\'s raw record. ' 'Skipping record: %s', 'record') @patch('logging.Logger.error') @patch( 'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity' ) def test_run_bad_entity(self, extract_mock, log_mock): """StreamAlert Class - Run, Bad Entity""" extract_mock.return_value = ('kinesis', '') self.__sa_handler.run({'Records': ['record']}) log_mock.assert_called_with( 'Unable to extract entity from payload\'s raw record for ' 'service %s. Skipping record: %s', 'kinesis', 'record') @patch('stream_alert.rule_processor.handler.load_stream_payload') @patch('stream_alert.rule_processor.handler.StreamClassifier.load_sources') @patch( 'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity' ) def test_run_load_payload_bad(self, extract_mock, load_sources_mock, load_payload_mock): """StreamAlert Class - Run, Loaded Payload Fail""" extract_mock.return_value = ('lambda', 'entity') load_sources_mock.return_value = True self.__sa_handler.run({'Records': ['record']}) load_payload_mock.assert_called_with('lambda', 'entity', 'record') @patch('stream_alert.rule_processor.handler.StreamRules.process') @patch( 'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity' ) def test_run_with_alert(self, extract_mock, rules_mock): """StreamAlert Class - Run, With Alert""" extract_mock.return_value = ('kinesis', 'unit_test_default_stream') rules_mock.return_value = ['success!!'] passed = self.__sa_handler.run(get_valid_event()) assert_true(passed) @patch('logging.Logger.debug') @patch( 'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity' ) def test_run_no_alerts(self, extract_mock, log_mock): """StreamAlert Class - Run, With No Alerts""" extract_mock.return_value = ('kinesis', 'unit_test_default_stream') self.__sa_handler.run(get_valid_event()) calls = [ call('Processed %d valid record(s) that resulted in %d alert(s).', 1, 0), call('Invalid record count: %d', 0), call('%s alerts triggered', 0) ] log_mock.assert_has_calls(calls) @patch('logging.Logger.error') @patch( 'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity' ) def test_run_invalid_data(self, extract_mock, log_mock): """StreamAlert Class - Run, Invalid Data""" extract_mock.return_value = ('kinesis', 'unit_test_default_stream') event = get_valid_event() # Replace the good log data with bad data event['Records'][0]['kinesis']['data'] = base64.b64encode( '{"bad": "data"}') # Swap out the alias so the logging occurs self.__sa_handler.env['lambda_alias'] = 'production' self.__sa_handler.run(event) assert_equal(log_mock.call_args[0][0], 'Record does not match any defined schemas: %s\n%s') assert_equal(log_mock.call_args[0][2], '{"bad": "data"}') @patch('stream_alert.rule_processor.sink.StreamSink.sink') @patch('stream_alert.rule_processor.handler.StreamRules.process') @patch( 'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity' ) def test_run_send_alerts(self, extract_mock, rules_mock, sink_mock): """StreamAlert Class - Run, Send Alert""" extract_mock.return_value = ('kinesis', 'unit_test_default_stream') rules_mock.return_value = ['success!!'] # Set send_alerts to true so the sink happens self.__sa_handler.enable_alert_processor = True # Swap out the alias so the logging occurs self.__sa_handler.env['lambda_alias'] = 'production' self.__sa_handler.run(get_valid_event()) sink_mock.assert_called_with(['success!!']) @patch('logging.Logger.debug') @patch('stream_alert.rule_processor.handler.StreamRules.process') @patch( 'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity' ) def test_run_debug_log_alert(self, extract_mock, rules_mock, log_mock): """StreamAlert Class - Run, Debug Log Alert""" extract_mock.return_value = ('kinesis', 'unit_test_default_stream') rules_mock.return_value = ['success!!'] # Cache the logger level log_level = LOGGER.getEffectiveLevel() # Increase the logger level to debug LOGGER.setLevel(logging.DEBUG) self.__sa_handler.run(get_valid_event()) # Reset the logger level LOGGER.setLevel(log_level) log_mock.assert_called_with('Alerts:\n%s', '[\n "success!!"\n]') @patch('stream_alert.rule_processor.handler.load_stream_payload') @patch('stream_alert.rule_processor.handler.StreamClassifier.load_sources') @patch( 'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity' ) def test_run_no_payload_class(self, extract_mock, load_sources_mock, load_payload_mock): """StreamAlert Class - Run, No Payload Class""" extract_mock.return_value = ('blah', 'entity') load_sources_mock.return_value = True load_payload_mock.return_value = None self.__sa_handler.run({'Records': ['record']}) load_payload_mock.assert_called() @patch('stream_alert.rule_processor.handler.LOGGER') @mock_kinesis def test_firehose_record_delivery(self, mock_logging): """StreamAlert Class - Firehose Record Delivery""" self.__sa_handler.firehose_client = boto3.client( 'firehose', region_name='us-east-1') test_event = convert_events_to_kinesis([ # unit_test_simple_log { 'unit_key_01': 1, 'unit_key_02': 'test' }, { 'unit_key_01': 2, 'unit_key_02': 'test' }, # test_log_type_json_nested { 'date': 'January 01, 3005', 'unixtime': '32661446400', 'host': 'my-host.name.website.com', 'data': { 'super': 'secret' } } ]) delivery_stream_names = [ 'streamalert_data_test_log_type_json_nested', 'streamalert_data_unit_test_simple_log' ] # Setup mock delivery streams for delivery_stream in delivery_stream_names: self.__sa_handler.firehose_client.create_delivery_stream( DeliveryStreamName=delivery_stream, S3DestinationConfiguration={ 'RoleARN': 'arn:aws:iam::123456789012:role/firehose_delivery_role', 'BucketARN': 'arn:aws:s3:::kinesis-test', 'Prefix': '{}/'.format(delivery_stream), 'BufferingHints': { 'SizeInMBs': 123, 'IntervalInSeconds': 124 }, 'CompressionFormat': 'Snappy', }) with patch.object(self.__sa_handler.firehose_client, 'put_record_batch') as firehose_mock: firehose_mock.return_value = {'FailedPutCount': 0} self.__sa_handler.run(test_event) firehose_mock.assert_called() assert_true(mock_logging.info.called) @patch('stream_alert.rule_processor.handler.LOGGER') def test_firehose_limit_record_size(self, mock_logging): """StreamAlert Class - Firehose - Record Size Check""" test_events = [ # unit_test_simple_log { 'unit_key_01': 1, 'unit_key_02': 'test' * 250001 # is 4 bytes higher than max }, { 'unit_key_01': 2, 'unit_key_02': 'test' }, # test_log_type_json_nested { 'date': 'January 01, 3005', 'unixtime': '32661446400', 'host': 'my-host.name.website.com', 'data': { 'super': 'secret' } } ] self.__sa_handler._limit_record_size(test_events) assert_true(len(test_events), 2) assert_true(mock_logging.error.called) @patch('stream_alert.rule_processor.handler.LOGGER') @mock_kinesis def test_firehose_record_delivery_failure(self, mock_logging): """StreamAlert Class - Firehose Record Delivery - Failed PutRecord""" class MockFirehoseClient(object): @staticmethod def put_record_batch(**kwargs): return { 'FailedPutCount': len(kwargs.get('Records')), 'RequestResponses': [ { 'RecordId': '12345', 'ErrorCode': '300', 'ErrorMessage': 'Bad message!!!' }, ] } self.__sa_handler.firehose_client = MockFirehoseClient() test_event = convert_events_to_kinesis([ # unit_test_simple_log { 'unit_key_01': 1, 'unit_key_02': 'test' }, { 'unit_key_01': 2, 'unit_key_02': 'test' }, # test_log_type_json_nested { 'date': 'January 01, 3005', 'unixtime': '32661446400', 'host': 'my-host.name.website.com', 'data': { 'super': 'secret' } } ]) self.__sa_handler.run(test_event) assert_true(mock_logging.error.called) @patch( 'stream_alert.rule_processor.handler.StreamThreatIntel.load_intelligence' ) def test_do_not_invoke_threat_intel(self, load_intelligence_mock): """StreamAlert Class - Invoke load_intelligence""" self.__sa_handler = StreamAlert(get_mock_context(), False) load_intelligence_mock.assert_called() def test_firehose_sanitize_keys(self): """StreamAlert Class - Firehose - Sanitize Keys""" # test_log_type_json_nested test_event = { 'date': 'January 01, 3005', 'unixtime': '32661446400', 'host': 'my-host.name.website.com', 'data': { 'super-duper': 'secret', 'sanitize_me': 1, 'example-key': 1, 'moar**data': 2, 'even.more': 3 } } expected_sanitized_event = { 'date': 'January 01, 3005', 'unixtime': '32661446400', 'host': 'my-host.name.website.com', 'data': { 'super_duper': 'secret', 'sanitize_me': 1, 'example_key': 1, 'moar__data': 2, 'even_more': 3 } } sanitized_event = self.__sa_handler.sanitize_keys(test_event) assert_equal(sanitized_event, expected_sanitized_event) def test_firehose_segment_records_by_size(self): """StreamAlert Class - Firehose - Segment Large Records""" record_batch = [ # unit_test_simple_log { 'unit_key_01': 2, 'unit_key_02': 'testtest' * 10000 } for _ in range(100) ] sized_batches = [] for sized_batch in self.__sa_handler._segment_records_by_size( record_batch): sized_batches.append(sized_batch) assert_true(len(str(sized_batches[0])) < 4000000) assert_equal(len(sized_batches), 4) assert_true(isinstance(sized_batches[3][0], dict)) @mock_kinesis def test_firehose_record_delivery_disabled_logs(self): """StreamAlert Class - Firehose Record Delivery - Disabled Logs""" self.__sa_handler.firehose_client = boto3.client( 'firehose', region_name='us-east-1') test_event = convert_events_to_kinesis([ # unit_test_simple_log { 'unit_key_01': 2, 'unit_key_02': 'testtest' } for _ in range(10) ]) delivery_stream_names = ['streamalert_data_unit_test_simple_log'] # Setup mock delivery streams for delivery_stream in delivery_stream_names: self.__sa_handler.firehose_client.create_delivery_stream( DeliveryStreamName=delivery_stream, S3DestinationConfiguration={ 'RoleARN': 'arn:aws:iam::123456789012:role/firehose_delivery_role', 'BucketARN': 'arn:aws:s3:::kinesis-test', 'Prefix': '{}/'.format(delivery_stream), 'BufferingHints': { 'SizeInMBs': 123, 'IntervalInSeconds': 124 }, 'CompressionFormat': 'Snappy', }) with patch.object(self.__sa_handler.firehose_client, 'put_record_batch') as firehose_mock: firehose_mock.return_value = {'FailedPutCount': 0} self.__sa_handler.config['global']['infrastructure'][ 'firehose'] = { 'disabled_logs': ['unit_test_simple_log'] } self.__sa_handler.run(test_event) firehose_mock.assert_not_called() @patch('stream_alert.rule_processor.handler.LOGGER') @mock_kinesis def test_firehose_record_delivery_client_errorr(self, mock_logging): """StreamAlert Class - Firehose Record Delivery - Client Error""" self.__sa_handler.firehose_client = boto3.client( 'firehose', region_name='us-east-1') test_events = [ # unit_test_simple_log { 'unit_key_01': 2, 'unit_key_02': 'testtest' } for _ in range(10) ] self.__sa_handler._firehose_request_helper('invalid_stream', test_events) missing_stream_message = 'Client Error ... An error occurred ' \ '(ResourceNotFoundException) when calling the PutRecordBatch ' \ 'operation: Stream invalid_stream under account 123456789012 not found.' assert_true(mock_logging.error.called_with(missing_stream_message))
class TestStreamAlert(object): """Test class for StreamAlert class""" def __init__(self): self.__sa_handler = None @patch('stream_alert.rule_processor.handler.load_config', lambda: load_config('tests/unit/conf/')) def setup(self): """Setup before each method""" self.__sa_handler = StreamAlert(get_mock_context(), False) def test_run_no_records(self): """StreamAlert Class - Run, No Records""" passed = self.__sa_handler.run({'Records': []}) assert_false(passed) @staticmethod @raises(ConfigError) def test_run_config_error(): """StreamAlert Class - Run, Config Error""" mock = mock_open( read_data='non-json string that will raise an exception') with patch('__builtin__.open', mock): StreamAlert(get_mock_context()) def test_get_alerts(self): """StreamAlert Class - Get Alerts""" default_list = ['alert1', 'alert2'] self.__sa_handler._alerts = default_list assert_list_equal(self.__sa_handler.get_alerts(), default_list) @patch('stream_alert.rule_processor.handler.StreamClassifier.load_sources') @patch( 'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity' ) def test_run_no_sources(self, extract_mock, load_sources_mock): """StreamAlert Class - Run, No Loaded Sources""" extract_mock.return_value = ('lambda', 'entity') load_sources_mock.return_value = None self.__sa_handler.run({'Records': ['record']}) load_sources_mock.assert_called_with('lambda', 'entity') @patch('logging.Logger.error') @patch( 'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity' ) def test_run_bad_service(self, extract_mock, log_mock): """StreamAlert Class - Run, Bad Service""" extract_mock.return_value = ('', 'entity') self.__sa_handler.run({'Records': ['record']}) log_mock.assert_called_with( 'No valid service found in payload\'s raw record. ' 'Skipping record: %s', 'record') @patch('logging.Logger.error') @patch( 'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity' ) def test_run_bad_entity(self, extract_mock, log_mock): """StreamAlert Class - Run, Bad Entity""" extract_mock.return_value = ('kinesis', '') self.__sa_handler.run({'Records': ['record']}) log_mock.assert_called_with( 'Unable to extract entity from payload\'s raw record for ' 'service %s. Skipping record: %s', 'kinesis', 'record') @patch('stream_alert.rule_processor.handler.load_stream_payload') @patch('stream_alert.rule_processor.handler.StreamClassifier.load_sources') @patch( 'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity' ) def test_run_load_payload_bad(self, extract_mock, load_sources_mock, load_payload_mock): """StreamAlert Class - Run, Loaded Payload Fail""" extract_mock.return_value = ('lambda', 'entity') load_sources_mock.return_value = True self.__sa_handler.run({'Records': ['record']}) load_payload_mock.assert_called_with('lambda', 'entity', 'record') @patch('stream_alert.rule_processor.handler.StreamRules.process') @patch( 'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity' ) def test_run_with_alert(self, extract_mock, rules_mock): """StreamAlert Class - Run, With Alert""" extract_mock.return_value = ('kinesis', 'unit_test_default_stream') rules_mock.return_value = ['success!!'] passed = self.__sa_handler.run(get_valid_event()) assert_true(passed) @patch('logging.Logger.debug') @patch( 'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity' ) def test_run_no_alerts(self, extract_mock, log_mock): """StreamAlert Class - Run, With No Alerts""" extract_mock.return_value = ('kinesis', 'unit_test_default_stream') self.__sa_handler.run(get_valid_event()) calls = [ call('Processed %d valid record(s) that resulted in %d alert(s).', 1, 0), call('Invalid record count: %d', 0), call('%s alerts triggered', 0) ] log_mock.assert_has_calls(calls) @patch('logging.Logger.error') @patch( 'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity' ) def test_run_invalid_data(self, extract_mock, log_mock): """StreamAlert Class - Run, Invalid Data""" extract_mock.return_value = ('kinesis', 'unit_test_default_stream') event = get_valid_event() # Replace the good log data with bad data event['Records'][0]['kinesis']['data'] = base64.b64encode( '{"bad": "data"}') # Swap out the alias so the logging occurs self.__sa_handler.env['lambda_alias'] = 'production' self.__sa_handler.run(event) assert_equal(log_mock.call_args[0][0], 'Record does not match any defined schemas: %s\n%s') assert_equal(log_mock.call_args[0][2], '{"bad": "data"}') @patch('stream_alert.rule_processor.sink.StreamSink.sink') @patch('stream_alert.rule_processor.handler.StreamRules.process') @patch( 'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity' ) def test_run_send_alerts(self, extract_mock, rules_mock, sink_mock): """StreamAlert Class - Run, Send Alert""" extract_mock.return_value = ('kinesis', 'unit_test_default_stream') rules_mock.return_value = ['success!!'] # Set send_alerts to true so the sink happens self.__sa_handler.enable_alert_processor = True # Swap out the alias so the logging occurs self.__sa_handler.env['lambda_alias'] = 'production' self.__sa_handler.run(get_valid_event()) sink_mock.assert_called_with(['success!!']) @patch('logging.Logger.debug') @patch('stream_alert.rule_processor.handler.StreamRules.process') @patch( 'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity' ) def test_run_debug_log_alert(self, extract_mock, rules_mock, log_mock): """StreamAlert Class - Run, Debug Log Alert""" extract_mock.return_value = ('kinesis', 'unit_test_default_stream') rules_mock.return_value = ['success!!'] # Cache the logger level log_level = LOGGER.getEffectiveLevel() # Increase the logger level to debug LOGGER.setLevel(logging.DEBUG) self.__sa_handler.run(get_valid_event()) # Reset the logger level LOGGER.setLevel(log_level) log_mock.assert_called_with('Alerts:\n%s', '[\n "success!!"\n]') @patch('stream_alert.rule_processor.handler.load_stream_payload') @patch('stream_alert.rule_processor.handler.StreamClassifier.load_sources') @patch( 'stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity' ) def test_run_no_payload_class(self, extract_mock, load_sources_mock, load_payload_mock): """StreamAlert Class - Run, No Payload Class""" extract_mock.return_value = ('blah', 'entity') load_sources_mock.return_value = True load_payload_mock.return_value = None self.__sa_handler.run({'Records': ['record']}) load_payload_mock.assert_called()
class RuleProcessorTester(object): """Class to encapsulate testing the rule processor""" def __init__(self, context, config, print_output): """RuleProcessorTester initializer Args: print_output (bool): Whether this processor test should print results to stdout. This is set to false when the alert processor is explicitly being testing alone, and set to true for rule processor tests and end-to-end tests. Warnings and errors captrued during rule processor testing will still be written to stdout regardless of this setting. """ # Create the RuleProcessor. Passing a mocked context object with fake # values and False for suppressing sending of alerts to alert processor self.processor = StreamAlert(context) self.cli_config = config # Use a list of status_messages to store pass/fail/warning info self.status_messages = [] self.total_tests = 0 self.all_tests_passed = True self.print_output = print_output # Configure mocks for Firehose and DDB helpers.setup_mock_firehose_delivery_streams(config) helpers.setup_mock_dynamodb_ioc_table(config) # Create a cache map of parsers to parser classes self.parsers = {} # Patch the tmp shredding as to not slow down testing patch( 'stream_alert.rule_processor.payload.S3Payload._shred_temp_directory' ).start() # Patch random_bool to always return true patch('helpers.base.random_bool', return_value=True).start() def test_processor(self, rules_filter, files_filter, validate_only): """Perform integration tests for the 'rule' Lambda function Args: rules_filter (set): A collection of rules to filter on, passed in by the user via the CLI using the --test-rules option. files_filter (set): A collection of files to filter on, passed in by the user via the CLI using the --test-files option. validate_only (bool): If true, validation of test records will occur without the rules engine being applied to events. Yields: tuple (bool, list) or None: If testing rules, this yields a tuple containig a boolean of test status and a list of alerts to run through the alert processor. If validating test records only, this does not yield. """ test_file_info = self._filter_files( helpers.get_rule_test_files(TEST_EVENTS_DIR), files_filter) for name in sorted(test_file_info): path = test_file_info[name] events, error = helpers.load_test_file(path) if error is not None: self.all_tests_passed = False self.status_messages.append( StatusMessage(StatusMessage.WARNING, error)) continue print_header = True for test_event in events: self.total_tests += 1 if self._detect_old_test_event(test_event): self.all_tests_passed = False message = ( 'Detected old format for test event in file \'{}.json\'. ' 'Please visit https://streamalert.io/rule-testing.html ' 'for information on the new format and update your ' 'test events accordingly.'.format(name)) self.status_messages.append( StatusMessage(StatusMessage.FAILURE, message)) continue if not self.check_keys(test_event): self.all_tests_passed = False continue # Check if there are any rule filters in place, and if the current test event # should be exeecuted per the filter if rules_filter and set( test_event['trigger_rules']).isdisjoint(rules_filter): self.total_tests -= 1 continue self.apply_helpers(test_event) if 'override_record' in test_event: self.apply_template(test_event) formatted_record = helpers.format_lambda_test_record( test_event) # If this test is to validate the schema only, continue the loop and # do not yield results on the rule tests below if validate_only or (not validate_only and test_event.get('validate_schema_only')): if self._validate_test_record(name, test_event, formatted_record, print_header) is False: self.all_tests_passed = False else: yield self._run_rule_tests(name, test_event, formatted_record, print_header) print_header = False # Report on the final test results self.report_output_summary() def _filter_files(self, file_info, files_filter): """Filter the test files based in input from the user Args: file_info (dict): Information about test files on disk, where the key is the base name of the file and the value is the relative path to the file files_filter (set): A collection of files to filter tests on Returns: dict: A modified version of the `file_info` arg with pared down values """ if not files_filter: return file_info files_filter = {os.path.splitext(name)[0] for name in files_filter} file_info = { name: path for name, path in file_info.iteritems() if os.path.splitext(name)[0] in files_filter } filter_diff = set(files_filter).difference(set(file_info)) message_template = 'No test events file found with base name \'{}\'' for missing_file in filter_diff: self.status_messages.append( StatusMessage(StatusMessage.WARNING, message_template.format(missing_file))) return file_info def _validate_test_record(self, file_name, test_event, formatted_record, print_header_line): """Function to validate test records and log any errors Args: file_name (str): The base name of the test event file. test_event (dict): A single test event containing the record and other detail formatted_record (dict): A dictionary that includes the 'data' from the test record, formatted into a structure that is resemblant of how an incoming record from a service would format it. See test/integration/templates for example of how each service formats records. print_header_line (bool): Indicates if this is the first record from a test file, and therefore we should print some header information """ service, entity = self.processor.classifier.extract_service_and_entity( formatted_record) if not self.processor.classifier.load_sources(service, entity): return False # Create the StreamPayload to use for encapsulating parsed info payload = load_stream_payload(service, entity, formatted_record) if not payload: return False if print_header_line: print '\n{}'.format(file_name) for record in payload.pre_parse(): self.processor.classifier.classify_record(record) if not record.valid: self.all_tests_passed = False self.analyze_record_delta(file_name, test_event) report_output(record.valid, [ '[log=\'{}\']'.format(record.log_source or 'unknown'), 'validation', record.service(), test_event['description'] ]) def _run_rule_tests(self, file_name, test_event, formatted_record, print_header_line): """Run tests on a test record for a given rule Args: file_name (str): The base name of the test event file. test_event (dict): The loaded test event from json formatted_record (dict): A dictionary that includes the 'data' from the test record, formatted into a structure that is resemblant of how an incoming record from a service would format it. See test/integration/templates for example of how each service formats records. print_header_line (bool): Indicates if this is the first record from a test file, and therefore we should print some header information Returns: list: alerts that were generated from this test event """ event = {'Records': [formatted_record]} expected_alert_count = len(test_event['trigger_rules']) # Run tests on the formatted record alerts, all_records_matched_schema = self.test_rule(event) # Get a list of any rules that triggerer but are not defined in the 'trigger_rules' unexpected_alerts = [] disabled_rules = [ item for item in test_event['trigger_rules'] if rule.Rule.get_rule(item).disabled ] expected_alert_count -= len(disabled_rules) triggers = set(test_event['trigger_rules']) - set(disabled_rules) # we only want alerts for the specific rule being tested (if trigger_rules are defined) if triggers: unexpected_alerts = [ alert for alert in alerts if alert.rule_name not in triggers ] alerts = [alert for alert in alerts if alert.rule_name in triggers] alerted_properly = (len(alerts) == expected_alert_count) and not unexpected_alerts current_test_passed = alerted_properly and all_records_matched_schema self.all_tests_passed = current_test_passed and self.all_tests_passed # Print rule name for section header, but only if we get # to a point where there is a record to actually be tested. # This avoids potentially blank sections if print_header_line and (alerts or self.print_output): print '\n{}'.format(file_name) if self.print_output: disabled_output = '' if disabled_rules: disabled_output = ',disabled={}'.format(len(disabled_rules)) report_output(current_test_passed, [ '[trigger={}{}]'.format(expected_alert_count, disabled_output), 'rule', test_event['service'], test_event['description'] ]) # Add the status of the rule to messages list if not all_records_matched_schema: self.analyze_record_delta(file_name, test_event) elif not alerted_properly: message = ('Test failure: [{}.json] Test event with description ' '\'{}\'').format(file_name, test_event['description']) if alerts and not triggers: # If there was a failure due to alerts triggering for a test event # that does not have any trigger_rules configured context = 'is triggering the following rules but should not trigger at all: {}' trigger_rules = ', '.join('\'{}\''.format(alert.rule_name) for alert in alerts) message = '{} {}'.format(message, context.format(trigger_rules)) elif unexpected_alerts: # If there was a failure due to alerts triggering for other rules outside # of the rules defined in the trigger_rules list for the event context = 'is triggering the following rules but should not be: {}' bad_rules = ', '.join('\'{}\''.format(alert.rule_name) for alert in unexpected_alerts) message = '{} {}'.format(message, context.format(bad_rules)) elif expected_alert_count != len(alerts): # If there was a failure due to alerts NOT triggering for 1+ rules # defined in the trigger_rules list for the event context = 'did not trigger the following rules: {}' non_triggered_rules = ', '.join( '\'{}\''.format(rule) for rule in triggers if rule not in [alert.rule_name for alert in alerts]) message = '{} {}'.format(message, context.format(non_triggered_rules)) else: # If there was a failure for some other reason, just use a default message message = 'Rule failure: [{}.json] {}'.format( file_name, test_event['description']) self.status_messages.append( StatusMessage(StatusMessage.FAILURE, message)) # Return the alerts back to caller return alerts @staticmethod def _detect_old_test_event(test_event): """Check if the test event contains the old format used Args: test_event (dict): The loaded test event from json Returns: bool: True if a legacy test file is detected, False otherwise """ record_keys = set(test_event) if (not {'log', 'trigger_rules'}.issubset(record_keys) and {'trigger'}.issubset(record_keys)): return True return False def check_keys(self, test_event): """Check if the test event contains the required keys Args: test_event (dict): The loaded test event from json Returns: bool: True if the proper keys are present """ required_keys = { 'description', 'log', 'service', 'source', 'trigger_rules' } record_keys = set(test_event) if not required_keys.issubset(record_keys): req_key_diff = required_keys.difference(record_keys) missing_keys = ', '.join('\'{}\''.format(key) for key in req_key_diff) message = 'Missing required key(s) in log: {}'.format(missing_keys) self.status_messages.append( StatusMessage(StatusMessage.FAILURE, message)) return False input_data_keys = {'data', 'override_record'} if not record_keys & input_data_keys: missing_keys = ', '.join('\'{}\''.format(key) for key in input_data_keys) message = 'Missing one of the following keys in log: {}'.format( missing_keys) self.status_messages.append( StatusMessage(StatusMessage.FAILURE, message)) return False optional_keys = {'compress', 'validate_schema_only'} key_diff = record_keys.difference(required_keys | optional_keys | input_data_keys) # Log a warning if there are extra keys declared in the test log if key_diff: extra_keys = ', '.join('\'{}\''.format(key) for key in key_diff) message = 'Additional unnecessary keys in log: {}'.format( extra_keys) # Remove the key(s) and just warn the user that they are extra record_keys.difference_update(key_diff) self.status_messages.append( StatusMessage(StatusMessage.WARNING, message)) return record_keys.issubset(required_keys | optional_keys | input_data_keys) def apply_template(self, test_event): """Apply default values to the given test event Args: test_event (dict): The loaded test event """ event_log = self.cli_config['logs'].get(test_event['log']) parser = event_log['parser'] schema = event_log['schema'] configuration = event_log.get('configuration', {}) # Add envelope keys schema.update(configuration.get('envelope_keys', {})) # Setup the parser to access default optional values self.parsers[parser] = self.parsers.get(parser, get_parser(parser)) # Add apply default values based on the declared schema default_test_event = { key: self.parsers[parser].default_optional_values(value) for key, value in schema.iteritems() } # Fill in the fields left out in the 'override_record' field, # and update the test event with a full 'data' key default_test_event.update(test_event['override_record']) test_event['data'] = default_test_event @staticmethod def apply_helpers(test_record): """Detect and apply helper functions to test event data Helpers are declared in test fixtures via the following keyword: "<helpers:helper_name>" Supported helper functions: last_hour: return the current epoch time minus 60 seconds to pass the last_hour rule helper. Args: test_record (dict): loaded fixture file JSON as a dict. """ # declare all helper functions here, they should always return a string record_helpers = {'last_hour': lambda: str(int(time.time()) - 60)} helper_regex = re.compile(r'<helper:(?P<helper>\w+)>') def find_and_apply_helpers(test_record): """Apply any helpers to the passed in test_record""" for key, value in test_record.iteritems(): if isinstance(value, (str, unicode)): test_record[key] = re.sub( helper_regex, lambda match: record_helpers[match.group('helper')](), test_record[key]) elif isinstance(value, dict): find_and_apply_helpers(test_record[key]) find_and_apply_helpers(test_record) def report_output_summary(self): """Helper function to print the summary results of all tests""" failure_messages = [ item for item in self.status_messages if item.type == StatusMessage.FAILURE ] warning_messages = [ item for item in self.status_messages if item.type == StatusMessage.WARNING ] passed_tests = sum(1 for item in self.status_messages if item.type == StatusMessage.SUCCESS) passed_tests = self.total_tests - len(failure_messages) # Print some lines at the bottom of output to make it more readable # This occurs here so there is always space and not only when the # successful test info prints print '\n\n' # Only print success info if we explicitly want to print output # but always print any errors or warnings below if self.print_output: # Print a message indicating how many of the total tests passed LOGGER_CLI.info('%s(%d/%d) Successful Tests%s', COLOR_GREEN, passed_tests, self.total_tests, COLOR_RESET) # Check if there were failed tests and report on them appropriately if failure_messages: # Print a message indicating how many of the total tests failed LOGGER_CLI.error('%s(%d/%d) Failures%s', COLOR_RED, len(failure_messages), self.total_tests, COLOR_RESET) # Iterate over the rule_name values in the failed list and report on them for index, failure in enumerate(failure_messages, start=1): LOGGER_CLI.error('%s(%d/%d) %s%s', COLOR_RED, index, len(failure_messages), failure.message, COLOR_RESET) # Check if there were any warnings and report on them if warning_messages: warning_count = len(warning_messages) LOGGER_CLI.warn('%s%d Warning%s%s', COLOR_YELLOW, warning_count, ('s' if warning_count > 1 else ''), COLOR_RESET) for index, warning in enumerate(warning_messages, start=1): LOGGER_CLI.warn('%s(%d/%d) %s%s', COLOR_YELLOW, index, warning_count, warning.message, COLOR_RESET) def test_rule(self, record): """Feed formatted records into StreamAlert and check for alerts Args: record (dict): A formatted event that reflects the structure expected as input to the Lambda function. Returns: list: alerts that hit for this rule bool: False if errors occurred during processing """ # Clear out any old alerts or errors from the previous test run # pylint: disable=protected-access del self.processor._alerts[:] self.processor._failed_record_count = 0 # Run the rule processor all_records_matched_schema = self.processor.run(record) return self.processor.alerts, all_records_matched_schema def check_log_declared_in_sources(self, base_message, test_event): """A simple check to see if this log type is defined in the sources for the service Args: base_message (str): Base error message to be reported with extra context test_event (dict): Actual record data being tested Returns: bool: False if the log type is not in the sources list, True if it is """ source = test_event['source'] service = test_event['service'] log = test_event['log'].split(':')[0] if not log in self.cli_config['sources'][service][source]['logs']: message = ( 'The \'sources.json\' file does not include the log type \'{}\' ' 'in the list of logs for this service & entity (\'{}:{}\').') message = '{} {}'.format(base_message, message.format(log, service, source)) self.status_messages.append( StatusMessage(StatusMessage.FAILURE, message)) return False return True def analyze_record_delta(self, file_name, test_event): """Provide some additional context on why this test failed. This will perform some analysis of the test record to determine which keys are missing or which unnecessary keys are causing the test to fail. Any errors are appended to a list of errors so they can be printed at the end of the test run. Args: file_name (str): Name of file containing the test event test_event (dict): Actual record data being tested """ base_message = ( 'Invalid test event in file \'{}.json\' with description ' '\'{}\'.'.format(file_name, test_event['description'])) if not self.check_log_declared_in_sources(base_message, test_event): return log_type = test_event['log'] if log_type not in self.cli_config['logs']: message = ( '{} Log (\'{}\') declared in test event does not exist in ' 'logs.json'.format(base_message, log_type)) self.status_messages.append( StatusMessage(StatusMessage.FAILURE, message)) return config_log_info = self.cli_config['logs'][log_type] schema_keys = config_log_info['schema'] envelope_keys = config_log_info.get('configuration', {}).get('envelope_keys') if envelope_keys: if self.report_envelope_key_error(base_message, envelope_keys, test_event['data']): return # Check is a json path is used for nested records json_path = config_log_info.get('configuration', {}).get('json_path') if json_path: records_jsonpath = jsonpath_rw.parse(json_path) for match in records_jsonpath.find(test_event['data']): self.report_record_delta(base_message, log_type, schema_keys, match.value) return self.report_record_delta(base_message, log_type, schema_keys, test_event['data']) def report_envelope_key_error(self, base_message, envelope_keys, test_record): """Provide context failures related to envelope key issues. Args: base_message (str): Base error message to be reported with extra context envelope_keys (list): A collection of the envelope keys for this nested schema test_record (dict): Actual record being tested - this could be one of many records extracted using jsonpath_rw """ missing_env_key_list = set(envelope_keys).difference(set(test_record)) if missing_env_key_list: missing_key_list = ', '.join('\'{}\''.format(key) for key in missing_env_key_list) message = ( '{} Data is invalid due to missing envelope key(s) in test record: ' '{}.'.format(base_message, missing_key_list)) self.status_messages.append( StatusMessage(StatusMessage.FAILURE, message)) return True return False def report_record_delta(self, base_message, log_type, schema_keys, test_record): """Provide context on why this specific record failed. Args: base_message (str): Base error message to be reported with extra context log_type (str): Type of log being tested schema_keys (set): A collection of the keys from the schema test_record (dict): Actual record being tested - this could be one of many records extracted using jsonpath_rw """ optional_keys = set( self.cli_config['logs'] \ [log_type].get('configuration', {}).get('optional_top_level_keys', {}) ) min_req_record_schema_keys = set(schema_keys).difference(optional_keys) test_record_keys = set(test_record) schema_diff = min_req_record_schema_keys.difference(test_record_keys) if schema_diff: missing_key_list = ', '.join('\'{}\''.format(key) for key in schema_diff) message = ( '{} Data is invalid due to missing key(s) in test record: ' '{}.'.format(base_message, missing_key_list)) self.status_messages.append( StatusMessage(StatusMessage.FAILURE, message)) return unexpected_keys = test_record_keys.difference(schema_keys) if unexpected_keys: unexpected_key_list = ', '.join('\'{}\''.format(key) for key in unexpected_keys) message = ( '{} Data is invalid due to unexpected key(s) in test record: ' '{}.'.format(base_message, unexpected_key_list)) self.status_messages.append( StatusMessage(StatusMessage.FAILURE, message)) return # Add a generic error message if we can not determine what the issue is message = '{} Please look for any errors above.'.format(base_message) self.status_messages.append( StatusMessage(StatusMessage.FAILURE, message))
class RuleProcessorTester(object): """Class to encapsulate testing the rule processor""" def __init__(self, context, print_output): """RuleProcessorTester initializer Args: print_output (bool): Whether this processor test should print results to stdout. This is set to false when the alert processor is explicitly being testing alone, and set to true for rule processor tests and end-to-end tests. Warnings and errors captrued during rule processor testing will still be written to stdout regardless of this setting. """ # Create the RuleProcessor. Passing a mocked context object with fake # values and False for suppressing sending of alerts to alert processor self.processor = StreamAlert(context, False) # Use a list of status_messages to store pass/fail/warning info self.status_messages = [] self.total_tests = 0 self.all_tests_passed = True self.print_output = print_output def test_processor(self, filter_rules, validate_only=False): """Perform integration tests for the 'rule' Lambda function Args: filter_rules (list|None): Specific rule names (or None) to restrict testing to. This is passed in from the CLI using the --rules option. validate_only (bool): If true, validation of test records will occur without the rules engine being applied to events. Yields: tuple (bool, list) or None: If testing rules, this yields a tuple containig a boolean of test status and a list of alerts to run through the alert processor. If validating test records only, this does not yield. """ for rule_name, contents in self._get_rule_test_files( filter_rules, validate_only): # Go over the records and test the applicable rule for index, test_record in enumerate(contents.get('records')): self.total_tests += 1 if not self.check_keys(rule_name, test_record): self.all_tests_passed = False continue self.apply_helpers(test_record) print_header_line = index == 0 formatted_record = helpers.format_lambda_test_record( test_record) if validate_only: self._validate_test_records(rule_name, test_record, formatted_record, print_header_line) continue yield self._run_rule_tests(rule_name, test_record, formatted_record, print_header_line) # Report on the final test results self.report_output_summary() def _validate_test_records(self, rule_name, test_record, formatted_record, print_header_line): """Function to validate test records and log any errors Args: rule_name (str): The rule name being tested test_record (dict): A single record to test formatted_record (dict): A dictionary that includes the 'data' from the test record, formatted into a structure that is resemblant of how an incoming record from a service would format it. See test/integration/templates for example of how each service formats records. """ service, entity = self.processor.classifier.extract_service_and_entity( formatted_record) if not self.processor.classifier.load_sources(service, entity): self.all_tests_passed = False return # Create the StreamPayload to use for encapsulating parsed info payload = load_stream_payload(service, entity, formatted_record) if not payload: self.all_tests_passed = False return if print_header_line: print '\n{}'.format(rule_name) for record in payload.pre_parse(): self.processor.classifier.classify_record(record) if not record.valid: self.all_tests_passed = False self.analyze_record_delta(rule_name, test_record) report_output(record.valid, [ '[log=\'{}\']'.format(record.log_source or 'unknown'), 'validation', record.service(), test_record['description'] ]) def _run_rule_tests(self, rule_name, test_record, formatted_record, print_header_line): """Run tests on a test record for a given rule Args: rule_name (str): The name of the rule being tested. test_record (dict): The loaded test event from json formatted_record (dict): A dictionary that includes the 'data' from the test record, formatted into a structure that is resemblant of how an incoming record from a service would format it. See test/integration/templates for example of how each service formats records. print_header_line (bool): Indicates if this is the first record from a test file, and therefore we should print some header information Returns: list: alerts that were generated from this test event """ event = {'Records': [formatted_record]} # Run tests on the formatted record alerts, expected_alerts, all_records_matched_schema = self.test_rule( rule_name, test_record, event) alerted_properly = (len(alerts) == expected_alerts) current_test_passed = alerted_properly and all_records_matched_schema self.all_tests_passed = current_test_passed and self.all_tests_passed # Print rule name for section header, but only if we get # to a point where there is a record to actually be tested. # This avoids potentialy blank sections if print_header_line and (alerts or self.print_output): print '\n{}'.format(rule_name) if self.print_output: report_output(current_test_passed, [ '[trigger={}]'.format(expected_alerts), 'rule', test_record['service'], test_record['description'] ]) # Add the status of the rule to messages list if not all_records_matched_schema: self.analyze_record_delta(rule_name, test_record) elif not alerted_properly: message = 'Rule failure: {}'.format(test_record['description']) self.status_messages.append( StatusMessage(StatusMessage.FAILURE, rule_name, message)) # Return the alerts back to caller return alerts def _get_rule_test_files(self, filter_rules, validate_only): """Helper to get rule files to be tested Args: filter_rules (list|None): List of specific rule names or file names (or None) that has been fed in from the CLI to restrict testing to Yields: str: rule name dict: loaded json contents of the respective test event file """ # Since filter_rules can be either a list of rule names or rule files, # we should check to see if there is a '.json' extension and just use the # base filename. This approach avoids two functions that do largely the same thing if filter_rules: for index, rule in enumerate(filter_rules): parts = os.path.splitext(rule) if parts[1] == '.json': filter_rules[index] = parts[0] # Create a copy of the filtered rules that can be altered filter_rules_copy = filter_rules[:] for _, _, test_rule_files in os.walk(DIR_RULES): for rule_file in test_rule_files: rule_name = os.path.splitext(rule_file)[0] # If only specific rules are being tested, # skip files that do not match those rules if filter_rules: if rule_name not in filter_rules: continue filter_rules_copy.remove(rule_name) with open(os.path.join(DIR_RULES, rule_file), 'r') as rule_file_handle: try: contents = json.load(rule_file_handle) except (ValueError, TypeError) as err: self.all_tests_passed = False message = 'Improperly formatted file ({}): {}'.format( rule_file, err.message) self.status_messages.append( StatusMessage(StatusMessage.WARNING, rule_name, message)) continue if not contents.get('records'): self.all_tests_passed = False self.status_messages.append( StatusMessage(StatusMessage.WARNING, rule_name, 'No records to test in file')) continue yield rule_name, contents # Print any of the filtered rules that remain in the list # This means that there are not tests configured for them if filter_rules and filter_rules_copy: self.all_tests_passed = False message = 'No test events configured for designated rule' for filter_rule in filter_rules: if validate_only: message = 'Designated file ({}.json) does not exist within \'{}\''.format( filter_rule, DIR_RULES) self.status_messages.append( StatusMessage(StatusMessage.WARNING, filter_rule, message)) def check_keys(self, rule_name, test_record): """Check the test_record contains the required keys Args: rule_name (str): The name of the rule being tested. This is passed in here strictly for reporting any errors with key checks. test_record (dict): The raw test record being processed Returns: bool: True if the proper keys are present """ required_keys = {'data', 'description', 'service', 'source', 'trigger'} record_keys = set(test_record.keys()) if not required_keys.issubset(record_keys): req_key_diff = required_keys.difference(record_keys) missing_keys = ','.join('\'{}\''.format(key) for key in req_key_diff) message = 'Missing required key(s) in log: {}'.format(missing_keys) self.status_messages.append( StatusMessage(StatusMessage.FAILURE, rule_name, message)) return False optional_keys = {'trigger_count', 'compress'} key_diff = record_keys.difference(required_keys | optional_keys) # Log a warning if there are extra keys declared in the test log if key_diff: extra_keys = ','.join('\'{}\''.format(key) for key in key_diff) message = 'Additional unnecessary keys in log: {}'.format( extra_keys) # Remove the key(s) and just warn the user that they are extra record_keys.difference_update(key_diff) self.status_messages.append( StatusMessage(StatusMessage.WARNING, rule_name, message)) return record_keys.issubset(required_keys | optional_keys) @staticmethod def apply_helpers(test_record): """Detect and apply helper functions to test fixtures Helpers are declared in test fixtures via the following keyword: "<helpers:helper_name>" Supported helper functions: last_hour: return the current epoch time minus 60 seconds to pass the last_hour rule helper. Args: test_record (dict): loaded fixture file JSON as a dict. """ # declare all helper functions here, they should always return a string record_helpers = {'last_hour': lambda: str(int(time.time()) - 60)} helper_regex = re.compile(r'<helper:(?P<helper>\w+)>') def find_and_apply_helpers(test_record): """Apply any helpers to the passed in test_record""" for key, value in test_record.iteritems(): if isinstance(value, (str, unicode)): test_record[key] = re.sub( helper_regex, lambda match: record_helpers[match.group('helper')](), test_record[key]) elif isinstance(value, dict): find_and_apply_helpers(test_record[key]) find_and_apply_helpers(test_record) def report_output_summary(self): """Helper function to print the summary results of all tests""" failure_messages = [ item for item in self.status_messages if item.type == StatusMessage.FAILURE ] warning_messages = [ item for item in self.status_messages if item.type == StatusMessage.WARNING ] passed_tests = sum(1 for item in self.status_messages if item.type == StatusMessage.SUCCESS) passed_tests = self.total_tests - len(failure_messages) # Print some lines at the bottom of output to make it more readable # This occurs here so there is always space and not only when the # successful test info prints print '\n\n' # Only print success info if we explicitly want to print output # but always print any errors or warnings below if self.print_output: # Print a message indicating how many of the total tests passed LOGGER_CLI.info('%s(%d/%d) Successful Tests%s', COLOR_GREEN, passed_tests, self.total_tests, COLOR_RESET) # Check if there were failed tests and report on them appropriately if failure_messages: # Print a message indicating how many of the total tests failed LOGGER_CLI.error('%s(%d/%d) Failures%s', COLOR_RED, len(failure_messages), self.total_tests, COLOR_RESET) # Iterate over the rule_name values in the failed list and report on them for index, failure in enumerate(failure_messages, start=1): LOGGER_CLI.error('%s(%d/%d) [%s] %s%s', COLOR_RED, index, len(failure_messages), failure.rule, failure.message, COLOR_RESET) # Check if there were any warnings and report on them if warning_messages: warning_count = len(warning_messages) LOGGER_CLI.warn('%s%d Warning%s%s', COLOR_YELLOW, warning_count, ('s' if warning_count > 1 else ''), COLOR_RESET) for index, warning in enumerate(warning_messages, start=1): LOGGER_CLI.warn('%s(%d/%d) [%s] %s%s', COLOR_YELLOW, index, warning_count, warning.rule, warning.message, COLOR_RESET) def test_rule(self, rule_name, test_record, event): """Feed formatted records into StreamAlert and check for alerts Args: rule_name (str): The rule name being tested test_record (dict): A single raw record to test event (dict): A formatted event that reflects the structure expected as input to the Lambda function. Returns: list: alerts that hit for this rule int: count of expected alerts for this rule bool: False if errors occurred during processing """ # Clear out any old alerts or errors from the previous test run # pylint: disable=protected-access del self.processor._alerts[:] self.processor._failed_record_count = 0 expected_alert_count = test_record.get('trigger_count') if not expected_alert_count: expected_alert_count = 1 if test_record['trigger'] else 0 # Run the rule processor all_records_matched_schema = self.processor.run(event) alerts = self.processor.get_alerts() # we only want alerts for the specific rule being tested alerts = [alert for alert in alerts if alert['rule_name'] == rule_name] return alerts, expected_alert_count, all_records_matched_schema def analyze_record_delta(self, rule_name, test_record): """Provide some additional context on why this test failed. This will perform some analysis of the test record to determine which keys are missing or which unnecessary keys are causing the test to fail. Any errors are appended to a list of errors so they can be printed at the end of the test run. Args: rule_name (str): Name of rule being tested test_record (dict): Actual record data being tested """ logs = self.processor.classifier.get_log_info_for_source() rule_info = StreamRules.get_rules()[rule_name] test_record_keys = set(test_record['data']) for log in rule_info.logs: if log not in logs: message = 'Log declared in rule ({}) does not exist in logs.json'.format( log) self.status_messages.append( StatusMessage(StatusMessage.FAILURE, rule_name, message)) continue all_record_schema_keys = set(logs[log]['schema']) optional_keys = set(logs[log].get('configuration', {}).get( 'optional_top_level_keys', {})) min_req_record_schema_keys = all_record_schema_keys.difference( optional_keys) schema_diff = min_req_record_schema_keys.difference( test_record_keys) if schema_diff: message = ( 'Data is invalid due to missing key(s) in test record: {}. ' 'Rule: \'{}\'. Description: \'{}\''.format( ', '.join('\'{}\''.format(key) for key in schema_diff), rule_info.rule_name, test_record['description'])) self.status_messages.append( StatusMessage(StatusMessage.FAILURE, rule_name, message)) continue unexpected_record_keys = test_record_keys.difference( all_record_schema_keys) if unexpected_record_keys: message = ( 'Data is invalid due to unexpected key(s) in test record: {}. ' 'Rule: \'{}\'. Description: \'{}\''.format( ', '.join('\'{}\''.format(key) for key in unexpected_record_keys), rule_info.rule_name, test_record['description'])) self.status_messages.append( StatusMessage(StatusMessage.FAILURE, rule_name, message))
class TestStreamAlert(object): """Test class for StreamAlert class""" @patch('stream_alert.rule_processor.handler.load_config', lambda: load_config('tests/unit/conf/')) def setup(self): """Setup before each method""" self.__sa_handler = StreamAlert(get_mock_context(), False) def test_run_no_records(self): """StreamAlert Class - Run, No Records""" passed = self.__sa_handler.run({'Records': []}) assert_false(passed) def test_get_alerts(self): """StreamAlert Class - Get Alerts""" default_list = ['alert1', 'alert2'] self.__sa_handler._alerts = default_list assert_list_equal(self.__sa_handler.get_alerts(), default_list) @patch('stream_alert.rule_processor.handler.StreamClassifier.load_sources') @patch('stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity') def test_run_no_sources(self, extract_mock, load_sources_mock): """StreamAlert Class - Run, No Loaded Sources""" extract_mock.return_value = ('lambda', 'entity') load_sources_mock.return_value = None self.__sa_handler.run({'Records': ['record']}) load_sources_mock.assert_called_with('lambda', 'entity') @patch('logging.Logger.error') @patch('stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity') def test_run_bad_service(self, extract_mock, log_mock): """StreamAlert Class - Run, Bad Service""" extract_mock.return_value = ('', 'entity') self.__sa_handler.run({'Records': ['record']}) log_mock.assert_called_with('No valid service found in payload\'s raw record. ' 'Skipping record: %s', 'record') @patch('logging.Logger.error') @patch('stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity') def test_run_bad_entity(self, extract_mock, log_mock): """StreamAlert Class - Run, Bad Entity""" extract_mock.return_value = ('kinesis', '') self.__sa_handler.run({'Records': ['record']}) log_mock.assert_called_with( 'Unable to extract entity from payload\'s raw record for ' 'service %s. Skipping record: %s', 'kinesis', 'record') @patch('stream_alert.rule_processor.handler.load_stream_payload') @patch('stream_alert.rule_processor.handler.StreamClassifier.load_sources') @patch('stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity') def test_run_load_payload_bad( self, extract_mock, load_sources_mock, load_payload_mock): """StreamAlert Class - Run, Loaded Payload Fail""" extract_mock.return_value = ('lambda', 'entity') load_sources_mock.return_value = True self.__sa_handler.run({'Records': ['record']}) load_payload_mock.assert_called_with( 'lambda', 'entity', 'record' ) @patch('stream_alert.rule_processor.handler.StreamRules.process') @patch('stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity') def test_run_with_alert(self, extract_mock, rules_mock): """StreamAlert Class - Run, With Alert""" extract_mock.return_value = ('kinesis', 'unit_test_default_stream') rules_mock.return_value = (['success!!'], ['normalized_records']) passed = self.__sa_handler.run(get_valid_event()) assert_true(passed) @patch('stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity') def test_run_alert_count(self, extract_mock): """StreamAlert Class - Run, Check Count With 4 Logs""" count = 4 extract_mock.return_value = ('kinesis', 'unit_test_default_stream') self.__sa_handler.run(get_valid_event(count)) assert_equal(self.__sa_handler._processed_record_count, count) @patch('logging.Logger.debug') @patch('stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity') def test_run_no_alerts(self, extract_mock, log_mock): """StreamAlert Class - Run, With No Alerts""" extract_mock.return_value = ('kinesis', 'unit_test_default_stream') self.__sa_handler.run(get_valid_event()) calls = [call('Processed %d valid record(s) that resulted in %d alert(s).', 1, 0), call('Invalid record count: %d', 0), call('%s alerts triggered', 0)] log_mock.assert_has_calls(calls) @patch('logging.Logger.error') @patch('stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity') def test_run_invalid_data(self, extract_mock, log_mock): """StreamAlert Class - Run, Invalid Data""" extract_mock.return_value = ('kinesis', 'unit_test_default_stream') event = get_valid_event() # Replace the good log data with bad data event['Records'][0]['kinesis']['data'] = base64.b64encode('{"bad": "data"}') # Swap out the alias so the logging occurs self.__sa_handler.env['lambda_alias'] = 'production' self.__sa_handler.run(event) assert_equal( log_mock.call_args[0][0], 'Record does not match any defined schemas: %s\n%s') assert_equal(log_mock.call_args[0][2], '{"bad": "data"}') @patch('stream_alert.rule_processor.sink.StreamSink.sink') @patch('stream_alert.rule_processor.handler.StreamRules.process') @patch('stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity') def test_run_send_alerts(self, extract_mock, rules_mock, sink_mock): """StreamAlert Class - Run, Send Alert""" extract_mock.return_value = ('kinesis', 'unit_test_default_stream') rules_mock.return_value = (['success!!'], ['normalized_records']) # Set send_alerts to true so the sink happens self.__sa_handler.enable_alert_processor = True # Swap out the alias so the logging occurs self.__sa_handler.env['lambda_alias'] = 'production' self.__sa_handler.run(get_valid_event()) sink_mock.assert_called_with(['success!!']) @patch('logging.Logger.debug') @patch('stream_alert.rule_processor.handler.StreamRules.process') @patch('stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity') def test_run_debug_log_alert(self, extract_mock, rules_mock, log_mock): """StreamAlert Class - Run, Debug Log Alert""" extract_mock.return_value = ('kinesis', 'unit_test_default_stream') rules_mock.return_value = (['success!!'], ['normalized_records']) # Cache the logger level log_level = LOGGER.getEffectiveLevel() # Increase the logger level to debug LOGGER.setLevel(logging.DEBUG) self.__sa_handler.run(get_valid_event()) # Reset the logger level LOGGER.setLevel(log_level) log_mock.assert_called_with('Alerts:\n%s', '[\n "success!!"\n]') @patch('stream_alert.rule_processor.handler.load_stream_payload') @patch('stream_alert.rule_processor.handler.StreamClassifier.load_sources') @patch('stream_alert.rule_processor.handler.StreamClassifier.extract_service_and_entity') def test_run_no_payload_class( self, extract_mock, load_sources_mock, load_payload_mock): """StreamAlert Class - Run, No Payload Class""" extract_mock.return_value = ('blah', 'entity') load_sources_mock.return_value = True load_payload_mock.return_value = None self.__sa_handler.run({'Records': ['record']}) load_payload_mock.assert_called() @mock_kinesis def test_firehose_record_delivery_disabled_logs(self): """StreamAlert Class - Firehose Record Delivery - Disabled Logs""" self.__sa_handler.firehose_client = boto3.client( 'firehose', region_name='us-east-1') test_event = convert_events_to_kinesis([ # unit_test_simple_log {'unit_key_01': 2, 'unit_key_02': 'testtest'} for _ in range(10)]) delivery_stream_names = ['streamalert_data_unit_test_simple_log'] # Setup mock delivery streams for delivery_stream in delivery_stream_names: self.__sa_handler.firehose_client.create_delivery_stream( DeliveryStreamName=delivery_stream, S3DestinationConfiguration={ 'RoleARN': 'arn:aws:iam::123456789012:role/firehose_delivery_role', 'BucketARN': 'arn:aws:s3:::kinesis-test', 'Prefix': '{}/'.format(delivery_stream), 'BufferingHints': { 'SizeInMBs': 123, 'IntervalInSeconds': 124 }, 'CompressionFormat': 'Snappy', } ) with patch.object(self.__sa_handler.firehose_client, 'put_record_batch') as firehose_mock: firehose_mock.return_value = {'FailedPutCount': 0} self.__sa_handler.config['global']['infrastructure']['firehose'] = { 'disabled_logs': ['unit_test_simple_log']} self.__sa_handler.run(test_event) firehose_mock.assert_not_called() @patch('stream_alert.rule_processor.threat_intel.StreamThreatIntel._query') @patch('stream_alert.rule_processor.threat_intel.StreamThreatIntel.load_from_config') def test_run_threat_intel_enabled(self, mock_threat_intel, mock_query): # pylint: disable=no-self-use """StreamAlert Class - Run SA when threat intel enabled""" @rule(datatypes=['sourceAddress'], outputs=['s3:sample_bucket']) def match_ipaddress(_): # pylint: disable=unused-variable """Testing dummy rule""" return True mock_threat_intel.return_value = StreamThreatIntel('test_table_name', 'us-east-1') mock_query.return_value = ([], []) sa_handler = StreamAlert(get_mock_context(), False) event = { 'account': 123456, 'region': '123456123456', 'source': '1.1.1.2', 'detail': { 'eventName': 'ConsoleLogin', 'sourceIPAddress': '1.1.1.2', 'recipientAccountId': '654321' } } events = [] for i in range(10): event['source'] = '1.1.1.{}'.format(i) events.append(event) kinesis_events = { 'Records': [make_kinesis_raw_record('test_kinesis_stream', json.dumps(event)) for event in events] } passed = sa_handler.run(kinesis_events) assert_true(passed) assert_equal(mock_query.call_count, 1)