Example #1
0
    def test_exception_when_processing_jobs(self):
        # Given a server to scan
        server_scan = ServerScanRequest(
            server_info=ServerConnectivityInfoFactory.create(),
            scan_commands={
                ScanCommandForTests.MOCK_COMMAND_1,
                # And one of the scan commands will trigger an exception when processing the completed scan jobs
                ScanCommandForTests.
                MOCK_COMMAND_EXCEPTION_WHEN_PROCESSING_JOBS,
            },
        )

        # When queuing the scan
        scanner = Scanner()
        scanner.queue_scan(server_scan)

        # It succeeds
        all_results = []
        for result in scanner.get_results():
            all_results.append(result)

            assert result.server_info == server_scan.server_info
            assert result.scan_commands == server_scan.scan_commands
            assert result.scan_commands_extra_arguments == server_scan.scan_commands_extra_arguments
            assert len(result.scan_commands_results) == 1

            # And the exception was properly caught and returned
            assert len(result.scan_commands_errors) == 1
            error = result.scan_commands_errors[
                ScanCommandForTests.
                MOCK_COMMAND_EXCEPTION_WHEN_PROCESSING_JOBS]
            assert ScanCommandErrorReasonEnum.BUG_IN_SSLYZE == error.reason
            assert error.exception_trace

        assert len(all_results) == 1
Example #2
0
    def test(self, mock_scan_commands):
        # Given a server to scan
        server_scan = ServerScanRequest(
            server_info=ServerConnectivityInfoFactory.create(),
            scan_commands={ScanCommandForTests.MOCK_COMMAND_1, ScanCommandForTests.MOCK_COMMAND_2},
        )

        # When queuing the scan
        scanner = Scanner()
        scanner.queue_scan(server_scan)

        # It succeeds
        all_results = []
        for result in scanner.get_results():
            all_results.append(result)

            # And the right result is returned
            assert result.server_info == server_scan.server_info
            assert result.scan_commands == server_scan.scan_commands
            assert result.scan_commands_extra_arguments == server_scan.scan_commands_extra_arguments
            assert len(result.scan_commands_results) == 2

            assert type(result.scan_commands_results[ScanCommandForTests.MOCK_COMMAND_1]) == MockPlugin1ScanResult
            assert type(result.scan_commands_results[ScanCommandForTests.MOCK_COMMAND_2]) == MockPlugin2ScanResult

        assert len(all_results) == 1
Example #3
0
    def test_error_server_connectivity_issue_handshake_timeout(self, mock_scan_commands):
        # Given a server to scan with some commands
        server_scan = ServerScanRequest(
            server_info=ServerConnectivityInfoFactory.create(),
            scan_commands={ScanCommandForTests.MOCK_COMMAND_1, ScanCommandForTests.MOCK_COMMAND_2},
        )

        # And the first scan command will trigger a handshake timeout with the server
        with mock.patch.object(
            MockPlugin1Implementation,
            "_scan_job_work_function",
            side_effect=TlsHandshakeTimedOut(
                server_location=server_scan.server_info.server_location,
                network_configuration=server_scan.server_info.network_configuration,
                error_message="error",
            ),
        ):
            # When queuing the scan
            scanner = Scanner()
            scanner.queue_scan(server_scan)

        # It succeeds
        for result in scanner.get_results():
            # And the error was properly caught and returned
            assert len(result.scan_commands_errors) == 1
            error = result.scan_commands_errors[ScanCommandForTests.MOCK_COMMAND_1]
            assert ScanCommandErrorReasonEnum.CONNECTIVITY_ISSUE == error.reason
            assert error.exception_trace
Example #4
0
    def test_with_extra_arguments(self, mock_scan_commands):
        # Given a server to scan with a scan command
        server_scan = ServerScanRequest(
            server_info=ServerConnectivityInfoFactory.create(),
            scan_commands={ScanCommandForTests.MOCK_COMMAND_1},
            # And the command takes an extra argument
            scan_commands_extra_arguments={
                ScanCommandForTests.MOCK_COMMAND_1:
                MockPlugin1ExtraArguments(extra_field="test")
            },
        )

        # When running the scan
        scanner = Scanner()
        scanner.start_scans([server_scan])

        # It succeeds
        all_results = []
        for result in scanner.get_results():
            all_results.append(result)
        assert len(all_results) == 1

        # And the extra argument was taken into account
        assert all_results[
            0].scan_commands_extra_arguments == server_scan.scan_commands_extra_arguments
Example #5
0
    def test(self, mock_scan_commands):
        # Given a lot of servers to scan
        total_server_scans_count = 100
        server_scans = [
            ServerScanRequest(
                server_info=ServerConnectivityInfoFactory.create(),
                scan_commands={ScanCommandForTests.MOCK_COMMAND_1, ScanCommandForTests.MOCK_COMMAND_2},
            )
            for _ in range(total_server_scans_count)
        ]

        # And a scanner with specifically chosen network settings
        per_server_concurrent_connections_limit = 4
        concurrent_server_scans_limit = 20
        scanner = Scanner(per_server_concurrent_connections_limit, concurrent_server_scans_limit)

        # When queuing the scans, it succeeds
        for scan in server_scans:
            scanner.queue_scan(scan)

        # And the right number of scans was performed
        assert total_server_scans_count == len(scanner._queued_server_scans)

        # And the chosen network settings were used
        assert concurrent_server_scans_limit == len(scanner._thread_pools)
        for pool in scanner._thread_pools:
            assert per_server_concurrent_connections_limit == pool._max_workers

        # And the server scans were evenly distributed among the thread pools to maximize performance
        expected_server_scans_per_pool = int(total_server_scans_count / concurrent_server_scans_limit)
        thread_pools_used = [server_scan.queued_on_thread_pool_at_index for server_scan in scanner._queued_server_scans]
        server_scans_per_pool_count = Counter(thread_pools_used)
        for pool_count in server_scans_per_pool_count.values():
            assert expected_server_scans_per_pool == pool_count
Example #6
0
    def test(self, mock_scan_commands):
        # Given a server to scan
        server_scan = ServerScanRequest(
            server_info=ServerConnectivityInfoFactory.create(),
            scan_commands={
                ScanCommandForTests.MOCK_COMMAND_1,
                ScanCommandForTests.MOCK_COMMAND_2
            },
        )

        # When running the scan
        scanner = Scanner()
        scanner.start_scans([server_scan])

        # It succeeds
        all_results = []
        for result in scanner.get_results():
            all_results.append(result)
        assert len(all_results) == 1

        # And the right result is returned
        result = all_results[0]
        assert result.server_info == server_scan.server_info
        assert result.scan_commands == server_scan.scan_commands
        assert result.scan_commands_extra_arguments == server_scan.scan_commands_extra_arguments
        assert len(result.scan_commands_results) == 2

        assert type(result.scan_commands_results[
            ScanCommandForTests.MOCK_COMMAND_1]) == MockPlugin1ScanResult
        assert type(result.scan_commands_results[
            ScanCommandForTests.MOCK_COMMAND_2]) == MockPlugin2ScanResult

        # And the Scanner instance is all done and cleaned up
        assert not scanner._are_server_scans_ongoing
Example #7
0
    def test_with_extra_arguments(self):
        # Given a server to scan
        server_scan = ServerScanRequest(
            server_info=ServerConnectivityInfoFactory.create(),
            scan_commands={ScanCommandForTests.MOCK_COMMAND_1},
            # With an extra argument for one command
            scan_commands_extra_arguments={
                ScanCommandForTests.MOCK_COMMAND_1:
                MockPlugin1ExtraArguments(extra_field="test")
            },
        )

        # When queuing the scan
        scanner = Scanner()
        scanner.queue_scan(server_scan)

        # It succeeds
        all_results = []
        for result in scanner.get_results():
            all_results.append(result)

            # And the extra argument was taken into account
            assert result.scan_commands_extra_arguments == server_scan.scan_commands_extra_arguments

        assert len(all_results) == 1
Example #8
0
    def test_error_bug_in_sslyze_when_processing_job_results(
            self, mock_scan_commands):
        # Given a server to scan with some scan commands
        server_scan = ServerScanRequest(
            server_info=ServerConnectivityInfoFactory.create(),
            scan_commands={
                ScanCommandForTests.MOCK_COMMAND_1,
                ScanCommandForTests.MOCK_COMMAND_2
            },
        )

        # And the first scan command will trigger an error when processing the completed scan jobs
        with mock.patch.object(MockPlugin1Implementation,
                               "_scan_job_work_function",
                               side_effect=RuntimeError):
            # When running the scan
            scanner = Scanner()
            scanner.start_scans([server_scan])

            # It succeeds
            all_results = []
            for result in scanner.get_results():
                all_results.append(result)
            assert len(all_results) == 1

            # And the exception was properly caught and returned
            result = all_results[0]
            assert len(result.scan_commands_errors) == 1
            error = result.scan_commands_errors[
                ScanCommandForTests.MOCK_COMMAND_1]
            assert ScanCommandErrorReasonEnum.BUG_IN_SSLYZE == error.reason
            assert error.exception_trace
Example #9
0
    def test_error_client_certificate_needed(self):
        # Given a server that requires client authentication
        with LegacyOpenSslServer(client_auth_config=ClientAuthConfigEnum.REQUIRED) as server:
            # And sslyze does NOT provide a client certificate
            server_location = ServerNetworkLocationViaDirectConnection(
                hostname=server.hostname, ip_address=server.ip_address, port=server.port
            )
            server_info = ServerConnectivityTester().perform(server_location)

            server_scan = ServerScanRequest(
                server_info=server_info,
                scan_commands={
                    # And a scan command that cannot be completed without a client certificate
                    ScanCommand.HTTP_HEADERS,
                },
            )

            # When queuing the scan
            scanner = Scanner()
            scanner.queue_scan(server_scan)

            # It succeeds
            all_results = []
            for result in scanner.get_results():
                all_results.append(result)

            assert len(all_results) == 1

            # And the error was properly returned
            error = all_results[0].scan_commands_errors[ScanCommand.HTTP_HEADERS]
            assert error.reason == ScanCommandErrorReasonEnum.CLIENT_CERTIFICATE_NEEDED
Example #10
0
def main() -> None:
    global global_scanner

    # For py2exe builds
    freeze_support()

    # Handle SIGINT to terminate processes
    signal.signal(signal.SIGINT, sigint_handler)
    start_time = time()

    # Create the command line parser and the list of available options
    sslyze_parser = CommandLineParser(__version__)
    try:
        parsed_command_line = sslyze_parser.parse_command_line()
    except CommandLineParsingError as e:
        print(e.get_error_msg())
        return

    output_hub = OutputHub()
    output_hub.command_line_parsed(parsed_command_line)

    global_scanner = Scanner(
        per_server_concurrent_connections_limit=parsed_command_line.per_server_concurrent_connections_limit,
        concurrent_server_scans_limit=parsed_command_line.concurrent_server_scans_limit,
    )

    # Figure out which hosts are up and fill the task queue with work to do
    connectivity_tester = ServerConnectivityTester()
    with ThreadPoolExecutor(max_workers=10) as thread_pool:
        futures = [
            thread_pool.submit(connectivity_tester.perform, server_location, network_config)
            for server_location, network_config in parsed_command_line.servers_to_scans
        ]
        for completed_future in as_completed(futures):
            try:
                server_connectivity_info = completed_future.result()
                output_hub.server_connectivity_test_succeeded(server_connectivity_info)

                # Send scan commands for this server to the scanner
                scan_request = ServerScanRequest(
                    server_info=server_connectivity_info,
                    scan_commands=parsed_command_line.scan_commands,
                    scan_commands_extra_arguments=parsed_command_line.scan_commands_extra_arguments,
                )
                global_scanner.queue_scan(scan_request)

            except ConnectionToServerFailed as e:
                output_hub.server_connectivity_test_failed(e)

    output_hub.scans_started()

    # Process the results as they come
    for scan_result in global_scanner.get_results():
        output_hub.server_scan_completed(scan_result)

    # All done
    exec_time = time() - start_time
    output_hub.scans_completed(exec_time)
Example #11
0
def main() -> None:
    # First validate that we can connect to the servers we want to scan
    servers_to_scan = []
    for hostname in ["cloudflare.com", "google.com"]:
        server_location = ServerNetworkLocationViaDirectConnection.with_ip_address_lookup(hostname, 443)
        try:
            server_info = ServerConnectivityTester().perform(server_location)
            servers_to_scan.append(server_info)
        except ConnectionToServerFailed as e:
            print(f"Error connecting to {server_location.hostname}:{server_location.port}: {e.error_message}")
            return

    scanner = Scanner()

    # Then queue some scan commands for each server
    for server_info in servers_to_scan:
        server_scan_req = ServerScanRequest(
            server_info=server_info,
            scan_commands={
                ScanCommandEnum.TLS_1_0_CIPHER_SUITES,
                ScanCommandEnum.TLS_1_1_CIPHER_SUITES,
                ScanCommandEnum.TLS_1_2_CIPHER_SUITES,
                ScanCommandEnum.CERTIFICATE_INFO,
                ScanCommandEnum.TLS_COMPRESSION,
            },
        )
        scanner.queue_scan(server_scan_req)

    # Then retrieve the result of the scan commands for each server
    for server_scan_result in scanner.get_results():
        print(f"\nResults for {server_scan_result.server_info.server_location.hostname}:")

        # Scan commands that were run with no errors
        for scan_command, result in server_scan_result.scan_commands_results.items():
            if scan_command in [
                ScanCommandEnum.TLS_1_0_CIPHER_SUITES,
                ScanCommandEnum.TLS_1_1_CIPHER_SUITES,
                ScanCommandEnum.TLS_1_2_CIPHER_SUITES,
            ]:
                typed_result = cast(CipherSuitesScanResult, result)
                print(f"\nAccepted cipher suites for {scan_command.name}:")
                for accepted_cipher_suite in typed_result.accepted_cipher_suites:
                    print(f"* {accepted_cipher_suite.cipher_suite.name}")

            elif scan_command == ScanCommandEnum.CERTIFICATE_INFO:
                typed_result = cast(CertificateInfoScanResult, result)
                print("\nCertificate info:")
                for cert_deployment in typed_result.certificate_deployments:
                    print(f"Leaf certificate: \n{cert_deployment.verified_certificate_chain_as_pem[0]}")

            elif scan_command == ScanCommandEnum.TLS_COMPRESSION:
                typed_result = cast(CompressionScanResult, result)
                print(f"\nCompression / CRIME: {typed_result.supports_compression}")

        # Scan commands that were run with errors
        for scan_command, error in server_scan_result.scan_commands_errors.items():
            print(f"\nError when running {scan_command}:\n{error.exception_trace}")
Example #12
0
def main() -> None:
    start_time = time()

    # Create the command line parser and the list of available options
    sslyze_parser = CommandLineParser(__version__)
    try:
        # Parse the supplied command line
        parsed_command_line = sslyze_parser.parse_command_line()
    except CommandLineParsingError as e:
        print(e.get_error_msg())
        return

    output_hub = OutputHub()
    output_hub.command_line_parsed(parsed_command_line)

    # Figure out which servers are reachable
    connectivity_tester = ServerConnectivityTester()
    all_server_scan_requests = []
    with ThreadPoolExecutor(max_workers=10) as thread_pool:
        futures = [
            thread_pool.submit(connectivity_tester.perform, server_location, network_config)
            for server_location, network_config in parsed_command_line.servers_to_scans
        ]
        for completed_future in as_completed(futures):
            try:
                server_connectivity_info = completed_future.result()
                output_hub.server_connectivity_test_succeeded(server_connectivity_info)

                # Server is only; add it to the list of servers to scan
                scan_request = ServerScanRequest(
                    server_info=server_connectivity_info,
                    scan_commands=parsed_command_line.scan_commands,
                    scan_commands_extra_arguments=parsed_command_line.scan_commands_extra_arguments,
                )
                all_server_scan_requests.append(scan_request)

            except ConnectionToServerFailed as e:
                output_hub.server_connectivity_test_failed(e)

    # For the servers that are reachable, start the scans
    output_hub.scans_started()
    if all_server_scan_requests:
        sslyze_scanner = Scanner(
            per_server_concurrent_connections_limit=parsed_command_line.per_server_concurrent_connections_limit,
            concurrent_server_scans_limit=parsed_command_line.concurrent_server_scans_limit,
        )
        sslyze_scanner.start_scans(all_server_scan_requests)

        # Process the results as they come
        for scan_result in sslyze_scanner.get_results():
            output_hub.server_scan_completed(scan_result)

    # All done
    exec_time = time() - start_time
    output_hub.scans_completed(exec_time)
Example #13
0
    def __init__(
        self,
        domain: str,
        target_profile: str,
        ca_file: Optional[str] = None,
        cert_expire_warning: int = 15,
    ) -> None:
        """
        :param domain:
        :param target_profile: One of [old|intermediate|modern]
        :param ca_file: Path to a trusted custom root certificates in PEM format.
        :param cert_expire_warning: A warning is issued if the certificate expires in less days than specified.
        """
        self.scan_commands_extra_args = {}
        if ca_file:
            ca_path = Path(ca_file)
            self.scan_commands_extra_args[
                ScanCommand.CERTIFICATE_INFO] = CertificateInfoExtraArguments(
                    ca_path)

        self.cert_expire_warning = cert_expire_warning

        if TLSProfiler.PROFILES is None:
            TLSProfiler.PROFILES = requests.get(self.PROFILES_URL).json()
            log.info(
                f"Loaded version {TLSProfiler.PROFILES['version']} of the Mozilla TLS configuration recommendations."
            )

        self.target_profile = TLSProfiler.PROFILES["configurations"][
            target_profile]
        self.target_profile["tls_curves"] = self._get_equivalent_curves(
            self.target_profile["tls_curves"])
        self.target_profile[
            "certificate_curves_preprocessed"] = self._get_equivalent_curves(
                self.target_profile["certificate_curves"])

        server_location = (
            ServerNetworkLocationViaDirectConnection.with_ip_address_lookup(
                domain, 443))
        self.scanner = Scanner()
        try:
            log.info(
                f"Testing connectivity with {server_location.hostname}:{server_location.port}..."
            )
            self.server_info = ServerConnectivityTester().perform(
                server_location)
            self.server_error = None
        except ConnectionToServerFailed as e:
            # Could not establish an SSL connection to the server
            log.warning(
                f"Could not connect to {e.server_location.hostname}: {e.error_message}"
            )
            self.server_error = e.error_message
            self.server_info = None
Example #14
0
    def __init__(
        self,
        domain: str,
        ca_file: Optional[str] = None,
        cert_expire_warning: int = 15,
    ) -> None:
        """
        :param domain: the domain name of the target server
        :param ca_file: Path to a trusted custom root certificates in PEM format.
        :param cert_expire_warning: A warning is issued if the certificate expires in less days than specified.
        """
        self._comparator = None
        self._validation_errors = None
        self._vulnerability_errors = None
        self._server_scan_result = None

        self._scan_commands_extra_args = {}
        if ca_file:
            ca_path = Path(ca_file)
            self._scan_commands_extra_args[
                ScanCommand.CERTIFICATE_INFO] = CertificateInfoExtraArguments(
                    ca_path)

        self._cert_expire_warning = cert_expire_warning

        server_location = (
            ServerNetworkLocationViaDirectConnection.with_ip_address_lookup(
                domain, 443))
        self._scanner = Scanner()
        try:
            log.info(
                f"Testing connectivity with {server_location.hostname}:{server_location.port}..."
            )
            self._server_info = ServerConnectivityTester().perform(
                server_location)
            self.server_error = None
        except ConnectionToServerFailed as e:
            # Could not establish an SSL connection to the server
            log.warning(
                f"Could not connect to {e.server_location.hostname}: {e.error_message}"
            )
            self.server_error = e.error_message
            self._server_info = None
Example #15
0
    def test_error_bug_in_sslyze_when_scheduling_jobs(self, mock_scan_commands):
        # Given a server to scan with some scan commands
        server_scan = ServerScanRequest(
            server_info=ServerConnectivityInfoFactory.create(),
            scan_commands={ScanCommandForTests.MOCK_COMMAND_1, ScanCommandForTests.MOCK_COMMAND_2},
        )

        # And the first scan command will trigger an error when generating scan jobs
        with mock.patch.object(MockPlugin1Implementation, "scan_jobs_for_scan_command", side_effect=RuntimeError):
            # When queuing the scan
            scanner = Scanner()
            scanner.queue_scan(server_scan)

        # It succeeds
        for result in scanner.get_results():
            # And the exception was properly caught and returned
            assert len(result.scan_commands_errors) == 1
            error = result.scan_commands_errors[ScanCommandForTests.MOCK_COMMAND_1]
            assert ScanCommandErrorReasonEnum.BUG_IN_SSLYZE == error.reason
            assert error.exception_trace
Example #16
0
    def test_enforces_per_server_concurrent_connections_limit(
            self, mock_scan_commands):
        # Given a server to scan with a scan command that requires multiple connections/jobs to the server
        server_scan = ServerScanRequest(
            server_info=ServerConnectivityInfoFactory.create(),
            scan_commands={ScanCommandForTests.MOCK_COMMAND_1},
        )

        # And a scanner configured to only perform one concurrent connection per server scan
        scanner = Scanner(per_server_concurrent_connections_limit=1)

        # And the scan command will notify us when more than one connection is being performed concurrently
        # Test internals: setup plumbing to detect when more than one thread are running at the same time
        # We use a Barrier that waits for 2 concurrent threads, and puts True in a queue if that ever happens
        queue = Queue()

        def flag_concurrent_threads_running():
            # Only called when two threads are running at the same time
            queue.put(True)

        barrier = threading.Barrier(parties=2,
                                    action=flag_concurrent_threads_running,
                                    timeout=1)

        def scan_job_work_function(arg1: str, arg2: int):
            barrier.wait()

        with mock.patch.object(MockPlugin1Implementation,
                               "_scan_job_work_function",
                               scan_job_work_function):
            # When running the scan
            scanner.start_scans([server_scan])

            # It succeeds
            all_results = []
            for result in scanner.get_results():
                all_results.append(result)
            assert len(all_results) == 1

            # And there never was more than one thread (=1 job/connection) running at the same time
            assert queue.empty()
Example #17
0
    def test_emergency_shutdown(self, mock_scan_commands):
        # Given a lot of servers to scan
        total_server_scans_count = 100
        server_scans = [
            ServerScanRequest(
                server_info=ServerConnectivityInfoFactory.create(),
                scan_commands={ScanCommandForTests.MOCK_COMMAND_1, ScanCommandForTests.MOCK_COMMAND_2},
            )
            for _ in range(total_server_scans_count)
        ]

        # And the scans get queued
        scanner = Scanner()
        for scan in server_scans:
            scanner.queue_scan(scan)

        # When trying to quickly shutdown the scanner, it succeeds
        scanner.emergency_shutdown()

        # And all the queued jobs were done or cancelled
        all_queued_futures = []
        for server_scan in scanner._queued_server_scans:
            all_queued_futures.extend(server_scan.all_queued_scan_jobs)
        for completed_future in as_completed(all_queued_futures):
            assert completed_future.done()
Example #18
0
    def test_duplicate_server(self, mock_scan_commands):
        # Given a server to scan
        server_info = ServerConnectivityInfoFactory.create()

        # When trying to queue two scans for this server
        server_scan1 = ServerScanRequest(server_info=server_info, scan_commands={ScanCommandForTests.MOCK_COMMAND_1})
        server_scan2 = ServerScanRequest(server_info=server_info, scan_commands={ScanCommandForTests.MOCK_COMMAND_2})
        scanner = Scanner()
        scanner.queue_scan(server_scan1)

        # It fails
        with pytest.raises(ValueError):
            scanner.queue_scan(server_scan2)
Example #19
0
    def run(self):
        try:
            server_info = self.get_server_info()

            highest_tls_supported = str(
                server_info.tls_probing_result.highest_tls_version_supported
            ).split(".")[1]

            tls_supported = self.get_supported_tls(highest_tls_supported)
        except ConnectionToServerFailed as e:
            logging.error(f"Failed to connect to {self.domain}: {e}")
            return {}
        except ServerHostnameCouldNotBeResolved as e:
            logging.error(f"{self.domain} could not be resolved: {e}")
            return {}
        except gaierror as e:
            logging.error(
                f"Could not retrieve address info for {self.domain} {e}")
            return {}

        scanner = Scanner()

        designated_scans = set()

        # Scan for common vulnerabilities, certificate info, elliptic curves
        designated_scans.add(ScanCommand.OPENSSL_CCS_INJECTION)
        designated_scans.add(ScanCommand.HEARTBLEED)
        designated_scans.add(ScanCommand.CERTIFICATE_INFO)
        designated_scans.add(ScanCommand.ELLIPTIC_CURVES)

        # Test supported SSL/TLS
        if "SSL_2_0" in tls_supported:
            designated_scans.add(ScanCommand.SSL_2_0_CIPHER_SUITES)
        elif "SSL_3_0" in tls_supported:
            designated_scans.add(ScanCommand.SSL_3_0_CIPHER_SUITES)
        elif "TLS_1_0" in tls_supported:
            designated_scans.add(ScanCommand.TLS_1_0_CIPHER_SUITES)
        elif "TLS_1_1" in tls_supported:
            designated_scans.add(ScanCommand.TLS_1_1_CIPHER_SUITES)
        elif "TLS_1_2" in tls_supported:
            designated_scans.add(ScanCommand.TLS_1_2_CIPHER_SUITES)
        elif "TLS_1_3" in tls_supported:
            designated_scans.add(ScanCommand.TLS_1_3_CIPHER_SUITES)

        scan_request = ServerScanRequest(server_info=server_info,
                                         scan_commands=designated_scans)

        scanner.start_scans([scan_request])

        # Wait for asynchronous scans to complete
        # get_results() returns a generator with a single "ServerScanResult". We only want that object
        scan_results = [x for x in scanner.get_results()][0]
        logging.info("Scan results retrieved from generator")

        res = {
            "TLS": {
                "supported": tls_supported,
                "accepted_cipher_list": [],
                "rejected_cipher_list": [],
            }
        }

        # Parse scan results for required info
        for name, result in scan_results.scan_commands_results.items():

            # If CipherSuitesScanResults
            if name.endswith("suites"):
                logging.info("Parsing Cipher Suite Scan results...")

                for c in result.accepted_cipher_suites:
                    res["TLS"]["accepted_cipher_list"].append(
                        c.cipher_suite.name)

                for c in result.rejected_cipher_suites:
                    res["TLS"]["rejected_cipher_list"].append(
                        c.cipher_suite.name)

            elif name == "openssl_ccs_injection":
                logging.info(
                    "Parsing OpenSSL CCS Injection Vulnerability Scan results..."
                )
                res["is_vulnerable_to_ccs_injection"] = result.is_vulnerable_to_ccs_injection

            elif name == "heartbleed":
                logging.info(
                    "Parsing Heartbleed Vulnerability Scan results...")
                res["is_vulnerable_to_heartbleed"] = result.is_vulnerable_to_heartbleed

            elif name == "certificate_info":
                logging.info("Parsing Certificate Info Scan results...")
                try:
                    res["signature_algorithm"] = (
                        result.certificate_deployments[0].
                        verified_certificate_chain[0].signature_hash_algorithm.
                        __class__.__name__)
                except TypeError:
                    res["signature_algorithm"] = None

            else:
                logging.info("Parsing Elliptic Curve Scan results...")
                res["supports_ecdh_key_exchange"] = result.supports_ecdh_key_exchange
                res["supported_curves"] = []
                if result.supported_curves is not None:
                    for curve in result.supported_curves:
                        # sslyze returns ANSI curve names occaisionally
                        # In at least these two cases we can simply convert to
                        # using the equivalent SECG name, so that this aligns
                        # with CCCS guidance:
                        # https://datatracker.ietf.org/doc/html/rfc4492#appendix-A
                        if curve.name == "prime192v1":
                            res["supported_curves"].append("secp192r1")
                        elif curve.name == "prime256v1":
                            res["supported_curves"].append("secp256r1")
                        else:
                            res["supported_curves"].append(curve.name)

        return res
Example #20
0
class TLSProfiler:
    _SCT_REQUIRED_DATE = datetime(
        year=2018, month=4, day=1
    )  # SCTs are required after this date, see https://groups.google.com/a/chromium.org/forum/#!msg/ct-policy/sz_3W_xKBNY/6jq2ghJXBAAJ

    _SSL_SCAN_COMMANDS = {
        "SSLv2": ScanCommand.SSL_2_0_CIPHER_SUITES,
        "SSLv3": ScanCommand.SSL_3_0_CIPHER_SUITES,
        "TLSv1": ScanCommand.TLS_1_0_CIPHER_SUITES,
        "TLSv1.1": ScanCommand.TLS_1_1_CIPHER_SUITES,
        "TLSv1.2": ScanCommand.TLS_1_2_CIPHER_SUITES,
        "TLSv1.3": ScanCommand.TLS_1_3_CIPHER_SUITES,
    }

    _ALL_SCAN_COMMANDS = {
        ScanCommand.SSL_2_0_CIPHER_SUITES,
        ScanCommand.SSL_3_0_CIPHER_SUITES,
        ScanCommand.TLS_1_0_CIPHER_SUITES,
        ScanCommand.TLS_1_1_CIPHER_SUITES,
        ScanCommand.TLS_1_2_CIPHER_SUITES,
        ScanCommand.TLS_1_3_CIPHER_SUITES,
        ScanCommand.ELLIPTIC_CURVES,
        ScanCommand.HTTP_HEADERS,
        ScanCommand.CERTIFICATE_INFO,
        ScanCommand.HEARTBLEED,
        ScanCommand.ROBOT,
        ScanCommand.OPENSSL_CCS_INJECTION,
    }

    def __init__(
        self,
        domain: str,
        ca_file: Optional[str] = None,
        cert_expire_warning: int = 15,
    ) -> None:
        """
        :param domain: the domain name of the target server
        :param ca_file: Path to a trusted custom root certificates in PEM format.
        :param cert_expire_warning: A warning is issued if the certificate expires in less days than specified.
        """
        self._comparator = None
        self._validation_errors = None
        self._vulnerability_errors = None
        self._server_scan_result = None

        self._scan_commands_extra_args = {}
        if ca_file:
            ca_path = Path(ca_file)
            self._scan_commands_extra_args[
                ScanCommand.CERTIFICATE_INFO] = CertificateInfoExtraArguments(
                    ca_path)

        self._cert_expire_warning = cert_expire_warning

        server_location = (
            ServerNetworkLocationViaDirectConnection.with_ip_address_lookup(
                domain, 443))
        self._scanner = Scanner()
        try:
            log.info(
                f"Testing connectivity with {server_location.hostname}:{server_location.port}..."
            )
            self._server_info = ServerConnectivityTester().perform(
                server_location)
            self.server_error = None
        except ConnectionToServerFailed as e:
            # Could not establish an SSL connection to the server
            log.warning(
                f"Could not connect to {e.server_location.hostname}: {e.error_message}"
            )
            self.server_error = e.error_message
            self._server_info = None

    def scan_server(self) -> None:
        """
        This method scan the server's TLS settings and preprocesses
        the results to later compare them to a specific profile.
        """
        if self._server_info is None:
            return

        # run all scans together
        server_scan_req = ServerScanRequest(
            server_info=self._server_info,
            scan_commands=self._ALL_SCAN_COMMANDS,
            scan_commands_extra_arguments=self._scan_commands_extra_args,
        )
        self._scanner.queue_scan(server_scan_req)

        # We take the first result because only one server was queued
        self._server_scan_result = next(
            self._scanner.get_results())  # type: ServerScanResult

        # preprocess scan results
        (
            supported_ciphers,
            supported_protocols,
            supported_key_exchange,
            server_preferred_order,
        ) = self._preprocess_ciphers_and_protocols()
        supported_ecdh_curves = self._preprocess_ecdh_curves()
        certificate_obj = self._preprocess_certificate()
        hsts_header = self._preprocess_hsts_header()

        # Initialize the comparator class to compare
        # the TLS settings to a specific profile later.
        self._comparator = Comparator(
            supported_ciphers,
            supported_protocols,
            supported_key_exchange,
            supported_ecdh_curves,
            server_preferred_order,
            certificate_obj,
            self._cert_expire_warning,
            hsts_header,
        )

        self._validation_errors = self._validate_certificate(certificate_obj)
        self._vulnerability_errors = self._check_vulnerabilities()

    def compare_to_profile(self,
                           target_profile_name: str) -> TLSProfilerResult:
        """
        Uses the stored scan results from the 'scan_server()' method
        to compare them to a specific Mozilla TLS profile.

        :param target_profile_name: The target Mozilla TLS profile: one of [old|intermediate|modern].
        :rtype: TLSProfilerResult
        """
        if not self._comparator:
            return

        # compare the scan results with the target profile
        profile_errors, certificate_warnings = self._comparator.compare(
            target_profile_name)

        return TLSProfilerResult(
            target_profile_name,
            self._validation_errors,
            certificate_warnings,
            profile_errors,
            self._vulnerability_errors,
        )

    def _get_result(self, command: ScanCommandType):
        return self._server_scan_result.scan_commands_results[command]

    def _preprocess_ciphers_and_protocols(
        self,
    ) -> Tuple[Dict[str, List[str]], Set[str], List[Tuple[
            EphemeralKeyInfo, str]], Dict[str, CipherSuiteAcceptedByServer], ]:
        supported_ciphers = dict()
        supported_protocols = []
        supported_key_exchange = []
        server_preferred_order = dict()
        for name, command in self._SSL_SCAN_COMMANDS.items():
            log.debug(f"Testing protocol {name}")
            result = self._get_result(command)  # type: CipherSuitesScanResult
            ciphers = [
                cipher.cipher_suite.openssl_name
                for cipher in result.accepted_cipher_suites
            ]
            supported_ciphers[name] = ciphers
            # NOTE: In the newest sslyze version we only get the key
            # exchange parameters for ephemeral key exchanges.
            # We do not get any parameters for finite field DH with
            # static parameters.
            key_exchange = [(cipher.ephemeral_key,
                             cipher.cipher_suite.openssl_name)
                            for cipher in result.accepted_cipher_suites
                            if cipher.ephemeral_key]
            supported_key_exchange.extend(key_exchange)
            server_preferred_order[
                name] = result.cipher_suite_preferred_by_server
            if result.is_tls_protocol_version_supported:
                supported_protocols.append(name)

        return (
            supported_ciphers,
            set(supported_protocols),
            supported_key_exchange,
            server_preferred_order,
        )

    def _preprocess_ecdh_curves(self) -> Set[str]:
        # get all supported curves
        ecdh_scan_result = self._get_result(
            ScanCommand.ELLIPTIC_CURVES
        )  # type: SupportedEllipticCurvesScanResult
        supported_curves = {}
        if ecdh_scan_result.supported_curves:
            supported_curves = [
                curve.name for curve in ecdh_scan_result.supported_curves
            ]
            supported_curves = set(
                utils.get_equivalent_curves(supported_curves))

        return supported_curves

    def _preprocess_certificate(self) -> CertificateDeploymentAnalysisResult:
        result = self._get_result(
            ScanCommand.CERTIFICATE_INFO)  # type: CertificateInfoScanResult

        # TODO if there are multiple certificates analyze all of them
        certificate_obj = result.certificate_deployments[0]

        return certificate_obj

    def _preprocess_hsts_header(self) -> StrictTransportSecurityHeader:
        result = self._get_result(
            ScanCommand.HTTP_HEADERS)  # type: HttpHeadersScanResult

        return result.strict_transport_security_header

    def _validate_certificate(
            self,
            certificate_obj: CertificateDeploymentAnalysisResult) -> List[str]:

        certificate = certificate_obj.received_certificate_chain[0]

        validation_errors = []

        for r in certificate_obj.path_validation_results:  # type: PathValidationResult
            if not r.was_validation_successful:
                validation_errors.append(
                    f"Validation not successful: {r.openssl_error_string} (trust store {r.trust_store.name})"
                )

        # TODO check how to implement this with sslyze 3.1.0
        """
        if certificate0.path_validation_error_list:
            validation_errors = (
                fail.error_message for fail in certificate0.path_validation_error_list
            )
            validation_errors.append(
                f'Validation failed: {", ".join(validation_errors)}'
            )
        """

        if not certificate_obj.leaf_certificate_subject_matches_hostname:
            validation_errors.append(
                f"Leaf certificate subject does not match hostname!")

        if not certificate_obj.received_chain_has_valid_order:
            validation_errors.append(f"Certificate chain has wrong order.")

        if certificate_obj.verified_chain_has_sha1_signature:
            validation_errors.append(f"SHA1 signature found in chain.")

        if certificate_obj.verified_chain_has_legacy_symantec_anchor:
            validation_errors.append(
                f"Symantec legacy certificate found in chain.")

        sct_count = certificate_obj.leaf_certificate_signed_certificate_timestamps_count
        if sct_count < 2 and certificate.not_valid_before >= self._SCT_REQUIRED_DATE:
            validation_errors.append(
                f"Certificates issued on or after 2018-04-01 need certificate transparency, "
                f"i.e., two signed SCTs in certificate. Leaf certificate only has {sct_count}."
            )

        if len(validation_errors) == 0:
            log.debug(f"Certificate is ok")
        else:
            log.debug(f"Error validating certificate")
            for error in validation_errors:
                log.debug(f"  → {error}")

        return validation_errors

    def _check_vulnerabilities(self):
        errors = []

        result = self._get_result(
            ScanCommand.HEARTBLEED)  # type: HeartbleedScanResult

        if result.is_vulnerable_to_heartbleed:
            errors.append(f"Server is vulnerable to Heartbleed attack")

        result = self._get_result(ScanCommand.OPENSSL_CCS_INJECTION
                                  )  # type: OpenSslCcsInjectionScanResult

        if result.is_vulnerable_to_ccs_injection:
            errors.append(
                f"Server is vulnerable to OpenSSL CCS Injection (CVE-2014-0224)"
            )

        result = self._get_result(ScanCommand.ROBOT)  # type: RobotScanResult

        if result.robot_result in [
                RobotScanResultEnum.VULNERABLE_WEAK_ORACLE,
                RobotScanResultEnum.VULNERABLE_STRONG_ORACLE,
        ]:
            errors.append(f"Server is vulnerable to ROBOT attack.")

        return errors
Example #21
0
class TLSProfiler:
    PROFILES_URL = "https://ssl-config.mozilla.org/guidelines/5.6.json"
    PROFILES = None
    SCT_REQUIRED_DATE = datetime(
        year=2018, month=4, day=1
    )  # SCTs are required after this date, see https://groups.google.com/a/chromium.org/forum/#!msg/ct-policy/sz_3W_xKBNY/6jq2ghJXBAAJ

    SSL_SCAN_COMMANDS = {
        "SSLv2": ScanCommand.SSL_2_0_CIPHER_SUITES,
        "SSLv3": ScanCommand.SSL_3_0_CIPHER_SUITES,
        "TLSv1": ScanCommand.TLS_1_0_CIPHER_SUITES,
        "TLSv1.1": ScanCommand.TLS_1_1_CIPHER_SUITES,
        "TLSv1.2": ScanCommand.TLS_1_2_CIPHER_SUITES,
        "TLSv1.3": ScanCommand.TLS_1_3_CIPHER_SUITES,
    }

    ALL_SCAN_COMMANDS = {
        ScanCommand.SSL_2_0_CIPHER_SUITES,
        ScanCommand.SSL_3_0_CIPHER_SUITES,
        ScanCommand.TLS_1_0_CIPHER_SUITES,
        ScanCommand.TLS_1_1_CIPHER_SUITES,
        ScanCommand.TLS_1_2_CIPHER_SUITES,
        ScanCommand.TLS_1_3_CIPHER_SUITES,
        ScanCommand.ELLIPTIC_CURVES,
        ScanCommand.HTTP_HEADERS,
        ScanCommand.CERTIFICATE_INFO,
        ScanCommand.HEARTBLEED,
        ScanCommand.ROBOT,
        ScanCommand.OPENSSL_CCS_INJECTION,
    }

    def __init__(
        self,
        domain: str,
        target_profile: str,
        ca_file: Optional[str] = None,
        cert_expire_warning: int = 15,
    ) -> None:
        """
        :param domain:
        :param target_profile: One of [old|intermediate|modern]
        :param ca_file: Path to a trusted custom root certificates in PEM format.
        :param cert_expire_warning: A warning is issued if the certificate expires in less days than specified.
        """
        self.scan_commands_extra_args = {}
        if ca_file:
            ca_path = Path(ca_file)
            self.scan_commands_extra_args[
                ScanCommand.CERTIFICATE_INFO] = CertificateInfoExtraArguments(
                    ca_path)

        self.cert_expire_warning = cert_expire_warning

        if TLSProfiler.PROFILES is None:
            TLSProfiler.PROFILES = requests.get(self.PROFILES_URL).json()
            log.info(
                f"Loaded version {TLSProfiler.PROFILES['version']} of the Mozilla TLS configuration recommendations."
            )

        self.target_profile = TLSProfiler.PROFILES["configurations"][
            target_profile]
        self.target_profile["tls_curves"] = self._get_equivalent_curves(
            self.target_profile["tls_curves"])
        self.target_profile[
            "certificate_curves_preprocessed"] = self._get_equivalent_curves(
                self.target_profile["certificate_curves"])

        server_location = (
            ServerNetworkLocationViaDirectConnection.with_ip_address_lookup(
                domain, 443))
        self.scanner = Scanner()
        try:
            log.info(
                f"Testing connectivity with {server_location.hostname}:{server_location.port}..."
            )
            self.server_info = ServerConnectivityTester().perform(
                server_location)
            self.server_error = None
        except ConnectionToServerFailed as e:
            # Could not establish an SSL connection to the server
            log.warning(
                f"Could not connect to {e.server_location.hostname}: {e.error_message}"
            )
            self.server_error = e.error_message
            self.server_info = None

    def _get_equivalent_curves(self, curves: List[str]) -> Optional[List[str]]:
        if not curves:
            return None

        curves_tmp = curves.copy()
        for curve in curves:
            for curve_tuple in _EQUIVALENT_CURVES:
                if curve == curve_tuple[0]:
                    curves_tmp.append(curve_tuple[1])
                elif curve == curve_tuple[1]:
                    curves_tmp.append(curve_tuple[0])
        return curves_tmp

    def run(self) -> TLSProfilerResult:
        if self.server_info is None:
            return

        # run all scans together
        server_scan_req = ServerScanRequest(
            server_info=self.server_info,
            scan_commands=self.ALL_SCAN_COMMANDS,
            scan_commands_extra_arguments=self.scan_commands_extra_args,
        )
        self.scanner.queue_scan(server_scan_req)

        # We take the first result because only one server was queued
        self.server_scan_result = next(
            self.scanner.get_results())  # type: ServerScanResult

        (
            validation_errors,
            cert_profile_error,
            cert_warnings,
            pub_key_type,
        ) = self._check_certificate()
        hsts_errors = self._check_hsts_age()
        self._preprocess_ciphers_and_protocols()
        profile_errors = self._check_server_matches_profile(pub_key_type)
        vulnerability_errors = self._check_vulnerabilities()

        return TLSProfilerResult(
            validation_errors,
            cert_warnings,
            profile_errors + hsts_errors + cert_profile_error,
            vulnerability_errors,
        )

    def _get_result(self, command: ScanCommandType):
        return self.server_scan_result.scan_commands_results[command]

    def _preprocess_ciphers_and_protocols(self):
        supported_ciphers = dict()
        supported_protocols = []
        supported_key_exchange = []
        server_preferred_order = dict()
        for name, command in self.SSL_SCAN_COMMANDS.items():
            log.debug(f"Testing protocol {name}")
            result = self._get_result(command)  # type: CipherSuitesScanResult
            ciphers = [
                cipher.cipher_suite.openssl_name
                for cipher in result.accepted_cipher_suites
            ]
            supported_ciphers[name] = ciphers
            # NOTE: In the newest sslyze version we only get the key
            # exchange parameters for ephemeral key exchanges.
            # We do not get any parameters for finite field DH with
            # static parameters.
            key_exchange = [(cipher.ephemeral_key,
                             cipher.cipher_suite.openssl_name)
                            for cipher in result.accepted_cipher_suites
                            if cipher.ephemeral_key]
            supported_key_exchange.extend(key_exchange)
            server_preferred_order[
                name] = result.cipher_suite_preferred_by_server
            if result.is_tls_protocol_version_supported:
                supported_protocols.append(name)

        self.supported_ciphers = supported_ciphers
        self.supported_protocols = set(supported_protocols)
        self.supported_key_exchange = (
            supported_key_exchange
        )  # type: List[(Optional[EphemeralKeyInfo], str)]
        self.server_preferred_order = server_preferred_order

    def _check_pub_key_supports_cipher(self, cipher: str,
                                       pub_key_type: str) -> bool:
        """
        Checks if cipher suite works with the servers certificate (for TLS 1.2 and older).
        Source: https://wiki.mozilla.org/Security/Server_Side_TLS, https://tools.ietf.org/html/rfc5246#appendix-A.5
        :param cipher: OpenSSL cipher name
        :param pub_key_type:
        :return:
        """
        if "anon" in cipher:
            return True
        elif pub_key_type in cipher:
            return True
        elif pub_key_type == "RSA" and "ECDSA" not in cipher and "DSS" not in cipher:
            return True

        return False

    def _check_protocols(self) -> List[str]:
        errors = []

        # match supported TLS versions
        allowed_protocols = set(self.target_profile["tls_versions"])
        illegal_protocols = self.supported_protocols - allowed_protocols
        missing_protocols = allowed_protocols - self.supported_protocols

        for protocol in illegal_protocols:
            errors.append(f"Must not support {protocol}")

        for protocol in missing_protocols:
            errors.append(f"Must support {protocol}")

        return errors

    def _check_cipher_suites_and_order(self, pub_key_type: str) -> List[str]:
        errors = []

        # match supported cipher suite order for each supported protocol
        all_supported_ciphers = []
        for protocol, supported_ciphers in self.supported_ciphers.items():
            all_supported_ciphers.extend(supported_ciphers)

            if protocol in self.supported_protocols:
                allowed_ciphers = self.target_profile["ciphers"]["openssl"]

                # check if the server chooses the cipher suite
                if (self.target_profile["server_preferred_order"]
                        and not self.server_preferred_order[protocol]):
                    errors.append(
                        f"Server must choose the cipher suite, not the client (Protocol {protocol})"
                    )

                # check if the client chooses the cipher suite
                if (not self.target_profile["server_preferred_order"]
                        and self.server_preferred_order[protocol]):
                    errors.append(
                        f"Client must choose the cipher suite, not the server (Protocol {protocol})"
                    )

                # check whether the servers preferred cipher suite preference is correct
                if (self.target_profile["server_preferred_order"]
                        and self.server_preferred_order[protocol]
                        and not utils.check_cipher_order(
                            allowed_ciphers, supported_ciphers)):
                    # TODO wait for sslyze 3.1.1
                    errors.append(
                        f"Server has the wrong cipher suites order (Protocol {protocol})"
                    )

        # find cipher suites that should not be supported
        allowed_ciphers = (self.target_profile["ciphersuites"] +
                           self.target_profile["ciphers"]["openssl"])
        illegal_ciphers = set(all_supported_ciphers) - set(allowed_ciphers)
        for cipher in illegal_ciphers:
            errors.append(f"Must not support {cipher}")

        # find missing cipher suites
        missing_ciphers = set(allowed_ciphers) - set(all_supported_ciphers)
        for cipher in missing_ciphers:
            if self._check_pub_key_supports_cipher(cipher, pub_key_type):
                errors.append(f"Must support {cipher}")

        return errors

    def _check_dh_parameters(self) -> List[str]:
        errors = []

        # match DHE parameters
        for (
                key_info,
                cipher,
        ) in self.supported_key_exchange:  # type: (Optional[EphemeralKeyInfo], str)
            if (key_info.type == OpenSslEvpPkeyEnum.DH
                    and not self.target_profile["dh_param_size"]):
                errors.append(f"Must not support finite field DH key exchange")
                break
            elif (key_info.type == OpenSslEvpPkeyEnum.DH
                  and key_info.size != self.target_profile["dh_param_size"]):
                errors.append(
                    f"Wrong DHE parameter size {key_info.size} for cipher {cipher}"
                    f", should be {self.target_profile['dh_param_size']}")

        return errors

    def _check_ecdh_curves(self) -> List[str]:
        errors = []

        # get all supported curves
        ecdh_scan_result = self._get_result(
            ScanCommand.ELLIPTIC_CURVES
        )  # type: SupportedEllipticCurvesScanResult
        supported_curves = []
        if ecdh_scan_result.supported_curves:
            supported_curves = [
                curve.name for curve in ecdh_scan_result.supported_curves
            ]
            supported_curves = set(
                self._get_equivalent_curves(supported_curves))

        # get allowed curves
        allowed_curves = self.target_profile["tls_curves"]
        allowed_curves = set(self._get_equivalent_curves(allowed_curves))

        not_allowed_curves = supported_curves - allowed_curves
        missing_curves = allowed_curves - supported_curves

        # report errors
        for curve in not_allowed_curves:
            errors.append(
                f"Must not support ECDH curve {curve} for key exchange")

        for curve in missing_curves:
            errors.append(f"Must support ECDH curve {curve} for key exchange")

        return errors

    def _check_server_matches_profile(self, pub_key_type: str):
        errors = []

        errors.extend(self._check_protocols())

        errors.extend(self._check_cipher_suites_and_order(pub_key_type))

        errors.extend(self._check_dh_parameters())

        errors.extend(self._check_ecdh_curves())

        return errors

    def _cert_type_string(self, pub_key) -> str:
        if isinstance(pub_key, rsa.RSAPublicKey):
            return "RSA"
        elif isinstance(pub_key, ec.EllipticCurvePublicKey):
            return "ECDSA"
        elif isinstance(pub_key, ed25519.Ed25519PublicKey):
            return "ED25519"
        elif isinstance(pub_key, ed448.Ed448PublicKey):
            return "ED448"
        elif isinstance(pub_key, dsa.DSAPublicKey):
            return "DSA"

        return ""

    def _check_certificate_properties(
            self, certificate: Certificate,
            ocsp_stapling: bool) -> Tuple[List[str], List[str], str]:
        errors = []
        warnings = []

        # check certificate lifespan
        lifespan = certificate.not_valid_after - certificate.not_valid_before
        if self.target_profile["maximum_certificate_lifespan"] < lifespan.days:
            errors.append(
                f"Certificate lifespan too long (is {lifespan.days}, "
                f"should be less than {self.target_profile['maximum_certificate_lifespan']})"
            )
        elif (self.target_profile["recommended_certificate_lifespan"]
              and self.target_profile["recommended_certificate_lifespan"] <
              lifespan.days):
            warnings.append(
                f"Certificate lifespan is {lifespan.days} days but the recommended lifespan is {self.target_profile['recommended_certificate_lifespan']} days."
            )

        current_time = datetime.now()
        days_before_expire = certificate.not_valid_after - current_time
        if days_before_expire.days < self.cert_expire_warning:
            warnings.append(
                f"Certificate expires in {days_before_expire.days} days")

        # check certificate public key type
        pub_key_type = self._cert_type_string(certificate.public_key())
        if pub_key_type.lower(
        ) not in self.target_profile["certificate_types"]:
            errors.append(
                f"Wrong certificate type (is {pub_key_type}), "
                f"should be one of {self.target_profile['certificate_types']}")

        # check key property
        pub_key = certificate.public_key()
        if (isinstance(pub_key, rsa.RSAPublicKey)
                and self.target_profile["rsa_key_size"]
                and pub_key.key_size != self.target_profile["rsa_key_size"]):
            errors.append(
                f"RSA certificate has wrong key size (is {pub_key.key_size}, "
                f"should be {self.target_profile['rsa_key_size']})")
        elif (isinstance(pub_key, ec.EllipticCurvePublicKey)
              and self.target_profile["certificate_curves"]
              and pub_key.curve.name
              not in self.target_profile["certificate_curves_preprocessed"]):
            errors.append(
                f"ECDSA certificate uses wrong curve "
                f"(is {pub_key.curve.name}, should be one of {self.target_profile['certificate_curves']})"
            )

        # check certificate signature
        if (certificate.signature_algorithm_oid._name
                not in self.target_profile["certificate_signatures"]):
            errors.append(
                f"Certificate has a wrong signature (is {certificate.signature_algorithm_oid._name}), "
                f"should be one of {self.target_profile['certificate_signatures']}"
            )

        # check if ocsp stabling is supported
        if ocsp_stapling != self.target_profile["ocsp_staple"]:
            if self.target_profile["ocsp_staple"]:
                errors.append(f"OCSP stapling must be supported")
            else:
                errors.append(f"OCSP stapling should not be supported")

        return errors, warnings, pub_key_type

    def _check_certificate(
            self) -> Tuple[List[str], List[str], List[str], str]:
        result = self._get_result(
            ScanCommand.CERTIFICATE_INFO)  # type: CertificateInfoScanResult

        # TODO if there are multiple certificates analyze all of them
        certificate0 = result.certificate_deployments[0]

        validation_errors = []

        certificate = certificate0.received_certificate_chain[0]
        (
            profile_errors,
            cert_warnings,
            pub_key_type,
        ) = self._check_certificate_properties(
            certificate, certificate0.ocsp_response_is_trusted)

        for r in certificate0.path_validation_results:  # type: PathValidationResult
            if not r.was_validation_successful:
                validation_errors.append(
                    f"Validation not successful: {r.openssl_error_string} (trust store {r.trust_store.name})"
                )

        # TODO check how to implement this with sslyze 3.1.0
        """
        if certificate0.path_validation_error_list:
            validation_errors = (
                fail.error_message for fail in certificate0.path_validation_error_list
            )
            validation_errors.append(
                f'Validation failed: {", ".join(validation_errors)}'
            )
        """

        if not certificate0.leaf_certificate_subject_matches_hostname:
            validation_errors.append(
                f"Leaf certificate subject does not match hostname!")

        if not certificate0.received_chain_has_valid_order:
            validation_errors.append(f"Certificate chain has wrong order.")

        if certificate0.verified_chain_has_sha1_signature:
            validation_errors.append(f"SHA1 signature found in chain.")

        if certificate0.verified_chain_has_legacy_symantec_anchor:
            validation_errors.append(
                f"Symantec legacy certificate found in chain.")

        sct_count = certificate0.leaf_certificate_signed_certificate_timestamps_count
        if sct_count < 2 and certificate.not_valid_before >= self.SCT_REQUIRED_DATE:
            validation_errors.append(
                f"Certificates issued on or after 2018-04-01 need certificate transparency, "
                f"i.e., two signed SCTs in certificate. Leaf certificate only has {sct_count}."
            )

        if len(validation_errors) == 0:
            log.debug(f"Certificate is ok")
        else:
            log.debug(f"Error validating certificate")
            for error in validation_errors:
                log.debug(f"  → {error}")

        return validation_errors, profile_errors, cert_warnings, pub_key_type

    def _check_vulnerabilities(self):
        errors = []

        result = self._get_result(
            ScanCommand.HEARTBLEED)  # type: HeartbleedScanResult

        if result.is_vulnerable_to_heartbleed:
            errors.append(f"Server is vulnerable to Heartbleed attack")

        result = self._get_result(ScanCommand.OPENSSL_CCS_INJECTION
                                  )  # type: OpenSslCcsInjectionScanResult

        if result.is_vulnerable_to_ccs_injection:
            errors.append(
                f"Server is vulnerable to OpenSSL CCS Injection (CVE-2014-0224)"
            )

        result = self._get_result(ScanCommand.ROBOT)  # type: RobotScanResult

        if result.robot_result in [
                RobotScanResultEnum.VULNERABLE_WEAK_ORACLE,
                RobotScanResultEnum.VULNERABLE_STRONG_ORACLE,
        ]:
            errors.append(f"Server is vulnerable to ROBOT attack.")

        return errors

    def _check_hsts_age(self) -> List[str]:
        result = self._get_result(
            ScanCommand.HTTP_HEADERS)  # type: HttpHeadersScanResult

        errors = []

        if result.strict_transport_security_header:
            if (result.strict_transport_security_header.max_age <
                    self.target_profile["hsts_min_age"]):
                errors.append(
                    f"wrong HSTS age (is {result.strict_transport_security_header.max_age}, "
                    f"should be at least {self.target_profile['hsts_min_age']})"
                )
        else:
            errors.append(f"HSTS header not set")

        return errors
Example #22
0
def https_check(endpoint):
    """
    Uses sslyze to figure out the reason the endpoint wouldn't verify.
    """

    # remove the https:// from prefix for sslyze
    try:
        hostname = endpoint.url[8:]
        server_location = (
            ServerNetworkLocationViaDirectConnection.with_ip_address_lookup(
                hostname, 443))
        server_tester = ServerConnectivityTester()
        server_info = server_tester.perform(server_location)
        endpoint.live = True
        ip = server_location.ip_address
        if endpoint.ip is None:
            endpoint.ip = ip
        else:
            if endpoint.ip != ip:
                logging.debug(
                    "{}: Endpoint IP is already {}, but requests IP is {}.".
                    format(endpoint.url, endpoint.ip, ip))
        if server_info.tls_probing_result.client_auth_requirement.name == "REQUIRED":
            endpoint.https_client_auth_required = True
            logging.debug("{}: Client Authentication REQUIRED".format(
                endpoint.url))
    except ConnectionToServerFailed as err:
        endpoint.live = False
        endpoint.https_valid = False
        logging.debug("{}: Error in sslyze server connectivity check".format(
            endpoint.url))
        return
    except Exception as err:
        endpoint.unknown_error = True
        logging.debug(
            "{}: Unknown exception in sslyze server connectivity check.".
            format(endpoint.url))
        return

    try:
        cert_plugin_result = None
        command = ScanCommand.CERTIFICATE_INFO
        scanner = Scanner()
        scan_request = ServerScanRequest(server_info=server_info,
                                         scan_commands=[command])
        scanner.queue_scan(scan_request)
        # Retrieve results from generator object
        scan_result = [x for x in scanner.get_results()][0]
        cert_plugin_result = scan_result.scan_commands_results.get(
            "certificate_info", None)
    except Exception as err:
        try:
            if "timed out" in str(err):
                logging.debug(
                    "{}: Retrying sslyze scanner certificate plugin.".format(
                        endpoint.url))
                scanner.queue_scan(scan_request)
                # Retrieve results from generator object
                scan_result = [x for x in scanner.get_results()][0]
                cert_plugin_result = scan_result.scan_commands_results.get(
                    "certificate_info", None)
            else:
                logging.debug(
                    "{}: Unknown exception in sslyze scanner certificate plugin."
                    .format(endpoint.url))
                endpoint.unknown_error = True
                # We could make this False, but there was an error so
                # we don't know
                endpoint.https_valid = None
                return
        except Exception:
            logging.debug(
                "{}: Unknown exception in sslyze scanner certificate plugin.".
                format(endpoint.url))
            endpoint.unknown_error = True
            # We could make this False, but there was an error so we
            # don't know
            endpoint.https_valid = None
            return

    try:
        public_trust = True
        custom_trust = True
        public_not_trusted_string = ""
        if cert_plugin_result is not None:
            validation_results = cert_plugin_result.certificate_deployments[
                0].path_validation_results
        else:
            validation_results = []
        for result in validation_results:
            if result.was_validation_successful:
                # We're assuming that it is trusted to start with
                pass
            else:
                if "Custom" in result.trust_store.name:
                    custom_trust = False
                else:
                    public_trust = False
                    if len(public_not_trusted_string) > 0:
                        public_not_trusted_string += ", "
                    public_not_trusted_string += result.trust_store.name
        if public_trust:
            logging.debug(
                "{}: Publicly trusted by common trust stores.".format(
                    endpoint.url))
        else:
            logging.debug(
                "{}: Not publicly trusted - not trusted by {}.".format(
                    endpoint.url, public_not_trusted_string))
        custom_trust = None
        endpoint.https_public_trusted = public_trust
        endpoint.https_custom_trusted = custom_trust
    except Exception as err:
        # Ignore exception
        logging.debug("{}: Unknown exception examining trust.".format(
            endpoint.url))

    # Default endpoint assessments to False until proven True.
    endpoint.https_expired_cert = False
    endpoint.https_self_signed_cert = False
    endpoint.https_bad_chain = False
    endpoint.https_bad_hostname = False
    endpoint.https_cert_revoked = False

    cert_chain = cert_plugin_result.certificate_deployments[
        0].received_certificate_chain

    # Check for missing SAN (Leaf certificate)
    leaf_cert = cert_chain[0]
    # Extract Subject Alternative Names
    san_list = extract_dns_subject_alternative_names(leaf_cert)
    # If an empty list was return, SAN(s) are missing. Bad hostname
    if isinstance(san_list, list) and len(san_list) == 0:
        endpoint.https_bad_hostname = True

    # If leaf certificate subject does NOT match hostname, bad hostname
    if not cert_plugin_result.certificate_deployments[
            0].leaf_certificate_subject_matches_hostname:
        endpoint.https_bad_hostname = True

    try:
        endpoint.https_cert_revoked = query_crlite(
            leaf_cert.public_bytes(Encoding.PEM))
    except ValueError as e:
        logging.debug(
            f"Error while checking revocation status for {endpoint.url}: {str(e)}"
        )
        endpoint.https_cert_revoked = None

    # Check for leaf certificate expiration/self-signature.
    if leaf_cert.not_valid_after < datetime.datetime.now():
        endpoint.https_expired_cert = True

    # Check to see if the cert is self-signed
    if leaf_cert.issuer is leaf_cert.subject:
        endpoint.https_self_signed_cert = True

    # Check certificate chain
    for cert in cert_chain[1:]:
        # Check for certificate expiration
        if cert.not_valid_after < datetime.datetime.now():
            endpoint.https_bad_chain = True

        # Check to see if the cert is self-signed
        if cert.issuer is (cert.subject or None):
            endpoint.https_bad_chain = True

    try:
        endpoint.https_cert_chain_len = len(
            cert_plugin_result.certificate_deployments[0].
            received_certificate_chain)
        if endpoint.https_self_signed_cert is False and (len(
                cert_plugin_result.certificate_deployments[0].
                received_certificate_chain) < 2):
            # *** TODO check that it is not a bad hostname and that the root cert is trusted before suggesting that it is an intermediate cert issue.
            endpoint.https_missing_intermediate_cert = True
            if cert_plugin_result.verified_certificate_chain is None:
                logging.debug(
                    "{}: Untrusted certificate chain, probably due to missing intermediate certificate."
                    .format(endpoint.url))
                logging.debug(
                    "{}: Only {} certificates in certificate chain received.".
                    format(
                        endpoint.url,
                        cert_plugin_result.received_certificate_chain.__len__(
                        ),
                    ))
        else:
            endpoint.https_missing_intermediate_cert = False
    except Exception:
        logging.debug("Error while determining length of certificate chain")

    # If anything is wrong then https is not valid
    if (endpoint.https_expired_cert or endpoint.https_self_signed_cert
            or endpoint.https_bad_chain or endpoint.https_bad_hostname):
        endpoint.https_valid = False