示例#1
0
    def test_key_wrap_and_unwrap(self, vault, **kwargs):
        self.assertIsNotNone(vault)
        vault_uri = vault.properties.vault_uri
        key_name = self.get_resource_name('keywrap')

        key_id = KeyVaultId.create_key_id(vault_uri, key_name)
        plain_text = self.plain_text

        # import key
        imported_key = self._import_test_key(vault, key_id)
        key_id = KeyVaultId.parse_key_id(imported_key.key.kid)

        # wrap without version
        result = self.client.wrap_key(key_id.vault, key_id.name, '', 'RSA-OAEP', plain_text)
        cipher_text = result.result

        # unwrap without version
        result = self.client.unwrap_key(key_id.vault, key_id.name, '', 'RSA-OAEP', cipher_text)
        self.assertEqual(plain_text, result.result)

        # wrap with version
        result = self.client.wrap_key(key_id.vault, key_id.name, key_id.version, 'RSA-OAEP', plain_text)
        cipher_text = result.result

        # unwrap with version
        result = self.client.unwrap_key(key_id.vault, key_id.name, key_id.version, 'RSA-OAEP', cipher_text)
        self.assertEqual(plain_text, result.result)
    def test_key_sign_and_verify(self, vault, **kwargs):
        self.assertIsNotNone(vault)
        vault_uri = vault.properties.vault_uri
        key_name = self.get_resource_name('keysign')

        key_id = KeyVaultId.create_key_id(vault_uri, key_name)
        plain_text = self.plain_text
        md = hashlib.sha256()
        md.update(plain_text)
        digest = md.digest()

        # import key
        imported_key = self._import_test_key(vault, key_id)
        key_id = KeyVaultId.parse_key_id(imported_key.key.kid)

        # sign without version
        signature = self.client.sign(key_id.vault, key_id.name, '', 'RS256',
                                     digest).result

        # verify without version
        result = self.client.verify(key_id.vault, key_id.name, '', 'RS256',
                                    digest, signature)
        self.assertTrue(result.value)

        # sign with version
        signature = self.client.sign(key_id.vault, key_id.name, '', 'RS256',
                                     digest).result

        # verify with version
        result = self.client.verify(key_id.vault, key_id.name, key_id.version,
                                    'RS256', digest, signature)
        self.assertTrue(result.value)
    def test_key_wrap_and_unwrap(self, vault, **kwargs):
        self.assertIsNotNone(vault)
        vault_uri = vault.properties.vault_uri
        key_name = self.get_resource_name('keywrap')

        key_id = KeyVaultId.create_key_id(vault_uri, key_name)
        plain_text = self.plain_text

        # import key
        imported_key = self._import_test_key(vault, key_id)
        key_id = KeyVaultId.parse_key_id(imported_key.key.kid)

        # wrap without version
        result = self.client.wrap_key(key_id.vault, key_id.name, '',
                                      'RSA-OAEP', plain_text)
        cipher_text = result.result

        # unwrap without version
        result = self.client.unwrap_key(key_id.vault, key_id.name, '',
                                        'RSA-OAEP', cipher_text)
        self.assertEqual(plain_text, result.result)

        # wrap with version
        result = self.client.wrap_key(key_id.vault, key_id.name,
                                      key_id.version, 'RSA-OAEP', plain_text)
        cipher_text = result.result

        # unwrap with version
        result = self.client.unwrap_key(key_id.vault, key_id.name,
                                        key_id.version, 'RSA-OAEP',
                                        cipher_text)
        self.assertEqual(plain_text, result.result)
    def test_key_recover_and_purge(self, vault, **kwargs):
        self.assertIsNotNone(vault)
        vault_uri = vault.properties.vault_uri

        keys = {}

        # create keys to recover
        for i in range(0, self.list_test_size):
            key_name = self.get_resource_name('keyrec{}'.format(str(i)))
            keys[key_name] = self.client.create_key(vault_uri, key_name, 'RSA')

        # create keys to purge
        for i in range(0, self.list_test_size):
            key_name = self.get_resource_name('keyprg{}'.format(str(i)))
            keys[key_name] = self.client.create_key(vault_uri, key_name, 'RSA')

        # delete all keys
        for key_name in keys.keys():
            self.client.delete_key(vault_uri, key_name)

        if not self.is_playback():
            time.sleep(20)

        # validate all our deleted keys are returned by get_deleted_keys
        deleted = [
            KeyVaultId.parse_key_id(s.kid).name
            for s in self.client.get_deleted_keys(vault_uri)
        ]
        self.assertTrue(all(s in deleted for s in keys.keys()))

        # recover select keys
        for key_name in [s for s in keys.keys() if s.startswith('keyrec')]:
            self.client.recover_deleted_key(vault_uri, key_name)

        # purge select keys
        for key_name in [s for s in keys.keys() if s.startswith('keyprg')]:
            self.client.purge_deleted_key(vault_uri, key_name)

        if not self.is_playback():
            time.sleep(20)

        # validate none of our deleted keys are returned by get_deleted_keys
        deleted = [
            KeyVaultId.parse_key_id(s.kid).name
            for s in self.client.get_deleted_keys(vault_uri)
        ]
        self.assertTrue(not any(s in deleted for s in keys.keys()))

        # validate the recovered keys
        expected = {
            k: v
            for k, v in keys.items()
            if k.startswith('key-') and k.endswith('-recover')
        }
        actual = {
            k: self.client.get_key(vault_uri, k)
            for k in expected.keys()
        }
        self.assertEqual(len(set(expected.keys()) & set(actual.keys())),
                         len(expected))
 def _validate_certificate_issuer_list(self, issuers, expected):
     for issuer in issuers:
         KeyVaultId.parse_certificate_issuer_id(issuer.id)
         provider = expected[issuer.id]
         if provider:
             self.assertEqual(provider, issuer.provider)
             del expected[issuer.id]
示例#6
0
    def test_key_sign_and_verify(self, vault, **kwargs):
        self.assertIsNotNone(vault)
        vault_uri = vault.properties.vault_uri
        key_name = self.get_resource_name('keysign')

        key_id = KeyVaultId.create_key_id(vault_uri, key_name)
        plain_text = self.plain_text
        md = hashlib.sha256()
        md.update(plain_text);
        digest = md.digest();

        # import key
        imported_key = self._import_test_key(vault, key_id)
        key_id = KeyVaultId.parse_key_id(imported_key.key.kid)

        # sign without version
        signature = self.client.sign(key_id.vault, key_id.name, '', 'RS256', digest).result

        # verify without version
        result = self.client.verify(key_id.vault, key_id.name, '', 'RS256', digest, signature)
        self.assertTrue(result.value)

        # sign with version
        signature = self.client.sign(key_id.vault, key_id.name, '', 'RS256', digest).result

        # verify with version
        result = self.client.verify(key_id.vault, key_id.name, key_id.version, 'RS256', digest, signature)
        self.assertTrue(result.value)
    def test_recover_and_purge(self, vault, **kwargs):
        self.assertIsNotNone(vault)
        vault_uri = vault.properties.vault_uri

        certs = {}
        cert_policy = CertificatePolicy(key_properties=KeyProperties(exportable=True,
                                                                     key_type='RSA',
                                                                     key_size=2048,
                                                                     reuse_key=False),
                                        secret_properties=SecretProperties(content_type='application/x-pkcs12'),
                                        issuer_parameters=IssuerParameters(name='Self'),
                                        x509_certificate_properties=X509CertificateProperties(
                                            subject='CN=*.microsoft.com',
                                            subject_alternative_names=SubjectAlternativeNames(
                                                dns_names=['onedrive.microsoft.com', 'xbox.microsoft.com']
                                            ),
                                            validity_in_months=24
                                        ))
        # create certificates to recover
        for i in range(0, self.list_test_size):
            cert_name = self.get_resource_name('certrec{}'.format(str(i)))
            certs[cert_name] = self._import_common_certificate(vault_uri, cert_name)

        # create certificates to purge
        for i in range(0, self.list_test_size):
            cert_name = self.get_resource_name('certprg{}'.format(str(i)))
            certs[cert_name] = self._import_common_certificate(vault_uri, cert_name)

        # delete all certificates
        for cert_name in certs.keys():
            delcert = self.client.delete_certificate(vault_uri, cert_name)
            print(delcert)

        if not self.is_playback():
            time.sleep(30)

        # validate all our deleted certificates are returned by get_deleted_certificates
        deleted = [KeyVaultId.parse_certificate_id(s.id).name for s in self.client.get_deleted_certificates(vault_uri)]
        # self.assertTrue(all(s in deleted for s in certs.keys()))

        # recover select secrets
        for certificate_name in [s for s in certs.keys() if s.startswith('certrec')]:
            self.client.recover_deleted_certificate(vault_uri, certificate_name)

        # purge select secrets
        for certificate_name in [s for s in certs.keys() if s.startswith('certprg')]:
            self.client.purge_deleted_certificate(vault_uri, certificate_name)

        if not self.is_playback():
            time.sleep(30)

        # validate none of our deleted certificates are returned by get_deleted_certificates
        deleted = [KeyVaultId.parse_secret_id(s.id).name for s in self.client.get_deleted_certificates(vault_uri)]
        self.assertTrue(not any(s in deleted for s in certs.keys()))

        # validate the recovered certificates
        expected = {k: v for k, v in certs.items() if k.startswith('certrec')}
        actual = {k: self.client.get_certificate(vault_uri, k, KeyVaultId.version_none) for k in expected.keys()}
        self.assertEqual(len(set(expected.keys()) & set(actual.keys())), len(expected))
示例#8
0
    def test_create_certificate_id(self):
        expected = self._get_expected('certificates', 'myvault', 'mycert')
        res = KeyVaultId.create_certificate_id('https://myvault.vault.azure.net', ' mycert', None)
        self.assertEqual(res.__dict__, expected)

        expected = self._get_expected('certificates', 'myvault', 'mycert', 'abc123')
        res = KeyVaultId.create_certificate_id('https://myvault.vault.azure.net', 'mycert', ' abc123')
        self.assertEqual(res.__dict__, expected)
示例#9
0
    def test_parse_key_id(self):
        expected = self._get_expected('keys', 'myvault', 'mykey', 'abc123')
        res = KeyVaultId.parse_key_id('https://myvault.vault.azure.net/keys/mykey/abc123')
        self.assertEqual(res.__dict__, expected)

        expected = self._get_expected('keys', 'myvault', 'mykey')
        res = KeyVaultId.parse_key_id('https://myvault.vault.azure.net/keys/mykey')
        self.assertEqual(res.__dict__, expected)
示例#10
0
    def test_parse_secret_id(self):
        expected = self._get_expected('secrets', 'myvault', 'mysecret', 'abc123')
        res = KeyVaultId.parse_secret_id('https://myvault.vault.azure.net/secrets/mysecret/abc123')
        self.assertEqual(res.__dict__, expected)

        expected = self._get_expected('secrets', 'myvault', 'mysecret')
        res = KeyVaultId.parse_secret_id('https://myvault.vault.azure.net/secrets/mysecret')
        self.assertEqual(res.__dict__, expected)
    def test_crud_operations(self, vault, **kwargs):
        self.assertIsNotNone(vault)
        vault_uri = vault.properties.vault_uri
        cert_name = self.get_resource_name('cert')

        cert_policy = CertificatePolicy(key_properties=KeyProperties(exportable=True,
                                                                     key_type='RSA',
                                                                     key_size=2048,
                                                                     reuse_key=False),
                                        secret_properties=SecretProperties(content_type='application/x-pkcs12'),
                                        issuer_parameters=IssuerParameters(name='Self'),
                                        x509_certificate_properties=X509CertificateProperties(
                                            subject='CN=*.microsoft.com',
                                            subject_alternative_names=SubjectAlternativeNames(
                                                dns_names=['onedrive.microsoft.com', 'xbox.microsoft.com']
                                            ),
                                            validity_in_months=24
                                        ))

        # create certificate
        interval_time = 5 if not self.is_playback() else 0
        cert_operation = self.client.create_certificate(vault_uri, cert_name, cert_policy)
        while True:
            pending_cert = self.client.get_certificate_operation(vault_uri, cert_name)
            self._validate_certificate_operation(pending_cert, vault_uri, cert_name, cert_policy)
            if pending_cert.status.lower() == 'completed':
                cert_id = KeyVaultId.parse_certificate_operation_id(pending_cert.target)
                break
            elif pending_cert.status.lower() != 'inprogress':
                raise Exception('Unknown status code for pending certificate: {}'.format(pending_cert))
            time.sleep(interval_time)

        # get certificate
        cert_bundle = self.client.get_certificate(cert_id.vault, cert_id.name, '')
        self._validate_certificate_bundle(cert_bundle, vault_uri, cert_name, cert_policy)

        # get certificate as secret
        secret_id = KeyVaultId.parse_secret_id(cert_bundle.sid)
        secret_bundle = self.client.get_secret(secret_id.vault, secret_id.name, secret_id.version)

        # update certificate
        cert_policy.tags = {'tag1': 'value1'}
        cert_bundle = self.client.update_certificate(cert_id.vault, cert_id.name, cert_id.version, cert_policy)
        self._validate_certificate_bundle(cert_bundle, vault_uri, cert_name, cert_policy)

        # delete certificate
        cert_bundle = self.client.delete_certificate(vault_uri, cert_name)
        self._validate_certificate_bundle(cert_bundle, vault_uri, cert_name, cert_policy)

        # get certificate returns not found
        try:
            self.client.get_certificate(cert_id.vault, cert_id.name, '')
            self.fail('Get should fail')
        except Exception as ex:
            if not hasattr(ex, 'message') or 'not found' not in ex.message.lower():
                raise ex
    def test_parse_key_id(self):
        expected = self._get_expected('keys', 'myvault', 'mykey', 'abc123')
        res = KeyVaultId.parse_key_id(
            'https://myvault.vault.azure.net/keys/mykey/abc123')
        self.assertEqual(res.__dict__, expected)

        expected = self._get_expected('keys', 'myvault', 'mykey')
        res = KeyVaultId.parse_key_id(
            'https://myvault.vault.azure.net/keys/mykey')
        self.assertEqual(res.__dict__, expected)
    def test_create_object_id(self):
        # success scenarios
        expected = self._get_expected('keys', 'myvault', 'mykey')
        res = KeyVaultId.create_object_id('keys',
                                          'https://myvault.vault.azure.net',
                                          ' mykey', None)
        self.assertEqual(res.__dict__, expected)

        res = KeyVaultId.create_object_id('keys',
                                          'https://myvault.vault.azure.net',
                                          ' mykey', ' ')
        self.assertEqual(res.__dict__, expected)

        expected = self._get_expected('keys', 'myvault', 'mykey', 'abc123')
        res = KeyVaultId.create_object_id(' keys ',
                                          'https://myvault.vault.azure.net',
                                          ' mykey ', ' abc123 ')
        self.assertEqual(res.__dict__, expected)

        # failure scenarios
        with self.assertRaises(TypeError):
            KeyVaultId.create_object_id('keys',
                                        'https://myvault.vault.azure.net',
                                        ['stuff'], '')
        with self.assertRaises(ValueError):
            KeyVaultId.create_object_id('keys',
                                        'https://myvault.vault.azure.net', ' ',
                                        '')
        with self.assertRaises(ValueError):
            KeyVaultId.create_object_id('keys', 'myvault.vault.azure.net',
                                        'mykey', '')
    def test_create_secret_id(self):
        expected = self._get_expected('secrets', 'myvault', 'mysecret')
        res = KeyVaultId.create_secret_id('https://myvault.vault.azure.net',
                                          ' mysecret', None)
        self.assertEqual(res.__dict__, expected)

        expected = self._get_expected('secrets', 'myvault', 'mysecret',
                                      'abc123')
        res = KeyVaultId.create_secret_id('https://myvault.vault.azure.net',
                                          ' mysecret ', ' abc123 ')
        self.assertEqual(res.__dict__, expected)
    def test_parse_certificate_id(self):
        expected = self._get_expected('certificates', 'myvault', 'mycert',
                                      'abc123')
        res = KeyVaultId.parse_certificate_id(
            'https://myvault.vault.azure.net/certificates/mycert/abc123')
        self.assertEqual(res.__dict__, expected)

        expected = self._get_expected('certificates', 'myvault', 'mycert')
        res = KeyVaultId.parse_certificate_id(
            'https://myvault.vault.azure.net/certificates/mycert')
        self.assertEqual(res.__dict__, expected)
示例#16
0
    def test_recover_purge(self, vault, **kwargs):
        self.assertIsNotNone(vault)
        vault_uri = vault.properties.vault_uri

        secrets = {}

        # create secrets to recover
        for i in range(0, self.list_test_size):
            secret_name = self.get_resource_name('secrec{}'.format(str(i)))
            secret_value = self.get_resource_name('secval{}'.format((str(i))))
            secrets[secret_name] = self.client.set_secret(vault_uri, secret_name, secret_value)

        # create secrets to purge
        for i in range(0, self.list_test_size):
            secret_name = self.get_resource_name('secprg{}'.format(str(i)))
            secret_value = self.get_resource_name('secval{}'.format((str(i))))
            secrets[secret_name] = self.client.set_secret(vault_uri, secret_name, secret_value)

        # delete all secrets
        for secret_name in secrets.keys():
            self.client.delete_secret(vault_uri, secret_name)

        if not self.is_playback():
            time.sleep(20)

        # validate all our deleted secrets are returned by get_deleted_secrets
        deleted = [KeyVaultId.parse_secret_id(s.id).name for s in self.client.get_deleted_secrets(vault_uri)]
        self.assertTrue(all(s in deleted for s in secrets.keys()))

        # recover select secrets
        for secret_name in [s for s in secrets.keys() if s.startswith('secrec')]:
            self.client.recover_deleted_secret(vault_uri, secret_name)

        # purge select secrets
        for secret_name in [s for s in secrets.keys() if s.startswith('secprg')]:
            self.client.purge_deleted_secret(vault_uri, secret_name)

        if not self.is_playback():
            time.sleep(20)

        # validate none of our deleted secrets are returned by get_deleted_secrets
        deleted = [KeyVaultId.parse_secret_id(s.id).name for s in self.client.get_deleted_secrets(vault_uri)]
        self.assertTrue(not any(s in deleted for s in secrets.keys()))

        # validate the recovered secrets
        expected = {k: v for k, v in secrets.items() if k.startswith('secrec')}
        actual = {k: self.client.get_secret(vault_uri, k, KeyVaultId.version_none) for k in expected.keys()}
        self.assertEqual(len(set(expected.keys()) & set(actual.keys())), len(expected))
示例#17
0
    def test_recover_purge(self, vault, **kwargs):
        self.assertIsNotNone(vault)
        vault_uri = vault.properties.vault_uri

        secrets = {}

        # create secrets to recover
        for i in range(0, self.list_test_size):
            secret_name = self.get_resource_name('secrec{}'.format(str(i)))
            secret_value = self.get_resource_name('secval{}'.format((str(i))))
            secrets[secret_name] = self.client.set_secret(vault_uri, secret_name, secret_value)

        # create secrets to purge
        for i in range(0, self.list_test_size):
            secret_name = self.get_resource_name('secprg{}'.format(str(i)))
            secret_value = self.get_resource_name('secval{}'.format((str(i))))
            secrets[secret_name] = self.client.set_secret(vault_uri, secret_name, secret_value)

        # delete all secrets
        for secret_name in secrets.keys():
            self.client.delete_secret(vault_uri, secret_name)

        if not self.is_playback():
            time.sleep(20)

        # validate all our deleted secrets are returned by get_deleted_secrets
        deleted = [KeyVaultId.parse_secret_id(s.id).name for s in self.client.get_deleted_secrets(vault_uri)]
        self.assertTrue(all(s in deleted for s in secrets.keys()))

        # recover select secrets
        for secret_name in [s for s in secrets.keys() if s.startswith('secrec')]:
            self.client.recover_deleted_secret(vault_uri, secret_name)

        # purge select secrets
        for secret_name in [s for s in secrets.keys() if s.startswith('secprg')]:
            self.client.purge_deleted_secret(vault_uri, secret_name)

        if not self.is_playback():
            time.sleep(20)

        # validate none of our deleted secrets are returned by get_deleted_secrets
        deleted = [KeyVaultId.parse_secret_id(s.id).name for s in self.client.get_deleted_secrets(vault_uri)]
        self.assertTrue(not any(s in deleted for s in secrets.keys()))

        # validate the recovered secrets
        expected = {k: v for k, v in secrets.items() if k.startswith('secrec')}
        actual = {k: self.client.get_secret(vault_uri, k, KeyVaultId.version_none) for k in expected.keys()}
        self.assertEqual(len(set(expected.keys()) & set(actual.keys())), len(expected))
示例#18
0
    def test_key_list_versions(self, vault, **kwargs):
        self.assertIsNotNone(vault)
        vault_uri = vault.properties.vault_uri
        key_name = self.get_resource_name('key')

        max_keys = self.list_test_size
        expected = {}

        # create many key versions
        for x in range(0, max_keys):
            key_bundle = None
            error_count = 0
            while not key_bundle:
                try:
                    key_bundle = self.client.create_key(vault_uri, key_name, 'RSA')
                    kid = KeyVaultId.parse_key_id(key_bundle.key.kid).id.strip('/')
                    expected[kid] = key_bundle.attributes
                except Exception as ex:
                    if hasattr(ex, 'message') and 'Throttled' in ex.message:
                        error_count += 1
                        time.sleep(2.5 * error_count)
                        continue
                    else:
                        raise ex

        # list key versions
        self._validate_key_list(list(self.client.get_key_versions(vault_uri, key_name)), expected)
示例#19
0
 def get_secret(self, name, version=''):
     ''' Gets an existing secret '''
     secret_bundle = self.client.get_secret(self.keyvault_uri, name, version)
     if secret_bundle:
         secret_id = KeyVaultId.parse_secret_id(secret_bundle.id)
         return dict(secret_id=secret_id.id, secret_value=secret_bundle.value)
     return None
    def test_list_versions(self, vault, **kwargs):
        self.assertIsNotNone(vault)
        vault_uri = vault.properties.vault_uri
        cert_name = self.get_resource_name('certver')

        max_certificates = self.list_test_size
        expected = {}

        # import same certificates as different versions
        for x in range(0, max_certificates):
            cert_bundle = None
            error_count = 0
            while not cert_bundle:
                try:
                    cert_bundle = self._import_common_certificate(vault_uri, cert_name)[0]
                    cid = KeyVaultId.parse_certificate_id(cert_bundle.id).id.strip('/')
                    expected[cid] = cert_bundle.attributes
                except Exception as ex:
                    if hasattr(ex, 'message') and 'Throttled' in ex.message:
                        error_count += 1
                        time.sleep(2.5 * error_count)
                        continue
                    else:
                        raise ex

        # list certificate versions
        self._validate_certificate_list(list(self.client.get_certificate_versions(vault_uri, cert_name)), expected)
 def _validate_certificate_operation(self, pending_cert, vault, cert_name, cert_policy):
     self.assertIsNotNone(pending_cert)
     self.assertIsNotNone(pending_cert.csr)
     self.assertEqual(cert_policy.issuer_parameters.name, pending_cert.issuer_parameters.name)
     pending_id = KeyVaultId.parse_certificate_operation_id(pending_cert.id)
     self.assertEqual(pending_id.vault.strip('/'), vault.strip('/'))
     self.assertEqual(pending_id.name, cert_name)
示例#22
0
    def test_secret_list(self, vault, **kwargs):
        self.assertIsNotNone(vault)
        vault_uri = vault.properties.vault_uri

        max_secrets = self.list_test_size
        expected = {}

        # create many secrets
        for x in range(0, max_secrets):
            secret_name = 'sec{}'.format(x)
            secret_value = self.get_resource_name('secVal{}'.format(x))
            secret_bundle = None
            error_count = 0
            while not secret_bundle:
                try:
                    secret_bundle = self.client.set_secret(vault_uri, secret_name, secret_value)
                    sid = KeyVaultId.parse_secret_id(secret_bundle.id).base_id.strip('/')
                    expected[sid] = secret_bundle.attributes
                except Exception as ex:
                    if hasattr(ex, 'message') and 'Throttled' in ex.message:
                        error_count += 1
                        time.sleep(2.5 * error_count)
                        continue
                    else:
                        raise ex

        # list secrets
        result = list(self.client.get_secrets(vault_uri, self.list_test_size))
        self._validate_secret_list(result, expected)
示例#23
0
    def test_key_list_versions(self, vault, **kwargs):
        self.assertIsNotNone(vault)
        vault_uri = vault.properties.vault_uri
        key_name = self.get_resource_name('key')

        max_keys = self.list_test_size
        expected = {}

        # create many key versions
        for x in range(0, max_keys):
            key_bundle = None
            error_count = 0
            while not key_bundle:
                try:
                    key_bundle = self.client.create_key(
                        vault_uri, key_name, 'RSA')
                    kid = KeyVaultId.parse_key_id(
                        key_bundle.key.kid).id.strip('/')
                    expected[kid] = key_bundle.attributes
                except Exception as ex:
                    if hasattr(ex, 'message') and 'Throttled' in ex.message:
                        error_count += 1
                        time.sleep(2.5 * error_count)
                        continue
                    else:
                        raise ex

        # list key versions
        self._validate_key_list(
            list(self.client.get_key_versions(vault_uri, key_name)), expected)
示例#24
0
    def test_secret_list(self, vault, **kwargs):
        self.assertIsNotNone(vault)
        vault_uri = vault.properties.vault_uri

        max_secrets = self.list_test_size
        expected = {}

        # create many secrets
        for x in range(0, max_secrets):
            secret_name = 'sec{}'.format(x)
            secret_value = self.get_resource_name('secVal{}'.format(x))
            secret_bundle = None
            error_count = 0
            while not secret_bundle:
                try:
                    secret_bundle = self.client.set_secret(vault_uri, secret_name, secret_value)
                    sid = KeyVaultId.parse_secret_id(secret_bundle.id).base_id.strip('/')
                    expected[sid] = secret_bundle.attributes
                except Exception as ex:
                    if hasattr(ex, 'message') and 'Throttled' in ex.message:
                        error_count += 1
                        time.sleep(2.5 * error_count)
                        continue
                    else:
                        raise ex

        # list secrets
        result = list(self.client.get_secrets(vault_uri, self.list_test_size))
        self._validate_secret_list(result, expected)
示例#25
0
 def create_key(self, name, tags, kty='RSA'):
     ''' Creates a key '''
     key_bundle = self.client.create_key(vault_base_url=self.keyvault_uri,
                                         key_name=name,
                                         kty=kty,
                                         tags=tags)
     key_id = KeyVaultId.parse_key_id(key_bundle.key.kid)
     return key_id.id
示例#26
0
    def test_key_recover_and_purge(self, vault, **kwargs):
        self.assertIsNotNone(vault)
        vault_uri = vault.properties.vault_uri

        keys = {}

        # create keys to recover
        for i in range(0, self.list_test_size):
            key_name = self.get_resource_name('keyrec{}'.format(str(i)))
            keys[key_name] = self.client.create_key(vault_uri, key_name, 'RSA')

        # create keys to purge
        for i in range(0, self.list_test_size):
            key_name = self.get_resource_name('keyprg{}'.format(str(i)))
            keys[key_name] = self.client.create_key(vault_uri, key_name, 'RSA')

        # delete all keys
        for key_name in keys.keys():
            self.client.delete_key(vault_uri, key_name)

        if not self.is_playback():
            time.sleep(20)

        # validate all our deleted keys are returned by get_deleted_keys
        deleted = [KeyVaultId.parse_key_id(s.kid).name for s in self.client.get_deleted_keys(vault_uri)]
        self.assertTrue(all(s in deleted for s in keys.keys()))

        # recover select keys
        for key_name in [s for s in keys.keys() if s.startswith('keyrec')]:
            self.client.recover_deleted_key(vault_uri, key_name)

        # purge select keys
        for key_name in [s for s in keys.keys() if s.startswith('keyprg')]:
            self.client.purge_deleted_key(vault_uri, key_name)

        if not self.is_playback():
            time.sleep(20)

        # validate none of our deleted keys are returned by get_deleted_keys
        deleted = [KeyVaultId.parse_key_id(s.kid).name for s in self.client.get_deleted_keys(vault_uri)]
        self.assertTrue(not any(s in deleted for s in keys.keys()))

        # validate the recovered keys
        expected = {k: v for k, v in keys.items() if k.startswith('key-') and k.endswith('-recover')}
        actual = {k: self.client.get_key(vault_uri, k) for k in expected.keys()}
        self.assertEqual(len(set(expected.keys()) & set(actual.keys())), len(expected))
示例#27
0
def azure_decrypt(config, logger, session, encrypted_field):
    data = config[encrypted_field]  # type: str
    if type(data) is dict:
        kv_session = session.get_session_for_resource(resource=RESOURCE_VAULT)
        secret_id = KeyVaultId.parse_secret_id(data['secret'])
        kv_client = kv_session.client('azure.keyvault.KeyVaultClient')
        return kv_client.get_secret(secret_id.vault, secret_id.name, secret_id.version).value

    return data
 def create_update_secret(self, name, secret, tags, content_type):
     ''' Creates/Updates a secret '''
     secret_bundle = self.client.set_secret(self.keyvault_uri,
                                            name,
                                            secret,
                                            tags=tags,
                                            content_type=content_type)
     secret_id = KeyVaultId.parse_secret_id(secret_bundle.id)
     return secret_id.id
    def test_secret_crud_operations(self, vault, **kwargs):
        self.assertIsNotNone(vault)
        vault_uri = vault.properties.vault_uri
        secret_name = 'crud-secret'
        secret_value = self.get_resource_name('crud_secret_value')

        # create secret
        secret_bundle = self.client.set_secret(vault_uri, secret_name,
                                               secret_value)
        self._validate_secret_bundle(secret_bundle, vault_uri, secret_name,
                                     secret_value)
        created_bundle = secret_bundle
        secret_id = KeyVaultId.parse_secret_id(created_bundle.id)

        # get secret without version
        self.assertEqual(
            created_bundle,
            self.client.get_secret(secret_id.vault, secret_id.name, ''))

        # get secret with version
        self.assertEqual(
            created_bundle,
            self.client.get_secret(secret_id.vault, secret_id.name,
                                   secret_id.version))

        def _update_secret(secret_uri):
            updating_bundle = copy.deepcopy(created_bundle)
            updating_bundle.content_type = 'text/plain'
            updating_bundle.attributes.expires = date_parse.parse(
                '2050-02-02T08:00:00.000Z')
            updating_bundle.tags = {'foo': 'updated tag'}
            sid = KeyVaultId.parse_secret_id(secret_uri)
            secret_bundle = self.client.update_secret(
                sid.vault, sid.name, sid.version, updating_bundle.content_type,
                updating_bundle.attributes, updating_bundle.tags)
            self.assertEqual(updating_bundle.tags, secret_bundle.tags)
            self.assertEqual(updating_bundle.id, secret_bundle.id)
            self.assertNotEqual(str(updating_bundle.attributes.updated),
                                str(secret_bundle.attributes.updated))
            return secret_bundle

        # update secret without version
        secret_bundle = _update_secret(secret_id.base_id)

        # update secret with version
        secret_bundle = _update_secret(secret_id.id)

        # delete secret
        self.client.delete_secret(secret_id.vault, secret_id.name)

        # get secret returns not found
        try:
            self.client.get_secret(secret_id.vault, secret_id.name, '')
        except Exception as ex:
            if not hasattr(ex,
                           'message') or 'not found' not in ex.message.lower():
                raise ex
def secretitem_to_dict(secretitem):
    return dict(sid=secretitem.id,
                version=KeyVaultId.parse_secret_id(secretitem.id).version,
                tags=secretitem.tags,
                attributes=dict(
                    enabled=secretitem.attributes.enabled,
                    not_before=secretitem.attributes.not_before,
                    expires=secretitem.attributes.expires,
                    created=secretitem.attributes.created,
                    updated=secretitem.attributes.updated,
                    recovery_level=secretitem.attributes.recovery_level))
 def _validate_certificate_bundle(self, cert, vault, cert_name, cert_policy):
     cert_id = KeyVaultId.parse_certificate_id(cert.id)
     self.assertEqual(cert_id.vault.strip('/'), vault.strip('/'))
     self.assertEqual(cert_id.name, cert_name)
     self.assertIsNotNone(cert)
     self.assertIsNotNone(cert.x509_thumbprint)
     self.assertIsNotNone(cert.cer)
     self.assertIsNotNone(cert.attributes)
     self.assertIsNotNone(cert.policy)
     self.assertIsNotNone(cert.policy.id)
     self.assertIsNotNone(cert.policy.issuer_parameters)
     self.assertIsNotNone(cert.policy.lifetime_actions)
     self.assertEqual(cert.policy.key_properties, cert_policy.key_properties)
     self.assertEqual(cert.policy.secret_properties, cert_policy.secret_properties)
     self.assertIsNotNone(cert.policy.x509_certificate_properties)
     if cert_policy.x509_certificate_properties:
         self.assertEqual(cert.policy.x509_certificate_properties.validity_in_months,
                          cert_policy.x509_certificate_properties.validity_in_months)
     KeyVaultId.parse_secret_id(cert.sid)
     KeyVaultId.parse_key_id(cert.kid)
示例#32
0
    def test_key_import(self, vault, **kwargs):
        self.assertIsNotNone(vault)
        vault_uri = vault.properties.vault_uri
        key_name = self.get_resource_name('keyimp')

        key_id = KeyVaultId.create_key_id(vault_uri, key_name)

        # import to software
        self._import_test_key(vault, key_id, False)

        # import to hardware
        self._import_test_key(vault, key_id, True)
示例#33
0
 def _update_key(key_uri):
     updating_bundle = copy.deepcopy(created_bundle)
     updating_bundle.attributes.expires = date_parse.parse('2050-02-02T08:00:00.000Z')
     updating_bundle.key.key_ops = ['encrypt', 'decrypt']
     updating_bundle.tags = {'foo': 'updated tag'}
     kid = KeyVaultId.parse_key_id(key_uri)
     key_bundle = self.client.update_key(
         kid.vault, kid.name, kid.version, updating_bundle.key.key_ops, updating_bundle.attributes,
         updating_bundle.tags)
     self.assertEqual(updating_bundle.tags, key_bundle.tags)
     self.assertEqual(updating_bundle.key.kid, key_bundle.key.kid)
     return key_bundle
示例#34
0
    def test_key_import(self, vault, **kwargs):
        self.assertIsNotNone(vault)
        vault_uri = vault.properties.vault_uri
        key_name = self.get_resource_name('keyimp')

        key_id = KeyVaultId.create_key_id(vault_uri, key_name)

        # import to software
        self._import_test_key(vault, key_id, False)

        # import to hardware
        self._import_test_key(vault, key_id, True)
示例#35
0
    def test_key_crud_operations(self, vault, **kwargs):
        self.assertIsNotNone(vault)
        vault_uri = vault.properties.vault_uri
        key_name = self.get_resource_name('key')

        # create key
        created_bundle = self.client.create_key(vault_uri, key_name, 'RSA')
        self._validate_rsa_key_bundle(created_bundle, vault_uri, key_name,
                                      'RSA')
        key_id = KeyVaultId.parse_key_id(created_bundle.key.kid)

        # get key without version
        self.assertEqual(created_bundle,
                         self.client.get_key(key_id.vault, key_id.name, ''))

        # get key with version
        self.assertEqual(
            created_bundle,
            self.client.get_key(key_id.vault, key_id.name, key_id.version))

        def _update_key(key_uri):
            updating_bundle = copy.deepcopy(created_bundle)
            updating_bundle.attributes.expires = date_parse.parse(
                '2050-02-02T08:00:00.000Z')
            updating_bundle.key.key_ops = ['encrypt', 'decrypt']
            updating_bundle.tags = {'foo': 'updated tag'}
            kid = KeyVaultId.parse_key_id(key_uri)
            key_bundle = self.client.update_key(kid.vault, kid.name,
                                                kid.version,
                                                updating_bundle.key.key_ops,
                                                updating_bundle.attributes,
                                                updating_bundle.tags)
            self.assertEqual(updating_bundle.tags, key_bundle.tags)
            self.assertEqual(updating_bundle.key.kid, key_bundle.key.kid)
            return key_bundle

        # update key without version
        created_bundle = _update_key(key_id.base_id)

        # update key with version
        created_bundle = _update_key(key_id.id)

        # delete key
        self.client.delete_key(key_id.vault, key_id.name)

        # get key returns not found
        try:
            self.client.get_key(key_id.vault, key_id.name, '')
            self.fail('Get should fail')
        except Exception as ex:
            if not hasattr(ex,
                           'message') or 'not found' not in ex.message.lower():
                raise ex
def keyitem_to_dict(keyitem):
    return dict(kid=keyitem.kid,
                version=KeyVaultId.parse_key_id(keyitem.kid).version,
                tags=keyitem.tags,
                manged=keyitem.managed,
                attributes=dict(
                    enabled=keyitem.attributes.enabled,
                    not_before=keyitem.attributes.not_before,
                    expires=keyitem.attributes.expires,
                    created=keyitem.attributes.created,
                    updated=keyitem.attributes.updated,
                    recovery_level=keyitem.attributes.recovery_level))
示例#37
0
 def _update_secret(secret_uri):
     updating_bundle = copy.deepcopy(created_bundle)
     updating_bundle.content_type = 'text/plain'
     updating_bundle.attributes.expires = date_parse.parse('2050-02-02T08:00:00.000Z')
     updating_bundle.tags = {'foo': 'updated tag'}
     sid = KeyVaultId.parse_secret_id(secret_uri)
     secret_bundle = self.client.update_secret(
         sid.vault, sid.name, sid.version, updating_bundle.content_type, updating_bundle.attributes,
         updating_bundle.tags)
     self.assertEqual(updating_bundle.tags, secret_bundle.tags)
     self.assertEqual(updating_bundle.id, secret_bundle.id)
     self.assertNotEqual(str(updating_bundle.attributes.updated), str(secret_bundle.attributes.updated))
     return secret_bundle
def secretbundle_to_dict(bundle):
    return dict(tags=bundle.tags,
                attributes=dict(
                    enabled=bundle.attributes.enabled,
                    not_before=bundle.attributes.not_before,
                    expires=bundle.attributes.expires,
                    created=bundle.attributes.created,
                    updated=bundle.attributes.updated,
                    recovery_level=bundle.attributes.recovery_level),
                sid=bundle.id,
                version=KeyVaultId.parse_secret_id(bundle.id).version,
                content_type=bundle.content_type,
                secret=bundle.value)
示例#39
0
 def _update_secret(secret_uri):
     updating_bundle = copy.deepcopy(created_bundle)
     updating_bundle.content_type = 'text/plain'
     updating_bundle.attributes.expires = date_parse.parse('2050-02-02T08:00:00.000Z')
     updating_bundle.tags = {'foo': 'updated tag'}
     sid = KeyVaultId.parse_secret_id(secret_uri)
     secret_bundle = self.client.update_secret(
         sid.vault, sid.name, sid.version, updating_bundle.content_type, updating_bundle.attributes,
         updating_bundle.tags)
     self.assertEqual(updating_bundle.tags, secret_bundle.tags)
     self.assertEqual(updating_bundle.id, secret_bundle.id)
     self.assertNotEqual(str(updating_bundle.attributes.updated), str(secret_bundle.attributes.updated))
     return secret_bundle
示例#40
0
    def test_create_object_id(self):
        # success scenarios
        expected = self._get_expected('keys', 'myvault', 'mykey')
        res = KeyVaultId.create_object_id('keys', 'https://myvault.vault.azure.net', ' mykey', None)
        self.assertEqual(res.__dict__, expected)

        res = KeyVaultId.create_object_id('keys', 'https://myvault.vault.azure.net', ' mykey', ' ')
        self.assertEqual(res.__dict__, expected)

        expected = self._get_expected('keys', 'myvault', 'mykey', 'abc123')
        res = KeyVaultId.create_object_id(' keys ', 'https://myvault.vault.azure.net', ' mykey ', ' abc123 ')
        self.assertEqual(res.__dict__, expected)

        # failure scenarios
        with self.assertRaises(TypeError):
            KeyVaultId.create_object_id('keys', 'https://myvault.vault.azure.net', ['stuff'], '')
        with self.assertRaises(ValueError):
            KeyVaultId.create_object_id('keys', 'https://myvault.vault.azure.net', ' ', '')
        with self.assertRaises(ValueError):
            KeyVaultId.create_object_id('keys', 'myvault.vault.azure.net', 'mykey', '')
示例#41
0
    def test_parse_object_id(self):
        # success scenarios
        expected = self._get_expected('keys', 'myvault', 'mykey', 'abc123')
        res = KeyVaultId.parse_object_id('keys', 'https://myvault.vault.azure.net/keys/mykey/abc123')
        self.assertEqual(res.__dict__, expected)

        expected = self._get_expected('keys', 'myvault', 'mykey')
        res = KeyVaultId.parse_object_id('keys', 'https://myvault.vault.azure.net/keys/mykey')
        self.assertEqual(res.__dict__, expected)

        # failure scenarios
        with self.assertRaises(ValueError):
            KeyVaultId.parse_object_id('secret', 'https://myvault.vault.azure.net/keys/mykey/abc123')
        with self.assertRaises(ValueError):
            KeyVaultId.parse_object_id('keys', 'https://myvault.vault.azure.net/keys/mykey/abc123/extra')
        with self.assertRaises(ValueError):
            KeyVaultId.parse_object_id('keys', 'https://myvault.vault.azure.net')
示例#42
0
 def _update_key(key_uri):
     updating_bundle = copy.deepcopy(created_bundle)
     updating_bundle.attributes.expires = date_parse.parse(
         '2050-02-02T08:00:00.000Z')
     updating_bundle.key.key_ops = ['encrypt', 'decrypt']
     updating_bundle.tags = {'foo': 'updated tag'}
     kid = KeyVaultId.parse_key_id(key_uri)
     key_bundle = self.client.update_key(kid.vault, kid.name,
                                         kid.version,
                                         updating_bundle.key.key_ops,
                                         updating_bundle.attributes,
                                         updating_bundle.tags)
     self.assertEqual(updating_bundle.tags, key_bundle.tags)
     self.assertEqual(updating_bundle.key.kid, key_bundle.key.kid)
     return key_bundle
示例#43
0
    def test_secret_crud_operations(self, vault, **kwargs):
        self.assertIsNotNone(vault)
        vault_uri = vault.properties.vault_uri
        secret_name = 'crud-secret'
        secret_value = self.get_resource_name('crud_secret_value')

        # create secret
        secret_bundle = self.client.set_secret(vault_uri, secret_name, secret_value)
        self._validate_secret_bundle(secret_bundle, vault_uri, secret_name, secret_value)
        created_bundle = secret_bundle
        secret_id = KeyVaultId.parse_secret_id(created_bundle.id)

        # get secret without version
        self.assertEqual(created_bundle, self.client.get_secret(secret_id.vault, secret_id.name, ''))

        # get secret with version
        self.assertEqual(created_bundle, self.client.get_secret(secret_id.vault, secret_id.name, secret_id.version))

        def _update_secret(secret_uri):
            updating_bundle = copy.deepcopy(created_bundle)
            updating_bundle.content_type = 'text/plain'
            updating_bundle.attributes.expires = date_parse.parse('2050-02-02T08:00:00.000Z')
            updating_bundle.tags = {'foo': 'updated tag'}
            sid = KeyVaultId.parse_secret_id(secret_uri)
            secret_bundle = self.client.update_secret(
                sid.vault, sid.name, sid.version, updating_bundle.content_type, updating_bundle.attributes,
                updating_bundle.tags)
            self.assertEqual(updating_bundle.tags, secret_bundle.tags)
            self.assertEqual(updating_bundle.id, secret_bundle.id)
            self.assertNotEqual(str(updating_bundle.attributes.updated), str(secret_bundle.attributes.updated))
            return secret_bundle

        # update secret without version
        secret_bundle = _update_secret(secret_id.base_id)

        # update secret with version
        secret_bundle = _update_secret(secret_id.id)

        # delete secret
        self.client.delete_secret(secret_id.vault, secret_id.name)

        # get secret returns not found
        try:
            self.client.get_secret(secret_id.vault, secret_id.name, '')
        except Exception as ex:
            if not hasattr(ex, 'message') or 'not found' not in ex.message.lower():
                raise ex
示例#44
0
    def test_key_backup_and_restore(self, vault, **kwargs):
        self.assertIsNotNone(vault)
        vault_uri = vault.properties.vault_uri
        key_name = self.get_resource_name('keybak')

        # create key
        created_bundle = self.client.create_key(vault_uri, key_name, 'RSA')
        key_id = KeyVaultId.parse_key_id(created_bundle.key.kid)

        # backup key
        key_backup = self.client.backup_key(key_id.vault, key_id.name).value

        # delete key
        self.client.delete_key(key_id.vault, key_id.name)

        # restore key
        self.assertEqual(created_bundle, self.client.restore_key(vault_uri, key_backup))
示例#45
0
def get_keyvault_secret(user_identity_id, keyvault_secret_id):
    secret_id = KeyVaultId.parse_secret_id(keyvault_secret_id)
    access_token = None

    # Use UAI if client_id is provided
    if user_identity_id:
        msi = MSIAuthentication(client_id=user_identity_id,
                                resource=RESOURCE_VAULT)
    else:
        msi = MSIAuthentication(resource=RESOURCE_VAULT)

    access_token = AccessToken(token=msi.token['access_token'])
    credentials = KeyVaultAuthentication(lambda _1, _2, _3: access_token)

    kv_client = KeyVaultClient(credentials)
    return kv_client.get_secret(secret_id.vault, secret_id.name,
                                secret_id.version).value
示例#46
0
    def test_key_crud_operations(self, vault, **kwargs):
        self.assertIsNotNone(vault)
        vault_uri = vault.properties.vault_uri
        key_name = self.get_resource_name('key')

        # create key
        created_bundle = self.client.create_key(vault_uri, key_name, 'RSA')
        self._validate_rsa_key_bundle(created_bundle, vault_uri, key_name, 'RSA')
        key_id = KeyVaultId.parse_key_id(created_bundle.key.kid)

        # get key without version
        self.assertEqual(created_bundle, self.client.get_key(key_id.vault, key_id.name, ''))

        # get key with version
        self.assertEqual(created_bundle, self.client.get_key(key_id.vault, key_id.name, key_id.version))

        def _update_key(key_uri):
            updating_bundle = copy.deepcopy(created_bundle)
            updating_bundle.attributes.expires = date_parse.parse('2050-02-02T08:00:00.000Z')
            updating_bundle.key.key_ops = ['encrypt', 'decrypt']
            updating_bundle.tags = {'foo': 'updated tag'}
            kid = KeyVaultId.parse_key_id(key_uri)
            key_bundle = self.client.update_key(
                kid.vault, kid.name, kid.version, updating_bundle.key.key_ops, updating_bundle.attributes,
                updating_bundle.tags)
            self.assertEqual(updating_bundle.tags, key_bundle.tags)
            self.assertEqual(updating_bundle.key.kid, key_bundle.key.kid)
            return key_bundle

        # update key without version
        created_bundle = _update_key(key_id.base_id)

        # update key with version
        created_bundle = _update_key(key_id.id)

        # delete key
        self.client.delete_key(key_id.vault, key_id.name)

        # get key returns not found
        try:
            self.client.get_key(key_id.vault, key_id.name, '')
            self.fail('Get should fail')
        except Exception as ex:
            if not hasattr(ex, 'message') or 'not found' not in ex.message.lower():
                raise ex
    def _validate_issuer_bundle(self, bundle, vault, name, provider, credentials, org_details):
        self.assertIsNotNone(bundle)
        self.assertIsNotNone(bundle.attributes)
        self.assertIsNotNone(bundle.organization_details)
        self.assertEqual(bundle.provider, provider)

        issuer_id = KeyVaultId.parse_certificate_issuer_id(bundle.id)
        self.assertEqual(issuer_id.vault.strip('/'), vault.strip('/'))
        self.assertEqual(issuer_id.name, name)

        if credentials:
            self.assertEqual(bundle.credentials.account_id, credentials.account_id)
        if org_details:
            # To Accomodate tiny change in == semantic in msrest 0.4.20
            org_details.additional_properties = {}
            bundle.organization_details.additional_properties = {}

            self.assertEqual(bundle.organization_details, org_details)
示例#48
0
    def test_backup_restore(self, vault, **kwargs):
        self.assertIsNotNone(vault)
        vault_uri = vault.properties.vault_uri
        secret_name = self.get_resource_name('secbak')
        secret_value = self.get_resource_name('secVal')

        # create secret
        created_bundle = self.client.set_secret(vault_uri, secret_name, secret_value)
        secret_id = KeyVaultId.parse_secret_id(created_bundle.id)

        # backup secret
        secret_backup = self.client.backup_secret(secret_id.vault, secret_id.name).value

        # delete secret
        self.client.delete_secret(secret_id.vault, secret_id.name)

        # restore secret
        self.assertEqual(created_bundle.attributes, self.client.restore_secret(vault_uri, secret_backup).attributes)
示例#49
0
 def test_parse_certificate_issuer_id(self):
     expected = self._get_expected('certificates', 'myvault', 'myissuer')
     res = KeyVaultId.parse_certificate_issuer_id('https://myvault.vault.azure.net/certificates/issuers/myissuer')
     self.assertEqual(res.__dict__, expected)
示例#50
0
 def test_parse_certificate_operation_id(self):
     expected = self._get_expected('certificates', 'myvault', 'mycert', 'pending')
     res = KeyVaultId.parse_certificate_operation_id('https://myvault.vault.azure.net/certificates/mycert/pending')
     self.assertEqual(res.__dict__, expected)
 def create_key(self, name, tags, kty='RSA'):
     ''' Creates a key '''
     key_bundle = self.client.create_key(self.keyvault_uri, name, kty, tags=tags)
     key_id = KeyVaultId.parse_key_id(key_bundle.key.kid)
     return key_id.id
 def delete_key(self, name):
     ''' Deletes a key '''
     deleted_key = self.client.delete_key(self.keyvault_uri, name)
     key_id = KeyVaultId.parse_key_id(deleted_key.key.kid)
     return key_id.id
    def test_create_kafka_cluster_with_disk_encryption(self, resource_group, location, storage_account, storage_account_key, vault):
        # create managed identities for Azure resources.
        msi_name = self.get_resource_name('hdipyuai')
        msi = self.msi_client.user_assigned_identities.create_or_update(resource_group.name, msi_name, location)

        # add managed identity to vault
        required_permissions = Permissions(keys=[KeyPermissions.get, KeyPermissions.wrap_key, KeyPermissions.unwrap_key],
                                           secrets=[SecretPermissions.get, SecretPermissions.set,SecretPermissions.delete])
        vault.properties.access_policies.append(
            AccessPolicyEntry(tenant_id=self.tenant_id,
                              object_id=msi.principal_id,
                              permissions=required_permissions)
        )
        update_params = VaultCreateOrUpdateParameters(location=location,
                                                    properties=vault.properties)
        vault = self.vault_mgmt_client.vaults.create_or_update(resource_group.name, vault.name, update_params).result()
        self.assertIsNotNone(vault)

        # create key
        vault_uri = vault.properties.vault_uri
        key_name = self.get_resource_name('hdipykey1')
        created_bundle = self.vault_client.create_key(vault_uri, key_name, 'RSA')
        vault_key = KeyVaultId.parse_key_id(created_bundle.key.kid)

        # create HDInsight cluster with Kafka disk encryption
        rg_name = resource_group.name
        cluster_name = self.get_resource_name('hdisdk-kafka-byok')
        create_params = self.get_cluster_create_params(location, cluster_name, storage_account, storage_account_key)
        create_params.properties.cluster_definition.kind = 'Kafka'
        workernode = next(item for item in create_params.properties.compute_profile.roles if item.name == 'workernode')
        workernode.data_disks_groups = [
            DataDisksGroups(
                disks_per_node=8
            )
        ]
        create_params.identity = ClusterIdentity(
            type=ResourceIdentityType.user_assigned,
            user_assigned_identities={msi.id: ClusterIdentityUserAssignedIdentitiesValue()}
        )
        create_params.properties.disk_encryption_properties = DiskEncryptionProperties(
            vault_uri=vault_key.vault,
            key_name=vault_key.name,
            key_version=vault_key.version,
            msi_resource_id=msi.id
        )
        cluster = self.hdinsight_client.clusters.create(resource_group.name, cluster_name, create_params).result()
        self.validate_cluster(cluster_name, create_params, cluster)

        # check disk encryption properties
        self.assertIsNotNone(cluster.properties.disk_encryption_properties)
        self.assertEqual(create_params.properties.disk_encryption_properties.vault_uri, cluster.properties.disk_encryption_properties.vault_uri)
        self.assertEqual(create_params.properties.disk_encryption_properties.key_name, cluster.properties.disk_encryption_properties.key_name)
        self.assertEqual(create_params.properties.disk_encryption_properties.msi_resource_id.lower(), cluster.properties.disk_encryption_properties.msi_resource_id.lower())

        # create a new key
        new_key_name = self.get_resource_name('hdipykey2')
        created_bundle = self.vault_client.create_key(vault_uri, new_key_name, 'RSA')
        new_vault_key = KeyVaultId.parse_key_id(created_bundle.key.kid)
        rotate_params = ClusterDiskEncryptionParameters(
            vault_uri=new_vault_key.vault,
            key_name=new_vault_key.name,
            key_version=new_vault_key.version
        )

        # rotate cluster key
        self.hdinsight_client.clusters.rotate_disk_encryption_key(rg_name, cluster_name, rotate_params).wait()
        cluster = self.hdinsight_client.clusters.get(rg_name, cluster_name)

        # check disk encryption properties
        self.assertIsNotNone(cluster.properties.disk_encryption_properties)
        self.assertEqual(rotate_params.vault_uri, cluster.properties.disk_encryption_properties.vault_uri)
        self.assertEqual(rotate_params.key_name, cluster.properties.disk_encryption_properties.key_name)
        self.assertEqual(msi.id.lower(), cluster.properties.disk_encryption_properties.msi_resource_id.lower())