示例#1
0
    def test_strip_successful_records(self):
        """StreamAlertFirehose - Strip Successful Records"""
        batch = [{
            'test': 'success'
        }, {
            'test': 'data'
        }, {
            'other': 'failure'
        }, {
            'other': 'info'
        }]
        response = {
            'FailedPutCount':
            1,
            'RequestResponses': [{
                'RecordId': 'rec_id_00'
            }, {
                'RecordId': 'rec_id_01'
            }, {
                'ErrorCode': 10,
                'ErrorMessage': 'foo'
            }, {
                'RecordId': 'rec_id_03'
            }]
        }

        expected_batch = [{'other': 'failure'}]
        StreamAlertFirehose._strip_successful_records(batch, response)

        assert_equal(batch, expected_batch)
示例#2
0
    def test_limit_record_size(self, mock_logging):
        """StreamAlertFirehose - Record Size Check"""
        test_events = [
            # unit_test_simple_log
            {
                'unit_key_01': 1,
                'unit_key_02': 'test' * 250001  # is 4 bytes higher than max
            },
            {
                'unit_key_01': 2,
                'unit_key_02': 'test'
            },
            # test_log_type_json_nested
            {
                'date': 'January 01, 3005',
                'unixtime': '32661446400',
                'host': 'my-host.name.website.com',
                'data': {
                    'super': 'secret'
                }
            }
        ]

        StreamAlertFirehose._limit_record_size(test_events)

        assert_true(len(test_events), 2)
        assert_true(mock_logging.error.called)
示例#3
0
    def test_record_delivery(self, mock_logging):
        """StreamAlertFirehose - Record Delivery"""
        self.__sa_firehose = StreamAlertFirehose(region='us-east-1',
                                                 firehose_config={},
                                                 log_sources={})

        # Add sample categorized payloads
        for payload_type, logs in self._sample_categorized_payloads(
        ).iteritems():
            self.__sa_firehose.categorized_payloads[payload_type].extend(logs)

        # Setup mocked Delivery Streams
        self._mock_delivery_streams([
            'streamalert_data_test_log_type_json_nested',
            'streamalert_data_unit_test_simple_log'
        ])

        # Send the records
        with patch.object(self.__sa_firehose._firehose_client,
                          'put_record_batch') as firehose_mock:
            firehose_mock.return_value = {'FailedPutCount': 0}
            self.__sa_firehose.send()

            firehose_mock.assert_called()
            assert_true(mock_logging.info.called)
示例#4
0
    def test_record_delivery_failure(self, mock_logging):
        """StreamAlertFirehose - Record Delivery - Failed PutRecord"""
        self.__sa_firehose = StreamAlertFirehose(region='us-east-1',
                                                 firehose_config={},
                                                 log_sources={})

        # Add sample categorized payloads
        for payload_type, logs in self._sample_categorized_payloads(
        ).iteritems():
            self.__sa_firehose.categorized_payloads[payload_type].extend(logs)

        # Setup mocked Delivery Streams
        self._mock_delivery_streams([
            'streamalert_data_test_log_type_json_nested',
            'streamalert_data_unit_test_simple_log'
        ])

        # Send the records
        with patch.object(self.__sa_firehose._firehose_client,
                          'put_record_batch') as firehose_mock:
            firehose_mock.return_value = {
                'FailedPutCount':
                3,
                'RequestResponses': [
                    {
                        'RecordId': '12345',
                        'ErrorCode': '300',
                        'ErrorMessage': 'Bad message!!!'
                    },
                ]
            }
            self.__sa_firehose.send()

            firehose_mock.assert_called()
            assert_true(mock_logging.error.called)
示例#5
0
    def test_load_enabled_sources(self):
        """StreamAlertFirehose - Load Enabled Sources"""
        config = load_config('tests/unit/conf')
        firehose_config = {
            'enabled_logs':
            ['json:regex_key_with_envelope', 'test_cloudtrail', 'cloudwatch']
        }  # expands to 2 logs

        sa_firehose = StreamAlertFirehose(region='us-east-1',
                                          firehose_config=firehose_config,
                                          log_sources=config['logs'])

        assert_equal(len(sa_firehose._enabled_logs), 4)
        # Make sure the subtitution works properly
        assert_true(all([':' not in log for log in sa_firehose.enabled_logs]))
        assert_false(sa_firehose.enabled_log_source('test_inspec'))
示例#6
0
    def test_sanitize_keys(self):
        """StreamAlertFirehose - Sanitize Keys"""
        # test_log_type_json_nested
        test_event = {
            'date': 'January 01, 3005',
            'unixtime': '32661446400',
            'host': 'my-host.name.website.com',
            'data': {
                'super-duper': 'secret',
                'sanitize_me': 1,
                'example-key': 1,
                'moar**data': 2,
                'even.more': 3
            }
        }

        expected_sanitized_event = {
            'date': 'January 01, 3005',
            'unixtime': '32661446400',
            'host': 'my-host.name.website.com',
            'data': {
                'super_duper': 'secret',
                'sanitize_me': 1,
                'example_key': 1,
                'moar__data': 2,
                'even_more': 3
            }
        }

        sanitized_event = StreamAlertFirehose.sanitize_keys(test_event)
        assert_equal(sanitized_event, expected_sanitized_event)
示例#7
0
    def test_firehose_reset(self, mock_logging):
        """StreamAlertFirehose - Test Reset Firehose Client"""
        def test_func():
            pass

        sa_firehose = StreamAlertFirehose(region='us-east-1',
                                          firehose_config={},
                                          log_sources={})

        id_1 = id(sa_firehose._firehose_client)
        sa_firehose._backoff_handler_firehose_reset({
            'target': test_func,
            'wait': '0.134315135',
            'tries': 3
        })
        id_2 = id(sa_firehose._firehose_client)

        assert_true(mock_logging.info.called)
        assert_not_equal(id_1, id_2)
示例#8
0
    def test_record_delivery_client_error(self, mock_logging):
        """StreamAlertFirehose - Record Delivery - Client Error"""
        sa_firehose = StreamAlertFirehose(region='us-east-1',
                                          firehose_config={},
                                          log_sources={})

        test_events = [
            # unit_test_simple_log
            {
                'unit_key_01': 2,
                'unit_key_02': 'testtest'
            } for _ in range(10)
        ]

        sa_firehose._firehose_request_helper('invalid_stream', test_events)

        missing_stream_message = 'Client Error ... An error occurred ' \
            '(ResourceNotFoundException) when calling the PutRecordBatch ' \
            'operation: Stream invalid_stream under account 123456789012 not found.'
        assert_true(mock_logging.error.called_with(missing_stream_message))
示例#9
0
    def test_load_enabled_sources_invalid_log(self, mock_logging):
        """StreamAlertFirehose - Load Enabled Sources - Invalid Log"""
        config = load_config('tests/unit/conf')
        firehose_config = {'enabled_logs': ['log-that-doesnt-exist']}

        sa_firehose = StreamAlertFirehose(region='us-east-1',
                                          firehose_config=firehose_config,
                                          log_sources=config['logs'])

        assert_equal(len(sa_firehose._enabled_logs), 0)
        assert_true(mock_logging.error.called)
示例#10
0
    def test_segment_records_by_size(self):
        """StreamAlertFirehose - Segment Large Records"""
        sa_firehose = StreamAlertFirehose(region='us-east-1',
                                          firehose_config={},
                                          log_sources={})

        record_batch = [
            # unit_test_simple_log
            {
                'unit_key_01': 2,
                'unit_key_02': 'testtest' * 10000
            } for _ in range(100)
        ]

        sized_batches = []

        for sized_batch in sa_firehose._segment_records_by_size(record_batch):
            sized_batches.append(sized_batch)

        assert_true(len(str(sized_batches[0])) < 4000000)
        assert_equal(len(sized_batches), 4)
        assert_true(isinstance(sized_batches[3][0], dict))
示例#11
0
def generate_firehose(config, main_dict, logging_bucket):
    """Generate the Firehose Terraform modules

    Args:
        config (CLIConfig): The loaded StreamAlert Config
        main_dict (infinitedict): The Dict to marshal to a file
        logging_bucket (str): The name of the global logging bucket
    """
    if not config['global']['infrastructure'].get('firehose',
                                                  {}).get('enabled'):
        return

    sa_firehose = StreamAlertFirehose(
        config['global']['account']['region'],
        config['global']['infrastructure']['firehose'], config['logs'])

    firehose_config = config['global']['infrastructure']['firehose']
    firehose_s3_bucket_suffix = firehose_config.get('s3_bucket_suffix',
                                                    'streamalert.data')
    firehose_s3_bucket_name = '{}.{}'.format(
        config['global']['account']['prefix'], firehose_s3_bucket_suffix)

    # Firehose Setup module
    main_dict['module']['kinesis_firehose_setup'] = {
        'source': 'modules/tf_stream_alert_kinesis_firehose_setup',
        'account_id': config['global']['account']['aws_account_id'],
        'prefix': config['global']['account']['prefix'],
        'region': config['global']['account']['region'],
        's3_logging_bucket': logging_bucket,
        's3_bucket_name': firehose_s3_bucket_name,
        'kms_key_id': '${aws_kms_key.server_side_encryption.key_id}'
    }

    # Add the Delivery Streams individually
    for enabled_log in sa_firehose.enabled_logs:
        main_dict['module']['kinesis_firehose_{}'.format(enabled_log)] = {
            'source': 'modules/tf_stream_alert_kinesis_firehose_delivery_stream',
            'buffer_size': config['global']['infrastructure']
                           ['firehose'].get('buffer_size', 64),
            'buffer_interval': config['global']['infrastructure']
                               ['firehose'].get('buffer_interval', 300),\
            'compression_format': config['global']['infrastructure']
                                  ['firehose'].get('compression_format', 'GZIP'),
            'log_name': enabled_log,
            'role_arn': '${module.kinesis_firehose_setup.firehose_role_arn}',
            's3_bucket_name': firehose_s3_bucket_name,
            'kms_key_arn': '${aws_kms_key.server_side_encryption.arn}'
        }
示例#12
0
def setup_mock_firehose_delivery_streams(config):
    """Mock Kinesis Firehose Streams for rule testing

    Args:
        config (CLIConfig): The StreamAlert config
    """
    firehose_config = config['global']['infrastructure'].get('firehose')
    if not firehose_config:
        return

    region = config['global']['account']['region']
    sa_firehose = StreamAlertFirehose(region, firehose_config, config['logs'])

    for log_type in sa_firehose.enabled_logs:
        stream_name = 'streamalert_data_{}'.format(log_type)
        prefix = '{}/'.format(log_type)
        create_delivery_stream(region, stream_name, prefix)
示例#13
0
    def test_record_delivery_failed_put_count(self, mock_logging):
        """StreamAlertFirehose - Record Delivery - Failed Put Count"""
        self.__sa_firehose = StreamAlertFirehose(region='us-east-1',
                                                 firehose_config={},
                                                 log_sources={})

        # Add sample categorized payloads
        for payload_type, logs in self._sample_categorized_payloads(
        ).iteritems():
            self.__sa_firehose.categorized_payloads[payload_type].extend(logs)

        # Setup mocked Delivery Streams
        self._mock_delivery_streams([
            'streamalert_data_test_log_type_json_nested',
            'streamalert_data_unit_test_simple_log'
        ])

        with patch.object(self.__sa_firehose._firehose_client,
                          'put_record_batch') as firehose_mock:
            firehose_mock.side_effect = [{
                'FailedPutCount':
                3,
                'RequestResponses': [{
                    "ErrorCode": "ServiceUnavailableException",
                    "ErrorMessage": "Slow down."
                }, {
                    "ErrorCode": "ServiceUnavailableException",
                    "ErrorMessage": "Slow down."
                }, {
                    "ErrorCode": "ServiceUnavailableException",
                    "ErrorMessage": "Slow down."
                }]
            }, {
                'FailedPutCount':
                3,
                'RequestResponses': [{
                    "ErrorCode": "ServiceUnavailableException",
                    "ErrorMessage": "Slow down."
                }, {
                    "ErrorCode": "ServiceUnavailableException",
                    "ErrorMessage": "Slow down."
                }, {
                    "ErrorCode": "ServiceUnavailableException",
                    "ErrorMessage": "Slow down."
                }]
            }, {
                'FailedPutCount':
                0,
                'RequestResponses': [{
                    "RecordId": "12345678910",
                    "ErrorCode": "None",
                    "ErrorMessage": "None"
                }, {
                    "RecordId": "12345678910",
                    "ErrorCode": "None",
                    "ErrorMessage": "None"
                }, {
                    "RecordId": "12345678910",
                    "ErrorCode": "None",
                    "ErrorMessage": "None"
                }]
            }]
            self.__sa_firehose.send()

            firehose_mock.assert_called()
            assert_true(mock_logging.info.called)
示例#14
0
def rebuild_partitions(table, bucket, config):
    """Rebuild an Athena table's partitions

    Steps:
      - Get the list of current partitions
      - Destroy existing table
      - Re-create tables
      - Re-create partitions

    Args:
        table (str): The name of the table being rebuilt
        bucket (str): The s3 bucket to be used as the location for Athena data
        table_type (str): The type of table being refreshed
            Types of 'data' and 'alert' are accepted, but only 'data' is implemented
        config (CLIConfig): Loaded StreamAlert CLI
    """
    athena_client = get_athena_client(config)

    sa_firehose = StreamAlertFirehose(
        config['global']['account']['region'],
        config['global']['infrastructure']['firehose'], config['logs'])

    sanitized_table_name = sa_firehose.firehose_log_name(table)

    # Get the current set of partitions
    partitions = athena_client.get_table_partitions(sanitized_table_name)
    if not partitions:
        LOGGER_CLI.info('No partitions to rebuild for %s, nothing to do',
                        sanitized_table_name)
        return

    # Drop the table
    LOGGER_CLI.info('Dropping table %s', sanitized_table_name)
    success = athena_client.drop_table(sanitized_table_name)
    if not success:
        return

    LOGGER_CLI.info('Creating table %s', sanitized_table_name)

    # Re-create the table with previous partitions
    create_table(table, bucket, config)

    new_partitions_statement = helpers.add_partition_statement(
        partitions, bucket, sanitized_table_name)

    # Make sure our new alter table statement is within the query API limits
    if len(new_partitions_statement) > MAX_QUERY_LENGTH:
        LOGGER_CLI.error(
            'Partition statement too large, writing to local file')
        with open('partitions_{}.txt'.format(sanitized_table_name),
                  'w') as partition_file:
            partition_file.write(new_partitions_statement)
        return

    LOGGER_CLI.info('Creating %d new partitions for %s', len(partitions),
                    sanitized_table_name)

    success = athena_client.run_query(query=new_partitions_statement)
    if not success:
        LOGGER_CLI.error('Error re-creating new partitions for %s',
                         sanitized_table_name)
        return

    LOGGER_CLI.info('Successfully rebuilt partitions for %s',
                    sanitized_table_name)
示例#15
0
def create_table(table, bucket, config, schema_override=None):
    """Create a 'streamalert' Athena table

    Args:
        table (str): The name of the table being rebuilt
        bucket (str): The s3 bucket to be used as the location for Athena data
        table_type (str): The type of table being refreshed
        config (CLIConfig): Loaded StreamAlert CLI
        schema_override (set): An optional set of key=value pairs to be used for
            overriding the configured column_name=value_type.
    """
    athena_client = StreamAlertAthenaClient(
        config, results_key_prefix='stream_alert_cli')

    sa_firehose = StreamAlertFirehose(
        config['global']['account']['region'],
        config['global']['infrastructure']['firehose'], config['logs'])

    # Convert special characters in schema name to underscores
    sanitized_table_name = sa_firehose.firehose_log_name(table)

    # Check that the log type is enabled via Firehose
    if sanitized_table_name != 'alerts' and sanitized_table_name not in sa_firehose.enabled_logs:
        LOGGER_CLI.error(
            'Table name %s missing from configuration or '
            'is not enabled.', sanitized_table_name)
        return

    # Check if the table exists
    if athena_client.check_table_exists(sanitized_table_name, True):
        LOGGER_CLI.info('The \'%s\' table already exists.',
                        sanitized_table_name)
        return

    if table == 'alerts':
        # get a fake alert so we can get the keys needed and their types
        alert = Alert('temp_rule_name', {}, {})
        output = alert.output_dict()
        schema = record_to_schema(output)
        athena_schema = handler_helpers.to_athena_schema(schema)

        query = _construct_create_table_statement(schema=athena_schema,
                                                  table_name=table,
                                                  bucket=bucket)

    else:  # all other tables are log types

        log_info = config['logs'][table.replace('_', ':', 1)]

        schema = dict(log_info['schema'])
        sanitized_schema = StreamAlertFirehose.sanitize_keys(schema)

        athena_schema = handler_helpers.to_athena_schema(sanitized_schema)

        # Add envelope keys to Athena Schema
        configuration_options = log_info.get('configuration')
        if configuration_options:
            envelope_keys = configuration_options.get('envelope_keys')
            if envelope_keys:
                sanitized_envelope_key_schema = StreamAlertFirehose.sanitize_keys(
                    envelope_keys)
                # Note: this key is wrapped in backticks to be Hive compliant
                athena_schema[
                    '`streamalert:envelope_keys`'] = handler_helpers.to_athena_schema(
                        sanitized_envelope_key_schema)

        # Handle Schema overrides
        #   This is useful when an Athena schema needs to differ from the normal log schema
        if schema_override:
            for override in schema_override:
                column_name, column_type = override.split('=')
                if not all([column_name, column_type]):
                    LOGGER_CLI.error(
                        'Invalid schema override [%s], use column_name=type format',
                        override)

                # Columns are escaped to avoid Hive issues with special characters
                column_name = '`{}`'.format(column_name)
                if column_name in athena_schema:
                    athena_schema[column_name] = column_type
                    LOGGER_CLI.info('Applied schema override: %s:%s',
                                    column_name, column_type)
                else:
                    LOGGER_CLI.error(
                        'Schema override column %s not found in Athena Schema, skipping',
                        column_name)

        query = _construct_create_table_statement(
            schema=athena_schema,
            table_name=sanitized_table_name,
            bucket=bucket)

    create_table_success, _ = athena_client.run_athena_query(
        query=query, database=athena_client.sa_database)

    if not create_table_success:
        LOGGER_CLI.error('The %s table could not be created',
                         sanitized_table_name)
        return

    # Update the CLI config
    if (table != 'alerts' and bucket not in config['lambda']
        ['athena_partition_refresh_config']['buckets']):
        config['lambda']['athena_partition_refresh_config']['buckets'][
            bucket] = 'data'
        config.write()

    LOGGER_CLI.info('The %s table was successfully created!',
                    sanitized_table_name)
示例#16
0
class TestStreamAlertFirehose(object):
    """Test class for StreamAlertFirehose"""
    def __init__(self):
        self.__sa_firehose = None

    def teardown(self):
        """Setup before each method"""
        self.__sa_firehose = None

    @staticmethod
    def _sample_categorized_payloads():
        return {
            'unit_test_simple_log': [{
                'unit_key_01': 1,
                'unit_key_02': 'test'
            }, {
                'unit_key_01': 2,
                'unit_key_02': 'test'
            }],
            'test_log_type_json_nested': [{
                'date': 'January 01, 3005',
                'unixtime': '32661446400',
                'host': 'my-host.name.website.com',
                'data': {
                    'super': 'secret'
                }
            }]
        }

    @mock_kinesis
    def _mock_delivery_streams(self, delivery_stream_names):
        """Mock Kinesis Delivery Streams for tests"""
        for delivery_stream in delivery_stream_names:
            self.__sa_firehose._firehose_client.create_delivery_stream(
                DeliveryStreamName=delivery_stream,
                S3DestinationConfiguration={
                    'RoleARN':
                    'arn:aws:iam::123456789012:role/firehose_delivery_role',
                    'BucketARN': 'arn:aws:s3:::kinesis-test',
                    'Prefix': '{}/'.format(delivery_stream),
                    'BufferingHints': {
                        'SizeInMBs': 123,
                        'IntervalInSeconds': 124
                    },
                    'CompressionFormat': 'Snappy',
                })

    @patch('stream_alert.rule_processor.firehose.LOGGER')
    @mock_kinesis
    def test_record_delivery_failed_put_count(self, mock_logging):
        """StreamAlertFirehose - Record Delivery - Failed Put Count"""
        self.__sa_firehose = StreamAlertFirehose(region='us-east-1',
                                                 firehose_config={},
                                                 log_sources={})

        # Add sample categorized payloads
        for payload_type, logs in self._sample_categorized_payloads(
        ).iteritems():
            self.__sa_firehose.categorized_payloads[payload_type].extend(logs)

        # Setup mocked Delivery Streams
        self._mock_delivery_streams([
            'streamalert_data_test_log_type_json_nested',
            'streamalert_data_unit_test_simple_log'
        ])

        with patch.object(self.__sa_firehose._firehose_client,
                          'put_record_batch') as firehose_mock:
            firehose_mock.side_effect = [{
                'FailedPutCount':
                3,
                'RequestResponses': [{
                    "ErrorCode": "ServiceUnavailableException",
                    "ErrorMessage": "Slow down."
                }, {
                    "ErrorCode": "ServiceUnavailableException",
                    "ErrorMessage": "Slow down."
                }, {
                    "ErrorCode": "ServiceUnavailableException",
                    "ErrorMessage": "Slow down."
                }]
            }, {
                'FailedPutCount':
                3,
                'RequestResponses': [{
                    "ErrorCode": "ServiceUnavailableException",
                    "ErrorMessage": "Slow down."
                }, {
                    "ErrorCode": "ServiceUnavailableException",
                    "ErrorMessage": "Slow down."
                }, {
                    "ErrorCode": "ServiceUnavailableException",
                    "ErrorMessage": "Slow down."
                }]
            }, {
                'FailedPutCount':
                0,
                'RequestResponses': [{
                    "RecordId": "12345678910",
                    "ErrorCode": "None",
                    "ErrorMessage": "None"
                }, {
                    "RecordId": "12345678910",
                    "ErrorCode": "None",
                    "ErrorMessage": "None"
                }, {
                    "RecordId": "12345678910",
                    "ErrorCode": "None",
                    "ErrorMessage": "None"
                }]
            }]
            self.__sa_firehose.send()

            firehose_mock.assert_called()
            assert_true(mock_logging.info.called)

    @patch('stream_alert.rule_processor.firehose.LOGGER')
    @mock_kinesis
    def test_record_delivery(self, mock_logging):
        """StreamAlertFirehose - Record Delivery"""
        self.__sa_firehose = StreamAlertFirehose(region='us-east-1',
                                                 firehose_config={},
                                                 log_sources={})

        # Add sample categorized payloads
        for payload_type, logs in self._sample_categorized_payloads(
        ).iteritems():
            self.__sa_firehose.categorized_payloads[payload_type].extend(logs)

        # Setup mocked Delivery Streams
        self._mock_delivery_streams([
            'streamalert_data_test_log_type_json_nested',
            'streamalert_data_unit_test_simple_log'
        ])

        # Send the records
        with patch.object(self.__sa_firehose._firehose_client,
                          'put_record_batch') as firehose_mock:
            firehose_mock.return_value = {'FailedPutCount': 0}
            self.__sa_firehose.send()

            firehose_mock.assert_called()
            assert_true(mock_logging.info.called)

    @patch('stream_alert.rule_processor.firehose.LOGGER')
    @mock_kinesis
    def test_record_delivery_failure(self, mock_logging):
        """StreamAlertFirehose - Record Delivery - Failed PutRecord"""
        self.__sa_firehose = StreamAlertFirehose(region='us-east-1',
                                                 firehose_config={},
                                                 log_sources={})

        # Add sample categorized payloads
        for payload_type, logs in self._sample_categorized_payloads(
        ).iteritems():
            self.__sa_firehose.categorized_payloads[payload_type].extend(logs)

        # Setup mocked Delivery Streams
        self._mock_delivery_streams([
            'streamalert_data_test_log_type_json_nested',
            'streamalert_data_unit_test_simple_log'
        ])

        # Send the records
        with patch.object(self.__sa_firehose._firehose_client,
                          'put_record_batch') as firehose_mock:
            firehose_mock.return_value = {
                'FailedPutCount':
                3,
                'RequestResponses': [
                    {
                        'RecordId': '12345',
                        'ErrorCode': '300',
                        'ErrorMessage': 'Bad message!!!'
                    },
                ]
            }
            self.__sa_firehose.send()

            firehose_mock.assert_called()
            assert_true(mock_logging.error.called)

    @patch('stream_alert.rule_processor.firehose.LOGGER')
    @mock_kinesis
    def test_record_delivery_client_error(self, mock_logging):
        """StreamAlertFirehose - Record Delivery - Client Error"""
        sa_firehose = StreamAlertFirehose(region='us-east-1',
                                          firehose_config={},
                                          log_sources={})

        test_events = [
            # unit_test_simple_log
            {
                'unit_key_01': 2,
                'unit_key_02': 'testtest'
            } for _ in range(10)
        ]

        sa_firehose._firehose_request_helper('invalid_stream', test_events)

        missing_stream_message = 'Client Error ... An error occurred ' \
            '(ResourceNotFoundException) when calling the PutRecordBatch ' \
            'operation: Stream invalid_stream under account 123456789012 not found.'
        assert_true(mock_logging.error.called_with(missing_stream_message))

    @mock_kinesis
    def test_load_enabled_sources(self):
        """StreamAlertFirehose - Load Enabled Sources"""
        config = load_config('tests/unit/conf')
        firehose_config = {
            'enabled_logs':
            ['json:regex_key_with_envelope', 'test_cloudtrail', 'cloudwatch']
        }  # expands to 2 logs

        sa_firehose = StreamAlertFirehose(region='us-east-1',
                                          firehose_config=firehose_config,
                                          log_sources=config['logs'])

        assert_equal(len(sa_firehose._enabled_logs), 4)
        # Make sure the subtitution works properly
        assert_true(all([':' not in log for log in sa_firehose.enabled_logs]))
        assert_false(sa_firehose.enabled_log_source('test_inspec'))

    @patch('stream_alert.rule_processor.firehose.LOGGER')
    @mock_kinesis
    def test_load_enabled_sources_invalid_log(self, mock_logging):
        """StreamAlertFirehose - Load Enabled Sources - Invalid Log"""
        config = load_config('tests/unit/conf')
        firehose_config = {'enabled_logs': ['log-that-doesnt-exist']}

        sa_firehose = StreamAlertFirehose(region='us-east-1',
                                          firehose_config=firehose_config,
                                          log_sources=config['logs'])

        assert_equal(len(sa_firehose._enabled_logs), 0)
        assert_true(mock_logging.error.called)

    def test_segment_records_by_size(self):
        """StreamAlertFirehose - Segment Large Records"""
        sa_firehose = StreamAlertFirehose(region='us-east-1',
                                          firehose_config={},
                                          log_sources={})

        record_batch = [
            # unit_test_simple_log
            {
                'unit_key_01': 2,
                'unit_key_02': 'testtest' * 10000
            } for _ in range(100)
        ]

        sized_batches = []

        for sized_batch in sa_firehose._segment_records_by_size(record_batch):
            sized_batches.append(sized_batch)

        assert_true(len(str(sized_batches[0])) < 4000000)
        assert_equal(len(sized_batches), 4)
        assert_true(isinstance(sized_batches[3][0], dict))

    def test_sanitize_keys(self):
        """StreamAlertFirehose - Sanitize Keys"""
        # test_log_type_json_nested
        test_event = {
            'date': 'January 01, 3005',
            'unixtime': '32661446400',
            'host': 'my-host.name.website.com',
            'data': {
                'super-duper': 'secret',
                'sanitize_me': 1,
                'example-key': 1,
                'moar**data': 2,
                'even.more': 3
            }
        }

        expected_sanitized_event = {
            'date': 'January 01, 3005',
            'unixtime': '32661446400',
            'host': 'my-host.name.website.com',
            'data': {
                'super_duper': 'secret',
                'sanitize_me': 1,
                'example_key': 1,
                'moar__data': 2,
                'even_more': 3
            }
        }

        sanitized_event = StreamAlertFirehose.sanitize_keys(test_event)
        assert_equal(sanitized_event, expected_sanitized_event)

    @patch('stream_alert.rule_processor.firehose.LOGGER')
    def test_limit_record_size(self, mock_logging):
        """StreamAlertFirehose - Record Size Check"""
        test_events = [
            # unit_test_simple_log
            {
                'unit_key_01': 1,
                'unit_key_02': 'test' * 250001  # is 4 bytes higher than max
            },
            {
                'unit_key_01': 2,
                'unit_key_02': 'test'
            },
            # test_log_type_json_nested
            {
                'date': 'January 01, 3005',
                'unixtime': '32661446400',
                'host': 'my-host.name.website.com',
                'data': {
                    'super': 'secret'
                }
            }
        ]

        StreamAlertFirehose._limit_record_size(test_events)

        assert_true(len(test_events), 2)
        assert_true(mock_logging.error.called)
示例#17
0
def rebuild_partitions(table, bucket, config):
    """Rebuild an Athena table's partitions

    Steps:
      - Get the list of current partitions
      - Destroy existing table
      - Re-create tables
      - Re-create partitions

    Args:
        table (str): The name of the table being rebuilt
        bucket (str): The s3 bucket to be used as the location for Athena data
        table_type (str): The type of table being refreshed
            Types of 'data' and 'alert' are accepted, but only 'data' is implemented
        config (CLIConfig): Loaded StreamAlert CLI
    """
    athena_client = StreamAlertAthenaClient(
        config, results_key_prefix='stream_alert_cli')

    sa_firehose = StreamAlertFirehose(
        config['global']['account']['region'],
        config['global']['infrastructure']['firehose'], config['logs'])

    sanitized_table_name = sa_firehose.firehose_log_name(table)

    # Get the current set of partitions
    partition_success, partitions = athena_client.run_athena_query(
        query='SHOW PARTITIONS {}'.format(sanitized_table_name),
        database=athena_client.sa_database)
    if not partition_success:
        LOGGER_CLI.error('An error occurred when loading partitions for %s',
                         sanitized_table_name)
        return

    unique_partitions = athena_helpers.unique_values_from_query(partitions)

    if not unique_partitions:
        LOGGER_CLI.info('No partitions to rebuild for %s, nothing to do',
                        sanitized_table_name)
        return

    # Drop the table
    LOGGER_CLI.info('Dropping table %s', sanitized_table_name)
    drop_success, _ = athena_client.run_athena_query(
        query='DROP TABLE {}'.format(sanitized_table_name),
        database=athena_client.sa_database)
    if not drop_success:
        LOGGER_CLI.error('An error occurred when dropping the %s table',
                         sanitized_table_name)
        return

    LOGGER_CLI.info('Dropped table %s', sanitized_table_name)

    LOGGER_CLI.info('Creating table %s', sanitized_table_name)

    # Re-create the table with previous partitions
    create_table(table, bucket, config)

    new_partitions_statement = athena_helpers.partition_statement(
        unique_partitions, bucket, sanitized_table_name)

    # Make sure our new alter table statement is within the query API limits
    if len(new_partitions_statement) > MAX_QUERY_LENGTH:
        LOGGER_CLI.error(
            'Partition statement too large, writing to local file')
        with open('partitions_{}.txt'.format(sanitized_table_name),
                  'w') as partition_file:
            partition_file.write(new_partitions_statement)
        return

    LOGGER_CLI.info('Creating %d new partitions for %s',
                    len(unique_partitions), sanitized_table_name)
    new_part_success, _ = athena_client.run_athena_query(
        query=new_partitions_statement, database=athena_client.sa_database)
    if not new_part_success:
        LOGGER_CLI.error('Error re-creating new partitions for %s',
                         sanitized_table_name)
        return

    LOGGER_CLI.info('Successfully rebuilt partitions for %s',
                    sanitized_table_name)
示例#18
0
def create_table(athena_client, options, config):
    """Create a 'streamalert' Athena table

    Args:
        athena_client (boto3.client): Instantiated CLI AthenaClient
        options (namedtuple): The parsed args passed from the CLI
        config (CLIConfig): Loaded StreamAlert CLI
    """
    sa_firehose = StreamAlertFirehose(
        config['global']['account']['region'],
        config['global']['infrastructure']['firehose'], config['logs'])

    if not options.bucket:
        LOGGER_CLI.error('Missing command line argument --bucket')
        return

    if not options.refresh_type:
        LOGGER_CLI.error('Missing command line argument --refresh_type')
        return

    if options.type == 'data':
        if not options.table_name:
            LOGGER_CLI.error('Missing command line argument --table_name')
            return

        # Convert special characters in schema name to underscores
        sanitized_table_name = sa_firehose.firehose_log_name(
            options.table_name)

        # Check that the log type is enabled via Firehose
        if sanitized_table_name not in sa_firehose.enabled_logs:
            LOGGER_CLI.error(
                'Table name %s missing from configuration or '
                'is not enabled.', sanitized_table_name)
            return

        # Check if the table exists
        if athena_client.check_table_exists(sanitized_table_name):
            LOGGER_CLI.info('The \'%s\' table already exists.',
                            sanitized_table_name)
            return

        log_info = config['logs'][options.table_name.replace('_', ':', 1)]

        schema = dict(log_info['schema'])
        sanitized_schema = StreamAlertFirehose.sanitize_keys(schema)

        athena_schema = handler_helpers.to_athena_schema(sanitized_schema)

        # Add envelope keys to Athena Schema
        configuration_options = log_info.get('configuration')
        if configuration_options:
            envelope_keys = configuration_options.get('envelope_keys')
            if envelope_keys:
                sanitized_envelope_key_schema = StreamAlertFirehose.sanitize_keys(
                    envelope_keys)
                # Note: this key is wrapped in backticks to be Hive compliant
                athena_schema[
                    '`streamalert:envelope_keys`'] = handler_helpers.to_athena_schema(
                        sanitized_envelope_key_schema)

        # Handle Schema overrides
        #   This is useful when an Athena schema needs to differ from the normal log schema
        if options.schema_override:
            for override in options.schema_override:
                if '=' not in override:
                    LOGGER_CLI.error(
                        'Invalid schema override [%s], use column_name=type format',
                        override)
                    return

                column_name, column_type = override.split('=')
                if not all([column_name, column_type]):
                    LOGGER_CLI.error(
                        'Invalid schema override [%s], use column_name=type format',
                        override)

                # Columns are escaped to avoid Hive issues with special characters
                column_name = '`{}`'.format(column_name)
                if column_name in athena_schema:
                    athena_schema[column_name] = column_type
                    LOGGER_CLI.info('Applied schema override: %s:%s',
                                    column_name, column_type)
                else:
                    LOGGER_CLI.error(
                        'Schema override column %s not found in Athena Schema, skipping',
                        column_name)

        query = _construct_create_table_statement(
            schema=athena_schema,
            table_name=sanitized_table_name,
            bucket=options.bucket)

    elif options.type == 'alerts':
        if athena_client.check_table_exists(options.type):
            LOGGER_CLI.info('The \'alerts\' table already exists.')
            return
        query = ALERTS_TABLE_STATEMENT.format(bucket=options.bucket)

    if query:
        create_table_success, _ = athena_client.run_athena_query(
            query=query, database='streamalert')

        if create_table_success:
            # Update the CLI config
            config['lambda']['athena_partition_refresh_config'] \
                  ['refresh_type'][options.refresh_type][options.bucket] = options.type
            config.write()

            table_name = options.type if options.type == 'alerts' else sanitized_table_name
            LOGGER_CLI.info('The %s table was successfully created!',
                            table_name)
示例#19
0
class StreamAlert(object):
    """Wrapper class for handling StreamAlert classification and processing"""
    config = {}

    def __init__(self, context):
        """Initializer

        Args:
            context (dict): An AWS context object which provides metadata on the currently
                executing lambda function.
        """
        # Load the config. Validation occurs during load, which will
        # raise exceptions on any ConfigErrors
        StreamAlert.config = StreamAlert.config or load_config()

        # Load the environment from the context arn
        self.env = load_env(context)

        # Instantiate the send_alerts here to handle sending the triggered alerts to the
        # alert processor
        self.alert_forwarder = AlertForwarder()

        # Instantiate a classifier that is used for this run
        self.classifier = StreamClassifier(config=self.config)

        self._failed_record_count = 0
        self._processed_record_count = 0
        self._processed_size = 0
        self._alerts = []

        rule_import_paths = [
            item for location in {'rule_locations', 'matcher_locations'}
            for item in self.config['global']['general'][location]
        ]

        # Create an instance of the StreamRules class that gets cached in the
        # StreamAlert class as an instance property
        self._rules_engine = RulesEngine(self.config, *rule_import_paths)

        # Firehose client attribute
        self._firehose_client = None

    def run(self, event):
        """StreamAlert Lambda function handler.

        Loads the configuration for the StreamAlert function which contains
        available data sources, log schemas, normalized types, and outputs.
        Classifies logs sent into a parsed type.
        Matches records against rules.

        Args:
            event (dict): An AWS event mapped to a specific source/entity
                containing data read by Lambda.

        Returns:
            bool: True if all logs being parsed match a schema
        """
        records = event.get('Records', [])
        LOGGER.debug('Number of incoming records: %d', len(records))
        if not records:
            return False

        firehose_config = self.config['global'].get('infrastructure',
                                                    {}).get('firehose', {})
        if firehose_config.get('enabled'):
            self._firehose_client = StreamAlertFirehose(
                self.env['lambda_region'], firehose_config,
                self.config['logs'])

        payload_with_normalized_records = []
        for raw_record in records:
            # Get the service and entity from the payload. If the service/entity
            # is not in our config, log and error and go onto the next record
            service, entity = self.classifier.extract_service_and_entity(
                raw_record)
            if not service:
                LOGGER.error(
                    'No valid service found in payload\'s raw record. Skipping '
                    'record: %s', raw_record)
                continue

            if not entity:
                LOGGER.error(
                    'Unable to extract entity from payload\'s raw record for service %s. '
                    'Skipping record: %s', service, raw_record)
                continue

            # Cache the log sources for this service and entity on the classifier
            if not self.classifier.load_sources(service, entity):
                continue

            # Create the StreamPayload to use for encapsulating parsed info
            payload = load_stream_payload(service, entity, raw_record)
            if not payload:
                continue

            payload_with_normalized_records.extend(
                self._process_alerts(payload))

        # Log normalized records metric
        MetricLogger.log_metric(FUNCTION_NAME, MetricLogger.NORMALIZED_RECORDS,
                                len(payload_with_normalized_records))

        # Apply Threat Intel to normalized records in the end of Rule Processor invocation
        record_alerts = self._rules_engine.threat_intel_match(
            payload_with_normalized_records)
        self._alerts.extend(record_alerts)
        if record_alerts:
            self.alert_forwarder.send_alerts(record_alerts)

        MetricLogger.log_metric(FUNCTION_NAME, MetricLogger.TOTAL_RECORDS,
                                self._processed_record_count)

        MetricLogger.log_metric(FUNCTION_NAME,
                                MetricLogger.TOTAL_PROCESSED_SIZE,
                                self._processed_size)

        LOGGER.debug('Invalid record count: %d', self._failed_record_count)

        MetricLogger.log_metric(FUNCTION_NAME, MetricLogger.FAILED_PARSES,
                                self._failed_record_count)

        LOGGER.debug('%s alerts triggered', len(self._alerts))

        MetricLogger.log_metric(FUNCTION_NAME, MetricLogger.TRIGGERED_ALERTS,
                                len(self._alerts))

        # Check if debugging logging is on before json dumping alerts since
        # this can be time consuming if there are a lot of alerts
        if self._alerts and LOGGER.isEnabledFor(LOG_LEVEL_DEBUG):
            LOGGER.debug(
                'Alerts:\n%s',
                json.dumps([alert.output_dict() for alert in self._alerts],
                           indent=2,
                           sort_keys=True))

        if self._firehose_client:
            self._firehose_client.send()

        # Only log rule info here if this is not running tests
        # During testing, this gets logged at the end and printing here could be confusing
        # since stress testing calls this method multiple times
        if self.env['lambda_alias'] != 'development':
            stats.print_rule_stats(True)

        return self._failed_record_count == 0

    @property
    def alerts(self):
        """Returns list of Alert instances (useful for testing)."""
        return self._alerts

    def _process_alerts(self, payload):
        """Run the record through the rules, saving any alerts and forwarding them to Dynamo.

        Args:
            payload (StreamPayload): StreamAlert payload object being processed
        """
        payload_with_normalized_records = []
        for record in payload.pre_parse():
            # Increment the processed size using the length of this record
            self._processed_size += len(record.pre_parsed_record)
            self.classifier.classify_record(record)
            if not record.valid:
                if self.env['lambda_alias'] != 'development':
                    LOGGER.error(
                        'Record does not match any defined schemas: %s\n%s',
                        record, record.pre_parsed_record)

                self._failed_record_count += 1
                continue

            # Increment the total processed records to get an accurate assessment of throughput
            self._processed_record_count += len(record.records)

            LOGGER.debug(
                'Classified and Parsed Payload: <Valid: %s, Log Source: %s, Entity: %s>',
                record.valid, record.log_source, record.entity)

            record_alerts, normalized_records = self._rules_engine.run(record)

            payload_with_normalized_records.extend(normalized_records)

            LOGGER.debug(
                'Processed %d valid record(s) that resulted in %d alert(s).',
                len(payload.records), len(record_alerts))

            # Add all parsed records to the categorized payload dict only if Firehose is enabled
            if self._firehose_client:
                # Only send payloads with enabled log sources
                if self._firehose_client.enabled_log_source(
                        payload.log_source):
                    self._firehose_client.categorized_payloads[
                        payload.log_source].extend(payload.records)

            if not record_alerts:
                continue

            # Extend the list of alerts with any new ones so they can be returned
            self._alerts.extend(record_alerts)

            self.alert_forwarder.send_alerts(record_alerts)

        return payload_with_normalized_records
示例#20
0
    def run(self, event):
        """StreamAlert Lambda function handler.

        Loads the configuration for the StreamAlert function which contains
        available data sources, log schemas, normalized types, and outputs.
        Classifies logs sent into a parsed type.
        Matches records against rules.

        Args:
            event (dict): An AWS event mapped to a specific source/entity
                containing data read by Lambda.

        Returns:
            bool: True if all logs being parsed match a schema
        """
        records = event.get('Records', [])
        LOGGER.debug('Number of incoming records: %d', len(records))
        if not records:
            return False

        firehose_config = self.config['global'].get('infrastructure',
                                                    {}).get('firehose', {})
        if firehose_config.get('enabled'):
            self._firehose_client = StreamAlertFirehose(
                self.env['lambda_region'], firehose_config,
                self.config['logs'])

        payload_with_normalized_records = []
        for raw_record in records:
            # Get the service and entity from the payload. If the service/entity
            # is not in our config, log and error and go onto the next record
            service, entity = self.classifier.extract_service_and_entity(
                raw_record)
            if not service:
                LOGGER.error(
                    'No valid service found in payload\'s raw record. Skipping '
                    'record: %s', raw_record)
                continue

            if not entity:
                LOGGER.error(
                    'Unable to extract entity from payload\'s raw record for service %s. '
                    'Skipping record: %s', service, raw_record)
                continue

            # Cache the log sources for this service and entity on the classifier
            if not self.classifier.load_sources(service, entity):
                continue

            # Create the StreamPayload to use for encapsulating parsed info
            payload = load_stream_payload(service, entity, raw_record)
            if not payload:
                continue

            payload_with_normalized_records.extend(
                self._process_alerts(payload))

        # Log normalized records metric
        MetricLogger.log_metric(FUNCTION_NAME, MetricLogger.NORMALIZED_RECORDS,
                                len(payload_with_normalized_records))

        # Apply Threat Intel to normalized records in the end of Rule Processor invocation
        record_alerts = self._rules_engine.threat_intel_match(
            payload_with_normalized_records)
        self._alerts.extend(record_alerts)
        if record_alerts:
            self.alert_forwarder.send_alerts(record_alerts)

        MetricLogger.log_metric(FUNCTION_NAME, MetricLogger.TOTAL_RECORDS,
                                self._processed_record_count)

        MetricLogger.log_metric(FUNCTION_NAME,
                                MetricLogger.TOTAL_PROCESSED_SIZE,
                                self._processed_size)

        LOGGER.debug('Invalid record count: %d', self._failed_record_count)

        MetricLogger.log_metric(FUNCTION_NAME, MetricLogger.FAILED_PARSES,
                                self._failed_record_count)

        LOGGER.debug('%s alerts triggered', len(self._alerts))

        MetricLogger.log_metric(FUNCTION_NAME, MetricLogger.TRIGGERED_ALERTS,
                                len(self._alerts))

        # Check if debugging logging is on before json dumping alerts since
        # this can be time consuming if there are a lot of alerts
        if self._alerts and LOGGER.isEnabledFor(LOG_LEVEL_DEBUG):
            LOGGER.debug(
                'Alerts:\n%s',
                json.dumps([alert.output_dict() for alert in self._alerts],
                           indent=2,
                           sort_keys=True))

        if self._firehose_client:
            self._firehose_client.send()

        # Only log rule info here if this is not running tests
        # During testing, this gets logged at the end and printing here could be confusing
        # since stress testing calls this method multiple times
        if self.env['lambda_alias'] != 'development':
            stats.print_rule_stats(True)

        return self._failed_record_count == 0
示例#21
0
def rebuild_partitions(athena_client, options, config):
    """Rebuild an Athena table's partitions

    Steps:
      - Get the list of current partitions
      - Destroy existing table
      - Re-create tables
      - Re-create partitions

    Args:
        athena_client (boto3.client): Instantiated CLI AthenaClient
        options (namedtuple): The parsed args passed from the CLI
        config (CLIConfig): Loaded StreamAlert CLI
    """
    if not options.table_name:
        LOGGER_CLI.error('Missing command line argument --table_name')
        return

    if not options.bucket:
        LOGGER_CLI.error('Missing command line argument --bucket')
        return

    sa_firehose = StreamAlertFirehose(
        config['global']['account']['region'],
        config['global']['infrastructure']['firehose'], config['logs'])
    sanitized_table_name = sa_firehose.firehose_log_name(options.table_name)

    if options.type == 'data':
        # Get the current set of partitions
        partition_success, partitions = athena_client.run_athena_query(
            query='SHOW PARTITIONS {}'.format(sanitized_table_name),
            database='streamalert')
        if not partition_success:
            LOGGER_CLI.error('An error occured when loading partitions for %s',
                             sanitized_table_name)
            return

        unique_partitions = athena_helpers.unique_values_from_query(partitions)

        # Drop the table
        LOGGER_CLI.info('Dropping table %s', sanitized_table_name)
        drop_success, _ = athena_client.run_athena_query(
            query='DROP TABLE {}'.format(sanitized_table_name),
            database='streamalert')
        if not drop_success:
            LOGGER_CLI.error('An error occured when dropping the %s table',
                             sanitized_table_name)
            return

        LOGGER_CLI.info('Dropped table %s', sanitized_table_name)

        new_partitions_statement = athena_helpers.partition_statement(
            unique_partitions, options.bucket, sanitized_table_name)

        # Make sure our new alter table statement is within the query API limits
        if len(new_partitions_statement) > MAX_QUERY_LENGTH:
            LOGGER_CLI.error(
                'Partition statement too large, writing to local file')
            with open('partitions_{}.txt'.format(sanitized_table_name),
                      'w') as partition_file:
                partition_file.write(new_partitions_statement)
            return

        # Re-create the table with previous partitions
        options.refresh_type = 'add_hive_partition'
        create_table(athena_client, options, config)

        LOGGER_CLI.info('Creating %d new partitions for %s',
                        len(unique_partitions), sanitized_table_name)
        new_part_success, _ = athena_client.run_athena_query(
            query=new_partitions_statement, database='streamalert')
        if not new_part_success:
            LOGGER_CLI.error('Error re-creating new partitions for %s',
                             sanitized_table_name)
            return

        LOGGER_CLI.info('Successfully rebuilt partitions for %s',
                        sanitized_table_name)

    else:
        LOGGER_CLI.info('Refreshing alerts tables unsupported')
示例#22
0
def create_table(athena_client, options, config):
    """Create a 'streamalert' Athena table

    Args:
        athena_client (boto3.client): Instantiated CLI AthenaClient
        options (namedtuple): The parsed args passed from the CLI
        config (CLIConfig): Loaded StreamAlert CLI
    """
    sa_firehose = StreamAlertFirehose(
        config['global']['account']['region'],
        config['global']['infrastructure']['firehose'], config['logs'])

    if not options.bucket:
        LOGGER_CLI.error('Missing command line argument --bucket')
        return

    if not options.refresh_type:
        LOGGER_CLI.error('Missing command line argument --refresh_type')
        return

    if options.type == 'data':
        if not options.table_name:
            LOGGER_CLI.error('Missing command line argument --table_name')
            return

        sanitized_table_name = sa_firehose.firehose_log_name(
            options.table_name)

        if sanitized_table_name not in sa_firehose.enabled_logs:
            LOGGER_CLI.error(
                'Table name %s missing from configuration or '
                'is not enabled.', sanitized_table_name)
            return

        if athena_client.check_table_exists(sanitized_table_name):
            LOGGER_CLI.info('The \'%s\' table already exists.',
                            sanitized_table_name)
            return

        log_info = config['logs'][options.table_name.replace('_', ':', 1)]
        schema = dict(log_info['schema'])
        schema_statement = ''

        sanitized_schema = StreamAlertFirehose.sanitize_keys(schema)
        athena_schema = {}

        _add_to_athena_schema(sanitized_schema, athena_schema)

        # Support envelope keys
        configuration_options = log_info.get('configuration')
        if configuration_options:
            envelope_keys = configuration_options.get('envelope_keys')
            if envelope_keys:
                sanitized_envelope_key_schema = StreamAlertFirehose.sanitize_keys(
                    envelope_keys)
                # Note: this key is wrapped in backticks to be Hive compliant
                _add_to_athena_schema(sanitized_envelope_key_schema,
                                      athena_schema,
                                      '`streamalert:envelope_keys`')

        for key_name, key_type in athena_schema.iteritems():
            # Account for nested structs
            if isinstance(key_type, dict):
                struct_schema = ''.join([
                    '{0}:{1},'.format(sub_key, sub_type)
                    for sub_key, sub_type in key_type.iteritems()
                ])
                nested_schema_statement = '{0} struct<{1}>, '.format(
                    key_name,
                    # Use the minus index to remove the last comma
                    struct_schema[:-1])
                schema_statement += nested_schema_statement
            else:
                schema_statement += '{0} {1},'.format(key_name, key_type)

        query = (
            'CREATE EXTERNAL TABLE {table_name} ({schema}) '
            'PARTITIONED BY (dt string) '
            'ROW FORMAT SERDE \'org.openx.data.jsonserde.JsonSerDe\' '
            'WITH SERDEPROPERTIES ( \'ignore.malformed.json\' = \'true\') '
            'LOCATION \'s3://{bucket}/{table_name}/\''.format(
                table_name=sanitized_table_name,
                # Use the minus index to remove the last comma
                schema=schema_statement[:-1],
                bucket=options.bucket))

    elif options.type == 'alerts':
        if athena_client.check_table_exists(options.type):
            LOGGER_CLI.info('The \'alerts\' table already exists.')
            return

        query = ('CREATE EXTERNAL TABLE alerts ('
                 'log_source string,'
                 'log_type string,'
                 'outputs array<string>,'
                 'record string,'
                 'rule_description string,'
                 'rule_name string,'
                 'source_entity string,'
                 'source_service string)'
                 'PARTITIONED BY (dt string)'
                 'ROW FORMAT SERDE \'org.openx.data.jsonserde.JsonSerDe\''
                 'LOCATION \'s3://{bucket}/alerts/\''.format(
                     bucket=options.bucket))

    if query:
        create_table_success, _ = athena_client.run_athena_query(
            query=query, database='streamalert')

        if create_table_success:
            # Update the CLI config
            config['lambda']['athena_partition_refresh_config'] \
                  ['refresh_type'][options.refresh_type][options.bucket] = options.type
            config.write()

            table_name = options.type if options.type == 'alerts' else sanitized_table_name
            LOGGER_CLI.info('The %s table was successfully created!',
                            table_name)