class Saml(object):

    SAML Wrapper around pysaml2.

    Implements SAML2 Service Provider functionality for Flask.

    def __init__(self, config, attribute_map=None):
        """Initialize SAML Service Provider.

            config (dict): Service Provider config info in dict form
            attribute_map (dict): Mapping of attribute keys to user data
        self._config = SPConfig()
        if config['metadata'].get('config'):
            # Hacked in a way to get the IdP metadata from a python dict
            # rather than having to resort to loading XML from file or http.
            idp_config = IdPConfig()
            idp_entityid = config['metadata']['config'][0]['entityid']
            idp_metadata_str = str(entity_descriptor(idp_config, 24))
            LOGGER.debug('IdP XML Metadata for %s: %s' % (
                idp_entityid, idp_metadata_str))
                idp_metadata_str, idp_entityid)
        self.attribute_map = {}
        if attribute_map is not None:
            self.attribute_map = attribute_map

    def authenticate(self, next_url='/', binding=BINDING_HTTP_REDIRECT):
        """Start SAML Authentication login process.

            next_url (string): HTTP URL to return user to when authentication
                is complete.
            binding (binding): Saml2 binding method to use for request,
                default BINDING_HTTP_REDIRECT (don't change til HTTP_POST
                support is complete in pysaml2.

            Flask Response object to return to user containing either
                HTTP_REDIRECT or HTTP_POST SAML message.

            AuthException: when unable to locate valid IdP.
            BadRequest: when invalid result returned from SAML client.
        # find configured for IdP for requested binding method
        idp_entityid = ''
        idps = self._config.idps().keys()
        for idp in idps:
            if self._config.single_sign_on_services(idp, binding) != []:
                idp_entityid = idp
        if idp_entityid == '':
            raise AuthException('Unable to locate valid IdP for this request')
        # fail if signing requested but no private key configured
        if self._config.authn_requests_signed == 'true':
            if not self._config.key_file \
                or not os.path.exists(self._config.key_file):
                raise AuthException(
                    'Signature requested for this Saml authentication request,'
                    ' but no private key file configured')

        LOGGER.debug('Connecting to Identity Provider %s' % idp_entityid)
        # retrieve cache
        outstanding_queries_cache = \
            AuthDictCache(session, '_saml_outstanding_queries')

        LOGGER.debug('Outstanding queries cache %s' % (

        # make pysaml2 call to authenticate
        client = Saml2Client(self._config, logger=LOGGER)
        (session_id, result) = client.authenticate(

        # The psaml2 source for this method indicates that BINDING_HTTP_POST
        # should not be used right now to authenticate. Regardless, we'll
        # check for it and act accordingly.

        if binding == BINDING_HTTP_REDIRECT:
            LOGGER.debug('Redirect to Identity Provider %s ( %s )' % (
                idp_entityid, result))
            response = make_response('', 302, dict([result]))
        elif binding == BINDING_HTTP_POST:
            LOGGER.warn('POST binding used to authenticate is not currently'
                ' supported by pysaml2 release version. Fix in place in repo.')
            LOGGER.debug('Post to Identity Provider %s ( %s )' % (
                idp_entityid, result))
            response = make_response('\n'.join(result), 200)
            raise BadRequest('Invalid result returned from SAML client')

            'Saving session_id ( %s ) in outstanding queries' % session_id)
        # cache the outstanding query
        outstanding_queries_cache.update({session_id: next_url})

        LOGGER.debug('Outstanding queries cache %s' % (

        return response

    def handle_assertion(self, request):
        """Handle SAML Authentication login assertion (POST).

            request (Request): Flask request object for this HTTP transaction.

            User Id (string), User attributes (dict), Redirect Flask response
                object to return user to now that authentication is complete.

            BadRequest: when error with SAML response from Identity Provider.
            AuthException: when unable to locate uid attribute in response.
        if not request.form.get('SAMLResponse'):
            raise BadRequest('SAMLResponse missing from POST')
        # retrieve cache
        outstanding_queries_cache = \
            AuthDictCache(session, '_saml_outstanding_queries')
        identity_cache = IdentityCache(session, '_saml_identity')

        LOGGER.debug('Outstanding queries cache %s' % (
        LOGGER.debug('Identity cache %s' % identity_cache)

        # use pysaml2 to process the SAML authentication response
        client = Saml2Client(self._config, identity_cache=identity_cache,
        saml_response = client.response(
        if saml_response is None:
            raise BadRequest('SAML response is invalid')
        # make sure outstanding query cache is cleared for this session_id
        session_id = saml_response.session_id()
        if session_id in outstanding_queries_cache.keys():
            del outstanding_queries_cache[session_id]
        # retrieve session_info
        saml_session_info = saml_response.session_info()
        LOGGER.debug('SAML Session Info ( %s )' % saml_session_info)
        # retrieve user data via API
            if self.attribute_map.get('uid', 'name_id') == 'name_id':
                user_id = saml_session_info.get('name_id')
                user_id = saml_session_info['ava'] \
            raise AuthException('Unable to find "%s" attribute in response' % (
                self.attribute_map.get('uid', 'name_id')))
        # Future: map attributes to user info
        user_attributes = dict()
        # set subject Id in cache to retrieved name_id
        session['_saml_subject_id'] = saml_session_info.get('name_id')

        LOGGER.debug('Outstanding queries cache %s' % (
        LOGGER.debug('Identity cache %s' % session['_saml_identity'])
        LOGGER.debug('Subject Id %s' % session['_saml_subject_id'])

        relay_state = request.form.get('RelayState', '/')
        LOGGER.debug('Returning redirect to %s' % relay_state)
        return user_id, user_attributes, redirect(relay_state)

    def logout(self, next_url='/'):
        """Start SAML Authentication logout process.

            next_url (string): HTTP URL to return user to when logout is

            Flask Response object to return to user containing either
                HTTP_REDIRECT or HTTP_POST SAML message.

            AuthException: when unable to resolve Identity Provider single logout end-point.
        # retrieve cache
        state_cache = AuthDictCache(session, '_saml_state')
        identity_cache = IdentityCache(session, '_saml_identity')
        subject_id = session.get('_saml_subject_id')
        # don't logout if not logged in
        if subject_id is None:
            raise AuthException('Unable to retrieve subject id for logout')
        # fail if signing requested but no private key configured
        if self._config.logout_requests_signed == 'true':
            if not self._config.key_file \
                or not os.path.exists(self._config.key_file):
                raise AuthException(
                    'Signature requested for this Saml logout request,'
                    ' but no private key file configured')

        LOGGER.debug('State cache %s' % state_cache)
        LOGGER.debug('Identity cache %s' % identity_cache)
        LOGGER.debug('Subject Id %s' % subject_id)

        # use pysaml2 to initiate the SAML logout request
        client = Saml2Client(self._config, state_cache=state_cache,
            identity_cache=identity_cache, logger=LOGGER)
        saml_response = client.global_logout(subject_id,

        # sync the state to cache

        LOGGER.debug('State cache %s' % session['_saml_state'])
        LOGGER.debug('Identity cache %s' % session['_saml_identity'])

        if saml_response[1] == "": # used SOAP BINDING successfully
            return redirect(next_url)

        LOGGER.debug('Returning Response from SAML for continuation of the'
            ' logout process')
        return make_response('\n'.join(saml_response[3]),
            saml_response[1], saml_response[2]) # body, status, headers

    def _handle_logout_request(self, client, request, subject_id, binding):
        """Handle SAML Authentication logout request (GET).

            client (Saml2Client): instance of SAML client class.
            request (Request): Flask request object for this HTTP transaction.
            subject_id (string): Id of the subject we are processing the
                logout for.
            binding (string): the SAML binding method being used for this

            Flask Response object to return to user containing
                HTTP_REDIRECT SAML message.

            BadRequest: when SAML request data is missing.
            AuthException: when SAML request indicates logout failed.
        LOGGER.debug('Received a logout request from Identity Provider')

        # pysaml2 logout_request currently only returns for
        # BINDING_HTTP_REDIRECT. We will have it fail for anything
        # other than the header 'Location'

            headers, _success = client.logout_request(
                request.values, subject_id, binding=binding)
        except TypeError:
            raise BadRequest('SAML request is invalid')
            assert headers is not None
            assert headers[0][0] == 'Location'
            return redirect(headers[0][1])
            raise AuthException('An error occurred during logout')

    def _handle_logout_response(self, client, request, binding, next_url):
        """Handle SAML Authentication logout response (GET or POST).

            client (Saml2Client): instance of SAML client class.
            request (Request): Flask request object for this HTTP transaction.
            binding (string): the SAML binding method being used for this
            next_url (string): URL to get redirected to if all is successful.

            Flask Response object to return to user containing
                HTTP_REDIRECT SAML message.

            BadRequest: when SAML response data is missing.
            AuthException: when SAML response indicates logout failed.
        LOGGER.debug('Received a logout response from Identity Provider')
            saml_response = client.logout_response(
                request.values['SAMLResponse'], binding=binding)
        except TypeError:
            raise BadRequest('SAML response is invalid')
        if saml_response:
            if saml_response[1] == '': # used SOAP BINDING successfully
                response = redirect(next_url)
                # body, status, headers
                response = make_response('\n'.join(saml_response[3]),
                    saml_response[1], saml_response[2])
                # pysaml2 returns an empty 200 in some cases,
                # we'll redirect instead
                if response.status_code == 200 and not
                    response = redirect(next_url)
            raise AuthException('An error occurred during logout')
        return response

    def handle_logout(self, request, next_url='/'):
        """Handle SAML Authentication logout request/response.

            request (Request): Flask request object for this HTTP transaction.
            next_url (string): URL to get redirected to if all is successful.

            (boolean) Success, Flask Response object to return to user
                containing HTTP_REDIRECT SAML message.

            BadRequest: when SAML request/response data is missing.
        # retrieve cache
        state_cache = AuthDictCache(session, '_saml_state')
        identity_cache = IdentityCache(session, '_saml_identity')
        subject_id = session.get('_saml_subject_id')

        LOGGER.debug('State cache %s' % state_cache)
        LOGGER.debug('Identity cache %s' % identity_cache)
        LOGGER.debug('Subject Id %s' % subject_id)

        # use pysaml2 to complete the SAML logout request
        client = Saml2Client(self._config, state_cache=state_cache,
            identity_cache=identity_cache, logger=LOGGER)
        # let's try to figure out what binding is being used and what type of
        # logout call we are handling
        if request.args:
            binding = BINDING_HTTP_REDIRECT
        elif request.form:
            binding = BINDING_HTTP_POST
            # The SOAP binding is only valid on logout requests which currently
            # pysaml2 doesn't support.
            raise BadRequest('Unable to find supported binding')

        if 'SAMLRequest' in request.values:
            response = self._handle_logout_request(
                client, request, subject_id, binding)
        elif 'SAMLResponse' in request.values:
            response = self._handle_logout_response(
                client, request, binding, next_url)
            raise BadRequest('Unable to find SAMLRequest or SAMLResponse')

        # cache the state and remove subject if logout was successful
        success = identity_cache.get_identity(subject_id) == ({}, [])
        if success:

        LOGGER.debug('State cache %s' % session['_saml_state'])
        LOGGER.debug('Identity cache %s' % session['_saml_identity'])

            'Returning redirect to complete/continue the logout process')
        return success, response

    def get_metadata(self):
        """Returns SAML Service Provider Metadata"""
        edesc = entity_descriptor(self._config, 24)
        if self._config.key_file:
            edesc = sign_entity_descriptor(edesc, 24, None, security_context(self._config))
        response = make_response(str(edesc))
        response.headers['Content-type'] = 'text/xml; charset=utf-8'
        return response