def _add_partition(self, s3_buckets_and_keys): """Execute a Hive Add Partition command on a given Athena table Args: s3_buckets_and_keys (dict): Buckets and unique keys to add partitions Returns: (bool): If the repair was successful for not """ partitions = self._get_partitions_from_keys(s3_buckets_and_keys) if not partitions: LOGGER.error('No partitons to add') return False for athena_table in partitions: partition_statement = ' '.join([ 'PARTITION {0} LOCATION {1}'.format(partition, location) for partition, location in partitions[athena_table].iteritems() ]) query = ('ALTER TABLE {athena_table} ' 'ADD IF NOT EXISTS {partition_statement};'.format( athena_table=athena_table, partition_statement=partition_statement)) success = self._athena_client.run_query(query=query) if not success: raise AthenaRefreshError( 'The add hive partition query has failed:\n{}'.format( query)) LOGGER.info( 'Successfully added the following partitions:\n%s', json.dumps({athena_table: partitions[athena_table]}, indent=4)) return True
def _delete_messages_from_queue(): # Determine the message batch for SQS message deletion len_processed_messages = len(self.processed_messages) batch = len_processed_messages if len_processed_messages < 10 else 10 # Pop processed records from the list to be deleted message_batch = [self.processed_messages.pop() for _ in range(batch)] # Try to delete the batch resp = self.sqs_client.delete_message_batch( QueueUrl=self.athena_sqs_url, Entries=[{'Id': message['MessageId'], 'ReceiptHandle': message['ReceiptHandle']} for message in message_batch]) # Handle successful deletions if resp.get('Successful'): self.deleted_message_count += len(resp['Successful']) # Handle failure deletion if resp.get('Failed'): LOGGER.error(('Failed to delete the messages with following (%d) ' 'error messages:\n%s'), len(resp['Failed']), json.dumps(resp['Failed'])) # Add the failed messages back to the processed_messages attribute # to be retried via backoff failed_message_ids = [message['Id'] for message in resp['Failed']] push_bach_messages = [message for message in message_batch if message['MessageId'] in failed_message_ids] self.processed_messages.extend(push_bach_messages) return len(self.processed_messages)
def _delete_messages_from_queue(): # Determine the message batch for SQS message deletion len_processed_messages = len(self.processed_messages) batch = len_processed_messages if len_processed_messages < 10 else 10 message_batch = [self.processed_messages.pop() for _ in range(batch)] # Try to delete the batch resp = self.sqs_client.delete_message_batch( QueueUrl=self.athena_sqs_url, Entries=[{'Id': message['MessageId'], 'ReceiptHandle': message['ReceiptHandle']} for message in message_batch]) # Handle successful deletions self.deleted_messages += len(resp['Successful']) # Handle failure deletion if resp.get('Failed'): LOGGER.error('Failed to delete the following (%d) messages:\n%s', len(resp['Failed']), json.dumps(resp['Failed'])) # Add the failed messages back to the processed_messages attribute failed_from_batch = [[message for message in message_batch if message['MessageId'] == failed_message['Id']] for failed_message in resp['Failed']] self.processed_messages.extend(failed_from_batch) return len(self.processed_messages)
def _giveup_handler(details): """Backoff logging handler for when backoff gives up. Args: details (dict): Backoff context containing the number of tries, target function currently executing, kwargs, args, value, and wait time. """ LOGGER.debug('[Backoff]: Exiting after %d tries calling %s', details['tries'], details['target'].__name__)
def delete_messages(self): """Delete messages off the queue once processed""" if not self.processed_messages: LOGGER.error('No processed messages to delete') return @backoff.on_predicate(backoff.fibo, lambda len_messages: len_messages > 0, max_value=10, max_tries=self.SQS_BACKOFF_MAX_RETRIES, jitter=backoff.full_jitter, on_backoff=backoff_handler(), on_success=success_handler()) def _delete_messages_from_queue(): # Determine the message batch for SQS message deletion len_processed_messages = len(self.processed_messages) batch = len_processed_messages if len_processed_messages < 10 else 10 # Pop processed records from the list to be deleted message_batch = [ self.processed_messages.pop() for _ in range(batch) ] # Try to delete the batch resp = self.sqs_client.delete_message_batch( QueueUrl=self.athena_sqs_url, Entries=[{ 'Id': message['MessageId'], 'ReceiptHandle': message['ReceiptHandle'] } for message in message_batch]) # Handle successful deletions if resp.get('Successful'): self.deleted_messages += len(resp['Successful']) # Handle failure deletion if resp.get('Failed'): LOGGER.error( ('Failed to delete the messages with following (%d) ' 'error messages:\n%s'), len(resp['Failed']), json.dumps(resp['Failed'])) # Add the failed messages back to the processed_messages attribute # to be retried via backoff failed_message_ids = [ message['Id'] for message in resp['Failed'] ] push_bach_messages = [ message for message in message_batch if message['MessageId'] in failed_message_ids ] self.processed_messages.extend(push_bach_messages) return len(self.processed_messages) _delete_messages_from_queue()
def _backoff_handler(details): """Backoff logging handler for when polling occurs. Args: details (dict): Backoff context containing the number of tries, target function currently executing, kwargs, args, value, and wait time. """ LOGGER.debug( '[Backoff]: Trying again in %f seconds after %d tries calling %s', details['wait'], details['tries'], details['target'].__name__)
def check_table_exists(self, table_name): """Verify a given StreamAlert Athena table exists.""" query_success, query_resp = self.run_athena_query( query='SHOW TABLES LIKE \'{}\';'.format(table_name), database=self.DATABASE_STREAMALERT) if query_success and query_resp['ResultSet']['Rows']: return True LOGGER.info('The streamalert table \'%s\' does not exist.', table_name) LOGGER.info('For help with creating tables: ' '$ python manage.py athena create-table --help') return False
def run(self): """Poll the SQS queue for messages and create partitions for new data""" # Check that the database being used exists before running queries if not self._athena_client.check_database_exists(): raise AthenaRefreshError( 'The \'{}\' database does not exist'.format( self._athena_client.database)) # Get the first batch of messages from SQS. If there are no # messages, this will exit early. self._sqs_client.get_messages(max_tries=2) if not self._sqs_client.received_messages: LOGGER.info('No SQS messages recieved, exiting') return # If the max amount of messages was initially returned, # then get the next batch of messages. The max is determined based # on (number of tries) * (number of possible max messages returned) if len(self._sqs_client.received_messages) == 20: self._sqs_client.get_messages(max_tries=8) s3_buckets_and_keys = self._sqs_client.unique_s3_buckets_and_keys() if not s3_buckets_and_keys: LOGGER.error('No new Athena partitions to add, exiting') return if not self._add_partition(s3_buckets_and_keys): LOGGER.error('Failed to add hive partition(s)') return self._sqs_client.delete_messages() LOGGER.info('Deleted %d messages from SQS', self._sqs_client.deleted_message_count)
def check_table_exists(self, table_name): """Verify a given StreamAlert Athena table exists.""" query_success, query_resp = self.run_athena_query( query='SHOW TABLES LIKE \'{}\';'.format(table_name), database=self.DATABASE_STREAMALERT) if query_success and query_resp['ResultSet']['Rows']: return True LOGGER.info( 'The streamalert table \'%s\' does not exist. ' 'For alert buckets, create it with the following command: \n' '$ python manage.py athena create-table ' '--type alerts --bucket s3.bucket.id', table_name) return False
def delete_messages(self): """Delete messages off the queue once processed""" if not self.processed_messages: LOGGER.error('No processed messages to delete') return @backoff.on_predicate(backoff.fibo, lambda len_messages: len_messages > 0, max_value=10, jitter=backoff.full_jitter, on_backoff=_backoff_handler, on_success=_success_handler) def _delete_messages_from_queue(): # Determine the message batch for SQS message deletion len_processed_messages = len(self.processed_messages) batch = len_processed_messages if len_processed_messages < 10 else 10 message_batch = [ self.processed_messages.pop() for _ in range(batch) ] # Try to delete the batch resp = self.sqs_client.delete_message_batch( QueueUrl=self.athena_sqs_url, Entries=[{ 'Id': message['MessageId'], 'ReceiptHandle': message['ReceiptHandle'] } for message in message_batch]) # Handle successful deletions self.deleted_messages += len(resp['Successful']) # Handle failure deletion if resp.get('Failed'): LOGGER.error( 'Failed to delete the following (%d) messages:\n%s', len(resp['Failed']), json.dumps(resp['Failed'])) # Add the failed messages back to the processed_messages attribute failed_from_batch = [[ message for message in message_batch if message['MessageId'] == failed_message['Id'] ] for failed_message in resp['Failed']] self.processed_messages.extend(failed_from_batch) return len(self.processed_messages) _delete_messages_from_queue()
def run_athena_query(self, **kwargs): """Helper function to run Athena queries Keyword Args: query (str): The SQL query to execute database (str): The database context to execute the query in async (bool): If the function should asynchronously run queries without backing off until completion. Returns: bool, dict: query success, query result response """ LOGGER.debug('Executing query: %s', kwargs['query']) query_execution_resp = self.athena_client.start_query_execution( QueryString=kwargs['query'], QueryExecutionContext={ 'Database': kwargs.get('database', self.DATABASE_DEFAULT) }, ResultConfiguration={ 'OutputLocation': '{}/{}'.format(self.athena_results_bucket, self.athena_results_key) }) # If asynchronous invocation is enabled, and a valid query # execution ID was returned. if kwargs.get('async') and query_execution_resp.get( 'QueryExecutionId'): return True, query_execution_resp exeuction_id = query_execution_resp['QueryExecutionId'] query_execution_result = self.check_query_status(exeuction_id) state = query_execution_result['QueryExecution']['Status']['State'] if state != 'SUCCEEDED': reason = query_execution_result['QueryExecution']['Status'][ 'StateChangeReason'] LOGGER.error('Query %s %s with reason %s, exiting!', exeuction_id, state, reason) LOGGER.error('Full query:\n%s', kwargs['query']) return False, {} query_results_resp = self.athena_client.get_query_results( QueryExecutionId=exeuction_id, ) # The idea here is to leave the processing logic to the calling functions. # No data being returned isn't always an indication that something is wrong. # When handling the query result data, iterate over each element in the Row, # and parse the Data key. # Reference: https://bit.ly/2tWOQ2N if not query_results_resp['ResultSet']['Rows']: LOGGER.debug('The query %s returned empty rows of data', kwargs['query']) return True, query_results_resp
def get_messages(self, **kwargs): """Poll the SQS queue for new messages Keyword Args: max_tries (int): The number of times to backoff max_value (int): The max wait interval between backoffs max_messages (int): The max number of messages to get from SQS """ start_message_count = len(self.received_messages) # Backoff up to 5 times to limit the time spent in this operation # relative to the entire Lambda duration. max_tries = kwargs.get('max_tries', 5) # This value restricts the max time of backoff each try. # This means the total backoff time for one function call is: # max_tries (attempts) * max_value (seconds) max_value = kwargs.get('max_value', 5) # Number of messages to poll from the stream. max_messages = kwargs.get('max_messages', self.MAX_SQS_GET_MESSAGE_COUNT) if max_messages > self.MAX_SQS_GET_MESSAGE_COUNT: LOGGER.error( 'SQS can only request up to 10 messages in one request') return @backoff.on_predicate(backoff.fibo, max_tries=max_tries, max_value=max_value, jitter=backoff.full_jitter, on_backoff=backoff_handler(), on_success=success_handler(True), on_giveup=giveup_handler(True)) def _receive_messages(): polled_messages = self.sqs_client.receive_message( QueueUrl=self.athena_sqs_url, MaxNumberOfMessages=max_messages) if 'Messages' not in polled_messages: return False self.received_messages.extend(polled_messages['Messages']) _receive_messages() batch_count = len(self.received_messages) - start_message_count LOGGER.info('Received %d message(s) from SQS', batch_count)
def check_database_exists(self, **kwargs): """Verify the StreamAlert Athena database exists. Keyword Args: database (str): The database name to execute the query under """ database = kwargs.get('database', self.sa_database) query_success, query_resp = self.run_athena_query( query='SHOW DATABASES LIKE \'{}\';'.format(database), ) if query_success and query_resp['ResultSet']['Rows']: return True LOGGER.error( 'The \'%s\' database does not exist. ' 'Create it with the following command: \n' '$ python manage.py athena create-db', database) return False
def get_messages(self, max_tries=5, max_value=5, max_messages=MAX_SQS_GET_MESSAGE_COUNT): """Poll the SQS queue for new messages Keyword Args: max_tries (int): The number of times to backoff Backoff up to 5 times to limit the time spent in this operation relative to the entire Lambda duration. max_value (int): The max wait interval between backoffs This value restricts the max time of backoff each try. This means the total backoff time for one function call is: max_tries (attempts) * max_value (seconds) max_messages (int): The max number of messages to get from SQS """ start_message_count = len(self.received_messages) # Number of messages to poll from the stream. if max_messages > self.MAX_SQS_GET_MESSAGE_COUNT: LOGGER.error('The maximum requested messages exceeds the SQS limitation per request. ' 'Setting max messages to %d', self.MAX_SQS_GET_MESSAGE_COUNT) max_messages = self.MAX_SQS_GET_MESSAGE_COUNT @backoff.on_predicate(backoff.fibo, max_tries=max_tries, max_value=max_value, jitter=backoff.full_jitter, on_backoff=backoff_handler(), on_success=success_handler(True), on_giveup=giveup_handler(True)) def _receive_messages(): polled_messages = self.sqs_client.receive_message( QueueUrl=self.athena_sqs_url, MaxNumberOfMessages=max_messages ) if 'Messages' not in polled_messages: return True # return True to stop polling self.received_messages.extend(polled_messages['Messages']) _receive_messages() batch_count = len(self.received_messages) - start_message_count LOGGER.info('Received %d message(s) from SQS', batch_count)
def repair_hive_table(self, unique_buckets): """Execute a MSCK REPAIR TABLE on a given Athena table Args: unique_buckets (list): S3 buckets to repair Returns: (bool): If the repair was successful for not """ athena_config = self.config['lambda'][ 'athena_partition_refresh_config'] repair_hive_table_config = athena_config['refresh_type'][ 'repair_hive_table'] LOGGER.info('Processing Hive repair table...') for data_bucket in unique_buckets: athena_table = repair_hive_table_config.get(data_bucket) if not athena_table: LOGGER.warning( '%s not found in repair_hive_table config. ' 'Please update your configuration accordingly.', athena_table) continue query_success, query_resp = self.run_athena_query( query='MSCK REPAIR TABLE {};'.format(athena_table), database=self.DATABASE_STREAMALERT) if query_success: LOGGER.info('Query results:') for row in query_resp['ResultSet']['Rows']: LOGGER.info(row['Data']) else: LOGGER.error( 'Partition refresh of the Athena table ' '%s has failed.', athena_table) return False return True
def run(self, event): """Take the messages from the SQS queue and create partitions for new data in S3 Args: event (dict): Lambda input event containing SQS messages. Each SQS message should contain one (or maybe more) S3 bucket notification message. """ # Check that the database being used exists before running queries if not self._athena_client.check_database_exists(): raise AthenaRefreshError( 'The \'{}\' database does not exist'.format( self._athena_client.database)) for sqs_rec in event['Records']: LOGGER.debug( 'Processing event with message ID \'%s\' and SentTimestamp %s', sqs_rec['messageId'], sqs_rec['attributes']['SentTimestamp']) body = json.loads(sqs_rec['body']) if body.get('Event') == 's3:TestEvent': LOGGER.debug('Skipping S3 bucket notification test event') continue for s3_rec in body['Records']: if 's3' not in s3_rec: LOGGER.info( 'Skipping non-s3 bucket notification message: %s', s3_rec) continue bucket_name = s3_rec['s3']['bucket']['name'] # Account for special characters in the S3 object key # Example: Usage of '=' in the key name object_key = urllib.unquote_plus( s3_rec['s3']['object']['key']).decode('utf8') LOGGER.debug( 'Received notification for object \'%s\' in bucket \'%s\'', object_key, bucket_name) self._s3_buckets_and_keys[bucket_name].add(object_key) if not self._add_partitions(): raise AthenaRefreshError('Failed to add partitions: {}'.format( dict(self._s3_buckets_and_keys)))
def _get_partitions_from_keys(self, s3_buckets_and_keys): partitions = defaultdict(dict) LOGGER.info('Processing new Hive partitions...') for bucket, keys in s3_buckets_and_keys.iteritems(): athena_table = self._athena_buckets.get(bucket) if not athena_table: # TODO(jacknagz): Add this as a metric LOGGER.error( '%s not found in \'buckets\' config. Please add this ' 'bucket to enable additions of Hive partitions.', athena_table) continue # Iterate over each key for key in keys: match = None for pattern in (self.FIREHOSE_REGEX, self.STREAMALERTS_REGEX): match = pattern.search(key) if match: break if not match: LOGGER.error( 'The key %s does not match any regex, skipping', key) continue # Get the path to the objects in S3 path = posixpath.dirname(key) # The config does not need to store all possible tables # for enabled log types because this can be inferred from # the incoming S3 bucket notification. Only enabled # log types will be sending data to Firehose. # This logic extracts out the name of the table from the # first element in the S3 path, as that's how log types # are configured to send to Firehose. if athena_table != 'alerts': athena_table = path.split('/')[0] # Example: # PARTITION (dt = '2017-01-01-01') LOCATION 's3://bucket/path/' partition = '(dt = \'{year}-{month}-{day}-{hour}\')'.format( **match.groupdict()) location = '\'s3://{bucket}/{path}\''.format(bucket=bucket, path=path) # By using the partition as the dict key, this ensures that # Athena will not try to add the same partition twice. # TODO(jacknagz): Write this dictionary to SSM/DynamoDb # to increase idempotence of this Lambda function partitions[athena_table][partition] = location return partitions
def handler(*_): """Athena Partition Refresher Handler Function""" config = _load_config() # Initialize the SQS client and recieve messages stream_alert_sqs = StreamAlertSQSClient(config) # Get the first batch of messages from SQS. If there are no # messages, this will exit early. stream_alert_sqs.get_messages(max_tries=2) if not stream_alert_sqs.received_messages: LOGGER.info('No SQS messages recieved, exiting') return # If the max amount of messages was initially returned, # then get the next batch of messages. The max is determined based # on (number of tries) * (number of possible max messages returned) if len(stream_alert_sqs.received_messages) == 20: stream_alert_sqs.get_messages(max_tries=8) s3_buckets_and_keys = stream_alert_sqs.unique_s3_buckets_and_keys() if not s3_buckets_and_keys: LOGGER.error('No new Athena partitions to add, exiting') return # Initialize the Athena client and run queries stream_alert_athena = StreamAlertAthenaClient(config) # Check that the 'streamalert' database exists before running queries if not stream_alert_athena.check_database_exists(): raise AthenaPartitionRefreshError( 'The \'{}\' database does not exist'.format( stream_alert_athena.sa_database)) if not stream_alert_athena.add_partition(s3_buckets_and_keys): LOGGER.error('Failed to add hive partition(s)') return stream_alert_sqs.delete_messages() LOGGER.info('Deleted %d messages from SQS', stream_alert_sqs.deleted_messages)
def unique_s3_buckets_and_keys(self): """Filter a list of unique s3 buckets and S3 keys from event notifications Returns: (dict): Keys of bucket names, and values of unique S3 keys """ s3_buckets_and_keys = defaultdict(set) if not self.received_messages: LOGGER.error( 'No messages to filter, fetch the messages with get_messages()' ) return for message in self.received_messages: if 'Body' not in message: LOGGER.error('Missing \'Body\' key in SQS message, skipping') continue loaded_message = json.loads(message['Body']) # From AWS documentation: http://amzn.to/2w4fcSq # When you configure an event notification on a bucket, # Amazon S3 sends the following test message: # { # "Service":"Amazon S3", # "Event":"s3:TestEvent", # "Time":"2014-10-13T15:57:02.089Z", # "Bucket":"bucketname", # "RequestId":"5582815E1AEA5ADF", # "HostId":"8cLeGAmw098X5cv4Zkwcmo8vvZa3eH3eKxsPzbB9wrR+YstdA6Knx4Ip8EXAMPLE" # } if loaded_message.get('Event') == 's3:TestEvent': LOGGER.debug('Skipping S3 bucket notification test event') continue if 'Records' not in loaded_message: LOGGER.error( 'Missing \'Records\' key in SQS message, skipping:\n%s', json.dumps(loaded_message, indent=4)) continue for record in loaded_message['Records']: if 's3' not in record: LOGGER.info('Skipping non-s3 bucket notification message') LOGGER.debug(record) continue bucket_name = record['s3']['bucket']['name'] # Account for special characters in the S3 object key # Example: Usage of '=' in the key name object_key = urllib.unquote( record['s3']['object']['key']).decode('utf8') s3_buckets_and_keys[bucket_name].add(object_key) # Add to a new list to track successfully processed messages from the queue self.processed_messages.append(message) return s3_buckets_and_keys
def add_hive_partition(self, s3_buckets_and_keys): """Execute a Hive Add Partition command on a given Athena table Args: s3_buckets_and_keys (dict): Buckets and unique keys to add partitions Returns: (bool): If the repair was successful for not """ athena_config = self.config['lambda']['athena_partition_refresh_config'] add_hive_partition_config = athena_config['refresh_type']['add_hive_partition'] partitions = {} LOGGER.info('Processing new Hive partitions...') for bucket, keys in s3_buckets_and_keys.iteritems(): athena_table = add_hive_partition_config.get(bucket) if not athena_table: LOGGER.error('%s not found in \'add_hive_partition\' config. ' 'Please add this bucket to enable additions ' 'of Hive partitions.', athena_table) continue # Gather all of the partitions to add per bucket s3_key_regex = self.STREAMALERTS_REGEX if athena_table == 'alerts' \ else self.FIREHOSE_REGEX # Iterate over each key for key in keys: match = s3_key_regex.search(key) if not match: LOGGER.error('The key %s does not match the regex %s, skipping', key, s3_key_regex.pattern) continue # Convert the match groups to a dict for easy access match_dict = match.groupdict() # Get the path to the objects in S3 path = os.path.dirname(key) # PARTITION (dt = '2017-01-01-01') LOCATION 's3://bucket/path/' partition = '(dt = \'{year}-{month}-{day}-{hour}\')'.format( year=match_dict['year'], month=match_dict['month'], day=match_dict['day'], hour=match_dict['hour']) location = '\'s3://{bucket}/{path}\''.format( bucket=bucket, path=path) # By using the partition as the dict key, this ensures that # Athena will not try to add the same partition twice. partitions[partition] = location if not partitions: LOGGER.error('No partitons to add') return False partition_statement = ' '.join( ['PARTITION {0} LOCATION {1}'.format( partition, location) for partition, location in partitions.iteritems()]) query = ('ALTER TABLE {athena_table} ' 'ADD IF NOT EXISTS {partition_statement};'.format( athena_table=athena_table, partition_statement=partition_statement)) query_success, _ = self.run_athena_query( query=query, database=self.DATABASE_STREAMALERT ) if not query_success: LOGGER.error('The add hive partition query has failed:\n%s', query) return False LOGGER.info('Successfully added the following partitions:\n%s', '\n'.join(partitions)) return True
def add_partition(self, s3_buckets_and_keys): """Execute a Hive Add Partition command on a given Athena table Args: s3_buckets_and_keys (dict): Buckets and unique keys to add partitions Returns: (bool): If the repair was successful for not """ athena_config = self.config['lambda'][ 'athena_partition_refresh_config'] athena_buckets = athena_config['buckets'] partitions = defaultdict(dict) LOGGER.info('Processing new Hive partitions...') for bucket, keys in s3_buckets_and_keys.iteritems(): athena_table = athena_buckets.get(bucket) if not athena_table: # TODO(jacknagz): Add this as a metric LOGGER.error( '%s not found in \'buckets\' config. ' 'Please add this bucket to enable additions ' 'of Hive partitions.', athena_table) continue # Iterate over each key for key in keys: for pattern in (self.FIREHOSE_REGEX, self.STREAMALERTS_REGEX): match = pattern.search(key) if match: break if not match: LOGGER.error( 'The key %s does not match any regex, skipping', key) continue # Convert the match groups to a dict for easy access match_dict = match.groupdict() # Get the path to the objects in S3 path = os.path.dirname(key) # The config does not need to store all possible tables # for enabled log types because this can be inferred from # the incoming S3 bucket notification. Only enabled # log types will be sending data to Firehose. # This logic extracts out the name of the table from the # first element in the S3 path, as that's how log types # are configured to send to Firehose. if athena_table != 'alerts': athena_table = path.split('/')[0] # Example: # PARTITION (dt = '2017-01-01-01') LOCATION 's3://bucket/path/' partition = '(dt = \'{year}-{month}-{day}-{hour}\')'.format( year=match_dict['year'], month=match_dict['month'], day=match_dict['day'], hour=match_dict['hour']) location = '\'s3://{bucket}/{path}\''.format(bucket=bucket, path=path) # By using the partition as the dict key, this ensures that # Athena will not try to add the same partition twice. # TODO(jacknagz): Write this dictionary to SSM/DynamoDb # to increase idempotence of this Lambda function partitions[athena_table][partition] = location if not partitions: LOGGER.error('No partitons to add') return False for athena_table in partitions: partition_statement = ' '.join([ 'PARTITION {0} LOCATION {1}'.format(partition, location) for partition, location in partitions[athena_table].iteritems() ]) query = ('ALTER TABLE {athena_table} ' 'ADD IF NOT EXISTS {partition_statement};'.format( athena_table=athena_table, partition_statement=partition_statement)) query_success, _ = self.run_athena_query(query=query, database=self.sa_database) if not query_success: raise AthenaPartitionRefreshError( 'The add hive partition query has failed:\n{}'.format( query)) LOGGER.info( 'Successfully added the following partitions:\n%s', json.dumps({athena_table: partitions[athena_table]}, indent=4)) return True