Пример #1
0
class AthenaRefresher(object):
    """Handle polling an SQS queue and running Athena queries for updating tables"""

    STREAMALERTS_REGEX = re.compile(r'alerts/dt=(?P<year>\d{4})'
                                    r'\-(?P<month>\d{2})'
                                    r'\-(?P<day>\d{2})'
                                    r'\-(?P<hour>\d{2})'
                                    r'\/.*.json')
    FIREHOSE_REGEX = re.compile(r'(?P<year>\d{4})'
                                r'\/(?P<month>\d{2})'
                                r'\/(?P<day>\d{2})'
                                r'\/(?P<hour>\d{2})\/.*')

    STREAMALERT_DATABASE = '{}_streamalert'
    ATHENA_S3_PREFIX = 'athena_partition_refresh'

    def __init__(self):
        config = load_config(include={'lambda.json', 'global.json'})
        prefix = config['global']['account']['prefix']
        athena_config = config['lambda']['athena_partition_refresh_config']

        self._athena_buckets = athena_config['buckets']

        db_name = athena_config.get('database_name',
                                    self.STREAMALERT_DATABASE.format(prefix))

        # Get the S3 bucket to store Athena query results
        results_bucket = athena_config.get(
            'results_bucket',
            's3://{}.streamalert.athena-results'.format(prefix))

        self._athena_client = AthenaClient(db_name, results_bucket,
                                           self.ATHENA_S3_PREFIX)

        self._s3_buckets_and_keys = defaultdict(set)

    def _get_partitions_from_keys(self):
        """Get the partitions that need to be added for the Athena tables

        Returns:
            (dict): representation of tables, partitions and locations to be added
                Example:
                    {
                        'alerts': {
                            '(dt = \'2018-08-01-01\')': 's3://streamalert.alerts/2018/08/01/01'
                        }
                    }
        """
        partitions = defaultdict(dict)

        LOGGER.info('Processing new Hive partitions...')
        for bucket, keys in self._s3_buckets_and_keys.iteritems():
            athena_table = self._athena_buckets.get(bucket)
            if not athena_table:
                # TODO(jacknagz): Add this as a metric
                LOGGER.error(
                    '\'%s\' not found in \'buckets\' config. Please add this '
                    'bucket to enable additions of Hive partitions.', bucket)
                continue

            # Iterate over each key
            for key in keys:
                match = None
                for pattern in (self.FIREHOSE_REGEX, self.STREAMALERTS_REGEX):
                    match = pattern.search(key)
                    if match:
                        break

                if not match:
                    LOGGER.error(
                        'The key %s does not match any regex, skipping', key)
                    continue

                # Get the path to the objects in S3
                path = posixpath.dirname(key)
                # The config does not need to store all possible tables
                # for enabled log types because this can be inferred from
                # the incoming S3 bucket notification.  Only enabled
                # log types will be sending data to Firehose.
                # This logic extracts out the name of the table from the
                # first element in the S3 path, as that's how log types
                # are configured to send to Firehose.
                if athena_table != 'alerts':
                    athena_table = path.split('/')[0]

                # Example:
                # PARTITION (dt = '2017-01-01-01') LOCATION 's3://bucket/path/'
                partition = '(dt = \'{year}-{month}-{day}-{hour}\')'.format(
                    **match.groupdict())
                location = '\'s3://{bucket}/{path}\''.format(bucket=bucket,
                                                             path=path)
                # By using the partition as the dict key, this ensures that
                # Athena will not try to add the same partition twice.
                # TODO(jacknagz): Write this dictionary to SSM/DynamoDb
                # to increase idempotence of this Lambda function
                partitions[athena_table][partition] = location

        return partitions

    def _add_partitions(self):
        """Execute a Hive Add Partition command for the given Athena tables and partitions

        Returns:
            (bool): If the repair was successful for not
        """
        partitions = self._get_partitions_from_keys()
        if not partitions:
            LOGGER.error('No partitons to add')
            return False

        for athena_table in partitions:
            partition_statement = ' '.join([
                'PARTITION {0} LOCATION {1}'.format(partition, location) for
                partition, location in partitions[athena_table].iteritems()
            ])
            query = ('ALTER TABLE {athena_table} '
                     'ADD IF NOT EXISTS {partition_statement};'.format(
                         athena_table=athena_table,
                         partition_statement=partition_statement))

            success = self._athena_client.run_query(query=query)
            if not success:
                raise AthenaRefreshError(
                    'The add hive partition query has failed:\n{}'.format(
                        query))

            LOGGER.info('Successfully added the following partitions:\n%s',
                        json.dumps({athena_table: partitions[athena_table]}))
        return True

    def run(self, event):
        """Take the messages from the SQS queue and create partitions for new data in S3

        Args:
            event (dict): Lambda input event containing SQS messages. Each SQS message
                should contain one (or maybe more) S3 bucket notification message.
        """
        # Check that the database being used exists before running queries
        if not self._athena_client.check_database_exists():
            raise AthenaRefreshError(
                'The \'{}\' database does not exist'.format(
                    self._athena_client.database))

        for sqs_rec in event['Records']:
            LOGGER.debug(
                'Processing event with message ID \'%s\' and SentTimestamp %s',
                sqs_rec['messageId'], sqs_rec['attributes']['SentTimestamp'])

            body = json.loads(sqs_rec['body'])
            if body.get('Event') == 's3:TestEvent':
                LOGGER.debug('Skipping S3 bucket notification test event')
                continue

            for s3_rec in body['Records']:
                if 's3' not in s3_rec:
                    LOGGER.info(
                        'Skipping non-s3 bucket notification message: %s',
                        s3_rec)
                    continue

                bucket_name = s3_rec['s3']['bucket']['name']

                # Account for special characters in the S3 object key
                # Example: Usage of '=' in the key name
                object_key = urllib.unquote_plus(
                    s3_rec['s3']['object']['key']).decode('utf8')

                LOGGER.debug(
                    'Received notification for object \'%s\' in bucket \'%s\'',
                    object_key, bucket_name)

                self._s3_buckets_and_keys[bucket_name].add(object_key)

        if not self._add_partitions():
            raise AthenaRefreshError('Failed to add partitions: {}'.format(
                dict(self._s3_buckets_and_keys)))
Пример #2
0
class TestAthenaClient(object):
    """Test class for AthenaClient"""
    @patch.dict(os.environ, {'AWS_DEFAULT_REGION': 'us-west-1'})
    @patch('boto3.client', Mock(side_effect=lambda c: MockAthenaClient()))
    def setup(self):
        """Setup the AthenaClient tests"""

        self._db_name = 'test_database'
        config = load_config('tests/unit/conf/')
        prefix = config['global']['account']['prefix']

        self.client = AthenaClient(
            self._db_name, 's3://{}.streamalert.athena-results'.format(prefix),
            'unit-testing')

    @patch('stream_alert.shared.athena.datetime')
    def test_init_fix_bucket_path(self, date_mock):
        """Athena - Fix Bucket Path"""
        date_now = datetime.utcnow()
        date_mock.utcnow.return_value = date_now
        date_format = date_now.strftime('%Y/%m/%d')
        expected_path = 's3://test.streamalert.athena-results/unit-testing/{}'.format(
            date_format)
        with patch.dict(os.environ, {'AWS_DEFAULT_REGION': 'us-west-1'}):
            client = AthenaClient(self._db_name,
                                  'test.streamalert.athena-results',
                                  'unit-testing')
            assert_equal(client._s3_results_path, expected_path)

    def test_unique_values_from_query(self):
        """Athena - Unique Values from Query"""
        query = {
            'ResultSet': {
                'Rows': [
                    {
                        'Data': [{
                            'VarCharValue': 'foobar'
                        }]
                    },
                    {
                        'Data': [{
                            'VarCharValue': 'barfoo'
                        }]
                    },
                    {
                        'Data': [{
                            'VarCharValue': 'barfoo'
                        }]
                    },
                    {
                        'Data': [{
                            'VarCharValue': 'foobarbaz'
                        }]
                    },
                ]
            }
        }
        expected_result = {'foobar', 'barfoo', 'foobarbaz'}

        result = self.client._unique_values_from_query(query)
        assert_items_equal(result, expected_result)

    def test_check_database_exists(self):
        """Athena - Check Database Exists"""
        self.client._client.results = [{
            'Data': [{
                'VarCharValue': self._db_name
            }]
        }]

        assert_true(self.client.check_database_exists())

    def test_check_database_exists_invalid(self):
        """Athena - Check Database Exists - Does Not Exist"""
        self.client._client.results = None

        assert_false(self.client.check_database_exists())

    def test_check_table_exists(self):
        """Athena - Check Table Exists"""
        self.client._client.results = [{
            'Data': [{
                'VarCharValue': 'test_table'
            }]
        }]

        assert_true(self.client.check_table_exists('test_table'))

    def test_check_table_exists_invalid(self):
        """Athena - Check Table Exists - Does Not Exist"""
        self.client._client.results = None

        assert_false(self.client.check_table_exists('test_table'))

    def test_get_table_partitions(self):
        """Athena - Get Table Partitions"""
        self.client._client.results = [
            {
                'Data': [{
                    'VarCharValue': 'dt=2018-12-10-10'
                }]
            },
            {
                'Data': [{
                    'VarCharValue': 'dt=2018-12-09-10'
                }]
            },
            {
                'Data': [{
                    'VarCharValue': 'dt=2018-12-09-10'
                }]
            },
            {
                'Data': [{
                    'VarCharValue': 'dt=2018-12-11-10'
                }]
            },
        ]

        expected_result = {
            'dt=2018-12-10-10', 'dt=2018-12-09-10', 'dt=2018-12-11-10'
        }

        result = self.client.get_table_partitions('test_table')
        assert_items_equal(result, expected_result)

    def test_get_table_partitions_error(self):
        """Athena - Get Table Partitions, Exception"""
        self.client._client.raise_exception = True
        assert_raises(AthenaQueryExecutionError,
                      self.client.get_table_partitions, 'test_table')

    def test_drop_table(self):
        """Athena - Drop Table, Success"""
        assert_true(self.client.drop_table('test_table'))

    def test_drop_table_failure(self):
        """Athena - Drop Table, Failure"""
        self.client._client.raise_exception = True
        assert_raises(AthenaQueryExecutionError, self.client.drop_table,
                      'test_table')

    @patch('stream_alert.shared.athena.AthenaClient.drop_table')
    def test_drop_all_tables(self, drop_table_mock):
        """Athena - Drop All Tables, Success"""
        self.client._client.results = [
            {
                'Data': [{
                    'VarCharValue': 'table_01'
                }]
            },
            {
                'Data': [{
                    'VarCharValue': 'table_02'
                }]
            },
            {
                'Data': [{
                    'VarCharValue': 'table_02'
                }]
            },
        ]
        assert_true(self.client.drop_all_tables())
        assert_equal(drop_table_mock.call_count, 2)

    @patch('stream_alert.shared.athena.AthenaClient.drop_table')
    def test_drop_all_tables_failure(self, drop_table_mock):
        """Athena - Drop All Tables, Failure"""
        self.client._client.results = [
            {
                'Data': [{
                    'VarCharValue': 'table_01'
                }]
            },
            {
                'Data': [{
                    'VarCharValue': 'table_02'
                }]
            },
            {
                'Data': [{
                    'VarCharValue': 'table_03'
                }]
            },
        ]
        drop_table_mock.side_effect = [True, True, False]
        assert_false(self.client.drop_all_tables())

    def test_drop_all_tables_exception(self):
        """Athena - Drop All Tables, Exception"""
        self.client._client.raise_exception = True
        assert_raises(AthenaQueryExecutionError, self.client.drop_all_tables)

    def test_execute_query(self):
        """Athena - Execute Query"""
        self.client._client.raise_exception = True
        assert_raises(AthenaQueryExecutionError, self.client._execute_query,
                      'BAD SQL')

    def test_execute_and_wait(self):
        """Athena - Execute and Wait"""
        self.client._client.results = [
            {
                'Data': [{
                    'VarCharValue': 'result'
                }]
            },
        ]
        result = self.client._execute_and_wait('SQL query')
        assert_true(result in self.client._client.query_executions)

    def test_execute_and_wait_failed(self):
        """Athena - Execute and Wait, Failed"""
        query = 'SQL query'
        self.client._client.result_state = 'FAILED'
        assert_raises(AthenaQueryExecutionError, self.client._execute_and_wait,
                      query)

    def test_query_result_paginator(self):
        """Athena - Query Result Paginator"""
        data = {'Data': [{'VarCharValue': 'result'}]}
        self.client._client.results = [
            data,
        ]

        items = list(self.client.query_result_paginator('test query'))
        assert_items_equal(items, [{'ResultSet': {'Rows': [data]}}] * 4)

    @raises(AthenaQueryExecutionError)
    def test_query_result_paginator_error(self):
        """Athena - Query Result Paginator, Exception"""
        self.client._client.raise_exception = True
        list(self.client.query_result_paginator('test query'))

    def test_run_async_query(self):
        """Athena - Run Async Query, Success"""
        assert_true(self.client.run_async_query('test query'))

    def test_run_async_query_failure(self):
        """Athena - Run Async Query, Failure"""
        self.client._client.raise_exception = True
        assert_raises(AthenaQueryExecutionError, self.client.run_async_query,
                      'test query')
Пример #3
0
class AthenaRefresher(object):
    """Handle polling an SQS queue and running Athena queries for updating tables"""

    STREAMALERTS_REGEX = re.compile(r'alerts/dt=(?P<year>\d{4})'
                                    r'\-(?P<month>\d{2})'
                                    r'\-(?P<day>\d{2})'
                                    r'\-(?P<hour>\d{2})'
                                    r'\/.*.json')
    FIREHOSE_REGEX = re.compile(r'(?P<year>\d{4})'
                                r'\/(?P<month>\d{2})'
                                r'\/(?P<day>\d{2})'
                                r'\/(?P<hour>\d{2})\/.*')

    STREAMALERT_DATABASE = '{}_streamalert'
    ATHENA_S3_PREFIX = 'athena_partition_refresh'

    def __init__(self):
        config = load_config(include={'lambda.json', 'global.json'})
        prefix = config['global']['account']['prefix']
        athena_config = config['lambda']['athena_partition_refresh_config']

        self._athena_buckets = athena_config['buckets']

        db_name = athena_config.get('database_name',
                                    self.STREAMALERT_DATABASE.format(prefix))

        # Get the S3 bucket to store Athena query results
        results_bucket = athena_config.get(
            'results_bucket',
            's3://{}.streamalert.athena-results'.format(prefix))

        self._athena_client = AthenaClient(db_name, results_bucket,
                                           self.ATHENA_S3_PREFIX)

        # Initialize the SQS client and recieve messages
        self._sqs_client = StreamAlertSQSClient(config)

    def _get_partitions_from_keys(self, s3_buckets_and_keys):
        partitions = defaultdict(dict)

        LOGGER.info('Processing new Hive partitions...')
        for bucket, keys in s3_buckets_and_keys.iteritems():
            athena_table = self._athena_buckets.get(bucket)
            if not athena_table:
                # TODO(jacknagz): Add this as a metric
                LOGGER.error(
                    '%s not found in \'buckets\' config. Please add this '
                    'bucket to enable additions of Hive partitions.',
                    athena_table)
                continue

            # Iterate over each key
            for key in keys:
                match = None
                for pattern in (self.FIREHOSE_REGEX, self.STREAMALERTS_REGEX):
                    match = pattern.search(key)
                    if match:
                        break

                if not match:
                    LOGGER.error(
                        'The key %s does not match any regex, skipping', key)
                    continue

                # Get the path to the objects in S3
                path = posixpath.dirname(key)
                # The config does not need to store all possible tables
                # for enabled log types because this can be inferred from
                # the incoming S3 bucket notification.  Only enabled
                # log types will be sending data to Firehose.
                # This logic extracts out the name of the table from the
                # first element in the S3 path, as that's how log types
                # are configured to send to Firehose.
                if athena_table != 'alerts':
                    athena_table = path.split('/')[0]

                # Example:
                # PARTITION (dt = '2017-01-01-01') LOCATION 's3://bucket/path/'
                partition = '(dt = \'{year}-{month}-{day}-{hour}\')'.format(
                    **match.groupdict())
                location = '\'s3://{bucket}/{path}\''.format(bucket=bucket,
                                                             path=path)
                # By using the partition as the dict key, this ensures that
                # Athena will not try to add the same partition twice.
                # TODO(jacknagz): Write this dictionary to SSM/DynamoDb
                # to increase idempotence of this Lambda function
                partitions[athena_table][partition] = location

        return partitions

    def _add_partition(self, s3_buckets_and_keys):
        """Execute a Hive Add Partition command on a given Athena table

        Args:
            s3_buckets_and_keys (dict): Buckets and unique keys to add partitions

        Returns:
            (bool): If the repair was successful for not
        """
        partitions = self._get_partitions_from_keys(s3_buckets_and_keys)
        if not partitions:
            LOGGER.error('No partitons to add')
            return False

        for athena_table in partitions:
            partition_statement = ' '.join([
                'PARTITION {0} LOCATION {1}'.format(partition, location) for
                partition, location in partitions[athena_table].iteritems()
            ])
            query = ('ALTER TABLE {athena_table} '
                     'ADD IF NOT EXISTS {partition_statement};'.format(
                         athena_table=athena_table,
                         partition_statement=partition_statement))

            success = self._athena_client.run_query(query=query)
            if not success:
                raise AthenaRefreshError(
                    'The add hive partition query has failed:\n{}'.format(
                        query))

            LOGGER.info(
                'Successfully added the following partitions:\n%s',
                json.dumps({athena_table: partitions[athena_table]}, indent=4))
        return True

    def run(self):
        """Poll the SQS queue for messages and create partitions for new data"""
        # Check that the database being used exists before running queries
        if not self._athena_client.check_database_exists():
            raise AthenaRefreshError(
                'The \'{}\' database does not exist'.format(
                    self._athena_client.database))

        # Get the first batch of messages from SQS.  If there are no
        # messages, this will exit early.
        self._sqs_client.get_messages(max_tries=2)

        if not self._sqs_client.received_messages:
            LOGGER.info('No SQS messages recieved, exiting')
            return

        # If the max amount of messages was initially returned,
        # then get the next batch of messages.  The max is determined based
        # on (number of tries) * (number of possible max messages returned)
        if len(self._sqs_client.received_messages) == 20:
            self._sqs_client.get_messages(max_tries=8)

        s3_buckets_and_keys = self._sqs_client.unique_s3_buckets_and_keys()
        if not s3_buckets_and_keys:
            LOGGER.error('No new Athena partitions to add, exiting')
            return

        if not self._add_partition(s3_buckets_and_keys):
            LOGGER.error('Failed to add hive partition(s)')
            return

        self._sqs_client.delete_messages()
        LOGGER.info('Deleted %d messages from SQS',
                    self._sqs_client.deleted_message_count)