Example #1
0
    def _send_batch(self, stream_name, record_batch):
        """Send record batches to Firehose

        Args:
            stream_name (str): The name of the Delivery Stream to send to
            record_batch (list): The records to send
        """
        @backoff.on_predicate(backoff.fibo,
                              lambda resp: resp['FailedPutCount'] > 0,
                              max_tries=self.MAX_BACKOFF_ATTEMPTS,
                              max_value=self.MAX_BACKOFF_FIBO_VALUE,
                              on_backoff=backoff_handler(debug_only=False),
                              on_success=success_handler(),
                              on_giveup=giveup_handler())
        @backoff.on_exception(backoff.fibo,
                              self.EXCEPTIONS_TO_BACKOFF,
                              max_tries=self.MAX_BACKOFF_ATTEMPTS,
                              on_backoff=backoff_handler(debug_only=False),
                              on_success=success_handler(),
                              on_giveup=giveup_handler())
        def _firehose_request_helper(data):
            """Firehose request wrapper to use with backoff"""
            # Use the current length of data here so we can track failed records that are retried
            LOGGER.debug('Sending %d records to firehose %s', len(data), stream_name)

            response = self._client.put_record_batch(DeliveryStreamName=stream_name, Records=data)

            # Log this as an error for now so it can be picked up in logs
            if response['FailedPutCount'] > 0:
                LOGGER.warning('Received non-zero FailedPutCount: %d', response['FailedPutCount'])
                # Strip out the successful records so only the failed ones are retried. This happens
                # to the list of dictionary objects, so the called function sees the updated list
                self._strip_successful_records(data, response)

            return response

        # The record here already contains a newline, so do not append one
        records_data = [
            {'Data': record}
            for record in record_batch
        ]

        # The try/except here is to catch the raised error at the end of the backoff
        try:
            return _firehose_request_helper(records_data)
        except self.EXCEPTIONS_TO_BACKOFF:
            LOGGER.exception('Firehose request failed')
            # Use the current length of the records_data in case some records were
            # successful but others were not
            self._log_failed(len(records_data))
Example #2
0
    def check_query_status(self, query_execution_id):
        """Check in on the running query, back off if the job is running or queued

        Args:
            query_execution_id (str): Athena query execution ID

        Returns:
            bool: True if the query state is SUCCEEDED, False otherwise
                Reference https://bit.ly/2uuRtda.
        """
        states_to_backoff = {'QUEUED', 'RUNNING'}
        @backoff.on_predicate(backoff.fibo,
                              lambda resp: \
                              resp['QueryExecution']['Status']['State'] in states_to_backoff,
                              max_value=10,
                              jitter=backoff.full_jitter,
                              on_backoff=backoff_handler(),
                              on_success=success_handler(True))
        def _check_status(query_execution_id):
            return self._client.get_query_execution(
                QueryExecutionId=query_execution_id)

        execution_result = _check_status(query_execution_id)
        state = execution_result['QueryExecution']['Status']['State']
        if state != 'SUCCEEDED':
            reason = execution_result['QueryExecution']['Status'][
                'StateChangeReason']
            LOGGER.error('Query %s %s with reason %s, exiting!',
                         query_execution_id, state, reason)
            return False

        return True
Example #3
0
    def _send_message(self, records):
        """Send a single message with a blob of records to CSIRT SQS
        Args:
            message (string): A message that is serialized to json to be sent to
                the Rules Engine function

        Returns:
            string|bool: The MessageId if the request was successful, False otherwise
        """
        @backoff.on_exception(backoff.expo,
                              self.EXCEPTIONS_TO_BACKOFF,
                              max_tries=self.MAX_BACKOFF_ATTEMPTS,
                              on_backoff=backoff_handler(debug_only=False),
                              on_success=success_handler(),
                              on_giveup=giveup_handler())
        def _send_message_helper(request):
            """Inner helper function for sending a single message with backoff

            Args:
                entries (list<dict>): List of SQS SendMessageBatchRequestEntry items
            """
            return self.queue.send_message(**request)

        # Prepare the request now to save time during retries
        request = {'MessageBody': '[{}]'.format(','.join(records))}

        # The try/except here is to catch any raised errors at the end of the backoff
        try:
            response = _send_message_helper(request)
            return response['MessageId']
        except self.EXCEPTIONS_TO_BACKOFF:
            LOGGER.exception('SQS request failed')
            return False
Example #4
0
    def _dispatch(self, alert, descriptor):
        """Send alert to a Kinesis Firehose Delivery Stream

        Publishing:
            By default this output sends the current publication in JSON to Kinesis.
            There is no "magic" field to "override" it: Simply publish what you want to send!

        Args:
            alert (Alert): Alert instance which triggered a rule
            descriptor (str): Output descriptor

        Returns:
            bool: True if alert was sent successfully, False otherwise
        """
        @backoff.on_exception(backoff.fibo,
                              ClientError,
                              max_tries=self.MAX_BACKOFF_ATTEMPTS,
                              jitter=backoff.full_jitter,
                              on_backoff=backoff_handler(),
                              on_success=success_handler(),
                              on_giveup=giveup_handler())
        def _firehose_request_wrapper(json_alert, delivery_stream):
            """Make the PutRecord request to Kinesis Firehose with backoff

            Args:
                json_alert (str): The JSON dumped alert body
                delivery_stream (str): The Firehose Delivery Stream to send to

            Returns:
                dict: Firehose response in the format below
                    {'RecordId': 'string'}
            """
            self.__aws_client__.put_record(
                DeliveryStreamName=delivery_stream,
                Record={'Data': json_alert}
            )

        if self.__aws_client__ is None:
            self.__aws_client__ = boto3.client('firehose', region_name=self.region)

        publication = compose_alert(alert, self, descriptor)

        json_alert = json.dumps(publication, separators=(',', ':')) + '\n'
        if len(json_alert) > self.MAX_RECORD_SIZE:
            LOGGER.error('Alert too large to send to Firehose: \n%s...', json_alert[0:1000])
            return False

        delivery_stream = self.config[self.__service__][descriptor]
        LOGGER.info('Sending %s to aws-firehose:%s', alert, delivery_stream)

        _firehose_request_wrapper(json_alert, delivery_stream)
        LOGGER.info('%s successfully sent to aws-firehose:%s', alert, delivery_stream)

        return True
Example #5
0
    def delete_messages(self):
        """Delete messages off the queue once processed"""
        if not self.processed_messages:
            LOGGER.error('No processed messages to delete')
            return

        @backoff.on_predicate(backoff.fibo,
                              lambda len_messages: len_messages > 0,
                              max_value=10,
                              max_tries=self.SQS_BACKOFF_MAX_RETRIES,
                              jitter=backoff.full_jitter,
                              on_backoff=backoff_handler(),
                              on_success=success_handler())
        def _delete_messages_from_queue():
            # Determine the message batch for SQS message deletion
            len_processed_messages = len(self.processed_messages)
            batch = len_processed_messages if len_processed_messages < 10 else 10
            # Pop processed records from the list to be deleted
            message_batch = [
                self.processed_messages.pop() for _ in range(batch)
            ]

            # Try to delete the batch
            resp = self.sqs_client.delete_message_batch(
                QueueUrl=self.athena_sqs_url,
                Entries=[{
                    'Id': message['MessageId'],
                    'ReceiptHandle': message['ReceiptHandle']
                } for message in message_batch])

            # Handle successful deletions
            if resp.get('Successful'):
                self.deleted_messages += len(resp['Successful'])
            # Handle failure deletion
            if resp.get('Failed'):
                LOGGER.error(
                    ('Failed to delete the messages with following (%d) '
                     'error messages:\n%s'), len(resp['Failed']),
                    json.dumps(resp['Failed']))
                # Add the failed messages back to the processed_messages attribute
                # to be retried via backoff
                failed_message_ids = [
                    message['Id'] for message in resp['Failed']
                ]
                push_bach_messages = [
                    message for message in message_batch
                    if message['MessageId'] in failed_message_ids
                ]

                self.processed_messages.extend(push_bach_messages)

            return len(self.processed_messages)

        _delete_messages_from_queue()
Example #6
0
 def real_decorator(func):
     """Actual decorator to retry on exceptions"""
     @backoff.on_exception(backoff.expo,
                           exceptions, # This is a tuple with exceptions
                           max_tries=OutputDispatcher.MAX_RETRY_ATTEMPTS,
                           jitter=backoff.full_jitter,
                           on_backoff=backoff_handler(),
                           on_success=success_handler(),
                           on_giveup=giveup_handler())
     def wrapper(*args, **kwargs):
         return func(*args, **kwargs)
     return wrapper
Example #7
0
    def dispatch(self, alert, descriptor):
        """Send alert to a Kinesis Firehose Delivery Stream

        Args:
            alert (Alert): Alert instance which triggered a rule
            descriptor (str): Output descriptor

        Returns:
            bool: True if alert was sent successfully, False otherwise
        """
        @backoff.on_exception(backoff.fibo,
                              ClientError,
                              max_tries=self.MAX_BACKOFF_ATTEMPTS,
                              jitter=backoff.full_jitter,
                              on_backoff=backoff_handler(),
                              on_success=success_handler(),
                              on_giveup=giveup_handler())
        def _firehose_request_wrapper(json_alert, delivery_stream):
            """Make the PutRecord request to Kinesis Firehose with backoff

            Args:
                json_alert (str): The JSON dumped alert body
                delivery_stream (str): The Firehose Delivery Stream to send to

            Returns:
                dict: Firehose response in the format below
                    {'RecordId': 'string'}
            """
            return self.__aws_client__.put_record(DeliveryStreamName=delivery_stream,
                                                  Record={'Data': json_alert})

        if self.__aws_client__ is None:
            self.__aws_client__ = boto3.client('firehose', region_name=self.region)

        json_alert = json.dumps(alert.output_dict(), separators=(',', ':')) + '\n'
        if len(json_alert) > self.MAX_RECORD_SIZE:
            LOGGER.error('Alert too large to send to Firehose: \n%s...', json_alert[0:1000])
            return False

        delivery_stream = self.config[self.__service__][descriptor]
        LOGGER.info('Sending %s to aws-firehose:%s', alert, delivery_stream)

        resp = _firehose_request_wrapper(json_alert, delivery_stream)

        if resp.get('RecordId'):
            LOGGER.info('%s successfully sent to aws-firehose:%s with RecordId:%s',
                        alert, delivery_stream, resp['RecordId'])

        return self._log_status(resp, descriptor)
Example #8
0
    def get_messages(self, **kwargs):
        """Poll the SQS queue for new messages

        Keyword Args:
            max_tries (int): The number of times to backoff
            max_value (int): The max wait interval between backoffs
            max_messages (int): The max number of messages to get from SQS
        """
        start_message_count = len(self.received_messages)

        # Backoff up to 5 times to limit the time spent in this operation
        # relative to the entire Lambda duration.
        max_tries = kwargs.get('max_tries', 5)

        # This value restricts the max time of backoff each try.
        # This means the total backoff time for one function call is:
        #   max_tries (attempts) * max_value (seconds)
        max_value = kwargs.get('max_value', 5)

        # Number of messages to poll from the stream.
        max_messages = kwargs.get('max_messages',
                                  self.MAX_SQS_GET_MESSAGE_COUNT)
        if max_messages > self.MAX_SQS_GET_MESSAGE_COUNT:
            LOGGER.error(
                'SQS can only request up to 10 messages in one request')
            return

        @backoff.on_predicate(backoff.fibo,
                              max_tries=max_tries,
                              max_value=max_value,
                              jitter=backoff.full_jitter,
                              on_backoff=backoff_handler(),
                              on_success=success_handler(True),
                              on_giveup=giveup_handler(True))
        def _receive_messages():
            polled_messages = self.sqs_client.receive_message(
                QueueUrl=self.athena_sqs_url, MaxNumberOfMessages=max_messages)

            if 'Messages' not in polled_messages:
                return False
            self.received_messages.extend(polled_messages['Messages'])

        _receive_messages()
        batch_count = len(self.received_messages) - start_message_count
        LOGGER.info('Received %d message(s) from SQS', batch_count)
Example #9
0
    def _query(self, values):
        """Instance method to query DynamoDB table

        Args:
            values (list): A list of string which contains IOC values

        Returns:
            A tuple(list, dict)
            list: A list of dict returned from dynamodb
                table query, in the format of
                    [
                        {'sub_type': 'c2_domain', 'ioc_value': 'evil.com'},
                        {'sub_type': 'mal_ip', 'ioc_value': '1.1.1.2'},
                    ]
            dict: A dict containing unprocesed keys.
        """
        @backoff.on_exception(backoff.expo,
                              self.EXCEPTIONS_TO_BACKOFF,
                              max_tries=self.BACKOFF_MAX_RETRIES,
                              giveup=exceptions_to_giveup,
                              on_backoff=backoff_handler(),
                              on_success=success_handler(),
                              on_giveup=giveup_handler())
        def _query(query_keys):
            response = self.dynamodb.batch_get_item(
                RequestItems={
                    self._table: {
                        'Keys': query_keys,
                        'ProjectionExpression': PROJECTION_EXPRESSION
                    }
                },
            )

            result = []
            if response.get('Responses'):
                result.extend(self._deserialize(response['Responses'].get(self._table)))

            return result, response.get('UnprocessedKeys')

        query_keys = [{PRIMARY_KEY: {'S': ioc}} for ioc in values if ioc]

        return _query(query_keys)
Example #10
0
    def get_messages(self, max_tries=5, max_value=5, max_messages=MAX_SQS_GET_MESSAGE_COUNT):
        """Poll the SQS queue for new messages

        Keyword Args:
            max_tries (int): The number of times to backoff
                Backoff up to 5 times to limit the time spent in this operation
                relative to the entire Lambda duration.
            max_value (int): The max wait interval between backoffs
                This value restricts the max time of backoff each try.
                This means the total backoff time for one function call is:
                    max_tries (attempts) * max_value (seconds)
            max_messages (int): The max number of messages to get from SQS
        """
        start_message_count = len(self.received_messages)

        # Number of messages to poll from the stream.
        if max_messages > self.MAX_SQS_GET_MESSAGE_COUNT:
            LOGGER.error('The maximum requested messages exceeds the SQS limitation per request. '
                         'Setting max messages to %d', self.MAX_SQS_GET_MESSAGE_COUNT)
            max_messages = self.MAX_SQS_GET_MESSAGE_COUNT

        @backoff.on_predicate(backoff.fibo,
                              max_tries=max_tries,
                              max_value=max_value,
                              jitter=backoff.full_jitter,
                              on_backoff=backoff_handler(),
                              on_success=success_handler(True),
                              on_giveup=giveup_handler(True))
        def _receive_messages():
            polled_messages = self.sqs_client.receive_message(
                QueueUrl=self.athena_sqs_url,
                MaxNumberOfMessages=max_messages
            )
            if 'Messages' not in polled_messages:
                return True # return True to stop polling

            self.received_messages.extend(polled_messages['Messages'])

        _receive_messages()
        batch_count = len(self.received_messages) - start_message_count
        LOGGER.info('Received %d message(s) from SQS', batch_count)
Example #11
0
    def check_query_status(self, execution_id):
        """Check in on the running query, back off if the job is running or queued

        Args:
            query_execution_id (str): Athena query execution ID

        Returns:
            bool: True if the query state is SUCCEEDED, False otherwise
                Reference https://bit.ly/2uuRtda.

        Raises:
            AthenaQueryExecutionError: If any failure occurs while checking the status of the
                query, this exception will be raised
        """
        LOGGER.debug('Checking status of query with execution ID: %s',
                     execution_id)

        states_to_backoff = {'QUEUED', 'RUNNING'}
        @backoff.on_predicate(backoff.fibo,
                              lambda resp: \
                              resp['QueryExecution']['Status']['State'] in states_to_backoff,
                              max_value=10,
                              jitter=backoff.full_jitter,
                              on_backoff=backoff_handler(),
                              on_success=success_handler(True))
        def _check_status(query_execution_id):
            return self._client.get_query_execution(
                QueryExecutionId=query_execution_id)

        execution_result = _check_status(execution_id)
        state = execution_result['QueryExecution']['Status']['State']
        if state == 'SUCCEEDED':
            return

        # When the state is not SUCCEEDED, something bad must have occurred, so raise an exception
        reason = execution_result['QueryExecution']['Status'][
            'StateChangeReason']
        raise AthenaQueryExecutionError(
            'Query \'{}\' {} with reason \'{}\', exiting'.format(
                execution_id, state, reason))
Example #12
0
    def check_query_status(self, query_execution_id):
        """Check in on the running query, back off if the job is running or queued

        Args:
            query_execution_id (str): The Athena query execution ID

        Returns:
            str: The result of the Query.  This value can be SUCCEEDED, FAILED, or CANCELLED.
                Reference https://bit.ly/2uuRtda.
        """
        states_to_backoff = {'QUEUED', 'RUNNING'}
        @backoff.on_predicate(backoff.fibo,
                              lambda resp: \
                              resp['QueryExecution']['Status']['State'] in states_to_backoff,
                              max_value=10,
                              jitter=backoff.full_jitter,
                              on_backoff=backoff_handler(),
                              on_success=success_handler(True))
        def _check_status(query_execution_id):
            return self.athena_client.get_query_execution(
                QueryExecutionId=query_execution_id)

        return _check_status(query_execution_id)
Example #13
0
    def _firehose_request_helper(self, stream_name, record_batch):
        """Send record batches to Firehose

        Args:
            stream_name (str): The name of the Delivery Stream to send to
            record_batch (list): The records to send
        """
        exceptions_to_backoff = (ClientError, ConnectionError, Timeout)

        @backoff.on_predicate(backoff.fibo,
                              lambda resp: resp['FailedPutCount'] > 0,
                              max_tries=self.MAX_BACKOFF_ATTEMPTS,
                              max_value=self.MAX_BACKOFF_FIBO_VALUE,
                              jitter=backoff.full_jitter,
                              on_backoff=backoff_handler(debug_only=False),
                              on_success=success_handler(),
                              on_giveup=giveup_handler())
        @backoff.on_exception(backoff.fibo,
                              exceptions_to_backoff,
                              max_tries=self.MAX_BACKOFF_ATTEMPTS,
                              jitter=backoff.full_jitter,
                              on_backoff=backoff_handler(debug_only=False),
                              on_success=success_handler(),
                              on_giveup=giveup_handler())
        def firehose_request_wrapper(data):
            """Firehose request wrapper to use with backoff"""
            # Use the current length of data here so we can track failed records that are retried
            LOGGER.info('[Firehose] Sending %d records to %s', len(data), stream_name)

            response = self._client.put_record_batch(DeliveryStreamName=stream_name, Records=data)

            # Log this as an error for now so it can be picked up in logs
            if response['FailedPutCount'] > 0:
                LOGGER.error('Received non-zero FailedPutCount: %d', response['FailedPutCount'])
                # Strip out the successful records so only the failed ones are retried. This happens
                # to the list of dictionary objects, so the called function sees the updated list
                self._strip_successful_records(data, response)

            return response

        original_batch_size = len(record_batch)

        # The newline at the end is required by Firehose,
        # otherwise all records will be on a single line and
        # unsearchable in Athena.
        records_data = [
            {'Data': json.dumps(self.sanitize_keys(record), separators=(",", ":")) + '\n'}
            for record in record_batch
        ]

        # The try/except here is to catch the raised error at the end of the backoff
        try:
            resp = firehose_request_wrapper(records_data)
        except exceptions_to_backoff as firehose_err:
            LOGGER.error(firehose_err)
            # Use the current length of the records_data in case some records were
            # successful but others were not
            MetricLogger.log_metric(FUNCTION_NAME,
                                    MetricLogger.FIREHOSE_FAILED_RECORDS,
                                    len(records_data))
            return

        # Error handle if failures occurred in PutRecordBatch after
        # several backoff attempts
        if resp.get('FailedPutCount') > 0:
            failed_records = [failed
                              for failed
                              in resp['RequestResponses']
                              if failed.get('ErrorCode')]
            MetricLogger.log_metric(FUNCTION_NAME,
                                    MetricLogger.FIREHOSE_FAILED_RECORDS,
                                    resp['FailedPutCount'])
            # Only print the first 100 failed records to Cloudwatch logs
            LOGGER.error('[Firehose] The following records failed to put to '
                         'the Delivery Stream %s: %s',
                         stream_name,
                         json.dumps(failed_records[:100], indent=2))
        else:
            MetricLogger.log_metric(FUNCTION_NAME,
                                    MetricLogger.FIREHOSE_RECORDS_SENT,
                                    original_batch_size)
            LOGGER.info('[Firehose] Successfully sent %d messages to %s with RequestId [%s]',
                        original_batch_size,
                        stream_name,
                        resp.get('ResponseMetadata', {}).get('RequestId', ''))
Example #14
0
    def _send_messages(self, batched_messages):
        """Send new formatted messages to CSIRT SQS
        Args:
            batched_messages (list): A list of messages that are already serialized to json
                to be sent to the Rules Engine function
        Returns:
            bool: True if the request was successful, False otherwise
        """
        @backoff.on_predicate(backoff.fibo,
                              lambda resp: len(resp.get('Failed', [])) > 0,
                              max_tries=self.MAX_BACKOFF_ATTEMPTS,
                              max_value=self.MAX_BACKOFF_FIBO_VALUE,
                              on_backoff=backoff_handler(debug_only=False),
                              on_success=success_handler(),
                              on_giveup=giveup_handler())
        @backoff.on_exception(backoff.expo, self.EXCEPTIONS_TO_BACKOFF,
                              max_tries=self.MAX_BACKOFF_ATTEMPTS,
                              on_backoff=backoff_handler(debug_only=False),
                              on_success=success_handler(),
                              on_giveup=giveup_handler())
        def _send_messages_helper(entries):
            """Inner helper function for sending messages with backoff_handler

            Args:
                entries (list<dict>): List of SQS SendMessageBatchRequestEntry items
            """
            LOGGER.info('Sending %d message(s) to %s', len(entries), self.queue.url)

            response = self.queue.send_messages(Entries=entries)

            if response.get('Successful'):
                LOGGER.info(
                    'Successfully sent %d message(s) to %s with MessageIds %s',
                    len(response['Successful']),
                    self.queue.url,
                    ', '.join(
                        '\'{}\''.format(resp['MessageId'])
                        for resp in response['Successful']
                    )
                )

            if response.get('Failed'):
                self._check_failures(response)  # Raise an exception if this is our fault
                self._strip_successful_records(entries, response)

            return response

        message_entries = [
            {
                'Id': str(idx),
                'MessageBody': message
            } for idx, message in enumerate(batched_messages)
        ]

        # The try/except here is to catch any raised errors at the end of the backoff
        try:
            return _send_messages_helper(message_entries)
        except self.EXCEPTIONS_TO_BACKOFF:
            LOGGER.exception('SQS request failed')
            # Use the current length of the message_entries in case some records were
            # successful but others were not
            self._log_failed(len(message_entries))
            return
Example #15
0
    def _firehose_request_helper(self, stream_name, record_batch):
        """Send record batches to Firehose

        Args:
            stream_name (str): The name of the Delivery Stream to send to
            record_batch (list): The records to send
        """
        resp = {}
        record_batch_size = len(record_batch)
        exceptions_to_backoff = (ClientError, ConnectionError)

        @backoff.on_predicate(backoff.fibo,
                              lambda resp: resp['FailedPutCount'] > 0,
                              max_tries=self.MAX_BACKOFF_ATTEMPTS,
                              max_value=self.MAX_BACKOFF_FIBO_VALUE,
                              jitter=backoff.full_jitter,
                              on_backoff=backoff_handler(),
                              on_success=success_handler(),
                              on_giveup=giveup_handler())
        @backoff.on_exception(backoff.fibo,
                              exceptions_to_backoff,
                              max_tries=self.MAX_BACKOFF_ATTEMPTS,
                              jitter=backoff.full_jitter,
                              on_backoff=backoff_handler(),
                              on_success=success_handler(),
                              on_giveup=giveup_handler())
        def firehose_request_wrapper(data):
            """Firehose request wrapper to use with backoff"""
            LOGGER.info('[Firehose] Sending %d records to %s',
                        record_batch_size, stream_name)
            return self._firehose_client.put_record_batch(
                DeliveryStreamName=stream_name, Records=data)

        # The newline at the end is required by Firehose,
        # otherwise all records will be on a single line and
        # unsearchable in Athena.
        records_data = [{
            'Data':
            json.dumps(self.sanitize_keys(record), separators=(",", ":")) +
            '\n'
        } for record in record_batch]

        # The try/except here is to catch the raised error at the
        # end of the backoff.
        try:
            resp = firehose_request_wrapper(records_data)
        except exceptions_to_backoff as firehose_err:
            LOGGER.error(firehose_err)
            MetricLogger.log_metric(FUNCTION_NAME,
                                    MetricLogger.FIREHOSE_FAILED_RECORDS,
                                    record_batch_size)
            return

        # Error handle if failures occurred in PutRecordBatch after
        # several backoff attempts
        if resp.get('FailedPutCount') > 0:
            failed_records = [
                failed for failed in resp['RequestResponses']
                if failed.get('ErrorCode')
            ]
            MetricLogger.log_metric(FUNCTION_NAME,
                                    MetricLogger.FIREHOSE_FAILED_RECORDS,
                                    resp['FailedPutCount'])
            # Only print the first 100 failed records to Cloudwatch logs
            LOGGER.error(
                '[Firehose] The following records failed to put to '
                'the Delivery Stream %s: %s', stream_name,
                json.dumps(failed_records[:100], indent=2))
        else:
            MetricLogger.log_metric(FUNCTION_NAME,
                                    MetricLogger.FIREHOSE_RECORDS_SENT,
                                    record_batch_size)
            LOGGER.info(
                '[Firehose] Successfully sent %d messages to %s with RequestId [%s]',
                record_batch_size, stream_name,
                resp.get('ResponseMetadata', {}).get('RequestId', ''))
def test_backoff_handler_info(log_mock):
    """Backoff Handlers - Backoff Handler, Info"""
    on_backoff = backoff_handler(False)
    on_backoff(_get_details(True))
    log_mock.assert_called()
Example #17
0
class ThreatStream(object):
    """Class to retrieve IOCs from ThreatStream.com and store them in DynamoDB"""
    _API_URL = 'https://api.threatstream.com'
    _API_RESOURCE = 'intelligence'
    _IOC_STATUS = 'active'
    # max IOC objects received from one API call, default is 0 (equal to 1000)
    _API_MAX_LIMIT = 1000
    _API_MAX_INDEX = 500000
    # Remaining time in seconds before lambda termination
    _END_TIME_BUFFER = 5
    CRED_PARAMETER_NAME = 'threat_intel_downloader_api_creds'

    EXCEPTIONS_TO_BACKOFF = (requests.exceptions.Timeout,
                             requests.exceptions.ConnectionError,
                             requests.exceptions.ChunkedEncodingError,
                             ThreatStreamRequestsError)
    BACKOFF_MAX_RETRIES = 3

    def __init__(self, function_arn, timing_func):
        self._config = self._load_config(function_arn)
        self.timing_func = timing_func
        self.api_user = None
        self.api_key = None

    @staticmethod
    def _load_config(function_arn):
        """Load the Threat Intel Downloader configuration from conf/lambda.json file

        Returns:
            (dict): Configuration for Threat Intel Downloader

        Raises:
            ConfigError: For invalid or missing configuration files.
        """

        base_config = parse_lambda_arn(function_arn)
        config = load_config(include={'lambda.json'})['lambda']
        base_config.update(config.get('threat_intel_downloader_config', {}))
        return base_config

    def _load_api_creds(self):
        """Retrieve ThreatStream API credentials from Parameter Store"""
        if self.api_user and self.api_key:
            return  # credentials already loaded from SSM

        try:
            ssm = boto3.client('ssm', self.region)
            response = ssm.get_parameter(Name=self.CRED_PARAMETER_NAME,
                                         WithDecryption=True)
        except ClientError:
            LOGGER.exception('Failed to get SSM parameters')
            raise

        if not response:
            raise ThreatStreamCredsError('Invalid response')

        try:
            decoded_creds = json.loads(response['Parameter']['Value'])
        except ValueError:
            raise ThreatStreamCredsError(
                'Cannot load value for parameter with name '
                '\'{}\'. The value is not valid json: '
                '\'{}\''.format(response['Parameter']['Name'],
                                response['Parameter']['Value']))

        self.api_user = decoded_creds['api_user']
        self.api_key = decoded_creds['api_key']

        if not (self.api_user and self.api_key):
            raise ThreatStreamCredsError('API Creds Error')

    @backoff.on_exception(backoff.constant,
                          EXCEPTIONS_TO_BACKOFF,
                          max_tries=BACKOFF_MAX_RETRIES,
                          on_backoff=backoff_handler(),
                          on_success=success_handler(),
                          on_giveup=giveup_handler())
    def _connect(self, next_url):
        """Send API call to ThreatStream with next token and return parsed IOCs

        The API call has retry logic up to 3 times.
        Args:
            next_url (str): url of next token to retrieve more objects from
                ThreatStream
        """
        intelligence = list()
        https_req = requests.get('{}{}'.format(self._API_URL, next_url),
                                 timeout=10)

        next_url = None
        if https_req.status_code == 200:
            data = https_req.json()
            if data.get('objects'):
                intelligence.extend(self._process_data(data['objects']))

            LOGGER.info('IOC Offset: %d', data['meta']['offset'])
            if not (data['meta']['next']
                    and data['meta']['offset'] < self.threshold):
                LOGGER.debug(
                    'Either next token is empty or IOC offset reaches threshold '
                    '%d. Stop retrieve more IOCs.', self.threshold)
            else:
                next_url = data['meta']['next']
        elif https_req.status_code == 401:
            raise ThreatStreamRequestsError(
                'Response status code 401, unauthorized.')
        elif https_req.status_code == 500:
            raise ThreatStreamRequestsError(
                'Response status code 500, retry now.')
        else:
            raise ThreatStreamRequestsError(
                'Unknown status code {}, do not retry.'.format(
                    https_req.status_code))

        self._finalize(intelligence, next_url)

    def _finalize(self, intel, next_url):
        """Finalize the execution

        Send data to dynamo and continue the invocation if necessary.

        Arguments:
            intel (list): List of intelligence to send to DynamoDB
            next_url (str): Next token to retrieve more IOCs
            continue_invoke (bool): Whether to retrieve more IOCs from
                threat feed. False if next token is empty or threshold of number
                of IOCs is reached.
        """
        if intel:
            LOGGER.info('Write %d IOCs to DynamoDB table', len(intel))
            self._write_to_dynamodb_table(intel)

        if next_url and self.timing_func() > self._END_TIME_BUFFER * 1000:
            self._invoke_lambda_function(next_url)

        LOGGER.debug("Time remaining (MS): %s", self.timing_func())

    def _invoke_lambda_function(self, next_url):
        """Invoke lambda function itself with next token to continually retrieve IOCs"""
        LOGGER.debug('This invocation is invoked by lambda function self.')
        lambda_client = boto3.client('lambda', region_name=self.region)
        try:
            lambda_client.invoke(FunctionName=self._config['function_name'],
                                 InvocationType='Event',
                                 Payload=json.dumps({'next_url': next_url}),
                                 Qualifier=self._config['qualifier'])
        except ClientError as err:
            raise ThreatStreamLambdaInvokeError(
                'Error invoking function: {}'.format(err))

    @staticmethod
    def _epoch_time(time_str, days=90):
        """Convert expiration time (in UTC) to epoch time
        Args:
            time_str (str): expiration time in string format
                Example: '2017-12-19T04:45:18.412Z'
            days (int): default expiration days which 90 days from now

        Returns:
            (int): Epoch time. If no expiration time presented, return to
                default value which is current time + 90 days.
        """
        if not time_str:
            return int((datetime.utcnow() + timedelta(days) -
                        datetime.utcfromtimestamp(0)).total_seconds())

        try:
            utc_time = datetime.strptime(time_str, "%Y-%m-%dT%H:%M:%S.%fZ")
            return int(
                (utc_time - datetime.utcfromtimestamp(0)).total_seconds())
        except ValueError:
            LOGGER.error('Cannot convert expiration date \'%s\' to epoch time',
                         time_str)
            raise

    def _process_data(self, data):
        """Process and filter data by sources and keys
        Args:
            data (list): A list contains ioc information
                Example:
                    [
                        {
                            'value': 'malicious_domain.com',
                            'itype': 'c2_domain',
                            'source': 'crowdstrike',
                            'type': 'domain',
                            'expiration_ts': '2017-12-19T04:45:18.412Z',
                            'key1': 'value1',
                            'key2': 'value2',
                            ...
                        },
                        {
                            'value': 'malicious_domain2.com',
                            'itype': 'c2_domain',
                            'source': 'ioc_source2',
                            'type': 'domain',
                            'expiration_ts': '2017-12-31T04:45:18.412Z',
                            'key1': 'value1',
                            'key2': 'value2',
                            ...
                        }
                    ]

        Returns:
            (list): A list of dict contains useful IOC information
                Example:
                    [
                        {
                            'value': 'malicious_domain.com',
                            'itype': 'c2_domain',
                            'source': 'crowdstrike',
                            'type': 'domain',
                            'expiration_ts': 1513658718,
                        }
                    ]
        """
        results = list()
        for obj in data:
            for source in self.ioc_sources:
                if source in obj['source'].lower():
                    filtered_obj = {
                        key: value
                        for key, value in obj.iteritems()
                        if key in self.ioc_keys
                    }
                    filtered_obj['expiration_ts'] = self._epoch_time(
                        filtered_obj['expiration_ts'])
                    results.append(filtered_obj)
        return results

    def _write_to_dynamodb_table(self, intelligence):
        """Store IOCs to DynamoDB table"""
        try:
            dynamodb = boto3.resource('dynamodb', region_name=self.region)
            table = dynamodb.Table(self.table_name)
            with table.batch_writer() as batch:
                for ioc in intelligence:
                    batch.put_item(
                        Item={
                            'ioc_value': ioc['value'],
                            'ioc_type': ioc['type'],
                            'sub_type': ioc['itype'],
                            'source': ioc['source'],
                            'expiration_ts': ioc['expiration_ts']
                        })
        except ClientError as err:
            LOGGER.debug('DynamoDB client error: %s', err)
            raise

    def runner(self, event):
        """Process URL before making API call
        Args:
            event (dict): Contains lambda function invocation information. Initially,
                Threat Intel Downloader lambda funciton is invoked by Cloudwatch
                event. 'next_url' key will be inserted to event lambda function
                invokes itself to retrieve more IOCs.

        Returns:
            (tuple): (list, str, bool)
                - First object is a list of intelligence.
                - Second object is a string of next token to retrieve more IOCs.
                - Third object is bool to indicated if retrieve more IOCs from
                    threat feed.
        """
        event = event or {}

        self._load_api_creds()

        query = '(status="{}")+AND+({})+AND+NOT+({})'.format(
            self._IOC_STATUS,
            "+OR+".join(['type="{}"'.format(ioc) for ioc in self.ioc_types]),
            "+OR+".join([
                'itype="{}"'.format(itype) for itype in self.excluded_sub_types
            ]))
        next_url = event.get(
            'next_url',
            '/api/v2/{}/?username={}&api_key={}&limit={}&q={}'.format(
                self._API_RESOURCE, self.api_user, self.api_key,
                self._API_MAX_LIMIT, query))

        self._connect(next_url)

    @property
    def excluded_sub_types(self):
        return self._config['excluded_sub_types']

    @property
    def ioc_keys(self):
        return self._config['ioc_keys']

    @property
    def ioc_sources(self):
        return self._config['ioc_filters']

    @property
    def ioc_types(self):
        return self._config['ioc_types']

    @property
    def region(self):
        return self._config['region']

    @property
    def table_name(self):
        return self._config['function_name']

    @property
    def threshold(self):
        return self._API_MAX_INDEX - self._API_MAX_LIMIT
Example #18
0
class PagerDutyIncidentOutput(OutputDispatcher):
    """PagerDutyIncidentOutput handles all alert dispatching for PagerDuty Incidents API v2"""
    __service__ = 'pagerduty-incident'
    INCIDENTS_ENDPOINT = 'incidents'
    USERS_ENDPOINT = 'users'
    POLICIES_ENDPOINT = 'escalation_policies'
    SERVICES_ENDPOINT = 'services'
    PRIORITIES_ENDPOINT = 'priorities'

    BACKOFF_MAX = 5
    BACKOFF_TIME = 5

    def __init__(self, *args, **kwargs):
        OutputDispatcher.__init__(self, *args, **kwargs)
        self._base_url = None
        self._headers = None
        self._escalation_policy_id = None

    @classmethod
    def _get_default_properties(cls):
        """Get the standard url used for PagerDuty Incidents API v2. This value the same for
        everyone, so is hard-coded here and does not need to be configured by the user

        Returns:
            dict: Contains various default items for this output (ie: url)
        """
        return {'api': 'https://api.pagerduty.com'}

    @classmethod
    def get_user_defined_properties(cls):
        """Get properties that must be asssigned by the user when configuring a new PagerDuty
        event output. This should be sensitive or unique information for this use-case that
        needs to come from the user.

        Every output should return a dict that contains a 'descriptor' with a description of the
        integration being configured.

        PagerDuty also requires a routing_key that represents this integration. This
        value should be masked during input and is a credential requirement.

        Returns:
            OrderedDict: Contains various OutputProperty items
        """
        return OrderedDict([
            ('descriptor',
             OutputProperty(
                 description='a short and unique descriptor for this '
                 'PagerDuty integration')),
            ('token',
             OutputProperty(
                 description='the token for this PagerDuty integration',
                 mask_input=True,
                 cred_requirement=True)),
            ('service_name',
             OutputProperty(
                 description='the service name for this PagerDuty integration',
                 cred_requirement=True)),
            ('service_id',
             OutputProperty(
                 description='the service ID for this PagerDuty integration',
                 cred_requirement=True)),
            ('escalation_policy',
             OutputProperty(
                 description='the name of the default escalation policy',
                 input_restrictions={},
                 cred_requirement=True)),
            ('escalation_policy_id',
             OutputProperty(
                 description='the ID of the default escalation policy',
                 cred_requirement=True)),
            ('email_from',
             OutputProperty(description='valid user email from the PagerDuty '
                            'account linked to the token',
                            cred_requirement=True)),
            ('integration_key',
             OutputProperty(
                 description=
                 'the integration key for this PagerDuty integration',
                 cred_requirement=True))
        ])

    @staticmethod
    def _get_endpoint(base_url, endpoint):
        """Helper to get the full url for a PagerDuty Incidents endpoint.

        Args:
            base_url (str): Base URL for the API
            endpoint (str): Endpoint that we want the full URL for

        Returns:
            str: Full URL of the provided endpoint
        """
        return os.path.join(base_url, endpoint)

    def _create_event(self, data):
        """Helper to create an event in the PagerDuty Events API v2

        Args:
            data (dict): JSON blob with the format of the PagerDuty Events API v2
        Returns:
            dict: Contains the HTTP response of the request to the API
        """
        url = 'https://events.pagerduty.com/v2/enqueue'
        try:
            resp = self._post_request_retry(url, data, None, False)
        except OutputRequestFailure:
            return False

        response = resp.json()
        if not response:
            return False

        return response

    @backoff.on_exception(backoff.constant,
                          PagerdutySearchDelay,
                          max_tries=BACKOFF_MAX,
                          interval=BACKOFF_TIME,
                          on_backoff=backoff_handler(),
                          on_success=success_handler(),
                          on_giveup=giveup_handler())
    def _get_event_incident_id(self, incident_key):
        """Helper to lookup an incident using the incident_key and return the id

        Args:
            incident_key (str): Incident key that indentifies uniquely an incident

        Returns:
            str: ID of the incident after look up the incident_key

        """
        params = {'incident_key': incident_key}
        incidents_url = self._get_endpoint(self._base_url,
                                           self.INCIDENTS_ENDPOINT)
        response = self._generic_api_get(incidents_url, params)

        incident = response.get('incidents', [])
        if not incident:
            raise PagerdutySearchDelay()

        return incident[0].get('id')

    def _merge_incidents(self, url, to_be_merged_id):
        """Helper to merge incidents by id using the PagerDuty REST API v2

        Args:
            url (str): The url to send the requests to in the API
            to_be_merged_id (str): ID of the incident to merge with

        Returns:
            dict: Contains the HTTP response of the request to the API
        """
        params = {
            'source_incidents': [{
                'id': to_be_merged_id,
                'type': 'incident_reference'
            }]
        }
        try:
            resp = self._put_request_retry(url, params, self._headers, False)
        except OutputRequestFailure:
            return False

        response = resp.json()
        if not response:
            return False

        return response

    def _generic_api_get(self, url, params):
        """Helper to submit generic GET requests with parameters to the PagerDuty REST API v2

        Args:
            url (str): The url to send the requests to in the API

        Returns:
            dict: Contains the HTTP response of the request to the API
        """
        try:
            resp = self._get_request_retry(url, params, self._headers, False)
        except OutputRequestFailure:
            return False

        response = resp.json()
        if not response:
            return False

        return response

    def _check_exists(self, filter_str, url, target_key, get_id=True):
        """Generic method to run a search in the PagerDuty REST API and return the id
        of the first occurence from the results.

        Args:
            filter_str (str): The query filter to search for in the API
            url (str): The url to send the requests to in the API
            target_key (str): The key to extract in the returned results
            get_id (boolean): Whether to generate a dict with result and reference

        Returns:
            str: ID of the targeted element that matches the provided filter or
                 True/False whether a matching element exists or not.
        """
        params = {'query': filter_str}
        response = self._generic_api_get(url, params)
        if not response:
            return False

        if not get_id:
            return True

        # If there are results, get the first occurence from the list
        return response[target_key][0][
            'id'] if target_key in response else False

    def _user_verify(self, user, get_id=True):
        """Method to verify the existance of an user with the API
        Args:
            user (str): User to query about in the API.
            get_id (boolean): Whether to generate a dict with result and reference
        Returns:
            dict or False: JSON object be used in the API call, containing the user_id
                           and user_reference. False if user is not found
        """
        return self._item_verify(user, self.USERS_ENDPOINT, 'user_reference',
                                 get_id)

    def _policy_verify(self, policy, default_policy):
        """Method to verify the existance of a escalation policy with the API
        Args:
            policy (str): Escalation policy to query about in the API
            default_policy (str): Escalation policy to use if the first one is not verified
        Returns:
            dict: JSON object be used in the API call, containing the policy_id
                  and escalation_policy_reference
        """
        verified = self._item_verify(policy, self.POLICIES_ENDPOINT,
                                     'escalation_policy_reference')

        # If the escalation policy provided is not verified in the API, use the default
        if verified:
            return verified

        return self._item_verify(default_policy, self.POLICIES_ENDPOINT,
                                 'escalation_policy_reference')

    def _service_verify(self, service):
        """Method to verify the existance of a service with the API

        Args:
            service (str): Service to query about in the API

        Returns:
            dict: JSON object be used in the API call, containing the service_id
                  and the service_reference
        """
        return self._item_verify(service, self.SERVICES_ENDPOINT,
                                 'service_reference')

    def _item_verify(self, item_str, item_key, item_type, get_id=True):
        """Method to verify the existance of an item with the API
        Args:
            item_str (str): Service to query about in the API
            item_key (str): Endpoint/key to be extracted from search results
            item_type (str): Type of item reference to be returned
            get_id (boolean): Whether to generate a dict with result and reference
        Returns:
            dict: JSON object be used in the API call, containing the item id
                  and the item reference, True if it just exists or False if it fails
        """
        item_url = self._get_endpoint(self._base_url, item_key)
        item_id = self._check_exists(item_str, item_url, item_key, get_id)
        if not item_id:
            LOGGER.info('%s not found in %s, %s', item_str, item_key,
                        self.__service__)
            return False

        if get_id:
            return {'id': item_id, 'type': item_type}

        return item_id

    def _priority_verify(self, context):
        """Method to verify the existance of a incident priority with the API

        Args:
            context (dict): Context provided in the alert record

        Returns:
            dict: JSON object be used in the API call, containing the priority id
                  and the priority reference, empty if it fails or it does not exist
        """
        if not context:
            return dict()

        priority_name = context.get('incident_priority', False)
        if not priority_name:
            return dict()

        priorities_url = self._get_endpoint(self._base_url,
                                            self.PRIORITIES_ENDPOINT)

        try:
            resp = self._get_request_retry(priorities_url, {}, self._headers,
                                           False)
        except OutputRequestFailure:
            return dict()

        response = resp.json()
        if not response:
            return dict()

        priorities = response.get('priorities', [])

        if not priorities:
            return dict()

        # If the requested priority is in the list, get the id
        priority_id = next(
            (item for item in priorities if item["name"] == priority_name),
            {}).get('id', False)

        # If the priority id is found, compose the JSON
        if priority_id:
            return {'id': priority_id, 'type': 'priority_reference'}

        return dict()

    def _incident_assignment(self, context):
        """Method to determine if the incident gets assigned to a user or an escalation policy

        Args:
            context (dict): Context provided in the alert record

        Returns:
            tuple: assigned_key (str), assigned_value (dict to assign incident to an escalation
            policy or array of dicts to assign incident to users)
        """
        # Check if a user to assign the incident is provided
        user_to_assign = context.get('assigned_user', False)

        # If provided, verify the user and get the id from API
        if user_to_assign:
            user_assignee = self._user_verify(user_to_assign)
            # User is verified, return tuple
            if user_assignee:
                return 'assignments', [{'assignee': user_assignee}]

        # If escalation policy ID was not provided, use default one
        policy_id_to_assign = context.get('assigned_policy_id',
                                          self._escalation_policy_id)

        # Assinged to escalation policy ID, return tuple
        return 'escalation_policy', {
            'id': policy_id_to_assign,
            'type': 'escalation_policy_reference'
        }

    def _add_incident_note(self, incident_id, note):
        """Method to add a text note to the provided incident id

        Args:
            incident_id (str): ID of the incident to add the note to

        Returns:
            str: ID of the note after being added to the incident or False if it fails
        """
        notes_path = '{}/{}/notes'.format(self.INCIDENTS_ENDPOINT, incident_id)
        incident_notes_url = self._get_endpoint(self._base_url, notes_path)
        data = {'note': {'content': note}}
        try:
            resp = self._post_request_retry(incident_notes_url, data,
                                            self._headers, True)
        except OutputRequestFailure:
            return False

        response = resp.json()
        if not response:
            return False

        note_rec = response.get('note', {})

        return note_rec.get('id', False)

    def dispatch(self, alert, descriptor):
        """Send incident to Pagerduty Incidents API v2

        Args:
            alert (Alert): Alert instance which triggered a rule
            descriptor (str): Output descriptor

        Returns:
            bool: True if alert was sent successfully, False otherwise
        """
        creds = self._load_creds(descriptor)
        if not creds:
            return self._log_status(False, descriptor)

        # Cache base_url
        self._base_url = creds['api']

        # Preparing headers for API calls
        self._headers = {
            'Accept': 'application/vnd.pagerduty+json;version=2',
            'Content-Type': 'application/json',
            'Authorization': 'Token token={}'.format(creds['token'])
        }

        # Get user email to be added as From header and verify
        user_email = creds['email_from']
        if not self._user_verify(user_email, False):
            LOGGER.error('Could not verify header From: %s, %s', user_email,
                         self.__service__)
            return self._log_status(False, descriptor)

        # Add From to the headers after verifying
        self._headers['From'] = user_email

        # Cache default escalation policy
        self._escalation_policy_id = creds['escalation_policy_id']

        # Extracting context data to assign the incident
        rule_context = alert.context
        if rule_context:
            rule_context = rule_context.get(self.__service__, {})

        # Use the priority provided in the context, use it or the incident will be low priority
        incident_priority = self._priority_verify(rule_context)

        # Incident assignment goes in this order:
        #  Provided user -> provided policy -> default policy
        assigned_key, assigned_value = self._incident_assignment(rule_context)

        # Start preparing the incident JSON blob to be sent to the API
        incident_title = 'StreamAlert Incident - Rule triggered: {}'.format(
            alert.rule_name)
        incident_body = {
            'type': 'incident_body',
            'details': alert.rule_description
        }
        # Using the service ID for the PagerDuty API
        incident_service = {
            'id': creds['service_id'],
            'type': 'service_reference'
        }
        incident_data = {
            'incident': {
                'type': 'incident',
                'title': incident_title,
                'service': incident_service,
                'priority': incident_priority,
                'body': incident_body,
                assigned_key: assigned_value
            }
        }
        incidents_url = self._get_endpoint(self._base_url,
                                           self.INCIDENTS_ENDPOINT)

        try:
            incident = self._post_request_retry(incidents_url, incident_data,
                                                self._headers, True)
        except OutputRequestFailure:
            incident = False

        if not incident:
            LOGGER.error('Could not create main incident, %s',
                         self.__service__)
            return self._log_status(False, descriptor)

        # Extract the json blob from the response, returned by self._post_request_retry
        incident_json = incident.json()
        if not incident_json:
            return self._log_status(False, descriptor)

        # Extract the incident id from the incident that was just created
        incident_id = incident_json.get('incident', {}).get('id')

        # Create alert to hold all the incident details
        with_record = rule_context.get('with_record', True)
        event_data = events_v2_data(alert, creds['integration_key'],
                                    with_record)
        event = self._create_event(event_data)
        if not event:
            LOGGER.error('Could not create incident event, %s',
                         self.__service__)
            return self._log_status(False, descriptor)

        # Lookup the incident_key returned as dedup_key to get the incident id
        incident_key = event.get('dedup_key')

        if not incident_key:
            LOGGER.error('Could not get incident key, %s', self.__service__)
            return self._log_status(False, descriptor)

        # Keep that id to be merged later with the created incident
        event_incident_id = self._get_event_incident_id(incident_key)

        # Merge the incident with the event, so we can have a rich context incident
        # assigned to a specific person, which the PagerDuty REST API v2 does not allow
        merging_url = '{}/{}/merge'.format(incidents_url, incident_id)
        merged = self._merge_incidents(merging_url, event_incident_id)

        # Add a note to the combined incident to help with triage
        if not merged:
            LOGGER.error('Could not add note to incident, %s',
                         self.__service__)
        else:
            merged_id = merged.get('incident', {}).get('id')
            note = rule_context.get('note', 'Creating SOX Incident')
            self._add_incident_note(merged_id, note)

        return self._log_status(incident_id, descriptor)
def test_backoff_handler_debug(log_mock):
    """Backoff Handlers - Backoff Handler, Debug"""
    on_backoff = backoff_handler()
    on_backoff(_get_details(True))
    log_mock.assert_called()
Example #20
0
class AlertProcessor(object):
    """Orchestrates delivery of alerts to the appropriate dispatchers."""
    ALERT_PROCESSOR = None  # AlertProcessor instance which can be re-used across Lambda invocations
    BACKOFF_MAX_TRIES = 5

    @classmethod
    def get_instance(cls):
        """Get an instance of the AlertProcessor, using a cached version if possible."""
        if not cls.ALERT_PROCESSOR:
            cls.ALERT_PROCESSOR = AlertProcessor()
        return cls.ALERT_PROCESSOR

    def __init__(self):
        """Initialization logic that can be cached across invocations"""
        # Merge user-specified output configuration with the required output configuration
        output_config = load_config(include={'outputs.json'})['outputs']
        self.config = resources.merge_required_outputs(output_config, env['STREAMALERT_PREFIX'])

        self.alerts_table = AlertTable(env['ALERTS_TABLE'])

    def _create_dispatcher(self, output):
        """Create a dispatcher for the given output.

        Args:
            output (str): Alert output, e.g. "aws-sns:topic-name"

        Returns:
            OutputDispatcher: Based on the output type.
                Returns None if the output is invalid or not defined in the config.
        """
        try:
            service, descriptor = output.split(':')
        except ValueError:
            LOGGER.error('Improperly formatted output [%s]. Outputs for rules must '
                         'be declared with both a service and a descriptor for the '
                         'integration (ie: \'slack:my_channel\')', output)
            return None

        if service not in self.config or descriptor not in self.config[service]:
            LOGGER.error('The output \'%s\' does not exist!', output)
            return None

        return StreamAlertOutput.create_dispatcher(service, self.config)

    def _send_to_outputs(self, alert):
        """Send an alert to each remaining output.

        Args:
            alert (Alert): Alert to send

        Returns:
            dict: Maps output (str) to whether it sent successfully (bool)
        """
        result = {}

        for output in alert.remaining_outputs:
            dispatcher = self._create_dispatcher(output)
            result[output] = dispatcher.dispatch(alert, output) if dispatcher else False

        alert.outputs_sent = set(output for output, success in result.items() if success)
        return result

    @backoff.on_exception(backoff.expo, ClientError,
                          max_tries=BACKOFF_MAX_TRIES, jitter=backoff.full_jitter,
                          on_backoff=backoff_handlers.backoff_handler(),
                          on_success=backoff_handlers.success_handler(),
                          on_giveup=backoff_handlers.giveup_handler())
    def _update_table(self, alert, output_results):
        """Update the alerts table based on the results of the outputs.

        Args:
            alert (Alert): Alert instance which was sent
            output_results (dict): Maps output (str) to whether it sent successfully (bool)
        """
        if not output_results:
            return

        if all(output_results.values()) and not alert.merge_enabled:
            # All outputs sent successfully and the alert will not be merged later - delete it now
            self.alerts_table.delete_alerts([(alert.rule_name, alert.alert_id)])
        elif any(output_results.values()):
            # At least one output succeeded - update table accordingly
            self.alerts_table.update_sent_outputs(alert)
        # else: If all outputs failed, no table updates are necessary

    def run(self, event):
        """Run the alert processor!

        Args:
            event (dict): Lambda invocation event containing at least the rule name and alert ID.

        Returns:
            dict: Maps output (str) to whether it sent successfully (bool).
                An empty dict is returned if the Alert was improperly formatted.
        """
        # Grab the alert record from Dynamo (if needed).
        if set(event) == {'AlertID', 'RuleName'}:
            LOGGER.info('Retrieving %s from alerts table', event)
            alert_record = self.alerts_table.get_alert_record(event['RuleName'], event['AlertID'])
            if not alert_record:
                LOGGER.error('%s does not exist in the alerts table', event)
                return {}
        else:
            alert_record = event

        # Convert record to an Alert instance.
        try:
            alert = Alert.create_from_dynamo_record(alert_record)
        except AlertCreationError:
            LOGGER.exception('Invalid alert %s', event)
            return {}

        # Remove normalization key from the record.
        # TODO: Consider including this in at least some outputs, e.g. default Athena firehose
        if Normalizer.NORMALIZATION_KEY in alert.record:
            del alert.record[Normalizer.NORMALIZATION_KEY]

        result = self._send_to_outputs(alert)
        self._update_table(alert, result)
        return result
Example #21
0
    def _query(self, values):
        """Instance method to query DynamoDB table

        Args:
            values (list): A list of string which contains IOC values

        Returns:
            A tuple(list, dict)
            list: A list of dict returned from dynamodb
                table query, in the format of
                    [
                        {'sub_type': 'c2_domain', 'ioc_value': 'evil.com'},
                        {'sub_type': 'mal_ip', 'ioc_value': '1.1.1.2'},
                    ]
            dict: A dict containing unprocesed keys.
        """
        @backoff.on_predicate(
            backoff.fibo,
            lambda resp: bool(resp['UnprocessedKeys']
                              ),  # retry if this is true
            max_tries=2,  # only retry unprocessed key 2 times max
            on_backoff=backoff_handler(),
            on_success=success_handler(),
            on_giveup=giveup_handler())
        @backoff.on_exception(backoff.expo,
                              self.EXCEPTIONS_TO_BACKOFF,
                              max_tries=self.BACKOFF_MAX_RETRIES,
                              giveup=self._exceptions_to_giveup,
                              on_backoff=backoff_handler(),
                              on_success=success_handler(),
                              on_giveup=giveup_handler())
        def _run_query(query_values, results):

            query_keys = [{
                self.PRIMARY_KEY: {
                    'S': ioc
                }
            } for ioc in query_values if ioc]

            response = self._dynamodb.batch_get_item(
                RequestItems={
                    self._table: {
                        'Keys': query_keys,
                        'ProjectionExpression': self.PROJECTION_EXPRESSION
                    }
                })

            results.extend([
                result for result in self._deserialize(
                    response['Responses'].get(self._table))
            ])

            # Log this as an error for now so it can be picked up in logs
            if response['UnprocessedKeys']:
                LOGGER.error('Retrying unprocessed keys in response: %s',
                             response['UnprocessedKeys'])
                # Strip out the successful keys so only the unprocesed ones are retried.
                # This changes the list in place, so the called function sees the updated list
                self._remove_processed_keys(
                    query_values,
                    response['UnprocessedKeys'][self._table]['Keys'])

            return response

        results = []

        _run_query(values, results)

        return results
Example #22
0
class ThreatStream(object):
    """Class to retrieve IOCs from ThreatStream.com and store them in DynamoDB"""
    _API_URL = 'https://api.threatstream.com'
    _API_RESOURCE = 'intelligence'
    _IOC_STATUS = 'active'
    # max IOC objects received from one API call, default is 0 (equal to 1000)
    _API_MAX_LIMIT = 1000
    _API_MAX_INDEX = 500000
    _PARAMETER_NAME = 'threat_intel_downloader_api_creds'

    EXCEPTIONS_TO_BACKOFF = (requests.exceptions.Timeout,
                             requests.exceptions.ConnectionError,
                             requests.exceptions.ChunkedEncodingError,
                             ThreatStreamRequestsError)
    BACKOFF_MAX_RETRIES = 3

    def __init__(self, config):
        self.ioc_types = config['ioc_types']
        self.excluded_sub_types = config['excluded_sub_types']
        self.ioc_sources = config['ioc_filters']
        self.threshold = self._API_MAX_INDEX - self._API_MAX_LIMIT
        self.region = config['region']
        self.ioc_keys = config['ioc_keys']
        self.api_user = None
        self.api_key = None
        self._get_api_creds()
        self.table_name = config['function_name']

    def _get_api_creds(self):
        """Retrieve ThreatStream API credentials from Parameter Store"""
        try:
            ssm = boto3.client('ssm', self.region)
            response = ssm.get_parameters(Names=[self._PARAMETER_NAME],
                                          WithDecryption=True)
        except ClientError as err:
            LOGGER.error('SSM client error: %s', err)
            raise

        for cred in response['Parameters']:
            if cred['Name'] == self._PARAMETER_NAME:
                try:
                    decoded_creds = json.loads(cred['Value'])
                    self.api_user = decoded_creds['api_user']
                    self.api_key = decoded_creds['api_key']
                except ValueError:
                    LOGGER.error(
                        'Can not load value for parameter with '
                        'name \'%s\'. The value is not valid json: '
                        '\'%s\'', cred['Name'], cred['Value'])
                    raise ThreatStreamCredsError('ValueError')

        if not (self.api_user and self.api_key):
            LOGGER.error('API Creds Error')
            raise ThreatStreamCredsError('API Creds Error')

    @backoff.on_exception(backoff.constant,
                          EXCEPTIONS_TO_BACKOFF,
                          max_tries=BACKOFF_MAX_RETRIES,
                          on_backoff=backoff_handler(),
                          on_success=success_handler(),
                          on_giveup=giveup_handler())
    def _connect(self, next_url):
        """Send API call to ThreatStream with next token and return parsed IOCs

        The API call has retry logic up to 3 times.
        Args:
            next_url (str): url of next token to retrieve more objects from
                ThreatStream

        Returns:
            (tuple): (list, str, bool)
                - First object is a list of intelligence.
                - Second object is a string of next token to retrieve more IOCs.
                - Third object is bool to indicated if retrieve more IOCs from
                    threat feed.
                    Return False if next token is empty or threshold of number
                    of IOCs is reached.
        """
        continue_invoke = False
        intelligence = list()

        https_req = requests.get('{}{}'.format(self._API_URL, next_url),
                                 timeout=10)
        if https_req.status_code == 200:
            data = https_req.json()
            if data.get('objects'):
                intelligence.extend(self._process_data(data['objects']))
            LOGGER.info('IOC Offset: %d', data['meta']['offset'])
            if not (data['meta']['next']
                    and data['meta']['offset'] < self.threshold):
                LOGGER.debug(
                    'Either next token is empty or IOC offset '
                    'reaches threshold %d. Stop retrieve more '
                    'IOCs.', self.threshold)
                continue_invoke = False
            else:
                next_url = data['meta']['next']
                continue_invoke = True
        elif https_req.status_code == 401:
            raise ThreatStreamRequestsError(
                'Response status code 401, unauthorized.')
        elif https_req.status_code == 500:
            raise ThreatStreamRequestsError(
                'Response status code 500, retry now.')
        else:
            raise ThreatStreamRequestsError('Unknown status code {}, '
                                            'do not retry.'.format(
                                                https_req.status_code))

        return (intelligence, next_url, continue_invoke)

    def runner(self, event):
        """Process URL before making API call
        Args:
            event (dict): Contains lambda function invocation information. Initially,
                Threat Intel Downloader lambda funciton is invoked by Cloudwatch
                event. 'next_url' key will be inserted to event lambda function
                invokes itself to retrieve more IOCs.

        Returns:
            (tuple): (list, str, bool)
                - First object is a list of intelligence.
                - Second object is a string of next token to retrieve more IOCs.
                - Third object is bool to indicated if retrieve more IOCs from
                    threat feed.
        """
        if not event:
            return None, None, False
        query = '(status="{}")+AND+({})+AND+NOT+({})'.format(
            self._IOC_STATUS,
            "+OR+".join(['type="{}"'.format(ioc) for ioc in self.ioc_types]),
            "+OR+".join([
                'itype="{}"'.format(itype) for itype in self.excluded_sub_types
            ]))
        next_url = event.get(
            'next_url',
            '/api/v2/{}/?username={}&api_key={}&limit={}&q={}'.format(
                self._API_RESOURCE, self.api_user, self.api_key,
                self._API_MAX_LIMIT, query))

        if not next_url:
            return None, None, False

        return self._connect(next_url)

    @staticmethod
    def _epoch_time(time_str, days=90):
        """Convert expiration time (in UTC) to epoch time
        Args:
            time_str (str): expiration time in string format
                Example: '2017-12-19T04:45:18.412Z'
            days (int): default expiration days which 90 days from now

        Returns:
            (int): Epoch time. If no expiration time presented, return to
                default value which is current time + 90 days.
        """

        if not time_str:
            return int((datetime.now() + timedelta(days) -
                        datetime.utcfromtimestamp(0)).total_seconds())

        try:
            utc_time = datetime.strptime(time_str, "%Y-%m-%dT%H:%M:%S.%fZ")
            return int(
                (utc_time - datetime.utcfromtimestamp(0)).total_seconds())
        except ValueError:
            LOGGER.error('Cannot convert expiration date \'%s\' to epoch time',
                         time_str)
            raise

    def _process_data(self, data):
        """Process and filter data by sources and keys
        Args:
            data (list): A list contains ioc information
                Example:
                    [
                        {
                            'value': 'malicious_domain.com',
                            'itype': 'c2_domain',
                            'source': 'crowdstrike',
                            'type': 'domain',
                            'expiration_ts': '2017-12-19T04:45:18.412Z',
                            'key1': 'value1',
                            'key2': 'value2',
                            ...
                        },
                        {
                            'value': 'malicious_domain2.com',
                            'itype': 'c2_domain',
                            'source': 'ioc_source2',
                            'type': 'domain',
                            'expiration_ts': '2017-12-31T04:45:18.412Z',
                            'key1': 'value1',
                            'key2': 'value2',
                            ...
                        }
                    ]

        Returns:
            (list): A list of dict contains useful IOC information
                Example:
                    [
                        {
                            'value': 'malicious_domain.com',
                            'itype': 'c2_domain',
                            'source': 'crowdstrike',
                            'type': 'domain',
                            'expiration_ts': 1513658718,
                        }
                    ]
        """
        results = list()
        for obj in data:
            for source in self.ioc_sources:
                if source in obj['source'].lower():
                    filtered_obj = {
                        key: value
                        for key, value in obj.iteritems()
                        if key in self.ioc_keys
                    }
                    filtered_obj['expiration_ts'] = self._epoch_time(
                        filtered_obj['expiration_ts'])
                    results.append(filtered_obj)
        return results

    def write_to_dynamodb_table(self, intelligence):
        """Store IOCs to DynamoDB table"""
        try:
            dynamodb = boto3.resource('dynamodb', region_name=self.region)
            table = dynamodb.Table(self.table_name)
            with table.batch_writer() as batch:
                for ioc in intelligence:
                    batch.put_item(
                        Item={
                            'ioc_value': ioc['value'],
                            'ioc_type': ioc['type'],
                            'sub_type': ioc['itype'],
                            'source': ioc['source'],
                            'expiration_ts': ioc['expiration_ts']
                        })
        except ClientError as err:
            LOGGER.debug('DynamoDB client error: %s', err)
            raise