Ejemplo n.º 1
0
def test_verify_id_token_at_hash_fail():
    token = 'AccessTokenWhichCouldBeASignedJWT'
    token2 = 'ACompletelyOtherAccessToken'
    lhsh = left_hash(token)

    idt = IdToken(**{
        "sub": "553df2bcf909104751cfd8b2",
        "aud": [
            "5542958437706128204e0000",
            "554295ce3770612820620000"
            ],
        "auth_time": 1441364872,
        "azp": "554295ce3770612820620000",
        "at_hash": lhsh
        })

    kj = KeyJar()
    kj.add_symmetric("", 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig'])
    kj.add_symmetric("https://sso.qa.7pass.ctf.prosiebensat1.com",
                     'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig'])
    packer = JWT(kj, sign_alg='HS256',
                 iss="https://sso.qa.7pass.ctf.prosiebensat1.com",
                 lifetime=3600)
    _jws = packer.pack(**idt.to_dict())
    msg = AuthorizationResponse(access_token=token2, id_token=_jws)
    with pytest.raises(AtHashError):
        verify_id_token(msg, check_hash=True, keyjar=kj,
                        iss="https://sso.qa.7pass.ctf.prosiebensat1.com",
                        client_id="554295ce3770612820620000")
Ejemplo n.º 2
0
    def __init__(
        self,
        name,
        sdb,
        cdb,
        userinfo,
        client_authn,
        urlmap=None,
        ca_certs="",
        keyjar=None,
        hostname="",
        dist_claims_mode=None,
    ):
        Provider.__init__(
            self, name, sdb, cdb, None, userinfo, None, client_authn, "", urlmap, ca_certs, keyjar, hostname
        )

        if keyjar is None:
            keyjar = KeyJar(ca_certs)

        for cid, _dic in cdb.items():
            try:
                keyjar.add_symmetric(cid, _dic["client_secret"], ["sig", "ver"])
            except KeyError:
                pass

        self.srvmethod = OICCServer(keyjar=keyjar)
        self.dist_claims_mode = dist_claims_mode
        self.info_store = {}
        self.claims_userinfo_endpoint = ""
Ejemplo n.º 3
0
def test_verify_id_token_mismatch_aud_azp():
    idt = IdToken(
        **{
            "sub": "553df2bcf909104751cfd8b2",
            "aud": ["5542958437706128204e0000", "554295ce3770612820620000"],
            "auth_time": 1441364872,
            "azp": "aaaaaaaaaaaaaaaaaaaa",
        })

    kj = KeyJar()
    kj.add_symmetric("", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"])
    kj.add_symmetric(
        "https://sso.qa.7pass.ctf.prosiebensat1.com",
        "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ",
        ["sig"],
    )
    packer = JWT(kj,
                 sign_alg="HS256",
                 iss="https://example.com/as",
                 lifetime=3600)
    _jws = packer.pack(**idt.to_dict())
    msg = AuthorizationResponse(id_token=_jws)
    with pytest.raises(ValueError):
        verify_id_token(
            msg,
            keyjar=kj,
            iss="https://sso.qa.7pass.ctf.prosiebensat1.com",
            client_id="aaaaaaaaaaaaaaaaaaaa",
        )
Ejemplo n.º 4
0
def test_verify_id_token_missing_c_hash():
    code = "AccessCode1"

    idt = IdToken(
        **{
            "sub": "553df2bcf909104751cfd8b2",
            "aud": ["5542958437706128204e0000", "554295ce3770612820620000"],
            "auth_time": 1441364872,
            "azp": "554295ce3770612820620000",
        })

    kj = KeyJar()
    kj.add_symmetric("", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"])
    kj.add_symmetric(
        "https://sso.qa.7pass.ctf.prosiebensat1.com",
        "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ",
        ["sig"],
    )
    packer = JWT(
        kj,
        sign_alg="HS256",
        iss="https://sso.qa.7pass.ctf.prosiebensat1.com",
        lifetime=3600,
    )
    _jws = packer.pack(**idt.to_dict())
    msg = AuthorizationResponse(code=code, id_token=_jws)
    with pytest.raises(MissingRequiredAttribute):
        verify_id_token(
            msg,
            check_hash=True,
            keyjar=kj,
            iss="https://sso.qa.7pass.ctf.prosiebensat1.com",
            client_id="554295ce3770612820620000",
        )
Ejemplo n.º 5
0
def test_verify_id_token_at_hash_and_chash():
    token = 'AccessTokenWhichCouldBeASignedJWT'
    at_hash = left_hash(token)
    code = 'AccessCode1'
    c_hash = left_hash(code)

    idt = IdToken(**{
        "sub": "553df2bcf909104751cfd8b2",
        "aud": [
            "5542958437706128204e0000",
            "554295ce3770612820620000"
            ],
        "auth_time": 1441364872,
        "azp": "554295ce3770612820620000",
        "at_hash": at_hash,
        'c_hash': c_hash
        })

    kj = KeyJar()
    kj.add_symmetric("", 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig'])
    kj.add_symmetric("https://sso.qa.7pass.ctf.prosiebensat1.com",
                     'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig'])
    packer = JWT(kj, sign_alg='HS256',
                 iss="https://sso.qa.7pass.ctf.prosiebensat1.com",
                 lifetime=3600)
    _jws = packer.pack(**idt.to_dict())
    msg = AuthorizationResponse(access_token=token, id_token=_jws, code=code)
    verify_id_token(msg, check_hash=True, keyjar=kj,
                    iss="https://sso.qa.7pass.ctf.prosiebensat1.com",
                    client_id="554295ce3770612820620000")
Ejemplo n.º 6
0
def test_verify_id_token_missing_c_hash():
    code = 'AccessCode1'

    idt = IdToken(**{
        "sub": "553df2bcf909104751cfd8b2",
        "aud": [
            "5542958437706128204e0000",
            "554295ce3770612820620000"
            ],
        "auth_time": 1441364872,
        "azp": "554295ce3770612820620000",
        })

    kj = KeyJar()
    kj.add_symmetric("", 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig'])
    kj.add_symmetric("https://sso.qa.7pass.ctf.prosiebensat1.com",
                     'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig'])
    packer = JWT(kj, sign_alg='HS256',
                 iss="https://sso.qa.7pass.ctf.prosiebensat1.com",
                 lifetime=3600)
    _jws = packer.pack(**idt.to_dict())
    msg = AuthorizationResponse(code=code, id_token=_jws)
    with pytest.raises(MissingRequiredAttribute):
        verify_id_token(msg, check_hash=True, keyjar=kj,
                        iss="https://sso.qa.7pass.ctf.prosiebensat1.com",
                        client_id="554295ce3770612820620000")
Ejemplo n.º 7
0
    def __init__(self,
                 name,
                 sdb,
                 cdb,
                 userinfo,
                 client_authn,
                 urlmap=None,
                 ca_certs="",
                 keyjar=None,
                 hostname="",
                 dist_claims_mode=None):
        Provider.__init__(self, name, sdb, cdb, None, userinfo, None,
                          client_authn, "", urlmap, ca_certs, keyjar, hostname)

        if keyjar is None:
            keyjar = KeyJar(ca_certs)

        for cid, _dic in cdb.items():
            try:
                keyjar.add_symmetric(cid, _dic["client_secret"],
                                     ["sig", "ver"])
            except KeyError:
                pass

        self.srvmethod = OICCServer(keyjar=keyjar)
        self.dist_claims_mode = dist_claims_mode
        self.info_store = {}
        self.claims_userinfo_endpoint = ""
Ejemplo n.º 8
0
def test_verify_id_token_at_hash():
    token = "AccessTokenWhichCouldBeASignedJWT"
    lhsh = left_hash(token)

    idt = IdToken(
        **{
            "sub": "553df2bcf909104751cfd8b2",
            "aud": ["5542958437706128204e0000", "554295ce3770612820620000"],
            "auth_time": 1441364872,
            "azp": "554295ce3770612820620000",
            "at_hash": lhsh,
        })

    kj = KeyJar()
    kj.add_symmetric("", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"])
    kj.add_symmetric(
        "https://sso.qa.7pass.ctf.prosiebensat1.com",
        "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ",
        ["sig"],
    )
    packer = JWT(
        kj,
        sign_alg="HS256",
        iss="https://sso.qa.7pass.ctf.prosiebensat1.com",
        lifetime=3600,
    )
    _jws = packer.pack(**idt.to_dict())
    msg = AuthorizationResponse(access_token=token, id_token=_jws)
    verify_id_token(
        msg,
        check_hash=True,
        keyjar=kj,
        iss="https://sso.qa.7pass.ctf.prosiebensat1.com",
        client_id="554295ce3770612820620000",
    )
Ejemplo n.º 9
0
    def __init__(
        self,
        name,
        sdb,
        cdb,
        userinfo,
        client_authn,
        urlmap=None,
        keyjar=None,
        hostname="",
        dist_claims_mode=None,
        verify_ssl=None,
        settings=None,
    ):
        self.settings = settings or OicProviderSettings()
        if verify_ssl is not None:
            warnings.warn(
                "`verify_ssl` is deprecated, please use `settings` instead if you need to set a non-default value.",
                DeprecationWarning,
                stacklevel=2,
            )
            self.settings.verify_ssl = verify_ssl
        Provider.__init__(
            self,
            name,
            sdb,
            cdb,
            None,
            userinfo,
            None,
            client_authn,
            None,
            urlmap,
            keyjar,
            hostname,
            settings=self.settings,
        )

        if keyjar is None:
            keyjar = KeyJar(verify_ssl=verify_ssl)

        for cid, _dic in cdb.items():
            try:
                keyjar.add_symmetric(cid, _dic["client_secret"], ["sig", "ver"])
            except KeyError:
                pass

        self.srvmethod = OICCServer(keyjar=keyjar)
        self.dist_claims_mode = dist_claims_mode
        self.info_store = {}  # type: Dict[str, Any]
        self.claims_userinfo_endpoint = ""
Ejemplo n.º 10
0
class TestBackchannelLogout(object):
    @pytest.fixture(autouse=True)
    def setup(self):
        self.kj = KeyJar()
        self.kj.add_symmetric("", "dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ", ["sig"])
        self.key = self.kj.get_signing_key("oct")
        lt = LogoutToken(
            iss="https://example.com",
            aud=["https://rp.example.org"],
            events={BACK_CHANNEL_LOGOUT_EVENT: {}},
            iat=utc_time_sans_frac(),
            jti=rndstr(16),
            sub="https://example.com/sub",
        )

        self.signed_jwt = lt.to_jwt(key=self.key, algorithm="HS256")

    def test_verify_with_keyjar(self):
        bclr = BackChannelLogoutRequest(logout_token=self.signed_jwt)
        assert bclr.verify(keyjar=self.kj)

        # The signed JWT is replaced by a dictionary with all the verified values
        assert bclr["logout_token"]["iss"] == "https://example.com"

    def test_verify_with_key(self):
        bclr = BackChannelLogoutRequest(logout_token=self.signed_jwt)
        assert bclr.verify(key=self.key)

        # The signed JWT is replaced by a dictionary with all the verified values
        assert bclr["logout_token"]["iss"] == "https://example.com"

    def test_bogus_logout_token(self):
        lt = LogoutToken(
            iss="https://example.com",
            aud=["https://rp.example.org"],
            events={BACK_CHANNEL_LOGOUT_EVENT: {}},
            iat=utc_time_sans_frac(),
            jti=rndstr(16),
            nonce=rndstr(16),
        )
        signed_jwt = lt.to_jwt(key=self.key, algorithm="HS256")
        bclr = BackChannelLogoutRequest(logout_token=signed_jwt)

        with pytest.raises(MessageException):
            bclr.verify(key=self.key)
Ejemplo n.º 11
0
    def __init__(
        self,
        name,
        sdb,
        cdb,
        userinfo,
        client_authn,
        urlmap=None,
        keyjar=None,
        hostname="",
        dist_claims_mode=None,
        verify_ssl=True,
    ):
        Provider.__init__(
            self,
            name,
            sdb,
            cdb,
            None,
            userinfo,
            None,
            client_authn,
            None,
            urlmap,
            keyjar,
            hostname,
            verify_ssl=verify_ssl,
        )

        if keyjar is None:
            keyjar = KeyJar(verify_ssl=verify_ssl)

        for cid, _dic in cdb.items():
            try:
                keyjar.add_symmetric(cid, _dic["client_secret"],
                                     ["sig", "ver"])
            except KeyError:
                pass

        self.srvmethod = OICCServer(keyjar=keyjar)
        self.dist_claims_mode = dist_claims_mode
        self.info_store = {}  # type: Dict[str, Any]
        self.claims_userinfo_endpoint = ""
Ejemplo n.º 12
0
def test_verify_id_token_iss_not_in_keyjar():
    idt = IdToken(**{
        "sub": "553df2bcf909104751cfd8b2",
        "aud": [
            "5542958437706128204e0000",
            "554295ce3770612820620000"
            ],
        "auth_time": 1441364872,
        "azp": "554295ce3770612820620000",
        })

    kj = KeyJar()
    kj.add_symmetric("", 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig'])
    kj.add_symmetric("https://sso.qa.7pass.ctf.prosiebensat1.com",
                     'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig'])
    packer = JWT(kj, sign_alg='HS256', lifetime=3600,
                 iss='https://example.com/op')
    _jws = packer.pack(**idt.to_dict())
    msg = AuthorizationResponse(id_token=_jws)
    with pytest.raises(ValueError):
        verify_id_token(msg, check_hash=True, keyjar=kj,
                        iss="https://sso.qa.7pass.ctf.prosiebensat1.com",
                        client_id="554295ce3770612820620000")
Ejemplo n.º 13
0
def test_verify_id_token():
    idt = IdToken(**{
        "sub": "553df2bcf909104751cfd8b2",
        "aud": [
            "5542958437706128204e0000",
            "554295ce3770612820620000"
            ],
        "auth_time": 1441364872,
        "azp": "554295ce3770612820620000",
        })

    kj = KeyJar()
    kj.add_symmetric("", 'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig'])
    kj.add_symmetric("https://sso.qa.7pass.ctf.prosiebensat1.com",
                     'dYMmrcQksKaPkhdgRNYk3zzh5l7ewdDJ', ['sig'])
    packer = JWT(kj, sign_alg='HS256',
                 iss="https://sso.qa.7pass.ctf.prosiebensat1.com",
                 lifetime=3600)
    _jws = packer.pack(**idt.to_dict())
    msg = AuthorizationResponse(id_token=_jws)
    vidt = verify_id_token(msg, keyjar=kj,
                           iss="https://sso.qa.7pass.ctf.prosiebensat1.com",
                           client_id="554295ce3770612820620000")
    assert vidt
Ejemplo n.º 14
0
from oic.oic import OpenIDRequest
from oic.utils.keyio import KeyJar

request='eyJhbGciOiAiQTEyOEtXIiwgImVuYyI6ICJBMTI4Q0JDLUhTMjU2In0.KLuBoByxG54JdHz5OBjpMjx_6ivPNi6oanRZ5UN38VzcTHw2ftv6FA.Tysc6pZ_AA_X7j95bRSHiQ.YxG8Kf3GVWXnMfzOo7Hva32eHcaNBgpcT3iPIEWq76SgKNCpdnGSKOSiFtJbvCdpXwfneXIAS3uFktQoyo9x698IHp92bAZD9M31G0GfaWh7oZgcHrBkn_QPBFavEQeTSfbvhYya3Wp2U9DrL9CrT6ytTo7mbx6b9drUpSe2waIGJkugOOFCiqr19zXXFDT1Qc04sCGhRwz_0JYMYI9qGULQ0Ws2zQVlcE_iMoA6cFs.gDd8Ns2fJRj18A6gg4-T4g'

keyjar = KeyJar()
keyjar.add_symmetric("jJFjKcsaygxp",
                     "f75695a7a87acccdef6c7c978d5e782db1b947e0f6990b050f58940b")

OpenIDRequest().from_jwt(request, keyjar=keyjar, sender="jJFjKcsaygxp")
Ejemplo n.º 15
0
class Client(PBase):
    _endpoints = ENDPOINTS

    def __init__(self,
                 client_id=None,
                 ca_certs=None,
                 client_authn_method=None,
                 keyjar=None,
                 verify_ssl=True,
                 config=None,
                 client_cert=None):
        """

        :param client_id: The client identifier
        :param ca_certs: Certificates used to verify HTTPS certificates
        :param client_authn_method: Methods that this client can use to
            authenticate itself. It's a dictionary with method names as
            keys and method classes as values.
        :param verify_ssl: Whether the SSL certificate should be verified.
        :return: Client instance
        """

        PBase.__init__(self,
                       ca_certs,
                       verify_ssl=verify_ssl,
                       client_cert=client_cert,
                       keyjar=keyjar)

        self.client_id = client_id
        self.client_authn_method = client_authn_method
        self.verify_ssl = verify_ssl
        # self.secret_type = "basic "

        # self.state = None
        self.nonce = None

        self.grant = {}
        self.state2nonce = {}
        # own endpoint
        self.redirect_uris = [None]

        # service endpoints
        self.authorization_endpoint = None
        self.token_endpoint = None
        self.token_revocation_endpoint = None

        self.request2endpoint = REQUEST2ENDPOINT
        self.response2error = RESPONSE2ERROR
        self.grant_class = Grant
        self.token_class = Token

        self.provider_info = {}
        self._c_secret = None
        self.kid = {"sig": {}, "enc": {}}
        self.authz_req = None

        # the OAuth issuer is the URL of the authorization server's
        # configuration information location
        self.config = config or {}
        try:
            self.issuer = self.config['issuer']
        except KeyError:
            self.issuer = ''
        self.allow = {}
        self.provider_info = {}

    def store_response(self, clinst, text):
        pass

    def get_client_secret(self):
        return self._c_secret

    def set_client_secret(self, val):
        if not val:
            self._c_secret = ""
        else:
            self._c_secret = val
            # client uses it for signing
            # Server might also use it for signing which means the
            # client uses it for verifying server signatures
            if self.keyjar is None:
                self.keyjar = KeyJar()
            self.keyjar.add_symmetric("", str(val))

    client_secret = property(get_client_secret, set_client_secret)

    def reset(self):
        # self.state = None
        self.nonce = None

        self.grant = {}

        self.authorization_endpoint = None
        self.token_endpoint = None
        self.redirect_uris = None

    def grant_from_state(self, state):
        for key, grant in self.grant.items():
            if key == state:
                return grant

        return None

    def _parse_args(self, request, **kwargs):
        ar_args = kwargs.copy()

        for prop in request.c_param.keys():
            if prop in ar_args:
                continue
            else:
                if prop == "redirect_uri":
                    _val = getattr(self, "redirect_uris", [None])[0]
                    if _val:
                        ar_args[prop] = _val
                else:
                    _val = getattr(self, prop, None)
                    if _val:
                        ar_args[prop] = _val

        return ar_args

    def _endpoint(self, endpoint, **kwargs):
        try:
            uri = kwargs[endpoint]
            if uri:
                del kwargs[endpoint]
        except KeyError:
            uri = ""

        if not uri:
            try:
                uri = getattr(self, endpoint)
            except Exception:
                raise MissingEndpoint("No '%s' specified" % endpoint)

        if not uri:
            raise MissingEndpoint("No '%s' specified" % endpoint)

        return uri

    def get_grant(self, state, **kwargs):
        # try:
        # _state = kwargs["state"]
        # if not _state:
        #         _state = self.state
        # except KeyError:
        #     _state = self.state

        try:
            return self.grant[state]
        except KeyError:
            raise GrantError("No grant found for state:'%s'" % state)

    def get_token(self, also_expired=False, **kwargs):
        try:
            return kwargs["token"]
        except KeyError:
            grant = self.get_grant(**kwargs)

            try:
                token = grant.get_token(kwargs["scope"])
            except KeyError:
                token = grant.get_token("")
                if not token:
                    try:
                        token = self.grant[kwargs["state"]].get_token("")
                    except KeyError:
                        raise TokenError("No token found for scope")

        if token is None:
            raise TokenError("No suitable token found")

        if also_expired:
            return token
        elif token.is_valid():
            return token
        else:
            raise TokenError("Token has expired")

    def construct_request(self, request, request_args=None, extra_args=None):
        if request_args is None:
            request_args = {}

        # logger.debug("request_args: %s" % sanitize(request_args))
        kwargs = self._parse_args(request, **request_args)

        if extra_args:
            kwargs.update(extra_args)
            # logger.debug("kwargs: %s" % sanitize(kwargs))
        # logger.debug("request: %s" % sanitize(request))
        return request(**kwargs)

    def construct_Message(self,
                          request=Message,
                          request_args=None,
                          extra_args=None,
                          **kwargs):

        return self.construct_request(request, request_args, extra_args)

    def construct_AuthorizationRequest(self,
                                       request=AuthorizationRequest,
                                       request_args=None,
                                       extra_args=None,
                                       **kwargs):

        if request_args is not None:
            try:  # change default
                new = request_args["redirect_uri"]
                if new:
                    self.redirect_uris = [new]
            except KeyError:
                pass
        else:
            request_args = {}

        if "client_id" not in request_args:
            request_args["client_id"] = self.client_id
        elif not request_args["client_id"]:
            request_args["client_id"] = self.client_id

        return self.construct_request(request, request_args, extra_args)

    def construct_AccessTokenRequest(self,
                                     request=AccessTokenRequest,
                                     request_args=None,
                                     extra_args=None,
                                     **kwargs):

        grant = self.get_grant(**kwargs)

        if not grant.is_valid():
            raise GrantExpired(
                "Authorization Code to old %s > %s" %
                (utc_time_sans_frac(), grant.grant_expiration_time))

        if request_args is None:
            request_args = {}

        request_args["code"] = grant.code

        try:
            request_args['state'] = kwargs['state']
        except KeyError:
            pass

        if "grant_type" not in request_args:
            request_args["grant_type"] = "authorization_code"

        if "client_id" not in request_args:
            request_args["client_id"] = self.client_id
        elif not request_args["client_id"]:
            request_args["client_id"] = self.client_id
        return self.construct_request(request, request_args, extra_args)

    def construct_RefreshAccessTokenRequest(self,
                                            request=RefreshAccessTokenRequest,
                                            request_args=None,
                                            extra_args=None,
                                            **kwargs):

        if request_args is None:
            request_args = {}

        token = self.get_token(also_expired=True, **kwargs)

        request_args["refresh_token"] = token.refresh_token

        try:
            request_args["scope"] = token.scope
        except AttributeError:
            pass

        return self.construct_request(request, request_args, extra_args)

    # def construct_TokenRevocationRequest(self,
    #                                      request=TokenRevocationRequest,
    #                                      request_args=None, extra_args=None,
    #                                      **kwargs):
    #
    #     if request_args is None:
    #         request_args = {}
    #
    #     token = self.get_token(**kwargs)
    #
    #     request_args["token"] = token.access_token
    #     return self.construct_request(request, request_args, extra_args)

    def construct_ResourceRequest(self,
                                  request=ResourceRequest,
                                  request_args=None,
                                  extra_args=None,
                                  **kwargs):

        if request_args is None:
            request_args = {}

        token = self.get_token(**kwargs)

        request_args["access_token"] = token.access_token
        return self.construct_request(request, request_args, extra_args)

    def uri_and_body(self,
                     reqmsg,
                     cis,
                     method="POST",
                     request_args=None,
                     **kwargs):

        if "endpoint" in kwargs and kwargs["endpoint"]:
            uri = kwargs["endpoint"]
        else:
            uri = self._endpoint(self.request2endpoint[reqmsg.__name__],
                                 **request_args)

        uri, body, kwargs = get_or_post(uri, method, cis, **kwargs)
        try:
            h_args = {"headers": kwargs["headers"]}
        except KeyError:
            h_args = {}

        return uri, body, h_args, cis

    def request_info(self,
                     request,
                     method="POST",
                     request_args=None,
                     extra_args=None,
                     lax=False,
                     **kwargs):

        if request_args is None:
            request_args = {}

        try:
            cls = getattr(self, "construct_%s" % request.__name__)
            cis = cls(request_args=request_args,
                      extra_args=extra_args,
                      **kwargs)
        except AttributeError:
            cis = self.construct_request(request, request_args, extra_args)

        if self.events:
            self.events.store('Protocol request', cis)

        if 'nonce' in cis and 'state' in cis:
            self.state2nonce[cis['state']] = cis['nonce']

        cis.lax = lax

        if "authn_method" in kwargs:
            h_arg = self.init_authentication_method(cis,
                                                    request_args=request_args,
                                                    **kwargs)
        else:
            h_arg = None

        if h_arg:
            if "headers" in kwargs.keys():
                kwargs["headers"].update(h_arg["headers"])
            else:
                kwargs["headers"] = h_arg["headers"]

        return self.uri_and_body(request, cis, method, request_args, **kwargs)

    def authorization_request_info(self,
                                   request_args=None,
                                   extra_args=None,
                                   **kwargs):
        return self.request_info(AuthorizationRequest, "GET", request_args,
                                 extra_args, **kwargs)

    def get_urlinfo(self, info):
        if '?' in info or '#' in info:
            parts = urlparse(info)
            scheme, netloc, path, params, query, fragment = parts[:6]
            # either query of fragment
            if query:
                info = query
            else:
                info = fragment
        return info

    def parse_response(self,
                       response,
                       info="",
                       sformat="json",
                       state="",
                       **kwargs):
        """
        Parse a response

        :param response: Response type
        :param info: The response, can be either in a JSON or an urlencoded
            format
        :param sformat: Which serialization that was used
        :param state: The state
        :param kwargs: Extra key word arguments
        :return: The parsed and to some extend verified response
        """

        _r2e = self.response2error

        if sformat == "urlencoded":
            info = self.get_urlinfo(info)

        # if self.events:
        #    self.events.store('Response', info)
        resp = response().deserialize(info, sformat, **kwargs)
        msg = 'Initial response parsing => "{}"'
        logger.debug(msg.format(sanitize(resp.to_dict())))
        if self.events:
            self.events.store('Response', resp.to_dict())

        if "error" in resp and not isinstance(resp, ErrorResponse):
            resp = None
            try:
                errmsgs = _r2e[response.__name__]
            except KeyError:
                errmsgs = [ErrorResponse]

            try:
                for errmsg in errmsgs:
                    try:
                        resp = errmsg().deserialize(info, sformat)
                        resp.verify()
                        break
                    except Exception:
                        resp = None
            except KeyError:
                pass
        elif resp.only_extras():
            resp = None
        else:
            kwargs["client_id"] = self.client_id
            try:
                kwargs['iss'] = self.provider_info['issuer']
            except (KeyError, AttributeError):
                if self.issuer:
                    kwargs['iss'] = self.issuer

            if "key" not in kwargs and "keyjar" not in kwargs:
                kwargs["keyjar"] = self.keyjar

            logger.debug("Verify response with {}".format(sanitize(kwargs)))
            verf = resp.verify(**kwargs)

            if not verf:
                logger.error('Verification of the response failed')
                raise PyoidcError("Verification of the response failed")
            if resp.type() == "AuthorizationResponse" and "scope" not in resp:
                try:
                    resp["scope"] = kwargs["scope"]
                except KeyError:
                    pass

        if not resp:
            logger.error('Missing or faulty response')
            raise ResponseError("Missing or faulty response")

        self.store_response(resp, info)

        if resp.type() in ["AuthorizationResponse", "AccessTokenResponse"]:
            try:
                _state = resp["state"]
            except (AttributeError, KeyError):
                _state = ""

            if not _state:
                _state = state

            try:
                self.grant[_state].update(resp)
            except KeyError:
                self.grant[_state] = self.grant_class(resp=resp)

        return resp

    def init_authentication_method(self,
                                   cis,
                                   authn_method,
                                   request_args=None,
                                   http_args=None,
                                   **kwargs):

        if http_args is None:
            http_args = {}
        if request_args is None:
            request_args = {}

        if authn_method:
            return self.client_authn_method[authn_method](self).construct(
                cis, request_args, http_args, **kwargs)
        else:
            return http_args

    def parse_request_response(self,
                               reqresp,
                               response,
                               body_type,
                               state="",
                               **kwargs):

        if reqresp.status_code in SUCCESSFUL:
            body_type = verify_header(reqresp, body_type)
        elif reqresp.status_code in [302, 303]:  # redirect
            return reqresp
        elif reqresp.status_code == 500:
            logger.error("(%d) %s" %
                         (reqresp.status_code, sanitize(reqresp.text)))
            raise ParseError("ERROR: Something went wrong: %s" % reqresp.text)
        elif reqresp.status_code in [400, 401]:
            # expecting an error response
            if issubclass(response, ErrorResponse):
                pass
        else:
            logger.error("(%d) %s" %
                         (reqresp.status_code, sanitize(reqresp.text)))
            raise HttpError("HTTP ERROR: %s [%s] on %s" %
                            (reqresp.text, reqresp.status_code, reqresp.url))

        if response:
            if body_type == 'txt':
                # no meaning trying to parse unstructured text
                return reqresp.text
            return self.parse_response(response, reqresp.text, body_type,
                                       state, **kwargs)

        # could be an error response
        if reqresp.status_code in [200, 400, 401]:
            if body_type == 'txt':
                body_type = 'urlencoded'
            try:
                err = ErrorResponse().deserialize(reqresp.message,
                                                  method=body_type)
                try:
                    err.verify()
                except PyoidcError:
                    pass
                else:
                    return err
            except Exception:
                pass

        return reqresp

    def request_and_return(self,
                           url,
                           response=None,
                           method="GET",
                           body=None,
                           body_type="json",
                           state="",
                           http_args=None,
                           **kwargs):
        """
        :param url: The URL to which the request should be sent
        :param response: Response type
        :param method: Which HTTP method to use
        :param body: A message body if any
        :param body_type: The format of the body of the return message
        :param http_args: Arguments for the HTTP client
        :return: A cls or ErrorResponse instance or the HTTP response
            instance if no response body was expected.
        """

        if http_args is None:
            http_args = {}

        try:
            resp = self.http_request(url, method, data=body, **http_args)
        except Exception:
            raise

        if "keyjar" not in kwargs:
            kwargs["keyjar"] = self.keyjar

        return self.parse_request_response(resp, response, body_type, state,
                                           **kwargs)

    def do_authorization_request(self,
                                 request=AuthorizationRequest,
                                 state="",
                                 body_type="",
                                 method="GET",
                                 request_args=None,
                                 extra_args=None,
                                 http_args=None,
                                 response_cls=AuthorizationResponse,
                                 **kwargs):

        if state:
            try:
                request_args["state"] = state
            except TypeError:
                request_args = {"state": state}

        kwargs['authn_endpoint'] = 'authorization'
        url, body, ht_args, csi = self.request_info(request, method,
                                                    request_args, extra_args,
                                                    **kwargs)

        try:
            self.authz_req[request_args["state"]] = csi
        except TypeError:
            pass

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)

        try:
            algs = kwargs["algs"]
        except KeyError:
            algs = {}

        resp = self.request_and_return(url,
                                       response_cls,
                                       method,
                                       body,
                                       body_type,
                                       state=state,
                                       http_args=http_args,
                                       algs=algs)

        if isinstance(resp, Message):
            if resp.type() in RESPONSE2ERROR["AuthorizationResponse"]:
                resp.state = csi.state

        return resp

    def do_access_token_request(self,
                                request=AccessTokenRequest,
                                scope="",
                                state="",
                                body_type="json",
                                method="POST",
                                request_args=None,
                                extra_args=None,
                                http_args=None,
                                response_cls=AccessTokenResponse,
                                authn_method="",
                                **kwargs):

        kwargs['authn_endpoint'] = 'token'
        # method is default POST
        url, body, ht_args, csi = self.request_info(request,
                                                    method=method,
                                                    request_args=request_args,
                                                    extra_args=extra_args,
                                                    scope=scope,
                                                    state=state,
                                                    authn_method=authn_method,
                                                    **kwargs)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)

        if self.events is not None:
            self.events.store('request_url', url)
            self.events.store('request_http_args', http_args)
            self.events.store('Request', body)

        logger.debug("<do_access_token> URL: %s, Body: %s" %
                     (url, sanitize(body)))
        logger.debug("<do_access_token> response_cls: %s" % response_cls)

        return self.request_and_return(url,
                                       response_cls,
                                       method,
                                       body,
                                       body_type,
                                       state=state,
                                       http_args=http_args,
                                       **kwargs)

    def do_access_token_refresh(self,
                                request=RefreshAccessTokenRequest,
                                state="",
                                body_type="json",
                                method="POST",
                                request_args=None,
                                extra_args=None,
                                http_args=None,
                                response_cls=AccessTokenResponse,
                                authn_method="",
                                **kwargs):

        token = self.get_token(also_expired=True, state=state, **kwargs)
        kwargs['authn_endpoint'] = 'refresh'
        url, body, ht_args, csi = self.request_info(request,
                                                    method=method,
                                                    request_args=request_args,
                                                    extra_args=extra_args,
                                                    token=token,
                                                    authn_method=authn_method)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)

        return self.request_and_return(url,
                                       response_cls,
                                       method,
                                       body,
                                       body_type,
                                       state=state,
                                       http_args=http_args)

    # def do_revocate_token(self, request=TokenRevocationRequest,
    #                       scope="", state="", body_type="json", method="POST",
    #                       request_args=None, extra_args=None, http_args=None,
    #                       response_cls=None, authn_method=""):
    #
    #     url, body, ht_args, csi = self.request_info(request, method=method,
    #                                                 request_args=request_args,
    #                                                 extra_args=extra_args,
    #                                                 scope=scope, state=state,
    #                                                 authn_method=authn_method)
    #
    #     if http_args is None:
    #         http_args = ht_args
    #     else:
    #         http_args.update(ht_args)
    #
    #     return self.request_and_return(url, response_cls, method, body,
    #                                    body_type, state=state,
    #                                    http_args=http_args)

    def do_any(self,
               request,
               endpoint="",
               scope="",
               state="",
               body_type="json",
               method="POST",
               request_args=None,
               extra_args=None,
               http_args=None,
               response=None,
               authn_method=""):

        url, body, ht_args, csi = self.request_info(request,
                                                    method=method,
                                                    request_args=request_args,
                                                    extra_args=extra_args,
                                                    scope=scope,
                                                    state=state,
                                                    authn_method=authn_method,
                                                    endpoint=endpoint)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)

        return self.request_and_return(url,
                                       response,
                                       method,
                                       body,
                                       body_type,
                                       state=state,
                                       http_args=http_args)

    def fetch_protected_resource(self,
                                 uri,
                                 method="GET",
                                 headers=None,
                                 state="",
                                 **kwargs):

        if "token" in kwargs and kwargs["token"]:
            token = kwargs["token"]
            request_args = {"access_token": token}
        else:
            try:
                token = self.get_token(state=state, **kwargs)
            except ExpiredToken:
                # The token is to old, refresh
                self.do_access_token_refresh()
                token = self.get_token(state=state, **kwargs)
            request_args = {"access_token": token.access_token}

        if headers is None:
            headers = {}

        if "authn_method" in kwargs:
            http_args = self.init_authentication_method(
                request_args=request_args, **kwargs)
        else:
            # If nothing defined this is the default
            http_args = self.client_authn_method["bearer_header"](
                self).construct(request_args=request_args)

        headers.update(http_args["headers"])

        logger.debug("Fetch URI: %s" % uri)
        return self.http_request(uri, method, headers=headers)

    def add_code_challenge(self):
        """
        PKCE RFC 7636 support

        :return:
        """
        try:
            cv_len = self.config['code_challenge']['length']
        except KeyError:
            cv_len = 64  # Use default

        code_verifier = unreserved(cv_len)
        _cv = code_verifier.encode()

        try:
            _method = self.config['code_challenge']['method']
        except KeyError:
            _method = 'S256'

        try:
            _h = CC_METHOD[_method](_cv).hexdigest()
            code_challenge = b64e(_h.encode()).decode()
        except KeyError:
            raise Unsupported('PKCE Transformation method:{}'.format(_method))

        # TODO store code_verifier

        return {
            "code_challenge": code_challenge,
            "code_challenge_method": _method
        }, code_verifier

    def handle_provider_config(self, pcr, issuer, keys=True, endpoints=True):
        """
        Deal with Provider Config Response
        :param pcr: The ProviderConfigResponse instance
        :param issuer: The one I thought should be the issuer of the config
        :param keys: Should I deal with keys
        :param endpoints: Should I deal with endpoints, that is store them
        as attributes in self.
        """

        if "issuer" in pcr:
            _pcr_issuer = pcr["issuer"]
            if pcr["issuer"].endswith("/"):
                if issuer.endswith("/"):
                    _issuer = issuer
                else:
                    _issuer = issuer + "/"
            else:
                if issuer.endswith("/"):
                    _issuer = issuer[:-1]
                else:
                    _issuer = issuer

            try:
                self.allow["issuer_mismatch"]
            except KeyError:
                try:
                    assert _issuer == _pcr_issuer
                except AssertionError:
                    raise PyoidcError(
                        "provider info issuer mismatch '%s' != '%s'" %
                        (_issuer, _pcr_issuer))

            self.provider_info = pcr
        else:
            _pcr_issuer = issuer

        self.issuer = _pcr_issuer

        if endpoints:
            for key, val in pcr.items():
                if key.endswith("_endpoint"):
                    setattr(self, key, val)

        if keys:
            if self.keyjar is None:
                self.keyjar = KeyJar()

            self.keyjar.load_keys(pcr, _pcr_issuer)

    def provider_config(self,
                        issuer,
                        keys=True,
                        endpoints=True,
                        response_cls=ASConfigurationResponse,
                        serv_pattern=OIDCONF_PATTERN):
        if issuer.endswith("/"):
            _issuer = issuer[:-1]
        else:
            _issuer = issuer

        url = serv_pattern % _issuer

        pcr = None
        r = self.http_request(url)
        if r.status_code == 200:
            pcr = response_cls().from_json(r.text)
        elif r.status_code == 302:
            while r.status_code == 302:
                r = self.http_request(r.headers["location"])
                if r.status_code == 200:
                    pcr = response_cls().from_json(r.text)
                    break

        if pcr is None:
            raise PyoidcError("Trying '%s', status %s" % (url, r.status_code))

        self.handle_provider_config(pcr, issuer, keys, endpoints)

        return pcr
Ejemplo n.º 16
0
class Client(PBase):
    _endpoints = ENDPOINTS

    def __init__(self, client_id=None, ca_certs=None, client_authn_method=None,
                 keyjar=None, verify_ssl=True):
        """

        :param client_id: The client identifier
        :param ca_certs: Certificates used to verify HTTPS certificates
        :param client_authn_method: Methods that this client can use to
            authenticate itself. It's a dictionary with method names as
            keys and method classes as values.
        :param verify_ssl: Whether the SSL certificate should be verified.
        :return: Client instance
        """

        PBase.__init__(self, ca_certs, verify_ssl=verify_ssl)

        self.client_id = client_id
        self.client_authn_method = client_authn_method
        self.keyjar = keyjar or KeyJar(verify_ssl=verify_ssl)
        self.verify_ssl = verify_ssl
        # self.secret_type = "basic "

        # self.state = None
        self.nonce = None

        self.grant = {}

        # own endpoint
        self.redirect_uris = [None]

        # service endpoints
        self.authorization_endpoint = None
        self.token_endpoint = None
        self.token_revocation_endpoint = None

        self.request2endpoint = REQUEST2ENDPOINT
        self.response2error = RESPONSE2ERROR
        self.grant_class = Grant
        self.token_class = Token

        self.provider_info = {}
        self._c_secret = None
        self.kid = {"sig": {}, "enc": {}}
        self.authz_req = None

    def store_response(self, clinst, text):
        pass

    def get_client_secret(self):
        return self._c_secret

    def set_client_secret(self, val):
        if not val:
            self._c_secret = ""
        else:
            self._c_secret = val
            # client uses it for signing
            # Server might also use it for signing which means the
            # client uses it for verifying server signatures
            if self.keyjar is None:
                self.keyjar = KeyJar()
            self.keyjar.add_symmetric("", str(val), ["sig"])

    client_secret = property(get_client_secret, set_client_secret)

    def reset(self):
        # self.state = None
        self.nonce = None

        self.grant = {}

        self.authorization_endpoint = None
        self.token_endpoint = None
        self.redirect_uris = [None]

    def grant_from_state(self, state):
        for key, grant in self.grant.items():
            if key == state:
                return grant

        return None

    def _parse_args(self, request, **kwargs):
        ar_args = kwargs.copy()

        for prop in request.c_param.keys():
            if prop in ar_args:
                continue
            else:
                if prop == "redirect_uri":
                    _val = getattr(self, "redirect_uris", [None])[0]
                    if _val:
                        ar_args[prop] = _val
                else:
                    _val = getattr(self, prop, None)
                    if _val:
                        ar_args[prop] = _val

        return ar_args

    def _endpoint(self, endpoint, **kwargs):
        try:
            uri = kwargs[endpoint]
            if uri:
                del kwargs[endpoint]
        except KeyError:
            uri = ""

        if not uri:
            try:
                uri = getattr(self, endpoint)
            except Exception:
                raise MissingEndpoint("No '%s' specified" % endpoint)

        if not uri:
            raise MissingEndpoint("No '%s' specified" % endpoint)

        return uri

    def get_grant(self, state, **kwargs):
        # try:
        # _state = kwargs["state"]
        # if not _state:
        #         _state = self.state
        # except KeyError:
        #     _state = self.state

        try:
            return self.grant[state]
        except:
            raise GrantError("No grant found for state:'%s'" % state)

    def get_token(self, also_expired=False, **kwargs):
        try:
            return kwargs["token"]
        except KeyError:
            grant = self.get_grant(**kwargs)

            try:
                token = grant.get_token(kwargs["scope"])
            except KeyError:
                token = grant.get_token("")
                if not token:
                    try:
                        token = self.grant[kwargs["state"]].get_token("")
                    except KeyError:
                        raise TokenError("No token found for scope")

        if token is None:
            raise TokenError("No suitable token found")

        if also_expired:
            return token
        elif token.is_valid():
            return token
        else:
            raise TokenError("Token has expired")

    def construct_request(self, request, request_args=None, extra_args=None):
        if request_args is None:
            request_args = {}

        # logger.debug("request_args: %s" % request_args)
        kwargs = self._parse_args(request, **request_args)

        if extra_args:
            kwargs.update(extra_args)
            # logger.debug("kwargs: %s" % kwargs)
        # logger.debug("request: %s" % request)
        return request(**kwargs)

    def construct_Message(self, request=Message, request_args=None,
                          extra_args=None, **kwargs):

        return self.construct_request(request, request_args, extra_args)

    # noinspection PyUnusedLocal
    def construct_AuthorizationRequest(self, request=AuthorizationRequest,
                                       request_args=None, extra_args=None,
                                       **kwargs):

        if request_args is not None:
            try:  # change default
                new = request_args["redirect_uri"]
                if new:
                    self.redirect_uris = [new]
            except KeyError:
                pass
        else:
            request_args = {}

        if "client_id" not in request_args:
            request_args["client_id"] = self.client_id
        elif not request_args["client_id"]:
            request_args["client_id"] = self.client_id

        return self.construct_request(request, request_args, extra_args)

    # noinspection PyUnusedLocal
    def construct_AccessTokenRequest(self,
                                     request=AccessTokenRequest,
                                     request_args=None, extra_args=None,
                                     **kwargs):

        grant = self.get_grant(**kwargs)

        if not grant.is_valid():
            raise GrantExpired("Authorization Code to old %s > %s" % (
                utc_time_sans_frac(),
                grant.grant_expiration_time))

        if request_args is None:
            request_args = {}

        request_args["code"] = grant.code

        if "grant_type" not in request_args:
            request_args["grant_type"] = "authorization_code"

        if "client_id" not in request_args:
            request_args["client_id"] = self.client_id
        elif not request_args["client_id"]:
            request_args["client_id"] = self.client_id
        return self.construct_request(request, request_args, extra_args)

    def construct_RefreshAccessTokenRequest(self,
                                            request=RefreshAccessTokenRequest,
                                            request_args=None, extra_args=None,
                                            **kwargs):

        if request_args is None:
            request_args = {}

        token = self.get_token(also_expired=True, **kwargs)

        request_args["refresh_token"] = token.refresh_token

        try:
            request_args["scope"] = token.scope
        except AttributeError:
            pass

        return self.construct_request(request, request_args, extra_args)

    def construct_TokenRevocationRequest(self,
                                         request=TokenRevocationRequest,
                                         request_args=None, extra_args=None,
                                         **kwargs):

        if request_args is None:
            request_args = {}

        token = self.get_token(**kwargs)

        request_args["token"] = token.access_token
        return self.construct_request(request, request_args, extra_args)

    def construct_ResourceRequest(self, request=ResourceRequest,
                                  request_args=None, extra_args=None,
                                  **kwargs):

        if request_args is None:
            request_args = {}

        token = self.get_token(**kwargs)

        request_args["access_token"] = token.access_token
        return self.construct_request(request, request_args, extra_args)

    def uri_and_body(self, reqmsg, cis, method="POST", request_args=None,
                     **kwargs):

        if "endpoint" in kwargs and kwargs["endpoint"]:
            uri = kwargs["endpoint"]
        else:
            uri = self._endpoint(self.request2endpoint[reqmsg.__name__],
                                 **request_args)

        uri, body, kwargs = get_or_post(uri, method, cis, **kwargs)
        try:
            h_args = {"headers": kwargs["headers"]}
        except KeyError:
            h_args = {}

        return uri, body, h_args, cis

    def request_info(self, request, method="POST", request_args=None,
                     extra_args=None, lax=False, **kwargs):

        if request_args is None:
            request_args = {}

        try:
            cls = getattr(self, "construct_%s" % request.__name__)
            cis = cls(request_args=request_args, extra_args=extra_args,
                      **kwargs)
        except AttributeError:
            cis = self.construct_request(request, request_args, extra_args)

        cis.lax = lax

        if "authn_method" in kwargs:
            h_arg = self.init_authentication_method(cis,
                                                    request_args=request_args,
                                                    **kwargs)
        else:
            h_arg = None

        if h_arg:
            if "headers" in kwargs.keys():
                kwargs["headers"].update(h_arg["headers"])
            else:
                kwargs["headers"] = h_arg["headers"]

        return self.uri_and_body(request, cis, method, request_args,
                                 **kwargs)

    def authorization_request_info(self, request_args=None, extra_args=None,
                                   **kwargs):
        return self.request_info(AuthorizationRequest, "GET",
                                 request_args, extra_args, **kwargs)

    def get_urlinfo(self, info):
        if '?' in info or '#' in info:
            parts = urlparse(info)
            scheme, netloc, path, params, query, fragment = parts[:6]
            # either query of fragment
            if query:
                info = query
            else:
                info = fragment
        return info

    def parse_response(self, response, info="", sformat="json", state="",
                       **kwargs):
        """
        Parse a response

        :param response: Response type
        :param info: The response, can be either in a JSON or an urlencoded
            format
        :param sformat: Which serialization that was used
        :param state: The state
        :param kwargs: Extra key word arguments
        :return: The parsed and to some extend verified response
        """

        _r2e = self.response2error

        if sformat == "urlencoded":
            info = self.get_urlinfo(info)

        resp = response().deserialize(info, sformat, **kwargs)
        if "error" in resp and not isinstance(resp, ErrorResponse):
            resp = None
            try:
                errmsgs = _r2e[response.__name__]
            except KeyError:
                errmsgs = [ErrorResponse]

            try:
                for errmsg in errmsgs:
                    try:
                        resp = errmsg().deserialize(info, sformat)
                        resp.verify()
                        break
                    except Exception as aerr:
                        resp = None
                        err = aerr
            except KeyError:
                pass
        elif resp.only_extras():
            resp = None
        else:
            kwargs["client_id"] = self.client_id
            if "key" not in kwargs and "keyjar" not in kwargs:
                kwargs["keyjar"] = self.keyjar

            logger.debug("Verify response with {}".format(kwargs))
            verf = resp.verify(**kwargs)

            if not verf:
                raise PyoidcError("Verification of the response failed")
            if resp.type() == "AuthorizationResponse" and \
                    "scope" not in resp:
                try:
                    resp["scope"] = kwargs["scope"]
                except KeyError:
                    pass

        if not resp:
            raise ResponseError("Missing or faulty response")

        self.store_response(resp, info)

        if resp.type() in ["AuthorizationResponse", "AccessTokenResponse"]:
            try:
                _state = resp["state"]
            except (AttributeError, KeyError):
                _state = ""

            if not _state:
                _state = state

            try:
                self.grant[_state].update(resp)
            except KeyError:
                self.grant[_state] = self.grant_class(resp=resp)

        return resp

    # noinspection PyUnusedLocal
    def init_authentication_method(self, cis, authn_method, request_args=None,
                                   http_args=None, **kwargs):

        if http_args is None:
            http_args = {}
        if request_args is None:
            request_args = {}

        if authn_method:
            return self.client_authn_method[authn_method](self).construct(
                cis, request_args, http_args, **kwargs)
        else:
            return http_args

    def parse_request_response(self, reqresp, response, body_type, state="",
                               **kwargs):

        if reqresp.status_code in SUCCESSFUL:
            body_type = verify_header(reqresp, body_type)
        elif reqresp.status_code == 302:  # redirect
            pass
        elif reqresp.status_code == 500:
            logger.error("(%d) %s" % (reqresp.status_code, reqresp.text))
            raise ParseError("ERROR: Something went wrong: %s" % reqresp.text)
        elif reqresp.status_code in [400, 401]:
            # expecting an error response
            if issubclass(response, ErrorResponse):
                pass
        else:
            logger.error("(%d) %s" % (reqresp.status_code, reqresp.text))
            raise HttpError("HTTP ERROR: %s [%s] on %s" % (
                reqresp.text, reqresp.status_code, reqresp.url))

        if body_type:
            if response:
                return self.parse_response(response, reqresp.text, body_type,
                                           state, **kwargs)
            else:
                raise OtherError("Didn't expect a response body")
        else:
            return reqresp

    def request_and_return(self, url, response=None, method="GET", body=None,
                           body_type="json", state="", http_args=None,
                           **kwargs):
        """
        :param url: The URL to which the request should be sent
        :param response: Response type
        :param method: Which HTTP method to use
        :param body: A message body if any
        :param body_type: The format of the body of the return message
        :param http_args: Arguments for the HTTP client
        :return: A cls or ErrorResponse instance or the HTTP response
            instance if no response body was expected.
        """

        if http_args is None:
            http_args = {}

        try:
            resp = self.http_request(url, method, data=body, **http_args)
        except Exception:
            raise

        if "keyjar" not in kwargs:
            kwargs["keyjar"] = self.keyjar

        return self.parse_request_response(resp, response, body_type, state,
                                           **kwargs)

    def do_authorization_request(self, request=AuthorizationRequest,
                                 state="", body_type="", method="GET",
                                 request_args=None, extra_args=None,
                                 http_args=None,
                                 response_cls=AuthorizationResponse,
                                 **kwargs):

        if state:
            request_args["state"] = state

        url, body, ht_args, csi = self.request_info(request, method,
                                                    request_args, extra_args,
                                                    **kwargs)

        try:
            self.authz_req[request_args["state"]] = csi
        except TypeError:
            pass

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)

        try:
            algs = kwargs["algs"]
        except:
            algs = {}

        resp = self.request_and_return(url, response_cls, method, body,
                                       body_type, state=state,
                                       http_args=http_args, algs=algs)

        if isinstance(resp, Message):
            if resp.type() in RESPONSE2ERROR["AuthorizationRequest"]:
                resp.state = csi.state

        return resp

    def do_access_token_request(self, request=AccessTokenRequest,
                                scope="", state="", body_type="json",
                                method="POST", request_args=None,
                                extra_args=None, http_args=None,
                                response_cls=AccessTokenResponse,
                                authn_method="", **kwargs):

        # method is default POST
        url, body, ht_args, csi = self.request_info(request, method=method,
                                                    request_args=request_args,
                                                    extra_args=extra_args,
                                                    scope=scope, state=state,
                                                    authn_method=authn_method,
                                                    **kwargs)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)

        logger.debug("<do_access_token> URL: %s, Body: %s" % (url, body))
        logger.debug("<do_access_token> response_cls: %s" % response_cls)

        return self.request_and_return(url, response_cls, method, body,
                                       body_type, state=state,
                                       http_args=http_args, **kwargs)

    def do_access_token_refresh(self, request=RefreshAccessTokenRequest,
                                state="", body_type="json", method="POST",
                                request_args=None, extra_args=None,
                                http_args=None,
                                response_cls=AccessTokenResponse,
                                authn_method="", **kwargs):

        token = self.get_token(also_expired=True, state=state, **kwargs)

        url, body, ht_args, csi = self.request_info(request, method=method,
                                                    request_args=request_args,
                                                    extra_args=extra_args,
                                                    token=token,
                                                    authn_method=authn_method)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)

        return self.request_and_return(url, response_cls, method, body,
                                       body_type, state=state,
                                       http_args=http_args)

    def do_revocate_token(self, request=TokenRevocationRequest,
                          scope="", state="", body_type="json", method="POST",
                          request_args=None, extra_args=None, http_args=None,
                          response_cls=None, authn_method=""):

        url, body, ht_args, csi = self.request_info(request, method=method,
                                                    request_args=request_args,
                                                    extra_args=extra_args,
                                                    scope=scope, state=state,
                                                    authn_method=authn_method)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)

        return self.request_and_return(url, response_cls, method, body,
                                       body_type, state=state,
                                       http_args=http_args)

    def do_any(self, request, endpoint="", scope="", state="", body_type="json",
               method="POST", request_args=None, extra_args=None,
               http_args=None, response=None, authn_method=""):

        url, body, ht_args, csi = self.request_info(request, method=method,
                                                    request_args=request_args,
                                                    extra_args=extra_args,
                                                    scope=scope, state=state,
                                                    authn_method=authn_method,
                                                    endpoint=endpoint)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)

        return self.request_and_return(url, response, method, body, body_type,
                                       state=state, http_args=http_args)

    def fetch_protected_resource(self, uri, method="GET", headers=None,
                                 state="", **kwargs):

        if "token" in kwargs and kwargs["token"]:
            token = kwargs["token"]
            request_args = {"access_token": token}
        else:
            try:
                token = self.get_token(state=state, **kwargs)
            except ExpiredToken:
                # The token is to old, refresh
                self.do_access_token_refresh()
                token = self.get_token(state=state, **kwargs)
            request_args = {"access_token": token.access_token}

        if headers is None:
            headers = {}

        if "authn_method" in kwargs:
            http_args = self.init_authentication_method(
                request_args=request_args, **kwargs)
        else:
            # If nothing defined this is the default
            http_args = self.client_authn_method[
                "bearer_header"](self).construct(request_args=request_args)

        headers.update(http_args["headers"])

        logger.debug("Fetch URI: %s" % uri)
        return self.http_request(uri, method, headers=headers)
Ejemplo n.º 17
0
class Client(PBase):
    _endpoints = ENDPOINTS

    def __init__(
        self,
        client_id=None,
        client_authn_method=None,
        keyjar=None,
        verify_ssl=True,
        config=None,
        client_cert=None,
        timeout=5,
        message_factory: Type[MessageFactory] = OauthMessageFactory,
    ):
        """
        Initialize the instance.

        :param client_id: The client identifier
        :param client_authn_method: Methods that this client can use to
            authenticate itself. It's a dictionary with method names as
            keys and method classes as values.
        :param keyjar: The keyjar for this client.
        :param verify_ssl: Whether the SSL certificate should be verified.
        :param client_cert: A client certificate to use.
        :param timeout: Timeout for requests library. Can be specified either as
            a single integer or as a tuple of integers. For more details, refer to
            ``requests`` documentation.
        :param: message_factory: Factory for message classes, should inherit from OauthMessageFactory
        :return: Client instance
        """
        PBase.__init__(
            self,
            verify_ssl=verify_ssl,
            keyjar=keyjar,
            client_cert=client_cert,
            timeout=timeout,
        )

        self.client_id = client_id
        self.client_authn_method = client_authn_method

        self.nonce = None

        self.message_factory = message_factory
        self.grant = {}  # type: Dict[str, Grant]
        self.state2nonce = {}  # type: Dict[str, str]
        # own endpoint
        self.redirect_uris = []  # type: List[str]

        # service endpoints
        self.authorization_endpoint = None  # type: Optional[str]
        self.token_endpoint = None  # type: Optional[str]
        self.token_revocation_endpoint = None  # type: Optional[str]

        self.request2endpoint = REQUEST2ENDPOINT
        self.response2error = RESPONSE2ERROR  # type: Dict[str, List]
        self.grant_class = Grant
        self.token_class = Token

        self.provider_info = ASConfigurationResponse()  # type: Message
        self._c_secret = ""  # type: str
        self.kid = {"sig": {}, "enc": {}}  # type: Dict[str, Dict]
        self.authz_req = {}  # type: Dict[str, Message]

        # the OAuth issuer is the URL of the authorization server's
        # configuration information location
        self.config = config or {}
        try:
            self.issuer = self.config["issuer"]
        except KeyError:
            self.issuer = ""
        self.allow = {}  # type: Dict[str, Any]

    def store_response(self, clinst, text):
        pass

    def get_client_secret(self) -> str:
        return self._c_secret

    def set_client_secret(self, val: str):
        if not val:
            self._c_secret = ""
        else:
            self._c_secret = val
            # client uses it for signing
            # Server might also use it for signing which means the
            # client uses it for verifying server signatures
            if self.keyjar is None:
                self.keyjar = KeyJar()
            self.keyjar.add_symmetric("", str(val))

    client_secret = property(get_client_secret, set_client_secret)

    def reset(self) -> None:
        self.nonce = None

        self.grant = {}

        self.authorization_endpoint = None
        self.token_endpoint = None
        self.redirect_uris = []

    def grant_from_state(self, state: str) -> Optional[Grant]:
        for key, grant in self.grant.items():
            if key == state:
                return grant

        return None

    def _parse_args(self, request: Type[Message], **kwargs) -> Dict:
        ar_args = kwargs.copy()

        for prop in request.c_param.keys():
            if prop in ar_args:
                continue
            else:
                if prop == "redirect_uri":
                    _val = getattr(self, "redirect_uris", [None])[0]
                    if _val:
                        ar_args[prop] = _val
                else:
                    _val = getattr(self, prop, None)
                    if _val:
                        ar_args[prop] = _val

        return ar_args

    def _endpoint(self, endpoint: str, **kwargs) -> str:
        try:
            uri = kwargs[endpoint]
            if uri:
                del kwargs[endpoint]
        except KeyError:
            uri = ""

        if not uri:
            try:
                uri = getattr(self, endpoint)
            except Exception:
                raise MissingEndpoint("No '%s' specified" % endpoint)

        if not uri:
            raise MissingEndpoint("No '%s' specified" % endpoint)

        return uri

    def get_grant(self, state: str, **kwargs) -> Grant:
        try:
            return self.grant[state]
        except KeyError:
            raise GrantError("No grant found for state:'%s'" % state)

    def get_token(self, also_expired: bool = False, **kwargs) -> Token:
        try:
            return kwargs["token"]
        except KeyError:
            grant = self.get_grant(**kwargs)

            try:
                token = grant.get_token(kwargs["scope"])
            except KeyError:
                token = grant.get_token("")
                if not token:
                    try:
                        token = self.grant[kwargs["state"]].get_token("")
                    except KeyError:
                        raise TokenError("No token found for scope")

        if token is None:
            raise TokenError("No suitable token found")

        if also_expired:
            return token
        elif token.is_valid():
            return token
        else:
            raise TokenError("Token has expired")

    def clean_tokens(self) -> None:
        """Clean replaced and invalid tokens."""
        for state in self.grant:
            grant = self.get_grant(state)
            for token in grant.tokens:
                if token.replaced or not token.is_valid():
                    grant.delete_token(token)

    def construct_request(
        self, request: Type[Message], request_args=None, extra_args=None
    ):
        if request_args is None:
            request_args = {}

        kwargs = self._parse_args(request, **request_args)

        if extra_args:
            kwargs.update(extra_args)
        logger.debug("request: %s" % sanitize(request))
        return request(**kwargs)

    def construct_Message(
        self,
        request: Type[Message] = Message,
        request_args=None,
        extra_args=None,
        **kwargs
    ) -> Message:

        return self.construct_request(request, request_args, extra_args)

    def construct_AuthorizationRequest(
        self,
        request: Type[AuthorizationRequest] = None,
        request_args=None,
        extra_args=None,
        **kwargs
    ) -> AuthorizationRequest:

        if request is None:
            request = self.message_factory.get_request_type("authorization_endpoint")
        if request_args is not None:
            try:  # change default
                new = request_args["redirect_uri"]
                if new:
                    self.redirect_uris = [new]
            except KeyError:
                pass
        else:
            request_args = {}

        if "client_id" not in request_args:
            request_args["client_id"] = self.client_id
        elif not request_args["client_id"]:
            request_args["client_id"] = self.client_id

        return self.construct_request(request, request_args, extra_args)

    def construct_AccessTokenRequest(
        self,
        request: Type[AccessTokenRequest] = None,
        request_args=None,
        extra_args=None,
        **kwargs
    ) -> AccessTokenRequest:

        if request is None:
            request = self.message_factory.get_request_type("token_endpoint")
        if request_args is None:
            request_args = {}
        if request is not ROPCAccessTokenRequest:
            grant = self.get_grant(**kwargs)

            if not grant.is_valid():
                raise GrantExpired(
                    "Authorization Code to old %s > %s"
                    % (utc_time_sans_frac(), grant.grant_expiration_time)
                )

            request_args["code"] = grant.code

        try:
            request_args["state"] = kwargs["state"]
        except KeyError:
            pass

        if "grant_type" not in request_args:
            request_args["grant_type"] = "authorization_code"

        if "client_id" not in request_args:
            request_args["client_id"] = self.client_id
        elif not request_args["client_id"]:
            request_args["client_id"] = self.client_id
        return self.construct_request(request, request_args, extra_args)

    def construct_RefreshAccessTokenRequest(
        self,
        request: Type[RefreshAccessTokenRequest] = None,
        request_args=None,
        extra_args=None,
        **kwargs
    ) -> RefreshAccessTokenRequest:

        if request is None:
            request = self.message_factory.get_request_type("refresh_endpoint")
        if request_args is None:
            request_args = {}

        token = self.get_token(also_expired=True, **kwargs)

        request_args["refresh_token"] = token.refresh_token

        try:
            request_args["scope"] = token.scope
        except AttributeError:
            pass

        return self.construct_request(request, request_args, extra_args)

    def construct_ResourceRequest(
        self,
        request: Type[ResourceRequest] = None,
        request_args=None,
        extra_args=None,
        **kwargs
    ) -> ResourceRequest:

        if request is None:
            request = self.message_factory.get_request_type("resource_endpoint")
        if request_args is None:
            request_args = {}

        token = self.get_token(**kwargs)

        request_args["access_token"] = token.access_token
        return self.construct_request(request, request_args, extra_args)

    def uri_and_body(
        self,
        reqmsg: Type[Message],
        cis: Message,
        method="POST",
        request_args=None,
        **kwargs
    ) -> Tuple[str, str, Dict, Message]:
        if "endpoint" in kwargs and kwargs["endpoint"]:
            uri = kwargs["endpoint"]
        else:
            uri = self._endpoint(self.request2endpoint[reqmsg.__name__], **request_args)

        uri, body, kwargs = get_or_post(uri, method, cis, **kwargs)
        try:
            h_args = {"headers": kwargs["headers"]}
        except KeyError:
            h_args = {}

        return uri, body, h_args, cis

    def request_info(
        self,
        request: Type[Message],
        method="POST",
        request_args=None,
        extra_args=None,
        lax=False,
        **kwargs
    ) -> Tuple[str, str, Dict, Message]:

        if request_args is None:
            request_args = {}

        try:
            cls = getattr(self, "construct_%s" % request.__name__)
            cis = cls(request_args=request_args, extra_args=extra_args, **kwargs)
        except AttributeError:
            cis = self.construct_request(request, request_args, extra_args)

        if self.events:
            self.events.store("Protocol request", cis)

        if "nonce" in cis and "state" in cis:
            self.state2nonce[cis["state"]] = cis["nonce"]

        cis.lax = lax

        if "authn_method" in kwargs:
            h_arg = self.init_authentication_method(
                cis, request_args=request_args, **kwargs
            )
        else:
            h_arg = None

        if h_arg:
            if "headers" in kwargs.keys():
                kwargs["headers"].update(h_arg["headers"])
            else:
                kwargs["headers"] = h_arg["headers"]

        return self.uri_and_body(request, cis, method, request_args, **kwargs)

    def authorization_request_info(self, request_args=None, extra_args=None, **kwargs):
        return self.request_info(
            self.message_factory.get_request_type("authorization_endpoint"),
            "GET",
            request_args,
            extra_args,
            **kwargs
        )

    @staticmethod
    def get_urlinfo(info: str) -> str:
        if "?" in info or "#" in info:
            parts = urlparse(info)
            scheme, netloc, path, params, query, fragment = parts[:6]
            # either query of fragment
            if query:
                info = query
            else:
                info = fragment
        return info

    def parse_response(
        self,
        response: Type[Message],
        info: str = "",
        sformat: ENCODINGS = "json",
        state: str = "",
        **kwargs
    ) -> Message:
        """
        Parse a response.

        :param response: Response type
        :param info: The response, can be either in a JSON or an urlencoded
            format
        :param sformat: Which serialization that was used
        :param state: The state
        :param kwargs: Extra key word arguments
        :return: The parsed and to some extend verified response
        """
        _r2e = self.response2error

        if sformat == "urlencoded":
            info = self.get_urlinfo(info)

        resp = response().deserialize(info, sformat, **kwargs)
        msg = 'Initial response parsing => "{}"'
        logger.debug(msg.format(sanitize(resp.to_dict())))
        if self.events:
            self.events.store("Response", resp.to_dict())

        if "error" in resp and not isinstance(resp, ErrorResponse):
            resp = None
            errmsgs = []  # type: List[Any]
            try:
                errmsgs = _r2e[response.__name__]
            except KeyError:
                errmsgs = [ErrorResponse]

            try:
                for errmsg in errmsgs:
                    try:
                        resp = errmsg().deserialize(info, sformat)
                        resp.verify()
                        break
                    except Exception:
                        resp = None
            except KeyError:
                pass
        elif resp.only_extras():
            resp = None
        else:
            kwargs["client_id"] = self.client_id
            try:
                kwargs["iss"] = self.provider_info["issuer"]
            except (KeyError, AttributeError):
                if self.issuer:
                    kwargs["iss"] = self.issuer

            if "key" not in kwargs and "keyjar" not in kwargs:
                kwargs["keyjar"] = self.keyjar

            logger.debug("Verify response with {}".format(sanitize(kwargs)))
            verf = resp.verify(**kwargs)

            if not verf:
                logger.error("Verification of the response failed")
                raise PyoidcError("Verification of the response failed")
            if resp.type() == "AuthorizationResponse" and "scope" not in resp:
                try:
                    resp["scope"] = kwargs["scope"]
                except KeyError:
                    pass

        if not resp:
            logger.error("Missing or faulty response")
            raise ResponseError("Missing or faulty response")

        self.store_response(resp, info)

        if isinstance(resp, (AuthorizationResponse, AccessTokenResponse)):
            try:
                _state = resp["state"]
            except (AttributeError, KeyError):
                _state = ""

            if not _state:
                _state = state

            try:
                self.grant[_state].update(resp)
            except KeyError:
                self.grant[_state] = self.grant_class(resp=resp)

        return resp

    def init_authentication_method(
        self, cis, authn_method, request_args=None, http_args=None, **kwargs
    ):

        if http_args is None:
            http_args = {}
        if request_args is None:
            request_args = {}

        if authn_method:
            return self.client_authn_method[authn_method](self).construct(
                cis, request_args, http_args, **kwargs
            )
        else:
            return http_args

    def parse_request_response(self, reqresp, response, body_type, state="", **kwargs):

        if reqresp.status_code in SUCCESSFUL:
            body_type = verify_header(reqresp, body_type)
        elif reqresp.status_code in [302, 303]:  # redirect
            return reqresp
        elif reqresp.status_code == 500:
            logger.error("(%d) %s" % (reqresp.status_code, sanitize(reqresp.text)))
            raise ParseError("ERROR: Something went wrong: %s" % reqresp.text)
        elif reqresp.status_code in [400, 401]:
            # expecting an error response
            if issubclass(response, ErrorResponse):
                pass
        else:
            logger.error("(%d) %s" % (reqresp.status_code, sanitize(reqresp.text)))
            raise HttpError(
                "HTTP ERROR: %s [%s] on %s"
                % (reqresp.text, reqresp.status_code, reqresp.url)
            )

        if response:
            if body_type == "txt":
                # no meaning trying to parse unstructured text
                return reqresp.text
            return self.parse_response(
                response, reqresp.text, body_type, state, **kwargs
            )

        # could be an error response
        if reqresp.status_code in [200, 400, 401]:
            if body_type == "txt":
                body_type = "urlencoded"
            try:
                err = ErrorResponse().deserialize(reqresp.message, method=body_type)
                try:
                    err.verify()
                except PyoidcError:
                    pass
                else:
                    return err
            except Exception:
                pass

        return reqresp

    def request_and_return(
        self,
        url: str,
        response: Type[Message] = None,
        method="GET",
        body=None,
        body_type: ENCODINGS = "json",
        state: str = "",
        http_args=None,
        **kwargs
    ):
        """
        Perform a request and return the response.

        :param url: The URL to which the request should be sent
        :param response: Response type
        :param method: Which HTTP method to use
        :param body: A message body if any
        :param body_type: The format of the body of the return message
        :param http_args: Arguments for the HTTP client
        :return: A cls or ErrorResponse instance or the HTTP response instance if no response body was expected.
        """
        # FIXME: Cannot annotate return value as Message since it disrupts all other methods
        if http_args is None:
            http_args = {}

        try:
            resp = self.http_request(url, method, data=body, **http_args)
        except Exception:
            raise

        if "keyjar" not in kwargs:
            kwargs["keyjar"] = self.keyjar

        return self.parse_request_response(resp, response, body_type, state, **kwargs)

    def do_authorization_request(
        self,
        state="",
        body_type="",
        method="GET",
        request_args=None,
        extra_args=None,
        http_args=None,
        **kwargs
    ) -> AuthorizationResponse:

        request = self.message_factory.get_request_type("authorization_endpoint")
        response_cls = self.message_factory.get_response_type("authorization_endpoint")

        if state:
            try:
                request_args["state"] = state
            except TypeError:
                request_args = {"state": state}

        kwargs["authn_endpoint"] = "authorization"
        url, body, ht_args, csi = self.request_info(
            request, method, request_args, extra_args, **kwargs
        )

        try:
            self.authz_req[request_args["state"]] = csi
        except TypeError:
            pass

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)

        try:
            algs = kwargs["algs"]
        except KeyError:
            algs = {}

        resp = self.request_and_return(
            url,
            response_cls,
            method,
            body,
            body_type,
            state=state,
            http_args=http_args,
            algs=algs,
        )

        if isinstance(resp, Message):
            # FIXME: The Message classes do not have classical attrs
            if resp.type() in RESPONSE2ERROR["AuthorizationResponse"]:  # type: ignore
                resp.state = csi.state  # type: ignore

        return resp

    def do_access_token_request(
        self,
        scope: str = "",
        state: str = "",
        body_type: ENCODINGS = "json",
        method="POST",
        request_args=None,
        extra_args=None,
        http_args=None,
        authn_method="",
        **kwargs
    ) -> AccessTokenResponse:

        request = self.message_factory.get_request_type("token_endpoint")
        response_cls = self.message_factory.get_response_type("token_endpoint")

        kwargs["authn_endpoint"] = "token"
        # method is default POST
        url, body, ht_args, csi = self.request_info(
            request,
            method=method,
            request_args=request_args,
            extra_args=extra_args,
            scope=scope,
            state=state,
            authn_method=authn_method,
            **kwargs
        )

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)
            http_args.pop("password", None)

        if self.events is not None:
            self.events.store("request_url", url)
            self.events.store("request_http_args", http_args)
            self.events.store("Request", body)

        logger.debug("<do_access_token> URL: %s, Body: %s" % (url, sanitize(body)))
        logger.debug("<do_access_token> response_cls: %s" % response_cls)

        return self.request_and_return(
            url,
            response_cls,
            method,
            body,
            body_type,
            state=state,
            http_args=http_args,
            **kwargs
        )

    def do_access_token_refresh(
        self,
        state: str = "",
        body_type: ENCODINGS = "json",
        method="POST",
        request_args=None,
        extra_args=None,
        http_args=None,
        authn_method="",
        **kwargs
    ) -> AccessTokenResponse:

        request = self.message_factory.get_request_type("refresh_endpoint")
        response_cls = self.message_factory.get_response_type("refresh_endpoint")

        token = self.get_token(also_expired=True, state=state, **kwargs)
        kwargs["authn_endpoint"] = "refresh"
        url, body, ht_args, csi = self.request_info(
            request,
            method=method,
            request_args=request_args,
            extra_args=extra_args,
            token=token,
            authn_method=authn_method,
        )

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)

        response = self.request_and_return(
            url, response_cls, method, body, body_type, state=state, http_args=http_args
        )
        if token.replaced:
            grant = self.get_grant(state)
            grant.delete_token(token)
        return response

    def do_any(
        self,
        request: Type[Message],
        endpoint="",
        scope="",
        state="",
        body_type="json",
        method="POST",
        request_args=None,
        extra_args=None,
        http_args=None,
        response: Type[Message] = None,
        authn_method="",
    ) -> Message:

        url, body, ht_args, _ = self.request_info(
            request,
            method=method,
            request_args=request_args,
            extra_args=extra_args,
            scope=scope,
            state=state,
            authn_method=authn_method,
            endpoint=endpoint,
        )

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)

        return self.request_and_return(
            url, response, method, body, body_type, state=state, http_args=http_args
        )

    def fetch_protected_resource(
        self, uri, method="GET", headers=None, state="", **kwargs
    ):

        if "token" in kwargs and kwargs["token"]:
            token = kwargs["token"]
            request_args = {"access_token": token}
        else:
            try:
                token = self.get_token(state=state, **kwargs)
            except ExpiredToken:
                # The token is to old, refresh
                self.do_access_token_refresh(state=state)
                token = self.get_token(state=state, **kwargs)
            request_args = {"access_token": token.access_token}

        if headers is None:
            headers = {}

        if "authn_method" in kwargs:
            http_args = self.init_authentication_method(
                request_args=request_args, **kwargs
            )
        else:
            # If nothing defined this is the default
            http_args = self.client_authn_method["bearer_header"](self).construct(
                request_args=request_args
            )

        headers.update(http_args["headers"])

        logger.debug("Fetch URI: %s" % uri)
        return self.http_request(uri, method, headers=headers)

    def add_code_challenge(self):
        """
        PKCE RFC 7636 support.

        :return:
        """
        try:
            cv_len = self.config["code_challenge"]["length"]
        except KeyError:
            cv_len = 64  # Use default

        code_verifier = unreserved(cv_len)
        _cv = code_verifier.encode("ascii")

        try:
            _method = self.config["code_challenge"]["method"]
        except KeyError:
            _method = "S256"

        try:
            _h = CC_METHOD[_method](_cv).digest()
            code_challenge = b64e(_h).decode("ascii")
        except KeyError:
            raise Unsupported("PKCE Transformation method:{}".format(_method))

        # TODO store code_verifier

        return (
            {"code_challenge": code_challenge, "code_challenge_method": _method},
            code_verifier,
        )

    def handle_provider_config(
        self,
        pcr: ASConfigurationResponse,
        issuer: str,
        keys: bool = True,
        endpoints: bool = True,
    ) -> None:
        """
        Deal with Provider Config Response.

        :param pcr: The ProviderConfigResponse instance
        :param issuer: The one I thought should be the issuer of the config
        :param keys: Should I deal with keys
        :param endpoints: Should I deal with endpoints, that is store them as attributes in self.
        """
        if "issuer" in pcr:
            _pcr_issuer = pcr["issuer"]
            if pcr["issuer"].endswith("/"):
                if issuer.endswith("/"):
                    _issuer = issuer
                else:
                    _issuer = issuer + "/"
            else:
                if issuer.endswith("/"):
                    _issuer = issuer[:-1]
                else:
                    _issuer = issuer

            if not self.allow.get("issuer_mismatch", False) and _issuer != _pcr_issuer:
                raise PyoidcError(
                    "provider info issuer mismatch '%s' != '%s'"
                    % (_issuer, _pcr_issuer)
                )

            self.provider_info = pcr
        else:
            _pcr_issuer = issuer

        self.issuer = _pcr_issuer

        if endpoints:
            for key, val in pcr.items():
                if key.endswith("_endpoint"):
                    setattr(self, key, val)

        if keys:
            if self.keyjar is None:
                self.keyjar = KeyJar()

            self.keyjar.load_keys(pcr, _pcr_issuer)

    def provider_config(
        self,
        issuer: str,
        keys: bool = True,
        endpoints: bool = True,
        serv_pattern: str = OIDCONF_PATTERN,
    ) -> ASConfigurationResponse:

        response_cls = self.message_factory.get_response_type("configuration_endpoint")
        if issuer.endswith("/"):
            _issuer = issuer[:-1]
        else:
            _issuer = issuer

        url = serv_pattern % _issuer

        pcr = None
        r = self.http_request(url, allow_redirects=True)
        if r.status_code == 200:
            try:
                pcr = response_cls().from_json(r.text)
            except Exception as e:
                # FIXME: This should catch specific exception from `from_json()`
                _err_txt = "Faulty provider config response: {}".format(e)
                logger.error(sanitize(_err_txt))
                raise ParseError(_err_txt)
        else:
            raise CommunicationError("Trying '%s', status %s" % (url, r.status_code))

        self.store_response(pcr, r.text)
        self.handle_provider_config(pcr, issuer, keys, endpoints)
        return pcr
Ejemplo n.º 18
0
class Client(PBase):
    _endpoints = ENDPOINTS

    def __init__(self,
                 client_id=None,
                 ca_certs=None,
                 client_authn_method=None,
                 keyjar=None,
                 verify_ssl=True):
        """

        :param client_id: The client identifier
        :param ca_certs: Certificates used to verify HTTPS certificates
        :param client_authn_method: Methods that this client can use to
            authenticate itself. It's a dictionary with method names as
            keys and method classes as values.
        :param verify_ssl: Whether the SSL certificate should be verfied.
        :return: Client instance
        """

        PBase.__init__(self, ca_certs, verify_ssl=verify_ssl)

        self.client_id = client_id
        self.client_authn_method = client_authn_method
        self.keyjar = keyjar or KeyJar(verify_ssl=verify_ssl)
        self.verify_ssl = verify_ssl
        # self.secret_type = "basic "

        # self.state = None
        self.nonce = None

        self.grant = {}

        # own endpoint
        self.redirect_uris = [None]

        # service endpoints
        self.authorization_endpoint = None
        self.token_endpoint = None
        self.token_revocation_endpoint = None

        self.request2endpoint = REQUEST2ENDPOINT
        self.response2error = RESPONSE2ERROR
        self.grant_class = Grant
        self.token_class = Token

        self.provider_info = {}
        self._c_secret = None
        self.kid = {"sig": {}, "enc": {}}

    def get_client_secret(self):
        return self._c_secret

    def set_client_secret(self, val):
        if not val:
            self._c_secret = ""
        else:
            self._c_secret = val
            # client uses it for signing
            # Server might also use it for signing which means the
            # client uses it for verifying server signatures
            if self.keyjar is None:
                self.keyjar = KeyJar()
            self.keyjar.add_symmetric("", str(val), ["sig"])

    client_secret = property(get_client_secret, set_client_secret)

    def reset(self):
        # self.state = None
        self.nonce = None

        self.grant = {}

        self.authorization_endpoint = None
        self.token_endpoint = None
        self.redirect_uris = None

    def grant_from_state(self, state):
        for key, grant in self.grant.items():
            if key == state:
                return grant

        return None

    def _parse_args(self, request, **kwargs):
        ar_args = kwargs.copy()

        for prop in request.c_param.keys():
            if prop in ar_args:
                continue
            else:
                if prop == "redirect_uri":
                    _val = getattr(self, "redirect_uris", [None])[0]
                    if _val:
                        ar_args[prop] = _val
                else:
                    _val = getattr(self, prop, None)
                    if _val:
                        ar_args[prop] = _val

        return ar_args

    def _endpoint(self, endpoint, **kwargs):
        try:
            uri = kwargs[endpoint]
            if uri:
                del kwargs[endpoint]
        except KeyError:
            uri = ""

        if not uri:
            try:
                uri = getattr(self, endpoint)
            except Exception:
                raise MissingEndpoint("No '%s' specified" % endpoint)

        if not uri:
            raise MissingEndpoint("No '%s' specified" % endpoint)

        return uri

    def get_grant(self, state, **kwargs):
        # try:
        # _state = kwargs["state"]
        # if not _state:
        #         _state = self.state
        # except KeyError:
        #     _state = self.state

        try:
            return self.grant[state]
        except:
            raise GrantError("No grant found for state:'%s'" % state)

    def get_token(self, also_expired=False, **kwargs):
        try:
            return kwargs["token"]
        except KeyError:
            grant = self.get_grant(**kwargs)

            try:
                token = grant.get_token(kwargs["scope"])
            except KeyError:
                token = grant.get_token("")
                if not token:
                    try:
                        token = self.grant[kwargs["state"]].get_token("")
                    except KeyError:
                        raise TokenError("No token found for scope")

        if token is None:
            raise TokenError("No suitable token found")

        if also_expired:
            return token
        elif token.is_valid():
            return token
        else:
            raise TokenError("Token has expired")

    def construct_request(self, request, request_args=None, extra_args=None):
        if request_args is None:
            request_args = {}

        # logger.debug("request_args: %s" % request_args)
        kwargs = self._parse_args(request, **request_args)

        if extra_args:
            kwargs.update(extra_args)
            # logger.debug("kwargs: %s" % kwargs)
        # logger.debug("request: %s" % request)
        return request(**kwargs)

    def construct_Message(self,
                          request=Message,
                          request_args=None,
                          extra_args=None,
                          **kwargs):

        return self.construct_request(request, request_args, extra_args)

    # noinspection PyUnusedLocal
    def construct_AuthorizationRequest(self,
                                       request=AuthorizationRequest,
                                       request_args=None,
                                       extra_args=None,
                                       **kwargs):

        if request_args is not None:
            try:  # change default
                new = request_args["redirect_uri"]
                if new:
                    self.redirect_uris = [new]
            except KeyError:
                pass
        else:
            request_args = {}

        if "client_id" not in request_args:
            request_args["client_id"] = self.client_id
        elif not request_args["client_id"]:
            request_args["client_id"] = self.client_id

        return self.construct_request(request, request_args, extra_args)

    # noinspection PyUnusedLocal
    def construct_AccessTokenRequest(self,
                                     request=AccessTokenRequest,
                                     request_args=None,
                                     extra_args=None,
                                     **kwargs):

        grant = self.get_grant(**kwargs)

        if not grant.is_valid():
            raise GrantExpired(
                "Authorization Code to old %s > %s" %
                (utc_time_sans_frac(), grant.grant_expiration_time))

        if request_args is None:
            request_args = {}

        request_args["code"] = grant.code

        if "grant_type" not in request_args:
            request_args["grant_type"] = "authorization_code"

        if "client_id" not in request_args:
            request_args["client_id"] = self.client_id
        elif not request_args["client_id"]:
            request_args["client_id"] = self.client_id
        return self.construct_request(request, request_args, extra_args)

    def construct_RefreshAccessTokenRequest(self,
                                            request=RefreshAccessTokenRequest,
                                            request_args=None,
                                            extra_args=None,
                                            **kwargs):

        if request_args is None:
            request_args = {}

        token = self.get_token(also_expired=True, **kwargs)

        request_args["refresh_token"] = token.refresh_token

        try:
            request_args["scope"] = token.scope
        except AttributeError:
            pass

        return self.construct_request(request, request_args, extra_args)

    def construct_TokenRevocationRequest(self,
                                         request=TokenRevocationRequest,
                                         request_args=None,
                                         extra_args=None,
                                         **kwargs):

        if request_args is None:
            request_args = {}

        token = self.get_token(**kwargs)

        request_args["token"] = token.access_token
        return self.construct_request(request, request_args, extra_args)

    def construct_ResourceRequest(self,
                                  request=ResourceRequest,
                                  request_args=None,
                                  extra_args=None,
                                  **kwargs):

        if request_args is None:
            request_args = {}

        token = self.get_token(**kwargs)

        request_args["access_token"] = token.access_token
        return self.construct_request(request, request_args, extra_args)

    @staticmethod
    def get_or_post(uri,
                    method,
                    req,
                    content_type=DEFAULT_POST_CONTENT_TYPE,
                    accept=None,
                    **kwargs):
        if method == "GET":
            _qp = req.to_urlencoded()
            if _qp:
                path = uri + '?' + _qp
            else:
                path = uri
            body = None
        elif method == "POST":
            path = uri
            if content_type == URL_ENCODED:
                body = req.to_urlencoded()
            elif content_type == JSON_ENCODED:
                body = req.to_json()
            else:
                raise UnSupported("Unsupported content type: '%s'" %
                                  content_type)

            header_ext = {"Content-type": content_type}
            if accept:
                header_ext = {"Accept": accept}

            if "headers" in kwargs.keys():
                kwargs["headers"].update(header_ext)
            else:
                kwargs["headers"] = header_ext
        else:
            raise UnSupported("Unsupported HTTP method: '%s'" % method)

        return path, body, kwargs

    def uri_and_body(self,
                     reqmsg,
                     cis,
                     method="POST",
                     request_args=None,
                     **kwargs):

        if "endpoint" in kwargs and kwargs["endpoint"]:
            uri = kwargs["endpoint"]
        else:
            uri = self._endpoint(self.request2endpoint[reqmsg.__name__],
                                 **request_args)

        uri, body, kwargs = self.get_or_post(uri, method, cis, **kwargs)
        try:
            h_args = {"headers": kwargs["headers"]}
        except KeyError:
            h_args = {}

        return uri, body, h_args, cis

    def request_info(self,
                     request,
                     method="POST",
                     request_args=None,
                     extra_args=None,
                     **kwargs):

        if request_args is None:
            request_args = {}

        try:
            cls = getattr(self, "construct_%s" % request.__name__)
            cis = cls(request_args=request_args,
                      extra_args=extra_args,
                      **kwargs)
        except AttributeError:
            cis = self.construct_request(request, request_args, extra_args)

        if "authn_method" in kwargs:
            h_arg = self.init_authentication_method(cis,
                                                    request_args=request_args,
                                                    **kwargs)
        else:
            h_arg = None

        if h_arg:
            if "headers" in kwargs.keys():
                kwargs["headers"].update(h_arg["headers"])
            else:
                kwargs["headers"] = h_arg["headers"]

        return self.uri_and_body(request, cis, method, request_args, **kwargs)

    def authorization_request_info(self,
                                   request_args=None,
                                   extra_args=None,
                                   **kwargs):
        return self.request_info(AuthorizationRequest, "GET", request_args,
                                 extra_args, **kwargs)

    def parse_response(self,
                       response,
                       info="",
                       sformat="json",
                       state="",
                       **kwargs):
        """
        Parse a response

        :param response: Response type
        :param info: The response, can be either in a JSON or an urlencoded
            format
        :param sformat: Which serialization that was used
        :param state: The state
        :param kwargs: Extra key word arguments
        :return: The parsed and to some extend verified response
        """

        _r2e = self.response2error

        if sformat == "urlencoded":
            if '?' in info or '#' in info:
                parts = urlparse.urlparse(info)
                scheme, netloc, path, params, query, fragment = parts[:6]
                # either query of fragment
                if query:
                    info = query
                else:
                    info = fragment

        err = None
        try:
            resp = response().deserialize(info, sformat, **kwargs)
            if "error" in resp and not isinstance(resp, ErrorResponse):
                resp = None
                try:
                    errmsgs = _r2e[response.__name__]
                except KeyError:
                    errmsgs = [ErrorResponse]

                try:
                    for errmsg in errmsgs:
                        try:
                            resp = errmsg().deserialize(info, sformat)
                            resp.verify()
                            break
                        except Exception, aerr:
                            resp = None
                            err = aerr
                except KeyError:
                    pass
            elif resp.only_extras():
                resp = None
            else:
                if "key" not in kwargs and "keyjar" not in kwargs:
                    kwargs["keyjar"] = self.keyjar
                verf = resp.verify(**kwargs)
                if not verf:
                    raise PyoidcError("Verification of the response failed")
                if resp.type() == "AuthorizationResponse" and \
                        "scope" not in resp:
                    try:
                        resp["scope"] = kwargs["scope"]
                    except KeyError:
                        pass
        except Exception, derr:
            resp = None
            err = derr

        if not resp:
            if err:
                raise err
            else:
                raise ResponseError("Missing or faulty response")

        if resp.type() in ["AuthorizationResponse", "AccessTokenResponse"]:
            try:
                _state = resp["state"]
            except (AttributeError, KeyError):
                _state = ""

            if not _state:
                _state = state

            try:
                self.grant[_state].update(resp)
            except KeyError:
                self.grant[_state] = self.grant_class(resp=resp)

        return resp
Ejemplo n.º 19
0
from oic.oic import OpenIDRequest
from oic.utils.keyio import KeyJar

request = 'eyJhbGciOiAiQTEyOEtXIiwgImVuYyI6ICJBMTI4Q0JDLUhTMjU2In0.KLuBoByxG54JdHz5OBjpMjx_6ivPNi6oanRZ5UN38VzcTHw2ftv6FA.Tysc6pZ_AA_X7j95bRSHiQ.YxG8Kf3GVWXnMfzOo7Hva32eHcaNBgpcT3iPIEWq76SgKNCpdnGSKOSiFtJbvCdpXwfneXIAS3uFktQoyo9x698IHp92bAZD9M31G0GfaWh7oZgcHrBkn_QPBFavEQeTSfbvhYya3Wp2U9DrL9CrT6ytTo7mbx6b9drUpSe2waIGJkugOOFCiqr19zXXFDT1Qc04sCGhRwz_0JYMYI9qGULQ0Ws2zQVlcE_iMoA6cFs.gDd8Ns2fJRj18A6gg4-T4g'

keyjar = KeyJar()
keyjar.add_symmetric(
    "jJFjKcsaygxp", "f75695a7a87acccdef6c7c978d5e782db1b947e0f6990b050f58940b")

OpenIDRequest().from_jwt(request, keyjar=keyjar, sender="jJFjKcsaygxp")
Ejemplo n.º 20
0
class Client(PBase):
    _endpoints = ENDPOINTS

    def __init__(self, client_id=None, client_authn_method=None,
                 keyjar=None, verify_ssl=True, config=None, client_cert=None,
                 timeout=5):
        """

        :param client_id: The client identifier
        :param client_authn_method: Methods that this client can use to
            authenticate itself. It's a dictionary with method names as
            keys and method classes as values.
        :param keyjar: The keyjar for this client.
        :param verify_ssl: Whether the SSL certificate should be verified.
        :param client_cert: A client certificate to use.
        :param timeout: Timeout for requests library. Can be specified either as
            a single integer or as a tuple of integers. For more details, refer to
            ``requests`` documentation.
        :return: Client instance
        """

        PBase.__init__(self, verify_ssl=verify_ssl, keyjar=keyjar,
                       client_cert=client_cert, timeout=timeout)

        self.client_id = client_id
        self.client_authn_method = client_authn_method

        self.nonce = None

        self.grant = {}
        self.state2nonce = {}
        # own endpoint
        self.redirect_uris = [None]

        # service endpoints
        self.authorization_endpoint = None
        self.token_endpoint = None
        self.token_revocation_endpoint = None

        self.request2endpoint = REQUEST2ENDPOINT
        self.response2error = RESPONSE2ERROR
        self.grant_class = Grant
        self.token_class = Token

        self.provider_info = {}
        self._c_secret = None
        self.kid = {"sig": {}, "enc": {}}
        self.authz_req = None

        # the OAuth issuer is the URL of the authorization server's
        # configuration information location
        self.config = config or {}
        try:
            self.issuer = self.config['issuer']
        except KeyError:
            self.issuer = ''
        self.allow = {}
        self.provider_info = {}

    def store_response(self, clinst, text):
        pass

    def get_client_secret(self):
        return self._c_secret

    def set_client_secret(self, val):
        if not val:
            self._c_secret = ""
        else:
            self._c_secret = val
            # client uses it for signing
            # Server might also use it for signing which means the
            # client uses it for verifying server signatures
            if self.keyjar is None:
                self.keyjar = KeyJar()
            self.keyjar.add_symmetric("", str(val))

    client_secret = property(get_client_secret, set_client_secret)

    def reset(self):
        self.nonce = None

        self.grant = {}

        self.authorization_endpoint = None
        self.token_endpoint = None
        self.redirect_uris = None

    def grant_from_state(self, state):
        for key, grant in self.grant.items():
            if key == state:
                return grant

        return None

    def _parse_args(self, request, **kwargs):
        ar_args = kwargs.copy()

        for prop in request.c_param.keys():
            if prop in ar_args:
                continue
            else:
                if prop == "redirect_uri":
                    _val = getattr(self, "redirect_uris", [None])[0]
                    if _val:
                        ar_args[prop] = _val
                else:
                    _val = getattr(self, prop, None)
                    if _val:
                        ar_args[prop] = _val

        return ar_args

    def _endpoint(self, endpoint, **kwargs):
        try:
            uri = kwargs[endpoint]
            if uri:
                del kwargs[endpoint]
        except KeyError:
            uri = ""

        if not uri:
            try:
                uri = getattr(self, endpoint)
            except Exception:
                raise MissingEndpoint("No '%s' specified" % endpoint)

        if not uri:
            raise MissingEndpoint("No '%s' specified" % endpoint)

        return uri

    def get_grant(self, state, **kwargs):
        try:
            return self.grant[state]
        except KeyError:
            raise GrantError("No grant found for state:'%s'" % state)

    def get_token(self, also_expired=False, **kwargs):
        try:
            return kwargs["token"]
        except KeyError:
            grant = self.get_grant(**kwargs)

            try:
                token = grant.get_token(kwargs["scope"])
            except KeyError:
                token = grant.get_token("")
                if not token:
                    try:
                        token = self.grant[kwargs["state"]].get_token("")
                    except KeyError:
                        raise TokenError("No token found for scope")

        if token is None:
            raise TokenError("No suitable token found")

        if also_expired:
            return token
        elif token.is_valid():
            return token
        else:
            raise TokenError("Token has expired")

    def clean_tokens(self):
        """Clean replaced and invalid tokens."""
        for state in self.grant:
            grant = self.get_grant(state)
            for token in grant.tokens:
                if token.replaced or not token.is_valid():
                    grant.delete_token(token)

    def construct_request(self, request, request_args=None, extra_args=None):
        if request_args is None:
            request_args = {}

        kwargs = self._parse_args(request, **request_args)

        if extra_args:
            kwargs.update(extra_args)
        logger.debug("request: %s" % sanitize(request))
        return request(**kwargs)

    def construct_Message(self, request=Message, request_args=None,
                          extra_args=None, **kwargs):

        return self.construct_request(request, request_args, extra_args)

    def construct_AuthorizationRequest(self, request=AuthorizationRequest,
                                       request_args=None, extra_args=None,
                                       **kwargs):

        if request_args is not None:
            try:  # change default
                new = request_args["redirect_uri"]
                if new:
                    self.redirect_uris = [new]
            except KeyError:
                pass
        else:
            request_args = {}

        if "client_id" not in request_args:
            request_args["client_id"] = self.client_id
        elif not request_args["client_id"]:
            request_args["client_id"] = self.client_id

        return self.construct_request(request, request_args, extra_args)

    def construct_AccessTokenRequest(self,
                                     request=AccessTokenRequest,
                                     request_args=None, extra_args=None,
                                     **kwargs):

        if request_args is None:
            request_args = {}
        if request is not ROPCAccessTokenRequest:
            grant = self.get_grant(**kwargs)

            if not grant.is_valid():
                raise GrantExpired("Authorization Code to old %s > %s" % (
                    utc_time_sans_frac(),
                    grant.grant_expiration_time))

            request_args["code"] = grant.code

        try:
            request_args['state'] = kwargs['state']
        except KeyError:
            pass

        if "grant_type" not in request_args:
            request_args["grant_type"] = "authorization_code"

        if "client_id" not in request_args:
            request_args["client_id"] = self.client_id
        elif not request_args["client_id"]:
            request_args["client_id"] = self.client_id
        return self.construct_request(request, request_args, extra_args)

    def construct_RefreshAccessTokenRequest(self,
                                            request=RefreshAccessTokenRequest,
                                            request_args=None, extra_args=None,
                                            **kwargs):

        if request_args is None:
            request_args = {}

        token = self.get_token(also_expired=True, **kwargs)

        request_args["refresh_token"] = token.refresh_token

        try:
            request_args["scope"] = token.scope
        except AttributeError:
            pass

        return self.construct_request(request, request_args, extra_args)

    def construct_ResourceRequest(self, request=ResourceRequest,
                                  request_args=None, extra_args=None,
                                  **kwargs):

        if request_args is None:
            request_args = {}

        token = self.get_token(**kwargs)

        request_args["access_token"] = token.access_token
        return self.construct_request(request, request_args, extra_args)

    def uri_and_body(self, reqmsg, cis, method="POST", request_args=None,
                     **kwargs):

        if "endpoint" in kwargs and kwargs["endpoint"]:
            uri = kwargs["endpoint"]
        else:
            uri = self._endpoint(self.request2endpoint[reqmsg.__name__],
                                 **request_args)

        uri, body, kwargs = get_or_post(uri, method, cis, **kwargs)
        try:
            h_args = {"headers": kwargs["headers"]}
        except KeyError:
            h_args = {}

        return uri, body, h_args, cis

    def request_info(self, request, method="POST", request_args=None,
                     extra_args=None, lax=False, **kwargs):

        if request_args is None:
            request_args = {}

        try:
            cls = getattr(self, "construct_%s" % request.__name__)
            cis = cls(request_args=request_args, extra_args=extra_args,
                      **kwargs)
        except AttributeError:
            cis = self.construct_request(request, request_args, extra_args)

        if self.events:
            self.events.store('Protocol request', cis)

        if 'nonce' in cis and 'state' in cis:
            self.state2nonce[cis['state']] = cis['nonce']

        cis.lax = lax

        if "authn_method" in kwargs:
            h_arg = self.init_authentication_method(cis,
                                                    request_args=request_args,
                                                    **kwargs)
        else:
            h_arg = None

        if h_arg:
            if "headers" in kwargs.keys():
                kwargs["headers"].update(h_arg["headers"])
            else:
                kwargs["headers"] = h_arg["headers"]

        return self.uri_and_body(request, cis, method, request_args,
                                 **kwargs)

    def authorization_request_info(self, request_args=None, extra_args=None,
                                   **kwargs):
        return self.request_info(AuthorizationRequest, "GET",
                                 request_args, extra_args, **kwargs)

    def get_urlinfo(self, info):
        if '?' in info or '#' in info:
            parts = urlparse(info)
            scheme, netloc, path, params, query, fragment = parts[:6]
            # either query of fragment
            if query:
                info = query
            else:
                info = fragment
        return info

    def parse_response(self, response, info="", sformat="json", state="",
                       **kwargs):
        """
        Parse a response

        :param response: Response type
        :param info: The response, can be either in a JSON or an urlencoded
            format
        :param sformat: Which serialization that was used
        :param state: The state
        :param kwargs: Extra key word arguments
        :return: The parsed and to some extend verified response
        """

        _r2e = self.response2error

        if sformat == "urlencoded":
            info = self.get_urlinfo(info)

        resp = response().deserialize(info, sformat, **kwargs)
        msg = 'Initial response parsing => "{}"'
        logger.debug(msg.format(sanitize(resp.to_dict())))
        if self.events:
            self.events.store('Response', resp.to_dict())

        if "error" in resp and not isinstance(resp, ErrorResponse):
            resp = None
            try:
                errmsgs = _r2e[response.__name__]
            except KeyError:
                errmsgs = [ErrorResponse]

            try:
                for errmsg in errmsgs:
                    try:
                        resp = errmsg().deserialize(info, sformat)
                        resp.verify()
                        break
                    except Exception:
                        resp = None
            except KeyError:
                pass
        elif resp.only_extras():
            resp = None
        else:
            kwargs["client_id"] = self.client_id
            try:
                kwargs['iss'] = self.provider_info['issuer']
            except (KeyError, AttributeError):
                if self.issuer:
                    kwargs['iss'] = self.issuer

            if "key" not in kwargs and "keyjar" not in kwargs:
                kwargs["keyjar"] = self.keyjar

            logger.debug("Verify response with {}".format(sanitize(kwargs)))
            verf = resp.verify(**kwargs)

            if not verf:
                logger.error('Verification of the response failed')
                raise PyoidcError("Verification of the response failed")
            if resp.type() == "AuthorizationResponse" and "scope" not in resp:
                try:
                    resp["scope"] = kwargs["scope"]
                except KeyError:
                    pass

        if not resp:
            logger.error('Missing or faulty response')
            raise ResponseError("Missing or faulty response")

        self.store_response(resp, info)

        if resp.type() in ["AuthorizationResponse", "AccessTokenResponse"]:
            try:
                _state = resp["state"]
            except (AttributeError, KeyError):
                _state = ""

            if not _state:
                _state = state

            try:
                self.grant[_state].update(resp)
            except KeyError:
                self.grant[_state] = self.grant_class(resp=resp)

        return resp

    def init_authentication_method(self, cis, authn_method, request_args=None,
                                   http_args=None, **kwargs):

        if http_args is None:
            http_args = {}
        if request_args is None:
            request_args = {}

        if authn_method:
            return self.client_authn_method[authn_method](self).construct(
                cis, request_args, http_args, **kwargs)
        else:
            return http_args

    def parse_request_response(self, reqresp, response, body_type, state="",
                               **kwargs):

        if reqresp.status_code in SUCCESSFUL:
            body_type = verify_header(reqresp, body_type)
        elif reqresp.status_code in [302, 303]:  # redirect
            return reqresp
        elif reqresp.status_code == 500:
            logger.error("(%d) %s" % (reqresp.status_code,
                                      sanitize(reqresp.text)))
            raise ParseError("ERROR: Something went wrong: %s" % reqresp.text)
        elif reqresp.status_code in [400, 401]:
            # expecting an error response
            if issubclass(response, ErrorResponse):
                pass
        else:
            logger.error("(%d) %s" % (reqresp.status_code,
                                      sanitize(reqresp.text)))
            raise HttpError("HTTP ERROR: %s [%s] on %s" % (
                reqresp.text, reqresp.status_code, reqresp.url))

        if response:
            if body_type == 'txt':
                # no meaning trying to parse unstructured text
                return reqresp.text
            return self.parse_response(response, reqresp.text, body_type,
                                       state, **kwargs)

        # could be an error response
        if reqresp.status_code in [200, 400, 401]:
            if body_type == 'txt':
                body_type = 'urlencoded'
            try:
                err = ErrorResponse().deserialize(reqresp.message,
                                                  method=body_type)
                try:
                    err.verify()
                except PyoidcError:
                    pass
                else:
                    return err
            except Exception:
                pass

        return reqresp

    def request_and_return(self, url, response=None, method="GET", body=None,
                           body_type="json", state="", http_args=None,
                           **kwargs):
        """
        :param url: The URL to which the request should be sent
        :param response: Response type
        :param method: Which HTTP method to use
        :param body: A message body if any
        :param body_type: The format of the body of the return message
        :param http_args: Arguments for the HTTP client
        :return: A cls or ErrorResponse instance or the HTTP response
            instance if no response body was expected.
        """

        if http_args is None:
            http_args = {}

        try:
            resp = self.http_request(url, method, data=body, **http_args)
        except Exception:
            raise

        if "keyjar" not in kwargs:
            kwargs["keyjar"] = self.keyjar

        return self.parse_request_response(resp, response, body_type, state,
                                           **kwargs)

    def do_authorization_request(self, request=AuthorizationRequest,
                                 state="", body_type="", method="GET",
                                 request_args=None, extra_args=None,
                                 http_args=None,
                                 response_cls=AuthorizationResponse,
                                 **kwargs):

        if state:
            try:
                request_args["state"] = state
            except TypeError:
                request_args = {"state": state}

        kwargs['authn_endpoint'] = 'authorization'
        url, body, ht_args, csi = self.request_info(request, method,
                                                    request_args, extra_args,
                                                    **kwargs)

        try:
            self.authz_req[request_args["state"]] = csi
        except TypeError:
            pass

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)

        try:
            algs = kwargs["algs"]
        except KeyError:
            algs = {}

        resp = self.request_and_return(url, response_cls, method, body,
                                       body_type, state=state,
                                       http_args=http_args, algs=algs)

        if isinstance(resp, Message):
            if resp.type() in RESPONSE2ERROR["AuthorizationResponse"]:
                resp.state = csi.state

        return resp

    def do_access_token_request(self, request=AccessTokenRequest,
                                scope="", state="", body_type="json",
                                method="POST", request_args=None,
                                extra_args=None, http_args=None,
                                response_cls=AccessTokenResponse,
                                authn_method="", **kwargs):

        kwargs['authn_endpoint'] = 'token'
        # method is default POST
        url, body, ht_args, csi = self.request_info(request, method=method,
                                                    request_args=request_args,
                                                    extra_args=extra_args,
                                                    scope=scope, state=state,
                                                    authn_method=authn_method,
                                                    **kwargs)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)

        if self.events is not None:
            self.events.store('request_url', url)
            self.events.store('request_http_args', http_args)
            self.events.store('Request', body)

        logger.debug("<do_access_token> URL: %s, Body: %s" % (url,
                                                              sanitize(body)))
        logger.debug("<do_access_token> response_cls: %s" % response_cls)

        return self.request_and_return(url, response_cls, method, body,
                                       body_type, state=state,
                                       http_args=http_args, **kwargs)

    def do_access_token_refresh(self, request=RefreshAccessTokenRequest,
                                state="", body_type="json", method="POST",
                                request_args=None, extra_args=None,
                                http_args=None,
                                response_cls=AccessTokenResponse,
                                authn_method="", **kwargs):

        token = self.get_token(also_expired=True, state=state, **kwargs)
        kwargs['authn_endpoint'] = 'refresh'
        url, body, ht_args, csi = self.request_info(request, method=method,
                                                    request_args=request_args,
                                                    extra_args=extra_args,
                                                    token=token,
                                                    authn_method=authn_method)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)

        response = self.request_and_return(url, response_cls, method, body,
                                           body_type, state=state,
                                           http_args=http_args)
        if token.replaced:
            grant = self.get_grant(state)
            grant.delete_token(token)
        return response

    def do_any(self, request, endpoint="", scope="", state="", body_type="json",
               method="POST", request_args=None, extra_args=None,
               http_args=None, response=None, authn_method=""):

        url, body, ht_args, _ = self.request_info(request, method=method,
                                                  request_args=request_args,
                                                  extra_args=extra_args,
                                                  scope=scope, state=state,
                                                  authn_method=authn_method,
                                                  endpoint=endpoint)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)

        return self.request_and_return(url, response, method, body, body_type,
                                       state=state, http_args=http_args)

    def fetch_protected_resource(self, uri, method="GET", headers=None,
                                 state="", **kwargs):

        if "token" in kwargs and kwargs["token"]:
            token = kwargs["token"]
            request_args = {"access_token": token}
        else:
            try:
                token = self.get_token(state=state, **kwargs)
            except ExpiredToken:
                # The token is to old, refresh
                self.do_access_token_refresh(state=state)
                token = self.get_token(state=state, **kwargs)
            request_args = {"access_token": token.access_token}

        if headers is None:
            headers = {}

        if "authn_method" in kwargs:
            http_args = self.init_authentication_method(
                request_args=request_args, **kwargs)
        else:
            # If nothing defined this is the default
            http_args = self.client_authn_method[
                "bearer_header"](self).construct(request_args=request_args)

        headers.update(http_args["headers"])

        logger.debug("Fetch URI: %s" % uri)
        return self.http_request(uri, method, headers=headers)

    def add_code_challenge(self):
        """
        PKCE RFC 7636 support

        :return:
        """
        try:
            cv_len = self.config['code_challenge']['length']
        except KeyError:
            cv_len = 64  # Use default

        code_verifier = unreserved(cv_len)
        _cv = code_verifier.encode('ascii')

        try:
            _method = self.config['code_challenge']['method']
        except KeyError:
            _method = 'S256'

        try:
            _h = CC_METHOD[_method](_cv).digest()
            code_challenge = b64e(_h).decode('ascii')
        except KeyError:
            raise Unsupported(
                'PKCE Transformation method:{}'.format(_method))

        # TODO store code_verifier

        return {"code_challenge": code_challenge,
                "code_challenge_method": _method}, code_verifier

    def handle_provider_config(self, pcr, issuer, keys=True, endpoints=True):
        """
        Deal with Provider Config Response
        :param pcr: The ProviderConfigResponse instance
        :param issuer: The one I thought should be the issuer of the config
        :param keys: Should I deal with keys
        :param endpoints: Should I deal with endpoints, that is store them
        as attributes in self.
        """

        if "issuer" in pcr:
            _pcr_issuer = pcr["issuer"]
            if pcr["issuer"].endswith("/"):
                if issuer.endswith("/"):
                    _issuer = issuer
                else:
                    _issuer = issuer + "/"
            else:
                if issuer.endswith("/"):
                    _issuer = issuer[:-1]
                else:
                    _issuer = issuer

            if not self.allow.get("issuer_mismatch", False) and _issuer != _pcr_issuer:
                raise PyoidcError("provider info issuer mismatch '%s' != '%s'" % (_issuer, _pcr_issuer))

            self.provider_info = pcr
        else:
            _pcr_issuer = issuer

        self.issuer = _pcr_issuer

        if endpoints:
            for key, val in pcr.items():
                if key.endswith("_endpoint"):
                    setattr(self, key, val)

        if keys:
            if self.keyjar is None:
                self.keyjar = KeyJar()

            self.keyjar.load_keys(pcr, _pcr_issuer)

    def provider_config(self, issuer, keys=True, endpoints=True,
                        response_cls=ASConfigurationResponse,
                        serv_pattern=OIDCONF_PATTERN):
        if issuer.endswith("/"):
            _issuer = issuer[:-1]
        else:
            _issuer = issuer

        url = serv_pattern % _issuer

        pcr = None
        r = self.http_request(url)
        if r.status_code == 200:
            pcr = response_cls().from_json(r.text)
        elif r.status_code == 302:
            while r.status_code == 302:
                r = self.http_request(r.headers["location"])
                if r.status_code == 200:
                    pcr = response_cls().from_json(r.text)
                    break

        if pcr is None:
            raise PyoidcError("Trying '%s', status %s" % (url, r.status_code))

        self.handle_provider_config(pcr, issuer, keys, endpoints)

        return pcr
Ejemplo n.º 21
0
class Client(PBase):
    _endpoints = ENDPOINTS

    def __init__(self,
                 client_id=None,
                 ca_certs=None,
                 client_authn_method=None,
                 keyjar=None,
                 verify_ssl=True):
        """

        :param client_id: The client identifier
        :param ca_certs: Certificates used to verify HTTPS certificates
        :param client_authn_method: Methods that this client can use to
            authenticate itself. It's a dictionary with method names as
            keys and method classes as values.
        :param verify_ssl: Whether the SSL certificate should be verified.
        :return: Client instance
        """

        PBase.__init__(self, ca_certs, verify_ssl=verify_ssl)

        self.client_id = client_id
        self.client_authn_method = client_authn_method
        self.keyjar = keyjar or KeyJar(verify_ssl=verify_ssl)
        self.verify_ssl = verify_ssl
        # self.secret_type = "basic "

        # self.state = None
        self.nonce = None

        self.grant = {}

        # own endpoint
        self.redirect_uris = [None]

        # service endpoints
        self.authorization_endpoint = None
        self.token_endpoint = None
        self.token_revocation_endpoint = None

        self.request2endpoint = REQUEST2ENDPOINT
        self.response2error = RESPONSE2ERROR
        self.grant_class = Grant
        self.token_class = Token

        self.provider_info = {}
        self._c_secret = None
        self.kid = {"sig": {}, "enc": {}}
        self.authz_req = None

    def store_response(self, clinst, text):
        pass

    def get_client_secret(self):
        return self._c_secret

    def set_client_secret(self, val):
        if not val:
            self._c_secret = ""
        else:
            self._c_secret = val
            # client uses it for signing
            # Server might also use it for signing which means the
            # client uses it for verifying server signatures
            if self.keyjar is None:
                self.keyjar = KeyJar()
            self.keyjar.add_symmetric("", str(val), ["sig"])

    client_secret = property(get_client_secret, set_client_secret)

    def reset(self):
        # self.state = None
        self.nonce = None

        self.grant = {}

        self.authorization_endpoint = None
        self.token_endpoint = None
        self.redirect_uris = None

    def grant_from_state(self, state):
        for key, grant in self.grant.items():
            if key == state:
                return grant

        return None

    def _parse_args(self, request, **kwargs):
        ar_args = kwargs.copy()

        for prop in request.c_param.keys():
            if prop in ar_args:
                continue
            else:
                if prop == "redirect_uri":
                    _val = getattr(self, "redirect_uris", [None])[0]
                    if _val:
                        ar_args[prop] = _val
                else:
                    _val = getattr(self, prop, None)
                    if _val:
                        ar_args[prop] = _val

        return ar_args

    def _endpoint(self, endpoint, **kwargs):
        try:
            uri = kwargs[endpoint]
            if uri:
                del kwargs[endpoint]
        except KeyError:
            uri = ""

        if not uri:
            try:
                uri = getattr(self, endpoint)
            except Exception:
                raise MissingEndpoint("No '%s' specified" % endpoint)

        if not uri:
            raise MissingEndpoint("No '%s' specified" % endpoint)

        return uri

    def get_grant(self, state, **kwargs):
        # try:
        # _state = kwargs["state"]
        # if not _state:
        #         _state = self.state
        # except KeyError:
        #     _state = self.state

        try:
            return self.grant[state]
        except:
            raise GrantError("No grant found for state:'%s'" % state)

    def get_token(self, also_expired=False, **kwargs):
        try:
            return kwargs["token"]
        except KeyError:
            grant = self.get_grant(**kwargs)

            try:
                token = grant.get_token(kwargs["scope"])
            except KeyError:
                token = grant.get_token("")
                if not token:
                    try:
                        token = self.grant[kwargs["state"]].get_token("")
                    except KeyError:
                        raise TokenError("No token found for scope")

        if token is None:
            raise TokenError("No suitable token found")

        if also_expired:
            return token
        elif token.is_valid():
            return token
        else:
            raise TokenError("Token has expired")

    def construct_request(self, request, request_args=None, extra_args=None):
        if request_args is None:
            request_args = {}

        # logger.debug("request_args: %s" % request_args)
        kwargs = self._parse_args(request, **request_args)

        if extra_args:
            kwargs.update(extra_args)
            # logger.debug("kwargs: %s" % kwargs)
        # logger.debug("request: %s" % request)
        return request(**kwargs)

    def construct_Message(self,
                          request=Message,
                          request_args=None,
                          extra_args=None,
                          **kwargs):

        return self.construct_request(request, request_args, extra_args)

    # noinspection PyUnusedLocal
    def construct_AuthorizationRequest(self,
                                       request=AuthorizationRequest,
                                       request_args=None,
                                       extra_args=None,
                                       **kwargs):

        if request_args is not None:
            try:  # change default
                new = request_args["redirect_uri"]
                if new:
                    self.redirect_uris = [new]
            except KeyError:
                pass
        else:
            request_args = {}

        if "client_id" not in request_args:
            request_args["client_id"] = self.client_id
        elif not request_args["client_id"]:
            request_args["client_id"] = self.client_id

        return self.construct_request(request, request_args, extra_args)

    # noinspection PyUnusedLocal
    def construct_AccessTokenRequest(self,
                                     request=AccessTokenRequest,
                                     request_args=None,
                                     extra_args=None,
                                     **kwargs):

        grant = self.get_grant(**kwargs)

        if not grant.is_valid():
            raise GrantExpired(
                "Authorization Code to old %s > %s" %
                (utc_time_sans_frac(), grant.grant_expiration_time))

        if request_args is None:
            request_args = {}

        request_args["code"] = grant.code

        if "grant_type" not in request_args:
            request_args["grant_type"] = "authorization_code"

        if "client_id" not in request_args:
            request_args["client_id"] = self.client_id
        elif not request_args["client_id"]:
            request_args["client_id"] = self.client_id
        return self.construct_request(request, request_args, extra_args)

    def construct_RefreshAccessTokenRequest(self,
                                            request=RefreshAccessTokenRequest,
                                            request_args=None,
                                            extra_args=None,
                                            **kwargs):

        if request_args is None:
            request_args = {}

        token = self.get_token(also_expired=True, **kwargs)

        request_args["refresh_token"] = token.refresh_token

        try:
            request_args["scope"] = token.scope
        except AttributeError:
            pass

        return self.construct_request(request, request_args, extra_args)

    def construct_TokenRevocationRequest(self,
                                         request=TokenRevocationRequest,
                                         request_args=None,
                                         extra_args=None,
                                         **kwargs):

        if request_args is None:
            request_args = {}

        token = self.get_token(**kwargs)

        request_args["token"] = token.access_token
        return self.construct_request(request, request_args, extra_args)

    def construct_ResourceRequest(self,
                                  request=ResourceRequest,
                                  request_args=None,
                                  extra_args=None,
                                  **kwargs):

        if request_args is None:
            request_args = {}

        token = self.get_token(**kwargs)

        request_args["access_token"] = token.access_token
        return self.construct_request(request, request_args, extra_args)

    def uri_and_body(self,
                     reqmsg,
                     cis,
                     method="POST",
                     request_args=None,
                     **kwargs):

        if "endpoint" in kwargs and kwargs["endpoint"]:
            uri = kwargs["endpoint"]
        else:
            uri = self._endpoint(self.request2endpoint[reqmsg.__name__],
                                 **request_args)

        uri, body, kwargs = get_or_post(uri, method, cis, **kwargs)
        try:
            h_args = {"headers": kwargs["headers"]}
        except KeyError:
            h_args = {}

        return uri, body, h_args, cis

    def request_info(self,
                     request,
                     method="POST",
                     request_args=None,
                     extra_args=None,
                     lax=False,
                     **kwargs):

        if request_args is None:
            request_args = {}

        try:
            cls = getattr(self, "construct_%s" % request.__name__)
            cis = cls(request_args=request_args,
                      extra_args=extra_args,
                      **kwargs)
        except AttributeError:
            cis = self.construct_request(request, request_args, extra_args)

        cis.lax = lax

        if "authn_method" in kwargs:
            h_arg = self.init_authentication_method(cis,
                                                    request_args=request_args,
                                                    **kwargs)
        else:
            h_arg = None

        if h_arg:
            if "headers" in kwargs.keys():
                kwargs["headers"].update(h_arg["headers"])
            else:
                kwargs["headers"] = h_arg["headers"]

        return self.uri_and_body(request, cis, method, request_args, **kwargs)

    def authorization_request_info(self,
                                   request_args=None,
                                   extra_args=None,
                                   **kwargs):
        return self.request_info(AuthorizationRequest, "GET", request_args,
                                 extra_args, **kwargs)

    def get_urlinfo(self, info):
        if '?' in info or '#' in info:
            parts = urlparse.urlparse(info)
            scheme, netloc, path, params, query, fragment = parts[:6]
            # either query of fragment
            if query:
                info = query
            else:
                info = fragment
        return info

    def parse_response(self,
                       response,
                       info="",
                       sformat="json",
                       state="",
                       **kwargs):
        """
        Parse a response

        :param response: Response type
        :param info: The response, can be either in a JSON or an urlencoded
            format
        :param sformat: Which serialization that was used
        :param state: The state
        :param kwargs: Extra key word arguments
        :return: The parsed and to some extend verified response
        """

        _r2e = self.response2error

        if sformat == "urlencoded":
            info = self.get_urlinfo(info)

        resp = response().deserialize(info, sformat, **kwargs)
        if "error" in resp and not isinstance(resp, ErrorResponse):
            resp = None
            try:
                errmsgs = _r2e[response.__name__]
            except KeyError:
                errmsgs = [ErrorResponse]

            try:
                for errmsg in errmsgs:
                    try:
                        resp = errmsg().deserialize(info, sformat)
                        resp.verify()
                        break
                    except Exception, aerr:
                        resp = None
                        err = aerr
            except KeyError:
                pass
        elif resp.only_extras():
            resp = None
        else:
            kwargs["client_id"] = self.client_id
            if "key" not in kwargs and "keyjar" not in kwargs:
                kwargs["keyjar"] = self.keyjar
            verf = resp.verify(**kwargs)
            if not verf:
                raise PyoidcError("Verification of the response failed")
            if resp.type() == "AuthorizationResponse" and \
                    "scope" not in resp:
                try:
                    resp["scope"] = kwargs["scope"]
                except KeyError:
                    pass

        if not resp:
            raise ResponseError("Missing or faulty response")

        self.store_response(resp, info)

        if resp.type() in ["AuthorizationResponse", "AccessTokenResponse"]:
            try:
                _state = resp["state"]
            except (AttributeError, KeyError):
                _state = ""

            if not _state:
                _state = state

            try:
                self.grant[_state].update(resp)
            except KeyError:
                self.grant[_state] = self.grant_class(resp=resp)

        return resp

    #noinspection PyUnusedLocal
    def init_authentication_method(self,
                                   cis,
                                   authn_method,
                                   request_args=None,
                                   http_args=None,
                                   **kwargs):

        if http_args is None:
            http_args = {}
        if request_args is None:
            request_args = {}

        if authn_method:
            return self.client_authn_method[authn_method](self).construct(
                cis, request_args, http_args, **kwargs)
        else:
            return http_args

    def parse_request_response(self,
                               reqresp,
                               response,
                               body_type,
                               state="",
                               **kwargs):

        if reqresp.status_code in SUCCESSFUL:
            body_type = verify_header(reqresp, body_type)
        elif reqresp.status_code == 302:  # redirect
            pass
        elif reqresp.status_code == 500:
            logger.error("(%d) %s" % (reqresp.status_code, reqresp.text))
            raise ParseError("ERROR: Something went wrong: %s" % reqresp.text)
        elif reqresp.status_code in [400, 401]:
            #expecting an error response
            if issubclass(response, ErrorResponse):
                pass
        else:
            logger.error("(%d) %s" % (reqresp.status_code, reqresp.text))
            raise HttpError("HTTP ERROR: %s [%s] on %s" %
                            (reqresp.text, reqresp.status_code, reqresp.url))

        if body_type:
            if response:
                return self.parse_response(response, reqresp.text, body_type,
                                           state, **kwargs)
            else:
                raise OtherError("Didn't expect a response body")
        else:
            return reqresp

    def request_and_return(self,
                           url,
                           response=None,
                           method="GET",
                           body=None,
                           body_type="json",
                           state="",
                           http_args=None,
                           **kwargs):
        """
        :param url: The URL to which the request should be sent
        :param response: Response type
        :param method: Which HTTP method to use
        :param body: A message body if any
        :param body_type: The format of the body of the return message
        :param http_args: Arguments for the HTTP client
        :return: A cls or ErrorResponse instance or the HTTP response
            instance if no response body was expected.
        """

        if http_args is None:
            http_args = {}

        try:
            resp = self.http_request(url, method, data=body, **http_args)
        except Exception:
            raise

        if "keyjar" not in kwargs:
            kwargs["keyjar"] = self.keyjar

        return self.parse_request_response(resp, response, body_type, state,
                                           **kwargs)

    def do_authorization_request(self,
                                 request=AuthorizationRequest,
                                 state="",
                                 body_type="",
                                 method="GET",
                                 request_args=None,
                                 extra_args=None,
                                 http_args=None,
                                 response_cls=AuthorizationResponse,
                                 **kwargs):

        if state:
            request_args["state"] = state

        url, body, ht_args, csi = self.request_info(request, method,
                                                    request_args, extra_args,
                                                    **kwargs)

        try:
            self.authz_req[request_args["state"]] = csi
        except TypeError:
            pass

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)

        try:
            algs = kwargs["algs"]
        except:
            algs = {}

        resp = self.request_and_return(url,
                                       response_cls,
                                       method,
                                       body,
                                       body_type,
                                       state=state,
                                       http_args=http_args,
                                       algs=algs)

        if isinstance(resp, Message):
            if resp.type() in RESPONSE2ERROR["AuthorizationRequest"]:
                resp.state = csi.state

        return resp

    def do_access_token_request(self,
                                request=AccessTokenRequest,
                                scope="",
                                state="",
                                body_type="json",
                                method="POST",
                                request_args=None,
                                extra_args=None,
                                http_args=None,
                                response_cls=AccessTokenResponse,
                                authn_method="",
                                **kwargs):

        # method is default POST
        url, body, ht_args, csi = self.request_info(request,
                                                    method=method,
                                                    request_args=request_args,
                                                    extra_args=extra_args,
                                                    scope=scope,
                                                    state=state,
                                                    authn_method=authn_method,
                                                    **kwargs)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)

        logger.debug("<do_access_token> URL: %s, Body: %s" % (url, body))
        logger.debug("<do_access_token> response_cls: %s" % response_cls)

        return self.request_and_return(url,
                                       response_cls,
                                       method,
                                       body,
                                       body_type,
                                       state=state,
                                       http_args=http_args,
                                       **kwargs)

    def do_access_token_refresh(self,
                                request=RefreshAccessTokenRequest,
                                state="",
                                body_type="json",
                                method="POST",
                                request_args=None,
                                extra_args=None,
                                http_args=None,
                                response_cls=AccessTokenResponse,
                                authn_method="",
                                **kwargs):

        token = self.get_token(also_expired=True, state=state, **kwargs)

        url, body, ht_args, csi = self.request_info(request,
                                                    method=method,
                                                    request_args=request_args,
                                                    extra_args=extra_args,
                                                    token=token,
                                                    authn_method=authn_method)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)

        return self.request_and_return(url,
                                       response_cls,
                                       method,
                                       body,
                                       body_type,
                                       state=state,
                                       http_args=http_args)

    def do_revocate_token(self,
                          request=TokenRevocationRequest,
                          scope="",
                          state="",
                          body_type="json",
                          method="POST",
                          request_args=None,
                          extra_args=None,
                          http_args=None,
                          response_cls=None,
                          authn_method=""):

        url, body, ht_args, csi = self.request_info(request,
                                                    method=method,
                                                    request_args=request_args,
                                                    extra_args=extra_args,
                                                    scope=scope,
                                                    state=state,
                                                    authn_method=authn_method)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)

        return self.request_and_return(url,
                                       response_cls,
                                       method,
                                       body,
                                       body_type,
                                       state=state,
                                       http_args=http_args)

    def do_any(self,
               request,
               endpoint="",
               scope="",
               state="",
               body_type="json",
               method="POST",
               request_args=None,
               extra_args=None,
               http_args=None,
               response=None,
               authn_method=""):

        url, body, ht_args, csi = self.request_info(request,
                                                    method=method,
                                                    request_args=request_args,
                                                    extra_args=extra_args,
                                                    scope=scope,
                                                    state=state,
                                                    authn_method=authn_method,
                                                    endpoint=endpoint)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)

        return self.request_and_return(url,
                                       response,
                                       method,
                                       body,
                                       body_type,
                                       state=state,
                                       http_args=http_args)

    def fetch_protected_resource(self,
                                 uri,
                                 method="GET",
                                 headers=None,
                                 state="",
                                 **kwargs):

        if "token" in kwargs and kwargs["token"]:
            token = kwargs["token"]
            request_args = {"access_token": token}
        else:
            try:
                token = self.get_token(state=state, **kwargs)
            except ExpiredToken:
                # The token is to old, refresh
                self.do_access_token_refresh()
                token = self.get_token(state=state, **kwargs)
            request_args = {"access_token": token.access_token}

        if headers is None:
            headers = {}

        if "authn_method" in kwargs:
            http_args = self.init_authentication_method(
                request_args=request_args, **kwargs)
        else:
            # If nothing defined this is the default
            http_args = self.client_authn_method["bearer_header"](
                self).construct(request_args=request_args)

        headers.update(http_args["headers"])

        logger.debug("Fetch URI: %s" % uri)
        return self.http_request(uri, method, headers=headers)
Ejemplo n.º 22
0
class Client(PBase):
    _endpoints = ENDPOINTS

    def __init__(self, client_id=None, ca_certs=None, client_authn_method=None,
                 keyjar=None, verify_ssl=True):
        """

        :param client_id: The client identifier
        :param ca_certs: Certificates used to verify HTTPS certificates
        :param client_authn_method: Methods that this client can use to
            authenticate itself. It's a dictionary with method names as
            keys and method classes as values.
        :param verify_ssl: Whether the SSL certificate should be verfied.
        :return: Client instance
        """

        PBase.__init__(self, ca_certs, verify_ssl=verify_ssl)

        self.client_id = client_id
        self.client_authn_method = client_authn_method
        self.keyjar = keyjar or KeyJar(verify_ssl=verify_ssl)
        self.verify_ssl = verify_ssl
        #self.secret_type = "basic "

        self.state = None
        self.nonce = None

        self.grant = {}

        # own endpoint
        self.redirect_uris = [None]

        # service endpoints
        self.authorization_endpoint = None
        self.token_endpoint = None
        self.token_revocation_endpoint = None

        self.request2endpoint = REQUEST2ENDPOINT
        self.response2error = RESPONSE2ERROR
        self.grant_class = Grant
        self.token_class = Token

        self.provider_info = {}
        self._c_secret = None

    def get_client_secret(self):
        return self._c_secret

    def set_client_secret(self, val):
        if not val:
            self._c_secret = ""
        else:
            self._c_secret = val
            # client uses it for signing
            # Server might also use it for signing which means the
            # client uses it for verifying server signatures
            if self.keyjar is None:
                self.keyjar = KeyJar()
            self.keyjar.add_symmetric("", str(val), ["sig"])

    client_secret = property(get_client_secret, set_client_secret)

    def reset(self):
        self.state = None
        self.nonce = None

        self.grant = {}

        self.authorization_endpoint = None
        self.token_endpoint = None
        self.redirect_uris = None

    def grant_from_state(self, state):
        for key, grant in self.grant.items():
            if key == state:
                return grant

        return None

    def _parse_args(self, request, **kwargs):
        ar_args = kwargs.copy()

        for prop in request.c_param.keys():
            if prop in ar_args:
                continue
            else:
                if prop == "redirect_uri":
                    _val = getattr(self, "redirect_uris", [None])[0]
                    if _val:
                        ar_args[prop] = _val
                else:
                    _val = getattr(self, prop, None)
                    if _val:
                        ar_args[prop] = _val

        return ar_args

    def _endpoint(self, endpoint, **kwargs):
        try:
            uri = kwargs[endpoint]
            if uri:
                del kwargs[endpoint]
        except KeyError:
            uri = ""

        if not uri:
            try:
                uri = getattr(self, endpoint)
            except Exception:
                raise Exception("No '%s' specified" % endpoint)

        if not uri:
            raise Exception("No '%s' specified" % endpoint)

        return uri

    def get_grant(self, **kwargs):
        try:
            _state = kwargs["state"]
            if not _state:
                _state = self.state
        except KeyError:
            _state = self.state

        try:
            return self.grant[_state]
        except:
            raise Exception("No grant found for state:'%s'" % _state)

    def get_token(self, also_expired=False, **kwargs):
        try:
            return kwargs["token"]
        except KeyError:
            grant = self.get_grant(**kwargs)

            try:
                token = grant.get_token(kwargs["scope"])
            except KeyError:
                token = grant.get_token("")
                if not token:
                    try:
                        token = self.grant[kwargs["state"]].get_token("")
                    except KeyError:
                        raise Exception("No token found for scope")

        if token is None:
            raise Exception("No suitable token found")

        if also_expired:
            return token
        elif token.is_valid():
            return token
        else:
            raise ExpiredToken()

    def construct_request(self, request, request_args=None, extra_args=None):
        if request_args is None:
            request_args = {}

        #logger.debug("request_args: %s" % request_args)
        kwargs = self._parse_args(request, **request_args)

        if extra_args:
            kwargs.update(extra_args)
            #logger.debug("kwargs: %s" % kwargs)
        #logger.debug("request: %s" % request)
        return request(**kwargs)

    def construct_Message(self, request=Message, request_args=None,
                          extra_args=None, **kwargs):

        return self.construct_request(request, request_args, extra_args)

    #noinspection PyUnusedLocal
    def construct_AuthorizationRequest(self, request=AuthorizationRequest,
                                       request_args=None, extra_args=None,
                                       **kwargs):

        if request_args is not None:
            try:  # change default
                new = request_args["redirect_uri"]
                if new:
                    self.redirect_uris = [new]
            except KeyError:
                pass
        else:
            request_args = {}

        if "client_id" not in request_args:
            request_args["client_id"] = self.client_id
        elif not request_args["client_id"]:
            request_args["client_id"] = self.client_id

        return self.construct_request(request, request_args, extra_args)

    #noinspection PyUnusedLocal
    def construct_AccessTokenRequest(self,
                                     request=AccessTokenRequest,
                                     request_args=None, extra_args=None,
                                     **kwargs):

        grant = self.get_grant(**kwargs)

        if not grant.is_valid():
            raise GrantExpired("Authorization Code to old %s > %s" % (
                utc_time_sans_frac(),
                grant.grant_expiration_time))

        if request_args is None:
            request_args = {}

        request_args["code"] = grant.code

        if "grant_type" not in request_args:
            request_args["grant_type"] = "authorization_code"

        if "client_id" not in request_args:
            request_args["client_id"] = self.client_id
        elif not request_args["client_id"]:
            request_args["client_id"] = self.client_id
        return self.construct_request(request, request_args, extra_args)

    def construct_RefreshAccessTokenRequest(self,
                                            request=RefreshAccessTokenRequest,
                                            request_args=None, extra_args=None,
                                            **kwargs):

        if request_args is None:
            request_args = {}

        token = self.get_token(also_expired=True, **kwargs)

        request_args["refresh_token"] = token.refresh_token

        try:
            request_args["scope"] = token.scope
        except AttributeError:
            pass

        return self.construct_request(request, request_args, extra_args)

    def construct_TokenRevocationRequest(self,
                                         request=TokenRevocationRequest,
                                         request_args=None, extra_args=None,
                                         **kwargs):

        if request_args is None:
            request_args = {}

        token = self.get_token(**kwargs)

        request_args["token"] = token.access_token
        return self.construct_request(request, request_args, extra_args)

    def construct_ResourceRequest(self, request=ResourceRequest,
                                  request_args=None, extra_args=None,
                                  **kwargs):

        if request_args is None:
            request_args = {}

        token = self.get_token(**kwargs)

        request_args["access_token"] = token.access_token
        return self.construct_request(request, request_args, extra_args)

    def get_or_post(self, uri, method, req,
                    content_type=DEFAULT_POST_CONTENT_TYPE, **kwargs):
        if method == "GET":
            _qp = req.to_urlencoded()
            if _qp:
                path = uri + '?' + _qp
            else:
                path = uri
            body = None
        elif method == "POST":
            path = uri
            if content_type == URL_ENCODED:
                body = req.to_urlencoded()
            elif content_type == JSON_ENCODED:
                body = req.to_json()
            else:
                raise UnSupported(
                    "Unsupported content type: '%s'" % content_type)

            header_ext = {"content-type": content_type}
            if "headers" in kwargs.keys():
                kwargs["headers"].update(header_ext)
            else:
                kwargs["headers"] = header_ext
        else:
            raise Exception("Unsupported HTTP method: '%s'" % method)

        return path, body, kwargs

    def uri_and_body(self, reqmsg, cis, method="POST", request_args=None,
                     **kwargs):

        if "endpoint" in kwargs and kwargs["endpoint"]:
            uri = kwargs["endpoint"]
        else:
            uri = self._endpoint(self.request2endpoint[reqmsg.__name__],
                                 **request_args)

        uri, body, kwargs = self.get_or_post(uri, method, cis, **kwargs)
        try:
            h_args = {"headers": kwargs["headers"]}
        except KeyError:
            h_args = {}

        return uri, body, h_args, cis

    def request_info(self, request, method="POST", request_args=None,
                     extra_args=None, **kwargs):

        if request_args is None:
            request_args = {}

        try:
            cls = getattr(self, "construct_%s" % request.__name__)
            cis = cls(request_args=request_args, extra_args=extra_args,
                      **kwargs)
        except AttributeError:
            cis = self.construct_request(request, request_args, extra_args)

        if "authn_method" in kwargs:
            h_arg = self.init_authentication_method(cis,
                                                    request_args=request_args,
                                                    **kwargs)
        else:
            h_arg = None

        if h_arg:
            if "headers" in kwargs.keys():
                kwargs["headers"].update(h_arg)
            else:
                kwargs["headers"] = h_arg

        return self.uri_and_body(request, cis, method, request_args,
                                 **kwargs)

    def authorization_request_info(self, request_args=None, extra_args=None,
                                   **kwargs):
        return self.request_info(AuthorizationRequest, "GET",
                                 request_args, extra_args, **kwargs)

    def parse_response(self, response, info="", sformat="json", state="",
                       **kwargs):
        """
        Parse a response

        :param response: Response type
        :param info: The response, can be either in a JSON or an urlencoded
            format
        :param sformat: Which serialization that was used
        :param state:
        :param kwargs: Extra key word arguments
        :return: The parsed and to some extend verified response
        """

        _r2e = self.response2error

        if sformat == "urlencoded":
            if '?' in info or '#' in info:
                parts = urlparse.urlparse(info)
                scheme, netloc, path, params, query, fragment = parts[:6]
                # either query of fragment
                if query:
                    info = query
                else:
                    info = fragment

        err = None
        try:
            resp = response().deserialize(info, sformat, **kwargs)
            if "error" in resp and not isinstance(resp, ErrorResponse):
                resp = None
                try:
                    errmsgs = _r2e[response.__name__]
                except KeyError:
                    errmsgs = [ErrorResponse]

                try:
                    for errmsg in errmsgs:
                        try:
                            resp = errmsg().deserialize(info, sformat)
                            resp.verify()
                            break
                        except Exception, aerr:
                            resp = None
                            err = aerr
                except KeyError:
                    pass
            elif resp.only_extras():
                resp = None
            else:
                verf = resp.verify(**kwargs)
                if not verf:
                    raise PyoidcError("Verification of the response failed")
                if resp.type() == "AuthorizationResponse" and \
                        "scope" not in resp:
                    try:
                        resp["scope"] = kwargs["scope"]
                    except KeyError:
                        pass
        except Exception, derr:
            resp = None
            err = derr

        if not resp:
            if err:
                raise err
            else:
                raise ResponseError("Missing or faulty response")

        if resp.type() in ["AuthorizationResponse", "AccessTokenResponse"]:
            try:
                _state = resp["state"]
            except (AttributeError, KeyError):
                _state = ""

            if not _state:
                _state = state

            try:
                self.grant[_state].update(resp)
            except KeyError:
                self.grant[_state] = self.grant_class(resp=resp)

        return resp