def test_aes_cbc_key_size_validation(): """The client should raise an error when the key is an inappropriate size for the specified algorithm""" jwk = JsonWebKey(kty="oct-HSM", key_ops=["encrypt", "decrypt"], k=os.urandom(64)) iv = os.urandom(16) client = CryptographyClient.from_jwk(jwk=jwk) with pytest.raises(AzureError) as ex: client.encrypt(EncryptionAlgorithm.a128_cbcpad, b"...", iv=iv) # requires 16-byte key assert "key size" in str(ex.value).lower() with pytest.raises(AzureError) as ex: client.encrypt(EncryptionAlgorithm.a192_cbcpad, b"...", iv=iv) # requires 24-byte key assert "key size" in str(ex.value).lower() with pytest.raises(AzureError) as ex: client.encrypt(EncryptionAlgorithm.a256_cbcpad, b"...", iv=iv) # requires 32-byte key assert "key size" in str(ex.value).lower()
def test_encrypt_local_from_jwk(self, key_client, is_hsm, **kwargs): """Encrypt locally, decrypt with Key Vault""" key_name = self.get_resource_name("encrypt-local") key = self._create_rsa_key(key_client, key_name, size=4096, hardware_protected=is_hsm) crypto_client = self.create_crypto_client( key, api_version=key_client.api_version) local_client = CryptographyClient.from_jwk(key.key) rsa_encrypt_algorithms = [ algo for algo in EncryptionAlgorithm if algo.startswith("RSA") ] for encrypt_algorithm in rsa_encrypt_algorithms: result = local_client.encrypt(encrypt_algorithm, self.plaintext) self.assertEqual(result.key_id, key.id) result = crypto_client.decrypt(result.algorithm, result.ciphertext) self.assertEqual(result.plaintext, self.plaintext)
def test_encrypt_argument_validation(): """The client should raise an error when arguments don't work with the specified algorithm""" mock_client = mock.Mock() key = mock.Mock( spec=KeyVaultKey, id="https://localhost/fake/key/version", properties=mock.Mock( not_before=datetime(2000, 1, 1, tzinfo=_UTC), expires_on=datetime(3000, 1, 1, tzinfo=_UTC) ), ) client = CryptographyClient(key, mock.Mock()) client._client = mock_client with pytest.raises(ValueError) as ex: client.encrypt(EncryptionAlgorithm.rsa_oaep, b"...", iv=b"...") assert "iv" in str(ex.value) with pytest.raises(ValueError) as ex: client.encrypt(EncryptionAlgorithm.rsa_oaep, b"...", additional_authenticated_data=b"...") assert "additional_authenticated_data" in str(ex.value)
def test_local_only_mode_raise(): """A local-only CryptographyClient should raise an exception if an operation can't be performed locally""" jwk = {"kty":"RSA", "key_ops":["decrypt", "verify", "unwrapKey"], "n":b"10011", "e":b"10001"} client = CryptographyClient.from_jwk(jwk=jwk) # Algorithm not supported locally with pytest.raises(NotImplementedError) as ex: client.decrypt(EncryptionAlgorithm.a256_gcm, b"...") assert EncryptionAlgorithm.a256_gcm in str(ex.value) assert KeyOperation.decrypt in str(ex.value) # Operation not included in JWK permissions with pytest.raises(AzureError) as ex: client.encrypt(EncryptionAlgorithm.rsa_oaep, b"...") assert KeyOperation.encrypt in str(ex.value) # Algorithm not supported locally with pytest.raises(NotImplementedError) as ex: client.verify(SignatureAlgorithm.es256, b"...", b"...") assert SignatureAlgorithm.es256 in str(ex.value) assert KeyOperation.verify in str(ex.value) # Algorithm not supported locally, and operation not included in JWK permissions with pytest.raises(NotImplementedError) as ex: client.sign(SignatureAlgorithm.rs256, b"...") assert SignatureAlgorithm.rs256 in str(ex.value) assert KeyOperation.sign in str(ex.value) # Algorithm not supported locally with pytest.raises(NotImplementedError) as ex: client.unwrap_key(KeyWrapAlgorithm.aes_256, b"...") assert KeyWrapAlgorithm.aes_256 in str(ex.value) assert KeyOperation.unwrap_key in str(ex.value) # Operation not included in JWK permissions with pytest.raises(AzureError) as ex: client.wrap_key(KeyWrapAlgorithm.rsa_oaep, b"...") assert KeyOperation.wrap_key in str(ex.value)
def test_initialization_get_key_successful(): """If the client is able to get key material, it shouldn't do so again""" mock_client = mock.Mock() mock_client.get_key.return_value = mock.Mock(spec=KeyVaultKey) client = CryptographyClient("https://localhost/fake/key/version", mock.Mock()) client._client = mock_client mock_key = mock.Mock() mock_client.get_key.return_value = mock_key assert mock_client.get_key.call_count == 0 with mock.patch(CryptographyClient.__module__ + ".get_local_cryptography_provider") as get_provider: client.verify(SignatureAlgorithm.rs256, b"...", b"...") get_provider.assert_called_once_with(mock_key) for _ in range(3): assert mock_client.get_key.call_count == 1 assert get_provider.call_count == 1 client.verify(SignatureAlgorithm.rs256, b"...", b"...")
def test_ec_verify_local_from_jwk(self, azure_keyvault_url, **kwargs): """Sign with Key Vault, verify locally""" key_client = self.create_key_client(azure_keyvault_url) matrix = { KeyCurveName.p_256: (SignatureAlgorithm.es256, hashlib.sha256), KeyCurveName.p_256_k: (SignatureAlgorithm.es256_k, hashlib.sha256), KeyCurveName.p_384: (SignatureAlgorithm.es384, hashlib.sha384), KeyCurveName.p_521: (SignatureAlgorithm.es512, hashlib.sha512), } for curve, (signature_algorithm, hash_function) in sorted(matrix.items()): key_name = self.get_resource_name("ec-verify-{}".format(curve.value)) key = key_client.create_ec_key(key_name, curve=curve) crypto_client = self.create_crypto_client(key) local_client = CryptographyClient.from_jwk(key.key) digest = hash_function(self.plaintext).digest() result = crypto_client.sign(signature_algorithm, digest) self.assertEqual(result.key_id, key.id) result = local_client.verify(result.algorithm, digest, result.signature) self.assertTrue(result.is_valid)
def test_rsa_verify_local_from_jwk(self, azure_keyvault_url, **kwargs): """Sign with Key Vault, verify locally""" key_client = self.create_key_client(azure_keyvault_url) for size in (2048, 3072, 4096): key_name = self.get_resource_name("rsa-verify-{}".format(size)) key = key_client.create_rsa_key(key_name, size=size) crypto_client = self.create_crypto_client(key) local_client = CryptographyClient.from_jwk(key.key) for signature_algorithm, hash_function in ( (SignatureAlgorithm.ps256, hashlib.sha256), (SignatureAlgorithm.ps384, hashlib.sha384), (SignatureAlgorithm.ps512, hashlib.sha512), (SignatureAlgorithm.rs256, hashlib.sha256), (SignatureAlgorithm.rs384, hashlib.sha384), (SignatureAlgorithm.rs512, hashlib.sha512), ): digest = hash_function(self.plaintext).digest() result = crypto_client.sign(signature_algorithm, digest) self.assertEqual(result.key_id, key.id) result = local_client.verify(result.algorithm, digest, result.signature) self.assertTrue(result.is_valid)
def test_wrap_local_from_jwk(self, **kwargs): """Wrap locally, unwrap with Key Vault""" is_hsm = kwargs.pop("is_hsm") self._skip_if_not_configured(is_hsm) endpoint_url = self.managed_hsm_url if is_hsm else self.vault_url key_client = self.create_key_client(endpoint_url) key_name = self.get_resource_name("wrap-local") key = self._create_rsa_key(key_client, key_name, size=4096, hardware_protected=is_hsm) crypto_client = self.create_crypto_client(key) local_client = CryptographyClient.from_jwk(key.key) for wrap_algorithm in (algo for algo in KeyWrapAlgorithm if algo.startswith("RSA")): result = local_client.wrap_key(wrap_algorithm, self.plaintext) self.assertEqual(result.key_id, key.id) result = crypto_client.unwrap_key(result.algorithm, result.encrypted_key) self.assertEqual(result.key, self.plaintext)
def test_encrypt_local_from_jwk(self, **kwargs): """Encrypt locally, decrypt with Key Vault""" is_hsm = kwargs.pop("is_hsm") self._skip_if_not_configured(is_hsm) endpoint_url = self.managed_hsm_url if is_hsm else self.vault_url key_client = self.create_key_client(endpoint_url) key_name = self.get_resource_name("encrypt-local") key = self._create_rsa_key(key_client, key_name, size=4096, hardware_protected=is_hsm) crypto_client = self.create_crypto_client(key) local_client = CryptographyClient.from_jwk(key.key) rsa_encrypt_algorithms = [ algo for algo in EncryptionAlgorithm if algo.startswith("RSA") ] for encrypt_algorithm in rsa_encrypt_algorithms: result = local_client.encrypt(encrypt_algorithm, self.plaintext) self.assertEqual(result.key_id, key.id) result = crypto_client.decrypt(result.algorithm, result.ciphertext) self.assertEqual(result.plaintext, self.plaintext)
def azkms_obj(key_id): """ Return Azure Key Vault Object """ # e.g of key_id https://kapitanbackend.vault.azure.net/keys/myKey/deadbeef if not cached.azkms_obj: url = urlparse(key_id) # ['', 'keys', 'myKey', 'deadbeef'] or ['kapitanbackend.vault.azure.net', 'keys', 'myKey', 'deadbeef'] # depending on if key_id is prefixed with https:// attrs = url.path.split("/") key_vault_uri = url.hostname or attrs[0] key_name = attrs[-2] key_version = attrs[-1] # If --verbose is set, show requests from azure if logger.getEffectiveLevel() > logging.DEBUG: logging.getLogger("azure").setLevel(logging.ERROR) credential = DefaultAzureCredential() key_client = KeyClient(vault_url=f"https://{key_vault_uri}", credential=credential) key = key_client.get_key(key_name, key_version) cached.azkms_obj = CryptographyClient(key, credential) return cached.azkms_obj
def test_initialization_get_key_successful(): """If the client is able to get key material, it shouldn't do so again""" key_id = "https://localhost/fake/key/version" mock_key = mock.Mock() mock_key.key.kid = key_id mock_client = mock.Mock() mock_client.get_key.return_value = mock_key client = CryptographyClient(key_id, mock.Mock()) client._client = mock_client assert mock_client.get_key.call_count == 0 with mock.patch(CryptographyClient.__module__ + ".get_local_cryptography_provider") as get_provider: client.verify(SignatureAlgorithm.rs256, b"...", b"...") args, _ = get_provider.call_args assert len(args) == 1 and isinstance(args[0], JsonWebKey) and args[0].kid == key_id for _ in range(3): assert mock_client.get_key.call_count == 1 assert get_provider.call_count == 1 client.verify(SignatureAlgorithm.rs256, b"...", b"...")
def get_cryptography_client(self, key): return CryptographyClient(key, self._credential)
def encrypt(self, key_name, plaintext): key = self.key_client.get_key(key_name) crypto_client = CryptographyClient(key, credential=self.credential) text = crypto_client.encrypt(EncryptionAlgorithm.rsa_oaep, bytes(plaintext.encode())) return text.ciphertext
def decrypt(self, ciphertext, key_name): key = self.key_client.get_key(key_name) crypto_client = CryptographyClient(key, credential=self.credential) text = crypto_client.decrypt(EncryptionAlgorithm.rsa_oaep, ciphertext) return text.plaintext.decode()
def test_rsa_key_id(self, key_client, credential, **kwargs): """When initialized with a key ID, the client should retrieve the key and perform public operations locally""" key = key_client.create_rsa_key(self.create_random_name("rsakey")) crypto_client = CryptographyClient(key.id, credential) crypto_client._initialize() assert crypto_client.key_id == key.id # ensure all remote crypto operations will fail crypto_client._client = None crypto_client.encrypt(EncryptionAlgorithm.rsa_oaep, self.plaintext) crypto_client.verify(SignatureAlgorithm.rs256, hashlib.sha256(self.plaintext).digest(), self.plaintext) crypto_client.wrap_key(KeyWrapAlgorithm.rsa_oaep, self.plaintext)
def obtain_access_token(key_vault_url, msi_credential, certificate_name, tenant_id, client_id, resource): # Get certificate from Key Vault, load the DER certificate it returns, and calculate the thumbprint cert_client = CertificateClient(key_vault_url, msi_credential) result = (cert_client.get_certificate(certificate_name)).cer cert = load_der_x509_certificate(result, backend=default_backend()) thumbprint = base64.urlsafe_b64encode(cert.fingerprint( hashes.SHA1())).decode('UTF-8') # Create the headers for the JWT headers = {"alg": "RS256", "typ": "JWT", "x5t": thumbprint} encoded_header = (base64.urlsafe_b64encode( bytes(json.dumps(headers), 'UTF-8'))).decode('UTF-8') # Generate a nonce nonce = uuid4().hex # Create the JWT payload claims = { "aud": f"https://login.microsoftonline.com/{tenant_id}/oauth2/token", "iss": client_id, "sub": client_id, "jti": nonce, "nbf": int(time.time()), "exp": int(time.time() + (7 * 86400)) } encoded_claims = (base64.urlsafe_b64encode( bytes(json.dumps(claims), 'UTF-8'))).decode('UTF-8').rstrip('=') # Issue the request to Key Vault to sign the data key_client = KeyClient(key_vault_url, msi_credential) key = key_client.get_key(certificate_name) crypto_client = CryptographyClient(key, credential=msi_credential) data_hash = hashlib.sha256( bytes((encoded_header + '.' + encoded_claims), 'UTF-8')).digest() # Use Key Vault to calculate a signature using RSASSA-PKCS1-v1_5 using SHA-256 jws_signature = (crypto_client.sign(SignatureAlgorithm.rs256, data_hash)).signature encoded_jws_signature = ( base64.urlsafe_b64encode(jws_signature)).decode('UTF-8').rstrip('=') assertion = encoded_header + '.' + encoded_claims + '.' + encoded_jws_signature payload = { "grant_type": "client_credentials", "client_id": client_id, "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", "client_assertion": assertion, "resource": resource } # Post the request for the access token result = requests.post( url=f"https://login.microsoftonline.com/{tenant_id}/oauth2/token", data=payload) # Validate that access token was returned if result.status_code == 200: logging.info('Access token successfully obtained') return ((json.loads(result.text))['access_token']) else: error = json.loads(result.text) logging.error('Unable to obtain access token') logging.error(f"Error was: {error['error']}") logging.error(f"Error description was: {error['error_description']}") logging.error(f"Error correlation_id was: {error['correlation_id']}") raise Exception('Failed to obtain access token')
def __init__(self, kek, credential): self.algorithm = KeyWrapAlgorithm.aes_256 self.kek = kek self.kid = kek.id self.client = CryptographyClient(kek, credential)
def test_calls_service_for_operations_unsupported_locally(): """When an operation can't be performed locally, the client should request Key Vault perform it""" mock_client = mock.Mock() key = mock.Mock(spec=KeyVaultKey, id="https://localhost/fake/key/version") client = CryptographyClient(key, mock.Mock()) client._client = mock_client supports_nothing = mock.Mock(supports=mock.Mock(return_value=False)) with mock.patch( CryptographyClient.__module__ + ".get_local_cryptography_provider", lambda *_: supports_nothing): client.decrypt(EncryptionAlgorithm.rsa_oaep, b"...") assert mock_client.decrypt.call_count == 1 assert supports_nothing.decrypt.call_count == 0 client.encrypt(EncryptionAlgorithm.rsa_oaep, b"...") assert mock_client.encrypt.call_count == 1 assert supports_nothing.encrypt.call_count == 0 client.sign(SignatureAlgorithm.rs256, b"...") assert mock_client.sign.call_count == 1 assert supports_nothing.sign.call_count == 0 client.verify(SignatureAlgorithm.rs256, b"...", b"...") assert mock_client.verify.call_count == 1 assert supports_nothing.verify.call_count == 0 client.unwrap_key(KeyWrapAlgorithm.rsa_oaep, b"...") assert mock_client.unwrap_key.call_count == 1 assert supports_nothing.unwrap_key.call_count == 0 client.wrap_key(KeyWrapAlgorithm.rsa_oaep, b"...") assert mock_client.wrap_key.call_count == 1 assert supports_nothing.wrap_key.call_count == 0
class AzureEncryptionProvider(EncryptionProvider): """An EncryptionProvider implementation for Azure Key Vault. To authenticate, please provide Azure AD service principal info as 'tenant_id', 'client_id', and 'client_secret' kwargs, or as AZURE_TENANT_ID, AZURE_CLIENT_ID, and AZURE_CLIENT_SECRET environment variables. Attributes: tenant_id (str): The tenant ID of the Azure SP to connect with. client_id (str): The client ID of the Azure SP to connect with. client_secret (str): The client secret of the Azure SP to connect with. key_client (KeyClient): Azure Key Vault key client. crypto_client (CryptographyClient): Azure Key Vault crypto client. Args: vault_url (str): The URL of the key vault to connect to. key (str): The name of the key encryption key to use for envelope encryption. auth_via_cli (bool, kwarg): If we should auth via Azure CLI. **kwargs: Authentication information. Raises: TypeError: If authentication information is not provided correctly. """ def __init__(self, vault_url: str, key: str, **kwargs) -> None: auth_via_cli = bool(kwargs.pop("auth_via_cli", False)) if auth_via_cli: try: self.key_client = get_client_from_cli_profile( KeyClient, vault_url=vault_url) self.key_encryption_key = self.key_client.get_key(key) self.crypto_client = get_client_from_cli_profile( CryptographyClient, key=self.key_encryption_key) except CLIError: logging.error( "ERROR: Unable to authenticate via Azure CLI, have you " "logged in with 'az login'?") raise SystemExit(1) else: tenant_id = kwargs.pop("tenant_id", os.getenv(TENANT_ID_ENVVAR)) client_id = kwargs.pop("client_id", os.getenv(CLIENT_ID_ENVVAR)) client_secret = kwargs.pop("client_secret", os.getenv(CLIENT_SECRET_ENVVAR)) if tenant_id is None or client_id is None or client_secret is None: raise TypeError( "Please specify tenant_id, client_id, and client_secret " "in config or in environment variables as in " "https://github.com/Azure/azure-sdk-for-python/tree/master/sdk/identity/azure-identity#service-principal-with-secret" ) self.cred = ClientSecretCredential(tenant_id, client_id, client_secret) self.key_client = KeyClient(vault_url, credential=self.cred, logger=None) self.key_encryption_key = self.key_client.get_key(key) self.crypto_client = CryptographyClient(self.key_encryption_key, self.cred) # overrides EncryptionProvider.encrypt() def encrypt(self, data: bytes) -> EncryptionEnvelope: # encrypt the data locally, generating a data key and a nonce ciphertext, data_key, nonce = self._data_encrypt(data) # encrypt the data key using the key from the vault result = self.crypto_client.encrypt(ENCRYPTION_ALGORITHM, data_key) del data_key # we don't wanna keep this around after we've encrypted it encrypted_data_key = result.ciphertext # encode to base64 for storage/transmission b64_ciphertext = base64.b64encode(ciphertext).decode("utf-8") b64_encrypted_data_key = base64.b64encode(encrypted_data_key).decode( "utf-8") b64_nonce = base64.b64encode(nonce).decode("utf-8") return EncryptionEnvelope(b64_ciphertext, b64_encrypted_data_key, b64_nonce, self.key_encryption_key.properties.version) # overrides EncryptionProvider.decrypt() def decrypt(self, envelope: EncryptionEnvelope) -> Union[bytes, None]: if envelope.version != self.key_encryption_key.properties.version: logging.error( "Encryption key version %s is out of " "date, please re-encrypt with 'victoria encrypt rotate'", envelope.version) return None # decode from base64 ciphertext = base64.b64decode(envelope.data) encrypted_data_key = base64.b64decode(envelope.key) nonce = base64.b64decode(envelope.iv) # decrypt the data key result = self.crypto_client.decrypt(ENCRYPTION_ALGORITHM, encrypted_data_key) data_key = result.plaintext # decrypt the data locally using the nonce and decrypted data key return self._data_decrypt(ciphertext, data_key, nonce) # overrides EncryptionProvider.rotate_key() def rotate_key(self, envelope: EncryptionEnvelope, version: Optional[str] = None) -> EncryptionEnvelope: old_key = self.key_client.get_key(self.key_encryption_key.name, version=envelope.version) old_crypto_client = CryptographyClient(old_key, self.cred) # decode from base64 ciphertext = base64.b64decode(envelope.data) encrypted_data_key = base64.b64decode(envelope.key) nonce = base64.b64decode(envelope.iv) # decrypt the data key result = old_crypto_client.decrypt(ENCRYPTION_ALGORITHM, encrypted_data_key) data_key = result.plaintext # decrypt the data locally using the nonce and decrypted data key decrypted_data = self._data_decrypt(ciphertext, data_key, nonce) # now re-encrypt with the latest key encryption key return self.encrypt(decrypted_data)
def test_prefers_local_provider(): """The client should complete operations locally whenever possible""" mock_client = mock.Mock() key = mock.Mock( spec=KeyVaultKey, id="https://localhost/fake/key/version", properties=mock.Mock(not_before=datetime(2000, 1, 1, tzinfo=_UTC), expires_on=datetime(3000, 1, 1, tzinfo=_UTC)), ) client = CryptographyClient(key, mock.Mock()) client._client = mock_client supports_everything = mock.Mock(supports=mock.Mock(return_value=True)) with mock.patch( CryptographyClient.__module__ + ".get_local_cryptography_provider", lambda *_: supports_everything): client.decrypt(EncryptionAlgorithm.rsa_oaep, b"...") assert mock_client.decrypt.call_count == 0 assert supports_everything.decrypt.call_count == 1 client.encrypt(EncryptionAlgorithm.rsa_oaep, b"...") assert mock_client.encrypt.call_count == 0 assert supports_everything.encrypt.call_count == 1 client.sign(SignatureAlgorithm.rs256, b"...") assert mock_client.sign.call_count == 0 assert supports_everything.sign.call_count == 1 client.verify(SignatureAlgorithm.rs256, b"...", b"...") assert mock_client.verify.call_count == 0 assert supports_everything.verify.call_count == 1 client.unwrap_key(KeyWrapAlgorithm.rsa_oaep, b"...") assert mock_client.unwrap_key.call_count == 0 assert supports_everything.unwrap_key.call_count == 1 client.wrap_key(KeyWrapAlgorithm.rsa_oaep, b"...") assert mock_client.wrap_key.call_count == 0 assert supports_everything.wrap_key.call_count == 1
def test_custom_hook_policy(): class CustomHookPolicy(SansIOHTTPPolicy): pass client = CryptographyClient("https://localhost/fake/key/version", object(), custom_hook_policy=CustomHookPolicy()) assert isinstance(client._client._config.custom_hook_policy, CustomHookPolicy)
def test_decrypt_argument_validation(): mock_client = mock.Mock() key = mock.Mock( spec=KeyVaultKey, id="https://localhost/fake/key/version", properties=mock.Mock(not_before=datetime(2000, 1, 1, tzinfo=_UTC), expires_on=datetime(3000, 1, 1, tzinfo=_UTC)), ) client = CryptographyClient(key, mock.Mock()) client._client = mock_client with pytest.raises(ValueError) as ex: client.decrypt(EncryptionAlgorithm.rsa_oaep, b"...", iv=b"...") assert "iv" in str(ex.value) with pytest.raises(ValueError) as ex: client.decrypt(EncryptionAlgorithm.rsa_oaep, b"...", additional_authenticated_data=b"...") assert "additional_authenticated_data" in str(ex.value) with pytest.raises(ValueError) as ex: client.decrypt(EncryptionAlgorithm.rsa_oaep, b"...", authentication_tag=b"...") assert "authentication_tag" in str(ex.value) with pytest.raises(ValueError) as ex: client.decrypt(EncryptionAlgorithm.a128_gcm, b"...", iv=b"...") assert "authentication_tag" in str(ex.value) and "required" in str( ex.value) with pytest.raises(ValueError) as ex: client.decrypt(EncryptionAlgorithm.a192_cbcpad, b"...") assert "iv" in str(ex.value) and "required" in str(ex.value)
def _get_cripto_client(): uri = os.environ['KEYVAULT_URI'] credential = DefaultAzureCredential() key_client = KeyClient(vault_url=uri, credential=credential) key = key_client.get_key("generated-key") return CryptographyClient(key, credential=credential)