Example #1
0
def generate_data_table_schema(config, table, schema_override=None):
    """Generate the schema for data table in terraform

    Args:
        config (CLIConfig): Loaded StreamAlert config
        table (string): The name of data table

    Returns:
        athena_schema (dict): Equivalent Athena schema used for generating create table statement
    """
    enabled_logs = FirehoseClient.load_enabled_log_sources(
        config['global']['infrastructure']['firehose'], config['logs'])

    # Convert special characters in schema name to underscores
    sanitized_table_name = FirehoseClient.sanitized_value(table)

    # Check that the log type is enabled via Firehose
    if sanitized_table_name not in enabled_logs:
        LOGGER.error(
            'Table name %s missing from configuration or '
            'is not enabled.', sanitized_table_name)
        return None

    log_info = config['logs'][enabled_logs.get(sanitized_table_name)]

    schema = dict(log_info['schema'])
    sanitized_schema = FirehoseClient.sanitize_keys(schema)

    athena_schema = logs_schema_to_athena_schema(sanitized_schema, False)

    # Add envelope keys to Athena Schema
    configuration_options = log_info.get('configuration')
    if configuration_options:
        envelope_keys = configuration_options.get('envelope_keys')
        if envelope_keys:
            sanitized_envelope_key_schema = FirehoseClient.sanitize_keys(
                envelope_keys)
            # Note: this key is wrapped in backticks to be Hive compliant
            athena_schema[
                'streamalert:envelope_keys'] = logs_schema_to_athena_schema(
                    sanitized_envelope_key_schema, False)

    # Handle Schema overrides
    #   This is useful when an Athena schema needs to differ from the normal log schema
    if schema_override:
        for override in schema_override:
            column_name, column_type = override.split('=')
            # Columns are escaped to avoid Hive issues with special characters
            column_name = '{}'.format(column_name)
            if column_name in athena_schema:
                athena_schema[column_name] = column_type
                LOGGER.info('Applied schema override: %s:%s', column_name,
                            column_type)
            else:
                LOGGER.error(
                    'Schema override column %s not found in Athena Schema, skipping',
                    column_name)

    return format_schema_tf(athena_schema)
Example #2
0
    def test_finalize_success(self, log_mock):
        """FirehoseClient - Finalize, Success"""
        request_id = 'success_id'
        stream_name = 'stream_name'
        count = 3
        response = {'ResponseMetadata': {'RequestId': request_id}}

        FirehoseClient._finalize(response, stream_name, count,
                                 'test_function_name')
        log_mock.assert_called_with(
            'Successfully sent %d message(s) to firehose %s with RequestId \'%s\'',
            count, stream_name, request_id)
    def setup(self):
        """Setup before each method"""
        with patch('boto3.client'):
            ArtifactExtractor._firehose_client = FirehoseClient(
                prefix='unit-test')

        self._artifact_extractor = ArtifactExtractor('unit_test_dst_fh_arn')
Example #4
0
    def test_sanitize_keys(self):
        """FirehoseClient - Sanitize Keys"""
        test_event = {
            'date': 'January 01, 3005',
            'data': {
                'super-duper': 'secret',
                'do_not_sanitize_me': 1,
                'example-key': 2,
                'moar**data': 3,
                'even.more': 4
            }
        }

        expected_sanitized_event = {
            'date': 'January 01, 3005',
            'data': {
                'super_duper': 'secret',
                'do_not_sanitize_me': 1,
                'example_key': 2,
                'moar__data': 3,
                'even_more': 4
            }
        }

        sanitized_event = FirehoseClient.sanitize_keys(test_event)
        assert_equal(sanitized_event, expected_sanitized_event)
Example #5
0
def create_log_tables(config):
    """Create all tables needed for historical search
    Args:
        config (CLIConfig): Loaded StreamAlert config
    Returns:
        bool: False if errors occurred, True otherwise
    """
    if not config['global']['infrastructure'].get('firehose',
                                                  {}).get('enabled'):
        return True

    firehose_config = config['global']['infrastructure']['firehose']
    firehose_s3_bucket_suffix = firehose_config.get('s3_bucket_suffix',
                                                    'streamalert-data')
    firehose_s3_bucket_name = '{}-{}'.format(
        config['global']['account']['prefix'], firehose_s3_bucket_suffix)

    enabled_logs = FirehoseClient.load_enabled_log_sources(
        config['global']['infrastructure']['firehose'], config['logs'])

    for log_stream_name in enabled_logs:
        if not create_table(log_stream_name, firehose_s3_bucket_name, config):
            return False

    return True
Example #6
0
def rebuild_partitions(table, bucket, config):
    """Rebuild an Athena table's partitions

    Steps:
      - Get the list of current partitions
      - Destroy existing table
      - Re-create tables
      - Re-create partitions

    Args:
        table (str): The name of the table being rebuilt
        bucket (str): The s3 bucket to be used as the location for Athena data
        table_type (str): The type of table being refreshed
            Types of 'data' and 'alert' are accepted, but only 'data' is implemented
        config (CLIConfig): Loaded StreamAlert config

    Returns:
        bool: False if errors occurred, True otherwise
    """
    sanitized_table_name = FirehoseClient.sanitized_value(table)

    athena_client = get_athena_client(config)

    # Get the current set of partitions
    partitions = athena_client.get_table_partitions(sanitized_table_name)
    if not partitions:
        LOGGER.info('No partitions to rebuild for %s, nothing to do',
                    sanitized_table_name)
        return False

    # Drop the table
    LOGGER.info('Dropping table %s', sanitized_table_name)
    if not athena_client.drop_table(sanitized_table_name):
        return False

    LOGGER.info('Creating table %s', sanitized_table_name)

    # Re-create the table with previous partitions
    if not create_table(table, bucket, config):
        return False

    new_partitions_statements = helpers.add_partition_statements(
        partitions, bucket, sanitized_table_name)

    LOGGER.info('Creating total %d new partitions for %s', len(partitions),
                sanitized_table_name)

    for idx, statement in enumerate(new_partitions_statements):
        success = athena_client.run_query(query=statement)
        LOGGER.info('Rebuilt partitions part %d', idx + 1)
        if not success:
            LOGGER.error('Error re-creating new partitions for %s',
                         sanitized_table_name)
            write_partitions_statements(new_partitions_statements,
                                        sanitized_table_name)
            return False

    LOGGER.info('Successfully rebuilt all partitions for %s',
                sanitized_table_name)
    return True
Example #7
0
 def test_load_from_config(self):
     """FirehoseClient - Load From Config"""
     with patch('boto3.client'):  # patch to speed up unit tests slightly
         client = FirehoseClient.load_from_config(
             prefix='unit-test',
             firehose_config={'enabled': True},
             log_sources=None)
         assert_equal(isinstance(client, FirehoseClient), True)
Example #8
0
    def test_records_to_json_list(self):
        """FirehoseClient - Records JSON Lines"""
        records = self._sample_raw_records()

        expected_result = ['{"key_0":"value_0"}\n', '{"key_1":"value_1"}\n']

        result = FirehoseClient._records_to_json_list(records)
        assert_equal(result, expected_result)
Example #9
0
    def test_record_batches_rec_too_large(self, failure_mock):
        """FirehoseClient - Record Batches, Record Too Large"""
        records = [{'key': 'test' * 1000 * 1000}]

        result = list(
            FirehoseClient._record_batches(records, 'test_function_name'))
        assert_equal(result, [])
        failure_mock.assert_called_with(1, 'test_function_name')
Example #10
0
    def test_finalize_failures(self, failure_mock):
        """FirehoseClient - Finalize, With Failures"""
        response = {
            'FailedPutCount':
            1,
            'RequestResponses': [{
                'RecordId': 'rec_id_01'
            }, {
                'ErrorCode': 10,
                'ErrorMessage': 'foo'
            }, {
                'RecordId': 'rec_id_03'
            }]
        }

        FirehoseClient._finalize(response, 'stream_name', 3,
                                 'test_function_name')
        failure_mock.assert_called_with(1, 'test_function_name')
Example #11
0
    def test_record_batches(self):
        """FirehoseClient - Record Batches"""
        records = self._sample_raw_records()

        expected_result = [['{"key_0":"value_0"}\n', '{"key_1":"value_1"}\n']]

        result = list(
            FirehoseClient._record_batches(records, 'test_function_name'))
        assert_equal(result, expected_result)
Example #12
0
    def test_record_batches_max_batch_count(self):
        """FirehoseClient - Record Batches, Max Batch Count"""
        records = self._sample_raw_records(count=501)

        result = list(
            FirehoseClient._record_batches(records, 'test_function_name'))
        assert_equal(len(result), 2)
        assert_equal(len(result[0]), 500)
        assert_equal(len(result[1]), 1)
Example #13
0
    def test_load_enabled_sources_invalid_log_subtype(self, log_mock):
        """FirehoseClient - Load Enabled Log Sources, Invalid Log Sub-type"""
        logs_config = {'log_type_01:sub_type_01': {}}
        log_type = 'log_type_01:sub_type_02'
        firehose_config = {'enabled_logs': [log_type]}

        enabled_logs = FirehoseClient.load_enabled_log_sources(
            firehose_config, logs_config)
        assert_equal(enabled_logs, dict())
        log_mock.assert_called_with(
            'Enabled Firehose log %s not declared in logs.json', log_type)
Example #14
0
    def test_strip_successful_records(self):
        """FirehoseClient - Strip Successful Records"""
        batch = [{'test': 'success'}, {'other': 'failure'}, {'other': 'info'}]
        response = {
            'FailedPutCount':
            1,
            'RequestResponses': [{
                'RecordId': 'rec_id_01'
            }, {
                'ErrorCode': 10,
                'ErrorMessage': 'foo'
            }, {
                'RecordId': 'rec_id_03'
            }]
        }

        expected_batch = [{'other': 'failure'}]
        FirehoseClient._strip_successful_records(batch, response)

        assert_equal(batch, expected_batch)
Example #15
0
    def __init__(self, artifacts_fh_stream_name):
        self._dst_firehose_stream_name = artifacts_fh_stream_name
        self._artifacts = list()

        ArtifactExtractor._config = ArtifactExtractor._config or config.load_config(
            validate=True)

        ArtifactExtractor._firehose_client = (
            ArtifactExtractor._firehose_client or FirehoseClient.get_client(
                prefix=self.config['global']['account']['prefix'],
                artifact_extractor_config=self.config['global'].get(
                    'infrastructure', {}).get('artifact_extractor', {})))
Example #16
0
 def test_record_batches_max_batch_size(self):
     """FirehoseClient - Record Batches, Max Batch Size"""
     records = [{'key_{}'.format(i): 'test' * 100000} for i in range(10)]
     result = list(
         FirehoseClient._record_batches(records, 'test_function_name'))
     assert_equal(len(result), 2)
     assert_equal(len(result[0]), 9)
     assert_equal(len(result[1]), 1)
     batch_size_01 = sum(len(rec) for rec in result[0])
     batch_size_02 = sum(len(rec) for rec in result[1])
     assert_equal(batch_size_01 < FirehoseClient.MAX_BATCH_SIZE, True)
     assert_equal(batch_size_02 < FirehoseClient.MAX_BATCH_SIZE, True)
     assert_equal(
         batch_size_01 + batch_size_02 > FirehoseClient.MAX_BATCH_SIZE,
         True)
Example #17
0
    def __init__(self):
        # Create some objects to be cached if they have not already been created
        Classifier._config = Classifier._config or config.load_config(validate=True)
        Classifier._firehose_client = (
            Classifier._firehose_client or FirehoseClient.load_from_config(
                prefix=self.config['global']['account']['prefix'],
                firehose_config=self.config['global'].get('infrastructure', {}).get('firehose', {}),
                log_sources=self.config['logs']
            )
        )
        Classifier._sqs_client = Classifier._sqs_client or SQSClient()

        # Setup the normalization logic
        Normalizer.load_from_config(self.config)
        self._cluster = os.environ['CLUSTER']
        self._payloads = []
        self._failed_record_count = 0
        self._processed_size = 0
Example #18
0
def generate_artifact_extractor(config):
    """Generate Terraform for the Artifact Extractor Lambda function
    Args:
        config (dict): The loaded config from the 'conf/' directory
    Returns:
        dict: Artifact Extractor Terraform definition to be marshaled to JSON
    """
    result = infinitedict()

    if not artifact_extractor_enabled(config):
        return

    ae_config = config['global']['infrastructure']['artifact_extractor']
    stream_name = FirehoseClient.artifacts_firehose_stream_name(config)

    # Set variables for the artifact extractor module
    result['module']['artifact_extractor'] = {
        'source':
        './modules/tf_artifact_extractor',
        'account_id':
        config['global']['account']['aws_account_id'],
        'prefix':
        config['global']['account']['prefix'],
        'region':
        config['global']['account']['region'],
        'glue_catalog_db_name':
        get_database_name(config),
        'glue_catalog_table_name':
        ae_config.get('table_name', DEFAULT_ARTIFACTS_TABLE_NAME),
        's3_bucket_name':
        firehose_data_bucket(config),
        'stream_name':
        stream_name,
        'buffer_size':
        ae_config.get('firehose_buffer_size', 128),
        'buffer_interval':
        ae_config.get('firehose_buffer_interval', 900),
        'kms_key_arn':
        '${aws_kms_key.server_side_encryption.arn}',
        'schema':
        generate_artifacts_table_schema()
    }

    return result
Example #19
0
    def test_send_no_prefixing(self, send_batch_mock):
        """FirehoseClient - Send, No Prefixing"""
        FirehoseClient._ENABLED_LOGS = {
            'log_type_01_sub_type_01': 'log_type_01:sub_type_01'
        }
        expected_batch = [
            '{"unit_key_01":1,"unit_key_02":"test"}\n',
            '{"unit_key_01":2,"unit_key_02":"test"}\n'
        ]

        client = FirehoseClient.load_from_config(prefix='unit-test',
                                                 firehose_config={
                                                     'enabled': True,
                                                     'use_prefix': False
                                                 },
                                                 log_sources=None)

        client.send(self._sample_payloads)
        send_batch_mock.assert_called_with(
            'streamalert_log_type_01_sub_type_01', expected_batch,
            'classifier')
Example #20
0
    def test_send_long_log_name(self, send_batch_mock):
        """FirehoseClient - Send data when the log name is very long"""
        FirehoseClient._ENABLED_LOGS = {
            'very_very_very_long_log_stream_name_abcdefg_hijklmn_70_characters_long':
            {}
        }
        expected_batch = [
            '{"unit_key_01":1,"unit_key_02":"test"}\n',
            '{"unit_key_01":2,"unit_key_02":"test"}\n'
        ]

        client = FirehoseClient.load_from_config(prefix='unit-test',
                                                 firehose_config={
                                                     'enabled': True,
                                                     'use_prefix': False
                                                 },
                                                 log_sources=None)

        client.send(self._sample_payloads_long_log_name)
        send_batch_mock.assert_called_with(
            'streamalert_very_very_very_long_log_stream_name_abcdefg_7c88167b',
            expected_batch, 'classifier')
Example #21
0
    def test_load_enabled_sources(self):
        """FirehoseClient - Load Enabled Log Sources"""
        logs_config = {
            'log_type_01:sub_type_01': {},
            'log_type_01:sub_type_02':
            {},  # This log type should is not enabled
            'log_type_02:sub_type_01': {},
            'log_type_02:sub_type_02': {},
        }
        firehose_config = {
            'enabled_logs': [
                'log_type_01:sub_type_01',  # One log for log_type_01
                'log_type_02'  # All of log_type_02
            ]
        }
        expected_result = {
            'log_type_01_sub_type_01': 'log_type_01:sub_type_01',
            'log_type_02_sub_type_01': 'log_type_02:sub_type_01',
            'log_type_02_sub_type_02': 'log_type_02:sub_type_02'
        }

        enabled_logs = FirehoseClient.load_enabled_log_sources(
            firehose_config, logs_config)
        assert_equal(enabled_logs, expected_result)
Example #22
0
 def setup(self):
     """Setup before each method"""
     with patch('boto3.client'):  # patch to speed up unit tests slightly
         self._client = FirehoseClient(prefix='unit-test',
                                       firehose_config={'use_prefix': True})
Example #23
0
def create_table(table, bucket, config, schema_override=None):
    """Create a 'streamalert' Athena table

    Args:
        table (str): The name of the table being rebuilt
        bucket (str): The s3 bucket to be used as the location for Athena data
        table_type (str): The type of table being refreshed
        config (CLIConfig): Loaded StreamAlert config
        schema_override (set): An optional set of key=value pairs to be used for
            overriding the configured column_name=value_type.

    Returns:
        bool: False if errors occurred, True otherwise
    """
    enabled_logs = FirehoseClient.load_enabled_log_sources(
        config['global']['infrastructure']['firehose'], config['logs'])

    # Convert special characters in schema name to underscores
    sanitized_table_name = FirehoseClient.sanitized_value(table)

    # Check that the log type is enabled via Firehose
    if sanitized_table_name != 'alerts' and sanitized_table_name not in enabled_logs:
        LOGGER.error(
            'Table name %s missing from configuration or '
            'is not enabled.', sanitized_table_name)
        return False

    athena_client = get_athena_client(config)

    # Check if the table exists
    if athena_client.check_table_exists(sanitized_table_name):
        LOGGER.info('The \'%s\' table already exists.', sanitized_table_name)
        return True

    if table == 'alerts':
        # get a fake alert so we can get the keys needed and their types
        alert = Alert('temp_rule_name', {}, {})
        output = alert.output_dict()
        schema = record_to_schema(output)
        athena_schema = helpers.logs_schema_to_athena_schema(schema)

        # Use the bucket if supplied, otherwise use the default alerts bucket
        bucket = bucket or firehose_alerts_bucket(config)

        query = _construct_create_table_statement(
            schema=athena_schema,
            table_name=table,
            bucket=bucket,
            file_format=get_data_file_format(config))

    else:  # all other tables are log types

        config_data_bucket = firehose_data_bucket(config)
        if not config_data_bucket:
            LOGGER.warning(
                'The \'firehose\' module is not enabled in global.json')
            return False

        # Use the bucket if supplied, otherwise use the default data bucket
        bucket = bucket or config_data_bucket

        log_info = config['logs'][enabled_logs.get(sanitized_table_name)]

        schema = dict(log_info['schema'])
        sanitized_schema = FirehoseClient.sanitize_keys(schema)

        athena_schema = helpers.logs_schema_to_athena_schema(sanitized_schema)

        # Add envelope keys to Athena Schema
        configuration_options = log_info.get('configuration')
        if configuration_options:
            envelope_keys = configuration_options.get('envelope_keys')
            if envelope_keys:
                sanitized_envelope_key_schema = FirehoseClient.sanitize_keys(
                    envelope_keys)
                # Note: this key is wrapped in backticks to be Hive compliant
                athena_schema[
                    '`streamalert:envelope_keys`'] = helpers.logs_schema_to_athena_schema(
                        sanitized_envelope_key_schema)

        # Handle Schema overrides
        #   This is useful when an Athena schema needs to differ from the normal log schema
        if schema_override:
            for override in schema_override:
                column_name, column_type = override.split('=')
                # Columns are escaped to avoid Hive issues with special characters
                column_name = '`{}`'.format(column_name)
                if column_name in athena_schema:
                    athena_schema[column_name] = column_type
                    LOGGER.info('Applied schema override: %s:%s', column_name,
                                column_type)
                else:
                    LOGGER.error(
                        'Schema override column %s not found in Athena Schema, skipping',
                        column_name)

        query = _construct_create_table_statement(
            schema=athena_schema,
            table_name=sanitized_table_name,
            bucket=bucket,
            file_format=get_data_file_format(config))

    success = athena_client.run_query(query=query)
    if not success:
        LOGGER.error('The %s table could not be created', sanitized_table_name)
        return False

    # Update the CLI config
    if table != 'alerts' and bucket != config_data_bucket:
        # Only add buckets to the config if they are not one of the default/configured buckets
        # Ensure 'buckets' exists in the config (since it is not required)
        config['lambda']['athena_partitioner_config']['buckets'] = (
            config['lambda']['athena_partitioner_config'].get('buckets', {}))
        if bucket not in config['lambda']['athena_partitioner_config'][
                'buckets']:
            config['lambda']['athena_partitioner_config']['buckets'][
                bucket] = 'data'
            config.write()

    LOGGER.info('The %s table was successfully created!', sanitized_table_name)

    return True
Example #24
0
 def test_load_from_config_disabled(self):
     """FirehoseClient - Load From Config, Disabled"""
     client = FirehoseClient.load_from_config(prefix='unit-test',
                                              firehose_config={},
                                              log_sources=None)
     assert_equal(client, None)
Example #25
0
class TestFirehoseClient:
    """Test class for FirehoseClient"""

    # pylint: disable=protected-access,no-self-use,attribute-defined-outside-init

    def setup(self):
        """Setup before each method"""
        with patch('boto3.client'):  # patch to speed up unit tests slightly
            self._client = FirehoseClient(prefix='unit-test',
                                          firehose_config={'use_prefix': True})

    def teardown(self):
        """Teardown after each method"""
        FirehoseClient._ENABLED_LOGS.clear()

    @property
    def _sample_payloads(self):
        return [
            Mock(log_schema_type='log_type_01_sub_type_01',
                 parsed_records=[{
                     'unit_key_01': 1,
                     'unit_key_02': 'test'
                 }, {
                     'unit_key_01': 2,
                     'unit_key_02': 'test'
                 }]),
            Mock(log_schema_type='log_type_02_sub_type_01',
                 parsed_records=[{
                     'date': 'January 01, 3005',
                     'unixtime': '32661446400',
                     'host': 'my-host.name.website.com',
                     'data': {
                         'super': 'secret'
                     }
                 }])
        ]

    @classmethod
    def _sample_raw_records(cls, count=2):
        return [{
            'key_{}'.format(i): 'value_{}'.format(i)
        } for i in range(count)]

    def test_records_to_json_list(self):
        """FirehoseClient - Records JSON Lines"""
        records = self._sample_raw_records()

        expected_result = ['{"key_0":"value_0"}\n', '{"key_1":"value_1"}\n']

        result = FirehoseClient._records_to_json_list(records)
        assert_equal(result, expected_result)

    def test_record_batches(self):
        """FirehoseClient - Record Batches"""
        records = self._sample_raw_records()

        expected_result = [['{"key_0":"value_0"}\n', '{"key_1":"value_1"}\n']]

        result = list(
            FirehoseClient._record_batches(records, 'test_function_name'))
        assert_equal(result, expected_result)

    @patch.object(FirehoseClient, '_log_failed')
    def test_record_batches_rec_too_large(self, failure_mock):
        """FirehoseClient - Record Batches, Record Too Large"""
        records = [{'key': 'test' * 1000 * 1000}]

        result = list(
            FirehoseClient._record_batches(records, 'test_function_name'))
        assert_equal(result, [])
        failure_mock.assert_called_with(1, 'test_function_name')

    def test_record_batches_max_batch_count(self):
        """FirehoseClient - Record Batches, Max Batch Count"""
        records = self._sample_raw_records(count=501)

        result = list(
            FirehoseClient._record_batches(records, 'test_function_name'))
        assert_equal(len(result), 2)
        assert_equal(len(result[0]), 500)
        assert_equal(len(result[1]), 1)

    def test_record_batches_max_batch_size(self):
        """FirehoseClient - Record Batches, Max Batch Size"""
        records = [{'key_{}'.format(i): 'test' * 100000} for i in range(10)]
        result = list(
            FirehoseClient._record_batches(records, 'test_function_name'))
        assert_equal(len(result), 2)
        assert_equal(len(result[0]), 9)
        assert_equal(len(result[1]), 1)
        batch_size_01 = sum(len(rec) for rec in result[0])
        batch_size_02 = sum(len(rec) for rec in result[1])
        assert_equal(batch_size_01 < FirehoseClient.MAX_BATCH_SIZE, True)
        assert_equal(batch_size_02 < FirehoseClient.MAX_BATCH_SIZE, True)
        assert_equal(
            batch_size_01 + batch_size_02 > FirehoseClient.MAX_BATCH_SIZE,
            True)

    def test_sanitize_keys(self):
        """FirehoseClient - Sanitize Keys"""
        test_event = {
            'date': 'January 01, 3005',
            'data': {
                'super-duper': 'secret',
                'do_not_sanitize_me': 1,
                'example-key': 2,
                'moar**data': 3,
                'even.more': 4
            }
        }

        expected_sanitized_event = {
            'date': 'January 01, 3005',
            'data': {
                'super_duper': 'secret',
                'do_not_sanitize_me': 1,
                'example_key': 2,
                'moar__data': 3,
                'even_more': 4
            }
        }

        sanitized_event = FirehoseClient.sanitize_keys(test_event)
        assert_equal(sanitized_event, expected_sanitized_event)

    def test_strip_successful_records(self):
        """FirehoseClient - Strip Successful Records"""
        batch = [{'test': 'success'}, {'other': 'failure'}, {'other': 'info'}]
        response = {
            'FailedPutCount':
            1,
            'RequestResponses': [{
                'RecordId': 'rec_id_01'
            }, {
                'ErrorCode': 10,
                'ErrorMessage': 'foo'
            }, {
                'RecordId': 'rec_id_03'
            }]
        }

        expected_batch = [{'other': 'failure'}]
        FirehoseClient._strip_successful_records(batch, response)

        assert_equal(batch, expected_batch)

    def test_categorize_records(self):
        """FirehoseClient - Categorize Records"""
        FirehoseClient._ENABLED_LOGS = {
            'log_type_01_sub_type_01': 'log_type_01:sub_type_01',
            'log_type_02_sub_type_01': 'log_type_02:sub_type_01'
        }

        payloads = self._sample_payloads

        result = self._client._categorize_records(payloads)
        expected_result = {
            'log_type_01_sub_type_01': payloads[0].parsed_records,
            'log_type_02_sub_type_01': payloads[1].parsed_records
        }
        assert_equal(dict(result), expected_result)

    def test_categorize_records_none_enabled(self):
        """FirehoseClient - Categorize Records, None Enabled"""
        payloads = self._sample_payloads
        result = self._client._categorize_records(payloads)

        assert_equal(dict(result), dict())

    def test_categorize_records_subset_enabled(self):
        """FirehoseClient - Categorize Records, Subset Enabled"""
        FirehoseClient._ENABLED_LOGS = {
            'log_type_01_sub_type_01': 'log_type_01:sub_type_01'
        }

        payloads = self._sample_payloads

        result = self._client._categorize_records(payloads)
        expected_result = {
            'log_type_01_sub_type_01': payloads[0].parsed_records
        }
        assert_equal(dict(result), expected_result)

    @patch.object(FirehoseClient, '_log_failed')
    def test_finalize_failures(self, failure_mock):
        """FirehoseClient - Finalize, With Failures"""
        response = {
            'FailedPutCount':
            1,
            'RequestResponses': [{
                'RecordId': 'rec_id_01'
            }, {
                'ErrorCode': 10,
                'ErrorMessage': 'foo'
            }, {
                'RecordId': 'rec_id_03'
            }]
        }

        FirehoseClient._finalize(response, 'stream_name', 3,
                                 'test_function_name')
        failure_mock.assert_called_with(1, 'test_function_name')

    @patch('logging.Logger.info')
    def test_finalize_success(self, log_mock):
        """FirehoseClient - Finalize, Success"""
        request_id = 'success_id'
        stream_name = 'stream_name'
        count = 3
        response = {'ResponseMetadata': {'RequestId': request_id}}

        FirehoseClient._finalize(response, stream_name, count,
                                 'test_function_name')
        log_mock.assert_called_with(
            'Successfully sent %d message(s) to firehose %s with RequestId \'%s\'',
            count, stream_name, request_id)

    def test_send_batch(self):
        """FirehoseClient - Send Batch"""
        records = [
            '{"unit_key_02":"test","unit_key_01":1}\n',
            '{"unit_key_02":"test","unit_key_01":2}\n'
        ]

        stream_name = 'test_stream_name'
        expected_second_call = [{'Data': records[1]}]
        with patch.object(self._client, '_client') as boto_mock:
            boto_mock.put_record_batch.side_effect = [{
                'FailedPutCount':
                1,
                'RequestResponses': [{
                    'RecordId': 'rec_id_01'
                }, {
                    'ErrorCode': 10,
                    'ErrorMessage': 'foo'
                }]
            }, {
                'FailedPutCount':
                0,
                'RequestResponses': [
                    {
                        'RecordId': 'rec_id_02'
                    },
                ]
            }]

            self._client._send_batch(stream_name, records,
                                     'test_function_name')

            boto_mock.put_record_batch.assert_called_with(
                DeliveryStreamName=stream_name, Records=expected_second_call)

    @patch('logging.Logger.exception')
    @patch.object(FirehoseClient, 'MAX_BACKOFF_ATTEMPTS', 1)
    def test_send_batch_error(self, log_mock):
        """FirehoseClient - Send Batch, Error"""
        stream_name = 'test_stream_name'
        with patch.object(self._client, '_client') as boto_mock:
            error = ClientError({'Error': {
                'Code': 10
            }}, 'InvalidRequestException')
            boto_mock.put_record_batch.side_effect = error

            self._client._send_batch(stream_name, ['data'],
                                     'test_function_name')

            log_mock.assert_called_with('Firehose request failed')

    def test_sanitized_value(self):
        """FirehoseClient - Sanitized Value"""
        expected_result = 'test_log_type_name'
        result = FirehoseClient.sanitized_value('test*log.type-name')
        assert_equal(result, expected_result)

    def test_enabled_log_source(self):
        """FirehoseClient - Enabled Log Source"""
        log = 'enabled_log'
        FirehoseClient._ENABLED_LOGS = {log: 'enabled:log'}
        assert_equal(FirehoseClient.enabled_log_source(log), True)

    def test_enabled_log_source_false(self):
        """FirehoseClient - Enabled Log Source, False"""
        log = 'enabled_log'
        assert_equal(FirehoseClient.enabled_log_source(log), False)

    def test_load_enabled_sources(self):
        """FirehoseClient - Load Enabled Log Sources"""
        logs_config = {
            'log_type_01:sub_type_01': {},
            'log_type_01:sub_type_02':
            {},  # This log type should is not enabled
            'log_type_02:sub_type_01': {},
            'log_type_02:sub_type_02': {},
        }
        firehose_config = {
            'enabled_logs': [
                'log_type_01:sub_type_01',  # One log for log_type_01
                'log_type_02'  # All of log_type_02
            ]
        }
        expected_result = {
            'log_type_01_sub_type_01': 'log_type_01:sub_type_01',
            'log_type_02_sub_type_01': 'log_type_02:sub_type_01',
            'log_type_02_sub_type_02': 'log_type_02:sub_type_02'
        }

        enabled_logs = FirehoseClient.load_enabled_log_sources(
            firehose_config, logs_config)
        assert_equal(enabled_logs, expected_result)

    @patch('logging.Logger.error')
    def test_load_enabled_sources_invalid_log(self, log_mock):
        """FirehoseClient - Load Enabled Log Sources, Invalid Log Type"""
        logs_config = {
            'log_type_01:sub_type_01': {},
            'log_type_01:sub_type_02': {}
        }
        log_type = 'log_type_03'
        firehose_config = {'enabled_logs': [log_type]}

        enabled_logs = FirehoseClient.load_enabled_log_sources(
            firehose_config, logs_config)
        assert_equal(enabled_logs, dict())
        log_mock.assert_called_with(
            'Enabled Firehose log %s not declared in logs.json', log_type)

    @patch('logging.Logger.error')
    def test_load_enabled_sources_invalid_log_subtype(self, log_mock):
        """FirehoseClient - Load Enabled Log Sources, Invalid Log Sub-type"""
        logs_config = {'log_type_01:sub_type_01': {}}
        log_type = 'log_type_01:sub_type_02'
        firehose_config = {'enabled_logs': [log_type]}

        enabled_logs = FirehoseClient.load_enabled_log_sources(
            firehose_config, logs_config)
        assert_equal(enabled_logs, dict())
        log_mock.assert_called_with(
            'Enabled Firehose log %s not declared in logs.json', log_type)

    def test_load_from_config(self):
        """FirehoseClient - Load From Config"""
        with patch('boto3.client'):  # patch to speed up unit tests slightly
            client = FirehoseClient.load_from_config(
                prefix='unit-test',
                firehose_config={'enabled': True},
                log_sources=None)
            assert_equal(isinstance(client, FirehoseClient), True)

    def test_load_from_config_disabled(self):
        """FirehoseClient - Load From Config, Disabled"""
        client = FirehoseClient.load_from_config(prefix='unit-test',
                                                 firehose_config={},
                                                 log_sources=None)
        assert_equal(client, None)

    @patch.object(FirehoseClient, '_send_batch')
    def test_send(self, send_batch_mock):
        """FirehoseClient - Send"""
        FirehoseClient._ENABLED_LOGS = {
            'log_type_01_sub_type_01': 'log_type_01:sub_type_01'
        }
        expected_batch = [
            '{"unit_key_01":1,"unit_key_02":"test"}\n',
            '{"unit_key_01":2,"unit_key_02":"test"}\n'
        ]
        self._client.send(self._sample_payloads)
        send_batch_mock.assert_called_with(
            'unit_test_streamalert_log_type_01_sub_type_01', expected_batch,
            'classifier')

    @patch.object(FirehoseClient, '_send_batch')
    def test_send_no_prefixing(self, send_batch_mock):
        """FirehoseClient - Send, No Prefixing"""
        FirehoseClient._ENABLED_LOGS = {
            'log_type_01_sub_type_01': 'log_type_01:sub_type_01'
        }
        expected_batch = [
            '{"unit_key_01":1,"unit_key_02":"test"}\n',
            '{"unit_key_01":2,"unit_key_02":"test"}\n'
        ]

        client = FirehoseClient.load_from_config(prefix='unit-test',
                                                 firehose_config={
                                                     'enabled': True,
                                                     'use_prefix': False
                                                 },
                                                 log_sources=None)

        client.send(self._sample_payloads)
        send_batch_mock.assert_called_with(
            'streamalert_log_type_01_sub_type_01', expected_batch,
            'classifier')

    @property
    def _sample_payloads_long_log_name(self):
        return [
            Mock(log_schema_type=(
                'very_very_very_long_log_stream_name_abcdefg_hijklmn_70_characters_long'
            ),
                 parsed_records=[{
                     'unit_key_01': 1,
                     'unit_key_02': 'test'
                 }, {
                     'unit_key_01': 2,
                     'unit_key_02': 'test'
                 }])
        ]

    @patch.object(FirehoseClient, '_send_batch')
    def test_send_long_log_name(self, send_batch_mock):
        """FirehoseClient - Send data when the log name is very long"""
        FirehoseClient._ENABLED_LOGS = {
            'very_very_very_long_log_stream_name_abcdefg_hijklmn_70_characters_long':
            {}
        }
        expected_batch = [
            '{"unit_key_01":1,"unit_key_02":"test"}\n',
            '{"unit_key_01":2,"unit_key_02":"test"}\n'
        ]

        client = FirehoseClient.load_from_config(prefix='unit-test',
                                                 firehose_config={
                                                     'enabled': True,
                                                     'use_prefix': False
                                                 },
                                                 log_sources=None)

        client.send(self._sample_payloads_long_log_name)
        send_batch_mock.assert_called_with(
            'streamalert_very_very_very_long_log_stream_name_abcdefg_7c88167b',
            expected_batch, 'classifier')

    def test_generate_firehose_name(self):
        """FirehoseClient - Test helper to generate firehose stream name when prefix disabled"""
        log_names = [
            'logstreamname', 'log_stream_name',
            'very_very_long_log_stream_name_ab_52_characters_long',
            'very_very_very_long_log_stream_name_abcdefg_abcdefg_70_characters_long'
        ]

        expected_results = [
            'streamalert_logstreamname', 'streamalert_log_stream_name',
            'streamalert_very_very_long_log_stream_name_ab_52_characters_long',
            'streamalert_very_very_very_long_log_stream_name_abcdefg_272fa762'
        ]
        results = [
            self._client.generate_firehose_name('', log_name)
            for log_name in log_names
        ]

        assert_equal(expected_results, results)

    def test_generate_firehose_name_prefix(self):
        """FirehoseClient - Test helper to generate firehose stream name with prefix"""
        log_names = [
            'logstreamname', 'log_stream_name',
            'very_very_long_log_stream_name_ab_52_characters_long',
            'very_very_very_long_log_stream_name_abcdefg_abcdefg_70_characters_long'
        ]

        expected_results = [
            'prefix_streamalert_logstreamname',
            'prefix_streamalert_log_stream_name',
            'prefix_streamalert_very_very_long_log_stream_name_ab_52_63bd84dc',
            'prefix_streamalert_very_very_very_long_log_stream_name_a0c91e099'
        ]
        results = [
            self._client.generate_firehose_name('prefix', log_name)
            for log_name in log_names
        ]

        assert_equal(expected_results, results)

    def test_artifacts_firehose_stream_name(self):
        """FirehoseClient - Test generate artifacts firehose stream name"""
        config_data = {
            'global': {
                'account': {
                    'prefix': 'unittest'
                }
            },
            'lambda': {
                'artifact_extractor_config': {}
            }
        }

        assert_equal(self._client.artifacts_firehose_stream_name(config_data),
                     'unittest_streamalert_artifacts')

        config_data['lambda']['artifact_extractor_config'][
            'firehose_stream_name'] = ('test_artifacts_fh_name')

        assert_equal(self._client.artifacts_firehose_stream_name(config_data),
                     'test_artifacts_fh_name')
Example #26
0
 def test_enabled_log_source_false(self):
     """FirehoseClient - Enabled Log Source, False"""
     log = 'enabled_log'
     assert_equal(FirehoseClient.enabled_log_source(log), False)
Example #27
0
 def test_enabled_log_source(self):
     """FirehoseClient - Enabled Log Source"""
     log = 'enabled_log'
     FirehoseClient._ENABLED_LOGS = {log: 'enabled:log'}
     assert_equal(FirehoseClient.enabled_log_source(log), True)
Example #28
0
 def test_sanitized_value(self):
     """FirehoseClient - Sanitized Value"""
     expected_result = 'test_log_type_name'
     result = FirehoseClient.sanitized_value('test*log.type-name')
     assert_equal(result, expected_result)
Example #29
0
def generate_firehose(logging_bucket, main_dict, config):
    """Generate the Firehose Terraform modules

    Args:
        config (CLIConfig): The loaded StreamAlert Config
        main_dict (infinitedict): The Dict to marshal to a file
        logging_bucket (str): The name of the global logging bucket
    """
    if not config['global']['infrastructure'].get('firehose', {}).get('enabled'):
        return

    prefix = config['global']['account']['prefix']

    # This can return False but the check above ensures that that should never happen
    firehose_s3_bucket_name = firehose_data_bucket(config)

    firehose_conf = config['global']['infrastructure']['firehose']

    # Firehose Setup module
    main_dict['module']['kinesis_firehose_setup'] = {
        'source': './modules/tf_kinesis_firehose_setup',
        'account_id': config['global']['account']['aws_account_id'],
        'prefix': prefix,
        'region': config['global']['account']['region'],
        's3_logging_bucket': logging_bucket,
        's3_bucket_name': firehose_s3_bucket_name,
        'kms_key_id': '${aws_kms_key.server_side_encryption.key_id}'
    }

    enabled_logs = FirehoseClient.load_enabled_log_sources(
        firehose_conf,
        config['logs'],
        force_load=True
    )

    log_alarms_config = firehose_conf.get('enabled_logs', {})

    db_name = get_database_name(config)

    firehose_prefix = prefix if firehose_conf.get('use_prefix', True) else ''

    # Add the Delivery Streams individually
    for log_stream_name, log_type_name in enabled_logs.items():
        module_dict = {
            'source': './modules/tf_kinesis_firehose_delivery_stream',
            'buffer_size': (
                firehose_conf.get('buffer_size')
            ),
            'buffer_interval': (
                firehose_conf.get('buffer_interval', 300)
            ),
            'file_format': get_data_file_format(config),
            'stream_name': FirehoseClient.generate_firehose_name(firehose_prefix, log_stream_name),
            'role_arn': '${module.kinesis_firehose_setup.firehose_role_arn}',
            's3_bucket_name': firehose_s3_bucket_name,
            'kms_key_arn': '${aws_kms_key.server_side_encryption.arn}',
            'glue_catalog_db_name': db_name,
            'glue_catalog_table_name': log_stream_name,
            'schema': generate_data_table_schema(config, log_type_name)
        }

        # Try to get alarm info for this specific log type
        alarm_info = log_alarms_config.get(log_type_name)
        if not alarm_info and ':' in log_type_name:
            # Fallback on looking for alarm info for the parent log type
            alarm_info = log_alarms_config.get(log_type_name.split(':')[0])

        if alarm_info and alarm_info.get('enable_alarm'):
            module_dict['enable_alarm'] = True

            # There are defaults of these defined in the terraform module, so do
            # not set the variable values unless explicitly specified
            if alarm_info.get('log_min_count_threshold'):
                module_dict['alarm_threshold'] = alarm_info.get('log_min_count_threshold')

            if alarm_info.get('evaluation_periods'):
                module_dict['evaluation_periods'] = alarm_info.get('evaluation_periods')

            if alarm_info.get('period_seconds'):
                module_dict['period_seconds'] = alarm_info.get('period_seconds')

            if alarm_info.get('alarm_actions'):
                if not isinstance(alarm_info.get('alarm_actions'), list):
                    module_dict['alarm_actions'] = [alarm_info.get('alarm_actions')]
                else:
                    module_dict['alarm_actions'] = alarm_info.get('alarm_actions')
            else:
                module_dict['alarm_actions'] = [monitoring_topic_arn(config)]

        main_dict['module']['kinesis_firehose_{}'.format(log_stream_name)] = module_dict