Ejemplo n.º 1
0
    def test_ec_deserialization(self):
        """
        Perform the EC deserialization test
        """

        server_address = start_server(self.public_numbers.n, self.public_numbers.e,
                                      self.test_id,
                                      test_ec={'x': self.ec_public_numbers.x,
                                               'y': self.ec_public_numbers.y,
                                               'kid': self.ec_test_id})
        print(server_address)
        issuer = "http://localhost:{}/".format(server_address[1])
        token = scitokens.SciToken(key=self.ec_private_key, key_id=self.ec_test_id,
                                   algorithm="ES256")
        token.update_claims({"test": "true"})
        serialized_token = token.serialize(issuer=issuer)

        self.assertEqual(len(serialized_token.decode('utf8').split(".")), 3)

        scitoken = scitokens.SciToken.deserialize(serialized_token, insecure=True)

        self.assertIsInstance(scitoken, scitokens.SciToken)

        token = scitokens.SciToken(key=self.private_key, key_id="doesnotexist")
        serialized_token = token.serialize(issuer=issuer)
        with self.assertRaises(scitokens.utils.errors.MissingKeyException):
            scitoken = scitokens.SciToken.deserialize(serialized_token, insecure=True)
Ejemplo n.º 2
0
    def test_deserialization(self):
        """
        Perform the deserialization test
        """
        global TEST_N
        global TEST_E
        global TEST_ID
        with open('tests/simple_private_key.pem', 'rb') as key_file:
            private_key = serialization.load_pem_private_key(
                key_file.read(), password=None, backend=default_backend())
        TEST_ID = "stuffblah"

        token = scitokens.SciToken(key=private_key, key_id=TEST_ID)
        token.update_claims({"test": "true"})
        serialized_token = token.serialize(issuer="http://localhost:8080/")

        public_numbers = private_key.public_key().public_numbers()
        TEST_E = public_numbers.e
        TEST_N = public_numbers.n

        self.assertEqual(len(serialized_token.decode('utf8').split(".")), 3)

        scitoken = scitokens.SciToken.deserialize(serialized_token,
                                                  insecure=True)

        self.assertIsInstance(scitoken, scitokens.SciToken)

        token = scitokens.SciToken(key=private_key, key_id="doesnotexist")
        serialized_token = token.serialize(issuer="http://localhost:8080/")
        with self.assertRaises(scitokens.utils.errors.MissingKeyException):
            scitoken = scitokens.SciToken.deserialize(serialized_token,
                                                      insecure=True)
Ejemplo n.º 3
0
    def test_autodetect_keytype(self):
        """
        Test the autodetection of the key type
        """
        private_key = generate_private_key(public_exponent=65537,
                                           key_size=2048,
                                           backend=default_backend())

        ec_private_key = ec.generate_private_key(ec.SECP256R1(),
                                                 default_backend())

        # Test when we give it the wrong algorithm type
        with self.assertRaises(scitokens.scitokens.UnsupportedKeyException):
            token = scitokens.SciToken(key=private_key, algorithm="ES256")

        # Test when we give it the wrong algorithm type
        with self.assertRaises(scitokens.scitokens.UnsupportedKeyException):
            token = scitokens.SciToken(key=ec_private_key, algorithm="RS256")

        # Test when we give an unsupported algorithm
        unsupported_private_key = ec.generate_private_key(
            ec.SECP192R1(), default_backend())
        with self.assertRaises(scitokens.scitokens.UnsupportedKeyException):
            token = scitokens.SciToken(key=unsupported_private_key)

        token = scitokens.SciToken(key=ec_private_key, algorithm="ES256")
        token.serialize(issuer="local")
Ejemplo n.º 4
0
    def setUp(self):
        """
        Setup a sample token for testing the enforcer.
        """
        now = time.time()
        private_key = cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key(
            public_exponent=65537,
            key_size=2048,
            backend=cryptography.hazmat.backends.default_backend())
        self._token = scitokens.SciToken(key=private_key)
        self._token["foo"] = "bar"
        self._token["iat"] = int(now)
        self._token["exp"] = int(now + 600)
        self._token["iss"] = self._test_issuer
        self._token["nbf"] = int(now)

        # Scitoken v2
        self._token2 = scitokens.SciToken(key=private_key)
        self._token2["ver"] = "scitoken:2.0"
        self._token2["foo"] = "bar"
        self._token2["iat"] = int(now)
        self._token2["exp"] = int(now + 600)
        self._token2["iss"] = self._test_issuer
        self._token2["nbf"] = int(now)
        self._token2['wlcg.groups'] = ['groupA', 'groupB']
        self._token2["aud"] = "ANY"
Ejemplo n.º 5
0
 def setUp(self):
     self._private_key = generate_private_key(public_exponent=65537,
                                              key_size=2048,
                                              backend=default_backend())
     self._public_key = self._private_key.public_key()
     self._public_pem = self._public_key.public_bytes(
         encoding=serialization.Encoding.PEM,
         format=serialization.PublicFormat.SubjectPublicKeyInfo)
     keycache = scitokens.utils.keycache.KeyCache.getinstance()
     keycache.addkeyinfo("local", "sample_key",
                         self._private_key.public_key())
     self._token = scitokens.SciToken(key=self._private_key,
                                      key_id="sample_key")
     self._no_kid_token = scitokens.SciToken(key=self._private_key)
Ejemplo n.º 6
0
    def test_public_key(self):
        """
        Test when the public key is provided to deserialize
        """

        token = scitokens.SciToken(key=self._private_key)
        serialized_token = token.serialize(issuer="local")

        new_token = scitokens.SciToken.deserialize(serialized_token,
                                                   public_key=self._public_pem,
                                                   insecure=True)
        self.assertIsInstance(new_token, scitokens.SciToken)

        # With invalid key
        with self.assertRaises(ValueError):
            scitokens.SciToken.deserialize(serialized_token,
                                           insecure=True,
                                           public_key="asdf".encode())

        # With a proper key, but not the right one
        private_key = generate_private_key(public_exponent=65537,
                                           key_size=2048,
                                           backend=default_backend())
        public_key = private_key.public_key()
        pem = public_key.public_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PublicFormat.SubjectPublicKeyInfo)
        with self.assertRaises(DecodeError):
            scitokens.SciToken.deserialize(serialized_token,
                                           insecure=True,
                                           public_key=pem)
Ejemplo n.º 7
0
    def refresh_access_token(self, username, token_name):
        """
        Create a SciToken at the specified path.
        """

        token = scitokens.SciToken(algorithm="ES256", key=self._private_key, key_id=self._private_key_id)
        token.update_claims({'sub': username})
        user_authz = self.authz_template.format(username=username)
        token.update_claims({'scope': user_authz})
        token.update_claims({'ver': 'scitokens:2.0'})

        # Serialize the token and write it to a file
        try:
            serialized_token = token.serialize(issuer=self.token_issuer, lifetime=int(self.token_lifetime))
        except TypeError:
            self.log.exception("Failure when attempting to serialize a SciToken, likely due to algorithm mismatch")
            return False

        oauth_response = {"access_token": serialized_token.decode(),
                          "expires_in":   int(self.token_lifetime)}

        access_token_path = os.path.join(self.cred_dir, username, token_name + '.use')

        try:
            atomic_output_json(oauth_response, access_token_path)
        except OSError as oe:
            self.log.exception("Failure when writing out new access token to {}: {}.".format(
                access_token_path, str(oe)))
            return False
        return True
Ejemplo n.º 8
0
    def test_ec_public_key(self):
        """
        Test when the public key is provided to deserialize for Elliptical Curve
        """

        ec_private_key = ec.generate_private_key(
            ec.SECP256R1(), default_backend()
        )
        ec_public_key = ec_private_key.public_key()
        ec_public_pem = ec_public_key.public_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PublicFormat.SubjectPublicKeyInfo
        )

        token = scitokens.SciToken(key = ec_private_key, algorithm = "ES256")
        serialized_token = token.serialize(issuer = "local")

        new_token = scitokens.SciToken.deserialize(serialized_token, public_key = ec_public_pem, insecure = True)
        self.assertIsInstance(new_token, scitokens.SciToken)

        # With invalid key
        with self.assertRaises(ValueError):
            scitokens.SciToken.deserialize(serialized_token, insecure=True, public_key = "asdf".encode())

        # With a proper key, but not the right one
        private_key = ec.generate_private_key(
            ec.SECP256R1(), default_backend()
        )
        public_key = private_key.public_key()
        pem = public_key.public_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PublicFormat.SubjectPublicKeyInfo
        )
        with self.assertRaises(DecodeError):
            scitokens.SciToken.deserialize(serialized_token, insecure=True, public_key = pem)
Ejemplo n.º 9
0
    def test_serialize(self):
        """
        Test various edge cases of serialization, particularly around failures.
        """
        with self.assertRaises(NotImplementedError):
            print(self._token.serialize(issuer="local", include_key=True))

        token = scitokens.SciToken()
        with self.assertRaises(scitokens.utils.errors.MissingKeyException):
            print(token.serialize(issuer="local"))

        with self.assertRaises(scitokens.scitokens.MissingIssuerException):
            print(self._token.serialize())

        serialized_token = self._token.serialize(issuer="local")
        self.assertTrue(serialized_token)

        token = scitokens.SciToken.deserialize(serialized_token,
                                               public_key=self._public_pem,
                                               insecure=True)
        self.assertTrue(isinstance(token, scitokens.SciToken))

        with self.assertRaises(NotImplementedError):
            print(
                scitokens.SciToken.deserialize(serialized_token,
                                               require_key=True,
                                               insecure=True))

        with self.assertRaises(scitokens.scitokens.InvalidTokenFormat):
            print(scitokens.SciToken.deserialize("asdf1234"))
Ejemplo n.º 10
0
    def generate_scitoken(self,
                          parent_token=None,
                          refresh_token=None,
                          claims=None):
        '''
        This is the principal method that generates a scitoken.
        TODO : Note that Python implementation does not support parent tokens (i.e. hierarchies) yet.

        :param parent_token: Parent token for the new SciToken.
        :param refresh_token: Refresh token that will be used to generate the Scitoken
        :param claims: The set of claims that will added (through update) to the scitoken
        :return: scitoken
        '''
        payload = {'refresh_token': refresh_token}
        r = requests.post(self.VALIDATE_REFTOKEN_URL,
                          data=payload,
                          verify=False)
        if r.json()['result']:
            # NOTE : According to Github, the below should "Create token and generate a new private key"
            # but it does not generate a private key and results with MissingKeyException
            # token=scitokens.SciToken()
            token = scitokens.SciToken(key=self._private_key,
                                       parent=parent_token)
            if claims is not None:
                token.update_claims(claims)
            return token
        else:
            return None
Ejemplo n.º 11
0
 def setUp(self):
     now = time.time()
     self._token = scitokens.SciToken()
     self._token["foo"] = "bar"
     self._token["iat"] = int(now)
     self._token["exp"] = int(now + 600)
     self._token["iss"] = "https://scitokens.org/unittest"
     self._token["nbf"] = int(now)
Ejemplo n.º 12
0
    def test_create(self):
        """
        Test the creation of a simple SciToken.
        """

        token = scitokens.SciToken(key=self._private_key)
        token.update_claims({"test": "true"})
        serialized_token = token.serialize(issuer="local")

        self.assertEqual(len(serialized_token.decode('utf8').split(".")), 3)
        print(serialized_token)
Ejemplo n.º 13
0
    def refresh_access_token(self, username, token_name):
        """
        Create a SciToken at the specified path.
        """

        token = scitokens.SciToken(algorithm="ES256",
                                   key=self._private_key,
                                   key_id=self._private_key_id)
        token.update_claims({'sub': username})
        user_authz = self.authz_template.format(username=username)
        token.update_claims({'scope': user_authz})
        token.update_claims({'ver': 'scitokens:2.0'})

        # Serialize the token and write it to a file
        try:
            serialized_token = token.serialize(issuer=self.token_issuer,
                                               lifetime=int(
                                                   self.token_lifetime))
        except TypeError:
            self.log.exception(
                "Failure when attempting to serialize a SciToken, likely due to algorithm mismatch"
            )
            return False

        # copied from the Vault credmon
        (tmp_fd, tmp_access_token_path) = tempfile.mkstemp(dir=self.cred_dir)
        with os.fdopen(tmp_fd, 'w') as f:
            if self.token_use_json:
                # use JSON if configured to do so, i.e. when
                # LOCAL_CREDMON_TOKEN_USE_JSON = True (default)
                f.write(
                    json.dumps({
                        "access_token": serialized_token.decode(),
                        "expires_in": int(self.token_lifetime),
                    }))
            else:
                # otherwise write a bare token string when
                # LOCAL_CREDMON_TOKEN_USE_JSON = False
                f.write(serialized_token.decode() + '\n')

        access_token_path = os.path.join(self.cred_dir, username,
                                         token_name + '.use')

        # atomically move new file into place
        try:
            atomic_rename(tmp_access_token_path, access_token_path)
        except OSError as e:
            self.log.exception(
                "Failure when writing out new access token to {}: {}.".format(
                    access_token_path, str(e)))
            return False
        else:
            return True
Ejemplo n.º 14
0
    def test_ec_create(self):
        """
        Test the creation of a simple Elliptical Curve token
        """
        ec_private_key = ec.generate_private_key(ec.SECP256R1(),
                                                 default_backend())

        token = scitokens.SciToken(key=ec_private_key, algorithm="ES256")
        self.assertTrue(isinstance(ec_private_key, ec.EllipticCurvePrivateKey))
        token.update_claims({"test": "true"})
        serialized_token = token.serialize(issuer="local")

        self.assertEqual(len(serialized_token.decode('utf8').split(".")), 3)
        print(serialized_token)
Ejemplo n.º 15
0
    def test_valid(self):

        def always_accept(value):
            if value or not value:
                return True

        validator = scitokens.Validator()
        validator.add_validator("foo", always_accept)

        token = scitokens.SciToken()
        token["foo"] = "bar"

        self.assertTrue(validator.validate(token))
        self.assertTrue(validator(token))
Ejemplo n.º 16
0
 def setUp(self):
     """
     Setup a sample token for testing the enforcer.
     """
     now = time.time()
     private_key = cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key(
         public_exponent=65537,
         key_size=2048,
         backend=cryptography.hazmat.backends.default_backend())
     self._token = scitokens.SciToken(key=private_key)
     self._token["foo"] = "bar"
     self._token["iat"] = int(now)
     self._token["exp"] = int(now + 600)
     self._token["iss"] = "https://scitokens.org/unittest"
     self._token["nbf"] = int(now)
Ejemplo n.º 17
0
def Issue():
    """
    Issue a SciToken
    """

    algorithm = "RS256"
    payload = {}

    if request.method == 'POST':
        data = request.data
        try:
            dataDict = json.loads(data)
            payload = json.loads(dataDict['payload'])
            algorithm = dataDict['algorithm']
        except json.decoder.JSONDecodeError as json_err:
            return "", 400

    private_key_str = ""

    if algorithm == "RS256":

        # Load the private key
        if os.path.exists("private.pem"):
            private_key_str = open("private.pem").read()

        elif 'PRIVATE_KEY' in os.environ:
            private_key_str = base64.b64decode(os.environ['PRIVATE_KEY'])
        key_id = "key-rs256"
    elif algorithm == "ES256":
        # Load the private key
        if os.path.exists("ec_private.pem"):
            private_key_str = open("ec_private.pem").read()

        elif 'EC_PRIVATE_KEY' in os.environ:
            private_key_str = base64.b64decode(os.environ['EC_PRIVATE_KEY'])
        key_id = "key-es256"
    private_key = serialization.load_pem_private_key(private_key_str,
                                                     password=None,
                                                     backend=default_backend())

    token = scitokens.SciToken(key=private_key,
                               algorithm=algorithm,
                               key_id=key_id)
    for key, value in payload.items():
        token.update_claims({key: value})

    serialized_token = token.serialize(issuer="https://demo.scitokens.org")
    return serialized_token
Ejemplo n.º 18
0
    def test_aud(self):
        """
        Test the audience argument to deserialize
        """
        token = scitokens.SciToken(key = self._private_key)
        token.update_claims({'aud': 'local'})

        serialized_token = token.serialize(issuer = 'local')

        with self.assertRaises(InvalidAudienceError):
            scitokens.SciToken.deserialize(serialized_token, public_key = self._public_pem, insecure = True)

        new_token = scitokens.SciToken.deserialize(serialized_token,
                                                   public_key = self._public_pem,
                                                   insecure = True,
                                                   audience = 'local')
        self.assertIsInstance(new_token, scitokens.SciToken)
Ejemplo n.º 19
0
    def test_valid(self):
        """
        Basic unit test coverage of the Validator object.
        """
        def always_accept(value):
            """
            A validator that accepts any value.
            """
            if value or not value:
                return True

        validator = scitokens.Validator()
        validator.add_validator("foo", always_accept)

        token = scitokens.SciToken()
        token["foo"] = "bar"

        self.assertTrue(validator.validate(token))
        self.assertTrue(validator(token))
Ejemplo n.º 20
0
    def test_create(self):
        """
        Test the creation of a simple SciToken.
        """
        private_key = generate_private_key(
            public_exponent=65537,
            key_size=2048,
            backend=default_backend()
        )
        print(private_key.public_key().public_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PublicFormat.SubjectPublicKeyInfo
        ))

        token = scitokens.SciToken(key = private_key)
        token.update_claims({"test": "true"})
        serialized_token = token.serialize(issuer = "local")

        self.assertEqual(len(serialized_token.decode('utf8').split(".")), 3)
        print(serialized_token)
Ejemplo n.º 21
0
def main():
    """
    Given a set of command line parameters, generate a corresponding SciToken.
    """
    args = add_args()

    with open(args.keyfile, "r") as file_pointer:
        private_key_contents = file_pointer.read()

    loaded_private_key = serialization.load_pem_private_key(
        private_key_contents.encode(),
        password=None,  # Hey, it's a sample file committed to disk...
        backend=default_backend())

    token = scitokens.SciToken(key=loaded_private_key, key_id=args.key_id)

    for claim in args.claims:
        (key, value) = claim.split('=', 1)
        token.update_claims({key: value})

    serialized_token = token.serialize(issuer=args.issuer,
                                       lifetime=args.lifetime)
    print(serialized_token.decode())
Ejemplo n.º 22
0
def main():

    parser = argparse.ArgumentParser(
        description='Create token and test endpoint.')
    parser.add_argument('--aud', dest='aud', help="Insert an audience")
    parser.add_argument('pubjwk',
                        metavar='p',
                        type=str,
                        help='The jwks public key')
    args = parser.parse_args()

    private_key = None
    with open('private.pem', 'rb') as key_file:
        private_key = serialization.load_pem_private_key(
            key_file.read(), password=None, backend=default_backend())

    # Read in the public key to get the kid
    jwk_pub = ""
    with open(args.pubjwk, 'r') as jwk_pub_file:
        jwk_pub = json.load(jwk_pub_file)

    key_id = jwk_pub['keys'][0]['kid']

    token = scitokens.SciToken(key=private_key, key_id=key_id)
    token["scope"] = "read:/"

    if 'aud' in args and args.aud is not None:
        token["aud"] = args.aud

    token_str = token.serialize(issuer="https://localhost")
    headers = {"Authorization": "Bearer {0}".format(token_str)}
    #print token_str
    request = urllib2.Request("http://localhost:8080/tmp/random.txt",
                              headers=headers)
    contents = urllib2.urlopen(request).read()
    print contents,
Ejemplo n.º 23
0
def token_issuer():

    # Currently, we only support the client_credentials grant type.
    if request.form.get("grant_type") != "client_credentials":
        return return_oauth_error_response(
            "Incorrect grant_type; 'client_credentials' must be used.")
    requested_scopes = set(
        [i for i in request.form.get("scopes", "").split() if i])

    creds = {}
    dn_cred = None
    for key, val in request.environ.items():
        if key.startswith("GRST_CRED_AURI_"):
            entry_num = int(key[15:])  # 15 = len("GRST_CRED_AURI_")
            creds[entry_num] = val
    keys = creds.keys()
    keys.sort()
    entries = []
    for key in keys:
        if not dn_cred and creds[key].startswith("dn:"):
            dn_cred = creds[key][3:]
        entries.append(creds[key])

    if not dn_cred:
        return return_oauth_error_response(
            "No client certificate or proxy used for TLS authentication.")
    dn_cred = urllib.unquote_plus(dn_cred)

    #print entries
    scopes, user = generate_scopes_and_user(entries)
    #print scopes
    #print user

    # Compare the generated scopes against the requested scopes (if given)
    # If we don't give the user everything they want, then we
    return_updated_scopes = False
    if requested_scopes:
        updated_scopes = set()
        for issued_scope in scopes:
            for requested_scope in requested_scopes:
                new_scope = limit_scope(issued_scope, requested_scope)
                if new_scope:
                    updated_scopes.add(new_scope)
                    if new_scope != requested_scope:
                        changed_any_scope = True
        scopes = set(updated_scopes)
        if requested_scopes != updated_scopes:
            return_updated_scopes = True

    token = scitokens.SciToken(key=app.issuer_key, key_id=app.issuer_kid)
    token['scp'] = list(scopes)
    if user:
        token['sub'] = user
    else:
        token['sub'] = dn_cred
    if 'ISSUER' in app.config:
        issuer = app.config['ISSUER']
    else:
        split = urlparse.SplitResult(scheme="https",
                                     netloc=request.environ['HTTP_HOST'],
                                     path=request.environ['REQUEST_URI'],
                                     query="",
                                     fragment="")
        issuer = urlparse.urlunsplit(split)
    serialized_token = token.serialize(issuer=issuer,
                                       lifetime=app.config['LIFETIME'])

    json_response = {
        "access_token": serialized_token,
        "token_type": "bearer",
        "expires_in": app.config['LIFETIME'],
    }
    if return_updated_scopes:
        json_response["scope"] = " ".join(scopes)
    resp = app.response_class(response=json.dumps(json_response),
                              mimetype='application/json',
                              status=requests.codes.ok)
    resp.headers['Cache-Control'] = 'no-store'
    resp.headers['Pragma'] = 'no-cache'
    return resp
Ejemplo n.º 24
0
def token_issuer():

    # Currently, we only support the client_credentials grant type.
    if request.form.get("grant_type") != "client_credentials":
        return return_oauth_error_response(
            "Incorrect grant_type %s; 'client_credentials' must be used." %
            request.form.get("grant_type"))
    requested_scopes = set(
        [i for i in request.form.get("scopes", "").split() if i])

    creds = {}
    dn_cred = None
    entry_num = 0
    pattern = "GRST_CRED_AURI_"
    if app.config.get("CMS", False):
        pattern = "HTTP_CMS_AUTH"
    for key, val in request.environ.items():
        if app.config.get('VERBOSE', False):
            print("### request {} {}".format(key, val))
        if key.startswith("GRST_CRED_AURI_"):
            entry_num = int(key[15:])  # 15 = len("GRST_CRED_AURI_")
        if key.startswith(pattern):
            if pattern == "HTTP_CMS_AUTH":
                if key.endswith("_DN"):
                    val = "dn:" + val
                elif key.endswith("_LOGIN"):
                    val = "username:"******"HTTP_CMS_AUTHZ"):
                    val = "fqan:/{}".format(val.split(':')[-1])
                else:
                    continue
            creds[entry_num] = val
            entry_num += 1
    keys = creds.keys()
    keys.sort()
    entries = []
    for key in keys:
        if not dn_cred and creds[key].startswith("dn:"):
            dn_cred = creds[key][3:]
        entries.append(creds[key])

    if not dn_cred:
        return return_oauth_error_response(
            "No client certificate or proxy used for TLS authentication.")
    dn_cred = urllib.unquote_plus(dn_cred)

    scopes, user = generate_scopes_and_user(entries)
    if app.config.get('VERBOSE', False):
        print("### creds  : {}".format(creds))
        print("### entries: {}".format(entries))
        print("### scopes : {}".format(scopes))
        print("### user   : {}".format(user))

    # Compare the generated scopes against the requested scopes (if given)
    # If we don't give the user everything they want, then we
    return_updated_scopes = False
    if requested_scopes:
        updated_scopes = set()
        for issued_scope in scopes:
            for requested_scope in requested_scopes:
                new_scope = limit_scope(issued_scope, requested_scope)
                if new_scope:
                    updated_scopes.add(new_scope)
                    if new_scope != requested_scope:
                        changed_any_scope = True
        scopes = set(updated_scopes)
        if requested_scopes != updated_scopes:
            return_updated_scopes = True

    # Return a 405
    if not scopes:
        return return_oauth_error_response(
            "No applicable scopes for this user.")

    if isinstance(app.issuer_key, ec.EllipticCurvePrivateKey):
        algorithm = "ES256"
    else:
        algorithm = "RS256"

    token = scitokens.SciToken(key=app.issuer_key,
                               key_id=app.issuer_kid,
                               algorithm=algorithm)
    token['scope'] = ' '.join(scopes)
    if user:
        token['sub'] = user
    else:
        token['sub'] = dn_cred
    if 'ISSUER' in app.config:
        issuer = app.config['ISSUER']
    else:
        split = urlparse.SplitResult(scheme="https",
                                     netloc=request.environ['HTTP_HOST'],
                                     path=request.environ['REQUEST_URI'],
                                     query="",
                                     fragment="")
        issuer = urlparse.urlunsplit(split)

    try:
        serialized_token = token.serialize(issuer=issuer,
                                           lifetime=app.config['LIFETIME'])
    except Exception as ex:
        return return_internal_error_response(
            "Failure when serializing token: {}".format(ex))

    json_response = {
        "access_token": serialized_token,
        "token_type": "bearer",
        "expires_in": app.config['LIFETIME'],
    }
    if return_updated_scopes:
        json_response["scope"] = " ".join(scopes)
    resp = app.response_class(response=json.dumps(json_response),
                              mimetype='application/json',
                              status=requests.codes.ok)
    resp.headers['Cache-Control'] = 'no-store'
    resp.headers['Pragma'] = 'no-cache'
    return resp
Ejemplo n.º 25
0
    def test_discover(self):
        """
        Test wlcg bearer token discovery
        """
        # unset any wlcg discovery environment variables
        try:
            del os.environ['BEARER_TOKEN']
        except KeyError:
            pass
        try:
            del os.environ['BEARER_TOKEN_FILE']
        except KeyError:
            pass
        try:
            del os.environ['XDG_RUNTIME_DIR']
        except KeyError:
            pass

        # move any /tmp/bt_u$ID file out of the way
        try:
            bt_file = 'bt_u{}'.format(os.geteuid())
        except AttributeError as exc:  # windows doesn't have geteuid
            self.skipTest(str(exc))
        bt_path = os.path.join('/tmp', bt_file)
        (bt_fd, bt_tmp) = tempfile.mkstemp()
        os.close(bt_fd)
        if os.path.isfile(bt_path):
            shutil.move(bt_path, bt_tmp)

        # check that the function fails properly
        with self.assertRaises(IOError):
            print(self._token.discover())

        # generate a token and save it as /tmp/bt_u$ID
        tmp_file_token = scitokens.SciToken(key=self._private_key,
                                            key_id="tmp_file")
        tmp_file_token['scope'] = 'tmp_file'
        tmp_file_token_s = tmp_file_token.serialize(issuer="local")
        with open(bt_path, 'w') as f:
            f.write(tmp_file_token_s.decode('utf-8'))

        # discover a token and check we found /tmp/bt_u$ID
        token = self._token.discover(public_key=self._public_pem)
        self.assertEqual(token._serialized_token,
                         tmp_file_token._serialized_token)

        # generate a token and save it as $XDG_RUNTIME_DIR/bt_u$ID
        xdg_file_token = scitokens.SciToken(key=self._private_key,
                                            key_id="xdg_file")
        xdg_file_token['scope'] = 'xdg_file'
        xdg_file_token_s = xdg_file_token.serialize(issuer="local")
        xdg_dir = tempfile.mkdtemp()
        xdg_path = os.path.join(xdg_dir, bt_file)
        with open(xdg_path, 'w') as f:
            f.write(xdg_file_token_s.decode('utf-8'))

        # set the wlcg discovery environment variable
        os.environ['XDG_RUNTIME_DIR'] = xdg_dir

        # discover a token and check we found $XDG_RUNTIME_DIR/bt_u$ID
        # and not /tmp/bt_u$ID
        token = self._token.discover(public_key=self._public_pem,
                                     insecure=True)
        self.assertNotEqual(token._serialized_token,
                            tmp_file_token._serialized_token)
        self.assertEqual(token._serialized_token,
                         xdg_file_token._serialized_token)

        # generate a token and save it in BEARER_TOKEN_FILE
        bearer_file_token = scitokens.SciToken(key=self._private_key,
                                               key_id="bearer_file")
        bearer_file_token['scope'] = 'bearer_file'
        bearer_file_token_s = bearer_file_token.serialize(issuer="local")
        (fd, bearer_token_file) = tempfile.mkstemp()
        with open(bearer_token_file, 'w') as f:
            f.write(bearer_file_token_s.decode('utf-8'))
        os.close(fd)

        # set the wlcg discovery environment variable
        os.environ['BEARER_TOKEN_FILE'] = bearer_token_file

        # discover a token and check we found BEARER_TOKEN_FILE
        # and not $XDG_RUNTIME_DIR/bt_u$ID or /tmp/bt_u$ID
        token = self._token.discover(public_key=self._public_pem,
                                     insecure=True)
        self.assertNotEqual(token._serialized_token,
                            tmp_file_token._serialized_token)
        self.assertNotEqual(token._serialized_token,
                            xdg_file_token._serialized_token)
        self.assertEqual(token._serialized_token,
                         bearer_file_token._serialized_token)

        # generate a token
        bearer_token = scitokens.SciToken(key=self._private_key,
                                          key_id="bearer")
        bearer_token['scope'] = 'bearer'
        bearer_token_s = bearer_token.serialize(issuer="local")

        # set the wlcg discovery environment variable
        os.environ['BEARER_TOKEN'] = bearer_token_s.decode('utf-8')

        # discover a token and check we found BEARER_TOKEN
        # and not BEARER_TOKEN_FILE, $XDG_RUNTIME_DIR/bt_u$ID or /tmp/bt_u$ID
        token = self._token.discover(public_key=self._public_pem,
                                     insecure=True)
        self.assertNotEqual(token._serialized_token,
                            tmp_file_token._serialized_token)
        self.assertNotEqual(token._serialized_token,
                            xdg_file_token._serialized_token)
        self.assertNotEqual(token._serialized_token,
                            bearer_file_token._serialized_token)
        self.assertEqual(token._serialized_token,
                         bearer_token._serialized_token)

        # clean up the files and directories created
        shutil.rmtree(xdg_dir)
        os.remove(bearer_token_file)
        os.remove(bt_path)
        if os.path.isfile(bt_tmp):
            shutil.move(bt_tmp, bt_path)
Ejemplo n.º 26
0
 def test_unsupported_key(self):
     """
     Test a token with an unsupported key algorithm
     """
     with self.assertRaises(UnsupportedKeyException):
         scitokens.SciToken(key=self._private_key, algorithm="doesnotexist")
Ejemplo n.º 27
0
    def refresh_access_token(self, username, token_name):
        """
        Create a SciToken at the specified path.
        """

        token = scitokens.SciToken(algorithm="ES256",
                                   key=self._private_key,
                                   key_id=self._private_key_id)
        token.update_claims({'sub': username})
        user_authz = self.authz_template.format(username=username)
        token.update_claims({'scope': user_authz})

        # Only set the version if we have one.  No version is valid, and implies scitokens:1.0
        if self.token_ver:
            token.update_claims({'ver': self.token_ver})

        # Convert the space separated list of audiences to a proper list
        # No aud is valid for scitokens:1.0 tokens.  Also, no resonable default.
        aud_list = self.token_aud.strip().split()
        if aud_list:
            token.update_claims({'aud': aud_list})
        elif self.token_ver.lower() == "scitokens:2.0":
            self.log.error(
                'No "aud" claim, LOCAL_CREDMON_TOKEN_AUDIENCE must be set when requesting a scitokens:2.0 token'
            )
            return False

        # Serialize the token and write it to a file
        try:
            serialized_token = token.serialize(issuer=self.token_issuer,
                                               lifetime=int(
                                                   self.token_lifetime))
        except TypeError:
            self.log.exception(
                "Failure when attempting to serialize a SciToken, likely due to algorithm mismatch"
            )
            return False

        # copied from the Vault credmon
        (tmp_fd, tmp_access_token_path) = tempfile.mkstemp(dir=self.cred_dir)
        with os.fdopen(tmp_fd, 'w') as f:
            if self.token_use_json:
                # use JSON if configured to do so, i.e. when
                # LOCAL_CREDMON_TOKEN_USE_JSON = True (default)
                f.write(
                    json.dumps({
                        "access_token": serialized_token.decode(),
                        "expires_in": int(self.token_lifetime),
                    }))
            else:
                # otherwise write a bare token string when
                # LOCAL_CREDMON_TOKEN_USE_JSON = False
                f.write(serialized_token.decode() + '\n')

        access_token_path = os.path.join(self.cred_dir, username,
                                         token_name + '.use')

        # atomically move new file into place
        try:
            atomic_rename(tmp_access_token_path, access_token_path)
        except OSError as e:
            self.log.exception(
                "Failure when writing out new access token to {}: {}.".format(
                    access_token_path, str(e)))
            return False
        else:
            return True
Ejemplo n.º 28
0
import scitokens
import sys

scitokens.SciToken(token=sys.argv[1])