class Acme2KeyVault:
    """
    Provisions ACME/Let's Encrypt TLS certificates
    """
    def __init__(self, credential: ChainedTokenCredential, options: Options):
        self.options = options
        self._cert_client = CertificateClient(options.azure_keyvault_uri,
                                              credential)
        self._dns_client = DnsManagementClient(credential,
                                               options.azure_subscription_id)
        blob_svc_client = BlobServiceClient.from_connection_string(
            options.cert_request_connection_str, )
        self._container_client = blob_svc_client.get_container_client(
            options.cert_request_container)

    def _cert_policy_from_request(
            self, cert_request: CertRequest) -> CertificatePolicy:
        subject = f"CN={cert_request.common_name}"
        if self.options.cert_subject:
            subject += " " + self.options.cert_subject

        return CertificatePolicy(
            issuer_name="Unknown",
            key_size=self.options.cert_key_size,
            subject=subject,
            san_dns_names=cert_request.alternative_names,
            validity_in_months=CERT_VALIDITY_MONTHS,
            exportable=True,
        )

    def _generate_new_cert_key(self, cert_request: CertRequest) -> str:
        """
        Generates a new certificate key and returns it in the PEM format
        (as required by Azure keyvault)
        Note: Deletes any previous certificate operation with the same id
        """
        # Delete the previous certificate operation if it exists
        try:
            self._cert_client.delete_certificate_operation(cert_request.id)
        except HttpResponseError as e:
            if e.status_code != 404:
                raise

        # Generate a new certificate key
        self._cert_client.begin_create_certificate(
            certificate_name=cert_request.id,
            policy=self._cert_policy_from_request(cert_request),
        )

        # Get the CSR for the new certificate key
        cert_op = self._cert_client.get_certificate_operation(cert_request.id)
        return _csr_bytes_to_pem(cert_op.csr)

    def _create_acme_client(self) -> acme_client.ClientV2:
        net = acme_client.ClientNetwork(key=josepy.JWKRSA(
            key=rsa.generate_private_key(
                public_exponent=ACME_RSA_PUBLIC_EXPONENT,
                key_size=ACME_RSA_KEY_SIZE,
                backend=default_backend(),
            )))
        directory = acme_messages.Directory.from_json(
            net.get(self.options.acme_directory_url).json())
        return acme_client.ClientV2(directory=directory, net=net)

    def _authorization_to_relative_domain(
        self,
        authorization: acme_messages.Authorization,
    ):
        relative_domain = authorization.identifier.value.replace(
            f".{self.options.azure_dns_zone}", "")
        return f"_acme-challenge.{relative_domain}"

    def _setup_dns_challenge(
        self,
        record_set_name: str,
        value: str,
    ):
        self._dns_client.record_sets.create_or_update(
            resource_group_name=self.options.azure_dns_zone_resource_group,
            zone_name=self.options.azure_dns_zone,
            relative_record_set_name=record_set_name,
            record_type=RecordType.TXT,
            parameters=RecordSet(
                ttl=ACME_CHALLENGE_TXT_RECORD_TTL,
                txt_records=[TxtRecord(value=[value])],
            ),
        )

    def _teardown_dns_challenge(
        self,
        record_set_name: str,
    ):
        self._dns_client.record_sets.delete(
            resource_group_name=self.options.azure_dns_zone_resource_group,
            zone_name=self.options.azure_dns_zone,
            relative_record_set_name=record_set_name,
            record_type=RecordType.TXT,
        )

    def _verify_challenge(
        self,
        client: acme_client.ClientV2,
        authz: acme_messages.Authorization,
    ):
        challenge_body, dns_challenge = _find_dns_challenge(authz)
        validation_domain = self._authorization_to_relative_domain(authz)
        response, validation_value = dns_challenge.response_and_validation(
            client.net.key)

        try:
            logging.debug(
                "Creating DNS challenge: '%s' = '%s'",
                validation_domain,
                validation_value,
            )
            self._setup_dns_challenge(
                record_set_name=validation_domain,
                value=validation_value,
            )
            logging.debug("Answering DNS challenge")
            client.answer_challenge(challenge_body, response)
        finally:
            logging.debug("Tearing down DNS challenge: '%s'",
                          validation_domain)
            self._teardown_dns_challenge(record_set_name=validation_domain, )

    def _order_certificate(self, csr: str) -> str:
        logging.debug("Registering an account")
        client = self._create_acme_client()
        client.new_account(
            acme_messages.NewRegistration.from_data(
                email=self.options.acme_contact_email,
                terms_of_service_agreed=True,
            ))

        logging.debug("Placing a new order")
        order = client.new_order(csr)

        logging.debug("Verifying challenges")
        authz_resource: acme_messages.AuthorizationResource
        for authz_resource in order.authorizations:  # pylint: disable=not-an-iterable
            authz: acme_messages.Authorization = authz_resource.body
            self._verify_challenge(client, authz)

        logging.debug("Finalizing order")
        finalized_order = client.poll_and_finalize(order)
        return finalized_order.fullchain_pem

    def _store_certificate(self, cert_request: CertRequest, certificate: str):
        self._cert_client.merge_certificate(
            cert_request.id,
            [certificate.encode("ascii")],
        )

    def provision(self, cert_request: CertRequest):
        """
        Provision a TLS certificate based on the given certificate request.

        This can be used for both provisioning new certificates as well as renewing existing ones.
        """
        logging.info("Generating a certificate key %s", cert_request.id)
        csr = self._generate_new_cert_key(cert_request)

        logging.info("Ordering a certificate for %s", cert_request.common_name)
        certificate = self._order_certificate(csr)

        logging.info("Storing certificate %s", cert_request.id)
        self._store_certificate(cert_request, certificate)

    def save(self, cert_request: CertRequest):
        """
        Saves the certificate request details for certificate renewal purposes
        """
        blob_client = self._container_client.get_blob_client(
            cert_request.blob_name)
        blob_client.upload_blob(data=cert_request.json(), overwrite=True)

    def _cert_needs_renewal(self, certificate: CertificateProperties) -> bool:
        # Disabled certificates are not renewed
        if not certificate.enabled:
            return False

        now = datetime.now(tz=timezone.utc)
        time_until_expiration = certificate.expires_on.replace() - now
        return time_until_expiration.days <= self.options.cert_expiry_threshold_days

    def find_cert_names_needing_renewal(self) -> typing.List[str]:
        """
        Find all the names of the certificates from Key Vault that need to be renewed
        """
        certificates = self._cert_client.list_properties_of_certificates()
        return [
            certificate.name for certificate in certificates
            if self._cert_needs_renewal(certificate)
        ]

    def find_certs_needing_renewal(self) -> typing.List[CertRequest]:
        """
        Find all the certificates from Key Vault that need to be renewed
        """
        cert_requests: typing.List[CertRequest] = []
        for cert_name in self.find_cert_names_needing_renewal():
            blob_client = self._container_client.get_blob_client(
                _cert_name_to_blob_name(cert_name))
            try:
                bytes = blob_client.download_blob().readall()
                cert_request = CertRequest.parse_raw(bytes)
                cert_requests.append(cert_request)
            except HttpResponseError as e:
                if e.status_code == 404:
                    logging.error(
                        "Could not find cert request document for certificate %s",
                        cert_name,
                    )
                else:
                    raise

        return cert_requests
Esempio n. 2
0
def create_or_update_cert(
        kv_cert_name,
        *domains,
        use_prod=False,
        keyvault_url='https://ponti-certs-kvjwxwal2p6n.vault.azure.net/',
        dns_zone_resource_group='damienpontifex.com-rg',
        dns_zone_name='damienpontifex.com',
        registration_email='*****@*****.**',
        dns_subscription_id="fa2cbf67-1293-4f9c-8884-a0379a9e0c64"):

    # Get directory
    if use_prod:
        directory_url = 'https://acme-v02.api.letsencrypt.org/directory'
        user_key_name = 'acme'
        issuance_period_months = 3
    else:
        directory_url = 'https://acme-staging-v02.api.letsencrypt.org/directory'
        user_key_name = 'acme-staging'
        issuance_period_months = 1

    credential = DefaultAzureCredential()

    challenge_handler = functools.partial(
        dns_challenge_handler,
        credential=credential,
        subscription_id=dns_subscription_id,
        dns_zone_resource_group=dns_zone_resource_group,
        dns_zone_name=dns_zone_name)

    cert_client = CertificateClient(vault_url=keyvault_url,
                                    credential=credential)

    #%%
    key = KeyVaultRSAKey(credential, keyvault_url, user_key_name)

    account_key = josepy.JWKRSA(key=key)
    client_network = acme.client.ClientNetwork(account_key)

    directory = messages.Directory.from_json(
        client_network.get(directory_url).json())

    client = acme.client.ClientV2(directory, client_network)

    new_regr = acme.messages.Registration.from_data(
        key=account_key,
        email=registration_email,
        terms_of_service_agreed=True)

    # Register or fetch account
    try:
        regr = client.new_account(new_regr)
        logger.info('Created new account')
    except acme.errors.ConflictError as e:
        regr = acme.messages.RegistrationResource(uri=e.location,
                                                  body=new_regr)
        regr = client.query_registration(regr)
        logger.info('Got existing account')

    cert_policy = CertificatePolicy(
        issuer_name='Unknown',
        subject_name=f'CN={domains[0]}',
        exportable=True,
        key_type=KeyType.rsa,
        key_size=2048,
        content_type=CertificateContentType.pkcs12,
        san_dns_names=domains[1:] if len(domains) > 1 else [],
        validity_in_months=issuance_period_months)

    try:
        # Check an existing certificate operation isn't in progress
        cert_op = cert_client.get_certificate_operation(
            certificate_name=kv_cert_name)
        logger.info('Existing cert operation in progress')
    except ResourceNotFoundError:
        cert_op = cert_client.begin_create_certificate(
            certificate_name=kv_cert_name, policy=cert_policy)
        logger.info('New cert operation')

    # cert_op = kvclient.create_certificate(KEYVAULT_URL, certificate_name=kv_cert_name, certificate_policy=cert_policy)
    cert_op_res = cert_op.result()
    cert_op_r = cert_client.get_certificate_operation(kv_cert_name)

    logger.info('Created certificate request in key vault')

    # Wrap with header and footer for pem to show certificate request
    csr_pem = "-----BEGIN CERTIFICATE REQUEST-----\n" + base64.b64encode(
        cert_op_r.csr).decode() + "\n-----END CERTIFICATE REQUEST-----\n"

    # Submit order
    order_resource = client.new_order(csr_pem)
    logger.info('Submitted order')

    # Challenges from order
    # Respond to challenges
    challenges_to_respond_to = list(
        challenge_handler(authorizations=order_resource.authorizations,
                          account_key=account_key))

    for dns_challenge in challenges_to_respond_to:
        # Perform challenge
        auth_response = client.answer_challenge(
            dns_challenge, dns_challenge.chall.response(account_key))

    logger.info('Answered challenges')

    # Poll for status
    # Finalize order
    # Download certificate
    final_order = client.poll_and_finalize(order_resource)

    logger.info('Finalised order')

    # Strip header and footer of BEGIN/END CERTIFICATE
    # with open('cert.pem', 'w') as f:
    #     f.write(final_order.fullchain_pem)

    certificate_vals = [
        val.replace('\n', '').encode()
        for val in final_order.fullchain_pem.split('-----')
        if 'CERTIFICATE' not in val and len(val.replace('\n', '')) != 0
    ]

    cert_client.merge_certificate(name=kv_cert_name,
                                  x509_certificates=certificate_vals)

    logger.info('Merged certificate back to key vault')