示例#1
0
class ImageSimpleCache(ImageCache):
    """Simple image cache."""

    def __init__(self):
        """Initialize the cache."""
        super(ImageSimpleCache, self).__init__()
        self.cache = SimpleCache()

    def get(self, key):
        """Return the key value.

        :param key: the object's key
        :return: the stored object
        :rtype: `BytesIO` object
        """
        return self.cache.get(key)

    def set(self, key, value, timeout=None):
        """Cache the object.

        :param key: the object's key
        :param value: the stored object
        :type value: `BytesIO` object
        :param timeout: the cache timeout in seconds
        """
        timeout = timeout if timeout else self.timeout
        self.cache.set(key, value, timeout)

    def delete(self, key):
        """Delete the specific key."""
        self.cache.delete(key)

    def flush(self):
        """Flush the cache."""
        self.cache.clear()
示例#2
0
文件: sessions.py 项目: carylF/PCB
class _UserSessions(object):
  def __init__(self):
    self.cache = SimpleCache()

  def create(self, user_id):
    user = User.find(User.id == user_id)
    if user is None:
      return None
    sess = os.urandom(24)
    self.cache.set(sess, user_id)
    session['key'] = sess
    return sess

  def get(self):
    if 'key' not in session:
      return None
    key = session['key']
    user_id = self.cache.get(key)
    user = User.find(User.id == user_id)
    session['user'] = user
    return user

  def delete(self):
    if 'key' in session:
      self.cache.delete(session['key'])
      session.pop('key', None)
      session.pop('user', None)
示例#3
0
class LocalSessionManager(SessionManager):
    """
    Session manager that stores sessions in memory.
    Do not use it in a multi-threaded/multi-processes environment,
    and use RedisSessionManager instead.
    """
    def __init__(self):
        # The default timeout is 0, meaning no expiration.
        # When creating a session, the timeout of the key will be set accordingly
        # to the expiration time of the access token.
        self.sessions = SimpleCache(default_timeout=0)

    def create_session(self, session_id, payload: dict, expire_seconds=None):
        self.sessions.add(session_id, payload, timeout=expire_seconds)

    def get(self, session_id):
        return self.sessions.get(session_id)

    def destroy_session(self, session_id):
        return self.sessions.delete(session_id)
示例#4
0
class JwtManager(object):
    ALGORITHMS = "RS256"

    def __init__(self, app=None):

        # These are all set in the init_app function, but are listed here for easy reference
        self.app = app
        self.well_known_config = None
        self.well_known_obj_cache = None
        self.algorithms = JwtManager.ALGORITHMS
        self.jwks_uri = None
        self.issuer = None
        self.audience = None
        self.cache = None
        self.caching_enabled = False

        self.jwt_oidc_test_mode = False
        self.jwt_oidc_test_keys = None

        if app is not None:
            self.init_app(app)

    def init_app(self, app):
        """initializze this extension

        if the config['JWT_OIDC_WELL_KNOWN_CONFIG'] is set, then try to load the JWKS_URI & ISSUER from that
        If it is not set
        attempt to load the JWKS_URI and ISSUE from the application config

        Required settings to function:
        WELL_KNOWN_CONFIG (optional) is this is set, the JWKS_URI & ISSUER will be loaded from there
        JWKS_URI: the endpoint defined for the jwks_keys
        ISSUER: the endpoint for the issuer of the tokens
        ALGORITHMS: only RS256 is supported
        AUDIENCE: the oidc audience (or API_IDENTIFIER)
        CLIENT_SECRET: the shared secret / key assigned to the client (audience)
        """
        self.app = app
        self.jwt_oidc_test_mode = app.config.get('JWT_OIDC_TEST_MODE', None)
        #
        ## CHECK IF WE"RE RUNNING IN TEST_MODE!!
        #
        if self.jwt_oidc_test_mode:
            app.logger.debug(
                'JWT MANAGER running in test mode, using locally defined certs & tokens'
            )

            self.issuer = app.config.get('JWT_OIDC_TEST_ISSUER',
                                         'localhost.localdomain')
            self.jwt_oidc_test_keys = app.config.get('JWT_OIDC_TEST_KEYS',
                                                     None)
            self.audience = app.config.get('JWT_OIDC_TEST_AUDIENCE', None)
            self.jwt_oidc_test_private_key_pem = app.config.get(
                'JWT_OIDC_TEST_PRIVATE_KEY_PEM', None)

            if self.jwt_oidc_test_keys:
                app.logger.debug('local key being used: {}'.format(
                    self.jwt_oidc_test_keys))
            else:
                app.logger.error(
                    'Attempting to run JWT Manager with no local key assigned')
                raise Exception(
                    'Attempting to run JWT Manager with no local key assigned')

        else:

            self.algorithms = [
                app.config.get('JWT_OIDC_ALGORITHMS', JwtManager.ALGORITHMS)
            ]

            # If the WELL_KNOWN_CONFIG is set, then go fetch the JWKS & ISSUER
            self.well_known_config = app.config.get(
                'JWT_OIDC_WELL_KNOWN_CONFIG', None)
            if self.well_known_config:
                # try to get the jwks & issuer from the well known config
                # jurl = urlopen(url=self.well_known_config, context=ssl.create_default_context())
                jurl = urlopen(url=self.well_known_config)
                self.well_known_obj_cache = json.loads(
                    jurl.read().decode("utf-8"))

                self.jwks_uri = self.well_known_obj_cache['jwks_uri']
                self.issuer = self.well_known_obj_cache['issuer']
            else:

                self.jwks_uri = app.config.get('JWT_OIDC_JWKS_URI', None)
                self.issuer = app.config.get('JWT_OIDC_ISSUER', None)

            # Setup JWKS caching
            self.caching_enabled = app.config.get('JWT_OIDC_CACHING_ENABLED',
                                                  False)
            if self.caching_enabled:
                from werkzeug.contrib.cache import SimpleCache
                self.cache = SimpleCache(default_timeout=app.config.get(
                    'JWT_OIDC_JWKS_CACHE_TIMEOUT', 300))

            self.audience = app.config.get('JWT_OIDC_AUDIENCE', None)

        app.logger.debug('JWKS_URI: {}'.format(self.jwks_uri))
        app.logger.debug('ISSUER: {}'.format(self.issuer))
        app.logger.debug('ALGORITHMS: {}'.format(self.algorithms))
        app.logger.debug('AUDIENCE: {}'.format(self.audience))
        app.logger.debug('JWT_OIDC_TEST_MODE: {}'.format(
            self.jwt_oidc_test_mode))
        app.logger.debug('JWT_OIDC_TEST_KEYS: {}'.format(
            self.jwt_oidc_test_keys))

        # set the auth error handler
        auth_err_handler = app.config.get('JWT_OIDC_AUTH_ERROR_HANDLER',
                                          JwtManager.handle_auth_error)
        app.register_error_handler(AuthError, auth_err_handler)

        app.teardown_appcontext(self.teardown)

    def teardown(self, exception):
        pass
        # ctx = _app_ctx_stack.top
        # if hasattr(ctx, 'cached object'):

    @staticmethod
    def handle_auth_error(ex):
        response = jsonify(ex.error)
        response.status_code = ex.status_code
        return response

    def get_token_auth_header(self):
        """Obtains the access token from the Authorization Header
        """

        auth = request.headers.get("Authorization", None)
        if not auth:
            raise AuthError(
                {
                    "code": "authorization_header_missing",
                    "description": "Authorization header is expected"
                }, 401)

        parts = auth.split()

        if parts[0].lower() != "bearer":
            raise AuthError(
                {
                    "code": "invalid_header",
                    "description":
                    "Authorization header must start with Bearer"
                }, 401)

        elif len(parts) < 2:
            raise AuthError(
                {
                    "code": "invalid_header",
                    "description": "Token not found after Bearer"
                }, 401)

        elif len(parts) > 2:
            raise AuthError(
                {
                    "code":
                    "invalid_header",
                    "description":
                    "Authorization header is an invalid token structure"
                }, 401)

        return parts[1]

    def contains_role(self, roles):
        """Checks that the listed roles are in the token
           using the registered callback
        Args:
            roles [str,]: Comma separated list of valid roles
            JWT_ROLE_CALLBACK (fn): The callback added to the Flask configuration
        """
        token = self.get_token_auth_header()
        unverified_claims = jwt.get_unverified_claims(token)
        roles_in_token = current_app.config['JWT_ROLE_CALLBACK'](
            unverified_claims)
        if any(elem in roles_in_token for elem in roles):
            return True
        return False

    def has_one_of_roles(self, roles):
        """Checks that at least one of the roles are in the token
           using the registered callback
        Args:
            roles [str,]: Comma separated list of valid roles
            JWT_ROLE_CALLBACK (fn): The callback added to the Flask configuration
        """
        def decorated(f):
            @wraps(f)
            def wrapper(*args, **kwargs):
                self._require_auth_validation(*args, **kwargs)
                if self.contains_role(roles):
                    return f(*args, **kwargs)
                raise AuthError(
                    {
                        "code":
                        "missing_a_valid_role",
                        "description":
                        "Missing a role required to access this endpoint"
                    }, 401)

            return wrapper

        return decorated

    def validate_roles(self, required_roles):
        """Checks that the listed roles are in the token
           using the registered callback
        Args:
            required_roles [str,]: Comma separated list of required roles
            JWT_ROLE_CALLBACK (fn): The callback added to the Flask configuration
        """
        token = self.get_token_auth_header()
        unverified_claims = jwt.get_unverified_claims(token)
        roles_in_token = current_app.config['JWT_ROLE_CALLBACK'](
            unverified_claims)
        if all(elem in roles_in_token for elem in required_roles):
            return True
        return False

    def requires_roles(self, required_roles):
        """Checks that the listed roles are in the token
           using the registered callback
        Args:
            required_roles [str,]: Comma separated list of required roles
            JWT_ROLE_CALLBACK (fn): The callback added to the Flask configuration
        """
        def decorated(f):
            @wraps(f)
            def wrapper(*args, **kwargs):
                self._require_auth_validation(*args, **kwargs)
                if self.validate_roles(required_roles):
                    return f(*args, **kwargs)
                raise AuthError(
                    {
                        "code":
                        "missing_required_roles",
                        "description":
                        "Missing the role(s) required to access this endpoint"
                    }, 401)

            return wrapper

        return decorated

    def requires_auth(self, f):
        """Validates the Bearer Token
        """
        @wraps(f)
        def decorated(*args, **kwargs):
            self._require_auth_validation(*args, **kwargs)

            return f(*args, **kwargs)

        return decorated

    def _require_auth_validation(self, *args, **kwargs):
        token = self.get_token_auth_header()

        try:
            unverified_header = jwt.get_unverified_header(token)
        except jwt.JWTError:
            raise AuthError(
                {
                    "code":
                    "invalid_header",
                    "description":
                    "Invalid header. "
                    "Use an RS256 signed JWT Access Token"
                }, 401)
        if unverified_header["alg"] == "HS256":
            raise AuthError(
                {
                    "code":
                    "invalid_header",
                    "description":
                    "Invalid header. "
                    "Use an RS256 signed JWT Access Token"
                }, 401)
        if not "kid" in unverified_header:
            raise AuthError(
                {
                    "code": "invalid_header",
                    "description": "Invalid header. "
                    "No KID in token header"
                }, 401)

        rsa_key = self.get_rsa_key(self.get_jwks(), unverified_header["kid"])

        if not rsa_key and self.caching_enabled:
            # Could be key rotation, invalidate the cache and try again
            self.cache.delete('jwks')
            rsa_key = self.get_rsa_key(self.get_jwks(),
                                       unverified_header["kid"])

        if not rsa_key:
            raise AuthError(
                {
                    "code": "invalid_header",
                    "description":
                    "Unable to find jwks key referenced in token"
                }, 401)

        try:
            payload = jwt.decode(token,
                                 rsa_key,
                                 algorithms=self.algorithms,
                                 audience=self.audience,
                                 issuer=self.issuer)
            _request_ctx_stack.top.current_user = g.jwt_oidc_token_info = payload
        except jwt.ExpiredSignatureError:
            raise AuthError(
                {
                    "code": "token_expired",
                    "description": "token has expired"
                }, 401)
        except jwt.JWTClaimsError:
            raise AuthError(
                {
                    "code":
                    "invalid_claims",
                    "description":
                    "incorrect claims,"
                    " please check the audience and issuer"
                }, 401)
        except Exception:
            raise AuthError(
                {
                    "code": "invalid_header",
                    "description": "Unable to parse authentication"
                    " token."
                }, 401)

    def get_jwks(self):
        if self.jwt_oidc_test_mode:
            return self.jwt_oidc_test_keys

        if self.caching_enabled:
            return self._get_jwks_from_cache()
        else:
            return self._fetch_jwks_from_url()

    def _get_jwks_from_cache(self):
        jwks = self.cache.get('jwks')
        if jwks is None:
            jwks = self._fetch_jwks_from_url()
            self.cache.set('jwks', jwks)
        return jwks

    def _fetch_jwks_from_url(self):
        jsonurl = urlopen(self.jwks_uri)
        return json.loads(jsonurl.read().decode("utf-8"))

    def create_jwt(self, claims, header):
        token = jwt.encode(claims,
                           self.jwt_oidc_test_private_key_pem,
                           headers=header,
                           algorithm='RS256')
        return token

    def get_rsa_key(self, jwks, kid):
        rsa_key = {}
        for key in jwks["keys"]:
            if key["kid"] == kid:
                rsa_key = {
                    "kty": key["kty"],
                    "kid": key["kid"],
                    "use": key["use"],
                    "n": key["n"],
                    "e": key["e"]
                }
        return rsa_key
示例#5
0
class JWTAuthenticationBackend(AuthenticationBackend):  # pylint: disable=too-many-instance-attributes
    """JWT authentication backend for AuthenticationMiddleware."""
    def __init__(self):
        """Initize app."""
        self.algorithm = 'RS256'
        self.prefix = 'bearer'
        self.well_known_config = None
        self.well_known_obj_cache = None
        self.jwks_uri = None
        self.issuer = None
        self.audience = None
        self.client_secret = None
        self.cache = None
        self.caching_enabled = False
        self.jwt_oidc_test_mode = False
        self.jwt_oidc_test_public_key_pem = None
        self.jwt_oidc_test_private_key_pem = None

    def init_app(self, test_mode: bool = False):
        """Initize app."""
        self.jwt_oidc_test_mode = test_mode
        #
        # CHECK IF WE'RE RUNNING IN TEST_MODE!!
        #
        if not self.jwt_oidc_test_mode:
            self.algorithm = get_api_settings().JWT_OIDC_ALGORITHMS

            # If the WELL_KNOWN_CONFIG is set, then go fetch the JWKS & ISSUER
            self.well_known_config = get_api_settings(
            ).JWT_OIDC_WELL_KNOWN_CONFIG
            if self.well_known_config:
                # try to get the jwks & issuer from the well known config
                with urlopen(url=self.well_known_config) as jurl:
                    self.well_known_obj_cache = json.loads(
                        jurl.read().decode('utf-8'))

                self.jwks_uri = self.well_known_obj_cache['jwks_uri']
                self.issuer = self.well_known_obj_cache['issuer']
            else:

                self.jwks_uri = get_api_settings().JWT_OIDC_JWKS_URI
                self.issuer = get_api_settings().JWT_OIDC_ISSUER

            # Setup JWKS caching
            self.caching_enabled = get_api_settings().JWT_OIDC_CACHING_ENABLED
            if self.caching_enabled:
                self.cache = SimpleCache(default_timeout=get_api_settings().
                                         JWT_OIDC_JWKS_CACHE_TIMEOUT)

            self.audience = get_api_settings().JWT_OIDC_AUDIENCE
            self.client_secret = get_api_settings().JWT_OIDC_CLIENT_SECRET

    @classmethod
    def get_token_from_header(cls, authorization: str, prefix: str):
        """Get token from header."""
        try:
            scheme, token = authorization.split()
        except ValueError:
            raise AuthenticationError(
                'Could not separate Authorization scheme and token')
        if scheme.lower() != prefix.lower():
            raise AuthenticationError(
                f'Authorization scheme {scheme} is not supported')
        return token

    async def authenticate(self, request):  # pylint: disable=arguments-renamed
        """Authenticate the token."""
        if 'Authorization' not in request.headers:
            return None

        auth = request.headers['Authorization']
        token = self.get_token_from_header(authorization=auth,
                                           prefix=self.prefix)

        if self.jwt_oidc_test_mode:
            # in test mode, use a publick key to decode token directly.
            try:
                payload = jwt.decode(str.encode(token),
                                     self.jwt_oidc_test_public_key_pem,
                                     algorithms=self.algorithm)
            except jwt.InvalidTokenError as e:
                raise AuthenticationError(str(e))
        else:
            #  in production mod, get the public key from jwks_url
            try:
                unverified_header = jwt.get_unverified_header(token)
            except jwt.PyJWTError:
                raise AuthenticationError(
                    'Invalid header: Use an RS256 signed JWT Access Token')
            if unverified_header['alg'] == 'HS256':
                raise AuthenticationError(
                    'Invalid header: Use an RS256 signed JWT Access Token')
            if 'kid' not in unverified_header:
                raise AuthenticationError(
                    'Invalid header: No KID in token header')

            rsa_key = self.get_rsa_key(self.get_jwks(),
                                       unverified_header['kid'])

            if not rsa_key and self.caching_enabled:
                # Could be key rotation, invalidate the cache and try again
                self.cache.delete('jwks')
                rsa_key = self.get_rsa_key(self.get_jwks(),
                                           unverified_header['kid'])

            if not rsa_key:
                raise AuthenticationError(
                    'invalid_header: Unable to find jwks key referenced in token'
                )

            public_key = RSAAlgorithm.from_jwk(json.dumps(rsa_key))

            try:
                payload = jwt.decode(token,
                                     public_key,
                                     algorithms=self.algorithm,
                                     audience=self.audience)
            except jwt.InvalidTokenError as e:
                raise AuthenticationError(str(e))

        return AuthCredentials(['authenticated'
                                ]), JWTUser(username=payload['username'],
                                            token=token,
                                            payload=payload)

    def get_jwks(self):
        """Get jwks from well known config endpoint."""
        if self.caching_enabled:
            return self._get_jwks_from_cache()

        return self._fetch_jwks_from_url()

    def _get_jwks_from_cache(self):
        jwks = self.cache.get('jwks')
        if jwks is None:
            jwks = self._fetch_jwks_from_url()
            self.cache.set('jwks', jwks)
        return jwks

    def _fetch_jwks_from_url(self):
        with urlopen(url=self.jwks_uri) as jsonurl:
            return json.loads(jsonurl.read().decode('utf-8'))

    def get_rsa_key(self, jwks, kid):  # pylint: disable=no-self-use
        """Get a public key."""
        rsa_key = {}
        for key in jwks['keys']:
            if key['kid'] == kid:
                rsa_key = {
                    'kty': key['kty'],
                    'kid': key['kid'],
                    'use': key['use'],
                    'n': key['n'],
                    'e': key['e']
                }
        return rsa_key

    def create_testing_jwt(self, claims, header):
        """Create test jwt token."""
        token = jwt.encode(claims,
                           self.jwt_oidc_test_private_key_pem,
                           headers=header,
                           algorithm='RS256')
        return token.decode('utf-8')

    def set_testing_keys(self, private_key, public_key):
        """Set test keys."""
        self.jwt_oidc_test_private_key_pem = private_key
        self.jwt_oidc_test_public_key_pem = public_key
示例#6
0
class ImageSimpleCache(ImageCache):
    """Simple image cache."""
    def __init__(self):
        """Initialize the cache."""
        super(ImageSimpleCache, self).__init__()
        self.cache = SimpleCache()

    def get(self, key):
        """Return the key value.

        :param key: the object's key
        :return: the stored object
        :rtype: `BytesIO` object
        """
        return self.cache.get(key)

    def set(self, key, value, timeout=None):
        """Cache the object.

        :param key: the object's key
        :param value: the stored object
        :type value: `BytesIO` object
        :param timeout: the cache timeout in seconds
        """
        timeout = timeout if timeout else self.timeout
        self.cache.set(key, value, timeout)
        self.set_last_modification(key, timeout=timeout)

    def get_last_modification(self, key):
        """Get last modification of cached file.

        :param key: the file object's key
        """
        last = self.cache.get(self._last_modification_key_name(key))
        return last

    def set_last_modification(self, key, last_modification=None, timeout=None):
        """Set last modification of cached file.

        :param key: the file object's key
        :param last_modification: Last modification date of
            file represented by the key
        :type last_modification: datetime
        :param timeout: the cache timeout in seconds
        """
        if not key:
            return
        if not last_modification:
            last_modification = datetime.utcnow().replace(microsecond=0)
        timeout = timeout if timeout else self.timeout
        self.cache.set(self._last_modification_key_name(key),
                       last_modification, timeout)

    def delete(self, key):
        """Delete the specific key."""
        if key:
            self.cache.delete(key)
            self.cache.delete(self._last_modification_key_name(key))

    def flush(self):
        """Flush the cache."""
        self.cache.clear()