def test_limit_record_size(self, mock_logging): """FirehoseClient - Record Size Check""" test_events = [ # unit_test_simple_log { 'unit_key_01': 1, 'unit_key_02': 'test' * 250001 # is 4 bytes higher than max }, { 'unit_key_01': 2, 'unit_key_02': 'test' }, # test_log_type_json_nested { 'date': 'January 01, 3005', 'unixtime': '32661446400', 'host': 'my-host.name.website.com', 'data': { 'super': 'secret' } }, # add another unit_test_sample_log to verify in a different position { 'unit_key_01': 1, 'unit_key_02': 'test' * 250001 # is 4 bytes higher than max }, { 'test': 1 } ] FirehoseClient._limit_record_size(test_events) assert_true(len(test_events), 3) assert_true(mock_logging.error.called)
def test_load_enabled_sources(self): """FirehoseClient - Load Enabled Sources""" config = load_config('tests/unit/conf') firehose_config = { 'enabled_logs': ['json:regex_key_with_envelope', 'test_cloudtrail', 'cloudwatch'] } # expands to 2 logs enabled_logs = FirehoseClient.load_enabled_log_sources(firehose_config, config['logs']) assert_equal(len(enabled_logs), 4) # Make sure the subtitution works properly assert_true(all([':' not in log for log in enabled_logs])) assert_false(FirehoseClient.enabled_log_source('test_inspec'))
def test_sanitize_keys(self): """FirehoseClient - Sanitize Keys""" # test_log_type_json_nested test_event = { 'date': 'January 01, 3005', 'unixtime': '32661446400', 'host': 'my-host.name.website.com', 'data': { 'super-duper': 'secret', 'sanitize_me': 1, 'example-key': 1, 'moar**data': 2, 'even.more': 3 } } expected_sanitized_event = { 'date': 'January 01, 3005', 'unixtime': '32661446400', 'host': 'my-host.name.website.com', 'data': { 'super_duper': 'secret', 'sanitize_me': 1, 'example_key': 1, 'moar__data': 2, 'even_more': 3 } } sanitized_event = FirehoseClient.sanitize_keys(test_event) assert_equal(sanitized_event, expected_sanitized_event)
def test_strip_successful_records(self): """FirehoseClient - Strip Successful Records""" batch = [{'test': 'success'}, {'test': 'data'}, {'other': 'failure'}, {'other': 'info'}] response = { 'FailedPutCount': 1, 'RequestResponses': [ {'RecordId': 'rec_id_00'}, {'RecordId': 'rec_id_01'}, {'ErrorCode': 10, 'ErrorMessage': 'foo'}, {'RecordId': 'rec_id_03'} ] } expected_batch = [{'other': 'failure'}] FirehoseClient._strip_successful_records(batch, response) assert_equal(batch, expected_batch)
def rebuild_partitions(table, bucket, config): """Rebuild an Athena table's partitions Steps: - Get the list of current partitions - Destroy existing table - Re-create tables - Re-create partitions Args: table (str): The name of the table being rebuilt bucket (str): The s3 bucket to be used as the location for Athena data table_type (str): The type of table being refreshed Types of 'data' and 'alert' are accepted, but only 'data' is implemented config (CLIConfig): Loaded StreamAlert CLI """ sanitized_table_name = FirehoseClient.firehose_log_name(table) athena_client = get_athena_client(config) # Get the current set of partitions partitions = athena_client.get_table_partitions(sanitized_table_name) if not partitions: LOGGER_CLI.info('No partitions to rebuild for %s, nothing to do', sanitized_table_name) return # Drop the table LOGGER_CLI.info('Dropping table %s', sanitized_table_name) success = athena_client.drop_table(sanitized_table_name) if not success: return LOGGER_CLI.info('Creating table %s', sanitized_table_name) # Re-create the table with previous partitions create_table(table, bucket, config) new_partitions_statement = helpers.add_partition_statement( partitions, bucket, sanitized_table_name) # Make sure our new alter table statement is within the query API limits if len(new_partitions_statement) > MAX_QUERY_LENGTH: LOGGER_CLI.error('Partition statement too large, writing to local file') with open('partitions_{}.txt'.format(sanitized_table_name), 'w') as partition_file: partition_file.write(new_partitions_statement) return LOGGER_CLI.info('Creating %d new partitions for %s', len(partitions), sanitized_table_name) success = athena_client.run_query(query=new_partitions_statement) if not success: LOGGER_CLI.error('Error re-creating new partitions for %s', sanitized_table_name) return LOGGER_CLI.info('Successfully rebuilt partitions for %s', sanitized_table_name)
def test_load_enabled_sources_invalid_log(self, mock_logging): """FirehoseClient - Load Enabled Sources - Invalid Log""" config = load_config('tests/unit/conf') firehose_config = {'enabled_logs': ['log-that-doesnt-exist']} sa_firehose = FirehoseClient( region='us-east-1', firehose_config=firehose_config, log_sources=config['logs']) assert_equal(len(sa_firehose._ENABLED_LOGS), 0) mock_logging.assert_called_with( 'Enabled Firehose log %s not declared in logs.json', 'log-that-doesnt-exist' )
def setup_mock_firehose_delivery_streams(config): """Mock Kinesis Firehose Streams for rule testing Args: config (CLIConfig): The StreamAlert config """ firehose_config = config['global']['infrastructure'].get('firehose') if not firehose_config: return enabled_logs = FirehoseClient.load_enabled_log_sources( firehose_config, config['logs']) for log_type in enabled_logs: stream_name = 'streamalert_data_{}'.format(log_type) prefix = '{}/'.format(log_type) create_delivery_stream(config['global']['account']['region'], stream_name, prefix)
def test_segment_records_by_size(self): """FirehoseClient - Segment Large Records""" record_batch = [ # unit_test_simple_log { 'unit_key_01': 2, 'unit_key_02': 'testtest' * 10000 } for _ in range(100) ] sized_batches = [] for sized_batch in FirehoseClient._segment_records_by_size(record_batch): sized_batches.append(sized_batch) assert_true(len(str(sized_batches[0])) < 4000000) assert_equal(len(sized_batches), 4) assert_true(isinstance(sized_batches[3][0], dict))
def create_table(table, bucket, config, schema_override=None): """Create a 'streamalert' Athena table Args: table (str): The name of the table being rebuilt bucket (str): The s3 bucket to be used as the location for Athena data table_type (str): The type of table being refreshed config (CLIConfig): Loaded StreamAlert CLI schema_override (set): An optional set of key=value pairs to be used for overriding the configured column_name=value_type. """ enabled_logs = FirehoseClient.load_enabled_log_sources( config['global']['infrastructure']['firehose'], config['logs'] ) # Convert special characters in schema name to underscores sanitized_table_name = FirehoseClient.firehose_log_name(table) # Check that the log type is enabled via Firehose if sanitized_table_name != 'alerts' and sanitized_table_name not in enabled_logs: LOGGER_CLI.error('Table name %s missing from configuration or ' 'is not enabled.', sanitized_table_name) return athena_client = get_athena_client(config) # Check if the table exists if athena_client.check_table_exists(sanitized_table_name): LOGGER_CLI.info('The \'%s\' table already exists.', sanitized_table_name) return if table == 'alerts': # get a fake alert so we can get the keys needed and their types alert = Alert('temp_rule_name', {}, {}) output = alert.output_dict() schema = record_to_schema(output) athena_schema = helpers.logs_schema_to_athena_schema(schema) query = _construct_create_table_statement( schema=athena_schema, table_name=table, bucket=bucket) else: # all other tables are log types log_info = config['logs'][table.replace('_', ':', 1)] schema = dict(log_info['schema']) sanitized_schema = FirehoseClient.sanitize_keys(schema) athena_schema = helpers.logs_schema_to_athena_schema(sanitized_schema) # Add envelope keys to Athena Schema configuration_options = log_info.get('configuration') if configuration_options: envelope_keys = configuration_options.get('envelope_keys') if envelope_keys: sanitized_envelope_key_schema = FirehoseClient.sanitize_keys(envelope_keys) # Note: this key is wrapped in backticks to be Hive compliant athena_schema['`streamalert:envelope_keys`'] = helpers.logs_schema_to_athena_schema( sanitized_envelope_key_schema) # Handle Schema overrides # This is useful when an Athena schema needs to differ from the normal log schema if schema_override: for override in schema_override: column_name, column_type = override.split('=') if not all([column_name, column_type]): LOGGER_CLI.error('Invalid schema override [%s], use column_name=type format', override) # Columns are escaped to avoid Hive issues with special characters column_name = '`{}`'.format(column_name) if column_name in athena_schema: athena_schema[column_name] = column_type LOGGER_CLI.info('Applied schema override: %s:%s', column_name, column_type) else: LOGGER_CLI.error( 'Schema override column %s not found in Athena Schema, skipping', column_name) query = _construct_create_table_statement( schema=athena_schema, table_name=sanitized_table_name, bucket=bucket) success = athena_client.run_query(query=query) if not success: LOGGER_CLI.error('The %s table could not be created', sanitized_table_name) return # Update the CLI config if (table != 'alerts' and bucket not in config['lambda']['athena_partition_refresh_config']['buckets']): config['lambda']['athena_partition_refresh_config']['buckets'][bucket] = 'data' config.write() LOGGER_CLI.info('The %s table was successfully created!', sanitized_table_name)
def generate_firehose(logging_bucket, main_dict, config): """Generate the Firehose Terraform modules Args: config (CLIConfig): The loaded StreamAlert Config main_dict (infinitedict): The Dict to marshal to a file logging_bucket (str): The name of the global logging bucket """ if not config['global']['infrastructure'].get('firehose', {}).get('enabled'): return firehose_config = config['global']['infrastructure']['firehose'] firehose_s3_bucket_suffix = firehose_config.get('s3_bucket_suffix', 'streamalert.data') firehose_s3_bucket_name = '{}.{}'.format( config['global']['account']['prefix'], firehose_s3_bucket_suffix) # Firehose Setup module main_dict['module']['kinesis_firehose_setup'] = { 'source': 'modules/tf_stream_alert_kinesis_firehose_setup', 'account_id': config['global']['account']['aws_account_id'], 'prefix': config['global']['account']['prefix'], 'region': config['global']['account']['region'], 's3_logging_bucket': logging_bucket, 's3_bucket_name': firehose_s3_bucket_name, 'kms_key_id': '${aws_kms_key.server_side_encryption.key_id}' } enabled_logs = FirehoseClient.load_enabled_log_sources( config['global']['infrastructure']['firehose'], config['logs'], force_load=True) log_alarms_config = config['global']['infrastructure']['firehose'].get( 'enabled_logs', {}) # Add the Delivery Streams individually for log_stream_name, log_type_name in enabled_logs.iteritems(): module_dict = { 'source': 'modules/tf_stream_alert_kinesis_firehose_delivery_stream', 'buffer_size': config['global']['infrastructure']['firehose'].get( 'buffer_size', 64), 'buffer_interval': config['global']['infrastructure']['firehose'].get( 'buffer_interval', 300), 'compression_format': config['global']['infrastructure']['firehose'].get( 'compression_format', 'GZIP'), 'log_name': log_stream_name, 'role_arn': '${module.kinesis_firehose_setup.firehose_role_arn}', 's3_bucket_name': firehose_s3_bucket_name, 'kms_key_arn': '${aws_kms_key.server_side_encryption.arn}' } # Try to get alarm info for this specific log type alarm_info = log_alarms_config.get(log_type_name) if not alarm_info and ':' in log_type_name: # Fallback on looking for alarm info for the parent log type alarm_info = log_alarms_config.get(log_type_name.split(':')[0]) if alarm_info and alarm_info.get('enable_alarm'): module_dict['enable_alarm'] = True # There are defaults of these defined in the terraform module, so do # not set the variable values unless explicitly specified if alarm_info.get('log_min_count_threshold'): module_dict['alarm_threshold'] = alarm_info.get( 'log_min_count_threshold') if alarm_info.get('evaluation_periods'): module_dict['evaluation_periods'] = alarm_info.get( 'evaluation_periods') if alarm_info.get('period_seconds'): module_dict['period_seconds'] = alarm_info.get( 'period_seconds') if alarm_info.get('alarm_actions'): if not isinstance(alarm_info.get('alarm_actions'), list): module_dict['alarm_actions'] = [ alarm_info.get('alarm_actions') ] else: module_dict['alarm_actions'] = alarm_info.get( 'alarm_actions') else: module_dict['alarm_actions'] = [monitoring_topic_arn(config)] main_dict['module']['kinesis_firehose_{}'.format( log_stream_name)] = module_dict
def run(self, event): """StreamAlert Lambda function handler. Loads the configuration for the StreamAlert function which contains available data sources, log schemas, normalized types, and outputs. Classifies logs sent into a parsed type. Matches records against rules. Args: event (dict): An AWS event mapped to a specific source/entity containing data read by Lambda. Returns: bool: True if all logs being parsed match a schema """ records = event.get('Records', []) LOGGER.debug('Number of incoming records: %d', len(records)) if not records: return False firehose_config = self.config['global'].get('infrastructure', {}).get('firehose', {}) if firehose_config.get('enabled'): self._firehose_client = FirehoseClient( self.env['region'], firehose_config=firehose_config, log_sources=self.config['logs']) payload_with_normalized_records = [] for raw_record in records: # Get the service and entity from the payload. If the service/entity # is not in our config, log and error and go onto the next record service, entity = self.classifier.extract_service_and_entity( raw_record) if not service: LOGGER.error( 'No valid service found in payload\'s raw record. Skipping ' 'record: %s', raw_record) continue if not entity: LOGGER.error( 'Unable to extract entity from payload\'s raw record for service %s. ' 'Skipping record: %s', service, raw_record) continue # Cache the log sources for this service and entity on the classifier if not self.classifier.load_sources(service, entity): continue # Create the StreamPayload to use for encapsulating parsed info payload = load_stream_payload(service, entity, raw_record) if not payload: continue payload_with_normalized_records.extend( self._process_alerts(payload)) # Log normalized records metric MetricLogger.log_metric(FUNCTION_NAME, MetricLogger.NORMALIZED_RECORDS, len(payload_with_normalized_records)) # Apply Threat Intel to normalized records in the end of Rule Processor invocation record_alerts = self._rules_engine.threat_intel_match( payload_with_normalized_records) self._alerts.extend(record_alerts) if record_alerts: self.alert_forwarder.send_alerts(record_alerts) MetricLogger.log_metric(FUNCTION_NAME, MetricLogger.TOTAL_RECORDS, self._processed_record_count) MetricLogger.log_metric(FUNCTION_NAME, MetricLogger.TOTAL_PROCESSED_SIZE, self._processed_size) LOGGER.debug('Invalid record count: %d', self._failed_record_count) MetricLogger.log_metric(FUNCTION_NAME, MetricLogger.FAILED_PARSES, self._failed_record_count) LOGGER.debug('%s alerts triggered', len(self._alerts)) MetricLogger.log_metric(FUNCTION_NAME, MetricLogger.TRIGGERED_ALERTS, len(self._alerts)) # Check if debugging logging is on before json dumping alerts since # this can be time consuming if there are a lot of alerts if self._alerts and LOGGER_DEBUG_ENABLED: LOGGER.debug( 'Alerts:\n%s', json.dumps([alert.output_dict() for alert in self._alerts], indent=2, sort_keys=True)) if self._firehose_client: self._firehose_client.send() # Only log rule info here if this is not running tests # During testing, this gets logged at the end and printing here could be confusing # since stress testing calls this method multiple times if self.env['qualifier'] != 'development': stats.print_rule_stats(True) return self._failed_record_count == 0
class StreamAlert(object): """Wrapper class for handling StreamAlert classification and processing""" config = {} def __init__(self, context): """Initializer Args: context (dict): An AWS context object which provides metadata on the currently executing lambda function. """ # Load the config. Validation occurs during load, which will # raise exceptions on any ConfigError StreamAlert.config = StreamAlert.config or config.load_config( validate=True) # Load the environment from the context arn self.env = config.parse_lambda_arn(context.invoked_function_arn) # Instantiate the send_alerts here to handle sending the triggered alerts to the # alert processor self.alert_forwarder = AlertForwarder() # Instantiate a classifier that is used for this run self.classifier = StreamClassifier(config=self.config) self._failed_record_count = 0 self._processed_record_count = 0 self._processed_size = 0 self._alerts = [] rule_import_paths = [ item for location in {'rule_locations', 'matcher_locations'} for item in self.config['global']['general'][location] ] # Create an instance of the RulesEngine class that gets cached in the # StreamAlert class as an instance property self._rules_engine = RulesEngine(self.config, *rule_import_paths) # Firehose client attribute self._firehose_client = None def run(self, event): """StreamAlert Lambda function handler. Loads the configuration for the StreamAlert function which contains available data sources, log schemas, normalized types, and outputs. Classifies logs sent into a parsed type. Matches records against rules. Args: event (dict): An AWS event mapped to a specific source/entity containing data read by Lambda. Returns: bool: True if all logs being parsed match a schema """ records = event.get('Records', []) LOGGER.debug('Number of incoming records: %d', len(records)) if not records: return False firehose_config = self.config['global'].get('infrastructure', {}).get('firehose', {}) if firehose_config.get('enabled'): self._firehose_client = FirehoseClient( self.env['region'], firehose_config=firehose_config, log_sources=self.config['logs']) payload_with_normalized_records = [] for raw_record in records: # Get the service and entity from the payload. If the service/entity # is not in our config, log and error and go onto the next record service, entity = self.classifier.extract_service_and_entity( raw_record) if not service: LOGGER.error( 'No valid service found in payload\'s raw record. Skipping ' 'record: %s', raw_record) continue if not entity: LOGGER.error( 'Unable to extract entity from payload\'s raw record for service %s. ' 'Skipping record: %s', service, raw_record) continue # Cache the log sources for this service and entity on the classifier if not self.classifier.load_sources(service, entity): continue # Create the StreamPayload to use for encapsulating parsed info payload = load_stream_payload(service, entity, raw_record) if not payload: continue payload_with_normalized_records.extend( self._process_alerts(payload)) # Log normalized records metric MetricLogger.log_metric(FUNCTION_NAME, MetricLogger.NORMALIZED_RECORDS, len(payload_with_normalized_records)) # Apply Threat Intel to normalized records in the end of Rule Processor invocation record_alerts = self._rules_engine.threat_intel_match( payload_with_normalized_records) self._alerts.extend(record_alerts) if record_alerts: self.alert_forwarder.send_alerts(record_alerts) MetricLogger.log_metric(FUNCTION_NAME, MetricLogger.TOTAL_RECORDS, self._processed_record_count) MetricLogger.log_metric(FUNCTION_NAME, MetricLogger.TOTAL_PROCESSED_SIZE, self._processed_size) LOGGER.debug('Invalid record count: %d', self._failed_record_count) MetricLogger.log_metric(FUNCTION_NAME, MetricLogger.FAILED_PARSES, self._failed_record_count) LOGGER.debug('%s alerts triggered', len(self._alerts)) MetricLogger.log_metric(FUNCTION_NAME, MetricLogger.TRIGGERED_ALERTS, len(self._alerts)) # Check if debugging logging is on before json dumping alerts since # this can be time consuming if there are a lot of alerts if self._alerts and LOGGER_DEBUG_ENABLED: LOGGER.debug( 'Alerts:\n%s', json.dumps([alert.output_dict() for alert in self._alerts], indent=2, sort_keys=True)) if self._firehose_client: self._firehose_client.send() # Only log rule info here if this is not running tests # During testing, this gets logged at the end and printing here could be confusing # since stress testing calls this method multiple times if self.env['qualifier'] != 'development': stats.print_rule_stats(True) return self._failed_record_count == 0 @property def alerts(self): """Returns list of Alert instances (useful for testing).""" return self._alerts def _process_alerts(self, payload): """Run the record through the rules, saving any alerts and forwarding them to Dynamo. Args: payload (StreamPayload): StreamAlert payload object being processed """ payload_with_normalized_records = [] for record in payload.pre_parse(): # Increment the processed size using the length of this record self._processed_size += len(record.pre_parsed_record) self.classifier.classify_record(record) if not record.valid: if self.env['qualifier'] != 'development': LOGGER.error( 'Record does not match any defined schemas: %s\n%s', record, record.pre_parsed_record) self._failed_record_count += 1 continue # Increment the total processed records to get an accurate assessment of throughput self._processed_record_count += len(record.records) LOGGER.debug( 'Classified and Parsed Payload: <Valid: %s, Log Source: %s, Entity: %s>', record.valid, record.log_source, record.entity) record_alerts, normalized_records = self._rules_engine.run(record) payload_with_normalized_records.extend(normalized_records) LOGGER.debug( 'Processed %d valid record(s) that resulted in %d alert(s).', len(payload.records), len(record_alerts)) # Add all parsed records to the categorized payload dict only if Firehose is enabled if self._firehose_client: self._firehose_client.add_payload_records( payload.log_source, payload.records) if not record_alerts: continue # Extend the list of alerts with any new ones so they can be returned self._alerts.extend(record_alerts) self.alert_forwarder.send_alerts(record_alerts) return payload_with_normalized_records
def setup(self): """Setup before each method""" self.sa_firehose = FirehoseClient(region='us-east-1')
class TestFirehoseClient(object): """Test class for FirehoseClient""" # pylint: disable=protected-access,no-self-use,attribute-defined-outside-init def setup(self): """Setup before each method""" self.sa_firehose = FirehoseClient(region='us-east-1') def teardown(self): """Teardown after each method""" FirehoseClient._ENABLED_LOGS.clear() @staticmethod def _sample_categorized_payloads(): return { 'unit_test_simple_log': [{ 'unit_key_01': 1, 'unit_key_02': 'test' }, { 'unit_key_01': 2, 'unit_key_02': 'test' }], 'test_log_type_json_nested': [{ 'date': 'January 01, 3005', 'unixtime': '32661446400', 'host': 'my-host.name.website.com', 'data': { 'super': 'secret' } }] } @mock_kinesis def _mock_delivery_streams(self, delivery_stream_names): """Mock Kinesis Delivery Streams for tests""" for delivery_stream in delivery_stream_names: self.sa_firehose._client.create_delivery_stream( DeliveryStreamName=delivery_stream, S3DestinationConfiguration={ 'RoleARN': 'arn:aws:iam::123456789012:role/firehose_delivery_role', 'BucketARN': 'arn:aws:s3:::kinesis-test', 'Prefix': '{}/'.format(delivery_stream), 'BufferingHints': { 'SizeInMBs': 123, 'IntervalInSeconds': 124 }, 'CompressionFormat': 'Snappy', }) @patch('stream_alert.rule_processor.firehose.LOGGER') @mock_kinesis def test_record_delivery_failed_put_count(self, mock_logging): """FirehoseClient - Record Delivery - Failed Put Count""" # Add sample categorized payloads for payload_type, logs in self._sample_categorized_payloads().iteritems(): self.sa_firehose._categorized_payloads[payload_type].extend(logs) # Setup mocked Delivery Streams self._mock_delivery_streams( ['streamalert_data_test_log_type_json_nested', 'streamalert_data_unit_test_simple_log']) with patch.object(self.sa_firehose._client, 'put_record_batch') as firehose_mock: firehose_mock.side_effect = [{ 'FailedPutCount': 3, 'RequestResponses': [{ "ErrorCode": "ServiceUnavailableException", "ErrorMessage": "Slow down." }, { "ErrorCode": "ServiceUnavailableException", "ErrorMessage": "Slow down." }, { "ErrorCode": "ServiceUnavailableException", "ErrorMessage": "Slow down." }] }, { 'FailedPutCount': 3, 'RequestResponses': [{ "ErrorCode": "ServiceUnavailableException", "ErrorMessage": "Slow down." }, { "ErrorCode": "ServiceUnavailableException", "ErrorMessage": "Slow down." }, { "ErrorCode": "ServiceUnavailableException", "ErrorMessage": "Slow down." }] }, { 'FailedPutCount': 0, 'RequestResponses': [{ "RecordId": "12345678910", "ErrorCode": "None", "ErrorMessage": "None" }, { "RecordId": "12345678910", "ErrorCode": "None", "ErrorMessage": "None" }, { "RecordId": "12345678910", "ErrorCode": "None", "ErrorMessage": "None" }] }] self.sa_firehose.send() firehose_mock.assert_called() assert_true(mock_logging.info.called) @patch('stream_alert.rule_processor.firehose.LOGGER') @mock_kinesis def test_record_delivery(self, mock_logging): """FirehoseClient - Record Delivery""" # Add sample categorized payloads for payload_type, logs in self._sample_categorized_payloads().iteritems(): self.sa_firehose._categorized_payloads[payload_type].extend(logs) # Setup mocked Delivery Streams self._mock_delivery_streams( ['streamalert_data_test_log_type_json_nested', 'streamalert_data_unit_test_simple_log']) # Send the records with patch.object(self.sa_firehose._client, 'put_record_batch') as firehose_mock: firehose_mock.return_value = {'FailedPutCount': 0} self.sa_firehose.send() firehose_mock.assert_called() assert_true(mock_logging.info.called) @patch('stream_alert.rule_processor.firehose.LOGGER') @mock_kinesis def test_record_delivery_failure(self, mock_logging): """FirehoseClient - Record Delivery - Failed PutRecord""" # Add sample categorized payloads for payload_type, logs in self._sample_categorized_payloads().iteritems(): self.sa_firehose._categorized_payloads[payload_type].extend(logs) # Setup mocked Delivery Streams self._mock_delivery_streams( ['streamalert_data_test_log_type_json_nested', 'streamalert_data_unit_test_simple_log']) # Send the records with patch.object(self.sa_firehose._client, 'put_record_batch') as firehose_mock: firehose_mock.return_value = { 'FailedPutCount': 3, 'RequestResponses': [ { 'RecordId': '12345', 'ErrorCode': '300', 'ErrorMessage': 'Bad message!!!' }, ] } self.sa_firehose.send() firehose_mock.assert_called() assert_true(mock_logging.error.called) @patch('stream_alert.rule_processor.firehose.LOGGER') @mock_kinesis def test_record_delivery_client_error(self, mock_logging): """FirehoseClient - Record Delivery - Client Error""" test_events = [ # unit_test_simple_log { 'unit_key_01': 2, 'unit_key_02': 'testtest' } for _ in range(10) ] self.sa_firehose._firehose_request_helper('invalid_stream', test_events) missing_stream_message = 'Client Error ... An error occurred ' \ '(ResourceNotFoundException) when calling the PutRecordBatch ' \ 'operation: Stream invalid_stream under account 123456789012 not found.' assert_true(mock_logging.error.called_with(missing_stream_message)) @mock_kinesis def test_load_enabled_sources(self): """FirehoseClient - Load Enabled Sources""" config = load_config('tests/unit/conf') firehose_config = { 'enabled_logs': ['json:regex_key_with_envelope', 'test_cloudtrail', 'cloudwatch'] } # expands to 2 logs enabled_logs = FirehoseClient.load_enabled_log_sources(firehose_config, config['logs']) assert_equal(len(enabled_logs), 4) # Make sure the subtitution works properly assert_true(all([':' not in log for log in enabled_logs])) assert_false(FirehoseClient.enabled_log_source('test_inspec')) @patch('stream_alert.rule_processor.firehose.LOGGER.error') @mock_kinesis def test_load_enabled_sources_invalid_log(self, mock_logging): """FirehoseClient - Load Enabled Sources - Invalid Log""" config = load_config('tests/unit/conf') firehose_config = {'enabled_logs': ['log-that-doesnt-exist']} sa_firehose = FirehoseClient( region='us-east-1', firehose_config=firehose_config, log_sources=config['logs']) assert_equal(len(sa_firehose._ENABLED_LOGS), 0) mock_logging.assert_called_with( 'Enabled Firehose log %s not declared in logs.json', 'log-that-doesnt-exist' ) def test_strip_successful_records(self): """FirehoseClient - Strip Successful Records""" batch = [{'test': 'success'}, {'test': 'data'}, {'other': 'failure'}, {'other': 'info'}] response = { 'FailedPutCount': 1, 'RequestResponses': [ {'RecordId': 'rec_id_00'}, {'RecordId': 'rec_id_01'}, {'ErrorCode': 10, 'ErrorMessage': 'foo'}, {'RecordId': 'rec_id_03'} ] } expected_batch = [{'other': 'failure'}] FirehoseClient._strip_successful_records(batch, response) assert_equal(batch, expected_batch) def test_segment_records_by_size(self): """FirehoseClient - Segment Large Records""" record_batch = [ # unit_test_simple_log { 'unit_key_01': 2, 'unit_key_02': 'testtest' * 10000 } for _ in range(100) ] sized_batches = [] for sized_batch in FirehoseClient._segment_records_by_size(record_batch): sized_batches.append(sized_batch) assert_true(len(str(sized_batches[0])) < 4000000) assert_equal(len(sized_batches), 4) assert_true(isinstance(sized_batches[3][0], dict)) def test_sanitize_keys(self): """FirehoseClient - Sanitize Keys""" # test_log_type_json_nested test_event = { 'date': 'January 01, 3005', 'unixtime': '32661446400', 'host': 'my-host.name.website.com', 'data': { 'super-duper': 'secret', 'sanitize_me': 1, 'example-key': 1, 'moar**data': 2, 'even.more': 3 } } expected_sanitized_event = { 'date': 'January 01, 3005', 'unixtime': '32661446400', 'host': 'my-host.name.website.com', 'data': { 'super_duper': 'secret', 'sanitize_me': 1, 'example_key': 1, 'moar__data': 2, 'even_more': 3 } } sanitized_event = FirehoseClient.sanitize_keys(test_event) assert_equal(sanitized_event, expected_sanitized_event) @patch('stream_alert.rule_processor.firehose.LOGGER') def test_limit_record_size(self, mock_logging): """FirehoseClient - Record Size Check""" test_events = [ # unit_test_simple_log { 'unit_key_01': 1, 'unit_key_02': 'test' * 250001 # is 4 bytes higher than max }, { 'unit_key_01': 2, 'unit_key_02': 'test' }, # test_log_type_json_nested { 'date': 'January 01, 3005', 'unixtime': '32661446400', 'host': 'my-host.name.website.com', 'data': { 'super': 'secret' } }, # add another unit_test_sample_log to verify in a different position { 'unit_key_01': 1, 'unit_key_02': 'test' * 250001 # is 4 bytes higher than max }, { 'test': 1 } ] FirehoseClient._limit_record_size(test_events) assert_true(len(test_events), 3) assert_true(mock_logging.error.called)