コード例 #1
0
 def load(self) -> None:
     with self.__load_lock:
         if getattr(ArgumentParser.args, "ca_cert", None) is not None:
             log.debug(
                 f"Loading CA certificate from {ArgumentParser.args.ca_cert}"
             )
             self.__ca_cert = load_cert_from_file(
                 ArgumentParser.args.ca_cert)
         else:
             log.debug("Loading CA cert from core")
             try:
                 self.__ca_cert = get_ca_cert(
                     resotocore_uri=self.__resotocore_uri, psk=self.__psk)
             except FingerprintError as e:
                 log.fatal(f"{e}, MITM attack?")
                 raise
             except InvalidSignatureError as e:
                 log.fatal(f"{e}, wrong PSK?")
                 raise
             except NoJWTError as e:
                 log.fatal(f"{e}, resotocore started without PSK?")
                 raise
             except Exception as e:
                 log.fatal(f"{e}")
                 raise
         log.debug(f"Writing CA cert {self.__ca_cert_path}")
         write_ca_bundle(self.__ca_cert,
                         self.__ca_cert_path,
                         include_certifi=True)
         if not self.__ca_only:
             if (getattr(ArgumentParser.args, "cert", None) is not None
                     and getattr(ArgumentParser.args, "cert_key",
                                 None) is not None):
                 log.debug(
                     f"Loading certificate from {ArgumentParser.args.cert}")
                 self.__cert = load_cert_from_file(ArgumentParser.args.cert)
                 cert_key_pass = None
                 if getattr(ArgumentParser.args, "cert_key_pass",
                            None) is not None:
                     cert_key_pass = ArgumentParser.args.cert_key_pass
                 log.debug(
                     f"Loading key from {ArgumentParser.args.cert_key}")
                 self.__key = load_key_from_file(
                     ArgumentParser.args.cert_key, passphrase=cert_key_pass)
             else:
                 log.debug("Requesting signed cert from core")
                 self.__key, self.__cert = get_signed_cert(
                     common_name=self.common_name,
                     san_dns_names=self.san_dns_names,
                     san_ip_addresses=self.san_ip_addresses,
                     resotocore_uri=self.__resotocore_uri,
                     psk=self.__psk,
                     ca_cert_path=self.ca_cert_path,
                 )
             log.debug(f"Writing signed cert/key {self.__cert_path}")
             write_cert_to_file(self.__cert, self.__cert_path)
             write_key_to_file(self.__key, self.__key_path)
         self.__loaded.set()
コード例 #2
0
 def __refresh_files_on_disk(self, refresh_every_sec: int = 10800) -> None:
     if not self.__loaded.is_set():
         return
     try:
         last_ca_cert_update = time.time() - os.path.getmtime(
             self.__ca_cert_path)
         last_cert_update = time.time() - os.path.getmtime(self.__cert_path)
         if last_ca_cert_update > refresh_every_sec or last_cert_update > refresh_every_sec:
             log.debug("Refreshing cert/key files on disk")
             write_ca_bundle(self.__ca_cert,
                             self.__ca_cert_path,
                             include_certifi=True)
             write_cert_to_file(self.__cert, self.__cert_path)
             write_key_to_file(self.__key, self.__key_path)
     except FileNotFoundError:
         pass
コード例 #3
0
def test_x509():
    with tempfile.TemporaryDirectory() as tmp:
        ca_key, ca_cert = bootstrap_ca()
        cert_key = gen_rsa_key()
        gen_csr(cert_key)  # dummy call to generate CSR without SANs
        cert_csr = gen_csr(
            cert_key,
            san_dns_names=["example.com"],
            san_ip_addresses=["10.0.1.1", "10.0.0.0/24"],
        )
        cert_crt = sign_csr(cert_csr, ca_key, ca_cert)
        ca_key_path = os.path.join(tmp, "ca.key")
        ca_cert_path = os.path.join(tmp, "ca.crt")

        cert_key_path = os.path.join(tmp, "cert.key")
        cert_key_passphrase = "foobar"
        cert_csr_path = os.path.join(tmp, "cert.csr")
        cert_crt_path = os.path.join(tmp, "cert.crt")

        write_key_to_file(ca_key, key_path=ca_key_path)
        write_cert_to_file(ca_cert, cert_path=ca_cert_path)

        write_key_to_file(cert_key,
                          key_path=cert_key_path,
                          passphrase=cert_key_passphrase)
        write_csr_to_file(cert_csr, csr_path=cert_csr_path)
        write_cert_to_file(cert_crt, cert_path=cert_crt_path)

        loaded_ca_key = load_key_from_file(ca_key_path)
        loaded_ca_cert = load_cert_from_file(ca_cert_path)
        loaded_cert_key = load_key_from_file(cert_key_path,
                                             passphrase=cert_key_passphrase)
        loaded_cert_csr = load_csr_from_file(cert_csr_path)
        loaded_cert_crt = load_cert_from_file(cert_crt_path)

        assert loaded_ca_cert == ca_cert
        assert loaded_cert_csr == cert_csr
        assert loaded_cert_crt == cert_crt
        assert cert_fingerprint(loaded_ca_cert) == cert_fingerprint(ca_cert)
        assert cert_fingerprint(loaded_cert_crt) == cert_fingerprint(cert_crt)
        assert key_to_bytes(ca_key) == key_to_bytes(loaded_ca_key)
        assert key_to_bytes(cert_key) == key_to_bytes(loaded_cert_key)
コード例 #4
0
def disabled_test_secure_web():
    with tempfile.TemporaryDirectory() as tmp:
        ca_key, ca_cert = bootstrap_ca()
        cert_key = gen_rsa_key()
        cert_csr = gen_csr(cert_key, common_name="localhost")
        cert_crt = sign_csr(cert_csr, ca_key, ca_cert)
        ca_cert_path = os.path.join(tmp, "ca.crt")
        cert_key_path = os.path.join(tmp, "cert.key")
        cert_crt_path = os.path.join(tmp, "cert.crt")

        write_cert_to_file(ca_cert, cert_path=ca_cert_path)
        write_key_to_file(cert_key, key_path=cert_key_path)
        write_cert_to_file(cert_crt, cert_path=cert_crt_path)

        free_port = get_free_port()
        print(f"Starting https webserver on port {free_port}")
        web_server = WebServer(WebApp(),
                               web_port=free_port,
                               ssl_cert=cert_crt_path,
                               ssl_key=cert_key_path)
        web_server.daemon = True
        web_server.start()
        start_time = time.time()
        while not web_server.serving:
            if time.time() - start_time > 10:
                raise RuntimeError("timeout waiting for web server start")
            time.sleep(0.1)

        endpoint = f"https://localhost:{free_port}"
        r = requests.get(f"{endpoint}/health", verify=ca_cert_path)
        assert r.text == "ok\r\n"
        web_server.shutdown()
        while web_server.is_alive():
            print("Waiting for web server to shutdown")
            time.sleep(1)
        cherrypy.engine
コード例 #5
0
 def __create_host_context(
     config: CoreConfig, host_cert: Certificate, host_key: RSAPrivateKey
 ) -> Optional[SSLContext]:
     args = config.args
     if args.no_tls:
         log.info("TLS disabled.")
         return None
     else:
         # noinspection PyTypeChecker
         ctx = create_default_context(purpose=Purpose.CLIENT_AUTH)
         if config.args.cert:
             log.info("Using TLS certificate from command line.")
             # Use the certificate provided via cmd line flags
             ctx.load_cert_chain(args.cert, args.cert_key, args.cert_key_pass)
         else:
             log.info("Using TLS certificate from data store.")
             # ssl library wants to load cert/key from file: put it into a temp directory for loading
             with TemporaryDirectory() as td:
                 cert_file = Path(td, "cert")
                 key_file = Path(td, "key")
                 write_cert_to_file(host_cert, str(cert_file))
                 write_key_to_file(host_key, str(key_file))
                 ctx.load_cert_chain(str(cert_file), str(key_file), args.ca_cert_key_pass)
         return ctx