Пример #1
0
    def run(self):
        """Poll the SQS queue for messages and create partitions for new data"""
        # Check that the database being used exists before running queries
        if not self._athena_client.check_database_exists():
            raise AthenaRefreshError(
                'The \'{}\' database does not exist'.format(
                    self._athena_client.database))

        # Get the first batch of messages from SQS.  If there are no
        # messages, this will exit early.
        self._sqs_client.get_messages(max_tries=2)

        if not self._sqs_client.received_messages:
            LOGGER.info('No SQS messages recieved, exiting')
            return

        # If the max amount of messages was initially returned,
        # then get the next batch of messages.  The max is determined based
        # on (number of tries) * (number of possible max messages returned)
        if len(self._sqs_client.received_messages) == 20:
            self._sqs_client.get_messages(max_tries=8)

        s3_buckets_and_keys = self._sqs_client.unique_s3_buckets_and_keys()
        if not s3_buckets_and_keys:
            LOGGER.error('No new Athena partitions to add, exiting')
            return

        if not self._add_partition(s3_buckets_and_keys):
            LOGGER.error('Failed to add hive partition(s)')
            return

        self._sqs_client.delete_messages()
        LOGGER.info('Deleted %d messages from SQS',
                    self._sqs_client.deleted_message_count)
Пример #2
0
    def _add_partition(self, s3_buckets_and_keys):
        """Execute a Hive Add Partition command on a given Athena table

        Args:
            s3_buckets_and_keys (dict): Buckets and unique keys to add partitions

        Returns:
            (bool): If the repair was successful for not
        """
        partitions = self._get_partitions_from_keys(s3_buckets_and_keys)
        if not partitions:
            LOGGER.error('No partitons to add')
            return False

        for athena_table in partitions:
            partition_statement = ' '.join([
                'PARTITION {0} LOCATION {1}'.format(partition, location) for
                partition, location in partitions[athena_table].iteritems()
            ])
            query = ('ALTER TABLE {athena_table} '
                     'ADD IF NOT EXISTS {partition_statement};'.format(
                         athena_table=athena_table,
                         partition_statement=partition_statement))

            success = self._athena_client.run_query(query=query)
            if not success:
                raise AthenaRefreshError(
                    'The add hive partition query has failed:\n{}'.format(
                        query))

            LOGGER.info(
                'Successfully added the following partitions:\n%s',
                json.dumps({athena_table: partitions[athena_table]}, indent=4))
        return True
Пример #3
0
    def unique_s3_buckets_and_keys(self):
        """Filter a list of unique s3 buckets and S3 keys from event notifications

        Returns:
            (dict): Keys of bucket names, and values of unique S3 keys
        """
        s3_buckets_and_keys = defaultdict(set)

        if not self.received_messages:
            LOGGER.error(
                'No messages to filter, fetch the messages with get_messages()'
            )
            return

        for message in self.received_messages:
            if 'Body' not in message:
                LOGGER.error('Missing \'Body\' key in SQS message, skipping')
                continue

            loaded_message = json.loads(message['Body'])

            # From AWS documentation: http://amzn.to/2w4fcSq
            # When you configure an event notification on a bucket,
            # Amazon S3 sends the following test message:
            # {
            #    "Service":"Amazon S3",
            #    "Event":"s3:TestEvent",
            #    "Time":"2014-10-13T15:57:02.089Z",
            #    "Bucket":"bucketname",
            #    "RequestId":"5582815E1AEA5ADF",
            #    "HostId":"8cLeGAmw098X5cv4Zkwcmo8vvZa3eH3eKxsPzbB9wrR+YstdA6Knx4Ip8EXAMPLE"
            # }
            if loaded_message.get('Event') == 's3:TestEvent':
                LOGGER.debug('Skipping S3 bucket notification test event')
                continue

            if 'Records' not in loaded_message:
                LOGGER.error(
                    'Missing \'Records\' key in SQS message, skipping:\n%s',
                    json.dumps(loaded_message, indent=4))
                continue

            for record in loaded_message['Records']:
                if 's3' not in record:
                    LOGGER.info('Skipping non-s3 bucket notification message')
                    LOGGER.debug(record)
                    continue

                bucket_name = record['s3']['bucket']['name']
                # Account for special characters in the S3 object key
                # Example: Usage of '=' in the key name
                object_key = urllib.unquote(
                    record['s3']['object']['key']).decode('utf8')
                s3_buckets_and_keys[bucket_name].add(object_key)

                # Add to a new list to track successfully processed messages from the queue
                self.processed_messages.append(message)

        return s3_buckets_and_keys
Пример #4
0
    def _get_partitions_from_keys(self, s3_buckets_and_keys):
        partitions = defaultdict(dict)

        LOGGER.info('Processing new Hive partitions...')
        for bucket, keys in s3_buckets_and_keys.iteritems():
            athena_table = self._athena_buckets.get(bucket)
            if not athena_table:
                # TODO(jacknagz): Add this as a metric
                LOGGER.error(
                    '%s not found in \'buckets\' config. Please add this '
                    'bucket to enable additions of Hive partitions.',
                    athena_table)
                continue

            # Iterate over each key
            for key in keys:
                match = None
                for pattern in (self.FIREHOSE_REGEX, self.STREAMALERTS_REGEX):
                    match = pattern.search(key)
                    if match:
                        break

                if not match:
                    LOGGER.error(
                        'The key %s does not match any regex, skipping', key)
                    continue

                # Get the path to the objects in S3
                path = posixpath.dirname(key)
                # The config does not need to store all possible tables
                # for enabled log types because this can be inferred from
                # the incoming S3 bucket notification.  Only enabled
                # log types will be sending data to Firehose.
                # This logic extracts out the name of the table from the
                # first element in the S3 path, as that's how log types
                # are configured to send to Firehose.
                if athena_table != 'alerts':
                    athena_table = path.split('/')[0]

                # Example:
                # PARTITION (dt = '2017-01-01-01') LOCATION 's3://bucket/path/'
                partition = '(dt = \'{year}-{month}-{day}-{hour}\')'.format(
                    **match.groupdict())
                location = '\'s3://{bucket}/{path}\''.format(bucket=bucket,
                                                             path=path)
                # By using the partition as the dict key, this ensures that
                # Athena will not try to add the same partition twice.
                # TODO(jacknagz): Write this dictionary to SSM/DynamoDb
                # to increase idempotence of this Lambda function
                partitions[athena_table][partition] = location

        return partitions
Пример #5
0
    def check_table_exists(self, table_name):
        """Verify a given StreamAlert Athena table exists."""
        query_success, query_resp = self.run_athena_query(
            query='SHOW TABLES LIKE \'{}\';'.format(table_name),
            database=self.DATABASE_STREAMALERT)

        if query_success and query_resp['ResultSet']['Rows']:
            return True

        LOGGER.info('The streamalert table \'%s\' does not exist.', table_name)
        LOGGER.info('For help with creating tables: '
                    '$ python manage.py athena create-table --help')
        return False
Пример #6
0
    def check_table_exists(self, table_name):
        """Verify a given StreamAlert Athena table exists."""
        query_success, query_resp = self.run_athena_query(
            query='SHOW TABLES LIKE \'{}\';'.format(table_name),
            database=self.DATABASE_STREAMALERT)

        if query_success and query_resp['ResultSet']['Rows']:
            return True

        LOGGER.info(
            'The streamalert table \'%s\' does not exist. '
            'For alert buckets, create it with the following command: \n'
            '$ python manage.py athena create-table '
            '--type alerts --bucket s3.bucket.id', table_name)
        return False
Пример #7
0
    def run(self, event):
        """Take the messages from the SQS queue and create partitions for new data in S3

        Args:
            event (dict): Lambda input event containing SQS messages. Each SQS message
                should contain one (or maybe more) S3 bucket notification message.
        """
        # Check that the database being used exists before running queries
        if not self._athena_client.check_database_exists():
            raise AthenaRefreshError(
                'The \'{}\' database does not exist'.format(
                    self._athena_client.database))

        for sqs_rec in event['Records']:
            LOGGER.debug(
                'Processing event with message ID \'%s\' and SentTimestamp %s',
                sqs_rec['messageId'], sqs_rec['attributes']['SentTimestamp'])

            body = json.loads(sqs_rec['body'])
            if body.get('Event') == 's3:TestEvent':
                LOGGER.debug('Skipping S3 bucket notification test event')
                continue

            for s3_rec in body['Records']:
                if 's3' not in s3_rec:
                    LOGGER.info(
                        'Skipping non-s3 bucket notification message: %s',
                        s3_rec)
                    continue

                bucket_name = s3_rec['s3']['bucket']['name']

                # Account for special characters in the S3 object key
                # Example: Usage of '=' in the key name
                object_key = urllib.unquote_plus(
                    s3_rec['s3']['object']['key']).decode('utf8')

                LOGGER.debug(
                    'Received notification for object \'%s\' in bucket \'%s\'',
                    object_key, bucket_name)

                self._s3_buckets_and_keys[bucket_name].add(object_key)

        if not self._add_partitions():
            raise AthenaRefreshError('Failed to add partitions: {}'.format(
                dict(self._s3_buckets_and_keys)))
Пример #8
0
    def get_messages(self, **kwargs):
        """Poll the SQS queue for new messages

        Keyword Args:
            max_tries (int): The number of times to backoff
            max_value (int): The max wait interval between backoffs
            max_messages (int): The max number of messages to get from SQS
        """
        start_message_count = len(self.received_messages)

        # Backoff up to 5 times to limit the time spent in this operation
        # relative to the entire Lambda duration.
        max_tries = kwargs.get('max_tries', 5)

        # This value restricts the max time of backoff each try.
        # This means the total backoff time for one function call is:
        #   max_tries (attempts) * max_value (seconds)
        max_value = kwargs.get('max_value', 5)

        # Number of messages to poll from the stream.
        max_messages = kwargs.get('max_messages',
                                  self.MAX_SQS_GET_MESSAGE_COUNT)
        if max_messages > self.MAX_SQS_GET_MESSAGE_COUNT:
            LOGGER.error(
                'SQS can only request up to 10 messages in one request')
            return

        @backoff.on_predicate(backoff.fibo,
                              max_tries=max_tries,
                              max_value=max_value,
                              jitter=backoff.full_jitter,
                              on_backoff=backoff_handler(),
                              on_success=success_handler(True),
                              on_giveup=giveup_handler(True))
        def _receive_messages():
            polled_messages = self.sqs_client.receive_message(
                QueueUrl=self.athena_sqs_url, MaxNumberOfMessages=max_messages)

            if 'Messages' not in polled_messages:
                return False
            self.received_messages.extend(polled_messages['Messages'])

        _receive_messages()
        batch_count = len(self.received_messages) - start_message_count
        LOGGER.info('Received %d message(s) from SQS', batch_count)
Пример #9
0
def handler(*_):
    """Athena Partition Refresher Handler Function"""
    config = _load_config()

    # Initialize the SQS client and recieve messages
    stream_alert_sqs = StreamAlertSQSClient(config)
    # Get the first batch of messages from SQS.  If there are no
    # messages, this will exit early.
    stream_alert_sqs.get_messages(max_tries=2)

    if not stream_alert_sqs.received_messages:
        LOGGER.info('No SQS messages recieved, exiting')
        return

    # If the max amount of messages was initially returned,
    # then get the next batch of messages.  The max is determined based
    # on (number of tries) * (number of possible max messages returned)
    if len(stream_alert_sqs.received_messages) == 20:
        stream_alert_sqs.get_messages(max_tries=8)

    s3_buckets_and_keys = stream_alert_sqs.unique_s3_buckets_and_keys()
    if not s3_buckets_and_keys:
        LOGGER.error('No new Athena partitions to add, exiting')
        return

    # Initialize the Athena client and run queries
    stream_alert_athena = StreamAlertAthenaClient(config)

    # Check that the 'streamalert' database exists before running queries
    if not stream_alert_athena.check_database_exists():
        raise AthenaPartitionRefreshError(
            'The \'{}\' database does not exist'.format(
                stream_alert_athena.sa_database))

    if not stream_alert_athena.add_partition(s3_buckets_and_keys):
        LOGGER.error('Failed to add hive partition(s)')
        return

    stream_alert_sqs.delete_messages()
    LOGGER.info('Deleted %d messages from SQS',
                stream_alert_sqs.deleted_messages)
Пример #10
0
    def get_messages(self, max_tries=5, max_value=5, max_messages=MAX_SQS_GET_MESSAGE_COUNT):
        """Poll the SQS queue for new messages

        Keyword Args:
            max_tries (int): The number of times to backoff
                Backoff up to 5 times to limit the time spent in this operation
                relative to the entire Lambda duration.
            max_value (int): The max wait interval between backoffs
                This value restricts the max time of backoff each try.
                This means the total backoff time for one function call is:
                    max_tries (attempts) * max_value (seconds)
            max_messages (int): The max number of messages to get from SQS
        """
        start_message_count = len(self.received_messages)

        # Number of messages to poll from the stream.
        if max_messages > self.MAX_SQS_GET_MESSAGE_COUNT:
            LOGGER.error('The maximum requested messages exceeds the SQS limitation per request. '
                         'Setting max messages to %d', self.MAX_SQS_GET_MESSAGE_COUNT)
            max_messages = self.MAX_SQS_GET_MESSAGE_COUNT

        @backoff.on_predicate(backoff.fibo,
                              max_tries=max_tries,
                              max_value=max_value,
                              jitter=backoff.full_jitter,
                              on_backoff=backoff_handler(),
                              on_success=success_handler(True),
                              on_giveup=giveup_handler(True))
        def _receive_messages():
            polled_messages = self.sqs_client.receive_message(
                QueueUrl=self.athena_sqs_url,
                MaxNumberOfMessages=max_messages
            )
            if 'Messages' not in polled_messages:
                return True # return True to stop polling

            self.received_messages.extend(polled_messages['Messages'])

        _receive_messages()
        batch_count = len(self.received_messages) - start_message_count
        LOGGER.info('Received %d message(s) from SQS', batch_count)
Пример #11
0
    def repair_hive_table(self, unique_buckets):
        """Execute a MSCK REPAIR TABLE on a given Athena table

        Args:
            unique_buckets (list): S3 buckets to repair

        Returns:
            (bool): If the repair was successful for not
        """
        athena_config = self.config['lambda'][
            'athena_partition_refresh_config']
        repair_hive_table_config = athena_config['refresh_type'][
            'repair_hive_table']

        LOGGER.info('Processing Hive repair table...')
        for data_bucket in unique_buckets:
            athena_table = repair_hive_table_config.get(data_bucket)
            if not athena_table:
                LOGGER.warning(
                    '%s not found in repair_hive_table config. '
                    'Please update your configuration accordingly.',
                    athena_table)
                continue

            query_success, query_resp = self.run_athena_query(
                query='MSCK REPAIR TABLE {};'.format(athena_table),
                database=self.DATABASE_STREAMALERT)

            if query_success:
                LOGGER.info('Query results:')
                for row in query_resp['ResultSet']['Rows']:
                    LOGGER.info(row['Data'])
            else:
                LOGGER.error(
                    'Partition refresh of the Athena table '
                    '%s has failed.', athena_table)
                return False

        return True
Пример #12
0
    def add_partition(self, s3_buckets_and_keys):
        """Execute a Hive Add Partition command on a given Athena table

        Args:
            s3_buckets_and_keys (dict): Buckets and unique keys to add partitions

        Returns:
            (bool): If the repair was successful for not
        """
        athena_config = self.config['lambda'][
            'athena_partition_refresh_config']
        athena_buckets = athena_config['buckets']
        partitions = defaultdict(dict)

        LOGGER.info('Processing new Hive partitions...')
        for bucket, keys in s3_buckets_and_keys.iteritems():
            athena_table = athena_buckets.get(bucket)
            if not athena_table:
                # TODO(jacknagz): Add this as a metric
                LOGGER.error(
                    '%s not found in \'buckets\' config. '
                    'Please add this bucket to enable additions '
                    'of Hive partitions.', athena_table)
                continue

            # Iterate over each key
            for key in keys:
                for pattern in (self.FIREHOSE_REGEX, self.STREAMALERTS_REGEX):
                    match = pattern.search(key)
                    if match:
                        break

                if not match:
                    LOGGER.error(
                        'The key %s does not match any regex, skipping', key)
                    continue

                # Convert the match groups to a dict for easy access
                match_dict = match.groupdict()
                # Get the path to the objects in S3
                path = os.path.dirname(key)
                # The config does not need to store all possible tables
                # for enabled log types because this can be inferred from
                # the incoming S3 bucket notification.  Only enabled
                # log types will be sending data to Firehose.
                # This logic extracts out the name of the table from the
                # first element in the S3 path, as that's how log types
                # are configured to send to Firehose.
                if athena_table != 'alerts':
                    athena_table = path.split('/')[0]

                # Example:
                # PARTITION (dt = '2017-01-01-01') LOCATION 's3://bucket/path/'
                partition = '(dt = \'{year}-{month}-{day}-{hour}\')'.format(
                    year=match_dict['year'],
                    month=match_dict['month'],
                    day=match_dict['day'],
                    hour=match_dict['hour'])
                location = '\'s3://{bucket}/{path}\''.format(bucket=bucket,
                                                             path=path)
                # By using the partition as the dict key, this ensures that
                # Athena will not try to add the same partition twice.
                # TODO(jacknagz): Write this dictionary to SSM/DynamoDb
                # to increase idempotence of this Lambda function
                partitions[athena_table][partition] = location

        if not partitions:
            LOGGER.error('No partitons to add')
            return False

        for athena_table in partitions:
            partition_statement = ' '.join([
                'PARTITION {0} LOCATION {1}'.format(partition, location) for
                partition, location in partitions[athena_table].iteritems()
            ])
            query = ('ALTER TABLE {athena_table} '
                     'ADD IF NOT EXISTS {partition_statement};'.format(
                         athena_table=athena_table,
                         partition_statement=partition_statement))

            query_success, _ = self.run_athena_query(query=query,
                                                     database=self.sa_database)

            if not query_success:
                raise AthenaPartitionRefreshError(
                    'The add hive partition query has failed:\n{}'.format(
                        query))

            LOGGER.info(
                'Successfully added the following partitions:\n%s',
                json.dumps({athena_table: partitions[athena_table]}, indent=4))
        return True
Пример #13
0
    def add_hive_partition(self, s3_buckets_and_keys):
        """Execute a Hive Add Partition command on a given Athena table

        Args:
            s3_buckets_and_keys (dict): Buckets and unique keys to add partitions

        Returns:
            (bool): If the repair was successful for not
        """
        athena_config = self.config['lambda']['athena_partition_refresh_config']
        add_hive_partition_config = athena_config['refresh_type']['add_hive_partition']
        partitions = {}

        LOGGER.info('Processing new Hive partitions...')
        for bucket, keys in s3_buckets_and_keys.iteritems():
            athena_table = add_hive_partition_config.get(bucket)
            if not athena_table:
                LOGGER.error('%s not found in \'add_hive_partition\' config. '
                             'Please add this bucket to enable additions '
                             'of Hive partitions.',
                             athena_table)
                continue

            # Gather all of the partitions to add per bucket
            s3_key_regex = self.STREAMALERTS_REGEX if athena_table == 'alerts' \
                                                   else self.FIREHOSE_REGEX
            # Iterate over each key
            for key in keys:
                match = s3_key_regex.search(key)
                if not match:
                    LOGGER.error('The key %s does not match the regex %s, skipping',
                                 key, s3_key_regex.pattern)
                    continue

                # Convert the match groups to a dict for easy access
                match_dict = match.groupdict()
                # Get the path to the objects in S3
                path = os.path.dirname(key)

                # PARTITION (dt = '2017-01-01-01') LOCATION 's3://bucket/path/'
                partition = '(dt = \'{year}-{month}-{day}-{hour}\')'.format(
                    year=match_dict['year'],
                    month=match_dict['month'],
                    day=match_dict['day'],
                    hour=match_dict['hour'])
                location = '\'s3://{bucket}/{path}\''.format(
                    bucket=bucket,
                    path=path)
                # By using the partition as the dict key, this ensures that
                # Athena will not try to add the same partition twice.
                partitions[partition] = location

        if not partitions:
            LOGGER.error('No partitons to add')
            return False

        partition_statement = ' '.join(
            ['PARTITION {0} LOCATION {1}'.format(
                partition, location) for partition, location in partitions.iteritems()])
        query = ('ALTER TABLE {athena_table} '
                 'ADD IF NOT EXISTS {partition_statement};'.format(
                     athena_table=athena_table,
                     partition_statement=partition_statement))

        query_success, _ = self.run_athena_query(
            query=query,
            database=self.DATABASE_STREAMALERT
        )

        if not query_success:
            LOGGER.error('The add hive partition query has failed:\n%s', query)
            return False

        LOGGER.info('Successfully added the following partitions:\n%s',
                    '\n'.join(partitions))
        return True