예제 #1
0
    def _evaluate_interval(self):
        """Get the interval at which this function is executing. This translates
        an AWS Rate Schedule Expression ('rate(2 hours)') into a second interval
        """
        rate_match = AWS_RATE_RE.match(self._schedule)

        if not rate_match:
            raise AppConfigError('Invalid \'rate\' interval value: '
                                 '{}'.format(self._schedule))

        value = rate_match.group(2) or rate_match.group(4)
        unit = rate_match.group(3) or rate_match.group(5).replace('s', '')

        translate_to_seconds = {
            'minute': 60,
            'hour': 60 * 60,
            'day': 60 * 60 * 24
        }

        interval = int(value) * translate_to_seconds[unit]

        LOGGER.debug('Evaluated rate interval: %d seconds', interval)

        # Get the total seconds that this rate evaluates to
        return interval
예제 #2
0
    def _gather(self):
        """Protected entry point to peform the gather that returns the time the process took

        Returns:
            float: time, in seconds, for which the function ran
        """
        # Make this request sleep if the API throttles requests
        self._sleep()

        # Increment the poll count
        self._poll_count += 1

        logs = self._gather_logs()

        # Make sure there are logs, this can be False if there was an issue polling
        # of if there are no new logs to be polled
        if not logs:
            self._more_to_poll = False
            LOGGER.error(
                '[%s] Gather process was not able to poll any logs '
                'on poll #%d', self, self._poll_count)
            return

        # Increment the count of logs gathered
        self._gathered_log_count += len(logs)

        # Utilize the batcher to send logs to the rule processor
        self._batcher.send_logs(logs)

        LOGGER.debug('Updating config last timestamp from %s to %s',
                     self._config.last_timestamp, self._last_timestamp)

        # Save the config's last timestamp after each function run
        self._config.last_timestamp = self._last_timestamp
예제 #3
0
    def _determine_last_time(self, date_format):
        """Determine the last time this function was executed and fallback on
        evaluating the rate value if there is no last timestamp available

        Returns:
            int: The unix timestamp for the starting point to fetch logs back to
        """
        if not self.last_timestamp:
            interval_time = self._evaluate_interval()
            current_time = int(calendar.timegm(time.gmtime()))
            time_delta = current_time - interval_time
            LOGGER.debug(
                'Current timestamp: %s seconds. Calculated delta: %s seconds',
                current_time, time_delta)

            # Request the date format from the app since some services expect different types
            # Using init=False will return the class without instantiating it
            if date_format:
                self.last_timestamp = datetime.utcfromtimestamp(
                    time_delta).strftime(date_format)
            else:
                self.last_timestamp = time_delta

        LOGGER.info('Starting last timestamp set to: %s', self.last_timestamp)

        return self.last_timestamp
예제 #4
0
    def gather(self):
        """Public method for actual gathering of logs"""
        # Initialize the app, saving state to 'running'
        if not self._initialize():
            return

        try:
            # Add a 50% buffer to the time it took to account for some unforeseen delay and to give
            # this function enough time to spawn a new invocation if there are more logs to poll
            while (((self._gather() * self._POLL_BUFFER_MULTIPLIER) +
                    self._sleep_seconds()) < self._remaining_seconds):
                LOGGER.debug('[%s] More logs to poll: %s', self,
                             self._more_to_poll)
                self._config.report_remaining_seconds()
                if not self._more_to_poll:
                    break

                # Reset the boolean indicating that there is more data to poll. Subclasses should
                # set this to 'True' within their implementation of the '_gather_logs' function
                self._more_to_poll = not self._more_to_poll

            LOGGER.debug(
                '[%s] Gathered all logs possible for this execution. More logs to poll: '
                '%s', self, self._more_to_poll)

            self._config.report_remaining_seconds()

            # Finalize, saving state to 'succeeded'
            self._finalize()
        finally:
            # Make sure the config is not left marked as running, which could be problematic
            if self._config and self._config.is_running:
                self._config.mark_failure()
예제 #5
0
    def last_timestamp(self, timestamp):
        """Set the last timestamp"""
        if self._last_timestamp == timestamp:
            LOGGER.debug('Timestamp is unchanged and will not be saved: %s',
                         timestamp)
            return

        LOGGER.debug('Setting last timestamp to: %s', timestamp)

        self._last_timestamp = timestamp
        self._save_state()
예제 #6
0
    def _sleep(self):
        """Function to sleep the looping"""
        # Do not sleep if this is the first poll
        if self._poll_count == 0:
            LOGGER.debug('Skipping sleep for first poll')
            return

        # Sleep for n seconds so the called API does not return a bad response
        sleep_for_secs = self._sleep_seconds()
        LOGGER.debug('[%s] Sleeping for %d seconds...', self, sleep_for_secs)

        time.sleep(sleep_for_secs)
예제 #7
0
    def _make_get_request(self, full_url, headers, params=None):
        """Method for returning the json loaded response for this GET request

        Returns:
            tuple (bool, dict): False if the was an error performing the request,
                and the dictionary loaded from the json response
        """
        LOGGER.debug('[%s] Making GET request on poll #%d', self, self._poll_count)

        # Perform the request and return the response as a dict
        response = requests.get(full_url, headers=headers,
                                params=params, timeout=self._DEFAULT_REQUEST_TIMEOUT)

        return self._check_http_response(response), response.json()
예제 #8
0
    def context(self, context):
        """Set an additional context dictionary specific to each app"""
        if self._context == context:
            LOGGER.debug('App context is unchanged and will not be saved: %s',
                         context)
            return

        if not isinstance(context, dict):
            raise AppStateError(
                'Unable to set context: %s. Must be a dictionary', context)

        LOGGER.debug('Setting context to: %s', context)

        self._context = context
        self._save_state()
예제 #9
0
    def current_state(self, state):
        """Set the current state of the execution"""
        if not getattr(self.States, str(state).upper(), None):
            LOGGER.error('Current state cannot be saved with value \'%s\'',
                         state)
            return

        if self._current_state == state:
            LOGGER.debug('State is unchanged and will not be saved: %s', state)
            return

        LOGGER.debug('Setting current state to: %s', state)

        self._current_state = state
        self._save_state()
예제 #10
0
    def _send_logs_to_lambda(self, logs):
        """Protected method for sending logs to the rule processor lambda
        function for processing. This performs some size checks before sending.

        Args:
            source_function (str): The app function name from which the logs came
            logs (list): List of the logs that have been gathered
        """
        # Create a payload to be sent to the rule processor that contains the
        # service these logs were collected from and the list of logs
        payload = {'Records': [{'stream_alert_app': self._source_function, 'logs': logs}]}
        payload_json = json.dumps(payload, separators=(',', ':'))
        if len(payload_json) > self.MAX_LAMBDA_PAYLOAD_SIZE:
            if len(logs) == 1:
                LOGGER.error('Log payload size for single log exceeds input limit and will be '
                             'dropped (%d > %d max).', len(payload_json),
                             self.MAX_LAMBDA_PAYLOAD_SIZE)
                return True

            LOGGER.debug('Log payload size for %d logs exceeds limit and will be '
                         'segmented (%d > %d max).', len(logs), len(payload_json),
                         self.MAX_LAMBDA_PAYLOAD_SIZE)
            return False

        LOGGER.debug('Sending %d logs to rule processor with payload size %d',
                     len(logs), len(payload_json))

        try:
            response = Batcher.LAMBDA_CLIENT.invoke(
                FunctionName=self._destination_function,
                InvocationType='Event',
                Payload=payload_json,
                Qualifier='production'
            )

        except ClientError as err:
            LOGGER.error('An error occurred while sending logs to '
                         '\'%s:production\'. Error is: %s',
                         self._destination_function,
                         err.response)
            raise

        LOGGER.info('Sent %d logs to \'%s\' with Lambda request ID \'%s\'',
                    len(logs),
                    self._destination_function,
                    response['ResponseMetadata']['RequestId'])

        return True
예제 #11
0
    def load_config(cls, event, context):
        """Load the configuration for this app invocation

        Args:
            event (dict): The AWS Lambda input event, which is JSON serialized to a dictionary
            context (LambdaContext): The AWS LambdaContext object, passed in via the handler.

        Returns:
            AppConfig: Configuration for the running application
        """
        # Patch out the protected _remaining_ms method to the AWS timing function
        AppConfig.remaining_ms = context.get_remaining_time_in_millis
        func_name = context.function_name
        func_version = context.function_version

        # Get full parameter names for authentication and state parameters
        auth_param_name = '_'.join([func_name, cls.AUTH_CONFIG_SUFFIX])
        state_param_name = '_'.join([func_name, cls.STATE_CONFIG_SUFFIX])

        # Get the loaded parameters and a list of any invalid ones from parameter store
        params, invalid_params = cls._get_parameters(auth_param_name,
                                                     state_param_name)

        # Check to see if the authentication param is in the invalid params list
        if auth_param_name in invalid_params:
            raise AppConfigError(
                'Could not load authentication parameter required for this '
                'app: {}'.format(auth_param_name))

        LOGGER.debug('Retrieved parameters from parameter store: %s',
                     cls._scrub_auth_info(params, auth_param_name))
        LOGGER.debug(
            'Invalid parameters could not be retrieved from parameter store: %s',
            invalid_params)

        # Load the authentication info. This data can vary from service to service
        auth_config = {
            key: value.encode('utf-8') if isinstance(value, unicode) else value
            for key, value in params[auth_param_name].iteritems()
        }
        state_config = params.get(state_param_name, {})

        return AppConfig(auth_config, state_config, event, func_name,
                         func_version)
예제 #12
0
    def _make_post_request(self, full_url, headers, data, is_json=True):
        """Method for returning the json loaded response for this POST request

        Returns:
            tuple (bool, dict): False if the was an error performing the request,
                and the dictionary loaded from the json response
        """
        LOGGER.debug('[%s] Making POST request on poll #%d', self, self._poll_count)

        # Perform the request and return the response as a dict
        if is_json:
            response = requests.post(full_url, headers=headers,
                                     json=data, timeout=self._DEFAULT_REQUEST_TIMEOUT)
        else:
            # if content type is form-encoded, the param is 'data' rather than 'json'
            response = requests.post(full_url, headers=headers,
                                     data=data, timeout=self._DEFAULT_REQUEST_TIMEOUT)

        return self._check_http_response(response), response.json()
예제 #13
0
    def _segment_and_send(self, logs):
        """Protected method for segmenting a list of logs into smaller lists
        so they conform to the input limit of AWS Lambda

        Args:
            source_function (str): The app function name from which the logs came
            logs (list): List of the logs that have been gathered
        """
        log_count = len(logs)
        LOGGER.debug('Segmenting %d logs into subsets', log_count)

        segment_size = int(math.ceil(log_count / 2.0))
        for index in range(0, log_count, segment_size):
            subset = logs[index:segment_size + index]
            # Try to send this current subset to the rule processor
            # and segment again if they are too large to be sent at once
            if not self._send_logs_to_lambda(subset):
                self._segment_and_send(subset)

        return True
예제 #14
0
    def _get_parameters(cls, *names):
        """Simple helper function to house the boto3 ssm client get_parameters operations

        Args:
            names (list): A list of parameter names to retrieve from the aws ssm
                parameter store

        Returns:
            tuple (dict, list): Dictionary with the load parameter names as keys
                and the actual parameter (as a dictionary) as the value. The seconary
                list that is returned contains any invalid parameters that were not loaded
        """
        # Create the ssm boto3 client that will be cached and used throughout this execution
        # if one does not exist already
        if AppConfig.SSM_CLIENT is None:
            boto_config = client.Config(connect_timeout=cls.BOTO_TIMEOUT,
                                        read_timeout=cls.BOTO_TIMEOUT)
            AppConfig.SSM_CLIENT = boto3.client('ssm', config=boto_config)

        LOGGER.debug('Retrieving values from parameter store with names: %s',
                     ', '.join('\'{}\''.format(name) for name in names))
        try:
            parameters = AppConfig.SSM_CLIENT.get_parameters(
                Names=list(names), WithDecryption=True)
        except ClientError as err:
            joined_names = ', '.join('\'{}\''.format(name) for name in names)
            raise AppConfigError(
                'Could not get parameter with names {}. Error: '
                '{}'.format(joined_names, err.response['Error']['Message']))

        decoded_params = {}
        for param in parameters['Parameters']:
            try:
                decoded_params[param['Name']] = json.loads(param['Value'])
            except ValueError:
                raise AppConfigError(
                    'Could not load value for parameter with '
                    'name \'{}\'. The value is not valid json: '
                    '\'{}\''.format(param['Name'], param['Value']))

        return decoded_params, parameters['InvalidParameters']
예제 #15
0
 def context(self):
     """Get an additional context dictionary specific to each app"""
     LOGGER.debug('Getting context: %s', self._context)
     return self._context
예제 #16
0
 def mark_failure(self):
     """Helper method to mark the state as 'failed'"""
     LOGGER.debug('Marking current_state as: %s', self.States.FAILED)
     self.current_state = self.States.FAILED
예제 #17
0
 def mark_success(self):
     """Helper method to mark the state as 'succeeded'"""
     LOGGER.debug('Marking current_state as: %s', self.States.SUCCEEDED)
     self.current_state = self.States.SUCCEEDED
예제 #18
0
 def mark_running(self):
     """Helper method to mark the state as 'running'"""
     LOGGER.debug('Marking current_state as: %s', self.States.RUNNING)
     self.current_state = self.States.RUNNING
예제 #19
0
 def mark_partial(self):
     """Helper method to mark the state as 'partial'"""
     LOGGER.debug('Marking current_state as: %s', self.States.PARTIAL)
     self.current_state = self.States.PARTIAL
예제 #20
0
    def is_successive_invocation(self):
        """Check if this invocation is a successive invoke from a previous execution"""
        is_successive = self._invocation_type == self.Events.SUCCESSIVE_INVOKE

        LOGGER.debug('Is successive invocation: %s', is_successive)
        return is_successive
예제 #21
0
 def last_timestamp(self):
     """Get the last timestamp"""
     LOGGER.debug('Getting last_timestamp as: %s', self._last_timestamp)
     return self._last_timestamp
예제 #22
0
 def current_state(self):
     """Get the current state of the execution"""
     LOGGER.debug('Getting current_state: %s', self._current_state)
     return self._current_state