def test_constructor_loads_from_os_when_not_provided():
    """OutputCredentials - Constructor

    When not provided, prefix and aws account id are loaded from the OS Environment."""

    provider = OutputCredentialsProvider('that_service_name', config=CONFIG, region=REGION)
    assert_equal(provider._prefix, 'prefix')
    assert_equal(provider.get_aws_account_id(), '123456789012')
Exemple #2
0
    def __init__(self, config):
        self.region = REGION
        self.config = config

        self._credentials_provider = OutputCredentialsProvider(
            self.__service__,
            config=config,
            defaults=self._get_default_properties(),
            region=self.region)
Exemple #3
0
    def setup(self):
        service_name = 'service'
        defaults = {'property2': 'abcdef'}
        prefix = 'test_asdf'
        aws_account_id = '1234567890'

        self._provider = OutputCredentialsProvider(
            service_name,
            config=CONFIG,
            defaults=defaults,
            region=REGION,
            prefix=prefix,
            aws_account_id=aws_account_id)

        # Pre-create the bucket so we dont get a "Bucket does not exist" error
        s3_driver = S3Driver('test_asdf', 'service', REGION)
        put_mock_s3_object(s3_driver.get_s3_secrets_bucket(),
                           'laskdjfaouhvawe', 'lafhawef', REGION)
Exemple #4
0
class OutputDispatcher(object):
    """OutputDispatcher is the base class to handle routing alerts to outputs

    Public methods:
        format_output_config: returns a formatted version of the outputs configuration
            that is to be written to disk
        get_user_defined_properties: returns any properties for this output that must be
            provided by the user. must be implemented by subclasses
        dispatch: handles the actual sending of alerts to the configured service. must
            be implemented by subclass
    """
    __metaclass__ = ABCMeta
    __service__ = NotImplemented

    # How many times it will attempt to retry something failing using backoff
    MAX_RETRY_ATTEMPTS = 5

    # _DEFAULT_REQUEST_TIMEOUT indicates how long the requests library will wait before timing
    # out for both get and post requests. This applies to both connection and read timeouts
    _DEFAULT_REQUEST_TIMEOUT = 3.05

    def __init__(self, config):
        self.region = REGION
        self.config = config

        self._credentials_provider = OutputCredentialsProvider(
            self.__service__,
            config=config,
            defaults=self._get_default_properties(),
            region=self.region)

    def _load_creds(self, descriptor):
        """Loads a dict of credentials relevant to this output descriptor

        Args:
            descriptor (str): unique identifier used to look up these credentials

        Returns:
            dict: the loaded credential info needed for sending alerts to this service
                or None if nothing gets loaded
        """
        return self._credentials_provider.load_credentials(descriptor)

    @classmethod
    def _log_status(cls, success, descriptor):
        """Log the status of sending the alerts

        Args:
            success (bool or dict): Indicates if the dispatching of alerts was successful
            descriptor (str): Service descriptor
        """
        if success:
            LOGGER.info('Successfully sent alert to %s:%s', cls.__service__,
                        descriptor)
        else:
            LOGGER.error('Failed to send alert to %s:%s', cls.__service__,
                         descriptor)

    @classmethod
    def _catch_exceptions(cls):
        """Classmethod that returns a tuple of the exceptions to catch"""
        default_exceptions = (OutputRequestFailure, ReqTimeout)
        exceptions = cls._get_exceptions_to_catch()
        if not exceptions:
            return default_exceptions

        if isinstance(exceptions, tuple):
            return default_exceptions + exceptions

        return default_exceptions + (exceptions, )

    @classmethod
    def _get_exceptions_to_catch(cls):
        """Classmethod that returns a tuple of the exceptions to catch"""

    @classmethod
    def _put_request(cls, url, params=None, headers=None, verify=True):
        """Method to return the json loaded response for this PUT request

        Args:
            url (str): Endpoint for this request
            params (dict): Payload to send with this request
            headers (dict): Dictionary containing request-specific header parameters
            verify (bool): Whether or not the server's SSL certificate should be verified
        Returns:
            dict: Contains the http response object
        """
        return requests.put(url,
                            headers=headers,
                            json=params,
                            verify=verify,
                            timeout=cls._DEFAULT_REQUEST_TIMEOUT)

    @classmethod
    def _put_request_retry(cls, url, params=None, headers=None, verify=True):
        """Method to return the json loaded response for this PUT request
        This method implements support for backoff to retry failed requests

        Args:
            url (str): Endpoint for this request
            params (dict): Payload to send with this request
            headers (dict): Dictionary containing request-specific header parameters
            verify (bool): Whether or not the server's SSL certificate should be verified
        Returns:
            dict: Contains the http response object
        Raises:
            OutputRequestFailure
        """
        @retry_on_exception(cls._catch_exceptions())
        def do_put_request():
            """Decorated nested function to perform the request with retry/backoff"""
            resp = cls._put_request(url, params, headers, verify)
            success = cls._check_http_response(resp)
            if not success:
                raise OutputRequestFailure(resp)

            return resp

        return do_put_request()

    @classmethod
    def _get_request(cls, url, params=None, headers=None, verify=True):
        """Method to return the json loaded response for this GET request

        Args:
            url (str): Endpoint for this request
            params (dict): Payload to send with this request
            headers (dict): Dictionary containing request-specific header parameters
            verify (bool): Whether or not the server's SSL certificate should be verified
        Returns:
            dict: Contains the http response object
        """
        return requests.get(url,
                            headers=headers,
                            params=params,
                            verify=verify,
                            timeout=cls._DEFAULT_REQUEST_TIMEOUT)

    @classmethod
    def _get_request_retry(cls, url, params=None, headers=None, verify=True):
        """Method to return the json loaded response for this GET request
        This method implements support for backoff to retry failed requests

        Args:
            url (str): Endpoint for this request
            params (dict): Payload to send with this request
            headers (dict): Dictionary containing request-specific header parameters
            verify (bool): Whether or not the server's SSL certificate should be verified
        Returns:
            dict: Contains the http response object
        Raises:
            OutputRequestFailure
        """
        @retry_on_exception(cls._catch_exceptions())
        def do_get_request():
            """Decorated nested function to perform the request with retry/backoff"""
            resp = cls._get_request(url, params, headers, verify)
            success = cls._check_http_response(resp)
            if not success:
                raise OutputRequestFailure(resp)

            return resp

        return do_get_request()

    @classmethod
    def _post_request(cls, url, data=None, headers=None, verify=True):
        """Method to return the json loaded response for this POST request

        Args:
            url (str): Endpoint for this request
            data (dict): Payload to send with this request
            headers (dict): Dictionary containing request-specific header parameters
            verify (bool): Whether or not the server's SSL certificate should be verified
        Returns:
            dict: Contains the http response object
        """
        return requests.post(url,
                             headers=headers,
                             json=data,
                             verify=verify,
                             timeout=cls._DEFAULT_REQUEST_TIMEOUT)

    @classmethod
    def _post_request_retry(cls, url, data=None, headers=None, verify=True):
        """Method to return the json loaded response for this POST request
        This method implements support for backoff to retry failed requests

        Args:
            url (str): Endpoint for this request
            data (dict): Payload to send with this request
            headers (dict): Dictionary containing request-specific header parameters
            verify (bool): Whether or not the server's SSL certificate should be verified
        Returns:
            dict: Contains the http response object
        Raises:
            OutputRequestFailure
        """
        @retry_on_exception(cls._catch_exceptions())
        def do_post_request():
            """Decorated nested function to perform the request with retry/backoff"""
            resp = cls._post_request(url, data, headers, verify)
            success = cls._check_http_response(resp)
            if not success:
                raise OutputRequestFailure(resp)

            return resp

        return do_post_request()

    @classmethod
    def _check_http_response(cls, response):
        """Method for checking for a valid HTTP response code

        Args:
            response (requests.Response): Response object from requests

        Returns:
            bool: Indicator of whether or not this request was successful
        """
        success = response is not None and (200 <= response.status_code <= 299)
        if not success:
            LOGGER.error('Encountered an error while sending to %s:\n%s',
                         cls.__service__, response.content)
        return success

    @classmethod
    def _get_default_properties(cls):
        """Base method for retrieving properties that should be hard-coded for this
        output service integration. This could include information such as a static
        url used for sending the alerts to this service, a static port, or other
        non-sensitive information.

        If information of this sort is needed, this should be overridden in output subclasses.

        NOTE: This should not contain any sensitive or use-case specific data. Information
        such as this should be retrieved from the user using `get_user_defined_properties()`
        so the user is prompted for the sensitive information at configuration time and said
        information is then sent to kms for encryption and s3 for storage.

        Returns:
            dict: Contains various default items for this output (ie: url)
        """
        pass

    @classmethod
    def format_output_config(cls, service_config, values):
        """Add this descriptor to the list of descriptor this service
           If the service doesn't exist, a new entry is added to an empty list

        Args:
            service_config (dict): Loaded configuration as a dictionary
            values (OrderedDict): Contains various OutputProperty items
        Returns:
            [list<string>] List of descriptors for this service
        """
        return service_config.get(cls.__service__,
                                  []) + [values['descriptor'].value]

    @classmethod
    @abstractmethod
    def get_user_defined_properties(cls):
        """Base method for retrieving properties that must be assigned by the user when
        configuring a new output for this service. This should include any information that
        is sensitive or use-case specific. For intance, if the url needed for this integration
        is unique to your situation, it should be supplied here.

        If information of this sort is needed, it should be added to the method that
        overrides this one in the subclass.

        At the very minimum, subclass functions should return an OrderedDict that contains
        the key 'descriptor' with a description of the integration being configured

        Returns:
            OrderedDict: Contains various OutputProperty items
        """

    @abstractmethod
    def _dispatch(self, alert, descriptor):
        """Send alerts to the given service.

        Args:
            alert (Alert): Alert instance which triggered a rule
            descriptor (str): Output descriptor (e.g. slack channel, pd integration)

        Returns:
            bool: True if alert was sent successfully, False otherwise
        """

    def dispatch(self, alert, output):
        """Send alerts to the given service.

        This wraps the protected subclass method of _dispatch to aid in usability

        Args:
            alert (Alert): Alert instance which triggered a rule
            output (str): Fully described output (e.g. "demisto:version1", "pagerduty:engineering"

        Returns:
            bool: True if alert was sent successfully, False otherwise
        """
        LOGGER.info('Sending %s to %s', alert, output)
        descriptor = output.split(':')[1]
        try:
            sent = bool(self._dispatch(alert, descriptor))
        except Exception:  # pylint: disable=broad-except
            LOGGER.exception('Exception when sending %s to %s. Alert:\n%s',
                             alert, output, repr(alert))
            sent = False

        self._log_status(sent, descriptor)

        return sent
class TestOutputCredentialsProvider(object):
    def setup(self):
        service_name = 'service'
        defaults = {
            'property2': 'abcdef'
        }
        prefix = 'test_asdf'
        aws_account_id = '1234567890'

        self._provider = OutputCredentialsProvider(
            service_name,
            config=CONFIG,
            defaults=defaults,
            region=REGION,
            prefix=prefix,
            aws_account_id=aws_account_id
        )

        # Pre-create the bucket so we dont get a "Bucket does not exist" error
        s3_driver = S3Driver('test_asdf', 'service', REGION)
        put_mock_s3_object(s3_driver.get_s3_secrets_bucket(), 'laskdjfaouhvawe', 'lafhawef', REGION)

    @mock_kms
    def test_save_and_load_credentials(self):
        """OutputCredentials - Save and Load Credentials

        Not only tests how save_credentials() interacts with load_credentials(), but also tests
        that cred_requirement=False properties are not saved. Also tests that default values
        are merged into the final credentials dict as appropriate."""

        descriptor = 'test_save_and_load_credentials'
        props = OrderedDict([
            ('property1',
             OutputProperty(description='This is a property and not a cred so it will not save')),
            ('property2',
             OutputProperty(description='Neither will this')),
            ('credential1',
             OutputProperty(description='Hello world',
                            value='this is a super secret secret, shhhh!',
                            mask_input=True,
                            cred_requirement=True)),
            ('credential2',
             OutputProperty(description='This appears too!',
                            value='where am i?',
                            mask_input=True,
                            cred_requirement=True)),
        ])

        # Save credential
        assert_true(self._provider.save_credentials(descriptor, KMS_ALIAS, props))

        # Pull it out
        creds_dict = self._provider.load_credentials(descriptor)
        expectation = {
            'property2': 'abcdef',
            'credential1': 'this is a super secret secret, shhhh!',
            'credential2': 'where am i?',
        }
        assert_equal(creds_dict, expectation)

    @mock_kms
    def test_load_credentials_multiple(self):
        """OutputCredentials - Load Credentials Loads from Cache Driver

        This test ensures that we only hit S3 once during, and that subsequent calls are routed
        to the Cache driver. Currently the cache driver is configured as Ephemeral."""

        descriptor = 'test_load_credentials_pulls_from_cache'
        props = OrderedDict([
            ('credential1',
             OutputProperty(description='Hello world',
                            value='there is no cow level',
                            mask_input=True,
                            cred_requirement=True)),
        ])

        # Save credential
        self._provider.save_credentials(descriptor, KMS_ALIAS, props)

        # Pull it out (Normal expected behavior)
        creds_dict = self._provider.load_credentials(descriptor)
        expectation = {'credential1': 'there is no cow level', 'property2': 'abcdef'}
        assert_equal(creds_dict, expectation)

        # Now we yank the S3 driver out of the driver pool
        # FIXME (derek.wang): Another way to do this is to install a spy on moto and make assertions
        #                     on the number of times it is called.
        assert_is_instance(self._provider._drivers[1], S3Driver)
        self._provider._drivers[1] = None
        self._provider._core_driver = None

        # Load again and see if it still is able to load without S3
        assert_equal(self._provider.load_credentials(descriptor), expectation)

        # Double-check; Examine the Driver guts and make sure that the EphemeralDriver has the
        # value cached.
        ep_driver = self._provider._drivers[0]
        assert_is_instance(ep_driver, EphemeralUnencryptedDriver)

        assert_true(ep_driver.has_credentials(descriptor))
        creds = ep_driver.load_credentials(descriptor)
        assert_equal(json.loads(creds.data())['credential1'], 'there is no cow level')

    @patch('logging.Logger.error')
    def test_load_credentials_returns_none_on_driver_failure(self, logging_error): #pylint: disable=invalid-name
        """OutputCredentials - Load Credentials Returns None on Driver Failure"""
        descriptor = 'descriptive'

        # To pretend all drivers fail, we can just remove all of the drivers.
        self._provider._drivers = []
        self._provider._core_driver = None

        creds_dict = self._provider.load_credentials(descriptor)
        assert_is_none(creds_dict)
        logging_error.assert_called_with('All drivers failed to retrieve credentials for [%s.%s]',
                                         'service',
                                         descriptor)