示例#1
0
    def backup_restore_certificate(self):
        """
        backs up a key vault certificate and restores it to another key vault
        """
        # create a key vault
        first_vault = self.create_vault()

        # create a certificate client
        credential = DefaultAzureCredential()
        first_certificate_client = CertificateClient(
            vault_url=first_vault.properties.vault_uri, credential=credential)

        # add a certificate to the vault
        certificate_name = get_name('certificate')

        certificate = first_certificate_client.begin_create_certificate(
            certificate_name, CertificatePolicy.get_default()).result()
        print('created certificate {}'.format(certificate.name))

        # list the certificates in the vault
        certificate_properties = first_certificate_client.list_properties_of_certificates(
        )
        print("all of the certificates in the client's vault:")
        for certificate_property in certificate_properties:
            print(certificate_property.name)

        # backup the certificate
        backup = first_certificate_client.backup_certificate(certificate_name)
        print('backed up certificate {}'.format(certificate_name))

        # create a second vault
        second_vault = self.create_vault()

        # create a certificate client
        second_certificate_client = CertificateClient(
            vault_url=second_vault.properties.vault_uri, credential=credential)

        # restore the certificate to the new vault
        restored = second_certificate_client.restore_certificate_backup(backup)
        print('restored certificate {}'.format(restored.name))

        # list the certificates in the new vault
        certificate_properties = second_certificate_client.list_properties_of_certificates(
        )
        print("all of the certificates in the new vault:")
        for certificate_property in certificate_properties:
            print(certificate_property.name)
class Acme2KeyVault:
    """
    Provisions ACME/Let's Encrypt TLS certificates
    """
    def __init__(self, credential: ChainedTokenCredential, options: Options):
        self.options = options
        self._cert_client = CertificateClient(options.azure_keyvault_uri,
                                              credential)
        self._dns_client = DnsManagementClient(credential,
                                               options.azure_subscription_id)
        blob_svc_client = BlobServiceClient.from_connection_string(
            options.cert_request_connection_str, )
        self._container_client = blob_svc_client.get_container_client(
            options.cert_request_container)

    def _cert_policy_from_request(
            self, cert_request: CertRequest) -> CertificatePolicy:
        subject = f"CN={cert_request.common_name}"
        if self.options.cert_subject:
            subject += " " + self.options.cert_subject

        return CertificatePolicy(
            issuer_name="Unknown",
            key_size=self.options.cert_key_size,
            subject=subject,
            san_dns_names=cert_request.alternative_names,
            validity_in_months=CERT_VALIDITY_MONTHS,
            exportable=True,
        )

    def _generate_new_cert_key(self, cert_request: CertRequest) -> str:
        """
        Generates a new certificate key and returns it in the PEM format
        (as required by Azure keyvault)
        Note: Deletes any previous certificate operation with the same id
        """
        # Delete the previous certificate operation if it exists
        try:
            self._cert_client.delete_certificate_operation(cert_request.id)
        except HttpResponseError as e:
            if e.status_code != 404:
                raise

        # Generate a new certificate key
        self._cert_client.begin_create_certificate(
            certificate_name=cert_request.id,
            policy=self._cert_policy_from_request(cert_request),
        )

        # Get the CSR for the new certificate key
        cert_op = self._cert_client.get_certificate_operation(cert_request.id)
        return _csr_bytes_to_pem(cert_op.csr)

    def _create_acme_client(self) -> acme_client.ClientV2:
        net = acme_client.ClientNetwork(key=josepy.JWKRSA(
            key=rsa.generate_private_key(
                public_exponent=ACME_RSA_PUBLIC_EXPONENT,
                key_size=ACME_RSA_KEY_SIZE,
                backend=default_backend(),
            )))
        directory = acme_messages.Directory.from_json(
            net.get(self.options.acme_directory_url).json())
        return acme_client.ClientV2(directory=directory, net=net)

    def _authorization_to_relative_domain(
        self,
        authorization: acme_messages.Authorization,
    ):
        relative_domain = authorization.identifier.value.replace(
            f".{self.options.azure_dns_zone}", "")
        return f"_acme-challenge.{relative_domain}"

    def _setup_dns_challenge(
        self,
        record_set_name: str,
        value: str,
    ):
        self._dns_client.record_sets.create_or_update(
            resource_group_name=self.options.azure_dns_zone_resource_group,
            zone_name=self.options.azure_dns_zone,
            relative_record_set_name=record_set_name,
            record_type=RecordType.TXT,
            parameters=RecordSet(
                ttl=ACME_CHALLENGE_TXT_RECORD_TTL,
                txt_records=[TxtRecord(value=[value])],
            ),
        )

    def _teardown_dns_challenge(
        self,
        record_set_name: str,
    ):
        self._dns_client.record_sets.delete(
            resource_group_name=self.options.azure_dns_zone_resource_group,
            zone_name=self.options.azure_dns_zone,
            relative_record_set_name=record_set_name,
            record_type=RecordType.TXT,
        )

    def _verify_challenge(
        self,
        client: acme_client.ClientV2,
        authz: acme_messages.Authorization,
    ):
        challenge_body, dns_challenge = _find_dns_challenge(authz)
        validation_domain = self._authorization_to_relative_domain(authz)
        response, validation_value = dns_challenge.response_and_validation(
            client.net.key)

        try:
            logging.debug(
                "Creating DNS challenge: '%s' = '%s'",
                validation_domain,
                validation_value,
            )
            self._setup_dns_challenge(
                record_set_name=validation_domain,
                value=validation_value,
            )
            logging.debug("Answering DNS challenge")
            client.answer_challenge(challenge_body, response)
        finally:
            logging.debug("Tearing down DNS challenge: '%s'",
                          validation_domain)
            self._teardown_dns_challenge(record_set_name=validation_domain, )

    def _order_certificate(self, csr: str) -> str:
        logging.debug("Registering an account")
        client = self._create_acme_client()
        client.new_account(
            acme_messages.NewRegistration.from_data(
                email=self.options.acme_contact_email,
                terms_of_service_agreed=True,
            ))

        logging.debug("Placing a new order")
        order = client.new_order(csr)

        logging.debug("Verifying challenges")
        authz_resource: acme_messages.AuthorizationResource
        for authz_resource in order.authorizations:  # pylint: disable=not-an-iterable
            authz: acme_messages.Authorization = authz_resource.body
            self._verify_challenge(client, authz)

        logging.debug("Finalizing order")
        finalized_order = client.poll_and_finalize(order)
        return finalized_order.fullchain_pem

    def _store_certificate(self, cert_request: CertRequest, certificate: str):
        self._cert_client.merge_certificate(
            cert_request.id,
            [certificate.encode("ascii")],
        )

    def provision(self, cert_request: CertRequest):
        """
        Provision a TLS certificate based on the given certificate request.

        This can be used for both provisioning new certificates as well as renewing existing ones.
        """
        logging.info("Generating a certificate key %s", cert_request.id)
        csr = self._generate_new_cert_key(cert_request)

        logging.info("Ordering a certificate for %s", cert_request.common_name)
        certificate = self._order_certificate(csr)

        logging.info("Storing certificate %s", cert_request.id)
        self._store_certificate(cert_request, certificate)

    def save(self, cert_request: CertRequest):
        """
        Saves the certificate request details for certificate renewal purposes
        """
        blob_client = self._container_client.get_blob_client(
            cert_request.blob_name)
        blob_client.upload_blob(data=cert_request.json(), overwrite=True)

    def _cert_needs_renewal(self, certificate: CertificateProperties) -> bool:
        # Disabled certificates are not renewed
        if not certificate.enabled:
            return False

        now = datetime.now(tz=timezone.utc)
        time_until_expiration = certificate.expires_on.replace() - now
        return time_until_expiration.days <= self.options.cert_expiry_threshold_days

    def find_cert_names_needing_renewal(self) -> typing.List[str]:
        """
        Find all the names of the certificates from Key Vault that need to be renewed
        """
        certificates = self._cert_client.list_properties_of_certificates()
        return [
            certificate.name for certificate in certificates
            if self._cert_needs_renewal(certificate)
        ]

    def find_certs_needing_renewal(self) -> typing.List[CertRequest]:
        """
        Find all the certificates from Key Vault that need to be renewed
        """
        cert_requests: typing.List[CertRequest] = []
        for cert_name in self.find_cert_names_needing_renewal():
            blob_client = self._container_client.get_blob_client(
                _cert_name_to_blob_name(cert_name))
            try:
                bytes = blob_client.download_blob().readall()
                cert_request = CertRequest.parse_raw(bytes)
                cert_requests.append(cert_request)
            except HttpResponseError as e:
                if e.status_code == 404:
                    logging.error(
                        "Could not find cert request document for certificate %s",
                        cert_name,
                    )
                else:
                    raise

        return cert_requests
    storage_certificate_poller = client.begin_create_certificate(
        certificate_name=storage_cert_name,
        policy=CertificatePolicy.get_default())

    # await the creation of the bank and storage certificate
    bank_certificate = bank_certificate_poller.result()
    storage_certificate = storage_certificate_poller.result()

    print("Certificate with name '{0}' was created.".format(
        bank_certificate.name))
    print("Certificate with name '{0}' was created.".format(
        storage_certificate.name))

    # Let's list the certificates.
    print("\n.. List certificates from the Key Vault")
    certificates = client.list_properties_of_certificates()
    for certificate in certificates:
        print("Certificate with name '{0}' was found.".format(
            certificate.name))

    # You've decided to add tags to the certificate you created. Calling begin_create_certificate on an existing
    # certificate creates a new version of the certificate in the Key Vault with the new value.

    tags = {"a": "b"}
    bank_certificate_poller = client.begin_create_certificate(
        certificate_name=bank_cert_name,
        policy=CertificatePolicy.get_default(),
        tags=tags)
    bank_certificate = bank_certificate_poller.result()
    print(
        "Certificate with name '{0}' was created again with tags '{1}'".format(
    def deleted_certificate_recovery(self):
        """
        a sample of enumerating, retrieving, recovering and purging deleted certificates from a key vault 
        """
        # create a vault enabling the soft delete feature
        vault = self.create_vault()

        # create a certificate client
        credential = DefaultAzureCredential()
        certificate_client = CertificateClient(
            vault_url=vault.properties.vault_uri, credential=credential)

        # create certificates in the vault
        cert_to_recover = get_name('cert')
        cert_to_purge = get_name('cert')

        create_certificate_poller = certificate_client.begin_create_certificate(
            cert_to_recover, policy=CertificatePolicy.get_default())
        created_certificate = create_certificate_poller.result()
        print('created certificate {}'.format(created_certificate.name))

        create_certificate_poller = certificate_client.begin_create_certificate(
            cert_to_purge, policy=CertificatePolicy.get_default())
        created_certificate = create_certificate_poller.result()
        print('created certificate {}'.format(created_certificate.name))

        # list the vault certificates
        certificates = certificate_client.list_properties_of_certificates()
        print('list the vault certificates')
        for certificate in certificates:
            print(certificate.name)

        # delete the certificates
        deleted_certificate_poller = certificate_client.begin_delete_certificate(
            cert_to_recover)
        deleted_certificate = deleted_certificate_poller.result()
        deleted_certificate_poller.wait()
        print('deleted certificate {}'.format(deleted_certificate.name))

        deleted_certificate_poller = certificate_client.begin_delete_certificate(
            cert_to_purge)
        deleted_certificate = deleted_certificate_poller.result()
        deleted_certificate_poller.wait()
        print('deleted certificate {}'.format(deleted_certificate.name))

        # list the deleted certificates
        deleted_certs = certificate_client.list_deleted_certificates()
        print('deleted certificates:')
        for deleted_cert in deleted_certs:
            print(deleted_cert.name)

        # recover a deleted certificate
        recovered_certificate_poller = certificate_client.begin_recover_deleted_certificate(
            cert_to_recover)
        recovered_certificate_certificate = recovered_certificate_poller.result(
        )
        print('recovered certificate {}'.format(
            recovered_certificate_certificate.name))

        # purge a deleted certificate
        certificate_client.purge_deleted_certificate(cert_to_purge)
        time.sleep(50)
        print('purged certificate {}'.format(cert_to_purge))

        # list the vault certificates
        certificates = certificate_client.list_properties_of_certificates()
        print("all of the certificates in the client's vault:")
        for certificate in certificates:
            print(certificate.name)
if "KEYVAULT_CLIENT_ID" not in os.environ:
    raise EnvironmentError("Missing a client ID for Key Vault")
if "KEYVAULT_CLIENT_SECRET" not in os.environ:
    raise EnvironmentError("Missing a client secret for Key Vault")

credential = ClientSecretCredential(
    tenant_id=os.environ["KEYVAULT_TENANT_ID"],
    client_id=os.environ["KEYVAULT_CLIENT_ID"],
    client_secret=os.environ["KEYVAULT_CLIENT_SECRET"])

cert_client = CertificateClient(os.environ["AZURE_KEYVAULT_URL"], credential)
key_client = KeyClient(os.environ["AZURE_KEYVAULT_URL"], credential)
secret_client = SecretClient(os.environ["AZURE_KEYVAULT_URL"], credential)

test_certificates = [
    c for c in cert_client.list_properties_of_certificates()
    if c.name.startswith("livekvtest")
]
for certificate in test_certificates:
    cert_client.begin_delete_certificate(certificate.name).wait()
deleted_test_certificates = [
    c for c in cert_client.list_deleted_certificates()
    if c.name.startswith("livekvtest")
]
for certificate in deleted_test_certificates:
    cert_client.purge_deleted_certificate(certificate.name)

test_keys = [
    k for k in key_client.list_properties_of_keys()
    if k.name.startswith("livekvtest")
]