def __init__(self): config = load_config() prefix = config['global']['account']['prefix'] # Create the rule table class for getting staging information self._rule_table = RuleTable('{}_streamalert_rules'.format(prefix)) athena_config = config['lambda']['athena_partition_refresh_config'] # Get the name of the athena database to access db_name = athena_config.get('database_name', self.STREAMALERT_DATABASE.format(prefix)) # Get the S3 bucket to store Athena query results results_bucket = athena_config.get( 'results_bucket', 's3://{}.streamalert.athena-results'.format(prefix) ) self._athena_client = AthenaClient(db_name, results_bucket, self.ATHENA_S3_PREFIX) self._current_time = datetime.utcnow() # Store the SNS topic arn to send alert stat information to self._publisher = StatsPublisher(config, self._athena_client, self._current_time) self._staging_stats = dict()
def setup(self): """Setup the AthenaClient tests""" self._db_name = 'test_database' config = load_config('tests/unit/conf/') prefix = config['global']['account']['prefix'] self.client = AthenaClient( self._db_name, 's3://{}.streamalert.athena-results'.format(prefix), 'unit-testing')
def _create_client(cls, db_name, results_bucket): if cls._ATHENA_CLIENT: return # Client already created/cached cls._ATHENA_CLIENT = AthenaClient(db_name, results_bucket, cls.ATHENA_S3_PREFIX) # Check if the database exists when the client is created if not cls._ATHENA_CLIENT.check_database_exists(): raise AthenaRefreshError('The \'{}\' database does not exist'.format(db_name))
def __init__(self): config = load_config(include={'lambda.json', 'global.json'}) prefix = config['global']['account']['prefix'] athena_config = config['lambda']['athena_partition_refresh_config'] self._athena_buckets = athena_config['buckets'] db_name = athena_config.get('database_name', self.STREAMALERT_DATABASE.format(prefix)) # Get the S3 bucket to store Athena query results results_bucket = athena_config.get( 'results_bucket', 's3://{}.streamalert.athena-results'.format(prefix)) self._athena_client = AthenaClient(db_name, results_bucket, self.ATHENA_S3_PREFIX) self._s3_buckets_and_keys = defaultdict(set)
def test_init_fix_bucket_path(self, date_mock): """Athena - Fix Bucket Path""" date_now = datetime.utcnow() date_mock.utcnow.return_value = date_now date_format = date_now.strftime('%Y/%m/%d') expected_path = 's3://test.streamalert.athena-results/unit-testing/{}'.format( date_format) with patch.dict(os.environ, {'AWS_DEFAULT_REGION': 'us-west-1'}): client = AthenaClient(self._db_name, 'test.streamalert.athena-results', 'unit-testing') assert_equal(client._s3_results_path, expected_path)
def get_athena_client(config): """Get an athena client using the current config settings Args: config (CLIConfig): Loaded StreamAlert config Returns: AthenaClient: instantiated client for performing athena actions """ prefix = config['global']['account']['prefix'] athena_config = config['lambda']['athena_partition_refresh_config'] db_name = athena_config.get( 'database_name', AthenaRefresher.STREAMALERT_DATABASE.format(prefix)) # Get the S3 bucket to store Athena query results results_bucket = athena_config.get( 'results_bucket', 's3://{}.streamalert.athena-results'.format(prefix)) return AthenaClient(db_name, results_bucket, 'stream_alert_cli')
class TestAthenaClient(object): """Test class for AthenaClient""" @patch.dict(os.environ, {'AWS_DEFAULT_REGION': 'us-west-1'}) @patch('boto3.client', Mock(side_effect=lambda c: MockAthenaClient())) def setup(self): """Setup the AthenaClient tests""" self._db_name = 'test_database' config = load_config('tests/unit/conf/') prefix = config['global']['account']['prefix'] self.client = AthenaClient( self._db_name, 's3://{}.streamalert.athena-results'.format(prefix), 'unit-testing') @patch('stream_alert.shared.athena.datetime') def test_init_fix_bucket_path(self, date_mock): """Athena - Fix Bucket Path""" date_now = datetime.utcnow() date_mock.utcnow.return_value = date_now date_format = date_now.strftime('%Y/%m/%d') expected_path = 's3://test.streamalert.athena-results/unit-testing/{}'.format( date_format) with patch.dict(os.environ, {'AWS_DEFAULT_REGION': 'us-west-1'}): client = AthenaClient(self._db_name, 'test.streamalert.athena-results', 'unit-testing') assert_equal(client._s3_results_path, expected_path) def test_unique_values_from_query(self): """Athena - Unique Values from Query""" query = { 'ResultSet': { 'Rows': [ { 'Data': [{ 'VarCharValue': 'foobar' }] }, { 'Data': [{ 'VarCharValue': 'barfoo' }] }, { 'Data': [{ 'VarCharValue': 'barfoo' }] }, { 'Data': [{ 'VarCharValue': 'foobarbaz' }] }, ] } } expected_result = {'foobar', 'barfoo', 'foobarbaz'} result = self.client._unique_values_from_query(query) assert_items_equal(result, expected_result) def test_check_database_exists(self): """Athena - Check Database Exists""" self.client._client.results = [{ 'Data': [{ 'VarCharValue': self._db_name }] }] assert_true(self.client.check_database_exists()) def test_check_database_exists_invalid(self): """Athena - Check Database Exists - Does Not Exist""" self.client._client.results = None assert_false(self.client.check_database_exists()) def test_check_table_exists(self): """Athena - Check Table Exists""" self.client._client.results = [{ 'Data': [{ 'VarCharValue': 'test_table' }] }] assert_true(self.client.check_table_exists('test_table')) def test_check_table_exists_invalid(self): """Athena - Check Table Exists - Does Not Exist""" self.client._client.results = None assert_false(self.client.check_table_exists('test_table')) def test_get_table_partitions(self): """Athena - Get Table Partitions""" self.client._client.results = [ { 'Data': [{ 'VarCharValue': 'dt=2018-12-10-10' }] }, { 'Data': [{ 'VarCharValue': 'dt=2018-12-09-10' }] }, { 'Data': [{ 'VarCharValue': 'dt=2018-12-09-10' }] }, { 'Data': [{ 'VarCharValue': 'dt=2018-12-11-10' }] }, ] expected_result = { 'dt=2018-12-10-10', 'dt=2018-12-09-10', 'dt=2018-12-11-10' } result = self.client.get_table_partitions('test_table') assert_items_equal(result, expected_result) def test_get_table_partitions_error(self): """Athena - Get Table Partitions, Exception""" self.client._client.raise_exception = True assert_raises(AthenaQueryExecutionError, self.client.get_table_partitions, 'test_table') def test_drop_table(self): """Athena - Drop Table, Success""" assert_true(self.client.drop_table('test_table')) def test_drop_table_failure(self): """Athena - Drop Table, Failure""" self.client._client.raise_exception = True assert_raises(AthenaQueryExecutionError, self.client.drop_table, 'test_table') @patch('stream_alert.shared.athena.AthenaClient.drop_table') def test_drop_all_tables(self, drop_table_mock): """Athena - Drop All Tables, Success""" self.client._client.results = [ { 'Data': [{ 'VarCharValue': 'table_01' }] }, { 'Data': [{ 'VarCharValue': 'table_02' }] }, { 'Data': [{ 'VarCharValue': 'table_02' }] }, ] assert_true(self.client.drop_all_tables()) assert_equal(drop_table_mock.call_count, 2) @patch('stream_alert.shared.athena.AthenaClient.drop_table') def test_drop_all_tables_failure(self, drop_table_mock): """Athena - Drop All Tables, Failure""" self.client._client.results = [ { 'Data': [{ 'VarCharValue': 'table_01' }] }, { 'Data': [{ 'VarCharValue': 'table_02' }] }, { 'Data': [{ 'VarCharValue': 'table_03' }] }, ] drop_table_mock.side_effect = [True, True, False] assert_false(self.client.drop_all_tables()) def test_drop_all_tables_exception(self): """Athena - Drop All Tables, Exception""" self.client._client.raise_exception = True assert_raises(AthenaQueryExecutionError, self.client.drop_all_tables) def test_execute_query(self): """Athena - Execute Query""" self.client._client.raise_exception = True assert_raises(AthenaQueryExecutionError, self.client._execute_query, 'BAD SQL') def test_execute_and_wait(self): """Athena - Execute and Wait""" self.client._client.results = [ { 'Data': [{ 'VarCharValue': 'result' }] }, ] result = self.client._execute_and_wait('SQL query') assert_true(result in self.client._client.query_executions) def test_execute_and_wait_failed(self): """Athena - Execute and Wait, Failed""" query = 'SQL query' self.client._client.result_state = 'FAILED' assert_raises(AthenaQueryExecutionError, self.client._execute_and_wait, query) def test_query_result_paginator(self): """Athena - Query Result Paginator""" data = {'Data': [{'VarCharValue': 'result'}]} self.client._client.results = [ data, ] items = list(self.client.query_result_paginator('test query')) assert_items_equal(items, [{'ResultSet': {'Rows': [data]}}] * 4) @raises(AthenaQueryExecutionError) def test_query_result_paginator_error(self): """Athena - Query Result Paginator, Exception""" self.client._client.raise_exception = True list(self.client.query_result_paginator('test query')) def test_run_async_query(self): """Athena - Run Async Query, Success""" assert_true(self.client.run_async_query('test query')) def test_run_async_query_failure(self): """Athena - Run Async Query, Failure""" self.client._client.raise_exception = True assert_raises(AthenaQueryExecutionError, self.client.run_async_query, 'test query')
class RulePromoter(object): """Run queries to generate statistics on alerts.""" ATHENA_S3_PREFIX = 'rule_promoter' STREAMALERT_DATABASE = '{}_streamalert' def __init__(self): self._config = load_config() prefix = self._config['global']['account']['prefix'] # Create the rule table class for getting staging information self._rule_table = RuleTable('{}_streamalert_rules'.format(prefix)) athena_config = self._config['lambda'][ 'athena_partition_refresh_config'] # Get the name of the athena database to access db_name = athena_config.get('database_name', self.STREAMALERT_DATABASE.format(prefix)) # Get the S3 bucket to store Athena query results results_bucket = athena_config.get( 'results_bucket', 's3://{}.streamalert.athena-results'.format(prefix)) self._athena_client = AthenaClient(db_name, results_bucket, self.ATHENA_S3_PREFIX) self._current_time = datetime.utcnow() self._staging_stats = dict() def _get_staging_info(self): """Query the Rule table for rule staging info needed to count each rule's alerts Example of rule metadata returned by RuleTable.remote_rule_info(): { 'example_rule_name': { 'Staged': True 'StagedAt': datetime.datetime object, 'StagedUntil': '2018-04-21T02:23:13.332223Z' } } """ for rule in sorted(self._rule_table.remote_rule_info): info = self._rule_table.remote_rule_info[rule] # If the rule is not staged, do not get stats on it if not info['Staged']: continue self._staging_stats[rule] = StagingStatistic( info['StagedAt'], info['StagedUntil'], self._current_time, rule) return len(self._staging_stats) != 0 def _update_alert_count(self): """Transform Athena query results into alert counts for rules_engine Args: query (str): Athena query to run and wait for results Returns: dict: Representation of alert counts, where key is the rule name and value is the alert count (int) since this rule was staged """ query = StagingStatistic.construct_compound_count_query( self._staging_stats.values()) LOGGER.debug('Running compound query for alert count: \'%s\'', query) for page, results in enumerate( self._athena_client.query_result_paginator(query)): for i, row in enumerate(results['ResultSet']['Rows']): if page == 0 and i == 0: # skip header row included in first page only continue row_values = [data.values()[0] for data in row['Data']] rule_name, alert_count = row_values[0], int(row_values[1]) LOGGER.debug('Found %d alerts for rule \'%s\'', alert_count, rule_name) self._staging_stats[rule_name].alert_count = alert_count def run(self, send_digest): """Perform statistic analysis of currently staged rules Args: send_digest (bool): True if the staging statistics digest should be published, False otherwise """ if not self._get_staging_info(): LOGGER.debug('No staged rules to promote') return self._update_alert_count() self._promote_rules() if send_digest: publisher = StatsPublisher(self._config, self._athena_client, self._current_time) publisher.publish(self._staging_stats.values()) else: LOGGER.debug('Staging statistics digest will not be sent') def _promote_rules(self): """Promote any rule that has not resulted in any alerts since being staged""" for rule in self._rules_to_be_promoted: LOGGER.info('Promoting rule \'%s\' at %s', rule, self._current_time) self._rule_table.toggle_staged_state(rule, False) @property def _rules_to_be_promoted(self): """Returns a list of rules that are eligible for promotion""" return [ rule for rule, stat in self._staging_stats.iteritems() if self._current_time > stat.staged_until and stat.alert_count == 0 ] @property def _rules_failing_promotion(self): """Returns a list of rules that are ineligible for promotion""" return [ rule for rule, stat in self._staging_stats.iteritems() if stat.alert_count != 0 ]
class AthenaRefresher(object): """Handle polling an SQS queue and running Athena queries for updating tables""" STREAMALERTS_REGEX = re.compile(r'alerts/dt=(?P<year>\d{4})' r'\-(?P<month>\d{2})' r'\-(?P<day>\d{2})' r'\-(?P<hour>\d{2})' r'\/.*.json') FIREHOSE_REGEX = re.compile(r'(?P<year>\d{4})' r'\/(?P<month>\d{2})' r'\/(?P<day>\d{2})' r'\/(?P<hour>\d{2})\/.*') STREAMALERT_DATABASE = '{}_streamalert' ATHENA_S3_PREFIX = 'athena_partition_refresh' def __init__(self): config = load_config(include={'lambda.json', 'global.json'}) prefix = config['global']['account']['prefix'] athena_config = config['lambda']['athena_partition_refresh_config'] self._athena_buckets = athena_config['buckets'] db_name = athena_config.get('database_name', self.STREAMALERT_DATABASE.format(prefix)) # Get the S3 bucket to store Athena query results results_bucket = athena_config.get( 'results_bucket', 's3://{}.streamalert.athena-results'.format(prefix)) self._athena_client = AthenaClient(db_name, results_bucket, self.ATHENA_S3_PREFIX) self._s3_buckets_and_keys = defaultdict(set) def _get_partitions_from_keys(self): """Get the partitions that need to be added for the Athena tables Returns: (dict): representation of tables, partitions and locations to be added Example: { 'alerts': { '(dt = \'2018-08-01-01\')': 's3://streamalert.alerts/2018/08/01/01' } } """ partitions = defaultdict(dict) LOGGER.info('Processing new Hive partitions...') for bucket, keys in self._s3_buckets_and_keys.iteritems(): athena_table = self._athena_buckets.get(bucket) if not athena_table: # TODO(jacknagz): Add this as a metric LOGGER.error( '\'%s\' not found in \'buckets\' config. Please add this ' 'bucket to enable additions of Hive partitions.', bucket) continue # Iterate over each key for key in keys: match = None for pattern in (self.FIREHOSE_REGEX, self.STREAMALERTS_REGEX): match = pattern.search(key) if match: break if not match: LOGGER.error( 'The key %s does not match any regex, skipping', key) continue # Get the path to the objects in S3 path = posixpath.dirname(key) # The config does not need to store all possible tables # for enabled log types because this can be inferred from # the incoming S3 bucket notification. Only enabled # log types will be sending data to Firehose. # This logic extracts out the name of the table from the # first element in the S3 path, as that's how log types # are configured to send to Firehose. if athena_table != 'alerts': athena_table = path.split('/')[0] # Example: # PARTITION (dt = '2017-01-01-01') LOCATION 's3://bucket/path/' partition = '(dt = \'{year}-{month}-{day}-{hour}\')'.format( **match.groupdict()) location = '\'s3://{bucket}/{path}\''.format(bucket=bucket, path=path) # By using the partition as the dict key, this ensures that # Athena will not try to add the same partition twice. # TODO(jacknagz): Write this dictionary to SSM/DynamoDb # to increase idempotence of this Lambda function partitions[athena_table][partition] = location return partitions def _add_partitions(self): """Execute a Hive Add Partition command for the given Athena tables and partitions Returns: (bool): If the repair was successful for not """ partitions = self._get_partitions_from_keys() if not partitions: LOGGER.error('No partitons to add') return False for athena_table in partitions: partition_statement = ' '.join([ 'PARTITION {0} LOCATION {1}'.format(partition, location) for partition, location in partitions[athena_table].iteritems() ]) query = ('ALTER TABLE {athena_table} ' 'ADD IF NOT EXISTS {partition_statement};'.format( athena_table=athena_table, partition_statement=partition_statement)) success = self._athena_client.run_query(query=query) if not success: raise AthenaRefreshError( 'The add hive partition query has failed:\n{}'.format( query)) LOGGER.info('Successfully added the following partitions:\n%s', json.dumps({athena_table: partitions[athena_table]})) return True def run(self, event): """Take the messages from the SQS queue and create partitions for new data in S3 Args: event (dict): Lambda input event containing SQS messages. Each SQS message should contain one (or maybe more) S3 bucket notification message. """ # Check that the database being used exists before running queries if not self._athena_client.check_database_exists(): raise AthenaRefreshError( 'The \'{}\' database does not exist'.format( self._athena_client.database)) for sqs_rec in event['Records']: LOGGER.debug( 'Processing event with message ID \'%s\' and SentTimestamp %s', sqs_rec['messageId'], sqs_rec['attributes']['SentTimestamp']) body = json.loads(sqs_rec['body']) if body.get('Event') == 's3:TestEvent': LOGGER.debug('Skipping S3 bucket notification test event') continue for s3_rec in body['Records']: if 's3' not in s3_rec: LOGGER.info( 'Skipping non-s3 bucket notification message: %s', s3_rec) continue bucket_name = s3_rec['s3']['bucket']['name'] # Account for special characters in the S3 object key # Example: Usage of '=' in the key name object_key = urllib.unquote_plus( s3_rec['s3']['object']['key']).decode('utf8') LOGGER.debug( 'Received notification for object \'%s\' in bucket \'%s\'', object_key, bucket_name) self._s3_buckets_and_keys[bucket_name].add(object_key) if not self._add_partitions(): raise AthenaRefreshError('Failed to add partitions: {}'.format( dict(self._s3_buckets_and_keys)))
class AthenaRefresher(object): """Handle polling an SQS queue and running Athena queries for updating tables""" STREAMALERTS_REGEX = re.compile(r'alerts/dt=(?P<year>\d{4})' r'\-(?P<month>\d{2})' r'\-(?P<day>\d{2})' r'\-(?P<hour>\d{2})' r'\/.*.json') FIREHOSE_REGEX = re.compile(r'(?P<year>\d{4})' r'\/(?P<month>\d{2})' r'\/(?P<day>\d{2})' r'\/(?P<hour>\d{2})\/.*') STREAMALERT_DATABASE = '{}_streamalert' ATHENA_S3_PREFIX = 'athena_partition_refresh' def __init__(self): config = load_config(include={'lambda.json', 'global.json'}) prefix = config['global']['account']['prefix'] athena_config = config['lambda']['athena_partition_refresh_config'] self._athena_buckets = athena_config['buckets'] db_name = athena_config.get('database_name', self.STREAMALERT_DATABASE.format(prefix)) # Get the S3 bucket to store Athena query results results_bucket = athena_config.get( 'results_bucket', 's3://{}.streamalert.athena-results'.format(prefix)) self._athena_client = AthenaClient(db_name, results_bucket, self.ATHENA_S3_PREFIX) # Initialize the SQS client and recieve messages self._sqs_client = StreamAlertSQSClient(config) def _get_partitions_from_keys(self, s3_buckets_and_keys): partitions = defaultdict(dict) LOGGER.info('Processing new Hive partitions...') for bucket, keys in s3_buckets_and_keys.iteritems(): athena_table = self._athena_buckets.get(bucket) if not athena_table: # TODO(jacknagz): Add this as a metric LOGGER.error( '%s not found in \'buckets\' config. Please add this ' 'bucket to enable additions of Hive partitions.', athena_table) continue # Iterate over each key for key in keys: match = None for pattern in (self.FIREHOSE_REGEX, self.STREAMALERTS_REGEX): match = pattern.search(key) if match: break if not match: LOGGER.error( 'The key %s does not match any regex, skipping', key) continue # Get the path to the objects in S3 path = posixpath.dirname(key) # The config does not need to store all possible tables # for enabled log types because this can be inferred from # the incoming S3 bucket notification. Only enabled # log types will be sending data to Firehose. # This logic extracts out the name of the table from the # first element in the S3 path, as that's how log types # are configured to send to Firehose. if athena_table != 'alerts': athena_table = path.split('/')[0] # Example: # PARTITION (dt = '2017-01-01-01') LOCATION 's3://bucket/path/' partition = '(dt = \'{year}-{month}-{day}-{hour}\')'.format( **match.groupdict()) location = '\'s3://{bucket}/{path}\''.format(bucket=bucket, path=path) # By using the partition as the dict key, this ensures that # Athena will not try to add the same partition twice. # TODO(jacknagz): Write this dictionary to SSM/DynamoDb # to increase idempotence of this Lambda function partitions[athena_table][partition] = location return partitions def _add_partition(self, s3_buckets_and_keys): """Execute a Hive Add Partition command on a given Athena table Args: s3_buckets_and_keys (dict): Buckets and unique keys to add partitions Returns: (bool): If the repair was successful for not """ partitions = self._get_partitions_from_keys(s3_buckets_and_keys) if not partitions: LOGGER.error('No partitons to add') return False for athena_table in partitions: partition_statement = ' '.join([ 'PARTITION {0} LOCATION {1}'.format(partition, location) for partition, location in partitions[athena_table].iteritems() ]) query = ('ALTER TABLE {athena_table} ' 'ADD IF NOT EXISTS {partition_statement};'.format( athena_table=athena_table, partition_statement=partition_statement)) success = self._athena_client.run_query(query=query) if not success: raise AthenaRefreshError( 'The add hive partition query has failed:\n{}'.format( query)) LOGGER.info( 'Successfully added the following partitions:\n%s', json.dumps({athena_table: partitions[athena_table]}, indent=4)) return True def run(self): """Poll the SQS queue for messages and create partitions for new data""" # Check that the database being used exists before running queries if not self._athena_client.check_database_exists(): raise AthenaRefreshError( 'The \'{}\' database does not exist'.format( self._athena_client.database)) # Get the first batch of messages from SQS. If there are no # messages, this will exit early. self._sqs_client.get_messages(max_tries=2) if not self._sqs_client.received_messages: LOGGER.info('No SQS messages recieved, exiting') return # If the max amount of messages was initially returned, # then get the next batch of messages. The max is determined based # on (number of tries) * (number of possible max messages returned) if len(self._sqs_client.received_messages) == 20: self._sqs_client.get_messages(max_tries=8) s3_buckets_and_keys = self._sqs_client.unique_s3_buckets_and_keys() if not s3_buckets_and_keys: LOGGER.error('No new Athena partitions to add, exiting') return if not self._add_partition(s3_buckets_and_keys): LOGGER.error('Failed to add hive partition(s)') return self._sqs_client.delete_messages() LOGGER.info('Deleted %d messages from SQS', self._sqs_client.deleted_message_count)