def validate(self, message): """Verify the signature of self RSA fulfillment. The signature of self RSA fulfillment is verified against the provided message and the condition's public modulus. Args: message (bytes): Message to verify. Returns: bool: Whether self fulfillment is valid. """ if not isinstance(message, bytes): raise Exception('Message must be provided as bytes, was: ' + message) public_numbers = RSAPublicNumbers( PUBLIC_EXPONENT, int.from_bytes(self.modulus, byteorder='big'), ) public_key = public_numbers.public_key(default_backend()) try: public_key.verify( self.signature, message, padding.PSS( mgf=padding.MGF1(hashes.SHA256()), salt_length=SALT_LENGTH, ), hashes.SHA256()) except InvalidSignature as exc: raise ValidationError('Invalid RSA signature') from exc return True
def test_response_values(app, client): """ Do some more thorough checking on the response obtained from the JWKS endpoint. Because fence only uses the RSA algorithm for signing and validating JWTs, the ``alg``, ``kty``, ``use``, and ``key_ops`` fields are hard-coded for this. Furthermore, every JWK in the response should have values for the RSA public modulus ``n`` and exponent ``e`` which may be used to reconstruct the public key. """ keys = client.get('/.well-known/jwks').json['keys'] app_kids = [keypair.kid for keypair in app.keypairs] app_public_keys = [keypair.public_key for keypair in app.keypairs] for key in keys: assert key['alg'] == 'RS256' assert key['kty'] == 'RSA' assert key['use'] == 'sig' assert key['key_ops'] == 'verify' assert key['kid'] in app_kids # Attempt to reproduce the public key from the values for the public # modulus and exponent provided in the response, using cryptography # primitives. n = int(base64.b64decode(key['n'])) e = int(base64.b64decode(key['e'])) numbers = RSAPublicNumbers(e, n) key = numbers.public_key(default_backend()) key_pem = key.public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo) assert key_pem in app_public_keys
def from_jwk(jwk): try: obj = json.loads(jwk) except ValueError: raise InvalidKeyError('Key is not valid JSON') if obj.get != 'RSA': raise InvalidKeyError('Not an RSA key') if 'd' in obj and 'e' in obj and 'n' in obj: # Private key if 'oth' in obj: raise InvalidKeyError( 'Unsupported RSA private key: > 2 primes not supported' ) other_props = ['p', 'q', 'dp', 'dq', 'qi'] props_found = [prop in obj for prop in other_props] any_props_found = any(props_found) if any_props_found and not all(props_found): raise InvalidKeyError( 'RSA key must include all parameters if any are present besides d' ) public_numbers = RSAPublicNumbers( from_base64url_uint(obj['e']), from_base64url_uint(obj['n'])) if any_props_found: numbers = RSAPrivateNumbers( d=from_base64url_uint(obj['d']), p=from_base64url_uint(obj['p']), q=from_base64url_uint(obj['q']), dmp1=from_base64url_uint(obj['dp']), dmq1=from_base64url_uint(obj['dq']), iqmp=from_base64url_uint(obj['qi']), public_numbers=public_numbers) else: d = from_base64url_uint(obj['d']) p, q = rsa_recover_prime_factors(public_numbers.n, d, public_numbers.e) numbers = RSAPrivateNumbers(d=d, p=p, q=q, dmp1=rsa_crt_dmp1(d, p), dmq1=rsa_crt_dmq1(d, q), iqmp=rsa_crt_iqmp(p, q), public_numbers=public_numbers) return numbers.private_key(default_backend()) elif 'n' in obj and 'e' in obj: # Public key numbers = RSAPublicNumbers(from_base64url_uint(obj['e']), from_base64url_uint(obj['n'])) return numbers.public_key(default_backend()) else: raise InvalidKeyError('Not a public or private key')
def _convert(exponent: int, modulus: int) -> bytearray: components = RSAPublicNumbers(exponent, modulus) pub = components.public_key(backend=default_backend()) key_bytes = pub.public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo, ) return cast(bytearray, key_bytes)
def rsa_pem_from_jwk(jwk): public_num = RSAPublicNumbers(n=decode_value(jwk['n']), e=decode_value(jwk['e'])) public_key = public_num.public_key(default_backend()) pem = public_key.public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo) return pem
def __init__(self, json): super().__init__(json) decoded_n = bytes_to_int(urlsafe_b64decode(self.n)) decoded_e = bytes_to_int(urlsafe_b64decode(self.e)) rsa_public_key = RSAPublicNumbers(decoded_e, decoded_n).public_key( default_backend()) self.public_key = rsa_public_key.public_bytes( Encoding.PEM, PublicFormat.SubjectPublicKeyInfo).decode()
def pubkey_from_tpm2b_public(public: bytes) -> pubkey_type: (public, rest) = _extract_tpm2b(public) if len(rest) != 0: raise ValueError("More in tpm2b_public than tpmt_public") # Extract type, [nameAlg], and [objectAttributes] (we don't care about the # latter two) (alg_type, _, _) = struct.unpack(">HHI", public[0:8]) # Ignore the authPolicy (_, sym_parms) = _extract_tpm2b(public[8:]) # Ignore the non-asym-alg parameters (sym_mode, ) = struct.unpack(">H", sym_parms[0:2]) # Ignore the sym_mode and keybits (4 bytes), possibly symmetric (2) and sign # scheme (2) to_skip = 4 + 2 # sym_mode, keybits and sign scheme if sym_mode != TPM2_ALG_NULL: to_skip = to_skip + 2 asym_parms = sym_parms[to_skip:] # Handle fields if alg_type == TPM_ALG_RSA: (keybits, exponent) = struct.unpack(">HI", asym_parms[0:6]) if exponent == 0: exponent = 65537 (modulus, _) = _extract_tpm2b(asym_parms[6:]) if (len(modulus) * 8) != keybits: raise ValueError( f"Misparsed either modulus or keybits: {len(modulus)}*8 != {keybits}" ) bmodulus = int.from_bytes(modulus, byteorder="big") numbers = RSAPublicNumbers(exponent, bmodulus) return numbers.public_key(backend=default_backend()) if alg_type == TPM_ALG_ECC: (curve, _) = struct.unpack(">HH", asym_parms[0:4]) asym_x = asym_parms[4:] curve = _curve_from_curve_id(curve) (x, asym_y) = _extract_tpm2b(asym_x) (y, rest) = _extract_tpm2b(asym_y) if len(rest) != 0: raise ValueError("Misparsed: more contents after X and Y") if (len(x) * 8) != curve.key_size: raise ValueError( f"Misparsed either X or curve: {len(x)}*8 != {curve.key_size}") if (len(y) * 8) != curve.key_size: raise ValueError( f"Misparsed either Y or curve curve: {len(y)}*8 != {curve.key_size}" ) bx = int.from_bytes(x, byteorder="big") by = int.from_bytes(y, byteorder="big") numbers = EllipticCurvePublicNumbers(bx, by, curve) return numbers.public_key(backend=default_backend()) raise ValueError(f"Invalid tpm2b_public type: {alg_type}")
def get_PEM_from_RSA(self, modulus, exponent): exponent_long = self.base64_to_long(exponent) modulus_long = self.base64_to_long(modulus) numbers = RSAPublicNumbers(exponent_long, modulus_long) public_key = numbers.public_key(backend=default_backend()) pem = public_key.public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo) return pem
def jwk_to_pem(jwk): exponent = base64_to_long(jwk['e']) modulus = base64_to_long(jwk['n']) numbers = RSAPublicNumbers(exponent, modulus) public_key = numbers.public_key(backend=default_backend()) pem = public_key.public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo) return pem
def pemFromModExp(modulus, exponent): exponentlong = base64_to_long(exponent) moduluslong = base64_to_long(modulus) numbers = RSAPublicNumbers(exponentlong, moduluslong) public_key = numbers.public_key(backend=default_backend()) pem = public_key.public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo) return pem
def __init__(self, json): self._json = json self.expires = datetime.now() + timedelta(minutes=JWK_EXPIRES_MIN) decoded_n = bytes_to_int(urlsafe_b64decode(self.n)) decoded_e = bytes_to_int(urlsafe_b64decode(self.e)) rsa_public_key = RSAPublicNumbers(decoded_e, decoded_n).public_key( default_backend()) self.public_key = rsa_public_key.public_bytes( Encoding.PEM, PublicFormat.SubjectPublicKeyInfo).decode()
def rsa_pem_from_jwk(jsonKeyResult): print(jsonKeyResult['keys'][0]['n']) exponent = base64_to_long(jsonKeyResult['keys'][0]['e']) modulus = base64_to_long(jsonKeyResult['keys'][0]['n']) numbers = RSAPublicNumbers(exponent, modulus) public_key = numbers.public_key(backend=default_backend()) pem = public_key.public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo) return pem
def verify(self): """ Verify the SSHCSR object in situ """ public_key = RSAPublicNumbers(self.public_key.e, self.public_key.n) try: public_key.public_key(default_backend()).verify( base64.b64decode(self.signature), self.json(False), padding.PKCS1v15(), hashes.SHA256()) return True except InvalidSignature: return False
def from_jwk(obj): if obj.get('kty') != 'RSA': raise InvalidKeyError('Not an RSA key') # Public key numbers = RSAPublicNumbers(from_base64url_uint(obj['e']), from_base64url_uint(obj['n'])) return numbers.public_key(default_backend())
def jwk_to_pem(jwk_): # source: https://github.com/jpf/okta-jwks-to-pem/blob/master/jwks_to_pem.py exponent = base64_to_long(jwk_['e']) modulus = base64_to_long(jwk_['n']) numbers = RSAPublicNumbers(exponent, modulus) public_key = numbers.public_key(backend=default_backend()) pem = public_key.public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo ) return pem
def from_jwk(jwk): if not isinstance(jwk, JsonWebKey): raise TypeError('The specified jwk must be a JsonWebKey') if jwk.kty != 'RSA' and jwk.kty != 'RSA-HSM': raise ValueError( 'The specified jwk must have a key type of "RSA" or "RSA-HSM"') if not jwk.n or not jwk.e: raise ValueError( 'Invalid RSA jwk, both n and e must be have values') rsa_key = _RsaKey() rsa_key.kid = jwk.kid rsa_key.kty = jwk.kty rsa_key.key_ops = jwk.key_ops pub = RSAPublicNumbers(n=_bytes_to_int(jwk.n), e=_bytes_to_int(jwk.e)) # if the private key values are specified construct a private key # only the secret primes and private exponent are needed as other fields can be calculated if jwk.p and jwk.q and jwk.d: # convert the values of p, q, and d from bytes to int p = _bytes_to_int(jwk.p) q = _bytes_to_int(jwk.q) d = _bytes_to_int(jwk.d) # convert or compute the remaining private key numbers dmp1 = _bytes_to_int(jwk.dp) if jwk.dp else rsa_crt_dmp1( private_exponent=d, p=p) dmq1 = _bytes_to_int(jwk.dq) if jwk.dq else rsa_crt_dmq1( private_exponent=d, q=q) iqmp = _bytes_to_int(jwk.qi) if jwk.qi else rsa_crt_iqmp(p=p, q=q) # create the private key from the jwk key values priv = RSAPrivateNumbers(p=p, q=q, d=d, dmp1=dmp1, dmq1=dmq1, iqmp=iqmp, public_numbers=pub) key_impl = priv.private_key( cryptography.hazmat.backends.default_backend()) # if the necessary private key values are not specified create the public key else: key_impl = pub.public_key( cryptography.hazmat.backends.default_backend()) rsa_key._rsa_impl = key_impl return rsa_key
def jwks_to_pem_keys(json_web_keys): pem_keys = [] for jwk in json_web_keys['keys']: exponent = base64_to_long(jwk['e']) modulus = base64_to_long(jwk['n']) numbers = RSAPublicNumbers(exponent, modulus) public_key = numbers.public_key(backend=default_backend()) pem = public_key.public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo) pem_keys.append(pem) return pem_keys
def _load_public_key(self, e, n): def to_int(x): bs = base64.urlsafe_b64decode(x + "==") return int.from_bytes(bs, byteorder="big") ei = to_int(e) ni = to_int(n) numbers = RSAPublicNumbers(ei, ni) public_key = numbers.public_key(backend=default_backend()) pem = public_key.public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo) return pem
def calculate_response(challenge, user_id, password): obj = decode(challenge) out = { 1: [49], 2: obj[2], 3: obj[3], 8: [ord(c) for c in user_id], 9: [ord(c) for c in password], } encoded = encode(out) pub_key = RSAPublicNumbers(ba2num(obj[5]), ba2num(obj[4])).public_key(default_backend()) encrypted = pub_key.encrypt(ba2bytes(encoded), PKCS1v15()) return bytes2hex(encrypted)
def get_rsa_public_key(n, e): """ Retrieve an RSA public key based on a module and exponent as provided by the JWKS format. :param n: :param e: :return: a RSA Public Key in PEM format """ n = int(binascii.hexlify(jwt.utils.base64url_decode(bytes(n))), 16) e = int(binascii.hexlify(jwt.utils.base64url_decode(bytes(e))), 16) pub = RSAPublicNumbers(e, n).public_key(default_backend()) return pub.public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo)
def get_rsa_public_key(n, e): """ Retrieve an RSA public key based on a module and exponent as provided by the JWKS format. :param n: :param e: :return: a RSA Public Key in PEM format """ n = int(binascii.hexlify(jwt.utils.base64url_decode(bytes(n))), 16) e = int(binascii.hexlify(jwt.utils.base64url_decode(bytes(e))), 16) pub = RSAPublicNumbers(e, n).public_key(default_backend()) return pub.public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo )
def verify(signed, public_key): data = base642bytes(signed.data) signature = base642bytes(signed.signature) key = RSAPublicNumbers(public_key.e, base642int(public_key.n)).public_key(BACKEND) key.verify( signature=signature, data=data, padding=PADDING.get(signed.padding, PSS), algorithm=ALGORITHM.get(signed.algorithm, SHA256), ) return json2value(data.decode("utf8"))
def _convert_rsa_private_key(keypair): backend = Backend() def _trim_bn_to_int(s): binary = s.lstrip(b'\x00') bn_ptr = backend._lib.BN_bin2bn(binary, len(binary), backend._ffi.NULL) try: return backend._bn_to_int(bn_ptr) finally: backend._lib.OPENSSL_free(bn_ptr) (n, e, d, iqmp, p, q) = [ _trim_bn_to_int(value) for value in _read_KEY_RSA(io.BytesIO(keypair.private_key)) ] numbers = RSAPrivateNumbers( d=d, p=p, q=q, dmp1=rsa_crt_dmp1(d, p), dmq1=rsa_crt_dmq1(d, q), iqmp=rsa_crt_iqmp(p, q), public_numbers=RSAPublicNumbers(e=e, n=n), ) return numbers.private_key(backend)
def rsa_pem_from_jwk(): return RSAPublicNumbers( n=decode_value(jwts["keys"][0]['n']), e=decode_value(jwts["keys"][0]['e'])).public_key( default_backend()).public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo)
async def fetch_keys(client: httpx.AsyncClient) -> None: """ Fetch JWT keys and store them in the JWT_KEYS that will act as a cache. """ # TODO: Handling of the JWKS should be refactored to a separate library and it # should make sure the caching + fetching fresh data is implemented according to # standards. This simplified version now caches the keys forever (= until restart). url = openid_conf.jwks_uri logger.info("Fetching JWKS", url=url) response = await client.get(url) key_data = response.json() for entry in key_data["keys"]: kid = entry["kid"] algorithm = entry["alg"] if entry.get("kty") == "RSA" and "n" in entry and "e" in entry: n = urlsafe_b64_to_unsigned_int(entry["n"]) e = urlsafe_b64_to_unsigned_int(entry["e"]) key = RSAPublicNumbers(e=e, n=n).public_key(default_backend()) else: key = load_der_x509_certificate(b64decode(entry["x5c"][0]), default_backend()).public_key() JWT_KEYS[kid] = { "algorithm": algorithm, "key": key, } logger.info("Added JWT key", kid=kid, algorithm=algorithm, type=type(key).__name__)
def get_pem_public_key_from_modulus_exponent(n, e): def _b64decode(data): # padding to have multiple of 4 characters if len(data) % 4: data = data + '=' * (len(data) % 4) data = data.encode('ascii') data = bytes(data) return long(base64.urlsafe_b64decode(data).encode('hex'), 16) modulus = _b64decode(n) exponent = _b64decode(e) numbers = RSAPublicNumbers(exponent, modulus) public_key = numbers.public_key(backend=default_backend()) return public_key.public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo)
def load_private_key(self): obj = self._dict_data if 'oth' in obj: # pragma: no cover # https://tools.ietf.org/html/rfc7518#section-6.3.2.7 raise ValueError('"oth" is not supported yet') public_numbers = RSAPublicNumbers(base64_to_int(obj['e']), base64_to_int(obj['n'])) if has_all_prime_factors(obj): numbers = RSAPrivateNumbers(d=base64_to_int(obj['d']), p=base64_to_int(obj['p']), q=base64_to_int(obj['q']), dmp1=base64_to_int(obj['dp']), dmq1=base64_to_int(obj['dq']), iqmp=base64_to_int(obj['qi']), public_numbers=public_numbers) else: d = base64_to_int(obj['d']) p, q = rsa_recover_prime_factors(public_numbers.n, d, public_numbers.e) numbers = RSAPrivateNumbers(d=d, p=p, q=q, dmp1=rsa_crt_dmp1(d, p), dmq1=rsa_crt_dmq1(d, q), iqmp=rsa_crt_iqmp(p, q), public_numbers=public_numbers) return numbers.private_key(default_backend())
def test_load_ssh_public_key_rsa(self, backend): ssh_key = ( b"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDDu/XRP1kyK6Cgt36gts9XAk" b"FiiuJLW6RU0j3KKVZSs1I7Z3UmU9/9aVh/rZV43WQG8jaR6kkcP4stOR0DEtll" b"PDA7ZRBnrfiHpSQYQ874AZaAoIjgkv7DBfsE6gcDQLub0PFjWyrYQUJhtOLQEK" b"vY/G0vt2iRL3juawWmCFdTK3W3XvwAdgGk71i6lHt+deOPNEPN2H58E4odrZ2f" b"sxn/adpDqfb2sM0kPwQs0aWvrrKGvUaustkivQE4XWiSFnB0oJB/lKK/CKVKuy" b"///ImSCGHQRvhwariN2tvZ6CBNSLh3iQgeB0AkyJlng7MXB2qYq/Ci2FUOryCX" b"2MzHvnbv testkey@localhost") key = load_ssh_public_key(ssh_key, backend) assert key is not None assert isinstance(key, interfaces.RSAPublicKey) numbers = key.public_numbers() expected_e = 0x10001 expected_n = int( '00C3BBF5D13F59322BA0A0B77EA0B6CF570241628AE24B5BA454D' '23DCA295652B3523B67752653DFFD69587FAD9578DD6406F23691' 'EA491C3F8B2D391D0312D9653C303B651067ADF887A5241843CEF' '8019680A088E092FEC305FB04EA070340BB9BD0F1635B2AD84142' '61B4E2D010ABD8FC6D2FB768912F78EE6B05A60857532B75B75EF' 'C007601A4EF58BA947B7E75E38F3443CDD87E7C138A1DAD9D9FB3' '19FF69DA43A9F6F6B0CD243F042CD1A5AFAEB286BD46AEB2D922B' 'D01385D6892167074A0907F94A2BF08A54ABB2FFFFC89920861D0' '46F8706AB88DDADBD9E8204D48B87789081E074024C8996783B31' '7076A98ABF0A2D8550EAF2097D8CCC7BE76EF', 16) expected = RSAPublicNumbers(expected_e, expected_n) assert numbers == expected
def private_numbers(self) -> RSAPrivateNumbers: n = self._backend._ffi.new("BIGNUM **") e = self._backend._ffi.new("BIGNUM **") d = self._backend._ffi.new("BIGNUM **") p = self._backend._ffi.new("BIGNUM **") q = self._backend._ffi.new("BIGNUM **") dmp1 = self._backend._ffi.new("BIGNUM **") dmq1 = self._backend._ffi.new("BIGNUM **") iqmp = self._backend._ffi.new("BIGNUM **") self._backend._lib.RSA_get0_key(self._rsa_cdata, n, e, d) self._backend.openssl_assert(n[0] != self._backend._ffi.NULL) self._backend.openssl_assert(e[0] != self._backend._ffi.NULL) self._backend.openssl_assert(d[0] != self._backend._ffi.NULL) self._backend._lib.RSA_get0_factors(self._rsa_cdata, p, q) self._backend.openssl_assert(p[0] != self._backend._ffi.NULL) self._backend.openssl_assert(q[0] != self._backend._ffi.NULL) self._backend._lib.RSA_get0_crt_params(self._rsa_cdata, dmp1, dmq1, iqmp) self._backend.openssl_assert(dmp1[0] != self._backend._ffi.NULL) self._backend.openssl_assert(dmq1[0] != self._backend._ffi.NULL) self._backend.openssl_assert(iqmp[0] != self._backend._ffi.NULL) return RSAPrivateNumbers( p=self._backend._bn_to_int(p[0]), q=self._backend._bn_to_int(q[0]), d=self._backend._bn_to_int(d[0]), dmp1=self._backend._bn_to_int(dmp1[0]), dmq1=self._backend._bn_to_int(dmq1[0]), iqmp=self._backend._bn_to_int(iqmp[0]), public_numbers=RSAPublicNumbers( e=self._backend._bn_to_int(e[0]), n=self._backend._bn_to_int(n[0]), ), )
def rsa_pem_from_jwk(jwk): return (RSAPublicNumbers( n=decode_value(jwk["n"]), e=decode_value(jwk["e"])).public_key(default_backend()).public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo, ))
def _process_ssh_rsa(self, data): """Parses ssh-rsa public keys.""" current_position, raw_e = self._unpack_by_int(data, 0) current_position, raw_n = self._unpack_by_int(data, current_position) unpacked_e = self._parse_long(raw_e) unpacked_n = self._parse_long(raw_n) self.rsa = RSAPublicNumbers(unpacked_e, unpacked_n).public_key(default_backend()) self.bits = self.rsa.key_size if self.strict_mode: min_length = self.RSA_MIN_LENGTH_STRICT max_length = self.RSA_MAX_LENGTH_STRICT else: min_length = self.RSA_MIN_LENGTH_LOOSE max_length = self.RSA_MAX_LENGTH_LOOSE if self.bits < min_length: raise TooShortKeyError( "%s key data can not be shorter than %s bits (was %s)" % (self.key_type, min_length, self.bits)) if self.bits > max_length: raise TooLongKeyError( "%s key data can not be longer than %s bits (was %s)" % (self.key_type, max_length, self.bits)) return current_position
def from_dict(cls, dct): if 'oth' in dct: raise UnsupportedKeyTypeError( 'RSA keys with multiples primes are not supported') try: e = uint_b64decode(dct['e']) n = uint_b64decode(dct['n']) except KeyError as why: raise MalformedJWKError('e and n are required') pub_numbers = RSAPublicNumbers(e, n) if 'd' not in dct: return cls( pub_numbers.public_key(backend=default_backend()), **dct) d = uint_b64decode(dct['d']) privparams = {'p', 'q', 'dp', 'dq', 'qi'} product = set(dct.keys()) & privparams if len(product) == 0: p, q = rsa_recover_prime_factors(n, e, d) priv_numbers = RSAPrivateNumbers( d=d, p=p, q=q, dmp1=rsa_crt_dmp1(d, p), dmq1=rsa_crt_dmq1(d, q), iqmp=rsa_crt_iqmp(p, q), public_numbers=pub_numbers) elif product == privparams: priv_numbers = RSAPrivateNumbers( d=d, p=uint_b64decode(dct['p']), q=uint_b64decode(dct['q']), dmp1=uint_b64decode(dct['dp']), dmq1=uint_b64decode(dct['dq']), iqmp=uint_b64decode(dct['qi']), public_numbers=pub_numbers) else: # If the producer includes any of the other private key parameters, # then all of the others MUST be present, with the exception of # "oth", which MUST only be present when more than two prime # factors were used. raise MalformedJWKError( 'p, q, dp, dq, qi MUST be present or' 'all of them MUST be absent') return cls(priv_numbers.private_key(backend=default_backend()), **dct)
def from_jwk(jwk): if not isinstance(jwk, JsonWebKey): raise TypeError('The specified jwk must be a JsonWebKey') if jwk.kty != 'RSA' and jwk.kty != 'RSA-HSM': raise ValueError('The specified jwk must have a key type of "RSA" or "RSA-HSM"') if not jwk.n or not jwk.e: raise ValueError('Invalid RSA jwk, both n and e must be have values') rsa_key = _RsaKey() rsa_key.kid = jwk.kid rsa_key.kty = jwk.kty rsa_key.key_ops = jwk.key_ops pub = RSAPublicNumbers(n=_bytes_to_int(jwk.n), e=_bytes_to_int(jwk.e)) # if the private key values are specified construct a private key # only the secret primes and private exponent are needed as other fields can be calculated if jwk.p and jwk.q and jwk.d: # convert the values of p, q, and d from bytes to int p = _bytes_to_int(jwk.p) q = _bytes_to_int(jwk.q) d = _bytes_to_int(jwk.d) # convert or compute the remaining private key numbers dmp1 = _bytes_to_int(jwk.dp) if jwk.dp else rsa_crt_dmp1(private_exponent=d, p=p) dmq1 = _bytes_to_int(jwk.dq) if jwk.dq else rsa_crt_dmq1(private_exponent=d, q=q) iqmp = _bytes_to_int(jwk.qi) if jwk.qi else rsa_crt_iqmp(p=p, q=q) # create the private key from the jwk key values priv = RSAPrivateNumbers(p=p, q=q, d=d, dmp1=dmp1, dmq1=dmq1, iqmp=iqmp, public_numbers=pub) key_impl = priv.private_key(cryptography.hazmat.backends.default_backend()) # if the necessary private key values are not specified create the public key else: key_impl = pub.public_key(cryptography.hazmat.backends.default_backend()) rsa_key._rsa_impl = key_impl return rsa_key
def calculate_cost(self): """Calculate the cost of fulfilling self condition. The cost of the RSA condition is the size of the modulus squared, divided By 64. Returns: int: Expected maximum cost to fulfill self condition. """ if self.modulus is None: raise MissingDataError('Requires a public modulus') public_numbers = RSAPublicNumbers( PUBLIC_EXPONENT, int.from_bytes(self.modulus, byteorder='big'), ) public_key = public_numbers.public_key(default_backend()) modulus_bit_length = public_key.key_size # TODO watch out >> in Python is not the sane as JS >>>, may need to be # corrected. For instance see: # http://grokbase.com/t/python/python-list/0454t3tgaw/zero-fill-shift return int(math.pow(modulus_bit_length, 2)) >> RsaSha256.COST_RIGHT_SHIFT
def validate(self, message): """Verify the signature of self RSA fulfillment. The signature of self RSA fulfillment is verified against the provided message and the condition's public modulus. Args: message (bytes): Message to verify. Returns: bool: Whether self fulfillment is valid. """ if not isinstance(message, bytes): raise Exception( 'Message must be provided as bytes, was: ' + message) public_numbers = RSAPublicNumbers( PUBLIC_EXPONENT, int.from_bytes(self.modulus, byteorder='big'), ) public_key = public_numbers.public_key(default_backend()) verifier = public_key.verifier( self.signature, padding.PSS( mgf=padding.MGF1(hashes.SHA256()), salt_length=SALT_LENGTH, ), hashes.SHA256() ) verifier.update(message) try: verifier.verify() except InvalidSignature as exc: raise ValidationError('Invalid RSA signature') from exc return True
class RSAPublicKey: def __init__(self, n, e): self.n = n self.e = e self._key = RSAPublicNumbers(e, n).public_key(default_backend()) def verify(self, data, sig): verifier = self._key.verifier(sig, PKCS1v15(), SHA1()) verifier.update(data) try: verifier.verify() return True except InvalidSignature: return False
def __init__(self, n, e): self.n = n self.e = e self._key = RSAPublicNumbers(e, n).public_key(default_backend())